# handler.py import io import base64 import numpy as np from PIL import Image import torch from transformers import SamModel, SamProcessor from typing import Dict, List, Any import torch.nn.functional as F # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): """ Called once at startup. Load the SAM model using Hugging Face Transformers. """ try: # Load the model and processor from the local path self.model = SamModel.from_pretrained(path).to(device) self.processor = SamProcessor.from_pretrained(path) except Exception as e: # Fallback to loading from a known SAM model if local loading fails print(f"Failed to load from local path: {e}") print("Attempting to load from facebook/sam-vit-base") self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def __call__(self, data: Any) -> Any: """ Called on every HTTP request. Args: data (:obj:): includes the input data and the parameters for the inference. """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) img = Image.open(io.BytesIO(inputs)) # img = Image.open(io.BytesIO(image_bytes)).convert("RGB") # img = raw_images[0] # SAM requires input prompts, so we'll generate a center point prompt height, width = img.size[1], img.size[0] # PIL returns (width, height) # Create a center point prompt for automatic segmentation input_points = [[[width // 2, height // 2]]] # Center point input_labels = [[1]] # Positive prompt # Prepare inputs for the model with prompts inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt") # Generate masks using the model with torch.no_grad(): outputs = self.model(**inputs) try: # Get original image size original_height, original_width = inputs["original_sizes"][0].tolist() # Get predicted masks and scores pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W) iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,) # The model might return 4D or 5D tensors. Squeeze if 5D. if pred_masks.ndim == 5: pred_masks = pred_masks.squeeze(1) # Select the best mask best_mask_idx = torch.argmax(iou_scores) best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W) # Upscale the mask to original image size # Add batch and channel dims for interpolate upscaled_mask = F.interpolate( best_mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=(original_height, original_width), mode='bilinear', align_corners=False ).squeeze() # remove batch/channel dims # Convert to binary mask mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255 except Exception as e: print(f"Error processing masks: {e}") # Fallback height, width = img.size[1], img.size[0] mask_binary = np.zeros((height, width), dtype=np.uint8) center_x, center_y = width // 2, height // 2 size = min(width, height) // 8 mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255 # Convert result to base64 out = io.BytesIO() Image.fromarray(mask_binary).save(out, format="PNG") out.seek(0) mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8') # Decode the returned mask and save mask_bytes = base64.b64decode(mask_base64) mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB") # mask_img.save(output_path, format="JPEG") # print(f"Wrote mask to {output_path}") # Return in the expected format return mask_img def main(): # Hardcoded input and output paths input_path = "/Users/rp7/Downloads/test.jpeg" output_path = "output.jpg" # Read and base64-encode the input image with open(input_path, "rb") as f: img_bytes = f.read() img_b64 = base64.b64encode(img_bytes).decode("utf-8") data_url = f"data:image/jpeg;base64,{img_b64}" handler = EndpointHandler(path=".") result = handler({"inputs": data_url})[0] # Decode the returned mask and save mask_bytes = base64.b64decode(result["mask_png_base64"]) mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB") mask_img.save(output_path, format="JPEG") print(f"Wrote mask to {output_path}") if __name__ == "__main__": main()