# MidasMap: Automated Immunogold Particle Detection for TEM Synapse Images --- ## The Problem Neuroscientists use **immunogold labeling** to visualize receptor proteins at synapses in transmission electron microscopy (TEM) images. - **6nm gold beads** label AMPA receptors (panAMPA) - **12nm gold beads** label NR1 (NMDA) receptors - **18nm gold beads** label vGlut2 (vesicular glutamate transporter) **Manual counting is slow and subjective.** Each image takes 30-60 minutes to annotate. With hundreds of synapses per experiment, this becomes a bottleneck. ### The Challenge - Particles are **tiny** (4-10 pixels radius) on 2048x2115 images - Contrast delta is only **11-39 intensity units** on a 0-255 scale - Large dark vesicles look similar to gold particles to naive detectors - Only **453 labeled particles** across 10 training images --- ## Previous Approaches (GoldDigger et al.) | Approach | Result | |----------|--------| | CenterNet (initial attempt) | "Detection quality remained poor" | | U-Net heatmap | Macro F1 = 0.005-0.017 | | GoldDigger/cGAN | "No durable breakthrough" | | Aggressive filtering | "FP dropped but TP dropped harder" | **Core issue:** Previous systems failed due to: 1. Incorrect coordinate conversion (microns treated as normalized values) 2. Broken loss function (heatmap peaks not exactly 1.0) 3. Overfitting to fixed training patches --- ## MidasMap Architecture ``` Input: Raw TEM Image (any size) | [Sliding Window → 512x512 patches] | ResNet-50 Encoder (pretrained on CEM500K: 500K EM images) | BiFPN Neck (bidirectional feature pyramid, 2 rounds, 128ch) | Transposed Conv Decoder → stride-2 output | +------------------+-------------------+ | | | Heatmap Head Offset Head (2ch sigmoid) (2ch regression) 6nm channel sub-pixel x,y 12nm channel correction | | +------------------+-------------------+ | Peak Extraction (max-pool NMS) | Cross-class NMS + Mask Filter | Output: [(x, y, class, confidence), ...] ``` ### Key Design Decisions **CEM500K Backbone:** ResNet-50 pretrained on 500,000 electron microscopy images via self-supervised learning. The backbone already understands EM structures (membranes, vesicles, organelles) before seeing any gold particles. This is why the model reaches F1=0.93 in just 5 epochs. **Stride-2 Output:** Standard CenterNet uses stride 4. At stride 4, a 6nm bead (4-6px radius) collapses to 1 pixel — too small to detect reliably. At stride 2, the same bead occupies 2-3 pixels, enough for Gaussian peak detection. **CornerNet Focal Loss:** With positive:negative pixel ratio of 1:23,000, standard BCE would learn to predict all zeros. The focal loss uses `(1-p)^alpha` weighting to focus on hard examples and `(1-gt)^beta` penalty reduction near peaks. **Raw Image Input:** No preprocessing. The CEM500K backbone was trained on raw EM images. Any heavy preprocessing (top-hat, CLAHE) creates a domain gap and hurts performance. The model learns to distinguish particles from vesicles through training data, not handcrafted filters. --- ## Training Strategy ### 3-Phase Training with Discriminative Learning Rates | Phase | Epochs | What's Trainable | Learning Rate | |-------|--------|-------------------|---------------| | **1. Warm-up** | 40 | BiFPN + heads only | 1e-3 | | **2. Deep unfreeze** | 40 | + layer3 + layer4 | 1e-5 to 5e-4 | | **3. Full fine-tune** | 60 | All layers | 1e-6 to 2e-4 | ``` Loss Curve (final model): Phase 1 Phase 2 Phase 3 | | | 1.4 |\ | | | \ | | 1.0 | \ | | | ---- | | 0.8 | \ | | | \ | | 0.6 | \--+--- | | | \ | 0.4 | | \--- | | | \-------+--- 0.2 | | | +---+---+----+---+---+----+---+---+--> Epoch 0 10 20 40 50 60 80 100 140 ``` ### Data Augmentation - Random 90-degree rotations (EM is rotation-invariant) - Horizontal/vertical flips - Conservative brightness/contrast (+-8% — preserves the subtle particle signal) - Gaussian noise (simulates shot noise) - **Copy-paste augmentation**: real bead crops blended onto training patches - **70% hard mining**: patches centered on particles, 30% random ### Overfitting Prevention - **Unique patches every epoch**: RNG reseeded per sample so the model never sees the same patch twice - **Early stopping**: patience=20 epochs, monitoring validation F1 - **Weight decay**: 1e-4 on all parameters --- ## Critical Bugs Found and Fixed ### Bug 1: Coordinate Conversion **Problem:** CSV files labeled "XY in microns" were assumed to be normalized [0,1] coordinates. They were actual micron values. **Effect:** All particle annotations were offset by 50-80 pixels from the real locations. The model was learning to detect particles where none existed. **Fix:** Multiply by 1790 px/micron (verified against researcher's color overlay TIFs across 7 synapses). ### Bug 2: Heatmap Peak Values **Problem:** Gaussian peaks were centered at float coordinates, producing peak values of 0.78-0.93 instead of exactly 1.0. **Effect:** The CornerNet focal loss uses `pos_mask = (gt == 1.0)` to identify positive pixels. With no pixels at exactly 1.0, the model had **zero positive training signal**. It literally could not learn. **Fix:** Center Gaussians at the integer grid point (always produces 1.0). Sub-pixel precision is handled by the offset regression head. ### Bug 3: Overfitting on Fixed Patches **Problem:** The dataset generated 200 random patches once at initialization. Every epoch replayed the same patches. **Effect:** On fast CUDA GPUs, the model memorized all patches in ~17 epochs (loss crashed from 1.6 to 0.002). Validation F1 peaked at 0.66 and degraded. **Fix:** Reseed RNG per `__getitem__` call so every patch is unique. --- ## Results ### Leave-One-Image-Out Cross-Validation (10 folds, 5 seeds each) | Fold | Avg F1 | Best F1 | Notes | |------|--------|---------|-------| | S27 | **0.990** | 0.994 | | | S8 | **0.981** | 0.988 | | | S25 | **0.972** | 0.977 | | | S29 | **0.956** | 0.966 | | | S1 | **0.930** | 0.940 | | | S4 | **0.919** | 0.972 | | | S22 | **0.907** | 0.938 | | | S13 | **0.890** | 0.912 | | | S7 | 0.799 | 1.000 | Only 3 particles (noisy metric) | | S15 | 0.633 | 0.667 | Only 1 particle (noisy metric) | **Mean F1 = 0.943** (8 folds with sufficient annotations) ### Per-class Performance (S1 fold, best threshold) | Class | Precision | Recall | F1 | |-------|-----------|--------|----| | 6nm (AMPA) | 0.895 | **1.000** | **0.944** | | 12nm (NR1) | 0.833 | **1.000** | **0.909** | **100% recall** on both classes — every particle is found. Only errors are a few false positives. ### Generalization to Unseen Images Tested on 15 completely unseen images from a different imaging session. Detections land correctly on particles with no manual tuning. The model successfully detects both 6nm and 12nm particles on: - Wild-type (Wt2) samples - Heterozygous (Het1) samples - Different synapse regions (D1, E3, S1, S10, S12, S18) --- ## System Components ``` MidasMap/ config/config.yaml # All hyperparameters src/ preprocessing.py # Data loading (10 synapses, 453 particles) model.py # CenterNet: ResNet-50 + BiFPN + heads (24.4M params) loss.py # CornerNet focal loss + offset regression heatmap.py # GT generation + peak extraction + NMS dataset.py # Patch sampling, augmentation, copy-paste postprocess.py # Mask filter, cross-class NMS ensemble.py # D4 TTA + sliding window inference evaluate.py # Hungarian matching, F1/precision/recall visualize.py # Overlay visualizations train.py # LOOCV training (--fold, --seed) train_final.py # Final deployable model (all data) predict.py # Inference on new images evaluate_loocv.py # Full evaluation runner app.py # Gradio web dashboard slurm/ # HPC job scripts tests/ # 36 unit tests ``` --- ## Dashboard MidasMap includes a web-based dashboard (Gradio) for interactive use: 1. **Upload** any TEM image (.tif) 2. **Adjust** confidence threshold and NMS parameters 3. **View** detections overlaid on the image 4. **Inspect** per-class heatmaps 5. **Analyze** confidence distributions and spatial patterns 6. **Export** results as CSV (particle_id, x_px, y_px, x_um, y_um, class, confidence) ``` python app.py --checkpoint checkpoints/final/final_model.pth # Opens at http://localhost:7860 ``` --- ## Future Directions 1. **Spatial analytics**: distance to synaptic cleft, nearest-neighbor analysis, Ripley's K-function 2. **Size regression head**: predict actual bead diameter instead of binary classification 3. **18nm detection**: extend to vGlut2 particles (3-class model) 4. **Active learning**: flag low-confidence detections for human review 5. **Cross-protocol generalization**: fine-tune on cryo-EM or different staining protocols --- ## Technical Summary - **Model**: CenterNet with CEM500K-pretrained ResNet-50, BiFPN neck, stride-2 output - **Training**: 3-phase with discriminative LRs, 140 epochs, 453 particles / 10 images - **Evaluation**: Leave-one-image-out CV, Hungarian matching, F1 = 0.943 - **Inference**: Sliding window (512x512, 128px overlap), ~10s per image on GPU - **Output**: Per-particle (x, y, class, confidence) with optional heatmap visualization