tbuyuktanir commited on
Commit
95d468b
·
verified ·
1 Parent(s): efb3863

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import SamModel, SamProcessor
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ import numpy as np
6
+
7
+ # Load SAM
8
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
9
+ model = SamModel.from_pretrained("facebook/sam-vit-huge")
10
+ model.eval()
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+ # Global variables
15
+ global_state = {
16
+ "image": None,
17
+ "clicks": [], # List of tuples: (x, y, label) where label=1 fg, 0 bg
18
+ "bbox": None, # (x0, y0, x1, y1)
19
+ }
20
+
21
+ # Helper to apply mask overlay
22
+ def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
23
+ mask_img = Image.fromarray(mask.astype(np.uint8) * 255).convert("L")
24
+ color_mask = Image.new("RGB", image.size, color)
25
+ mask_rgb = Image.composite(color_mask, image, mask_img)
26
+ blended = Image.blend(image, mask_rgb, alpha=0.5)
27
+ return blended
28
+
29
+ # Set image
30
+ def upload_image(img):
31
+ global_state["image"] = img
32
+ global_state["clicks"] = []
33
+ global_state["bbox"] = None
34
+ return "Image uploaded. Now click or draw a box.", img
35
+
36
+ # Handle point clicks
37
+ def on_click(evt: gr.SelectData):
38
+ if global_state["image"] is None:
39
+ return gr.update(), "Please upload an image first."
40
+
41
+ # Default to foreground click
42
+ global_state["clicks"].append((evt.index[0], evt.index[1], 1))
43
+ return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})"
44
+
45
+ # Handle bounding box input
46
+ def set_bbox(x0: int, y0: int, x1: int, y1: int):
47
+ global_state["bbox"] = (x0, y0, x1, y1)
48
+ return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})"
49
+
50
+ # Run segmentation
51
+ def run_segmentation():
52
+ if global_state["image"] is None:
53
+ return None, "Please upload an image first."
54
+
55
+ image = global_state["image"]
56
+ inputs = processor(image, return_tensors="pt").to(device)
57
+
58
+ if global_state["clicks"]:
59
+ points = torch.tensor([[[x, y] for (x, y, l) in global_state["clicks"]]], device=device)
60
+ labels = torch.tensor([[l for (_, _, l) in global_state["clicks"]]], device=device)
61
+ inputs.update({"input_points": points, "input_labels": labels})
62
+
63
+ if global_state["bbox"]:
64
+ x0, y0, x1, y1 = global_state["bbox"]
65
+ box = torch.tensor([[[x0, y0, x1, y1]]], device=device)
66
+ inputs.update({"input_boxes": box})
67
+
68
+ with torch.no_grad():
69
+ outputs = model(**inputs, multimask_output=False)
70
+
71
+ masks = processor.image_processor.post_process_masks(
72
+ outputs.pred_masks.cpu(),
73
+ inputs["original_sizes"].cpu(),
74
+ inputs["reshaped_input_sizes"].cpu()
75
+ )[0]
76
+
77
+ final_mask = masks[0].numpy()
78
+ overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
79
+
80
+ return overlayed, "Segmentation complete."
81
+
82
+ # Reset
83
+ def reset_all():
84
+ global_state["image"] = None
85
+ global_state["clicks"] = []
86
+ global_state["bbox"] = None
87
+ return None, None, "State reset."
88
+
89
+ # Gradio UI
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("## 🔬 Interactive Pathology Segmentation with SAM")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ image_input = gr.Image(type="pil", label="Upload Pathology Image")
96
+ upload_status = gr.Textbox(label="Status")
97
+ upload_btn = gr.Button("Upload & Set Image")
98
+ click_output = gr.Textbox(label="Click Info")
99
+ run_btn = gr.Button("Run Segmentation")
100
+ reset_btn = gr.Button("Reset")
101
+ bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)")
102
+ set_bbox_btn = gr.Button("Set Bounding Box")
103
+ with gr.Column():
104
+ image_output = gr.Image(type="pil", label="Segmentation Output")
105
+
106
+ # Handlers
107
+ upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output])
108
+ image_input.select(on_click, outputs=[image_output, click_output])
109
+ run_btn.click(run_segmentation, outputs=[image_output, upload_status])
110
+ reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status])
111
+ set_bbox_btn.click(
112
+ lambda coords: set_bbox(*map(int, coords.split(","))),
113
+ inputs=bbox_coords,
114
+ outputs=click_output
115
+ )
116
+
117
+ demo.launch()