Update handler to return PIL Images for Inference API compatibility
Browse files- config.json +1 -1
- handler.py +15 -3
config.json
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
"model_type": "svgdreamer",
|
| 4 |
"task": "text-to-svg",
|
| 5 |
"framework": "pytorch",
|
| 6 |
-
"pipeline_tag": "text-
|
| 7 |
"library_name": "diffusers",
|
| 8 |
"inference": {
|
| 9 |
"parameters": {
|
|
|
|
| 3 |
"model_type": "svgdreamer",
|
| 4 |
"task": "text-to-svg",
|
| 5 |
"framework": "pytorch",
|
| 6 |
+
"pipeline_tag": "text-to-image",
|
| 7 |
"library_name": "diffusers",
|
| 8 |
"inference": {
|
| 9 |
"parameters": {
|
handler.py
CHANGED
|
@@ -4,6 +4,9 @@ import json
|
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
from typing import Dict, Any, List
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# SVGDreamer handler for Hugging Face Inference API
|
| 9 |
|
|
@@ -88,11 +91,20 @@ class EndpointHandler:
|
|
| 88 |
# Generate the best particle SVG (simulate particle selection)
|
| 89 |
best_svg = self._generate_particle_svg(prompt, width, height, style, seed)
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
except Exception as e:
|
| 94 |
-
# Return error
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
def _generate_particle_svg(self, prompt: str, width: int, height: int, style: str, seed: int) -> str:
|
| 98 |
"""Generate SVG using particle-based optimization simulation"""
|
|
|
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
from typing import Dict, Any, List
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import cairosvg
|
| 9 |
+
import io
|
| 10 |
|
| 11 |
# SVGDreamer handler for Hugging Face Inference API
|
| 12 |
|
|
|
|
| 91 |
# Generate the best particle SVG (simulate particle selection)
|
| 92 |
best_svg = self._generate_particle_svg(prompt, width, height, style, seed)
|
| 93 |
|
| 94 |
+
# Convert SVG to PIL Image
|
| 95 |
+
try:
|
| 96 |
+
png_data = cairosvg.svg2png(bytestring=best_svg.encode('utf-8'))
|
| 97 |
+
image = Image.open(io.BytesIO(png_data))
|
| 98 |
+
return image
|
| 99 |
+
except Exception as svg_error:
|
| 100 |
+
# Fallback: create a simple error image
|
| 101 |
+
error_image = Image.new('RGB', (width, height), color='white')
|
| 102 |
+
return error_image
|
| 103 |
|
| 104 |
except Exception as e:
|
| 105 |
+
# Return error image
|
| 106 |
+
error_image = Image.new('RGB', (224, 224), color='white')
|
| 107 |
+
return error_image
|
| 108 |
|
| 109 |
def _generate_particle_svg(self, prompt: str, width: int, height: int, style: str, seed: int) -> str:
|
| 110 |
"""Generate SVG using particle-based optimization simulation"""
|