File size: 2,622 Bytes
aa1c1e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import cv2


class SamWrapper:
    """
    Wrapper for Segment Anything Model (SAM).
    Handles both automatic mask generation and guided segmentation via bounding boxes.
    """

    def __init__(self, model_type="vit_b", checkpoint_path=None, device=None):
        """
        Initialize the SAM model.

        :param model_type: Type of SAM backbone (e.g., 'vit_b', 'vit_l', 'vit_h')
        :param checkpoint_path: Path to the .pth checkpoint file
        :param device: 'cuda' or 'cpu'; if None, auto-detects
        """

        device = "cpu"

        self.device = device
        self.model = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.model.to(self.device)

        self.automatic_generator = SamAutomaticMaskGenerator(
            model=self.model,
            points_per_side=12,
            pred_iou_thresh=0.92,
            stability_score_thresh=0.95,
            min_mask_region_area=1500,
            box_nms_thresh=0.3
        )
        self.predictor = SamPredictor(self.model)

    def generate_masks(self, image, boxes=None):
        """
        Generate segmentation masks for the given image.

        :param image: Input image as NumPy array (BGR)
        :param boxes: Optional list of bounding boxes [x1, y1, x2, y2]
        :return: List of binary masks (NumPy arrays)
        """
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if boxes is None:
            masks = self.automatic_generator.generate(image_rgb)
            return [mask['segmentation'] for mask in masks]

        # Set the image once
        self.predictor.set_image(image_rgb)

        # Convert boxes to tensor and transform
        transformed_boxes = self.predictor.transform.apply_boxes_torch(
            torch.tensor(boxes, dtype=torch.float32, device=self.device),
            image.shape[:2]
        )

        # Predict masks for all boxes at once
        masks, _, _ = self.predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False
        )

        return [m[0].cpu().numpy() for m in masks]

    def predict_with_box(self, image, box):
        """
        Predict a single segmentation mask for the given box.

        :param image: Input image (BGR)
        :param box: One bounding box [x1, y1, x2, y2]
        :return: Binary mask (NumPy array)
        """
        masks = self.generate_masks(image, boxes=[box])
        return masks[0] if masks else None