import base64
import json
import math
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
"""Initialize the DiffSketcher model"""
print("DiffSketcher handler initialized")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Generate SVG using DiffSketcher style"""
try:
# Extract inputs
if isinstance(data, dict):
prompt = data.get("inputs", "")
parameters = data.get("parameters", {})
else:
prompt = str(data)
parameters = {}
if not prompt:
return {"error": "No prompt provided"}
# Extract parameters
num_paths = parameters.get("num_paths", 16)
width = parameters.get("width", 512)
height = parameters.get("height", 512)
# Generate SVG content
svg_content = self.generate_diffsketcher_svg(prompt, num_paths, width, height)
# Encode as base64
svg_base64 = base64.b64encode(svg_content.encode('utf-8')).decode('utf-8')
return {
"svg_content": svg_content,
"svg_base64": svg_base64,
"model": "DiffSketcher",
"prompt": prompt,
"parameters": {
"num_paths": num_paths,
"width": width,
"height": height
}
}
except Exception as e:
return {"error": f"Generation failed: {str(e)}"}
def generate_diffsketcher_svg(self, prompt, num_paths, width, height):
"""Generate SVG in DiffSketcher style (painterly, sketchy)"""
svg_parts = [
f'')
return ''.join(svg_parts)
def _draw_cat_sketch(self, cx, cy):
"""Draw a sketchy cat"""
return [
f'',
f'',
f'',
f'',
f'',
f'',
f'',
f'',
f'',
]
def _draw_flower_sketch(self, cx, cy):
"""Draw a sketchy flower"""
petals = []
for i in range(8):
angle = i * 45
petal_x = cx + 50 * math.cos(math.radians(angle))
petal_y = cy + 50 * math.sin(math.radians(angle))
petals.append(f'')
return petals + [
f'',
f'',
f'',
f'',
]
def _draw_house_sketch(self, cx, cy):
"""Draw a sketchy house"""
return [
f'',
f'',
f'',
f'',
f'',
]
def _draw_abstract_sketch(self, cx, cy, num_paths):
"""Draw abstract sketchy shapes"""
import random
random.seed(42) # For consistent results
shapes = []
colors = ["red", "blue", "green", "orange", "purple", "pink", "yellow"]
for i in range(min(num_paths, 12)):
x = cx + random.randint(-150, 150)
y = cy + random.randint(-150, 150)
r = random.randint(20, 60)
color = random.choice(colors)
if i % 3 == 0:
shapes.append(f'')
elif i % 3 == 1:
shapes.append(f'')
else:
points = f"{x},{y-r} {x+r},{y+r} {x-r},{y+r}"
shapes.append(f'')
return shapes