Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import torch | |
| from segment_anything import sam_model_registry, SamPredictor | |
| import gradio as gr | |
| sam_checkpoint = { | |
| "ViT-base": "weights/sam_vit_b_01ec64.pth", | |
| "ViT-large": "weights/sam_vit_l_0b3195.pth", | |
| "ViT-huge": "weights/sam_vit_h_4b8939.pth", | |
| } | |
| model_type = { | |
| "ViT-base": "vit_b", | |
| "ViT-large": "vit_l", | |
| "ViT-huge": "vit_h", | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def get_coords(evt: gr.SelectData): | |
| return f"{evt.index[0]}, {evt.index[1]}" | |
| def inference(image, input_label, model_choice): | |
| sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice]) | |
| sam.to(device=device) | |
| predictor = SamPredictor(sam) | |
| predictor.set_image(image) | |
| x = int(input_label['label'].split(',')[0]) | |
| y = int(input_label['label'].split(',')[1]) | |
| input_point = np.array([[x, y]]) | |
| input_label = np.array([1]) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=True, | |
| ) | |
| mask1 = masks[0] | |
| score1 = scores[0] | |
| img1 = image.copy() | |
| img1[mask1, 0] = 255 | |
| img1[y-10:y+10, x-10:x+10, 2] = 255 | |
| mask2 = masks[1] | |
| score2 = scores[1] | |
| img2 = image.copy() | |
| img2[mask2, 0] = 255 | |
| img2[y-10:y+10, x-10:x+10, 2] = 255 | |
| mask3 = masks[2] | |
| score3 = scores[2] | |
| img3 = image.copy() | |
| img3[mask3, 0] = 255 | |
| img3[y-10:y+10, x-10:x+10, 2] = 255 | |
| return f"{score1}", img1, f"{score2}", img2, f"{score3}", img3 | |
| my_app = gr.Blocks() | |
| with my_app: | |
| gr.Markdown("Segment Anything Testing") | |
| with gr.Tabs(): | |
| with gr.TabItem("Select your image"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| img_source = gr.ImageEditor(label="Please select picture and click the part to segment", | |
| value='./images/truck.jpg', height=500, width=1000) | |
| with gr.Row(): | |
| coords = gr.Label(label="Image Coordinate") | |
| model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone') | |
| with gr.Row(): | |
| infer = gr.Button("Segment") | |
| with gr.Row(): | |
| score1 = gr.Label(label="Mask 1 Confidence") | |
| with gr.Row(): | |
| img_output1 = gr.Image(label="Output Mask 1", height=500, width=1000) | |
| with gr.Row(): | |
| score2 = gr.Label(label="Mask 2 Confidence") | |
| with gr.Row(): | |
| img_output2 = gr.Image(label="Output Mask 2", height=500, width=1000) | |
| with gr.Row(): | |
| score3 = gr.Label(label="Mask 3 Confidence") | |
| with gr.Row(): | |
| img_output3 = gr.Image(label="Output Mask 3", height=500, width=1000) | |
| img_source.select(get_coords, [], coords) | |
| infer.click( | |
| inference, | |
| [ | |
| img_source, | |
| coords, | |
| model_choice | |
| ], | |
| [ | |
| score1, | |
| img_output1, | |
| score2, | |
| img_output2, | |
| score3, | |
| img_output3, | |
| ] | |
| ) | |
| my_app.launch(debug=True) |