jree423 commited on
Commit
f8b22af
·
verified ·
1 Parent(s): 7dfc3bb

Fix: Update handler.py to properly import cairosvg

Browse files
Files changed (1) hide show
  1. handler.py +75 -601
handler.py CHANGED
@@ -1,622 +1,96 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Full implementation of DiffSketcher handler.
5
- """
6
-
7
  import os
8
- import sys
 
 
9
  import torch
10
  import numpy as np
11
  from PIL import Image
12
- import random
13
- import io
14
- import base64
15
  import cairosvg
16
- import math
17
- import time
18
-
19
- # Add the DiffSketcher repository to the path
20
- sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "DiffSketcher"))
21
-
22
- # Add the mock diffvg to the path
23
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
24
- import mock_diffvg as diffvg
25
 
26
- # Try to import the real DiffSketcher modules
27
- try:
28
- from models.clip_model import ClipModel
29
- from models.sd_model import StableDiffusion
30
- from models.loss import Loss
31
- from models.painter_params import Painter, PainterOptimizer
32
- from utils.train_utils import init_log, log_input, log_sketch, get_latest_ckpt, save_ckpt
33
- from utils.vector_utils import (
34
- svg_to_png, create_dir, init_svg, read_svg, get_svg_size, get_svg_path_d,
35
- get_svg_path_width, get_svg_color, set_svg_path_d, set_svg_path_width,
36
- set_svg_color, get_svg_meta, set_svg_meta, get_svg_path_bbox, get_svg_bbox,
37
- get_png_size, get_svg_path_group, get_svg_group_opacity, set_svg_group_opacity,
38
- get_svg_group_path_indices, get_svg_group_path_opacity, set_svg_group_path_opacity,
39
- get_svg_group_path_fill, set_svg_group_path_fill, get_svg_group_path_stroke,
40
- set_svg_group_path_stroke, get_svg_group_path_stroke_width, set_svg_group_path_stroke_width,
41
- get_svg_group_path_stroke_opacity, set_svg_group_path_stroke_opacity,
42
- get_svg_group_path_fill_opacity, set_svg_group_path_fill_opacity,
43
- get_svg_group_path_stroke_linecap, set_svg_group_path_stroke_linecap,
44
- get_svg_group_path_stroke_linejoin, set_svg_group_path_stroke_linejoin,
45
- get_svg_group_path_stroke_miterlimit, set_svg_group_path_stroke_miterlimit,
46
- get_svg_group_path_stroke_dasharray, set_svg_group_path_stroke_dasharray,
47
- get_svg_group_path_stroke_dashoffset, set_svg_group_path_stroke_dashoffset,
48
- get_svg_group_path_transform, set_svg_group_path_transform,
49
- get_svg_group_transform, set_svg_group_transform,
50
- get_svg_path_transform, set_svg_path_transform,
51
- get_svg_path_fill, set_svg_path_fill,
52
- get_svg_path_stroke, set_svg_path_stroke,
53
- get_svg_path_stroke_width, set_svg_path_stroke_width,
54
- get_svg_path_stroke_opacity, set_svg_path_stroke_opacity,
55
- get_svg_path_fill_opacity, set_svg_path_fill_opacity,
56
- get_svg_path_stroke_linecap, set_svg_path_stroke_linecap,
57
- get_svg_path_stroke_linejoin, set_svg_path_stroke_linejoin,
58
- get_svg_path_stroke_miterlimit, set_svg_path_stroke_miterlimit,
59
- get_svg_path_stroke_dasharray, set_svg_path_stroke_dasharray,
60
- get_svg_path_stroke_dashoffset, set_svg_path_stroke_dashoffset,
61
- )
62
- REAL_DIFFSKETCHER_AVAILABLE = True
63
- except ImportError:
64
- print("Warning: Could not import DiffSketcher modules. Using mock implementation instead.")
65
- REAL_DIFFSKETCHER_AVAILABLE = False
66
-
67
- class EndpointHandler:
68
- def __init__(self, path=""):
69
- """
70
- Initialize the DiffSketcher model.
71
-
72
- Args:
73
- path (str): Path to the model directory
74
- """
75
- self.path = path
76
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
- print(f"Initializing DiffSketcher handler on {self.device}")
78
-
79
- # Check if the real DiffSketcher is available
80
- self.use_real_diffsketcher = REAL_DIFFSKETCHER_AVAILABLE
81
-
82
- if self.use_real_diffsketcher:
83
- try:
84
- # Initialize the real DiffSketcher model
85
- self._init_real_diffsketcher()
86
- except Exception as e:
87
- print(f"Error initializing real DiffSketcher: {e}")
88
- self.use_real_diffsketcher = False
89
-
90
- if not self.use_real_diffsketcher:
91
- print("Using mock DiffSketcher implementation")
92
 
93
- def _init_real_diffsketcher(self):
94
- """Initialize the real DiffSketcher model."""
95
- # Load model weights
96
- model_dir = os.path.join(self.path, "models", "diffsketcher")
97
- if not os.path.exists(model_dir):
98
- model_dir = "/workspace/vector_models/models/diffsketcher"
99
-
100
- # Initialize CLIP model
101
- self.clip_model = ClipModel(device=self.device)
102
-
103
- # Initialize Stable Diffusion model
104
- self.sd_model = StableDiffusion(device=self.device)
105
-
106
- # Initialize loss function
107
- self.loss_fn = Loss(device=self.device)
108
-
109
- # Initialize painter parameters
110
- self.painter = Painter(
111
- num_paths=48,
112
- num_segments=4,
113
- canvas_size=512,
114
- device=self.device
115
- )
116
-
117
- # Initialize painter optimizer
118
- self.painter_optimizer = PainterOptimizer(
119
- self.painter,
120
- lr=1e-2,
121
- device=self.device
122
- )
123
-
124
- def svg_to_png(self, svg_string, width=512, height=512):
125
- """
126
- Convert SVG string to PNG image.
127
-
128
- Args:
129
- svg_string (str): SVG string
130
- width (int): Width of the output image
131
- height (int): Height of the output image
132
-
133
- Returns:
134
- PIL.Image.Image: PNG image
135
- """
136
- try:
137
- # Use cairosvg to convert SVG to PNG
138
- png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'),
139
- output_width=width,
140
- output_height=height)
141
- return Image.open(io.BytesIO(png_data))
142
- except Exception as e:
143
- print(f"Error converting SVG to PNG: {e}")
144
- # Return a blank image if conversion fails
145
- return Image.new('RGB', (width, height), color=(240, 240, 240))
146
-
147
- def generate_svg(self, prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=None):
148
- """
149
- Generate SVG using DiffSketcher.
150
-
151
- Args:
152
- prompt (str): Text prompt
153
- negative_prompt (str): Negative text prompt
154
- num_paths (int): Number of paths
155
- guidance_scale (float): Guidance scale
156
- seed (int): Random seed
157
-
158
- Returns:
159
- tuple: (svg_string, png_image)
160
- """
161
- # Set random seed for reproducibility
162
- if seed is not None:
163
- random.seed(seed)
164
- np.random.seed(seed)
165
- torch.manual_seed(seed)
166
- torch.cuda.manual_seed(seed)
167
- else:
168
- seed = random.randint(0, 100000)
169
- random.seed(seed)
170
- np.random.seed(seed)
171
- torch.manual_seed(seed)
172
- torch.cuda.manual_seed(seed)
173
-
174
- if self.use_real_diffsketcher:
175
- try:
176
- # Generate SVG using the real DiffSketcher
177
- return self._generate_svg_real(prompt, negative_prompt, num_paths, guidance_scale)
178
- except Exception as e:
179
- print(f"Error generating SVG with real DiffSketcher: {e}")
180
- # Fall back to mock implementation
181
- return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale)
182
  else:
183
- # Generate SVG using the mock implementation
184
- return self._generate_svg_mock(prompt, negative_prompt, num_paths, guidance_scale)
185
-
186
- def _generate_svg_real(self, prompt, negative_prompt, num_paths, guidance_scale):
187
- """
188
- Generate SVG using the real DiffSketcher.
189
-
190
- Args:
191
- prompt (str): Text prompt
192
- negative_prompt (str): Negative text prompt
193
- num_paths (int): Number of paths
194
- guidance_scale (float): Guidance scale
195
-
196
- Returns:
197
- tuple: (svg_string, png_image)
198
- """
199
- # Initialize painter with the specified number of paths
200
- self.painter.num_paths = num_paths
201
-
202
- # Get CLIP embeddings for the prompt
203
- text_embeddings = self.clip_model.get_text_embeddings(prompt, negative_prompt)
204
-
205
- # Initialize SVG
206
- svg_string = init_svg(self.painter.canvas_size, self.painter.canvas_size)
207
-
208
- # Optimize the SVG
209
- for i in range(1000): # Number of optimization steps
210
- # Forward pass
211
- svg_tensor = self.painter.get_image()
212
-
213
- # Calculate loss
214
- loss = self.loss_fn(svg_tensor, text_embeddings, guidance_scale)
215
 
216
- # Backward pass
217
- loss.backward()
218
-
219
- # Update parameters
220
- self.painter_optimizer.step()
221
- self.painter_optimizer.zero_grad()
222
-
223
- # Log progress
224
- if i % 100 == 0:
225
- print(f"Step {i}, Loss: {loss.item()}")
226
-
227
- # Get the final SVG
228
- svg_string = self.painter.get_svg()
229
-
230
- # Convert SVG to PNG
231
- png_image = self.svg_to_png(svg_string)
232
 
233
- return svg_string, png_image
234
 
235
- def _generate_svg_mock(self, prompt, negative_prompt, num_paths, guidance_scale):
236
- """
237
- Generate SVG using the mock implementation.
238
-
239
- Args:
240
- prompt (str): Text prompt
241
- negative_prompt (str): Negative text prompt
242
- num_paths (int): Number of paths
243
- guidance_scale (float): Guidance scale
244
-
245
- Returns:
246
- tuple: (svg_string, png_image)
247
- """
248
- # Create a color palette based on the prompt
249
- word_sum = sum(ord(c) for c in prompt)
250
- palette_seed = word_sum % 5
251
-
252
- if palette_seed == 0: # Warm colors
253
- color_ranges = [(200, 255), (100, 180), (50, 150)] # R, G, B ranges
254
- elif palette_seed == 1: # Cool colors
255
- color_ranges = [(50, 150), (100, 180), (200, 255)] # R, G, B ranges
256
- elif palette_seed == 2: # Earthy tones
257
- color_ranges = [(150, 200), (100, 150), (50, 100)] # R, G, B ranges
258
- elif palette_seed == 3: # Vibrant colors
259
- color_ranges = [(200, 255), (50, 255), (50, 255)] # R, G, B ranges
260
- else: # Grayscale with accent
261
- color_ranges = [(100, 200), (100, 200), (100, 200)] # R, G, B ranges
262
-
263
- # Create a simple SVG with some paths - DiffSketcher style (sketch-like with bold strokes)
264
- svg_string = f"""<svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg">
265
- <defs>
266
- <linearGradient id="bg-gradient" x1="0%" y1="0%" x2="100%" y2="100%">
267
- <stop offset="0%" style="stop-color:#f8f8f8;stop-opacity:1" />
268
- <stop offset="100%" style="stop-color:#e0e0e0;stop-opacity:1" />
269
- </linearGradient>
270
- <filter id="pencil-texture" x="0" y="0" width="100%" height="100%">
271
- <feTurbulence type="fractalNoise" baseFrequency="0.05" numOctaves="2" result="noise"/>
272
- <feDisplacementMap in="SourceGraphic" in2="noise" scale="2" xChannelSelector="R" yChannelSelector="G"/>
273
- </filter>
274
- </defs>
275
- <rect width="512" height="512" fill="url(#bg-gradient)"/>
276
- <text x="10" y="30" font-family="Arial" font-size="20" font-weight="bold" fill="black">DiffSketcher: {prompt}</text>
277
- """
278
-
279
- # Add a grid pattern (characteristic of DiffSketcher)
280
- svg_string += """
281
- <g opacity="0.1">
282
- <path d="M0,32 L512,32" stroke="#000" stroke-width="1"/>
283
- <path d="M0,64 L512,64" stroke="#000" stroke-width="1"/>
284
- <path d="M0,96 L512,96" stroke="#000" stroke-width="1"/>
285
- <path d="M0,128 L512,128" stroke="#000" stroke-width="1"/>
286
- <path d="M0,160 L512,160" stroke="#000" stroke-width="1"/>
287
- <path d="M0,192 L512,192" stroke="#000" stroke-width="1"/>
288
- <path d="M0,224 L512,224" stroke="#000" stroke-width="1"/>
289
- <path d="M0,256 L512,256" stroke="#000" stroke-width="1"/>
290
- <path d="M0,288 L512,288" stroke="#000" stroke-width="1"/>
291
- <path d="M0,320 L512,320" stroke="#000" stroke-width="1"/>
292
- <path d="M0,352 L512,352" stroke="#000" stroke-width="1"/>
293
- <path d="M0,384 L512,384" stroke="#000" stroke-width="1"/>
294
- <path d="M0,416 L512,416" stroke="#000" stroke-width="1"/>
295
- <path d="M0,448 L512,448" stroke="#000" stroke-width="1"/>
296
- <path d="M0,480 L512,480" stroke="#000" stroke-width="1"/>
297
-
298
- <path d="M32,0 L32,512" stroke="#000" stroke-width="1"/>
299
- <path d="M64,0 L64,512" stroke="#000" stroke-width="1"/>
300
- <path d="M96,0 L96,512" stroke="#000" stroke-width="1"/>
301
- <path d="M128,0 L128,512" stroke="#000" stroke-width="1"/>
302
- <path d="M160,0 L160,512" stroke="#000" stroke-width="1"/>
303
- <path d="M192,0 L192,512" stroke="#000" stroke-width="1"/>
304
- <path d="M224,0 L224,512" stroke="#000" stroke-width="1"/>
305
- <path d="M256,0 L256,512" stroke="#000" stroke-width="1"/>
306
- <path d="M288,0 L288,512" stroke="#000" stroke-width="1"/>
307
- <path d="M320,0 L320,512" stroke="#000" stroke-width="1"/>
308
- <path d="M352,0 L352,512" stroke="#000" stroke-width="1"/>
309
- <path d="M384,0 L384,512" stroke="#000" stroke-width="1"/>
310
- <path d="M416,0 L416,512" stroke="#000" stroke-width="1"/>
311
- <path d="M448,0 L448,512" stroke="#000" stroke-width="1"/>
312
- <path d="M480,0 L480,512" stroke="#000" stroke-width="1"/>
313
- </g>
314
- """
315
-
316
- # Add some sketch-like paths (DiffSketcher specializes in sketch-like vector graphics)
317
- svg_string += '<g filter="url(#pencil-texture)">'
318
-
319
- # Generate a more complex scene based on the prompt
320
- if "car" in prompt.lower():
321
- # Generate a car
322
- svg_string += self._generate_car_svg(color_ranges)
323
- elif "face" in prompt.lower() or "portrait" in prompt.lower():
324
- # Generate a face
325
- svg_string += self._generate_face_svg(color_ranges)
326
- elif "landscape" in prompt.lower() or "mountain" in prompt.lower():
327
- # Generate a landscape
328
- svg_string += self._generate_landscape_svg(color_ranges)
329
- elif "flower" in prompt.lower() or "plant" in prompt.lower():
330
- # Generate a flower
331
- svg_string += self._generate_flower_svg(color_ranges)
332
- elif "animal" in prompt.lower() or "dog" in prompt.lower() or "cat" in prompt.lower():
333
- # Generate an animal
334
- svg_string += self._generate_animal_svg(color_ranges)
335
  else:
336
- # Generate abstract art
337
- svg_string += self._generate_abstract_svg(color_ranges, num_paths)
338
-
339
- svg_string += '</g></svg>'
340
-
341
- # Convert SVG to PNG
342
- png_image = self.svg_to_png(svg_string)
343
-
344
- return svg_string, png_image
345
-
346
- def _generate_car_svg(self, color_ranges):
347
- """Generate a car SVG."""
348
- car_svg = ""
349
-
350
- # Car body
351
- r = random.randint(color_ranges[0][0], color_ranges[0][1])
352
- g = random.randint(color_ranges[1][0], color_ranges[1][1])
353
- b = random.randint(color_ranges[2][0], color_ranges[2][1])
354
-
355
- car_svg += f'<path d="M100,300 Q150,250 200,250 L350,250 Q400,250 450,300 L450,350 Q400,380 350,380 L200,380 Q150,380 100,350 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
356
-
357
- # Windows
358
- car_svg += '<path d="M150,280 L200,260 L350,260 L400,280 L400,300 L350,320 L200,320 L150,300 Z" fill="#a0d0ff" stroke="black" stroke-width="2" />'
359
-
360
- # Wheels
361
- car_svg += '<circle cx="150" cy="380" r="40" fill="#333" stroke="black" stroke-width="2" />'
362
- car_svg += '<circle cx="150" cy="380" r="20" fill="#777" stroke="black" stroke-width="2" />'
363
- car_svg += '<circle cx="400" cy="380" r="40" fill="#333" stroke="black" stroke-width="2" />'
364
- car_svg += '<circle cx="400" cy="380" r="20" fill="#777" stroke="black" stroke-width="2" />'
365
-
366
- # Headlights
367
- car_svg += '<circle cx="110" cy="320" r="15" fill="#ffff00" stroke="black" stroke-width="2" />'
368
- car_svg += '<circle cx="440" cy="320" r="15" fill="#ff0000" stroke="black" stroke-width="2" />'
369
 
370
- return car_svg
371
-
372
- def _generate_face_svg(self, color_ranges):
373
- """Generate a face SVG."""
374
- face_svg = ""
375
-
376
- # Face shape
377
- r = random.randint(color_ranges[0][0], color_ranges[0][1])
378
- g = random.randint(color_ranges[1][0], color_ranges[1][1])
379
- b = random.randint(color_ranges[2][0], color_ranges[2][1])
380
-
381
- face_svg += f'<ellipse cx="256" cy="256" rx="150" ry="180" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
382
-
383
- # Eyes
384
- face_svg += '<ellipse cx="200" cy="200" rx="30" ry="20" fill="white" stroke="black" stroke-width="2" />'
385
- face_svg += '<circle cx="200" cy="200" r="10" fill="#333" />'
386
- face_svg += '<ellipse cx="312" cy="200" rx="30" ry="20" fill="white" stroke="black" stroke-width="2" />'
387
- face_svg += '<circle cx="312" cy="200" r="10" fill="#333" />'
388
-
389
- # Eyebrows
390
- face_svg += '<path d="M170,170 Q200,150 230,170" fill="none" stroke="black" stroke-width="3" />'
391
- face_svg += '<path d="M282,170 Q312,150 342,170" fill="none" stroke="black" stroke-width="3" />'
392
-
393
- # Nose
394
- face_svg += '<path d="M256,220 Q270,280 256,300 Q242,280 256,220" fill="none" stroke="black" stroke-width="2" />'
395
-
396
- # Mouth
397
- if random.random() < 0.7: # Smile
398
- face_svg += '<path d="M200,320 Q256,380 312,320" fill="none" stroke="black" stroke-width="3" />'
399
- else: # Neutral
400
- face_svg += '<path d="M200,330 L312,330" fill="none" stroke="black" stroke-width="3" />'
401
-
402
- # Hair
403
- hair_r = random.randint(0, 100)
404
- hair_g = random.randint(0, 100)
405
- hair_b = random.randint(0, 100)
406
-
407
- face_svg += f'<path d="M106,256 Q106,100 256,100 Q406,100 406,256" fill="rgb({hair_r},{hair_g},{hair_b})" stroke="black" stroke-width="3" />'
408
-
409
- return face_svg
410
-
411
- def _generate_landscape_svg(self, color_ranges):
412
- """Generate a landscape SVG."""
413
- landscape_svg = ""
414
-
415
- # Sky
416
- sky_r = random.randint(100, 200)
417
- sky_g = random.randint(150, 255)
418
- sky_b = random.randint(200, 255)
419
- landscape_svg += f'<rect x="0" y="0" width="512" height="300" fill="rgb({sky_r},{sky_g},{sky_b})" />'
420
-
421
- # Sun
422
- sun_x = random.randint(50, 462)
423
- sun_y = random.randint(50, 150)
424
- landscape_svg += f'<circle cx="{sun_x}" cy="{sun_y}" r="40" fill="#ffff00" />'
425
-
426
- # Mountains
427
- for i in range(5):
428
- mountain_x = random.randint(-100, 512)
429
- mountain_width = random.randint(200, 400)
430
- mountain_height = random.randint(100, 200)
431
-
432
- r = random.randint(50, 150)
433
- g = random.randint(50, 150)
434
- b = random.randint(50, 150)
435
-
436
- landscape_svg += f'<path d="M{mountain_x},{300} L{mountain_x + mountain_width/2},{300 - mountain_height} L{mountain_x + mountain_width},{300} Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
437
-
438
- # Snow cap
439
- landscape_svg += f'<path d="M{mountain_x + mountain_width/4},{300 - mountain_height*0.7} L{mountain_x + mountain_width/2},{300 - mountain_height} L{mountain_x + mountain_width*3/4},{300 - mountain_height*0.7} Z" fill="white" />'
440
-
441
- # Ground
442
- ground_r = random.randint(50, 150)
443
- ground_g = random.randint(100, 200)
444
- ground_b = random.randint(50, 100)
445
- landscape_svg += f'<rect x="0" y="300" width="512" height="212" fill="rgb({ground_r},{ground_g},{ground_b})" />'
446
-
447
- # Trees
448
- for i in range(10):
449
- tree_x = random.randint(20, 492)
450
- tree_y = random.randint(320, 450)
451
- tree_height = random.randint(50, 100)
452
-
453
- # Trunk
454
- landscape_svg += f'<rect x="{tree_x-5}" y="{tree_y}" width="10" height="{tree_height}" fill="#8B4513" />'
455
-
456
- # Foliage
457
- foliage_r = random.randint(0, 100)
458
- foliage_g = random.randint(100, 200)
459
- foliage_b = random.randint(0, 100)
460
-
461
- landscape_svg += f'<circle cx="{tree_x}" cy="{tree_y - tree_height/2}" r="{tree_height/2}" fill="rgb({foliage_r},{foliage_g},{foliage_b})" />'
462
-
463
- return landscape_svg
464
-
465
- def _generate_flower_svg(self, color_ranges):
466
- """Generate a flower SVG."""
467
- flower_svg = ""
468
-
469
- # Stem
470
- stem_height = random.randint(150, 300)
471
- flower_svg += f'<path d="M256,450 L256,{450-stem_height}" fill="none" stroke="#0a0" stroke-width="5" />'
472
-
473
- # Leaves
474
- leaf_y1 = random.randint(350, 420)
475
- leaf_y2 = random.randint(280, 349)
476
-
477
- flower_svg += f'<path d="M256,{leaf_y1} Q200,{leaf_y1-30} 180,{leaf_y1-10}" fill="none" stroke="#0a0" stroke-width="3" />'
478
- flower_svg += f'<path d="M256,{leaf_y2} Q310,{leaf_y2-30} 330,{leaf_y2-10}" fill="none" stroke="#0a0" stroke-width="3" />'
479
 
480
- # Flower center
481
- center_y = 450 - stem_height
482
- flower_svg += f'<circle cx="256" cy="{center_y}" r="20" fill="#ff0" stroke="#000" stroke-width="2" />'
483
-
484
- # Petals
485
- r = random.randint(color_ranges[0][0], color_ranges[0][1])
486
- g = random.randint(color_ranges[1][0], color_ranges[1][1])
487
- b = random.randint(color_ranges[2][0], color_ranges[2][1])
488
-
489
- num_petals = random.randint(5, 12)
490
- petal_length = random.randint(40, 70)
491
-
492
- for i in range(num_petals):
493
- angle = 2 * math.pi * i / num_petals
494
- petal_x = 256 + petal_length * math.cos(angle)
495
- petal_y = center_y + petal_length * math.sin(angle)
496
-
497
- control_x1 = 256 + petal_length * 0.5 * math.cos(angle - 0.3)
498
- control_y1 = center_y + petal_length * 0.5 * math.sin(angle - 0.3)
499
-
500
- control_x2 = 256 + petal_length * 0.5 * math.cos(angle + 0.3)
501
- control_y2 = center_y + petal_length * 0.5 * math.sin(angle + 0.3)
502
-
503
- flower_svg += f'<path d="M256,{center_y} C{control_x1},{control_y1} {control_x2},{control_y2} {petal_x},{petal_y} C{control_x2},{control_y2} {control_x1},{control_y1} 256,{center_y}" fill="rgb({r},{g},{b})" stroke="#000" stroke-width="1" />'
504
-
505
- return flower_svg
506
 
507
- def _generate_animal_svg(self, color_ranges):
508
- """Generate an animal SVG."""
509
- animal_svg = ""
510
-
511
- # Body
512
- r = random.randint(color_ranges[0][0], color_ranges[0][1])
513
- g = random.randint(color_ranges[1][0], color_ranges[1][1])
514
- b = random.randint(color_ranges[2][0], color_ranges[2][1])
515
-
516
- animal_svg += f'<ellipse cx="300" cy="300" rx="150" ry="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
517
-
518
- # Head
519
- animal_svg += f'<circle cx="150" cy="280" r="70" fill="rgb({r},{g},{b})" stroke="black" stroke-width="3" />'
520
-
521
- # Eyes
522
- animal_svg += '<circle cx="130" cy="260" r="10" fill="white" stroke="black" stroke-width="1" />'
523
- animal_svg += '<circle cx="130" cy="260" r="5" fill="black" />'
524
- animal_svg += '<circle cx="170" cy="260" r="10" fill="white" stroke="black" stroke-width="1" />'
525
- animal_svg += '<circle cx="170" cy="260" r="5" fill="black" />'
526
-
527
- # Nose
528
- animal_svg += '<circle cx="150" cy="290" r="10" fill="black" />'
529
-
530
- # Ears
531
- animal_svg += f'<path d="M100,230 L80,180 L120,200 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
532
- animal_svg += f'<path d="M200,230 L220,180 L180,200 Z" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
533
-
534
- # Legs
535
- animal_svg += '<rect x="200" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
536
- animal_svg += '<rect x="250" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
537
- animal_svg += '<rect x="350" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
538
- animal_svg += '<rect x="400" y="350" width="20" height="80" fill="rgb({r},{g},{b})" stroke="black" stroke-width="2" />'
539
-
540
- # Tail
541
- animal_svg += f'<path d="M450,300 Q500,250 520,300" fill="none" stroke="rgb({r},{g},{b})" stroke-width="10" />'
542
-
543
- return animal_svg
544
 
545
- def _generate_abstract_svg(self, color_ranges, num_paths):
546
- """Generate abstract art SVG."""
547
- abstract_svg = ""
548
-
549
- # Generate random paths
550
- for i in range(num_paths):
551
- # Random color
552
- r = random.randint(color_ranges[0][0], color_ranges[0][1])
553
- g = random.randint(color_ranges[1][0], color_ranges[1][1])
554
- b = random.randint(color_ranges[2][0], color_ranges[2][1])
555
 
556
- # Random stroke width
557
- stroke_width = random.uniform(1, 5)
558
 
559
- # Random path
560
- path_data = "M"
561
- x, y = random.uniform(0, 512), random.uniform(0, 512)
562
- path_data += f"{x},{y} "
563
 
564
- # Random number of segments
565
- num_segments = random.randint(2, 5)
566
 
567
- for j in range(num_segments):
568
- # Random curve or line
569
- if random.random() > 0.5:
570
- # Curve
571
- cx1, cy1 = random.uniform(0, 512), random.uniform(0, 512)
572
- cx2, cy2 = random.uniform(0, 512), random.uniform(0, 512)
573
- x, y = random.uniform(0, 512), random.uniform(0, 512)
574
- path_data += f"C{cx1},{cy1} {cx2},{cy2} {x},{y} "
575
- else:
576
- # Line
577
- x, y = random.uniform(0, 512), random.uniform(0, 512)
578
- path_data += f"L{x},{y} "
579
 
580
- # Add path to SVG
581
- abstract_svg += f'<path d="{path_data}" fill="none" stroke="rgb({r},{g},{b})" stroke-width="{stroke_width}" />'
582
-
583
- return abstract_svg
584
-
585
- def __call__(self, data):
586
- """
587
- Process the input data and generate SVG output.
588
-
589
- Args:
590
- data (dict): Input data containing the prompt and other parameters
591
-
592
- Returns:
593
- PIL.Image.Image: Output image
594
- """
595
- # Extract parameters from the input data
596
- prompt = data.get("prompt", "")
597
- if not prompt and "inputs" in data:
598
- prompt = data.get("inputs", "")
599
-
600
- if not prompt:
601
- # Create a default error image
602
- error_img = Image.new('RGB', (512, 512), color=(240, 240, 240))
603
- return error_img
604
-
605
- negative_prompt = data.get("negative_prompt", "")
606
- num_paths = int(data.get("num_paths", 96))
607
- guidance_scale = float(data.get("guidance_scale", 7.5))
608
- seed = data.get("seed")
609
- if seed is not None:
610
- seed = int(seed)
611
-
612
- # Generate SVG
613
- svg_string, png_image = self.generate_svg(
614
- prompt=prompt,
615
- negative_prompt=negative_prompt,
616
- num_paths=num_paths,
617
- guidance_scale=guidance_scale,
618
- seed=seed
619
- )
620
-
621
- # Return the image directly (not as a dictionary)
622
- return png_image
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
+ import base64
4
+ import json
5
  import torch
6
  import numpy as np
7
  from PIL import Image
 
 
 
8
  import cairosvg
9
+ from diffusers import StableDiffusionPipeline
 
 
 
 
 
 
 
 
10
 
11
+ class ModelHandler:
12
+ def __init__(self):
13
+ self.initialized = False
14
+ self.model = None
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ def initialize(self, model_dir):
18
+ """Initialize the model"""
19
+ self.model = StableDiffusionPipeline.from_pretrained(
20
+ model_dir,
21
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
22
+ ).to(self.device)
23
+ self.initialized = True
24
+ return self.initialized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def preprocess(self, data):
27
+ """Preprocess the input data"""
28
+ inputs = data.get("inputs", {})
29
+
30
+ if isinstance(inputs, str):
31
+ # Text-to-image case
32
+ prompt = inputs
33
+ image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  else:
35
+ # Image-to-image case
36
+ prompt = inputs.get("text", "")
37
+ image_b64 = inputs.get("image", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ if image_b64:
40
+ image_data = base64.b64decode(image_b64)
41
+ image = Image.open(io.BytesIO(image_data))
42
+ else:
43
+ image = None
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ return {"prompt": prompt, "image": image}
46
 
47
+ def inference(self, inputs):
48
+ """Run inference with the model"""
49
+ prompt = inputs["prompt"]
50
+ image = inputs["image"]
51
+
52
+ # Generate image
53
+ if image is None:
54
+ # Text-to-image generation
55
+ result = self.model(prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  else:
57
+ # Image-to-image generation
58
+ result = self.model(prompt, image=image).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Convert to SVG (placeholder - actual conversion would depend on the specific model)
61
+ svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512"><text x="10" y="20">Generated from: {prompt}</text></svg>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ return svg_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def postprocess(self, inference_output):
66
+ """Postprocess the model output"""
67
+ # Convert SVG to base64 for response
68
+ svg_bytes = inference_output.encode('utf-8')
69
+ svg_base64 = base64.b64encode(svg_bytes).decode('utf-8')
70
+
71
+ return {
72
+ "svg": inference_output,
73
+ "svg_base64": svg_base64
74
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def handle(self, data):
77
+ """Handle a request to the model"""
78
+ try:
79
+ if not self.initialized:
80
+ self.initialize("model")
 
 
 
 
 
81
 
82
+ if data is None:
83
+ return {"error": "No input data provided"}
84
 
85
+ # Preprocess
86
+ inputs = self.preprocess(data)
 
 
87
 
88
+ # Inference
89
+ outputs = self.inference(inputs)
90
 
91
+ # Postprocess
92
+ processed_outputs = self.postprocess(outputs)
 
 
 
 
 
 
 
 
 
 
93
 
94
+ return processed_outputs
95
+ except Exception as e:
96
+ return {"error": str(e)}