KurtLin's picture
Initial Commit
712d80d
raw
history blame
1.35 kB
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
from preprocess import show_mask, show_points, show_box
import gradio as gr
def get_coord_infer(evt: gr.SelectData):
return [evt.index[0], evt.index[1]]
# sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
# model_type = "vit_b"
# device = "cuda" if torch.cuda.is_available() else "cpu"
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
# predictor = SamPredictor(sam)
my_app = gr.Blocks()
with my_app:
gr.Markdown("Segment Anything Testing")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Row():
with gr.Column():
img_source = gr.Image(label="Please select picture.", value='./images/truck.jpg', shape=(768, 768))
coords = gr.Label(label="Image Coordinate.")
with gr.Column():
img_output = gr.Image(label="Output Mask")
img_source.select(get_coord_infer, [], coords)
# set_point.click(
# img_source.select(get_coord),
# [
# img_source
# ],
# [
# coords
# ]
# )
my_app.launch(debug=True)