diffsketcher / handler.py
jree423's picture
Update: Improve handler implementation
8d52022 verified
raw
history blame
5.57 kB
import os
import io
import sys
import base64
import json
import torch
import numpy as np
from PIL import Image
import requests
import tempfile
import shutil
import subprocess
from pathlib import Path
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting handler initialization")
# Safely import cairosvg with fallback
try:
import cairosvg
debug_log("Successfully imported cairosvg")
except ImportError:
debug_log("cairosvg not found. Installing...")
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
import cairosvg
debug_log("Installed and imported cairosvg")
class EndpointHandler:
def __init__(self, model_dir):
"""Initialize the handler with model directory"""
try:
debug_log(f"Initializing handler with model_dir: {model_dir}")
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
debug_log(f"Using device: {self.device}")
# Check if model weights exist
weights_path = os.path.join(model_dir, "checkpoint.pth")
if os.path.exists(weights_path):
debug_log(f"Found model weights at {weights_path}")
debug_log(f"Weights file size: {os.path.getsize(weights_path)} bytes")
else:
debug_log(f"Model weights not found at {weights_path}")
# Try to find weights in other locations
for root, dirs, files in os.walk(model_dir):
for file in files:
if file.endswith(".pth"):
debug_log(f"Found weights file: {os.path.join(root, file)}")
# For now, we'll just use a placeholder implementation
self.use_model = False
debug_log("Using placeholder implementation")
except Exception as e:
debug_log(f"Error in handler initialization: {e}")
import traceback
debug_log(traceback.format_exc())
self.use_model = False
def generate_svg(self, prompt, width=512, height=512):
"""Generate an SVG from a text prompt"""
debug_log(f"Generating SVG for prompt: {prompt}")
# Create a more interesting placeholder that looks like a sketch
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="#ffffff"/>
<g stroke="#000000" fill="none">
<!-- Draw a simple sketch based on the prompt -->
<circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/>
<ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/>
<path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/>
<path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/>
</g>
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
</svg>"""
debug_log("Generated SVG content")
return svg_content
def __call__(self, data):
"""Handle a request to the model"""
try:
debug_log(f"Handling request: {data}")
# Extract the prompt
if isinstance(data, dict) and "inputs" in data:
if isinstance(data["inputs"], str):
prompt = data["inputs"]
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
prompt = data["inputs"]["text"]
else:
prompt = "No prompt provided"
else:
prompt = "No prompt provided"
debug_log(f"Extracted prompt: {prompt}")
# Generate SVG
svg_content = self.generate_svg(prompt)
# Convert SVG to PNG
try:
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
image = Image.open(io.BytesIO(png_data))
debug_log("Generated image from SVG")
except Exception as e:
debug_log(f"Error converting SVG to PNG: {e}")
import traceback
debug_log(traceback.format_exc())
# Create a simple placeholder image
image = Image.new("RGB", (512, 512), color="#f0f0f0")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), prompt, fill="black", anchor="mm")
debug_log("Created placeholder image")
# Return the PIL Image directly
debug_log("Returning image")
return image
except Exception as e:
debug_log(f"Error in handler: {e}")
import traceback
debug_log(traceback.format_exc())
# Return a simple error image
image = Image.new("RGB", (512, 512), color="#ff0000")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
debug_log("Returning error image")
return image