Noursine commited on
Commit
44bc78d
·
verified ·
1 Parent(s): 25086f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py CHANGED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import gradio as gr
4
+ from segment_anything import sam_model_registry, SamPredictor
5
+ from PIL import Image
6
+
7
+ # ===== Load SAM2 model =====
8
+ CHECKPOINT_PATH = "sam2_huge.pth" # Place your checkpoint in the repo
9
+ MODEL_TYPE = "vit_h" # Adjust based on checkpoint
10
+ DEVICE = "cuda" # or "cpu"
11
+
12
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
13
+ sam.to(DEVICE)
14
+ predictor = SamPredictor(sam)
15
+
16
+ # ===== State variables =====
17
+ points_state = []
18
+ labels_state = []
19
+ box_state = []
20
+
21
+ def set_image(image):
22
+ global points_state, labels_state, box_state
23
+ points_state, labels_state, box_state = [], [], []
24
+ predictor.set_image(image)
25
+ return image
26
+
27
+ def add_point(x, y, label):
28
+ points_state.append([x, y])
29
+ labels_state.append(label)
30
+ return run_prediction()
31
+
32
+ def set_box(x1, y1, x2, y2):
33
+ global box_state
34
+ box_state = [x1, y1, x2, y2]
35
+ return run_prediction()
36
+
37
+ def run_prediction():
38
+ points_np = np.array(points_state) if points_state else None
39
+ labels_np = np.array(labels_state) if labels_state else None
40
+ box_np = np.array(box_state) if len(box_state) == 4 else None
41
+
42
+ masks, _, _ = predictor.predict(
43
+ point_coords=points_np,
44
+ point_labels=labels_np,
45
+ box=box_np[None, :] if box_np is not None else None,
46
+ multimask_output=False
47
+ )
48
+
49
+ mask = masks[0]
50
+ overlay = overlay_mask(predictor.original_image, mask)
51
+ return overlay
52
+
53
+ def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5):
54
+ overlay = image.copy()
55
+ overlay[mask] = (overlay[mask] * (1 - alpha) + np.array(color) * alpha).astype(np.uint8)
56
+ return overlay
57
+
58
+ # ===== Gradio Interface =====
59
+ with gr.Blocks() as demo:
60
+ img_input = gr.Image(label="Upload Image", type="numpy")
61
+ img_output = gr.Image(label="Segmentation Output", type="numpy")
62
+
63
+ img_input.change(set_image, inputs=img_input, outputs=img_output)
64
+
65
+ # JS events will trigger these
66
+ gr.Button("Add Positive Point").click(
67
+ lambda: add_point(latest_click_x, latest_click_y, 1), outputs=img_output
68
+ )
69
+ gr.Button("Add Negative Point").click(
70
+ lambda: add_point(latest_click_x, latest_click_y, 0), outputs=img_output
71
+ )
72
+
73
+ demo.launch()