import logging from typing import Optional, Sequence import numpy as np import torch from PIL import Image from transformers import Sam3Model, Sam3Processor from .base import Segmenter, SegmentationResult class SAM3Segmenter(Segmenter): """ SAM3 (Segment Anything Model 3) segmenter. Performs automatic instance segmentation on images without prompts. Uses facebook/sam3 model from HuggingFace. """ name = "sam3" def __init__( self, model_id: str = "facebook/sam3", device: Optional[str] = None, threshold: float = 0.5, mask_threshold: float = 0.5, ): """ Initialize SAM3 segmenter. Args: model_id: HuggingFace model ID device: Device to run on (cuda/cpu), auto-detected if None threshold: Confidence threshold for filtering instances mask_threshold: Threshold for binarizing masks """ self.device = device or ( "cuda" if torch.cuda.is_available() else "cpu" ) self.threshold = threshold self.mask_threshold = mask_threshold logging.info( "Loading SAM3 model %s on device %s", model_id, self.device ) try: self.model = Sam3Model.from_pretrained(model_id).to(self.device) self.processor = Sam3Processor.from_pretrained(model_id) self.model.eval() except Exception: logging.exception("Failed to load SAM3 model") raise logging.info("SAM3 model loaded successfully") supports_batch = True max_batch_size = 8 def _parse_single_result(self, results, frame_shape) -> SegmentationResult: # Extract results masks = results.get("masks", []) scores = results.get("scores", None) boxes = results.get("boxes", None) # Convert to numpy arrays if len(masks) > 0: # Stack masks: list of (H, W) -> (N, H, W) masks_array = np.stack([m.cpu().numpy() for m in masks]) else: # No objects detected masks_array = np.zeros( (0, frame_shape[0], frame_shape[1]), dtype=bool ) scores_array = ( scores.cpu().numpy() if scores is not None else None ) boxes_array = ( boxes.cpu().numpy() if boxes is not None else None ) return SegmentationResult( masks=masks_array, scores=scores_array, boxes=boxes_array, ) def _expand_inputs_if_needed(self, inputs): """ Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts. Handles: 1. 1 image, N texts (Expand 1 -> N) 2. N images, N*M texts (Expand N -> N*M) """ pixel_values = inputs.get("pixel_values") input_ids = inputs.get("input_ids") if ( pixel_values is not None and input_ids is not None ): img_batch = pixel_values.shape[0] text_batch = input_ids.shape[0] should_expand = False expansion_factor = 1 if img_batch == 1 and text_batch > 1: should_expand = True expansion_factor = text_batch elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0: should_expand = True expansion_factor = text_batch // img_batch if should_expand: logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.") # 1. Compute vision embeddings once for original images with torch.no_grad(): vision_outputs = self.model.get_vision_features( pixel_values=pixel_values ) # Iterate over keys to expand keys_to_expand = list(vision_outputs.keys()) for key in keys_to_expand: value = getattr(vision_outputs, key, None) if value is None: # Try getItem try: value = vision_outputs[key] except: continue new_value = None if isinstance(value, torch.Tensor): # Ensure we only expand the batch dimension (dim 0) if value.shape[0] == img_batch: new_value = value.repeat_interleave(expansion_factor, dim=0) elif isinstance(value, (list, tuple)): new_list = [] valid_expansion = False for i, v in enumerate(value): if isinstance(v, torch.Tensor) and v.shape[0] == img_batch: new_list.append(v.repeat_interleave(expansion_factor, dim=0)) valid_expansion = True else: new_list.append(v) if valid_expansion: # Preserve type new_value = type(value)(new_list) if new_value is not None: # Update dict item if possible try: vision_outputs[key] = new_value except: pass # Update attribute explicitly if it exists if hasattr(vision_outputs, key): setattr(vision_outputs, key, new_value) # 3. Update inputs for model call inputs["vision_embeds"] = vision_outputs del inputs["pixel_values"] # Mutually exclusive with vision_embeds # 4. Expand other metadata if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch: inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0) if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch: inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0) def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult: """ Run SAM3 segmentation on a frame. Args: frame: Input image (HxWx3 numpy array in RGB) text_prompts: List of text prompts for segmentation Returns: SegmentationResult with instance masks """ # Convert numpy array to PIL Image if frame.dtype == np.uint8: pil_image = Image.fromarray(frame) else: # Normalize to 0-255 if needed frame_uint8 = (frame * 255).astype(np.uint8) pil_image = Image.fromarray(frame_uint8) # Use default prompts if none provided if not text_prompts: text_prompts = ["object"] # Process image with text prompts inputs = self.processor( images=pil_image, text=text_prompts, return_tensors="pt" ).to(self.device) # Handle batch expansion self._expand_inputs_if_needed(inputs) # Run inference try: if "pixel_values" in inputs: logging.debug(f"SAM3 Input pixel_values shape: {inputs['pixel_values'].shape}") with torch.no_grad(): outputs = self.model(**inputs) except RuntimeError as e: logging.error(f"RuntimeError during SAM3 inference: {e}") logging.error(f"Input keys: {inputs.keys()}") if 'pixel_values' in inputs: logging.error(f"Pixel values shape: {inputs['pixel_values'].shape}") # Re-raise to let user know raise # Post-process to get instance masks try: results = self.processor.post_process_instance_segmentation( outputs, threshold=self.threshold, mask_threshold=self.mask_threshold, target_sizes=inputs.get("original_sizes").tolist(), )[0] return self._parse_single_result(results, frame.shape) except Exception: logging.exception("SAM3 post-processing failed") # Return empty result return SegmentationResult( masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool), scores=None, boxes=None, ) def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]: pil_images = [] for f in frames: if f.dtype == np.uint8: pil_images.append(Image.fromarray(f)) else: f_uint8 = (f * 255).astype(np.uint8) pil_images.append(Image.fromarray(f_uint8)) prompts = text_prompts or ["object"] # Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...] flattened_prompts = [] for _ in frames: flattened_prompts.extend(prompts) inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device) # Handle batch expansion self._expand_inputs_if_needed(inputs) with torch.no_grad(): outputs = self.model(**inputs) try: results_list = self.processor.post_process_instance_segmentation( outputs, threshold=self.threshold, mask_threshold=self.mask_threshold, target_sizes=inputs.get("original_sizes").tolist(), ) return [self._parse_single_result(r, f.shape) for r, f in zip(results_list, frames)] except Exception: logging.exception("SAM3 batch post-processing failed") return [ SegmentationResult( masks=np.zeros((0, f.shape[0], f.shape[1]), dtype=bool), scores=None, boxes=None ) for f in frames ]