|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import base64 |
|
|
import json |
|
|
import numpy as np |
|
|
import svgwrite |
|
|
import random |
|
|
import math |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from typing import List, Dict, Any, Tuple |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
"""Initialize DiffSketchEdit handler for Hugging Face Inference API""" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.pipe = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
) |
|
|
self.pipe = self.pipe.to(self.device) |
|
|
print("Stable Diffusion pipeline loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading pipeline: {e}") |
|
|
self.pipe = None |
|
|
|
|
|
|
|
|
try: |
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.text_encoder = self.text_encoder.to(self.device) |
|
|
print("Text encoder loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading text encoder: {e}") |
|
|
self.tokenizer = None |
|
|
self.text_encoder = None |
|
|
|
|
|
def __call__(self, data): |
|
|
"""Edit vector sketches based on text prompts""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
prompts = inputs.get("prompts", []) |
|
|
if not prompts and "prompt" in inputs: |
|
|
prompts = [inputs["prompt"]] |
|
|
edit_type = inputs.get("edit_type", "refine") |
|
|
input_svg = inputs.get("input_svg", None) |
|
|
else: |
|
|
|
|
|
prompts = [str(inputs)] |
|
|
edit_type = parameters.get("edit_type", "refine") |
|
|
input_svg = parameters.get("input_svg", None) |
|
|
|
|
|
if not prompts: |
|
|
prompts = ["a simple sketch"] |
|
|
|
|
|
|
|
|
width = parameters.get("width", 224) |
|
|
height = parameters.get("height", 224) |
|
|
seed = parameters.get("seed", 42) |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}") |
|
|
|
|
|
|
|
|
if edit_type == "replace" and len(prompts) >= 2: |
|
|
result = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg) |
|
|
elif edit_type == "refine": |
|
|
result = self.prompt_refinement_edit(prompts[0], width, height, input_svg) |
|
|
elif edit_type == "reweight": |
|
|
result = self.attention_reweighting_edit(prompts[0], width, height, input_svg) |
|
|
elif edit_type == "generate": |
|
|
result = self.simple_generation(prompts[0], width, height) |
|
|
else: |
|
|
|
|
|
result = self.prompt_refinement_edit(prompts[0], width, height, input_svg) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in handler: {e}") |
|
|
|
|
|
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height) |
|
|
return { |
|
|
"svg": fallback_svg, |
|
|
"svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'), |
|
|
"edit_type": edit_type, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None): |
|
|
"""Perform word replacement editing""" |
|
|
try: |
|
|
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'") |
|
|
|
|
|
|
|
|
source_words = set(source_prompt.lower().split()) |
|
|
target_words = set(target_prompt.lower().split()) |
|
|
|
|
|
added_words = target_words - source_words |
|
|
removed_words = source_words - target_words |
|
|
|
|
|
print(f"Added words: {added_words}, Removed words: {removed_words}") |
|
|
|
|
|
|
|
|
if input_svg: |
|
|
base_svg = input_svg |
|
|
else: |
|
|
base_svg = self.generate_base_svg(source_prompt, width, height) |
|
|
|
|
|
|
|
|
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height) |
|
|
|
|
|
return { |
|
|
"svg": edited_svg, |
|
|
"svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'), |
|
|
"edit_type": "replace", |
|
|
"source_prompt": source_prompt, |
|
|
"target_prompt": target_prompt, |
|
|
"added_words": list(added_words), |
|
|
"removed_words": list(removed_words) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in word_replacement_edit: {e}") |
|
|
return self.create_error_result(source_prompt, "replace", str(e), width, height) |
|
|
|
|
|
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None): |
|
|
"""Perform prompt refinement editing""" |
|
|
try: |
|
|
print(f"Prompt refinement for: '{prompt}'") |
|
|
|
|
|
|
|
|
if input_svg: |
|
|
base_svg = input_svg |
|
|
else: |
|
|
base_svg = self.generate_base_svg(prompt, width, height) |
|
|
|
|
|
|
|
|
refined_svg = self.apply_refinement(base_svg, prompt, width, height) |
|
|
|
|
|
return { |
|
|
"svg": refined_svg, |
|
|
"svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'), |
|
|
"edit_type": "refine", |
|
|
"prompt": prompt |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in prompt_refinement_edit: {e}") |
|
|
return self.create_error_result(prompt, "refine", str(e), width, height) |
|
|
|
|
|
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None): |
|
|
"""Perform attention reweighting editing""" |
|
|
try: |
|
|
print(f"Attention reweighting for: '{prompt}'") |
|
|
|
|
|
|
|
|
weighted_prompt, attention_weights = self.parse_attention_weights(prompt) |
|
|
|
|
|
|
|
|
if input_svg: |
|
|
base_svg = input_svg |
|
|
else: |
|
|
base_svg = self.generate_base_svg(weighted_prompt, width, height) |
|
|
|
|
|
|
|
|
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height) |
|
|
|
|
|
return { |
|
|
"svg": reweighted_svg, |
|
|
"svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'), |
|
|
"edit_type": "reweight", |
|
|
"prompt": prompt, |
|
|
"weighted_prompt": weighted_prompt, |
|
|
"attention_weights": attention_weights |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in attention_reweighting_edit: {e}") |
|
|
return self.create_error_result(prompt, "reweight", str(e), width, height) |
|
|
|
|
|
def simple_generation(self, prompt: str, width: int, height: int): |
|
|
"""Perform simple SVG generation""" |
|
|
try: |
|
|
print(f"Simple generation for: '{prompt}'") |
|
|
|
|
|
svg_content = self.generate_base_svg(prompt, width, height) |
|
|
|
|
|
return { |
|
|
"svg": svg_content, |
|
|
"svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'), |
|
|
"edit_type": "generate", |
|
|
"prompt": prompt |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in simple_generation: {e}") |
|
|
return self.create_error_result(prompt, "generate", str(e), width, height) |
|
|
|
|
|
def generate_base_svg(self, prompt: str, width: int, height: int): |
|
|
"""Generate base SVG from prompt""" |
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
if any(word in prompt_lower for word in ['house', 'building', 'home']): |
|
|
self._add_house_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['tree', 'forest', 'nature']): |
|
|
self._add_tree_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['car', 'vehicle', 'transport']): |
|
|
self._add_car_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['face', 'person', 'portrait']): |
|
|
self._add_face_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']): |
|
|
self._add_flower_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['cat', 'dog', 'animal']): |
|
|
self._add_animal_elements(dwg, width, height, prompt_lower) |
|
|
else: |
|
|
self._add_abstract_elements(dwg, width, height, prompt) |
|
|
|
|
|
return dwg.tostring() |
|
|
|
|
|
def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, added_words: set, removed_words: set, width: int, height: int): |
|
|
"""Apply word replacement transformations to SVG""" |
|
|
|
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
|
|
|
for word in added_words: |
|
|
if word in ['red', 'blue', 'green', 'yellow', 'purple']: |
|
|
self._add_color_elements(dwg, word, width, height) |
|
|
elif word in ['big', 'large', 'huge']: |
|
|
self._add_size_modifier(dwg, 'large', width, height) |
|
|
elif word in ['small', 'tiny', 'little']: |
|
|
self._add_size_modifier(dwg, 'small', width, height) |
|
|
elif word in ['cat', 'dog', 'bird']: |
|
|
self._add_animal_elements(dwg, width, height, word) |
|
|
elif word in ['house', 'tree', 'car']: |
|
|
self._add_object_elements(dwg, word, width, height) |
|
|
|
|
|
|
|
|
target_lower = target_prompt.lower() |
|
|
if any(word in target_lower for word in ['house', 'building']): |
|
|
self._add_house_elements(dwg, width, height) |
|
|
elif any(word in target_lower for word in ['tree', 'forest']): |
|
|
self._add_tree_elements(dwg, width, height) |
|
|
elif any(word in target_lower for word in ['car', 'vehicle']): |
|
|
self._add_car_elements(dwg, width, height) |
|
|
|
|
|
return dwg.tostring() |
|
|
|
|
|
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int): |
|
|
"""Apply refinement to existing SVG""" |
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
|
|
|
if 'detailed' in prompt_lower or 'complex' in prompt_lower: |
|
|
self._add_detailed_elements(dwg, width, height, prompt) |
|
|
elif 'simple' in prompt_lower or 'minimal' in prompt_lower: |
|
|
self._add_simple_elements(dwg, width, height, prompt) |
|
|
else: |
|
|
|
|
|
self._add_standard_elements(dwg, width, height, prompt) |
|
|
|
|
|
return dwg.tostring() |
|
|
|
|
|
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int): |
|
|
"""Apply attention reweighting to SVG elements""" |
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
|
|
|
for word, weight in attention_weights.items(): |
|
|
if weight > 1.0: |
|
|
|
|
|
self._emphasize_element(dwg, word, weight, width, height) |
|
|
elif weight < 1.0: |
|
|
|
|
|
self._deemphasize_element(dwg, word, weight, width, height) |
|
|
|
|
|
|
|
|
self._add_standard_elements(dwg, width, height, prompt) |
|
|
|
|
|
return dwg.tostring() |
|
|
|
|
|
def parse_attention_weights(self, prompt: str) -> Tuple[str, dict]: |
|
|
"""Parse attention weights from prompt""" |
|
|
import re |
|
|
|
|
|
|
|
|
pattern = r'[\(\[]([^:\)\]]+):([0-9\.]+)[\)\]]' |
|
|
matches = re.findall(pattern, prompt) |
|
|
|
|
|
attention_weights = {} |
|
|
clean_prompt = prompt |
|
|
|
|
|
for word, weight_str in matches: |
|
|
try: |
|
|
weight = float(weight_str) |
|
|
attention_weights[word.strip()] = weight |
|
|
|
|
|
clean_prompt = re.sub(rf'[\(\[]{re.escape(word)}:{re.escape(weight_str)}[\)\]]', word, clean_prompt) |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
return clean_prompt.strip(), attention_weights |
|
|
|
|
|
def _add_house_elements(self, dwg, width, height): |
|
|
"""Add house elements to SVG""" |
|
|
house_width = width * 0.6 |
|
|
house_height = height * 0.4 |
|
|
house_x = (width - house_width) / 2 |
|
|
house_y = height * 0.4 |
|
|
|
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(house_x, house_y), |
|
|
size=(house_width, house_height), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
roof_points = [ |
|
|
(house_x, house_y), |
|
|
(house_x + house_width/2, house_y - house_height*0.3), |
|
|
(house_x + house_width, house_y) |
|
|
] |
|
|
dwg.add(dwg.polygon(roof_points, fill='none', stroke='black', stroke_width=2)) |
|
|
|
|
|
|
|
|
door_width = house_width * 0.2 |
|
|
door_height = house_height * 0.6 |
|
|
door_x = house_x + (house_width - door_width) / 2 |
|
|
door_y = house_y + house_height - door_height |
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(door_x, door_y), |
|
|
size=(door_width, door_height), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_tree_elements(self, dwg, width, height): |
|
|
"""Add tree elements to SVG""" |
|
|
center_x = width / 2 |
|
|
center_y = height / 2 |
|
|
|
|
|
|
|
|
trunk_width = 12 |
|
|
trunk_height = height * 0.3 |
|
|
dwg.add(dwg.rect( |
|
|
insert=(center_x - trunk_width/2, center_y + 20), |
|
|
size=(trunk_width, trunk_height), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
crown_radius = width * 0.25 |
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x, center_y), |
|
|
r=crown_radius, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_car_elements(self, dwg, width, height): |
|
|
"""Add car elements to SVG""" |
|
|
car_width = width * 0.7 |
|
|
car_height = height * 0.3 |
|
|
car_x = (width - car_width) / 2 |
|
|
car_y = (height - car_height) / 2 |
|
|
|
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(car_x, car_y), |
|
|
size=(car_width, car_height), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2, |
|
|
rx=5 |
|
|
)) |
|
|
|
|
|
|
|
|
wheel_radius = car_height * 0.4 |
|
|
wheel_y = car_y + car_height - wheel_radius/2 |
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(car_x + car_width * 0.2, wheel_y), |
|
|
r=wheel_radius, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
dwg.add(dwg.circle( |
|
|
center=(car_x + car_width * 0.8, wheel_y), |
|
|
r=wheel_radius, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_face_elements(self, dwg, width, height): |
|
|
"""Add face elements to SVG""" |
|
|
center_x = width / 2 |
|
|
center_y = height / 2 |
|
|
face_radius = min(width, height) * 0.3 |
|
|
|
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x, center_y), |
|
|
r=face_radius, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
eye_offset = face_radius * 0.3 |
|
|
eye_radius = face_radius * 0.1 |
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x - eye_offset, center_y - eye_offset), |
|
|
r=eye_radius, |
|
|
fill='black' |
|
|
)) |
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x + eye_offset, center_y - eye_offset), |
|
|
r=eye_radius, |
|
|
fill='black' |
|
|
)) |
|
|
|
|
|
|
|
|
mouth_y = center_y + face_radius * 0.3 |
|
|
dwg.add(dwg.path( |
|
|
d=f"M {center_x - face_radius*0.3},{mouth_y} Q {center_x},{mouth_y + face_radius*0.2} {center_x + face_radius*0.3},{mouth_y}", |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_flower_elements(self, dwg, width, height): |
|
|
"""Add flower elements to SVG""" |
|
|
center_x = width / 2 |
|
|
center_y = height / 2 |
|
|
|
|
|
|
|
|
dwg.add(dwg.line( |
|
|
start=(center_x, center_y + 20), |
|
|
end=(center_x, height - 20), |
|
|
stroke='green', |
|
|
stroke_width=4 |
|
|
)) |
|
|
|
|
|
|
|
|
petal_radius = 15 |
|
|
for angle in range(0, 360, 45): |
|
|
x = center_x + 25 * math.cos(math.radians(angle)) |
|
|
y = center_y + 25 * math.sin(math.radians(angle)) |
|
|
dwg.add(dwg.circle( |
|
|
center=(x, y), |
|
|
r=petal_radius, |
|
|
fill='none', |
|
|
stroke='red', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x, center_y), |
|
|
r=8, |
|
|
fill='yellow', |
|
|
stroke='orange', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_animal_elements(self, dwg, width, height, animal_type): |
|
|
"""Add animal elements to SVG""" |
|
|
center_x = width / 2 |
|
|
center_y = height / 2 |
|
|
|
|
|
if 'cat' in animal_type: |
|
|
|
|
|
dwg.add(dwg.ellipse( |
|
|
center=(center_x, center_y + 20), |
|
|
r=(30, 20), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x, center_y - 20), |
|
|
r=25, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)] |
|
|
ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)] |
|
|
dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2)) |
|
|
dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2)) |
|
|
|
|
|
elif 'dog' in animal_type: |
|
|
|
|
|
dwg.add(dwg.ellipse( |
|
|
center=(center_x, center_y + 10), |
|
|
r=(40, 25), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
dwg.add(dwg.ellipse( |
|
|
center=(center_x, center_y - 25), |
|
|
r=(25, 20), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_color_elements(self, dwg, color, width, height): |
|
|
"""Add color-specific elements""" |
|
|
color_map = { |
|
|
'red': '#FF0000', |
|
|
'blue': '#0000FF', |
|
|
'green': '#00FF00', |
|
|
'yellow': '#FFFF00', |
|
|
'purple': '#800080' |
|
|
} |
|
|
|
|
|
fill_color = color_map.get(color, '#000000') |
|
|
|
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(width * 0.8, height * 0.2), |
|
|
r=15, |
|
|
fill=fill_color, |
|
|
stroke='black', |
|
|
stroke_width=1 |
|
|
)) |
|
|
|
|
|
def _add_size_modifier(self, dwg, size_type, width, height): |
|
|
"""Add size modification indicators""" |
|
|
if size_type == 'large': |
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(10, 10), |
|
|
size=(width-20, height-20), |
|
|
fill='none', |
|
|
stroke='gray', |
|
|
stroke_width=3, |
|
|
stroke_dasharray='5,5' |
|
|
)) |
|
|
elif size_type == 'small': |
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(width*0.3, height*0.3), |
|
|
size=(width*0.4, height*0.4), |
|
|
fill='none', |
|
|
stroke='gray', |
|
|
stroke_width=1, |
|
|
stroke_dasharray='2,2' |
|
|
)) |
|
|
|
|
|
def _add_object_elements(self, dwg, obj_type, width, height): |
|
|
"""Add specific object elements""" |
|
|
if obj_type == 'house': |
|
|
self._add_house_elements(dwg, width, height) |
|
|
elif obj_type == 'tree': |
|
|
self._add_tree_elements(dwg, width, height) |
|
|
elif obj_type == 'car': |
|
|
self._add_car_elements(dwg, width, height) |
|
|
|
|
|
def _add_detailed_elements(self, dwg, width, height, prompt): |
|
|
"""Add detailed elements for complex prompts""" |
|
|
|
|
|
for i in range(8): |
|
|
x = random.randint(20, width-40) |
|
|
y = random.randint(20, height-40) |
|
|
size = random.randint(10, 30) |
|
|
|
|
|
shape_type = random.choice(['circle', 'rect', 'polygon']) |
|
|
|
|
|
if shape_type == 'circle': |
|
|
dwg.add(dwg.circle( |
|
|
center=(x, y), |
|
|
r=size, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=1, |
|
|
opacity=0.7 |
|
|
)) |
|
|
elif shape_type == 'rect': |
|
|
dwg.add(dwg.rect( |
|
|
insert=(x-size, y-size), |
|
|
size=(size*2, size*2), |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=1, |
|
|
opacity=0.7 |
|
|
)) |
|
|
|
|
|
def _add_simple_elements(self, dwg, width, height, prompt): |
|
|
"""Add simple elements for minimal prompts""" |
|
|
|
|
|
center_x = width / 2 |
|
|
center_y = height / 2 |
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x, center_y), |
|
|
r=min(width, height) * 0.2, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_standard_elements(self, dwg, width, height, prompt): |
|
|
"""Add standard elements based on prompt""" |
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
if any(word in prompt_lower for word in ['house', 'building']): |
|
|
self._add_house_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['tree', 'forest']): |
|
|
self._add_tree_elements(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['car', 'vehicle']): |
|
|
self._add_car_elements(dwg, width, height) |
|
|
else: |
|
|
self._add_abstract_elements(dwg, width, height, prompt) |
|
|
|
|
|
def _add_abstract_elements(self, dwg, width, height, prompt): |
|
|
"""Add abstract elements based on prompt""" |
|
|
prompt_hash = hash(prompt) % 100 |
|
|
|
|
|
for i in range(5): |
|
|
x = (i * 40 + prompt_hash) % (width - 40) + 20 |
|
|
y = (i * 35 + prompt_hash) % (height - 40) + 20 |
|
|
size = 15 + (i * 5) % 20 |
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(x, y), |
|
|
r=size, |
|
|
fill='none', |
|
|
stroke='black', |
|
|
stroke_width=2, |
|
|
opacity=0.8 |
|
|
)) |
|
|
|
|
|
def _emphasize_element(self, dwg, word, weight, width, height): |
|
|
"""Emphasize an element based on attention weight""" |
|
|
|
|
|
scale_factor = weight |
|
|
stroke_width = int(2 * scale_factor) |
|
|
|
|
|
if word in ['house', 'building']: |
|
|
|
|
|
house_size = min(width, height) * 0.4 * scale_factor |
|
|
house_x = (width - house_size) / 2 |
|
|
house_y = (height - house_size) / 2 |
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(house_x, house_y), |
|
|
size=(house_size, house_size * 0.8), |
|
|
fill='none', |
|
|
stroke='red', |
|
|
stroke_width=stroke_width |
|
|
)) |
|
|
|
|
|
def _deemphasize_element(self, dwg, word, weight, width, height): |
|
|
"""De-emphasize an element based on attention weight""" |
|
|
|
|
|
scale_factor = weight |
|
|
stroke_width = max(1, int(2 * scale_factor)) |
|
|
|
|
|
if word in ['background', 'sky']: |
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(0, 0), |
|
|
size=(width, height * 0.3), |
|
|
fill='none', |
|
|
stroke='lightgray', |
|
|
stroke_width=stroke_width, |
|
|
opacity=scale_factor |
|
|
)) |
|
|
|
|
|
def create_error_result(self, prompt: str, edit_type: str, error: str, width: int, height: int): |
|
|
"""Create error result with fallback SVG""" |
|
|
fallback_svg = self.create_fallback_svg(prompt, width, height) |
|
|
return { |
|
|
"svg": fallback_svg, |
|
|
"svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'), |
|
|
"edit_type": edit_type, |
|
|
"prompt": prompt, |
|
|
"error": error |
|
|
} |
|
|
|
|
|
def create_fallback_svg(self, prompt: str, width: int, height: int): |
|
|
"""Create simple fallback SVG""" |
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
|
|
|
dwg.add(dwg.text( |
|
|
f"DiffSketchEdit\n{prompt[:20]}...", |
|
|
insert=(width/2, height/2), |
|
|
text_anchor="middle", |
|
|
font_size="14", |
|
|
fill="black" |
|
|
)) |
|
|
|
|
|
return dwg.tostring() |