# handler.py import io import base64 import numpy as np from PIL import Image import torch from transformers import SamModel, SamProcessor from typing import Dict, List, Any class EndpointHandler(): def __init__(self, path=""): """ Called once at startup. Load the SAM model using Hugging Face Transformers. """ try: # Load the model and processor from the local path self.model = SamModel.from_pretrained(path) self.processor = SamProcessor.from_pretrained(path) except Exception as e: # Fallback to loading from a known SAM model if local loading fails print(f"Failed to load from local path: {e}") print("Attempting to load from facebook/sam-vit-base") self.model = SamModel.from_pretrained("facebook/sam-vit-base") self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Called on every HTTP request. Expecting base64 encoded image in the 'inputs' field or 'image' field. """ # Handle different input formats if "inputs" in data: if isinstance(data["inputs"], str): # Base64 encoded image image_bytes = base64.b64decode(data["inputs"]) elif isinstance(data["inputs"], dict) and "image" in data["inputs"]: # Nested structure with image field image_bytes = base64.b64decode(data["inputs"]["image"]) else: raise ValueError("Invalid input format. Expected base64 encoded image string.") elif "image" in data: # Direct image field image_bytes = base64.b64decode(data["image"]) else: raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.") # Process the image img = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Prepare inputs for the model inputs = self.processor(img, return_tensors="pt") # Generate masks using the model with torch.no_grad(): outputs = self.model(**inputs) # Process the outputs to get masks masks = self.processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0] # Convert the first mask to a binary mask mask = masks[0].squeeze().numpy() mask_binary = (mask > 0.0).astype(np.uint8) * 255 # Convert result to base64 out = io.BytesIO() Image.fromarray(mask_binary).save(out, format="PNG") out.seek(0) mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8') # Return in the expected format return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}]