File size: 4,331 Bytes
eab4d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""

models/segmenter.py

--------------------

SAM (Segment Anything Model) wrapper.



Given a scene image and one or more bounding boxes (from the detector),

produces precise pixel-level masks for each detected object.

"""

import os
import sys
from typing import List, Tuple, Optional

import numpy as np
import torch
from PIL import Image

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import DEVICE, SAM_CHECKPOINT, SAM_MODEL_TYPE, MASK_DILATION_PX
from image_utils import load_image, save_image, show_mask, show_box, dilate_mask_with_sam_prediction, dilate_mask, combine_masks


class SAMSegmenter:
    """

    Wraps SAM to convert bounding boxes into fine-grained masks.

    Loaded lazily on first use.

    """

    def __init__(self) -> None:
        self._predictor = None

    def _load(self) -> None:
        if self._predictor is not None:
            self._predictor.model.to(DEVICE)
            return

        print("  Loading SAM (this may take a moment) ...")
        try:
            from segment_anything import sam_model_registry, SamPredictor
            sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
            sam = sam.to(DEVICE)
            self._predictor = SamPredictor(sam)
        except ImportError as e:
            raise RuntimeError(
                "segment-anything is not installed.\n"
                "Run: pip install git+https://github.com/facebookresearch/segment-anything.git\n"
                f"Original error: {e}"
            )

    def segment_boxes(

        self,

        image_pil: Image.Image,

        boxes: List[Tuple[int, int, int, int]],

        dilation_px: int = MASK_DILATION_PX,

    ) -> np.ndarray:
        """

        For each box, run SAM and return the combined binary mask (HW, uint8).



        Args:

            image_pil:   The scene image.

            boxes:       List of (x1, y1, x2, y2) in absolute pixels.

            dilation_px: How many pixels to dilate the final mask (covers edges).



        Returns:

            Combined mask (255 = object, 0 = background).

        """
        self._load()

        img_np = np.array(image_pil)
        h, w   = img_np.shape[:2]

        self._predictor.set_image(img_np)

        individual_masks = []
        for box in boxes:
            x1, y1, x2, y2 = box
            box_np = np.array([[x1, y1, x2, y2]], dtype=np.float32)

            masks, scores, _ = self._predictor.predict(
                box=box_np,
                multimask_output=True,
            )
            # scores shape: (3,); masks shape: (3, H, W)
            best_idx  = scores.argmax()
            best_mask = (masks[best_idx].astype(np.uint8)) * 255
            individual_masks.append(best_mask)

        if not individual_masks:
            return np.zeros((h, w), dtype=np.uint8)

        combined = combine_masks(individual_masks)
        if dilation_px > 0:
            combined = dilate_mask(combined, dilation_px)

        pct = 100 * (combined > 0).sum() / (h * w)
        print(f"    SAM mask covers {pct:.1f}% of the image")
        return combined

    def segment_points(

        self,

        image_pil: Image.Image,

        points: List[Tuple[int, int]],

        point_labels: Optional[List[int]] = None,

        dilation_px:  int = MASK_DILATION_PX,

    ) -> np.ndarray:
        """

        Segment using foreground point prompts (1 = foreground, 0 = background).

        Falls back to all-foreground if point_labels is None.

        """
        self._load()

        img_np = np.array(image_pil)
        h, w   = img_np.shape[:2]

        self._predictor.set_image(img_np)

        pts_np    = np.array(points, dtype=np.float32)
        labels_np = np.array(
            point_labels if point_labels else [1] * len(points), dtype=np.int32
        )

        masks, scores, _ = self._predictor.predict(
            point_coords=pts_np,
            point_labels=labels_np,
            multimask_output=True,
        )
        best_idx  = scores.argmax()
        best_mask = (masks[best_idx].astype(np.uint8)) * 255

        if dilation_px > 0:
            best_mask = dilate_mask(best_mask, dilation_px)

        return best_mask