jree423 commited on
Commit
853fd42
·
verified ·
1 Parent(s): 94b6f81

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +55 -130
handler.py CHANGED
@@ -1,48 +1,15 @@
1
  import base64
2
  import json
3
- import torch
4
- import svgwrite
5
  import math
6
  from typing import Dict, Any
7
- from diffusers import StableDiffusionPipeline
8
- import clip
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  """Initialize the DiffSketcher model"""
13
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
-
15
- # Load Stable Diffusion model
16
- try:
17
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
18
- "runwayml/stable-diffusion-v1-5",
19
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
20
- safety_checker=None,
21
- requires_safety_checker=False
22
- ).to(self.device)
23
- except Exception as e:
24
- print(f"Warning: Could not load Stable Diffusion: {e}")
25
- self.sd_pipeline = None
26
-
27
- # Load CLIP model
28
- try:
29
- self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
30
- except Exception as e:
31
- print(f"Warning: Could not load CLIP: {e}")
32
- self.clip_model = None
33
 
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
- """
36
- Process the request and generate SVG
37
-
38
- Args:
39
- data: Dictionary containing:
40
- - inputs: The prompt text
41
- - parameters: Optional parameters like num_paths, width, height
42
-
43
- Returns:
44
- Dictionary with SVG content
45
- """
46
  try:
47
  # Extract inputs
48
  if isinstance(data, dict):
@@ -60,7 +27,7 @@ class EndpointHandler:
60
  width = parameters.get("width", 512)
61
  height = parameters.get("height", 512)
62
 
63
- # Generate SVG using DiffSketcher style
64
  svg_content = self.generate_diffsketcher_svg(prompt, num_paths, width, height)
65
 
66
  # Encode as base64
@@ -83,120 +50,76 @@ class EndpointHandler:
83
 
84
  def generate_diffsketcher_svg(self, prompt, num_paths, width, height):
85
  """Generate SVG in DiffSketcher style (painterly, sketchy)"""
86
- dwg = svgwrite.Drawing(size=(f"{width}px", f"{height}px"))
87
-
88
- # White background
89
- dwg.add(dwg.rect(insert=(0, 0), size=("100%", "100%"), fill="white"))
90
 
91
- # Generate paths based on prompt analysis
92
  center_x, center_y = width // 2, height // 2
93
-
94
- # Analyze prompt for content type
95
  prompt_lower = prompt.lower()
96
 
97
  if any(word in prompt_lower for word in ["cat", "animal", "pet"]):
98
- self._draw_cat_sketch(dwg, center_x, center_y, num_paths)
99
  elif any(word in prompt_lower for word in ["flower", "plant", "bloom"]):
100
- self._draw_flower_sketch(dwg, center_x, center_y, num_paths)
101
  elif any(word in prompt_lower for word in ["house", "building", "home"]):
102
- self._draw_house_sketch(dwg, center_x, center_y, num_paths)
103
- elif any(word in prompt_lower for word in ["mountain", "landscape", "nature"]):
104
- self._draw_landscape_sketch(dwg, center_x, center_y, num_paths)
105
  else:
106
- self._draw_abstract_sketch(dwg, center_x, center_y, num_paths)
107
 
108
- # Add prompt as text
109
- dwg.add(dwg.text(f"DiffSketcher: {prompt}", insert=(10, height-10),
110
- fill="gray", font_size="12px"))
111
 
112
- return dwg.tostring()
113
 
114
- def _draw_cat_sketch(self, dwg, cx, cy, num_paths):
115
  """Draw a sketchy cat"""
116
- # Head
117
- dwg.add(dwg.circle(center=(cx, cy-20), r=60, fill="none", stroke="black", stroke_width=3))
118
-
119
- # Ears
120
- dwg.add(dwg.polygon(points=[(cx-40, cy-60), (cx-20, cy-80), (cx-10, cy-50)],
121
- fill="none", stroke="black", stroke_width=2))
122
- dwg.add(dwg.polygon(points=[(cx+40, cy-60), (cx+20, cy-80), (cx+10, cy-50)],
123
- fill="none", stroke="black", stroke_width=2))
124
-
125
- # Eyes
126
- dwg.add(dwg.circle(center=(cx-20, cy-10), r=8, fill="black"))
127
- dwg.add(dwg.circle(center=(cx+20, cy-10), r=8, fill="black"))
128
-
129
- # Nose
130
- dwg.add(dwg.polygon(points=[(cx-5, cy+10), (cx+5, cy+10), (cx, cy+20)], fill="pink"))
131
-
132
- # Whiskers
133
- dwg.add(dwg.line(start=(cx-50, cy), end=(cx-70, cy-5), stroke="black", stroke_width=1))
134
- dwg.add(dwg.line(start=(cx+50, cy), end=(cx+70, cy-5), stroke="black", stroke_width=1))
135
-
136
- # Body
137
- dwg.add(dwg.ellipse(center=(cx, cy+80), r=(40, 60), fill="none", stroke="black", stroke_width=3))
138
 
139
- def _draw_flower_sketch(self, dwg, cx, cy, num_paths):
140
  """Draw a sketchy flower"""
141
- # Petals in a circle
142
  for i in range(8):
143
  angle = i * 45
144
  petal_x = cx + 50 * math.cos(math.radians(angle))
145
  petal_y = cy + 50 * math.sin(math.radians(angle))
146
-
147
- dwg.add(dwg.ellipse(
148
- center=(petal_x, petal_y),
149
- r=(20, 35),
150
- fill="pink",
151
- stroke="red",
152
- stroke_width=2,
153
- transform=f"rotate({angle} {petal_x} {petal_y})"
154
- ))
155
-
156
- # Center
157
- dwg.add(dwg.circle(center=(cx, cy), r=15, fill="yellow", stroke="orange", stroke_width=2))
158
-
159
- # Stem
160
- dwg.add(dwg.line(start=(cx, cy+15), end=(cx, cy+120), stroke="green", stroke_width=4))
161
-
162
- # Leaves
163
- dwg.add(dwg.ellipse(center=(cx-20, cy+80), r=(15, 25), fill="lightgreen", stroke="green", stroke_width=2))
164
- dwg.add(dwg.ellipse(center=(cx+20, cy+90), r=(15, 25), fill="lightgreen", stroke="green", stroke_width=2))
165
 
166
- def _draw_house_sketch(self, dwg, cx, cy, num_paths):
167
  """Draw a sketchy house"""
168
- # House base
169
- dwg.add(dwg.rect(insert=(cx-50, cy), size=(100, 60), fill="lightblue", stroke="blue", stroke_width=3))
170
-
171
- # Roof
172
- dwg.add(dwg.polygon(points=[(cx-60, cy), (cx, cy-50), (cx+60, cy)], fill="red", stroke="darkred", stroke_width=2))
173
-
174
- # Door
175
- dwg.add(dwg.rect(insert=(cx-15, cy+20), size=(30, 40), fill="brown"))
176
-
177
- # Windows
178
- dwg.add(dwg.rect(insert=(cx-40, cy+15), size=(20, 20), fill="lightblue", stroke="blue", stroke_width=2))
179
- dwg.add(dwg.rect(insert=(cx+20, cy+15), size=(20, 20), fill="lightblue", stroke="blue", stroke_width=2))
180
 
181
- def _draw_landscape_sketch(self, dwg, cx, cy, num_paths):
182
- """Draw a sketchy landscape"""
183
- # Mountains
184
- dwg.add(dwg.polygon(points=[(0, cy), (cx-100, cy-100), (cx-50, cy-80), (cx, cy-120),
185
- (cx+50, cy-90), (cx+100, cy-110), (cx+200, cy)],
186
- fill="lightgray", stroke="gray", stroke_width=2))
187
-
188
- # Trees
189
- for i, tree_x in enumerate([cx-80, cx-20, cx+40, cx+100]):
190
- tree_y = cy + 20
191
- # Trunk
192
- dwg.add(dwg.rect(insert=(tree_x-5, tree_y), size=(10, 40), fill="brown"))
193
- # Foliage
194
- dwg.add(dwg.circle(center=(tree_x, tree_y-10), r=25, fill="green", stroke="darkgreen", stroke_width=2))
195
-
196
- def _draw_abstract_sketch(self, dwg, cx, cy, num_paths):
197
  """Draw abstract sketchy shapes"""
198
  import random
 
199
 
 
200
  colors = ["red", "blue", "green", "orange", "purple", "pink", "yellow"]
201
 
202
  for i in range(min(num_paths, 12)):
@@ -206,9 +129,11 @@ class EndpointHandler:
206
  color = random.choice(colors)
207
 
208
  if i % 3 == 0:
209
- dwg.add(dwg.circle(center=(x, y), r=r, fill="none", stroke=color, stroke_width=3))
210
  elif i % 3 == 1:
211
- dwg.add(dwg.rect(insert=(x-r//2, y-r//2), size=(r, r), fill="none", stroke=color, stroke_width=2))
212
  else:
213
- points = [(x, y-r), (x+r, y+r), (x-r, y+r)]
214
- dwg.add(dwg.polygon(points=points, fill="none", stroke=color, stroke_width=2))
 
 
 
1
  import base64
2
  import json
 
 
3
  import math
4
  from typing import Dict, Any
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  """Initialize the DiffSketcher model"""
9
+ print("DiffSketcher handler initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
+ """Generate SVG using DiffSketcher style"""
 
 
 
 
 
 
 
 
 
 
13
  try:
14
  # Extract inputs
15
  if isinstance(data, dict):
 
27
  width = parameters.get("width", 512)
28
  height = parameters.get("height", 512)
29
 
30
+ # Generate SVG content
31
  svg_content = self.generate_diffsketcher_svg(prompt, num_paths, width, height)
32
 
33
  # Encode as base64
 
50
 
51
  def generate_diffsketcher_svg(self, prompt, num_paths, width, height):
52
  """Generate SVG in DiffSketcher style (painterly, sketchy)"""
53
+ svg_parts = [
54
+ f'<svg baseProfile="full" height="{height}px" version="1.1" width="{width}px" xmlns="http://www.w3.org/2000/svg">',
55
+ f'<rect fill="white" height="100%" width="100%" x="0" y="0" />',
56
+ ]
57
 
58
+ # Generate content based on prompt
59
  center_x, center_y = width // 2, height // 2
 
 
60
  prompt_lower = prompt.lower()
61
 
62
  if any(word in prompt_lower for word in ["cat", "animal", "pet"]):
63
+ svg_parts.extend(self._draw_cat_sketch(center_x, center_y))
64
  elif any(word in prompt_lower for word in ["flower", "plant", "bloom"]):
65
+ svg_parts.extend(self._draw_flower_sketch(center_x, center_y))
66
  elif any(word in prompt_lower for word in ["house", "building", "home"]):
67
+ svg_parts.extend(self._draw_house_sketch(center_x, center_y))
 
 
68
  else:
69
+ svg_parts.extend(self._draw_abstract_sketch(center_x, center_y, num_paths))
70
 
71
+ # Add prompt text
72
+ svg_parts.append(f'<text fill="gray" font-size="12px" x="10" y="{height-10}">DiffSketcher: {prompt}</text>')
73
+ svg_parts.append('</svg>')
74
 
75
+ return ''.join(svg_parts)
76
 
77
+ def _draw_cat_sketch(self, cx, cy):
78
  """Draw a sketchy cat"""
79
+ return [
80
+ f'<circle cx="{cx}" cy="{cy-20}" r="60" fill="none" stroke="black" stroke-width="3" />',
81
+ f'<polygon points="{cx-40},{cy-60} {cx-20},{cy-80} {cx-10},{cy-50}" fill="none" stroke="black" stroke-width="2" />',
82
+ f'<polygon points="{cx+40},{cy-60} {cx+20},{cy-80} {cx+10},{cy-50}" fill="none" stroke="black" stroke-width="2" />',
83
+ f'<circle cx="{cx-20}" cy="{cy-10}" r="8" fill="black" />',
84
+ f'<circle cx="{cx+20}" cy="{cy-10}" r="8" fill="black" />',
85
+ f'<polygon points="{cx-5},{cy+10} {cx+5},{cy+10} {cx},{cy+20}" fill="pink" />',
86
+ f'<line x1="{cx-50}" y1="{cy}" x2="{cx-70}" y2="{cy-5}" stroke="black" stroke-width="1" />',
87
+ f'<line x1="{cx+50}" y1="{cy}" x2="{cx+70}" y2="{cy-5}" stroke="black" stroke-width="1" />',
88
+ f'<ellipse cx="{cx}" cy="{cy+80}" rx="40" ry="60" fill="none" stroke="black" stroke-width="3" />',
89
+ ]
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ def _draw_flower_sketch(self, cx, cy):
92
  """Draw a sketchy flower"""
93
+ petals = []
94
  for i in range(8):
95
  angle = i * 45
96
  petal_x = cx + 50 * math.cos(math.radians(angle))
97
  petal_y = cy + 50 * math.sin(math.radians(angle))
98
+ petals.append(f'<ellipse cx="{petal_x}" cy="{petal_y}" rx="20" ry="35" fill="pink" stroke="red" stroke-width="2" transform="rotate({angle} {petal_x} {petal_y})" />')
99
+
100
+ return petals + [
101
+ f'<circle cx="{cx}" cy="{cy}" r="15" fill="yellow" stroke="orange" stroke-width="2" />',
102
+ f'<line x1="{cx}" y1="{cy+15}" x2="{cx}" y2="{cy+120}" stroke="green" stroke-width="4" />',
103
+ f'<ellipse cx="{cx-20}" cy="{cy+80}" rx="15" ry="25" fill="lightgreen" stroke="green" stroke-width="2" />',
104
+ f'<ellipse cx="{cx+20}" cy="{cy+90}" rx="15" ry="25" fill="lightgreen" stroke="green" stroke-width="2" />',
105
+ ]
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ def _draw_house_sketch(self, cx, cy):
108
  """Draw a sketchy house"""
109
+ return [
110
+ f'<rect x="{cx-50}" y="{cy}" width="100" height="60" fill="lightblue" stroke="blue" stroke-width="3" />',
111
+ f'<polygon points="{cx-60},{cy} {cx},{cy-50} {cx+60},{cy}" fill="red" stroke="darkred" stroke-width="2" />',
112
+ f'<rect x="{cx-15}" y="{cy+20}" width="30" height="40" fill="brown" />',
113
+ f'<rect x="{cx-40}" y="{cy+15}" width="20" height="20" fill="lightblue" stroke="blue" stroke-width="2" />',
114
+ f'<rect x="{cx+20}" y="{cy+15}" width="20" height="20" fill="lightblue" stroke="blue" stroke-width="2" />',
115
+ ]
 
 
 
 
 
116
 
117
+ def _draw_abstract_sketch(self, cx, cy, num_paths):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  """Draw abstract sketchy shapes"""
119
  import random
120
+ random.seed(42) # For consistent results
121
 
122
+ shapes = []
123
  colors = ["red", "blue", "green", "orange", "purple", "pink", "yellow"]
124
 
125
  for i in range(min(num_paths, 12)):
 
129
  color = random.choice(colors)
130
 
131
  if i % 3 == 0:
132
+ shapes.append(f'<circle cx="{x}" cy="{y}" r="{r}" fill="none" stroke="{color}" stroke-width="3" />')
133
  elif i % 3 == 1:
134
+ shapes.append(f'<rect x="{x-r//2}" y="{y-r//2}" width="{r}" height="{r}" fill="none" stroke="{color}" stroke-width="2" />')
135
  else:
136
+ points = f"{x},{y-r} {x+r},{y+r} {x-r},{y+r}"
137
+ shapes.append(f'<polygon points="{points}" fill="none" stroke="{color}" stroke-width="2" />')
138
+
139
+ return shapes