File size: 3,595 Bytes
131f195
 
 
 
 
 
146111c
d13628a
146111c
d13628a
146111c
131f195
146111c
 
 
 
131f195
 
146111c
 
 
 
131f195
 
d13628a
 
 
 
 
 
 
 
 
 
 
131f195
 
 
 
 
 
 
 
 
 
146111c
 
 
131f195
146111c
131f195
 
 
146111c
131f195
 
 
 
146111c
131f195
 
 
 
 
 
 
d13628a
 
 
8ad7724
 
d13628a
8ad7724
 
131f195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d13628a
131f195
d13628a
 
 
131f195
 
 
146111c
131f195
 
 
 
 
 
 
 
 
 
 
 
d13628a
131f195
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import gradio as gr
import torch
import cv2
from segment_anything import SamPredictor, sam_model_registry

# Global variables
MODELS = ["./models/sam_vit_b_01ec64.pth", "./models/medsam_vitb.pth"]
OFFICIAL_CHECKPOINT = "./models/sam_vit_b_01ec64.pth"
MEDSAM_CHECKPOINT = "./models/medsam_vitb_best.pth"
MODEL_TYPE = "vit_b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model
## OFFICIAL SAM
SAM = sam_model_registry[MODEL_TYPE](checkpoint=OFFICIAL_CHECKPOINT)
SAM.to(device=DEVICE)
SAM_PREDICTOR = SamPredictor(SAM)
## MEDSAM
MEDSAM = sam_model_registry[MODEL_TYPE](checkpoint=MEDSAM_CHECKPOINT)
MEDSAM.to(device=DEVICE)
MEDSAM_PREDICTOR = SamPredictor(MEDSAM)


def load_model(model_choice: int) -> SamPredictor:
    """Load model."""
    print("model_choice", model_choice)
    if model_choice == 0:
        return SAM_PREDICTOR
    elif model_choice == 1:
        return MEDSAM_PREDICTOR
    else:
        raise ValueError("Model choice must be 0 or 1")


def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    # draw contour
    contour_image = image.copy()
    contours, _ = cv2.findContours(
        np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    cv2.drawContours(contour_image, contours, -1, (0, 0, 255), 3)
    return contour_image, contours


def inference(
    predictor: SamPredictor, image: np.ndarray, coord_y: int, coord_x: int
) -> np.ndarray:
    """Inference."""
    predictor.set_image(image)

    input_point = np.array([[coord_y, coord_x]])
    input_label = np.array([1])
    mask, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )

    h, w = mask.shape[-2:]
    mask = mask.reshape(h, w, 1)
    mask = (mask * 255).astype(np.uint8)
    contour_image, _ = draw_contour(image, mask)
    return contour_image


def extract_object_by_event(model_choice: int, image: np.ndarray, evt: gr.SelectData):
    """Extract object by mouse click."""
    predictor = load_model(model_choice)
    click_h, click_w = evt.index

    return inference(predictor, image, click_h, click_w)


def get_coords(evt: gr.SelectData):
    """Get coords from mouse click in gradio."""
    return evt.index[0], evt.index[1]


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(
            """# Segment Anything!🚀
            The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. 
            More information can be found in [**Official Project**](https://segment-anything.com/).
            """
        )
    with gr.Row():
        # select model
        model_choice = gr.Dropdown(
            label="Select Model",
            choices=[m for m in MODELS],
            type="index",
            interactive=True,
        )

    # Segment image
    with gr.Tab(label="SAM Inference"):
        with gr.Row().style(equal_height=True):
            with gr.Column(label="Input Image"):
                # input image
                input_image = gr.Image(type="numpy")

            with gr.Column(label="Output"):
                # output
                output = gr.Image(type="numpy")

    with gr.Row():
        coord_h = gr.Number(label="Mouse coords h")
        coord_w = gr.Number(label="Mouse coords w")
    input_image.select(extract_object_by_event, [model_choice, input_image], output)
    input_image.select(get_coords, None, [coord_h, coord_w])

demo.queue().launch(debug=True, enable_queue=True)