import argparse import cv2 import numpy as np from segment_anything import SamPredictor, sam_model_registry # Argument parser parser = argparse.ArgumentParser() parser.add_argument("-i", "--image", required=True, help="Path to the image") args = parser.parse_args() # Set hyperparameters sam_checkpoint = "./models/sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cpu" # Load model sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) # Preprocessing the image image = cv2.imread(args.image) predictor.set_image(image) # SAM Encoder for embedding embedding = predictor.get_image_embedding() np.save("models/embedding.npy", embedding) # SAM Decoder for segmentation input_point = np.array([[1300, 950]]) input_label = np.array([1]) mask, score, logit = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) # Save output h, w = mask.shape[-2:] mask = mask.reshape(h, w, 1) ## Mask has a 255 or 0 value mask = (mask * 255).astype(np.uint8) ## Save mask image cv2.imwrite("mask.png", mask[:, :])