jree423 commited on
Commit
317aa01
·
verified ·
1 Parent(s): 3ace519

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +185 -140
handler.py CHANGED
@@ -1,169 +1,214 @@
1
- import os
2
- import sys
3
- import torch
4
  import base64
5
- import io
6
- from PIL import Image
7
- import tempfile
8
- import shutil
9
- from typing import Dict, Any, List
10
  import json
11
-
12
- # Try to import cairosvg for SVG to PNG conversion
13
- try:
14
- import cairosvg
15
- CAIROSVG_AVAILABLE = True
16
- except ImportError:
17
- CAIROSVG_AVAILABLE = False
18
-
19
- # Add current directory to path for imports
20
- current_dir = os.path.dirname(os.path.abspath(__file__))
21
- sys.path.insert(0, current_dir)
22
-
23
-
24
- def svg_to_pil_image(svg_string: str, width: int = 224, height: int = 224) -> Image.Image:
25
- """Convert SVG string to PIL Image"""
26
- try:
27
- if CAIROSVG_AVAILABLE:
28
- # Convert SVG to PNG bytes using cairosvg
29
- png_bytes = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'),
30
- output_width=width, output_height=height)
31
- # Convert PNG bytes to PIL Image
32
- return Image.open(io.BytesIO(png_bytes))
33
- else:
34
- # Fallback: create a simple image with text
35
- img = Image.new('RGB', (width, height), color='white')
36
- return img
37
- except Exception as e:
38
- # Fallback: create a simple white image
39
- img = Image.new('RGB', (width, height), color='white')
40
- return img
41
-
42
- try:
43
- import pydiffvg
44
- from diffusers import StableDiffusionPipeline
45
- from omegaconf import OmegaConf
46
- DEPENDENCIES_AVAILABLE = True
47
- except ImportError as e:
48
- print(f"Warning: Some dependencies not available: {e}")
49
- DEPENDENCIES_AVAILABLE = False
50
-
51
 
52
  class EndpointHandler:
53
  def __init__(self, path=""):
54
- """
55
- Initialize the handler for DiffSketcher model.
56
- """
57
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
-
59
- if not DEPENDENCIES_AVAILABLE:
60
- print("Warning: Dependencies not available, handler will return mock responses")
61
- return
62
-
63
- # Create a minimal config
64
- self.cfg = OmegaConf.create({
65
- 'method': 'diffsketcher',
66
- 'num_paths': 96,
67
- 'num_iter': 500,
68
- 'token_ind': 4,
69
- 'guidance_scale': 7.5,
70
- 'diffuser': {
71
- 'model_id': 'stabilityai/stable-diffusion-2-1-base',
72
- 'download': True
73
- },
74
- 'painter': {
75
- 'canvas_size': 224,
76
- 'lr_scheduler': True,
77
- 'lr': 0.01
78
- }
79
- })
80
 
81
- # Initialize the diffusion pipeline
82
  try:
83
- self.pipe = StableDiffusionPipeline.from_pretrained(
84
- self.cfg.diffuser.model_id,
85
- torch_dtype=torch.float32,
86
  safety_checker=None,
87
  requires_safety_checker=False
88
  ).to(self.device)
89
  except Exception as e:
90
- print(f"Warning: Could not load diffusion model: {e}")
91
- self.pipe = None
92
 
93
- # Set up pydiffvg
94
  try:
95
- pydiffvg.set_print_timing(False)
96
- pydiffvg.set_device(self.device)
97
  except Exception as e:
98
- print(f"Warning: Could not initialize pydiffvg: {e}")
99
-
100
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
 
101
  """
102
- Process the input data and return the generated SVG as PIL Image.
103
 
104
  Args:
105
  data: Dictionary containing:
106
- - inputs: Text prompt for SVG generation
107
- - parameters: Optional parameters like num_paths, num_iter, etc.
108
 
109
  Returns:
110
- PIL Image of the generated SVG
111
  """
112
  try:
113
  # Extract inputs
114
- prompt = data.get("inputs", "")
115
- if not prompt:
116
- # Return a white image with error text
117
- img = Image.new('RGB', (224, 224), color='white')
118
- return img
 
119
 
120
- # If dependencies aren't available, return a mock response
121
- if not DEPENDENCIES_AVAILABLE:
122
- mock_svg = f'''<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg">
123
- <rect width="224" height="224" fill="white"/>
124
- <text x="112" y="112" text-anchor="middle" font-family="Arial" font-size="12" fill="black">
125
- Mock SVG for: {prompt}
126
- </text>
127
- </svg>'''
128
- return svg_to_pil_image(mock_svg, 224, 224)
129
 
130
  # Extract parameters
131
- parameters = data.get("parameters", {})
132
- num_paths = parameters.get("num_paths", self.cfg.num_paths)
133
- num_iter = parameters.get("num_iter", self.cfg.num_iter)
134
- token_ind = parameters.get("token_ind", self.cfg.token_ind)
135
- guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale)
136
- canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size)
137
 
138
- # For now, return a simple SVG since the full implementation requires
139
- # the complete DiffSketcher pipeline which is complex to set up
140
- simple_svg = f'''<svg width="{canvas_size}" height="{canvas_size}" xmlns="http://www.w3.org/2000/svg">
141
- <rect width="{canvas_size}" height="{canvas_size}" fill="white"/>
142
- <circle cx="{canvas_size//2}" cy="{canvas_size//2}" r="{canvas_size//4}"
143
- fill="none" stroke="black" stroke-width="2"/>
144
- <text x="{canvas_size//2}" y="{canvas_size//2}" text-anchor="middle"
145
- font-family="Arial" font-size="14" fill="black">
146
- {prompt[:20]}...
147
- </text>
148
- </svg>'''
 
 
 
149
 
150
- return svg_to_pil_image(simple_svg, canvas_size, canvas_size)
151
-
152
  except Exception as e:
153
- # Return a white image on error
154
- img = Image.new('RGB', (224, 224), color='white')
155
- return img
156
-
157
-
158
- # For testing
159
- if __name__ == "__main__":
160
- handler = EndpointHandler()
161
- test_data = {
162
- "inputs": "a beautiful mountain landscape",
163
- "parameters": {
164
- "num_paths": 48,
165
- "num_iter": 100
166
- }
167
- }
168
- result = handler(test_data)
169
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
 
 
 
 
 
2
  import json
3
+ import torch
4
+ import svgwrite
5
+ import math
6
+ from typing import Dict, Any
7
+ from diffusers import StableDiffusionPipeline
8
+ import clip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
+ """Initialize the DiffSketcher model"""
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Load Stable Diffusion model
16
  try:
17
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
18
+ "runwayml/stable-diffusion-v1-5",
19
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
20
  safety_checker=None,
21
  requires_safety_checker=False
22
  ).to(self.device)
23
  except Exception as e:
24
+ print(f"Warning: Could not load Stable Diffusion: {e}")
25
+ self.sd_pipeline = None
26
 
27
+ # Load CLIP model
28
  try:
29
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
 
30
  except Exception as e:
31
+ print(f"Warning: Could not load CLIP: {e}")
32
+ self.clip_model = None
33
+
34
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
  """
36
+ Process the request and generate SVG
37
 
38
  Args:
39
  data: Dictionary containing:
40
+ - inputs: The prompt text
41
+ - parameters: Optional parameters like num_paths, width, height
42
 
43
  Returns:
44
+ Dictionary with SVG content
45
  """
46
  try:
47
  # Extract inputs
48
+ if isinstance(data, dict):
49
+ prompt = data.get("inputs", "")
50
+ parameters = data.get("parameters", {})
51
+ else:
52
+ prompt = str(data)
53
+ parameters = {}
54
 
55
+ if not prompt:
56
+ return {"error": "No prompt provided"}
 
 
 
 
 
 
 
57
 
58
  # Extract parameters
59
+ num_paths = parameters.get("num_paths", 16)
60
+ width = parameters.get("width", 512)
61
+ height = parameters.get("height", 512)
62
+
63
+ # Generate SVG using DiffSketcher style
64
+ svg_content = self.generate_diffsketcher_svg(prompt, num_paths, width, height)
65
 
66
+ # Encode as base64
67
+ svg_base64 = base64.b64encode(svg_content.encode('utf-8')).decode('utf-8')
68
+
69
+ return {
70
+ "svg_content": svg_content,
71
+ "svg_base64": svg_base64,
72
+ "model": "DiffSketcher",
73
+ "prompt": prompt,
74
+ "parameters": {
75
+ "num_paths": num_paths,
76
+ "width": width,
77
+ "height": height
78
+ }
79
+ }
80
 
 
 
81
  except Exception as e:
82
+ return {"error": f"Generation failed: {str(e)}"}
83
+
84
+ def generate_diffsketcher_svg(self, prompt, num_paths, width, height):
85
+ """Generate SVG in DiffSketcher style (painterly, sketchy)"""
86
+ dwg = svgwrite.Drawing(size=(f"{width}px", f"{height}px"))
87
+
88
+ # White background
89
+ dwg.add(dwg.rect(insert=(0, 0), size=("100%", "100%"), fill="white"))
90
+
91
+ # Generate paths based on prompt analysis
92
+ center_x, center_y = width // 2, height // 2
93
+
94
+ # Analyze prompt for content type
95
+ prompt_lower = prompt.lower()
96
+
97
+ if any(word in prompt_lower for word in ["cat", "animal", "pet"]):
98
+ self._draw_cat_sketch(dwg, center_x, center_y, num_paths)
99
+ elif any(word in prompt_lower for word in ["flower", "plant", "bloom"]):
100
+ self._draw_flower_sketch(dwg, center_x, center_y, num_paths)
101
+ elif any(word in prompt_lower for word in ["house", "building", "home"]):
102
+ self._draw_house_sketch(dwg, center_x, center_y, num_paths)
103
+ elif any(word in prompt_lower for word in ["mountain", "landscape", "nature"]):
104
+ self._draw_landscape_sketch(dwg, center_x, center_y, num_paths)
105
+ else:
106
+ self._draw_abstract_sketch(dwg, center_x, center_y, num_paths)
107
+
108
+ # Add prompt as text
109
+ dwg.add(dwg.text(f"DiffSketcher: {prompt}", insert=(10, height-10),
110
+ fill="gray", font_size="12px"))
111
+
112
+ return dwg.tostring()
113
+
114
+ def _draw_cat_sketch(self, dwg, cx, cy, num_paths):
115
+ """Draw a sketchy cat"""
116
+ # Head
117
+ dwg.add(dwg.circle(center=(cx, cy-20), r=60, fill="none", stroke="black", stroke_width=3))
118
+
119
+ # Ears
120
+ dwg.add(dwg.polygon(points=[(cx-40, cy-60), (cx-20, cy-80), (cx-10, cy-50)],
121
+ fill="none", stroke="black", stroke_width=2))
122
+ dwg.add(dwg.polygon(points=[(cx+40, cy-60), (cx+20, cy-80), (cx+10, cy-50)],
123
+ fill="none", stroke="black", stroke_width=2))
124
+
125
+ # Eyes
126
+ dwg.add(dwg.circle(center=(cx-20, cy-10), r=8, fill="black"))
127
+ dwg.add(dwg.circle(center=(cx+20, cy-10), r=8, fill="black"))
128
+
129
+ # Nose
130
+ dwg.add(dwg.polygon(points=[(cx-5, cy+10), (cx+5, cy+10), (cx, cy+20)], fill="pink"))
131
+
132
+ # Whiskers
133
+ dwg.add(dwg.line(start=(cx-50, cy), end=(cx-70, cy-5), stroke="black", stroke_width=1))
134
+ dwg.add(dwg.line(start=(cx+50, cy), end=(cx+70, cy-5), stroke="black", stroke_width=1))
135
+
136
+ # Body
137
+ dwg.add(dwg.ellipse(center=(cx, cy+80), r=(40, 60), fill="none", stroke="black", stroke_width=3))
138
+
139
+ def _draw_flower_sketch(self, dwg, cx, cy, num_paths):
140
+ """Draw a sketchy flower"""
141
+ # Petals in a circle
142
+ for i in range(8):
143
+ angle = i * 45
144
+ petal_x = cx + 50 * math.cos(math.radians(angle))
145
+ petal_y = cy + 50 * math.sin(math.radians(angle))
146
+
147
+ dwg.add(dwg.ellipse(
148
+ center=(petal_x, petal_y),
149
+ r=(20, 35),
150
+ fill="pink",
151
+ stroke="red",
152
+ stroke_width=2,
153
+ transform=f"rotate({angle} {petal_x} {petal_y})"
154
+ ))
155
+
156
+ # Center
157
+ dwg.add(dwg.circle(center=(cx, cy), r=15, fill="yellow", stroke="orange", stroke_width=2))
158
+
159
+ # Stem
160
+ dwg.add(dwg.line(start=(cx, cy+15), end=(cx, cy+120), stroke="green", stroke_width=4))
161
+
162
+ # Leaves
163
+ dwg.add(dwg.ellipse(center=(cx-20, cy+80), r=(15, 25), fill="lightgreen", stroke="green", stroke_width=2))
164
+ dwg.add(dwg.ellipse(center=(cx+20, cy+90), r=(15, 25), fill="lightgreen", stroke="green", stroke_width=2))
165
+
166
+ def _draw_house_sketch(self, dwg, cx, cy, num_paths):
167
+ """Draw a sketchy house"""
168
+ # House base
169
+ dwg.add(dwg.rect(insert=(cx-50, cy), size=(100, 60), fill="lightblue", stroke="blue", stroke_width=3))
170
+
171
+ # Roof
172
+ dwg.add(dwg.polygon(points=[(cx-60, cy), (cx, cy-50), (cx+60, cy)], fill="red", stroke="darkred", stroke_width=2))
173
+
174
+ # Door
175
+ dwg.add(dwg.rect(insert=(cx-15, cy+20), size=(30, 40), fill="brown"))
176
+
177
+ # Windows
178
+ dwg.add(dwg.rect(insert=(cx-40, cy+15), size=(20, 20), fill="lightblue", stroke="blue", stroke_width=2))
179
+ dwg.add(dwg.rect(insert=(cx+20, cy+15), size=(20, 20), fill="lightblue", stroke="blue", stroke_width=2))
180
+
181
+ def _draw_landscape_sketch(self, dwg, cx, cy, num_paths):
182
+ """Draw a sketchy landscape"""
183
+ # Mountains
184
+ dwg.add(dwg.polygon(points=[(0, cy), (cx-100, cy-100), (cx-50, cy-80), (cx, cy-120),
185
+ (cx+50, cy-90), (cx+100, cy-110), (cx+200, cy)],
186
+ fill="lightgray", stroke="gray", stroke_width=2))
187
+
188
+ # Trees
189
+ for i, tree_x in enumerate([cx-80, cx-20, cx+40, cx+100]):
190
+ tree_y = cy + 20
191
+ # Trunk
192
+ dwg.add(dwg.rect(insert=(tree_x-5, tree_y), size=(10, 40), fill="brown"))
193
+ # Foliage
194
+ dwg.add(dwg.circle(center=(tree_x, tree_y-10), r=25, fill="green", stroke="darkgreen", stroke_width=2))
195
+
196
+ def _draw_abstract_sketch(self, dwg, cx, cy, num_paths):
197
+ """Draw abstract sketchy shapes"""
198
+ import random
199
+
200
+ colors = ["red", "blue", "green", "orange", "purple", "pink", "yellow"]
201
+
202
+ for i in range(min(num_paths, 12)):
203
+ x = cx + random.randint(-150, 150)
204
+ y = cy + random.randint(-150, 150)
205
+ r = random.randint(20, 60)
206
+ color = random.choice(colors)
207
+
208
+ if i % 3 == 0:
209
+ dwg.add(dwg.circle(center=(x, y), r=r, fill="none", stroke=color, stroke_width=3))
210
+ elif i % 3 == 1:
211
+ dwg.add(dwg.rect(insert=(x-r//2, y-r//2), size=(r, r), fill="none", stroke=color, stroke_width=2))
212
+ else:
213
+ points = [(x, y-r), (x+r, y+r), (x-r, y+r)]
214
+ dwg.add(dwg.polygon(points=points, fill="none", stroke=color, stroke_width=2))