|
|
import os |
|
|
import io |
|
|
import base64 |
|
|
import json |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
try: |
|
|
import cairosvg |
|
|
except ImportError: |
|
|
print("Warning: cairosvg not found. Installing...") |
|
|
import subprocess |
|
|
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"]) |
|
|
import cairosvg |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
"""Initialize the handler with model directory""" |
|
|
self.model_dir = model_dir |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Initialized model on device: {self.device}") |
|
|
|
|
|
def __call__(self, data): |
|
|
"""Handle a request to the model""" |
|
|
try: |
|
|
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
|
if isinstance(data["inputs"], str): |
|
|
prompt = data["inputs"] |
|
|
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]: |
|
|
prompt = data["inputs"]["text"] |
|
|
else: |
|
|
prompt = "No prompt provided" |
|
|
else: |
|
|
prompt = "No prompt provided" |
|
|
|
|
|
|
|
|
width, height = 512, 512 |
|
|
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
|
|
<rect width="100%" height="100%" fill="#f0f0f0"/> |
|
|
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text> |
|
|
</svg>""" |
|
|
|
|
|
|
|
|
svg_bytes = svg_content.encode("utf-8") |
|
|
svg_base64 = base64.b64encode(svg_bytes).decode("utf-8") |
|
|
|
|
|
|
|
|
try: |
|
|
png_data = cairosvg.svg2png(bytestring=svg_bytes) |
|
|
png_base64 = base64.b64encode(png_data).decode("utf-8") |
|
|
except Exception as e: |
|
|
print(f"Error converting SVG to PNG: {e}") |
|
|
|
|
|
png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" |
|
|
|
|
|
|
|
|
return { |
|
|
"svg": svg_content, |
|
|
"svg_base64": svg_base64, |
|
|
"png_base64": png_base64 |
|
|
} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |