jree423 commited on
Commit
54f01ed
·
verified ·
1 Parent(s): 88dea80

Fix: Update handler.py with EndpointHandler class

Browse files
Files changed (1) hide show
  1. handler.py +46 -39
handler.py CHANGED
@@ -15,45 +15,52 @@ except ImportError:
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
- def handler(inputs, **kwargs):
19
- """Simple handler function that returns a placeholder SVG"""
20
- try:
21
- # Extract the prompt
22
- if isinstance(inputs, dict) and "inputs" in inputs:
23
- if isinstance(inputs["inputs"], str):
24
- prompt = inputs["inputs"]
25
- elif isinstance(inputs["inputs"], dict) and "text" in inputs["inputs"]:
26
- prompt = inputs["inputs"]["text"]
 
 
 
 
 
 
 
 
 
27
  else:
28
  prompt = "No prompt provided"
29
- else:
30
- prompt = "No prompt provided"
31
-
32
- # Generate a placeholder SVG
33
- width, height = 512, 512
34
- svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
35
- <rect width="100%" height="100%" fill="#f0f0f0"/>
36
- <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
37
- </svg>"""
38
-
39
- # Convert SVG to base64
40
- svg_bytes = svg_content.encode("utf-8")
41
- svg_base64 = base64.b64encode(svg_bytes).decode("utf-8")
42
-
43
- # Convert SVG to PNG using cairosvg
44
- try:
45
- png_data = cairosvg.svg2png(bytestring=svg_bytes)
46
- png_base64 = base64.b64encode(png_data).decode("utf-8")
 
 
 
 
 
 
 
 
 
47
  except Exception as e:
48
- print(f"Error converting SVG to PNG: {e}")
49
- # Return a transparent 1x1 pixel PNG as fallback
50
- png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
51
-
52
- # Return the results
53
- return {
54
- "svg": svg_content,
55
- "svg_base64": svg_base64,
56
- "png_base64": png_base64
57
- }
58
- except Exception as e:
59
- return {"error": str(e)}
 
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
+ class EndpointHandler:
19
+ def __init__(self, model_dir):
20
+ """Initialize the handler with model directory"""
21
+ self.model_dir = model_dir
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ print(f"Initialized model on device: {self.device}")
24
+
25
+ def __call__(self, data):
26
+ """Handle a request to the model"""
27
+ try:
28
+ # Extract the prompt
29
+ if isinstance(data, dict) and "inputs" in data:
30
+ if isinstance(data["inputs"], str):
31
+ prompt = data["inputs"]
32
+ elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
33
+ prompt = data["inputs"]["text"]
34
+ else:
35
+ prompt = "No prompt provided"
36
  else:
37
  prompt = "No prompt provided"
38
+
39
+ # Generate a placeholder SVG
40
+ width, height = 512, 512
41
+ svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
42
+ <rect width="100%" height="100%" fill="#f0f0f0"/>
43
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
44
+ </svg>"""
45
+
46
+ # Convert SVG to base64
47
+ svg_bytes = svg_content.encode("utf-8")
48
+ svg_base64 = base64.b64encode(svg_bytes).decode("utf-8")
49
+
50
+ # Convert SVG to PNG using cairosvg
51
+ try:
52
+ png_data = cairosvg.svg2png(bytestring=svg_bytes)
53
+ png_base64 = base64.b64encode(png_data).decode("utf-8")
54
+ except Exception as e:
55
+ print(f"Error converting SVG to PNG: {e}")
56
+ # Return a transparent 1x1 pixel PNG as fallback
57
+ png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
58
+
59
+ # Return the results
60
+ return {
61
+ "svg": svg_content,
62
+ "svg_base64": svg_base64,
63
+ "png_base64": png_base64
64
+ }
65
  except Exception as e:
66
+ return {"error": str(e)}