import numpy as np import cv2 import gradio as gr from segment_anything import sam_model_registry, SamPredictor from PIL import Image # ===== Load SAM2 model ===== CHECKPOINT_PATH = "sam2_huge.pth" # Place your checkpoint in the repo MODEL_TYPE = "vit_h" # Adjust based on checkpoint DEVICE = "cuda" # or "cpu" sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH) sam.to(DEVICE) predictor = SamPredictor(sam) # ===== State variables ===== points_state = [] labels_state = [] box_state = [] def set_image(image): global points_state, labels_state, box_state points_state, labels_state, box_state = [], [], [] predictor.set_image(image) return image def add_point(x, y, label): points_state.append([x, y]) labels_state.append(label) return run_prediction() def set_box(x1, y1, x2, y2): global box_state box_state = [x1, y1, x2, y2] return run_prediction() def run_prediction(): points_np = np.array(points_state) if points_state else None labels_np = np.array(labels_state) if labels_state else None box_np = np.array(box_state) if len(box_state) == 4 else None masks, _, _ = predictor.predict( point_coords=points_np, point_labels=labels_np, box=box_np[None, :] if box_np is not None else None, multimask_output=False ) mask = masks[0] overlay = overlay_mask(predictor.original_image, mask) return overlay def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5): overlay = image.copy() overlay[mask] = (overlay[mask] * (1 - alpha) + np.array(color) * alpha).astype(np.uint8) return overlay # ===== Gradio Interface ===== with gr.Blocks() as demo: img_input = gr.Image(label="Upload Image", type="numpy") img_output = gr.Image(label="Segmentation Output", type="numpy") img_input.change(set_image, inputs=img_input, outputs=img_output) # JS events will trigger these gr.Button("Add Positive Point").click( lambda: add_point(latest_click_x, latest_click_y, 1), outputs=img_output ) gr.Button("Add Negative Point").click( lambda: add_point(latest_click_x, latest_click_y, 0), outputs=img_output ) demo.launch()