AnikS22 commited on
Commit
d8e0547
·
verified ·
1 Parent(s): cf82a19

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +191 -0
README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - immunogold
5
+ - particle-detection
6
+ - electron-microscopy
7
+ - TEM
8
+ - neuroscience
9
+ - CenterNet
10
+ - CEM500K
11
+ - synapse
12
+ datasets:
13
+ - custom
14
+ metrics:
15
+ - f1
16
+ model-index:
17
+ - name: MidasMap
18
+ results:
19
+ - task:
20
+ type: object-detection
21
+ name: Immunogold Particle Detection
22
+ metrics:
23
+ - type: f1
24
+ value: 0.943
25
+ name: LOOCV Mean F1 (8 annotated folds)
26
+ ---
27
+
28
+ # MidasMap: Immunogold Particle Detection for TEM Synapse Images
29
+
30
+ MidasMap automatically detects **6nm** (AMPA receptor) and **12nm** (NR1/NMDA receptor) immunogold particles in freeze-fracture replica immunolabeling (FFRIL) transmission electron microscopy images.
31
+
32
+ ## Performance
33
+
34
+ | Metric | Value |
35
+ |--------|-------|
36
+ | **LOOCV Mean F1** | **0.943** (8 folds with sufficient annotations) |
37
+ | 6nm (AMPA) F1 | 0.944 (100% recall) |
38
+ | 12nm (NR1) F1 | 0.909 (100% recall) |
39
+ | Parameters | 24.4M |
40
+ | Inference | ~10s per image (GPU) |
41
+
42
+ Validated on 453 labeled particles across 10 synapse images via leave-one-image-out cross-validation with 5 random seeds per fold.
43
+
44
+ ## Quick Start
45
+
46
+ ```python
47
+ import torch
48
+ from src.model import ImmunogoldCenterNet
49
+ from src.ensemble import sliding_window_inference
50
+ from src.heatmap import extract_peaks
51
+ from src.postprocess import cross_class_nms
52
+ import tifffile
53
+
54
+ # Load model
55
+ model = ImmunogoldCenterNet(bifpn_channels=128, bifpn_rounds=2)
56
+ ckpt = torch.load("checkpoints/final/final_model.pth", map_location="cpu")
57
+ model.load_state_dict(ckpt["model_state_dict"])
58
+ model.eval()
59
+
60
+ # Run on any TEM image
61
+ img = tifffile.imread("your_image.tif")
62
+ if img.ndim == 3:
63
+ img = img[:, :, 0]
64
+
65
+ with torch.no_grad():
66
+ hm, off = sliding_window_inference(model, img, patch_size=512, overlap=128)
67
+
68
+ dets = extract_peaks(torch.from_numpy(hm), torch.from_numpy(off),
69
+ stride=2, conf_threshold=0.25)
70
+ dets = cross_class_nms(dets, 8)
71
+
72
+ for d in dets:
73
+ print(f"{d['class']} at ({d['x']:.1f}, {d['y']:.1f}) conf={d['conf']:.3f}")
74
+ ```
75
+
76
+ ## Web Dashboard
77
+
78
+ ```bash
79
+ pip install gradio
80
+ python app.py --checkpoint checkpoints/final/final_model.pth
81
+ # Opens at http://localhost:7860
82
+ ```
83
+
84
+ Upload TIF images, adjust confidence threshold, view heatmaps, and export CSV results.
85
+
86
+ ## Architecture
87
+
88
+ ```
89
+ Raw TEM Image (any size)
90
+ |
91
+ [Sliding window: 512x512, 128px overlap]
92
+ |
93
+ ResNet-50 (CEM500K pretrained on 500K EM images)
94
+ |
95
+ BiFPN (bidirectional feature pyramid, 2 rounds, 128ch)
96
+ |
97
+ Transposed Conv → stride-2 output (H/2 x W/2)
98
+ |
99
+ +--Heatmap Head (2ch sigmoid: 6nm + 12nm)
100
+ +--Offset Head (2ch: sub-pixel x,y correction)
101
+ |
102
+ Peak extraction (max-pool NMS) → detections
103
+ ```
104
+
105
+ ### Key Design Choices
106
+
107
+ - **CEM500K backbone**: Pretrained on 500,000 electron microscopy images. Reaches F1=0.93 in just 5 training epochs because it already understands EM structures.
108
+ - **Stride-2 output**: Standard CenterNet uses stride 4, but 6nm beads (4-6px radius) collapse to 1px at that resolution. Stride 2 preserves 2-3px per bead.
109
+ - **CornerNet focal loss**: Handles the extreme class imbalance (positive:negative pixel ratio ~1:23,000).
110
+ - **Raw image input**: No preprocessing — CEM500K was trained on raw EM, so any heavy filtering creates a domain gap.
111
+
112
+ ## Training
113
+
114
+ ### 3-Phase Strategy
115
+ 1. **Phase 1** (40 epochs): Freeze encoder, train BiFPN + heads at lr=1e-3
116
+ 2. **Phase 2** (40 epochs): Unfreeze layer3+4 at lr=1e-5 to 5e-4
117
+ 3. **Phase 3** (60 epochs): Full fine-tune with discriminative LRs (1e-6 to 2e-4)
118
+
119
+ ### Data Augmentation
120
+ - Random 90-degree rotations, flips
121
+ - Conservative brightness/contrast (+-8%)
122
+ - Gaussian noise, mild blur
123
+ - Copy-paste: real bead crops blended onto training patches
124
+ - 70% hard mining (patches centered on particles)
125
+
126
+ ### Overfitting Prevention
127
+ - RNG reseeded per sample (unique patches every epoch)
128
+ - Early stopping (patience=20, monitoring val F1)
129
+ - Weight decay 1e-4
130
+
131
+ ### Train Final Model
132
+ ```bash
133
+ python train_final.py --config config/config.yaml --device cuda:0
134
+ ```
135
+
136
+ ### HPC (SLURM)
137
+ ```bash
138
+ sbatch slurm/05_train_final.sh
139
+ ```
140
+
141
+ ## LOOCV Results (per fold)
142
+
143
+ | Fold | Avg F1 | Best F1 | # Particles |
144
+ |------|--------|---------|-------------|
145
+ | S27 | 0.990 | 0.994 | 45 |
146
+ | S8 | 0.981 | 0.988 | 70 |
147
+ | S25 | 0.972 | 0.977 | 41 |
148
+ | S29 | 0.956 | 0.966 | 36 |
149
+ | S1 | 0.930 | 0.940 | 22 |
150
+ | S4 | 0.919 | 0.972 | 113 |
151
+ | S22 | 0.907 | 0.938 | 102 |
152
+ | S13 | 0.890 | 0.912 | 20 |
153
+ | S7* | 0.799 | 1.000 | 3 |
154
+ | S15* | 0.633 | 0.667 | 1 |
155
+
156
+ *S7 and S15 have insufficient annotations for reliable evaluation (3 and 1 particles respectively).
157
+
158
+ ## Dataset
159
+
160
+ - 10 FFRIL synapse images (2048x2115 pixels)
161
+ - 403 labeled 6nm particles (AMPA receptors)
162
+ - 50 labeled 12nm particles (NR1 receptors)
163
+ - Annotations in microns, converted at 1790 px/micron
164
+
165
+ ## Critical Implementation Notes
166
+
167
+ 1. **Coordinate conversion**: CSV "XY in microns" values are actual microns, not normalized coordinates. Multiply by 1790 to get pixels.
168
+ 2. **Heatmap peaks**: Must be exactly 1.0 at integer grid centers. The CornerNet focal loss uses `pos_mask = (gt == 1.0)`.
169
+ 3. **Patch diversity**: RNG must be reseeded per `__getitem__` call to prevent memorizing fixed patches.
170
+
171
+ ## Citation
172
+
173
+ If you use MidasMap in your research, please cite:
174
+
175
+ ```bibtex
176
+ @software{midasmap2026,
177
+ title={MidasMap: Automated Immunogold Particle Detection for TEM Synapse Images},
178
+ author={Sahai, Anik},
179
+ year={2026},
180
+ url={https://github.com/AnikS22/MidasMap}
181
+ }
182
+ ```
183
+
184
+ ## Dependencies
185
+
186
+ - PyTorch >= 2.0
187
+ - torchvision
188
+ - albumentations
189
+ - scikit-image
190
+ - tifffile
191
+ - CEM500K weights (download: `python scripts/download_cem500k.py`)