diffsketcher / handler.py
jree423's picture
Add: diffsketcher handler.py with original implementation
6d9283b verified
raw
history blame
7.19 kB
import os
import io
import sys
import torch
import numpy as np
from PIL import Image
import traceback
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting handler initialization")
# Safely import cairosvg with fallback
try:
import cairosvg
debug_log("Successfully imported cairosvg")
except ImportError:
debug_log("cairosvg not found. Installing...")
import subprocess
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
import cairosvg
debug_log("Installed and imported cairosvg")
# Add the model directory to the path
sys.path.append('/code/diffsketcher')
# Try to import the model
try:
from models.clip_model import ClipModel
from models.diffusion_model import DiffusionModel
from models.sketch_model import SketchModel
debug_log("Successfully imported DiffSketcher models")
except ImportError as e:
debug_log(f"Error importing DiffSketcher models: {e}")
debug_log(traceback.format_exc())
class EndpointHandler:
def __init__(self, model_dir):
"""Initialize the handler with model directory"""
try:
debug_log(f"Initializing handler with model_dir: {model_dir}")
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
debug_log(f"Using device: {self.device}")
# Initialize the model
try:
self.clip_model = ClipModel(device=self.device)
self.diffusion_model = DiffusionModel(device=self.device)
self.sketch_model = SketchModel(device=self.device)
# Load checkpoint if available
weights_path = os.path.join(model_dir, "checkpoint.pth")
if os.path.exists(weights_path):
debug_log(f"Loading checkpoint from {weights_path}")
checkpoint = torch.load(weights_path, map_location=self.device)
self.sketch_model.load_state_dict(checkpoint['sketch_model'])
debug_log("Successfully loaded checkpoint")
self.use_model = True
else:
debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
self.use_model = True
except Exception as e:
debug_log(f"Error initializing model: {e}")
debug_log(traceback.format_exc())
self.use_model = False
except Exception as e:
debug_log(f"Error in handler initialization: {e}")
debug_log(traceback.format_exc())
self.use_model = False
def generate_svg(self, prompt, width=512, height=512):
"""Generate an SVG from a text prompt"""
debug_log(f"Generating SVG for prompt: {prompt}")
if self.use_model:
try:
debug_log("Using initialized model")
# Generate SVG using DiffSketcher
text_features = self.clip_model.encode_text(prompt)
latent = self.diffusion_model.generate(text_features)
svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
debug_log("Generated SVG using DiffSketcher")
return svg_data
except Exception as e:
debug_log(f"Error generating SVG with model: {e}")
debug_log(traceback.format_exc())
return self._generate_placeholder_svg(prompt, width, height)
else:
debug_log("Model not initialized, using placeholder")
return self._generate_placeholder_svg(prompt, width, height)
def _generate_placeholder_svg(self, prompt, width=512, height=512):
"""Generate a placeholder SVG"""
debug_log(f"Generating placeholder SVG for prompt: {prompt}")
# Create a more interesting placeholder that looks like a sketch
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="#ffffff"/>
<g stroke="#000000" fill="none">
<!-- Draw a simple sketch based on the prompt -->
<circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/>
<ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/>
<path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/>
<path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/>
</g>
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
</svg>"""
debug_log("Generated placeholder SVG")
return svg_content
def __call__(self, data):
"""Handle a request to the model"""
try:
debug_log(f"Handling request: {data}")
# Extract the prompt
if isinstance(data, dict) and "inputs" in data:
if isinstance(data["inputs"], str):
prompt = data["inputs"]
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
prompt = data["inputs"]["text"]
else:
prompt = "No prompt provided"
else:
prompt = "No prompt provided"
debug_log(f"Extracted prompt: {prompt}")
# Generate SVG
svg_content = self.generate_svg(prompt)
# Convert SVG to PNG
try:
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
image = Image.open(io.BytesIO(png_data))
debug_log("Generated image from SVG")
except Exception as e:
debug_log(f"Error converting SVG to PNG: {e}")
debug_log(traceback.format_exc())
# Create a simple placeholder image
image = Image.new("RGB", (512, 512), color="#f0f0f0")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), prompt, fill="black", anchor="mm")
debug_log("Created placeholder image")
# Return the PIL Image directly
debug_log("Returning image")
return image
except Exception as e:
debug_log(f"Error in handler: {e}")
debug_log(traceback.format_exc())
# Return a simple error image
image = Image.new("RGB", (512, 512), color="#ff0000")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
debug_log("Returning error image")
return image