| from typing import Dict, Any |
| from transformers import Sam3Processor, Sam3Model |
| import torch |
| from PIL import Image |
| import io |
| import base64 |
| import numpy as np |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the SAM3 model and processor for text-prompted segmentation. |
| |
| Args: |
| path: Path to local model files (if deploying with custom weights) |
| or empty string to use the default facebook/sam3 model |
| """ |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| model_id = path if path else "facebook/sam3" |
| |
| self.processor = Sam3Processor.from_pretrained(model_id) |
| self.model = Sam3Model.from_pretrained(model_id).to(self.device) |
| self.model.eval() |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process an image with a text prompt and return segmentation masks. |
| |
| Expected input format: |
| { |
| "inputs": { |
| "image": "<base64_encoded_image>", |
| "prompt": "text description of object to segment" # e.g., "a red car" |
| } |
| } |
| |
| Returns: |
| { |
| "masks": [...], # List of binary masks as base64 encoded PNGs |
| "boxes": [...], # Bounding boxes in xyxy format |
| "scores": [...] # Confidence scores |
| } |
| """ |
| |
| inputs = data.pop("inputs", data) |
| image_b64 = inputs.get("image") |
| text_prompt = inputs.get("prompt", None) |
| |
| |
| threshold = inputs.get("threshold", 0.5) |
| mask_threshold = inputs.get("mask_threshold", 0.5) |
| |
| if not image_b64: |
| return {"error": "No image provided. Please provide a base64-encoded image."} |
| |
| if not text_prompt: |
| return {"error": "No text prompt provided. Please provide a 'prompt' field."} |
| |
| |
| try: |
| image_data = base64.b64decode(image_b64) |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") |
| except Exception as e: |
| return {"error": f"Failed to decode image: {str(e)}"} |
|
|
| |
| processor_inputs = self.processor( |
| images=image, |
| text=text_prompt, |
| return_tensors="pt" |
| ).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(**processor_inputs) |
|
|
| |
| results = self.processor.post_process_instance_segmentation( |
| outputs, |
| threshold=threshold, |
| mask_threshold=mask_threshold, |
| target_sizes=processor_inputs.get("original_sizes").tolist() |
| )[0] |
| |
| |
| response = { |
| "masks": [], |
| "boxes": [], |
| "scores": [] |
| } |
| |
| if len(results["masks"]) > 0: |
| |
| for mask in results["masks"]: |
| |
| mask_np = mask.cpu().numpy().astype(np.uint8) * 255 |
| mask_img = Image.fromarray(mask_np, mode="L") |
| |
| |
| buffer = io.BytesIO() |
| mask_img.save(buffer, format="PNG") |
| mask_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| response["masks"].append(mask_b64) |
| |
| |
| if "boxes" in results: |
| response["boxes"] = results["boxes"].cpu().tolist() |
| |
| |
| if "scores" in results: |
| response["scores"] = results["scores"].cpu().tolist() |
| |
| response["num_objects"] = len(response["masks"]) |
| |
| return response |
|
|