diffsketcher / handler.py
jree423's picture
Fix: Update handler.py with EndpointHandler class
54f01ed verified
raw
history blame
2.57 kB
import os
import io
import base64
import json
import torch
import numpy as np
from PIL import Image
# Safely import cairosvg with fallback
try:
import cairosvg
except ImportError:
print("Warning: cairosvg not found. Installing...")
import subprocess
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
import cairosvg
class EndpointHandler:
def __init__(self, model_dir):
"""Initialize the handler with model directory"""
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Initialized model on device: {self.device}")
def __call__(self, data):
"""Handle a request to the model"""
try:
# Extract the prompt
if isinstance(data, dict) and "inputs" in data:
if isinstance(data["inputs"], str):
prompt = data["inputs"]
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
prompt = data["inputs"]["text"]
else:
prompt = "No prompt provided"
else:
prompt = "No prompt provided"
# Generate a placeholder SVG
width, height = 512, 512
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="#f0f0f0"/>
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
</svg>"""
# Convert SVG to base64
svg_bytes = svg_content.encode("utf-8")
svg_base64 = base64.b64encode(svg_bytes).decode("utf-8")
# Convert SVG to PNG using cairosvg
try:
png_data = cairosvg.svg2png(bytestring=svg_bytes)
png_base64 = base64.b64encode(png_data).decode("utf-8")
except Exception as e:
print(f"Error converting SVG to PNG: {e}")
# Return a transparent 1x1 pixel PNG as fallback
png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
# Return the results
return {
"svg": svg_content,
"svg_base64": svg_base64,
"png_base64": png_base64
}
except Exception as e:
return {"error": str(e)}