Zhen Ye
perf: Tune batch sizes and queues for A10 GPUs
a5f8d15
import logging
from typing import Optional, Sequence
import numpy as np
import torch
from PIL import Image
from transformers import Sam3Model, Sam3Processor
from .base import Segmenter, SegmentationResult
class SAM3Segmenter(Segmenter):
"""
SAM3 (Segment Anything Model 3) segmenter.
Performs automatic instance segmentation on images without prompts.
Uses facebook/sam3 model from HuggingFace.
"""
name = "sam3"
def __init__(
self,
model_id: str = "facebook/sam3",
device: Optional[str] = None,
threshold: float = 0.5,
mask_threshold: float = 0.5,
):
"""
Initialize SAM3 segmenter.
Args:
model_id: HuggingFace model ID
device: Device to run on (cuda/cpu), auto-detected if None
threshold: Confidence threshold for filtering instances
mask_threshold: Threshold for binarizing masks
"""
self.device = device or (
"cuda" if torch.cuda.is_available() else "cpu"
)
self.threshold = threshold
self.mask_threshold = mask_threshold
logging.info(
"Loading SAM3 model %s on device %s", model_id, self.device
)
try:
self.model = Sam3Model.from_pretrained(model_id).to(self.device)
self.processor = Sam3Processor.from_pretrained(model_id)
self.model.eval()
except Exception:
logging.exception("Failed to load SAM3 model")
raise
logging.info("SAM3 model loaded successfully")
supports_batch = True
max_batch_size = 8
def _parse_single_result(self, results, frame_shape) -> SegmentationResult:
# Extract results
masks = results.get("masks", [])
scores = results.get("scores", None)
boxes = results.get("boxes", None)
# Convert to numpy arrays
if len(masks) > 0:
# Stack masks: list of (H, W) -> (N, H, W)
masks_array = np.stack([m.cpu().numpy() for m in masks])
else:
# No objects detected
masks_array = np.zeros(
(0, frame_shape[0], frame_shape[1]), dtype=bool
)
scores_array = (
scores.cpu().numpy() if scores is not None else None
)
boxes_array = (
boxes.cpu().numpy() if boxes is not None else None
)
return SegmentationResult(
masks=masks_array,
scores=scores_array,
boxes=boxes_array,
)
def _expand_inputs_if_needed(self, inputs):
"""
Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts.
Handles:
1. 1 image, N texts (Expand 1 -> N)
2. N images, N*M texts (Expand N -> N*M)
"""
pixel_values = inputs.get("pixel_values")
input_ids = inputs.get("input_ids")
if (
pixel_values is not None
and input_ids is not None
):
img_batch = pixel_values.shape[0]
text_batch = input_ids.shape[0]
should_expand = False
expansion_factor = 1
if img_batch == 1 and text_batch > 1:
should_expand = True
expansion_factor = text_batch
elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0:
should_expand = True
expansion_factor = text_batch // img_batch
if should_expand:
logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.")
# 1. Compute vision embeddings once for original images
with torch.no_grad():
vision_outputs = self.model.get_vision_features(
pixel_values=pixel_values
)
# Iterate over keys to expand
keys_to_expand = list(vision_outputs.keys())
for key in keys_to_expand:
value = getattr(vision_outputs, key, None)
if value is None:
# Try getItem
try:
value = vision_outputs[key]
except:
continue
new_value = None
if isinstance(value, torch.Tensor):
# Ensure we only expand the batch dimension (dim 0)
if value.shape[0] == img_batch:
new_value = value.repeat_interleave(expansion_factor, dim=0)
elif isinstance(value, (list, tuple)):
new_list = []
valid_expansion = False
for i, v in enumerate(value):
if isinstance(v, torch.Tensor) and v.shape[0] == img_batch:
new_list.append(v.repeat_interleave(expansion_factor, dim=0))
valid_expansion = True
else:
new_list.append(v)
if valid_expansion:
# Preserve type
new_value = type(value)(new_list)
if new_value is not None:
# Update dict item if possible
try:
vision_outputs[key] = new_value
except:
pass
# Update attribute explicitly if it exists
if hasattr(vision_outputs, key):
setattr(vision_outputs, key, new_value)
# 3. Update inputs for model call
inputs["vision_embeds"] = vision_outputs
del inputs["pixel_values"] # Mutually exclusive with vision_embeds
# 4. Expand other metadata
if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch:
inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0)
if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch:
inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0)
def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
"""
Run SAM3 segmentation on a frame.
Args:
frame: Input image (HxWx3 numpy array in RGB)
text_prompts: List of text prompts for segmentation
Returns:
SegmentationResult with instance masks
"""
# Convert numpy array to PIL Image
if frame.dtype == np.uint8:
pil_image = Image.fromarray(frame)
else:
# Normalize to 0-255 if needed
frame_uint8 = (frame * 255).astype(np.uint8)
pil_image = Image.fromarray(frame_uint8)
# Use default prompts if none provided
if not text_prompts:
text_prompts = ["object"]
# Process image with text prompts
inputs = self.processor(
images=pil_image, text=text_prompts, return_tensors="pt"
).to(self.device)
# Handle batch expansion
self._expand_inputs_if_needed(inputs)
# Run inference
try:
if "pixel_values" in inputs:
logging.debug(f"SAM3 Input pixel_values shape: {inputs['pixel_values'].shape}")
with torch.no_grad():
outputs = self.model(**inputs)
except RuntimeError as e:
logging.error(f"RuntimeError during SAM3 inference: {e}")
logging.error(f"Input keys: {inputs.keys()}")
if 'pixel_values' in inputs:
logging.error(f"Pixel values shape: {inputs['pixel_values'].shape}")
# Re-raise to let user know
raise
# Post-process to get instance masks
try:
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=self.threshold,
mask_threshold=self.mask_threshold,
target_sizes=inputs.get("original_sizes").tolist(),
)[0]
return self._parse_single_result(results, frame.shape)
except Exception:
logging.exception("SAM3 post-processing failed")
# Return empty result
return SegmentationResult(
masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
scores=None,
boxes=None,
)
def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]:
pil_images = []
for f in frames:
if f.dtype == np.uint8:
pil_images.append(Image.fromarray(f))
else:
f_uint8 = (f * 255).astype(np.uint8)
pil_images.append(Image.fromarray(f_uint8))
prompts = text_prompts or ["object"]
# Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...]
flattened_prompts = []
for _ in frames:
flattened_prompts.extend(prompts)
inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device)
# Handle batch expansion
self._expand_inputs_if_needed(inputs)
with torch.no_grad():
outputs = self.model(**inputs)
try:
results_list = self.processor.post_process_instance_segmentation(
outputs,
threshold=self.threshold,
mask_threshold=self.mask_threshold,
target_sizes=inputs.get("original_sizes").tolist(),
)
return [self._parse_single_result(r, f.shape) for r, f in zip(results_list, frames)]
except Exception:
logging.exception("SAM3 batch post-processing failed")
return [
SegmentationResult(
masks=np.zeros((0, f.shape[0], f.shape[1]), dtype=bool),
scores=None,
boxes=None
) for f in frames
]