File size: 4,541 Bytes
b86ba38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
import cv2 as cv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

"""
Segmentor Module that takes in an image and input points to generate segmentation masks.
"""

class Segmentor:
    def __init__(self, model, processor, device):
        self.model = model
        self.processor = processor
        self.device = device

    def segment(self, image_input, input_points):
        if isinstance(image_input, str):
            image = Image.open(image_input).convert("RGB")
        elif isinstance(image_input, np.ndarray):
            # OpenCV uses BGR, PIL uses RGB
            image = Image.fromarray(cv.cvtColor(image_input, cv.COLOR_BGR2RGB))
        elif isinstance(image_input, Image.Image):
            image = image_input.convert("RGB")
        else:
            raise ValueError("image_input must be a path, numpy array, or PIL Image")

        points = [[[ [int(x), int(y)] for (x, y) in input_points ]]]
        labels = [[[1] * len(input_points)]]

        inputs = self.processor(
            images=image,
            input_points=points,
            input_labels=labels,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        pred_masks = outputs.pred_masks
        iou_scores = outputs.iou_scores

        # Convert to original image size
        processed = self.processor.post_process_masks(
            masks=pred_masks,
            reshaped_input_sizes=inputs["reshaped_input_sizes"],
            original_sizes=inputs["original_sizes"]
        )

        # processed is a list per batch; we have batch=1
        masks = processed[0]  # shape: [point_batch, num_masks, H, W] or similar
        scores = iou_scores.cpu().numpy()

        # Normalize to a flat list of 2D uint8 masks
        flat_masks = []
        flat_scores = []
        masks_np = masks.cpu().numpy() if hasattr(masks, "cpu") else np.array(masks)
        
        for i, mask_group in enumerate(np.array(masks_np)):
            score_group = scores[0][i]
            for j, m in enumerate(np.array(mask_group)):
                m2d = np.squeeze(m)               # remove singleton dims → HxW
                m2d = (m2d > 0).astype(np.uint8)  # ensure binary 0/1
                flat_masks.append(m2d)
                flat_scores.append(score_group[j])
        return flat_masks, flat_scores

# Example usage
if __name__ == "__main__":
    segmentor = Segmentor(model, processor, device)
    image_path = "redbull.jpg"

    # get input from user input using cv2
    input_points = []

    def mouse_callback(event, x, y, flags, param):
        if event == cv.EVENT_LBUTTONDOWN:
            input_points.append([x, y])
            print(f"Point added: ({x}, {y})")

    cv.namedWindow("Input Image")
    cv.setMouseCallback("Input Image", mouse_callback)
    img = cv.imread(image_path)
    
    while True:
        cv.imshow("Input Image", img)
        if cv.waitKey(1) & 0xFF == ord('q'):
            break
    cv.destroyAllWindows()
    cv.waitKey(1)

    if len(input_points) == 0:
        print("No input points provided. Exiting.")
    else:
        masks, scores = segmentor.segment(image_path, input_points)
        
        print(f"Generated {len(masks)} candidate masks.")
        
        # Display candidates
        for i, (mask, score) in enumerate(zip(masks, scores)):
            masked_preview = cv.bitwise_and(img, img, mask=mask)
            cv.imshow(f"Candidate {i} (Score: {score:.4f})", masked_preview)
            print(f"Candidate {i}: Score {score:.4f}")

        print("Check the open windows for candidate masks.")
        cv.waitKey(100) # Give time for windows to draw

        try:
            selected_idx = int(input("Enter the index of the desired mask: "))
            if 0 <= selected_idx < len(masks):
                selected_mask = masks[selected_idx]
                masked_img = cv.bitwise_and(img, img, mask=selected_mask)
                cv.imwrite("masked_image.png", masked_img)
                print(f"Saved masked_image.png using candidate {selected_idx}")
            else:
                print("Invalid index selected.")
        except ValueError:
            print("Invalid input. Please enter a number.")
            
        cv.destroyAllWindows()