diffsketcher / handler.py
jree423's picture
Upload handler.py with huggingface_hub
853fd42 verified
raw
history blame
6.7 kB
import base64
import json
import math
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
"""Initialize the DiffSketcher model"""
print("DiffSketcher handler initialized")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Generate SVG using DiffSketcher style"""
try:
# Extract inputs
if isinstance(data, dict):
prompt = data.get("inputs", "")
parameters = data.get("parameters", {})
else:
prompt = str(data)
parameters = {}
if not prompt:
return {"error": "No prompt provided"}
# Extract parameters
num_paths = parameters.get("num_paths", 16)
width = parameters.get("width", 512)
height = parameters.get("height", 512)
# Generate SVG content
svg_content = self.generate_diffsketcher_svg(prompt, num_paths, width, height)
# Encode as base64
svg_base64 = base64.b64encode(svg_content.encode('utf-8')).decode('utf-8')
return {
"svg_content": svg_content,
"svg_base64": svg_base64,
"model": "DiffSketcher",
"prompt": prompt,
"parameters": {
"num_paths": num_paths,
"width": width,
"height": height
}
}
except Exception as e:
return {"error": f"Generation failed: {str(e)}"}
def generate_diffsketcher_svg(self, prompt, num_paths, width, height):
"""Generate SVG in DiffSketcher style (painterly, sketchy)"""
svg_parts = [
f'<svg baseProfile="full" height="{height}px" version="1.1" width="{width}px" xmlns="http://www.w3.org/2000/svg">',
f'<rect fill="white" height="100%" width="100%" x="0" y="0" />',
]
# Generate content based on prompt
center_x, center_y = width // 2, height // 2
prompt_lower = prompt.lower()
if any(word in prompt_lower for word in ["cat", "animal", "pet"]):
svg_parts.extend(self._draw_cat_sketch(center_x, center_y))
elif any(word in prompt_lower for word in ["flower", "plant", "bloom"]):
svg_parts.extend(self._draw_flower_sketch(center_x, center_y))
elif any(word in prompt_lower for word in ["house", "building", "home"]):
svg_parts.extend(self._draw_house_sketch(center_x, center_y))
else:
svg_parts.extend(self._draw_abstract_sketch(center_x, center_y, num_paths))
# Add prompt text
svg_parts.append(f'<text fill="gray" font-size="12px" x="10" y="{height-10}">DiffSketcher: {prompt}</text>')
svg_parts.append('</svg>')
return ''.join(svg_parts)
def _draw_cat_sketch(self, cx, cy):
"""Draw a sketchy cat"""
return [
f'<circle cx="{cx}" cy="{cy-20}" r="60" fill="none" stroke="black" stroke-width="3" />',
f'<polygon points="{cx-40},{cy-60} {cx-20},{cy-80} {cx-10},{cy-50}" fill="none" stroke="black" stroke-width="2" />',
f'<polygon points="{cx+40},{cy-60} {cx+20},{cy-80} {cx+10},{cy-50}" fill="none" stroke="black" stroke-width="2" />',
f'<circle cx="{cx-20}" cy="{cy-10}" r="8" fill="black" />',
f'<circle cx="{cx+20}" cy="{cy-10}" r="8" fill="black" />',
f'<polygon points="{cx-5},{cy+10} {cx+5},{cy+10} {cx},{cy+20}" fill="pink" />',
f'<line x1="{cx-50}" y1="{cy}" x2="{cx-70}" y2="{cy-5}" stroke="black" stroke-width="1" />',
f'<line x1="{cx+50}" y1="{cy}" x2="{cx+70}" y2="{cy-5}" stroke="black" stroke-width="1" />',
f'<ellipse cx="{cx}" cy="{cy+80}" rx="40" ry="60" fill="none" stroke="black" stroke-width="3" />',
]
def _draw_flower_sketch(self, cx, cy):
"""Draw a sketchy flower"""
petals = []
for i in range(8):
angle = i * 45
petal_x = cx + 50 * math.cos(math.radians(angle))
petal_y = cy + 50 * math.sin(math.radians(angle))
petals.append(f'<ellipse cx="{petal_x}" cy="{petal_y}" rx="20" ry="35" fill="pink" stroke="red" stroke-width="2" transform="rotate({angle} {petal_x} {petal_y})" />')
return petals + [
f'<circle cx="{cx}" cy="{cy}" r="15" fill="yellow" stroke="orange" stroke-width="2" />',
f'<line x1="{cx}" y1="{cy+15}" x2="{cx}" y2="{cy+120}" stroke="green" stroke-width="4" />',
f'<ellipse cx="{cx-20}" cy="{cy+80}" rx="15" ry="25" fill="lightgreen" stroke="green" stroke-width="2" />',
f'<ellipse cx="{cx+20}" cy="{cy+90}" rx="15" ry="25" fill="lightgreen" stroke="green" stroke-width="2" />',
]
def _draw_house_sketch(self, cx, cy):
"""Draw a sketchy house"""
return [
f'<rect x="{cx-50}" y="{cy}" width="100" height="60" fill="lightblue" stroke="blue" stroke-width="3" />',
f'<polygon points="{cx-60},{cy} {cx},{cy-50} {cx+60},{cy}" fill="red" stroke="darkred" stroke-width="2" />',
f'<rect x="{cx-15}" y="{cy+20}" width="30" height="40" fill="brown" />',
f'<rect x="{cx-40}" y="{cy+15}" width="20" height="20" fill="lightblue" stroke="blue" stroke-width="2" />',
f'<rect x="{cx+20}" y="{cy+15}" width="20" height="20" fill="lightblue" stroke="blue" stroke-width="2" />',
]
def _draw_abstract_sketch(self, cx, cy, num_paths):
"""Draw abstract sketchy shapes"""
import random
random.seed(42) # For consistent results
shapes = []
colors = ["red", "blue", "green", "orange", "purple", "pink", "yellow"]
for i in range(min(num_paths, 12)):
x = cx + random.randint(-150, 150)
y = cy + random.randint(-150, 150)
r = random.randint(20, 60)
color = random.choice(colors)
if i % 3 == 0:
shapes.append(f'<circle cx="{x}" cy="{y}" r="{r}" fill="none" stroke="{color}" stroke-width="3" />')
elif i % 3 == 1:
shapes.append(f'<rect x="{x-r//2}" y="{y-r//2}" width="{r}" height="{r}" fill="none" stroke="{color}" stroke-width="2" />')
else:
points = f"{x},{y-r} {x+r},{y+r} {x-r},{y+r}"
shapes.append(f'<polygon points="{points}" fill="none" stroke="{color}" stroke-width="2" />')
return shapes