File size: 3,651 Bytes
11f87cc
 
 
712d80d
 
11f87cc
 
00a3d3d
 
 
 
 
 
 
 
 
 
8896ee3
00a3d3d
8896ee3
 
 
 
00a3d3d
 
 
 
8896ee3
c94b67e
 
 
8896ee3
 
 
 
 
 
4068132
 
 
 
 
 
 
 
 
cdf1836
4068132
 
 
 
cdf1836
4068132
712d80d
 
11f87cc
 
 
 
 
ef16738
 
24fcab2
e814975
ef16738
8896ee3
4068132
ef16738
4d9e59a
ef16738
4068132
 
e814975
4068132
 
 
e814975
4068132
 
 
e814975
11f87cc
8896ee3
 
 
 
 
00a3d3d
 
8896ee3
 
4068132
 
 
 
 
 
8896ee3
 
11f87cc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
import gradio as gr

sam_checkpoint = {
                    "ViT-base": "weights/sam_vit_b_01ec64.pth",
                    "ViT-large": "weights/sam_vit_l_0b3195.pth",
                    "ViT-huge": "weights/sam_vit_h_4b8939.pth",
                 }
model_type = {
                "ViT-base": "vit_b",
                "ViT-large": "vit_l",
                "ViT-huge": "vit_h",
             }
device = "cuda" if torch.cuda.is_available() else "cpu"


def get_coords(evt: gr.SelectData):
    return f"{evt.index[0]}, {evt.index[1]}"

def inference(image, input_label, model_choice):
    sam = sam_model_registry[model_type[model_choice]](checkpoint=sam_checkpoint[model_choice])
    sam.to(device=device)
    predictor = SamPredictor(sam)
    predictor.set_image(image)
    x = int(input_label['label'].split(',')[0])
    y = int(input_label['label'].split(',')[1])
    input_point = np.array([[x, y]])
    input_label = np.array([1])
    masks, scores, logits = predictor.predict(
                                                point_coords=input_point,
                                                point_labels=input_label,
                                                multimask_output=True,
                                             )
    mask1 = masks[0]
    score1 = scores[0]
    img1 = image.copy()
    img1[mask1, 0] = 255
    img1[y-10:y+10, x-10:x+10, 2] = 255
    mask2 = masks[1]
    score2 = scores[1]
    img2 = image.copy()
    img2[mask2, 0] = 255
    img2[y-10:y+10, x-10:x+10, 2] = 255
    mask3 = masks[2]
    score3 = scores[2]
    img3 = image.copy()
    img3[mask3, 0] = 255
    img3[y-10:y+10, x-10:x+10, 2] = 255
    return f"{score1}", img1, f"{score2}", img2, f"{score3}", img3


my_app = gr.Blocks()
with my_app:
    gr.Markdown("Segment Anything Testing")
    with gr.Tabs():
        with gr.TabItem("Select your image"):
            with gr.Column():
                with gr.Row():
                    img_source = gr.ImageEditor(label="Please select picture and click the part to segment", 
                                          value='./images/truck.jpg', height=500, width=1000)
                with gr.Row():
                    coords = gr.Label(label="Image Coordinate")
                    model_choice = gr.Dropdown(['ViT-base', 'ViT-large', 'ViT-huge'], label='Model Backbone')
                with gr.Row():
                    infer = gr.Button("Segment")
                with gr.Row():
                    score1 = gr.Label(label="Mask 1 Confidence")
                with gr.Row():
                    img_output1 = gr.Image(label="Output Mask 1", height=500, width=1000)
                with gr.Row():
                    score2 = gr.Label(label="Mask 2 Confidence")
                with gr.Row():
                    img_output2 = gr.Image(label="Output Mask 2", height=500, width=1000)
                with gr.Row():
                    score3 = gr.Label(label="Mask 3 Confidence")
                with gr.Row():
                    img_output3 = gr.Image(label="Output Mask 3", height=500, width=1000)
                    
        img_source.select(get_coords, [], coords)
        infer.click(
            inference,
            [
                img_source,
                coords,
                model_choice
            ],
            [
                score1,
                img_output1,
                score2,
                img_output2,
                score3,
                img_output3, 
            ]
        )
    
my_app.launch(debug=True)