|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
from PIL import Image |
|
|
import svgwrite |
|
|
import numpy as np |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
import random |
|
|
import math |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
"""Initialize DiffSketcher handler for Hugging Face Inference API""" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.pipe = StableDiffusionPipeline.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
) |
|
|
self.pipe = self.pipe.to(self.device) |
|
|
print("Stable Diffusion pipeline loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading pipeline: {e}") |
|
|
self.pipe = None |
|
|
|
|
|
|
|
|
try: |
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.text_encoder = self.text_encoder.to(self.device) |
|
|
print("Text encoder loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading text encoder: {e}") |
|
|
self.tokenizer = None |
|
|
self.text_encoder = None |
|
|
|
|
|
def __call__(self, data): |
|
|
"""Generate SVG sketch from text prompt""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
prompt = inputs.get("prompt", inputs.get("text", "")) |
|
|
else: |
|
|
prompt = str(inputs) |
|
|
|
|
|
if not prompt: |
|
|
prompt = "a simple sketch" |
|
|
|
|
|
|
|
|
num_paths = parameters.get("num_paths", 96) |
|
|
num_iter = parameters.get("num_iter", 500) |
|
|
guidance_scale = parameters.get("guidance_scale", 7.5) |
|
|
width = parameters.get("width", 224) |
|
|
height = parameters.get("height", 224) |
|
|
seed = parameters.get("seed", 42) |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
print(f"Generating SVG for prompt: '{prompt}' with {num_paths} paths") |
|
|
|
|
|
|
|
|
svg_content = self.generate_svg_sketch( |
|
|
prompt, num_paths, num_iter, guidance_scale, width, height |
|
|
) |
|
|
|
|
|
|
|
|
svg_base64 = base64.b64encode(svg_content.encode('utf-8')).decode('utf-8') |
|
|
|
|
|
|
|
|
result = { |
|
|
"svg": svg_content, |
|
|
"svg_base64": svg_base64, |
|
|
"prompt": prompt, |
|
|
"parameters": { |
|
|
"num_paths": num_paths, |
|
|
"num_iter": num_iter, |
|
|
"guidance_scale": guidance_scale, |
|
|
"width": width, |
|
|
"height": height, |
|
|
"seed": seed |
|
|
} |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in handler: {e}") |
|
|
|
|
|
fallback_svg = self.create_fallback_svg(prompt, width, height) |
|
|
return { |
|
|
"svg": fallback_svg, |
|
|
"svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'), |
|
|
"prompt": prompt, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
def generate_svg_sketch(self, prompt, num_paths, num_iter, guidance_scale, width, height): |
|
|
"""Generate SVG sketch using simplified DiffSketcher approach""" |
|
|
try: |
|
|
|
|
|
text_embeddings = self.get_text_embeddings(prompt) |
|
|
|
|
|
|
|
|
attention_maps = self.generate_attention_maps(prompt, width, height) |
|
|
|
|
|
|
|
|
paths = self.initialize_paths_from_attention(attention_maps, num_paths, width, height) |
|
|
|
|
|
|
|
|
optimized_paths = self.optimize_paths(paths, text_embeddings, num_iter, guidance_scale) |
|
|
|
|
|
|
|
|
svg_content = self.create_svg_from_paths(optimized_paths, width, height) |
|
|
|
|
|
return svg_content |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in generate_svg_sketch: {e}") |
|
|
return self.create_fallback_svg(prompt, width, height) |
|
|
|
|
|
def get_text_embeddings(self, prompt): |
|
|
"""Get text embeddings from CLIP""" |
|
|
if self.tokenizer is None or self.text_encoder is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self.text_encoder(**inputs).last_hidden_state |
|
|
|
|
|
return embeddings |
|
|
except Exception as e: |
|
|
print(f"Error getting text embeddings: {e}") |
|
|
return None |
|
|
|
|
|
def generate_attention_maps(self, prompt, width, height): |
|
|
"""Generate simplified attention maps""" |
|
|
|
|
|
attention_map = np.zeros((height, width)) |
|
|
|
|
|
|
|
|
keywords = prompt.lower().split() |
|
|
|
|
|
for i, keyword in enumerate(keywords[:5]): |
|
|
|
|
|
center_x = (i + 1) * width // (len(keywords) + 1) |
|
|
center_y = height // 2 |
|
|
|
|
|
|
|
|
y, x = np.ogrid[:height, :width] |
|
|
mask = ((x - center_x) ** 2 + (y - center_y) ** 2) < (min(width, height) // 4) ** 2 |
|
|
attention_map[mask] += 1.0 |
|
|
|
|
|
|
|
|
if attention_map.max() > 0: |
|
|
attention_map = attention_map / attention_map.max() |
|
|
|
|
|
return attention_map |
|
|
|
|
|
def initialize_paths_from_attention(self, attention_map, num_paths, width, height): |
|
|
"""Initialize SVG paths based on attention maps""" |
|
|
paths = [] |
|
|
|
|
|
|
|
|
threshold = 0.3 |
|
|
high_attention = attention_map > threshold |
|
|
|
|
|
if not np.any(high_attention): |
|
|
|
|
|
return self.create_random_paths(num_paths, width, height) |
|
|
|
|
|
|
|
|
y_coords, x_coords = np.where(high_attention) |
|
|
|
|
|
for i in range(num_paths): |
|
|
if len(x_coords) > 0: |
|
|
|
|
|
idx = np.random.choice(len(x_coords), size=min(4, len(x_coords)), replace=False) |
|
|
path_points = [(x_coords[j], y_coords[j]) for j in idx] |
|
|
|
|
|
|
|
|
path_points.sort(key=lambda p: p[0]) |
|
|
|
|
|
paths.append(path_points) |
|
|
else: |
|
|
|
|
|
paths.append(self.create_single_random_path(width, height)) |
|
|
|
|
|
return paths |
|
|
|
|
|
def create_random_paths(self, num_paths, width, height): |
|
|
"""Create random paths as fallback""" |
|
|
paths = [] |
|
|
for i in range(num_paths): |
|
|
paths.append(self.create_single_random_path(width, height)) |
|
|
return paths |
|
|
|
|
|
def create_single_random_path(self, width, height): |
|
|
"""Create a single random path""" |
|
|
num_points = random.randint(3, 6) |
|
|
points = [] |
|
|
for _ in range(num_points): |
|
|
x = random.randint(0, width) |
|
|
y = random.randint(0, height) |
|
|
points.append((x, y)) |
|
|
return points |
|
|
|
|
|
def optimize_paths(self, paths, text_embeddings, num_iter, guidance_scale): |
|
|
"""Simplified path optimization""" |
|
|
|
|
|
optimized_paths = [] |
|
|
|
|
|
for path in paths: |
|
|
if len(path) < 2: |
|
|
optimized_paths.append(path) |
|
|
continue |
|
|
|
|
|
|
|
|
smoothed_path = [] |
|
|
for i in range(len(path)): |
|
|
if i == 0 or i == len(path) - 1: |
|
|
smoothed_path.append(path[i]) |
|
|
else: |
|
|
|
|
|
prev_x, prev_y = path[i-1] |
|
|
curr_x, curr_y = path[i] |
|
|
next_x, next_y = path[i+1] |
|
|
|
|
|
smooth_x = (prev_x + curr_x + next_x) / 3 |
|
|
smooth_y = (prev_y + curr_y + next_y) / 3 |
|
|
|
|
|
smoothed_path.append((smooth_x, smooth_y)) |
|
|
|
|
|
optimized_paths.append(smoothed_path) |
|
|
|
|
|
return optimized_paths |
|
|
|
|
|
def create_svg_from_paths(self, paths, width, height): |
|
|
"""Create SVG content from optimized paths""" |
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
|
|
|
|
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
|
|
|
for i, path in enumerate(paths): |
|
|
if len(path) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
path_str = f"M {path[0][0]},{path[0][1]}" |
|
|
for point in path[1:]: |
|
|
path_str += f" L {point[0]},{point[1]}" |
|
|
|
|
|
|
|
|
stroke_width = random.uniform(0.5, 3.0) |
|
|
stroke_color = f"rgb({random.randint(0, 100)},{random.randint(0, 100)},{random.randint(0, 100)})" |
|
|
|
|
|
dwg.add(dwg.path( |
|
|
d=path_str, |
|
|
stroke=stroke_color, |
|
|
stroke_width=stroke_width, |
|
|
fill='none', |
|
|
stroke_linecap='round', |
|
|
stroke_linejoin='round' |
|
|
)) |
|
|
|
|
|
return dwg.tostring() |
|
|
|
|
|
def create_fallback_svg(self, prompt, width=224, height=224): |
|
|
"""Create a simple fallback SVG""" |
|
|
dwg = svgwrite.Drawing(size=(width, height)) |
|
|
|
|
|
|
|
|
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
|
|
|
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
if any(word in prompt_lower for word in ['mountain', 'landscape']): |
|
|
self._add_mountain_sketch(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['house', 'building']): |
|
|
self._add_house_sketch(dwg, width, height) |
|
|
elif any(word in prompt_lower for word in ['flower', 'plant']): |
|
|
self._add_flower_sketch(dwg, width, height) |
|
|
else: |
|
|
self._add_abstract_sketch(dwg, width, height, prompt) |
|
|
|
|
|
return dwg.tostring() |
|
|
|
|
|
def _add_mountain_sketch(self, dwg, width, height): |
|
|
"""Add mountain sketch to SVG""" |
|
|
|
|
|
points = [(0, height*0.7)] |
|
|
for x in range(0, width, 20): |
|
|
y = height * 0.7 + 30 * math.sin(x * 0.02) + 15 * math.sin(x * 0.05) |
|
|
points.append((x, y)) |
|
|
points.append((width, height)) |
|
|
points.append((0, height)) |
|
|
|
|
|
dwg.add(dwg.polygon(points, fill='lightgray', stroke='black', stroke_width=2)) |
|
|
|
|
|
def _add_house_sketch(self, dwg, width, height): |
|
|
"""Add house sketch to SVG""" |
|
|
|
|
|
house_width = width * 0.6 |
|
|
house_height = height * 0.4 |
|
|
house_x = (width - house_width) / 2 |
|
|
house_y = height * 0.4 |
|
|
|
|
|
dwg.add(dwg.rect( |
|
|
insert=(house_x, house_y), |
|
|
size=(house_width, house_height), |
|
|
fill='lightblue', |
|
|
stroke='black', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
|
|
|
roof_points = [ |
|
|
(house_x, house_y), |
|
|
(house_x + house_width/2, house_y - house_height*0.3), |
|
|
(house_x + house_width, house_y) |
|
|
] |
|
|
dwg.add(dwg.polygon(roof_points, fill='red', stroke='black', stroke_width=2)) |
|
|
|
|
|
def _add_flower_sketch(self, dwg, width, height): |
|
|
"""Add flower sketch to SVG""" |
|
|
center_x, center_y = width/2, height/2 |
|
|
|
|
|
|
|
|
dwg.add(dwg.line( |
|
|
start=(center_x, center_y + 20), |
|
|
end=(center_x, height - 20), |
|
|
stroke='green', |
|
|
stroke_width=4 |
|
|
)) |
|
|
|
|
|
|
|
|
for angle in range(0, 360, 45): |
|
|
x = center_x + 25 * math.cos(math.radians(angle)) |
|
|
y = center_y + 25 * math.sin(math.radians(angle)) |
|
|
dwg.add(dwg.circle( |
|
|
center=(x, y), |
|
|
r=8, |
|
|
fill='pink', |
|
|
stroke='red', |
|
|
stroke_width=1 |
|
|
)) |
|
|
|
|
|
|
|
|
dwg.add(dwg.circle( |
|
|
center=(center_x, center_y), |
|
|
r=8, |
|
|
fill='yellow', |
|
|
stroke='orange', |
|
|
stroke_width=2 |
|
|
)) |
|
|
|
|
|
def _add_abstract_sketch(self, dwg, width, height, prompt): |
|
|
"""Add abstract sketch to SVG""" |
|
|
|
|
|
prompt_hash = hash(prompt) % 100 |
|
|
|
|
|
for i in range(8): |
|
|
points = [] |
|
|
start_x = (i * 30 + prompt_hash) % (width - 40) + 20 |
|
|
start_y = (i * 25 + prompt_hash) % (height - 40) + 20 |
|
|
|
|
|
for j in range(4): |
|
|
x = start_x + j * 25 + 15 * math.sin((i + j + prompt_hash) * 0.5) |
|
|
y = start_y + j * 20 + 15 * math.cos((i + j + prompt_hash) * 0.3) |
|
|
points.append((max(0, min(width, x)), max(0, min(height, y)))) |
|
|
|
|
|
|
|
|
if len(points) > 1: |
|
|
path_str = f"M {points[0][0]},{points[0][1]}" |
|
|
for point in points[1:]: |
|
|
path_str += f" L {point[0]},{point[1]}" |
|
|
|
|
|
color_val = (i * 30) % 200 + 50 |
|
|
dwg.add(dwg.path( |
|
|
d=path_str, |
|
|
stroke=f"rgb({color_val},{color_val//2},{color_val//3})", |
|
|
stroke_width=2, |
|
|
fill='none', |
|
|
stroke_linecap='round' |
|
|
)) |