import numpy as np import gradio as gr import torch import cv2 from segment_anything import SamPredictor, sam_model_registry # Global variables MODELS = ["./models/sam_vit_b_01ec64.pth", "./models/medsam_vitb.pth"] OFFICIAL_CHECKPOINT = "./models/sam_vit_b_01ec64.pth" MEDSAM_CHECKPOINT = "./models/medsam_vitb_best.pth" MODEL_TYPE = "vit_b" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Model ## OFFICIAL SAM SAM = sam_model_registry[MODEL_TYPE](checkpoint=OFFICIAL_CHECKPOINT) SAM.to(device=DEVICE) SAM_PREDICTOR = SamPredictor(SAM) ## MEDSAM MEDSAM = sam_model_registry[MODEL_TYPE](checkpoint=MEDSAM_CHECKPOINT) MEDSAM.to(device=DEVICE) MEDSAM_PREDICTOR = SamPredictor(MEDSAM) def load_model(model_choice: int) -> SamPredictor: """Load model.""" print("model_choice", model_choice) if model_choice == 0: return SAM_PREDICTOR elif model_choice == 1: return MEDSAM_PREDICTOR else: raise ValueError("Model choice must be 0 or 1") def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray: # draw contour contour_image = image.copy() contours, _ = cv2.findContours( np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) cv2.drawContours(contour_image, contours, -1, (0, 0, 255), 3) return contour_image, contours def inference( predictor: SamPredictor, image: np.ndarray, coord_y: int, coord_x: int ) -> np.ndarray: """Inference.""" predictor.set_image(image) input_point = np.array([[coord_y, coord_x]]) input_label = np.array([1]) mask, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) h, w = mask.shape[-2:] mask = mask.reshape(h, w, 1) mask = (mask * 255).astype(np.uint8) contour_image, _ = draw_contour(image, mask) return contour_image def extract_object_by_event(model_choice: int, image: np.ndarray, evt: gr.SelectData): """Extract object by mouse click.""" predictor = load_model(model_choice) click_h, click_w = evt.index return inference(predictor, image, click_h, click_w) def get_coords(evt: gr.SelectData): """Get coords from mouse click in gradio.""" return evt.index[0], evt.index[1] with gr.Blocks() as demo: with gr.Row(): gr.Markdown( """# Segment Anything!🚀 The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. More information can be found in [**Official Project**](https://segment-anything.com/). """ ) with gr.Row(): # select model model_choice = gr.Dropdown( label="Select Model", choices=[m for m in MODELS], type="index", interactive=True, ) # Segment image with gr.Tab(label="SAM Inference"): with gr.Row().style(equal_height=True): with gr.Column(label="Input Image"): # input image input_image = gr.Image(type="numpy") with gr.Column(label="Output"): # output output = gr.Image(type="numpy") with gr.Row(): coord_h = gr.Number(label="Mouse coords h") coord_w = gr.Number(label="Mouse coords w") input_image.select(extract_object_by_event, [model_choice, input_image], output) input_image.select(get_coords, None, [coord_h, coord_w]) demo.queue().launch(debug=True, enable_queue=True)