sam-reconstruction / segment.py
devsheroubi's picture
Upload segment.py with huggingface_hub
b86ba38 verified
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
import cv2 as cv
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
"""
Segmentor Module that takes in an image and input points to generate segmentation masks.
"""
class Segmentor:
def __init__(self, model, processor, device):
self.model = model
self.processor = processor
self.device = device
def segment(self, image_input, input_points):
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, np.ndarray):
# OpenCV uses BGR, PIL uses RGB
image = Image.fromarray(cv.cvtColor(image_input, cv.COLOR_BGR2RGB))
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise ValueError("image_input must be a path, numpy array, or PIL Image")
points = [[[ [int(x), int(y)] for (x, y) in input_points ]]]
labels = [[[1] * len(input_points)]]
inputs = self.processor(
images=image,
input_points=points,
input_labels=labels,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
pred_masks = outputs.pred_masks
iou_scores = outputs.iou_scores
# Convert to original image size
processed = self.processor.post_process_masks(
masks=pred_masks,
reshaped_input_sizes=inputs["reshaped_input_sizes"],
original_sizes=inputs["original_sizes"]
)
# processed is a list per batch; we have batch=1
masks = processed[0] # shape: [point_batch, num_masks, H, W] or similar
scores = iou_scores.cpu().numpy()
# Normalize to a flat list of 2D uint8 masks
flat_masks = []
flat_scores = []
masks_np = masks.cpu().numpy() if hasattr(masks, "cpu") else np.array(masks)
for i, mask_group in enumerate(np.array(masks_np)):
score_group = scores[0][i]
for j, m in enumerate(np.array(mask_group)):
m2d = np.squeeze(m) # remove singleton dims → HxW
m2d = (m2d > 0).astype(np.uint8) # ensure binary 0/1
flat_masks.append(m2d)
flat_scores.append(score_group[j])
return flat_masks, flat_scores
# Example usage
if __name__ == "__main__":
segmentor = Segmentor(model, processor, device)
image_path = "redbull.jpg"
# get input from user input using cv2
input_points = []
def mouse_callback(event, x, y, flags, param):
if event == cv.EVENT_LBUTTONDOWN:
input_points.append([x, y])
print(f"Point added: ({x}, {y})")
cv.namedWindow("Input Image")
cv.setMouseCallback("Input Image", mouse_callback)
img = cv.imread(image_path)
while True:
cv.imshow("Input Image", img)
if cv.waitKey(1) & 0xFF == ord('q'):
break
cv.destroyAllWindows()
cv.waitKey(1)
if len(input_points) == 0:
print("No input points provided. Exiting.")
else:
masks, scores = segmentor.segment(image_path, input_points)
print(f"Generated {len(masks)} candidate masks.")
# Display candidates
for i, (mask, score) in enumerate(zip(masks, scores)):
masked_preview = cv.bitwise_and(img, img, mask=mask)
cv.imshow(f"Candidate {i} (Score: {score:.4f})", masked_preview)
print(f"Candidate {i}: Score {score:.4f}")
print("Check the open windows for candidate masks.")
cv.waitKey(100) # Give time for windows to draw
try:
selected_idx = int(input("Enter the index of the desired mask: "))
if 0 <= selected_idx < len(masks):
selected_mask = masks[selected_idx]
masked_img = cv.bitwise_and(img, img, mask=selected_mask)
cv.imwrite("masked_image.png", masked_img)
print(f"Saved masked_image.png using candidate {selected_idx}")
else:
print("Invalid index selected.")
except ValueError:
print("Invalid input. Please enter a number.")
cv.destroyAllWindows()