jree423 commited on
Commit
55dc40f
·
verified ·
1 Parent(s): 5b99d31

Fix: Update handler.py with robust cairosvg import and PNG conversion

Browse files
Files changed (1) hide show
  1. handler.py +28 -3
handler.py CHANGED
@@ -5,8 +5,23 @@ import json
5
  import torch
6
  import numpy as np
7
  from PIL import Image
8
- import cairosvg
9
- from diffusers import StableDiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class ModelHandler:
12
  def __init__(self):
@@ -68,9 +83,19 @@ class ModelHandler:
68
  svg_bytes = inference_output.encode('utf-8')
69
  svg_base64 = base64.b64encode(svg_bytes).decode('utf-8')
70
 
 
 
 
 
 
 
 
 
 
71
  return {
72
  "svg": inference_output,
73
- "svg_base64": svg_base64
 
74
  }
75
 
76
  def handle(self, data):
 
5
  import torch
6
  import numpy as np
7
  from PIL import Image
8
+
9
+ # Safely import cairosvg with fallback
10
+ try:
11
+ import cairosvg
12
+ except ImportError:
13
+ print("Warning: cairosvg not found. Installing...")
14
+ import subprocess
15
+ subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
+ import cairosvg
17
+
18
+ try:
19
+ from diffusers import StableDiffusionPipeline
20
+ except ImportError:
21
+ print("Warning: diffusers not found. Installing...")
22
+ import subprocess
23
+ subprocess.check_call(["pip", "install", "diffusers", "transformers", "accelerate"])
24
+ from diffusers import StableDiffusionPipeline
25
 
26
  class ModelHandler:
27
  def __init__(self):
 
83
  svg_bytes = inference_output.encode('utf-8')
84
  svg_base64 = base64.b64encode(svg_bytes).decode('utf-8')
85
 
86
+ # Convert SVG to PNG using cairosvg
87
+ try:
88
+ png_data = cairosvg.svg2png(bytestring=svg_bytes)
89
+ png_base64 = base64.b64encode(png_data).decode('utf-8')
90
+ except Exception as e:
91
+ print(f"Error converting SVG to PNG: {e}")
92
+ # Return a transparent 1x1 pixel PNG as fallback
93
+ png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
94
+
95
  return {
96
  "svg": inference_output,
97
+ "svg_base64": svg_base64,
98
+ "png_base64": png_base64
99
  }
100
 
101
  def handle(self, data):