rembg / handler.py
radenmuaz's picture
Update handler.py
2e68455 verified
from typing import Dict, List, Any
from PIL import Image
import io
import base64
from rembg import remove, new_session
class EndpointHandler():
def __init__(self, path=""):
# Preload the rembg session.
# By default, it uses the 'u2net' model.
# This ensures the model is loaded into memory when the endpoint starts.
# model_name = "u2net"
model_name = "u2netp"
# model_name = "birefnet-general-lite"
self.session = new_session(model_name)
def __call__(self, data: Dict[str, Any]) -> Any:
"""
Args:
data (:obj: `Dict`):
- "inputs": PIL.Image or base64 string
- "parameters": (Optional) Dict containing rembg options like 'alpha_matting'
Return:
A PIL.Image or serialized image.
"""
# Get inputs
inputs = data.get("inputs", None)
if inputs is None:
return {"error": "No inputs provided"}
# Hugging Face usually passes the image as a PIL object
# if the request Content-Type is image/*
if isinstance(inputs, str):
# Handle base64 string if necessary
image_data = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
elif isinstance(inputs, Image.Image):
image = inputs
else:
# Fallback for raw bytes
image = Image.open(io.BytesIO(inputs)).convert("RGB")
# Extract optional parameters for rembg.remove
# e.g., alpha_matting, bgcolor, etc.
params = data.get("parameters", {})
# Execute background removal
# rembg.remove returns a PIL Image with an alpha channel (RGBA)
output_image = remove(
image,
session=self.session,
**params
)
# Return the PIL Image directly.
# Hugging Face Inference Endpoints will automatically serialize
# a PIL Image into a response.
return output_image