Noursine's picture
Update app.py
44bc78d verified
import numpy as np
import cv2
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
# ===== Load SAM2 model =====
CHECKPOINT_PATH = "sam2_huge.pth" # Place your checkpoint in the repo
MODEL_TYPE = "vit_h" # Adjust based on checkpoint
DEVICE = "cuda" # or "cpu"
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam.to(DEVICE)
predictor = SamPredictor(sam)
# ===== State variables =====
points_state = []
labels_state = []
box_state = []
def set_image(image):
global points_state, labels_state, box_state
points_state, labels_state, box_state = [], [], []
predictor.set_image(image)
return image
def add_point(x, y, label):
points_state.append([x, y])
labels_state.append(label)
return run_prediction()
def set_box(x1, y1, x2, y2):
global box_state
box_state = [x1, y1, x2, y2]
return run_prediction()
def run_prediction():
points_np = np.array(points_state) if points_state else None
labels_np = np.array(labels_state) if labels_state else None
box_np = np.array(box_state) if len(box_state) == 4 else None
masks, _, _ = predictor.predict(
point_coords=points_np,
point_labels=labels_np,
box=box_np[None, :] if box_np is not None else None,
multimask_output=False
)
mask = masks[0]
overlay = overlay_mask(predictor.original_image, mask)
return overlay
def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5):
overlay = image.copy()
overlay[mask] = (overlay[mask] * (1 - alpha) + np.array(color) * alpha).astype(np.uint8)
return overlay
# ===== Gradio Interface =====
with gr.Blocks() as demo:
img_input = gr.Image(label="Upload Image", type="numpy")
img_output = gr.Image(label="Segmentation Output", type="numpy")
img_input.change(set_image, inputs=img_input, outputs=img_output)
# JS events will trigger these
gr.Button("Add Positive Point").click(
lambda: add_point(latest_click_x, latest_click_y, 1), outputs=img_output
)
gr.Button("Add Negative Point").click(
lambda: add_point(latest_click_x, latest_click_y, 0), outputs=img_output
)
demo.launch()