# handler.py import io import base64 import numpy as np from PIL import Image import torch from transformers import SamModel, SamProcessor from typing import Dict, List, Any import torch.nn.functional as F # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): """ Called once at startup. Load the SAM model using Hugging Face Transformers. """ try: # Load the model and processor from the local path self.model = SamModel.from_pretrained(path).to(device) self.processor = SamProcessor.from_pretrained(path) except Exception as e: # Fallback to loading from a known SAM model if local loading fails print("Failed to load from local path: {}".format(e)) print("Attempting to load from facebook/sam-vit-base") self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def __call__(self, data): """ Called on every HTTP request. Handles both base64-encoded images and PIL images. Returns a PIL Image object. """ # 1. Parse and decode the input image inputs = data.pop("inputs", None) if inputs is None: raise ValueError("Missing 'inputs' key in the payload.") # Check the type of inputs to handle both base64 strings and pre-processed PIL Images if isinstance(inputs, Image.Image): img = inputs.convert("RGB") elif isinstance(inputs, str): if inputs.startswith("data:"): inputs = inputs.split(",", 1)[1] image_bytes = base64.b64decode(inputs) img = Image.open(io.BytesIO(image_bytes)).convert("RGB") else: raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.") # 2. Prepare prompts and process the image height, width = img.size[1], img.size[0] input_points = [[[width // 2, height // 2]]] input_labels = [[1]] inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) # 3. Generate masks with torch.no_grad(): outputs = self.model(**inputs) # 4. Process and select the best mask try: original_height, original_width = inputs["original_sizes"][0].tolist() pred_masks = outputs.pred_masks.cpu() iou_scores = outputs.iou_scores.cpu()[0] if pred_masks.ndim == 5: pred_masks = pred_masks.squeeze(1) best_mask_idx = torch.argmax(iou_scores) best_mask_tensor = pred_masks[0, best_mask_idx, :, :] upscaled_mask = F.interpolate( best_mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=(original_height, original_width), mode='bilinear', align_corners=False ).squeeze() mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255 except Exception as e: print("Error processing masks: {}".format(e)) mask_binary = np.zeros((height, width), dtype=np.uint8) center_x, center_y = width // 2, height // 2 size = min(width, height) // 8 mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255 # 5. Create and return the output PIL Image output_img = Image.fromarray(mask_binary) return [{'score': None, 'label': 'everything', 'mask': output_img}] def main(): # This main function shows how a client would call the endpoint locally. input_path = "/Users/rp7/Downloads/test.jpeg" output_path = "output.png" # 1. Prepare the payload with a base64-encoded image string with open(input_path, "rb") as f: img_bytes = f.read() img_b64 = base64.b64encode(img_bytes).decode("utf-8") payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)} # 2. Instantiate handler and get the PIL Image result handler = EndpointHandler(path=".") result_img = handler(payload) # 3. Save the resulting image result_img.save(output_path) print("Wrote mask to {}".format(output_path)) if __name__ == "__main__": main()