|
|
import torch |
|
|
from transformers import SamModel, SamProcessor |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import cv2 as cv |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) |
|
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
|
|
""" |
|
|
Segmentor Module that takes in an image and input points to generate segmentation masks. |
|
|
""" |
|
|
|
|
|
class Segmentor: |
|
|
def __init__(self, model, processor, device): |
|
|
self.model = model |
|
|
self.processor = processor |
|
|
self.device = device |
|
|
|
|
|
def segment(self, image_input, input_points): |
|
|
if isinstance(image_input, str): |
|
|
image = Image.open(image_input).convert("RGB") |
|
|
elif isinstance(image_input, np.ndarray): |
|
|
|
|
|
image = Image.fromarray(cv.cvtColor(image_input, cv.COLOR_BGR2RGB)) |
|
|
elif isinstance(image_input, Image.Image): |
|
|
image = image_input.convert("RGB") |
|
|
else: |
|
|
raise ValueError("image_input must be a path, numpy array, or PIL Image") |
|
|
|
|
|
points = [[[ [int(x), int(y)] for (x, y) in input_points ]]] |
|
|
labels = [[[1] * len(input_points)]] |
|
|
|
|
|
inputs = self.processor( |
|
|
images=image, |
|
|
input_points=points, |
|
|
input_labels=labels, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
pred_masks = outputs.pred_masks |
|
|
iou_scores = outputs.iou_scores |
|
|
|
|
|
|
|
|
processed = self.processor.post_process_masks( |
|
|
masks=pred_masks, |
|
|
reshaped_input_sizes=inputs["reshaped_input_sizes"], |
|
|
original_sizes=inputs["original_sizes"] |
|
|
) |
|
|
|
|
|
|
|
|
masks = processed[0] |
|
|
scores = iou_scores.cpu().numpy() |
|
|
|
|
|
|
|
|
flat_masks = [] |
|
|
flat_scores = [] |
|
|
masks_np = masks.cpu().numpy() if hasattr(masks, "cpu") else np.array(masks) |
|
|
|
|
|
for i, mask_group in enumerate(np.array(masks_np)): |
|
|
score_group = scores[0][i] |
|
|
for j, m in enumerate(np.array(mask_group)): |
|
|
m2d = np.squeeze(m) |
|
|
m2d = (m2d > 0).astype(np.uint8) |
|
|
flat_masks.append(m2d) |
|
|
flat_scores.append(score_group[j]) |
|
|
return flat_masks, flat_scores |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
segmentor = Segmentor(model, processor, device) |
|
|
image_path = "redbull.jpg" |
|
|
|
|
|
|
|
|
input_points = [] |
|
|
|
|
|
def mouse_callback(event, x, y, flags, param): |
|
|
if event == cv.EVENT_LBUTTONDOWN: |
|
|
input_points.append([x, y]) |
|
|
print(f"Point added: ({x}, {y})") |
|
|
|
|
|
cv.namedWindow("Input Image") |
|
|
cv.setMouseCallback("Input Image", mouse_callback) |
|
|
img = cv.imread(image_path) |
|
|
|
|
|
while True: |
|
|
cv.imshow("Input Image", img) |
|
|
if cv.waitKey(1) & 0xFF == ord('q'): |
|
|
break |
|
|
cv.destroyAllWindows() |
|
|
cv.waitKey(1) |
|
|
|
|
|
if len(input_points) == 0: |
|
|
print("No input points provided. Exiting.") |
|
|
else: |
|
|
masks, scores = segmentor.segment(image_path, input_points) |
|
|
|
|
|
print(f"Generated {len(masks)} candidate masks.") |
|
|
|
|
|
|
|
|
for i, (mask, score) in enumerate(zip(masks, scores)): |
|
|
masked_preview = cv.bitwise_and(img, img, mask=mask) |
|
|
cv.imshow(f"Candidate {i} (Score: {score:.4f})", masked_preview) |
|
|
print(f"Candidate {i}: Score {score:.4f}") |
|
|
|
|
|
print("Check the open windows for candidate masks.") |
|
|
cv.waitKey(100) |
|
|
|
|
|
try: |
|
|
selected_idx = int(input("Enter the index of the desired mask: ")) |
|
|
if 0 <= selected_idx < len(masks): |
|
|
selected_mask = masks[selected_idx] |
|
|
masked_img = cv.bitwise_and(img, img, mask=selected_mask) |
|
|
cv.imwrite("masked_image.png", masked_img) |
|
|
print(f"Saved masked_image.png using candidate {selected_idx}") |
|
|
else: |
|
|
print("Invalid index selected.") |
|
|
except ValueError: |
|
|
print("Invalid input. Please enter a number.") |
|
|
|
|
|
cv.destroyAllWindows() |
|
|
|
|
|
|