KurtLin's picture
Update app.py
24fcab2 verified
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
import gradio as gr
sam_checkpoint = {
"ViT-base": "weights/sam_vit_b_01ec64.pth",
"ViT-large": "weights/sam_vit_l_0b3195.pth",
"ViT-huge": "weights/sam_vit_h_4b8939.pth",
}
model_type = {
"ViT-base": "vit_b",
"ViT-large": "vit_l",
"ViT-huge": "vit_h",
}
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_coords(evt: gr.SelectData):
return f"{evt.index[0]}, {evt.index[1]}"
def inference(image, input_label, model_choice):
sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice])
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
x = int(input_label['label'].split(',')[0])
y = int(input_label['label'].split(',')[1])
input_point = np.array([[x, y]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
mask1 = masks[0]
score1 = scores[0]
img1 = image.copy()
img1[mask1, 0] = 255
img1[y-10:y+10, x-10:x+10, 2] = 255
mask2 = masks[1]
score2 = scores[1]
img2 = image.copy()
img2[mask2, 0] = 255
img2[y-10:y+10, x-10:x+10, 2] = 255
mask3 = masks[2]
score3 = scores[2]
img3 = image.copy()
img3[mask3, 0] = 255
img3[y-10:y+10, x-10:x+10, 2] = 255
return f"{score1}", img1, f"{score2}", img2, f"{score3}", img3
my_app = gr.Blocks()
with my_app:
gr.Markdown("Segment Anything Testing")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Column():
with gr.Row():
img_source = gr.ImageEditor(label="Please select picture and click the part to segment",
value='./images/truck.jpg', height=500, width=1000)
with gr.Row():
coords = gr.Label(label="Image Coordinate")
model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone')
with gr.Row():
infer = gr.Button("Segment")
with gr.Row():
score1 = gr.Label(label="Mask 1 Confidence")
with gr.Row():
img_output1 = gr.Image(label="Output Mask 1", height=500, width=1000)
with gr.Row():
score2 = gr.Label(label="Mask 2 Confidence")
with gr.Row():
img_output2 = gr.Image(label="Output Mask 2", height=500, width=1000)
with gr.Row():
score3 = gr.Label(label="Mask 3 Confidence")
with gr.Row():
img_output3 = gr.Image(label="Output Mask 3", height=500, width=1000)
img_source.select(get_coords, [], coords)
infer.click(
inference,
[
img_source,
coords,
model_choice
],
[
score1,
img_output1,
score2,
img_output2,
score3,
img_output3,
]
)
my_app.launch(debug=True)