import os import sys import tempfile import shutil from pathlib import Path import torch import yaml from omegaconf import OmegaConf from PIL import Image import io import cairosvg # Add DiffSketcher modules to path sys.path.append('/workspace/DiffSketcher') class EndpointHandler: def __init__(self, path=""): """Initialize DiffSketcher model for Hugging Face Inference API""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Initializing DiffSketcher on {self.device}") try: # Import DiffSketcher modules from libs.engine import ModelState from methods.painter.diffsketcher import DiffSketcher # Load configuration config_path = Path(path) / "config" / "diffsketcher.yaml" if not config_path.exists(): # Use default config config_path = Path(__file__).parent / "config" / "diffsketcher.yaml" with open(config_path, 'r') as f: self.config = OmegaConf.load(f) # Initialize model components self.model_state = ModelState(self.config) self.painter = DiffSketcher(self.config, self.device, self.model_state) print("DiffSketcher initialized successfully") except Exception as e: print(f"Error initializing DiffSketcher: {e}") # Fall back to simple SVG generation self.painter = None self.config = None def __call__(self, data): """ Generate sketch image from text prompt Args: data (dict): Input data containing: - inputs (str): Text prompt - parameters (dict): Generation parameters Returns: PIL.Image.Image: Generated sketch image """ try: # Extract inputs prompt = data.get("inputs", "") parameters = data.get("parameters", {}) if not prompt: return self._create_error_image("No prompt provided") # Extract parameters num_paths = parameters.get("num_paths", 96) num_iter = parameters.get("num_iter", 500) guidance_scale = parameters.get("guidance_scale", 7.5) seed = parameters.get("seed", 42) width = parameters.get("width", 224) height = parameters.get("height", 224) # Generate SVG if self.painter is not None: svg_content = self._generate_with_diffsketcher( prompt, num_paths, num_iter, guidance_scale, seed ) else: svg_content = self._generate_fallback_svg(prompt, width, height) # Convert SVG to PIL Image image = self._svg_to_image(svg_content, width, height) return image except Exception as e: print(f"Error in DiffSketcher inference: {e}") return self._create_error_image(f"Error: {str(e)[:50]}") def _generate_with_diffsketcher(self, prompt, num_paths, num_iter, guidance_scale, seed): """Generate SVG using actual DiffSketcher model""" try: # Set random seed torch.manual_seed(seed) # Create temporary directory for output with tempfile.TemporaryDirectory() as temp_dir: output_dir = Path(temp_dir) / "output" output_dir.mkdir(exist_ok=True) # Update config with parameters config = self.config.copy() config.num_paths = num_paths config.num_iter = num_iter config.guidance_scale = guidance_scale config.prompt = prompt config.output_dir = str(output_dir) # Generate sketch self.painter.paint( prompt=prompt, output_dir=str(output_dir), num_paths=num_paths, num_iter=num_iter ) # Find generated SVG file svg_files = list(output_dir.glob("*.svg")) if svg_files: with open(svg_files[0], 'r') as f: return f.read() else: raise Exception("No SVG file generated") except Exception as e: print(f"DiffSketcher generation failed: {e}") return self._generate_fallback_svg(prompt, 224, 224) def _generate_fallback_svg(self, prompt, width, height): """Generate simple SVG when model fails""" import random import math # Set seed for reproducibility random.seed(hash(prompt) % 1000) svg_parts = [f''] svg_parts.append(f'') # Generate sketch based on prompt keywords prompt_lower = prompt.lower() cx, cy = width // 2, height // 2 if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']): # Simple car sketch svg_parts.extend([ f'', f'', f'', f'' ]) elif any(word in prompt_lower for word in ['house', 'building', 'home']): # Simple house sketch svg_parts.extend([ f'', f'', f'', f'', f'' ]) else: # Abstract sketch for i in range(5): x = random.randint(20, width-20) y = random.randint(20, height-20) size = random.randint(10, 30) if i % 3 == 0: svg_parts.append(f'') elif i % 3 == 1: svg_parts.append(f'') else: points = [] for j in range(3): px = x + size * math.cos(j * 120 * math.pi / 180) py = y + size * math.sin(j * 120 * math.pi / 180) points.append(f"{px},{py}") svg_parts.append(f'') svg_parts.append('') return '\n'.join(svg_parts) def _svg_to_image(self, svg_content, width=224, height=224): """Convert SVG to PIL Image""" try: # Convert SVG to PNG using cairosvg png_data = cairosvg.svg2png( bytestring=svg_content.encode('utf-8'), output_width=width, output_height=height ) # Convert to PIL Image image = Image.open(io.BytesIO(png_data)) return image.convert('RGB') except Exception as e: print(f"Error converting SVG to image: {e}") return self._create_error_image("SVG conversion failed") def _create_error_image(self, message, width=224, height=224): """Create error image""" image = Image.new('RGB', (width, height), 'white') return image