Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import os | |
| from transformers import Sam3Processor, Sam3Model | |
| from segmenters import BaseSegmenter | |
| class SAM3Segmenter(BaseSegmenter): | |
| """ | |
| SAM3 wrapper using a text prompt of object type | |
| """ | |
| def __init__( | |
| self, | |
| text_prompt: str, | |
| model_name: str = "facebook/sam3", | |
| device: str = "cuda", | |
| score_threshold: float = 0.5, | |
| mask_threshold: float = 0.5 ): | |
| """ | |
| Args: | |
| text_prompt: stuff we want to segment. | |
| model_name: HF repo id for the SAM3 model. | |
| device: "cuda" or "cpu". | |
| score_threshold: min detection score to keep an instance. | |
| mask_threshold: pixel threshold for masks. | |
| """ | |
| super().__init__() | |
| if torch.cuda.is_available() and device.startswith("cuda"): | |
| self.device = torch.device(device) | |
| else: | |
| self.device = torch.device("cpu") | |
| # preprocess text prompt so metal_nut is processed as metal nut | |
| self.text_prompt = text_prompt.replace("_", " ") | |
| self.score_threshold = score_threshold | |
| self.mask_threshold = mask_threshold | |
| # Loading model + defining processor | |
| token = os.getenv("HF_TOKEN") | |
| self.model = Sam3Model.from_pretrained( | |
| model_name, | |
| token=token, | |
| trust_remote_code=True, | |
| ).to(self.device) | |
| self.model.eval() | |
| self.processor = Sam3Processor.from_pretrained( | |
| model_name, | |
| token=token, | |
| trust_remote_code=True, | |
| ) | |
| def get_object_mask(self, image: np.ndarray) -> np.ndarray: | |
| """ | |
| Running SAM3 and returning a single foreground mask. | |
| """ | |
| # Pill image stuff - probably there is less idiotic way, but it is wat ChatGPT suggested | |
| if isinstance(image, np.ndarray): | |
| pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") | |
| else: | |
| pil_image = image | |
| # defining preprocessor with text prompt | |
| inputs = self.processor( | |
| images=pil_image, | |
| text=self.text_prompt, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Post-process instance segmentation | |
| target_sizes = inputs.get("original_sizes").tolist() | |
| results = self.processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=self.score_threshold, | |
| mask_threshold=self.mask_threshold, | |
| target_sizes=target_sizes, | |
| )[0] | |
| masks = results.get("masks", None) | |
| scores = results.get("scores", None) | |
| # If SAM completely fails we keep everything | |
| if masks is None or masks.numel() == 0: | |
| H, W = pil_image.size[1], pil_image.size[0] | |
| return np.ones((H, W), dtype=bool) | |
| if scores is not None: | |
| keep = scores >= self.score_threshold | |
| if keep.sum() == 0: | |
| H, W = pil_image.size[1], pil_image.size[0] | |
| return np.ones((H, W), dtype=bool) | |
| masks = masks[keep] | |
| # check if mask passes mask treshold | |
| masks_bin = (masks > self.mask_threshold) | |
| combined = masks_bin.any(dim=0) | |
| full_mask = combined.cpu().numpy().astype(bool) | |
| return full_mask | |