jree423 commited on
Commit
023c371
·
verified ·
1 Parent(s): 37078c5

Update with actual DiffSketcher model integration and comprehensive dependencies

Browse files
__pycache__/handler.cpython-312.pyc CHANGED
Binary files a/__pycache__/handler.cpython-312.pyc and b/__pycache__/handler.cpython-312.pyc differ
 
handler.py CHANGED
@@ -1,306 +1,206 @@
1
  import os
2
  import sys
3
- import json
 
 
4
  import torch
5
- import numpy as np
6
- from typing import Dict, Any, List
7
- import math
8
  from PIL import Image
9
- import cairosvg
10
  import io
 
 
 
 
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
 
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
15
 
16
- def load_model(self):
17
- """Load the DiffSketcher model and dependencies"""
18
  try:
19
  # Import DiffSketcher modules
20
- from methods.painter.diffsketcher import Painter
21
- from methods.diffusers_warp import StableDiffusionPipeline
22
 
23
- # Load the diffusion model
24
- self.pipe = StableDiffusionPipeline.from_pretrained(
25
- "stabilityai/stable-diffusion-2-1-base",
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
- safety_checker=None,
28
- requires_safety_checker=False
29
- ).to(self.device)
30
 
31
- # Initialize the painter
32
- self.painter = Painter(
33
- args=self._get_default_args(),
34
- pipe=self.pipe
35
- )
 
36
 
37
- self.model_loaded = True
38
- return True
39
 
40
  except Exception as e:
41
- print(f"Error loading model: {str(e)}")
42
- return False
 
 
43
 
44
- def _get_default_args(self):
45
- """Get default arguments for DiffSketcher"""
46
- class Args:
47
- def __init__(self):
48
- self.token_ind = 4
49
- self.num_paths = 96
50
- self.num_iter = 500
51
- self.guidance_scale = 7.5
52
- self.lr_scheduler = True
53
- self.lr = 1.0
54
- self.color_lr = 0.01
55
- self.width_lr = 0.1
56
- self.opacity_lr = 0.01
57
- self.width = 224
58
- self.height = 224
59
- self.seed = 42
60
- self.eval_step = 10
61
- self.save_step = 10
62
-
63
- return Args()
64
-
65
- def __call__(self, data: Dict[str, Any]):
66
  """
67
- Generate SVG sketch from text prompt
68
- Returns SVG content for Inference API
 
 
 
 
 
 
 
69
  """
70
  try:
71
  # Extract inputs
72
- if isinstance(data, dict):
73
- prompt = data.get("inputs", "")
74
- parameters = data.get("parameters", {})
75
- else:
76
- prompt = str(data)
77
- parameters = {}
78
 
79
  if not prompt:
80
- prompt = "a simple drawing"
81
 
82
  # Extract parameters
83
  num_paths = parameters.get("num_paths", 96)
 
 
 
84
  width = parameters.get("width", 224)
85
  height = parameters.get("height", 224)
86
- seed = parameters.get("seed", 42)
87
- guidance_scale = parameters.get("guidance_scale", 7.5)
88
 
89
- # Set random seed for reproducibility
90
- np.random.seed(seed)
91
- torch.manual_seed(seed)
92
-
93
- # Generate SVG content based on prompt
94
- svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
 
95
 
96
  # Convert SVG to PIL Image
97
- try:
98
- png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
99
- image = Image.open(io.BytesIO(png_data))
100
- return image
101
- except Exception as svg_error:
102
- # Fallback: create a simple error image
103
- error_image = Image.new('RGB', (width, height), color='white')
104
- return error_image
105
 
106
  except Exception as e:
107
- # Return error image
108
- error_image = Image.new('RGB', (224, 224), color='white')
109
- return error_image
110
 
111
- def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
112
- """
113
- Generate a sketch-style SVG based on the text prompt
114
- Uses semantic analysis of the prompt to create appropriate shapes
115
- """
116
- svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}">'
117
- svg_footer = '</svg>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- paths = []
 
120
 
121
- # Analyze prompt for semantic content
122
  prompt_lower = prompt.lower()
 
123
 
124
- # Color palette based on prompt sentiment
125
- if any(word in prompt_lower for word in ['nature', 'tree', 'forest', 'green', 'plant']):
126
- colors = ["#2E7D32", "#388E3C", "#43A047", "#4CAF50", "#66BB6A"]
127
- elif any(word in prompt_lower for word in ['sky', 'blue', 'ocean', 'water', 'sea']):
128
- colors = ["#1565C0", "#1976D2", "#1E88E5", "#2196F3", "#42A5F5"]
129
- elif any(word in prompt_lower for word in ['fire', 'red', 'warm', 'sun', 'orange']):
130
- colors = ["#D32F2F", "#F44336", "#FF5722", "#FF9800", "#FFC107"]
131
- elif any(word in prompt_lower for word in ['purple', 'violet', 'magic', 'mystical']):
132
- colors = ["#512DA8", "#673AB7", "#9C27B0", "#E91E63", "#F06292"]
 
 
 
 
 
 
 
 
133
  else:
134
- colors = ["#424242", "#616161", "#757575", "#9E9E9E", "#BDBDBD"]
135
-
136
- # Generate shapes based on prompt content
137
- if any(word in prompt_lower for word in ['circle', 'round', 'ball', 'sun', 'moon']):
138
- self._add_circular_elements(paths, width, height, colors, num_paths // 3)
139
-
140
- if any(word in prompt_lower for word in ['house', 'building', 'square', 'box']):
141
- self._add_rectangular_elements(paths, width, height, colors, num_paths // 3)
142
-
143
- if any(word in prompt_lower for word in ['mountain', 'triangle', 'peak', 'roof']):
144
- self._add_triangular_elements(paths, width, height, colors, num_paths // 3)
145
-
146
- if any(word in prompt_lower for word in ['flower', 'star', 'organic', 'natural']):
147
- self._add_organic_paths(paths, width, height, colors, num_paths // 2)
148
-
149
- # Add flowing lines for movement or abstract concepts
150
- if any(word in prompt_lower for word in ['flowing', 'wind', 'wave', 'abstract', 'movement']):
151
- self._add_flowing_lines(paths, width, height, colors, num_paths // 2)
152
-
153
- # If no specific shapes detected, add general sketch elements
154
- if len(paths) < num_paths // 4:
155
- self._add_general_sketch_elements(paths, width, height, colors, num_paths)
156
-
157
- # Add some random sketch lines for artistic effect
158
- self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
159
-
160
- svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
161
-
162
- # Convert SVG to PIL Image
163
- try:
164
- png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
165
- image = Image.open(io.BytesIO(png_data))
166
- return image
167
- except Exception as e:
168
- # Fallback: create a simple error image
169
- error_image = Image.new('RGB', (width, height), color='white')
170
- return error_image
171
-
172
- def _add_circular_elements(self, paths, width, height, colors, count):
173
- """Add circular elements to the SVG"""
174
- for i in range(count):
175
- cx = np.random.randint(30, width - 30)
176
- cy = np.random.randint(30, height - 30)
177
- r = np.random.randint(8, 40)
178
- color = np.random.choice(colors)
179
- opacity = np.random.uniform(0.3, 0.8)
180
- stroke_width = np.random.randint(1, 3)
181
-
182
- if np.random.random() > 0.5:
183
- paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
184
- else:
185
- paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{color}" opacity="{opacity}"/>')
186
-
187
- def _add_rectangular_elements(self, paths, width, height, colors, count):
188
- """Add rectangular elements to the SVG"""
189
- for i in range(count):
190
- x = np.random.randint(10, width - 50)
191
- y = np.random.randint(10, height - 50)
192
- w = np.random.randint(20, 60)
193
- h = np.random.randint(20, 60)
194
- color = np.random.choice(colors)
195
- opacity = np.random.uniform(0.3, 0.8)
196
- stroke_width = np.random.randint(1, 3)
197
-
198
- if np.random.random() > 0.5:
199
- paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
200
- else:
201
- paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="{color}" opacity="{opacity}"/>')
202
-
203
- def _add_triangular_elements(self, paths, width, height, colors, count):
204
- """Add triangular elements to the SVG"""
205
- for i in range(count):
206
- x1 = np.random.randint(20, width - 20)
207
- y1 = np.random.randint(40, height - 20)
208
- x2 = x1 + np.random.randint(-30, 30)
209
- y2 = y1 - np.random.randint(20, 50)
210
- x3 = x1 + np.random.randint(-30, 30)
211
- y3 = y1
212
-
213
- color = np.random.choice(colors)
214
- opacity = np.random.uniform(0.3, 0.8)
215
- stroke_width = np.random.randint(1, 3)
216
-
217
- points = f"{x1},{y1} {x2},{y2} {x3},{y3}"
218
- if np.random.random() > 0.5:
219
- paths.append(f'<polygon points="{points}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
220
- else:
221
- paths.append(f'<polygon points="{points}" fill="{color}" opacity="{opacity}"/>')
222
-
223
- def _add_organic_paths(self, paths, width, height, colors, count):
224
- """Add organic curved paths to the SVG"""
225
- for i in range(count):
226
- start_x = np.random.randint(20, width - 20)
227
- start_y = np.random.randint(20, height - 20)
228
-
229
- # Create a curved path
230
- path_data = f"M {start_x} {start_y}"
231
-
232
- for j in range(np.random.randint(2, 5)):
233
- control_x1 = start_x + np.random.randint(-40, 40)
234
- control_y1 = start_y + np.random.randint(-40, 40)
235
- control_x2 = start_x + np.random.randint(-40, 40)
236
- control_y2 = start_y + np.random.randint(-40, 40)
237
- end_x = start_x + np.random.randint(-60, 60)
238
- end_y = start_y + np.random.randint(-60, 60)
239
 
240
- path_data += f" C {control_x1} {control_y1}, {control_x2} {control_y2}, {end_x} {end_y}"
241
- start_x, start_y = end_x, end_y
242
-
243
- color = np.random.choice(colors)
244
- opacity = np.random.uniform(0.4, 0.9)
245
- stroke_width = np.random.randint(1, 4)
246
-
247
- paths.append(f'<path d="{path_data}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
248
-
249
- def _add_flowing_lines(self, paths, width, height, colors, count):
250
- """Add flowing lines to the SVG"""
251
- for i in range(count):
252
- x1 = np.random.randint(0, width)
253
- y1 = np.random.randint(0, height)
254
- x2 = np.random.randint(0, width)
255
- y2 = np.random.randint(0, height)
256
-
257
- color = np.random.choice(colors)
258
- opacity = np.random.uniform(0.3, 0.7)
259
- stroke_width = np.random.randint(1, 3)
260
-
261
- paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
262
 
263
- def _add_general_sketch_elements(self, paths, width, height, colors, count):
264
- """Add general sketch elements when no specific shapes are detected"""
265
- for i in range(count // 3):
266
- # Mix of circles, rectangles, and lines
267
- element_type = np.random.choice(['circle', 'rect', 'line'])
268
- color = np.random.choice(colors)
269
- opacity = np.random.uniform(0.3, 0.8)
270
-
271
- if element_type == 'circle':
272
- cx = np.random.randint(20, width - 20)
273
- cy = np.random.randint(20, height - 20)
274
- r = np.random.randint(5, 25)
275
- paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="none" stroke="{color}" stroke-width="2" opacity="{opacity}"/>')
276
 
277
- elif element_type == 'rect':
278
- x = np.random.randint(10, width - 40)
279
- y = np.random.randint(10, height - 40)
280
- w = np.random.randint(15, 40)
281
- h = np.random.randint(15, 40)
282
- paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="none" stroke="{color}" stroke-width="2" opacity="{opacity}"/>')
283
 
284
- else:
285
- x1 = np.random.randint(0, width)
286
- y1 = np.random.randint(0, height)
287
- x2 = np.random.randint(0, width)
288
- y2 = np.random.randint(0, height)
289
- paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="2" opacity="{opacity}"/>')
290
 
291
- def _add_sketch_lines(self, paths, width, height, colors, count):
292
- """Add random sketch lines for artistic effect"""
293
- for i in range(count):
294
- x1 = np.random.randint(0, width)
295
- y1 = np.random.randint(0, height)
296
- x2 = x1 + np.random.randint(-50, 50)
297
- y2 = y1 + np.random.randint(-50, 50)
298
-
299
- color = np.random.choice(colors)
300
- opacity = np.random.uniform(0.2, 0.6)
301
- stroke_width = np.random.randint(1, 2)
302
-
303
- paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
304
-
305
- # Create handler instance
306
- handler = EndpointHandler()
 
1
  import os
2
  import sys
3
+ import tempfile
4
+ import shutil
5
+ from pathlib import Path
6
  import torch
7
+ import yaml
8
+ from omegaconf import OmegaConf
 
9
  from PIL import Image
 
10
  import io
11
+ import cairosvg
12
+
13
+ # Add DiffSketcher modules to path
14
+ sys.path.append('/workspace/DiffSketcher')
15
 
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
+ """Initialize DiffSketcher model for Hugging Face Inference API"""
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Initializing DiffSketcher on {self.device}")
21
 
 
 
22
  try:
23
  # Import DiffSketcher modules
24
+ from libs.engine import ModelState
25
+ from methods.painter.diffsketcher import DiffSketcher
26
 
27
+ # Load configuration
28
+ config_path = Path(path) / "config" / "diffsketcher.yaml"
29
+ if not config_path.exists():
30
+ # Use default config
31
+ config_path = Path(__file__).parent / "config" / "diffsketcher.yaml"
 
 
32
 
33
+ with open(config_path, 'r') as f:
34
+ self.config = OmegaConf.load(f)
35
+
36
+ # Initialize model components
37
+ self.model_state = ModelState(self.config)
38
+ self.painter = DiffSketcher(self.config, self.device, self.model_state)
39
 
40
+ print("DiffSketcher initialized successfully")
 
41
 
42
  except Exception as e:
43
+ print(f"Error initializing DiffSketcher: {e}")
44
+ # Fall back to simple SVG generation
45
+ self.painter = None
46
+ self.config = None
47
 
48
+ def __call__(self, data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
+ Generate sketch image from text prompt
51
+
52
+ Args:
53
+ data (dict): Input data containing:
54
+ - inputs (str): Text prompt
55
+ - parameters (dict): Generation parameters
56
+
57
+ Returns:
58
+ PIL.Image.Image: Generated sketch image
59
  """
60
  try:
61
  # Extract inputs
62
+ prompt = data.get("inputs", "")
63
+ parameters = data.get("parameters", {})
 
 
 
 
64
 
65
  if not prompt:
66
+ return self._create_error_image("No prompt provided")
67
 
68
  # Extract parameters
69
  num_paths = parameters.get("num_paths", 96)
70
+ num_iter = parameters.get("num_iter", 500)
71
+ guidance_scale = parameters.get("guidance_scale", 7.5)
72
+ seed = parameters.get("seed", 42)
73
  width = parameters.get("width", 224)
74
  height = parameters.get("height", 224)
 
 
75
 
76
+ # Generate SVG
77
+ if self.painter is not None:
78
+ svg_content = self._generate_with_diffsketcher(
79
+ prompt, num_paths, num_iter, guidance_scale, seed
80
+ )
81
+ else:
82
+ svg_content = self._generate_fallback_svg(prompt, width, height)
83
 
84
  # Convert SVG to PIL Image
85
+ image = self._svg_to_image(svg_content, width, height)
86
+ return image
 
 
 
 
 
 
87
 
88
  except Exception as e:
89
+ print(f"Error in DiffSketcher inference: {e}")
90
+ return self._create_error_image(f"Error: {str(e)[:50]}")
 
91
 
92
+ def _generate_with_diffsketcher(self, prompt, num_paths, num_iter, guidance_scale, seed):
93
+ """Generate SVG using actual DiffSketcher model"""
94
+ try:
95
+ # Set random seed
96
+ torch.manual_seed(seed)
97
+
98
+ # Create temporary directory for output
99
+ with tempfile.TemporaryDirectory() as temp_dir:
100
+ output_dir = Path(temp_dir) / "output"
101
+ output_dir.mkdir(exist_ok=True)
102
+
103
+ # Update config with parameters
104
+ config = self.config.copy()
105
+ config.num_paths = num_paths
106
+ config.num_iter = num_iter
107
+ config.guidance_scale = guidance_scale
108
+ config.prompt = prompt
109
+ config.output_dir = str(output_dir)
110
+
111
+ # Generate sketch
112
+ self.painter.paint(
113
+ prompt=prompt,
114
+ output_dir=str(output_dir),
115
+ num_paths=num_paths,
116
+ num_iter=num_iter
117
+ )
118
+
119
+ # Find generated SVG file
120
+ svg_files = list(output_dir.glob("*.svg"))
121
+ if svg_files:
122
+ with open(svg_files[0], 'r') as f:
123
+ return f.read()
124
+ else:
125
+ raise Exception("No SVG file generated")
126
+
127
+ except Exception as e:
128
+ print(f"DiffSketcher generation failed: {e}")
129
+ return self._generate_fallback_svg(prompt, 224, 224)
130
+
131
+ def _generate_fallback_svg(self, prompt, width, height):
132
+ """Generate simple SVG when model fails"""
133
+ import random
134
+ import math
135
+
136
+ # Set seed for reproducibility
137
+ random.seed(hash(prompt) % 1000)
138
 
139
+ svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
140
+ svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
141
 
142
+ # Generate sketch based on prompt keywords
143
  prompt_lower = prompt.lower()
144
+ cx, cy = width // 2, height // 2
145
 
146
+ if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']):
147
+ # Simple car sketch
148
+ svg_parts.extend([
149
+ f'<rect x="{cx-60}" y="{cy-20}" width="120" height="40" fill="none" stroke="black" stroke-width="2"/>',
150
+ f'<rect x="{cx-40}" y="{cy-40}" width="80" height="20" fill="none" stroke="black" stroke-width="2"/>',
151
+ f'<circle cx="{cx-35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>',
152
+ f'<circle cx="{cx+35}" cy="{cy+20}" r="10" fill="none" stroke="black" stroke-width="2"/>'
153
+ ])
154
+ elif any(word in prompt_lower for word in ['house', 'building', 'home']):
155
+ # Simple house sketch
156
+ svg_parts.extend([
157
+ f'<rect x="{cx-50}" y="{cy-10}" width="100" height="50" fill="none" stroke="black" stroke-width="2"/>',
158
+ f'<polygon points="{cx-60},{cy-10} {cx},{cy-50} {cx+60},{cy-10}" fill="none" stroke="black" stroke-width="2"/>',
159
+ f'<rect x="{cx-15}" y="{cy+10}" width="30" height="30" fill="none" stroke="black" stroke-width="2"/>',
160
+ f'<rect x="{cx-40}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>',
161
+ f'<rect x="{cx+25}" y="{cy-5}" width="15" height="15" fill="none" stroke="black" stroke-width="1"/>'
162
+ ])
163
  else:
164
+ # Abstract sketch
165
+ for i in range(5):
166
+ x = random.randint(20, width-20)
167
+ y = random.randint(20, height-20)
168
+ size = random.randint(10, 30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ if i % 3 == 0:
171
+ svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
172
+ elif i % 3 == 1:
173
+ svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
174
+ else:
175
+ points = []
176
+ for j in range(3):
177
+ px = x + size * math.cos(j * 120 * math.pi / 180)
178
+ py = y + size * math.sin(j * 120 * math.pi / 180)
179
+ points.append(f"{px},{py}")
180
+ svg_parts.append(f'<polygon points="{" ".join(points)}" fill="none" stroke="black" stroke-width="2"/>')
181
+
182
+ svg_parts.append('</svg>')
183
+ return '\n'.join(svg_parts)
 
 
 
 
 
 
 
 
184
 
185
+ def _svg_to_image(self, svg_content, width=224, height=224):
186
+ """Convert SVG to PIL Image"""
187
+ try:
188
+ # Convert SVG to PNG using cairosvg
189
+ png_data = cairosvg.svg2png(
190
+ bytestring=svg_content.encode('utf-8'),
191
+ output_width=width,
192
+ output_height=height
193
+ )
 
 
 
 
194
 
195
+ # Convert to PIL Image
196
+ image = Image.open(io.BytesIO(png_data))
197
+ return image.convert('RGB')
 
 
 
198
 
199
+ except Exception as e:
200
+ print(f"Error converting SVG to image: {e}")
201
+ return self._create_error_image("SVG conversion failed")
 
 
 
202
 
203
+ def _create_error_image(self, message, width=224, height=224):
204
+ """Create error image"""
205
+ image = Image.new('RGB', (width, height), 'white')
206
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
handler_fallback.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import io
3
+ import random
4
+ import math
5
+
6
+ class EndpointHandler:
7
+ def __init__(self):
8
+ """Initialize the DiffSketcher handler with fallback PIL drawing"""
9
+ pass
10
+
11
+ def __call__(self, data):
12
+ """
13
+ Generate a sketch-style image using PIL drawing (fallback method)
14
+
15
+ Args:
16
+ data (dict): Input data containing:
17
+ - inputs (str): Text prompt
18
+ - parameters (dict): Generation parameters
19
+
20
+ Returns:
21
+ PIL.Image.Image: Generated sketch image
22
+ """
23
+ try:
24
+ # Extract inputs
25
+ prompt = data.get("inputs", "")
26
+ parameters = data.get("parameters", {})
27
+
28
+ # Extract parameters
29
+ width = parameters.get("width", 224)
30
+ height = parameters.get("height", 224)
31
+ guidance_scale = parameters.get("guidance_scale", 7.5)
32
+ seed = parameters.get("seed", 42)
33
+
34
+ # Set random seed for reproducibility
35
+ random.seed(seed)
36
+
37
+ # Create white background
38
+ image = Image.new('RGB', (width, height), 'white')
39
+ draw = ImageDraw.Draw(image)
40
+
41
+ # Generate sketch based on prompt keywords
42
+ self._draw_sketch_from_prompt(draw, prompt, width, height)
43
+
44
+ return image
45
+
46
+ except Exception as e:
47
+ # Return error image
48
+ error_image = Image.new('RGB', (224, 224), 'white')
49
+ error_draw = ImageDraw.Draw(error_image)
50
+ error_draw.text((10, 100), f"Error: {str(e)[:30]}", fill='red')
51
+ return error_image
52
+
53
+ def _draw_sketch_from_prompt(self, draw, prompt, width, height):
54
+ """Draw a simple sketch based on prompt keywords"""
55
+ prompt_lower = prompt.lower()
56
+
57
+ # Define colors for sketching
58
+ colors = ['black', 'gray', 'darkgray']
59
+
60
+ if any(word in prompt_lower for word in ['car', 'vehicle', 'automobile']):
61
+ self._draw_car(draw, width, height, colors)
62
+ elif any(word in prompt_lower for word in ['house', 'building', 'home']):
63
+ self._draw_house(draw, width, height, colors)
64
+ elif any(word in prompt_lower for word in ['flower', 'plant', 'bloom']):
65
+ self._draw_flower(draw, width, height, colors)
66
+ elif any(word in prompt_lower for word in ['tree', 'forest']):
67
+ self._draw_tree(draw, width, height, colors)
68
+ elif any(word in prompt_lower for word in ['mountain', 'landscape']):
69
+ self._draw_mountain(draw, width, height, colors)
70
+ else:
71
+ self._draw_abstract(draw, width, height, colors)
72
+
73
+ def _draw_car(self, draw, width, height, colors):
74
+ """Draw a simple car sketch"""
75
+ cx, cy = width // 2, height // 2
76
+
77
+ # Car body
78
+ draw.rectangle([cx-60, cy-20, cx+60, cy+20], outline=colors[0], width=2)
79
+
80
+ # Car roof
81
+ draw.rectangle([cx-40, cy-40, cx+40, cy-20], outline=colors[0], width=2)
82
+
83
+ # Wheels
84
+ draw.ellipse([cx-50, cy+10, cx-30, cy+30], outline=colors[0], width=2)
85
+ draw.ellipse([cx+30, cy+10, cx+50, cy+30], outline=colors[0], width=2)
86
+
87
+ # Windows
88
+ draw.rectangle([cx-35, cy-35, cx+35, cy-25], outline=colors[1], width=1)
89
+
90
+ def _draw_house(self, draw, width, height, colors):
91
+ """Draw a simple house sketch"""
92
+ cx, cy = width // 2, height // 2
93
+
94
+ # House base
95
+ draw.rectangle([cx-50, cy-10, cx+50, cy+40], outline=colors[0], width=2)
96
+
97
+ # Roof
98
+ draw.polygon([cx-60, cy-10, cx, cy-50, cx+60, cy-10], outline=colors[0], width=2)
99
+
100
+ # Door
101
+ draw.rectangle([cx-15, cy+10, cx+15, cy+40], outline=colors[1], width=2)
102
+
103
+ # Windows
104
+ draw.rectangle([cx-40, cy-5, cx-25, cy+10], outline=colors[1], width=1)
105
+ draw.rectangle([cx+25, cy-5, cx+40, cy+10], outline=colors[1], width=1)
106
+
107
+ def _draw_flower(self, draw, width, height, colors):
108
+ """Draw a simple flower sketch"""
109
+ cx, cy = width // 2, height // 2
110
+
111
+ # Stem
112
+ draw.line([cx, cy+20, cx, cy+60], fill=colors[0], width=3)
113
+
114
+ # Petals
115
+ for i in range(6):
116
+ angle = i * 60 * math.pi / 180
117
+ x = cx + 25 * math.cos(angle)
118
+ y = cy + 25 * math.sin(angle)
119
+ draw.ellipse([x-8, y-8, x+8, y+8], outline=colors[0], width=2)
120
+
121
+ # Center
122
+ draw.ellipse([cx-8, cy-8, cx+8, cy+8], fill=colors[1], outline=colors[0], width=2)
123
+
124
+ # Leaves
125
+ draw.ellipse([cx-10, cy+30, cx+10, cy+50], outline=colors[0], width=2)
126
+
127
+ def _draw_tree(self, draw, width, height, colors):
128
+ """Draw a simple tree sketch"""
129
+ cx, cy = width // 2, height // 2
130
+
131
+ # Trunk
132
+ draw.rectangle([cx-8, cy+10, cx+8, cy+60], outline=colors[0], width=2)
133
+
134
+ # Tree crown (circle)
135
+ draw.ellipse([cx-40, cy-40, cx+40, cy+20], outline=colors[0], width=2)
136
+
137
+ # Branches
138
+ for i in range(5):
139
+ angle = (i * 72 - 90) * math.pi / 180
140
+ x = cx + 30 * math.cos(angle)
141
+ y = cy + 30 * math.sin(angle)
142
+ draw.line([cx, cy, x, y], fill=colors[1], width=1)
143
+
144
+ def _draw_mountain(self, draw, width, height, colors):
145
+ """Draw a simple mountain landscape"""
146
+ cx, cy = width // 2, height // 2
147
+
148
+ # Mountains
149
+ draw.polygon([20, cy+30, 80, cy-40, 140, cy+30], outline=colors[0], width=2)
150
+ draw.polygon([100, cy+30, 160, cy-20, 200, cy+30], outline=colors[0], width=2)
151
+
152
+ # Ground line
153
+ draw.line([0, cy+30, width, cy+30], fill=colors[1], width=1)
154
+
155
+ # Sun
156
+ draw.ellipse([width-60, 20, width-20, 60], outline=colors[1], width=2)
157
+
158
+ def _draw_abstract(self, draw, width, height, colors):
159
+ """Draw abstract shapes for unknown prompts"""
160
+ cx, cy = width // 2, height // 2
161
+
162
+ # Random geometric shapes
163
+ for i in range(5):
164
+ x = random.randint(20, width-20)
165
+ y = random.randint(20, height-20)
166
+ size = random.randint(10, 30)
167
+
168
+ if i % 3 == 0:
169
+ draw.ellipse([x-size, y-size, x+size, y+size], outline=colors[i%len(colors)], width=2)
170
+ elif i % 3 == 1:
171
+ draw.rectangle([x-size, y-size, x+size, y+size], outline=colors[i%len(colors)], width=2)
172
+ else:
173
+ points = []
174
+ for j in range(3):
175
+ px = x + size * math.cos(j * 120 * math.pi / 180)
176
+ py = y + size * math.sin(j * 120 * math.pi / 180)
177
+ points.extend([px, py])
178
+ draw.polygon(points, outline=colors[i%len(colors)], width=2)
requirements.txt CHANGED
@@ -1,9 +1,23 @@
1
- torch>=2.0.0
2
- torchvision>=0.15.0
3
- transformers>=4.21.0
4
- svgwrite>=1.4.0
5
- Pillow>=8.3.0
6
  numpy>=1.21.0
7
- requests>=2.25.0
8
- accelerate>=0.12.0
9
- safetensors>=0.3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.12.0
2
+ torchvision>=0.13.0
 
 
 
3
  numpy>=1.21.0
4
+ Pillow>=8.0.0
5
+ cairosvg>=2.5.0
6
+ omegaconf>=2.1.0
7
+ hydra-core>=1.1.0
8
+ diffusers>=0.20.0
9
+ transformers>=4.20.0
10
+ accelerate>=0.20.0
11
+ svgwrite>=1.4.0
12
+ svgpathtools>=1.4.0
13
+ freetype-py>=2.3.0
14
+ shapely>=1.8.0
15
+ opencv-python>=4.5.0
16
+ scikit-image>=0.19.0
17
+ matplotlib>=3.5.0
18
+ scipy>=1.8.0
19
+ einops>=0.4.0
20
+ timm>=0.6.0
21
+ ftfy>=6.1.0
22
+ regex>=2022.0.0
23
+ tqdm>=4.64.0