File size: 4,506 Bytes
592adee
2f7cfdc
592adee
d05bd8d
592adee
 
d816a26
 
d05bd8d
f9b3f94
592adee
c78d04e
 
 
d05bd8d
 
2f7cfdc
 
d816a26
2f7cfdc
d816a26
 
c78d04e
d816a26
 
 
e0fb0e6
d816a26
c78d04e
d816a26
592adee
e52ad65
2f7cfdc
 
2f4ef92
e52ad65
2f7cfdc
e0fb0e6
2f4ef92
 
 
c78d04e
2f4ef92
 
 
 
 
e52ad65
2f4ef92
 
 
 
16a5f8c
e0fb0e6
 
2f4ef92
 
16a5f8c
e0fb0e6
d816a26
e0fb0e6
d816a26
 
 
e0fb0e6
f9b3f94
 
e0fb0e6
 
f9b3f94
 
 
 
 
e0fb0e6
f9b3f94
 
 
 
 
 
e0fb0e6
f9b3f94
 
d816a26
f9b3f94
e0fb0e6
16a5f8c
 
 
 
d816a26
e52ad65
 
06bd1fa
f9b3f94
 
e52ad65
f9b3f94
e0fb0e6
f9b3f94
e52ad65
f9b3f94
 
 
e0fb0e6
f9b3f94
e52ad65
f9b3f94
e52ad65
e0fb0e6
e52ad65
 
e0fb0e6
f9b3f94
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# handler.py

import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        """
        Called once at startup. 
        Load the SAM model using Hugging Face Transformers.
        """
        try:
            # Load the model and processor from the local path
            self.model = SamModel.from_pretrained(path).to(device)
            self.processor = SamProcessor.from_pretrained(path)
        except Exception as e:
            # Fallback to loading from a known SAM model if local loading fails
            print("Failed to load from local path: {}".format(e))
            print("Attempting to load from facebook/sam-vit-base")
            self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
            self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

    def __call__(self, data):
        """
        Called on every HTTP request.
        Handles both base64-encoded images and PIL images.
        Returns a PIL Image object.
        """
        # 1. Parse and decode the input image
        inputs = data.pop("inputs", None)
        if inputs is None:
            raise ValueError("Missing 'inputs' key in the payload.")

        # Check the type of inputs to handle both base64 strings and pre-processed PIL Images
        if isinstance(inputs, Image.Image):
            img = inputs.convert("RGB")
        elif isinstance(inputs, str):
            if inputs.startswith("data:"):
                inputs = inputs.split(",", 1)[1]
            image_bytes = base64.b64decode(inputs)
            img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        else:
            raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
        
        # 2. Prepare prompts and process the image
        height, width = img.size[1], img.size[0]
        input_points = [[[width // 2, height // 2]]]
        input_labels = [[1]]
        
        inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
        
        # 3. Generate masks
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # 4. Process and select the best mask
        try:
            original_height, original_width = inputs["original_sizes"][0].tolist()
            pred_masks = outputs.pred_masks.cpu()
            iou_scores = outputs.iou_scores.cpu()[0]

            if pred_masks.ndim == 5:
                pred_masks = pred_masks.squeeze(1)

            best_mask_idx = torch.argmax(iou_scores)
            best_mask_tensor = pred_masks[0, best_mask_idx, :, :]

            upscaled_mask = F.interpolate(
                best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
                size=(original_height, original_width),
                mode='bilinear',
                align_corners=False
            ).squeeze()

            mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
        
        except Exception as e:
            print("Error processing masks: {}".format(e))
            mask_binary = np.zeros((height, width), dtype=np.uint8)
            center_x, center_y = width // 2, height // 2
            size = min(width, height) // 8
            mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
        
        # 5. Create and return the output PIL Image
        output_img = Image.fromarray(mask_binary)
        return [{'score': None, 'label': 'everything', 'mask': output_img}]

def main():
    # This main function shows how a client would call the endpoint locally.
    input_path = "/Users/rp7/Downloads/test.jpeg"
    output_path = "output.png"

    # 1. Prepare the payload with a base64-encoded image string
    with open(input_path, "rb") as f:
        img_bytes = f.read()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}

    # 2. Instantiate handler and get the PIL Image result
    handler = EndpointHandler(path=".")
    result_img = handler(payload)

    # 3. Save the resulting image
    result_img.save(output_path)
    print("Wrote mask to {}".format(output_path))

if __name__ == "__main__":
    main()