raheebhassan's picture
Add code + LFS attributes
398659b
"""SAM3-style promptable segmentation.
This integrates a prompt-driven segmentation method into the existing
class-based segmentation interface. Instead of relying on a fixed class
vocabulary, it accepts natural-language prompts (e.g., "a red car", "the person").
Implementation approach (lightweight, no custom training):
- Use a text-conditioned detector (OWL-ViT) to propose bounding boxes from text.
- Use SAM (Segment Anything) to convert boxes into masks.
Notes:
- This is not "SAM 3" in the sense of an official model release; it is a
prompt-to-mask pipeline exposed as a single segmenter named "sam3".
- If required dependencies/models are missing, this segmenter raises a clear
error message.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from .base import BaseSegmenter
from model_cache import hf_cache_dir, ensure_default_checkpoint_dirs
@dataclass
class _SAM3Config:
detector_model: str = "google/owlvit-base-patch32"
sam_model: str = "facebook/sam-vit-base"
box_threshold: float = 0.02
max_boxes: int = 5
class SAM3Segmenter(BaseSegmenter):
"""Prompt-driven segmentation via (text detector → SAM).
Use `target_classes` to pass natural language prompts:
- `['car']`, `['a car']`, `['the person']`, etc.
Returns a binary mask (H, W) with 1 for predicted ROI.
"""
def __init__(
self,
device: str = "cuda",
detector_model: str = _SAM3Config.detector_model,
sam_model: str = _SAM3Config.sam_model,
box_threshold: float = _SAM3Config.box_threshold,
max_boxes: int = _SAM3Config.max_boxes,
**kwargs,
):
super().__init__(device=device, **kwargs)
self.detector_model_name = detector_model
self.sam_model_name = sam_model
self.box_threshold = float(box_threshold)
self.max_boxes = int(max_boxes)
self._detector = None
self._sam_model = None
self._sam_processor = None
def load_model(self):
try:
from transformers import pipeline, SamModel, SamProcessor
except Exception as e: # pragma: no cover
raise ImportError(
"SAM3Segmenter requires `transformers` with SAM support. "
"Try: pip install -U transformers"
) from e
# Make sure any HF downloads (including pipeline internals) land under `checkpoints/`.
ensure_default_checkpoint_dirs()
# Configure device for HF pipeline
if self.device.startswith("cuda") and torch.cuda.is_available():
pipeline_device = 0
else:
pipeline_device = -1
self._detector = pipeline(
task="zero-shot-object-detection",
model=self.detector_model_name,
device=pipeline_device,
)
cache_dir = str(hf_cache_dir())
self._sam_processor = SamProcessor.from_pretrained(self.sam_model_name, cache_dir=cache_dir)
self._sam_model = SamModel.from_pretrained(self.sam_model_name, cache_dir=cache_dir)
self._sam_model = self._sam_model.to(self.device)
self._sam_model.eval()
# Keep BaseSegmenter.model set for consistency
self.model = self._sam_model
def segment(
self,
image: Image.Image,
target_classes: Optional[List[str]] = None,
**kwargs,
) -> np.ndarray:
self.ensure_loaded()
prompts: List[str]
if target_classes is None or len(target_classes) == 0:
prompts = ["object"]
else:
# Treat provided "classes" as free-form text prompts.
prompts = [str(p).strip() for p in target_classes if str(p).strip()]
if not prompts:
prompts = ["object"]
box_threshold = float(kwargs.get("box_threshold", self.box_threshold))
max_boxes = int(kwargs.get("max_boxes", self.max_boxes))
detections = self._detector(image, candidate_labels=prompts)
# HF pipeline may return dict (single) or list
if isinstance(detections, dict):
detections = [detections]
boxes: List[List[float]] = []
for det in detections:
score = float(det.get("score", 0.0))
if score < box_threshold:
continue
b = det.get("box") or {}
xmin = float(b.get("xmin", 0.0))
ymin = float(b.get("ymin", 0.0))
xmax = float(b.get("xmax", 0.0))
ymax = float(b.get("ymax", 0.0))
# Sanity clamp
xmin, ymin = max(0.0, xmin), max(0.0, ymin)
xmax, ymax = max(xmin + 1.0, xmax), max(ymin + 1.0, ymax)
boxes.append([xmin, ymin, xmax, ymax])
if not boxes:
return np.zeros((image.height, image.width), dtype=np.float32)
boxes = boxes[:max_boxes]
# SAM expects a batch; provide one image with N boxes
inputs = self._sam_processor(
image,
input_boxes=[boxes],
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._sam_model(**inputs)
# Post-process masks back to original image size
# Returns list (batch) of tensors: [num_boxes, H, W]
post = self._sam_processor.image_processor.post_process_masks(
outputs.pred_masks.detach().cpu(),
inputs["original_sizes"].detach().cpu(),
inputs["reshaped_input_sizes"].detach().cpu(),
)
masks0 = post[0]
if isinstance(masks0, (list, tuple)):
# Defensive: some versions may nest
masks0 = torch.stack([m.squeeze(0) if m.ndim == 3 else m for m in masks0], dim=0)
# masks0: [num_boxes, H, W] or [num_boxes, 1, H, W]
if masks0.ndim == 4:
masks0 = masks0[:, 0]
combined = (masks0 > 0.5).any(dim=0).to(torch.float32)
return combined.numpy()
def get_available_classes(self) -> Union[List[str], dict]:
# Prompt-based model: not a fixed class list.
return []
def get_default_classes(self) -> List[str]:
return ["object"]