MidasMap / docs /MidasMap_Presentation.md
AnikS22's picture
Upload docs/MidasMap_Presentation.md with huggingface_hub
d1fe61c verified
# MidasMap: Automated Immunogold Particle Detection for TEM Synapse Images
---
## The Problem
Neuroscientists use **immunogold labeling** to visualize receptor proteins at synapses in transmission electron microscopy (TEM) images.
- **6nm gold beads** label AMPA receptors (panAMPA)
- **12nm gold beads** label NR1 (NMDA) receptors
- **18nm gold beads** label vGlut2 (vesicular glutamate transporter)
**Manual counting is slow and subjective.** Each image takes 30-60 minutes to annotate. With hundreds of synapses per experiment, this becomes a bottleneck.
### The Challenge
- Particles are **tiny** (4-10 pixels radius) on 2048x2115 images
- Contrast delta is only **11-39 intensity units** on a 0-255 scale
- Large dark vesicles look similar to gold particles to naive detectors
- Only **453 labeled particles** across 10 training images
---
## Previous Approaches (GoldDigger et al.)
| Approach | Result |
|----------|--------|
| CenterNet (initial attempt) | "Detection quality remained poor" |
| U-Net heatmap | Macro F1 = 0.005-0.017 |
| GoldDigger/cGAN | "No durable breakthrough" |
| Aggressive filtering | "FP dropped but TP dropped harder" |
**Core issue:** Previous systems failed due to:
1. Incorrect coordinate conversion (microns treated as normalized values)
2. Broken loss function (heatmap peaks not exactly 1.0)
3. Overfitting to fixed training patches
---
## MidasMap Architecture
```
Input: Raw TEM Image (any size)
|
[Sliding Window → 512x512 patches]
|
ResNet-50 Encoder (pretrained on CEM500K: 500K EM images)
|
BiFPN Neck (bidirectional feature pyramid, 2 rounds, 128ch)
|
Transposed Conv Decoder → stride-2 output
|
+------------------+-------------------+
| | |
Heatmap Head Offset Head
(2ch sigmoid) (2ch regression)
6nm channel sub-pixel x,y
12nm channel correction
| |
+------------------+-------------------+
|
Peak Extraction (max-pool NMS)
|
Cross-class NMS + Mask Filter
|
Output: [(x, y, class, confidence), ...]
```
### Key Design Decisions
**CEM500K Backbone:** ResNet-50 pretrained on 500,000 electron microscopy images via self-supervised learning. The backbone already understands EM structures (membranes, vesicles, organelles) before seeing any gold particles. This is why the model reaches F1=0.93 in just 5 epochs.
**Stride-2 Output:** Standard CenterNet uses stride 4. At stride 4, a 6nm bead (4-6px radius) collapses to 1 pixel — too small to detect reliably. At stride 2, the same bead occupies 2-3 pixels, enough for Gaussian peak detection.
**CornerNet Focal Loss:** With positive:negative pixel ratio of 1:23,000, standard BCE would learn to predict all zeros. The focal loss uses `(1-p)^alpha` weighting to focus on hard examples and `(1-gt)^beta` penalty reduction near peaks.
**Raw Image Input:** No preprocessing. The CEM500K backbone was trained on raw EM images. Any heavy preprocessing (top-hat, CLAHE) creates a domain gap and hurts performance. The model learns to distinguish particles from vesicles through training data, not handcrafted filters.
---
## Training Strategy
### 3-Phase Training with Discriminative Learning Rates
| Phase | Epochs | What's Trainable | Learning Rate |
|-------|--------|-------------------|---------------|
| **1. Warm-up** | 40 | BiFPN + heads only | 1e-3 |
| **2. Deep unfreeze** | 40 | + layer3 + layer4 | 1e-5 to 5e-4 |
| **3. Full fine-tune** | 60 | All layers | 1e-6 to 2e-4 |
```
Loss Curve (final model):
Phase 1 Phase 2 Phase 3
| | |
1.4 |\ | |
| \ | |
1.0 | \ | |
| ---- | |
0.8 | \ | |
| \ | |
0.6 | \--+--- |
| | \ |
0.4 | | \--- |
| | \-------+---
0.2 | | |
+---+---+----+---+---+----+---+---+--> Epoch
0 10 20 40 50 60 80 100 140
```
### Data Augmentation
- Random 90-degree rotations (EM is rotation-invariant)
- Horizontal/vertical flips
- Conservative brightness/contrast (+-8% — preserves the subtle particle signal)
- Gaussian noise (simulates shot noise)
- **Copy-paste augmentation**: real bead crops blended onto training patches
- **70% hard mining**: patches centered on particles, 30% random
### Overfitting Prevention
- **Unique patches every epoch**: RNG reseeded per sample so the model never sees the same patch twice
- **Early stopping**: patience=20 epochs, monitoring validation F1
- **Weight decay**: 1e-4 on all parameters
---
## Critical Bugs Found and Fixed
### Bug 1: Coordinate Conversion
**Problem:** CSV files labeled "XY in microns" were assumed to be normalized [0,1] coordinates. They were actual micron values.
**Effect:** All particle annotations were offset by 50-80 pixels from the real locations. The model was learning to detect particles where none existed.
**Fix:** Multiply by 1790 px/micron (verified against researcher's color overlay TIFs across 7 synapses).
### Bug 2: Heatmap Peak Values
**Problem:** Gaussian peaks were centered at float coordinates, producing peak values of 0.78-0.93 instead of exactly 1.0.
**Effect:** The CornerNet focal loss uses `pos_mask = (gt == 1.0)` to identify positive pixels. With no pixels at exactly 1.0, the model had **zero positive training signal**. It literally could not learn.
**Fix:** Center Gaussians at the integer grid point (always produces 1.0). Sub-pixel precision is handled by the offset regression head.
### Bug 3: Overfitting on Fixed Patches
**Problem:** The dataset generated 200 random patches once at initialization. Every epoch replayed the same patches.
**Effect:** On fast CUDA GPUs, the model memorized all patches in ~17 epochs (loss crashed from 1.6 to 0.002). Validation F1 peaked at 0.66 and degraded.
**Fix:** Reseed RNG per `__getitem__` call so every patch is unique.
---
## Results
### Leave-One-Image-Out Cross-Validation (10 folds, 5 seeds each)
| Fold | Avg F1 | Best F1 | Notes |
|------|--------|---------|-------|
| S27 | **0.990** | 0.994 | |
| S8 | **0.981** | 0.988 | |
| S25 | **0.972** | 0.977 | |
| S29 | **0.956** | 0.966 | |
| S1 | **0.930** | 0.940 | |
| S4 | **0.919** | 0.972 | |
| S22 | **0.907** | 0.938 | |
| S13 | **0.890** | 0.912 | |
| S7 | 0.799 | 1.000 | Only 3 particles (noisy metric) |
| S15 | 0.633 | 0.667 | Only 1 particle (noisy metric) |
**Mean F1 = 0.943** (8 folds with sufficient annotations)
### Per-class Performance (S1 fold, best threshold)
| Class | Precision | Recall | F1 |
|-------|-----------|--------|----|
| 6nm (AMPA) | 0.895 | **1.000** | **0.944** |
| 12nm (NR1) | 0.833 | **1.000** | **0.909** |
**100% recall** on both classes — every particle is found. Only errors are a few false positives.
### Generalization to Unseen Images
Tested on 15 completely unseen images from a different imaging session. Detections land correctly on particles with no manual tuning. The model successfully detects both 6nm and 12nm particles on:
- Wild-type (Wt2) samples
- Heterozygous (Het1) samples
- Different synapse regions (D1, E3, S1, S10, S12, S18)
---
## System Components
```
MidasMap/
config/config.yaml # All hyperparameters
src/
preprocessing.py # Data loading (10 synapses, 453 particles)
model.py # CenterNet: ResNet-50 + BiFPN + heads (24.4M params)
loss.py # CornerNet focal loss + offset regression
heatmap.py # GT generation + peak extraction + NMS
dataset.py # Patch sampling, augmentation, copy-paste
postprocess.py # Mask filter, cross-class NMS
ensemble.py # D4 TTA + sliding window inference
evaluate.py # Hungarian matching, F1/precision/recall
visualize.py # Overlay visualizations
train.py # LOOCV training (--fold, --seed)
train_final.py # Final deployable model (all data)
predict.py # Inference on new images
evaluate_loocv.py # Full evaluation runner
app.py # Gradio web dashboard
slurm/ # HPC job scripts
tests/ # 36 unit tests
```
---
## Dashboard
MidasMap includes a web-based dashboard (Gradio) for interactive use:
1. **Upload** any TEM image (.tif)
2. **Adjust** confidence threshold and NMS parameters
3. **View** detections overlaid on the image
4. **Inspect** per-class heatmaps
5. **Analyze** confidence distributions and spatial patterns
6. **Export** results as CSV (particle_id, x_px, y_px, x_um, y_um, class, confidence)
```
python app.py --checkpoint checkpoints/final/final_model.pth
# Opens at http://localhost:7860
```
---
## Future Directions
1. **Spatial analytics**: distance to synaptic cleft, nearest-neighbor analysis, Ripley's K-function
2. **Size regression head**: predict actual bead diameter instead of binary classification
3. **18nm detection**: extend to vGlut2 particles (3-class model)
4. **Active learning**: flag low-confidence detections for human review
5. **Cross-protocol generalization**: fine-tune on cryo-EM or different staining protocols
---
## Technical Summary
- **Model**: CenterNet with CEM500K-pretrained ResNet-50, BiFPN neck, stride-2 output
- **Training**: 3-phase with discriminative LRs, 140 epochs, 453 particles / 10 images
- **Evaluation**: Leave-one-image-out CV, Hungarian matching, F1 = 0.943
- **Inference**: Sliding window (512x512, 128px overlap), ~10s per image on GPU
- **Output**: Per-particle (x, y, class, confidence) with optional heatmap visualization