jree423 commited on
Commit
42940f8
·
verified ·
1 Parent(s): c884f93

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +186 -47
handler.py CHANGED
@@ -67,7 +67,7 @@ class SVGDreamerHandler:
67
  return Args()
68
 
69
  def __call__(self, data: Dict[str, Any]):
70
- """Process input and return PIL Image"""
71
  try:
72
  # Extract inputs
73
  if isinstance(data, dict):
@@ -81,6 +81,7 @@ class SVGDreamerHandler:
81
  prompt = "a simple drawing"
82
 
83
  # Extract parameters
 
84
  width = parameters.get("width", 224)
85
  height = parameters.get("height", 224)
86
  seed = parameters.get("seed", 42)
@@ -89,57 +90,195 @@ class SVGDreamerHandler:
89
  # Set random seed
90
  np.random.seed(seed)
91
 
92
- # Create PIL Image for proper serialization
93
- from PIL import Image, ImageDraw
94
 
95
- img = Image.new('RGB', (width, height), 'white')
96
- draw = ImageDraw.Draw(img)
97
 
98
- # Different color schemes based on style
99
- if style == "iconography":
100
- colors = [(44, 62, 80), (231, 76, 60), (52, 152, 219), (46, 204, 113)]
101
- elif style == "pixel_art":
102
- colors = [(255, 107, 107), (78, 205, 196), (69, 183, 209), (150, 206, 180)]
103
- else:
104
- colors = [(52, 73, 94), (230, 126, 34), (26, 188, 156), (142, 68, 173)]
105
-
106
- # Generate patterns based on style
107
- if style == "iconography":
108
- # Clean geometric shapes
109
- for i in range(6):
110
- x = np.random.randint(20, width-40)
111
- y = np.random.randint(20, height-40)
112
- size = np.random.randint(15, 35)
113
- color = colors[i % len(colors)]
114
- draw.ellipse([x, y, x+size, y+size], fill=color)
115
-
116
- elif style == "pixel_art":
117
- # Pixelated squares
118
- for i in range(12):
119
- x = np.random.randint(0, width-20)
120
- y = np.random.randint(0, height-20)
121
- size = 15
122
- color = colors[i % len(colors)]
123
- draw.rectangle([x, y, x+size, y+size], fill=color)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  else:
126
- # Mixed shapes
127
- for i in range(8):
128
- x = np.random.randint(10, width-30)
129
- y = np.random.randint(10, height-30)
130
- color = colors[i % len(colors)]
131
- if i % 2 == 0:
132
- draw.ellipse([x, y, x+20, y+20], fill=color)
133
- else:
134
- draw.rectangle([x, y, x+20, y+20], fill=color)
135
-
136
- return img
137
-
138
- except Exception as e:
139
- # Return error image
140
- img = Image.new('RGB', (224, 224), 'red')
141
- return img
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def _generate_simple_svg(self, prompt: str, width: int, height: int, particle_id: int, style: str) -> str:
144
  """
145
  Generate a simple SVG as placeholder for each particle
 
67
  return Args()
68
 
69
  def __call__(self, data: Dict[str, Any]):
70
+ """Generate multi-particle SVG from text prompt"""
71
  try:
72
  # Extract inputs
73
  if isinstance(data, dict):
 
81
  prompt = "a simple drawing"
82
 
83
  # Extract parameters
84
+ n_particle = parameters.get("n_particle", 6)
85
  width = parameters.get("width", 224)
86
  height = parameters.get("height", 224)
87
  seed = parameters.get("seed", 42)
 
90
  # Set random seed
91
  np.random.seed(seed)
92
 
93
+ # Generate the best particle SVG (simulate particle selection)
94
+ best_svg = self._generate_particle_svg(prompt, width, height, style, seed)
95
 
96
+ return best_svg
 
97
 
98
+ except Exception as e:
99
+ # Return error SVG
100
+ return f'<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg"><text x="10" y="20" fill="red">Error: {str(e)}</text></svg>'
101
+
102
+ def _generate_particle_svg(self, prompt: str, width: int, height: int, style: str, seed: int) -> str:
103
+ """Generate SVG using particle-based optimization simulation"""
104
+ svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}">'
105
+ svg_footer = '</svg>'
106
+
107
+ paths = []
108
+ prompt_lower = prompt.lower()
109
+
110
+ # Style-based color palettes
111
+ if style == "iconography":
112
+ colors = ["#2C3E50", "#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6"]
113
+ elif style == "pixel_art":
114
+ colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
115
+ elif style == "sketch":
116
+ colors = ["#34495E", "#7F8C8D", "#95A5A6", "#BDC3C7", "#ECF0F1", "#2C3E50"]
117
+ else: # painting
118
+ colors = ["#8E44AD", "#3498DB", "#E67E22", "#E74C3C", "#F1C40F", "#27AE60"]
119
+
120
+ # Generate content based on prompt analysis
121
+ if any(word in prompt_lower for word in ['icon', 'logo', 'symbol', 'simple']):
122
+ self._add_icon_elements(paths, width, height, colors, style)
123
+ elif any(word in prompt_lower for word in ['landscape', 'mountain', 'nature', 'scene']):
124
+ self._add_landscape_elements(paths, width, height, colors, style)
125
+ elif any(word in prompt_lower for word in ['character', 'person', 'face', 'figure']):
126
+ self._add_character_elements(paths, width, height, colors, style)
127
+ elif any(word in prompt_lower for word in ['abstract', 'pattern', 'geometric']):
128
+ self._add_abstract_elements(paths, width, height, colors, style)
129
+ else:
130
+ self._add_general_elements(paths, width, height, colors, style, prompt_lower)
131
+
132
+ return svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
133
+
134
+ def _add_icon_elements(self, paths, width, height, colors, style):
135
+ """Add icon-style elements"""
136
+ center_x, center_y = width // 2, height // 2
137
+
138
+ if style == "iconography":
139
+ # Clean geometric icon
140
+ main_size = min(width, height) // 3
141
+ paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{main_size}" fill="none" stroke="{colors[0]}" stroke-width="4"/>')
142
+ paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{main_size//2}" fill="{colors[1]}" opacity="0.8"/>')
143
+ else:
144
+ # More detailed icon
145
+ for i in range(3):
146
+ size = (3-i) * 15
147
+ paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{size}" fill="none" stroke="{colors[i]}" stroke-width="3"/>')
148
+
149
+ def _add_landscape_elements(self, paths, width, height, colors, style):
150
+ """Add landscape elements"""
151
+ # Sky
152
+ paths.append(f'<rect x="0" y="0" width="{width}" height="{height//2}" fill="{colors[2]}" opacity="0.3"/>')
153
+
154
+ # Mountains
155
+ for i in range(3):
156
+ x1 = i * width // 3
157
+ x2 = x1 + width // 3
158
+ peak_x = x1 + width // 6
159
+ peak_y = height // 3 + i * 10
160
 
161
+ points = f"{x1},{height//2} {peak_x},{peak_y} {x2},{height//2}"
162
+ paths.append(f'<polygon points="{points}" fill="{colors[i % len(colors)]}" opacity="0.7"/>')
163
+
164
+ # Ground
165
+ paths.append(f'<rect x="0" y="{height//2}" width="{width}" height="{height//2}" fill="{colors[3]}" opacity="0.4"/>')
166
+
167
+ def _add_character_elements(self, paths, width, height, colors, style):
168
+ """Add character/figure elements"""
169
+ center_x, center_y = width // 2, height // 2
170
+
171
+ # Head
172
+ head_r = min(width, height) // 8
173
+ paths.append(f'<circle cx="{center_x}" cy="{center_y - 30}" r="{head_r}" fill="{colors[0]}" opacity="0.8"/>')
174
+
175
+ # Body
176
+ body_width = head_r * 2
177
+ body_height = head_r * 3
178
+ paths.append(f'<rect x="{center_x - body_width//2}" y="{center_y - 10}" width="{body_width}" height="{body_height}" fill="{colors[1]}" opacity="0.8"/>')
179
+
180
+ # Arms
181
+ arm_length = head_r * 2
182
+ paths.append(f'<line x1="{center_x - body_width//2}" y1="{center_y + 10}" x2="{center_x - body_width//2 - arm_length}" y2="{center_y + 30}" stroke="{colors[2]}" stroke-width="4"/>')
183
+ paths.append(f'<line x1="{center_x + body_width//2}" y1="{center_y + 10}" x2="{center_x + body_width//2 + arm_length}" y2="{center_y + 30}" stroke="{colors[2]}" stroke-width="4"/>')
184
+
185
+ def _add_abstract_elements(self, paths, width, height, colors, style):
186
+ """Add abstract/geometric elements"""
187
+ for i in range(8):
188
+ if i % 3 == 0:
189
+ # Circles
190
+ cx = np.random.randint(30, width - 30)
191
+ cy = np.random.randint(30, height - 30)
192
+ r = np.random.randint(10, 40)
193
+ paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{colors[i % len(colors)]}" opacity="0.6"/>')
194
+ elif i % 3 == 1:
195
+ # Rectangles
196
+ x = np.random.randint(10, width - 50)
197
+ y = np.random.randint(10, height - 50)
198
+ w = np.random.randint(20, 60)
199
+ h = np.random.randint(20, 60)
200
+ paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="{colors[i % len(colors)]}" opacity="0.6"/>')
201
  else:
202
+ # Triangles
203
+ x1 = np.random.randint(20, width - 20)
204
+ y1 = np.random.randint(20, height - 20)
205
+ x2 = x1 + np.random.randint(-30, 30)
206
+ y2 = y1 + np.random.randint(-30, 30)
207
+ x3 = x1 + np.random.randint(-30, 30)
208
+ y3 = y1 + np.random.randint(-30, 30)
209
+ points = f"{x1},{y1} {x2},{y2} {x3},{y3}"
210
+ paths.append(f'<polygon points="{points}" fill="{colors[i % len(colors)]}" opacity="0.6"/>')
 
 
 
 
 
 
 
211
 
212
+ def _add_general_elements(self, paths, width, height, colors, style, prompt_lower):
213
+ """Add general elements based on prompt keywords"""
214
+ # Analyze prompt for specific objects
215
+ if 'cat' in prompt_lower or 'animal' in prompt_lower:
216
+ self._add_simple_cat(paths, width, height, colors)
217
+ elif 'house' in prompt_lower or 'building' in prompt_lower:
218
+ self._add_simple_house(paths, width, height, colors)
219
+ elif 'tree' in prompt_lower or 'plant' in prompt_lower:
220
+ self._add_simple_tree(paths, width, height, colors)
221
+ elif 'car' in prompt_lower or 'vehicle' in prompt_lower:
222
+ self._add_simple_car(paths, width, height, colors)
223
+ else:
224
+ # Default abstract composition
225
+ self._add_abstract_elements(paths, width, height, colors, style)
226
+
227
+ def _add_simple_cat(self, paths, width, height, colors):
228
+ """Add a simple cat figure"""
229
+ center_x, center_y = width // 2, height // 2
230
+
231
+ # Body
232
+ paths.append(f'<ellipse cx="{center_x}" cy="{center_y + 20}" rx="40" ry="25" fill="{colors[0]}"/>')
233
+ # Head
234
+ paths.append(f'<circle cx="{center_x}" cy="{center_y - 20}" r="25" fill="{colors[0]}"/>')
235
+ # Ears
236
+ paths.append(f'<polygon points="{center_x-20},{center_y-35} {center_x-10},{center_y-50} {center_x-5},{center_y-35}" fill="{colors[0]}"/>')
237
+ paths.append(f'<polygon points="{center_x+5},{center_y-35} {center_x+10},{center_y-50} {center_x+20},{center_y-35}" fill="{colors[0]}"/>')
238
+ # Tail
239
+ paths.append(f'<path d="M {center_x+35} {center_y+15} Q {center_x+60} {center_y-10} {center_x+45} {center_y-30}" stroke="{colors[0]}" stroke-width="8" fill="none"/>')
240
+
241
+ def _add_simple_house(self, paths, width, height, colors):
242
+ """Add a simple house"""
243
+ center_x, center_y = width // 2, height // 2
244
+
245
+ # Base
246
+ house_width, house_height = 80, 60
247
+ paths.append(f'<rect x="{center_x - house_width//2}" y="{center_y}" width="{house_width}" height="{house_height}" fill="{colors[0]}"/>')
248
+ # Roof
249
+ roof_points = f"{center_x - house_width//2 - 10},{center_y} {center_x},{center_y - 40} {center_x + house_width//2 + 10},{center_y}"
250
+ paths.append(f'<polygon points="{roof_points}" fill="{colors[1]}"/>')
251
+ # Door
252
+ paths.append(f'<rect x="{center_x - 10}" y="{center_y + 20}" width="20" height="40" fill="{colors[2]}"/>')
253
+ # Window
254
+ paths.append(f'<rect x="{center_x + 15}" y="{center_y + 15}" width="15" height="15" fill="{colors[3]}"/>')
255
+
256
+ def _add_simple_tree(self, paths, width, height, colors):
257
+ """Add a simple tree"""
258
+ center_x, center_y = width // 2, height // 2
259
+
260
+ # Trunk
261
+ paths.append(f'<rect x="{center_x - 8}" y="{center_y + 10}" width="16" height="40" fill="{colors[0]}"/>')
262
+ # Leaves
263
+ paths.append(f'<circle cx="{center_x}" cy="{center_y - 10}" r="35" fill="{colors[1]}"/>')
264
+ paths.append(f'<circle cx="{center_x - 20}" cy="{center_y}" r="25" fill="{colors[1]}"/>')
265
+ paths.append(f'<circle cx="{center_x + 20}" cy="{center_y}" r="25" fill="{colors[1]}"/>')
266
+
267
+ def _add_simple_car(self, paths, width, height, colors):
268
+ """Add a simple car"""
269
+ center_x, center_y = width // 2, height // 2
270
+
271
+ # Body
272
+ paths.append(f'<rect x="{center_x - 50}" y="{center_y}" width="100" height="30" fill="{colors[0]}"/>')
273
+ # Top
274
+ paths.append(f'<rect x="{center_x - 30}" y="{center_y - 20}" width="60" height="20" fill="{colors[0]}"/>')
275
+ # Wheels
276
+ paths.append(f'<circle cx="{center_x - 30}" cy="{center_y + 35}" r="12" fill="{colors[1]}"/>')
277
+ paths.append(f'<circle cx="{center_x + 30}" cy="{center_y + 35}" r="12" fill="{colors[1]}"/>')
278
+ # Windows
279
+ paths.append(f'<rect x="{center_x - 25}" y="{center_y - 15}" width="20" height="12" fill="{colors[2]}"/>')
280
+ paths.append(f'<rect x="{center_x + 5}" y="{center_y - 15}" width="20" height="12" fill="{colors[2]}"/>')
281
+
282
  def _generate_simple_svg(self, prompt: str, width: int, height: int, particle_id: int, style: str) -> str:
283
  """
284
  Generate a simple SVG as placeholder for each particle