sam-inference / app.py
Bono93's picture
fix: medsam model
d13628a
import numpy as np
import gradio as gr
import torch
import cv2
from segment_anything import SamPredictor, sam_model_registry
# Global variables
MODELS = ["./models/sam_vit_b_01ec64.pth", "./models/medsam_vitb.pth"]
OFFICIAL_CHECKPOINT = "./models/sam_vit_b_01ec64.pth"
MEDSAM_CHECKPOINT = "./models/medsam_vitb_best.pth"
MODEL_TYPE = "vit_b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model
## OFFICIAL SAM
SAM = sam_model_registry[MODEL_TYPE](checkpoint=OFFICIAL_CHECKPOINT)
SAM.to(device=DEVICE)
SAM_PREDICTOR = SamPredictor(SAM)
## MEDSAM
MEDSAM = sam_model_registry[MODEL_TYPE](checkpoint=MEDSAM_CHECKPOINT)
MEDSAM.to(device=DEVICE)
MEDSAM_PREDICTOR = SamPredictor(MEDSAM)
def load_model(model_choice: int) -> SamPredictor:
"""Load model."""
print("model_choice", model_choice)
if model_choice == 0:
return SAM_PREDICTOR
elif model_choice == 1:
return MEDSAM_PREDICTOR
else:
raise ValueError("Model choice must be 0 or 1")
def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
# draw contour
contour_image = image.copy()
contours, _ = cv2.findContours(
np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(contour_image, contours, -1, (0, 0, 255), 3)
return contour_image, contours
def inference(
predictor: SamPredictor, image: np.ndarray, coord_y: int, coord_x: int
) -> np.ndarray:
"""Inference."""
predictor.set_image(image)
input_point = np.array([[coord_y, coord_x]])
input_label = np.array([1])
mask, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
h, w = mask.shape[-2:]
mask = mask.reshape(h, w, 1)
mask = (mask * 255).astype(np.uint8)
contour_image, _ = draw_contour(image, mask)
return contour_image
def extract_object_by_event(model_choice: int, image: np.ndarray, evt: gr.SelectData):
"""Extract object by mouse click."""
predictor = load_model(model_choice)
click_h, click_w = evt.index
return inference(predictor, image, click_h, click_w)
def get_coords(evt: gr.SelectData):
"""Get coords from mouse click in gradio."""
return evt.index[0], evt.index[1]
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""# Segment Anything!🚀
The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image.
More information can be found in [**Official Project**](https://segment-anything.com/).
"""
)
with gr.Row():
# select model
model_choice = gr.Dropdown(
label="Select Model",
choices=[m for m in MODELS],
type="index",
interactive=True,
)
# Segment image
with gr.Tab(label="SAM Inference"):
with gr.Row().style(equal_height=True):
with gr.Column(label="Input Image"):
# input image
input_image = gr.Image(type="numpy")
with gr.Column(label="Output"):
# output
output = gr.Image(type="numpy")
with gr.Row():
coord_h = gr.Number(label="Mouse coords h")
coord_w = gr.Number(label="Mouse coords w")
input_image.select(extract_object_by_event, [model_choice, input_image], output)
input_image.select(get_coords, None, [coord_h, coord_w])
demo.queue().launch(debug=True, enable_queue=True)