Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Optional, Sequence | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import Sam3Model, Sam3Processor | |
| from .base import Segmenter, SegmentationResult | |
| class SAM3Segmenter(Segmenter): | |
| """ | |
| SAM3 (Segment Anything Model 3) segmenter. | |
| Performs automatic instance segmentation on images without prompts. | |
| Uses facebook/sam3 model from HuggingFace. | |
| """ | |
| name = "sam3" | |
| def __init__( | |
| self, | |
| model_id: str = "facebook/sam3", | |
| device: Optional[str] = None, | |
| threshold: float = 0.5, | |
| mask_threshold: float = 0.5, | |
| ): | |
| """ | |
| Initialize SAM3 segmenter. | |
| Args: | |
| model_id: HuggingFace model ID | |
| device: Device to run on (cuda/cpu), auto-detected if None | |
| threshold: Confidence threshold for filtering instances | |
| mask_threshold: Threshold for binarizing masks | |
| """ | |
| self.device = device or ( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.threshold = threshold | |
| self.mask_threshold = mask_threshold | |
| logging.info( | |
| "Loading SAM3 model %s on device %s", model_id, self.device | |
| ) | |
| try: | |
| self.model = Sam3Model.from_pretrained(model_id).to(self.device) | |
| self.processor = Sam3Processor.from_pretrained(model_id) | |
| self.model.eval() | |
| except Exception: | |
| logging.exception("Failed to load SAM3 model") | |
| raise | |
| logging.info("SAM3 model loaded successfully") | |
| supports_batch = True | |
| max_batch_size = 8 | |
| def _parse_single_result(self, results, frame_shape) -> SegmentationResult: | |
| # Extract results | |
| masks = results.get("masks", []) | |
| scores = results.get("scores", None) | |
| boxes = results.get("boxes", None) | |
| # Convert to numpy arrays | |
| if len(masks) > 0: | |
| # Stack masks: list of (H, W) -> (N, H, W) | |
| masks_array = np.stack([m.cpu().numpy() for m in masks]) | |
| else: | |
| # No objects detected | |
| masks_array = np.zeros( | |
| (0, frame_shape[0], frame_shape[1]), dtype=bool | |
| ) | |
| scores_array = ( | |
| scores.cpu().numpy() if scores is not None else None | |
| ) | |
| boxes_array = ( | |
| boxes.cpu().numpy() if boxes is not None else None | |
| ) | |
| return SegmentationResult( | |
| masks=masks_array, | |
| scores=scores_array, | |
| boxes=boxes_array, | |
| ) | |
| def _expand_inputs_if_needed(self, inputs): | |
| """ | |
| Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts. | |
| Handles: | |
| 1. 1 image, N texts (Expand 1 -> N) | |
| 2. N images, N*M texts (Expand N -> N*M) | |
| """ | |
| pixel_values = inputs.get("pixel_values") | |
| input_ids = inputs.get("input_ids") | |
| if ( | |
| pixel_values is not None | |
| and input_ids is not None | |
| ): | |
| img_batch = pixel_values.shape[0] | |
| text_batch = input_ids.shape[0] | |
| should_expand = False | |
| expansion_factor = 1 | |
| if img_batch == 1 and text_batch > 1: | |
| should_expand = True | |
| expansion_factor = text_batch | |
| elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0: | |
| should_expand = True | |
| expansion_factor = text_batch // img_batch | |
| if should_expand: | |
| logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.") | |
| # 1. Compute vision embeddings once for original images | |
| with torch.no_grad(): | |
| vision_outputs = self.model.get_vision_features( | |
| pixel_values=pixel_values | |
| ) | |
| # Iterate over keys to expand | |
| keys_to_expand = list(vision_outputs.keys()) | |
| for key in keys_to_expand: | |
| value = getattr(vision_outputs, key, None) | |
| if value is None: | |
| # Try getItem | |
| try: | |
| value = vision_outputs[key] | |
| except: | |
| continue | |
| new_value = None | |
| if isinstance(value, torch.Tensor): | |
| # Ensure we only expand the batch dimension (dim 0) | |
| if value.shape[0] == img_batch: | |
| new_value = value.repeat_interleave(expansion_factor, dim=0) | |
| elif isinstance(value, (list, tuple)): | |
| new_list = [] | |
| valid_expansion = False | |
| for i, v in enumerate(value): | |
| if isinstance(v, torch.Tensor) and v.shape[0] == img_batch: | |
| new_list.append(v.repeat_interleave(expansion_factor, dim=0)) | |
| valid_expansion = True | |
| else: | |
| new_list.append(v) | |
| if valid_expansion: | |
| # Preserve type | |
| new_value = type(value)(new_list) | |
| if new_value is not None: | |
| # Update dict item if possible | |
| try: | |
| vision_outputs[key] = new_value | |
| except: | |
| pass | |
| # Update attribute explicitly if it exists | |
| if hasattr(vision_outputs, key): | |
| setattr(vision_outputs, key, new_value) | |
| # 3. Update inputs for model call | |
| inputs["vision_embeds"] = vision_outputs | |
| del inputs["pixel_values"] # Mutually exclusive with vision_embeds | |
| # 4. Expand other metadata | |
| if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch: | |
| inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0) | |
| if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch: | |
| inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0) | |
| def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult: | |
| """ | |
| Run SAM3 segmentation on a frame. | |
| Args: | |
| frame: Input image (HxWx3 numpy array in RGB) | |
| text_prompts: List of text prompts for segmentation | |
| Returns: | |
| SegmentationResult with instance masks | |
| """ | |
| # Convert numpy array to PIL Image | |
| if frame.dtype == np.uint8: | |
| pil_image = Image.fromarray(frame) | |
| else: | |
| # Normalize to 0-255 if needed | |
| frame_uint8 = (frame * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(frame_uint8) | |
| # Use default prompts if none provided | |
| if not text_prompts: | |
| text_prompts = ["object"] | |
| # Process image with text prompts | |
| inputs = self.processor( | |
| images=pil_image, text=text_prompts, return_tensors="pt" | |
| ).to(self.device) | |
| # Handle batch expansion | |
| self._expand_inputs_if_needed(inputs) | |
| # Run inference | |
| try: | |
| if "pixel_values" in inputs: | |
| logging.debug(f"SAM3 Input pixel_values shape: {inputs['pixel_values'].shape}") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| except RuntimeError as e: | |
| logging.error(f"RuntimeError during SAM3 inference: {e}") | |
| logging.error(f"Input keys: {inputs.keys()}") | |
| if 'pixel_values' in inputs: | |
| logging.error(f"Pixel values shape: {inputs['pixel_values'].shape}") | |
| # Re-raise to let user know | |
| raise | |
| # Post-process to get instance masks | |
| try: | |
| results = self.processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=self.threshold, | |
| mask_threshold=self.mask_threshold, | |
| target_sizes=inputs.get("original_sizes").tolist(), | |
| )[0] | |
| return self._parse_single_result(results, frame.shape) | |
| except Exception: | |
| logging.exception("SAM3 post-processing failed") | |
| # Return empty result | |
| return SegmentationResult( | |
| masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool), | |
| scores=None, | |
| boxes=None, | |
| ) | |
| def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]: | |
| pil_images = [] | |
| for f in frames: | |
| if f.dtype == np.uint8: | |
| pil_images.append(Image.fromarray(f)) | |
| else: | |
| f_uint8 = (f * 255).astype(np.uint8) | |
| pil_images.append(Image.fromarray(f_uint8)) | |
| prompts = text_prompts or ["object"] | |
| # Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...] | |
| flattened_prompts = [] | |
| for _ in frames: | |
| flattened_prompts.extend(prompts) | |
| inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device) | |
| # Handle batch expansion | |
| self._expand_inputs_if_needed(inputs) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| try: | |
| results_list = self.processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=self.threshold, | |
| mask_threshold=self.mask_threshold, | |
| target_sizes=inputs.get("original_sizes").tolist(), | |
| ) | |
| return [self._parse_single_result(r, f.shape) for r, f in zip(results_list, frames)] | |
| except Exception: | |
| logging.exception("SAM3 batch post-processing failed") | |
| return [ | |
| SegmentationResult( | |
| masks=np.zeros((0, f.shape[0], f.shape[1]), dtype=bool), | |
| scores=None, | |
| boxes=None | |
| ) for f in frames | |
| ] | |