jree423 commited on
Commit
b0efdb8
·
verified ·
1 Parent(s): 70410f6

Fix DiffSketchEdit handler to properly implement text-based vector sketch editing

Browse files
Files changed (1) hide show
  1. handler.py +694 -257
handler.py CHANGED
@@ -1,307 +1,744 @@
1
- from PIL import Image, ImageDraw, ImageFilter
2
- import math
 
 
 
 
 
3
  import random
 
 
 
 
4
  import io
5
- import base64
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  """Initialize DiffSketchEdit handler for Hugging Face Inference API"""
10
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __call__(self, data):
13
- """Edit sketch based on text prompt"""
14
- # Extract prompt
15
- inputs = data.get("inputs", "")
16
- if isinstance(inputs, dict):
17
- prompt = inputs.get("prompt", inputs.get("text", ""))
18
- input_image_data = inputs.get("input_image", None)
19
- edit_type = inputs.get("edit_type", "refine")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  else:
21
- prompt = str(inputs)
22
- input_image_data = data.get("input_image", None)
23
- edit_type = data.get("edit_type", "refine")
 
 
 
 
 
 
24
 
25
- if not prompt:
26
- prompt = "edit sketch"
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Handle input image if provided
29
- input_image = None
30
- if input_image_data:
31
- try:
32
- if isinstance(input_image_data, str):
33
- # Base64 encoded image
34
- img_data = base64.b64decode(input_image_data)
35
- input_image = Image.open(io.BytesIO(img_data)).convert('RGB')
36
- input_image = input_image.resize((224, 224))
37
- except Exception:
38
- input_image = None
39
-
40
- # Generate/edit sketch
41
- image = self.edit_sketch(prompt, input_image, edit_type)
42
-
43
- # Return PIL Image directly for HF Inference API
44
- return image
45
 
46
- def edit_sketch(self, prompt, input_image=None, edit_type="refine"):
47
- """Edit or create a sketch based on the prompt"""
 
 
48
 
49
- if input_image is None:
50
- # Create initial sketch
51
- img = self._create_initial_sketch(prompt)
52
- else:
53
- img = input_image.copy()
54
-
55
- # Apply editing based on type
56
- if edit_type == "refine":
57
- img = self._refine_sketch(img, prompt)
58
- elif edit_type == "style_transfer":
59
- img = self._apply_style_transfer(img, prompt)
60
- elif edit_type == "add_details":
61
- img = self._add_details(img, prompt)
62
- elif edit_type == "color":
63
- img = self._apply_coloring(img, prompt)
64
  else:
65
  # Default refinement
66
- img = self._refine_sketch(img, prompt)
67
 
68
- return img
69
 
70
- def _create_initial_sketch(self, prompt):
71
- """Create initial sketch based on prompt"""
72
- img = Image.new('RGB', (224, 224), 'white')
73
- draw = ImageDraw.Draw(img)
74
 
75
- prompt_lower = prompt.lower()
 
 
 
 
 
 
 
76
 
77
- if any(word in prompt_lower for word in ['house', 'building', 'home']):
78
- self._draw_house_sketch(draw)
79
- elif any(word in prompt_lower for word in ['tree', 'plant', 'nature']):
80
- self._draw_tree_sketch(draw)
81
- elif any(word in prompt_lower for word in ['face', 'portrait', 'person']):
82
- self._draw_face_sketch(draw)
83
- elif any(word in prompt_lower for word in ['car', 'vehicle']):
84
- self._draw_car_sketch(draw)
85
- elif any(word in prompt_lower for word in ['flower', 'bloom']):
86
- self._draw_flower_sketch(draw)
87
- else:
88
- self._draw_abstract_sketch(draw, prompt)
89
 
90
- return img
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def _draw_house_sketch(self, draw):
93
- """Draw basic house sketch"""
94
- # Base
95
- draw.rectangle([50, 120, 174, 180], outline='black', width=2)
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Roof
97
- draw.polygon([(50, 120), (112, 80), (174, 120)], outline='black', width=2)
 
 
 
 
 
 
98
  # Door
99
- draw.rectangle([100, 150, 124, 180], outline='black', width=2)
100
- # Windows
101
- draw.rectangle([70, 140, 90, 160], outline='black', width=2)
102
- draw.rectangle([134, 140, 154, 160], outline='black', width=2)
 
 
 
 
 
 
 
 
103
 
104
- def _draw_tree_sketch(self, draw):
105
- """Draw basic tree sketch"""
 
 
 
106
  # Trunk
107
- draw.rectangle([105, 140, 119, 200], fill='brown', outline='black', width=2)
 
 
 
 
 
 
 
 
 
108
  # Crown
109
- draw.ellipse([70, 80, 154, 150], outline='green', width=3)
110
- # Branches
111
- for angle in range(0, 360, 60):
112
- x = 112 + 25 * math.cos(math.radians(angle))
113
- y = 115 + 25 * math.sin(math.radians(angle))
114
- draw.line([112, 115, x, y], fill='brown', width=2)
 
 
115
 
116
- def _draw_face_sketch(self, draw):
117
- """Draw basic face sketch"""
118
- center = (112, 112)
119
- # Head outline
120
- draw.ellipse([center[0]-40, center[1]-50, center[0]+40, center[1]+30],
121
- outline='black', width=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # Eyes
123
- draw.ellipse([center[0]-20, center[1]-20, center[0]-10, center[1]-10],
124
- outline='black', width=2)
125
- draw.ellipse([center[0]+10, center[1]-20, center[0]+20, center[1]-10],
126
- outline='black', width=2)
127
- # Nose
128
- draw.line([center[0], center[1]-5, center[0]-3, center[1]+5], fill='black', width=2)
 
 
 
 
 
 
 
 
129
  # Mouth
130
- draw.arc([center[0]-15, center[1]+5, center[0]+15, center[1]+20], 0, 180, fill='black', width=2)
 
 
 
 
 
 
131
 
132
- def _draw_car_sketch(self, draw):
133
- """Draw basic car sketch"""
134
- # Body
135
- draw.rectangle([50, 120, 174, 160], outline='black', width=2)
136
- # Roof
137
- draw.rectangle([70, 100, 154, 120], outline='black', width=2)
138
- # Wheels
139
- draw.ellipse([60, 150, 80, 170], outline='black', width=2)
140
- draw.ellipse([144, 150, 164, 170], outline='black', width=2)
141
- # Windows
142
- draw.rectangle([75, 105, 100, 115], outline='black', width=1)
143
- draw.rectangle([124, 105, 149, 115], outline='black', width=1)
144
-
145
- def _draw_flower_sketch(self, draw):
146
- """Draw basic flower sketch"""
147
- center = (112, 112)
148
  # Stem
149
- draw.line([center[0], center[1]+20, center[0], 200], fill='green', width=4)
 
 
 
 
 
 
150
  # Petals
 
151
  for angle in range(0, 360, 45):
152
- x = center[0] + 25 * math.cos(math.radians(angle))
153
- y = center[1] + 25 * math.sin(math.radians(angle))
154
- draw.ellipse([x-8, y-15, x+8, y+5], outline='black', width=2)
 
 
 
 
 
 
 
155
  # Center
156
- draw.ellipse([center[0]-8, center[1]-8, center[0]+8, center[1]+8],
157
- outline='black', width=2)
 
 
 
 
 
158
 
159
- def _draw_abstract_sketch(self, draw, prompt):
160
- """Draw abstract sketch"""
161
- prompt_hash = hash(prompt) % 100
 
162
 
163
- for i in range(5):
164
- x1 = (i * 40 + prompt_hash) % 180 + 22
165
- y1 = (i * 30 + prompt_hash) % 160 + 32
166
- x2 = x1 + 40 + (i * 10) % 30
167
- y2 = y1 + 60 + (i * 15) % 40
168
- draw.ellipse([x1, y1, x2, y2], outline='black', width=2)
169
-
170
- def _refine_sketch(self, img, prompt):
171
- """Refine the sketch with enhanced details"""
172
- # Apply sharpening
173
- img = img.filter(ImageFilter.SHARPEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- draw = ImageDraw.Draw(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # Add refinement details based on prompt
178
- prompt_lower = prompt.lower()
179
 
180
- if "house" in prompt_lower:
181
- # Add roof tiles
182
- for y in range(80, 120, 5):
183
- for x in range(50 + (y-80)//2, 174 - (y-80)//2, 10):
184
- draw.line([x, y, x+5, y], fill='gray', width=1)
185
- # Add door handle
186
- draw.ellipse([120, 164, 122, 166], fill='black')
187
-
188
- elif "tree" in prompt_lower:
189
- # Add leaf details
190
- for i in range(10):
191
- x = 70 + random.randint(0, 84)
192
- y = 80 + random.randint(0, 70)
193
- draw.ellipse([x-2, y-2, x+2, y+2], fill='green')
194
-
195
- elif "face" in prompt_lower:
196
- # Add facial details
197
- center = (112, 112)
198
- # Eyebrows
199
- draw.arc([center[0]-25, center[1]-35, center[0]-5, center[1]-25], 0, 180, fill='black', width=2)
200
- draw.arc([center[0]+5, center[1]-35, center[0]+25, center[1]-25], 0, 180, fill='black', width=2)
201
- # Pupils
202
- draw.ellipse([center[0]-17, center[1]-17, center[0]-13, center[1]-13], fill='black')
203
- draw.ellipse([center[0]+13, center[1]-17, center[0]+17, center[1]-13], fill='black')
204
-
205
- return img
206
 
207
- def _apply_style_transfer(self, img, prompt):
208
- """Apply style transfer effects"""
209
- prompt_lower = prompt.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- if "cartoon" in prompt_lower:
212
- # Cartoon style - enhance colors and add outlines
213
- img = img.filter(ImageFilter.EDGE_ENHANCE_MORE)
214
- # Increase saturation
215
- from PIL import ImageEnhance
216
- enhancer = ImageEnhance.Color(img)
217
- img = enhancer.enhance(1.5)
218
-
219
- elif "realistic" in prompt_lower:
220
- # Realistic style - add shading and texture
221
- img = img.filter(ImageFilter.GaussianBlur(radius=0.5))
222
- # Add subtle noise for texture
223
- import numpy as np
224
- pixels = np.array(img)
225
- noise = np.random.normal(0, 3, pixels.shape)
226
- pixels = np.clip(pixels + noise, 0, 255).astype(np.uint8)
227
- img = Image.fromarray(pixels)
228
-
229
- elif "watercolor" in prompt_lower:
230
- # Watercolor style - soft edges and bleeding
231
- img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
232
- from PIL import ImageEnhance
233
- enhancer = ImageEnhance.Color(img)
234
- img = enhancer.enhance(0.8)
235
-
236
- return img
237
 
238
- def _add_details(self, img, prompt):
239
- """Add contextual details to the sketch"""
240
- draw = ImageDraw.Draw(img)
241
  prompt_lower = prompt.lower()
242
 
243
- if "landscape" in prompt_lower:
244
- # Add clouds
245
- for i in range(3):
246
- x = 40 + i * 60
247
- y = 30 + i * 10
248
- draw.ellipse([x, y, x+30, y+15], fill='lightgray', outline='gray')
249
- # Add birds
250
- for i in range(2):
251
- x = 60 + i * 80
252
- y = 50 + i * 5
253
- draw.arc([x, y, x+10, y+5], 0, 180, fill='black', width=1)
254
-
255
- elif "portrait" in prompt_lower:
256
- # Add hair details
257
- center = (112, 112)
258
- for i in range(15):
259
- x = center[0] + random.randint(-50, 50)
260
- y = center[1] - 60 + random.randint(0, 30)
261
- draw.line([x, y, x + random.randint(-5, 5), y + random.randint(10, 30)],
262
- fill='brown', width=1)
263
-
264
- elif "building" in prompt_lower:
265
- # Add architectural details
266
- # Window frames
267
- draw.rectangle([69, 139, 91, 161], outline='black', width=1)
268
- draw.rectangle([133, 139, 155, 161], outline='black', width=1)
269
- # Chimney
270
- draw.rectangle([140, 70, 150, 90], outline='black', width=2)
271
-
272
- return img
273
 
274
- def _apply_coloring(self, img, prompt):
275
- """Apply coloring to the sketch"""
276
- # Convert to RGBA for transparency effects
277
- img = img.convert('RGBA')
278
-
279
- # Create color overlay
280
- color_overlay = Image.new('RGBA', img.size, (0, 0, 0, 0))
281
- draw = ImageDraw.Draw(color_overlay)
282
-
283
- prompt_lower = prompt.lower()
284
 
285
- if "sunset" in prompt_lower:
286
- # Sunset colors
287
- draw.rectangle([0, 0, 224, 112], fill=(255, 200, 100, 50))
288
- draw.rectangle([0, 112, 224, 224], fill=(255, 150, 50, 30))
289
-
290
- elif "nature" in prompt_lower:
291
- # Nature colors
292
- draw.rectangle([0, 140, 224, 224], fill=(100, 200, 100, 40)) # Green ground
293
- draw.rectangle([0, 0, 224, 140], fill=(150, 200, 255, 30)) # Blue sky
294
 
295
- elif "warm" in prompt_lower:
296
- # Warm color palette
297
- draw.rectangle([0, 0, 224, 224], fill=(255, 200, 150, 25))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- elif "cool" in prompt_lower:
300
- # Cool color palette
301
- draw.rectangle([0, 0, 224, 224], fill=(150, 200, 255, 25))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- # Blend with original
304
- img = Image.alpha_composite(img, color_overlay)
305
- img = img.convert('RGB')
 
 
 
 
 
306
 
307
- return img
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import base64
5
+ import json
6
+ import numpy as np
7
+ import svgwrite
8
  import random
9
+ import math
10
+ from diffusers import StableDiffusionPipeline
11
+ from transformers import CLIPTextModel, CLIPTokenizer
12
+ from typing import List, Dict, Any, Tuple
13
  import io
14
+ from PIL import Image
15
 
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
  """Initialize DiffSketchEdit handler for Hugging Face Inference API"""
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {self.device}")
21
+
22
+ # Initialize Stable Diffusion pipeline
23
+ try:
24
+ self.pipe = StableDiffusionPipeline.from_pretrained(
25
+ "runwayml/stable-diffusion-v1-5",
26
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
27
+ safety_checker=None,
28
+ requires_safety_checker=False
29
+ )
30
+ self.pipe = self.pipe.to(self.device)
31
+ print("Stable Diffusion pipeline loaded successfully")
32
+ except Exception as e:
33
+ print(f"Error loading pipeline: {e}")
34
+ self.pipe = None
35
+
36
+ # Initialize tokenizer and text encoder
37
+ try:
38
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
39
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
40
+ self.text_encoder = self.text_encoder.to(self.device)
41
+ print("Text encoder loaded successfully")
42
+ except Exception as e:
43
+ print(f"Error loading text encoder: {e}")
44
+ self.tokenizer = None
45
+ self.text_encoder = None
46
 
47
  def __call__(self, data):
48
+ """Edit vector sketches based on text prompts"""
49
+ try:
50
+ # Extract inputs
51
+ inputs = data.get("inputs", "")
52
+ parameters = data.get("parameters", {})
53
+
54
+ # Handle different input formats
55
+ if isinstance(inputs, dict):
56
+ prompts = inputs.get("prompts", [])
57
+ if not prompts and "prompt" in inputs:
58
+ prompts = [inputs["prompt"]]
59
+ edit_type = inputs.get("edit_type", "refine")
60
+ input_svg = inputs.get("input_svg", None)
61
+ else:
62
+ # Simple string input
63
+ prompts = [str(inputs)]
64
+ edit_type = parameters.get("edit_type", "refine")
65
+ input_svg = parameters.get("input_svg", None)
66
+
67
+ if not prompts:
68
+ prompts = ["a simple sketch"]
69
+
70
+ # Extract parameters
71
+ width = parameters.get("width", 224)
72
+ height = parameters.get("height", 224)
73
+ seed = parameters.get("seed", 42)
74
+
75
+ # Set seed for reproducibility
76
+ torch.manual_seed(seed)
77
+ np.random.seed(seed)
78
+ random.seed(seed)
79
+
80
+ print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
81
+
82
+ # Process based on edit type
83
+ if edit_type == "replace" and len(prompts) >= 2:
84
+ result = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
85
+ elif edit_type == "refine":
86
+ result = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
87
+ elif edit_type == "reweight":
88
+ result = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
89
+ elif edit_type == "generate":
90
+ result = self.simple_generation(prompts[0], width, height)
91
+ else:
92
+ # Default to refinement
93
+ result = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
94
+
95
+ return result
96
+
97
+ except Exception as e:
98
+ print(f"Error in handler: {e}")
99
+ # Return fallback result
100
+ fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
101
+ return {
102
+ "svg": fallback_svg,
103
+ "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
104
+ "edit_type": edit_type,
105
+ "error": str(e)
106
+ }
107
+
108
+ def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
109
+ """Perform word replacement editing"""
110
+ try:
111
+ print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
112
+
113
+ # Analyze the difference between prompts
114
+ source_words = set(source_prompt.lower().split())
115
+ target_words = set(target_prompt.lower().split())
116
+
117
+ added_words = target_words - source_words
118
+ removed_words = source_words - target_words
119
+
120
+ print(f"Added words: {added_words}, Removed words: {removed_words}")
121
+
122
+ # Generate base SVG from source prompt
123
+ if input_svg:
124
+ base_svg = input_svg
125
+ else:
126
+ base_svg = self.generate_base_svg(source_prompt, width, height)
127
+
128
+ # Apply word replacement transformations
129
+ edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
130
+
131
+ return {
132
+ "svg": edited_svg,
133
+ "svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'),
134
+ "edit_type": "replace",
135
+ "source_prompt": source_prompt,
136
+ "target_prompt": target_prompt,
137
+ "added_words": list(added_words),
138
+ "removed_words": list(removed_words)
139
+ }
140
+
141
+ except Exception as e:
142
+ print(f"Error in word_replacement_edit: {e}")
143
+ return self.create_error_result(source_prompt, "replace", str(e), width, height)
144
+
145
+ def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
146
+ """Perform prompt refinement editing"""
147
+ try:
148
+ print(f"Prompt refinement for: '{prompt}'")
149
+
150
+ # Generate or use base SVG
151
+ if input_svg:
152
+ base_svg = input_svg
153
+ else:
154
+ base_svg = self.generate_base_svg(prompt, width, height)
155
+
156
+ # Apply refinement based on prompt analysis
157
+ refined_svg = self.apply_refinement(base_svg, prompt, width, height)
158
+
159
+ return {
160
+ "svg": refined_svg,
161
+ "svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'),
162
+ "edit_type": "refine",
163
+ "prompt": prompt
164
+ }
165
+
166
+ except Exception as e:
167
+ print(f"Error in prompt_refinement_edit: {e}")
168
+ return self.create_error_result(prompt, "refine", str(e), width, height)
169
+
170
+ def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
171
+ """Perform attention reweighting editing"""
172
+ try:
173
+ print(f"Attention reweighting for: '{prompt}'")
174
+
175
+ # Parse attention weights from prompt (e.g., "(cat:1.5)" or "[dog:0.8]")
176
+ weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
177
+
178
+ # Generate or use base SVG
179
+ if input_svg:
180
+ base_svg = input_svg
181
+ else:
182
+ base_svg = self.generate_base_svg(weighted_prompt, width, height)
183
+
184
+ # Apply attention reweighting
185
+ reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
186
+
187
+ return {
188
+ "svg": reweighted_svg,
189
+ "svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'),
190
+ "edit_type": "reweight",
191
+ "prompt": prompt,
192
+ "weighted_prompt": weighted_prompt,
193
+ "attention_weights": attention_weights
194
+ }
195
+
196
+ except Exception as e:
197
+ print(f"Error in attention_reweighting_edit: {e}")
198
+ return self.create_error_result(prompt, "reweight", str(e), width, height)
199
+
200
+ def simple_generation(self, prompt: str, width: int, height: int):
201
+ """Perform simple SVG generation"""
202
+ try:
203
+ print(f"Simple generation for: '{prompt}'")
204
+
205
+ svg_content = self.generate_base_svg(prompt, width, height)
206
+
207
+ return {
208
+ "svg": svg_content,
209
+ "svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'),
210
+ "edit_type": "generate",
211
+ "prompt": prompt
212
+ }
213
+
214
+ except Exception as e:
215
+ print(f"Error in simple_generation: {e}")
216
+ return self.create_error_result(prompt, "generate", str(e), width, height)
217
+
218
+ def generate_base_svg(self, prompt: str, width: int, height: int):
219
+ """Generate base SVG from prompt"""
220
+ dwg = svgwrite.Drawing(size=(width, height))
221
+ dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
222
+
223
+ # Analyze prompt to determine content
224
+ prompt_lower = prompt.lower()
225
+
226
+ if any(word in prompt_lower for word in ['house', 'building', 'home']):
227
+ self._add_house_elements(dwg, width, height)
228
+ elif any(word in prompt_lower for word in ['tree', 'forest', 'nature']):
229
+ self._add_tree_elements(dwg, width, height)
230
+ elif any(word in prompt_lower for word in ['car', 'vehicle', 'transport']):
231
+ self._add_car_elements(dwg, width, height)
232
+ elif any(word in prompt_lower for word in ['face', 'person', 'portrait']):
233
+ self._add_face_elements(dwg, width, height)
234
+ elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']):
235
+ self._add_flower_elements(dwg, width, height)
236
+ elif any(word in prompt_lower for word in ['cat', 'dog', 'animal']):
237
+ self._add_animal_elements(dwg, width, height, prompt_lower)
238
  else:
239
+ self._add_abstract_elements(dwg, width, height, prompt)
240
+
241
+ return dwg.tostring()
242
+
243
+ def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, added_words: set, removed_words: set, width: int, height: int):
244
+ """Apply word replacement transformations to SVG"""
245
+ # Parse the base SVG and modify based on word changes
246
+ dwg = svgwrite.Drawing(size=(width, height))
247
+ dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
248
 
249
+ # Analyze what needs to change
250
+ for word in added_words:
251
+ if word in ['red', 'blue', 'green', 'yellow', 'purple']:
252
+ self._add_color_elements(dwg, word, width, height)
253
+ elif word in ['big', 'large', 'huge']:
254
+ self._add_size_modifier(dwg, 'large', width, height)
255
+ elif word in ['small', 'tiny', 'little']:
256
+ self._add_size_modifier(dwg, 'small', width, height)
257
+ elif word in ['cat', 'dog', 'bird']:
258
+ self._add_animal_elements(dwg, width, height, word)
259
+ elif word in ['house', 'tree', 'car']:
260
+ self._add_object_elements(dwg, word, width, height)
261
 
262
+ # Apply transformations based on target prompt
263
+ target_lower = target_prompt.lower()
264
+ if any(word in target_lower for word in ['house', 'building']):
265
+ self._add_house_elements(dwg, width, height)
266
+ elif any(word in target_lower for word in ['tree', 'forest']):
267
+ self._add_tree_elements(dwg, width, height)
268
+ elif any(word in target_lower for word in ['car', 'vehicle']):
269
+ self._add_car_elements(dwg, width, height)
270
+
271
+ return dwg.tostring()
 
 
 
 
 
 
 
272
 
273
+ def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
274
+ """Apply refinement to existing SVG"""
275
+ dwg = svgwrite.Drawing(size=(width, height))
276
+ dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
277
 
278
+ prompt_lower = prompt.lower()
279
+
280
+ # Add refined details based on prompt
281
+ if 'detailed' in prompt_lower or 'complex' in prompt_lower:
282
+ self._add_detailed_elements(dwg, width, height, prompt)
283
+ elif 'simple' in prompt_lower or 'minimal' in prompt_lower:
284
+ self._add_simple_elements(dwg, width, height, prompt)
 
 
 
 
 
 
 
 
285
  else:
286
  # Default refinement
287
+ self._add_standard_elements(dwg, width, height, prompt)
288
 
289
+ return dwg.tostring()
290
 
291
+ def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
292
+ """Apply attention reweighting to SVG elements"""
293
+ dwg = svgwrite.Drawing(size=(width, height))
294
+ dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
295
 
296
+ # Apply weighted emphasis to different elements
297
+ for word, weight in attention_weights.items():
298
+ if weight > 1.0:
299
+ # Emphasize this element
300
+ self._emphasize_element(dwg, word, weight, width, height)
301
+ elif weight < 1.0:
302
+ # De-emphasize this element
303
+ self._deemphasize_element(dwg, word, weight, width, height)
304
 
305
+ # Add base elements
306
+ self._add_standard_elements(dwg, width, height, prompt)
307
+
308
+ return dwg.tostring()
309
+
310
+ def parse_attention_weights(self, prompt: str) -> Tuple[str, dict]:
311
+ """Parse attention weights from prompt"""
312
+ import re
313
+
314
+ # Pattern for (word:weight) and [word:weight]
315
+ pattern = r'[\(\[]([^:\)\]]+):([0-9\.]+)[\)\]]'
316
+ matches = re.findall(pattern, prompt)
317
 
318
+ attention_weights = {}
319
+ clean_prompt = prompt
320
+
321
+ for word, weight_str in matches:
322
+ try:
323
+ weight = float(weight_str)
324
+ attention_weights[word.strip()] = weight
325
+ # Remove the weight notation from prompt
326
+ clean_prompt = re.sub(rf'[\(\[]{re.escape(word)}:{re.escape(weight_str)}[\)\]]', word, clean_prompt)
327
+ except ValueError:
328
+ continue
329
+
330
+ return clean_prompt.strip(), attention_weights
331
 
332
+ def _add_house_elements(self, dwg, width, height):
333
+ """Add house elements to SVG"""
334
+ house_width = width * 0.6
335
+ house_height = height * 0.4
336
+ house_x = (width - house_width) / 2
337
+ house_y = height * 0.4
338
+
339
+ # House base
340
+ dwg.add(dwg.rect(
341
+ insert=(house_x, house_y),
342
+ size=(house_width, house_height),
343
+ fill='none',
344
+ stroke='black',
345
+ stroke_width=2
346
+ ))
347
+
348
  # Roof
349
+ roof_points = [
350
+ (house_x, house_y),
351
+ (house_x + house_width/2, house_y - house_height*0.3),
352
+ (house_x + house_width, house_y)
353
+ ]
354
+ dwg.add(dwg.polygon(roof_points, fill='none', stroke='black', stroke_width=2))
355
+
356
  # Door
357
+ door_width = house_width * 0.2
358
+ door_height = house_height * 0.6
359
+ door_x = house_x + (house_width - door_width) / 2
360
+ door_y = house_y + house_height - door_height
361
+
362
+ dwg.add(dwg.rect(
363
+ insert=(door_x, door_y),
364
+ size=(door_width, door_height),
365
+ fill='none',
366
+ stroke='black',
367
+ stroke_width=2
368
+ ))
369
 
370
+ def _add_tree_elements(self, dwg, width, height):
371
+ """Add tree elements to SVG"""
372
+ center_x = width / 2
373
+ center_y = height / 2
374
+
375
  # Trunk
376
+ trunk_width = 12
377
+ trunk_height = height * 0.3
378
+ dwg.add(dwg.rect(
379
+ insert=(center_x - trunk_width/2, center_y + 20),
380
+ size=(trunk_width, trunk_height),
381
+ fill='none',
382
+ stroke='black',
383
+ stroke_width=2
384
+ ))
385
+
386
  # Crown
387
+ crown_radius = width * 0.25
388
+ dwg.add(dwg.circle(
389
+ center=(center_x, center_y),
390
+ r=crown_radius,
391
+ fill='none',
392
+ stroke='black',
393
+ stroke_width=2
394
+ ))
395
 
396
+ def _add_car_elements(self, dwg, width, height):
397
+ """Add car elements to SVG"""
398
+ car_width = width * 0.7
399
+ car_height = height * 0.3
400
+ car_x = (width - car_width) / 2
401
+ car_y = (height - car_height) / 2
402
+
403
+ # Car body
404
+ dwg.add(dwg.rect(
405
+ insert=(car_x, car_y),
406
+ size=(car_width, car_height),
407
+ fill='none',
408
+ stroke='black',
409
+ stroke_width=2,
410
+ rx=5
411
+ ))
412
+
413
+ # Wheels
414
+ wheel_radius = car_height * 0.4
415
+ wheel_y = car_y + car_height - wheel_radius/2
416
+
417
+ dwg.add(dwg.circle(
418
+ center=(car_x + car_width * 0.2, wheel_y),
419
+ r=wheel_radius,
420
+ fill='none',
421
+ stroke='black',
422
+ stroke_width=2
423
+ ))
424
+ dwg.add(dwg.circle(
425
+ center=(car_x + car_width * 0.8, wheel_y),
426
+ r=wheel_radius,
427
+ fill='none',
428
+ stroke='black',
429
+ stroke_width=2
430
+ ))
431
+
432
+ def _add_face_elements(self, dwg, width, height):
433
+ """Add face elements to SVG"""
434
+ center_x = width / 2
435
+ center_y = height / 2
436
+ face_radius = min(width, height) * 0.3
437
+
438
+ # Face outline
439
+ dwg.add(dwg.circle(
440
+ center=(center_x, center_y),
441
+ r=face_radius,
442
+ fill='none',
443
+ stroke='black',
444
+ stroke_width=2
445
+ ))
446
+
447
  # Eyes
448
+ eye_offset = face_radius * 0.3
449
+ eye_radius = face_radius * 0.1
450
+
451
+ dwg.add(dwg.circle(
452
+ center=(center_x - eye_offset, center_y - eye_offset),
453
+ r=eye_radius,
454
+ fill='black'
455
+ ))
456
+ dwg.add(dwg.circle(
457
+ center=(center_x + eye_offset, center_y - eye_offset),
458
+ r=eye_radius,
459
+ fill='black'
460
+ ))
461
+
462
  # Mouth
463
+ mouth_y = center_y + face_radius * 0.3
464
+ dwg.add(dwg.path(
465
+ d=f"M {center_x - face_radius*0.3},{mouth_y} Q {center_x},{mouth_y + face_radius*0.2} {center_x + face_radius*0.3},{mouth_y}",
466
+ fill='none',
467
+ stroke='black',
468
+ stroke_width=2
469
+ ))
470
 
471
+ def _add_flower_elements(self, dwg, width, height):
472
+ """Add flower elements to SVG"""
473
+ center_x = width / 2
474
+ center_y = height / 2
475
+
 
 
 
 
 
 
 
 
 
 
 
476
  # Stem
477
+ dwg.add(dwg.line(
478
+ start=(center_x, center_y + 20),
479
+ end=(center_x, height - 20),
480
+ stroke='green',
481
+ stroke_width=4
482
+ ))
483
+
484
  # Petals
485
+ petal_radius = 15
486
  for angle in range(0, 360, 45):
487
+ x = center_x + 25 * math.cos(math.radians(angle))
488
+ y = center_y + 25 * math.sin(math.radians(angle))
489
+ dwg.add(dwg.circle(
490
+ center=(x, y),
491
+ r=petal_radius,
492
+ fill='none',
493
+ stroke='red',
494
+ stroke_width=2
495
+ ))
496
+
497
  # Center
498
+ dwg.add(dwg.circle(
499
+ center=(center_x, center_y),
500
+ r=8,
501
+ fill='yellow',
502
+ stroke='orange',
503
+ stroke_width=2
504
+ ))
505
 
506
+ def _add_animal_elements(self, dwg, width, height, animal_type):
507
+ """Add animal elements to SVG"""
508
+ center_x = width / 2
509
+ center_y = height / 2
510
 
511
+ if 'cat' in animal_type:
512
+ # Cat body
513
+ dwg.add(dwg.ellipse(
514
+ center=(center_x, center_y + 20),
515
+ r=(30, 20),
516
+ fill='none',
517
+ stroke='black',
518
+ stroke_width=2
519
+ ))
520
+
521
+ # Cat head
522
+ dwg.add(dwg.circle(
523
+ center=(center_x, center_y - 20),
524
+ r=25,
525
+ fill='none',
526
+ stroke='black',
527
+ stroke_width=2
528
+ ))
529
+
530
+ # Cat ears
531
+ ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)]
532
+ ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)]
533
+ dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2))
534
+ dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2))
535
 
536
+ elif 'dog' in animal_type:
537
+ # Dog body
538
+ dwg.add(dwg.ellipse(
539
+ center=(center_x, center_y + 10),
540
+ r=(40, 25),
541
+ fill='none',
542
+ stroke='black',
543
+ stroke_width=2
544
+ ))
545
+
546
+ # Dog head
547
+ dwg.add(dwg.ellipse(
548
+ center=(center_x, center_y - 25),
549
+ r=(25, 20),
550
+ fill='none',
551
+ stroke='black',
552
+ stroke_width=2
553
+ ))
554
+
555
+ def _add_color_elements(self, dwg, color, width, height):
556
+ """Add color-specific elements"""
557
+ color_map = {
558
+ 'red': '#FF0000',
559
+ 'blue': '#0000FF',
560
+ 'green': '#00FF00',
561
+ 'yellow': '#FFFF00',
562
+ 'purple': '#800080'
563
+ }
564
 
565
+ fill_color = color_map.get(color, '#000000')
 
566
 
567
+ # Add a colored accent element
568
+ dwg.add(dwg.circle(
569
+ center=(width * 0.8, height * 0.2),
570
+ r=15,
571
+ fill=fill_color,
572
+ stroke='black',
573
+ stroke_width=1
574
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
+ def _add_size_modifier(self, dwg, size_type, width, height):
577
+ """Add size modification indicators"""
578
+ if size_type == 'large':
579
+ # Add larger elements
580
+ dwg.add(dwg.rect(
581
+ insert=(10, 10),
582
+ size=(width-20, height-20),
583
+ fill='none',
584
+ stroke='gray',
585
+ stroke_width=3,
586
+ stroke_dasharray='5,5'
587
+ ))
588
+ elif size_type == 'small':
589
+ # Add smaller elements
590
+ dwg.add(dwg.rect(
591
+ insert=(width*0.3, height*0.3),
592
+ size=(width*0.4, height*0.4),
593
+ fill='none',
594
+ stroke='gray',
595
+ stroke_width=1,
596
+ stroke_dasharray='2,2'
597
+ ))
598
+
599
+ def _add_object_elements(self, dwg, obj_type, width, height):
600
+ """Add specific object elements"""
601
+ if obj_type == 'house':
602
+ self._add_house_elements(dwg, width, height)
603
+ elif obj_type == 'tree':
604
+ self._add_tree_elements(dwg, width, height)
605
+ elif obj_type == 'car':
606
+ self._add_car_elements(dwg, width, height)
607
+
608
+ def _add_detailed_elements(self, dwg, width, height, prompt):
609
+ """Add detailed elements for complex prompts"""
610
+ # Add multiple overlapping shapes for complexity
611
+ for i in range(8):
612
+ x = random.randint(20, width-40)
613
+ y = random.randint(20, height-40)
614
+ size = random.randint(10, 30)
615
+
616
+ shape_type = random.choice(['circle', 'rect', 'polygon'])
617
+
618
+ if shape_type == 'circle':
619
+ dwg.add(dwg.circle(
620
+ center=(x, y),
621
+ r=size,
622
+ fill='none',
623
+ stroke='black',
624
+ stroke_width=1,
625
+ opacity=0.7
626
+ ))
627
+ elif shape_type == 'rect':
628
+ dwg.add(dwg.rect(
629
+ insert=(x-size, y-size),
630
+ size=(size*2, size*2),
631
+ fill='none',
632
+ stroke='black',
633
+ stroke_width=1,
634
+ opacity=0.7
635
+ ))
636
+
637
+ def _add_simple_elements(self, dwg, width, height, prompt):
638
+ """Add simple elements for minimal prompts"""
639
+ # Add just a few basic shapes
640
+ center_x = width / 2
641
+ center_y = height / 2
642
 
643
+ dwg.add(dwg.circle(
644
+ center=(center_x, center_y),
645
+ r=min(width, height) * 0.2,
646
+ fill='none',
647
+ stroke='black',
648
+ stroke_width=2
649
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
+ def _add_standard_elements(self, dwg, width, height, prompt):
652
+ """Add standard elements based on prompt"""
 
653
  prompt_lower = prompt.lower()
654
 
655
+ if any(word in prompt_lower for word in ['house', 'building']):
656
+ self._add_house_elements(dwg, width, height)
657
+ elif any(word in prompt_lower for word in ['tree', 'forest']):
658
+ self._add_tree_elements(dwg, width, height)
659
+ elif any(word in prompt_lower for word in ['car', 'vehicle']):
660
+ self._add_car_elements(dwg, width, height)
661
+ else:
662
+ self._add_abstract_elements(dwg, width, height, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
+ def _add_abstract_elements(self, dwg, width, height, prompt):
665
+ """Add abstract elements based on prompt"""
666
+ prompt_hash = hash(prompt) % 100
 
 
 
 
 
 
 
667
 
668
+ for i in range(5):
669
+ x = (i * 40 + prompt_hash) % (width - 40) + 20
670
+ y = (i * 35 + prompt_hash) % (height - 40) + 20
671
+ size = 15 + (i * 5) % 20
 
 
 
 
 
672
 
673
+ dwg.add(dwg.circle(
674
+ center=(x, y),
675
+ r=size,
676
+ fill='none',
677
+ stroke='black',
678
+ stroke_width=2,
679
+ opacity=0.8
680
+ ))
681
+
682
+ def _emphasize_element(self, dwg, word, weight, width, height):
683
+ """Emphasize an element based on attention weight"""
684
+ # Make elements larger and more prominent
685
+ scale_factor = weight
686
+ stroke_width = int(2 * scale_factor)
687
+
688
+ if word in ['house', 'building']:
689
+ # Emphasized house
690
+ house_size = min(width, height) * 0.4 * scale_factor
691
+ house_x = (width - house_size) / 2
692
+ house_y = (height - house_size) / 2
693
 
694
+ dwg.add(dwg.rect(
695
+ insert=(house_x, house_y),
696
+ size=(house_size, house_size * 0.8),
697
+ fill='none',
698
+ stroke='red',
699
+ stroke_width=stroke_width
700
+ ))
701
+
702
+ def _deemphasize_element(self, dwg, word, weight, width, height):
703
+ """De-emphasize an element based on attention weight"""
704
+ # Make elements smaller and less prominent
705
+ scale_factor = weight
706
+ stroke_width = max(1, int(2 * scale_factor))
707
+
708
+ if word in ['background', 'sky']:
709
+ # De-emphasized background elements
710
+ dwg.add(dwg.rect(
711
+ insert=(0, 0),
712
+ size=(width, height * 0.3),
713
+ fill='none',
714
+ stroke='lightgray',
715
+ stroke_width=stroke_width,
716
+ opacity=scale_factor
717
+ ))
718
+
719
+ def create_error_result(self, prompt: str, edit_type: str, error: str, width: int, height: int):
720
+ """Create error result with fallback SVG"""
721
+ fallback_svg = self.create_fallback_svg(prompt, width, height)
722
+ return {
723
+ "svg": fallback_svg,
724
+ "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
725
+ "edit_type": edit_type,
726
+ "prompt": prompt,
727
+ "error": error
728
+ }
729
+
730
+ def create_fallback_svg(self, prompt: str, width: int, height: int):
731
+ """Create simple fallback SVG"""
732
+ dwg = svgwrite.Drawing(size=(width, height))
733
+ dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
734
 
735
+ # Simple centered text
736
+ dwg.add(dwg.text(
737
+ f"DiffSketchEdit\n{prompt[:20]}...",
738
+ insert=(width/2, height/2),
739
+ text_anchor="middle",
740
+ font_size="14",
741
+ fill="black"
742
+ ))
743
 
744
+ return dwg.tostring()