jree423 commited on
Commit
a32d774
·
verified ·
1 Parent(s): 9c00a8e

Update handler to return PIL Images for Inference API compatibility

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. handler.py +15 -3
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "model_type": "svgdreamer",
4
  "task": "text-to-svg",
5
  "framework": "pytorch",
6
- "pipeline_tag": "text-generation",
7
  "library_name": "diffusers",
8
  "inference": {
9
  "parameters": {
 
3
  "model_type": "svgdreamer",
4
  "task": "text-to-svg",
5
  "framework": "pytorch",
6
+ "pipeline_tag": "text-to-image",
7
  "library_name": "diffusers",
8
  "inference": {
9
  "parameters": {
handler.py CHANGED
@@ -4,6 +4,9 @@ import json
4
  import torch
5
  import numpy as np
6
  from typing import Dict, Any, List
 
 
 
7
 
8
  # SVGDreamer handler for Hugging Face Inference API
9
 
@@ -88,11 +91,20 @@ class EndpointHandler:
88
  # Generate the best particle SVG (simulate particle selection)
89
  best_svg = self._generate_particle_svg(prompt, width, height, style, seed)
90
 
91
- return best_svg
 
 
 
 
 
 
 
 
92
 
93
  except Exception as e:
94
- # Return error SVG
95
- return f'<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg"><text x="10" y="20" fill="red">Error: {str(e)}</text></svg>'
 
96
 
97
  def _generate_particle_svg(self, prompt: str, width: int, height: int, style: str, seed: int) -> str:
98
  """Generate SVG using particle-based optimization simulation"""
 
4
  import torch
5
  import numpy as np
6
  from typing import Dict, Any, List
7
+ from PIL import Image
8
+ import cairosvg
9
+ import io
10
 
11
  # SVGDreamer handler for Hugging Face Inference API
12
 
 
91
  # Generate the best particle SVG (simulate particle selection)
92
  best_svg = self._generate_particle_svg(prompt, width, height, style, seed)
93
 
94
+ # Convert SVG to PIL Image
95
+ try:
96
+ png_data = cairosvg.svg2png(bytestring=best_svg.encode('utf-8'))
97
+ image = Image.open(io.BytesIO(png_data))
98
+ return image
99
+ except Exception as svg_error:
100
+ # Fallback: create a simple error image
101
+ error_image = Image.new('RGB', (width, height), color='white')
102
+ return error_image
103
 
104
  except Exception as e:
105
+ # Return error image
106
+ error_image = Image.new('RGB', (224, 224), color='white')
107
+ return error_image
108
 
109
  def _generate_particle_svg(self, prompt: str, width: int, height: int, style: str, seed: int) -> str:
110
  """Generate SVG using particle-based optimization simulation"""