tbuyuktanir commited on
Commit
fadee54
·
verified ·
1 Parent(s): 74641c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -116
app.py CHANGED
@@ -1,117 +1,126 @@
1
- import gradio as gr
2
- from transformers import SamModel, SamProcessor
3
- from PIL import Image, ImageDraw
4
- import torch
5
- import numpy as np
6
-
7
- # Load SAM
8
- processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
9
- model = SamModel.from_pretrained("facebook/sam-vit-huge")
10
- model.eval()
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model.to(device)
13
-
14
- # Global variables
15
- global_state = {
16
- "image": None,
17
- "clicks": [], # List of tuples: (x, y, label) where label=1 fg, 0 bg
18
- "bbox": None, # (x0, y0, x1, y1)
19
- }
20
-
21
- # Helper to apply mask overlay
22
- def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
23
- mask_img = Image.fromarray(mask.astype(np.uint8) * 255).convert("L")
24
- color_mask = Image.new("RGB", image.size, color)
25
- mask_rgb = Image.composite(color_mask, image, mask_img)
26
- blended = Image.blend(image, mask_rgb, alpha=0.5)
27
- return blended
28
-
29
- # Set image
30
- def upload_image(img):
31
- global_state["image"] = img
32
- global_state["clicks"] = []
33
- global_state["bbox"] = None
34
- return "Image uploaded. Now click or draw a box.", img
35
-
36
- # Handle point clicks
37
- def on_click(evt: gr.SelectData):
38
- if global_state["image"] is None:
39
- return gr.update(), "Please upload an image first."
40
-
41
- # Default to foreground click
42
- global_state["clicks"].append((evt.index[0], evt.index[1], 1))
43
- return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})"
44
-
45
- # Handle bounding box input
46
- def set_bbox(x0: int, y0: int, x1: int, y1: int):
47
- global_state["bbox"] = (x0, y0, x1, y1)
48
- return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})"
49
-
50
- # Run segmentation
51
- def run_segmentation():
52
- if global_state["image"] is None:
53
- return None, "Please upload an image first."
54
-
55
- image = global_state["image"]
56
- inputs = processor(image, return_tensors="pt").to(device)
57
-
58
- if global_state["clicks"]:
59
- points = torch.tensor([[[x, y] for (x, y, l) in global_state["clicks"]]], device=device)
60
- labels = torch.tensor([[l for (_, _, l) in global_state["clicks"]]], device=device)
61
- inputs.update({"input_points": points, "input_labels": labels})
62
-
63
- if global_state["bbox"]:
64
- x0, y0, x1, y1 = global_state["bbox"]
65
- box = torch.tensor([[[x0, y0, x1, y1]]], device=device)
66
- inputs.update({"input_boxes": box})
67
-
68
- with torch.no_grad():
69
- outputs = model(**inputs, multimask_output=False)
70
-
71
- masks = processor.image_processor.post_process_masks(
72
- outputs.pred_masks.cpu(),
73
- inputs["original_sizes"].cpu(),
74
- inputs["reshaped_input_sizes"].cpu()
75
- )[0]
76
-
77
- final_mask = masks[0].numpy()
78
- overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
79
-
80
- return overlayed, "Segmentation complete."
81
-
82
- # Reset
83
- def reset_all():
84
- global_state["image"] = None
85
- global_state["clicks"] = []
86
- global_state["bbox"] = None
87
- return None, None, "State reset."
88
-
89
- # Gradio UI
90
- with gr.Blocks() as demo:
91
- gr.Markdown("## 🔬 Interactive Pathology Segmentation with SAM")
92
-
93
- with gr.Row():
94
- with gr.Column():
95
- image_input = gr.Image(type="pil", label="Upload Pathology Image")
96
- upload_status = gr.Textbox(label="Status")
97
- upload_btn = gr.Button("Upload & Set Image")
98
- click_output = gr.Textbox(label="Click Info")
99
- run_btn = gr.Button("Run Segmentation")
100
- reset_btn = gr.Button("Reset")
101
- bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)")
102
- set_bbox_btn = gr.Button("Set Bounding Box")
103
- with gr.Column():
104
- image_output = gr.Image(type="pil", label="Segmentation Output")
105
-
106
- # Handlers
107
- upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output])
108
- image_input.select(on_click, outputs=[image_output, click_output])
109
- run_btn.click(run_segmentation, outputs=[image_output, upload_status])
110
- reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status])
111
- set_bbox_btn.click(
112
- lambda coords: set_bbox(*map(int, coords.split(","))),
113
- inputs=bbox_coords,
114
- outputs=click_output
115
- )
116
-
 
 
 
 
 
 
 
 
 
117
  demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import SamModel, SamProcessor
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ import numpy as np
6
+
7
+ # Load SAM
8
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
9
+ model = SamModel.from_pretrained("facebook/sam-vit-huge")
10
+ model.eval()
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+ # Global variables
15
+ global_state = {
16
+ "image": None,
17
+ "clicks": [], # List of tuples: (x, y, label) where label=1 fg, 0 bg
18
+ "bbox": None, # (x0, y0, x1, y1)
19
+ }
20
+
21
+ # Helper to apply mask overlay
22
+ def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
23
+ mask_img = Image.fromarray(mask.astype(np.uint8) * 255).convert("L")
24
+ color_mask = Image.new("RGB", image.size, color)
25
+ mask_rgb = Image.composite(color_mask, image, mask_img)
26
+ blended = Image.blend(image, mask_rgb, alpha=0.5)
27
+ return blended
28
+
29
+ # Set image
30
+ def upload_image(img):
31
+ global_state["image"] = img
32
+ global_state["clicks"] = []
33
+ global_state["bbox"] = None
34
+ return "Image uploaded. Now click or draw a box.", img
35
+
36
+ # Handle point clicks
37
+ def on_click(evt: gr.SelectData):
38
+ if global_state["image"] is None:
39
+ return gr.update(), "Please upload an image first."
40
+
41
+ # Default to foreground click
42
+ global_state["clicks"].append((evt.index[0], evt.index[1], 1))
43
+ return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})"
44
+
45
+ # Handle bounding box input
46
+ def set_bbox(x0: int, y0: int, x1: int, y1: int):
47
+ global_state["bbox"] = (x0, y0, x1, y1)
48
+ return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})"
49
+
50
+ # Run segmentation
51
+ def run_segmentation():
52
+ if global_state["image"] is None:
53
+ return None, "Please upload an image first."
54
+
55
+ image = global_state["image"]
56
+ inputs = processor(image, return_tensors="pt").to(device)
57
+
58
+ if global_state["clicks"]:
59
+ points = torch.tensor([[[x, y] for (x, y, l) in global_state["clicks"]]], device=device)
60
+ labels = torch.tensor([[l for (_, _, l) in global_state["clicks"]]], device=device)
61
+ inputs.update({"input_points": points, "input_labels": labels})
62
+
63
+ if global_state["bbox"]:
64
+ x0, y0, x1, y1 = global_state["bbox"]
65
+ box = torch.tensor([[[x0, y0, x1, y1]]], device=device)
66
+ inputs.update({"input_boxes": box})
67
+
68
+ with torch.no_grad():
69
+ outputs = model(**inputs, multimask_output=False)
70
+
71
+ masks = processor.image_processor.post_process_masks(
72
+ outputs.pred_masks.cpu(),
73
+ inputs["original_sizes"].cpu(),
74
+ inputs["reshaped_input_sizes"].cpu()
75
+ )[0]
76
+
77
+ final_mask = masks[0].numpy()
78
+ overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
79
+
80
+ return overlayed, "Segmentation complete."
81
+
82
+ # Reset
83
+ def reset_all():
84
+ global_state["image"] = None
85
+ global_state["clicks"] = []
86
+ global_state["bbox"] = None
87
+ return None, None, "State reset."
88
+
89
+ # Gradio UI
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("Interactive Pathology Segmentation with SAM")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ image_input = gr.Image(type="pil", label="Upload Pathology Image")
96
+ upload_status = gr.Textbox(label="Status")
97
+ upload_btn = gr.Button("Upload & Set Image")
98
+ click_output = gr.Textbox(label="Click Info")
99
+ run_btn = gr.Button("Run Segmentation")
100
+ reset_btn = gr.Button("Reset")
101
+ bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)")
102
+ set_bbox_btn = gr.Button("Set Bounding Box")
103
+ examples = gr.Examples(
104
+ examples=[
105
+ ["https://www.webpathology.com/_next/image?url=https%3A%2F%2Fd3cyex60hhnlth.cloudfront.net%2Ffit-in%2F650x650%2Ffilters%3Aformat(webp)%2Fcase%2Fdetail_images%2Fc354_detail.jpg&w=750&q=75"],
106
+ ["https://www.webpathology.com/_next/image?url=https%3A%2F%2Fd3cyex60hhnlth.cloudfront.net%2Ffit-in%2F650x650%2Ffilters%3Aformat(webp)%2Fcase%2Fdetail_images%2Fc354_detail.jpg&w=750&q=75"],
107
+ ["https://www.webpathology.com/_next/image?url=https%3A%2F%2Fd3cyex60hhnlth.cloudfront.net%2Ffit-in%2F650x650%2Ffilters%3Aformat(webp)%2Fcase%2Fdetail_images%2Fc354_detail.jpg&w=750&q=75"]
108
+ ],
109
+ inputs=[image_input],
110
+ label="Example Pathology Images"
111
+ )
112
+ with gr.Column():
113
+ image_output = gr.Image(type="pil", label="Segmentation Output")
114
+
115
+ # Handlers
116
+ upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output])
117
+ image_input.select(on_click, outputs=[image_output, click_output])
118
+ run_btn.click(run_segmentation, outputs=[image_output, upload_status])
119
+ reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status])
120
+ set_bbox_btn.click(
121
+ lambda coords: set_bbox(*map(int, coords.split(","))),
122
+ inputs=bbox_coords,
123
+ outputs=click_output
124
+ )
125
+
126
  demo.launch()