jree423 commited on
Commit
c198636
·
verified ·
1 Parent(s): a6afa85

Fix handler to return PIL.Image.Image directly

Browse files

Updated handler to fix the issue with returning PIL.Image.Image objects directly instead of dictionaries.

Files changed (1) hide show
  1. handler.py +36 -112
handler.py CHANGED
@@ -1,37 +1,12 @@
1
  import os
2
  import sys
3
- import json
4
- import base64
5
- from io import BytesIO
6
  import torch
7
  import numpy as np
8
- from PIL import Image, ImageDraw
9
  import random
10
- import tempfile
11
- import subprocess
12
- import importlib.util
13
- import shutil
14
- import time
15
-
16
- # Add DiffSketcher repository to the path
17
- DIFFSKETCHER_PATH = "/workspace/repos/DiffSketcher"
18
- if DIFFSKETCHER_PATH not in sys.path:
19
- sys.path.append(DIFFSKETCHER_PATH)
20
-
21
- # Import DiffSketcher modules
22
- try:
23
- # These imports are commented out to avoid errors
24
- # We're using a placeholder implementation anyway
25
- # from models.clip_model import ClipModel
26
- # from models.sd_model import StableDiffusion
27
- # from models.loss import Loss
28
- # from models.render import Render
29
- # from models.svg import SVG
30
- # from utils.train_utils import init_log, get_latest_ckpt, save_cfg
31
- # from utils.vector_utils import svg_to_png
32
- pass
33
- except ImportError:
34
- print("Failed to import DiffSketcher modules. Using placeholder implementation.")
35
 
36
  class EndpointHandler:
37
  def __init__(self, path=""):
@@ -43,24 +18,33 @@ class EndpointHandler:
43
  """
44
  self.path = path
45
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- self.initialized = False
47
- self.use_placeholder = True
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
- # Initialize DiffSketcher components
51
- self.init_diffsketcher()
52
- self.initialized = True
53
- self.use_placeholder = False
54
- print(f"Successfully initialized DiffSketcher on {self.device}")
55
  except Exception as e:
56
- print(f"Failed to initialize DiffSketcher: {e}")
57
- print("Using placeholder implementation instead")
58
-
59
- def init_diffsketcher(self):
60
- """Initialize the DiffSketcher model components."""
61
- # This is a simplified initialization
62
- # In a real implementation, we would load the model weights and initialize all components
63
- pass
64
 
65
  def generate_svg(self, prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=None):
66
  """
@@ -76,8 +60,11 @@ class EndpointHandler:
76
  Returns:
77
  tuple: (svg_string, png_image)
78
  """
79
- # This is where we would call the actual DiffSketcher model
80
- # For now, we'll return a placeholder SVG
 
 
 
81
 
82
  # Create a simple SVG with some paths
83
  svg_string = f"""<svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg">
@@ -87,7 +74,6 @@ class EndpointHandler:
87
  """
88
 
89
  # Add some random paths
90
- random.seed(seed)
91
  for i in range(min(num_paths, 20)): # Limit to 20 paths for the placeholder
92
  x = random.randint(50, 462)
93
  y = random.randint(50, 462)
@@ -98,10 +84,10 @@ class EndpointHandler:
98
 
99
  if random.random() < 0.5:
100
  # Circle
101
- svg_string += f'<circle cx="{x}" cy="{y}" r="{size/2}" fill="rgb({r},{g},{b})" />\n'
102
  else:
103
  # Rectangle
104
- svg_string += f'<rect x="{x}" y="{y}" width="{size}" height="{size}" fill="rgb({r},{g},{b})" />\n'
105
 
106
  # Add some lines
107
  for i in range(min(num_paths // 2, 10)): # Limit to 10 lines for the placeholder
@@ -114,7 +100,7 @@ class EndpointHandler:
114
  b = random.randint(0, 255)
115
  width = random.randint(1, 5)
116
 
117
- svg_string += f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="rgb({r},{g},{b})" stroke-width="{width}" />\n'
118
 
119
  svg_string += "</svg>"
120
 
@@ -123,60 +109,6 @@ class EndpointHandler:
123
 
124
  return svg_string, png_image
125
 
126
- def svg_to_png(self, svg_string, width=512, height=512):
127
- """
128
- Convert SVG string to PNG image.
129
-
130
- Args:
131
- svg_string (str): SVG string
132
- width (int): Image width
133
- height (int): Image height
134
-
135
- Returns:
136
- PIL.Image.Image: PNG image
137
- """
138
- try:
139
- # Write SVG to a temporary file
140
- with tempfile.NamedTemporaryFile(suffix=".svg", delete=False) as f:
141
- f.write(svg_string.encode('utf-8'))
142
- svg_path = f.name
143
-
144
- # Convert SVG to PNG using cairosvg or other method
145
- png_path = svg_path.replace('.svg', '.png')
146
-
147
- try:
148
- # Try using cairosvg
149
- import cairosvg
150
- cairosvg.svg2png(url=svg_path, write_to=png_path, output_width=width, output_height=height)
151
- except ImportError:
152
- # Fall back to using Inkscape or other command-line tools
153
- try:
154
- subprocess.run(['inkscape', '--export-filename', png_path, svg_path], check=True)
155
- except (subprocess.SubprocessError, FileNotFoundError):
156
- # If all else fails, create a simple image with PIL
157
- image = Image.new('RGB', (width, height), color=(240, 240, 240))
158
- draw = ImageDraw.Draw(image)
159
- draw.text((width//2, height//2), "SVG Rendering Failed", fill=(255, 0, 0), anchor="mm")
160
- image.save(png_path)
161
-
162
- # Load the PNG image
163
- image = Image.open(png_path)
164
-
165
- # Clean up temporary files
166
- os.remove(svg_path)
167
- os.remove(png_path)
168
-
169
- return image
170
- except Exception as e:
171
- print(f"Error converting SVG to PNG: {e}")
172
-
173
- # Create a simple error image
174
- image = Image.new('RGB', (width, height), color=(240, 240, 240))
175
- draw = ImageDraw.Draw(image)
176
- draw.text((width//2, height//2), f"SVG Rendering Error: {str(e)}", fill=(255, 0, 0), anchor="mm")
177
-
178
- return image
179
-
180
  def __call__(self, data):
181
  """
182
  Process the input data and generate SVG output.
@@ -195,8 +127,6 @@ class EndpointHandler:
195
  if not prompt:
196
  # Create a default error image
197
  error_img = Image.new('RGB', (512, 512), color=(240, 240, 240))
198
- draw = ImageDraw.Draw(error_img)
199
- draw.text((256, 256), "Error: Prompt is required", fill=(255, 0, 0), anchor="mm")
200
  return error_img
201
 
202
  negative_prompt = data.get("negative_prompt", "")
@@ -213,11 +143,5 @@ class EndpointHandler:
213
  seed=seed
214
  )
215
 
216
- # Add metadata to the image
217
- draw = ImageDraw.Draw(png_image)
218
- draw.rectangle((0, 0, 512, 40), fill=(0, 0, 0, 128))
219
- draw.text((10, 10), f"DiffSketcher: {prompt}", fill=(255, 255, 255))
220
- draw.text((10, 30), f"Paths: {num_paths}, Guidance: {guidance_scale}, Seed: {seed}", fill=(200, 200, 200))
221
-
222
  # Return the image directly (not as a dictionary)
223
  return png_image
 
1
  import os
2
  import sys
 
 
 
3
  import torch
4
  import numpy as np
5
+ from PIL import Image
6
  import random
7
+ import io
8
+ import base64
9
+ import cairosvg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
 
18
  """
19
  self.path = path
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ print(f"Initializing DiffSketcher handler on {self.device}")
 
22
 
23
+ # In a real implementation, we would load the model weights and initialize all components
24
+ # For now, we'll use a placeholder implementation
25
+
26
+ def svg_to_png(self, svg_string, width=512, height=512):
27
+ """
28
+ Convert SVG string to PNG image.
29
+
30
+ Args:
31
+ svg_string (str): SVG string
32
+ width (int): Width of the output image
33
+ height (int): Height of the output image
34
+
35
+ Returns:
36
+ PIL.Image.Image: PNG image
37
+ """
38
  try:
39
+ # Use cairosvg to convert SVG to PNG
40
+ png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'),
41
+ output_width=width,
42
+ output_height=height)
43
+ return Image.open(io.BytesIO(png_data))
44
  except Exception as e:
45
+ print(f"Error converting SVG to PNG: {e}")
46
+ # Return a blank image if conversion fails
47
+ return Image.new('RGB', (width, height), color=(240, 240, 240))
 
 
 
 
 
48
 
49
  def generate_svg(self, prompt, negative_prompt="", num_paths=96, guidance_scale=7.5, seed=None):
50
  """
 
60
  Returns:
61
  tuple: (svg_string, png_image)
62
  """
63
+ # Set random seed for reproducibility
64
+ if seed is not None:
65
+ random.seed(seed)
66
+ else:
67
+ random.seed(42)
68
 
69
  # Create a simple SVG with some paths
70
  svg_string = f"""<svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg">
 
74
  """
75
 
76
  # Add some random paths
 
77
  for i in range(min(num_paths, 20)): # Limit to 20 paths for the placeholder
78
  x = random.randint(50, 462)
79
  y = random.randint(50, 462)
 
84
 
85
  if random.random() < 0.5:
86
  # Circle
87
+ svg_string += f'<circle cx="{x}" cy="{y}" r="{size/2}" fill="rgb({r},{g},{b})" />'
88
  else:
89
  # Rectangle
90
+ svg_string += f'<rect x="{x}" y="{y}" width="{size}" height="{size}" fill="rgb({r},{g},{b})" />'
91
 
92
  # Add some lines
93
  for i in range(min(num_paths // 2, 10)): # Limit to 10 lines for the placeholder
 
100
  b = random.randint(0, 255)
101
  width = random.randint(1, 5)
102
 
103
+ svg_string += f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="rgb({r},{g},{b})" stroke-width="{width}" />'
104
 
105
  svg_string += "</svg>"
106
 
 
109
 
110
  return svg_string, png_image
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def __call__(self, data):
113
  """
114
  Process the input data and generate SVG output.
 
127
  if not prompt:
128
  # Create a default error image
129
  error_img = Image.new('RGB', (512, 512), color=(240, 240, 240))
 
 
130
  return error_img
131
 
132
  negative_prompt = data.get("negative_prompt", "")
 
143
  seed=seed
144
  )
145
 
 
 
 
 
 
 
146
  # Return the image directly (not as a dictionary)
147
  return png_image