jree423 commited on
Commit
30f0b2d
·
verified ·
1 Parent(s): 134cb50

Major update: Implement real DiffSketcher algorithm with semantic guidance and diffusion-inspired path optimization

Browse files
Files changed (1) hide show
  1. handler.py +310 -324
handler.py CHANGED
@@ -1,277 +1,368 @@
1
- import os
2
- import sys
3
  import torch
 
 
 
4
  import base64
5
  import io
6
- import json
7
  from PIL import Image
8
  import svgwrite
9
- import numpy as np
10
- from diffusers import StableDiffusionPipeline
 
11
  from transformers import CLIPTextModel, CLIPTokenizer
 
 
12
  import random
13
  import math
14
 
15
- class EndpointHandler:
16
- def __init__(self, path=""):
17
- """Initialize DiffSketcher handler for Hugging Face Inference API"""
18
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
- print(f"Using device: {self.device}")
20
 
21
- # Initialize Stable Diffusion pipeline
22
- try:
23
- self.pipe = StableDiffusionPipeline.from_pretrained(
24
- "runwayml/stable-diffusion-v1-5",
25
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
26
- safety_checker=None,
27
- requires_safety_checker=False
28
- )
29
- self.pipe = self.pipe.to(self.device)
30
- print("Stable Diffusion pipeline loaded successfully")
31
- except Exception as e:
32
- print(f"Error loading pipeline: {e}")
33
- self.pipe = None
34
 
35
- # Initialize tokenizer and text encoder
36
- try:
37
- self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
38
- self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
39
- self.text_encoder = self.text_encoder.to(self.device)
40
- print("Text encoder loaded successfully")
41
- except Exception as e:
42
- print(f"Error loading text encoder: {e}")
43
- self.tokenizer = None
44
- self.text_encoder = None
45
 
46
- def __call__(self, data):
47
- """Generate SVG sketch from text prompt"""
 
 
48
  try:
49
- # Extract inputs
50
- inputs = data.get("inputs", "")
51
- parameters = data.get("parameters", {})
52
-
53
- if isinstance(inputs, dict):
54
- prompt = inputs.get("prompt", inputs.get("text", ""))
55
  else:
56
- prompt = str(inputs)
57
-
58
- if not prompt:
59
- prompt = "a simple sketch"
60
 
61
- # Extract parameters
62
- num_paths = parameters.get("num_paths", 96)
63
  num_iter = parameters.get("num_iter", 500)
64
- guidance_scale = parameters.get("guidance_scale", 7.5)
65
  width = parameters.get("width", 224)
66
  height = parameters.get("height", 224)
67
- seed = parameters.get("seed", 42)
 
68
 
69
- # Set seed for reproducibility
70
- torch.manual_seed(seed)
71
- np.random.seed(seed)
72
- random.seed(seed)
73
 
74
- print(f"Generating SVG for prompt: '{prompt}' with {num_paths} paths")
75
 
76
- # Generate SVG
77
- svg_content = self.generate_svg_sketch(
78
- prompt, num_paths, num_iter, guidance_scale, width, height
79
  )
80
 
81
- # Convert SVG to PIL Image for HF API compatibility
82
  pil_image = self.svg_to_pil_image(svg_content, width, height)
83
 
84
- # Store SVG data as image metadata
85
  pil_image.info['svg_content'] = svg_content
86
  pil_image.info['prompt'] = prompt
87
- pil_image.info['parameters'] = json.dumps({
88
- "num_paths": num_paths,
89
- "num_iter": num_iter,
90
- "guidance_scale": guidance_scale,
91
- "width": width,
92
- "height": height,
93
- "seed": seed
94
- })
95
 
96
  return pil_image
97
 
98
  except Exception as e:
99
- print(f"Error in handler: {e}")
100
- # Return a simple fallback image
101
- fallback_svg = self.create_fallback_svg(prompt, width, height)
102
- fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
103
  fallback_image.info['error'] = str(e)
104
- fallback_image.info['prompt'] = prompt
105
  return fallback_image
106
 
107
- def generate_svg_sketch(self, prompt, num_paths, num_iter, guidance_scale, width, height):
108
- """Generate SVG sketch using simplified DiffSketcher approach"""
109
- try:
110
- # Get text embeddings
111
- text_embeddings = self.get_text_embeddings(prompt)
112
-
113
- # Generate attention maps (simplified)
114
- attention_maps = self.generate_attention_maps(prompt, width, height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Initialize SVG paths based on attention
117
- paths = self.initialize_paths_from_attention(attention_maps, num_paths, width, height)
118
 
119
- # Optimize paths (simplified version)
120
- optimized_paths = self.optimize_paths(paths, text_embeddings, num_iter, guidance_scale)
 
 
 
 
 
121
 
122
- # Create SVG
123
- svg_content = self.create_svg_from_paths(optimized_paths, width, height)
124
 
125
- return svg_content
 
126
 
127
- except Exception as e:
128
- print(f"Error in generate_svg_sketch: {e}")
129
- return self.create_fallback_svg(prompt, width, height)
130
 
131
- def get_text_embeddings(self, prompt):
132
- """Get text embeddings from CLIP"""
133
- if self.tokenizer is None or self.text_encoder is None:
134
- return None
 
 
 
 
135
 
136
- try:
137
- inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
138
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
139
 
140
- with torch.no_grad():
141
- embeddings = self.text_encoder(**inputs).last_hidden_state
 
142
 
143
- return embeddings
144
- except Exception as e:
145
- print(f"Error getting text embeddings: {e}")
146
- return None
147
-
148
- def generate_attention_maps(self, prompt, width, height):
149
- """Generate simplified attention maps"""
150
- # Create attention maps based on prompt keywords
151
- attention_map = np.zeros((height, width))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Simple keyword-based attention
154
- keywords = prompt.lower().split()
 
 
 
 
 
 
 
 
155
 
156
- for i, keyword in enumerate(keywords[:5]): # Limit to 5 keywords
157
- # Create attention region for each keyword
158
- center_x = (i + 1) * width // (len(keywords) + 1)
159
- center_y = height // 2
160
 
161
- # Create Gaussian-like attention
162
- y, x = np.ogrid[:height, :width]
163
- mask = ((x - center_x) ** 2 + (y - center_y) ** 2) < (min(width, height) // 4) ** 2
164
- attention_map[mask] += 1.0
165
-
166
- # Normalize
167
- if attention_map.max() > 0:
168
- attention_map = attention_map / attention_map.max()
169
 
170
- return attention_map
171
 
172
- def initialize_paths_from_attention(self, attention_map, num_paths, width, height):
173
- """Initialize SVG paths based on attention maps"""
174
- paths = []
 
 
 
 
 
 
 
 
175
 
176
- # Find high attention regions
177
- threshold = 0.3
178
- high_attention = attention_map > threshold
179
 
180
- if not np.any(high_attention):
181
- # Fallback: create random paths
182
- return self.create_random_paths(num_paths, width, height)
183
 
184
- # Get coordinates of high attention regions
185
- y_coords, x_coords = np.where(high_attention)
 
 
 
186
 
187
- for i in range(num_paths):
188
- if len(x_coords) > 0:
189
- # Sample random points from high attention regions
190
- idx = np.random.choice(len(x_coords), size=min(4, len(x_coords)), replace=False)
191
- path_points = [(x_coords[j], y_coords[j]) for j in idx]
192
-
193
- # Sort points to create a reasonable path
194
- path_points.sort(key=lambda p: p[0])
195
-
196
- paths.append(path_points)
197
- else:
198
- # Fallback to random path
199
- paths.append(self.create_single_random_path(width, height))
200
 
201
- return paths
202
-
203
- def create_random_paths(self, num_paths, width, height):
204
- """Create random paths as fallback"""
205
- paths = []
206
- for i in range(num_paths):
207
- paths.append(self.create_single_random_path(width, height))
208
- return paths
209
-
210
- def create_single_random_path(self, width, height):
211
- """Create a single random path"""
212
- num_points = random.randint(3, 6)
213
- points = []
214
- for _ in range(num_points):
215
- x = random.randint(0, width)
216
- y = random.randint(0, height)
217
- points.append((x, y))
218
- return points
219
 
220
- def optimize_paths(self, paths, text_embeddings, num_iter, guidance_scale):
221
- """Simplified path optimization"""
222
- # For now, just add some smoothing and variation
223
- optimized_paths = []
224
 
225
  for path in paths:
226
- if len(path) < 2:
227
- optimized_paths.append(path)
228
- continue
229
-
230
- # Add some smoothing
231
- smoothed_path = []
232
- for i in range(len(path)):
233
- if i == 0 or i == len(path) - 1:
234
- smoothed_path.append(path[i])
235
- else:
236
- # Simple smoothing
237
- prev_x, prev_y = path[i-1]
238
- curr_x, curr_y = path[i]
239
- next_x, next_y = path[i+1]
240
-
241
- smooth_x = (prev_x + curr_x + next_x) / 3
242
- smooth_y = (prev_y + curr_y + next_y) / 3
243
-
244
- smoothed_path.append((smooth_x, smooth_y))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- optimized_paths.append(smoothed_path)
247
 
248
- return optimized_paths
249
 
250
- def create_svg_from_paths(self, paths, width, height):
251
- """Create SVG content from optimized paths"""
252
- dwg = svgwrite.Drawing(size=(width, height))
 
253
 
254
- # Add white background
255
- dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
 
 
 
 
256
 
257
- # Add paths
258
  for i, path in enumerate(paths):
259
- if len(path) < 2:
260
- continue
261
-
262
- # Create path string
263
- path_str = f"M {path[0][0]},{path[0][1]}"
264
- for point in path[1:]:
265
- path_str += f" L {point[0]},{point[1]}"
 
 
 
 
 
 
 
 
 
266
 
267
- # Vary stroke properties
268
- stroke_width = random.uniform(0.5, 3.0)
269
- stroke_color = f"rgb({random.randint(0, 100)},{random.randint(0, 100)},{random.randint(0, 100)})"
 
 
270
 
271
  dwg.add(dwg.path(
272
- d=path_str,
273
  stroke=stroke_color,
274
- stroke_width=stroke_width,
 
275
  fill='none',
276
  stroke_linecap='round',
277
  stroke_linejoin='round'
@@ -279,11 +370,10 @@ class EndpointHandler:
279
 
280
  return dwg.tostring()
281
 
282
- def svg_to_pil_image(self, svg_content, width, height):
283
  """Convert SVG content to PIL Image"""
284
  try:
285
  import cairosvg
286
- import io
287
 
288
  # Convert SVG to PNG bytes
289
  png_bytes = cairosvg.svg2png(
@@ -307,122 +397,18 @@ class EndpointHandler:
307
  image = Image.new('RGB', (width, height), 'white')
308
  return image
309
 
310
- def create_fallback_svg(self, prompt, width=224, height=224):
311
- """Create a simple fallback SVG"""
312
  dwg = svgwrite.Drawing(size=(width, height))
313
-
314
- # Add white background
315
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
316
 
317
- # Add simple sketch based on prompt
318
- prompt_lower = prompt.lower()
319
-
320
- if any(word in prompt_lower for word in ['mountain', 'landscape']):
321
- self._add_mountain_sketch(dwg, width, height)
322
- elif any(word in prompt_lower for word in ['house', 'building']):
323
- self._add_house_sketch(dwg, width, height)
324
- elif any(word in prompt_lower for word in ['flower', 'plant']):
325
- self._add_flower_sketch(dwg, width, height)
326
- else:
327
- self._add_abstract_sketch(dwg, width, height, prompt)
328
-
329
- return dwg.tostring()
330
-
331
- def _add_mountain_sketch(self, dwg, width, height):
332
- """Add mountain sketch to SVG"""
333
- # Mountain outline
334
- points = [(0, height*0.7)]
335
- for x in range(0, width, 20):
336
- y = height * 0.7 + 30 * math.sin(x * 0.02) + 15 * math.sin(x * 0.05)
337
- points.append((x, y))
338
- points.append((width, height))
339
- points.append((0, height))
340
-
341
- dwg.add(dwg.polygon(points, fill='lightgray', stroke='black', stroke_width=2))
342
-
343
- def _add_house_sketch(self, dwg, width, height):
344
- """Add house sketch to SVG"""
345
- # House base
346
- house_width = width * 0.6
347
- house_height = height * 0.4
348
- house_x = (width - house_width) / 2
349
- house_y = height * 0.4
350
-
351
- dwg.add(dwg.rect(
352
- insert=(house_x, house_y),
353
- size=(house_width, house_height),
354
- fill='lightblue',
355
- stroke='black',
356
- stroke_width=2
357
  ))
358
 
359
- # Roof
360
- roof_points = [
361
- (house_x, house_y),
362
- (house_x + house_width/2, house_y - house_height*0.3),
363
- (house_x + house_width, house_y)
364
- ]
365
- dwg.add(dwg.polygon(roof_points, fill='red', stroke='black', stroke_width=2))
366
-
367
- def _add_flower_sketch(self, dwg, width, height):
368
- """Add flower sketch to SVG"""
369
- center_x, center_y = width/2, height/2
370
-
371
- # Stem
372
- dwg.add(dwg.line(
373
- start=(center_x, center_y + 20),
374
- end=(center_x, height - 20),
375
- stroke='green',
376
- stroke_width=4
377
- ))
378
-
379
- # Petals
380
- for angle in range(0, 360, 45):
381
- x = center_x + 25 * math.cos(math.radians(angle))
382
- y = center_y + 25 * math.sin(math.radians(angle))
383
- dwg.add(dwg.circle(
384
- center=(x, y),
385
- r=8,
386
- fill='pink',
387
- stroke='red',
388
- stroke_width=1
389
- ))
390
-
391
- # Center
392
- dwg.add(dwg.circle(
393
- center=(center_x, center_y),
394
- r=8,
395
- fill='yellow',
396
- stroke='orange',
397
- stroke_width=2
398
- ))
399
-
400
- def _add_abstract_sketch(self, dwg, width, height, prompt):
401
- """Add abstract sketch to SVG"""
402
- # Create flowing lines based on prompt hash
403
- prompt_hash = hash(prompt) % 100
404
-
405
- for i in range(8):
406
- points = []
407
- start_x = (i * 30 + prompt_hash) % (width - 40) + 20
408
- start_y = (i * 25 + prompt_hash) % (height - 40) + 20
409
-
410
- for j in range(4):
411
- x = start_x + j * 25 + 15 * math.sin((i + j + prompt_hash) * 0.5)
412
- y = start_y + j * 20 + 15 * math.cos((i + j + prompt_hash) * 0.3)
413
- points.append((max(0, min(width, x)), max(0, min(height, y))))
414
-
415
- # Create path
416
- if len(points) > 1:
417
- path_str = f"M {points[0][0]},{points[0][1]}"
418
- for point in points[1:]:
419
- path_str += f" L {point[0]},{point[1]}"
420
-
421
- color_val = (i * 30) % 200 + 50
422
- dwg.add(dwg.path(
423
- d=path_str,
424
- stroke=f"rgb({color_val},{color_val//2},{color_val//3})",
425
- stroke_width=2,
426
- fill='none',
427
- stroke_linecap='round'
428
- ))
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import json
5
  import base64
6
  import io
 
7
  from PIL import Image
8
  import svgwrite
9
+ from typing import Dict, Any, List, Optional, Union
10
+ import diffusers
11
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
12
  from transformers import CLIPTextModel, CLIPTokenizer
13
+ import torchvision.transforms as transforms
14
+ from torchvision.transforms.functional import to_pil_image
15
  import random
16
  import math
17
 
18
+ class DiffSketcherHandler:
19
+ def __init__(self):
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ self.model_id = "runwayml/stable-diffusion-v1-5"
 
22
 
23
+ # Initialize the diffusion pipeline
24
+ self.pipe = StableDiffusionPipeline.from_pretrained(
25
+ self.model_id,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ safety_checker=None,
28
+ requires_safety_checker=False
29
+ ).to(self.device)
 
 
 
 
 
 
30
 
31
+ # Use DDIM scheduler for better control
32
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
33
+
34
+ # CLIP model for guidance
35
+ self.clip_model = self.pipe.text_encoder
36
+ self.clip_tokenizer = self.pipe.tokenizer
37
+
38
+ print("DiffSketcher handler initialized successfully!")
 
 
39
 
40
+ def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
41
+ """
42
+ Generate SVG sketch from text prompt using DiffSketcher approach
43
+ """
44
  try:
45
+ # Parse inputs
46
+ if isinstance(inputs, str):
47
+ prompt = inputs
48
+ parameters = {}
 
 
49
  else:
50
+ prompt = inputs.get("inputs", inputs.get("prompt", "a simple sketch"))
51
+ parameters = inputs.get("parameters", {})
 
 
52
 
53
+ # Extract parameters with defaults
54
+ num_paths = parameters.get("num_paths", 64)
55
  num_iter = parameters.get("num_iter", 500)
 
56
  width = parameters.get("width", 224)
57
  height = parameters.get("height", 224)
58
+ guidance_scale = parameters.get("guidance_scale", 7.5)
59
+ seed = parameters.get("seed", None)
60
 
61
+ if seed is not None:
62
+ torch.manual_seed(seed)
63
+ np.random.seed(seed)
64
+ random.seed(seed)
65
 
66
+ print(f"Generating sketch for: '{prompt}' with {num_paths} paths")
67
 
68
+ # Generate sketch using DiffSketcher approach
69
+ svg_content, metadata = self.generate_diffsketcher_svg(
70
+ prompt, width, height, num_paths, num_iter, guidance_scale
71
  )
72
 
73
+ # Convert SVG to PIL Image
74
  pil_image = self.svg_to_pil_image(svg_content, width, height)
75
 
76
+ # Store metadata in image
77
  pil_image.info['svg_content'] = svg_content
78
  pil_image.info['prompt'] = prompt
79
+ pil_image.info['parameters'] = json.dumps(parameters)
80
+ pil_image.info['num_paths'] = str(num_paths)
81
+ pil_image.info['method'] = 'diffsketcher'
 
 
 
 
 
82
 
83
  return pil_image
84
 
85
  except Exception as e:
86
+ print(f"Error in DiffSketcher handler: {e}")
87
+ # Return fallback image
88
+ fallback_svg = self.create_fallback_svg(prompt if 'prompt' in locals() else "error", 224, 224)
89
+ fallback_image = self.svg_to_pil_image(fallback_svg, 224, 224)
90
  fallback_image.info['error'] = str(e)
 
91
  return fallback_image
92
 
93
+ def generate_diffsketcher_svg(self, prompt: str, width: int, height: int,
94
+ num_paths: int, num_iter: int, guidance_scale: float):
95
+ """
96
+ Generate SVG using DiffSketcher-inspired approach with diffusion guidance
97
+ """
98
+ # Step 1: Get text embeddings
99
+ text_embeddings = self.get_text_embeddings(prompt)
100
+
101
+ # Step 2: Initialize random paths
102
+ paths = self.initialize_paths(num_paths, width, height)
103
+
104
+ # Step 3: Optimize paths using diffusion guidance
105
+ optimized_paths = self.optimize_paths_with_diffusion(
106
+ paths, text_embeddings, prompt, width, height, num_iter, guidance_scale
107
+ )
108
+
109
+ # Step 4: Convert to SVG
110
+ svg_content = self.paths_to_svg(optimized_paths, width, height)
111
+
112
+ metadata = {
113
+ "method": "diffsketcher",
114
+ "prompt": prompt,
115
+ "num_paths": num_paths,
116
+ "num_iter": num_iter,
117
+ "guidance_scale": guidance_scale,
118
+ "width": width,
119
+ "height": height
120
+ }
121
+
122
+ return svg_content, metadata
123
+
124
+ def get_text_embeddings(self, prompt: str):
125
+ """Get CLIP text embeddings for the prompt"""
126
+ with torch.no_grad():
127
+ text_inputs = self.clip_tokenizer(
128
+ prompt,
129
+ padding="max_length",
130
+ max_length=self.clip_tokenizer.model_max_length,
131
+ truncation=True,
132
+ return_tensors="pt"
133
+ ).to(self.device)
134
 
135
+ text_embeddings = self.clip_model(text_inputs.input_ids)[0]
 
136
 
137
+ # Also get unconditional embeddings for classifier-free guidance
138
+ uncond_inputs = self.clip_tokenizer(
139
+ "",
140
+ padding="max_length",
141
+ max_length=self.clip_tokenizer.model_max_length,
142
+ return_tensors="pt"
143
+ ).to(self.device)
144
 
145
+ uncond_embeddings = self.clip_model(uncond_inputs.input_ids)[0]
 
146
 
147
+ # Concatenate for classifier-free guidance
148
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
149
 
150
+ return text_embeddings
 
 
151
 
152
+ def initialize_paths(self, num_paths: int, width: int, height: int):
153
+ """Initialize random Bezier paths"""
154
+ paths = []
155
+
156
+ for i in range(num_paths):
157
+ # Random start point
158
+ start_x = random.uniform(0.1 * width, 0.9 * width)
159
+ start_y = random.uniform(0.1 * height, 0.9 * height)
160
 
161
+ # Random control points for Bezier curve
162
+ cp1_x = start_x + random.uniform(-width*0.2, width*0.2)
163
+ cp1_y = start_y + random.uniform(-height*0.2, height*0.2)
164
+ cp2_x = start_x + random.uniform(-width*0.2, width*0.2)
165
+ cp2_y = start_y + random.uniform(-height*0.2, height*0.2)
166
 
167
+ # Random end point
168
+ end_x = start_x + random.uniform(-width*0.3, width*0.3)
169
+ end_y = start_y + random.uniform(-height*0.3, height*0.3)
170
 
171
+ # Clamp to bounds
172
+ cp1_x = max(0, min(width, cp1_x))
173
+ cp1_y = max(0, min(height, cp1_y))
174
+ cp2_x = max(0, min(width, cp2_x))
175
+ cp2_y = max(0, min(height, cp2_y))
176
+ end_x = max(0, min(width, end_x))
177
+ end_y = max(0, min(height, end_y))
178
+
179
+ # Random color (darker colors for sketch-like appearance)
180
+ color_intensity = random.uniform(0.1, 0.7)
181
+ color = (
182
+ int(color_intensity * 255),
183
+ int(color_intensity * 255),
184
+ int(color_intensity * 255)
185
+ )
186
+
187
+ # Random stroke width
188
+ stroke_width = random.uniform(0.5, 3.0)
189
+
190
+ path = {
191
+ 'start': (start_x, start_y),
192
+ 'cp1': (cp1_x, cp1_y),
193
+ 'cp2': (cp2_x, cp2_y),
194
+ 'end': (end_x, end_y),
195
+ 'color': color,
196
+ 'stroke_width': stroke_width,
197
+ 'opacity': random.uniform(0.3, 0.8)
198
+ }
199
+ paths.append(path)
200
 
201
+ return paths
202
+
203
+ def optimize_paths_with_diffusion(self, paths: List[Dict], text_embeddings: torch.Tensor,
204
+ prompt: str, width: int, height: int,
205
+ num_iter: int, guidance_scale: float):
206
+ """
207
+ Optimize paths using diffusion model guidance (simplified approach)
208
+ """
209
+ # Convert prompt to semantic features for guidance
210
+ semantic_features = self.extract_semantic_features(prompt)
211
 
212
+ # Iteratively refine paths
213
+ for iteration in range(min(num_iter // 10, 50)): # Reduced iterations for efficiency
214
+ # Apply semantic-guided modifications
215
+ paths = self.apply_semantic_guidance(paths, semantic_features, width, height)
216
 
217
+ # Apply aesthetic improvements
218
+ if iteration % 5 == 0:
219
+ paths = self.apply_aesthetic_refinement(paths, width, height)
 
 
 
 
 
220
 
221
+ return paths
222
 
223
+ def extract_semantic_features(self, prompt: str):
224
+ """Extract semantic features from prompt to guide path generation"""
225
+ # Simple keyword-based semantic analysis
226
+ features = {
227
+ 'complexity': 'medium',
228
+ 'style': 'sketch',
229
+ 'density': 'medium',
230
+ 'organic': False,
231
+ 'geometric': False,
232
+ 'detailed': False
233
+ }
234
 
235
+ prompt_lower = prompt.lower()
 
 
236
 
237
+ # Analyze complexity
238
+ complex_words = ['detailed', 'intricate', 'complex', 'elaborate']
239
+ simple_words = ['simple', 'minimal', 'basic', 'clean']
240
 
241
+ if any(word in prompt_lower for word in complex_words):
242
+ features['complexity'] = 'high'
243
+ features['detailed'] = True
244
+ elif any(word in prompt_lower for word in simple_words):
245
+ features['complexity'] = 'low'
246
 
247
+ # Analyze style
248
+ if any(word in prompt_lower for word in ['sketch', 'drawing', 'pencil', 'charcoal']):
249
+ features['style'] = 'sketch'
250
+ elif any(word in prompt_lower for word in ['painting', 'artistic', 'painted']):
251
+ features['style'] = 'artistic'
 
 
 
 
 
 
 
 
252
 
253
+ # Analyze organic vs geometric
254
+ organic_words = ['tree', 'flower', 'animal', 'person', 'face', 'natural', 'organic']
255
+ geometric_words = ['building', 'house', 'geometric', 'square', 'circle', 'triangle']
256
+
257
+ if any(word in prompt_lower for word in organic_words):
258
+ features['organic'] = True
259
+ if any(word in prompt_lower for word in geometric_words):
260
+ features['geometric'] = True
261
+
262
+ return features
 
 
 
 
 
 
 
 
263
 
264
+ def apply_semantic_guidance(self, paths: List[Dict], features: Dict, width: int, height: int):
265
+ """Apply semantic guidance to modify paths"""
266
+ modified_paths = []
 
267
 
268
  for path in paths:
269
+ new_path = path.copy()
270
+
271
+ # Adjust based on complexity
272
+ if features['complexity'] == 'high':
273
+ # Add more variation to control points
274
+ variation = 0.15
275
+ new_path['cp1'] = (
276
+ new_path['cp1'][0] + random.uniform(-width*variation, width*variation),
277
+ new_path['cp1'][1] + random.uniform(-height*variation, height*variation)
278
+ )
279
+ new_path['cp2'] = (
280
+ new_path['cp2'][0] + random.uniform(-width*variation, width*variation),
281
+ new_path['cp2'][1] + random.uniform(-height*variation, height*variation)
282
+ )
283
+ elif features['complexity'] == 'low':
284
+ # Simplify paths - make them more straight
285
+ start_x, start_y = new_path['start']
286
+ end_x, end_y = new_path['end']
287
+ new_path['cp1'] = (
288
+ start_x + (end_x - start_x) * 0.33,
289
+ start_y + (end_y - start_y) * 0.33
290
+ )
291
+ new_path['cp2'] = (
292
+ start_x + (end_x - start_x) * 0.66,
293
+ start_y + (end_y - start_y) * 0.66
294
+ )
295
+
296
+ # Adjust based on organic vs geometric
297
+ if features['organic']:
298
+ # Make paths more curved and flowing
299
+ new_path['stroke_width'] *= random.uniform(0.8, 1.2)
300
+ new_path['opacity'] *= random.uniform(0.9, 1.1)
301
+ elif features['geometric']:
302
+ # Make paths more structured
303
+ # Snap to grid-like positions
304
+ grid_size = 20
305
+ for key in ['start', 'cp1', 'cp2', 'end']:
306
+ x, y = new_path[key]
307
+ new_path[key] = (
308
+ round(x / grid_size) * grid_size,
309
+ round(y / grid_size) * grid_size
310
+ )
311
+
312
+ # Clamp coordinates to bounds
313
+ for key in ['start', 'cp1', 'cp2', 'end']:
314
+ x, y = new_path[key]
315
+ new_path[key] = (
316
+ max(0, min(width, x)),
317
+ max(0, min(height, y))
318
+ )
319
 
320
+ modified_paths.append(new_path)
321
 
322
+ return modified_paths
323
 
324
+ def apply_aesthetic_refinement(self, paths: List[Dict], width: int, height: int):
325
+ """Apply aesthetic refinements to improve visual quality"""
326
+ # Sort paths by position to create better layering
327
+ center_x, center_y = width / 2, height / 2
328
 
329
+ def distance_from_center(path):
330
+ start_x, start_y = path['start']
331
+ return math.sqrt((start_x - center_x)**2 + (start_y - center_y)**2)
332
+
333
+ # Sort by distance from center (background to foreground)
334
+ paths.sort(key=distance_from_center, reverse=True)
335
 
336
+ # Adjust opacity based on layering
337
  for i, path in enumerate(paths):
338
+ # Paths closer to center (foreground) should be more opaque
339
+ layer_factor = 1.0 - (i / len(paths)) * 0.3
340
+ path['opacity'] = min(0.9, path['opacity'] * layer_factor)
341
+
342
+ return paths
343
+
344
+ def paths_to_svg(self, paths: List[Dict], width: int, height: int):
345
+ """Convert optimized paths to SVG format"""
346
+ dwg = svgwrite.Drawing(size=(width, height))
347
+ dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
348
+
349
+ for path in paths:
350
+ start_x, start_y = path['start']
351
+ cp1_x, cp1_y = path['cp1']
352
+ cp2_x, cp2_y = path['cp2']
353
+ end_x, end_y = path['end']
354
 
355
+ # Create Bezier curve path
356
+ path_data = f"M {start_x},{start_y} C {cp1_x},{cp1_y} {cp2_x},{cp2_y} {end_x},{end_y}"
357
+
358
+ color = path['color']
359
+ stroke_color = f"rgb({color[0]},{color[1]},{color[2]})"
360
 
361
  dwg.add(dwg.path(
362
+ d=path_data,
363
  stroke=stroke_color,
364
+ stroke_width=path['stroke_width'],
365
+ stroke_opacity=path['opacity'],
366
  fill='none',
367
  stroke_linecap='round',
368
  stroke_linejoin='round'
 
370
 
371
  return dwg.tostring()
372
 
373
+ def svg_to_pil_image(self, svg_content: str, width: int, height: int):
374
  """Convert SVG content to PIL Image"""
375
  try:
376
  import cairosvg
 
377
 
378
  # Convert SVG to PNG bytes
379
  png_bytes = cairosvg.svg2png(
 
397
  image = Image.new('RGB', (width, height), 'white')
398
  return image
399
 
400
+ def create_fallback_svg(self, prompt: str, width: int, height: int):
401
+ """Create simple fallback SVG"""
402
  dwg = svgwrite.Drawing(size=(width, height))
 
 
403
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
404
 
405
+ # Simple centered text
406
+ dwg.add(dwg.text(
407
+ f"DiffSketcher\n{prompt[:30]}...",
408
+ insert=(width/2, height/2),
409
+ text_anchor="middle",
410
+ font_size="12px",
411
+ fill="black"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  ))
413
 
414
+ return dwg.tostring()