ISR / models /segmenters /sam3.py
Zhen Ye
inital commit
f297b77
raw
history blame
4.21 kB
import logging
from typing import Optional
import numpy as np
import torch
from PIL import Image
from transformers import Sam3Model, Sam3Processor
from .base import Segmenter, SegmentationResult
class SAM3Segmenter(Segmenter):
"""
SAM3 (Segment Anything Model 3) segmenter.
Performs automatic instance segmentation on images without prompts.
Uses facebook/sam3 model from HuggingFace.
"""
name = "sam3"
def __init__(
self,
model_id: str = "facebook/sam3",
device: Optional[str] = None,
threshold: float = 0.5,
mask_threshold: float = 0.5,
):
"""
Initialize SAM3 segmenter.
Args:
model_id: HuggingFace model ID
device: Device to run on (cuda/cpu), auto-detected if None
threshold: Confidence threshold for filtering instances
mask_threshold: Threshold for binarizing masks
"""
self.device = device or (
"cuda" if torch.cuda.is_available() else "cpu"
)
self.threshold = threshold
self.mask_threshold = mask_threshold
logging.info(
"Loading SAM3 model %s on device %s", model_id, self.device
)
try:
self.model = Sam3Model.from_pretrained(model_id).to(self.device)
self.processor = Sam3Processor.from_pretrained(model_id)
self.model.eval()
except Exception:
logging.exception("Failed to load SAM3 model")
raise
logging.info("SAM3 model loaded successfully")
def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
"""
Run SAM3 segmentation on a frame.
Args:
frame: Input image (HxWx3 numpy array in RGB)
text_prompts: List of text prompts for segmentation
Returns:
SegmentationResult with instance masks
"""
# Convert numpy array to PIL Image
if frame.dtype == np.uint8:
pil_image = Image.fromarray(frame)
else:
# Normalize to 0-255 if needed
frame_uint8 = (frame * 255).astype(np.uint8)
pil_image = Image.fromarray(frame_uint8)
# Use default prompts if none provided
if not text_prompts:
text_prompts = ["object"]
# Process image with text prompts
inputs = self.processor(
images=pil_image, text=text_prompts, return_tensors="pt"
).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process to get instance masks
try:
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=self.threshold,
mask_threshold=self.mask_threshold,
target_sizes=inputs.get("original_sizes").tolist(),
)[0]
# Extract results
masks = results.get("masks", [])
scores = results.get("scores", None)
boxes = results.get("boxes", None)
# Convert to numpy arrays
if len(masks) > 0:
# Stack masks: list of (H, W) -> (N, H, W)
masks_array = np.stack([m.cpu().numpy() for m in masks])
else:
# No objects detected
masks_array = np.zeros(
(0, frame.shape[0], frame.shape[1]), dtype=bool
)
scores_array = (
scores.cpu().numpy() if scores is not None else None
)
boxes_array = (
boxes.cpu().numpy() if boxes is not None else None
)
return SegmentationResult(
masks=masks_array,
scores=scores_array,
boxes=boxes_array,
)
except Exception:
logging.exception("SAM3 post-processing failed")
# Return empty result
return SegmentationResult(
masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
scores=None,
boxes=None,
)