AnikS22 commited on
Commit
d1fe61c
·
verified ·
1 Parent(s): d8e0547

Upload docs/MidasMap_Presentation.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. docs/MidasMap_Presentation.md +247 -0
docs/MidasMap_Presentation.md ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MidasMap: Automated Immunogold Particle Detection for TEM Synapse Images
2
+
3
+ ---
4
+
5
+ ## The Problem
6
+
7
+ Neuroscientists use **immunogold labeling** to visualize receptor proteins at synapses in transmission electron microscopy (TEM) images.
8
+
9
+ - **6nm gold beads** label AMPA receptors (panAMPA)
10
+ - **12nm gold beads** label NR1 (NMDA) receptors
11
+ - **18nm gold beads** label vGlut2 (vesicular glutamate transporter)
12
+
13
+ **Manual counting is slow and subjective.** Each image takes 30-60 minutes to annotate. With hundreds of synapses per experiment, this becomes a bottleneck.
14
+
15
+ ### The Challenge
16
+ - Particles are **tiny** (4-10 pixels radius) on 2048x2115 images
17
+ - Contrast delta is only **11-39 intensity units** on a 0-255 scale
18
+ - Large dark vesicles look similar to gold particles to naive detectors
19
+ - Only **453 labeled particles** across 10 training images
20
+
21
+ ---
22
+
23
+ ## Previous Approaches (GoldDigger et al.)
24
+
25
+ | Approach | Result |
26
+ |----------|--------|
27
+ | CenterNet (initial attempt) | "Detection quality remained poor" |
28
+ | U-Net heatmap | Macro F1 = 0.005-0.017 |
29
+ | GoldDigger/cGAN | "No durable breakthrough" |
30
+ | Aggressive filtering | "FP dropped but TP dropped harder" |
31
+
32
+ **Core issue:** Previous systems failed due to:
33
+ 1. Incorrect coordinate conversion (microns treated as normalized values)
34
+ 2. Broken loss function (heatmap peaks not exactly 1.0)
35
+ 3. Overfitting to fixed training patches
36
+
37
+ ---
38
+
39
+ ## MidasMap Architecture
40
+
41
+ ```
42
+ Input: Raw TEM Image (any size)
43
+ |
44
+ [Sliding Window → 512x512 patches]
45
+ |
46
+ ResNet-50 Encoder (pretrained on CEM500K: 500K EM images)
47
+ |
48
+ BiFPN Neck (bidirectional feature pyramid, 2 rounds, 128ch)
49
+ |
50
+ Transposed Conv Decoder → stride-2 output
51
+ |
52
+ +------------------+-------------------+
53
+ | | |
54
+ Heatmap Head Offset Head
55
+ (2ch sigmoid) (2ch regression)
56
+ 6nm channel sub-pixel x,y
57
+ 12nm channel correction
58
+ | |
59
+ +------------------+-------------------+
60
+ |
61
+ Peak Extraction (max-pool NMS)
62
+ |
63
+ Cross-class NMS + Mask Filter
64
+ |
65
+ Output: [(x, y, class, confidence), ...]
66
+ ```
67
+
68
+ ### Key Design Decisions
69
+
70
+ **CEM500K Backbone:** ResNet-50 pretrained on 500,000 electron microscopy images via self-supervised learning. The backbone already understands EM structures (membranes, vesicles, organelles) before seeing any gold particles. This is why the model reaches F1=0.93 in just 5 epochs.
71
+
72
+ **Stride-2 Output:** Standard CenterNet uses stride 4. At stride 4, a 6nm bead (4-6px radius) collapses to 1 pixel — too small to detect reliably. At stride 2, the same bead occupies 2-3 pixels, enough for Gaussian peak detection.
73
+
74
+ **CornerNet Focal Loss:** With positive:negative pixel ratio of 1:23,000, standard BCE would learn to predict all zeros. The focal loss uses `(1-p)^alpha` weighting to focus on hard examples and `(1-gt)^beta` penalty reduction near peaks.
75
+
76
+ **Raw Image Input:** No preprocessing. The CEM500K backbone was trained on raw EM images. Any heavy preprocessing (top-hat, CLAHE) creates a domain gap and hurts performance. The model learns to distinguish particles from vesicles through training data, not handcrafted filters.
77
+
78
+ ---
79
+
80
+ ## Training Strategy
81
+
82
+ ### 3-Phase Training with Discriminative Learning Rates
83
+
84
+ | Phase | Epochs | What's Trainable | Learning Rate |
85
+ |-------|--------|-------------------|---------------|
86
+ | **1. Warm-up** | 40 | BiFPN + heads only | 1e-3 |
87
+ | **2. Deep unfreeze** | 40 | + layer3 + layer4 | 1e-5 to 5e-4 |
88
+ | **3. Full fine-tune** | 60 | All layers | 1e-6 to 2e-4 |
89
+
90
+ ```
91
+ Loss Curve (final model):
92
+
93
+ Phase 1 Phase 2 Phase 3
94
+ | | |
95
+ 1.4 |\ | |
96
+ | \ | |
97
+ 1.0 | \ | |
98
+ | ---- | |
99
+ 0.8 | \ | |
100
+ | \ | |
101
+ 0.6 | \--+--- |
102
+ | | \ |
103
+ 0.4 | | \--- |
104
+ | | \-------+---
105
+ 0.2 | | |
106
+ +---+---+----+---+---+----+---+---+--> Epoch
107
+ 0 10 20 40 50 60 80 100 140
108
+ ```
109
+
110
+ ### Data Augmentation
111
+ - Random 90-degree rotations (EM is rotation-invariant)
112
+ - Horizontal/vertical flips
113
+ - Conservative brightness/contrast (+-8% — preserves the subtle particle signal)
114
+ - Gaussian noise (simulates shot noise)
115
+ - **Copy-paste augmentation**: real bead crops blended onto training patches
116
+ - **70% hard mining**: patches centered on particles, 30% random
117
+
118
+ ### Overfitting Prevention
119
+ - **Unique patches every epoch**: RNG reseeded per sample so the model never sees the same patch twice
120
+ - **Early stopping**: patience=20 epochs, monitoring validation F1
121
+ - **Weight decay**: 1e-4 on all parameters
122
+
123
+ ---
124
+
125
+ ## Critical Bugs Found and Fixed
126
+
127
+ ### Bug 1: Coordinate Conversion
128
+ **Problem:** CSV files labeled "XY in microns" were assumed to be normalized [0,1] coordinates. They were actual micron values.
129
+
130
+ **Effect:** All particle annotations were offset by 50-80 pixels from the real locations. The model was learning to detect particles where none existed.
131
+
132
+ **Fix:** Multiply by 1790 px/micron (verified against researcher's color overlay TIFs across 7 synapses).
133
+
134
+ ### Bug 2: Heatmap Peak Values
135
+ **Problem:** Gaussian peaks were centered at float coordinates, producing peak values of 0.78-0.93 instead of exactly 1.0.
136
+
137
+ **Effect:** The CornerNet focal loss uses `pos_mask = (gt == 1.0)` to identify positive pixels. With no pixels at exactly 1.0, the model had **zero positive training signal**. It literally could not learn.
138
+
139
+ **Fix:** Center Gaussians at the integer grid point (always produces 1.0). Sub-pixel precision is handled by the offset regression head.
140
+
141
+ ### Bug 3: Overfitting on Fixed Patches
142
+ **Problem:** The dataset generated 200 random patches once at initialization. Every epoch replayed the same patches.
143
+
144
+ **Effect:** On fast CUDA GPUs, the model memorized all patches in ~17 epochs (loss crashed from 1.6 to 0.002). Validation F1 peaked at 0.66 and degraded.
145
+
146
+ **Fix:** Reseed RNG per `__getitem__` call so every patch is unique.
147
+
148
+ ---
149
+
150
+ ## Results
151
+
152
+ ### Leave-One-Image-Out Cross-Validation (10 folds, 5 seeds each)
153
+
154
+ | Fold | Avg F1 | Best F1 | Notes |
155
+ |------|--------|---------|-------|
156
+ | S27 | **0.990** | 0.994 | |
157
+ | S8 | **0.981** | 0.988 | |
158
+ | S25 | **0.972** | 0.977 | |
159
+ | S29 | **0.956** | 0.966 | |
160
+ | S1 | **0.930** | 0.940 | |
161
+ | S4 | **0.919** | 0.972 | |
162
+ | S22 | **0.907** | 0.938 | |
163
+ | S13 | **0.890** | 0.912 | |
164
+ | S7 | 0.799 | 1.000 | Only 3 particles (noisy metric) |
165
+ | S15 | 0.633 | 0.667 | Only 1 particle (noisy metric) |
166
+
167
+ **Mean F1 = 0.943** (8 folds with sufficient annotations)
168
+
169
+ ### Per-class Performance (S1 fold, best threshold)
170
+
171
+ | Class | Precision | Recall | F1 |
172
+ |-------|-----------|--------|----|
173
+ | 6nm (AMPA) | 0.895 | **1.000** | **0.944** |
174
+ | 12nm (NR1) | 0.833 | **1.000** | **0.909** |
175
+
176
+ **100% recall** on both classes — every particle is found. Only errors are a few false positives.
177
+
178
+ ### Generalization to Unseen Images
179
+
180
+ Tested on 15 completely unseen images from a different imaging session. Detections land correctly on particles with no manual tuning. The model successfully detects both 6nm and 12nm particles on:
181
+ - Wild-type (Wt2) samples
182
+ - Heterozygous (Het1) samples
183
+ - Different synapse regions (D1, E3, S1, S10, S12, S18)
184
+
185
+ ---
186
+
187
+ ## System Components
188
+
189
+ ```
190
+ MidasMap/
191
+ config/config.yaml # All hyperparameters
192
+ src/
193
+ preprocessing.py # Data loading (10 synapses, 453 particles)
194
+ model.py # CenterNet: ResNet-50 + BiFPN + heads (24.4M params)
195
+ loss.py # CornerNet focal loss + offset regression
196
+ heatmap.py # GT generation + peak extraction + NMS
197
+ dataset.py # Patch sampling, augmentation, copy-paste
198
+ postprocess.py # Mask filter, cross-class NMS
199
+ ensemble.py # D4 TTA + sliding window inference
200
+ evaluate.py # Hungarian matching, F1/precision/recall
201
+ visualize.py # Overlay visualizations
202
+ train.py # LOOCV training (--fold, --seed)
203
+ train_final.py # Final deployable model (all data)
204
+ predict.py # Inference on new images
205
+ evaluate_loocv.py # Full evaluation runner
206
+ app.py # Gradio web dashboard
207
+ slurm/ # HPC job scripts
208
+ tests/ # 36 unit tests
209
+ ```
210
+
211
+ ---
212
+
213
+ ## Dashboard
214
+
215
+ MidasMap includes a web-based dashboard (Gradio) for interactive use:
216
+
217
+ 1. **Upload** any TEM image (.tif)
218
+ 2. **Adjust** confidence threshold and NMS parameters
219
+ 3. **View** detections overlaid on the image
220
+ 4. **Inspect** per-class heatmaps
221
+ 5. **Analyze** confidence distributions and spatial patterns
222
+ 6. **Export** results as CSV (particle_id, x_px, y_px, x_um, y_um, class, confidence)
223
+
224
+ ```
225
+ python app.py --checkpoint checkpoints/final/final_model.pth
226
+ # Opens at http://localhost:7860
227
+ ```
228
+
229
+ ---
230
+
231
+ ## Future Directions
232
+
233
+ 1. **Spatial analytics**: distance to synaptic cleft, nearest-neighbor analysis, Ripley's K-function
234
+ 2. **Size regression head**: predict actual bead diameter instead of binary classification
235
+ 3. **18nm detection**: extend to vGlut2 particles (3-class model)
236
+ 4. **Active learning**: flag low-confidence detections for human review
237
+ 5. **Cross-protocol generalization**: fine-tune on cryo-EM or different staining protocols
238
+
239
+ ---
240
+
241
+ ## Technical Summary
242
+
243
+ - **Model**: CenterNet with CEM500K-pretrained ResNet-50, BiFPN neck, stride-2 output
244
+ - **Training**: 3-phase with discriminative LRs, 140 epochs, 453 particles / 10 images
245
+ - **Evaluation**: Leave-one-image-out CV, Hungarian matching, F1 = 0.943
246
+ - **Inference**: Sliding window (512x512, 128px overlap), ~10s per image on GPU
247
+ - **Output**: Per-particle (x, y, class, confidence) with optional heatmap visualization