Fix handler to return PIL Images instead of dictionaries for HF API compatibility
Browse files- handler.py +73 -29
handler.py
CHANGED
|
@@ -81,29 +81,37 @@ class EndpointHandler:
|
|
| 81 |
|
| 82 |
# Process based on edit type
|
| 83 |
if edit_type == "replace" and len(prompts) >= 2:
|
| 84 |
-
|
| 85 |
elif edit_type == "refine":
|
| 86 |
-
|
| 87 |
elif edit_type == "reweight":
|
| 88 |
-
|
| 89 |
elif edit_type == "generate":
|
| 90 |
-
|
| 91 |
else:
|
| 92 |
# Default to refinement
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
except Exception as e:
|
| 98 |
print(f"Error in handler: {e}")
|
| 99 |
-
# Return fallback
|
| 100 |
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
"error": str(e)
|
| 106 |
-
}
|
| 107 |
|
| 108 |
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
|
| 109 |
"""Perform word replacement editing"""
|
|
@@ -128,9 +136,7 @@ class EndpointHandler:
|
|
| 128 |
# Apply word replacement transformations
|
| 129 |
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
|
| 130 |
|
| 131 |
-
|
| 132 |
-
"svg": edited_svg,
|
| 133 |
-
"svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'),
|
| 134 |
"edit_type": "replace",
|
| 135 |
"source_prompt": source_prompt,
|
| 136 |
"target_prompt": target_prompt,
|
|
@@ -138,9 +144,13 @@ class EndpointHandler:
|
|
| 138 |
"removed_words": list(removed_words)
|
| 139 |
}
|
| 140 |
|
|
|
|
|
|
|
| 141 |
except Exception as e:
|
| 142 |
print(f"Error in word_replacement_edit: {e}")
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
| 146 |
"""Perform prompt refinement editing"""
|
|
@@ -156,16 +166,18 @@ class EndpointHandler:
|
|
| 156 |
# Apply refinement based on prompt analysis
|
| 157 |
refined_svg = self.apply_refinement(base_svg, prompt, width, height)
|
| 158 |
|
| 159 |
-
|
| 160 |
-
"svg": refined_svg,
|
| 161 |
-
"svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'),
|
| 162 |
"edit_type": "refine",
|
| 163 |
"prompt": prompt
|
| 164 |
}
|
| 165 |
|
|
|
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
print(f"Error in prompt_refinement_edit: {e}")
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
|
| 170 |
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
| 171 |
"""Perform attention reweighting editing"""
|
|
@@ -184,18 +196,20 @@ class EndpointHandler:
|
|
| 184 |
# Apply attention reweighting
|
| 185 |
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
|
| 186 |
|
| 187 |
-
|
| 188 |
-
"svg": reweighted_svg,
|
| 189 |
-
"svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'),
|
| 190 |
"edit_type": "reweight",
|
| 191 |
"prompt": prompt,
|
| 192 |
"weighted_prompt": weighted_prompt,
|
| 193 |
"attention_weights": attention_weights
|
| 194 |
}
|
| 195 |
|
|
|
|
|
|
|
| 196 |
except Exception as e:
|
| 197 |
print(f"Error in attention_reweighting_edit: {e}")
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
|
| 200 |
def simple_generation(self, prompt: str, width: int, height: int):
|
| 201 |
"""Perform simple SVG generation"""
|
|
@@ -204,16 +218,18 @@ class EndpointHandler:
|
|
| 204 |
|
| 205 |
svg_content = self.generate_base_svg(prompt, width, height)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
"svg": svg_content,
|
| 209 |
-
"svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'),
|
| 210 |
"edit_type": "generate",
|
| 211 |
"prompt": prompt
|
| 212 |
}
|
| 213 |
|
|
|
|
|
|
|
| 214 |
except Exception as e:
|
| 215 |
print(f"Error in simple_generation: {e}")
|
| 216 |
-
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def generate_base_svg(self, prompt: str, width: int, height: int):
|
| 219 |
"""Generate base SVG from prompt"""
|
|
@@ -727,6 +743,34 @@ class EndpointHandler:
|
|
| 727 |
"error": error
|
| 728 |
}
|
| 729 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
def create_fallback_svg(self, prompt: str, width: int, height: int):
|
| 731 |
"""Create simple fallback SVG"""
|
| 732 |
dwg = svgwrite.Drawing(size=(width, height))
|
|
|
|
| 81 |
|
| 82 |
# Process based on edit type
|
| 83 |
if edit_type == "replace" and len(prompts) >= 2:
|
| 84 |
+
svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
|
| 85 |
elif edit_type == "refine":
|
| 86 |
+
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
|
| 87 |
elif edit_type == "reweight":
|
| 88 |
+
svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
|
| 89 |
elif edit_type == "generate":
|
| 90 |
+
svg_content, metadata = self.simple_generation(prompts[0], width, height)
|
| 91 |
else:
|
| 92 |
# Default to refinement
|
| 93 |
+
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
|
| 94 |
|
| 95 |
+
# Convert SVG to PIL Image for HF API compatibility
|
| 96 |
+
pil_image = self.svg_to_pil_image(svg_content, width, height)
|
| 97 |
+
|
| 98 |
+
# Store metadata
|
| 99 |
+
for key, value in metadata.items():
|
| 100 |
+
if isinstance(value, (dict, list)):
|
| 101 |
+
pil_image.info[key] = json.dumps(value)
|
| 102 |
+
else:
|
| 103 |
+
pil_image.info[key] = str(value)
|
| 104 |
+
|
| 105 |
+
return pil_image
|
| 106 |
|
| 107 |
except Exception as e:
|
| 108 |
print(f"Error in handler: {e}")
|
| 109 |
+
# Return fallback image
|
| 110 |
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
|
| 111 |
+
fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
|
| 112 |
+
fallback_image.info['error'] = str(e)
|
| 113 |
+
fallback_image.info['edit_type'] = edit_type
|
| 114 |
+
return fallback_image
|
|
|
|
|
|
|
| 115 |
|
| 116 |
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
|
| 117 |
"""Perform word replacement editing"""
|
|
|
|
| 136 |
# Apply word replacement transformations
|
| 137 |
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
|
| 138 |
|
| 139 |
+
metadata = {
|
|
|
|
|
|
|
| 140 |
"edit_type": "replace",
|
| 141 |
"source_prompt": source_prompt,
|
| 142 |
"target_prompt": target_prompt,
|
|
|
|
| 144 |
"removed_words": list(removed_words)
|
| 145 |
}
|
| 146 |
|
| 147 |
+
return edited_svg, metadata
|
| 148 |
+
|
| 149 |
except Exception as e:
|
| 150 |
print(f"Error in word_replacement_edit: {e}")
|
| 151 |
+
fallback_svg = self.create_fallback_svg(source_prompt, width, height)
|
| 152 |
+
metadata = {"edit_type": "replace", "error": str(e)}
|
| 153 |
+
return fallback_svg, metadata
|
| 154 |
|
| 155 |
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
| 156 |
"""Perform prompt refinement editing"""
|
|
|
|
| 166 |
# Apply refinement based on prompt analysis
|
| 167 |
refined_svg = self.apply_refinement(base_svg, prompt, width, height)
|
| 168 |
|
| 169 |
+
metadata = {
|
|
|
|
|
|
|
| 170 |
"edit_type": "refine",
|
| 171 |
"prompt": prompt
|
| 172 |
}
|
| 173 |
|
| 174 |
+
return refined_svg, metadata
|
| 175 |
+
|
| 176 |
except Exception as e:
|
| 177 |
print(f"Error in prompt_refinement_edit: {e}")
|
| 178 |
+
fallback_svg = self.create_fallback_svg(prompt, width, height)
|
| 179 |
+
metadata = {"edit_type": "refine", "error": str(e)}
|
| 180 |
+
return fallback_svg, metadata
|
| 181 |
|
| 182 |
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
| 183 |
"""Perform attention reweighting editing"""
|
|
|
|
| 196 |
# Apply attention reweighting
|
| 197 |
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
|
| 198 |
|
| 199 |
+
metadata = {
|
|
|
|
|
|
|
| 200 |
"edit_type": "reweight",
|
| 201 |
"prompt": prompt,
|
| 202 |
"weighted_prompt": weighted_prompt,
|
| 203 |
"attention_weights": attention_weights
|
| 204 |
}
|
| 205 |
|
| 206 |
+
return reweighted_svg, metadata
|
| 207 |
+
|
| 208 |
except Exception as e:
|
| 209 |
print(f"Error in attention_reweighting_edit: {e}")
|
| 210 |
+
fallback_svg = self.create_fallback_svg(prompt, width, height)
|
| 211 |
+
metadata = {"edit_type": "reweight", "error": str(e)}
|
| 212 |
+
return fallback_svg, metadata
|
| 213 |
|
| 214 |
def simple_generation(self, prompt: str, width: int, height: int):
|
| 215 |
"""Perform simple SVG generation"""
|
|
|
|
| 218 |
|
| 219 |
svg_content = self.generate_base_svg(prompt, width, height)
|
| 220 |
|
| 221 |
+
metadata = {
|
|
|
|
|
|
|
| 222 |
"edit_type": "generate",
|
| 223 |
"prompt": prompt
|
| 224 |
}
|
| 225 |
|
| 226 |
+
return svg_content, metadata
|
| 227 |
+
|
| 228 |
except Exception as e:
|
| 229 |
print(f"Error in simple_generation: {e}")
|
| 230 |
+
fallback_svg = self.create_fallback_svg(prompt, width, height)
|
| 231 |
+
metadata = {"edit_type": "generate", "error": str(e)}
|
| 232 |
+
return fallback_svg, metadata
|
| 233 |
|
| 234 |
def generate_base_svg(self, prompt: str, width: int, height: int):
|
| 235 |
"""Generate base SVG from prompt"""
|
|
|
|
| 743 |
"error": error
|
| 744 |
}
|
| 745 |
|
| 746 |
+
def svg_to_pil_image(self, svg_content, width, height):
|
| 747 |
+
"""Convert SVG content to PIL Image"""
|
| 748 |
+
try:
|
| 749 |
+
import cairosvg
|
| 750 |
+
import io
|
| 751 |
+
|
| 752 |
+
# Convert SVG to PNG bytes
|
| 753 |
+
png_bytes = cairosvg.svg2png(
|
| 754 |
+
bytestring=svg_content.encode('utf-8'),
|
| 755 |
+
output_width=width,
|
| 756 |
+
output_height=height
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Convert to PIL Image
|
| 760 |
+
image = Image.open(io.BytesIO(png_bytes)).convert('RGB')
|
| 761 |
+
return image
|
| 762 |
+
|
| 763 |
+
except ImportError:
|
| 764 |
+
print("cairosvg not available, creating simple image representation")
|
| 765 |
+
# Fallback: create a simple image with text
|
| 766 |
+
image = Image.new('RGB', (width, height), 'white')
|
| 767 |
+
return image
|
| 768 |
+
except Exception as e:
|
| 769 |
+
print(f"Error converting SVG to image: {e}")
|
| 770 |
+
# Fallback: create a simple image
|
| 771 |
+
image = Image.new('RGB', (width, height), 'white')
|
| 772 |
+
return image
|
| 773 |
+
|
| 774 |
def create_fallback_svg(self, prompt: str, width: int, height: int):
|
| 775 |
"""Create simple fallback SVG"""
|
| 776 |
dwg = svgwrite.Drawing(size=(width, height))
|