Create handler.py
Browse files- handler.py +66 -0
handler.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import base64
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transparent_background import Remover
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EndpointHandler:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
# Initialize the remover model with desired settings.
|
| 11 |
+
self.remover = Remover(mode='fast')
|
| 12 |
+
# Warm up the model with a dummy image.
|
| 13 |
+
dummy = Image.new("RGB", (256, 256), "white")
|
| 14 |
+
_ = self.remover.process(dummy)
|
| 15 |
+
|
| 16 |
+
def __call__(self, request):
|
| 17 |
+
"""
|
| 18 |
+
Expects a dictionary (JSON) with keys:
|
| 19 |
+
- "images": a list of base64-encoded image strings (e.g. "data:image/png;base64,...")
|
| 20 |
+
- "output_type": one of "rgba", "map", "green", "blur", or "overlay"
|
| 21 |
+
- "threshold": a float (0.0 to 1.0)
|
| 22 |
+
- "reverse": a boolean flag
|
| 23 |
+
Returns a dictionary with:
|
| 24 |
+
- "images": list of processed images (base64-encoded with data URI prefix)
|
| 25 |
+
- "processing_times": a string with individual and total processing times
|
| 26 |
+
"""
|
| 27 |
+
# Get parameters from the request.
|
| 28 |
+
images_data = request.get("images", [])
|
| 29 |
+
output_type = request.get("output_type", "rgba")
|
| 30 |
+
threshold = request.get("threshold", 0.1)
|
| 31 |
+
reverse = request.get("reverse", False)
|
| 32 |
+
|
| 33 |
+
processed_results = []
|
| 34 |
+
times_list = []
|
| 35 |
+
|
| 36 |
+
global_start = time.time()
|
| 37 |
+
|
| 38 |
+
# Process up to 3 images
|
| 39 |
+
for idx, img_b64 in enumerate(images_data[:3]):
|
| 40 |
+
# Remove data URI prefix if present.
|
| 41 |
+
if img_b64.startswith("data:"):
|
| 42 |
+
img_b64 = img_b64.split(",")[1]
|
| 43 |
+
# Decode the image.
|
| 44 |
+
img_bytes = base64.b64decode(img_b64)
|
| 45 |
+
image = Image.open(BytesIO(img_bytes)).convert("RGB")
|
| 46 |
+
|
| 47 |
+
start_time = time.time()
|
| 48 |
+
result = self.remover.process(image, type=output_type, threshold=threshold, reverse=reverse)
|
| 49 |
+
elapsed = time.time() - start_time
|
| 50 |
+
times_list.append(f"Image {idx+1}: {elapsed:.2f} seconds")
|
| 51 |
+
|
| 52 |
+
# Convert the result to base64.
|
| 53 |
+
buffer = BytesIO()
|
| 54 |
+
result.save(buffer, format="PNG")
|
| 55 |
+
buffer.seek(0)
|
| 56 |
+
result_b64 = base64.b64encode(buffer.read()).decode("utf-8")
|
| 57 |
+
processed_results.append("data:image/png;base64," + result_b64)
|
| 58 |
+
|
| 59 |
+
total_time = time.time() - global_start
|
| 60 |
+
times_list.append(f"Total time: {total_time:.2f} seconds")
|
| 61 |
+
elapsed_str = "\n".join(times_list)
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
"images": processed_results,
|
| 65 |
+
"processing_times": elapsed_str
|
| 66 |
+
}
|