jree423 commited on
Commit
2ec5443
·
verified ·
1 Parent(s): 047a30a

Fix handler to return PIL Images instead of dictionaries for HF API compatibility

Browse files
Files changed (1) hide show
  1. handler.py +53 -4
handler.py CHANGED
@@ -78,13 +78,31 @@ class EndpointHandler:
78
  prompt, n_particle, num_iter, guidance_scale, width, height, style
79
  )
80
 
81
- return particles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  except Exception as e:
84
  print(f"Error in handler: {e}")
85
- # Return fallback particles
86
- fallback_particles = self.create_fallback_particles(prompt, n_particle, width, height, style)
87
- return fallback_particles
 
 
 
88
 
89
  def generate_svg_particles(self, prompt, n_particle, num_iter, guidance_scale, width, height, style):
90
  """Generate multiple SVG particles using SVGDreamer approach"""
@@ -527,6 +545,37 @@ class EndpointHandler:
527
 
528
  return particles
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  def create_fallback_svg(self, prompt, width, height, style):
531
  """Create simple fallback SVG"""
532
  dwg = svgwrite.Drawing(size=(width, height))
 
78
  prompt, n_particle, num_iter, guidance_scale, width, height, style
79
  )
80
 
81
+ # Convert the first particle to PIL Image for HF API compatibility
82
+ if particles and len(particles) > 0:
83
+ main_svg = particles[0]["svg"]
84
+ pil_image = self.svg_to_pil_image(main_svg, width, height)
85
+
86
+ # Store all particles data as metadata
87
+ pil_image.info['particles'] = json.dumps(particles)
88
+ pil_image.info['prompt'] = prompt
89
+ pil_image.info['style'] = style
90
+ pil_image.info['n_particle'] = str(n_particle)
91
+
92
+ return pil_image
93
+ else:
94
+ # Fallback
95
+ fallback_svg = self.create_fallback_svg(prompt, width, height, style)
96
+ return self.svg_to_pil_image(fallback_svg, width, height)
97
 
98
  except Exception as e:
99
  print(f"Error in handler: {e}")
100
+ # Return fallback image
101
+ fallback_svg = self.create_fallback_svg(prompt, width, height, style)
102
+ fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
103
+ fallback_image.info['error'] = str(e)
104
+ fallback_image.info['prompt'] = prompt
105
+ return fallback_image
106
 
107
  def generate_svg_particles(self, prompt, n_particle, num_iter, guidance_scale, width, height, style):
108
  """Generate multiple SVG particles using SVGDreamer approach"""
 
545
 
546
  return particles
547
 
548
+ def svg_to_pil_image(self, svg_content, width, height):
549
+ """Convert SVG content to PIL Image"""
550
+ try:
551
+ import cairosvg
552
+ import io
553
+
554
+ # Convert SVG to PNG bytes
555
+ png_bytes = cairosvg.svg2png(
556
+ bytestring=svg_content.encode('utf-8'),
557
+ output_width=width,
558
+ output_height=height
559
+ )
560
+
561
+ # Convert to PIL Image
562
+ from PIL import Image
563
+ image = Image.open(io.BytesIO(png_bytes)).convert('RGB')
564
+ return image
565
+
566
+ except ImportError:
567
+ print("cairosvg not available, creating simple image representation")
568
+ # Fallback: create a simple image with text
569
+ from PIL import Image
570
+ image = Image.new('RGB', (width, height), 'white')
571
+ return image
572
+ except Exception as e:
573
+ print(f"Error converting SVG to image: {e}")
574
+ # Fallback: create a simple image
575
+ from PIL import Image
576
+ image = Image.new('RGB', (width, height), 'white')
577
+ return image
578
+
579
  def create_fallback_svg(self, prompt, width, height, style):
580
  """Create simple fallback SVG"""
581
  dwg = svgwrite.Drawing(size=(width, height))