| """SAM3-style promptable segmentation. |
| |
| This integrates a prompt-driven segmentation method into the existing |
| class-based segmentation interface. Instead of relying on a fixed class |
| vocabulary, it accepts natural-language prompts (e.g., "a red car", "the person"). |
| |
| Implementation approach (lightweight, no custom training): |
| - Use a text-conditioned detector (OWL-ViT) to propose bounding boxes from text. |
| - Use SAM (Segment Anything) to convert boxes into masks. |
| |
| Notes: |
| - This is not "SAM 3" in the sense of an official model release; it is a |
| prompt-to-mask pipeline exposed as a single segmenter named "sam3". |
| - If required dependencies/models are missing, this segmenter raises a clear |
| error message. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from .base import BaseSegmenter |
| from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs |
|
|
|
|
| @dataclass |
| class _SAM3Config: |
| detector_model: str = "google/owlvit-base-patch32" |
| sam_model: str = "facebook/sam-vit-base" |
| box_threshold: float = 0.02 |
| max_boxes: int = 5 |
|
|
|
|
| class SAM3Segmenter(BaseSegmenter): |
| """Prompt-driven segmentation via (text detector → SAM). |
| |
| Use `target_classes` to pass natural language prompts: |
| - `['car']`, `['a car']`, `['the person']`, etc. |
| |
| Returns a binary mask (H, W) with 1 for predicted ROI. |
| """ |
|
|
| def __init__( |
| self, |
| device: str = "cuda", |
| detector_model: str = _SAM3Config.detector_model, |
| sam_model: str = _SAM3Config.sam_model, |
| box_threshold: float = _SAM3Config.box_threshold, |
| max_boxes: int = _SAM3Config.max_boxes, |
| **kwargs, |
| ): |
| super().__init__(device=device, **kwargs) |
| self.detector_model_name = detector_model |
| self.sam_model_name = sam_model |
| self.box_threshold = float(box_threshold) |
| self.max_boxes = int(max_boxes) |
|
|
| self._detector = None |
| self._sam_model = None |
| self._sam_processor = None |
|
|
| def load_model(self): |
| try: |
| from transformers import pipeline, SamModel, SamProcessor |
| except Exception as e: |
| raise ImportError( |
| "SAM3Segmenter requires `transformers` with SAM support. " |
| "Try: pip install -U transformers" |
| ) from e |
|
|
| |
| ensure_default_checkpoint_dirs() |
|
|
| |
| if self.device.startswith("cuda") and torch.cuda.is_available(): |
| pipeline_device = 0 |
| else: |
| pipeline_device = -1 |
|
|
| self._detector = pipeline( |
| task="zero-shot-object-detection", |
| model=self.detector_model_name, |
| device=pipeline_device, |
| ) |
|
|
| cache_dir = str(hf_cache_dir()) |
|
|
| self._sam_processor = SamProcessor.from_pretrained(self.sam_model_name, cache_dir=cache_dir) |
| self._sam_model = SamModel.from_pretrained(self.sam_model_name, cache_dir=cache_dir) |
| self._sam_model = self._sam_model.to(self.device) |
| self._sam_model.eval() |
|
|
| |
| self.model = self._sam_model |
|
|
| def segment( |
| self, |
| image: Image.Image, |
| target_classes: Optional[List[str]] = None, |
| **kwargs, |
| ) -> np.ndarray: |
| self.ensure_loaded() |
|
|
| prompts: List[str] |
| if target_classes is None or len(target_classes) == 0: |
| prompts = ["object"] |
| else: |
| |
| prompts = [str(p).strip() for p in target_classes if str(p).strip()] |
| if not prompts: |
| prompts = ["object"] |
|
|
| box_threshold = float(kwargs.get("box_threshold", self.box_threshold)) |
| max_boxes = int(kwargs.get("max_boxes", self.max_boxes)) |
|
|
| detections = self._detector(image, candidate_labels=prompts) |
|
|
| |
| if isinstance(detections, dict): |
| detections = [detections] |
|
|
| boxes: List[List[float]] = [] |
| for det in detections: |
| score = float(det.get("score", 0.0)) |
| if score < box_threshold: |
| continue |
| b = det.get("box") or {} |
| xmin = float(b.get("xmin", 0.0)) |
| ymin = float(b.get("ymin", 0.0)) |
| xmax = float(b.get("xmax", 0.0)) |
| ymax = float(b.get("ymax", 0.0)) |
| |
| xmin, ymin = max(0.0, xmin), max(0.0, ymin) |
| xmax, ymax = max(xmin + 1.0, xmax), max(ymin + 1.0, ymax) |
| boxes.append([xmin, ymin, xmax, ymax]) |
|
|
| if not boxes: |
| return np.zeros((image.height, image.width), dtype=np.float32) |
|
|
| boxes = boxes[:max_boxes] |
|
|
| |
| inputs = self._sam_processor( |
| image, |
| input_boxes=[boxes], |
| return_tensors="pt", |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = self._sam_model(**inputs) |
|
|
| |
| |
| post = self._sam_processor.image_processor.post_process_masks( |
| outputs.pred_masks.detach().cpu(), |
| inputs["original_sizes"].detach().cpu(), |
| inputs["reshaped_input_sizes"].detach().cpu(), |
| ) |
|
|
| masks0 = post[0] |
| if isinstance(masks0, (list, tuple)): |
| |
| masks0 = torch.stack([m.squeeze(0) if m.ndim == 3 else m for m in masks0], dim=0) |
|
|
| |
| if masks0.ndim == 4: |
| masks0 = masks0[:, 0] |
|
|
| combined = (masks0 > 0.5).any(dim=0).to(torch.float32) |
| return combined.numpy() |
|
|
| def get_available_classes(self) -> Union[List[str], dict]: |
| |
| return [] |
|
|
| def get_default_classes(self) -> List[str]: |
| return ["object"] |
|
|