Deploy MidasMap Gradio app, src, requirements, checkpoint
Browse files- README.md +30 -5
- app.py +751 -0
- checkpoints/final/final_model.pth +3 -0
- requirements.txt +16 -0
- src/.DS_Store +0 -0
- src/__init__.py +1 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/ensemble.cpython-311.pyc +0 -0
- src/__pycache__/evaluate.cpython-311.pyc +0 -0
- src/__pycache__/heatmap.cpython-311.pyc +0 -0
- src/__pycache__/loss.cpython-311.pyc +0 -0
- src/__pycache__/model.cpython-311.pyc +0 -0
- src/dataset.py +438 -0
- src/ensemble.py +236 -0
- src/evaluate.py +203 -0
- src/heatmap.py +187 -0
- src/loss.py +137 -0
- src/model.py +382 -0
- src/postprocess.py +157 -0
- src/preprocessing.py +284 -0
- src/visualize.py +244 -0
README.md
CHANGED
|
@@ -1,13 +1,38 @@
|
|
| 1 |
---
|
| 2 |
title: MidasMap
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: MidasMap
|
| 3 |
+
emoji: 🔬
|
| 4 |
+
colorFrom: gray
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# MidasMap Space
|
| 14 |
+
|
| 15 |
+
This folder is a **template** for creating a [Hugging Face Space](https://huggingface.co/docs/hub/spaces-overview).
|
| 16 |
+
|
| 17 |
+
**Why not Vercel for the model?** Vercel serverless functions have strict size and time limits; they are not meant for PyTorch + a ~100MB checkpoint and multi-second GPU/CPU inference. **Host the Gradio app + weights on a Space** (CPU free tier or GPU upgrade).
|
| 18 |
+
|
| 19 |
+
## Create the Space
|
| 20 |
+
|
| 21 |
+
1. On Hugging Face: **New Space** → SDK **Gradio** → name e.g. `MidasMap`.
|
| 22 |
+
2. Clone the Space repo locally, or connect **GitHub** and set the Space root to this monorepo with **App file** pointing to the copied `app.py`.
|
| 23 |
+
3. Copy into the Space repository root:
|
| 24 |
+
- `app.py` from the main MidasMap repo (project root), **or** symlink / duplicate.
|
| 25 |
+
- `src/` (entire package)
|
| 26 |
+
- `requirements-space.txt` from this folder as **`requirements.txt`**
|
| 27 |
+
4. In Space **Settings → Repository secrets** (if needed): none required for public weights.
|
| 28 |
+
5. Ensure `checkpoints/final/final_model.pth` is present:
|
| 29 |
+
- Upload via **Files** tab, or
|
| 30 |
+
- Add a startup script to download from `AnikS22/MidasMap` on the Hub (see HF docs for `hf_hub_download`).
|
| 31 |
+
|
| 32 |
+
After the Space builds, point your **Vercel** site (`vercel-site`) at it:
|
| 33 |
+
|
| 34 |
+
`https://yoursite.vercel.app/?embed=https://huggingface.co/spaces/YOUR_USER/YOUR_SPACE`
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
Gradio app and model logic: [github.com/AnikS22/MidasMap](https://github.com/AnikS22/MidasMap)
|
app.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MidasMap — Immunogold particle analysis for FFRIL / TEM synapse imaging
|
| 3 |
+
|
| 4 |
+
Web UI for neuroscientists: calibrated coordinates (µm), receptor labels,
|
| 5 |
+
export for quantification, and clear interpretation of model limits.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python app.py
|
| 9 |
+
python app.py --checkpoint checkpoints/final/final_model.pth
|
| 10 |
+
python app.py --share
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import tempfile
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import gradio as gr
|
| 21 |
+
import gradio_client.utils as _gcu
|
| 22 |
+
|
| 23 |
+
# Pydantic v2 can emit JSON Schema with additionalProperties: true (bool);
|
| 24 |
+
# Gradio 4.4x gradio_client assumes a dict and crashes rendering "/".
|
| 25 |
+
_orig_json_type = _gcu._json_schema_to_python_type
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _json_schema_to_python_type_safe(schema, defs=None):
|
| 29 |
+
if schema is True or schema is False:
|
| 30 |
+
return "Any"
|
| 31 |
+
if not isinstance(schema, dict):
|
| 32 |
+
return "Any"
|
| 33 |
+
return _orig_json_type(schema, defs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_gcu._json_schema_to_python_type = _json_schema_to_python_type_safe
|
| 37 |
+
|
| 38 |
+
import matplotlib
|
| 39 |
+
|
| 40 |
+
matplotlib.use("Agg")
|
| 41 |
+
import matplotlib.patheffects as pe
|
| 42 |
+
import matplotlib.pyplot as plt
|
| 43 |
+
from matplotlib.patches import Patch
|
| 44 |
+
import numpy as np
|
| 45 |
+
import pandas as pd
|
| 46 |
+
import torch
|
| 47 |
+
import tifffile
|
| 48 |
+
|
| 49 |
+
from src.ensemble import sliding_window_inference
|
| 50 |
+
from src.heatmap import extract_peaks
|
| 51 |
+
from src.model import ImmunogoldCenterNet
|
| 52 |
+
from src.postprocess import cross_class_nms
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Calibration used for training / published metrics (change in UI if your scope differs)
|
| 56 |
+
DEFAULT_PX_PER_UM = 1790.0
|
| 57 |
+
|
| 58 |
+
plt.rcParams.update(
|
| 59 |
+
{
|
| 60 |
+
"figure.facecolor": "white",
|
| 61 |
+
"figure.dpi": 120,
|
| 62 |
+
"savefig.facecolor": "white",
|
| 63 |
+
"axes.facecolor": "#fafafa",
|
| 64 |
+
"axes.edgecolor": "#cbd5e1",
|
| 65 |
+
"axes.linewidth": 0.8,
|
| 66 |
+
"axes.labelcolor": "#1e293b",
|
| 67 |
+
"axes.titlecolor": "#0f172a",
|
| 68 |
+
"axes.grid": False,
|
| 69 |
+
"xtick.color": "#475569",
|
| 70 |
+
"ytick.color": "#475569",
|
| 71 |
+
"font.size": 10,
|
| 72 |
+
"axes.titlesize": 11,
|
| 73 |
+
"axes.labelsize": 10,
|
| 74 |
+
"legend.frameon": True,
|
| 75 |
+
"legend.framealpha": 0.92,
|
| 76 |
+
"legend.edgecolor": "#e2e8f0",
|
| 77 |
+
}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
MODEL = None
|
| 82 |
+
DEVICE = None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_model(checkpoint_path: str):
|
| 86 |
+
global MODEL, DEVICE
|
| 87 |
+
DEVICE = torch.device(
|
| 88 |
+
"cuda"
|
| 89 |
+
if torch.cuda.is_available()
|
| 90 |
+
else "mps"
|
| 91 |
+
if torch.backends.mps.is_available()
|
| 92 |
+
else "cpu"
|
| 93 |
+
)
|
| 94 |
+
MODEL = ImmunogoldCenterNet(bifpn_channels=128, bifpn_rounds=2)
|
| 95 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 96 |
+
MODEL.load_state_dict(ckpt["model_state_dict"])
|
| 97 |
+
MODEL.to(DEVICE)
|
| 98 |
+
MODEL.eval()
|
| 99 |
+
print(f"Model loaded from {checkpoint_path} on {DEVICE}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _receptor_label(class_name: str) -> str:
|
| 103 |
+
return "AMPA receptor" if class_name == "6nm" else "NR1 (NMDA receptor)"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _gold_nm(class_name: str) -> int:
|
| 107 |
+
return 6 if class_name == "6nm" else 12
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _pick_scale_bar_um(field_width_um: float) -> float:
|
| 111 |
+
"""Pick a readable scale bar (~15–30% of field width)."""
|
| 112 |
+
if field_width_um <= 0:
|
| 113 |
+
return 0.2
|
| 114 |
+
target = field_width_um * 0.22
|
| 115 |
+
candidates = (0.05, 0.1, 0.2, 0.25, 0.5, 1.0, 2.0, 5.0)
|
| 116 |
+
best = candidates[0]
|
| 117 |
+
for c in candidates:
|
| 118 |
+
if abs(c - target) < abs(best - target):
|
| 119 |
+
best = c
|
| 120 |
+
# Keep bar from dominating the field
|
| 121 |
+
while best > 0 and best / field_width_um > 0.45:
|
| 122 |
+
best = max(0.05, best / 2)
|
| 123 |
+
return float(best)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _draw_scale_bar_um(ax, w: int, h: int, px_per_um: float) -> None:
|
| 127 |
+
field_um = max(w, h) / px_per_um
|
| 128 |
+
bar_um = _pick_scale_bar_um(field_um)
|
| 129 |
+
bar_px = bar_um * px_per_um
|
| 130 |
+
margin = max(12, int(min(w, h) * 0.025))
|
| 131 |
+
y_line = h - margin
|
| 132 |
+
x0, x1 = margin, margin + bar_px
|
| 133 |
+
for lw, color in ((5, "white"), (2, "#0f172a")):
|
| 134 |
+
ax.plot([x0, x1], [y_line, y_line], color=color, linewidth=lw, solid_capstyle="butt", clip_on=False)
|
| 135 |
+
t = ax.text(
|
| 136 |
+
(x0 + x1) / 2,
|
| 137 |
+
y_line - margin * 0.35,
|
| 138 |
+
f"{bar_um:g} µm",
|
| 139 |
+
ha="center",
|
| 140 |
+
va="bottom",
|
| 141 |
+
color="white",
|
| 142 |
+
fontsize=9,
|
| 143 |
+
fontweight="600",
|
| 144 |
+
)
|
| 145 |
+
t.set_path_effects([pe.withStroke(linewidth=2.5, foreground="#0f172a")])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _export_columns() -> list[str]:
|
| 149 |
+
return [
|
| 150 |
+
"particle_id",
|
| 151 |
+
"receptor",
|
| 152 |
+
"gold_diameter_nm",
|
| 153 |
+
"x_px",
|
| 154 |
+
"y_px",
|
| 155 |
+
"x_um",
|
| 156 |
+
"y_um",
|
| 157 |
+
"confidence",
|
| 158 |
+
"class_model",
|
| 159 |
+
"calibration_px_per_um",
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _empty_results_df() -> pd.DataFrame:
|
| 164 |
+
return pd.DataFrame(columns=_export_columns())
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _df_to_preview_html(df: pd.DataFrame) -> str:
|
| 168 |
+
if df is None or len(df) == 0:
|
| 169 |
+
return "<p class='mm-table-empty'><em>No particles above the current threshold.</em></p>"
|
| 170 |
+
return df.to_html(
|
| 171 |
+
classes=["mm-table"],
|
| 172 |
+
index=False,
|
| 173 |
+
border=0,
|
| 174 |
+
justify="left",
|
| 175 |
+
escape=True,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def detect_particles(
|
| 180 |
+
image_file,
|
| 181 |
+
conf_threshold: float = 0.25,
|
| 182 |
+
nms_6nm: int = 3,
|
| 183 |
+
nms_12nm: int = 5,
|
| 184 |
+
px_per_um: float = DEFAULT_PX_PER_UM,
|
| 185 |
+
progress=gr.Progress(track_tqdm=False),
|
| 186 |
+
):
|
| 187 |
+
"""Run detection; returns figures, CSV path, table HTML, and summary HTML."""
|
| 188 |
+
empty_table = "<p class='mm-table-empty'><em>Run detection to populate the table.</em></p>"
|
| 189 |
+
|
| 190 |
+
if MODEL is None:
|
| 191 |
+
msg = "<p class='mm-callout mm-callout-warn'>Model not loaded. Use <code>--checkpoint</code> with a valid <code>.pth</code> file.</p>"
|
| 192 |
+
return None, None, None, None, empty_table, msg
|
| 193 |
+
|
| 194 |
+
if image_file is None:
|
| 195 |
+
msg = "<p class='mm-callout'>Upload a micrograph, set calibration if needed, then run detection.</p>"
|
| 196 |
+
return None, None, None, None, empty_table, msg
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
px_per_um = float(px_per_um)
|
| 200 |
+
except (TypeError, ValueError):
|
| 201 |
+
px_per_um = DEFAULT_PX_PER_UM
|
| 202 |
+
if px_per_um <= 0:
|
| 203 |
+
px_per_um = DEFAULT_PX_PER_UM
|
| 204 |
+
|
| 205 |
+
progress(0.05, desc="Loading image…")
|
| 206 |
+
|
| 207 |
+
if isinstance(image_file, str):
|
| 208 |
+
img = tifffile.imread(image_file)
|
| 209 |
+
elif hasattr(image_file, "name"):
|
| 210 |
+
img = tifffile.imread(image_file.name)
|
| 211 |
+
else:
|
| 212 |
+
img = np.array(image_file)
|
| 213 |
+
|
| 214 |
+
if img.ndim == 3:
|
| 215 |
+
img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
|
| 216 |
+
img = img.astype(np.uint8)
|
| 217 |
+
|
| 218 |
+
h, w = img.shape[:2]
|
| 219 |
+
field_w_um = w / px_per_um
|
| 220 |
+
field_h_um = h / px_per_um
|
| 221 |
+
|
| 222 |
+
progress(0.15, desc="Neural network (sliding window)…")
|
| 223 |
+
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
hm_np, off_np = sliding_window_inference(
|
| 226 |
+
MODEL,
|
| 227 |
+
img,
|
| 228 |
+
patch_size=512,
|
| 229 |
+
overlap=128,
|
| 230 |
+
device=DEVICE,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
progress(0.72, desc="Peak extraction & NMS…")
|
| 234 |
+
|
| 235 |
+
dets = extract_peaks(
|
| 236 |
+
torch.from_numpy(hm_np),
|
| 237 |
+
torch.from_numpy(off_np),
|
| 238 |
+
stride=2,
|
| 239 |
+
conf_threshold=conf_threshold,
|
| 240 |
+
nms_kernel_sizes={"6nm": nms_6nm, "12nm": nms_12nm},
|
| 241 |
+
)
|
| 242 |
+
dets = cross_class_nms(dets, distance_threshold=8)
|
| 243 |
+
|
| 244 |
+
n_6nm = sum(1 for d in dets if d["class"] == "6nm")
|
| 245 |
+
n_12nm = sum(1 for d in dets if d["class"] == "12nm")
|
| 246 |
+
confs_6 = [d["conf"] for d in dets if d["class"] == "6nm"]
|
| 247 |
+
confs_12 = [d["conf"] for d in dets if d["class"] == "12nm"]
|
| 248 |
+
|
| 249 |
+
progress(0.78, desc="Rendering figures…")
|
| 250 |
+
|
| 251 |
+
from skimage.transform import resize
|
| 252 |
+
|
| 253 |
+
hm6_up = resize(hm_np[0], (h, w), order=1)
|
| 254 |
+
hm12_up = resize(hm_np[1], (h, w), order=1)
|
| 255 |
+
|
| 256 |
+
# --- Overlay (publication-style legend + scale bar) ---
|
| 257 |
+
fig_overlay, ax = plt.subplots(figsize=(11, 11))
|
| 258 |
+
ax.imshow(img, cmap="gray", aspect="equal")
|
| 259 |
+
for d in dets:
|
| 260 |
+
color = "#06b6d4" if d["class"] == "6nm" else "#ca8a04"
|
| 261 |
+
radius = 7 if d["class"] == "6nm" else 12
|
| 262 |
+
ax.add_patch(
|
| 263 |
+
plt.Circle(
|
| 264 |
+
(d["x"], d["y"]),
|
| 265 |
+
radius,
|
| 266 |
+
fill=False,
|
| 267 |
+
edgecolor=color,
|
| 268 |
+
linewidth=1.8,
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
_draw_scale_bar_um(ax, w, h, px_per_um)
|
| 272 |
+
ax.set_title(
|
| 273 |
+
f"Immunogold detections · AMPA (6 nm): {n_6nm} · NR1 (12 nm): {n_12nm} · Total: {len(dets)}",
|
| 274 |
+
fontsize=11,
|
| 275 |
+
pad=12,
|
| 276 |
+
)
|
| 277 |
+
ax.axis("off")
|
| 278 |
+
legend_elems = [
|
| 279 |
+
Patch(facecolor="none", edgecolor="#06b6d4", linewidth=2, label="6 nm gold — AMPA receptor"),
|
| 280 |
+
Patch(facecolor="none", edgecolor="#ca8a04", linewidth=2, label="12 nm gold — NR1 (NMDAR)"),
|
| 281 |
+
]
|
| 282 |
+
ax.legend(
|
| 283 |
+
handles=legend_elems,
|
| 284 |
+
loc="upper right",
|
| 285 |
+
fontsize=8.5,
|
| 286 |
+
title="Label class",
|
| 287 |
+
title_fontsize=9,
|
| 288 |
+
)
|
| 289 |
+
plt.tight_layout()
|
| 290 |
+
fig_overlay.canvas.draw()
|
| 291 |
+
overlay_img = np.asarray(fig_overlay.canvas.renderer.buffer_rgba())[:, :, :3]
|
| 292 |
+
plt.close(fig_overlay)
|
| 293 |
+
|
| 294 |
+
# --- Heatmaps ---
|
| 295 |
+
fig_hm, axes = plt.subplots(1, 2, figsize=(14, 6.2))
|
| 296 |
+
axes[0].imshow(img, cmap="gray", aspect="equal")
|
| 297 |
+
axes[0].imshow(hm6_up, cmap="magma", alpha=0.55, vmin=0, vmax=max(0.3, float(hm6_up.max())))
|
| 298 |
+
axes[0].set_title(f"AMPA (6 nm) channel · n = {n_6nm}", fontsize=11)
|
| 299 |
+
axes[0].axis("off")
|
| 300 |
+
|
| 301 |
+
axes[1].imshow(img, cmap="gray", aspect="equal")
|
| 302 |
+
axes[1].imshow(hm12_up, cmap="inferno", alpha=0.55, vmin=0, vmax=max(0.3, float(hm12_up.max())))
|
| 303 |
+
axes[1].set_title(f"NR1 (12 nm) channel · n = {n_12nm}", fontsize=11)
|
| 304 |
+
axes[1].axis("off")
|
| 305 |
+
plt.tight_layout()
|
| 306 |
+
fig_hm.canvas.draw()
|
| 307 |
+
heatmap_img = np.asarray(fig_hm.canvas.renderer.buffer_rgba())[:, :, :3]
|
| 308 |
+
plt.close(fig_hm)
|
| 309 |
+
|
| 310 |
+
# --- Stats (µm where helpful) ---
|
| 311 |
+
fig_stats, axes = plt.subplots(1, 3, figsize=(16, 4.8))
|
| 312 |
+
if dets:
|
| 313 |
+
if confs_6:
|
| 314 |
+
axes[0].hist(confs_6, bins=18, alpha=0.75, color="#0891b2", label=f"AMPA (n={len(confs_6)})")
|
| 315 |
+
if confs_12:
|
| 316 |
+
axes[0].hist(confs_12, bins=18, alpha=0.75, color="#a16207", label=f"NR1 (n={len(confs_12)})")
|
| 317 |
+
axes[0].axvline(conf_threshold, color="#be123c", linestyle="--", linewidth=1.2, label=f"Threshold = {conf_threshold:.2f}")
|
| 318 |
+
axes[0].legend(fontsize=8)
|
| 319 |
+
axes[0].set_xlabel("Confidence score")
|
| 320 |
+
axes[0].set_ylabel("Count")
|
| 321 |
+
axes[0].set_title("Score distribution")
|
| 322 |
+
axes[0].spines["top"].set_visible(False)
|
| 323 |
+
axes[0].spines["right"].set_visible(False)
|
| 324 |
+
|
| 325 |
+
if dets:
|
| 326 |
+
xs_um = np.array([d["x"] for d in dets]) / px_per_um
|
| 327 |
+
ys_um = np.array([d["y"] for d in dets]) / px_per_um
|
| 328 |
+
colors = ["#0891b2" if d["class"] == "6nm" else "#a16207" for d in dets]
|
| 329 |
+
axes[1].scatter(xs_um, ys_um, c=colors, s=22, alpha=0.75, edgecolors="none")
|
| 330 |
+
axes[1].set_xlim(0, field_w_um)
|
| 331 |
+
axes[1].set_ylim(field_h_um, 0)
|
| 332 |
+
axes[1].set_xlabel("x (µm)")
|
| 333 |
+
axes[1].set_ylabel("y (µm)")
|
| 334 |
+
axes[1].set_title("Positions (image coordinates)")
|
| 335 |
+
axes[1].set_aspect("equal")
|
| 336 |
+
axes[1].spines["top"].set_visible(False)
|
| 337 |
+
axes[1].spines["right"].set_visible(False)
|
| 338 |
+
|
| 339 |
+
axes[2].axis("off")
|
| 340 |
+
table_data = [
|
| 341 |
+
["Field of view", f"{field_w_um:.3f} × {field_h_um:.3f} µm"],
|
| 342 |
+
["Calibration", f"{px_per_um:.1f} px/µm"],
|
| 343 |
+
["AMPA (6 nm)", str(n_6nm)],
|
| 344 |
+
["NR1 (12 nm)", str(n_12nm)],
|
| 345 |
+
["Total particles", str(len(dets))],
|
| 346 |
+
["Score threshold", f"{conf_threshold:.2f}"],
|
| 347 |
+
["Mean score · AMPA", f"{float(np.mean(confs_6)):.3f}" if confs_6 else "—"],
|
| 348 |
+
["Mean score · NR1", f"{float(np.mean(confs_12)):.3f}" if confs_12 else "—"],
|
| 349 |
+
]
|
| 350 |
+
tbl = axes[2].table(
|
| 351 |
+
cellText=table_data,
|
| 352 |
+
colLabels=["Quantity", "Value"],
|
| 353 |
+
loc="center",
|
| 354 |
+
cellLoc="left",
|
| 355 |
+
)
|
| 356 |
+
tbl.auto_set_font_size(False)
|
| 357 |
+
tbl.set_fontsize(10)
|
| 358 |
+
tbl.scale(1.05, 1.65)
|
| 359 |
+
for (row, col), cell in tbl.get_celld().items():
|
| 360 |
+
if row == 0:
|
| 361 |
+
cell.set_text_props(fontweight="600")
|
| 362 |
+
cell.set_facecolor("#e2e8f0")
|
| 363 |
+
axes[2].set_title("Summary", fontsize=11, pad=12)
|
| 364 |
+
plt.tight_layout()
|
| 365 |
+
fig_stats.canvas.draw()
|
| 366 |
+
stats_img = np.asarray(fig_stats.canvas.renderer.buffer_rgba())[:, :, :3]
|
| 367 |
+
plt.close(fig_stats)
|
| 368 |
+
|
| 369 |
+
rows = []
|
| 370 |
+
for i, d in enumerate(dets):
|
| 371 |
+
rows.append(
|
| 372 |
+
{
|
| 373 |
+
"particle_id": i + 1,
|
| 374 |
+
"receptor": _receptor_label(d["class"]),
|
| 375 |
+
"gold_diameter_nm": _gold_nm(d["class"]),
|
| 376 |
+
"x_px": round(d["x"], 2),
|
| 377 |
+
"y_px": round(d["y"], 2),
|
| 378 |
+
"x_um": round(d["x"] / px_per_um, 5),
|
| 379 |
+
"y_um": round(d["y"] / px_per_um, 5),
|
| 380 |
+
"confidence": round(d["conf"], 4),
|
| 381 |
+
"class_model": d["class"],
|
| 382 |
+
"calibration_px_per_um": round(px_per_um, 4),
|
| 383 |
+
}
|
| 384 |
+
)
|
| 385 |
+
df = pd.DataFrame(rows, columns=_export_columns()) if rows else _empty_results_df()
|
| 386 |
+
|
| 387 |
+
csv_f = tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w", encoding="utf-8")
|
| 388 |
+
df.to_csv(csv_f.name, index=False)
|
| 389 |
+
csv_f.close()
|
| 390 |
+
|
| 391 |
+
progress(1.0, desc="Done")
|
| 392 |
+
|
| 393 |
+
density_note = ""
|
| 394 |
+
if field_w_um > 0 and field_h_um > 0:
|
| 395 |
+
area = field_w_um * field_h_um
|
| 396 |
+
density_note = f"<span class='mm-density'>Areal density (all): {len(dets) / area:.2f} particles/µm² · AMPA: {n_6nm / area:.2f} · NR1: {n_12nm / area:.2f}</span>"
|
| 397 |
+
|
| 398 |
+
summary = f"""<div class="mm-summary">
|
| 399 |
+
<div class="mm-stat"><span class="mm-stat-label">AMPA · 6 nm gold</span>
|
| 400 |
+
<span class="mm-stat-value mm-teal">{n_6nm}</span></div>
|
| 401 |
+
<div class="mm-stat"><span class="mm-stat-label">NR1 · 12 nm gold</span>
|
| 402 |
+
<span class="mm-stat-value mm-amber">{n_12nm}</span></div>
|
| 403 |
+
<div class="mm-stat"><span class="mm-stat-label">Total</span>
|
| 404 |
+
<span class="mm-stat-value">{len(dets)}</span></div>
|
| 405 |
+
<div class="mm-stat mm-stat-wide"><span class="mm-stat-label">Field & calibration</span>
|
| 406 |
+
<span class="mm-stat-meta">{field_w_um:.3f} × {field_h_um:.3f} µm · {px_per_um:.1f} px/µm · {DEVICE}</span></div>
|
| 407 |
+
{density_note and f'<div class="mm-stat mm-stat-wide">{density_note}</div>'}
|
| 408 |
+
</div>"""
|
| 409 |
+
|
| 410 |
+
return overlay_img, heatmap_img, stats_img, csv_f.name, _df_to_preview_html(df), summary
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
MM_CSS = """
|
| 414 |
+
.gradio-container { max-width: 1320px !important; margin: auto !important; }
|
| 415 |
+
.mm-brand-bar {
|
| 416 |
+
display: flex; align-items: center; justify-content: space-between;
|
| 417 |
+
flex-wrap: wrap; gap: 0.75rem;
|
| 418 |
+
padding: 0.6rem 0 1.25rem;
|
| 419 |
+
border-bottom: 1px solid var(--border-color-primary);
|
| 420 |
+
margin-bottom: 1.25rem;
|
| 421 |
+
}
|
| 422 |
+
.mm-brand-bar span {
|
| 423 |
+
font-size: 0.72rem; letter-spacing: 0.14em; text-transform: uppercase;
|
| 424 |
+
color: var(--body-text-color-subdued); font-weight: 600;
|
| 425 |
+
}
|
| 426 |
+
.mm-hero {
|
| 427 |
+
padding: 1.5rem 1.35rem 1.35rem;
|
| 428 |
+
margin-bottom: 0.25rem;
|
| 429 |
+
border-radius: 10px;
|
| 430 |
+
background: linear-gradient(145deg, #0c4a6e22 0%, #0f172a 48%, #1e1b4b33 100%);
|
| 431 |
+
border: 1px solid #33415588;
|
| 432 |
+
}
|
| 433 |
+
.mm-hero h1 {
|
| 434 |
+
font-family: "Libre Baskerville", Georgia, serif;
|
| 435 |
+
font-weight: 700;
|
| 436 |
+
letter-spacing: -0.02em;
|
| 437 |
+
margin: 0 0 0.4rem 0;
|
| 438 |
+
font-size: 1.65rem;
|
| 439 |
+
color: #f1f5f9;
|
| 440 |
+
}
|
| 441 |
+
.mm-hero .mm-sub {
|
| 442 |
+
margin: 0 0 0.85rem 0;
|
| 443 |
+
color: #94a3b8;
|
| 444 |
+
font-size: 0.92rem;
|
| 445 |
+
line-height: 1.55;
|
| 446 |
+
max-width: 58ch;
|
| 447 |
+
}
|
| 448 |
+
.mm-badge-row { display: flex; flex-wrap: wrap; gap: 0.4rem; }
|
| 449 |
+
.mm-badge {
|
| 450 |
+
font-size: 0.65rem; text-transform: uppercase; letter-spacing: 0.07em;
|
| 451 |
+
padding: 0.2rem 0.5rem; border-radius: 4px;
|
| 452 |
+
background: #0e749033; color: #99f6e4; border: 1px solid #14b8a644;
|
| 453 |
+
}
|
| 454 |
+
.mm-layout { display: flex; gap: 1.25rem; align-items: flex-start; flex-wrap: wrap; }
|
| 455 |
+
.mm-sidebar {
|
| 456 |
+
flex: 1 1 280px; max-width: 340px;
|
| 457 |
+
padding: 1rem 1.1rem; border-radius: 10px;
|
| 458 |
+
border: 1px solid var(--border-color-primary);
|
| 459 |
+
background: var(--block-background-fill);
|
| 460 |
+
}
|
| 461 |
+
.mm-main { flex: 3 1 520px; min-width: 0; }
|
| 462 |
+
.mm-panel-title {
|
| 463 |
+
font-size: 0.7rem; text-transform: uppercase; letter-spacing: 0.1em;
|
| 464 |
+
color: var(--body-text-color-subdued); font-weight: 600; margin: 0 0 0.65rem 0;
|
| 465 |
+
}
|
| 466 |
+
.mm-callout {
|
| 467 |
+
margin: 0; padding: 0.75rem 0.9rem; border-radius: 8px;
|
| 468 |
+
background: #1e293b66; border: 1px solid var(--border-color-primary);
|
| 469 |
+
font-size: 0.88rem; line-height: 1.45; color: var(--body-text-color);
|
| 470 |
+
}
|
| 471 |
+
.mm-callout-warn { border-color: #f59e0b55; background: #78350f22; }
|
| 472 |
+
.mm-science {
|
| 473 |
+
margin-top: 1rem; font-size: 0.82rem; line-height: 1.5;
|
| 474 |
+
color: var(--body-text-color-subdued);
|
| 475 |
+
}
|
| 476 |
+
.mm-science h4 { margin: 0.5rem 0 0.35rem; font-size: 0.78rem; text-transform: uppercase; letter-spacing: 0.06em; color: #94a3b8; }
|
| 477 |
+
.mm-science ul { margin: 0.25rem 0 0 1rem; padding: 0; }
|
| 478 |
+
.mm-summary { display: flex; flex-wrap: wrap; gap: 0.65rem; margin: 0 0 1rem 0; }
|
| 479 |
+
.mm-stat {
|
| 480 |
+
flex: 1 1 118px; padding: 0.75rem 0.95rem; border-radius: 8px;
|
| 481 |
+
background: var(--block-background-fill);
|
| 482 |
+
border: 1px solid var(--border-color-primary);
|
| 483 |
+
}
|
| 484 |
+
.mm-stat-wide { flex: 1 1 100%; }
|
| 485 |
+
.mm-stat-label {
|
| 486 |
+
display: block; font-size: 0.68rem; text-transform: uppercase;
|
| 487 |
+
letter-spacing: 0.06em; opacity: 0.72; margin-bottom: 0.2rem;
|
| 488 |
+
}
|
| 489 |
+
.mm-stat-value { font-size: 1.4rem; font-weight: 700; font-variant-numeric: tabular-nums; letter-spacing: -0.02em; }
|
| 490 |
+
.mm-stat-value.mm-teal { color: #2dd4bf; }
|
| 491 |
+
.mm-stat-value.mm-amber { color: #fbbf24; }
|
| 492 |
+
.mm-stat-meta { font-size: 0.84rem; opacity: 0.92; line-height: 1.35; }
|
| 493 |
+
.mm-density { font-size: 0.84rem; opacity: 0.9; }
|
| 494 |
+
table.mm-table {
|
| 495 |
+
width: 100%; border-collapse: collapse; font-size: 0.82rem;
|
| 496 |
+
margin: 0.25rem 0 0.75rem 0;
|
| 497 |
+
}
|
| 498 |
+
table.mm-table th {
|
| 499 |
+
text-align: left; padding: 0.45rem 0.5rem;
|
| 500 |
+
border-bottom: 1px solid var(--border-color-primary);
|
| 501 |
+
color: var(--body-text-color-subdued); font-weight: 600;
|
| 502 |
+
}
|
| 503 |
+
table.mm-table td { padding: 0.35rem 0.5rem; border-bottom: 1px solid #33415544; }
|
| 504 |
+
.mm-table-empty { margin: 0.5rem 0; opacity: 0.75; font-size: 0.9rem; }
|
| 505 |
+
.mm-foot {
|
| 506 |
+
margin-top: 2rem; padding-top: 1rem;
|
| 507 |
+
border-top: 1px solid var(--border-color-primary);
|
| 508 |
+
font-size: 0.78rem; line-height: 1.45;
|
| 509 |
+
color: var(--body-text-color-subdued);
|
| 510 |
+
}
|
| 511 |
+
.mm-foot code { font-size: 0.76rem; }
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def build_app():
|
| 516 |
+
theme = gr.themes.Soft(
|
| 517 |
+
primary_hue=gr.themes.Color(
|
| 518 |
+
c50="#f0fdfa",
|
| 519 |
+
c100="#ccfbf1",
|
| 520 |
+
c200="#99f6e4",
|
| 521 |
+
c300="#5eead4",
|
| 522 |
+
c400="#2dd4bf",
|
| 523 |
+
c500="#14b8a6",
|
| 524 |
+
c600="#0d9488",
|
| 525 |
+
c700="#0f766e",
|
| 526 |
+
c800="#115e59",
|
| 527 |
+
c900="#134e4a",
|
| 528 |
+
c950="#042f2e",
|
| 529 |
+
),
|
| 530 |
+
neutral_hue=gr.themes.colors.slate,
|
| 531 |
+
font=("Source Sans 3", "ui-sans-serif", "system-ui", "sans-serif"),
|
| 532 |
+
font_mono=("IBM Plex Mono", "ui-monospace", "monospace"),
|
| 533 |
+
).set(
|
| 534 |
+
body_background_fill_dark="*neutral_950",
|
| 535 |
+
block_background_fill_dark="*neutral_900",
|
| 536 |
+
border_color_primary="*neutral_700",
|
| 537 |
+
button_primary_background_fill="*primary_600",
|
| 538 |
+
button_primary_background_fill_hover="*primary_500",
|
| 539 |
+
block_label_text_size="*text_sm",
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
head = """
|
| 543 |
+
<link href="https://fonts.googleapis.com/css2?family=Libre+Baskerville:wght@700&family=Source+Sans+3:wght@400;600;700&display=swap" rel="stylesheet">
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
with gr.Blocks(
|
| 547 |
+
title="MidasMap — Immunogold analysis",
|
| 548 |
+
theme=theme,
|
| 549 |
+
css=MM_CSS,
|
| 550 |
+
head=head,
|
| 551 |
+
) as app:
|
| 552 |
+
gr.HTML(
|
| 553 |
+
"""
|
| 554 |
+
<div class="mm-brand-bar">
|
| 555 |
+
<span>Quantitative EM · synapse immunogold</span>
|
| 556 |
+
<span>Research use · validate critical counts manually</span>
|
| 557 |
+
</div>
|
| 558 |
+
<div class="mm-hero">
|
| 559 |
+
<h1>MidasMap</h1>
|
| 560 |
+
<p class="mm-sub">
|
| 561 |
+
Automated particle picking for <strong>freeze-fracture replica immunolabeling (FFRIL)</strong> TEM:
|
| 562 |
+
<strong>6 nm</strong> gold (AMPA receptors) and <strong>12 nm</strong> gold (NR1 / NMDA receptors).
|
| 563 |
+
Coordinates export in <strong>µm</strong> for comparison to physiology and super-resolution data—set calibration to match your microscope.
|
| 564 |
+
</p>
|
| 565 |
+
<div class="mm-badge-row">
|
| 566 |
+
<span class="mm-badge">FFRIL / TEM</span>
|
| 567 |
+
<span class="mm-badge">CenterNet</span>
|
| 568 |
+
<span class="mm-badge">CEM500K backbone</span>
|
| 569 |
+
<span class="mm-badge">LOOCV F1 ≈ 0.94</span>
|
| 570 |
+
</div>
|
| 571 |
+
</div>
|
| 572 |
+
"""
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
with gr.Row(elem_classes=["mm-layout"]):
|
| 576 |
+
with gr.Column(elem_classes=["mm-sidebar"]):
|
| 577 |
+
gr.HTML('<p class="mm-panel-title">Micrograph & calibration</p>')
|
| 578 |
+
image_input = gr.File(
|
| 579 |
+
label="Upload image",
|
| 580 |
+
file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"],
|
| 581 |
+
)
|
| 582 |
+
px_per_um_in = gr.Number(
|
| 583 |
+
value=DEFAULT_PX_PER_UM,
|
| 584 |
+
label="Calibration (pixels per µm)",
|
| 585 |
+
info=f"Default {DEFAULT_PX_PER_UM:.0f} matches the published training set. "
|
| 586 |
+
"Update if your acquisition scale differs.",
|
| 587 |
+
minimum=1,
|
| 588 |
+
maximum=1e6,
|
| 589 |
+
)
|
| 590 |
+
conf_slider = gr.Slider(
|
| 591 |
+
minimum=0.05,
|
| 592 |
+
maximum=0.95,
|
| 593 |
+
value=0.25,
|
| 594 |
+
step=0.05,
|
| 595 |
+
label="Confidence threshold",
|
| 596 |
+
info="Higher → fewer, sharper peaks. Lower → recall with more false positives.",
|
| 597 |
+
)
|
| 598 |
+
with gr.Accordion("Advanced · non-max suppression", open=False):
|
| 599 |
+
nms_6nm = gr.Slider(
|
| 600 |
+
minimum=1,
|
| 601 |
+
maximum=9,
|
| 602 |
+
value=3,
|
| 603 |
+
step=2,
|
| 604 |
+
label="NMS · 6 nm channel",
|
| 605 |
+
info="Minimum spacing between AMPA peaks on the heatmap grid.",
|
| 606 |
+
)
|
| 607 |
+
nms_12nm = gr.Slider(
|
| 608 |
+
minimum=1,
|
| 609 |
+
maximum=9,
|
| 610 |
+
value=5,
|
| 611 |
+
step=2,
|
| 612 |
+
label="NMS · 12 nm channel",
|
| 613 |
+
)
|
| 614 |
+
detect_btn = gr.Button("Run detection", variant="primary", size="lg")
|
| 615 |
+
|
| 616 |
+
with gr.Accordion("For neuroscientists — interpretation", open=False):
|
| 617 |
+
gr.Markdown(
|
| 618 |
+
"""
|
| 619 |
+
#### What the model outputs
|
| 620 |
+
- **Circles** mark predicted gold centers; **scores** are CNN confidences, not p-values.
|
| 621 |
+
- **AMPA** = 6 nm class; **NR1** = 12 nm class (NMDA receptor subunit). Verify ambiguous sites on the raw image.
|
| 622 |
+
|
| 623 |
+
#### When to trust it
|
| 624 |
+
- Trained on **10 FFRIL synapse images** (453 hand-placed particles). Expect best performance on **similar prep, contrast, and magnification**.
|
| 625 |
+
- **Always spot-check** counts used for publication, especially near membranes and dense clusters.
|
| 626 |
+
|
| 627 |
+
#### Coordinates & CSV
|
| 628 |
+
- **x, y** follow image pixel order (origin top-left). **µm** columns use your calibration above.
|
| 629 |
+
- CSV includes **receptor**, **gold diameter**, and **calibration** used for provenance.
|
| 630 |
+
|
| 631 |
+
#### Citation
|
| 632 |
+
Sahai, A. (2026). *MidasMap* (software). https://github.com/AnikS22/MidasMap
|
| 633 |
+
"""
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
with gr.Column(elem_classes=["mm-main"]):
|
| 637 |
+
summary_md = gr.HTML(
|
| 638 |
+
value="<p class='mm-callout'>Upload a synapse micrograph to begin. Adjust calibration before export if your scale differs from the default.</p>"
|
| 639 |
+
)
|
| 640 |
+
with gr.Tabs():
|
| 641 |
+
with gr.Tab("Overlay"):
|
| 642 |
+
overlay_output = gr.Image(
|
| 643 |
+
label="Detections + scale bar",
|
| 644 |
+
type="numpy",
|
| 645 |
+
height=540,
|
| 646 |
+
)
|
| 647 |
+
with gr.Tab("Heatmaps"):
|
| 648 |
+
heatmap_output = gr.Image(
|
| 649 |
+
label="Class-specific maps",
|
| 650 |
+
type="numpy",
|
| 651 |
+
height=540,
|
| 652 |
+
)
|
| 653 |
+
with gr.Tab("Quant summary"):
|
| 654 |
+
stats_output = gr.Image(
|
| 655 |
+
label="Distributions & table",
|
| 656 |
+
type="numpy",
|
| 657 |
+
height=440,
|
| 658 |
+
)
|
| 659 |
+
with gr.Tab("Table & export"):
|
| 660 |
+
table_output = gr.HTML(
|
| 661 |
+
label="Detections (preview)",
|
| 662 |
+
value="<p class='mm-table-empty'><em>Results appear here after detection.</em></p>",
|
| 663 |
+
)
|
| 664 |
+
csv_output = gr.File(label="Download CSV")
|
| 665 |
+
|
| 666 |
+
gr.HTML(
|
| 667 |
+
f"""
|
| 668 |
+
<div class="mm-foot">
|
| 669 |
+
<strong>Training context:</strong> LOOCV mean F1 ≈ 0.94 on eight well-annotated folds;
|
| 670 |
+
raw grayscale input (avoid heavy filtering). Not a clinical device.
|
| 671 |
+
Model weights: <code>checkpoints/final/final_model.pth</code> or
|
| 672 |
+
<a href="https://huggingface.co/AnikS22/MidasMap" target="_blank" rel="noopener">Hugging Face</a>.
|
| 673 |
+
</div>
|
| 674 |
+
"""
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
detect_btn.click(
|
| 678 |
+
fn=detect_particles,
|
| 679 |
+
inputs=[image_input, conf_slider, nms_6nm, nms_12nm, px_per_um_in],
|
| 680 |
+
outputs=[
|
| 681 |
+
overlay_output,
|
| 682 |
+
heatmap_output,
|
| 683 |
+
stats_output,
|
| 684 |
+
csv_output,
|
| 685 |
+
table_output,
|
| 686 |
+
summary_md,
|
| 687 |
+
],
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
return app
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def main():
|
| 694 |
+
parser = argparse.ArgumentParser(description="MidasMap web dashboard")
|
| 695 |
+
parser.add_argument(
|
| 696 |
+
"--checkpoint",
|
| 697 |
+
type=str,
|
| 698 |
+
default="checkpoints/final/final_model.pth",
|
| 699 |
+
help="Path to trained checkpoint (.pth)",
|
| 700 |
+
)
|
| 701 |
+
parser.add_argument("--share", action="store_true", help="Gradio public share link (use if localhost is blocked)")
|
| 702 |
+
parser.add_argument(
|
| 703 |
+
"--server-name",
|
| 704 |
+
type=str,
|
| 705 |
+
default=None,
|
| 706 |
+
metavar="HOST",
|
| 707 |
+
help='Bind address, e.g. 0.0.0.0 for LAN (default: 127.0.0.1)',
|
| 708 |
+
)
|
| 709 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 710 |
+
args = parser.parse_args()
|
| 711 |
+
|
| 712 |
+
if os.environ.get("GRADIO_SHARE", "").lower() in ("1", "true", "yes"):
|
| 713 |
+
args.share = True
|
| 714 |
+
|
| 715 |
+
ckpt = Path(args.checkpoint)
|
| 716 |
+
if not ckpt.is_file():
|
| 717 |
+
raise SystemExit(
|
| 718 |
+
f"Checkpoint not found: {ckpt}\n"
|
| 719 |
+
"Train with train_final.py or download from Hugging Face:\n"
|
| 720 |
+
" huggingface-cli download AnikS22/MidasMap checkpoints/final/final_model.pth "
|
| 721 |
+
"--local-dir ."
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
load_model(str(ckpt))
|
| 725 |
+
demo = build_app()
|
| 726 |
+
launch_kw = dict(
|
| 727 |
+
share=args.share,
|
| 728 |
+
server_port=args.port,
|
| 729 |
+
server_name=args.server_name,
|
| 730 |
+
show_api=False,
|
| 731 |
+
inbrowser=False,
|
| 732 |
+
)
|
| 733 |
+
try:
|
| 734 |
+
demo.launch(**launch_kw)
|
| 735 |
+
except ValueError as err:
|
| 736 |
+
if (
|
| 737 |
+
"localhost is not accessible" in str(err)
|
| 738 |
+
and not launch_kw.get("share")
|
| 739 |
+
and os.environ.get("GRADIO_SHARE", "").lower() not in ("1", "true", "yes")
|
| 740 |
+
):
|
| 741 |
+
print(
|
| 742 |
+
"Localhost check failed in this environment; starting with share=True "
|
| 743 |
+
"(Gradio tunnel). Use --share next time, or set GRADIO_SHARE=1."
|
| 744 |
+
)
|
| 745 |
+
build_app().launch(**{**launch_kw, "share": True})
|
| 746 |
+
else:
|
| 747 |
+
raise
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
if __name__ == "__main__":
|
| 751 |
+
main()
|
checkpoints/final/final_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:735d37e839019318cb0e4c7e40d99194abd59f57efd8594ca51602ce3451dfb6
|
| 3 |
+
size 98043418
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pin versions compatible with HF Spaces (adjust if Space build fails).
|
| 2 |
+
# Rename to requirements.txt in the Space repo root.
|
| 3 |
+
numpy>=1.24,<2
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
torchvision>=0.15.0
|
| 6 |
+
scipy>=1.10.0
|
| 7 |
+
scikit-image>=0.21.0
|
| 8 |
+
matplotlib>=3.7.0
|
| 9 |
+
tifffile>=2023.4.0
|
| 10 |
+
pandas>=2.0.0
|
| 11 |
+
PyYAML>=6.0
|
| 12 |
+
albumentations>=1.3.0
|
| 13 |
+
opencv-python-headless>=4.7.0
|
| 14 |
+
gradio==4.44.1
|
| 15 |
+
huggingface_hub>=0.20.0,<0.25.0
|
| 16 |
+
tqdm>=4.65.0
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Immunogold particle detection system for TEM images."""
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
src/__pycache__/ensemble.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
src/__pycache__/evaluate.cpython-311.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|
src/__pycache__/heatmap.cpython-311.pyc
ADDED
|
Binary file (7.64 kB). View file
|
|
|
src/__pycache__/loss.cpython-311.pyc
ADDED
|
Binary file (5.31 kB). View file
|
|
|
src/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
src/dataset.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Dataset for immunogold particle detection.
|
| 3 |
+
|
| 4 |
+
Implements patch-based training with:
|
| 5 |
+
- 70% hard mining (patches centered near particles)
|
| 6 |
+
- 30% random patches (background recognition)
|
| 7 |
+
- Copy-paste augmentation with Gaussian-blended bead bank
|
| 8 |
+
- Albumentations pipeline with keypoint co-transforms
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import albumentations as A
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from torch.utils.data import Dataset
|
| 20 |
+
|
| 21 |
+
from src.heatmap import generate_heatmap_gt
|
| 22 |
+
from src.preprocessing import (
|
| 23 |
+
SynapseRecord,
|
| 24 |
+
load_all_annotations,
|
| 25 |
+
load_image,
|
| 26 |
+
load_mask,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Augmentation pipeline
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
def get_train_augmentation() -> A.Compose:
|
| 35 |
+
"""
|
| 36 |
+
Training augmentation pipeline.
|
| 37 |
+
|
| 38 |
+
Conservative intensity limits: contrast delta is only 11-39 units on uint8.
|
| 39 |
+
DO NOT use Cutout/Mixup/JPEG artifacts — they destroy or mimic particles.
|
| 40 |
+
"""
|
| 41 |
+
return A.Compose(
|
| 42 |
+
[
|
| 43 |
+
# Geometric (co-transform keypoints)
|
| 44 |
+
A.RandomRotate90(p=1.0), # EM is rotation invariant
|
| 45 |
+
A.HorizontalFlip(p=0.5),
|
| 46 |
+
A.VerticalFlip(p=0.5),
|
| 47 |
+
# Only ±10° to avoid interpolation artifacts that destroy contrast
|
| 48 |
+
A.Rotate(
|
| 49 |
+
limit=10,
|
| 50 |
+
border_mode=cv2.BORDER_REFLECT_101,
|
| 51 |
+
p=0.5,
|
| 52 |
+
),
|
| 53 |
+
# Mild elastic deformation (simulates section flatness variation)
|
| 54 |
+
A.ElasticTransform(alpha=30, sigma=5, p=0.3),
|
| 55 |
+
# Intensity (image only)
|
| 56 |
+
A.RandomBrightnessContrast(
|
| 57 |
+
brightness_limit=0.08, # NOT default 0.2
|
| 58 |
+
contrast_limit=0.08,
|
| 59 |
+
p=0.7,
|
| 60 |
+
),
|
| 61 |
+
# EM shot noise simulation
|
| 62 |
+
A.GaussNoise(p=0.5),
|
| 63 |
+
# Mild blur — simulate slight defocus
|
| 64 |
+
A.GaussianBlur(blur_limit=(3, 3), p=0.2),
|
| 65 |
+
],
|
| 66 |
+
keypoint_params=A.KeypointParams(
|
| 67 |
+
format="xy",
|
| 68 |
+
remove_invisible=True,
|
| 69 |
+
label_fields=["class_labels"],
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_val_augmentation() -> A.Compose:
|
| 75 |
+
"""No augmentation for validation — identity transform."""
|
| 76 |
+
return A.Compose(
|
| 77 |
+
[],
|
| 78 |
+
keypoint_params=A.KeypointParams(
|
| 79 |
+
format="xy",
|
| 80 |
+
remove_invisible=True,
|
| 81 |
+
label_fields=["class_labels"],
|
| 82 |
+
),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Bead bank for copy-paste augmentation
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
class BeadBank:
|
| 91 |
+
"""
|
| 92 |
+
Pre-extracted particle crops for copy-paste augmentation.
|
| 93 |
+
|
| 94 |
+
Stores small patches centered on annotated particles from training
|
| 95 |
+
images. During training, random beads are pasted onto patches to
|
| 96 |
+
increase particle density and address class imbalance.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self):
|
| 100 |
+
self.crops: Dict[str, List[Tuple[np.ndarray, int]]] = {
|
| 101 |
+
"6nm": [],
|
| 102 |
+
"12nm": [],
|
| 103 |
+
}
|
| 104 |
+
self.crop_sizes = {"6nm": 32, "12nm": 48}
|
| 105 |
+
|
| 106 |
+
def extract_from_image(
|
| 107 |
+
self,
|
| 108 |
+
image: np.ndarray,
|
| 109 |
+
annotations: Dict[str, np.ndarray],
|
| 110 |
+
):
|
| 111 |
+
"""Extract bead crops from a training image."""
|
| 112 |
+
h, w = image.shape[:2]
|
| 113 |
+
|
| 114 |
+
for cls, coords in annotations.items():
|
| 115 |
+
crop_size = self.crop_sizes[cls]
|
| 116 |
+
half = crop_size // 2
|
| 117 |
+
|
| 118 |
+
for x, y in coords:
|
| 119 |
+
xi, yi = int(round(x)), int(round(y))
|
| 120 |
+
# Skip if too close to edge
|
| 121 |
+
if yi - half < 0 or yi + half > h or xi - half < 0 or xi + half > w:
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
crop = image[yi - half : yi + half, xi - half : xi + half].copy()
|
| 125 |
+
if crop.shape == (crop_size, crop_size):
|
| 126 |
+
self.crops[cls].append((crop, half))
|
| 127 |
+
|
| 128 |
+
def paste_beads(
|
| 129 |
+
self,
|
| 130 |
+
image: np.ndarray,
|
| 131 |
+
coords_6nm: List[Tuple[float, float]],
|
| 132 |
+
coords_12nm: List[Tuple[float, float]],
|
| 133 |
+
class_labels: List[str],
|
| 134 |
+
mask: Optional[np.ndarray] = None,
|
| 135 |
+
n_paste_per_class: int = 5,
|
| 136 |
+
rng: Optional[np.random.Generator] = None,
|
| 137 |
+
) -> Tuple[np.ndarray, List[Tuple[float, float]], List[Tuple[float, float]], List[str]]:
|
| 138 |
+
"""
|
| 139 |
+
Paste random beads onto image with Gaussian alpha blending.
|
| 140 |
+
|
| 141 |
+
Returns augmented image and updated coordinate lists.
|
| 142 |
+
"""
|
| 143 |
+
if rng is None:
|
| 144 |
+
rng = np.random.default_rng()
|
| 145 |
+
|
| 146 |
+
image = image.copy()
|
| 147 |
+
h, w = image.shape[:2]
|
| 148 |
+
new_coords_6nm = list(coords_6nm)
|
| 149 |
+
new_coords_12nm = list(coords_12nm)
|
| 150 |
+
new_labels = list(class_labels)
|
| 151 |
+
|
| 152 |
+
for cls in ["6nm", "12nm"]:
|
| 153 |
+
if not self.crops[cls]:
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
crop_size = self.crop_sizes[cls]
|
| 157 |
+
half = crop_size // 2
|
| 158 |
+
n_paste = min(n_paste_per_class, len(self.crops[cls]))
|
| 159 |
+
|
| 160 |
+
for _ in range(n_paste):
|
| 161 |
+
# Random paste location (within image bounds)
|
| 162 |
+
px = rng.integers(half + 5, w - half - 5)
|
| 163 |
+
py = rng.integers(half + 5, h - half - 5)
|
| 164 |
+
|
| 165 |
+
# Skip if outside tissue mask
|
| 166 |
+
if mask is not None:
|
| 167 |
+
if py >= mask.shape[0] or px >= mask.shape[1] or not mask[py, px]:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
# Check minimum distance from existing particles (avoid overlap)
|
| 171 |
+
too_close = False
|
| 172 |
+
all_existing = new_coords_6nm + new_coords_12nm
|
| 173 |
+
for ex, ey in all_existing:
|
| 174 |
+
if (ex - px) ** 2 + (ey - py) ** 2 < (half * 1.5) ** 2:
|
| 175 |
+
too_close = True
|
| 176 |
+
break
|
| 177 |
+
if too_close:
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
# Select random crop
|
| 181 |
+
crop, _ = self.crops[cls][rng.integers(len(self.crops[cls]))]
|
| 182 |
+
|
| 183 |
+
# Gaussian alpha mask for soft blending
|
| 184 |
+
yy, xx = np.mgrid[:crop_size, :crop_size]
|
| 185 |
+
center = crop_size / 2
|
| 186 |
+
sigma = half * 0.7
|
| 187 |
+
alpha = np.exp(-((xx - center) ** 2 + (yy - center) ** 2) / (2 * sigma ** 2))
|
| 188 |
+
|
| 189 |
+
# Blend
|
| 190 |
+
region = image[py - half : py + half, px - half : px + half]
|
| 191 |
+
if region.shape != crop.shape:
|
| 192 |
+
continue
|
| 193 |
+
blended = (alpha * crop + (1 - alpha) * region).astype(np.uint8)
|
| 194 |
+
image[py - half : py + half, px - half : px + half] = blended
|
| 195 |
+
|
| 196 |
+
# Add to annotations
|
| 197 |
+
if cls == "6nm":
|
| 198 |
+
new_coords_6nm.append((float(px), float(py)))
|
| 199 |
+
else:
|
| 200 |
+
new_coords_12nm.append((float(px), float(py)))
|
| 201 |
+
new_labels.append(cls)
|
| 202 |
+
|
| 203 |
+
return image, new_coords_6nm, new_coords_12nm, new_labels
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
# Dataset
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
|
| 210 |
+
class ImmunogoldDataset(Dataset):
|
| 211 |
+
"""
|
| 212 |
+
Patch-based dataset for immunogold particle detection.
|
| 213 |
+
|
| 214 |
+
Sampling strategy:
|
| 215 |
+
- 70% of patches centered within 100px of a known particle (hard mining)
|
| 216 |
+
- 30% of patches at random locations (background recognition)
|
| 217 |
+
|
| 218 |
+
This ensures the model sees particles in nearly every batch despite
|
| 219 |
+
particles occupying <0.1% of image area.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
records: List[SynapseRecord],
|
| 225 |
+
fold_id: str,
|
| 226 |
+
mode: str = "train",
|
| 227 |
+
patch_size: int = 512,
|
| 228 |
+
stride: int = 2,
|
| 229 |
+
hard_mining_fraction: float = 0.7,
|
| 230 |
+
copy_paste_per_class: int = 5,
|
| 231 |
+
sigmas: Optional[Dict[str, float]] = None,
|
| 232 |
+
samples_per_epoch: int = 200,
|
| 233 |
+
seed: int = 42,
|
| 234 |
+
):
|
| 235 |
+
"""
|
| 236 |
+
Args:
|
| 237 |
+
records: all SynapseRecord entries
|
| 238 |
+
fold_id: synapse_id to hold out (test set)
|
| 239 |
+
mode: 'train' or 'val'
|
| 240 |
+
patch_size: training patch size
|
| 241 |
+
stride: model output stride
|
| 242 |
+
hard_mining_fraction: fraction of patches near particles
|
| 243 |
+
copy_paste_per_class: beads to paste per class
|
| 244 |
+
sigmas: heatmap Gaussian sigmas per class
|
| 245 |
+
samples_per_epoch: virtual epoch size
|
| 246 |
+
seed: random seed
|
| 247 |
+
"""
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.patch_size = patch_size
|
| 250 |
+
self.stride = stride
|
| 251 |
+
self.hard_mining_fraction = hard_mining_fraction
|
| 252 |
+
self.copy_paste_per_class = copy_paste_per_class if mode == "train" else 0
|
| 253 |
+
self.sigmas = sigmas or {"6nm": 1.0, "12nm": 1.5}
|
| 254 |
+
self.samples_per_epoch = samples_per_epoch
|
| 255 |
+
self.mode = mode
|
| 256 |
+
self._base_seed = seed
|
| 257 |
+
self.rng = np.random.default_rng(seed)
|
| 258 |
+
|
| 259 |
+
# Split records
|
| 260 |
+
if mode == "train":
|
| 261 |
+
self.records = [r for r in records if r.synapse_id != fold_id]
|
| 262 |
+
elif mode == "val":
|
| 263 |
+
self.records = [r for r in records if r.synapse_id == fold_id]
|
| 264 |
+
else:
|
| 265 |
+
self.records = records
|
| 266 |
+
|
| 267 |
+
# Pre-load all images and annotations into memory (~4MB each × 10 = 40MB)
|
| 268 |
+
self.images = {}
|
| 269 |
+
self.masks = {}
|
| 270 |
+
self.annotations = {}
|
| 271 |
+
|
| 272 |
+
for record in self.records:
|
| 273 |
+
sid = record.synapse_id
|
| 274 |
+
self.images[sid] = load_image(record.image_path)
|
| 275 |
+
if record.mask_path:
|
| 276 |
+
self.masks[sid] = load_mask(record.mask_path)
|
| 277 |
+
self.annotations[sid] = load_all_annotations(record, self.images[sid].shape)
|
| 278 |
+
|
| 279 |
+
# Build particle index for hard mining
|
| 280 |
+
self._build_particle_index()
|
| 281 |
+
|
| 282 |
+
# Build bead bank for copy-paste
|
| 283 |
+
self.bead_bank = BeadBank()
|
| 284 |
+
if mode == "train":
|
| 285 |
+
for sid in self.images:
|
| 286 |
+
self.bead_bank.extract_from_image(
|
| 287 |
+
self.images[sid], self.annotations[sid]
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Augmentation
|
| 291 |
+
if mode == "train":
|
| 292 |
+
self.transform = get_train_augmentation()
|
| 293 |
+
else:
|
| 294 |
+
self.transform = get_val_augmentation()
|
| 295 |
+
|
| 296 |
+
def _build_particle_index(self):
|
| 297 |
+
"""Build flat index of all particles for hard mining."""
|
| 298 |
+
self.particle_list = [] # (synapse_id, x, y, class)
|
| 299 |
+
for sid, annots in self.annotations.items():
|
| 300 |
+
for cls in ["6nm", "12nm"]:
|
| 301 |
+
for x, y in annots[cls]:
|
| 302 |
+
self.particle_list.append((sid, x, y, cls))
|
| 303 |
+
|
| 304 |
+
@staticmethod
|
| 305 |
+
def worker_init_fn(worker_id: int):
|
| 306 |
+
"""Re-seed RNG per DataLoader worker to avoid identical sequences."""
|
| 307 |
+
import torch
|
| 308 |
+
seed = torch.initial_seed() % (2**32) + worker_id
|
| 309 |
+
np.random.seed(seed)
|
| 310 |
+
|
| 311 |
+
def __len__(self) -> int:
|
| 312 |
+
return self.samples_per_epoch
|
| 313 |
+
|
| 314 |
+
def __getitem__(self, idx: int) -> dict:
|
| 315 |
+
# Reseed RNG using idx so each call produces a unique patch.
|
| 316 |
+
# Without this, the same 200 patches repeat every epoch → instant overfitting.
|
| 317 |
+
self.rng = np.random.default_rng(self._base_seed + idx + int(torch.initial_seed() % 100000))
|
| 318 |
+
"""
|
| 319 |
+
Sample a patch with ground truth heatmap.
|
| 320 |
+
|
| 321 |
+
Returns dict with:
|
| 322 |
+
'image': (1, patch_size, patch_size) float32 tensor
|
| 323 |
+
'heatmap': (2, patch_size//stride, patch_size//stride) float32
|
| 324 |
+
'offsets': (2, patch_size//stride, patch_size//stride) float32
|
| 325 |
+
'offset_mask': (patch_size//stride, patch_size//stride) bool
|
| 326 |
+
'conf_map': (2, patch_size//stride, patch_size//stride) float32
|
| 327 |
+
"""
|
| 328 |
+
# Decide: hard or random patch
|
| 329 |
+
do_hard = (self.rng.random() < self.hard_mining_fraction
|
| 330 |
+
and len(self.particle_list) > 0
|
| 331 |
+
and self.mode == "train")
|
| 332 |
+
|
| 333 |
+
if do_hard:
|
| 334 |
+
# Pick random particle, center patch on it with jitter
|
| 335 |
+
pidx = self.rng.integers(len(self.particle_list))
|
| 336 |
+
sid, px, py, _ = self.particle_list[pidx]
|
| 337 |
+
# Jitter center up to 128px
|
| 338 |
+
jitter = 128
|
| 339 |
+
cx = int(px + self.rng.integers(-jitter, jitter + 1))
|
| 340 |
+
cy = int(py + self.rng.integers(-jitter, jitter + 1))
|
| 341 |
+
else:
|
| 342 |
+
# Random image and location
|
| 343 |
+
sid = list(self.images.keys())[
|
| 344 |
+
self.rng.integers(len(self.images))
|
| 345 |
+
]
|
| 346 |
+
h, w = self.images[sid].shape[:2]
|
| 347 |
+
cx = self.rng.integers(self.patch_size // 2, w - self.patch_size // 2)
|
| 348 |
+
cy = self.rng.integers(self.patch_size // 2, h - self.patch_size // 2)
|
| 349 |
+
|
| 350 |
+
# Extract patch
|
| 351 |
+
image = self.images[sid]
|
| 352 |
+
h, w = image.shape[:2]
|
| 353 |
+
half = self.patch_size // 2
|
| 354 |
+
|
| 355 |
+
# Clamp to image bounds
|
| 356 |
+
cx = max(half, min(w - half, cx))
|
| 357 |
+
cy = max(half, min(h - half, cy))
|
| 358 |
+
|
| 359 |
+
x0, x1 = cx - half, cx + half
|
| 360 |
+
y0, y1 = cy - half, cy + half
|
| 361 |
+
|
| 362 |
+
patch = image[y0:y1, x0:x1].copy()
|
| 363 |
+
|
| 364 |
+
# Pad if needed (edge cases)
|
| 365 |
+
if patch.shape[0] != self.patch_size or patch.shape[1] != self.patch_size:
|
| 366 |
+
padded = np.zeros((self.patch_size, self.patch_size), dtype=np.uint8)
|
| 367 |
+
ph, pw = patch.shape[:2]
|
| 368 |
+
padded[:ph, :pw] = patch
|
| 369 |
+
patch = padded
|
| 370 |
+
|
| 371 |
+
# Get annotations within this patch (convert to patch-local coordinates)
|
| 372 |
+
keypoints = []
|
| 373 |
+
class_labels = []
|
| 374 |
+
for cls in ["6nm", "12nm"]:
|
| 375 |
+
for ax, ay in self.annotations[sid][cls]:
|
| 376 |
+
# Convert to patch-local coords
|
| 377 |
+
lx = ax - x0
|
| 378 |
+
ly = ay - y0
|
| 379 |
+
if 0 <= lx < self.patch_size and 0 <= ly < self.patch_size:
|
| 380 |
+
keypoints.append((lx, ly))
|
| 381 |
+
class_labels.append(cls)
|
| 382 |
+
|
| 383 |
+
# Copy-paste augmentation (before geometric transforms)
|
| 384 |
+
if self.copy_paste_per_class > 0 and self.mode == "train":
|
| 385 |
+
local_6nm = [(x, y) for (x, y), c in zip(keypoints, class_labels) if c == "6nm"]
|
| 386 |
+
local_12nm = [(x, y) for (x, y), c in zip(keypoints, class_labels) if c == "12nm"]
|
| 387 |
+
mask_patch = None
|
| 388 |
+
if sid in self.masks:
|
| 389 |
+
mask_patch = self.masks[sid][y0:y1, x0:x1]
|
| 390 |
+
|
| 391 |
+
patch, local_6nm, local_12nm, class_labels = self.bead_bank.paste_beads(
|
| 392 |
+
patch, local_6nm, local_12nm, class_labels,
|
| 393 |
+
mask=mask_patch,
|
| 394 |
+
n_paste_per_class=self.copy_paste_per_class,
|
| 395 |
+
rng=self.rng,
|
| 396 |
+
)
|
| 397 |
+
# Rebuild keypoints from updated coords
|
| 398 |
+
keypoints = [(x, y) for x, y in local_6nm] + [(x, y) for x, y in local_12nm]
|
| 399 |
+
class_labels = ["6nm"] * len(local_6nm) + ["12nm"] * len(local_12nm)
|
| 400 |
+
|
| 401 |
+
# Apply augmentation (co-transforms keypoints)
|
| 402 |
+
transformed = self.transform(
|
| 403 |
+
image=patch,
|
| 404 |
+
keypoints=keypoints,
|
| 405 |
+
class_labels=class_labels,
|
| 406 |
+
)
|
| 407 |
+
patch_aug = transformed["image"]
|
| 408 |
+
kp_aug = transformed["keypoints"]
|
| 409 |
+
cl_aug = transformed["class_labels"]
|
| 410 |
+
|
| 411 |
+
# Separate keypoints by class
|
| 412 |
+
coords_6nm = np.array(
|
| 413 |
+
[(x, y) for (x, y), c in zip(kp_aug, cl_aug) if c == "6nm"],
|
| 414 |
+
dtype=np.float64,
|
| 415 |
+
).reshape(-1, 2)
|
| 416 |
+
coords_12nm = np.array(
|
| 417 |
+
[(x, y) for (x, y), c in zip(kp_aug, cl_aug) if c == "12nm"],
|
| 418 |
+
dtype=np.float64,
|
| 419 |
+
).reshape(-1, 2)
|
| 420 |
+
|
| 421 |
+
# Generate heatmap GT from TRANSFORMED coordinates (never warp heatmap)
|
| 422 |
+
heatmap, offsets, offset_mask, conf_map = generate_heatmap_gt(
|
| 423 |
+
coords_6nm, coords_12nm,
|
| 424 |
+
self.patch_size, self.patch_size,
|
| 425 |
+
sigmas=self.sigmas,
|
| 426 |
+
stride=self.stride,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Convert to tensors
|
| 430 |
+
patch_tensor = torch.from_numpy(patch_aug).float().unsqueeze(0) / 255.0
|
| 431 |
+
|
| 432 |
+
return {
|
| 433 |
+
"image": patch_tensor,
|
| 434 |
+
"heatmap": torch.from_numpy(heatmap),
|
| 435 |
+
"offsets": torch.from_numpy(offsets),
|
| 436 |
+
"offset_mask": torch.from_numpy(offset_mask),
|
| 437 |
+
"conf_map": torch.from_numpy(conf_map),
|
| 438 |
+
}
|
src/ensemble.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test-time augmentation (D4 dihedral group) and model ensemble averaging.
|
| 3 |
+
|
| 4 |
+
D4 TTA: 4 rotations x 2 reflections = 8 geometric views
|
| 5 |
+
+ 2 intensity variants = 10 total forward passes.
|
| 6 |
+
Gold beads are rotationally invariant — D4 TTA is maximally effective.
|
| 7 |
+
Expected F1 gain: +1-3% over single forward pass.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from typing import List, Optional
|
| 14 |
+
|
| 15 |
+
from src.model import ImmunogoldCenterNet
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def d4_tta_predict(
|
| 19 |
+
model: ImmunogoldCenterNet,
|
| 20 |
+
image: np.ndarray,
|
| 21 |
+
device: torch.device = torch.device("cpu"),
|
| 22 |
+
) -> tuple:
|
| 23 |
+
"""
|
| 24 |
+
Test-time augmentation over D4 dihedral group + intensity variants.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model: trained CenterNet model
|
| 28 |
+
image: (H, W) uint8 preprocessed image
|
| 29 |
+
device: torch device
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
averaged_heatmap: (2, H/2, W/2) numpy array
|
| 33 |
+
averaged_offsets: (2, H/2, W/2) numpy array
|
| 34 |
+
"""
|
| 35 |
+
model.eval()
|
| 36 |
+
heatmaps = []
|
| 37 |
+
offsets_list = []
|
| 38 |
+
|
| 39 |
+
# Ensure image dimensions are divisible by 32 for the encoder
|
| 40 |
+
h, w = image.shape[:2]
|
| 41 |
+
pad_h = (32 - h % 32) % 32
|
| 42 |
+
pad_w = (32 - w % 32) % 32
|
| 43 |
+
|
| 44 |
+
def _forward(img_np):
|
| 45 |
+
"""Run model on numpy image, return heatmap and offsets."""
|
| 46 |
+
# Pad to multiple of 32
|
| 47 |
+
if pad_h > 0 or pad_w > 0:
|
| 48 |
+
img_np = np.pad(img_np, ((0, pad_h), (0, pad_w)), mode="reflect")
|
| 49 |
+
|
| 50 |
+
tensor = (
|
| 51 |
+
torch.from_numpy(img_np)
|
| 52 |
+
.float()
|
| 53 |
+
.unsqueeze(0)
|
| 54 |
+
.unsqueeze(0) # (1, 1, H, W)
|
| 55 |
+
/ 255.0
|
| 56 |
+
).to(device)
|
| 57 |
+
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
hm, off = model(tensor)
|
| 60 |
+
|
| 61 |
+
hm = hm.squeeze(0).cpu().numpy() # (2, H/2, W/2)
|
| 62 |
+
off = off.squeeze(0).cpu().numpy() # (2, H/2, W/2)
|
| 63 |
+
|
| 64 |
+
# Remove padding from output
|
| 65 |
+
hm_h = h // 2
|
| 66 |
+
hm_w = w // 2
|
| 67 |
+
return hm[:, :hm_h, :hm_w], off[:, :hm_h, :hm_w]
|
| 68 |
+
|
| 69 |
+
# D4 group: 4 rotations x 2 reflections = 8 geometric views
|
| 70 |
+
for k in range(4):
|
| 71 |
+
for flip in [False, True]:
|
| 72 |
+
aug = np.rot90(image, k).copy()
|
| 73 |
+
if flip:
|
| 74 |
+
aug = np.fliplr(aug).copy()
|
| 75 |
+
|
| 76 |
+
hm, off = _forward(aug)
|
| 77 |
+
|
| 78 |
+
# Inverse transforms on heatmap and offsets
|
| 79 |
+
if flip:
|
| 80 |
+
hm = np.flip(hm, axis=2).copy() # flip W axis
|
| 81 |
+
off = np.flip(off, axis=2).copy()
|
| 82 |
+
off[0] = -off[0] # negate x offset for horizontal flip
|
| 83 |
+
|
| 84 |
+
if k > 0:
|
| 85 |
+
hm = np.rot90(hm, -k, axes=(1, 2)).copy()
|
| 86 |
+
off = np.rot90(off, -k, axes=(1, 2)).copy()
|
| 87 |
+
# Rotate offset vectors
|
| 88 |
+
if k == 1: # 90° CCW undo
|
| 89 |
+
off = np.stack([-off[1], off[0]], axis=0)
|
| 90 |
+
elif k == 2: # 180°
|
| 91 |
+
off = np.stack([-off[0], -off[1]], axis=0)
|
| 92 |
+
elif k == 3: # 270° CCW undo
|
| 93 |
+
off = np.stack([off[1], -off[0]], axis=0)
|
| 94 |
+
|
| 95 |
+
heatmaps.append(hm)
|
| 96 |
+
offsets_list.append(off)
|
| 97 |
+
|
| 98 |
+
# 2 intensity variants
|
| 99 |
+
for factor in [0.9, 1.1]:
|
| 100 |
+
aug = np.clip(image.astype(np.float32) * factor, 0, 255).astype(np.uint8)
|
| 101 |
+
hm, off = _forward(aug)
|
| 102 |
+
heatmaps.append(hm)
|
| 103 |
+
offsets_list.append(off)
|
| 104 |
+
|
| 105 |
+
# Average all views
|
| 106 |
+
avg_heatmap = np.mean(heatmaps, axis=0)
|
| 107 |
+
avg_offsets = np.mean(offsets_list, axis=0)
|
| 108 |
+
|
| 109 |
+
return avg_heatmap, avg_offsets
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def ensemble_predict(
|
| 113 |
+
models: List[ImmunogoldCenterNet],
|
| 114 |
+
image: np.ndarray,
|
| 115 |
+
device: torch.device = torch.device("cpu"),
|
| 116 |
+
use_tta: bool = True,
|
| 117 |
+
) -> tuple:
|
| 118 |
+
"""
|
| 119 |
+
Ensemble prediction: average heatmaps from N models.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
models: list of trained models (e.g., 5 seeds x 3 snapshots = 15)
|
| 123 |
+
image: (H, W) uint8 preprocessed image
|
| 124 |
+
device: torch device
|
| 125 |
+
use_tta: whether to apply D4 TTA per model
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
averaged_heatmap: (2, H/2, W/2) numpy array
|
| 129 |
+
averaged_offsets: (2, H/2, W/2) numpy array
|
| 130 |
+
"""
|
| 131 |
+
all_heatmaps = []
|
| 132 |
+
all_offsets = []
|
| 133 |
+
|
| 134 |
+
for model in models:
|
| 135 |
+
model.eval()
|
| 136 |
+
model.to(device)
|
| 137 |
+
|
| 138 |
+
if use_tta:
|
| 139 |
+
hm, off = d4_tta_predict(model, image, device)
|
| 140 |
+
else:
|
| 141 |
+
h, w = image.shape[:2]
|
| 142 |
+
pad_h = (32 - h % 32) % 32
|
| 143 |
+
pad_w = (32 - w % 32) % 32
|
| 144 |
+
img_padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")
|
| 145 |
+
|
| 146 |
+
tensor = (
|
| 147 |
+
torch.from_numpy(img_padded)
|
| 148 |
+
.float()
|
| 149 |
+
.unsqueeze(0)
|
| 150 |
+
.unsqueeze(0)
|
| 151 |
+
/ 255.0
|
| 152 |
+
).to(device)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
hm_t, off_t = model(tensor)
|
| 156 |
+
|
| 157 |
+
hm = hm_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
|
| 158 |
+
off = off_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
|
| 159 |
+
|
| 160 |
+
all_heatmaps.append(hm)
|
| 161 |
+
all_offsets.append(off)
|
| 162 |
+
|
| 163 |
+
return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def sliding_window_inference(
|
| 167 |
+
model: ImmunogoldCenterNet,
|
| 168 |
+
image: np.ndarray,
|
| 169 |
+
patch_size: int = 512,
|
| 170 |
+
overlap: int = 128,
|
| 171 |
+
device: torch.device = torch.device("cpu"),
|
| 172 |
+
) -> tuple:
|
| 173 |
+
"""
|
| 174 |
+
Full-image inference via sliding window with overlap stitching.
|
| 175 |
+
|
| 176 |
+
Tiles the image into overlapping patches, runs the model on each,
|
| 177 |
+
and stitches heatmaps using max in overlap regions.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
model: trained model
|
| 181 |
+
image: (H, W) uint8 preprocessed image
|
| 182 |
+
patch_size: tile size
|
| 183 |
+
overlap: overlap between tiles
|
| 184 |
+
device: torch device
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
heatmap: (2, H/2, W/2) numpy array
|
| 188 |
+
offsets: (2, H/2, W/2) numpy array
|
| 189 |
+
"""
|
| 190 |
+
model.eval()
|
| 191 |
+
h, w = image.shape[:2]
|
| 192 |
+
stride_step = patch_size - overlap
|
| 193 |
+
|
| 194 |
+
# Output dimensions at model stride
|
| 195 |
+
out_h = h // 2
|
| 196 |
+
out_w = w // 2
|
| 197 |
+
out_patch = patch_size // 2
|
| 198 |
+
|
| 199 |
+
heatmap = np.zeros((2, out_h, out_w), dtype=np.float32)
|
| 200 |
+
offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
|
| 201 |
+
count = np.zeros((out_h, out_w), dtype=np.float32)
|
| 202 |
+
|
| 203 |
+
for y0 in range(0, h - patch_size + 1, stride_step):
|
| 204 |
+
for x0 in range(0, w - patch_size + 1, stride_step):
|
| 205 |
+
patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
|
| 206 |
+
tensor = (
|
| 207 |
+
torch.from_numpy(patch)
|
| 208 |
+
.float()
|
| 209 |
+
.unsqueeze(0)
|
| 210 |
+
.unsqueeze(0)
|
| 211 |
+
/ 255.0
|
| 212 |
+
).to(device)
|
| 213 |
+
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
hm, off = model(tensor)
|
| 216 |
+
|
| 217 |
+
hm_np = hm.squeeze(0).cpu().numpy()
|
| 218 |
+
off_np = off.squeeze(0).cpu().numpy()
|
| 219 |
+
|
| 220 |
+
# Output coordinates
|
| 221 |
+
oy0 = y0 // 2
|
| 222 |
+
ox0 = x0 // 2
|
| 223 |
+
|
| 224 |
+
# Max-stitch heatmap, average-stitch offsets
|
| 225 |
+
heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] = np.maximum(
|
| 226 |
+
heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch],
|
| 227 |
+
hm_np,
|
| 228 |
+
)
|
| 229 |
+
offsets[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += off_np
|
| 230 |
+
count[oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += 1
|
| 231 |
+
|
| 232 |
+
# Average offsets where counted
|
| 233 |
+
count = np.maximum(count, 1)
|
| 234 |
+
offsets /= count[np.newaxis, :, :]
|
| 235 |
+
|
| 236 |
+
return heatmap, offsets
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation: Hungarian matching, per-class metrics, LOOCV runner.
|
| 3 |
+
|
| 4 |
+
Uses scipy linear_sum_assignment for optimal bipartite matching between
|
| 5 |
+
predictions and ground truth with class-specific match radii.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.optimize import linear_sum_assignment
|
| 10 |
+
from scipy.spatial.distance import cdist
|
| 11 |
+
from typing import Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_f1(tp: int, fp: int, fn: int, eps: float = 1e-6) -> Tuple[float, float, float]:
|
| 15 |
+
"""Compute F1, precision, recall from TP/FP/FN counts."""
|
| 16 |
+
precision = tp / (tp + fp + eps)
|
| 17 |
+
recall = tp / (tp + fn + eps)
|
| 18 |
+
f1 = 2 * precision * recall / (precision + recall + eps)
|
| 19 |
+
return f1, precision, recall
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def match_detections_to_gt(
|
| 23 |
+
detections: List[dict],
|
| 24 |
+
gt_coords_6nm: np.ndarray,
|
| 25 |
+
gt_coords_12nm: np.ndarray,
|
| 26 |
+
match_radii: Optional[Dict[str, float]] = None,
|
| 27 |
+
) -> Dict[str, dict]:
|
| 28 |
+
"""
|
| 29 |
+
Hungarian matching between predictions and ground truth.
|
| 30 |
+
|
| 31 |
+
A detection matches GT only if:
|
| 32 |
+
1. Euclidean distance < match_radius[class]
|
| 33 |
+
2. Predicted class == GT class
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
detections: list of {'x', 'y', 'class', 'conf'}
|
| 37 |
+
gt_coords_6nm: (N, 2) array of (x, y) GT for 6nm
|
| 38 |
+
gt_coords_12nm: (M, 2) array of (x, y) GT for 12nm
|
| 39 |
+
match_radii: per-class match radius in pixels
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Dict with per-class and overall TP/FP/FN/F1/precision/recall.
|
| 43 |
+
"""
|
| 44 |
+
if match_radii is None:
|
| 45 |
+
match_radii = {"6nm": 9.0, "12nm": 15.0}
|
| 46 |
+
|
| 47 |
+
gt_by_class = {"6nm": gt_coords_6nm, "12nm": gt_coords_12nm}
|
| 48 |
+
results = {}
|
| 49 |
+
|
| 50 |
+
total_tp = 0
|
| 51 |
+
total_fp = 0
|
| 52 |
+
total_fn = 0
|
| 53 |
+
|
| 54 |
+
for cls in ["6nm", "12nm"]:
|
| 55 |
+
cls_dets = [d for d in detections if d["class"] == cls]
|
| 56 |
+
gt = gt_by_class[cls]
|
| 57 |
+
radius = match_radii[cls]
|
| 58 |
+
|
| 59 |
+
if len(cls_dets) == 0 and len(gt) == 0:
|
| 60 |
+
results[cls] = {
|
| 61 |
+
"tp": 0, "fp": 0, "fn": 0,
|
| 62 |
+
"f1": 1.0, "precision": 1.0, "recall": 1.0,
|
| 63 |
+
}
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
if len(cls_dets) == 0:
|
| 67 |
+
results[cls] = {
|
| 68 |
+
"tp": 0, "fp": 0, "fn": len(gt),
|
| 69 |
+
"f1": 0.0, "precision": 0.0, "recall": 0.0,
|
| 70 |
+
}
|
| 71 |
+
total_fn += len(gt)
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
if len(gt) == 0:
|
| 75 |
+
results[cls] = {
|
| 76 |
+
"tp": 0, "fp": len(cls_dets), "fn": 0,
|
| 77 |
+
"f1": 0.0, "precision": 0.0, "recall": 0.0,
|
| 78 |
+
}
|
| 79 |
+
total_fp += len(cls_dets)
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Build cost matrix
|
| 83 |
+
pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets])
|
| 84 |
+
cost = cdist(pred_coords, gt)
|
| 85 |
+
|
| 86 |
+
# Set costs beyond radius to a large value (forbid match)
|
| 87 |
+
cost_masked = cost.copy()
|
| 88 |
+
cost_masked[cost_masked > radius] = 1e6
|
| 89 |
+
|
| 90 |
+
# Hungarian matching
|
| 91 |
+
row_ind, col_ind = linear_sum_assignment(cost_masked)
|
| 92 |
+
|
| 93 |
+
# Count valid matches (within radius)
|
| 94 |
+
tp = sum(
|
| 95 |
+
1 for r, c in zip(row_ind, col_ind)
|
| 96 |
+
if cost_masked[r, c] <= radius
|
| 97 |
+
)
|
| 98 |
+
fp = len(cls_dets) - tp
|
| 99 |
+
fn = len(gt) - tp
|
| 100 |
+
|
| 101 |
+
f1, prec, rec = compute_f1(tp, fp, fn)
|
| 102 |
+
|
| 103 |
+
results[cls] = {
|
| 104 |
+
"tp": tp, "fp": fp, "fn": fn,
|
| 105 |
+
"f1": f1, "precision": prec, "recall": rec,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
total_tp += tp
|
| 109 |
+
total_fp += fp
|
| 110 |
+
total_fn += fn
|
| 111 |
+
|
| 112 |
+
# Overall
|
| 113 |
+
f1_overall, prec_overall, rec_overall = compute_f1(total_tp, total_fp, total_fn)
|
| 114 |
+
results["overall"] = {
|
| 115 |
+
"tp": total_tp, "fp": total_fp, "fn": total_fn,
|
| 116 |
+
"f1": f1_overall, "precision": prec_overall, "recall": rec_overall,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# Mean F1 across classes
|
| 120 |
+
class_f1s = [results[c]["f1"] for c in ["6nm", "12nm"] if results[c]["fn"] + results[c]["tp"] > 0]
|
| 121 |
+
results["mean_f1"] = np.mean(class_f1s) if class_f1s else 0.0
|
| 122 |
+
|
| 123 |
+
return results
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def evaluate_fold(
|
| 127 |
+
detections: List[dict],
|
| 128 |
+
gt_annotations: Dict[str, np.ndarray],
|
| 129 |
+
match_radii: Optional[Dict[str, float]] = None,
|
| 130 |
+
has_6nm: bool = True,
|
| 131 |
+
) -> Dict[str, dict]:
|
| 132 |
+
"""
|
| 133 |
+
Evaluate detections for a single LOOCV fold.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
detections: model predictions
|
| 137 |
+
gt_annotations: {'6nm': Nx2, '12nm': Mx2}
|
| 138 |
+
match_radii: per-class match radii
|
| 139 |
+
has_6nm: whether this fold has 6nm GT (False for S7, S15)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Evaluation metrics dict.
|
| 143 |
+
"""
|
| 144 |
+
gt_6nm = gt_annotations.get("6nm", np.empty((0, 2)))
|
| 145 |
+
gt_12nm = gt_annotations.get("12nm", np.empty((0, 2)))
|
| 146 |
+
|
| 147 |
+
results = match_detections_to_gt(detections, gt_6nm, gt_12nm, match_radii)
|
| 148 |
+
|
| 149 |
+
if not has_6nm:
|
| 150 |
+
results["6nm"]["note"] = "N/A (missing annotations)"
|
| 151 |
+
|
| 152 |
+
return results
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def compute_average_precision(
|
| 156 |
+
detections: List[dict],
|
| 157 |
+
gt_coords: np.ndarray,
|
| 158 |
+
match_radius: float,
|
| 159 |
+
) -> float:
|
| 160 |
+
"""
|
| 161 |
+
Compute Average Precision (AP) for a single class.
|
| 162 |
+
|
| 163 |
+
Follows PASCAL VOC style: sort by confidence, compute precision-recall
|
| 164 |
+
curve, then compute area under curve.
|
| 165 |
+
"""
|
| 166 |
+
if len(gt_coords) == 0:
|
| 167 |
+
return 0.0 if detections else 1.0
|
| 168 |
+
|
| 169 |
+
# Sort by confidence descending
|
| 170 |
+
sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
|
| 171 |
+
|
| 172 |
+
tp_list = []
|
| 173 |
+
fp_list = []
|
| 174 |
+
matched_gt = set()
|
| 175 |
+
|
| 176 |
+
for det in sorted_dets:
|
| 177 |
+
det_coord = np.array([det["x"], det["y"]])
|
| 178 |
+
dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
|
| 179 |
+
min_idx = np.argmin(dists)
|
| 180 |
+
|
| 181 |
+
if dists[min_idx] <= match_radius and min_idx not in matched_gt:
|
| 182 |
+
tp_list.append(1)
|
| 183 |
+
fp_list.append(0)
|
| 184 |
+
matched_gt.add(min_idx)
|
| 185 |
+
else:
|
| 186 |
+
tp_list.append(0)
|
| 187 |
+
fp_list.append(1)
|
| 188 |
+
|
| 189 |
+
tp_cumsum = np.cumsum(tp_list)
|
| 190 |
+
fp_cumsum = np.cumsum(fp_list)
|
| 191 |
+
|
| 192 |
+
precision = tp_cumsum / (tp_cumsum + fp_cumsum)
|
| 193 |
+
recall = tp_cumsum / len(gt_coords)
|
| 194 |
+
|
| 195 |
+
# Compute AP using all-point interpolation
|
| 196 |
+
ap = 0.0
|
| 197 |
+
for i in range(len(precision)):
|
| 198 |
+
if i == 0:
|
| 199 |
+
ap += precision[i] * recall[i]
|
| 200 |
+
else:
|
| 201 |
+
ap += precision[i] * (recall[i] - recall[i - 1])
|
| 202 |
+
|
| 203 |
+
return ap
|
src/heatmap.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ground truth heatmap generation and peak extraction for CenterNet.
|
| 3 |
+
|
| 4 |
+
Generates Gaussian-splat heatmaps at stride-2 resolution with
|
| 5 |
+
class-specific sigma values calibrated to bead size.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from typing import Dict, List, Tuple, Optional
|
| 12 |
+
|
| 13 |
+
# Class index mapping
|
| 14 |
+
CLASS_IDX = {"6nm": 0, "12nm": 1}
|
| 15 |
+
CLASS_NAMES = ["6nm", "12nm"]
|
| 16 |
+
STRIDE = 2
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def generate_heatmap_gt(
|
| 20 |
+
coords_6nm: np.ndarray,
|
| 21 |
+
coords_12nm: np.ndarray,
|
| 22 |
+
image_h: int,
|
| 23 |
+
image_w: int,
|
| 24 |
+
sigmas: Optional[Dict[str, float]] = None,
|
| 25 |
+
stride: int = STRIDE,
|
| 26 |
+
confidence_6nm: Optional[np.ndarray] = None,
|
| 27 |
+
confidence_12nm: Optional[np.ndarray] = None,
|
| 28 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 29 |
+
"""
|
| 30 |
+
Generate CenterNet ground truth heatmaps and offset maps.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
coords_6nm: (N, 2) array of (x, y) in ORIGINAL pixel space
|
| 34 |
+
coords_12nm: (M, 2) array of (x, y) in ORIGINAL pixel space
|
| 35 |
+
image_h: original image height
|
| 36 |
+
image_w: original image width
|
| 37 |
+
sigmas: per-class Gaussian sigma in feature space
|
| 38 |
+
stride: output stride (default 2)
|
| 39 |
+
confidence_6nm: optional per-particle confidence weights
|
| 40 |
+
confidence_12nm: optional per-particle confidence weights
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
heatmap: (2, H//stride, W//stride) float32 in [0, 1]
|
| 44 |
+
offsets: (2, H//stride, W//stride) float32 sub-pixel offsets
|
| 45 |
+
offset_mask: (H//stride, W//stride) bool — True at particle centers
|
| 46 |
+
conf_map: (2, H//stride, W//stride) float32 confidence weights
|
| 47 |
+
"""
|
| 48 |
+
if sigmas is None:
|
| 49 |
+
sigmas = {"6nm": 1.0, "12nm": 1.5}
|
| 50 |
+
|
| 51 |
+
h_feat = image_h // stride
|
| 52 |
+
w_feat = image_w // stride
|
| 53 |
+
|
| 54 |
+
heatmap = np.zeros((2, h_feat, w_feat), dtype=np.float32)
|
| 55 |
+
offsets = np.zeros((2, h_feat, w_feat), dtype=np.float32)
|
| 56 |
+
offset_mask = np.zeros((h_feat, w_feat), dtype=bool)
|
| 57 |
+
conf_map = np.ones((2, h_feat, w_feat), dtype=np.float32)
|
| 58 |
+
|
| 59 |
+
# Prepare coordinate lists with class labels and confidences
|
| 60 |
+
all_entries = []
|
| 61 |
+
if len(coords_6nm) > 0:
|
| 62 |
+
confs = confidence_6nm if confidence_6nm is not None else np.ones(len(coords_6nm))
|
| 63 |
+
for i, (x, y) in enumerate(coords_6nm):
|
| 64 |
+
all_entries.append((x, y, "6nm", confs[i]))
|
| 65 |
+
if len(coords_12nm) > 0:
|
| 66 |
+
confs = confidence_12nm if confidence_12nm is not None else np.ones(len(coords_12nm))
|
| 67 |
+
for i, (x, y) in enumerate(coords_12nm):
|
| 68 |
+
all_entries.append((x, y, "12nm", confs[i]))
|
| 69 |
+
|
| 70 |
+
for x, y, cls, conf in all_entries:
|
| 71 |
+
cidx = CLASS_IDX[cls]
|
| 72 |
+
sigma = sigmas[cls]
|
| 73 |
+
|
| 74 |
+
# Feature-space center (float)
|
| 75 |
+
cx_f = x / stride
|
| 76 |
+
cy_f = y / stride
|
| 77 |
+
|
| 78 |
+
# Integer grid center
|
| 79 |
+
cx_i = int(round(cx_f))
|
| 80 |
+
cy_i = int(round(cy_f))
|
| 81 |
+
|
| 82 |
+
# Sub-pixel offset
|
| 83 |
+
off_x = cx_f - cx_i
|
| 84 |
+
off_y = cy_f - cy_i
|
| 85 |
+
|
| 86 |
+
# Gaussian radius: truncate at 3 sigma
|
| 87 |
+
r = max(int(3 * sigma + 1), 2)
|
| 88 |
+
|
| 89 |
+
# Bounds-clipped grid
|
| 90 |
+
y0 = max(0, cy_i - r)
|
| 91 |
+
y1 = min(h_feat, cy_i + r + 1)
|
| 92 |
+
x0 = max(0, cx_i - r)
|
| 93 |
+
x1 = min(w_feat, cx_i + r + 1)
|
| 94 |
+
|
| 95 |
+
if y0 >= y1 or x0 >= x1:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
yy, xx = np.meshgrid(
|
| 99 |
+
np.arange(y0, y1),
|
| 100 |
+
np.arange(x0, x1),
|
| 101 |
+
indexing="ij",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Gaussian centered at INTEGER center (not float)
|
| 105 |
+
# The integer center MUST be exactly 1.0 — the CornerNet focal loss
|
| 106 |
+
# uses pos_mask = (gt == 1.0) and treats everything else as negative.
|
| 107 |
+
# Centering the Gaussian at the float position produces peaks of 0.78-0.93
|
| 108 |
+
# which the loss sees as negatives → zero positive training signal.
|
| 109 |
+
gaussian = np.exp(
|
| 110 |
+
-((xx - cx_i) ** 2 + (yy - cy_i) ** 2) / (2 * sigma ** 2)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Scale by confidence (for pseudo-label weighting)
|
| 114 |
+
gaussian = gaussian * conf
|
| 115 |
+
|
| 116 |
+
# Element-wise max (handles overlapping particles correctly)
|
| 117 |
+
heatmap[cidx, y0:y1, x0:x1] = np.maximum(
|
| 118 |
+
heatmap[cidx, y0:y1, x0:x1], gaussian
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Offset and confidence only at the integer center pixel
|
| 122 |
+
if 0 <= cy_i < h_feat and 0 <= cx_i < w_feat:
|
| 123 |
+
offsets[0, cy_i, cx_i] = off_x
|
| 124 |
+
offsets[1, cy_i, cx_i] = off_y
|
| 125 |
+
offset_mask[cy_i, cx_i] = True
|
| 126 |
+
conf_map[cidx, cy_i, cx_i] = conf
|
| 127 |
+
|
| 128 |
+
return heatmap, offsets, offset_mask, conf_map
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def extract_peaks(
|
| 132 |
+
heatmap: torch.Tensor,
|
| 133 |
+
offset_map: torch.Tensor,
|
| 134 |
+
stride: int = STRIDE,
|
| 135 |
+
conf_threshold: float = 0.3,
|
| 136 |
+
nms_kernel_sizes: Optional[Dict[str, int]] = None,
|
| 137 |
+
) -> List[dict]:
|
| 138 |
+
"""
|
| 139 |
+
Extract detections from predicted heatmap via max-pool NMS.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
heatmap: (2, H/stride, W/stride) sigmoid-activated
|
| 143 |
+
offset_map: (2, H/stride, W/stride) raw offset predictions
|
| 144 |
+
stride: output stride
|
| 145 |
+
conf_threshold: minimum confidence to keep
|
| 146 |
+
nms_kernel_sizes: per-class NMS kernel sizes
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
List of {'x': float, 'y': float, 'class': str, 'conf': float}
|
| 150 |
+
"""
|
| 151 |
+
if nms_kernel_sizes is None:
|
| 152 |
+
nms_kernel_sizes = {"6nm": 3, "12nm": 5}
|
| 153 |
+
|
| 154 |
+
detections = []
|
| 155 |
+
|
| 156 |
+
for cls_idx, cls_name in enumerate(CLASS_NAMES):
|
| 157 |
+
hm_cls = heatmap[cls_idx].unsqueeze(0).unsqueeze(0) # (1,1,H,W)
|
| 158 |
+
kernel = nms_kernel_sizes[cls_name]
|
| 159 |
+
|
| 160 |
+
# Max-pool NMS
|
| 161 |
+
hmax = F.max_pool2d(
|
| 162 |
+
hm_cls, kernel_size=kernel, stride=1, padding=kernel // 2
|
| 163 |
+
)
|
| 164 |
+
peaks = (hmax.squeeze() == heatmap[cls_idx]) & (
|
| 165 |
+
heatmap[cls_idx] > conf_threshold
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
ys, xs = torch.where(peaks)
|
| 169 |
+
for y_idx, x_idx in zip(ys, xs):
|
| 170 |
+
y_i = y_idx.item()
|
| 171 |
+
x_i = x_idx.item()
|
| 172 |
+
conf = heatmap[cls_idx, y_i, x_i].item()
|
| 173 |
+
dx = offset_map[0, y_i, x_i].item()
|
| 174 |
+
dy = offset_map[1, y_i, x_i].item()
|
| 175 |
+
|
| 176 |
+
# Back to input space with sub-pixel offset
|
| 177 |
+
det_x = (x_i + dx) * stride
|
| 178 |
+
det_y = (y_i + dy) * stride
|
| 179 |
+
|
| 180 |
+
detections.append({
|
| 181 |
+
"x": det_x,
|
| 182 |
+
"y": det_y,
|
| 183 |
+
"class": cls_name,
|
| 184 |
+
"conf": conf,
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
return detections
|
src/loss.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loss functions for CenterNet immunogold detection.
|
| 3 |
+
|
| 4 |
+
Implements CornerNet penalty-reduced focal loss for sparse heatmaps
|
| 5 |
+
and smooth L1 offset regression loss.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cornernet_focal_loss(
|
| 13 |
+
pred: torch.Tensor,
|
| 14 |
+
gt: torch.Tensor,
|
| 15 |
+
alpha: int = 2,
|
| 16 |
+
beta: int = 4,
|
| 17 |
+
conf_weights: torch.Tensor = None,
|
| 18 |
+
eps: float = 1e-6,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
CornerNet penalty-reduced focal loss for sparse heatmaps.
|
| 22 |
+
|
| 23 |
+
The positive:negative pixel ratio is ~1:23,000 per channel.
|
| 24 |
+
Standard BCE would learn to predict all zeros. This loss
|
| 25 |
+
penalizes confident wrong predictions and rewards uncertain
|
| 26 |
+
correct ones via the (1-p)^alpha and p^alpha terms.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
pred: (B, C, H, W) sigmoid-activated predictions in [0, 1]
|
| 30 |
+
gt: (B, C, H, W) Gaussian heatmap targets in [0, 1]
|
| 31 |
+
alpha: focal exponent for prediction confidence (default 2)
|
| 32 |
+
beta: penalty reduction exponent near GT peaks (default 4)
|
| 33 |
+
conf_weights: optional (B, C, H, W) per-pixel confidence weights
|
| 34 |
+
for pseudo-label weighting
|
| 35 |
+
eps: numerical stability
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Scalar loss, normalized by number of positive locations.
|
| 39 |
+
"""
|
| 40 |
+
pos_mask = (gt == 1).float()
|
| 41 |
+
neg_mask = (gt < 1).float()
|
| 42 |
+
|
| 43 |
+
# Penalty reduction: pixels near particle centers get lower negative penalty
|
| 44 |
+
# (1 - gt)^beta → 0 near peaks, → 1 far from peaks
|
| 45 |
+
neg_weights = torch.pow(1 - gt, beta)
|
| 46 |
+
|
| 47 |
+
# Positive loss: encourage high confidence at GT peaks
|
| 48 |
+
pos_loss = torch.log(pred.clamp(min=eps)) * torch.pow(1 - pred, alpha) * pos_mask
|
| 49 |
+
|
| 50 |
+
# Negative loss: penalize high confidence away from GT peaks
|
| 51 |
+
neg_loss = (
|
| 52 |
+
torch.log((1 - pred).clamp(min=eps))
|
| 53 |
+
* torch.pow(pred, alpha)
|
| 54 |
+
* neg_weights
|
| 55 |
+
* neg_mask
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Apply confidence weighting if provided (for pseudo-label support)
|
| 59 |
+
if conf_weights is not None:
|
| 60 |
+
pos_loss = pos_loss * conf_weights
|
| 61 |
+
# Negative loss near pseudo-labels also scaled
|
| 62 |
+
neg_loss = neg_loss * conf_weights
|
| 63 |
+
|
| 64 |
+
num_pos = pos_mask.sum().clamp(min=1)
|
| 65 |
+
loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos
|
| 66 |
+
|
| 67 |
+
return loss
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def offset_loss(
|
| 71 |
+
pred_offsets: torch.Tensor,
|
| 72 |
+
gt_offsets: torch.Tensor,
|
| 73 |
+
mask: torch.Tensor,
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Smooth L1 loss on sub-pixel offsets at annotated particle locations only.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
pred_offsets: (B, 2, H, W) predicted offsets
|
| 80 |
+
gt_offsets: (B, 2, H, W) ground truth offsets
|
| 81 |
+
mask: (B, H, W) boolean — True at particle integer centers
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Scalar loss.
|
| 85 |
+
"""
|
| 86 |
+
# Expand mask to match offset dimensions
|
| 87 |
+
mask_expanded = mask.unsqueeze(1).expand_as(pred_offsets)
|
| 88 |
+
|
| 89 |
+
if mask_expanded.sum() == 0:
|
| 90 |
+
return torch.tensor(0.0, device=pred_offsets.device, requires_grad=True)
|
| 91 |
+
|
| 92 |
+
loss = F.smooth_l1_loss(
|
| 93 |
+
pred_offsets[mask_expanded],
|
| 94 |
+
gt_offsets[mask_expanded],
|
| 95 |
+
reduction="mean",
|
| 96 |
+
)
|
| 97 |
+
return loss
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def total_loss(
|
| 101 |
+
heatmap_pred: torch.Tensor,
|
| 102 |
+
heatmap_gt: torch.Tensor,
|
| 103 |
+
offset_pred: torch.Tensor,
|
| 104 |
+
offset_gt: torch.Tensor,
|
| 105 |
+
offset_mask: torch.Tensor,
|
| 106 |
+
lambda_offset: float = 1.0,
|
| 107 |
+
focal_alpha: int = 2,
|
| 108 |
+
focal_beta: int = 4,
|
| 109 |
+
conf_weights: torch.Tensor = None,
|
| 110 |
+
) -> tuple:
|
| 111 |
+
"""
|
| 112 |
+
Combined heatmap focal loss + offset regression loss.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
heatmap_pred: (B, 2, H, W) sigmoid predictions
|
| 116 |
+
heatmap_gt: (B, 2, H, W) Gaussian GT
|
| 117 |
+
offset_pred: (B, 2, H, W) predicted offsets
|
| 118 |
+
offset_gt: (B, 2, H, W) GT offsets
|
| 119 |
+
offset_mask: (B, H, W) boolean mask
|
| 120 |
+
lambda_offset: weight for offset loss (default 1.0)
|
| 121 |
+
focal_alpha: focal loss alpha
|
| 122 |
+
focal_beta: focal loss beta
|
| 123 |
+
conf_weights: optional per-pixel confidence weights
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
(total_loss, heatmap_loss_value, offset_loss_value)
|
| 127 |
+
"""
|
| 128 |
+
l_hm = cornernet_focal_loss(
|
| 129 |
+
heatmap_pred, heatmap_gt,
|
| 130 |
+
alpha=focal_alpha, beta=focal_beta,
|
| 131 |
+
conf_weights=conf_weights,
|
| 132 |
+
)
|
| 133 |
+
l_off = offset_loss(offset_pred, offset_gt, offset_mask)
|
| 134 |
+
|
| 135 |
+
total = l_hm + lambda_offset * l_off
|
| 136 |
+
|
| 137 |
+
return total, l_hm.item(), l_off.item()
|
src/model.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CenterNet with CEM500K-pretrained ResNet-50 backbone for immunogold detection.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
Input: 1ch grayscale, variable size (padded to multiple of 32)
|
| 6 |
+
Encoder: CEM500K ResNet-50 (pretrained), conv1 adapted for 1ch input
|
| 7 |
+
Neck: BiFPN (2 rounds, 128ch)
|
| 8 |
+
Decoder: Transposed conv → stride-2 output
|
| 9 |
+
Heads: Heatmap (2ch sigmoid), Offset (2ch)
|
| 10 |
+
Output: Stride-2 maps → (H/2, W/2) resolution
|
| 11 |
+
|
| 12 |
+
Output stride is 2, NOT 4 or 8. At stride 4, a 6nm bead (4-6px radius)
|
| 13 |
+
collapses to 1px in feature space — insufficient for detection.
|
| 14 |
+
At stride 2, same bead occupies 2-3px, enough for Gaussian peak extraction.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torchvision.models as models
|
| 22 |
+
from typing import List, Optional
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# BiFPN: Bidirectional Feature Pyramid Network
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 30 |
+
"""Depthwise separable convolution as used in BiFPN."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3,
|
| 33 |
+
stride: int = 1, padding: int = 1):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.depthwise = nn.Conv2d(
|
| 36 |
+
in_ch, in_ch, kernel_size, stride=stride,
|
| 37 |
+
padding=padding, groups=in_ch, bias=False,
|
| 38 |
+
)
|
| 39 |
+
self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
|
| 40 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 41 |
+
self.act = nn.ReLU(inplace=True)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
return self.act(self.bn(self.pointwise(self.depthwise(x))))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BiFPNFusionNode(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
Single BiFPN fusion node with fast normalized weighted fusion.
|
| 50 |
+
|
| 51 |
+
w_normalized = relu(w) / (sum(relu(w)) + eps)
|
| 52 |
+
output = conv(sum(w_i * input_i))
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, channels: int, n_inputs: int = 2, eps: float = 1e-4):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.eps = eps
|
| 58 |
+
# Learnable fusion weights
|
| 59 |
+
self.weights = nn.Parameter(torch.ones(n_inputs, dtype=torch.float32))
|
| 60 |
+
self.conv = DepthwiseSeparableConv(channels, channels)
|
| 61 |
+
|
| 62 |
+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
| 63 |
+
# Fast normalized fusion
|
| 64 |
+
w = F.relu(self.weights)
|
| 65 |
+
w_norm = w / (w.sum() + self.eps)
|
| 66 |
+
|
| 67 |
+
fused = sum(w_i * inp for w_i, inp in zip(w_norm, inputs))
|
| 68 |
+
return self.conv(fused)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class BiFPNLayer(nn.Module):
|
| 72 |
+
"""
|
| 73 |
+
One round of BiFPN: top-down + bottom-up bidirectional fusion.
|
| 74 |
+
|
| 75 |
+
Input levels: P2 (stride 4), P3 (stride 8), P4 (stride 16), P5 (stride 32)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, channels: int):
|
| 79 |
+
super().__init__()
|
| 80 |
+
# Top-down fusion nodes (P5 → P4_td, P4_td+P3 → P3_td, P3_td+P2 → P2_td)
|
| 81 |
+
self.td_p4 = BiFPNFusionNode(channels, n_inputs=2)
|
| 82 |
+
self.td_p3 = BiFPNFusionNode(channels, n_inputs=2)
|
| 83 |
+
self.td_p2 = BiFPNFusionNode(channels, n_inputs=2)
|
| 84 |
+
|
| 85 |
+
# Bottom-up fusion nodes (combine top-down outputs with original)
|
| 86 |
+
self.bu_p3 = BiFPNFusionNode(channels, n_inputs=3) # p3_orig + p3_td + p2_out
|
| 87 |
+
self.bu_p4 = BiFPNFusionNode(channels, n_inputs=3) # p4_orig + p4_td + p3_out
|
| 88 |
+
self.bu_p5 = BiFPNFusionNode(channels, n_inputs=2) # p5_orig + p4_out
|
| 89 |
+
|
| 90 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
features: [P2, P3, P4, P5] at channels ch, with decreasing spatial dims
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
[P2_out, P3_out, P4_out, P5_out]
|
| 97 |
+
"""
|
| 98 |
+
p2, p3, p4, p5 = features
|
| 99 |
+
|
| 100 |
+
# --- Top-down pathway ---
|
| 101 |
+
# P5 → upscale → fuse with P4
|
| 102 |
+
p5_up = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
|
| 103 |
+
p4_td = self.td_p4([p4, p5_up])
|
| 104 |
+
|
| 105 |
+
# P4_td → upscale → fuse with P3
|
| 106 |
+
p4_td_up = F.interpolate(p4_td, size=p3.shape[2:], mode="nearest")
|
| 107 |
+
p3_td = self.td_p3([p3, p4_td_up])
|
| 108 |
+
|
| 109 |
+
# P3_td → upscale → fuse with P2
|
| 110 |
+
p3_td_up = F.interpolate(p3_td, size=p2.shape[2:], mode="nearest")
|
| 111 |
+
p2_td = self.td_p2([p2, p3_td_up])
|
| 112 |
+
|
| 113 |
+
# --- Bottom-up pathway ---
|
| 114 |
+
p2_out = p2_td
|
| 115 |
+
|
| 116 |
+
# P2_out → downsample → fuse with P3_td and P3_orig
|
| 117 |
+
p2_down = F.max_pool2d(p2_out, kernel_size=2)
|
| 118 |
+
p3_out = self.bu_p3([p3, p3_td, p2_down])
|
| 119 |
+
|
| 120 |
+
# P3_out → downsample → fuse with P4_td and P4_orig
|
| 121 |
+
p3_down = F.max_pool2d(p3_out, kernel_size=2)
|
| 122 |
+
p4_out = self.bu_p4([p4, p4_td, p3_down])
|
| 123 |
+
|
| 124 |
+
# P4_out → downsample → fuse with P5_orig
|
| 125 |
+
p4_down = F.max_pool2d(p4_out, kernel_size=2)
|
| 126 |
+
p5_out = self.bu_p5([p5, p4_down])
|
| 127 |
+
|
| 128 |
+
return [p2_out, p3_out, p4_out, p5_out]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class BiFPN(nn.Module):
|
| 132 |
+
"""Multi-round BiFPN with lateral projections."""
|
| 133 |
+
|
| 134 |
+
def __init__(self, in_channels: List[int], out_channels: int = 128,
|
| 135 |
+
num_rounds: int = 2):
|
| 136 |
+
super().__init__()
|
| 137 |
+
# Lateral 1x1 projections to unify channel count
|
| 138 |
+
self.laterals = nn.ModuleList([
|
| 139 |
+
nn.Sequential(
|
| 140 |
+
nn.Conv2d(in_ch, out_channels, 1, bias=False),
|
| 141 |
+
nn.BatchNorm2d(out_channels),
|
| 142 |
+
nn.ReLU(inplace=True),
|
| 143 |
+
)
|
| 144 |
+
for in_ch in in_channels
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
# BiFPN rounds
|
| 148 |
+
self.rounds = nn.ModuleList([
|
| 149 |
+
BiFPNLayer(out_channels) for _ in range(num_rounds)
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 153 |
+
# Project to uniform channels
|
| 154 |
+
projected = [lat(feat) for lat, feat in zip(self.laterals, features)]
|
| 155 |
+
|
| 156 |
+
# Run BiFPN rounds
|
| 157 |
+
for bifpn_round in self.rounds:
|
| 158 |
+
projected = bifpn_round(projected)
|
| 159 |
+
|
| 160 |
+
return projected
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Detection Heads
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
class HeatmapHead(nn.Module):
|
| 168 |
+
"""Heatmap prediction head at stride-2 resolution."""
|
| 169 |
+
|
| 170 |
+
def __init__(self, in_channels: int = 64, num_classes: int = 2):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
|
| 173 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 174 |
+
self.relu = nn.ReLU(inplace=True)
|
| 175 |
+
self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
|
| 176 |
+
|
| 177 |
+
# Initialize final conv bias for focal loss: -log((1-pi)/pi) where pi=0.01
|
| 178 |
+
# This prevents the network from producing high false positive rate early
|
| 179 |
+
nn.init.constant_(self.conv2.bias, -math.log((1 - 0.01) / 0.01))
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
| 183 |
+
return torch.sigmoid(self.conv2(x))
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class OffsetHead(nn.Module):
|
| 187 |
+
"""Sub-pixel offset regression head."""
|
| 188 |
+
|
| 189 |
+
def __init__(self, in_channels: int = 64):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
|
| 192 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 193 |
+
self.relu = nn.ReLU(inplace=True)
|
| 194 |
+
self.conv2 = nn.Conv2d(64, 2, kernel_size=1) # dx, dy
|
| 195 |
+
|
| 196 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 197 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
| 198 |
+
return self.conv2(x)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
# Full CenterNet Model
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
class ImmunogoldCenterNet(nn.Module):
|
| 206 |
+
"""
|
| 207 |
+
CenterNet with CEM500K-pretrained ResNet-50 backbone.
|
| 208 |
+
|
| 209 |
+
Detects 6nm and 12nm immunogold particles at stride-2 resolution.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
pretrained_path: Optional[str] = None,
|
| 215 |
+
bifpn_channels: int = 128,
|
| 216 |
+
bifpn_rounds: int = 2,
|
| 217 |
+
num_classes: int = 2,
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.num_classes = num_classes
|
| 221 |
+
|
| 222 |
+
# --- Encoder: ResNet-50 ---
|
| 223 |
+
backbone = models.resnet50(weights=None)
|
| 224 |
+
# Adapt conv1 for 1-channel grayscale input
|
| 225 |
+
backbone.conv1 = nn.Conv2d(
|
| 226 |
+
1, 64, kernel_size=7, stride=2, padding=3, bias=False,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Load pretrained weights
|
| 230 |
+
if pretrained_path:
|
| 231 |
+
self._load_pretrained(backbone, pretrained_path)
|
| 232 |
+
else:
|
| 233 |
+
# Use ImageNet weights as fallback, adapting conv1
|
| 234 |
+
imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
| 235 |
+
state = imagenet_backbone.state_dict()
|
| 236 |
+
# Mean-pool RGB conv1 weights → grayscale
|
| 237 |
+
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
|
| 238 |
+
backbone.load_state_dict(state, strict=False)
|
| 239 |
+
|
| 240 |
+
# Extract encoder stages
|
| 241 |
+
self.stem = nn.Sequential(
|
| 242 |
+
backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
|
| 243 |
+
)
|
| 244 |
+
self.layer1 = backbone.layer1 # 256ch, stride 4
|
| 245 |
+
self.layer2 = backbone.layer2 # 512ch, stride 8
|
| 246 |
+
self.layer3 = backbone.layer3 # 1024ch, stride 16
|
| 247 |
+
self.layer4 = backbone.layer4 # 2048ch, stride 32
|
| 248 |
+
|
| 249 |
+
# --- BiFPN Neck ---
|
| 250 |
+
self.bifpn = BiFPN(
|
| 251 |
+
in_channels=[256, 512, 1024, 2048],
|
| 252 |
+
out_channels=bifpn_channels,
|
| 253 |
+
num_rounds=bifpn_rounds,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# --- Decoder: upsample P2 (stride 4) → stride 2 ---
|
| 257 |
+
self.upsample = nn.Sequential(
|
| 258 |
+
nn.ConvTranspose2d(
|
| 259 |
+
bifpn_channels, 64, kernel_size=4, stride=2, padding=1, bias=False,
|
| 260 |
+
),
|
| 261 |
+
nn.BatchNorm2d(64),
|
| 262 |
+
nn.ReLU(inplace=True),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# --- Detection Heads (at stride-2 resolution) ---
|
| 266 |
+
self.heatmap_head = HeatmapHead(64, num_classes)
|
| 267 |
+
self.offset_head = OffsetHead(64)
|
| 268 |
+
|
| 269 |
+
def _load_pretrained(self, backbone: nn.Module, path: str):
|
| 270 |
+
"""Load CEM500K MoCoV2 pretrained weights."""
|
| 271 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 272 |
+
|
| 273 |
+
state = {}
|
| 274 |
+
# CEM500K uses MoCo format: keys prefixed with 'module.encoder_q.'
|
| 275 |
+
src_state = ckpt.get("state_dict", ckpt)
|
| 276 |
+
for k, v in src_state.items():
|
| 277 |
+
# Strip MoCo prefix
|
| 278 |
+
new_key = k
|
| 279 |
+
for prefix in ["module.encoder_q.", "module.", "encoder_q."]:
|
| 280 |
+
if new_key.startswith(prefix):
|
| 281 |
+
new_key = new_key[len(prefix):]
|
| 282 |
+
break
|
| 283 |
+
state[new_key] = v
|
| 284 |
+
|
| 285 |
+
# Adapt conv1: mean-pool 3ch RGB → 1ch grayscale
|
| 286 |
+
if "conv1.weight" in state and state["conv1.weight"].shape[1] == 3:
|
| 287 |
+
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
|
| 288 |
+
|
| 289 |
+
# Load with strict=False (head layers won't match)
|
| 290 |
+
missing, unexpected = backbone.load_state_dict(state, strict=False)
|
| 291 |
+
# Expected: fc.weight, fc.bias will be missing/unexpected
|
| 292 |
+
print(f"CEM500K loaded: {len(state)} keys, "
|
| 293 |
+
f"{len(missing)} missing, {len(unexpected)} unexpected")
|
| 294 |
+
|
| 295 |
+
def forward(self, x: torch.Tensor) -> tuple:
|
| 296 |
+
"""
|
| 297 |
+
Args:
|
| 298 |
+
x: (B, 1, H, W) grayscale input
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
heatmap: (B, 2, H/2, W/2) sigmoid-activated class heatmaps
|
| 302 |
+
offsets: (B, 2, H/2, W/2) sub-pixel offset predictions
|
| 303 |
+
"""
|
| 304 |
+
# Encoder
|
| 305 |
+
x0 = self.stem(x) # stride 4
|
| 306 |
+
p2 = self.layer1(x0) # 256ch, stride 4
|
| 307 |
+
p3 = self.layer2(p2) # 512ch, stride 8
|
| 308 |
+
p4 = self.layer3(p3) # 1024ch, stride 16
|
| 309 |
+
p5 = self.layer4(p4) # 2048ch, stride 32
|
| 310 |
+
|
| 311 |
+
# BiFPN neck
|
| 312 |
+
features = self.bifpn([p2, p3, p4, p5])
|
| 313 |
+
|
| 314 |
+
# Decoder: upsample P2 to stride 2
|
| 315 |
+
x_up = self.upsample(features[0])
|
| 316 |
+
|
| 317 |
+
# Detection heads
|
| 318 |
+
heatmap = self.heatmap_head(x_up) # (B, 2, H/2, W/2)
|
| 319 |
+
offsets = self.offset_head(x_up) # (B, 2, H/2, W/2)
|
| 320 |
+
|
| 321 |
+
return heatmap, offsets
|
| 322 |
+
|
| 323 |
+
def freeze_encoder(self):
|
| 324 |
+
"""Freeze entire encoder (Phase 1 training)."""
|
| 325 |
+
for module in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
|
| 326 |
+
for param in module.parameters():
|
| 327 |
+
param.requires_grad = False
|
| 328 |
+
|
| 329 |
+
def unfreeze_deep_layers(self):
|
| 330 |
+
"""Unfreeze layer3 and layer4 (Phase 2 training)."""
|
| 331 |
+
for module in [self.layer3, self.layer4]:
|
| 332 |
+
for param in module.parameters():
|
| 333 |
+
param.requires_grad = True
|
| 334 |
+
|
| 335 |
+
def unfreeze_all(self):
|
| 336 |
+
"""Unfreeze all layers (Phase 3 training)."""
|
| 337 |
+
for param in self.parameters():
|
| 338 |
+
param.requires_grad = True
|
| 339 |
+
|
| 340 |
+
def get_param_groups(self, phase: int, cfg: dict) -> list:
|
| 341 |
+
"""
|
| 342 |
+
Get parameter groups with discriminative learning rates per phase.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
phase: 1, 2, or 3
|
| 346 |
+
cfg: training phase config from config.yaml
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
List of param group dicts for optimizer.
|
| 350 |
+
"""
|
| 351 |
+
if phase == 1:
|
| 352 |
+
# Only neck + heads trainable
|
| 353 |
+
return [
|
| 354 |
+
{"params": self.bifpn.parameters(), "lr": cfg["lr"]},
|
| 355 |
+
{"params": self.upsample.parameters(), "lr": cfg["lr"]},
|
| 356 |
+
{"params": self.heatmap_head.parameters(), "lr": cfg["lr"]},
|
| 357 |
+
{"params": self.offset_head.parameters(), "lr": cfg["lr"]},
|
| 358 |
+
]
|
| 359 |
+
elif phase == 2:
|
| 360 |
+
return [
|
| 361 |
+
{"params": self.stem.parameters(), "lr": 0},
|
| 362 |
+
{"params": self.layer1.parameters(), "lr": 0},
|
| 363 |
+
{"params": self.layer2.parameters(), "lr": 0},
|
| 364 |
+
{"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
|
| 365 |
+
{"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
|
| 366 |
+
{"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
|
| 367 |
+
{"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
|
| 368 |
+
{"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 369 |
+
{"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 370 |
+
]
|
| 371 |
+
else: # phase 3
|
| 372 |
+
return [
|
| 373 |
+
{"params": self.stem.parameters(), "lr": cfg["lr_stem"]},
|
| 374 |
+
{"params": self.layer1.parameters(), "lr": cfg["lr_layer1"]},
|
| 375 |
+
{"params": self.layer2.parameters(), "lr": cfg["lr_layer2"]},
|
| 376 |
+
{"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
|
| 377 |
+
{"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
|
| 378 |
+
{"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
|
| 379 |
+
{"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
|
| 380 |
+
{"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 381 |
+
{"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 382 |
+
]
|
src/postprocess.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Post-processing: structural mask filtering, cross-class NMS, threshold sweep.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.spatial.distance import cdist
|
| 7 |
+
from skimage.morphology import dilation, disk
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def apply_structural_mask_filter(
|
| 12 |
+
detections: List[dict],
|
| 13 |
+
mask: np.ndarray,
|
| 14 |
+
margin_px: int = 5,
|
| 15 |
+
) -> List[dict]:
|
| 16 |
+
"""
|
| 17 |
+
Remove detections outside biological tissue regions.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
detections: list of {'x', 'y', 'class', 'conf'}
|
| 21 |
+
mask: boolean array (H, W) where True = tissue region
|
| 22 |
+
margin_px: dilate mask by this many pixels
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Filtered detection list.
|
| 26 |
+
"""
|
| 27 |
+
if mask is None:
|
| 28 |
+
return detections
|
| 29 |
+
|
| 30 |
+
# Dilate mask to allow particles at region boundaries
|
| 31 |
+
tissue = dilation(mask, disk(margin_px))
|
| 32 |
+
|
| 33 |
+
filtered = []
|
| 34 |
+
for det in detections:
|
| 35 |
+
xi, yi = int(round(det["x"])), int(round(det["y"]))
|
| 36 |
+
if (0 <= yi < tissue.shape[0] and
|
| 37 |
+
0 <= xi < tissue.shape[1] and
|
| 38 |
+
tissue[yi, xi]):
|
| 39 |
+
filtered.append(det)
|
| 40 |
+
return filtered
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def cross_class_nms(
|
| 44 |
+
detections: List[dict],
|
| 45 |
+
distance_threshold: float = 8.0,
|
| 46 |
+
) -> List[dict]:
|
| 47 |
+
"""
|
| 48 |
+
When 6nm and 12nm detections overlap, keep the higher-confidence one.
|
| 49 |
+
|
| 50 |
+
This handles cases where both heads fire on the same particle.
|
| 51 |
+
"""
|
| 52 |
+
if len(detections) <= 1:
|
| 53 |
+
return detections
|
| 54 |
+
|
| 55 |
+
# Sort by confidence descending
|
| 56 |
+
dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
|
| 57 |
+
keep = [True] * len(dets)
|
| 58 |
+
|
| 59 |
+
coords = np.array([[d["x"], d["y"]] for d in dets])
|
| 60 |
+
|
| 61 |
+
for i in range(len(dets)):
|
| 62 |
+
if not keep[i]:
|
| 63 |
+
continue
|
| 64 |
+
for j in range(i + 1, len(dets)):
|
| 65 |
+
if not keep[j]:
|
| 66 |
+
continue
|
| 67 |
+
# Only suppress across classes
|
| 68 |
+
if dets[i]["class"] == dets[j]["class"]:
|
| 69 |
+
continue
|
| 70 |
+
dist = np.sqrt(
|
| 71 |
+
(coords[i, 0] - coords[j, 0]) ** 2
|
| 72 |
+
+ (coords[i, 1] - coords[j, 1]) ** 2
|
| 73 |
+
)
|
| 74 |
+
if dist < distance_threshold:
|
| 75 |
+
keep[j] = False # Lower confidence suppressed
|
| 76 |
+
|
| 77 |
+
return [d for d, k in zip(dets, keep) if k]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def sweep_confidence_threshold(
|
| 81 |
+
detections: List[dict],
|
| 82 |
+
gt_coords: Dict[str, np.ndarray],
|
| 83 |
+
match_radii: Dict[str, float],
|
| 84 |
+
start: float = 0.05,
|
| 85 |
+
stop: float = 0.95,
|
| 86 |
+
step: float = 0.01,
|
| 87 |
+
) -> Dict[str, float]:
|
| 88 |
+
"""
|
| 89 |
+
Sweep confidence thresholds to find optimal per-class thresholds.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
detections: all detections (before thresholding)
|
| 93 |
+
gt_coords: {'6nm': Nx2, '12nm': Mx2} ground truth
|
| 94 |
+
match_radii: per-class match radii in pixels
|
| 95 |
+
start, stop, step: sweep range
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Dict with best threshold per class and overall.
|
| 99 |
+
"""
|
| 100 |
+
from src.evaluate import match_detections_to_gt, compute_f1
|
| 101 |
+
|
| 102 |
+
best_thresholds = {}
|
| 103 |
+
thresholds = np.arange(start, stop, step)
|
| 104 |
+
|
| 105 |
+
for cls in ["6nm", "12nm"]:
|
| 106 |
+
best_f1 = -1
|
| 107 |
+
best_thr = 0.3
|
| 108 |
+
|
| 109 |
+
for thr in thresholds:
|
| 110 |
+
cls_dets = [d for d in detections if d["class"] == cls and d["conf"] >= thr]
|
| 111 |
+
if not cls_dets and len(gt_coords[cls]) == 0:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
pred_coords = np.array([[d["x"], d["y"]] for d in cls_dets]).reshape(-1, 2)
|
| 115 |
+
gt = gt_coords[cls]
|
| 116 |
+
|
| 117 |
+
if len(pred_coords) == 0:
|
| 118 |
+
tp, fp, fn = 0, 0, len(gt)
|
| 119 |
+
elif len(gt) == 0:
|
| 120 |
+
tp, fp, fn = 0, len(pred_coords), 0
|
| 121 |
+
else:
|
| 122 |
+
tp, fp, fn = _simple_match(pred_coords, gt, match_radii[cls])
|
| 123 |
+
|
| 124 |
+
f1, _, _ = compute_f1(tp, fp, fn)
|
| 125 |
+
if f1 > best_f1:
|
| 126 |
+
best_f1 = f1
|
| 127 |
+
best_thr = thr
|
| 128 |
+
|
| 129 |
+
best_thresholds[cls] = best_thr
|
| 130 |
+
|
| 131 |
+
return best_thresholds
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _simple_match(
|
| 135 |
+
pred: np.ndarray, gt: np.ndarray, radius: float
|
| 136 |
+
) -> tuple:
|
| 137 |
+
"""Quick matching for threshold sweep (greedy, not Hungarian)."""
|
| 138 |
+
from scipy.spatial.distance import cdist
|
| 139 |
+
|
| 140 |
+
if len(pred) == 0 or len(gt) == 0:
|
| 141 |
+
return 0, len(pred), len(gt)
|
| 142 |
+
|
| 143 |
+
dists = cdist(pred, gt)
|
| 144 |
+
tp = 0
|
| 145 |
+
matched_gt = set()
|
| 146 |
+
|
| 147 |
+
# Greedy: match closest pairs first
|
| 148 |
+
for i in range(len(pred)):
|
| 149 |
+
min_j = np.argmin(dists[i])
|
| 150 |
+
if dists[i, min_j] <= radius and min_j not in matched_gt:
|
| 151 |
+
tp += 1
|
| 152 |
+
matched_gt.add(min_j)
|
| 153 |
+
dists[:, min_j] = np.inf
|
| 154 |
+
|
| 155 |
+
fp = len(pred) - tp
|
| 156 |
+
fn = len(gt) - tp
|
| 157 |
+
return tp, fp, fn
|
src/preprocessing.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading, annotation parsing, and preprocessing for immunogold TEM images.
|
| 3 |
+
|
| 4 |
+
The model receives raw images — the CEM500K backbone was pretrained on raw EM.
|
| 5 |
+
Top-hat preprocessing is only used by LodeStar (Stage 1).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import tifffile
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Data registry: robust discovery of images, masks, and annotations
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class SynapseRecord:
|
| 23 |
+
"""Metadata for one synapse sample."""
|
| 24 |
+
synapse_id: str
|
| 25 |
+
image_path: Path
|
| 26 |
+
mask_path: Optional[Path]
|
| 27 |
+
csv_6nm_paths: List[Path] = field(default_factory=list)
|
| 28 |
+
csv_12nm_paths: List[Path] = field(default_factory=list)
|
| 29 |
+
has_6nm: bool = False
|
| 30 |
+
has_12nm: bool = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def discover_synapse_data(root: str, synapse_ids: List[str]) -> List[SynapseRecord]:
|
| 34 |
+
"""
|
| 35 |
+
Discover all TIF images, masks, and CSV annotations for each synapse.
|
| 36 |
+
|
| 37 |
+
Handles naming inconsistencies:
|
| 38 |
+
- S22: main image is S22_0003.tif, two Results folders
|
| 39 |
+
- S25: 12nm CSV has no space ("Results12nm")
|
| 40 |
+
- CSV patterns: "Results 6nm XY" vs "Results XY in microns 6nm"
|
| 41 |
+
"""
|
| 42 |
+
root = Path(root)
|
| 43 |
+
analyzed = root / "analyzed synapses"
|
| 44 |
+
records = []
|
| 45 |
+
|
| 46 |
+
for sid in synapse_ids:
|
| 47 |
+
folder = analyzed / sid
|
| 48 |
+
if not folder.exists():
|
| 49 |
+
raise FileNotFoundError(f"Synapse folder not found: {folder}")
|
| 50 |
+
|
| 51 |
+
# --- Find main image (TIF without 'mask' or 'color' in name) ---
|
| 52 |
+
all_tifs = list(folder.glob("*.tif"))
|
| 53 |
+
main_tifs = [
|
| 54 |
+
t for t in all_tifs
|
| 55 |
+
if "mask" not in t.stem.lower() and "color" not in t.stem.lower()
|
| 56 |
+
]
|
| 57 |
+
if not main_tifs:
|
| 58 |
+
raise FileNotFoundError(f"No main image found in {folder}")
|
| 59 |
+
# Prefer the largest file (main EM image) if multiple found
|
| 60 |
+
image_path = max(main_tifs, key=lambda t: t.stat().st_size)
|
| 61 |
+
|
| 62 |
+
# --- Find mask ---
|
| 63 |
+
mask_tifs = [t for t in all_tifs if "mask" in t.stem.lower()]
|
| 64 |
+
mask_path = None
|
| 65 |
+
if mask_tifs:
|
| 66 |
+
# Prefer plain "mask.tif" over "mask 1.tif" / "mask 2.tif"
|
| 67 |
+
plain = [t for t in mask_tifs if t.stem.lower().endswith("mask")]
|
| 68 |
+
mask_path = plain[0] if plain else mask_tifs[0]
|
| 69 |
+
|
| 70 |
+
# --- Find CSVs across all Results* subdirectories ---
|
| 71 |
+
results_dirs = sorted(folder.glob("Results*"))
|
| 72 |
+
# Also check direct subdirs like "Results 1", "Results 2"
|
| 73 |
+
csv_6nm_paths = []
|
| 74 |
+
csv_12nm_paths = []
|
| 75 |
+
|
| 76 |
+
for rdir in results_dirs:
|
| 77 |
+
if rdir.is_dir():
|
| 78 |
+
for csv_file in rdir.glob("*.csv"):
|
| 79 |
+
name_lower = csv_file.name.lower()
|
| 80 |
+
if "6nm" in name_lower:
|
| 81 |
+
csv_6nm_paths.append(csv_file)
|
| 82 |
+
elif "12nm" in name_lower:
|
| 83 |
+
csv_12nm_paths.append(csv_file)
|
| 84 |
+
|
| 85 |
+
record = SynapseRecord(
|
| 86 |
+
synapse_id=sid,
|
| 87 |
+
image_path=image_path,
|
| 88 |
+
mask_path=mask_path,
|
| 89 |
+
csv_6nm_paths=csv_6nm_paths,
|
| 90 |
+
csv_12nm_paths=csv_12nm_paths,
|
| 91 |
+
has_6nm=len(csv_6nm_paths) > 0,
|
| 92 |
+
has_12nm=len(csv_12nm_paths) > 0,
|
| 93 |
+
)
|
| 94 |
+
records.append(record)
|
| 95 |
+
|
| 96 |
+
return records
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
# Image I/O
|
| 101 |
+
# ---------------------------------------------------------------------------
|
| 102 |
+
|
| 103 |
+
def load_image(path: Path) -> np.ndarray:
|
| 104 |
+
"""
|
| 105 |
+
Load a TIF image as grayscale uint8.
|
| 106 |
+
|
| 107 |
+
Handles:
|
| 108 |
+
- RGB images (take first channel)
|
| 109 |
+
- Palette-mode images
|
| 110 |
+
- Already-grayscale images
|
| 111 |
+
"""
|
| 112 |
+
img = tifffile.imread(str(path))
|
| 113 |
+
if img.ndim == 3:
|
| 114 |
+
# RGB or multi-channel — take first channel (all channels identical in these images)
|
| 115 |
+
img = img[:, :, 0] if img.shape[2] <= 4 else img[0]
|
| 116 |
+
return img.astype(np.uint8)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_mask(path: Path) -> np.ndarray:
|
| 120 |
+
"""
|
| 121 |
+
Load mask TIF as binary array.
|
| 122 |
+
|
| 123 |
+
Mask is RGB where tissue regions have values < 250 in at least one channel.
|
| 124 |
+
Returns boolean array: True = tissue/structural region.
|
| 125 |
+
"""
|
| 126 |
+
mask_rgb = tifffile.imread(str(path))
|
| 127 |
+
if mask_rgb.ndim == 2:
|
| 128 |
+
return mask_rgb < 250
|
| 129 |
+
# RGB mask: tissue where any channel is not white
|
| 130 |
+
return np.any(mask_rgb < 250, axis=-1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
# Annotation loading and coordinate conversion
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
|
| 137 |
+
def load_annotations_csv(csv_path: Path) -> pd.DataFrame:
|
| 138 |
+
"""
|
| 139 |
+
Load annotation CSV with columns [index, X, Y].
|
| 140 |
+
|
| 141 |
+
CSV headers have leading space: " ,X,Y".
|
| 142 |
+
Coordinates are normalized [0, 1] despite 'microns' in filename.
|
| 143 |
+
"""
|
| 144 |
+
df = pd.read_csv(csv_path)
|
| 145 |
+
# Normalize column names (strip whitespace)
|
| 146 |
+
df.columns = [c.strip() for c in df.columns]
|
| 147 |
+
# Rename unnamed index column
|
| 148 |
+
if "" in df.columns:
|
| 149 |
+
df = df.rename(columns={"": "idx"})
|
| 150 |
+
return df[["X", "Y"]]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Micron-to-pixel scale factor: consistent across all synapses (verified
|
| 154 |
+
# against researcher's color overlay TIFs). The CSV columns labeled "XY in
|
| 155 |
+
# microns" really ARE microns — multiply by this constant to get pixels.
|
| 156 |
+
MICRONS_TO_PIXELS = 1790.0
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_all_annotations(
|
| 160 |
+
record: SynapseRecord, image_shape: Tuple[int, int]
|
| 161 |
+
) -> Dict[str, np.ndarray]:
|
| 162 |
+
"""
|
| 163 |
+
Load and convert annotations for one synapse to pixel coordinates.
|
| 164 |
+
|
| 165 |
+
CSV coordinates are in microns (despite filename suggesting normalization).
|
| 166 |
+
Multiply by MICRONS_TO_PIXELS (1790 px/micron) to convert.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
record: SynapseRecord with CSV paths.
|
| 170 |
+
image_shape: (height, width) of the corresponding image.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Dictionary with keys '6nm' and '12nm', each containing
|
| 174 |
+
an Nx2 array of (x, y) pixel coordinates.
|
| 175 |
+
"""
|
| 176 |
+
h, w = image_shape[:2]
|
| 177 |
+
result = {"6nm": np.empty((0, 2), dtype=np.float64),
|
| 178 |
+
"12nm": np.empty((0, 2), dtype=np.float64)}
|
| 179 |
+
|
| 180 |
+
for cls, paths in [("6nm", record.csv_6nm_paths),
|
| 181 |
+
("12nm", record.csv_12nm_paths)]:
|
| 182 |
+
all_coords = []
|
| 183 |
+
for csv_path in paths:
|
| 184 |
+
df = load_annotations_csv(csv_path)
|
| 185 |
+
# Convert microns to pixels
|
| 186 |
+
px_x = df["X"].values * MICRONS_TO_PIXELS
|
| 187 |
+
px_y = df["Y"].values * MICRONS_TO_PIXELS
|
| 188 |
+
# Validate: coords must fall within image bounds
|
| 189 |
+
assert px_x.max() < w + 10, \
|
| 190 |
+
f"X coords out of bounds ({px_x.max():.0f} > {w}) in {csv_path}"
|
| 191 |
+
assert px_y.max() < h + 10, \
|
| 192 |
+
f"Y coords out of bounds ({px_y.max():.0f} > {h}) in {csv_path}"
|
| 193 |
+
all_coords.append(np.stack([px_x, px_y], axis=1))
|
| 194 |
+
|
| 195 |
+
if all_coords:
|
| 196 |
+
coords = np.concatenate(all_coords, axis=0)
|
| 197 |
+
# Deduplicate (for S22 merged results): remove within 3px
|
| 198 |
+
if len(coords) > 1:
|
| 199 |
+
coords = _deduplicate_coords(coords, min_dist=3.0)
|
| 200 |
+
result[cls] = coords
|
| 201 |
+
|
| 202 |
+
return result
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _deduplicate_coords(
|
| 206 |
+
coords: np.ndarray, min_dist: float = 3.0
|
| 207 |
+
) -> np.ndarray:
|
| 208 |
+
"""Remove duplicate coordinates within min_dist pixels."""
|
| 209 |
+
from scipy.spatial.distance import cdist
|
| 210 |
+
|
| 211 |
+
if len(coords) <= 1:
|
| 212 |
+
return coords
|
| 213 |
+
dists = cdist(coords, coords)
|
| 214 |
+
np.fill_diagonal(dists, np.inf)
|
| 215 |
+
keep = np.ones(len(coords), dtype=bool)
|
| 216 |
+
for i in range(len(coords)):
|
| 217 |
+
if not keep[i]:
|
| 218 |
+
continue
|
| 219 |
+
# Mark later duplicates
|
| 220 |
+
for j in range(i + 1, len(coords)):
|
| 221 |
+
if keep[j] and dists[i, j] < min_dist:
|
| 222 |
+
keep[j] = False
|
| 223 |
+
return coords[keep]
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
# Preprocessing transforms
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
|
| 230 |
+
def preprocess_image(img: np.ndarray, bead_class: str,
|
| 231 |
+
tophat_radii: Optional[Dict[str, int]] = None,
|
| 232 |
+
clahe_clip_limit: float = 0.03,
|
| 233 |
+
clahe_kernel_size: int = 64) -> np.ndarray:
|
| 234 |
+
"""
|
| 235 |
+
Top-hat + CLAHE preprocessing. Used ONLY by LodeStar (Stage 1).
|
| 236 |
+
|
| 237 |
+
Not used for model training — the CEM500K backbone expects raw EM images.
|
| 238 |
+
"""
|
| 239 |
+
from skimage import exposure
|
| 240 |
+
from skimage.morphology import disk, white_tophat
|
| 241 |
+
|
| 242 |
+
if tophat_radii is None:
|
| 243 |
+
tophat_radii = {"6nm": 8, "12nm": 12}
|
| 244 |
+
|
| 245 |
+
img_inv = (255 - img).astype(np.float32)
|
| 246 |
+
radius = tophat_radii[bead_class]
|
| 247 |
+
tophat = white_tophat(img_inv, disk(radius))
|
| 248 |
+
|
| 249 |
+
tophat_max = tophat.max()
|
| 250 |
+
if tophat_max > 0:
|
| 251 |
+
tophat_norm = tophat / tophat_max
|
| 252 |
+
else:
|
| 253 |
+
tophat_norm = tophat
|
| 254 |
+
|
| 255 |
+
enhanced = exposure.equalize_adapthist(
|
| 256 |
+
tophat_norm,
|
| 257 |
+
clip_limit=clahe_clip_limit,
|
| 258 |
+
kernel_size=clahe_kernel_size,
|
| 259 |
+
)
|
| 260 |
+
return (enhanced * 255).astype(np.uint8)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# ---------------------------------------------------------------------------
|
| 264 |
+
# Convenience: load everything for one synapse
|
| 265 |
+
# ---------------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
def load_synapse(record: SynapseRecord) -> dict:
|
| 268 |
+
"""
|
| 269 |
+
Load image, mask, and annotations for one synapse.
|
| 270 |
+
|
| 271 |
+
Returns dict with keys: 'image', 'mask', 'annotations',
|
| 272 |
+
'synapse_id', 'image_shape'
|
| 273 |
+
"""
|
| 274 |
+
img = load_image(record.image_path)
|
| 275 |
+
mask = load_mask(record.mask_path) if record.mask_path else None
|
| 276 |
+
annotations = load_all_annotations(record, img.shape)
|
| 277 |
+
|
| 278 |
+
return {
|
| 279 |
+
"synapse_id": record.synapse_id,
|
| 280 |
+
"image": img,
|
| 281 |
+
"mask": mask,
|
| 282 |
+
"annotations": annotations,
|
| 283 |
+
"image_shape": img.shape,
|
| 284 |
+
}
|
src/visualize.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization utilities for QC at every pipeline stage.
|
| 3 |
+
|
| 4 |
+
Generates overlay images showing predictions on raw EM images:
|
| 5 |
+
- Cyan circles for 6nm particles
|
| 6 |
+
- Yellow circles for 12nm particles
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib
|
| 11 |
+
matplotlib.use("Agg")
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib.patches as mpatches
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Color scheme
|
| 19 |
+
COLORS = {
|
| 20 |
+
"6nm": (0, 255, 255), # cyan
|
| 21 |
+
"12nm": (255, 255, 0), # yellow
|
| 22 |
+
"6nm_pred": (0, 200, 200),
|
| 23 |
+
"12nm_pred": (200, 200, 0),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
RADII = {"6nm": 6, "12nm": 12}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def overlay_annotations(
|
| 30 |
+
image: np.ndarray,
|
| 31 |
+
annotations: Dict[str, np.ndarray],
|
| 32 |
+
title: str = "",
|
| 33 |
+
save_path: Optional[Path] = None,
|
| 34 |
+
predictions: Optional[List[dict]] = None,
|
| 35 |
+
figsize: tuple = (12, 12),
|
| 36 |
+
) -> plt.Figure:
|
| 37 |
+
"""
|
| 38 |
+
Overlay ground truth annotations (and optional predictions) on image.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
image: (H, W) grayscale image
|
| 42 |
+
annotations: {'6nm': Nx2, '12nm': Mx2} pixel coordinates
|
| 43 |
+
title: figure title
|
| 44 |
+
save_path: if provided, save figure here
|
| 45 |
+
predictions: optional list of {'x', 'y', 'class', 'conf'}
|
| 46 |
+
figsize: figure size
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
matplotlib Figure
|
| 50 |
+
"""
|
| 51 |
+
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
| 52 |
+
ax.imshow(image, cmap="gray")
|
| 53 |
+
|
| 54 |
+
# Ground truth circles (solid)
|
| 55 |
+
for cls, coords in annotations.items():
|
| 56 |
+
if len(coords) == 0:
|
| 57 |
+
continue
|
| 58 |
+
color_rgb = np.array(COLORS[cls]) / 255.0
|
| 59 |
+
radius = RADII[cls]
|
| 60 |
+
for x, y in coords:
|
| 61 |
+
circle = plt.Circle(
|
| 62 |
+
(x, y), radius, fill=False,
|
| 63 |
+
edgecolor=color_rgb, linewidth=1.5,
|
| 64 |
+
)
|
| 65 |
+
ax.add_patch(circle)
|
| 66 |
+
|
| 67 |
+
# Predictions (dashed)
|
| 68 |
+
if predictions:
|
| 69 |
+
for det in predictions:
|
| 70 |
+
cls = det["class"]
|
| 71 |
+
color_rgb = np.array(COLORS.get(f"{cls}_pred", COLORS[cls])) / 255.0
|
| 72 |
+
radius = RADII[cls]
|
| 73 |
+
circle = plt.Circle(
|
| 74 |
+
(det["x"], det["y"]), radius, fill=False,
|
| 75 |
+
edgecolor=color_rgb, linewidth=1.0, linestyle="--",
|
| 76 |
+
)
|
| 77 |
+
ax.add_patch(circle)
|
| 78 |
+
# Confidence label
|
| 79 |
+
ax.text(
|
| 80 |
+
det["x"] + radius + 2, det["y"],
|
| 81 |
+
f'{det["conf"]:.2f}',
|
| 82 |
+
color=color_rgb, fontsize=6,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Legend
|
| 86 |
+
legend_elements = [
|
| 87 |
+
mpatches.Patch(facecolor="none", edgecolor="cyan", label=f'6nm GT ({len(annotations.get("6nm", []))})', linewidth=1.5),
|
| 88 |
+
mpatches.Patch(facecolor="none", edgecolor="yellow", label=f'12nm GT ({len(annotations.get("12nm", []))})', linewidth=1.5),
|
| 89 |
+
]
|
| 90 |
+
if predictions:
|
| 91 |
+
n_pred_6 = sum(1 for d in predictions if d["class"] == "6nm")
|
| 92 |
+
n_pred_12 = sum(1 for d in predictions if d["class"] == "12nm")
|
| 93 |
+
legend_elements.extend([
|
| 94 |
+
mpatches.Patch(facecolor="none", edgecolor="darkcyan", label=f"6nm pred ({n_pred_6})", linewidth=1.0),
|
| 95 |
+
mpatches.Patch(facecolor="none", edgecolor="goldenrod", label=f"12nm pred ({n_pred_12})", linewidth=1.0),
|
| 96 |
+
])
|
| 97 |
+
ax.legend(handles=legend_elements, loc="upper right", fontsize=8)
|
| 98 |
+
|
| 99 |
+
ax.set_title(title, fontsize=10)
|
| 100 |
+
ax.axis("off")
|
| 101 |
+
|
| 102 |
+
if save_path:
|
| 103 |
+
save_path = Path(save_path)
|
| 104 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 105 |
+
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
| 106 |
+
plt.close(fig)
|
| 107 |
+
|
| 108 |
+
return fig
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def plot_heatmap_overlay(
|
| 112 |
+
image: np.ndarray,
|
| 113 |
+
heatmap: np.ndarray,
|
| 114 |
+
title: str = "",
|
| 115 |
+
save_path: Optional[Path] = None,
|
| 116 |
+
) -> plt.Figure:
|
| 117 |
+
"""
|
| 118 |
+
Overlay predicted heatmap on image for QC.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
image: (H, W) grayscale
|
| 122 |
+
heatmap: (2, H/2, W/2) predicted heatmap
|
| 123 |
+
"""
|
| 124 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
| 125 |
+
|
| 126 |
+
axes[0].imshow(image, cmap="gray")
|
| 127 |
+
axes[0].set_title("Raw Image")
|
| 128 |
+
axes[0].axis("off")
|
| 129 |
+
|
| 130 |
+
# Upsample heatmap to image size for overlay
|
| 131 |
+
h, w = image.shape[:2]
|
| 132 |
+
|
| 133 |
+
for idx, (cls, color) in enumerate([("6nm", "hot"), ("12nm", "cool")]):
|
| 134 |
+
hm = heatmap[idx]
|
| 135 |
+
# Resize to image dims
|
| 136 |
+
from skimage.transform import resize
|
| 137 |
+
hm_up = resize(hm, (h, w), order=1)
|
| 138 |
+
|
| 139 |
+
axes[idx + 1].imshow(image, cmap="gray")
|
| 140 |
+
axes[idx + 1].imshow(hm_up, cmap=color, alpha=0.5, vmin=0, vmax=1)
|
| 141 |
+
axes[idx + 1].set_title(f"{cls} heatmap")
|
| 142 |
+
axes[idx + 1].axis("off")
|
| 143 |
+
|
| 144 |
+
fig.suptitle(title, fontsize=12)
|
| 145 |
+
|
| 146 |
+
if save_path:
|
| 147 |
+
save_path = Path(save_path)
|
| 148 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 149 |
+
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
| 150 |
+
plt.close(fig)
|
| 151 |
+
|
| 152 |
+
return fig
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def plot_training_curves(
|
| 156 |
+
metrics: dict,
|
| 157 |
+
save_path: Optional[Path] = None,
|
| 158 |
+
) -> plt.Figure:
|
| 159 |
+
"""Plot training loss and F1 curves."""
|
| 160 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
| 161 |
+
|
| 162 |
+
epochs = range(1, len(metrics["train_loss"]) + 1)
|
| 163 |
+
|
| 164 |
+
# Loss
|
| 165 |
+
ax1.plot(epochs, metrics["train_loss"], label="Train Loss")
|
| 166 |
+
if "val_loss" in metrics:
|
| 167 |
+
ax1.plot(epochs, metrics["val_loss"], label="Val Loss")
|
| 168 |
+
ax1.set_xlabel("Epoch")
|
| 169 |
+
ax1.set_ylabel("Loss")
|
| 170 |
+
ax1.set_title("Training Loss")
|
| 171 |
+
ax1.legend()
|
| 172 |
+
ax1.grid(True, alpha=0.3)
|
| 173 |
+
|
| 174 |
+
# F1
|
| 175 |
+
if "val_f1_6nm" in metrics:
|
| 176 |
+
ax2.plot(epochs, metrics["val_f1_6nm"], label="6nm F1")
|
| 177 |
+
if "val_f1_12nm" in metrics:
|
| 178 |
+
ax2.plot(epochs, metrics["val_f1_12nm"], label="12nm F1")
|
| 179 |
+
if "val_f1_mean" in metrics:
|
| 180 |
+
ax2.plot(epochs, metrics["val_f1_mean"], label="Mean F1", linewidth=2)
|
| 181 |
+
ax2.set_xlabel("Epoch")
|
| 182 |
+
ax2.set_ylabel("F1 Score")
|
| 183 |
+
ax2.set_title("Validation F1")
|
| 184 |
+
ax2.legend()
|
| 185 |
+
ax2.grid(True, alpha=0.3)
|
| 186 |
+
|
| 187 |
+
if save_path:
|
| 188 |
+
save_path = Path(save_path)
|
| 189 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 190 |
+
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
| 191 |
+
plt.close(fig)
|
| 192 |
+
|
| 193 |
+
return fig
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def plot_precision_recall_curve(
|
| 197 |
+
detections: List[dict],
|
| 198 |
+
gt_coords: np.ndarray,
|
| 199 |
+
match_radius: float,
|
| 200 |
+
cls_name: str = "",
|
| 201 |
+
save_path: Optional[Path] = None,
|
| 202 |
+
) -> plt.Figure:
|
| 203 |
+
"""Plot precision-recall curve for one class."""
|
| 204 |
+
sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
|
| 205 |
+
|
| 206 |
+
tp_list = []
|
| 207 |
+
matched_gt = set()
|
| 208 |
+
|
| 209 |
+
for det in sorted_dets:
|
| 210 |
+
det_coord = np.array([det["x"], det["y"]])
|
| 211 |
+
if len(gt_coords) > 0:
|
| 212 |
+
dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
|
| 213 |
+
min_idx = np.argmin(dists)
|
| 214 |
+
if dists[min_idx] <= match_radius and min_idx not in matched_gt:
|
| 215 |
+
tp_list.append(1)
|
| 216 |
+
matched_gt.add(min_idx)
|
| 217 |
+
else:
|
| 218 |
+
tp_list.append(0)
|
| 219 |
+
else:
|
| 220 |
+
tp_list.append(0)
|
| 221 |
+
|
| 222 |
+
tp_cumsum = np.cumsum(tp_list)
|
| 223 |
+
fp_cumsum = np.cumsum([1 - t for t in tp_list])
|
| 224 |
+
n_gt = max(len(gt_coords), 1)
|
| 225 |
+
|
| 226 |
+
precision = tp_cumsum / (tp_cumsum + fp_cumsum)
|
| 227 |
+
recall = tp_cumsum / n_gt
|
| 228 |
+
|
| 229 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 230 |
+
ax.plot(recall, precision, linewidth=2)
|
| 231 |
+
ax.set_xlabel("Recall")
|
| 232 |
+
ax.set_ylabel("Precision")
|
| 233 |
+
ax.set_title(f"PR Curve — {cls_name}")
|
| 234 |
+
ax.set_xlim(0, 1)
|
| 235 |
+
ax.set_ylim(0, 1)
|
| 236 |
+
ax.grid(True, alpha=0.3)
|
| 237 |
+
|
| 238 |
+
if save_path:
|
| 239 |
+
save_path = Path(save_path)
|
| 240 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 241 |
+
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
| 242 |
+
plt.close(fig)
|
| 243 |
+
|
| 244 |
+
return fig
|