AnikS22 commited on
Commit
86c24cb
·
verified ·
1 Parent(s): cb40f1e

Deploy MidasMap Gradio app, src, requirements, checkpoint

Browse files
README.md CHANGED
@@ -1,13 +1,38 @@
1
  ---
2
  title: MidasMap
3
- emoji: 🦀
4
- colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.10.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Detects Immunogold particles in EM images
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: MidasMap
3
+ emoji: 🔬
4
+ colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # MidasMap Space
14
+
15
+ This folder is a **template** for creating a [Hugging Face Space](https://huggingface.co/docs/hub/spaces-overview).
16
+
17
+ **Why not Vercel for the model?** Vercel serverless functions have strict size and time limits; they are not meant for PyTorch + a ~100MB checkpoint and multi-second GPU/CPU inference. **Host the Gradio app + weights on a Space** (CPU free tier or GPU upgrade).
18
+
19
+ ## Create the Space
20
+
21
+ 1. On Hugging Face: **New Space** → SDK **Gradio** → name e.g. `MidasMap`.
22
+ 2. Clone the Space repo locally, or connect **GitHub** and set the Space root to this monorepo with **App file** pointing to the copied `app.py`.
23
+ 3. Copy into the Space repository root:
24
+ - `app.py` from the main MidasMap repo (project root), **or** symlink / duplicate.
25
+ - `src/` (entire package)
26
+ - `requirements-space.txt` from this folder as **`requirements.txt`**
27
+ 4. In Space **Settings → Repository secrets** (if needed): none required for public weights.
28
+ 5. Ensure `checkpoints/final/final_model.pth` is present:
29
+ - Upload via **Files** tab, or
30
+ - Add a startup script to download from `AnikS22/MidasMap` on the Hub (see HF docs for `hf_hub_download`).
31
+
32
+ After the Space builds, point your **Vercel** site (`vercel-site`) at it:
33
+
34
+ `https://yoursite.vercel.app/?embed=https://huggingface.co/spaces/YOUR_USER/YOUR_SPACE`
35
+
36
+ ---
37
+
38
+ Gradio app and model logic: [github.com/AnikS22/MidasMap](https://github.com/AnikS22/MidasMap)
app.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MidasMap — Immunogold particle analysis for FFRIL / TEM synapse imaging
3
+
4
+ Web UI for neuroscientists: calibrated coordinates (µm), receptor labels,
5
+ export for quantification, and clear interpretation of model limits.
6
+
7
+ Usage:
8
+ python app.py
9
+ python app.py --checkpoint checkpoints/final/final_model.pth
10
+ python app.py --share
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import os
17
+ import tempfile
18
+ from pathlib import Path
19
+
20
+ import gradio as gr
21
+ import gradio_client.utils as _gcu
22
+
23
+ # Pydantic v2 can emit JSON Schema with additionalProperties: true (bool);
24
+ # Gradio 4.4x gradio_client assumes a dict and crashes rendering "/".
25
+ _orig_json_type = _gcu._json_schema_to_python_type
26
+
27
+
28
+ def _json_schema_to_python_type_safe(schema, defs=None):
29
+ if schema is True or schema is False:
30
+ return "Any"
31
+ if not isinstance(schema, dict):
32
+ return "Any"
33
+ return _orig_json_type(schema, defs)
34
+
35
+
36
+ _gcu._json_schema_to_python_type = _json_schema_to_python_type_safe
37
+
38
+ import matplotlib
39
+
40
+ matplotlib.use("Agg")
41
+ import matplotlib.patheffects as pe
42
+ import matplotlib.pyplot as plt
43
+ from matplotlib.patches import Patch
44
+ import numpy as np
45
+ import pandas as pd
46
+ import torch
47
+ import tifffile
48
+
49
+ from src.ensemble import sliding_window_inference
50
+ from src.heatmap import extract_peaks
51
+ from src.model import ImmunogoldCenterNet
52
+ from src.postprocess import cross_class_nms
53
+
54
+
55
+ # Calibration used for training / published metrics (change in UI if your scope differs)
56
+ DEFAULT_PX_PER_UM = 1790.0
57
+
58
+ plt.rcParams.update(
59
+ {
60
+ "figure.facecolor": "white",
61
+ "figure.dpi": 120,
62
+ "savefig.facecolor": "white",
63
+ "axes.facecolor": "#fafafa",
64
+ "axes.edgecolor": "#cbd5e1",
65
+ "axes.linewidth": 0.8,
66
+ "axes.labelcolor": "#1e293b",
67
+ "axes.titlecolor": "#0f172a",
68
+ "axes.grid": False,
69
+ "xtick.color": "#475569",
70
+ "ytick.color": "#475569",
71
+ "font.size": 10,
72
+ "axes.titlesize": 11,
73
+ "axes.labelsize": 10,
74
+ "legend.frameon": True,
75
+ "legend.framealpha": 0.92,
76
+ "legend.edgecolor": "#e2e8f0",
77
+ }
78
+ )
79
+
80
+
81
+ MODEL = None
82
+ DEVICE = None
83
+
84
+
85
+ def load_model(checkpoint_path: str):
86
+ global MODEL, DEVICE
87
+ DEVICE = torch.device(
88
+ "cuda"
89
+ if torch.cuda.is_available()
90
+ else "mps"
91
+ if torch.backends.mps.is_available()
92
+ else "cpu"
93
+ )
94
+ MODEL = ImmunogoldCenterNet(bifpn_channels=128, bifpn_rounds=2)
95
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
96
+ MODEL.load_state_dict(ckpt["model_state_dict"])
97
+ MODEL.to(DEVICE)
98
+ MODEL.eval()
99
+ print(f"Model loaded from {checkpoint_path} on {DEVICE}")
100
+
101
+
102
+ def _receptor_label(class_name: str) -> str:
103
+ return "AMPA receptor" if class_name == "6nm" else "NR1 (NMDA receptor)"
104
+
105
+
106
+ def _gold_nm(class_name: str) -> int:
107
+ return 6 if class_name == "6nm" else 12
108
+
109
+
110
+ def _pick_scale_bar_um(field_width_um: float) -> float:
111
+ """Pick a readable scale bar (~15–30% of field width)."""
112
+ if field_width_um <= 0:
113
+ return 0.2
114
+ target = field_width_um * 0.22
115
+ candidates = (0.05, 0.1, 0.2, 0.25, 0.5, 1.0, 2.0, 5.0)
116
+ best = candidates[0]
117
+ for c in candidates:
118
+ if abs(c - target) < abs(best - target):
119
+ best = c
120
+ # Keep bar from dominating the field
121
+ while best > 0 and best / field_width_um > 0.45:
122
+ best = max(0.05, best / 2)
123
+ return float(best)
124
+
125
+
126
+ def _draw_scale_bar_um(ax, w: int, h: int, px_per_um: float) -> None:
127
+ field_um = max(w, h) / px_per_um
128
+ bar_um = _pick_scale_bar_um(field_um)
129
+ bar_px = bar_um * px_per_um
130
+ margin = max(12, int(min(w, h) * 0.025))
131
+ y_line = h - margin
132
+ x0, x1 = margin, margin + bar_px
133
+ for lw, color in ((5, "white"), (2, "#0f172a")):
134
+ ax.plot([x0, x1], [y_line, y_line], color=color, linewidth=lw, solid_capstyle="butt", clip_on=False)
135
+ t = ax.text(
136
+ (x0 + x1) / 2,
137
+ y_line - margin * 0.35,
138
+ f"{bar_um:g} µm",
139
+ ha="center",
140
+ va="bottom",
141
+ color="white",
142
+ fontsize=9,
143
+ fontweight="600",
144
+ )
145
+ t.set_path_effects([pe.withStroke(linewidth=2.5, foreground="#0f172a")])
146
+
147
+
148
+ def _export_columns() -> list[str]:
149
+ return [
150
+ "particle_id",
151
+ "receptor",
152
+ "gold_diameter_nm",
153
+ "x_px",
154
+ "y_px",
155
+ "x_um",
156
+ "y_um",
157
+ "confidence",
158
+ "class_model",
159
+ "calibration_px_per_um",
160
+ ]
161
+
162
+
163
+ def _empty_results_df() -> pd.DataFrame:
164
+ return pd.DataFrame(columns=_export_columns())
165
+
166
+
167
+ def _df_to_preview_html(df: pd.DataFrame) -> str:
168
+ if df is None or len(df) == 0:
169
+ return "<p class='mm-table-empty'><em>No particles above the current threshold.</em></p>"
170
+ return df.to_html(
171
+ classes=["mm-table"],
172
+ index=False,
173
+ border=0,
174
+ justify="left",
175
+ escape=True,
176
+ )
177
+
178
+
179
+ def detect_particles(
180
+ image_file,
181
+ conf_threshold: float = 0.25,
182
+ nms_6nm: int = 3,
183
+ nms_12nm: int = 5,
184
+ px_per_um: float = DEFAULT_PX_PER_UM,
185
+ progress=gr.Progress(track_tqdm=False),
186
+ ):
187
+ """Run detection; returns figures, CSV path, table HTML, and summary HTML."""
188
+ empty_table = "<p class='mm-table-empty'><em>Run detection to populate the table.</em></p>"
189
+
190
+ if MODEL is None:
191
+ msg = "<p class='mm-callout mm-callout-warn'>Model not loaded. Use <code>--checkpoint</code> with a valid <code>.pth</code> file.</p>"
192
+ return None, None, None, None, empty_table, msg
193
+
194
+ if image_file is None:
195
+ msg = "<p class='mm-callout'>Upload a micrograph, set calibration if needed, then run detection.</p>"
196
+ return None, None, None, None, empty_table, msg
197
+
198
+ try:
199
+ px_per_um = float(px_per_um)
200
+ except (TypeError, ValueError):
201
+ px_per_um = DEFAULT_PX_PER_UM
202
+ if px_per_um <= 0:
203
+ px_per_um = DEFAULT_PX_PER_UM
204
+
205
+ progress(0.05, desc="Loading image…")
206
+
207
+ if isinstance(image_file, str):
208
+ img = tifffile.imread(image_file)
209
+ elif hasattr(image_file, "name"):
210
+ img = tifffile.imread(image_file.name)
211
+ else:
212
+ img = np.array(image_file)
213
+
214
+ if img.ndim == 3:
215
+ img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
216
+ img = img.astype(np.uint8)
217
+
218
+ h, w = img.shape[:2]
219
+ field_w_um = w / px_per_um
220
+ field_h_um = h / px_per_um
221
+
222
+ progress(0.15, desc="Neural network (sliding window)…")
223
+
224
+ with torch.no_grad():
225
+ hm_np, off_np = sliding_window_inference(
226
+ MODEL,
227
+ img,
228
+ patch_size=512,
229
+ overlap=128,
230
+ device=DEVICE,
231
+ )
232
+
233
+ progress(0.72, desc="Peak extraction & NMS…")
234
+
235
+ dets = extract_peaks(
236
+ torch.from_numpy(hm_np),
237
+ torch.from_numpy(off_np),
238
+ stride=2,
239
+ conf_threshold=conf_threshold,
240
+ nms_kernel_sizes={"6nm": nms_6nm, "12nm": nms_12nm},
241
+ )
242
+ dets = cross_class_nms(dets, distance_threshold=8)
243
+
244
+ n_6nm = sum(1 for d in dets if d["class"] == "6nm")
245
+ n_12nm = sum(1 for d in dets if d["class"] == "12nm")
246
+ confs_6 = [d["conf"] for d in dets if d["class"] == "6nm"]
247
+ confs_12 = [d["conf"] for d in dets if d["class"] == "12nm"]
248
+
249
+ progress(0.78, desc="Rendering figures…")
250
+
251
+ from skimage.transform import resize
252
+
253
+ hm6_up = resize(hm_np[0], (h, w), order=1)
254
+ hm12_up = resize(hm_np[1], (h, w), order=1)
255
+
256
+ # --- Overlay (publication-style legend + scale bar) ---
257
+ fig_overlay, ax = plt.subplots(figsize=(11, 11))
258
+ ax.imshow(img, cmap="gray", aspect="equal")
259
+ for d in dets:
260
+ color = "#06b6d4" if d["class"] == "6nm" else "#ca8a04"
261
+ radius = 7 if d["class"] == "6nm" else 12
262
+ ax.add_patch(
263
+ plt.Circle(
264
+ (d["x"], d["y"]),
265
+ radius,
266
+ fill=False,
267
+ edgecolor=color,
268
+ linewidth=1.8,
269
+ )
270
+ )
271
+ _draw_scale_bar_um(ax, w, h, px_per_um)
272
+ ax.set_title(
273
+ f"Immunogold detections · AMPA (6 nm): {n_6nm} · NR1 (12 nm): {n_12nm} · Total: {len(dets)}",
274
+ fontsize=11,
275
+ pad=12,
276
+ )
277
+ ax.axis("off")
278
+ legend_elems = [
279
+ Patch(facecolor="none", edgecolor="#06b6d4", linewidth=2, label="6 nm gold — AMPA receptor"),
280
+ Patch(facecolor="none", edgecolor="#ca8a04", linewidth=2, label="12 nm gold — NR1 (NMDAR)"),
281
+ ]
282
+ ax.legend(
283
+ handles=legend_elems,
284
+ loc="upper right",
285
+ fontsize=8.5,
286
+ title="Label class",
287
+ title_fontsize=9,
288
+ )
289
+ plt.tight_layout()
290
+ fig_overlay.canvas.draw()
291
+ overlay_img = np.asarray(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
292
+ plt.close(fig_overlay)
293
+
294
+ # --- Heatmaps ---
295
+ fig_hm, axes = plt.subplots(1, 2, figsize=(14, 6.2))
296
+ axes[0].imshow(img, cmap="gray", aspect="equal")
297
+ axes[0].imshow(hm6_up, cmap="magma", alpha=0.55, vmin=0, vmax=max(0.3, float(hm6_up.max())))
298
+ axes[0].set_title(f"AMPA (6 nm) channel · n = {n_6nm}", fontsize=11)
299
+ axes[0].axis("off")
300
+
301
+ axes[1].imshow(img, cmap="gray", aspect="equal")
302
+ axes[1].imshow(hm12_up, cmap="inferno", alpha=0.55, vmin=0, vmax=max(0.3, float(hm12_up.max())))
303
+ axes[1].set_title(f"NR1 (12 nm) channel · n = {n_12nm}", fontsize=11)
304
+ axes[1].axis("off")
305
+ plt.tight_layout()
306
+ fig_hm.canvas.draw()
307
+ heatmap_img = np.asarray(fig_hm.canvas.renderer.buffer_rgba())[:, :, :3]
308
+ plt.close(fig_hm)
309
+
310
+ # --- Stats (µm where helpful) ---
311
+ fig_stats, axes = plt.subplots(1, 3, figsize=(16, 4.8))
312
+ if dets:
313
+ if confs_6:
314
+ axes[0].hist(confs_6, bins=18, alpha=0.75, color="#0891b2", label=f"AMPA (n={len(confs_6)})")
315
+ if confs_12:
316
+ axes[0].hist(confs_12, bins=18, alpha=0.75, color="#a16207", label=f"NR1 (n={len(confs_12)})")
317
+ axes[0].axvline(conf_threshold, color="#be123c", linestyle="--", linewidth=1.2, label=f"Threshold = {conf_threshold:.2f}")
318
+ axes[0].legend(fontsize=8)
319
+ axes[0].set_xlabel("Confidence score")
320
+ axes[0].set_ylabel("Count")
321
+ axes[0].set_title("Score distribution")
322
+ axes[0].spines["top"].set_visible(False)
323
+ axes[0].spines["right"].set_visible(False)
324
+
325
+ if dets:
326
+ xs_um = np.array([d["x"] for d in dets]) / px_per_um
327
+ ys_um = np.array([d["y"] for d in dets]) / px_per_um
328
+ colors = ["#0891b2" if d["class"] == "6nm" else "#a16207" for d in dets]
329
+ axes[1].scatter(xs_um, ys_um, c=colors, s=22, alpha=0.75, edgecolors="none")
330
+ axes[1].set_xlim(0, field_w_um)
331
+ axes[1].set_ylim(field_h_um, 0)
332
+ axes[1].set_xlabel("x (µm)")
333
+ axes[1].set_ylabel("y (µm)")
334
+ axes[1].set_title("Positions (image coordinates)")
335
+ axes[1].set_aspect("equal")
336
+ axes[1].spines["top"].set_visible(False)
337
+ axes[1].spines["right"].set_visible(False)
338
+
339
+ axes[2].axis("off")
340
+ table_data = [
341
+ ["Field of view", f"{field_w_um:.3f} × {field_h_um:.3f} µm"],
342
+ ["Calibration", f"{px_per_um:.1f} px/µm"],
343
+ ["AMPA (6 nm)", str(n_6nm)],
344
+ ["NR1 (12 nm)", str(n_12nm)],
345
+ ["Total particles", str(len(dets))],
346
+ ["Score threshold", f"{conf_threshold:.2f}"],
347
+ ["Mean score · AMPA", f"{float(np.mean(confs_6)):.3f}" if confs_6 else "—"],
348
+ ["Mean score · NR1", f"{float(np.mean(confs_12)):.3f}" if confs_12 else "—"],
349
+ ]
350
+ tbl = axes[2].table(
351
+ cellText=table_data,
352
+ colLabels=["Quantity", "Value"],
353
+ loc="center",
354
+ cellLoc="left",
355
+ )
356
+ tbl.auto_set_font_size(False)
357
+ tbl.set_fontsize(10)
358
+ tbl.scale(1.05, 1.65)
359
+ for (row, col), cell in tbl.get_celld().items():
360
+ if row == 0:
361
+ cell.set_text_props(fontweight="600")
362
+ cell.set_facecolor("#e2e8f0")
363
+ axes[2].set_title("Summary", fontsize=11, pad=12)
364
+ plt.tight_layout()
365
+ fig_stats.canvas.draw()
366
+ stats_img = np.asarray(fig_stats.canvas.renderer.buffer_rgba())[:, :, :3]
367
+ plt.close(fig_stats)
368
+
369
+ rows = []
370
+ for i, d in enumerate(dets):
371
+ rows.append(
372
+ {
373
+ "particle_id": i + 1,
374
+ "receptor": _receptor_label(d["class"]),
375
+ "gold_diameter_nm": _gold_nm(d["class"]),
376
+ "x_px": round(d["x"], 2),
377
+ "y_px": round(d["y"], 2),
378
+ "x_um": round(d["x"] / px_per_um, 5),
379
+ "y_um": round(d["y"] / px_per_um, 5),
380
+ "confidence": round(d["conf"], 4),
381
+ "class_model": d["class"],
382
+ "calibration_px_per_um": round(px_per_um, 4),
383
+ }
384
+ )
385
+ df = pd.DataFrame(rows, columns=_export_columns()) if rows else _empty_results_df()
386
+
387
+ csv_f = tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w", encoding="utf-8")
388
+ df.to_csv(csv_f.name, index=False)
389
+ csv_f.close()
390
+
391
+ progress(1.0, desc="Done")
392
+
393
+ density_note = ""
394
+ if field_w_um > 0 and field_h_um > 0:
395
+ area = field_w_um * field_h_um
396
+ density_note = f"<span class='mm-density'>Areal density (all): {len(dets) / area:.2f} particles/µm² · AMPA: {n_6nm / area:.2f} · NR1: {n_12nm / area:.2f}</span>"
397
+
398
+ summary = f"""<div class="mm-summary">
399
+ <div class="mm-stat"><span class="mm-stat-label">AMPA · 6 nm gold</span>
400
+ <span class="mm-stat-value mm-teal">{n_6nm}</span></div>
401
+ <div class="mm-stat"><span class="mm-stat-label">NR1 · 12 nm gold</span>
402
+ <span class="mm-stat-value mm-amber">{n_12nm}</span></div>
403
+ <div class="mm-stat"><span class="mm-stat-label">Total</span>
404
+ <span class="mm-stat-value">{len(dets)}</span></div>
405
+ <div class="mm-stat mm-stat-wide"><span class="mm-stat-label">Field & calibration</span>
406
+ <span class="mm-stat-meta">{field_w_um:.3f} × {field_h_um:.3f} µm · {px_per_um:.1f} px/µm · {DEVICE}</span></div>
407
+ {density_note and f'<div class="mm-stat mm-stat-wide">{density_note}</div>'}
408
+ </div>"""
409
+
410
+ return overlay_img, heatmap_img, stats_img, csv_f.name, _df_to_preview_html(df), summary
411
+
412
+
413
+ MM_CSS = """
414
+ .gradio-container { max-width: 1320px !important; margin: auto !important; }
415
+ .mm-brand-bar {
416
+ display: flex; align-items: center; justify-content: space-between;
417
+ flex-wrap: wrap; gap: 0.75rem;
418
+ padding: 0.6rem 0 1.25rem;
419
+ border-bottom: 1px solid var(--border-color-primary);
420
+ margin-bottom: 1.25rem;
421
+ }
422
+ .mm-brand-bar span {
423
+ font-size: 0.72rem; letter-spacing: 0.14em; text-transform: uppercase;
424
+ color: var(--body-text-color-subdued); font-weight: 600;
425
+ }
426
+ .mm-hero {
427
+ padding: 1.5rem 1.35rem 1.35rem;
428
+ margin-bottom: 0.25rem;
429
+ border-radius: 10px;
430
+ background: linear-gradient(145deg, #0c4a6e22 0%, #0f172a 48%, #1e1b4b33 100%);
431
+ border: 1px solid #33415588;
432
+ }
433
+ .mm-hero h1 {
434
+ font-family: "Libre Baskerville", Georgia, serif;
435
+ font-weight: 700;
436
+ letter-spacing: -0.02em;
437
+ margin: 0 0 0.4rem 0;
438
+ font-size: 1.65rem;
439
+ color: #f1f5f9;
440
+ }
441
+ .mm-hero .mm-sub {
442
+ margin: 0 0 0.85rem 0;
443
+ color: #94a3b8;
444
+ font-size: 0.92rem;
445
+ line-height: 1.55;
446
+ max-width: 58ch;
447
+ }
448
+ .mm-badge-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
449
+ .mm-badge {
450
+ font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.07em;
451
+ padding: 0.2rem 0.5rem; border-radius: 4px;
452
+ background: #0e749033; color: #99f6e4; border: 1px solid #14b8a644;
453
+ }
454
+ .mm-layout { display: flex; gap: 1.25rem; align-items: flex-start; flex-wrap: wrap; }
455
+ .mm-sidebar {
456
+ flex: 1 1 280px; max-width: 340px;
457
+ padding: 1rem 1.1rem; border-radius: 10px;
458
+ border: 1px solid var(--border-color-primary);
459
+ background: var(--block-background-fill);
460
+ }
461
+ .mm-main { flex: 3 1 520px; min-width: 0; }
462
+ .mm-panel-title {
463
+ font-size: 0.7rem; text-transform: uppercase; letter-spacing: 0.1em;
464
+ color: var(--body-text-color-subdued); font-weight: 600; margin: 0 0 0.65rem 0;
465
+ }
466
+ .mm-callout {
467
+ margin: 0; padding: 0.75rem 0.9rem; border-radius: 8px;
468
+ background: #1e293b66; border: 1px solid var(--border-color-primary);
469
+ font-size: 0.88rem; line-height: 1.45; color: var(--body-text-color);
470
+ }
471
+ .mm-callout-warn { border-color: #f59e0b55; background: #78350f22; }
472
+ .mm-science {
473
+ margin-top: 1rem; font-size: 0.82rem; line-height: 1.5;
474
+ color: var(--body-text-color-subdued);
475
+ }
476
+ .mm-science h4 { margin: 0.5rem 0 0.35rem; font-size: 0.78rem; text-transform: uppercase; letter-spacing: 0.06em; color: #94a3b8; }
477
+ .mm-science ul { margin: 0.25rem 0 0 1rem; padding: 0; }
478
+ .mm-summary { display: flex; flex-wrap: wrap; gap: 0.65rem; margin: 0 0 1rem 0; }
479
+ .mm-stat {
480
+ flex: 1 1 118px; padding: 0.75rem 0.95rem; border-radius: 8px;
481
+ background: var(--block-background-fill);
482
+ border: 1px solid var(--border-color-primary);
483
+ }
484
+ .mm-stat-wide { flex: 1 1 100%; }
485
+ .mm-stat-label {
486
+ display: block; font-size: 0.68rem; text-transform: uppercase;
487
+ letter-spacing: 0.06em; opacity: 0.72; margin-bottom: 0.2rem;
488
+ }
489
+ .mm-stat-value { font-size: 1.4rem; font-weight: 700; font-variant-numeric: tabular-nums; letter-spacing: -0.02em; }
490
+ .mm-stat-value.mm-teal { color: #2dd4bf; }
491
+ .mm-stat-value.mm-amber { color: #fbbf24; }
492
+ .mm-stat-meta { font-size: 0.84rem; opacity: 0.92; line-height: 1.35; }
493
+ .mm-density { font-size: 0.84rem; opacity: 0.9; }
494
+ table.mm-table {
495
+ width: 100%; border-collapse: collapse; font-size: 0.82rem;
496
+ margin: 0.25rem 0 0.75rem 0;
497
+ }
498
+ table.mm-table th {
499
+ text-align: left; padding: 0.45rem 0.5rem;
500
+ border-bottom: 1px solid var(--border-color-primary);
501
+ color: var(--body-text-color-subdued); font-weight: 600;
502
+ }
503
+ table.mm-table td { padding: 0.35rem 0.5rem; border-bottom: 1px solid #33415544; }
504
+ .mm-table-empty { margin: 0.5rem 0; opacity: 0.75; font-size: 0.9rem; }
505
+ .mm-foot {
506
+ margin-top: 2rem; padding-top: 1rem;
507
+ border-top: 1px solid var(--border-color-primary);
508
+ font-size: 0.78rem; line-height: 1.45;
509
+ color: var(--body-text-color-subdued);
510
+ }
511
+ .mm-foot code { font-size: 0.76rem; }
512
+ """
513
+
514
+
515
+ def build_app():
516
+ theme = gr.themes.Soft(
517
+ primary_hue=gr.themes.Color(
518
+ c50="#f0fdfa",
519
+ c100="#ccfbf1",
520
+ c200="#99f6e4",
521
+ c300="#5eead4",
522
+ c400="#2dd4bf",
523
+ c500="#14b8a6",
524
+ c600="#0d9488",
525
+ c700="#0f766e",
526
+ c800="#115e59",
527
+ c900="#134e4a",
528
+ c950="#042f2e",
529
+ ),
530
+ neutral_hue=gr.themes.colors.slate,
531
+ font=("Source Sans 3", "ui-sans-serif", "system-ui", "sans-serif"),
532
+ font_mono=("IBM Plex Mono", "ui-monospace", "monospace"),
533
+ ).set(
534
+ body_background_fill_dark="*neutral_950",
535
+ block_background_fill_dark="*neutral_900",
536
+ border_color_primary="*neutral_700",
537
+ button_primary_background_fill="*primary_600",
538
+ button_primary_background_fill_hover="*primary_500",
539
+ block_label_text_size="*text_sm",
540
+ )
541
+
542
+ head = """
543
+ <link href="https://fonts.googleapis.com/css2?family=Libre+Baskerville:wght@700&family=Source+Sans+3:wght@400;600;700&display=swap" rel="stylesheet">
544
+ """
545
+
546
+ with gr.Blocks(
547
+ title="MidasMap — Immunogold analysis",
548
+ theme=theme,
549
+ css=MM_CSS,
550
+ head=head,
551
+ ) as app:
552
+ gr.HTML(
553
+ """
554
+ <div class="mm-brand-bar">
555
+ <span>Quantitative EM · synapse immunogold</span>
556
+ <span>Research use · validate critical counts manually</span>
557
+ </div>
558
+ <div class="mm-hero">
559
+ <h1>MidasMap</h1>
560
+ <p class="mm-sub">
561
+ Automated particle picking for <strong>freeze-fracture replica immunolabeling (FFRIL)</strong> TEM:
562
+ <strong>6 nm</strong> gold (AMPA receptors) and <strong>12 nm</strong> gold (NR1 / NMDA receptors).
563
+ Coordinates export in <strong>µm</strong> for comparison to physiology and super-resolution data—set calibration to match your microscope.
564
+ </p>
565
+ <div class="mm-badge-row">
566
+ <span class="mm-badge">FFRIL / TEM</span>
567
+ <span class="mm-badge">CenterNet</span>
568
+ <span class="mm-badge">CEM500K backbone</span>
569
+ <span class="mm-badge">LOOCV F1 ≈ 0.94</span>
570
+ </div>
571
+ </div>
572
+ """
573
+ )
574
+
575
+ with gr.Row(elem_classes=["mm-layout"]):
576
+ with gr.Column(elem_classes=["mm-sidebar"]):
577
+ gr.HTML('<p class="mm-panel-title">Micrograph & calibration</p>')
578
+ image_input = gr.File(
579
+ label="Upload image",
580
+ file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"],
581
+ )
582
+ px_per_um_in = gr.Number(
583
+ value=DEFAULT_PX_PER_UM,
584
+ label="Calibration (pixels per µm)",
585
+ info=f"Default {DEFAULT_PX_PER_UM:.0f} matches the published training set. "
586
+ "Update if your acquisition scale differs.",
587
+ minimum=1,
588
+ maximum=1e6,
589
+ )
590
+ conf_slider = gr.Slider(
591
+ minimum=0.05,
592
+ maximum=0.95,
593
+ value=0.25,
594
+ step=0.05,
595
+ label="Confidence threshold",
596
+ info="Higher → fewer, sharper peaks. Lower → recall with more false positives.",
597
+ )
598
+ with gr.Accordion("Advanced · non-max suppression", open=False):
599
+ nms_6nm = gr.Slider(
600
+ minimum=1,
601
+ maximum=9,
602
+ value=3,
603
+ step=2,
604
+ label="NMS · 6 nm channel",
605
+ info="Minimum spacing between AMPA peaks on the heatmap grid.",
606
+ )
607
+ nms_12nm = gr.Slider(
608
+ minimum=1,
609
+ maximum=9,
610
+ value=5,
611
+ step=2,
612
+ label="NMS · 12 nm channel",
613
+ )
614
+ detect_btn = gr.Button("Run detection", variant="primary", size="lg")
615
+
616
+ with gr.Accordion("For neuroscientists — interpretation", open=False):
617
+ gr.Markdown(
618
+ """
619
+ #### What the model outputs
620
+ - **Circles** mark predicted gold centers; **scores** are CNN confidences, not p-values.
621
+ - **AMPA** = 6 nm class; **NR1** = 12 nm class (NMDA receptor subunit). Verify ambiguous sites on the raw image.
622
+
623
+ #### When to trust it
624
+ - Trained on **10 FFRIL synapse images** (453 hand-placed particles). Expect best performance on **similar prep, contrast, and magnification**.
625
+ - **Always spot-check** counts used for publication, especially near membranes and dense clusters.
626
+
627
+ #### Coordinates & CSV
628
+ - **x, y** follow image pixel order (origin top-left). **µm** columns use your calibration above.
629
+ - CSV includes **receptor**, **gold diameter**, and **calibration** used for provenance.
630
+
631
+ #### Citation
632
+ Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
633
+ """
634
+ )
635
+
636
+ with gr.Column(elem_classes=["mm-main"]):
637
+ summary_md = gr.HTML(
638
+ value="<p class='mm-callout'>Upload a synapse micrograph to begin. Adjust calibration before export if your scale differs from the default.</p>"
639
+ )
640
+ with gr.Tabs():
641
+ with gr.Tab("Overlay"):
642
+ overlay_output = gr.Image(
643
+ label="Detections + scale bar",
644
+ type="numpy",
645
+ height=540,
646
+ )
647
+ with gr.Tab("Heatmaps"):
648
+ heatmap_output = gr.Image(
649
+ label="Class-specific maps",
650
+ type="numpy",
651
+ height=540,
652
+ )
653
+ with gr.Tab("Quant summary"):
654
+ stats_output = gr.Image(
655
+ label="Distributions & table",
656
+ type="numpy",
657
+ height=440,
658
+ )
659
+ with gr.Tab("Table & export"):
660
+ table_output = gr.HTML(
661
+ label="Detections (preview)",
662
+ value="<p class='mm-table-empty'><em>Results appear here after detection.</em></p>",
663
+ )
664
+ csv_output = gr.File(label="Download CSV")
665
+
666
+ gr.HTML(
667
+ f"""
668
+ <div class="mm-foot">
669
+ <strong>Training context:</strong> LOOCV mean F1 ≈ 0.94 on eight well-annotated folds;
670
+ raw grayscale input (avoid heavy filtering). Not a clinical device.
671
+ Model weights: <code>checkpoints/final/final_model.pth</code> or
672
+ <a href="https://huggingface.co/AnikS22/MidasMap" target="_blank" rel="noopener">Hugging Face</a>.
673
+ </div>
674
+ """
675
+ )
676
+
677
+ detect_btn.click(
678
+ fn=detect_particles,
679
+ inputs=[image_input, conf_slider, nms_6nm, nms_12nm, px_per_um_in],
680
+ outputs=[
681
+ overlay_output,
682
+ heatmap_output,
683
+ stats_output,
684
+ csv_output,
685
+ table_output,
686
+ summary_md,
687
+ ],
688
+ )
689
+
690
+ return app
691
+
692
+
693
+ def main():
694
+ parser = argparse.ArgumentParser(description="MidasMap web dashboard")
695
+ parser.add_argument(
696
+ "--checkpoint",
697
+ type=str,
698
+ default="checkpoints/final/final_model.pth",
699
+ help="Path to trained checkpoint (.pth)",
700
+ )
701
+ parser.add_argument("--share", action="store_true", help="Gradio public share link (use if localhost is blocked)")
702
+ parser.add_argument(
703
+ "--server-name",
704
+ type=str,
705
+ default=None,
706
+ metavar="HOST",
707
+ help='Bind address, e.g. 0.0.0.0 for LAN (default: 127.0.0.1)',
708
+ )
709
+ parser.add_argument("--port", type=int, default=7860)
710
+ args = parser.parse_args()
711
+
712
+ if os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes"):
713
+ args.share = True
714
+
715
+ ckpt = Path(args.checkpoint)
716
+ if not ckpt.is_file():
717
+ raise SystemExit(
718
+ f"Checkpoint not found: {ckpt}\n"
719
+ "Train with train_final.py or download from Hugging Face:\n"
720
+ " huggingface-cli download AnikS22/MidasMap checkpoints/final/final_model.pth "
721
+ "--local-dir ."
722
+ )
723
+
724
+ load_model(str(ckpt))
725
+ demo = build_app()
726
+ launch_kw = dict(
727
+ share=args.share,
728
+ server_port=args.port,
729
+ server_name=args.server_name,
730
+ show_api=False,
731
+ inbrowser=False,
732
+ )
733
+ try:
734
+ demo.launch(**launch_kw)
735
+ except ValueError as err:
736
+ if (
737
+ "localhost is not accessible" in str(err)
738
+ and not launch_kw.get("share")
739
+ and os.environ.get("GRADIO_SHARE", "").lower() not in ("1", "true", "yes")
740
+ ):
741
+ print(
742
+ "Localhost check failed in this environment; starting with share=True "
743
+ "(Gradio tunnel). Use --share next time, or set GRADIO_SHARE=1."
744
+ )
745
+ build_app().launch(**{**launch_kw, "share": True})
746
+ else:
747
+ raise
748
+
749
+
750
+ if __name__ == "__main__":
751
+ main()
checkpoints/final/final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735d37e839019318cb0e4c7e40d99194abd59f57efd8594ca51602ce3451dfb6
3
+ size 98043418
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pin versions compatible with HF Spaces (adjust if Space build fails).
2
+ # Rename to requirements.txt in the Space repo root.
3
+ numpy>=1.24,<2
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+ scipy>=1.10.0
7
+ scikit-image>=0.21.0
8
+ matplotlib>=3.7.0
9
+ tifffile>=2023.4.0
10
+ pandas>=2.0.0
11
+ PyYAML>=6.0
12
+ albumentations>=1.3.0
13
+ opencv-python-headless>=4.7.0
14
+ gradio==4.44.1
15
+ huggingface_hub>=0.20.0,<0.25.0
16
+ tqdm>=4.65.0
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Immunogold particle detection system for TEM images."""
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (224 Bytes). View file
 
src/__pycache__/ensemble.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
src/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (8.73 kB). View file
 
src/__pycache__/heatmap.cpython-311.pyc ADDED
Binary file (7.64 kB). View file
 
src/__pycache__/loss.cpython-311.pyc ADDED
Binary file (5.31 kB). View file
 
src/__pycache__/model.cpython-311.pyc ADDED
Binary file (22.3 kB). View file
 
src/dataset.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Dataset for immunogold particle detection.
3
+
4
+ Implements patch-based training with:
5
+ - 70% hard mining (patches centered near particles)
6
+ - 30% random patches (background recognition)
7
+ - Copy-paste augmentation with Gaussian-blended bead bank
8
+ - Albumentations pipeline with keypoint co-transforms
9
+ """
10
+
11
+ import random
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple
14
+
15
+ import albumentations as A
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ from torch.utils.data import Dataset
20
+
21
+ from src.heatmap import generate_heatmap_gt
22
+ from src.preprocessing import (
23
+ SynapseRecord,
24
+ load_all_annotations,
25
+ load_image,
26
+ load_mask,
27
+ )
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Augmentation pipeline
32
+ # ---------------------------------------------------------------------------
33
+
34
+ def get_train_augmentation() -> A.Compose:
35
+ """
36
+ Training augmentation pipeline.
37
+
38
+ Conservative intensity limits: contrast delta is only 11-39 units on uint8.
39
+ DO NOT use Cutout/Mixup/JPEG artifacts — they destroy or mimic particles.
40
+ """
41
+ return A.Compose(
42
+ [
43
+ # Geometric (co-transform keypoints)
44
+ A.RandomRotate90(p=1.0), # EM is rotation invariant
45
+ A.HorizontalFlip(p=0.5),
46
+ A.VerticalFlip(p=0.5),
47
+ # Only ±10° to avoid interpolation artifacts that destroy contrast
48
+ A.Rotate(
49
+ limit=10,
50
+ border_mode=cv2.BORDER_REFLECT_101,
51
+ p=0.5,
52
+ ),
53
+ # Mild elastic deformation (simulates section flatness variation)
54
+ A.ElasticTransform(alpha=30, sigma=5, p=0.3),
55
+ # Intensity (image only)
56
+ A.RandomBrightnessContrast(
57
+ brightness_limit=0.08, # NOT default 0.2
58
+ contrast_limit=0.08,
59
+ p=0.7,
60
+ ),
61
+ # EM shot noise simulation
62
+ A.GaussNoise(p=0.5),
63
+ # Mild blur — simulate slight defocus
64
+ A.GaussianBlur(blur_limit=(3, 3), p=0.2),
65
+ ],
66
+ keypoint_params=A.KeypointParams(
67
+ format="xy",
68
+ remove_invisible=True,
69
+ label_fields=["class_labels"],
70
+ ),
71
+ )
72
+
73
+
74
+ def get_val_augmentation() -> A.Compose:
75
+ """No augmentation for validation — identity transform."""
76
+ return A.Compose(
77
+ [],
78
+ keypoint_params=A.KeypointParams(
79
+ format="xy",
80
+ remove_invisible=True,
81
+ label_fields=["class_labels"],
82
+ ),
83
+ )
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Bead bank for copy-paste augmentation
88
+ # ---------------------------------------------------------------------------
89
+
90
+ class BeadBank:
91
+ """
92
+ Pre-extracted particle crops for copy-paste augmentation.
93
+
94
+ Stores small patches centered on annotated particles from training
95
+ images. During training, random beads are pasted onto patches to
96
+ increase particle density and address class imbalance.
97
+ """
98
+
99
+ def __init__(self):
100
+ self.crops: Dict[str, List[Tuple[np.ndarray, int]]] = {
101
+ "6nm": [],
102
+ "12nm": [],
103
+ }
104
+ self.crop_sizes = {"6nm": 32, "12nm": 48}
105
+
106
+ def extract_from_image(
107
+ self,
108
+ image: np.ndarray,
109
+ annotations: Dict[str, np.ndarray],
110
+ ):
111
+ """Extract bead crops from a training image."""
112
+ h, w = image.shape[:2]
113
+
114
+ for cls, coords in annotations.items():
115
+ crop_size = self.crop_sizes[cls]
116
+ half = crop_size // 2
117
+
118
+ for x, y in coords:
119
+ xi, yi = int(round(x)), int(round(y))
120
+ # Skip if too close to edge
121
+ if yi - half < 0 or yi + half > h or xi - half < 0 or xi + half > w:
122
+ continue
123
+
124
+ crop = image[yi - half : yi + half, xi - half : xi + half].copy()
125
+ if crop.shape == (crop_size, crop_size):
126
+ self.crops[cls].append((crop, half))
127
+
128
+ def paste_beads(
129
+ self,
130
+ image: np.ndarray,
131
+ coords_6nm: List[Tuple[float, float]],
132
+ coords_12nm: List[Tuple[float, float]],
133
+ class_labels: List[str],
134
+ mask: Optional[np.ndarray] = None,
135
+ n_paste_per_class: int = 5,
136
+ rng: Optional[np.random.Generator] = None,
137
+ ) -> Tuple[np.ndarray, List[Tuple[float, float]], List[Tuple[float, float]], List[str]]:
138
+ """
139
+ Paste random beads onto image with Gaussian alpha blending.
140
+
141
+ Returns augmented image and updated coordinate lists.
142
+ """
143
+ if rng is None:
144
+ rng = np.random.default_rng()
145
+
146
+ image = image.copy()
147
+ h, w = image.shape[:2]
148
+ new_coords_6nm = list(coords_6nm)
149
+ new_coords_12nm = list(coords_12nm)
150
+ new_labels = list(class_labels)
151
+
152
+ for cls in ["6nm", "12nm"]:
153
+ if not self.crops[cls]:
154
+ continue
155
+
156
+ crop_size = self.crop_sizes[cls]
157
+ half = crop_size // 2
158
+ n_paste = min(n_paste_per_class, len(self.crops[cls]))
159
+
160
+ for _ in range(n_paste):
161
+ # Random paste location (within image bounds)
162
+ px = rng.integers(half + 5, w - half - 5)
163
+ py = rng.integers(half + 5, h - half - 5)
164
+
165
+ # Skip if outside tissue mask
166
+ if mask is not None:
167
+ if py >= mask.shape[0] or px >= mask.shape[1] or not mask[py, px]:
168
+ continue
169
+
170
+ # Check minimum distance from existing particles (avoid overlap)
171
+ too_close = False
172
+ all_existing = new_coords_6nm + new_coords_12nm
173
+ for ex, ey in all_existing:
174
+ if (ex - px) ** 2 + (ey - py) ** 2 < (half * 1.5) ** 2:
175
+ too_close = True
176
+ break
177
+ if too_close:
178
+ continue
179
+
180
+ # Select random crop
181
+ crop, _ = self.crops[cls][rng.integers(len(self.crops[cls]))]
182
+
183
+ # Gaussian alpha mask for soft blending
184
+ yy, xx = np.mgrid[:crop_size, :crop_size]
185
+ center = crop_size / 2
186
+ sigma = half * 0.7
187
+ alpha = np.exp(-((xx - center) ** 2 + (yy - center) ** 2) / (2 * sigma ** 2))
188
+
189
+ # Blend
190
+ region = image[py - half : py + half, px - half : px + half]
191
+ if region.shape != crop.shape:
192
+ continue
193
+ blended = (alpha * crop + (1 - alpha) * region).astype(np.uint8)
194
+ image[py - half : py + half, px - half : px + half] = blended
195
+
196
+ # Add to annotations
197
+ if cls == "6nm":
198
+ new_coords_6nm.append((float(px), float(py)))
199
+ else:
200
+ new_coords_12nm.append((float(px), float(py)))
201
+ new_labels.append(cls)
202
+
203
+ return image, new_coords_6nm, new_coords_12nm, new_labels
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # Dataset
208
+ # ---------------------------------------------------------------------------
209
+
210
+ class ImmunogoldDataset(Dataset):
211
+ """
212
+ Patch-based dataset for immunogold particle detection.
213
+
214
+ Sampling strategy:
215
+ - 70% of patches centered within 100px of a known particle (hard mining)
216
+ - 30% of patches at random locations (background recognition)
217
+
218
+ This ensures the model sees particles in nearly every batch despite
219
+ particles occupying <0.1% of image area.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ records: List[SynapseRecord],
225
+ fold_id: str,
226
+ mode: str = "train",
227
+ patch_size: int = 512,
228
+ stride: int = 2,
229
+ hard_mining_fraction: float = 0.7,
230
+ copy_paste_per_class: int = 5,
231
+ sigmas: Optional[Dict[str, float]] = None,
232
+ samples_per_epoch: int = 200,
233
+ seed: int = 42,
234
+ ):
235
+ """
236
+ Args:
237
+ records: all SynapseRecord entries
238
+ fold_id: synapse_id to hold out (test set)
239
+ mode: 'train' or 'val'
240
+ patch_size: training patch size
241
+ stride: model output stride
242
+ hard_mining_fraction: fraction of patches near particles
243
+ copy_paste_per_class: beads to paste per class
244
+ sigmas: heatmap Gaussian sigmas per class
245
+ samples_per_epoch: virtual epoch size
246
+ seed: random seed
247
+ """
248
+ super().__init__()
249
+ self.patch_size = patch_size
250
+ self.stride = stride
251
+ self.hard_mining_fraction = hard_mining_fraction
252
+ self.copy_paste_per_class = copy_paste_per_class if mode == "train" else 0
253
+ self.sigmas = sigmas or {"6nm": 1.0, "12nm": 1.5}
254
+ self.samples_per_epoch = samples_per_epoch
255
+ self.mode = mode
256
+ self._base_seed = seed
257
+ self.rng = np.random.default_rng(seed)
258
+
259
+ # Split records
260
+ if mode == "train":
261
+ self.records = [r for r in records if r.synapse_id != fold_id]
262
+ elif mode == "val":
263
+ self.records = [r for r in records if r.synapse_id == fold_id]
264
+ else:
265
+ self.records = records
266
+
267
+ # Pre-load all images and annotations into memory (~4MB each × 10 = 40MB)
268
+ self.images = {}
269
+ self.masks = {}
270
+ self.annotations = {}
271
+
272
+ for record in self.records:
273
+ sid = record.synapse_id
274
+ self.images[sid] = load_image(record.image_path)
275
+ if record.mask_path:
276
+ self.masks[sid] = load_mask(record.mask_path)
277
+ self.annotations[sid] = load_all_annotations(record, self.images[sid].shape)
278
+
279
+ # Build particle index for hard mining
280
+ self._build_particle_index()
281
+
282
+ # Build bead bank for copy-paste
283
+ self.bead_bank = BeadBank()
284
+ if mode == "train":
285
+ for sid in self.images:
286
+ self.bead_bank.extract_from_image(
287
+ self.images[sid], self.annotations[sid]
288
+ )
289
+
290
+ # Augmentation
291
+ if mode == "train":
292
+ self.transform = get_train_augmentation()
293
+ else:
294
+ self.transform = get_val_augmentation()
295
+
296
+ def _build_particle_index(self):
297
+ """Build flat index of all particles for hard mining."""
298
+ self.particle_list = [] # (synapse_id, x, y, class)
299
+ for sid, annots in self.annotations.items():
300
+ for cls in ["6nm", "12nm"]:
301
+ for x, y in annots[cls]:
302
+ self.particle_list.append((sid, x, y, cls))
303
+
304
+ @staticmethod
305
+ def worker_init_fn(worker_id: int):
306
+ """Re-seed RNG per DataLoader worker to avoid identical sequences."""
307
+ import torch
308
+ seed = torch.initial_seed() % (2**32) + worker_id
309
+ np.random.seed(seed)
310
+
311
+ def __len__(self) -> int:
312
+ return self.samples_per_epoch
313
+
314
+ def __getitem__(self, idx: int) -> dict:
315
+ # Reseed RNG using idx so each call produces a unique patch.
316
+ # Without this, the same 200 patches repeat every epoch → instant overfitting.
317
+ self.rng = np.random.default_rng(self._base_seed + idx + int(torch.initial_seed() % 100000))
318
+ """
319
+ Sample a patch with ground truth heatmap.
320
+
321
+ Returns dict with:
322
+ 'image': (1, patch_size, patch_size) float32 tensor
323
+ 'heatmap': (2, patch_size//stride, patch_size//stride) float32
324
+ 'offsets': (2, patch_size//stride, patch_size//stride) float32
325
+ 'offset_mask': (patch_size//stride, patch_size//stride) bool
326
+ 'conf_map': (2, patch_size//stride, patch_size//stride) float32
327
+ """
328
+ # Decide: hard or random patch
329
+ do_hard = (self.rng.random() < self.hard_mining_fraction
330
+ and len(self.particle_list) > 0
331
+ and self.mode == "train")
332
+
333
+ if do_hard:
334
+ # Pick random particle, center patch on it with jitter
335
+ pidx = self.rng.integers(len(self.particle_list))
336
+ sid, px, py, _ = self.particle_list[pidx]
337
+ # Jitter center up to 128px
338
+ jitter = 128
339
+ cx = int(px + self.rng.integers(-jitter, jitter + 1))
340
+ cy = int(py + self.rng.integers(-jitter, jitter + 1))
341
+ else:
342
+ # Random image and location
343
+ sid = list(self.images.keys())[
344
+ self.rng.integers(len(self.images))
345
+ ]
346
+ h, w = self.images[sid].shape[:2]
347
+ cx = self.rng.integers(self.patch_size // 2, w - self.patch_size // 2)
348
+ cy = self.rng.integers(self.patch_size // 2, h - self.patch_size // 2)
349
+
350
+ # Extract patch
351
+ image = self.images[sid]
352
+ h, w = image.shape[:2]
353
+ half = self.patch_size // 2
354
+
355
+ # Clamp to image bounds
356
+ cx = max(half, min(w - half, cx))
357
+ cy = max(half, min(h - half, cy))
358
+
359
+ x0, x1 = cx - half, cx + half
360
+ y0, y1 = cy - half, cy + half
361
+
362
+ patch = image[y0:y1, x0:x1].copy()
363
+
364
+ # Pad if needed (edge cases)
365
+ if patch.shape[0] != self.patch_size or patch.shape[1] != self.patch_size:
366
+ padded = np.zeros((self.patch_size, self.patch_size), dtype=np.uint8)
367
+ ph, pw = patch.shape[:2]
368
+ padded[:ph, :pw] = patch
369
+ patch = padded
370
+
371
+ # Get annotations within this patch (convert to patch-local coordinates)
372
+ keypoints = []
373
+ class_labels = []
374
+ for cls in ["6nm", "12nm"]:
375
+ for ax, ay in self.annotations[sid][cls]:
376
+ # Convert to patch-local coords
377
+ lx = ax - x0
378
+ ly = ay - y0
379
+ if 0 <= lx < self.patch_size and 0 <= ly < self.patch_size:
380
+ keypoints.append((lx, ly))
381
+ class_labels.append(cls)
382
+
383
+ # Copy-paste augmentation (before geometric transforms)
384
+ if self.copy_paste_per_class > 0 and self.mode == "train":
385
+ local_6nm = [(x, y) for (x, y), c in zip(keypoints, class_labels) if c == "6nm"]
386
+ local_12nm = [(x, y) for (x, y), c in zip(keypoints, class_labels) if c == "12nm"]
387
+ mask_patch = None
388
+ if sid in self.masks:
389
+ mask_patch = self.masks[sid][y0:y1, x0:x1]
390
+
391
+ patch, local_6nm, local_12nm, class_labels = self.bead_bank.paste_beads(
392
+ patch, local_6nm, local_12nm, class_labels,
393
+ mask=mask_patch,
394
+ n_paste_per_class=self.copy_paste_per_class,
395
+ rng=self.rng,
396
+ )
397
+ # Rebuild keypoints from updated coords
398
+ keypoints = [(x, y) for x, y in local_6nm] + [(x, y) for x, y in local_12nm]
399
+ class_labels = ["6nm"] * len(local_6nm) + ["12nm"] * len(local_12nm)
400
+
401
+ # Apply augmentation (co-transforms keypoints)
402
+ transformed = self.transform(
403
+ image=patch,
404
+ keypoints=keypoints,
405
+ class_labels=class_labels,
406
+ )
407
+ patch_aug = transformed["image"]
408
+ kp_aug = transformed["keypoints"]
409
+ cl_aug = transformed["class_labels"]
410
+
411
+ # Separate keypoints by class
412
+ coords_6nm = np.array(
413
+ [(x, y) for (x, y), c in zip(kp_aug, cl_aug) if c == "6nm"],
414
+ dtype=np.float64,
415
+ ).reshape(-1, 2)
416
+ coords_12nm = np.array(
417
+ [(x, y) for (x, y), c in zip(kp_aug, cl_aug) if c == "12nm"],
418
+ dtype=np.float64,
419
+ ).reshape(-1, 2)
420
+
421
+ # Generate heatmap GT from TRANSFORMED coordinates (never warp heatmap)
422
+ heatmap, offsets, offset_mask, conf_map = generate_heatmap_gt(
423
+ coords_6nm, coords_12nm,
424
+ self.patch_size, self.patch_size,
425
+ sigmas=self.sigmas,
426
+ stride=self.stride,
427
+ )
428
+
429
+ # Convert to tensors
430
+ patch_tensor = torch.from_numpy(patch_aug).float().unsqueeze(0) / 255.0
431
+
432
+ return {
433
+ "image": patch_tensor,
434
+ "heatmap": torch.from_numpy(heatmap),
435
+ "offsets": torch.from_numpy(offsets),
436
+ "offset_mask": torch.from_numpy(offset_mask),
437
+ "conf_map": torch.from_numpy(conf_map),
438
+ }
src/ensemble.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-time augmentation (D4 dihedral group) and model ensemble averaging.
3
+
4
+ D4 TTA: 4 rotations x 2 reflections = 8 geometric views
5
+ + 2 intensity variants = 10 total forward passes.
6
+ Gold beads are rotationally invariant — D4 TTA is maximally effective.
7
+ Expected F1 gain: +1-3% over single forward pass.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from typing import List, Optional
14
+
15
+ from src.model import ImmunogoldCenterNet
16
+
17
+
18
+ def d4_tta_predict(
19
+ model: ImmunogoldCenterNet,
20
+ image: np.ndarray,
21
+ device: torch.device = torch.device("cpu"),
22
+ ) -> tuple:
23
+ """
24
+ Test-time augmentation over D4 dihedral group + intensity variants.
25
+
26
+ Args:
27
+ model: trained CenterNet model
28
+ image: (H, W) uint8 preprocessed image
29
+ device: torch device
30
+
31
+ Returns:
32
+ averaged_heatmap: (2, H/2, W/2) numpy array
33
+ averaged_offsets: (2, H/2, W/2) numpy array
34
+ """
35
+ model.eval()
36
+ heatmaps = []
37
+ offsets_list = []
38
+
39
+ # Ensure image dimensions are divisible by 32 for the encoder
40
+ h, w = image.shape[:2]
41
+ pad_h = (32 - h % 32) % 32
42
+ pad_w = (32 - w % 32) % 32
43
+
44
+ def _forward(img_np):
45
+ """Run model on numpy image, return heatmap and offsets."""
46
+ # Pad to multiple of 32
47
+ if pad_h > 0 or pad_w > 0:
48
+ img_np = np.pad(img_np, ((0, pad_h), (0, pad_w)), mode="reflect")
49
+
50
+ tensor = (
51
+ torch.from_numpy(img_np)
52
+ .float()
53
+ .unsqueeze(0)
54
+ .unsqueeze(0) # (1, 1, H, W)
55
+ / 255.0
56
+ ).to(device)
57
+
58
+ with torch.no_grad():
59
+ hm, off = model(tensor)
60
+
61
+ hm = hm.squeeze(0).cpu().numpy() # (2, H/2, W/2)
62
+ off = off.squeeze(0).cpu().numpy() # (2, H/2, W/2)
63
+
64
+ # Remove padding from output
65
+ hm_h = h // 2
66
+ hm_w = w // 2
67
+ return hm[:, :hm_h, :hm_w], off[:, :hm_h, :hm_w]
68
+
69
+ # D4 group: 4 rotations x 2 reflections = 8 geometric views
70
+ for k in range(4):
71
+ for flip in [False, True]:
72
+ aug = np.rot90(image, k).copy()
73
+ if flip:
74
+ aug = np.fliplr(aug).copy()
75
+
76
+ hm, off = _forward(aug)
77
+
78
+ # Inverse transforms on heatmap and offsets
79
+ if flip:
80
+ hm = np.flip(hm, axis=2).copy() # flip W axis
81
+ off = np.flip(off, axis=2).copy()
82
+ off[0] = -off[0] # negate x offset for horizontal flip
83
+
84
+ if k > 0:
85
+ hm = np.rot90(hm, -k, axes=(1, 2)).copy()
86
+ off = np.rot90(off, -k, axes=(1, 2)).copy()
87
+ # Rotate offset vectors
88
+ if k == 1: # 90° CCW undo
89
+ off = np.stack([-off[1], off[0]], axis=0)
90
+ elif k == 2: # 180°
91
+ off = np.stack([-off[0], -off[1]], axis=0)
92
+ elif k == 3: # 270° CCW undo
93
+ off = np.stack([off[1], -off[0]], axis=0)
94
+
95
+ heatmaps.append(hm)
96
+ offsets_list.append(off)
97
+
98
+ # 2 intensity variants
99
+ for factor in [0.9, 1.1]:
100
+ aug = np.clip(image.astype(np.float32) * factor, 0, 255).astype(np.uint8)
101
+ hm, off = _forward(aug)
102
+ heatmaps.append(hm)
103
+ offsets_list.append(off)
104
+
105
+ # Average all views
106
+ avg_heatmap = np.mean(heatmaps, axis=0)
107
+ avg_offsets = np.mean(offsets_list, axis=0)
108
+
109
+ return avg_heatmap, avg_offsets
110
+
111
+
112
+ def ensemble_predict(
113
+ models: List[ImmunogoldCenterNet],
114
+ image: np.ndarray,
115
+ device: torch.device = torch.device("cpu"),
116
+ use_tta: bool = True,
117
+ ) -> tuple:
118
+ """
119
+ Ensemble prediction: average heatmaps from N models.
120
+
121
+ Args:
122
+ models: list of trained models (e.g., 5 seeds x 3 snapshots = 15)
123
+ image: (H, W) uint8 preprocessed image
124
+ device: torch device
125
+ use_tta: whether to apply D4 TTA per model
126
+
127
+ Returns:
128
+ averaged_heatmap: (2, H/2, W/2) numpy array
129
+ averaged_offsets: (2, H/2, W/2) numpy array
130
+ """
131
+ all_heatmaps = []
132
+ all_offsets = []
133
+
134
+ for model in models:
135
+ model.eval()
136
+ model.to(device)
137
+
138
+ if use_tta:
139
+ hm, off = d4_tta_predict(model, image, device)
140
+ else:
141
+ h, w = image.shape[:2]
142
+ pad_h = (32 - h % 32) % 32
143
+ pad_w = (32 - w % 32) % 32
144
+ img_padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")
145
+
146
+ tensor = (
147
+ torch.from_numpy(img_padded)
148
+ .float()
149
+ .unsqueeze(0)
150
+ .unsqueeze(0)
151
+ / 255.0
152
+ ).to(device)
153
+
154
+ with torch.no_grad():
155
+ hm_t, off_t = model(tensor)
156
+
157
+ hm = hm_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
158
+ off = off_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
159
+
160
+ all_heatmaps.append(hm)
161
+ all_offsets.append(off)
162
+
163
+ return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
164
+
165
+
166
+ def sliding_window_inference(
167
+ model: ImmunogoldCenterNet,
168
+ image: np.ndarray,
169
+ patch_size: int = 512,
170
+ overlap: int = 128,
171
+ device: torch.device = torch.device("cpu"),
172
+ ) -> tuple:
173
+ """
174
+ Full-image inference via sliding window with overlap stitching.
175
+
176
+ Tiles the image into overlapping patches, runs the model on each,
177
+ and stitches heatmaps using max in overlap regions.
178
+
179
+ Args:
180
+ model: trained model
181
+ image: (H, W) uint8 preprocessed image
182
+ patch_size: tile size
183
+ overlap: overlap between tiles
184
+ device: torch device
185
+
186
+ Returns:
187
+ heatmap: (2, H/2, W/2) numpy array
188
+ offsets: (2, H/2, W/2) numpy array
189
+ """
190
+ model.eval()
191
+ h, w = image.shape[:2]
192
+ stride_step = patch_size - overlap
193
+
194
+ # Output dimensions at model stride
195
+ out_h = h // 2
196
+ out_w = w // 2
197
+ out_patch = patch_size // 2
198
+
199
+ heatmap = np.zeros((2, out_h, out_w), dtype=np.float32)
200
+ offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
201
+ count = np.zeros((out_h, out_w), dtype=np.float32)
202
+
203
+ for y0 in range(0, h - patch_size + 1, stride_step):
204
+ for x0 in range(0, w - patch_size + 1, stride_step):
205
+ patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
206
+ tensor = (
207
+ torch.from_numpy(patch)
208
+ .float()
209
+ .unsqueeze(0)
210
+ .unsqueeze(0)
211
+ / 255.0
212
+ ).to(device)
213
+
214
+ with torch.no_grad():
215
+ hm, off = model(tensor)
216
+
217
+ hm_np = hm.squeeze(0).cpu().numpy()
218
+ off_np = off.squeeze(0).cpu().numpy()
219
+
220
+ # Output coordinates
221
+ oy0 = y0 // 2
222
+ ox0 = x0 // 2
223
+
224
+ # Max-stitch heatmap, average-stitch offsets
225
+ heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] = np.maximum(
226
+ heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch],
227
+ hm_np,
228
+ )
229
+ offsets[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += off_np
230
+ count[oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += 1
231
+
232
+ # Average offsets where counted
233
+ count = np.maximum(count, 1)
234
+ offsets /= count[np.newaxis, :, :]
235
+
236
+ return heatmap, offsets
src/evaluate.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation: Hungarian matching, per-class metrics, LOOCV runner.
3
+
4
+ Uses scipy linear_sum_assignment for optimal bipartite matching between
5
+ predictions and ground truth with class-specific match radii.
6
+ """
7
+
8
+ import numpy as np
9
+ from scipy.optimize import linear_sum_assignment
10
+ from scipy.spatial.distance import cdist
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+
14
+ def compute_f1(tp: int, fp: int, fn: int, eps: float = 1e-6) -> Tuple[float, float, float]:
15
+ """Compute F1, precision, recall from TP/FP/FN counts."""
16
+ precision = tp / (tp + fp + eps)
17
+ recall = tp / (tp + fn + eps)
18
+ f1 = 2 * precision * recall / (precision + recall + eps)
19
+ return f1, precision, recall
20
+
21
+
22
+ def match_detections_to_gt(
23
+ detections: List[dict],
24
+ gt_coords_6nm: np.ndarray,
25
+ gt_coords_12nm: np.ndarray,
26
+ match_radii: Optional[Dict[str, float]] = None,
27
+ ) -> Dict[str, dict]:
28
+ """
29
+ Hungarian matching between predictions and ground truth.
30
+
31
+ A detection matches GT only if:
32
+ 1. Euclidean distance < match_radius[class]
33
+ 2. Predicted class == GT class
34
+
35
+ Args:
36
+ detections: list of {'x', 'y', 'class', 'conf'}
37
+ gt_coords_6nm: (N, 2) array of (x, y) GT for 6nm
38
+ gt_coords_12nm: (M, 2) array of (x, y) GT for 12nm
39
+ match_radii: per-class match radius in pixels
40
+
41
+ Returns:
42
+ Dict with per-class and overall TP/FP/FN/F1/precision/recall.
43
+ """
44
+ if match_radii is None:
45
+ match_radii = {"6nm": 9.0, "12nm": 15.0}
46
+
47
+ gt_by_class = {"6nm": gt_coords_6nm, "12nm": gt_coords_12nm}
48
+ results = {}
49
+
50
+ total_tp = 0
51
+ total_fp = 0
52
+ total_fn = 0
53
+
54
+ for cls in ["6nm", "12nm"]:
55
+ cls_dets = [d for d in detections if d["class"] == cls]
56
+ gt = gt_by_class[cls]
57
+ radius = match_radii[cls]
58
+
59
+ if len(cls_dets) == 0 and len(gt) == 0:
60
+ results[cls] = {
61
+ "tp": 0, "fp": 0, "fn": 0,
62
+ "f1": 1.0, "precision": 1.0, "recall": 1.0,
63
+ }
64
+ continue
65
+
66
+ if len(cls_dets) == 0:
67
+ results[cls] = {
68
+ "tp": 0, "fp": 0, "fn": len(gt),
69
+ "f1": 0.0, "precision": 0.0, "recall": 0.0,
70
+ }
71
+ total_fn += len(gt)
72
+ continue
73
+
74
+ if len(gt) == 0:
75
+ results[cls] = {
76
+ "tp": 0, "fp": len(cls_dets), "fn": 0,
77
+ "f1": 0.0, "precision": 0.0, "recall": 0.0,
78
+ }
79
+ total_fp += len(cls_dets)
80
+ continue
81
+
82
+ # Build cost matrix
83
+ pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets])
84
+ cost = cdist(pred_coords, gt)
85
+
86
+ # Set costs beyond radius to a large value (forbid match)
87
+ cost_masked = cost.copy()
88
+ cost_masked[cost_masked > radius] = 1e6
89
+
90
+ # Hungarian matching
91
+ row_ind, col_ind = linear_sum_assignment(cost_masked)
92
+
93
+ # Count valid matches (within radius)
94
+ tp = sum(
95
+ 1 for r, c in zip(row_ind, col_ind)
96
+ if cost_masked[r, c] <= radius
97
+ )
98
+ fp = len(cls_dets) - tp
99
+ fn = len(gt) - tp
100
+
101
+ f1, prec, rec = compute_f1(tp, fp, fn)
102
+
103
+ results[cls] = {
104
+ "tp": tp, "fp": fp, "fn": fn,
105
+ "f1": f1, "precision": prec, "recall": rec,
106
+ }
107
+
108
+ total_tp += tp
109
+ total_fp += fp
110
+ total_fn += fn
111
+
112
+ # Overall
113
+ f1_overall, prec_overall, rec_overall = compute_f1(total_tp, total_fp, total_fn)
114
+ results["overall"] = {
115
+ "tp": total_tp, "fp": total_fp, "fn": total_fn,
116
+ "f1": f1_overall, "precision": prec_overall, "recall": rec_overall,
117
+ }
118
+
119
+ # Mean F1 across classes
120
+ class_f1s = [results[c]["f1"] for c in ["6nm", "12nm"] if results[c]["fn"] + results[c]["tp"] > 0]
121
+ results["mean_f1"] = np.mean(class_f1s) if class_f1s else 0.0
122
+
123
+ return results
124
+
125
+
126
+ def evaluate_fold(
127
+ detections: List[dict],
128
+ gt_annotations: Dict[str, np.ndarray],
129
+ match_radii: Optional[Dict[str, float]] = None,
130
+ has_6nm: bool = True,
131
+ ) -> Dict[str, dict]:
132
+ """
133
+ Evaluate detections for a single LOOCV fold.
134
+
135
+ Args:
136
+ detections: model predictions
137
+ gt_annotations: {'6nm': Nx2, '12nm': Mx2}
138
+ match_radii: per-class match radii
139
+ has_6nm: whether this fold has 6nm GT (False for S7, S15)
140
+
141
+ Returns:
142
+ Evaluation metrics dict.
143
+ """
144
+ gt_6nm = gt_annotations.get("6nm", np.empty((0, 2)))
145
+ gt_12nm = gt_annotations.get("12nm", np.empty((0, 2)))
146
+
147
+ results = match_detections_to_gt(detections, gt_6nm, gt_12nm, match_radii)
148
+
149
+ if not has_6nm:
150
+ results["6nm"]["note"] = "N/A (missing annotations)"
151
+
152
+ return results
153
+
154
+
155
+ def compute_average_precision(
156
+ detections: List[dict],
157
+ gt_coords: np.ndarray,
158
+ match_radius: float,
159
+ ) -> float:
160
+ """
161
+ Compute Average Precision (AP) for a single class.
162
+
163
+ Follows PASCAL VOC style: sort by confidence, compute precision-recall
164
+ curve, then compute area under curve.
165
+ """
166
+ if len(gt_coords) == 0:
167
+ return 0.0 if detections else 1.0
168
+
169
+ # Sort by confidence descending
170
+ sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
171
+
172
+ tp_list = []
173
+ fp_list = []
174
+ matched_gt = set()
175
+
176
+ for det in sorted_dets:
177
+ det_coord = np.array([det["x"], det["y"]])
178
+ dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
179
+ min_idx = np.argmin(dists)
180
+
181
+ if dists[min_idx] <= match_radius and min_idx not in matched_gt:
182
+ tp_list.append(1)
183
+ fp_list.append(0)
184
+ matched_gt.add(min_idx)
185
+ else:
186
+ tp_list.append(0)
187
+ fp_list.append(1)
188
+
189
+ tp_cumsum = np.cumsum(tp_list)
190
+ fp_cumsum = np.cumsum(fp_list)
191
+
192
+ precision = tp_cumsum / (tp_cumsum + fp_cumsum)
193
+ recall = tp_cumsum / len(gt_coords)
194
+
195
+ # Compute AP using all-point interpolation
196
+ ap = 0.0
197
+ for i in range(len(precision)):
198
+ if i == 0:
199
+ ap += precision[i] * recall[i]
200
+ else:
201
+ ap += precision[i] * (recall[i] - recall[i - 1])
202
+
203
+ return ap
src/heatmap.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ground truth heatmap generation and peak extraction for CenterNet.
3
+
4
+ Generates Gaussian-splat heatmaps at stride-2 resolution with
5
+ class-specific sigma values calibrated to bead size.
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import Dict, List, Tuple, Optional
12
+
13
+ # Class index mapping
14
+ CLASS_IDX = {"6nm": 0, "12nm": 1}
15
+ CLASS_NAMES = ["6nm", "12nm"]
16
+ STRIDE = 2
17
+
18
+
19
+ def generate_heatmap_gt(
20
+ coords_6nm: np.ndarray,
21
+ coords_12nm: np.ndarray,
22
+ image_h: int,
23
+ image_w: int,
24
+ sigmas: Optional[Dict[str, float]] = None,
25
+ stride: int = STRIDE,
26
+ confidence_6nm: Optional[np.ndarray] = None,
27
+ confidence_12nm: Optional[np.ndarray] = None,
28
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
29
+ """
30
+ Generate CenterNet ground truth heatmaps and offset maps.
31
+
32
+ Args:
33
+ coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space
34
+ coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space
35
+ image_h: original image height
36
+ image_w: original image width
37
+ sigmas: per-class Gaussian sigma in feature space
38
+ stride: output stride (default 2)
39
+ confidence_6nm: optional per-particle confidence weights
40
+ confidence_12nm: optional per-particle confidence weights
41
+
42
+ Returns:
43
+ heatmap: (2, H//stride, W//stride) float32 in [0, 1]
44
+ offsets: (2, H//stride, W//stride) float32 sub-pixel offsets
45
+ offset_mask: (H//stride, W//stride) bool — True at particle centers
46
+ conf_map: (2, H//stride, W//stride) float32 confidence weights
47
+ """
48
+ if sigmas is None:
49
+ sigmas = {"6nm": 1.0, "12nm": 1.5}
50
+
51
+ h_feat = image_h // stride
52
+ w_feat = image_w // stride
53
+
54
+ heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32)
55
+ offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32)
56
+ offset_mask = np.zeros((h_feat, w_feat), dtype=bool)
57
+ conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32)
58
+
59
+ # Prepare coordinate lists with class labels and confidences
60
+ all_entries = []
61
+ if len(coords_6nm) > 0:
62
+ confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm))
63
+ for i, (x, y) in enumerate(coords_6nm):
64
+ all_entries.append((x, y, "6nm", confs[i]))
65
+ if len(coords_12nm) > 0:
66
+ confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm))
67
+ for i, (x, y) in enumerate(coords_12nm):
68
+ all_entries.append((x, y, "12nm", confs[i]))
69
+
70
+ for x, y, cls, conf in all_entries:
71
+ cidx = CLASS_IDX[cls]
72
+ sigma = sigmas[cls]
73
+
74
+ # Feature-space center (float)
75
+ cx_f = x / stride
76
+ cy_f = y / stride
77
+
78
+ # Integer grid center
79
+ cx_i = int(round(cx_f))
80
+ cy_i = int(round(cy_f))
81
+
82
+ # Sub-pixel offset
83
+ off_x = cx_f - cx_i
84
+ off_y = cy_f - cy_i
85
+
86
+ # Gaussian radius: truncate at 3 sigma
87
+ r = max(int(3 * sigma + 1), 2)
88
+
89
+ # Bounds-clipped grid
90
+ y0 = max(0, cy_i - r)
91
+ y1 = min(h_feat, cy_i + r + 1)
92
+ x0 = max(0, cx_i - r)
93
+ x1 = min(w_feat, cx_i + r + 1)
94
+
95
+ if y0 >= y1 or x0 >= x1:
96
+ continue
97
+
98
+ yy, xx = np.meshgrid(
99
+ np.arange(y0, y1),
100
+ np.arange(x0, x1),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Gaussian centered at INTEGER center (not float)
105
+ # The integer center MUST be exactly 1.0 — the CornerNet focal loss
106
+ # uses pos_mask = (gt == 1.0) and treats everything else as negative.
107
+ # Centering the Gaussian at the float position produces peaks of 0.78-0.93
108
+ # which the loss sees as negatives → zero positive training signal.
109
+ gaussian = np.exp(
110
+ -((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2)
111
+ )
112
+
113
+ # Scale by confidence (for pseudo-label weighting)
114
+ gaussian = gaussian * conf
115
+
116
+ # Element-wise max (handles overlapping particles correctly)
117
+ heatmap[cidx, y0:y1, x0:x1] = np.maximum(
118
+ heatmap[cidx, y0:y1, x0:x1], gaussian
119
+ )
120
+
121
+ # Offset and confidence only at the integer center pixel
122
+ if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat:
123
+ offsets[0, cy_i, cx_i] = off_x
124
+ offsets[1, cy_i, cx_i] = off_y
125
+ offset_mask[cy_i, cx_i] = True
126
+ conf_map[cidx, cy_i, cx_i] = conf
127
+
128
+ return heatmap, offsets, offset_mask, conf_map
129
+
130
+
131
+ def extract_peaks(
132
+ heatmap: torch.Tensor,
133
+ offset_map: torch.Tensor,
134
+ stride: int = STRIDE,
135
+ conf_threshold: float = 0.3,
136
+ nms_kernel_sizes: Optional[Dict[str, int]] = None,
137
+ ) -> List[dict]:
138
+ """
139
+ Extract detections from predicted heatmap via max-pool NMS.
140
+
141
+ Args:
142
+ heatmap: (2, H/stride, W/stride) sigmoid-activated
143
+ offset_map: (2, H/stride, W/stride) raw offset predictions
144
+ stride: output stride
145
+ conf_threshold: minimum confidence to keep
146
+ nms_kernel_sizes: per-class NMS kernel sizes
147
+
148
+ Returns:
149
+ List of {'x': float, 'y': float, 'class': str, 'conf': float}
150
+ """
151
+ if nms_kernel_sizes is None:
152
+ nms_kernel_sizes = {"6nm": 3, "12nm": 5}
153
+
154
+ detections = []
155
+
156
+ for cls_idx, cls_name in enumerate(CLASS_NAMES):
157
+ hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) # (1,1,H,W)
158
+ kernel = nms_kernel_sizes[cls_name]
159
+
160
+ # Max-pool NMS
161
+ hmax = F.max_pool2d(
162
+ hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2
163
+ )
164
+ peaks = (hmax.squeeze() == heatmap[cls_idx]) & (
165
+ heatmap[cls_idx] > conf_threshold
166
+ )
167
+
168
+ ys, xs = torch.where(peaks)
169
+ for y_idx, x_idx in zip(ys, xs):
170
+ y_i = y_idx.item()
171
+ x_i = x_idx.item()
172
+ conf = heatmap[cls_idx, y_i, x_i].item()
173
+ dx = offset_map[0, y_i, x_i].item()
174
+ dy = offset_map[1, y_i, x_i].item()
175
+
176
+ # Back to input space with sub-pixel offset
177
+ det_x = (x_i + dx) * stride
178
+ det_y = (y_i + dy) * stride
179
+
180
+ detections.append({
181
+ "x": det_x,
182
+ "y": det_y,
183
+ "class": cls_name,
184
+ "conf": conf,
185
+ })
186
+
187
+ return detections
src/loss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for CenterNet immunogold detection.
3
+
4
+ Implements CornerNet penalty-reduced focal loss for sparse heatmaps
5
+ and smooth L1 offset regression loss.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def cornernet_focal_loss(
13
+ pred: torch.Tensor,
14
+ gt: torch.Tensor,
15
+ alpha: int = 2,
16
+ beta: int = 4,
17
+ conf_weights: torch.Tensor = None,
18
+ eps: float = 1e-6,
19
+ ) -> torch.Tensor:
20
+ """
21
+ CornerNet penalty-reduced focal loss for sparse heatmaps.
22
+
23
+ The positive:negative pixel ratio is ~1:23,000 per channel.
24
+ Standard BCE would learn to predict all zeros. This loss
25
+ penalizes confident wrong predictions and rewards uncertain
26
+ correct ones via the (1-p)^alpha and p^alpha terms.
27
+
28
+ Args:
29
+ pred: (B, C, H, W) sigmoid-activated predictions in [0, 1]
30
+ gt: (B, C, H, W) Gaussian heatmap targets in [0, 1]
31
+ alpha: focal exponent for prediction confidence (default 2)
32
+ beta: penalty reduction exponent near GT peaks (default 4)
33
+ conf_weights: optional (B, C, H, W) per-pixel confidence weights
34
+ for pseudo-label weighting
35
+ eps: numerical stability
36
+
37
+ Returns:
38
+ Scalar loss, normalized by number of positive locations.
39
+ """
40
+ pos_mask = (gt == 1).float()
41
+ neg_mask = (gt < 1).float()
42
+
43
+ # Penalty reduction: pixels near particle centers get lower negative penalty
44
+ # (1 - gt)^beta → 0 near peaks, → 1 far from peaks
45
+ neg_weights = torch.pow(1 - gt, beta)
46
+
47
+ # Positive loss: encourage high confidence at GT peaks
48
+ pos_loss = torch.log(pred.clamp(min=eps)) * torch.pow(1 - pred, alpha) * pos_mask
49
+
50
+ # Negative loss: penalize high confidence away from GT peaks
51
+ neg_loss = (
52
+ torch.log((1 - pred).clamp(min=eps))
53
+ * torch.pow(pred, alpha)
54
+ * neg_weights
55
+ * neg_mask
56
+ )
57
+
58
+ # Apply confidence weighting if provided (for pseudo-label support)
59
+ if conf_weights is not None:
60
+ pos_loss = pos_loss * conf_weights
61
+ # Negative loss near pseudo-labels also scaled
62
+ neg_loss = neg_loss * conf_weights
63
+
64
+ num_pos = pos_mask.sum().clamp(min=1)
65
+ loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos
66
+
67
+ return loss
68
+
69
+
70
+ def offset_loss(
71
+ pred_offsets: torch.Tensor,
72
+ gt_offsets: torch.Tensor,
73
+ mask: torch.Tensor,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Smooth L1 loss on sub-pixel offsets at annotated particle locations only.
77
+
78
+ Args:
79
+ pred_offsets: (B, 2, H, W) predicted offsets
80
+ gt_offsets: (B, 2, H, W) ground truth offsets
81
+ mask: (B, H, W) boolean — True at particle integer centers
82
+
83
+ Returns:
84
+ Scalar loss.
85
+ """
86
+ # Expand mask to match offset dimensions
87
+ mask_expanded = mask.unsqueeze(1).expand_as(pred_offsets)
88
+
89
+ if mask_expanded.sum() == 0:
90
+ return torch.tensor(0.0, device=pred_offsets.device, requires_grad=True)
91
+
92
+ loss = F.smooth_l1_loss(
93
+ pred_offsets[mask_expanded],
94
+ gt_offsets[mask_expanded],
95
+ reduction="mean",
96
+ )
97
+ return loss
98
+
99
+
100
+ def total_loss(
101
+ heatmap_pred: torch.Tensor,
102
+ heatmap_gt: torch.Tensor,
103
+ offset_pred: torch.Tensor,
104
+ offset_gt: torch.Tensor,
105
+ offset_mask: torch.Tensor,
106
+ lambda_offset: float = 1.0,
107
+ focal_alpha: int = 2,
108
+ focal_beta: int = 4,
109
+ conf_weights: torch.Tensor = None,
110
+ ) -> tuple:
111
+ """
112
+ Combined heatmap focal loss + offset regression loss.
113
+
114
+ Args:
115
+ heatmap_pred: (B, 2, H, W) sigmoid predictions
116
+ heatmap_gt: (B, 2, H, W) Gaussian GT
117
+ offset_pred: (B, 2, H, W) predicted offsets
118
+ offset_gt: (B, 2, H, W) GT offsets
119
+ offset_mask: (B, H, W) boolean mask
120
+ lambda_offset: weight for offset loss (default 1.0)
121
+ focal_alpha: focal loss alpha
122
+ focal_beta: focal loss beta
123
+ conf_weights: optional per-pixel confidence weights
124
+
125
+ Returns:
126
+ (total_loss, heatmap_loss_value, offset_loss_value)
127
+ """
128
+ l_hm = cornernet_focal_loss(
129
+ heatmap_pred, heatmap_gt,
130
+ alpha=focal_alpha, beta=focal_beta,
131
+ conf_weights=conf_weights,
132
+ )
133
+ l_off = offset_loss(offset_pred, offset_gt, offset_mask)
134
+
135
+ total = l_hm + lambda_offset * l_off
136
+
137
+ return total, l_hm.item(), l_off.item()
src/model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CenterNet with CEM500K-pretrained ResNet-50 backbone for immunogold detection.
3
+
4
+ Architecture:
5
+ Input: 1ch grayscale, variable size (padded to multiple of 32)
6
+ Encoder: CEM500K ResNet-50 (pretrained), conv1 adapted for 1ch input
7
+ Neck: BiFPN (2 rounds, 128ch)
8
+ Decoder: Transposed conv → stride-2 output
9
+ Heads: Heatmap (2ch sigmoid), Offset (2ch)
10
+ Output: Stride-2 maps → (H/2, W/2) resolution
11
+
12
+ Output stride is 2, NOT 4 or 8. At stride 4, a 6nm bead (4-6px radius)
13
+ collapses to 1px in feature space — insufficient for detection.
14
+ At stride 2, same bead occupies 2-3px, enough for Gaussian peak extraction.
15
+ """
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision.models as models
22
+ from typing import List, Optional
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # BiFPN: Bidirectional Feature Pyramid Network
27
+ # ---------------------------------------------------------------------------
28
+
29
+ class DepthwiseSeparableConv(nn.Module):
30
+ """Depthwise separable convolution as used in BiFPN."""
31
+
32
+ def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3,
33
+ stride: int = 1, padding: int = 1):
34
+ super().__init__()
35
+ self.depthwise = nn.Conv2d(
36
+ in_ch, in_ch, kernel_size, stride=stride,
37
+ padding=padding, groups=in_ch, bias=False,
38
+ )
39
+ self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
40
+ self.bn = nn.BatchNorm2d(out_ch)
41
+ self.act = nn.ReLU(inplace=True)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ return self.act(self.bn(self.pointwise(self.depthwise(x))))
45
+
46
+
47
+ class BiFPNFusionNode(nn.Module):
48
+ """
49
+ Single BiFPN fusion node with fast normalized weighted fusion.
50
+
51
+ w_normalized = relu(w) / (sum(relu(w)) + eps)
52
+ output = conv(sum(w_i * input_i))
53
+ """
54
+
55
+ def __init__(self, channels: int, n_inputs: int = 2, eps: float = 1e-4):
56
+ super().__init__()
57
+ self.eps = eps
58
+ # Learnable fusion weights
59
+ self.weights = nn.Parameter(torch.ones(n_inputs, dtype=torch.float32))
60
+ self.conv = DepthwiseSeparableConv(channels, channels)
61
+
62
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
63
+ # Fast normalized fusion
64
+ w = F.relu(self.weights)
65
+ w_norm = w / (w.sum() + self.eps)
66
+
67
+ fused = sum(w_i * inp for w_i, inp in zip(w_norm, inputs))
68
+ return self.conv(fused)
69
+
70
+
71
+ class BiFPNLayer(nn.Module):
72
+ """
73
+ One round of BiFPN: top-down + bottom-up bidirectional fusion.
74
+
75
+ Input levels: P2 (stride 4), P3 (stride 8), P4 (stride 16), P5 (stride 32)
76
+ """
77
+
78
+ def __init__(self, channels: int):
79
+ super().__init__()
80
+ # Top-down fusion nodes (P5 → P4_td, P4_td+P3 → P3_td, P3_td+P2 → P2_td)
81
+ self.td_p4 = BiFPNFusionNode(channels, n_inputs=2)
82
+ self.td_p3 = BiFPNFusionNode(channels, n_inputs=2)
83
+ self.td_p2 = BiFPNFusionNode(channels, n_inputs=2)
84
+
85
+ # Bottom-up fusion nodes (combine top-down outputs with original)
86
+ self.bu_p3 = BiFPNFusionNode(channels, n_inputs=3) # p3_orig + p3_td + p2_out
87
+ self.bu_p4 = BiFPNFusionNode(channels, n_inputs=3) # p4_orig + p4_td + p3_out
88
+ self.bu_p5 = BiFPNFusionNode(channels, n_inputs=2) # p5_orig + p4_out
89
+
90
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
91
+ """
92
+ Args:
93
+ features: [P2, P3, P4, P5] at channels ch, with decreasing spatial dims
94
+
95
+ Returns:
96
+ [P2_out, P3_out, P4_out, P5_out]
97
+ """
98
+ p2, p3, p4, p5 = features
99
+
100
+ # --- Top-down pathway ---
101
+ # P5 → upscale → fuse with P4
102
+ p5_up = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
103
+ p4_td = self.td_p4([p4, p5_up])
104
+
105
+ # P4_td → upscale → fuse with P3
106
+ p4_td_up = F.interpolate(p4_td, size=p3.shape[2:], mode="nearest")
107
+ p3_td = self.td_p3([p3, p4_td_up])
108
+
109
+ # P3_td → upscale → fuse with P2
110
+ p3_td_up = F.interpolate(p3_td, size=p2.shape[2:], mode="nearest")
111
+ p2_td = self.td_p2([p2, p3_td_up])
112
+
113
+ # --- Bottom-up pathway ---
114
+ p2_out = p2_td
115
+
116
+ # P2_out → downsample → fuse with P3_td and P3_orig
117
+ p2_down = F.max_pool2d(p2_out, kernel_size=2)
118
+ p3_out = self.bu_p3([p3, p3_td, p2_down])
119
+
120
+ # P3_out → downsample → fuse with P4_td and P4_orig
121
+ p3_down = F.max_pool2d(p3_out, kernel_size=2)
122
+ p4_out = self.bu_p4([p4, p4_td, p3_down])
123
+
124
+ # P4_out → downsample → fuse with P5_orig
125
+ p4_down = F.max_pool2d(p4_out, kernel_size=2)
126
+ p5_out = self.bu_p5([p5, p4_down])
127
+
128
+ return [p2_out, p3_out, p4_out, p5_out]
129
+
130
+
131
+ class BiFPN(nn.Module):
132
+ """Multi-round BiFPN with lateral projections."""
133
+
134
+ def __init__(self, in_channels: List[int], out_channels: int = 128,
135
+ num_rounds: int = 2):
136
+ super().__init__()
137
+ # Lateral 1x1 projections to unify channel count
138
+ self.laterals = nn.ModuleList([
139
+ nn.Sequential(
140
+ nn.Conv2d(in_ch, out_channels, 1, bias=False),
141
+ nn.BatchNorm2d(out_channels),
142
+ nn.ReLU(inplace=True),
143
+ )
144
+ for in_ch in in_channels
145
+ ])
146
+
147
+ # BiFPN rounds
148
+ self.rounds = nn.ModuleList([
149
+ BiFPNLayer(out_channels) for _ in range(num_rounds)
150
+ ])
151
+
152
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
153
+ # Project to uniform channels
154
+ projected = [lat(feat) for lat, feat in zip(self.laterals, features)]
155
+
156
+ # Run BiFPN rounds
157
+ for bifpn_round in self.rounds:
158
+ projected = bifpn_round(projected)
159
+
160
+ return projected
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Detection Heads
165
+ # ---------------------------------------------------------------------------
166
+
167
+ class HeatmapHead(nn.Module):
168
+ """Heatmap prediction head at stride-2 resolution."""
169
+
170
+ def __init__(self, in_channels: int = 64, num_classes: int = 2):
171
+ super().__init__()
172
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
173
+ self.bn1 = nn.BatchNorm2d(64)
174
+ self.relu = nn.ReLU(inplace=True)
175
+ self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
176
+
177
+ # Initialize final conv bias for focal loss: -log((1-pi)/pi) where pi=0.01
178
+ # This prevents the network from producing high false positive rate early
179
+ nn.init.constant_(self.conv2.bias, -math.log((1 - 0.01) / 0.01))
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ x = self.relu(self.bn1(self.conv1(x)))
183
+ return torch.sigmoid(self.conv2(x))
184
+
185
+
186
+ class OffsetHead(nn.Module):
187
+ """Sub-pixel offset regression head."""
188
+
189
+ def __init__(self, in_channels: int = 64):
190
+ super().__init__()
191
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
192
+ self.bn1 = nn.BatchNorm2d(64)
193
+ self.relu = nn.ReLU(inplace=True)
194
+ self.conv2 = nn.Conv2d(64, 2, kernel_size=1) # dx, dy
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ x = self.relu(self.bn1(self.conv1(x)))
198
+ return self.conv2(x)
199
+
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # Full CenterNet Model
203
+ # ---------------------------------------------------------------------------
204
+
205
+ class ImmunogoldCenterNet(nn.Module):
206
+ """
207
+ CenterNet with CEM500K-pretrained ResNet-50 backbone.
208
+
209
+ Detects 6nm and 12nm immunogold particles at stride-2 resolution.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ pretrained_path: Optional[str] = None,
215
+ bifpn_channels: int = 128,
216
+ bifpn_rounds: int = 2,
217
+ num_classes: int = 2,
218
+ ):
219
+ super().__init__()
220
+ self.num_classes = num_classes
221
+
222
+ # --- Encoder: ResNet-50 ---
223
+ backbone = models.resnet50(weights=None)
224
+ # Adapt conv1 for 1-channel grayscale input
225
+ backbone.conv1 = nn.Conv2d(
226
+ 1, 64, kernel_size=7, stride=2, padding=3, bias=False,
227
+ )
228
+
229
+ # Load pretrained weights
230
+ if pretrained_path:
231
+ self._load_pretrained(backbone, pretrained_path)
232
+ else:
233
+ # Use ImageNet weights as fallback, adapting conv1
234
+ imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
235
+ state = imagenet_backbone.state_dict()
236
+ # Mean-pool RGB conv1 weights → grayscale
237
+ state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
238
+ backbone.load_state_dict(state, strict=False)
239
+
240
+ # Extract encoder stages
241
+ self.stem = nn.Sequential(
242
+ backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
243
+ )
244
+ self.layer1 = backbone.layer1 # 256ch, stride 4
245
+ self.layer2 = backbone.layer2 # 512ch, stride 8
246
+ self.layer3 = backbone.layer3 # 1024ch, stride 16
247
+ self.layer4 = backbone.layer4 # 2048ch, stride 32
248
+
249
+ # --- BiFPN Neck ---
250
+ self.bifpn = BiFPN(
251
+ in_channels=[256, 512, 1024, 2048],
252
+ out_channels=bifpn_channels,
253
+ num_rounds=bifpn_rounds,
254
+ )
255
+
256
+ # --- Decoder: upsample P2 (stride 4) → stride 2 ---
257
+ self.upsample = nn.Sequential(
258
+ nn.ConvTranspose2d(
259
+ bifpn_channels, 64, kernel_size=4, stride=2, padding=1, bias=False,
260
+ ),
261
+ nn.BatchNorm2d(64),
262
+ nn.ReLU(inplace=True),
263
+ )
264
+
265
+ # --- Detection Heads (at stride-2 resolution) ---
266
+ self.heatmap_head = HeatmapHead(64, num_classes)
267
+ self.offset_head = OffsetHead(64)
268
+
269
+ def _load_pretrained(self, backbone: nn.Module, path: str):
270
+ """Load CEM500K MoCoV2 pretrained weights."""
271
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
272
+
273
+ state = {}
274
+ # CEM500K uses MoCo format: keys prefixed with 'module.encoder_q.'
275
+ src_state = ckpt.get("state_dict", ckpt)
276
+ for k, v in src_state.items():
277
+ # Strip MoCo prefix
278
+ new_key = k
279
+ for prefix in ["module.encoder_q.", "module.", "encoder_q."]:
280
+ if new_key.startswith(prefix):
281
+ new_key = new_key[len(prefix):]
282
+ break
283
+ state[new_key] = v
284
+
285
+ # Adapt conv1: mean-pool 3ch RGB → 1ch grayscale
286
+ if "conv1.weight" in state and state["conv1.weight"].shape[1] == 3:
287
+ state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
288
+
289
+ # Load with strict=False (head layers won't match)
290
+ missing, unexpected = backbone.load_state_dict(state, strict=False)
291
+ # Expected: fc.weight, fc.bias will be missing/unexpected
292
+ print(f"CEM500K loaded: {len(state)} keys, "
293
+ f"{len(missing)} missing, {len(unexpected)} unexpected")
294
+
295
+ def forward(self, x: torch.Tensor) -> tuple:
296
+ """
297
+ Args:
298
+ x: (B, 1, H, W) grayscale input
299
+
300
+ Returns:
301
+ heatmap: (B, 2, H/2, W/2) sigmoid-activated class heatmaps
302
+ offsets: (B, 2, H/2, W/2) sub-pixel offset predictions
303
+ """
304
+ # Encoder
305
+ x0 = self.stem(x) # stride 4
306
+ p2 = self.layer1(x0) # 256ch, stride 4
307
+ p3 = self.layer2(p2) # 512ch, stride 8
308
+ p4 = self.layer3(p3) # 1024ch, stride 16
309
+ p5 = self.layer4(p4) # 2048ch, stride 32
310
+
311
+ # BiFPN neck
312
+ features = self.bifpn([p2, p3, p4, p5])
313
+
314
+ # Decoder: upsample P2 to stride 2
315
+ x_up = self.upsample(features[0])
316
+
317
+ # Detection heads
318
+ heatmap = self.heatmap_head(x_up) # (B, 2, H/2, W/2)
319
+ offsets = self.offset_head(x_up) # (B, 2, H/2, W/2)
320
+
321
+ return heatmap, offsets
322
+
323
+ def freeze_encoder(self):
324
+ """Freeze entire encoder (Phase 1 training)."""
325
+ for module in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
326
+ for param in module.parameters():
327
+ param.requires_grad = False
328
+
329
+ def unfreeze_deep_layers(self):
330
+ """Unfreeze layer3 and layer4 (Phase 2 training)."""
331
+ for module in [self.layer3, self.layer4]:
332
+ for param in module.parameters():
333
+ param.requires_grad = True
334
+
335
+ def unfreeze_all(self):
336
+ """Unfreeze all layers (Phase 3 training)."""
337
+ for param in self.parameters():
338
+ param.requires_grad = True
339
+
340
+ def get_param_groups(self, phase: int, cfg: dict) -> list:
341
+ """
342
+ Get parameter groups with discriminative learning rates per phase.
343
+
344
+ Args:
345
+ phase: 1, 2, or 3
346
+ cfg: training phase config from config.yaml
347
+
348
+ Returns:
349
+ List of param group dicts for optimizer.
350
+ """
351
+ if phase == 1:
352
+ # Only neck + heads trainable
353
+ return [
354
+ {"params": self.bifpn.parameters(), "lr": cfg["lr"]},
355
+ {"params": self.upsample.parameters(), "lr": cfg["lr"]},
356
+ {"params": self.heatmap_head.parameters(), "lr": cfg["lr"]},
357
+ {"params": self.offset_head.parameters(), "lr": cfg["lr"]},
358
+ ]
359
+ elif phase == 2:
360
+ return [
361
+ {"params": self.stem.parameters(), "lr": 0},
362
+ {"params": self.layer1.parameters(), "lr": 0},
363
+ {"params": self.layer2.parameters(), "lr": 0},
364
+ {"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
365
+ {"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
366
+ {"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
367
+ {"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
368
+ {"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
369
+ {"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
370
+ ]
371
+ else: # phase 3
372
+ return [
373
+ {"params": self.stem.parameters(), "lr": cfg["lr_stem"]},
374
+ {"params": self.layer1.parameters(), "lr": cfg["lr_layer1"]},
375
+ {"params": self.layer2.parameters(), "lr": cfg["lr_layer2"]},
376
+ {"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
377
+ {"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
378
+ {"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
379
+ {"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
380
+ {"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
381
+ {"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
382
+ ]
src/postprocess.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-processing: structural mask filtering, cross-class NMS, threshold sweep.
3
+ """
4
+
5
+ import numpy as np
6
+ from scipy.spatial.distance import cdist
7
+ from skimage.morphology import dilation, disk
8
+ from typing import Dict, List, Optional
9
+
10
+
11
+ def apply_structural_mask_filter(
12
+ detections: List[dict],
13
+ mask: np.ndarray,
14
+ margin_px: int = 5,
15
+ ) -> List[dict]:
16
+ """
17
+ Remove detections outside biological tissue regions.
18
+
19
+ Args:
20
+ detections: list of {'x', 'y', 'class', 'conf'}
21
+ mask: boolean array (H, W) where True = tissue region
22
+ margin_px: dilate mask by this many pixels
23
+
24
+ Returns:
25
+ Filtered detection list.
26
+ """
27
+ if mask is None:
28
+ return detections
29
+
30
+ # Dilate mask to allow particles at region boundaries
31
+ tissue = dilation(mask, disk(margin_px))
32
+
33
+ filtered = []
34
+ for det in detections:
35
+ xi, yi = int(round(det["x"])), int(round(det["y"]))
36
+ if (0 <= yi < tissue.shape[0] and
37
+ 0 <= xi < tissue.shape[1] and
38
+ tissue[yi, xi]):
39
+ filtered.append(det)
40
+ return filtered
41
+
42
+
43
+ def cross_class_nms(
44
+ detections: List[dict],
45
+ distance_threshold: float = 8.0,
46
+ ) -> List[dict]:
47
+ """
48
+ When 6nm and 12nm detections overlap, keep the higher-confidence one.
49
+
50
+ This handles cases where both heads fire on the same particle.
51
+ """
52
+ if len(detections) <= 1:
53
+ return detections
54
+
55
+ # Sort by confidence descending
56
+ dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
57
+ keep = [True] * len(dets)
58
+
59
+ coords = np.array([[d["x"], d["y"]] for d in dets])
60
+
61
+ for i in range(len(dets)):
62
+ if not keep[i]:
63
+ continue
64
+ for j in range(i + 1, len(dets)):
65
+ if not keep[j]:
66
+ continue
67
+ # Only suppress across classes
68
+ if dets[i]["class"] == dets[j]["class"]:
69
+ continue
70
+ dist = np.sqrt(
71
+ (coords[i, 0] - coords[j, 0]) ** 2
72
+ + (coords[i, 1] - coords[j, 1]) ** 2
73
+ )
74
+ if dist < distance_threshold:
75
+ keep[j] = False # Lower confidence suppressed
76
+
77
+ return [d for d, k in zip(dets, keep) if k]
78
+
79
+
80
+ def sweep_confidence_threshold(
81
+ detections: List[dict],
82
+ gt_coords: Dict[str, np.ndarray],
83
+ match_radii: Dict[str, float],
84
+ start: float = 0.05,
85
+ stop: float = 0.95,
86
+ step: float = 0.01,
87
+ ) -> Dict[str, float]:
88
+ """
89
+ Sweep confidence thresholds to find optimal per-class thresholds.
90
+
91
+ Args:
92
+ detections: all detections (before thresholding)
93
+ gt_coords: {'6nm': Nx2, '12nm': Mx2} ground truth
94
+ match_radii: per-class match radii in pixels
95
+ start, stop, step: sweep range
96
+
97
+ Returns:
98
+ Dict with best threshold per class and overall.
99
+ """
100
+ from src.evaluate import match_detections_to_gt, compute_f1
101
+
102
+ best_thresholds = {}
103
+ thresholds = np.arange(start, stop, step)
104
+
105
+ for cls in ["6nm", "12nm"]:
106
+ best_f1 = -1
107
+ best_thr = 0.3
108
+
109
+ for thr in thresholds:
110
+ cls_dets = [d for d in detections if d["class"] == cls and d["conf"] >= thr]
111
+ if not cls_dets and len(gt_coords[cls]) == 0:
112
+ continue
113
+
114
+ pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets]).reshape(-1, 2)
115
+ gt = gt_coords[cls]
116
+
117
+ if len(pred_coords) == 0:
118
+ tp, fp, fn = 0, 0, len(gt)
119
+ elif len(gt) == 0:
120
+ tp, fp, fn = 0, len(pred_coords), 0
121
+ else:
122
+ tp, fp, fn = _simple_match(pred_coords, gt, match_radii[cls])
123
+
124
+ f1, _, _ = compute_f1(tp, fp, fn)
125
+ if f1 > best_f1:
126
+ best_f1 = f1
127
+ best_thr = thr
128
+
129
+ best_thresholds[cls] = best_thr
130
+
131
+ return best_thresholds
132
+
133
+
134
+ def _simple_match(
135
+ pred: np.ndarray, gt: np.ndarray, radius: float
136
+ ) -> tuple:
137
+ """Quick matching for threshold sweep (greedy, not Hungarian)."""
138
+ from scipy.spatial.distance import cdist
139
+
140
+ if len(pred) == 0 or len(gt) == 0:
141
+ return 0, len(pred), len(gt)
142
+
143
+ dists = cdist(pred, gt)
144
+ tp = 0
145
+ matched_gt = set()
146
+
147
+ # Greedy: match closest pairs first
148
+ for i in range(len(pred)):
149
+ min_j = np.argmin(dists[i])
150
+ if dists[i, min_j] <= radius and min_j not in matched_gt:
151
+ tp += 1
152
+ matched_gt.add(min_j)
153
+ dists[:, min_j] = np.inf
154
+
155
+ fp = len(pred) - tp
156
+ fn = len(gt) - tp
157
+ return tp, fp, fn
src/preprocessing.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading, annotation parsing, and preprocessing for immunogold TEM images.
3
+
4
+ The model receives raw images — the CEM500K backbone was pretrained on raw EM.
5
+ Top-hat preprocessing is only used by LodeStar (Stage 1).
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import tifffile
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Data registry: robust discovery of images, masks, and annotations
19
+ # ---------------------------------------------------------------------------
20
+
21
+ @dataclass
22
+ class SynapseRecord:
23
+ """Metadata for one synapse sample."""
24
+ synapse_id: str
25
+ image_path: Path
26
+ mask_path: Optional[Path]
27
+ csv_6nm_paths: List[Path] = field(default_factory=list)
28
+ csv_12nm_paths: List[Path] = field(default_factory=list)
29
+ has_6nm: bool = False
30
+ has_12nm: bool = False
31
+
32
+
33
+ def discover_synapse_data(root: str, synapse_ids: List[str]) -> List[SynapseRecord]:
34
+ """
35
+ Discover all TIF images, masks, and CSV annotations for each synapse.
36
+
37
+ Handles naming inconsistencies:
38
+ - S22: main image is S22_0003.tif, two Results folders
39
+ - S25: 12nm CSV has no space ("Results12nm")
40
+ - CSV patterns: "Results 6nm XY" vs "Results XY in microns 6nm"
41
+ """
42
+ root = Path(root)
43
+ analyzed = root / "analyzed synapses"
44
+ records = []
45
+
46
+ for sid in synapse_ids:
47
+ folder = analyzed / sid
48
+ if not folder.exists():
49
+ raise FileNotFoundError(f"Synapse folder not found: {folder}")
50
+
51
+ # --- Find main image (TIF without 'mask' or 'color' in name) ---
52
+ all_tifs = list(folder.glob("*.tif"))
53
+ main_tifs = [
54
+ t for t in all_tifs
55
+ if "mask" not in t.stem.lower() and "color" not in t.stem.lower()
56
+ ]
57
+ if not main_tifs:
58
+ raise FileNotFoundError(f"No main image found in {folder}")
59
+ # Prefer the largest file (main EM image) if multiple found
60
+ image_path = max(main_tifs, key=lambda t: t.stat().st_size)
61
+
62
+ # --- Find mask ---
63
+ mask_tifs = [t for t in all_tifs if "mask" in t.stem.lower()]
64
+ mask_path = None
65
+ if mask_tifs:
66
+ # Prefer plain "mask.tif" over "mask 1.tif" / "mask 2.tif"
67
+ plain = [t for t in mask_tifs if t.stem.lower().endswith("mask")]
68
+ mask_path = plain[0] if plain else mask_tifs[0]
69
+
70
+ # --- Find CSVs across all Results* subdirectories ---
71
+ results_dirs = sorted(folder.glob("Results*"))
72
+ # Also check direct subdirs like "Results 1", "Results 2"
73
+ csv_6nm_paths = []
74
+ csv_12nm_paths = []
75
+
76
+ for rdir in results_dirs:
77
+ if rdir.is_dir():
78
+ for csv_file in rdir.glob("*.csv"):
79
+ name_lower = csv_file.name.lower()
80
+ if "6nm" in name_lower:
81
+ csv_6nm_paths.append(csv_file)
82
+ elif "12nm" in name_lower:
83
+ csv_12nm_paths.append(csv_file)
84
+
85
+ record = SynapseRecord(
86
+ synapse_id=sid,
87
+ image_path=image_path,
88
+ mask_path=mask_path,
89
+ csv_6nm_paths=csv_6nm_paths,
90
+ csv_12nm_paths=csv_12nm_paths,
91
+ has_6nm=len(csv_6nm_paths) > 0,
92
+ has_12nm=len(csv_12nm_paths) > 0,
93
+ )
94
+ records.append(record)
95
+
96
+ return records
97
+
98
+
99
+ # ---------------------------------------------------------------------------
100
+ # Image I/O
101
+ # ---------------------------------------------------------------------------
102
+
103
+ def load_image(path: Path) -> np.ndarray:
104
+ """
105
+ Load a TIF image as grayscale uint8.
106
+
107
+ Handles:
108
+ - RGB images (take first channel)
109
+ - Palette-mode images
110
+ - Already-grayscale images
111
+ """
112
+ img = tifffile.imread(str(path))
113
+ if img.ndim == 3:
114
+ # RGB or multi-channel — take first channel (all channels identical in these images)
115
+ img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
116
+ return img.astype(np.uint8)
117
+
118
+
119
+ def load_mask(path: Path) -> np.ndarray:
120
+ """
121
+ Load mask TIF as binary array.
122
+
123
+ Mask is RGB where tissue regions have values < 250 in at least one channel.
124
+ Returns boolean array: True = tissue/structural region.
125
+ """
126
+ mask_rgb = tifffile.imread(str(path))
127
+ if mask_rgb.ndim == 2:
128
+ return mask_rgb < 250
129
+ # RGB mask: tissue where any channel is not white
130
+ return np.any(mask_rgb < 250, axis=-1)
131
+
132
+
133
+ # ---------------------------------------------------------------------------
134
+ # Annotation loading and coordinate conversion
135
+ # ---------------------------------------------------------------------------
136
+
137
+ def load_annotations_csv(csv_path: Path) -> pd.DataFrame:
138
+ """
139
+ Load annotation CSV with columns [index, X, Y].
140
+
141
+ CSV headers have leading space: " ,X,Y".
142
+ Coordinates are normalized [0, 1] despite 'microns' in filename.
143
+ """
144
+ df = pd.read_csv(csv_path)
145
+ # Normalize column names (strip whitespace)
146
+ df.columns = [c.strip() for c in df.columns]
147
+ # Rename unnamed index column
148
+ if "" in df.columns:
149
+ df = df.rename(columns={"": "idx"})
150
+ return df[["X", "Y"]]
151
+
152
+
153
+ # Micron-to-pixel scale factor: consistent across all synapses (verified
154
+ # against researcher's color overlay TIFs). The CSV columns labeled "XY in
155
+ # microns" really ARE microns — multiply by this constant to get pixels.
156
+ MICRONS_TO_PIXELS = 1790.0
157
+
158
+
159
+ def load_all_annotations(
160
+ record: SynapseRecord, image_shape: Tuple[int, int]
161
+ ) -> Dict[str, np.ndarray]:
162
+ """
163
+ Load and convert annotations for one synapse to pixel coordinates.
164
+
165
+ CSV coordinates are in microns (despite filename suggesting normalization).
166
+ Multiply by MICRONS_TO_PIXELS (1790 px/micron) to convert.
167
+
168
+ Args:
169
+ record: SynapseRecord with CSV paths.
170
+ image_shape: (height, width) of the corresponding image.
171
+
172
+ Returns:
173
+ Dictionary with keys '6nm' and '12nm', each containing
174
+ an Nx2 array of (x, y) pixel coordinates.
175
+ """
176
+ h, w = image_shape[:2]
177
+ result = {"6nm": np.empty((0, 2), dtype=np.float64),
178
+ "12nm": np.empty((0, 2), dtype=np.float64)}
179
+
180
+ for cls, paths in [("6nm", record.csv_6nm_paths),
181
+ ("12nm", record.csv_12nm_paths)]:
182
+ all_coords = []
183
+ for csv_path in paths:
184
+ df = load_annotations_csv(csv_path)
185
+ # Convert microns to pixels
186
+ px_x = df["X"].values * MICRONS_TO_PIXELS
187
+ px_y = df["Y"].values * MICRONS_TO_PIXELS
188
+ # Validate: coords must fall within image bounds
189
+ assert px_x.max() < w + 10, \
190
+ f"X coords out of bounds ({px_x.max():.0f} > {w}) in {csv_path}"
191
+ assert px_y.max() < h + 10, \
192
+ f"Y coords out of bounds ({px_y.max():.0f} > {h}) in {csv_path}"
193
+ all_coords.append(np.stack([px_x, px_y], axis=1))
194
+
195
+ if all_coords:
196
+ coords = np.concatenate(all_coords, axis=0)
197
+ # Deduplicate (for S22 merged results): remove within 3px
198
+ if len(coords) > 1:
199
+ coords = _deduplicate_coords(coords, min_dist=3.0)
200
+ result[cls] = coords
201
+
202
+ return result
203
+
204
+
205
+ def _deduplicate_coords(
206
+ coords: np.ndarray, min_dist: float = 3.0
207
+ ) -> np.ndarray:
208
+ """Remove duplicate coordinates within min_dist pixels."""
209
+ from scipy.spatial.distance import cdist
210
+
211
+ if len(coords) <= 1:
212
+ return coords
213
+ dists = cdist(coords, coords)
214
+ np.fill_diagonal(dists, np.inf)
215
+ keep = np.ones(len(coords), dtype=bool)
216
+ for i in range(len(coords)):
217
+ if not keep[i]:
218
+ continue
219
+ # Mark later duplicates
220
+ for j in range(i + 1, len(coords)):
221
+ if keep[j] and dists[i, j] < min_dist:
222
+ keep[j] = False
223
+ return coords[keep]
224
+
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Preprocessing transforms
228
+ # ---------------------------------------------------------------------------
229
+
230
+ def preprocess_image(img: np.ndarray, bead_class: str,
231
+ tophat_radii: Optional[Dict[str, int]] = None,
232
+ clahe_clip_limit: float = 0.03,
233
+ clahe_kernel_size: int = 64) -> np.ndarray:
234
+ """
235
+ Top-hat + CLAHE preprocessing. Used ONLY by LodeStar (Stage 1).
236
+
237
+ Not used for model training — the CEM500K backbone expects raw EM images.
238
+ """
239
+ from skimage import exposure
240
+ from skimage.morphology import disk, white_tophat
241
+
242
+ if tophat_radii is None:
243
+ tophat_radii = {"6nm": 8, "12nm": 12}
244
+
245
+ img_inv = (255 - img).astype(np.float32)
246
+ radius = tophat_radii[bead_class]
247
+ tophat = white_tophat(img_inv, disk(radius))
248
+
249
+ tophat_max = tophat.max()
250
+ if tophat_max > 0:
251
+ tophat_norm = tophat / tophat_max
252
+ else:
253
+ tophat_norm = tophat
254
+
255
+ enhanced = exposure.equalize_adapthist(
256
+ tophat_norm,
257
+ clip_limit=clahe_clip_limit,
258
+ kernel_size=clahe_kernel_size,
259
+ )
260
+ return (enhanced * 255).astype(np.uint8)
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # Convenience: load everything for one synapse
265
+ # ---------------------------------------------------------------------------
266
+
267
+ def load_synapse(record: SynapseRecord) -> dict:
268
+ """
269
+ Load image, mask, and annotations for one synapse.
270
+
271
+ Returns dict with keys: 'image', 'mask', 'annotations',
272
+ 'synapse_id', 'image_shape'
273
+ """
274
+ img = load_image(record.image_path)
275
+ mask = load_mask(record.mask_path) if record.mask_path else None
276
+ annotations = load_all_annotations(record, img.shape)
277
+
278
+ return {
279
+ "synapse_id": record.synapse_id,
280
+ "image": img,
281
+ "mask": mask,
282
+ "annotations": annotations,
283
+ "image_shape": img.shape,
284
+ }
src/visualize.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for QC at every pipeline stage.
3
+
4
+ Generates overlay images showing predictions on raw EM images:
5
+ - Cyan circles for 6nm particles
6
+ - Yellow circles for 12nm particles
7
+ """
8
+
9
+ import numpy as np
10
+ import matplotlib
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional
16
+
17
+
18
+ # Color scheme
19
+ COLORS = {
20
+ "6nm": (0, 255, 255), # cyan
21
+ "12nm": (255, 255, 0), # yellow
22
+ "6nm_pred": (0, 200, 200),
23
+ "12nm_pred": (200, 200, 0),
24
+ }
25
+
26
+ RADII = {"6nm": 6, "12nm": 12}
27
+
28
+
29
+ def overlay_annotations(
30
+ image: np.ndarray,
31
+ annotations: Dict[str, np.ndarray],
32
+ title: str = "",
33
+ save_path: Optional[Path] = None,
34
+ predictions: Optional[List[dict]] = None,
35
+ figsize: tuple = (12, 12),
36
+ ) -> plt.Figure:
37
+ """
38
+ Overlay ground truth annotations (and optional predictions) on image.
39
+
40
+ Args:
41
+ image: (H, W) grayscale image
42
+ annotations: {'6nm': Nx2, '12nm': Mx2} pixel coordinates
43
+ title: figure title
44
+ save_path: if provided, save figure here
45
+ predictions: optional list of {'x', 'y', 'class', 'conf'}
46
+ figsize: figure size
47
+
48
+ Returns:
49
+ matplotlib Figure
50
+ """
51
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
52
+ ax.imshow(image, cmap="gray")
53
+
54
+ # Ground truth circles (solid)
55
+ for cls, coords in annotations.items():
56
+ if len(coords) == 0:
57
+ continue
58
+ color_rgb = np.array(COLORS[cls]) / 255.0
59
+ radius = RADII[cls]
60
+ for x, y in coords:
61
+ circle = plt.Circle(
62
+ (x, y), radius, fill=False,
63
+ edgecolor=color_rgb, linewidth=1.5,
64
+ )
65
+ ax.add_patch(circle)
66
+
67
+ # Predictions (dashed)
68
+ if predictions:
69
+ for det in predictions:
70
+ cls = det["class"]
71
+ color_rgb = np.array(COLORS.get(f"{cls}_pred", COLORS[cls])) / 255.0
72
+ radius = RADII[cls]
73
+ circle = plt.Circle(
74
+ (det["x"], det["y"]), radius, fill=False,
75
+ edgecolor=color_rgb, linewidth=1.0, linestyle="--",
76
+ )
77
+ ax.add_patch(circle)
78
+ # Confidence label
79
+ ax.text(
80
+ det["x"] + radius + 2, det["y"],
81
+ f'{det["conf"]:.2f}',
82
+ color=color_rgb, fontsize=6,
83
+ )
84
+
85
+ # Legend
86
+ legend_elements = [
87
+ mpatches.Patch(facecolor="none", edgecolor="cyan", label=f'6nm GT ({len(annotations.get("6nm", []))})', linewidth=1.5),
88
+ mpatches.Patch(facecolor="none", edgecolor="yellow", label=f'12nm GT ({len(annotations.get("12nm", []))})', linewidth=1.5),
89
+ ]
90
+ if predictions:
91
+ n_pred_6 = sum(1 for d in predictions if d["class"] == "6nm")
92
+ n_pred_12 = sum(1 for d in predictions if d["class"] == "12nm")
93
+ legend_elements.extend([
94
+ mpatches.Patch(facecolor="none", edgecolor="darkcyan", label=f"6nm pred ({n_pred_6})", linewidth=1.0),
95
+ mpatches.Patch(facecolor="none", edgecolor="goldenrod", label=f"12nm pred ({n_pred_12})", linewidth=1.0),
96
+ ])
97
+ ax.legend(handles=legend_elements, loc="upper right", fontsize=8)
98
+
99
+ ax.set_title(title, fontsize=10)
100
+ ax.axis("off")
101
+
102
+ if save_path:
103
+ save_path = Path(save_path)
104
+ save_path.parent.mkdir(parents=True, exist_ok=True)
105
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
106
+ plt.close(fig)
107
+
108
+ return fig
109
+
110
+
111
+ def plot_heatmap_overlay(
112
+ image: np.ndarray,
113
+ heatmap: np.ndarray,
114
+ title: str = "",
115
+ save_path: Optional[Path] = None,
116
+ ) -> plt.Figure:
117
+ """
118
+ Overlay predicted heatmap on image for QC.
119
+
120
+ Args:
121
+ image: (H, W) grayscale
122
+ heatmap: (2, H/2, W/2) predicted heatmap
123
+ """
124
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
125
+
126
+ axes[0].imshow(image, cmap="gray")
127
+ axes[0].set_title("Raw Image")
128
+ axes[0].axis("off")
129
+
130
+ # Upsample heatmap to image size for overlay
131
+ h, w = image.shape[:2]
132
+
133
+ for idx, (cls, color) in enumerate([("6nm", "hot"), ("12nm", "cool")]):
134
+ hm = heatmap[idx]
135
+ # Resize to image dims
136
+ from skimage.transform import resize
137
+ hm_up = resize(hm, (h, w), order=1)
138
+
139
+ axes[idx + 1].imshow(image, cmap="gray")
140
+ axes[idx + 1].imshow(hm_up, cmap=color, alpha=0.5, vmin=0, vmax=1)
141
+ axes[idx + 1].set_title(f"{cls} heatmap")
142
+ axes[idx + 1].axis("off")
143
+
144
+ fig.suptitle(title, fontsize=12)
145
+
146
+ if save_path:
147
+ save_path = Path(save_path)
148
+ save_path.parent.mkdir(parents=True, exist_ok=True)
149
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
150
+ plt.close(fig)
151
+
152
+ return fig
153
+
154
+
155
+ def plot_training_curves(
156
+ metrics: dict,
157
+ save_path: Optional[Path] = None,
158
+ ) -> plt.Figure:
159
+ """Plot training loss and F1 curves."""
160
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
161
+
162
+ epochs = range(1, len(metrics["train_loss"]) + 1)
163
+
164
+ # Loss
165
+ ax1.plot(epochs, metrics["train_loss"], label="Train Loss")
166
+ if "val_loss" in metrics:
167
+ ax1.plot(epochs, metrics["val_loss"], label="Val Loss")
168
+ ax1.set_xlabel("Epoch")
169
+ ax1.set_ylabel("Loss")
170
+ ax1.set_title("Training Loss")
171
+ ax1.legend()
172
+ ax1.grid(True, alpha=0.3)
173
+
174
+ # F1
175
+ if "val_f1_6nm" in metrics:
176
+ ax2.plot(epochs, metrics["val_f1_6nm"], label="6nm F1")
177
+ if "val_f1_12nm" in metrics:
178
+ ax2.plot(epochs, metrics["val_f1_12nm"], label="12nm F1")
179
+ if "val_f1_mean" in metrics:
180
+ ax2.plot(epochs, metrics["val_f1_mean"], label="Mean F1", linewidth=2)
181
+ ax2.set_xlabel("Epoch")
182
+ ax2.set_ylabel("F1 Score")
183
+ ax2.set_title("Validation F1")
184
+ ax2.legend()
185
+ ax2.grid(True, alpha=0.3)
186
+
187
+ if save_path:
188
+ save_path = Path(save_path)
189
+ save_path.parent.mkdir(parents=True, exist_ok=True)
190
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
191
+ plt.close(fig)
192
+
193
+ return fig
194
+
195
+
196
+ def plot_precision_recall_curve(
197
+ detections: List[dict],
198
+ gt_coords: np.ndarray,
199
+ match_radius: float,
200
+ cls_name: str = "",
201
+ save_path: Optional[Path] = None,
202
+ ) -> plt.Figure:
203
+ """Plot precision-recall curve for one class."""
204
+ sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
205
+
206
+ tp_list = []
207
+ matched_gt = set()
208
+
209
+ for det in sorted_dets:
210
+ det_coord = np.array([det["x"], det["y"]])
211
+ if len(gt_coords) > 0:
212
+ dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
213
+ min_idx = np.argmin(dists)
214
+ if dists[min_idx] <= match_radius and min_idx not in matched_gt:
215
+ tp_list.append(1)
216
+ matched_gt.add(min_idx)
217
+ else:
218
+ tp_list.append(0)
219
+ else:
220
+ tp_list.append(0)
221
+
222
+ tp_cumsum = np.cumsum(tp_list)
223
+ fp_cumsum = np.cumsum([1 - t for t in tp_list])
224
+ n_gt = max(len(gt_coords), 1)
225
+
226
+ precision = tp_cumsum / (tp_cumsum + fp_cumsum)
227
+ recall = tp_cumsum / n_gt
228
+
229
+ fig, ax = plt.subplots(figsize=(6, 6))
230
+ ax.plot(recall, precision, linewidth=2)
231
+ ax.set_xlabel("Recall")
232
+ ax.set_ylabel("Precision")
233
+ ax.set_title(f"PR Curve — {cls_name}")
234
+ ax.set_xlim(0, 1)
235
+ ax.set_ylim(0, 1)
236
+ ax.grid(True, alpha=0.3)
237
+
238
+ if save_path:
239
+ save_path = Path(save_path)
240
+ save_path.parent.mkdir(parents=True, exist_ok=True)
241
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
242
+ plt.close(fig)
243
+
244
+ return fig