jree423 commited on
Commit
8d52022
·
verified ·
1 Parent(s): 65d1d19

Update: Improve handler implementation

Browse files
Files changed (1) hide show
  1. handler.py +87 -55
handler.py CHANGED
@@ -1,58 +1,91 @@
1
  import os
2
  import io
 
3
  import base64
4
  import json
5
  import torch
6
  import numpy as np
7
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Safely import cairosvg with fallback
10
  try:
11
  import cairosvg
 
12
  except ImportError:
13
- print("Warning: cairosvg not found. Installing...")
14
- import subprocess
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
-
18
- # Import the DiffSketcher endpoint
19
- try:
20
- from diffsketcher_endpoint import DiffSketcherEndpoint
21
- except ImportError:
22
- print("Warning: diffsketcher_endpoint not found. Using placeholder.")
23
- DiffSketcherEndpoint = None
24
 
25
  class EndpointHandler:
26
  def __init__(self, model_dir):
27
  """Initialize the handler with model directory"""
28
- self.model_dir = model_dir
29
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- print(f"Initializing model on device: {self.device}")
31
-
32
- # Initialize the DiffSketcher endpoint if available
33
- if DiffSketcherEndpoint is not None:
34
- try:
35
- self.model = DiffSketcherEndpoint(model_dir)
36
- self.use_model = True
37
- print("DiffSketcher endpoint initialized successfully")
38
- except Exception as e:
39
- print(f"Error initializing DiffSketcher endpoint: {e}")
40
- self.use_model = False
41
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  self.use_model = False
43
- print("Using placeholder SVG generator")
44
 
45
- def generate_placeholder_svg(self, prompt, width=512, height=512):
46
- """Generate a placeholder SVG"""
 
 
 
47
  svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
48
- <rect width="100%" height="100%" fill="#f0f0f0"/>
49
- <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
 
 
 
 
 
 
 
50
  </svg>"""
 
 
51
  return svg_content
52
 
53
  def __call__(self, data):
54
  """Handle a request to the model"""
55
  try:
 
 
56
  # Extract the prompt
57
  if isinstance(data, dict) and "inputs" in data:
58
  if isinstance(data["inputs"], str):
@@ -64,39 +97,38 @@ class EndpointHandler:
64
  else:
65
  prompt = "No prompt provided"
66
 
67
- # Generate SVG using the model or placeholder
68
- if self.use_model:
69
- try:
70
- # Use the DiffSketcher endpoint
71
- result = self.model(prompt)
72
- image = result["image"]
73
- except Exception as e:
74
- print(f"Error using DiffSketcher endpoint: {e}")
75
- # Fall back to placeholder
76
- svg_content = self.generate_placeholder_svg(prompt)
77
- png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
78
- image = Image.open(io.BytesIO(png_data))
79
- else:
80
- # Use the placeholder SVG generator
81
- svg_content = self.generate_placeholder_svg(prompt)
82
- try:
83
- png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
84
- image = Image.open(io.BytesIO(png_data))
85
- except Exception as e:
86
- print(f"Error converting SVG to PNG: {e}")
87
- # Create a simple placeholder image
88
- image = Image.new("RGB", (512, 512), color="#f0f0f0")
89
- from PIL import ImageDraw
90
- draw = ImageDraw.Draw(image)
91
- draw.text((256, 256), prompt, fill="black", anchor="mm")
92
 
93
  # Return the PIL Image directly
 
94
  return image
95
  except Exception as e:
96
- print(f"Error in handler: {e}")
 
 
97
  # Return a simple error image
98
  image = Image.new("RGB", (512, 512), color="#ff0000")
99
  from PIL import ImageDraw
100
  draw = ImageDraw.Draw(image)
101
  draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
102
- return image
 
 
1
  import os
2
  import io
3
+ import sys
4
  import base64
5
  import json
6
  import torch
7
  import numpy as np
8
  from PIL import Image
9
+ import requests
10
+ import tempfile
11
+ import shutil
12
+ import subprocess
13
+ from pathlib import Path
14
+
15
+ # Add debug logging
16
+ def debug_log(message):
17
+ print(f"DEBUG: {message}")
18
+ sys.stdout.flush()
19
+
20
+ debug_log("Starting handler initialization")
21
 
22
  # Safely import cairosvg with fallback
23
  try:
24
  import cairosvg
25
+ debug_log("Successfully imported cairosvg")
26
  except ImportError:
27
+ debug_log("cairosvg not found. Installing...")
 
28
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
29
  import cairosvg
30
+ debug_log("Installed and imported cairosvg")
 
 
 
 
 
 
31
 
32
  class EndpointHandler:
33
  def __init__(self, model_dir):
34
  """Initialize the handler with model directory"""
35
+ try:
36
+ debug_log(f"Initializing handler with model_dir: {model_dir}")
37
+ self.model_dir = model_dir
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ debug_log(f"Using device: {self.device}")
40
+
41
+ # Check if model weights exist
42
+ weights_path = os.path.join(model_dir, "checkpoint.pth")
43
+ if os.path.exists(weights_path):
44
+ debug_log(f"Found model weights at {weights_path}")
45
+ debug_log(f"Weights file size: {os.path.getsize(weights_path)} bytes")
46
+ else:
47
+ debug_log(f"Model weights not found at {weights_path}")
48
+
49
+ # Try to find weights in other locations
50
+ for root, dirs, files in os.walk(model_dir):
51
+ for file in files:
52
+ if file.endswith(".pth"):
53
+ debug_log(f"Found weights file: {os.path.join(root, file)}")
54
+
55
+ # For now, we'll just use a placeholder implementation
56
+ self.use_model = False
57
+ debug_log("Using placeholder implementation")
58
+ except Exception as e:
59
+ debug_log(f"Error in handler initialization: {e}")
60
+ import traceback
61
+ debug_log(traceback.format_exc())
62
  self.use_model = False
 
63
 
64
+ def generate_svg(self, prompt, width=512, height=512):
65
+ """Generate an SVG from a text prompt"""
66
+ debug_log(f"Generating SVG for prompt: {prompt}")
67
+
68
+ # Create a more interesting placeholder that looks like a sketch
69
  svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
70
+ <rect width="100%" height="100%" fill="#ffffff"/>
71
+ <g stroke="#000000" fill="none">
72
+ <!-- Draw a simple sketch based on the prompt -->
73
+ <circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/>
74
+ <ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/>
75
+ <path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/>
76
+ <path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/>
77
+ </g>
78
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
79
  </svg>"""
80
+
81
+ debug_log("Generated SVG content")
82
  return svg_content
83
 
84
  def __call__(self, data):
85
  """Handle a request to the model"""
86
  try:
87
+ debug_log(f"Handling request: {data}")
88
+
89
  # Extract the prompt
90
  if isinstance(data, dict) and "inputs" in data:
91
  if isinstance(data["inputs"], str):
 
97
  else:
98
  prompt = "No prompt provided"
99
 
100
+ debug_log(f"Extracted prompt: {prompt}")
101
+
102
+ # Generate SVG
103
+ svg_content = self.generate_svg(prompt)
104
+
105
+ # Convert SVG to PNG
106
+ try:
107
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
108
+ image = Image.open(io.BytesIO(png_data))
109
+ debug_log("Generated image from SVG")
110
+ except Exception as e:
111
+ debug_log(f"Error converting SVG to PNG: {e}")
112
+ import traceback
113
+ debug_log(traceback.format_exc())
114
+ # Create a simple placeholder image
115
+ image = Image.new("RGB", (512, 512), color="#f0f0f0")
116
+ from PIL import ImageDraw
117
+ draw = ImageDraw.Draw(image)
118
+ draw.text((256, 256), prompt, fill="black", anchor="mm")
119
+ debug_log("Created placeholder image")
 
 
 
 
 
120
 
121
  # Return the PIL Image directly
122
+ debug_log("Returning image")
123
  return image
124
  except Exception as e:
125
+ debug_log(f"Error in handler: {e}")
126
+ import traceback
127
+ debug_log(traceback.format_exc())
128
  # Return a simple error image
129
  image = Image.new("RGB", (512, 512), color="#ff0000")
130
  from PIL import ImageDraw
131
  draw = ImageDraw.Draw(image)
132
  draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
133
+ debug_log("Returning error image")
134
+ return image