from typing import Dict, Any from transformers import Sam3Processor, Sam3Model import torch from PIL import Image import io import base64 import numpy as np class EndpointHandler: def __init__(self, path=""): """ Initialize the SAM3 model and processor for text-prompted segmentation. Args: path: Path to local model files (if deploying with custom weights) or empty string to use the default facebook/sam3 model """ self.device = "cuda" if torch.cuda.is_available() else "cpu" # Use local path if provided, otherwise use the default model model_id = path if path else "facebook/sam3" self.processor = Sam3Processor.from_pretrained(model_id) self.model = Sam3Model.from_pretrained(model_id).to(self.device) self.model.eval() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process an image with a text prompt and return segmentation masks. Expected input format: { "inputs": { "image": "", "prompt": "text description of object to segment" # e.g., "a red car" } } Returns: { "masks": [...], # List of binary masks as base64 encoded PNGs "boxes": [...], # Bounding boxes in xyxy format "scores": [...] # Confidence scores } """ # 1. Extract inputs inputs = data.pop("inputs", data) image_b64 = inputs.get("image") text_prompt = inputs.get("prompt", None) # Optional parameters threshold = inputs.get("threshold", 0.5) mask_threshold = inputs.get("mask_threshold", 0.5) if not image_b64: return {"error": "No image provided. Please provide a base64-encoded image."} if not text_prompt: return {"error": "No text prompt provided. Please provide a 'prompt' field."} # 2. Decode image try: image_data = base64.b64decode(image_b64) image = Image.open(io.BytesIO(image_data)).convert("RGB") except Exception as e: return {"error": f"Failed to decode image: {str(e)}"} # 3. Process inputs with text prompt processor_inputs = self.processor( images=image, text=text_prompt, return_tensors="pt" ).to(self.device) # 4. Run Inference with torch.no_grad(): outputs = self.model(**processor_inputs) # 5. Post-process results results = self.processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=processor_inputs.get("original_sizes").tolist() )[0] # 6. Format response response = { "masks": [], "boxes": [], "scores": [] } if len(results["masks"]) > 0: # Convert masks to base64-encoded PNGs for mask in results["masks"]: # Convert boolean mask to uint8 image mask_np = mask.cpu().numpy().astype(np.uint8) * 255 mask_img = Image.fromarray(mask_np, mode="L") # Encode as base64 PNG buffer = io.BytesIO() mask_img.save(buffer, format="PNG") mask_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") response["masks"].append(mask_b64) # Convert boxes to list if "boxes" in results: response["boxes"] = results["boxes"].cpu().tolist() # Convert scores to list if "scores" in results: response["scores"] = results["scores"].cpu().tolist() response["num_objects"] = len(response["masks"]) return response