File size: 15,077 Bytes
8f849bf
 
 
 
 
 
 
 
 
 
 
9159ecd
8f849bf
9159ecd
 
 
 
8f849bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9159ecd
 
8f849bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9159ecd
8f849bf
 
 
9159ecd
8f849bf
 
 
9159ecd
8f849bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9159ecd
8f849bf
 
 
 
9159ecd
8f849bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9159ecd
 
8f849bf
 
 
 
 
 
9159ecd
8f849bf
9159ecd
8f849bf
9159ecd
8f849bf
 
 
 
 
 
 
 
 
 
 
9159ecd
8f849bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9159ecd
8f849bf
 
 
 
 
 
9159ecd
8f849bf
 
 
 
9159ecd
8f849bf
 
 
 
 
 
 
9159ecd
 
8f849bf
 
 
 
 
 
 
 
 
 
9159ecd
8f849bf
 
 
 
 
 
 
9159ecd
8f849bf
 
 
9159ecd
 
 
 
8f849bf
 
9159ecd
 
 
 
8f849bf
9159ecd
8f849bf
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import os
import sys
import torch
import base64
import io
import json
from PIL import Image
import svgwrite
import numpy as np
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
import random
import math

class EndpointHandler:
    def __init__(self, path=""):
        """Initialize DiffSketcher handler for Hugging Face Inference API"""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        # Initialize Stable Diffusion pipeline
        try:
            self.pipe = StableDiffusionPipeline.from_pretrained(
                "runwayml/stable-diffusion-v1-5",
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                safety_checker=None,
                requires_safety_checker=False
            )
            self.pipe = self.pipe.to(self.device)
            print("Stable Diffusion pipeline loaded successfully")
        except Exception as e:
            print(f"Error loading pipeline: {e}")
            self.pipe = None
        
        # Initialize tokenizer and text encoder
        try:
            self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
            self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
            self.text_encoder = self.text_encoder.to(self.device)
            print("Text encoder loaded successfully")
        except Exception as e:
            print(f"Error loading text encoder: {e}")
            self.tokenizer = None
            self.text_encoder = None

    def __call__(self, data):
        """Generate SVG sketch from text prompt"""
        try:
            # Extract inputs
            inputs = data.get("inputs", "")
            parameters = data.get("parameters", {})
            
            if isinstance(inputs, dict):
                prompt = inputs.get("prompt", inputs.get("text", ""))
            else:
                prompt = str(inputs)
            
            if not prompt:
                prompt = "a simple sketch"
            
            # Extract parameters
            num_paths = parameters.get("num_paths", 96)
            num_iter = parameters.get("num_iter", 500)
            guidance_scale = parameters.get("guidance_scale", 7.5)
            width = parameters.get("width", 224)
            height = parameters.get("height", 224)
            seed = parameters.get("seed", 42)
            
            # Set seed for reproducibility
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            
            print(f"Generating SVG for prompt: '{prompt}' with {num_paths} paths")
            
            # Generate SVG
            svg_content = self.generate_svg_sketch(
                prompt, num_paths, num_iter, guidance_scale, width, height
            )
            
            # Convert SVG to base64 for transmission
            svg_base64 = base64.b64encode(svg_content.encode('utf-8')).decode('utf-8')
            
            # Return result
            result = {
                "svg": svg_content,
                "svg_base64": svg_base64,
                "prompt": prompt,
                "parameters": {
                    "num_paths": num_paths,
                    "num_iter": num_iter,
                    "guidance_scale": guidance_scale,
                    "width": width,
                    "height": height,
                    "seed": seed
                }
            }
            
            return result
            
        except Exception as e:
            print(f"Error in handler: {e}")
            # Return a simple fallback SVG
            fallback_svg = self.create_fallback_svg(prompt, width, height)
            return {
                "svg": fallback_svg,
                "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
                "prompt": prompt,
                "error": str(e)
            }

    def generate_svg_sketch(self, prompt, num_paths, num_iter, guidance_scale, width, height):
        """Generate SVG sketch using simplified DiffSketcher approach"""
        try:
            # Get text embeddings
            text_embeddings = self.get_text_embeddings(prompt)
            
            # Generate attention maps (simplified)
            attention_maps = self.generate_attention_maps(prompt, width, height)
            
            # Initialize SVG paths based on attention
            paths = self.initialize_paths_from_attention(attention_maps, num_paths, width, height)
            
            # Optimize paths (simplified version)
            optimized_paths = self.optimize_paths(paths, text_embeddings, num_iter, guidance_scale)
            
            # Create SVG
            svg_content = self.create_svg_from_paths(optimized_paths, width, height)
            
            return svg_content
            
        except Exception as e:
            print(f"Error in generate_svg_sketch: {e}")
            return self.create_fallback_svg(prompt, width, height)

    def get_text_embeddings(self, prompt):
        """Get text embeddings from CLIP"""
        if self.tokenizer is None or self.text_encoder is None:
            return None
            
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                embeddings = self.text_encoder(**inputs).last_hidden_state
            
            return embeddings
        except Exception as e:
            print(f"Error getting text embeddings: {e}")
            return None

    def generate_attention_maps(self, prompt, width, height):
        """Generate simplified attention maps"""
        # Create attention maps based on prompt keywords
        attention_map = np.zeros((height, width))
        
        # Simple keyword-based attention
        keywords = prompt.lower().split()
        
        for i, keyword in enumerate(keywords[:5]):  # Limit to 5 keywords
            # Create attention region for each keyword
            center_x = (i + 1) * width // (len(keywords) + 1)
            center_y = height // 2
            
            # Create Gaussian-like attention
            y, x = np.ogrid[:height, :width]
            mask = ((x - center_x) ** 2 + (y - center_y) ** 2) < (min(width, height) // 4) ** 2
            attention_map[mask] += 1.0
        
        # Normalize
        if attention_map.max() > 0:
            attention_map = attention_map / attention_map.max()
        
        return attention_map

    def initialize_paths_from_attention(self, attention_map, num_paths, width, height):
        """Initialize SVG paths based on attention maps"""
        paths = []
        
        # Find high attention regions
        threshold = 0.3
        high_attention = attention_map > threshold
        
        if not np.any(high_attention):
            # Fallback: create random paths
            return self.create_random_paths(num_paths, width, height)
        
        # Get coordinates of high attention regions
        y_coords, x_coords = np.where(high_attention)
        
        for i in range(num_paths):
            if len(x_coords) > 0:
                # Sample random points from high attention regions
                idx = np.random.choice(len(x_coords), size=min(4, len(x_coords)), replace=False)
                path_points = [(x_coords[j], y_coords[j]) for j in idx]
                
                # Sort points to create a reasonable path
                path_points.sort(key=lambda p: p[0])
                
                paths.append(path_points)
            else:
                # Fallback to random path
                paths.append(self.create_single_random_path(width, height))
        
        return paths

    def create_random_paths(self, num_paths, width, height):
        """Create random paths as fallback"""
        paths = []
        for i in range(num_paths):
            paths.append(self.create_single_random_path(width, height))
        return paths

    def create_single_random_path(self, width, height):
        """Create a single random path"""
        num_points = random.randint(3, 6)
        points = []
        for _ in range(num_points):
            x = random.randint(0, width)
            y = random.randint(0, height)
            points.append((x, y))
        return points

    def optimize_paths(self, paths, text_embeddings, num_iter, guidance_scale):
        """Simplified path optimization"""
        # For now, just add some smoothing and variation
        optimized_paths = []
        
        for path in paths:
            if len(path) < 2:
                optimized_paths.append(path)
                continue
                
            # Add some smoothing
            smoothed_path = []
            for i in range(len(path)):
                if i == 0 or i == len(path) - 1:
                    smoothed_path.append(path[i])
                else:
                    # Simple smoothing
                    prev_x, prev_y = path[i-1]
                    curr_x, curr_y = path[i]
                    next_x, next_y = path[i+1]
                    
                    smooth_x = (prev_x + curr_x + next_x) / 3
                    smooth_y = (prev_y + curr_y + next_y) / 3
                    
                    smoothed_path.append((smooth_x, smooth_y))
            
            optimized_paths.append(smoothed_path)
        
        return optimized_paths

    def create_svg_from_paths(self, paths, width, height):
        """Create SVG content from optimized paths"""
        dwg = svgwrite.Drawing(size=(width, height))
        
        # Add white background
        dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
        
        # Add paths
        for i, path in enumerate(paths):
            if len(path) < 2:
                continue
                
            # Create path string
            path_str = f"M {path[0][0]},{path[0][1]}"
            for point in path[1:]:
                path_str += f" L {point[0]},{point[1]}"
            
            # Vary stroke properties
            stroke_width = random.uniform(0.5, 3.0)
            stroke_color = f"rgb({random.randint(0, 100)},{random.randint(0, 100)},{random.randint(0, 100)})"
            
            dwg.add(dwg.path(
                d=path_str,
                stroke=stroke_color,
                stroke_width=stroke_width,
                fill='none',
                stroke_linecap='round',
                stroke_linejoin='round'
            ))
        
        return dwg.tostring()

    def create_fallback_svg(self, prompt, width=224, height=224):
        """Create a simple fallback SVG"""
        dwg = svgwrite.Drawing(size=(width, height))
        
        # Add white background
        dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
        
        # Add simple sketch based on prompt
        prompt_lower = prompt.lower()
        
        if any(word in prompt_lower for word in ['mountain', 'landscape']):
            self._add_mountain_sketch(dwg, width, height)
        elif any(word in prompt_lower for word in ['house', 'building']):
            self._add_house_sketch(dwg, width, height)
        elif any(word in prompt_lower for word in ['flower', 'plant']):
            self._add_flower_sketch(dwg, width, height)
        else:
            self._add_abstract_sketch(dwg, width, height, prompt)
        
        return dwg.tostring()

    def _add_mountain_sketch(self, dwg, width, height):
        """Add mountain sketch to SVG"""
        # Mountain outline
        points = [(0, height*0.7)]
        for x in range(0, width, 20):
            y = height * 0.7 + 30 * math.sin(x * 0.02) + 15 * math.sin(x * 0.05)
            points.append((x, y))
        points.append((width, height))
        points.append((0, height))
        
        dwg.add(dwg.polygon(points, fill='lightgray', stroke='black', stroke_width=2))

    def _add_house_sketch(self, dwg, width, height):
        """Add house sketch to SVG"""
        # House base
        house_width = width * 0.6
        house_height = height * 0.4
        house_x = (width - house_width) / 2
        house_y = height * 0.4
        
        dwg.add(dwg.rect(
            insert=(house_x, house_y),
            size=(house_width, house_height),
            fill='lightblue',
            stroke='black',
            stroke_width=2
        ))
        
        # Roof
        roof_points = [
            (house_x, house_y),
            (house_x + house_width/2, house_y - house_height*0.3),
            (house_x + house_width, house_y)
        ]
        dwg.add(dwg.polygon(roof_points, fill='red', stroke='black', stroke_width=2))

    def _add_flower_sketch(self, dwg, width, height):
        """Add flower sketch to SVG"""
        center_x, center_y = width/2, height/2
        
        # Stem
        dwg.add(dwg.line(
            start=(center_x, center_y + 20),
            end=(center_x, height - 20),
            stroke='green',
            stroke_width=4
        ))
        
        # Petals
        for angle in range(0, 360, 45):
            x = center_x + 25 * math.cos(math.radians(angle))
            y = center_y + 25 * math.sin(math.radians(angle))
            dwg.add(dwg.circle(
                center=(x, y),
                r=8,
                fill='pink',
                stroke='red',
                stroke_width=1
            ))
        
        # Center
        dwg.add(dwg.circle(
            center=(center_x, center_y),
            r=8,
            fill='yellow',
            stroke='orange',
            stroke_width=2
        ))

    def _add_abstract_sketch(self, dwg, width, height, prompt):
        """Add abstract sketch to SVG"""
        # Create flowing lines based on prompt hash
        prompt_hash = hash(prompt) % 100
        
        for i in range(8):
            points = []
            start_x = (i * 30 + prompt_hash) % (width - 40) + 20
            start_y = (i * 25 + prompt_hash) % (height - 40) + 20
            
            for j in range(4):
                x = start_x + j * 25 + 15 * math.sin((i + j + prompt_hash) * 0.5)
                y = start_y + j * 20 + 15 * math.cos((i + j + prompt_hash) * 0.3)
                points.append((max(0, min(width, x)), max(0, min(height, y))))
            
            # Create path
            if len(points) > 1:
                path_str = f"M {points[0][0]},{points[0][1]}"
                for point in points[1:]:
                    path_str += f" L {point[0]},{point[1]}"
                
                color_val = (i * 30) % 200 + 50
                dwg.add(dwg.path(
                    d=path_str,
                    stroke=f"rgb({color_val},{color_val//2},{color_val//3})",
                    stroke_width=2,
                    fill='none',
                    stroke_linecap='round'
                ))