File size: 2,204 Bytes
44bc78d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import cv2
import gradio as gr
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image

# ===== Load SAM2 model =====
CHECKPOINT_PATH = "sam2_huge.pth"  # Place your checkpoint in the repo
MODEL_TYPE = "vit_h"  # Adjust based on checkpoint
DEVICE = "cuda"  # or "cpu"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam.to(DEVICE)
predictor = SamPredictor(sam)

# ===== State variables =====
points_state = []
labels_state = []
box_state = []

def set_image(image):
    global points_state, labels_state, box_state
    points_state, labels_state, box_state = [], [], []
    predictor.set_image(image)
    return image

def add_point(x, y, label):
    points_state.append([x, y])
    labels_state.append(label)
    return run_prediction()

def set_box(x1, y1, x2, y2):
    global box_state
    box_state = [x1, y1, x2, y2]
    return run_prediction()

def run_prediction():
    points_np = np.array(points_state) if points_state else None
    labels_np = np.array(labels_state) if labels_state else None
    box_np = np.array(box_state) if len(box_state) == 4 else None

    masks, _, _ = predictor.predict(
        point_coords=points_np,
        point_labels=labels_np,
        box=box_np[None, :] if box_np is not None else None,
        multimask_output=False
    )

    mask = masks[0]
    overlay = overlay_mask(predictor.original_image, mask)
    return overlay

def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5):
    overlay = image.copy()
    overlay[mask] = (overlay[mask] * (1 - alpha) + np.array(color) * alpha).astype(np.uint8)
    return overlay

# ===== Gradio Interface =====
with gr.Blocks() as demo:
    img_input = gr.Image(label="Upload Image", type="numpy")
    img_output = gr.Image(label="Segmentation Output", type="numpy")

    img_input.change(set_image, inputs=img_input, outputs=img_output)

    # JS events will trigger these
    gr.Button("Add Positive Point").click(
        lambda: add_point(latest_click_x, latest_click_y, 1), outputs=img_output
    )
    gr.Button("Add Negative Point").click(
        lambda: add_point(latest_click_x, latest_click_y, 0), outputs=img_output
    )

demo.launch()