jree423 commited on
Commit
d87b721
·
verified ·
1 Parent(s): ab51276

Fix: Update handler.py with simplified implementation

Browse files
Files changed (1) hide show
  1. handler.py +25 -87
handler.py CHANGED
@@ -15,107 +15,45 @@ except ImportError:
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
- try:
19
- from diffusers import StableDiffusionPipeline
20
- except ImportError:
21
- print("Warning: diffusers not found. Installing...")
22
- import subprocess
23
- subprocess.check_call(["pip", "install", "diffusers", "transformers", "accelerate"])
24
- from diffusers import StableDiffusionPipeline
25
-
26
- class ModelHandler:
27
- def __init__(self):
28
- self.initialized = False
29
- self.model = None
30
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
-
32
- def initialize(self, model_dir):
33
- """Initialize the model"""
34
- self.model = StableDiffusionPipeline.from_pretrained(
35
- model_dir,
36
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
37
- ).to(self.device)
38
- self.initialized = True
39
- return self.initialized
40
-
41
- def preprocess(self, data):
42
- """Preprocess the input data"""
43
- inputs = data.get("inputs", {})
44
-
45
- if isinstance(inputs, str):
46
- # Text-to-image case
47
- prompt = inputs
48
- image = None
49
- else:
50
- # Image-to-image case
51
- prompt = inputs.get("text", "")
52
- image_b64 = inputs.get("image", None)
53
-
54
- if image_b64:
55
- image_data = base64.b64decode(image_b64)
56
- image = Image.open(io.BytesIO(image_data))
57
  else:
58
- image = None
59
-
60
- return {"prompt": prompt, "image": image}
61
-
62
- def inference(self, inputs):
63
- """Run inference with the model"""
64
- prompt = inputs["prompt"]
65
- image = inputs["image"]
66
-
67
- # Generate image
68
- if image is None:
69
- # Text-to-image generation
70
- result = self.model(prompt).images[0]
71
  else:
72
- # Image-to-image generation
73
- result = self.model(prompt, image=image).images[0]
74
 
75
- # Convert to SVG (placeholder - actual conversion would depend on the specific model)
76
- svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512"><text x="10" y="20">Generated from: {prompt}</text></svg>'
 
 
 
 
77
 
78
- return svg_content
79
-
80
- def postprocess(self, inference_output):
81
- """Postprocess the model output"""
82
- # Convert SVG to base64 for response
83
- svg_bytes = inference_output.encode('utf-8')
84
- svg_base64 = base64.b64encode(svg_bytes).decode('utf-8')
85
 
86
  # Convert SVG to PNG using cairosvg
87
  try:
88
  png_data = cairosvg.svg2png(bytestring=svg_bytes)
89
- png_base64 = base64.b64encode(png_data).decode('utf-8')
90
  except Exception as e:
91
  print(f"Error converting SVG to PNG: {e}")
92
  # Return a transparent 1x1 pixel PNG as fallback
93
  png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
94
 
 
95
  return {
96
- "svg": inference_output,
97
  "svg_base64": svg_base64,
98
  "png_base64": png_base64
99
  }
100
-
101
- def handle(self, data):
102
- """Handle a request to the model"""
103
- try:
104
- if not self.initialized:
105
- self.initialize("model")
106
-
107
- if data is None:
108
- return {"error": "No input data provided"}
109
-
110
- # Preprocess
111
- inputs = self.preprocess(data)
112
-
113
- # Inference
114
- outputs = self.inference(inputs)
115
-
116
- # Postprocess
117
- processed_outputs = self.postprocess(outputs)
118
-
119
- return processed_outputs
120
- except Exception as e:
121
- return {"error": str(e)}
 
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
+ def handler(inputs, **kwargs):
19
+ """Simple handler function that returns a placeholder SVG"""
20
+ try:
21
+ # Extract the prompt
22
+ if isinstance(inputs, dict) and "inputs" in inputs:
23
+ if isinstance(inputs["inputs"], str):
24
+ prompt = inputs["inputs"]
25
+ elif isinstance(inputs["inputs"], dict) and "text" in inputs["inputs"]:
26
+ prompt = inputs["inputs"]["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  else:
28
+ prompt = "No prompt provided"
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
+ prompt = "No prompt provided"
 
31
 
32
+ # Generate a placeholder SVG
33
+ width, height = 512, 512
34
+ svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
35
+ <rect width="100%" height="100%" fill="#f0f0f0"/>
36
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
37
+ </svg>"""
38
 
39
+ # Convert SVG to base64
40
+ svg_bytes = svg_content.encode("utf-8")
41
+ svg_base64 = base64.b64encode(svg_bytes).decode("utf-8")
 
 
 
 
42
 
43
  # Convert SVG to PNG using cairosvg
44
  try:
45
  png_data = cairosvg.svg2png(bytestring=svg_bytes)
46
+ png_base64 = base64.b64encode(png_data).decode("utf-8")
47
  except Exception as e:
48
  print(f"Error converting SVG to PNG: {e}")
49
  # Return a transparent 1x1 pixel PNG as fallback
50
  png_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
51
 
52
+ # Return the results
53
  return {
54
+ "svg": svg_content,
55
  "svg_base64": svg_base64,
56
  "png_base64": png_base64
57
  }
58
+ except Exception as e:
59
+ return {"error": str(e)}