Spaces:
Runtime error
Runtime error
File size: 3,651 Bytes
11f87cc 712d80d 11f87cc 00a3d3d 8896ee3 00a3d3d 8896ee3 00a3d3d 8896ee3 c94b67e 8896ee3 4068132 cdf1836 4068132 cdf1836 4068132 712d80d 11f87cc ef16738 24fcab2 e814975 ef16738 8896ee3 4068132 ef16738 4d9e59a ef16738 4068132 e814975 4068132 e814975 4068132 e814975 11f87cc 8896ee3 00a3d3d 8896ee3 4068132 8896ee3 11f87cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
import gradio as gr
sam_checkpoint = {
"ViT-base": "weights/sam_vit_b_01ec64.pth",
"ViT-large": "weights/sam_vit_l_0b3195.pth",
"ViT-huge": "weights/sam_vit_h_4b8939.pth",
}
model_type = {
"ViT-base": "vit_b",
"ViT-large": "vit_l",
"ViT-huge": "vit_h",
}
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_coords(evt: gr.SelectData):
return f"{evt.index[0]}, {evt.index[1]}"
def inference(image, input_label, model_choice):
sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice])
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
x = int(input_label['label'].split(',')[0])
y = int(input_label['label'].split(',')[1])
input_point = np.array([[x, y]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
mask1 = masks[0]
score1 = scores[0]
img1 = image.copy()
img1[mask1, 0] = 255
img1[y-10:y+10, x-10:x+10, 2] = 255
mask2 = masks[1]
score2 = scores[1]
img2 = image.copy()
img2[mask2, 0] = 255
img2[y-10:y+10, x-10:x+10, 2] = 255
mask3 = masks[2]
score3 = scores[2]
img3 = image.copy()
img3[mask3, 0] = 255
img3[y-10:y+10, x-10:x+10, 2] = 255
return f"{score1}", img1, f"{score2}", img2, f"{score3}", img3
my_app = gr.Blocks()
with my_app:
gr.Markdown("Segment Anything Testing")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Column():
with gr.Row():
img_source = gr.ImageEditor(label="Please select picture and click the part to segment",
value='./images/truck.jpg', height=500, width=1000)
with gr.Row():
coords = gr.Label(label="Image Coordinate")
model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone')
with gr.Row():
infer = gr.Button("Segment")
with gr.Row():
score1 = gr.Label(label="Mask 1 Confidence")
with gr.Row():
img_output1 = gr.Image(label="Output Mask 1", height=500, width=1000)
with gr.Row():
score2 = gr.Label(label="Mask 2 Confidence")
with gr.Row():
img_output2 = gr.Image(label="Output Mask 2", height=500, width=1000)
with gr.Row():
score3 = gr.Label(label="Mask 3 Confidence")
with gr.Row():
img_output3 = gr.Image(label="Output Mask 3", height=500, width=1000)
img_source.select(get_coords, [], coords)
infer.click(
inference,
[
img_source,
coords,
model_choice
],
[
score1,
img_output1,
score2,
img_output2,
score3,
img_output3,
]
)
my_app.launch(debug=True) |