devsheroubi commited on
Commit
cbdb75f
·
verified ·
1 Parent(s): d9cda46

Upload ./pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +127 -0
pipeline.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import SamModel, SamProcessor
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2 as cv
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
9
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
10
+
11
+ """
12
+ Segmentor Module that takes in an image and input points to generate segmentation masks.
13
+ """
14
+
15
+ class Segmentor:
16
+ def __init__(self, model, processor, device):
17
+ self.model = model
18
+ self.processor = processor
19
+ self.device = device
20
+
21
+ def segment(self, image_input, input_points):
22
+ if isinstance(image_input, str):
23
+ image = Image.open(image_input).convert("RGB")
24
+ elif isinstance(image_input, np.ndarray):
25
+ # OpenCV uses BGR, PIL uses RGB
26
+ image = Image.fromarray(cv.cvtColor(image_input, cv.COLOR_BGR2RGB))
27
+ elif isinstance(image_input, Image.Image):
28
+ image = image_input.convert("RGB")
29
+ else:
30
+ raise ValueError("image_input must be a path, numpy array, or PIL Image")
31
+
32
+ points = [[[ [int(x), int(y)] for (x, y) in input_points ]]]
33
+ labels = [[[1] * len(input_points)]]
34
+
35
+ inputs = self.processor(
36
+ images=image,
37
+ input_points=points,
38
+ input_labels=labels,
39
+ return_tensors="pt"
40
+ ).to(self.device)
41
+
42
+ with torch.no_grad():
43
+ outputs = self.model(**inputs)
44
+
45
+ pred_masks = outputs.pred_masks
46
+ iou_scores = outputs.iou_scores
47
+
48
+ # Convert to original image size
49
+ processed = self.processor.post_process_masks(
50
+ masks=pred_masks,
51
+ reshaped_input_sizes=inputs["reshaped_input_sizes"],
52
+ original_sizes=inputs["original_sizes"]
53
+ )
54
+
55
+ # processed is a list per batch; we have batch=1
56
+ masks = processed[0] # shape: [point_batch, num_masks, H, W] or similar
57
+ scores = iou_scores.cpu().numpy()
58
+
59
+ # Normalize to a flat list of 2D uint8 masks
60
+ flat_masks = []
61
+ flat_scores = []
62
+ masks_np = masks.cpu().numpy() if hasattr(masks, "cpu") else np.array(masks)
63
+
64
+ for i, mask_group in enumerate(np.array(masks_np)):
65
+ score_group = scores[0][i]
66
+ for j, m in enumerate(np.array(mask_group)):
67
+ m2d = np.squeeze(m) # remove singleton dims → HxW
68
+ m2d = (m2d > 0).astype(np.uint8) # ensure binary 0/1
69
+ flat_masks.append(m2d)
70
+ flat_scores.append(score_group[j])
71
+ return flat_masks, flat_scores
72
+
73
+ # Example usage
74
+ if __name__ == "__main__":
75
+ segmentor = Segmentor(model, processor, device)
76
+ image_path = "redbull.jpg"
77
+
78
+ # get input from user input using cv2
79
+ input_points = []
80
+
81
+ def mouse_callback(event, x, y, flags, param):
82
+ if event == cv.EVENT_LBUTTONDOWN:
83
+ input_points.append([x, y])
84
+ print(f"Point added: ({x}, {y})")
85
+
86
+ cv.namedWindow("Input Image")
87
+ cv.setMouseCallback("Input Image", mouse_callback)
88
+ img = cv.imread(image_path)
89
+
90
+ while True:
91
+ cv.imshow("Input Image", img)
92
+ if cv.waitKey(1) & 0xFF == ord('q'):
93
+ break
94
+ cv.destroyAllWindows()
95
+ cv.waitKey(1)
96
+
97
+ if len(input_points) == 0:
98
+ print("No input points provided. Exiting.")
99
+ else:
100
+ masks, scores = segmentor.segment(image_path, input_points)
101
+
102
+ print(f"Generated {len(masks)} candidate masks.")
103
+
104
+ # Display candidates
105
+ for i, (mask, score) in enumerate(zip(masks, scores)):
106
+ masked_preview = cv.bitwise_and(img, img, mask=mask)
107
+ cv.imshow(f"Candidate {i} (Score: {score:.4f})", masked_preview)
108
+ print(f"Candidate {i}: Score {score:.4f}")
109
+
110
+ print("Check the open windows for candidate masks.")
111
+ cv.waitKey(100) # Give time for windows to draw
112
+
113
+ try:
114
+ selected_idx = int(input("Enter the index of the desired mask: "))
115
+ if 0 <= selected_idx < len(masks):
116
+ selected_mask = masks[selected_idx]
117
+ masked_img = cv.bitwise_and(img, img, mask=selected_mask)
118
+ cv.imwrite("masked_image.png", masked_img)
119
+ print(f"Saved masked_image.png using candidate {selected_idx}")
120
+ else:
121
+ print("Invalid index selected.")
122
+ except ValueError:
123
+ print("Invalid input. Please enter a number.")
124
+
125
+ cv.destroyAllWindows()
126
+
127
+