import os import sys import torch import base64 import json import numpy as np import svgwrite import random import math from diffusers import StableDiffusionPipeline from transformers import CLIPTextModel, CLIPTokenizer from typing import List, Dict, Any, Tuple import io from PIL import Image class EndpointHandler: def __init__(self, path=""): """Initialize DiffSketchEdit handler for Hugging Face Inference API""" self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Initialize Stable Diffusion pipeline try: self.pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None, requires_safety_checker=False ) self.pipe = self.pipe.to(self.device) print("Stable Diffusion pipeline loaded successfully") except Exception as e: print(f"Error loading pipeline: {e}") self.pipe = None # Initialize tokenizer and text encoder try: self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") self.text_encoder = self.text_encoder.to(self.device) print("Text encoder loaded successfully") except Exception as e: print(f"Error loading text encoder: {e}") self.tokenizer = None self.text_encoder = None def __call__(self, data): """Edit vector sketches based on text prompts""" try: # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Handle different input formats if isinstance(inputs, dict): prompts = inputs.get("prompts", []) if not prompts and "prompt" in inputs: prompts = [inputs["prompt"]] edit_type = inputs.get("edit_type", "refine") input_svg = inputs.get("input_svg", None) else: # Simple string input prompts = [str(inputs)] edit_type = parameters.get("edit_type", "refine") input_svg = parameters.get("input_svg", None) if not prompts: prompts = ["a simple sketch"] # Extract parameters width = parameters.get("width", 224) height = parameters.get("height", 224) seed = parameters.get("seed", 42) # Set seed for reproducibility torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print(f"Processing edit type: '{edit_type}' with prompts: {prompts}") # Process based on edit type if edit_type == "replace" and len(prompts) >= 2: result = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg) elif edit_type == "refine": result = self.prompt_refinement_edit(prompts[0], width, height, input_svg) elif edit_type == "reweight": result = self.attention_reweighting_edit(prompts[0], width, height, input_svg) elif edit_type == "generate": result = self.simple_generation(prompts[0], width, height) else: # Default to refinement result = self.prompt_refinement_edit(prompts[0], width, height, input_svg) return result except Exception as e: print(f"Error in handler: {e}") # Return fallback result fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height) return { "svg": fallback_svg, "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'), "edit_type": edit_type, "error": str(e) } def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None): """Perform word replacement editing""" try: print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'") # Analyze the difference between prompts source_words = set(source_prompt.lower().split()) target_words = set(target_prompt.lower().split()) added_words = target_words - source_words removed_words = source_words - target_words print(f"Added words: {added_words}, Removed words: {removed_words}") # Generate base SVG from source prompt if input_svg: base_svg = input_svg else: base_svg = self.generate_base_svg(source_prompt, width, height) # Apply word replacement transformations edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height) return { "svg": edited_svg, "svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'), "edit_type": "replace", "source_prompt": source_prompt, "target_prompt": target_prompt, "added_words": list(added_words), "removed_words": list(removed_words) } except Exception as e: print(f"Error in word_replacement_edit: {e}") return self.create_error_result(source_prompt, "replace", str(e), width, height) def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None): """Perform prompt refinement editing""" try: print(f"Prompt refinement for: '{prompt}'") # Generate or use base SVG if input_svg: base_svg = input_svg else: base_svg = self.generate_base_svg(prompt, width, height) # Apply refinement based on prompt analysis refined_svg = self.apply_refinement(base_svg, prompt, width, height) return { "svg": refined_svg, "svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'), "edit_type": "refine", "prompt": prompt } except Exception as e: print(f"Error in prompt_refinement_edit: {e}") return self.create_error_result(prompt, "refine", str(e), width, height) def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None): """Perform attention reweighting editing""" try: print(f"Attention reweighting for: '{prompt}'") # Parse attention weights from prompt (e.g., "(cat:1.5)" or "[dog:0.8]") weighted_prompt, attention_weights = self.parse_attention_weights(prompt) # Generate or use base SVG if input_svg: base_svg = input_svg else: base_svg = self.generate_base_svg(weighted_prompt, width, height) # Apply attention reweighting reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height) return { "svg": reweighted_svg, "svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'), "edit_type": "reweight", "prompt": prompt, "weighted_prompt": weighted_prompt, "attention_weights": attention_weights } except Exception as e: print(f"Error in attention_reweighting_edit: {e}") return self.create_error_result(prompt, "reweight", str(e), width, height) def simple_generation(self, prompt: str, width: int, height: int): """Perform simple SVG generation""" try: print(f"Simple generation for: '{prompt}'") svg_content = self.generate_base_svg(prompt, width, height) return { "svg": svg_content, "svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'), "edit_type": "generate", "prompt": prompt } except Exception as e: print(f"Error in simple_generation: {e}") return self.create_error_result(prompt, "generate", str(e), width, height) def generate_base_svg(self, prompt: str, width: int, height: int): """Generate base SVG from prompt""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Analyze prompt to determine content prompt_lower = prompt.lower() if any(word in prompt_lower for word in ['house', 'building', 'home']): self._add_house_elements(dwg, width, height) elif any(word in prompt_lower for word in ['tree', 'forest', 'nature']): self._add_tree_elements(dwg, width, height) elif any(word in prompt_lower for word in ['car', 'vehicle', 'transport']): self._add_car_elements(dwg, width, height) elif any(word in prompt_lower for word in ['face', 'person', 'portrait']): self._add_face_elements(dwg, width, height) elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']): self._add_flower_elements(dwg, width, height) elif any(word in prompt_lower for word in ['cat', 'dog', 'animal']): self._add_animal_elements(dwg, width, height, prompt_lower) else: self._add_abstract_elements(dwg, width, height, prompt) return dwg.tostring() def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, added_words: set, removed_words: set, width: int, height: int): """Apply word replacement transformations to SVG""" # Parse the base SVG and modify based on word changes dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Analyze what needs to change for word in added_words: if word in ['red', 'blue', 'green', 'yellow', 'purple']: self._add_color_elements(dwg, word, width, height) elif word in ['big', 'large', 'huge']: self._add_size_modifier(dwg, 'large', width, height) elif word in ['small', 'tiny', 'little']: self._add_size_modifier(dwg, 'small', width, height) elif word in ['cat', 'dog', 'bird']: self._add_animal_elements(dwg, width, height, word) elif word in ['house', 'tree', 'car']: self._add_object_elements(dwg, word, width, height) # Apply transformations based on target prompt target_lower = target_prompt.lower() if any(word in target_lower for word in ['house', 'building']): self._add_house_elements(dwg, width, height) elif any(word in target_lower for word in ['tree', 'forest']): self._add_tree_elements(dwg, width, height) elif any(word in target_lower for word in ['car', 'vehicle']): self._add_car_elements(dwg, width, height) return dwg.tostring() def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int): """Apply refinement to existing SVG""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) prompt_lower = prompt.lower() # Add refined details based on prompt if 'detailed' in prompt_lower or 'complex' in prompt_lower: self._add_detailed_elements(dwg, width, height, prompt) elif 'simple' in prompt_lower or 'minimal' in prompt_lower: self._add_simple_elements(dwg, width, height, prompt) else: # Default refinement self._add_standard_elements(dwg, width, height, prompt) return dwg.tostring() def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int): """Apply attention reweighting to SVG elements""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Apply weighted emphasis to different elements for word, weight in attention_weights.items(): if weight > 1.0: # Emphasize this element self._emphasize_element(dwg, word, weight, width, height) elif weight < 1.0: # De-emphasize this element self._deemphasize_element(dwg, word, weight, width, height) # Add base elements self._add_standard_elements(dwg, width, height, prompt) return dwg.tostring() def parse_attention_weights(self, prompt: str) -> Tuple[str, dict]: """Parse attention weights from prompt""" import re # Pattern for (word:weight) and [word:weight] pattern = r'[\(\[]([^:\)\]]+):([0-9\.]+)[\)\]]' matches = re.findall(pattern, prompt) attention_weights = {} clean_prompt = prompt for word, weight_str in matches: try: weight = float(weight_str) attention_weights[word.strip()] = weight # Remove the weight notation from prompt clean_prompt = re.sub(rf'[\(\[]{re.escape(word)}:{re.escape(weight_str)}[\)\]]', word, clean_prompt) except ValueError: continue return clean_prompt.strip(), attention_weights def _add_house_elements(self, dwg, width, height): """Add house elements to SVG""" house_width = width * 0.6 house_height = height * 0.4 house_x = (width - house_width) / 2 house_y = height * 0.4 # House base dwg.add(dwg.rect( insert=(house_x, house_y), size=(house_width, house_height), fill='none', stroke='black', stroke_width=2 )) # Roof roof_points = [ (house_x, house_y), (house_x + house_width/2, house_y - house_height*0.3), (house_x + house_width, house_y) ] dwg.add(dwg.polygon(roof_points, fill='none', stroke='black', stroke_width=2)) # Door door_width = house_width * 0.2 door_height = house_height * 0.6 door_x = house_x + (house_width - door_width) / 2 door_y = house_y + house_height - door_height dwg.add(dwg.rect( insert=(door_x, door_y), size=(door_width, door_height), fill='none', stroke='black', stroke_width=2 )) def _add_tree_elements(self, dwg, width, height): """Add tree elements to SVG""" center_x = width / 2 center_y = height / 2 # Trunk trunk_width = 12 trunk_height = height * 0.3 dwg.add(dwg.rect( insert=(center_x - trunk_width/2, center_y + 20), size=(trunk_width, trunk_height), fill='none', stroke='black', stroke_width=2 )) # Crown crown_radius = width * 0.25 dwg.add(dwg.circle( center=(center_x, center_y), r=crown_radius, fill='none', stroke='black', stroke_width=2 )) def _add_car_elements(self, dwg, width, height): """Add car elements to SVG""" car_width = width * 0.7 car_height = height * 0.3 car_x = (width - car_width) / 2 car_y = (height - car_height) / 2 # Car body dwg.add(dwg.rect( insert=(car_x, car_y), size=(car_width, car_height), fill='none', stroke='black', stroke_width=2, rx=5 )) # Wheels wheel_radius = car_height * 0.4 wheel_y = car_y + car_height - wheel_radius/2 dwg.add(dwg.circle( center=(car_x + car_width * 0.2, wheel_y), r=wheel_radius, fill='none', stroke='black', stroke_width=2 )) dwg.add(dwg.circle( center=(car_x + car_width * 0.8, wheel_y), r=wheel_radius, fill='none', stroke='black', stroke_width=2 )) def _add_face_elements(self, dwg, width, height): """Add face elements to SVG""" center_x = width / 2 center_y = height / 2 face_radius = min(width, height) * 0.3 # Face outline dwg.add(dwg.circle( center=(center_x, center_y), r=face_radius, fill='none', stroke='black', stroke_width=2 )) # Eyes eye_offset = face_radius * 0.3 eye_radius = face_radius * 0.1 dwg.add(dwg.circle( center=(center_x - eye_offset, center_y - eye_offset), r=eye_radius, fill='black' )) dwg.add(dwg.circle( center=(center_x + eye_offset, center_y - eye_offset), r=eye_radius, fill='black' )) # Mouth mouth_y = center_y + face_radius * 0.3 dwg.add(dwg.path( d=f"M {center_x - face_radius*0.3},{mouth_y} Q {center_x},{mouth_y + face_radius*0.2} {center_x + face_radius*0.3},{mouth_y}", fill='none', stroke='black', stroke_width=2 )) def _add_flower_elements(self, dwg, width, height): """Add flower elements to SVG""" center_x = width / 2 center_y = height / 2 # Stem dwg.add(dwg.line( start=(center_x, center_y + 20), end=(center_x, height - 20), stroke='green', stroke_width=4 )) # Petals petal_radius = 15 for angle in range(0, 360, 45): x = center_x + 25 * math.cos(math.radians(angle)) y = center_y + 25 * math.sin(math.radians(angle)) dwg.add(dwg.circle( center=(x, y), r=petal_radius, fill='none', stroke='red', stroke_width=2 )) # Center dwg.add(dwg.circle( center=(center_x, center_y), r=8, fill='yellow', stroke='orange', stroke_width=2 )) def _add_animal_elements(self, dwg, width, height, animal_type): """Add animal elements to SVG""" center_x = width / 2 center_y = height / 2 if 'cat' in animal_type: # Cat body dwg.add(dwg.ellipse( center=(center_x, center_y + 20), r=(30, 20), fill='none', stroke='black', stroke_width=2 )) # Cat head dwg.add(dwg.circle( center=(center_x, center_y - 20), r=25, fill='none', stroke='black', stroke_width=2 )) # Cat ears ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)] ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)] dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2)) dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2)) elif 'dog' in animal_type: # Dog body dwg.add(dwg.ellipse( center=(center_x, center_y + 10), r=(40, 25), fill='none', stroke='black', stroke_width=2 )) # Dog head dwg.add(dwg.ellipse( center=(center_x, center_y - 25), r=(25, 20), fill='none', stroke='black', stroke_width=2 )) def _add_color_elements(self, dwg, color, width, height): """Add color-specific elements""" color_map = { 'red': '#FF0000', 'blue': '#0000FF', 'green': '#00FF00', 'yellow': '#FFFF00', 'purple': '#800080' } fill_color = color_map.get(color, '#000000') # Add a colored accent element dwg.add(dwg.circle( center=(width * 0.8, height * 0.2), r=15, fill=fill_color, stroke='black', stroke_width=1 )) def _add_size_modifier(self, dwg, size_type, width, height): """Add size modification indicators""" if size_type == 'large': # Add larger elements dwg.add(dwg.rect( insert=(10, 10), size=(width-20, height-20), fill='none', stroke='gray', stroke_width=3, stroke_dasharray='5,5' )) elif size_type == 'small': # Add smaller elements dwg.add(dwg.rect( insert=(width*0.3, height*0.3), size=(width*0.4, height*0.4), fill='none', stroke='gray', stroke_width=1, stroke_dasharray='2,2' )) def _add_object_elements(self, dwg, obj_type, width, height): """Add specific object elements""" if obj_type == 'house': self._add_house_elements(dwg, width, height) elif obj_type == 'tree': self._add_tree_elements(dwg, width, height) elif obj_type == 'car': self._add_car_elements(dwg, width, height) def _add_detailed_elements(self, dwg, width, height, prompt): """Add detailed elements for complex prompts""" # Add multiple overlapping shapes for complexity for i in range(8): x = random.randint(20, width-40) y = random.randint(20, height-40) size = random.randint(10, 30) shape_type = random.choice(['circle', 'rect', 'polygon']) if shape_type == 'circle': dwg.add(dwg.circle( center=(x, y), r=size, fill='none', stroke='black', stroke_width=1, opacity=0.7 )) elif shape_type == 'rect': dwg.add(dwg.rect( insert=(x-size, y-size), size=(size*2, size*2), fill='none', stroke='black', stroke_width=1, opacity=0.7 )) def _add_simple_elements(self, dwg, width, height, prompt): """Add simple elements for minimal prompts""" # Add just a few basic shapes center_x = width / 2 center_y = height / 2 dwg.add(dwg.circle( center=(center_x, center_y), r=min(width, height) * 0.2, fill='none', stroke='black', stroke_width=2 )) def _add_standard_elements(self, dwg, width, height, prompt): """Add standard elements based on prompt""" prompt_lower = prompt.lower() if any(word in prompt_lower for word in ['house', 'building']): self._add_house_elements(dwg, width, height) elif any(word in prompt_lower for word in ['tree', 'forest']): self._add_tree_elements(dwg, width, height) elif any(word in prompt_lower for word in ['car', 'vehicle']): self._add_car_elements(dwg, width, height) else: self._add_abstract_elements(dwg, width, height, prompt) def _add_abstract_elements(self, dwg, width, height, prompt): """Add abstract elements based on prompt""" prompt_hash = hash(prompt) % 100 for i in range(5): x = (i * 40 + prompt_hash) % (width - 40) + 20 y = (i * 35 + prompt_hash) % (height - 40) + 20 size = 15 + (i * 5) % 20 dwg.add(dwg.circle( center=(x, y), r=size, fill='none', stroke='black', stroke_width=2, opacity=0.8 )) def _emphasize_element(self, dwg, word, weight, width, height): """Emphasize an element based on attention weight""" # Make elements larger and more prominent scale_factor = weight stroke_width = int(2 * scale_factor) if word in ['house', 'building']: # Emphasized house house_size = min(width, height) * 0.4 * scale_factor house_x = (width - house_size) / 2 house_y = (height - house_size) / 2 dwg.add(dwg.rect( insert=(house_x, house_y), size=(house_size, house_size * 0.8), fill='none', stroke='red', stroke_width=stroke_width )) def _deemphasize_element(self, dwg, word, weight, width, height): """De-emphasize an element based on attention weight""" # Make elements smaller and less prominent scale_factor = weight stroke_width = max(1, int(2 * scale_factor)) if word in ['background', 'sky']: # De-emphasized background elements dwg.add(dwg.rect( insert=(0, 0), size=(width, height * 0.3), fill='none', stroke='lightgray', stroke_width=stroke_width, opacity=scale_factor )) def create_error_result(self, prompt: str, edit_type: str, error: str, width: int, height: int): """Create error result with fallback SVG""" fallback_svg = self.create_fallback_svg(prompt, width, height) return { "svg": fallback_svg, "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'), "edit_type": edit_type, "prompt": prompt, "error": error } def create_fallback_svg(self, prompt: str, width: int, height: int): """Create simple fallback SVG""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Simple centered text dwg.add(dwg.text( f"DiffSketchEdit\n{prompt[:20]}...", insert=(width/2, height/2), text_anchor="middle", font_size="14", fill="black" )) return dwg.tostring()