jree423 commited on
Commit
697ad7a
·
verified ·
1 Parent(s): 2f39322

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +194 -36
handler.py CHANGED
@@ -63,7 +63,7 @@ class DiffSketchEditHandler:
63
  return Args()
64
 
65
  def __call__(self, data: Dict[str, Any]):
66
- """Process editing requests and return PIL Image"""
67
  try:
68
  # Handle different input formats
69
  if isinstance(data, dict):
@@ -75,16 +75,16 @@ class DiffSketchEditHandler:
75
 
76
  # Parse editing instructions
77
  if isinstance(inputs, str):
78
- prompt = inputs
79
  edit_type = "generate"
80
  elif isinstance(inputs, dict):
81
  if "prompts" in inputs:
82
- prompt = inputs["prompts"][0] if inputs["prompts"] else "Hello world!"
83
  else:
84
- prompt = inputs.get("prompt", "Hello world!")
85
  edit_type = inputs.get("edit_type", "replace")
86
  else:
87
- prompt = "Hello world!"
88
  edit_type = "generate"
89
 
90
  # Extract parameters
@@ -95,41 +95,199 @@ class DiffSketchEditHandler:
95
  # Set random seed
96
  np.random.seed(seed)
97
 
98
- # Create PIL Image for proper serialization
99
- img = Image.new('RGB', (width, height), 'white')
100
- draw = ImageDraw.Draw(img)
101
 
102
- # Draw based on edit type
103
- colors = [(231, 76, 60), (52, 152, 219), (46, 204, 113), (243, 156, 18)]
104
-
105
- if edit_type == "replace":
106
- # Draw replacement pattern
107
- for i in range(8):
108
- x = np.random.randint(10, width-30)
109
- y = np.random.randint(10, height-30)
110
- color = colors[i % len(colors)]
111
- draw.rectangle([x, y, x+20, y+20], fill=color)
112
- else:
113
- # Draw default pattern
114
- for i in range(6):
115
- x = np.random.randint(10, width-20)
116
- y = np.random.randint(10, height-20)
117
- color = colors[i % len(colors)]
118
- draw.ellipse([x, y, x+15, y+15], fill=color)
119
-
120
- # Add text if space allows
121
- try:
122
- draw.text((10, 10), f"{edit_type}: {prompt[:20]}...", fill='black')
123
- except:
124
- pass
125
-
126
- return img
127
 
128
  except Exception as e:
129
- # Return error image
130
- img = Image.new('RGB', (224, 224), 'red')
131
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
134
  """
135
  Generate an edited SVG as placeholder
 
63
  return Args()
64
 
65
  def __call__(self, data: Dict[str, Any]):
66
+ """Process editing requests and return edited SVG"""
67
  try:
68
  # Handle different input formats
69
  if isinstance(data, dict):
 
75
 
76
  # Parse editing instructions
77
  if isinstance(inputs, str):
78
+ prompts = [inputs]
79
  edit_type = "generate"
80
  elif isinstance(inputs, dict):
81
  if "prompts" in inputs:
82
+ prompts = inputs["prompts"] if inputs["prompts"] else ["Hello world!"]
83
  else:
84
+ prompts = [inputs.get("prompt", "Hello world!")]
85
  edit_type = inputs.get("edit_type", "replace")
86
  else:
87
+ prompts = ["Hello world!"]
88
  edit_type = "generate"
89
 
90
  # Extract parameters
 
95
  # Set random seed
96
  np.random.seed(seed)
97
 
98
+ # Generate edited SVG based on the sequence of prompts
99
+ svg_content = self._generate_edited_svg_sequence(prompts, width, height, edit_type, seed)
 
100
 
101
+ return svg_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  except Exception as e:
104
+ # Return error SVG
105
+ return f'<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg"><text x="10" y="20" fill="red">Error: {str(e)}</text></svg>'
106
+
107
+ def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
108
+ """Generate SVG showing editing progression through prompt sequence"""
109
+ svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}">'
110
+ svg_footer = '</svg>'
111
+
112
+ paths = []
113
+
114
+ # Color schemes for different edit types
115
+ if edit_type == "replace":
116
+ colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
117
+ elif edit_type == "refine":
118
+ colors = ["#34495E", "#2C3E50", "#7F8C8D", "#95A5A6", "#BDC3C7", "#ECF0F1"]
119
+ elif edit_type == "reweight":
120
+ colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
121
+ else: # generate
122
+ colors = ["#2C3E50", "#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6"]
123
+
124
+ # Generate base content from first prompt
125
+ if prompts:
126
+ base_prompt = prompts[0].lower()
127
+ self._add_base_content(paths, width, height, colors, base_prompt)
128
+
129
+ # Apply edits based on subsequent prompts
130
+ for i, prompt in enumerate(prompts[1:], 1):
131
+ self._apply_edit_step(paths, width, height, colors, prompt.lower(), edit_type, i)
132
+
133
+ # Add editing indicators
134
+ self._add_edit_indicators(paths, width, height, edit_type, len(prompts))
135
+
136
+ return svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
137
+
138
+ def _add_base_content(self, paths, width, height, colors, prompt):
139
+ """Add base content based on the first prompt"""
140
+ center_x, center_y = width // 2, height // 2
141
+
142
+ # Analyze prompt for content type
143
+ if any(word in prompt for word in ['cat', 'animal', 'pet']):
144
+ self._add_cat_base(paths, center_x, center_y, colors[0])
145
+ elif any(word in prompt for word in ['house', 'building', 'home']):
146
+ self._add_house_base(paths, center_x, center_y, colors[0])
147
+ elif any(word in prompt for word in ['tree', 'plant', 'nature']):
148
+ self._add_tree_base(paths, center_x, center_y, colors[0])
149
+ elif any(word in prompt for word in ['car', 'vehicle', 'automobile']):
150
+ self._add_car_base(paths, center_x, center_y, colors[0])
151
+ else:
152
+ # Generic geometric base
153
+ self._add_generic_base(paths, center_x, center_y, colors[0])
154
+
155
+ def _apply_edit_step(self, paths, width, height, colors, prompt, edit_type, step):
156
+ """Apply editing step based on prompt and edit type"""
157
+ color = colors[step % len(colors)]
158
+
159
+ if edit_type == "replace":
160
+ # Replace elements with new ones
161
+ if 'burger' in prompt:
162
+ self._add_burger_elements(paths, width, height, color, step)
163
+ elif 'rabbit' in prompt:
164
+ self._add_rabbit_elements(paths, width, height, color, step)
165
+ else:
166
+ self._add_replacement_elements(paths, width, height, color, step)
167
+
168
+ elif edit_type == "refine":
169
+ # Add refinement details
170
+ self._add_refinement_details(paths, width, height, color, step)
171
+
172
+ elif edit_type == "reweight":
173
+ # Emphasize certain elements
174
+ self._add_emphasis_elements(paths, width, height, color, step)
175
+
176
+ else: # generate
177
+ self._add_generation_elements(paths, width, height, color, step)
178
+
179
+ def _add_edit_indicators(self, paths, width, height, edit_type, num_steps):
180
+ """Add visual indicators of the editing process"""
181
+ # Add step indicators
182
+ for i in range(num_steps):
183
+ x = 10 + i * 15
184
+ y = height - 20
185
+ paths.append(f'<circle cx="{x}" cy="{y}" r="5" fill="#333" opacity="0.7"/>')
186
+ paths.append(f'<text x="{x}" y="{y + 3}" text-anchor="middle" font-size="8" fill="white">{i+1}</text>')
187
+
188
+ # Add edit type label
189
+ paths.append(f'<text x="10" y="15" font-size="12" fill="#333">{edit_type.title()} Edit</text>')
190
+
191
+ def _add_cat_base(self, paths, center_x, center_y, color):
192
+ """Add base cat shape"""
193
+ # Body
194
+ paths.append(f'<ellipse cx="{center_x}" cy="{center_y + 20}" rx="35" ry="20" fill="{color}" opacity="0.8"/>')
195
+ # Head
196
+ paths.append(f'<circle cx="{center_x}" cy="{center_y - 15}" r="20" fill="{color}" opacity="0.8"/>')
197
+ # Ears
198
+ paths.append(f'<polygon points="{center_x-15},{center_y-25} {center_x-8},{center_y-35} {center_x-3},{center_y-25}" fill="{color}"/>')
199
+ paths.append(f'<polygon points="{center_x+3},{center_y-25} {center_x+8},{center_y-35} {center_x+15},{center_y-25}" fill="{color}"/>')
200
+
201
+ def _add_house_base(self, paths, center_x, center_y, color):
202
+ """Add base house shape"""
203
+ # Base
204
+ paths.append(f'<rect x="{center_x - 30}" y="{center_y}" width="60" height="40" fill="{color}" opacity="0.8"/>')
205
+ # Roof
206
+ paths.append(f'<polygon points="{center_x-35},{center_y} {center_x},{center_y-25} {center_x+35},{center_y}" fill="{color}"/>')
207
+
208
+ def _add_tree_base(self, paths, center_x, center_y, color):
209
+ """Add base tree shape"""
210
+ # Trunk
211
+ paths.append(f'<rect x="{center_x - 5}" y="{center_y + 10}" width="10" height="25" fill="{color}"/>')
212
+ # Leaves
213
+ paths.append(f'<circle cx="{center_x}" cy="{center_y - 5}" r="25" fill="{color}" opacity="0.8"/>')
214
 
215
+ def _add_car_base(self, paths, center_x, center_y, color):
216
+ """Add base car shape"""
217
+ # Body
218
+ paths.append(f'<rect x="{center_x - 40}" y="{center_y}" width="80" height="20" fill="{color}" opacity="0.8"/>')
219
+ # Wheels
220
+ paths.append(f'<circle cx="{center_x - 25}" cy="{center_y + 25}" r="8" fill="{color}"/>')
221
+ paths.append(f'<circle cx="{center_x + 25}" cy="{center_y + 25}" r="8" fill="{color}"/>')
222
+
223
+ def _add_generic_base(self, paths, center_x, center_y, color):
224
+ """Add generic base shapes"""
225
+ paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="30" fill="none" stroke="{color}" stroke-width="3"/>')
226
+ paths.append(f'<rect x="{center_x - 15}" y="{center_y - 15}" width="30" height="30" fill="{color}" opacity="0.5"/>')
227
+
228
+ def _add_burger_elements(self, paths, width, height, color, step):
229
+ """Add burger elements for replacement"""
230
+ center_x, center_y = width // 2, height // 2
231
+ offset = step * 10
232
+
233
+ # Burger bun
234
+ paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y - 10}" rx="25" ry="8" fill="{color}"/>')
235
+ # Patty
236
+ paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y}" rx="20" ry="5" fill="{color}" opacity="0.8"/>')
237
+ # Bottom bun
238
+ paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y + 10}" rx="25" ry="8" fill="{color}"/>')
239
+
240
+ def _add_rabbit_elements(self, paths, width, height, color, step):
241
+ """Add rabbit elements for replacement"""
242
+ center_x, center_y = width // 2, height // 2
243
+ offset = step * 15
244
+
245
+ # Body
246
+ paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y + 15}" rx="30" ry="18" fill="{color}" opacity="0.8"/>')
247
+ # Head
248
+ paths.append(f'<circle cx="{center_x + offset}" cy="{center_y - 10}" r="18" fill="{color}" opacity="0.8"/>')
249
+ # Long ears
250
+ paths.append(f'<ellipse cx="{center_x + offset - 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
251
+ paths.append(f'<ellipse cx="{center_x + offset + 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
252
+
253
+ def _add_replacement_elements(self, paths, width, height, color, step):
254
+ """Add generic replacement elements"""
255
+ for i in range(3):
256
+ x = np.random.randint(20, width - 20)
257
+ y = np.random.randint(20, height - 20)
258
+ size = 10 + step * 2
259
+ paths.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{color}" opacity="0.6"/>')
260
+
261
+ def _add_refinement_details(self, paths, width, height, color, step):
262
+ """Add refinement details"""
263
+ center_x, center_y = width // 2, height // 2
264
+
265
+ # Add fine details around center
266
+ for i in range(step * 2):
267
+ angle = (i * 360 / (step * 2)) * (3.14159 / 180)
268
+ radius = 40 + step * 5
269
+ x = center_x + radius * np.cos(angle)
270
+ y = center_y + radius * np.sin(angle)
271
+ paths.append(f'<circle cx="{x}" cy="{y}" r="2" fill="{color}"/>')
272
+
273
+ def _add_emphasis_elements(self, paths, width, height, color, step):
274
+ """Add emphasis elements for reweighting"""
275
+ center_x, center_y = width // 2, height // 2
276
+
277
+ # Add emphasis rings
278
+ for i in range(step):
279
+ radius = 20 + i * 15
280
+ stroke_width = 3 + i
281
+ paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{radius}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="0.7"/>')
282
+
283
+ def _add_generation_elements(self, paths, width, height, color, step):
284
+ """Add generation elements"""
285
+ for i in range(step * 2):
286
+ x = np.random.randint(10, width - 10)
287
+ y = np.random.randint(10, height - 10)
288
+ size = np.random.randint(5, 15)
289
+ paths.append(f'<rect x="{x}" y="{y}" width="{size}" height="{size}" fill="{color}" opacity="0.6"/>')
290
+
291
  def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
292
  """
293
  Generate an edited SVG as placeholder