File size: 9,895 Bytes
d1fe61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# MidasMap: Automated Immunogold Particle Detection for TEM Synapse Images

---

## The Problem

Neuroscientists use **immunogold labeling** to visualize receptor proteins at synapses in transmission electron microscopy (TEM) images.

- **6nm gold beads** label AMPA receptors (panAMPA)
- **12nm gold beads** label NR1 (NMDA) receptors
- **18nm gold beads** label vGlut2 (vesicular glutamate transporter)

**Manual counting is slow and subjective.** Each image takes 30-60 minutes to annotate. With hundreds of synapses per experiment, this becomes a bottleneck.

### The Challenge
- Particles are **tiny** (4-10 pixels radius) on 2048x2115 images
- Contrast delta is only **11-39 intensity units** on a 0-255 scale
- Large dark vesicles look similar to gold particles to naive detectors
- Only **453 labeled particles** across 10 training images

---

## Previous Approaches (GoldDigger et al.)

| Approach | Result |
|----------|--------|
| CenterNet (initial attempt) | "Detection quality remained poor" |
| U-Net heatmap | Macro F1 = 0.005-0.017 |
| GoldDigger/cGAN | "No durable breakthrough" |
| Aggressive filtering | "FP dropped but TP dropped harder" |

**Core issue:** Previous systems failed due to:
1. Incorrect coordinate conversion (microns treated as normalized values)
2. Broken loss function (heatmap peaks not exactly 1.0)
3. Overfitting to fixed training patches

---

## MidasMap Architecture

```
Input: Raw TEM Image (any size)
         |
    [Sliding Window → 512x512 patches]
         |
    ResNet-50 Encoder (pretrained on CEM500K: 500K EM images)
         |
    BiFPN Neck (bidirectional feature pyramid, 2 rounds, 128ch)
         |
    Transposed Conv Decoder → stride-2 output
         |
    +------------------+-------------------+
    |                  |                   |
    Heatmap Head       Offset Head
    (2ch sigmoid)      (2ch regression)
    6nm channel        sub-pixel x,y
    12nm channel       correction
    |                  |
    +------------------+-------------------+
         |
    Peak Extraction (max-pool NMS)
         |
    Cross-class NMS + Mask Filter
         |
    Output: [(x, y, class, confidence), ...]
```

### Key Design Decisions

**CEM500K Backbone:** ResNet-50 pretrained on 500,000 electron microscopy images via self-supervised learning. The backbone already understands EM structures (membranes, vesicles, organelles) before seeing any gold particles. This is why the model reaches F1=0.93 in just 5 epochs.

**Stride-2 Output:** Standard CenterNet uses stride 4. At stride 4, a 6nm bead (4-6px radius) collapses to 1 pixel — too small to detect reliably. At stride 2, the same bead occupies 2-3 pixels, enough for Gaussian peak detection.

**CornerNet Focal Loss:** With positive:negative pixel ratio of 1:23,000, standard BCE would learn to predict all zeros. The focal loss uses `(1-p)^alpha` weighting to focus on hard examples and `(1-gt)^beta` penalty reduction near peaks.

**Raw Image Input:** No preprocessing. The CEM500K backbone was trained on raw EM images. Any heavy preprocessing (top-hat, CLAHE) creates a domain gap and hurts performance. The model learns to distinguish particles from vesicles through training data, not handcrafted filters.

---

## Training Strategy

### 3-Phase Training with Discriminative Learning Rates

| Phase | Epochs | What's Trainable | Learning Rate |
|-------|--------|-------------------|---------------|
| **1. Warm-up** | 40 | BiFPN + heads only | 1e-3 |
| **2. Deep unfreeze** | 40 | + layer3 + layer4 | 1e-5 to 5e-4 |
| **3. Full fine-tune** | 60 | All layers | 1e-6 to 2e-4 |

```
Loss Curve (final model):

Phase 1          Phase 2          Phase 3
|                |                |
1.4 |\           |                |
    | \          |                |
1.0 |  \         |                |
    |   ----     |                |
0.8 |       \    |                |
    |        \   |                |
0.6 |         \--+---             |
    |            |   \            |
0.4 |            |    \---        |
    |            |        \-------+---
0.2 |            |                |
    +---+---+----+---+---+----+---+---+--> Epoch
    0   10  20   40  50  60   80  100 140
```

### Data Augmentation
- Random 90-degree rotations (EM is rotation-invariant)
- Horizontal/vertical flips
- Conservative brightness/contrast (+-8% — preserves the subtle particle signal)
- Gaussian noise (simulates shot noise)
- **Copy-paste augmentation**: real bead crops blended onto training patches
- **70% hard mining**: patches centered on particles, 30% random

### Overfitting Prevention
- **Unique patches every epoch**: RNG reseeded per sample so the model never sees the same patch twice
- **Early stopping**: patience=20 epochs, monitoring validation F1
- **Weight decay**: 1e-4 on all parameters

---

## Critical Bugs Found and Fixed

### Bug 1: Coordinate Conversion
**Problem:** CSV files labeled "XY in microns" were assumed to be normalized [0,1] coordinates. They were actual micron values.

**Effect:** All particle annotations were offset by 50-80 pixels from the real locations. The model was learning to detect particles where none existed.

**Fix:** Multiply by 1790 px/micron (verified against researcher's color overlay TIFs across 7 synapses).

### Bug 2: Heatmap Peak Values
**Problem:** Gaussian peaks were centered at float coordinates, producing peak values of 0.78-0.93 instead of exactly 1.0.

**Effect:** The CornerNet focal loss uses `pos_mask = (gt == 1.0)` to identify positive pixels. With no pixels at exactly 1.0, the model had **zero positive training signal**. It literally could not learn.

**Fix:** Center Gaussians at the integer grid point (always produces 1.0). Sub-pixel precision is handled by the offset regression head.

### Bug 3: Overfitting on Fixed Patches
**Problem:** The dataset generated 200 random patches once at initialization. Every epoch replayed the same patches.

**Effect:** On fast CUDA GPUs, the model memorized all patches in ~17 epochs (loss crashed from 1.6 to 0.002). Validation F1 peaked at 0.66 and degraded.

**Fix:** Reseed RNG per `__getitem__` call so every patch is unique.

---

## Results

### Leave-One-Image-Out Cross-Validation (10 folds, 5 seeds each)

| Fold | Avg F1 | Best F1 | Notes |
|------|--------|---------|-------|
| S27 | **0.990** | 0.994 | |
| S8 | **0.981** | 0.988 | |
| S25 | **0.972** | 0.977 | |
| S29 | **0.956** | 0.966 | |
| S1 | **0.930** | 0.940 | |
| S4 | **0.919** | 0.972 | |
| S22 | **0.907** | 0.938 | |
| S13 | **0.890** | 0.912 | |
| S7 | 0.799 | 1.000 | Only 3 particles (noisy metric) |
| S15 | 0.633 | 0.667 | Only 1 particle (noisy metric) |

**Mean F1 = 0.943** (8 folds with sufficient annotations)

### Per-class Performance (S1 fold, best threshold)

| Class | Precision | Recall | F1 |
|-------|-----------|--------|----|
| 6nm (AMPA) | 0.895 | **1.000** | **0.944** |
| 12nm (NR1) | 0.833 | **1.000** | **0.909** |

**100% recall** on both classes — every particle is found. Only errors are a few false positives.

### Generalization to Unseen Images

Tested on 15 completely unseen images from a different imaging session. Detections land correctly on particles with no manual tuning. The model successfully detects both 6nm and 12nm particles on:
- Wild-type (Wt2) samples
- Heterozygous (Het1) samples
- Different synapse regions (D1, E3, S1, S10, S12, S18)

---

## System Components

```
MidasMap/
  config/config.yaml        # All hyperparameters
  src/
    preprocessing.py        # Data loading (10 synapses, 453 particles)
    model.py                # CenterNet: ResNet-50 + BiFPN + heads (24.4M params)
    loss.py                 # CornerNet focal loss + offset regression
    heatmap.py              # GT generation + peak extraction + NMS
    dataset.py              # Patch sampling, augmentation, copy-paste
    postprocess.py          # Mask filter, cross-class NMS
    ensemble.py             # D4 TTA + sliding window inference
    evaluate.py             # Hungarian matching, F1/precision/recall
    visualize.py            # Overlay visualizations
  train.py                  # LOOCV training (--fold, --seed)
  train_final.py            # Final deployable model (all data)
  predict.py                # Inference on new images
  evaluate_loocv.py         # Full evaluation runner
  app.py                    # Gradio web dashboard
  slurm/                    # HPC job scripts
  tests/                    # 36 unit tests
```

---

## Dashboard

MidasMap includes a web-based dashboard (Gradio) for interactive use:

1. **Upload** any TEM image (.tif)
2. **Adjust** confidence threshold and NMS parameters
3. **View** detections overlaid on the image
4. **Inspect** per-class heatmaps
5. **Analyze** confidence distributions and spatial patterns
6. **Export** results as CSV (particle_id, x_px, y_px, x_um, y_um, class, confidence)

```
python app.py --checkpoint checkpoints/final/final_model.pth
# Opens at http://localhost:7860
```

---

## Future Directions

1. **Spatial analytics**: distance to synaptic cleft, nearest-neighbor analysis, Ripley's K-function
2. **Size regression head**: predict actual bead diameter instead of binary classification
3. **18nm detection**: extend to vGlut2 particles (3-class model)
4. **Active learning**: flag low-confidence detections for human review
5. **Cross-protocol generalization**: fine-tune on cryo-EM or different staining protocols

---

## Technical Summary

- **Model**: CenterNet with CEM500K-pretrained ResNet-50, BiFPN neck, stride-2 output
- **Training**: 3-phase with discriminative LRs, 140 epochs, 453 particles / 10 images
- **Evaluation**: Leave-one-image-out CV, Hungarian matching, F1 = 0.943
- **Inference**: Sliding window (512x512, 128px overlap), ~10s per image on GPU
- **Output**: Per-particle (x, y, class, confidence) with optional heatmap visualization