import os import io import sys import torch import numpy as np from PIL import Image import traceback # Add debug logging def debug_log(message): print(f"DEBUG: {message}") sys.stdout.flush() debug_log("Starting handler initialization") # Safely import cairosvg with fallback try: import cairosvg debug_log("Successfully imported cairosvg") except ImportError: debug_log("cairosvg not found. Installing...") import subprocess subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"]) import cairosvg debug_log("Installed and imported cairosvg") # Add the model directory to the path sys.path.append('/code/diffsketcher') # Try to import the model try: from models.clip_model import ClipModel from models.diffusion_model import DiffusionModel from models.sketch_model import SketchModel debug_log("Successfully imported DiffSketcher models") except ImportError as e: debug_log(f"Error importing DiffSketcher models: {e}") debug_log(traceback.format_exc()) class EndpointHandler: def __init__(self, model_dir): """Initialize the handler with model directory""" try: debug_log(f"Initializing handler with model_dir: {model_dir}") self.model_dir = model_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") debug_log(f"Using device: {self.device}") # Initialize the model try: self.clip_model = ClipModel(device=self.device) self.diffusion_model = DiffusionModel(device=self.device) self.sketch_model = SketchModel(device=self.device) # Load checkpoint if available weights_path = os.path.join(model_dir, "checkpoint.pth") if os.path.exists(weights_path): debug_log(f"Loading checkpoint from {weights_path}") checkpoint = torch.load(weights_path, map_location=self.device) self.sketch_model.load_state_dict(checkpoint['sketch_model']) debug_log("Successfully loaded checkpoint") self.use_model = True else: debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights") self.use_model = True except Exception as e: debug_log(f"Error initializing model: {e}") debug_log(traceback.format_exc()) self.use_model = False except Exception as e: debug_log(f"Error in handler initialization: {e}") debug_log(traceback.format_exc()) self.use_model = False def generate_svg(self, prompt, width=512, height=512): """Generate an SVG from a text prompt""" debug_log(f"Generating SVG for prompt: {prompt}") if self.use_model: try: debug_log("Using initialized model") # Generate SVG using DiffSketcher text_features = self.clip_model.encode_text(prompt) latent = self.diffusion_model.generate(text_features) svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height) debug_log("Generated SVG using DiffSketcher") return svg_data except Exception as e: debug_log(f"Error generating SVG with model: {e}") debug_log(traceback.format_exc()) return self._generate_placeholder_svg(prompt, width, height) else: debug_log("Model not initialized, using placeholder") return self._generate_placeholder_svg(prompt, width, height) def _generate_placeholder_svg(self, prompt, width=512, height=512): """Generate a placeholder SVG""" debug_log(f"Generating placeholder SVG for prompt: {prompt}") # Create a more interesting placeholder that looks like a sketch svg_content = f""" {prompt} """ debug_log("Generated placeholder SVG") return svg_content def __call__(self, data): """Handle a request to the model""" try: debug_log(f"Handling request: {data}") # Extract the prompt if isinstance(data, dict) and "inputs" in data: if isinstance(data["inputs"], str): prompt = data["inputs"] elif isinstance(data["inputs"], dict) and "text" in data["inputs"]: prompt = data["inputs"]["text"] else: prompt = "No prompt provided" else: prompt = "No prompt provided" debug_log(f"Extracted prompt: {prompt}") # Generate SVG svg_content = self.generate_svg(prompt) # Convert SVG to PNG try: png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) image = Image.open(io.BytesIO(png_data)) debug_log("Generated image from SVG") except Exception as e: debug_log(f"Error converting SVG to PNG: {e}") debug_log(traceback.format_exc()) # Create a simple placeholder image image = Image.new("RGB", (512, 512), color="#f0f0f0") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), prompt, fill="black", anchor="mm") debug_log("Created placeholder image") # Return the PIL Image directly debug_log("Returning image") return image except Exception as e: debug_log(f"Error in handler: {e}") debug_log(traceback.format_exc()) # Return a simple error image image = Image.new("RGB", (512, 512), color="#ff0000") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") debug_log("Returning error image") return image