Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import Sam3Model, Sam3Processor | |
| from .base import Segmenter, SegmentationResult | |
| class SAM3Segmenter(Segmenter): | |
| """ | |
| SAM3 (Segment Anything Model 3) segmenter. | |
| Performs automatic instance segmentation on images without prompts. | |
| Uses facebook/sam3 model from HuggingFace. | |
| """ | |
| name = "sam3" | |
| def __init__( | |
| self, | |
| model_id: str = "facebook/sam3", | |
| device: Optional[str] = None, | |
| threshold: float = 0.5, | |
| mask_threshold: float = 0.5, | |
| ): | |
| """ | |
| Initialize SAM3 segmenter. | |
| Args: | |
| model_id: HuggingFace model ID | |
| device: Device to run on (cuda/cpu), auto-detected if None | |
| threshold: Confidence threshold for filtering instances | |
| mask_threshold: Threshold for binarizing masks | |
| """ | |
| self.device = device or ( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.threshold = threshold | |
| self.mask_threshold = mask_threshold | |
| logging.info( | |
| "Loading SAM3 model %s on device %s", model_id, self.device | |
| ) | |
| try: | |
| self.model = Sam3Model.from_pretrained(model_id).to(self.device) | |
| self.processor = Sam3Processor.from_pretrained(model_id) | |
| self.model.eval() | |
| except Exception: | |
| logging.exception("Failed to load SAM3 model") | |
| raise | |
| logging.info("SAM3 model loaded successfully") | |
| def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult: | |
| """ | |
| Run SAM3 segmentation on a frame. | |
| Args: | |
| frame: Input image (HxWx3 numpy array in RGB) | |
| text_prompts: List of text prompts for segmentation | |
| Returns: | |
| SegmentationResult with instance masks | |
| """ | |
| # Convert numpy array to PIL Image | |
| if frame.dtype == np.uint8: | |
| pil_image = Image.fromarray(frame) | |
| else: | |
| # Normalize to 0-255 if needed | |
| frame_uint8 = (frame * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(frame_uint8) | |
| # Use default prompts if none provided | |
| if not text_prompts: | |
| text_prompts = ["object"] | |
| # Process image with text prompts | |
| inputs = self.processor( | |
| images=pil_image, text=text_prompts, return_tensors="pt" | |
| ).to(self.device) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Post-process to get instance masks | |
| try: | |
| results = self.processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=self.threshold, | |
| mask_threshold=self.mask_threshold, | |
| target_sizes=inputs.get("original_sizes").tolist(), | |
| )[0] | |
| # Extract results | |
| masks = results.get("masks", []) | |
| scores = results.get("scores", None) | |
| boxes = results.get("boxes", None) | |
| # Convert to numpy arrays | |
| if len(masks) > 0: | |
| # Stack masks: list of (H, W) -> (N, H, W) | |
| masks_array = np.stack([m.cpu().numpy() for m in masks]) | |
| else: | |
| # No objects detected | |
| masks_array = np.zeros( | |
| (0, frame.shape[0], frame.shape[1]), dtype=bool | |
| ) | |
| scores_array = ( | |
| scores.cpu().numpy() if scores is not None else None | |
| ) | |
| boxes_array = ( | |
| boxes.cpu().numpy() if boxes is not None else None | |
| ) | |
| return SegmentationResult( | |
| masks=masks_array, | |
| scores=scores_array, | |
| boxes=boxes_array, | |
| ) | |
| except Exception: | |
| logging.exception("SAM3 post-processing failed") | |
| # Return empty result | |
| return SegmentationResult( | |
| masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool), | |
| scores=None, | |
| boxes=None, | |
| ) | |