import numpy as np import matplotlib.pyplot as plt import cv2 import torch from segment_anything import sam_model_registry, SamPredictor import gradio as gr sam_checkpoint = { "ViT-base": "weights/sam_vit_b_01ec64.pth", "ViT-large": "weights/sam_vit_l_0b3195.pth", "ViT-huge": "weights/sam_vit_h_4b8939.pth", } model_type = { "ViT-base": "vit_b", "ViT-large": "vit_l", "ViT-huge": "vit_h", } device = "cuda" if torch.cuda.is_available() else "cpu" def get_coords(evt: gr.SelectData): return f"{evt.index[0]}, {evt.index[1]}" def inference(image, input_label, model_choice): sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice]) sam.to(device=device) predictor = SamPredictor(sam) predictor.set_image(image) x = int(input_label['label'].split(',')[0]) y = int(input_label['label'].split(',')[1]) input_point = np.array([[x, y]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) mask1 = masks[0] score1 = scores[0] img1 = image.copy() img1[mask1, 0] = 255 img1[y-10:y+10, x-10:x+10, 2] = 255 mask2 = masks[1] score2 = scores[1] img2 = image.copy() img2[mask2, 0] = 255 img2[y-10:y+10, x-10:x+10, 2] = 255 mask3 = masks[2] score3 = scores[2] img3 = image.copy() img3[mask3, 0] = 255 img3[y-10:y+10, x-10:x+10, 2] = 255 return f"{score1}", img1, f"{score2}", img2, f"{score3}", img3 my_app = gr.Blocks() with my_app: gr.Markdown("Segment Anything Testing") with gr.Tabs(): with gr.TabItem("Select your image"): with gr.Column(): with gr.Row(): img_source = gr.ImageEditor(label="Please select picture and click the part to segment", value='./images/truck.jpg', height=500, width=1000) with gr.Row(): coords = gr.Label(label="Image Coordinate") model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone') with gr.Row(): infer = gr.Button("Segment") with gr.Row(): score1 = gr.Label(label="Mask 1 Confidence") with gr.Row(): img_output1 = gr.Image(label="Output Mask 1", height=500, width=1000) with gr.Row(): score2 = gr.Label(label="Mask 2 Confidence") with gr.Row(): img_output2 = gr.Image(label="Output Mask 2", height=500, width=1000) with gr.Row(): score3 = gr.Label(label="Mask 3 Confidence") with gr.Row(): img_output3 = gr.Image(label="Output Mask 3", height=500, width=1000) img_source.select(get_coords, [], coords) infer.click( inference, [ img_source, coords, model_choice ], [ score1, img_output1, score2, img_output2, score3, img_output3, ] ) my_app.launch(debug=True)