V4ldeLund's picture
Update code/comments from local workspace
ca079c9 verified
from __future__ import annotations
import numpy as np
import torch
from PIL import Image
import os
from transformers import Sam3Processor, Sam3Model
from segmenters import BaseSegmenter
class SAM3Segmenter(BaseSegmenter):
"""
SAM3 wrapper using a text prompt of object type
"""
def __init__(
self,
text_prompt: str,
model_name: str = "facebook/sam3",
device: str = "cuda",
score_threshold: float = 0.5,
mask_threshold: float = 0.5 ):
"""
Args:
text_prompt: stuff we want to segment.
model_name: HF repo id for the SAM3 model.
device: "cuda" or "cpu".
score_threshold: min detection score to keep an instance.
mask_threshold: pixel threshold for masks.
"""
super().__init__()
if torch.cuda.is_available() and device.startswith("cuda"):
self.device = torch.device(device)
else:
self.device = torch.device("cpu")
# preprocess text prompt so metal_nut is processed as metal nut
self.text_prompt = text_prompt.replace("_", " ")
self.score_threshold = score_threshold
self.mask_threshold = mask_threshold
# Loading model + defining processor
token = os.getenv("HF_TOKEN")
self.model = Sam3Model.from_pretrained(
model_name,
token=token,
trust_remote_code=True,
).to(self.device)
self.model.eval()
self.processor = Sam3Processor.from_pretrained(
model_name,
token=token,
trust_remote_code=True,
)
def get_object_mask(self, image: np.ndarray) -> np.ndarray:
"""
Running SAM3 and returning a single foreground mask.
"""
# Pill image stuff - probably there is less idiotic way, but it is wat ChatGPT suggested
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
else:
pil_image = image
# defining preprocessor with text prompt
inputs = self.processor(
images=pil_image,
text=self.text_prompt,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process instance segmentation
target_sizes = inputs.get("original_sizes").tolist()
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=self.score_threshold,
mask_threshold=self.mask_threshold,
target_sizes=target_sizes,
)[0]
masks = results.get("masks", None)
scores = results.get("scores", None)
# If SAM completely fails we keep everything
if masks is None or masks.numel() == 0:
H, W = pil_image.size[1], pil_image.size[0]
return np.ones((H, W), dtype=bool)
if scores is not None:
keep = scores >= self.score_threshold
if keep.sum() == 0:
H, W = pil_image.size[1], pil_image.size[0]
return np.ones((H, W), dtype=bool)
masks = masks[keep]
# check if mask passes mask treshold
masks_bin = (masks > self.mask_threshold)
combined = masks_bin.any(dim=0)
full_mask = combined.cpu().numpy().astype(bool)
return full_mask