File size: 3,599 Bytes
1834bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import annotations
import numpy as np
import torch
from PIL import Image
from transformers import Sam2Processor, Sam2Model

from segmenters import BaseSegmenter


class SAM2Segmenter(BaseSegmenter):
    """

    SAM2 wrapper.



    - Uses Sam2Model (e.g. `facebook/sam2.1-hiera-large`).

    - Segments (approximately) all objects in the image by prompting

      with a full-image bounding box and returns a single boolean mask

      given by the union of all predicted masks.

    """

    def __init__(

        self,

        text_prompt: str | None = None,

        model_name: str = "facebook/sam2.1-hiera-large",

        device: str = "cuda",

        mask_threshold: float = 0.5,

    ) -> None:
        """

        Args:

            text_prompt: kept for compatibility with SAM3Segmenter, but unused.

            model_name: HF repo id for the SAM2 model, e.g. "facebook/sam2.1-hiera-large".

            device: "cuda" or "cpu".

            mask_threshold: pixel threshold for masks (after SAM2 post-processing).

        """
        super().__init__()

        if torch.cuda.is_available() and device.startswith("cuda"):
            self.device = torch.device(device)
        else:
            self.device = torch.device("cpu")

        # Load SAM2 model + processor
        self.model = Sam2Model.from_pretrained(model_name).to(self.device)
        self.model.eval()
        self.processor = Sam2Processor.from_pretrained(model_name)

    def get_object_mask(self, image: np.ndarray) -> np.ndarray:
        """

        Run SAM2 and return a single foreground mask.



        - Convert image to PIL.

        - Use a single bounding box covering the whole image as prompt.

        - Run SAM2, post-process masks to image resolution.

        - Threshold and union all masks into one boolean (H, W) array.



        """
        # Ensure PIL image
        if isinstance(image, np.ndarray):
            pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
        else:
            pil_image = image

        W, H = pil_image.size  # PIL: (W, H)

        # Full image bounding box: [x_min, y_min, x_max, y_max]
        input_boxes = [[[0, 0, W, H]]]

        # Build inputs for SAM2
        inputs = self.processor(
            images=pil_image,
            input_boxes=input_boxes,
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            # multimask_output=False → one mask per box
            outputs = self.model(**inputs, multimask_output=False)

        # Post-process masks to original resolution
        masks = self.processor.post_process_masks(
            outputs.pred_masks.cpu(),  # (B, num_masks, H', W')
            inputs["original_sizes"],
        )[0]

        # Shapes can be:
        # - (num_masks, H, W)
        # - or (1, num_masks, H, W) depending on version
        if masks.ndim == 4:
            # (B, num_masks, H, W) -> (num_masks, H, W) for B=1
            masks = masks[0]

        if masks.ndim == 2:
            # Single mask: (H, W)
            full_mask = (masks > self.mask_threshold).numpy().astype(bool)
            return full_mask

        if masks.ndim != 3:
            # Failsafe: if something weird happens, keep everything
            return np.ones((H, W), dtype=bool)

        # masks: (num_masks, H, W)
        masks_bin = masks > self.mask_threshold
        combined = masks_bin.any(dim=0)  # (H, W)
        full_mask = combined.numpy().astype(bool)

        return full_mask