File size: 5,085 Bytes
592adee
2f7cfdc
592adee
d05bd8d
592adee
 
d816a26
 
d05bd8d
f9b3f94
592adee
c78d04e
 
 
d05bd8d
 
2f7cfdc
 
d816a26
2f7cfdc
d816a26
 
c78d04e
d816a26
 
 
 
 
c78d04e
d816a26
592adee
c78d04e
2f7cfdc
 
c78d04e
 
 
2f7cfdc
c78d04e
 
 
064d94e
c78d04e
 
064d94e
2f7cfdc
16a5f8c
 
 
 
 
 
 
 
 
d816a26
 
 
 
 
f9b3f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d816a26
f9b3f94
 
 
 
16a5f8c
 
 
 
d816a26
d05bd8d
2f7cfdc
d816a26
2f7cfdc
d05bd8d
592adee
c78d04e
 
 
 
 
 
d05bd8d
c78d04e
f9b3f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# handler.py

import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        """
        Called once at startup. 
        Load the SAM model using Hugging Face Transformers.
        """
        try:
            # Load the model and processor from the local path
            self.model = SamModel.from_pretrained(path).to(device)
            self.processor = SamProcessor.from_pretrained(path)
        except Exception as e:
            # Fallback to loading from a known SAM model if local loading fails
            print(f"Failed to load from local path: {e}")
            print("Attempting to load from facebook/sam-vit-base")
            self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
            self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

    def __call__(self, data: Any) -> Any:
        """
        Called on every HTTP request.
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})

        img = Image.open(io.BytesIO(inputs))

        # img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        # img = raw_images[0]
        
        # SAM requires input prompts, so we'll generate a center point prompt
        height, width = img.size[1], img.size[0]  # PIL returns (width, height)
        
        # Create a center point prompt for automatic segmentation
        input_points = [[[width // 2, height // 2]]]  # Center point
        input_labels = [[1]]  # Positive prompt
        
        # Prepare inputs for the model with prompts
        inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
        
        # Generate masks using the model
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        try:
            # Get original image size
            original_height, original_width = inputs["original_sizes"][0].tolist()

            # Get predicted masks and scores
            pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
            iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)

            # The model might return 4D or 5D tensors. Squeeze if 5D.
            if pred_masks.ndim == 5:
                pred_masks = pred_masks.squeeze(1)

            # Select the best mask
            best_mask_idx = torch.argmax(iou_scores)
            best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W)

            # Upscale the mask to original image size
            # Add batch and channel dims for interpolate
            upscaled_mask = F.interpolate(
                best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
                size=(original_height, original_width),
                mode='bilinear',
                align_corners=False
            ).squeeze() # remove batch/channel dims

            # Convert to binary mask
            mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
        
        except Exception as e:
            print(f"Error processing masks: {e}")
            # Fallback
            height, width = img.size[1], img.size[0]
            mask_binary = np.zeros((height, width), dtype=np.uint8)
            center_x, center_y = width // 2, height // 2
            size = min(width, height) // 8
            mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
        
        # Convert result to base64
        out = io.BytesIO()
        Image.fromarray(mask_binary).save(out, format="PNG")
        out.seek(0)
        mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')

        # Decode the returned mask and save
        mask_bytes = base64.b64decode(mask_base64)
        mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
        # mask_img.save(output_path, format="JPEG")
        # print(f"Wrote mask to {output_path}")
        
        # Return in the expected format
        return mask_img

def main():
    # Hardcoded input and output paths
    input_path = "/Users/rp7/Downloads/test.jpeg"
    output_path = "output.jpg"

    # Read and base64-encode the input image
    with open(input_path, "rb") as f:
        img_bytes = f.read()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    data_url = f"data:image/jpeg;base64,{img_b64}"

    handler = EndpointHandler(path=".")
    result = handler({"inputs": data_url})[0]

    # Decode the returned mask and save
    mask_bytes = base64.b64decode(result["mask_png_base64"])
    mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
    mask_img.save(output_path, format="JPEG")
    print(f"Wrote mask to {output_path}")

if __name__ == "__main__":
    main()