File size: 2,039 Bytes
eb24b03
 
 
 
 
 
 
 
 
 
 
a7e6caa
2e68455
 
d9152d4
eb24b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Dict, List, Any
from PIL import Image
import io
import base64
from rembg import remove, new_session

class EndpointHandler():
    def __init__(self, path=""):
        # Preload the rembg session. 
        # By default, it uses the 'u2net' model.
        # This ensures the model is loaded into memory when the endpoint starts.
        # model_name = "u2net"
        model_name = "u2netp"
        # model_name = "birefnet-general-lite"
        self.session = new_session(model_name)

    def __call__(self, data: Dict[str, Any]) -> Any:
        """
        Args:
            data (:obj: `Dict`):
                - "inputs": PIL.Image or base64 string
                - "parameters": (Optional) Dict containing rembg options like 'alpha_matting'
        Return:
            A PIL.Image or serialized image.
        """
        # Get inputs
        inputs = data.get("inputs", None)
        if inputs is None:
            return {"error": "No inputs provided"}

        # Hugging Face usually passes the image as a PIL object 
        # if the request Content-Type is image/*
        if isinstance(inputs, str):
            # Handle base64 string if necessary
            image_data = base64.b64decode(inputs)
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
        elif isinstance(inputs, Image.Image):
            image = inputs
        else:
            # Fallback for raw bytes
            image = Image.open(io.BytesIO(inputs)).convert("RGB")

        # Extract optional parameters for rembg.remove
        # e.g., alpha_matting, bgcolor, etc.
        params = data.get("parameters", {})

        # Execute background removal
        # rembg.remove returns a PIL Image with an alpha channel (RGBA)
        output_image = remove(
            image, 
            session=self.session,
            **params
        )

        # Return the PIL Image directly. 
        # Hugging Face Inference Endpoints will automatically serialize 
        # a PIL Image into a response.
        return output_image