AnikS22 commited on
Commit
f2f04dc
·
verified ·
1 Parent(s): c81bc0d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MidasMap — Immunogold Particle Detection Dashboard
3
+
4
+ Upload a TEM image, get instant particle detections with heatmaps,
5
+ counts, confidence distributions, and exportable CSV results.
6
+
7
+ Usage:
8
+ python app.py
9
+ python app.py --checkpoint checkpoints/final/final_model.pth
10
+ python app.py --share # public link
11
+ """
12
+
13
+ import argparse
14
+ import io
15
+ import tempfile
16
+ from pathlib import Path
17
+
18
+ import gradio as gr
19
+ import matplotlib
20
+ matplotlib.use("Agg")
21
+ import matplotlib.pyplot as plt
22
+ import numpy as np
23
+ import pandas as pd
24
+ import torch
25
+ import tifffile
26
+
27
+ from src.ensemble import sliding_window_inference
28
+ from src.heatmap import extract_peaks
29
+ from src.model import ImmunogoldCenterNet
30
+ from src.postprocess import cross_class_nms
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Global model (loaded once at startup)
35
+ # ---------------------------------------------------------------------------
36
+ MODEL = None
37
+ DEVICE = None
38
+
39
+
40
+ def load_model(checkpoint_path: str):
41
+ global MODEL, DEVICE
42
+ DEVICE = torch.device(
43
+ "cuda" if torch.cuda.is_available()
44
+ else "mps" if torch.backends.mps.is_available()
45
+ else "cpu"
46
+ )
47
+ MODEL = ImmunogoldCenterNet(bifpn_channels=128, bifpn_rounds=2)
48
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
49
+ MODEL.load_state_dict(ckpt["model_state_dict"])
50
+ MODEL.to(DEVICE)
51
+ MODEL.eval()
52
+ print(f"Model loaded from {checkpoint_path} on {DEVICE}")
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Core detection function
57
+ # ---------------------------------------------------------------------------
58
+ def detect_particles(
59
+ image_file,
60
+ conf_threshold: float = 0.25,
61
+ nms_6nm: int = 3,
62
+ nms_12nm: int = 5,
63
+ ):
64
+ """Run detection on uploaded image. Returns visualization + data."""
65
+ if MODEL is None:
66
+ return None, None, None, "Model not loaded. Start app with --checkpoint"
67
+
68
+ # Load image
69
+ if isinstance(image_file, str):
70
+ img = tifffile.imread(image_file)
71
+ elif hasattr(image_file, "name"):
72
+ img = tifffile.imread(image_file.name)
73
+ else:
74
+ img = np.array(image_file)
75
+
76
+ if img.ndim == 3:
77
+ img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
78
+ img = img.astype(np.uint8)
79
+
80
+ h, w = img.shape[:2]
81
+
82
+ # Run model
83
+ with torch.no_grad():
84
+ hm_np, off_np = sliding_window_inference(
85
+ MODEL, img, patch_size=512, overlap=128, device=DEVICE,
86
+ )
87
+
88
+ # Extract detections
89
+ dets = extract_peaks(
90
+ torch.from_numpy(hm_np), torch.from_numpy(off_np),
91
+ stride=2, conf_threshold=conf_threshold,
92
+ nms_kernel_sizes={"6nm": nms_6nm, "12nm": nms_12nm},
93
+ )
94
+ dets = cross_class_nms(dets, distance_threshold=8)
95
+
96
+ n_6nm = sum(1 for d in dets if d["class"] == "6nm")
97
+ n_12nm = sum(1 for d in dets if d["class"] == "12nm")
98
+
99
+ # --- Generate visualizations ---
100
+
101
+ # 1. Detection overlay
102
+ from skimage.transform import resize
103
+ hm6_up = resize(hm_np[0], (h, w), order=1)
104
+ hm12_up = resize(hm_np[1], (h, w), order=1)
105
+
106
+ fig_overlay, ax = plt.subplots(figsize=(12, 12))
107
+ ax.imshow(img, cmap="gray")
108
+ for d in dets:
109
+ color = "#00FFFF" if d["class"] == "6nm" else "#FFD700"
110
+ radius = 8 if d["class"] == "6nm" else 14
111
+ circle = plt.Circle(
112
+ (d["x"], d["y"]), radius, fill=False,
113
+ edgecolor=color, linewidth=1.5,
114
+ )
115
+ ax.add_patch(circle)
116
+ ax.set_title(
117
+ f"Detected: {n_6nm} 6nm (cyan) + {n_12nm} 12nm (yellow) = {len(dets)} total",
118
+ fontsize=14, pad=10,
119
+ )
120
+ ax.axis("off")
121
+ plt.tight_layout()
122
+
123
+ # Convert to numpy for Gradio
124
+ fig_overlay.canvas.draw()
125
+ overlay_img = np.array(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
126
+ plt.close(fig_overlay)
127
+
128
+ # 2. Heatmap visualization
129
+ fig_hm, axes = plt.subplots(1, 2, figsize=(16, 7))
130
+ axes[0].imshow(img, cmap="gray")
131
+ axes[0].imshow(hm6_up, cmap="hot", alpha=0.6, vmin=0, vmax=max(0.3, hm6_up.max()))
132
+ axes[0].set_title(f"6nm Heatmap ({n_6nm} particles)", fontsize=13)
133
+ axes[0].axis("off")
134
+
135
+ axes[1].imshow(img, cmap="gray")
136
+ axes[1].imshow(hm12_up, cmap="YlOrRd", alpha=0.6, vmin=0, vmax=max(0.3, hm12_up.max()))
137
+ axes[1].set_title(f"12nm Heatmap ({n_12nm} particles)", fontsize=13)
138
+ axes[1].axis("off")
139
+ plt.tight_layout()
140
+
141
+ fig_hm.canvas.draw()
142
+ heatmap_img = np.array(fig_hm.canvas.renderer.buffer_rgba())[:, :, :3]
143
+ plt.close(fig_hm)
144
+
145
+ # 3. Stats dashboard
146
+ fig_stats, axes = plt.subplots(1, 3, figsize=(18, 5))
147
+
148
+ # Confidence histogram
149
+ if dets:
150
+ confs_6 = [d["conf"] for d in dets if d["class"] == "6nm"]
151
+ confs_12 = [d["conf"] for d in dets if d["class"] == "12nm"]
152
+ if confs_6:
153
+ axes[0].hist(confs_6, bins=20, alpha=0.7, color="#00CCCC", label=f"6nm (n={len(confs_6)})")
154
+ if confs_12:
155
+ axes[0].hist(confs_12, bins=20, alpha=0.7, color="#CCB300", label=f"12nm (n={len(confs_12)})")
156
+ axes[0].axvline(conf_threshold, color="red", linestyle="--", label=f"Threshold={conf_threshold}")
157
+ axes[0].legend(fontsize=9)
158
+ axes[0].set_xlabel("Confidence")
159
+ axes[0].set_ylabel("Count")
160
+ axes[0].set_title("Detection Confidence Distribution")
161
+
162
+ # Spatial distribution
163
+ if dets:
164
+ xs = [d["x"] for d in dets]
165
+ ys = [d["y"] for d in dets]
166
+ colors = ["#00CCCC" if d["class"] == "6nm" else "#CCB300" for d in dets]
167
+ axes[1].scatter(xs, ys, c=colors, s=20, alpha=0.7)
168
+ axes[1].set_xlim(0, w)
169
+ axes[1].set_ylim(h, 0)
170
+ axes[1].set_xlabel("X (pixels)")
171
+ axes[1].set_ylabel("Y (pixels)")
172
+ axes[1].set_title("Spatial Distribution")
173
+ axes[1].set_aspect("equal")
174
+
175
+ # Summary table
176
+ axes[2].axis("off")
177
+ table_data = [
178
+ ["Image size", f"{w} x {h} px"],
179
+ ["Scale", "1790 px/\u00b5m"],
180
+ ["6nm (AMPA)", str(n_6nm)],
181
+ ["12nm (NR1)", str(n_12nm)],
182
+ ["Total", str(len(dets))],
183
+ ["Threshold", f"{conf_threshold:.2f}"],
184
+ ["Mean conf (6nm)", f"{np.mean(confs_6):.3f}" if confs_6 else "N/A"],
185
+ ["Mean conf (12nm)", f"{np.mean(confs_12):.3f}" if confs_12 else "N/A"],
186
+ ]
187
+ table = axes[2].table(
188
+ cellText=table_data, colLabels=["Metric", "Value"],
189
+ loc="center", cellLoc="left",
190
+ )
191
+ table.auto_set_font_size(False)
192
+ table.set_fontsize(11)
193
+ table.scale(1, 1.5)
194
+ axes[2].set_title("Detection Summary")
195
+ plt.tight_layout()
196
+
197
+ fig_stats.canvas.draw()
198
+ stats_img = np.array(fig_stats.canvas.renderer.buffer_rgba())[:, :, :3]
199
+ plt.close(fig_stats)
200
+
201
+ # 4. CSV export
202
+ df = pd.DataFrame([
203
+ {
204
+ "particle_id": i + 1,
205
+ "x_px": round(d["x"], 1),
206
+ "y_px": round(d["y"], 1),
207
+ "x_um": round(d["x"] / 1790, 4),
208
+ "y_um": round(d["y"] / 1790, 4),
209
+ "class": d["class"],
210
+ "confidence": round(d["conf"], 4),
211
+ }
212
+ for i, d in enumerate(dets)
213
+ ])
214
+
215
+ csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w")
216
+ df.to_csv(csv_path.name, index=False)
217
+
218
+ summary = (
219
+ f"## Results\n"
220
+ f"- **6nm (AMPA)**: {n_6nm} particles\n"
221
+ f"- **12nm (NR1)**: {n_12nm} particles\n"
222
+ f"- **Total**: {len(dets)} particles\n"
223
+ f"- **Image**: {w}x{h} px\n"
224
+ )
225
+
226
+ return overlay_img, heatmap_img, stats_img, csv_path.name, summary
227
+
228
+
229
+ # ---------------------------------------------------------------------------
230
+ # Gradio UI
231
+ # ---------------------------------------------------------------------------
232
+ def build_app():
233
+ with gr.Blocks(title="MidasMap - Immunogold Particle Detection") as app:
234
+ gr.Markdown(
235
+ "# MidasMap\n"
236
+ "### Immunogold Particle Detection for TEM Synapse Images\n"
237
+ "Upload an EM image (.tif) to detect 6nm (AMPA) and 12nm (NR1) gold particles."
238
+ )
239
+
240
+ with gr.Row():
241
+ with gr.Column(scale=1):
242
+ image_input = gr.File(
243
+ label="Upload TEM Image (.tif)",
244
+ file_types=[".tif", ".tiff", ".png", ".jpg"],
245
+ )
246
+ conf_slider = gr.Slider(
247
+ minimum=0.05, maximum=0.95, value=0.25, step=0.05,
248
+ label="Confidence Threshold",
249
+ info="Lower = more detections (more FP), Higher = fewer but more certain",
250
+ )
251
+ nms_6nm = gr.Slider(
252
+ minimum=1, maximum=9, value=3, step=2,
253
+ label="NMS Kernel (6nm)",
254
+ info="Min distance between 6nm detections (pixels at stride 2)",
255
+ )
256
+ nms_12nm = gr.Slider(
257
+ minimum=1, maximum=9, value=5, step=2,
258
+ label="NMS Kernel (12nm)",
259
+ )
260
+ detect_btn = gr.Button("Detect Particles", variant="primary", size="lg")
261
+
262
+ with gr.Column(scale=2):
263
+ summary_md = gr.Markdown("Upload an image to begin.")
264
+
265
+ with gr.Tabs():
266
+ with gr.TabItem("Detection Overlay"):
267
+ overlay_output = gr.Image(label="Detected Particles")
268
+ with gr.TabItem("Heatmaps"):
269
+ heatmap_output = gr.Image(label="Class Heatmaps")
270
+ with gr.TabItem("Statistics"):
271
+ stats_output = gr.Image(label="Detection Statistics")
272
+ with gr.TabItem("Export"):
273
+ csv_output = gr.File(label="Download CSV Results")
274
+
275
+ detect_btn.click(
276
+ fn=detect_particles,
277
+ inputs=[image_input, conf_slider, nms_6nm, nms_12nm],
278
+ outputs=[overlay_output, heatmap_output, stats_output, csv_output, summary_md],
279
+ )
280
+
281
+ gr.Markdown(
282
+ "---\n"
283
+ "*MidasMap: CenterNet + CEM500K backbone, trained on 453 labeled particles "
284
+ "across 10 synapses. LOOCV F1 = 0.94.*"
285
+ )
286
+
287
+ return app
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # Main
292
+ # ---------------------------------------------------------------------------
293
+ def main():
294
+ parser = argparse.ArgumentParser()
295
+ parser.add_argument(
296
+ "--checkpoint", default="checkpoints/local_S1_v2/best.pth",
297
+ help="Path to model checkpoint",
298
+ )
299
+ parser.add_argument("--share", action="store_true", help="Create public link")
300
+ parser.add_argument("--port", type=int, default=7860)
301
+ args = parser.parse_args()
302
+
303
+ load_model(args.checkpoint)
304
+ app = build_app()
305
+ app.launch(share=args.share, server_port=args.port)
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()