Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| from segment_anything import SamPredictor, sam_model_registry | |
| # Global variables | |
| MODELS = ["./models/sam_vit_b_01ec64.pth", "./models/medsam_vitb.pth"] | |
| OFFICIAL_CHECKPOINT = "./models/sam_vit_b_01ec64.pth" | |
| MEDSAM_CHECKPOINT = "./models/medsam_vitb_best.pth" | |
| MODEL_TYPE = "vit_b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Model | |
| ## OFFICIAL SAM | |
| SAM = sam_model_registry[MODEL_TYPE](checkpoint=OFFICIAL_CHECKPOINT) | |
| SAM.to(device=DEVICE) | |
| SAM_PREDICTOR = SamPredictor(SAM) | |
| ## MEDSAM | |
| MEDSAM = sam_model_registry[MODEL_TYPE](checkpoint=MEDSAM_CHECKPOINT) | |
| MEDSAM.to(device=DEVICE) | |
| MEDSAM_PREDICTOR = SamPredictor(MEDSAM) | |
| def load_model(model_choice: int) -> SamPredictor: | |
| """Load model.""" | |
| print("model_choice", model_choice) | |
| if model_choice == 0: | |
| return SAM_PREDICTOR | |
| elif model_choice == 1: | |
| return MEDSAM_PREDICTOR | |
| else: | |
| raise ValueError("Model choice must be 0 or 1") | |
| def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| # draw contour | |
| contour_image = image.copy() | |
| contours, _ = cv2.findContours( | |
| np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| cv2.drawContours(contour_image, contours, -1, (0, 0, 255), 3) | |
| return contour_image, contours | |
| def inference( | |
| predictor: SamPredictor, image: np.ndarray, coord_y: int, coord_x: int | |
| ) -> np.ndarray: | |
| """Inference.""" | |
| predictor.set_image(image) | |
| input_point = np.array([[coord_y, coord_x]]) | |
| input_label = np.array([1]) | |
| mask, _, _ = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=False, | |
| ) | |
| h, w = mask.shape[-2:] | |
| mask = mask.reshape(h, w, 1) | |
| mask = (mask * 255).astype(np.uint8) | |
| contour_image, _ = draw_contour(image, mask) | |
| return contour_image | |
| def extract_object_by_event(model_choice: int, image: np.ndarray, evt: gr.SelectData): | |
| """Extract object by mouse click.""" | |
| predictor = load_model(model_choice) | |
| click_h, click_w = evt.index | |
| return inference(predictor, image, click_h, click_w) | |
| def get_coords(evt: gr.SelectData): | |
| """Get coords from mouse click in gradio.""" | |
| return evt.index[0], evt.index[1] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown( | |
| """# Segment Anything!🚀 | |
| The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. | |
| More information can be found in [**Official Project**](https://segment-anything.com/). | |
| """ | |
| ) | |
| with gr.Row(): | |
| # select model | |
| model_choice = gr.Dropdown( | |
| label="Select Model", | |
| choices=[m for m in MODELS], | |
| type="index", | |
| interactive=True, | |
| ) | |
| # Segment image | |
| with gr.Tab(label="SAM Inference"): | |
| with gr.Row().style(equal_height=True): | |
| with gr.Column(label="Input Image"): | |
| # input image | |
| input_image = gr.Image(type="numpy") | |
| with gr.Column(label="Output"): | |
| # output | |
| output = gr.Image(type="numpy") | |
| with gr.Row(): | |
| coord_h = gr.Number(label="Mouse coords h") | |
| coord_w = gr.Number(label="Mouse coords w") | |
| input_image.select(extract_object_by_event, [model_choice, input_image], output) | |
| input_image.select(get_coords, None, [coord_h, coord_w]) | |
| demo.queue().launch(debug=True, enable_queue=True) | |