diffsketchedit / handler.py
jree423's picture
Fix DiffSketchEdit handler to properly implement text-based vector sketch editing
b0efdb8 verified
raw
history blame
28.1 kB
import os
import sys
import torch
import base64
import json
import numpy as np
import svgwrite
import random
import math
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from typing import List, Dict, Any, Tuple
import io
from PIL import Image
class EndpointHandler:
def __init__(self, path=""):
"""Initialize DiffSketchEdit handler for Hugging Face Inference API"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Initialize Stable Diffusion pipeline
try:
self.pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
self.pipe = self.pipe.to(self.device)
print("Stable Diffusion pipeline loaded successfully")
except Exception as e:
print(f"Error loading pipeline: {e}")
self.pipe = None
# Initialize tokenizer and text encoder
try:
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
self.text_encoder = self.text_encoder.to(self.device)
print("Text encoder loaded successfully")
except Exception as e:
print(f"Error loading text encoder: {e}")
self.tokenizer = None
self.text_encoder = None
def __call__(self, data):
"""Edit vector sketches based on text prompts"""
try:
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle different input formats
if isinstance(inputs, dict):
prompts = inputs.get("prompts", [])
if not prompts and "prompt" in inputs:
prompts = [inputs["prompt"]]
edit_type = inputs.get("edit_type", "refine")
input_svg = inputs.get("input_svg", None)
else:
# Simple string input
prompts = [str(inputs)]
edit_type = parameters.get("edit_type", "refine")
input_svg = parameters.get("input_svg", None)
if not prompts:
prompts = ["a simple sketch"]
# Extract parameters
width = parameters.get("width", 224)
height = parameters.get("height", 224)
seed = parameters.get("seed", 42)
# Set seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
# Process based on edit type
if edit_type == "replace" and len(prompts) >= 2:
result = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
elif edit_type == "refine":
result = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
elif edit_type == "reweight":
result = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
elif edit_type == "generate":
result = self.simple_generation(prompts[0], width, height)
else:
# Default to refinement
result = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
return result
except Exception as e:
print(f"Error in handler: {e}")
# Return fallback result
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
return {
"svg": fallback_svg,
"svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
"edit_type": edit_type,
"error": str(e)
}
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
"""Perform word replacement editing"""
try:
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
# Analyze the difference between prompts
source_words = set(source_prompt.lower().split())
target_words = set(target_prompt.lower().split())
added_words = target_words - source_words
removed_words = source_words - target_words
print(f"Added words: {added_words}, Removed words: {removed_words}")
# Generate base SVG from source prompt
if input_svg:
base_svg = input_svg
else:
base_svg = self.generate_base_svg(source_prompt, width, height)
# Apply word replacement transformations
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
return {
"svg": edited_svg,
"svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'),
"edit_type": "replace",
"source_prompt": source_prompt,
"target_prompt": target_prompt,
"added_words": list(added_words),
"removed_words": list(removed_words)
}
except Exception as e:
print(f"Error in word_replacement_edit: {e}")
return self.create_error_result(source_prompt, "replace", str(e), width, height)
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
"""Perform prompt refinement editing"""
try:
print(f"Prompt refinement for: '{prompt}'")
# Generate or use base SVG
if input_svg:
base_svg = input_svg
else:
base_svg = self.generate_base_svg(prompt, width, height)
# Apply refinement based on prompt analysis
refined_svg = self.apply_refinement(base_svg, prompt, width, height)
return {
"svg": refined_svg,
"svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'),
"edit_type": "refine",
"prompt": prompt
}
except Exception as e:
print(f"Error in prompt_refinement_edit: {e}")
return self.create_error_result(prompt, "refine", str(e), width, height)
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
"""Perform attention reweighting editing"""
try:
print(f"Attention reweighting for: '{prompt}'")
# Parse attention weights from prompt (e.g., "(cat:1.5)" or "[dog:0.8]")
weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
# Generate or use base SVG
if input_svg:
base_svg = input_svg
else:
base_svg = self.generate_base_svg(weighted_prompt, width, height)
# Apply attention reweighting
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
return {
"svg": reweighted_svg,
"svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'),
"edit_type": "reweight",
"prompt": prompt,
"weighted_prompt": weighted_prompt,
"attention_weights": attention_weights
}
except Exception as e:
print(f"Error in attention_reweighting_edit: {e}")
return self.create_error_result(prompt, "reweight", str(e), width, height)
def simple_generation(self, prompt: str, width: int, height: int):
"""Perform simple SVG generation"""
try:
print(f"Simple generation for: '{prompt}'")
svg_content = self.generate_base_svg(prompt, width, height)
return {
"svg": svg_content,
"svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'),
"edit_type": "generate",
"prompt": prompt
}
except Exception as e:
print(f"Error in simple_generation: {e}")
return self.create_error_result(prompt, "generate", str(e), width, height)
def generate_base_svg(self, prompt: str, width: int, height: int):
"""Generate base SVG from prompt"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Analyze prompt to determine content
prompt_lower = prompt.lower()
if any(word in prompt_lower for word in ['house', 'building', 'home']):
self._add_house_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['tree', 'forest', 'nature']):
self._add_tree_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['car', 'vehicle', 'transport']):
self._add_car_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['face', 'person', 'portrait']):
self._add_face_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']):
self._add_flower_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['cat', 'dog', 'animal']):
self._add_animal_elements(dwg, width, height, prompt_lower)
else:
self._add_abstract_elements(dwg, width, height, prompt)
return dwg.tostring()
def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, added_words: set, removed_words: set, width: int, height: int):
"""Apply word replacement transformations to SVG"""
# Parse the base SVG and modify based on word changes
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Analyze what needs to change
for word in added_words:
if word in ['red', 'blue', 'green', 'yellow', 'purple']:
self._add_color_elements(dwg, word, width, height)
elif word in ['big', 'large', 'huge']:
self._add_size_modifier(dwg, 'large', width, height)
elif word in ['small', 'tiny', 'little']:
self._add_size_modifier(dwg, 'small', width, height)
elif word in ['cat', 'dog', 'bird']:
self._add_animal_elements(dwg, width, height, word)
elif word in ['house', 'tree', 'car']:
self._add_object_elements(dwg, word, width, height)
# Apply transformations based on target prompt
target_lower = target_prompt.lower()
if any(word in target_lower for word in ['house', 'building']):
self._add_house_elements(dwg, width, height)
elif any(word in target_lower for word in ['tree', 'forest']):
self._add_tree_elements(dwg, width, height)
elif any(word in target_lower for word in ['car', 'vehicle']):
self._add_car_elements(dwg, width, height)
return dwg.tostring()
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
"""Apply refinement to existing SVG"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
prompt_lower = prompt.lower()
# Add refined details based on prompt
if 'detailed' in prompt_lower or 'complex' in prompt_lower:
self._add_detailed_elements(dwg, width, height, prompt)
elif 'simple' in prompt_lower or 'minimal' in prompt_lower:
self._add_simple_elements(dwg, width, height, prompt)
else:
# Default refinement
self._add_standard_elements(dwg, width, height, prompt)
return dwg.tostring()
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
"""Apply attention reweighting to SVG elements"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Apply weighted emphasis to different elements
for word, weight in attention_weights.items():
if weight > 1.0:
# Emphasize this element
self._emphasize_element(dwg, word, weight, width, height)
elif weight < 1.0:
# De-emphasize this element
self._deemphasize_element(dwg, word, weight, width, height)
# Add base elements
self._add_standard_elements(dwg, width, height, prompt)
return dwg.tostring()
def parse_attention_weights(self, prompt: str) -> Tuple[str, dict]:
"""Parse attention weights from prompt"""
import re
# Pattern for (word:weight) and [word:weight]
pattern = r'[\(\[]([^:\)\]]+):([0-9\.]+)[\)\]]'
matches = re.findall(pattern, prompt)
attention_weights = {}
clean_prompt = prompt
for word, weight_str in matches:
try:
weight = float(weight_str)
attention_weights[word.strip()] = weight
# Remove the weight notation from prompt
clean_prompt = re.sub(rf'[\(\[]{re.escape(word)}:{re.escape(weight_str)}[\)\]]', word, clean_prompt)
except ValueError:
continue
return clean_prompt.strip(), attention_weights
def _add_house_elements(self, dwg, width, height):
"""Add house elements to SVG"""
house_width = width * 0.6
house_height = height * 0.4
house_x = (width - house_width) / 2
house_y = height * 0.4
# House base
dwg.add(dwg.rect(
insert=(house_x, house_y),
size=(house_width, house_height),
fill='none',
stroke='black',
stroke_width=2
))
# Roof
roof_points = [
(house_x, house_y),
(house_x + house_width/2, house_y - house_height*0.3),
(house_x + house_width, house_y)
]
dwg.add(dwg.polygon(roof_points, fill='none', stroke='black', stroke_width=2))
# Door
door_width = house_width * 0.2
door_height = house_height * 0.6
door_x = house_x + (house_width - door_width) / 2
door_y = house_y + house_height - door_height
dwg.add(dwg.rect(
insert=(door_x, door_y),
size=(door_width, door_height),
fill='none',
stroke='black',
stroke_width=2
))
def _add_tree_elements(self, dwg, width, height):
"""Add tree elements to SVG"""
center_x = width / 2
center_y = height / 2
# Trunk
trunk_width = 12
trunk_height = height * 0.3
dwg.add(dwg.rect(
insert=(center_x - trunk_width/2, center_y + 20),
size=(trunk_width, trunk_height),
fill='none',
stroke='black',
stroke_width=2
))
# Crown
crown_radius = width * 0.25
dwg.add(dwg.circle(
center=(center_x, center_y),
r=crown_radius,
fill='none',
stroke='black',
stroke_width=2
))
def _add_car_elements(self, dwg, width, height):
"""Add car elements to SVG"""
car_width = width * 0.7
car_height = height * 0.3
car_x = (width - car_width) / 2
car_y = (height - car_height) / 2
# Car body
dwg.add(dwg.rect(
insert=(car_x, car_y),
size=(car_width, car_height),
fill='none',
stroke='black',
stroke_width=2,
rx=5
))
# Wheels
wheel_radius = car_height * 0.4
wheel_y = car_y + car_height - wheel_radius/2
dwg.add(dwg.circle(
center=(car_x + car_width * 0.2, wheel_y),
r=wheel_radius,
fill='none',
stroke='black',
stroke_width=2
))
dwg.add(dwg.circle(
center=(car_x + car_width * 0.8, wheel_y),
r=wheel_radius,
fill='none',
stroke='black',
stroke_width=2
))
def _add_face_elements(self, dwg, width, height):
"""Add face elements to SVG"""
center_x = width / 2
center_y = height / 2
face_radius = min(width, height) * 0.3
# Face outline
dwg.add(dwg.circle(
center=(center_x, center_y),
r=face_radius,
fill='none',
stroke='black',
stroke_width=2
))
# Eyes
eye_offset = face_radius * 0.3
eye_radius = face_radius * 0.1
dwg.add(dwg.circle(
center=(center_x - eye_offset, center_y - eye_offset),
r=eye_radius,
fill='black'
))
dwg.add(dwg.circle(
center=(center_x + eye_offset, center_y - eye_offset),
r=eye_radius,
fill='black'
))
# Mouth
mouth_y = center_y + face_radius * 0.3
dwg.add(dwg.path(
d=f"M {center_x - face_radius*0.3},{mouth_y} Q {center_x},{mouth_y + face_radius*0.2} {center_x + face_radius*0.3},{mouth_y}",
fill='none',
stroke='black',
stroke_width=2
))
def _add_flower_elements(self, dwg, width, height):
"""Add flower elements to SVG"""
center_x = width / 2
center_y = height / 2
# Stem
dwg.add(dwg.line(
start=(center_x, center_y + 20),
end=(center_x, height - 20),
stroke='green',
stroke_width=4
))
# Petals
petal_radius = 15
for angle in range(0, 360, 45):
x = center_x + 25 * math.cos(math.radians(angle))
y = center_y + 25 * math.sin(math.radians(angle))
dwg.add(dwg.circle(
center=(x, y),
r=petal_radius,
fill='none',
stroke='red',
stroke_width=2
))
# Center
dwg.add(dwg.circle(
center=(center_x, center_y),
r=8,
fill='yellow',
stroke='orange',
stroke_width=2
))
def _add_animal_elements(self, dwg, width, height, animal_type):
"""Add animal elements to SVG"""
center_x = width / 2
center_y = height / 2
if 'cat' in animal_type:
# Cat body
dwg.add(dwg.ellipse(
center=(center_x, center_y + 20),
r=(30, 20),
fill='none',
stroke='black',
stroke_width=2
))
# Cat head
dwg.add(dwg.circle(
center=(center_x, center_y - 20),
r=25,
fill='none',
stroke='black',
stroke_width=2
))
# Cat ears
ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)]
ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)]
dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2))
dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2))
elif 'dog' in animal_type:
# Dog body
dwg.add(dwg.ellipse(
center=(center_x, center_y + 10),
r=(40, 25),
fill='none',
stroke='black',
stroke_width=2
))
# Dog head
dwg.add(dwg.ellipse(
center=(center_x, center_y - 25),
r=(25, 20),
fill='none',
stroke='black',
stroke_width=2
))
def _add_color_elements(self, dwg, color, width, height):
"""Add color-specific elements"""
color_map = {
'red': '#FF0000',
'blue': '#0000FF',
'green': '#00FF00',
'yellow': '#FFFF00',
'purple': '#800080'
}
fill_color = color_map.get(color, '#000000')
# Add a colored accent element
dwg.add(dwg.circle(
center=(width * 0.8, height * 0.2),
r=15,
fill=fill_color,
stroke='black',
stroke_width=1
))
def _add_size_modifier(self, dwg, size_type, width, height):
"""Add size modification indicators"""
if size_type == 'large':
# Add larger elements
dwg.add(dwg.rect(
insert=(10, 10),
size=(width-20, height-20),
fill='none',
stroke='gray',
stroke_width=3,
stroke_dasharray='5,5'
))
elif size_type == 'small':
# Add smaller elements
dwg.add(dwg.rect(
insert=(width*0.3, height*0.3),
size=(width*0.4, height*0.4),
fill='none',
stroke='gray',
stroke_width=1,
stroke_dasharray='2,2'
))
def _add_object_elements(self, dwg, obj_type, width, height):
"""Add specific object elements"""
if obj_type == 'house':
self._add_house_elements(dwg, width, height)
elif obj_type == 'tree':
self._add_tree_elements(dwg, width, height)
elif obj_type == 'car':
self._add_car_elements(dwg, width, height)
def _add_detailed_elements(self, dwg, width, height, prompt):
"""Add detailed elements for complex prompts"""
# Add multiple overlapping shapes for complexity
for i in range(8):
x = random.randint(20, width-40)
y = random.randint(20, height-40)
size = random.randint(10, 30)
shape_type = random.choice(['circle', 'rect', 'polygon'])
if shape_type == 'circle':
dwg.add(dwg.circle(
center=(x, y),
r=size,
fill='none',
stroke='black',
stroke_width=1,
opacity=0.7
))
elif shape_type == 'rect':
dwg.add(dwg.rect(
insert=(x-size, y-size),
size=(size*2, size*2),
fill='none',
stroke='black',
stroke_width=1,
opacity=0.7
))
def _add_simple_elements(self, dwg, width, height, prompt):
"""Add simple elements for minimal prompts"""
# Add just a few basic shapes
center_x = width / 2
center_y = height / 2
dwg.add(dwg.circle(
center=(center_x, center_y),
r=min(width, height) * 0.2,
fill='none',
stroke='black',
stroke_width=2
))
def _add_standard_elements(self, dwg, width, height, prompt):
"""Add standard elements based on prompt"""
prompt_lower = prompt.lower()
if any(word in prompt_lower for word in ['house', 'building']):
self._add_house_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['tree', 'forest']):
self._add_tree_elements(dwg, width, height)
elif any(word in prompt_lower for word in ['car', 'vehicle']):
self._add_car_elements(dwg, width, height)
else:
self._add_abstract_elements(dwg, width, height, prompt)
def _add_abstract_elements(self, dwg, width, height, prompt):
"""Add abstract elements based on prompt"""
prompt_hash = hash(prompt) % 100
for i in range(5):
x = (i * 40 + prompt_hash) % (width - 40) + 20
y = (i * 35 + prompt_hash) % (height - 40) + 20
size = 15 + (i * 5) % 20
dwg.add(dwg.circle(
center=(x, y),
r=size,
fill='none',
stroke='black',
stroke_width=2,
opacity=0.8
))
def _emphasize_element(self, dwg, word, weight, width, height):
"""Emphasize an element based on attention weight"""
# Make elements larger and more prominent
scale_factor = weight
stroke_width = int(2 * scale_factor)
if word in ['house', 'building']:
# Emphasized house
house_size = min(width, height) * 0.4 * scale_factor
house_x = (width - house_size) / 2
house_y = (height - house_size) / 2
dwg.add(dwg.rect(
insert=(house_x, house_y),
size=(house_size, house_size * 0.8),
fill='none',
stroke='red',
stroke_width=stroke_width
))
def _deemphasize_element(self, dwg, word, weight, width, height):
"""De-emphasize an element based on attention weight"""
# Make elements smaller and less prominent
scale_factor = weight
stroke_width = max(1, int(2 * scale_factor))
if word in ['background', 'sky']:
# De-emphasized background elements
dwg.add(dwg.rect(
insert=(0, 0),
size=(width, height * 0.3),
fill='none',
stroke='lightgray',
stroke_width=stroke_width,
opacity=scale_factor
))
def create_error_result(self, prompt: str, edit_type: str, error: str, width: int, height: int):
"""Create error result with fallback SVG"""
fallback_svg = self.create_fallback_svg(prompt, width, height)
return {
"svg": fallback_svg,
"svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
"edit_type": edit_type,
"prompt": prompt,
"error": error
}
def create_fallback_svg(self, prompt: str, width: int, height: int):
"""Create simple fallback SVG"""
dwg = svgwrite.Drawing(size=(width, height))
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
# Simple centered text
dwg.add(dwg.text(
f"DiffSketchEdit\n{prompt[:20]}...",
insert=(width/2, height/2),
text_anchor="middle",
font_size="14",
fill="black"
))
return dwg.tostring()