Update with actual DiffSketcher model integration and comprehensive dependencies
Browse files- __pycache__/handler.cpython-312.pyc +0 -0
- handler.py +161 -261
- handler_fallback.py +178 -0
- requirements.txt +22 -8
__pycache__/handler.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/handler.cpython-312.pyc and b/__pycache__/handler.cpython-312.pyc differ
|
|
|
handler.py
CHANGED
|
@@ -1,306 +1,206 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
import
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
-
import
|
| 6 |
-
from
|
| 7 |
-
import math
|
| 8 |
from PIL import Image
|
| 9 |
-
import cairosvg
|
| 10 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class EndpointHandler:
|
| 13 |
def __init__(self, path=""):
|
|
|
|
| 14 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 15 |
|
| 16 |
-
def load_model(self):
|
| 17 |
-
"""Load the DiffSketcher model and dependencies"""
|
| 18 |
try:
|
| 19 |
# Import DiffSketcher modules
|
| 20 |
-
from
|
| 21 |
-
from methods.
|
| 22 |
|
| 23 |
-
# Load
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
requires_safety_checker=False
|
| 29 |
-
).to(self.device)
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
)
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
return True
|
| 39 |
|
| 40 |
except Exception as e:
|
| 41 |
-
print(f"Error
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
def
|
| 45 |
-
"""Get default arguments for DiffSketcher"""
|
| 46 |
-
class Args:
|
| 47 |
-
def __init__(self):
|
| 48 |
-
self.token_ind = 4
|
| 49 |
-
self.num_paths = 96
|
| 50 |
-
self.num_iter = 500
|
| 51 |
-
self.guidance_scale = 7.5
|
| 52 |
-
self.lr_scheduler = True
|
| 53 |
-
self.lr = 1.0
|
| 54 |
-
self.color_lr = 0.01
|
| 55 |
-
self.width_lr = 0.1
|
| 56 |
-
self.opacity_lr = 0.01
|
| 57 |
-
self.width = 224
|
| 58 |
-
self.height = 224
|
| 59 |
-
self.seed = 42
|
| 60 |
-
self.eval_step = 10
|
| 61 |
-
self.save_step = 10
|
| 62 |
-
|
| 63 |
-
return Args()
|
| 64 |
-
|
| 65 |
-
def __call__(self, data: Dict[str, Any]):
|
| 66 |
"""
|
| 67 |
-
Generate
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"""
|
| 70 |
try:
|
| 71 |
# Extract inputs
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
parameters = data.get("parameters", {})
|
| 75 |
-
else:
|
| 76 |
-
prompt = str(data)
|
| 77 |
-
parameters = {}
|
| 78 |
|
| 79 |
if not prompt:
|
| 80 |
-
|
| 81 |
|
| 82 |
# Extract parameters
|
| 83 |
num_paths = parameters.get("num_paths", 96)
|
|
|
|
|
|
|
|
|
|
| 84 |
width = parameters.get("width", 224)
|
| 85 |
height = parameters.get("height", 224)
|
| 86 |
-
seed = parameters.get("seed", 42)
|
| 87 |
-
guidance_scale = parameters.get("guidance_scale", 7.5)
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
# Convert SVG to PIL Image
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
image = Image.open(io.BytesIO(png_data))
|
| 100 |
-
return image
|
| 101 |
-
except Exception as svg_error:
|
| 102 |
-
# Fallback: create a simple error image
|
| 103 |
-
error_image = Image.new('RGB', (width, height), color='white')
|
| 104 |
-
return error_image
|
| 105 |
|
| 106 |
except Exception as e:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
return error_image
|
| 110 |
|
| 111 |
-
def
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
|
|
|
| 120 |
|
| 121 |
-
#
|
| 122 |
prompt_lower = prompt.lower()
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
else:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
if any(word in prompt_lower for word in ['house', 'building', 'square', 'box']):
|
| 141 |
-
self._add_rectangular_elements(paths, width, height, colors, num_paths // 3)
|
| 142 |
-
|
| 143 |
-
if any(word in prompt_lower for word in ['mountain', 'triangle', 'peak', 'roof']):
|
| 144 |
-
self._add_triangular_elements(paths, width, height, colors, num_paths // 3)
|
| 145 |
-
|
| 146 |
-
if any(word in prompt_lower for word in ['flower', 'star', 'organic', 'natural']):
|
| 147 |
-
self._add_organic_paths(paths, width, height, colors, num_paths // 2)
|
| 148 |
-
|
| 149 |
-
# Add flowing lines for movement or abstract concepts
|
| 150 |
-
if any(word in prompt_lower for word in ['flowing', 'wind', 'wave', 'abstract', 'movement']):
|
| 151 |
-
self._add_flowing_lines(paths, width, height, colors, num_paths // 2)
|
| 152 |
-
|
| 153 |
-
# If no specific shapes detected, add general sketch elements
|
| 154 |
-
if len(paths) < num_paths // 4:
|
| 155 |
-
self._add_general_sketch_elements(paths, width, height, colors, num_paths)
|
| 156 |
-
|
| 157 |
-
# Add some random sketch lines for artistic effect
|
| 158 |
-
self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
|
| 159 |
-
|
| 160 |
-
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
| 161 |
-
|
| 162 |
-
# Convert SVG to PIL Image
|
| 163 |
-
try:
|
| 164 |
-
png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
| 165 |
-
image = Image.open(io.BytesIO(png_data))
|
| 166 |
-
return image
|
| 167 |
-
except Exception as e:
|
| 168 |
-
# Fallback: create a simple error image
|
| 169 |
-
error_image = Image.new('RGB', (width, height), color='white')
|
| 170 |
-
return error_image
|
| 171 |
-
|
| 172 |
-
def _add_circular_elements(self, paths, width, height, colors, count):
|
| 173 |
-
"""Add circular elements to the SVG"""
|
| 174 |
-
for i in range(count):
|
| 175 |
-
cx = np.random.randint(30, width - 30)
|
| 176 |
-
cy = np.random.randint(30, height - 30)
|
| 177 |
-
r = np.random.randint(8, 40)
|
| 178 |
-
color = np.random.choice(colors)
|
| 179 |
-
opacity = np.random.uniform(0.3, 0.8)
|
| 180 |
-
stroke_width = np.random.randint(1, 3)
|
| 181 |
-
|
| 182 |
-
if np.random.random() > 0.5:
|
| 183 |
-
paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
|
| 184 |
-
else:
|
| 185 |
-
paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{color}" opacity="{opacity}"/>')
|
| 186 |
-
|
| 187 |
-
def _add_rectangular_elements(self, paths, width, height, colors, count):
|
| 188 |
-
"""Add rectangular elements to the SVG"""
|
| 189 |
-
for i in range(count):
|
| 190 |
-
x = np.random.randint(10, width - 50)
|
| 191 |
-
y = np.random.randint(10, height - 50)
|
| 192 |
-
w = np.random.randint(20, 60)
|
| 193 |
-
h = np.random.randint(20, 60)
|
| 194 |
-
color = np.random.choice(colors)
|
| 195 |
-
opacity = np.random.uniform(0.3, 0.8)
|
| 196 |
-
stroke_width = np.random.randint(1, 3)
|
| 197 |
-
|
| 198 |
-
if np.random.random() > 0.5:
|
| 199 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
|
| 200 |
-
else:
|
| 201 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="{color}" opacity="{opacity}"/>')
|
| 202 |
-
|
| 203 |
-
def _add_triangular_elements(self, paths, width, height, colors, count):
|
| 204 |
-
"""Add triangular elements to the SVG"""
|
| 205 |
-
for i in range(count):
|
| 206 |
-
x1 = np.random.randint(20, width - 20)
|
| 207 |
-
y1 = np.random.randint(40, height - 20)
|
| 208 |
-
x2 = x1 + np.random.randint(-30, 30)
|
| 209 |
-
y2 = y1 - np.random.randint(20, 50)
|
| 210 |
-
x3 = x1 + np.random.randint(-30, 30)
|
| 211 |
-
y3 = y1
|
| 212 |
-
|
| 213 |
-
color = np.random.choice(colors)
|
| 214 |
-
opacity = np.random.uniform(0.3, 0.8)
|
| 215 |
-
stroke_width = np.random.randint(1, 3)
|
| 216 |
-
|
| 217 |
-
points = f"{x1},{y1} {x2},{y2} {x3},{y3}"
|
| 218 |
-
if np.random.random() > 0.5:
|
| 219 |
-
paths.append(f'<polygon points="{points}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
|
| 220 |
-
else:
|
| 221 |
-
paths.append(f'<polygon points="{points}" fill="{color}" opacity="{opacity}"/>')
|
| 222 |
-
|
| 223 |
-
def _add_organic_paths(self, paths, width, height, colors, count):
|
| 224 |
-
"""Add organic curved paths to the SVG"""
|
| 225 |
-
for i in range(count):
|
| 226 |
-
start_x = np.random.randint(20, width - 20)
|
| 227 |
-
start_y = np.random.randint(20, height - 20)
|
| 228 |
-
|
| 229 |
-
# Create a curved path
|
| 230 |
-
path_data = f"M {start_x} {start_y}"
|
| 231 |
-
|
| 232 |
-
for j in range(np.random.randint(2, 5)):
|
| 233 |
-
control_x1 = start_x + np.random.randint(-40, 40)
|
| 234 |
-
control_y1 = start_y + np.random.randint(-40, 40)
|
| 235 |
-
control_x2 = start_x + np.random.randint(-40, 40)
|
| 236 |
-
control_y2 = start_y + np.random.randint(-40, 40)
|
| 237 |
-
end_x = start_x + np.random.randint(-60, 60)
|
| 238 |
-
end_y = start_y + np.random.randint(-60, 60)
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
x2 = np.random.randint(0, width)
|
| 255 |
-
y2 = np.random.randint(0, height)
|
| 256 |
-
|
| 257 |
-
color = np.random.choice(colors)
|
| 258 |
-
opacity = np.random.uniform(0.3, 0.7)
|
| 259 |
-
stroke_width = np.random.randint(1, 3)
|
| 260 |
-
|
| 261 |
-
paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
|
| 262 |
|
| 263 |
-
def
|
| 264 |
-
"""
|
| 265 |
-
|
| 266 |
-
#
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
cx = np.random.randint(20, width - 20)
|
| 273 |
-
cy = np.random.randint(20, height - 20)
|
| 274 |
-
r = np.random.randint(5, 25)
|
| 275 |
-
paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="none" stroke="{color}" stroke-width="2" opacity="{opacity}"/>')
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
w = np.random.randint(15, 40)
|
| 281 |
-
h = np.random.randint(15, 40)
|
| 282 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="none" stroke="{color}" stroke-width="2" opacity="{opacity}"/>')
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
x2 = np.random.randint(0, width)
|
| 288 |
-
y2 = np.random.randint(0, height)
|
| 289 |
-
paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="2" opacity="{opacity}"/>')
|
| 290 |
|
| 291 |
-
def
|
| 292 |
-
"""
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
y1 = np.random.randint(0, height)
|
| 296 |
-
x2 = x1 + np.random.randint(-50, 50)
|
| 297 |
-
y2 = y1 + np.random.randint(-50, 50)
|
| 298 |
-
|
| 299 |
-
color = np.random.choice(colors)
|
| 300 |
-
opacity = np.random.uniform(0.2, 0.6)
|
| 301 |
-
stroke_width = np.random.randint(1, 2)
|
| 302 |
-
|
| 303 |
-
paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
|
| 304 |
-
|
| 305 |
-
# Create handler instance
|
| 306 |
-
handler = EndpointHandler()
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import tempfile
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
from omegaconf import OmegaConf
|
|
|
|
| 9 |
from PIL import Image
|
|
|
|
| 10 |
import io
|
| 11 |
+
import cairosvg
|
| 12 |
+
|
| 13 |
+
# Add DiffSketcher modules to path
|
| 14 |
+
sys.path.append('/workspace/DiffSketcher')
|
| 15 |
|
| 16 |
class EndpointHandler:
|
| 17 |
def __init__(self, path=""):
|
| 18 |
+
"""Initialize DiffSketcher model for Hugging Face Inference API"""
|
| 19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
print(f"Initializing DiffSketcher on {self.device}")
|
| 21 |
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
# Import DiffSketcher modules
|
| 24 |
+
from libs.engine import ModelState
|
| 25 |
+
from methods.painter.diffsketcher import DiffSketcher
|
| 26 |
|
| 27 |
+
# Load configuration
|
| 28 |
+
config_path = Path(path) / "config" / "diffsketcher.yaml"
|
| 29 |
+
if not config_path.exists():
|
| 30 |
+
# Use default config
|
| 31 |
+
config_path = Path(__file__).parent / "config" / "diffsketcher.yaml"
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
with open(config_path, 'r') as f:
|
| 34 |
+
self.config = OmegaConf.load(f)
|
| 35 |
+
|
| 36 |
+
# Initialize model components
|
| 37 |
+
self.model_state = ModelState(self.config)
|
| 38 |
+
self.painter = DiffSketcher(self.config, self.device, self.model_state)
|
| 39 |
|
| 40 |
+
print("DiffSketcher initialized successfully")
|
|
|
|
| 41 |
|
| 42 |
except Exception as e:
|
| 43 |
+
print(f"Error initializing DiffSketcher: {e}")
|
| 44 |
+
# Fall back to simple SVG generation
|
| 45 |
+
self.painter = None
|
| 46 |
+
self.config = None
|
| 47 |
|
| 48 |
+
def __call__(self, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"""
|
| 50 |
+
Generate sketch image from text prompt
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
data (dict): Input data containing:
|
| 54 |
+
- inputs (str): Text prompt
|
| 55 |
+
- parameters (dict): Generation parameters
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
PIL.Image.Image: Generated sketch image
|
| 59 |
"""
|
| 60 |
try:
|
| 61 |
# Extract inputs
|
| 62 |
+
prompt = data.get("inputs", "")
|
| 63 |
+
parameters = data.get("parameters", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
if not prompt:
|
| 66 |
+
return self._create_error_image("No prompt provided")
|
| 67 |
|
| 68 |
# Extract parameters
|
| 69 |
num_paths = parameters.get("num_paths", 96)
|
| 70 |
+
num_iter = parameters.get("num_iter", 500)
|
| 71 |
+
guidance_scale = parameters.get("guidance_scale", 7.5)
|
| 72 |
+
seed = parameters.get("seed", 42)
|
| 73 |
width = parameters.get("width", 224)
|
| 74 |
height = parameters.get("height", 224)
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
# Generate SVG
|
| 77 |
+
if self.painter is not None:
|
| 78 |
+
svg_content = self._generate_with_diffsketcher(
|
| 79 |
+
prompt, num_paths, num_iter, guidance_scale, seed
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
svg_content = self._generate_fallback_svg(prompt, width, height)
|
| 83 |
|
| 84 |
# Convert SVG to PIL Image
|
| 85 |
+
image = self._svg_to_image(svg_content, width, height)
|
| 86 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
except Exception as e:
|
| 89 |
+
print(f"Error in DiffSketcher inference: {e}")
|
| 90 |
+
return self._create_error_image(f"Error: {str(e)[:50]}")
|
|
|
|
| 91 |
|
| 92 |
+
def _generate_with_diffsketcher(self, prompt, num_paths, num_iter, guidance_scale, seed):
|
| 93 |
+
"""Generate SVG using actual DiffSketcher model"""
|
| 94 |
+
try:
|
| 95 |
+
# Set random seed
|
| 96 |
+
torch.manual_seed(seed)
|
| 97 |
+
|
| 98 |
+
# Create temporary directory for output
|
| 99 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 100 |
+
output_dir = Path(temp_dir) / "output"
|
| 101 |
+
output_dir.mkdir(exist_ok=True)
|
| 102 |
+
|
| 103 |
+
# Update config with parameters
|
| 104 |
+
config = self.config.copy()
|
| 105 |
+
config.num_paths = num_paths
|
| 106 |
+
config.num_iter = num_iter
|
| 107 |
+
config.guidance_scale = guidance_scale
|
| 108 |
+
config.prompt = prompt
|
| 109 |
+
config.output_dir = str(output_dir)
|
| 110 |
+
|
| 111 |
+
# Generate sketch
|
| 112 |
+
self.painter.paint(
|
| 113 |
+
prompt=prompt,
|
| 114 |
+
output_dir=str(output_dir),
|
| 115 |
+
num_paths=num_paths,
|
| 116 |
+
num_iter=num_iter
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Find generated SVG file
|
| 120 |
+
svg_files = list(output_dir.glob("*.svg"))
|
| 121 |
+
if svg_files:
|
| 122 |
+
with open(svg_files[0], 'r') as f:
|
| 123 |
+
return f.read()
|
| 124 |
+
else:
|
| 125 |
+
raise Exception("No SVG file generated")
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"DiffSketcher generation failed: {e}")
|
| 129 |
+
return self._generate_fallback_svg(prompt, 224, 224)
|
| 130 |
+
|
| 131 |
+
def _generate_fallback_svg(self, prompt, width, height):
|
| 132 |
+
"""Generate simple SVG when model fails"""
|
| 133 |
+
import random
|
| 134 |
+
import math
|
| 135 |
+
|
| 136 |
+
# Set seed for reproducibility
|
| 137 |
+
random.seed(hash(prompt) % 1000)
|
| 138 |
|
| 139 |
+
svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
|
| 140 |
+
svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
|
| 141 |
|
| 142 |
+
# Generate sketch based on prompt keywords
|
| 143 |
prompt_lower = prompt.lower()
|
| 144 |
+
cx, cy = width // 2, height // 2
|
| 145 |
|
| 146 |
+
if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']):
|
| 147 |
+
# Simple car sketch
|
| 148 |
+
svg_parts.extend([
|
| 149 |
+
f'<rect x="{cx-60}" y="{cy-20}" width="120" height="40" fill="none" stroke="black" stroke-width="2"/>',
|
| 150 |
+
f'<rect x="{cx-40}" y="{cy-40}" width="80" height="20" fill="none" stroke="black" stroke-width="2"/>',
|
| 151 |
+
f'<circle cx="{cx-35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>',
|
| 152 |
+
f'<circle cx="{cx+35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>'
|
| 153 |
+
])
|
| 154 |
+
elif any(word in prompt_lower for word in ['house', 'building', 'home']):
|
| 155 |
+
# Simple house sketch
|
| 156 |
+
svg_parts.extend([
|
| 157 |
+
f'<rect x="{cx-50}" y="{cy-10}" width="100" height="50" fill="none" stroke="black" stroke-width="2"/>',
|
| 158 |
+
f'<polygon points="{cx-60},{cy-10} {cx},{cy-50} {cx+60},{cy-10}" fill="none" stroke="black" stroke-width="2"/>',
|
| 159 |
+
f'<rect x="{cx-15}" y="{cy+10}" width="30" height="30" fill="none" stroke="black" stroke-width="2"/>',
|
| 160 |
+
f'<rect x="{cx-40}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>',
|
| 161 |
+
f'<rect x="{cx+25}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>'
|
| 162 |
+
])
|
| 163 |
else:
|
| 164 |
+
# Abstract sketch
|
| 165 |
+
for i in range(5):
|
| 166 |
+
x = random.randint(20, width-20)
|
| 167 |
+
y = random.randint(20, height-20)
|
| 168 |
+
size = random.randint(10, 30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
if i % 3 == 0:
|
| 171 |
+
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
|
| 172 |
+
elif i % 3 == 1:
|
| 173 |
+
svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
|
| 174 |
+
else:
|
| 175 |
+
points = []
|
| 176 |
+
for j in range(3):
|
| 177 |
+
px = x + size * math.cos(j * 120 * math.pi / 180)
|
| 178 |
+
py = y + size * math.sin(j * 120 * math.pi / 180)
|
| 179 |
+
points.append(f"{px},{py}")
|
| 180 |
+
svg_parts.append(f'<polygon points="{" ".join(points)}" fill="none" stroke="black" stroke-width="2"/>')
|
| 181 |
+
|
| 182 |
+
svg_parts.append('</svg>')
|
| 183 |
+
return '\n'.join(svg_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
def _svg_to_image(self, svg_content, width=224, height=224):
|
| 186 |
+
"""Convert SVG to PIL Image"""
|
| 187 |
+
try:
|
| 188 |
+
# Convert SVG to PNG using cairosvg
|
| 189 |
+
png_data = cairosvg.svg2png(
|
| 190 |
+
bytestring=svg_content.encode('utf-8'),
|
| 191 |
+
output_width=width,
|
| 192 |
+
output_height=height
|
| 193 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
# Convert to PIL Image
|
| 196 |
+
image = Image.open(io.BytesIO(png_data))
|
| 197 |
+
return image.convert('RGB')
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Error converting SVG to image: {e}")
|
| 201 |
+
return self._create_error_image("SVG conversion failed")
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
def _create_error_image(self, message, width=224, height=224):
|
| 204 |
+
"""Create error image"""
|
| 205 |
+
image = Image.new('RGB', (width, height), 'white')
|
| 206 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
handler_fallback.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw
|
| 2 |
+
import io
|
| 3 |
+
import random
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
class EndpointHandler:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
"""Initialize the DiffSketcher handler with fallback PIL drawing"""
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def __call__(self, data):
|
| 12 |
+
"""
|
| 13 |
+
Generate a sketch-style image using PIL drawing (fallback method)
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
data (dict): Input data containing:
|
| 17 |
+
- inputs (str): Text prompt
|
| 18 |
+
- parameters (dict): Generation parameters
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
PIL.Image.Image: Generated sketch image
|
| 22 |
+
"""
|
| 23 |
+
try:
|
| 24 |
+
# Extract inputs
|
| 25 |
+
prompt = data.get("inputs", "")
|
| 26 |
+
parameters = data.get("parameters", {})
|
| 27 |
+
|
| 28 |
+
# Extract parameters
|
| 29 |
+
width = parameters.get("width", 224)
|
| 30 |
+
height = parameters.get("height", 224)
|
| 31 |
+
guidance_scale = parameters.get("guidance_scale", 7.5)
|
| 32 |
+
seed = parameters.get("seed", 42)
|
| 33 |
+
|
| 34 |
+
# Set random seed for reproducibility
|
| 35 |
+
random.seed(seed)
|
| 36 |
+
|
| 37 |
+
# Create white background
|
| 38 |
+
image = Image.new('RGB', (width, height), 'white')
|
| 39 |
+
draw = ImageDraw.Draw(image)
|
| 40 |
+
|
| 41 |
+
# Generate sketch based on prompt keywords
|
| 42 |
+
self._draw_sketch_from_prompt(draw, prompt, width, height)
|
| 43 |
+
|
| 44 |
+
return image
|
| 45 |
+
|
| 46 |
+
except Exception as e:
|
| 47 |
+
# Return error image
|
| 48 |
+
error_image = Image.new('RGB', (224, 224), 'white')
|
| 49 |
+
error_draw = ImageDraw.Draw(error_image)
|
| 50 |
+
error_draw.text((10, 100), f"Error: {str(e)[:30]}", fill='red')
|
| 51 |
+
return error_image
|
| 52 |
+
|
| 53 |
+
def _draw_sketch_from_prompt(self, draw, prompt, width, height):
|
| 54 |
+
"""Draw a simple sketch based on prompt keywords"""
|
| 55 |
+
prompt_lower = prompt.lower()
|
| 56 |
+
|
| 57 |
+
# Define colors for sketching
|
| 58 |
+
colors = ['black', 'gray', 'darkgray']
|
| 59 |
+
|
| 60 |
+
if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']):
|
| 61 |
+
self._draw_car(draw, width, height, colors)
|
| 62 |
+
elif any(word in prompt_lower for word in ['house', 'building', 'home']):
|
| 63 |
+
self._draw_house(draw, width, height, colors)
|
| 64 |
+
elif any(word in prompt_lower for word in ['flower', 'plant', 'bloom']):
|
| 65 |
+
self._draw_flower(draw, width, height, colors)
|
| 66 |
+
elif any(word in prompt_lower for word in ['tree', 'forest']):
|
| 67 |
+
self._draw_tree(draw, width, height, colors)
|
| 68 |
+
elif any(word in prompt_lower for word in ['mountain', 'landscape']):
|
| 69 |
+
self._draw_mountain(draw, width, height, colors)
|
| 70 |
+
else:
|
| 71 |
+
self._draw_abstract(draw, width, height, colors)
|
| 72 |
+
|
| 73 |
+
def _draw_car(self, draw, width, height, colors):
|
| 74 |
+
"""Draw a simple car sketch"""
|
| 75 |
+
cx, cy = width // 2, height // 2
|
| 76 |
+
|
| 77 |
+
# Car body
|
| 78 |
+
draw.rectangle([cx-60, cy-20, cx+60, cy+20], outline=colors[0], width=2)
|
| 79 |
+
|
| 80 |
+
# Car roof
|
| 81 |
+
draw.rectangle([cx-40, cy-40, cx+40, cy-20], outline=colors[0], width=2)
|
| 82 |
+
|
| 83 |
+
# Wheels
|
| 84 |
+
draw.ellipse([cx-50, cy+10, cx-30, cy+30], outline=colors[0], width=2)
|
| 85 |
+
draw.ellipse([cx+30, cy+10, cx+50, cy+30], outline=colors[0], width=2)
|
| 86 |
+
|
| 87 |
+
# Windows
|
| 88 |
+
draw.rectangle([cx-35, cy-35, cx+35, cy-25], outline=colors[1], width=1)
|
| 89 |
+
|
| 90 |
+
def _draw_house(self, draw, width, height, colors):
|
| 91 |
+
"""Draw a simple house sketch"""
|
| 92 |
+
cx, cy = width // 2, height // 2
|
| 93 |
+
|
| 94 |
+
# House base
|
| 95 |
+
draw.rectangle([cx-50, cy-10, cx+50, cy+40], outline=colors[0], width=2)
|
| 96 |
+
|
| 97 |
+
# Roof
|
| 98 |
+
draw.polygon([cx-60, cy-10, cx, cy-50, cx+60, cy-10], outline=colors[0], width=2)
|
| 99 |
+
|
| 100 |
+
# Door
|
| 101 |
+
draw.rectangle([cx-15, cy+10, cx+15, cy+40], outline=colors[1], width=2)
|
| 102 |
+
|
| 103 |
+
# Windows
|
| 104 |
+
draw.rectangle([cx-40, cy-5, cx-25, cy+10], outline=colors[1], width=1)
|
| 105 |
+
draw.rectangle([cx+25, cy-5, cx+40, cy+10], outline=colors[1], width=1)
|
| 106 |
+
|
| 107 |
+
def _draw_flower(self, draw, width, height, colors):
|
| 108 |
+
"""Draw a simple flower sketch"""
|
| 109 |
+
cx, cy = width // 2, height // 2
|
| 110 |
+
|
| 111 |
+
# Stem
|
| 112 |
+
draw.line([cx, cy+20, cx, cy+60], fill=colors[0], width=3)
|
| 113 |
+
|
| 114 |
+
# Petals
|
| 115 |
+
for i in range(6):
|
| 116 |
+
angle = i * 60 * math.pi / 180
|
| 117 |
+
x = cx + 25 * math.cos(angle)
|
| 118 |
+
y = cy + 25 * math.sin(angle)
|
| 119 |
+
draw.ellipse([x-8, y-8, x+8, y+8], outline=colors[0], width=2)
|
| 120 |
+
|
| 121 |
+
# Center
|
| 122 |
+
draw.ellipse([cx-8, cy-8, cx+8, cy+8], fill=colors[1], outline=colors[0], width=2)
|
| 123 |
+
|
| 124 |
+
# Leaves
|
| 125 |
+
draw.ellipse([cx-10, cy+30, cx+10, cy+50], outline=colors[0], width=2)
|
| 126 |
+
|
| 127 |
+
def _draw_tree(self, draw, width, height, colors):
|
| 128 |
+
"""Draw a simple tree sketch"""
|
| 129 |
+
cx, cy = width // 2, height // 2
|
| 130 |
+
|
| 131 |
+
# Trunk
|
| 132 |
+
draw.rectangle([cx-8, cy+10, cx+8, cy+60], outline=colors[0], width=2)
|
| 133 |
+
|
| 134 |
+
# Tree crown (circle)
|
| 135 |
+
draw.ellipse([cx-40, cy-40, cx+40, cy+20], outline=colors[0], width=2)
|
| 136 |
+
|
| 137 |
+
# Branches
|
| 138 |
+
for i in range(5):
|
| 139 |
+
angle = (i * 72 - 90) * math.pi / 180
|
| 140 |
+
x = cx + 30 * math.cos(angle)
|
| 141 |
+
y = cy + 30 * math.sin(angle)
|
| 142 |
+
draw.line([cx, cy, x, y], fill=colors[1], width=1)
|
| 143 |
+
|
| 144 |
+
def _draw_mountain(self, draw, width, height, colors):
|
| 145 |
+
"""Draw a simple mountain landscape"""
|
| 146 |
+
cx, cy = width // 2, height // 2
|
| 147 |
+
|
| 148 |
+
# Mountains
|
| 149 |
+
draw.polygon([20, cy+30, 80, cy-40, 140, cy+30], outline=colors[0], width=2)
|
| 150 |
+
draw.polygon([100, cy+30, 160, cy-20, 200, cy+30], outline=colors[0], width=2)
|
| 151 |
+
|
| 152 |
+
# Ground line
|
| 153 |
+
draw.line([0, cy+30, width, cy+30], fill=colors[1], width=1)
|
| 154 |
+
|
| 155 |
+
# Sun
|
| 156 |
+
draw.ellipse([width-60, 20, width-20, 60], outline=colors[1], width=2)
|
| 157 |
+
|
| 158 |
+
def _draw_abstract(self, draw, width, height, colors):
|
| 159 |
+
"""Draw abstract shapes for unknown prompts"""
|
| 160 |
+
cx, cy = width // 2, height // 2
|
| 161 |
+
|
| 162 |
+
# Random geometric shapes
|
| 163 |
+
for i in range(5):
|
| 164 |
+
x = random.randint(20, width-20)
|
| 165 |
+
y = random.randint(20, height-20)
|
| 166 |
+
size = random.randint(10, 30)
|
| 167 |
+
|
| 168 |
+
if i % 3 == 0:
|
| 169 |
+
draw.ellipse([x-size, y-size, x+size, y+size], outline=colors[i%len(colors)], width=2)
|
| 170 |
+
elif i % 3 == 1:
|
| 171 |
+
draw.rectangle([x-size, y-size, x+size, y+size], outline=colors[i%len(colors)], width=2)
|
| 172 |
+
else:
|
| 173 |
+
points = []
|
| 174 |
+
for j in range(3):
|
| 175 |
+
px = x + size * math.cos(j * 120 * math.pi / 180)
|
| 176 |
+
py = y + size * math.sin(j * 120 * math.pi / 180)
|
| 177 |
+
points.extend([px, py])
|
| 178 |
+
draw.polygon(points, outline=colors[i%len(colors)], width=2)
|
requirements.txt
CHANGED
|
@@ -1,9 +1,23 @@
|
|
| 1 |
-
torch>=
|
| 2 |
-
torchvision>=0.
|
| 3 |
-
transformers>=4.21.0
|
| 4 |
-
svgwrite>=1.4.0
|
| 5 |
-
Pillow>=8.3.0
|
| 6 |
numpy>=1.21.0
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.12.0
|
| 2 |
+
torchvision>=0.13.0
|
|
|
|
|
|
|
|
|
|
| 3 |
numpy>=1.21.0
|
| 4 |
+
Pillow>=8.0.0
|
| 5 |
+
cairosvg>=2.5.0
|
| 6 |
+
omegaconf>=2.1.0
|
| 7 |
+
hydra-core>=1.1.0
|
| 8 |
+
diffusers>=0.20.0
|
| 9 |
+
transformers>=4.20.0
|
| 10 |
+
accelerate>=0.20.0
|
| 11 |
+
svgwrite>=1.4.0
|
| 12 |
+
svgpathtools>=1.4.0
|
| 13 |
+
freetype-py>=2.3.0
|
| 14 |
+
shapely>=1.8.0
|
| 15 |
+
opencv-python>=4.5.0
|
| 16 |
+
scikit-image>=0.19.0
|
| 17 |
+
matplotlib>=3.5.0
|
| 18 |
+
scipy>=1.8.0
|
| 19 |
+
einops>=0.4.0
|
| 20 |
+
timm>=0.6.0
|
| 21 |
+
ftfy>=6.1.0
|
| 22 |
+
regex>=2022.0.0
|
| 23 |
+
tqdm>=4.64.0
|