File size: 3,566 Bytes
1834bc0
 
 
 
 
ca079c9
 
1834bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca079c9
 
 
 
 
 
 
 
 
 
 
 
 
1834bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

from __future__ import annotations
import numpy as np
import torch
from PIL import Image
import os
from transformers import Sam3Processor, Sam3Model

from segmenters import BaseSegmenter


class SAM3Segmenter(BaseSegmenter):
    """

    SAM3 wrapper using a text prompt of object type

    """

    def __init__(

        self,

        text_prompt: str,

        model_name: str = "facebook/sam3",

        device: str = "cuda",

        score_threshold: float = 0.5,

        mask_threshold: float = 0.5 ):
        """

        Args:

            text_prompt: stuff we want to segment.

            model_name: HF repo id for the SAM3 model.

            device: "cuda" or "cpu".

            score_threshold: min detection score to keep an instance.

            mask_threshold: pixel threshold for masks.

        """
        super().__init__()

        if torch.cuda.is_available() and device.startswith("cuda"):
            self.device = torch.device(device)
        else:
            self.device = torch.device("cpu")

        # preprocess text prompt so metal_nut is processed as metal nut
        self.text_prompt = text_prompt.replace("_", " ")
        self.score_threshold = score_threshold
        self.mask_threshold = mask_threshold

        # Loading model + defining processor
        token = os.getenv("HF_TOKEN")
        self.model = Sam3Model.from_pretrained(
            model_name,
            token=token,
            trust_remote_code=True,
        ).to(self.device)
        self.model.eval()
        self.processor = Sam3Processor.from_pretrained(
            model_name,
            token=token,
            trust_remote_code=True,
        )

    def get_object_mask(self, image: np.ndarray) -> np.ndarray:
        """

        Running SAM3 and returning a single foreground mask.

        """
        # Pill image stuff - probably there is less idiotic way, but it is wat ChatGPT suggested
        if isinstance(image, np.ndarray):
            pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
        else:
            pil_image = image

        # defining preprocessor with text prompt
        inputs = self.processor(
            images=pil_image,
            text=self.text_prompt,
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        # Post-process instance segmentation 
        target_sizes = inputs.get("original_sizes").tolist()
        results = self.processor.post_process_instance_segmentation(
            outputs,
            threshold=self.score_threshold,
            mask_threshold=self.mask_threshold,
            target_sizes=target_sizes,
        )[0]

        masks = results.get("masks", None)
        scores = results.get("scores", None)

        # If SAM completely fails we keep everything
        if masks is None or masks.numel() == 0:
            H, W = pil_image.size[1], pil_image.size[0] 
            return np.ones((H, W), dtype=bool)

        if scores is not None:
            keep = scores >= self.score_threshold
            if keep.sum() == 0:
                H, W = pil_image.size[1], pil_image.size[0]
                return np.ones((H, W), dtype=bool)
            masks = masks[keep]

        # check if mask passes mask treshold
        masks_bin = (masks > self.mask_threshold)
        combined = masks_bin.any(dim=0)
        full_mask = combined.cpu().numpy().astype(bool)

        return full_mask