adelelsayed1991 commited on
Commit
abd02e7
·
verified ·
1 Parent(s): 243a764

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +21 -0
  3. README.md +353 -3
  4. configs/__init__.py +0 -0
  5. configs/__pycache__/__init__.cpython-313.pyc +0 -0
  6. configs/__pycache__/configs.cpython-313.pyc +0 -0
  7. configs/configs.py +55 -0
  8. data/__init__.py +0 -0
  9. data/__pycache__/__init__.cpython-313.pyc +0 -0
  10. data/__pycache__/__init__.cpython-314.pyc +0 -0
  11. data/__pycache__/dataset.cpython-313.pyc +0 -0
  12. data/__pycache__/dataset.cpython-314.pyc +0 -0
  13. data/dataset.py +508 -0
  14. data/splitter.py +347 -0
  15. gitignore.txt +61 -0
  16. loss/__init__.py +0 -0
  17. loss/__pycache__/__init__.cpython-313.pyc +0 -0
  18. loss/__pycache__/assymetric.cpython-313.pyc +0 -0
  19. loss/assymetric.py +59 -0
  20. models/__init__.py +0 -0
  21. models/__pycache__/__init__.cpython-313.pyc +0 -0
  22. models/__pycache__/classifier.cpython-313.pyc +0 -0
  23. models/__pycache__/densenet.cpython-313.pyc +0 -0
  24. models/__pycache__/mae.cpython-313.pyc +0 -0
  25. models/classifier.py +323 -0
  26. models/densenet.py +157 -0
  27. models/mae.py +177 -0
  28. notebooks/chexpert_mae.ipynb +0 -0
  29. notebooks/chexpert_mae_mask_classifier.ipynb +0 -0
  30. requirements.txt +29 -0
  31. results/test-results.docx +0 -0
  32. trainer/__init__.py +0 -0
  33. trainer/__pycache__/__init__.cpython-313.pyc +0 -0
  34. trainer/__pycache__/__init__.cpython-314.pyc +0 -0
  35. trainer/__pycache__/trainer.cpython-313.pyc +0 -0
  36. trainer/__pycache__/trainer.cpython-314.pyc +0 -0
  37. trainer/__pycache__/utils.cpython-313.pyc +0 -0
  38. trainer/test.py +15 -0
  39. trainer/trainer.py +19 -0
  40. trainer/utils.py +837 -0
  41. training logs/classifier/1/metrics.png +0 -0
  42. training logs/classifier/11/metrics.png +3 -0
  43. training logs/classifier/Events.docx +3 -0
  44. training logs/classifier/history.json +1 -0
  45. training logs/classifier/test_log.txt +0 -0
  46. training logs/classifier/training_log.txt +0 -0
  47. training logs/classifier/val_log.txt +0 -0
  48. training logs/mae/1/metrics.png +0 -0
  49. training logs/mae/101/metrics.png +0 -0
  50. training logs/mae/11/metrics.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ training[[:space:]]logs/classifier/11/metrics.png filter=lfs diff=lfs merge=lfs -text
37
+ training[[:space:]]logs/classifier/Events.docx filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Adel Elsayed
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,353 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CheXpert MAE-DenseNet-FPN
2
+
3
+ A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining **Masked Autoencoders (MAE)**, **DenseNet** with CBAM attention, and **Feature Pyramid Networks (FPN)** with bidirectional cross-attention fusion.
4
+
5
+ ## 🏗️ Architecture Overview
6
+
7
+ This project implements a novel multi-modal fusion architecture for medical image classification:
8
+
9
+ - **MAE Encoder**: Vision Transformer-based masked autoencoder for self-supervised feature extraction
10
+ - **DenseNet-169**: Dense convolutional network with Channel and Spatial Attention (CBAM)
11
+ - **Feature Pyramid Network**: Multi-scale feature extraction at 4 different resolutions
12
+ - **Bidirectional Cross-Attention**: Fusion mechanism allowing MAE and DenseNet features to attend to each other
13
+ - **Learned Logit Ensemble**: Intelligent combination of 7 prediction heads with learnable temperature scaling
14
+
15
+ ### Key Components
16
+
17
+ ```
18
+ Input Image (384×384)
19
+
20
+ ├─────────────────────────────┐
21
+ │ │
22
+ ▼ ▼
23
+ MAE Encoder DenseNet-169
24
+ (ViT-based) (with CBAM)
25
+ │ │
26
+ │ ┌───────────────────┤
27
+ │ │ │
28
+ │ FPN Pyramid Dense Features
29
+ │ (P1-P4) (Multi-scale)
30
+ │ │ │
31
+ └─────────┴───────────────────┘
32
+
33
+ Bidirectional Cross-Attention
34
+
35
+ ┌─────────┴──────────┐
36
+ │ │
37
+ MAE Head Dense Head + 4 FPN Heads
38
+ │ │
39
+ └────────┬───────────┘
40
+
41
+ Learned Ensemble (7 heads)
42
+
43
+
44
+ 14-class Predictions
45
+ ```
46
+
47
+ ## ✨ Features
48
+
49
+ - **Hybrid Architecture**: Combines transformer-based and convolutional approaches
50
+ - **Multi-scale Learning**: FPN extracts features at 4 different resolutions
51
+ - **Advanced Fusion**: Bidirectional cross-attention between MAE and DenseNet features
52
+ - **Optimized Training**:
53
+ - Mixed precision training (FP16)
54
+ - Gradient accumulation
55
+ - Weighted sampling for class imbalance
56
+ - Cosine annealing with linear warmup
57
+ - Gradient checkpointing for memory efficiency
58
+ - **Smart Data Loading**:
59
+ - ZIP file reader with LRU caching
60
+ - On-the-fly augmentation using Albumentations
61
+ - Multi-worker data loading with persistent workers
62
+ - **Comprehensive Evaluation**:
63
+ - Per-class AUC metrics
64
+ - Optimal threshold computation per class
65
+ - Macro and Micro AUC tracking
66
+
67
+ ## 📋 Requirements
68
+
69
+ - Python 3.8+
70
+ - CUDA-capable GPU (recommended: 16GB+ VRAM)
71
+ - CheXpert dataset
72
+
73
+ ## 🚀 Installation
74
+
75
+ 1. **Clone the repository**
76
+ ```bash
77
+ git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git
78
+ cd chexpert-mae-densenet-fpn
79
+ ```
80
+
81
+ 2. **Create a virtual environment**
82
+ ```bash
83
+ python -m venv venv
84
+ source venv/bin/activate # On Windows: venv\Scripts\activate
85
+ ```
86
+
87
+ 3. **Install dependencies**
88
+ ```bash
89
+ pip install -r requirements.txt
90
+ ```
91
+
92
+ ## 📊 Dataset Setup
93
+
94
+ 1. **Download CheXpert Dataset**
95
+ - Visit: https://stanfordmlgroup.github.io/competitions/chexpert/
96
+ - Download CheXpert-v1.0-small
97
+
98
+ 2. **Prepare the dataset**
99
+ ```bash
100
+ # Extract the dataset
101
+ unzip CheXpert-v1.0-small.zip
102
+
103
+ # Optionally, create a ZIP archive for faster loading
104
+ cd CheXpert-v1.0-small
105
+ zip -r chexpert.zip train/ valid/
106
+ ```
107
+
108
+ 3. **Update configuration**
109
+ - Edit `configs/configs.py`
110
+ - Update `root` variable to point to your dataset location
111
+ - Update all paths accordingly
112
+
113
+ ## 🔧 Configuration
114
+
115
+ Edit `configs/configs.py` to customize:
116
+
117
+ ```python
118
+ # Example: Update paths
119
+ root = "/path/to/your/data"
120
+
121
+ mae_config = {
122
+ "lr": 1e-4,
123
+ "num_epochs": 200,
124
+ "batch_size": 96,
125
+ # ... other parameters
126
+ }
127
+
128
+ config = {
129
+ "lr": 1e-4,
130
+ "num_epochs": 200,
131
+ "batch_size": 36,
132
+ # ... other parameters
133
+ }
134
+ ```
135
+
136
+ ## 🎯 Training
137
+
138
+ ### Phase 1: Pre-train MAE
139
+
140
+ ```bash
141
+ python trainer/trainer.py
142
+ # When prompted, type: mae
143
+ ```
144
+
145
+ The MAE pre-training learns robust feature representations through masked image reconstruction.
146
+
147
+ ### Phase 2: Train Classifier
148
+
149
+ ```bash
150
+ python trainer/trainer.py
151
+ # When prompted, type: classifier
152
+ ```
153
+
154
+ This loads the pre-trained MAE encoder and trains the full classification pipeline.
155
+
156
+ ### Training Configuration
157
+
158
+ - **MAE Training**:
159
+ - Batch size: 96
160
+ - Mask ratio: 0.75 (masks 75% of patches)
161
+ - Reconstruction loss on masked patches
162
+
163
+ - **Classifier Training**:
164
+ - Batch size: 36 with gradient accumulation (8 steps)
165
+ - Effective batch size: 288
166
+ - Asymmetric loss with class weights
167
+ - Per-class threshold optimization
168
+
169
+ ## 🧪 Testing
170
+
171
+ ```python
172
+ from trainer.utils import Trainer
173
+ from configs.configs import config
174
+
175
+ # Initialize trainer
176
+ trainer = Trainer(config)
177
+
178
+ # Run evaluation on test set
179
+ macro_auc, micro_auc, per_class = trainer.test(
180
+ model_path="path/to/checkpoint.pth"
181
+ )
182
+
183
+ print(f"Macro AUC: {macro_auc:.4f}")
184
+ print(f"Micro AUC: {micro_auc:.4f}")
185
+ ```
186
+
187
+ ## 📁 Project Structure
188
+
189
+ ```
190
+ chexpert-mae-densenet-fpn/
191
+ ├── configs/
192
+ │ ├── __init__.py
193
+ │ └── configs.py # Configuration parameters
194
+ ├── data/
195
+ │ ├── __init__.py
196
+ │ ├── dataset.py # CheXpert dataset with ZIP caching
197
+ │ └── splitter.py # Data splitting utilities
198
+ ├── loss/
199
+ │ ├── __init__.py
200
+ │ └── assymetric.py # Asymmetric loss for imbalanced data
201
+ ├── models/
202
+ │ ├── __init__.py
203
+ │ ├── mae.py # Masked Autoencoder implementation
204
+ │ ├── densenet.py # DenseNet-169 with CBAM
205
+ │ └── classifier.py # Full classification architecture
206
+ ├── trainer/
207
+ │ ├── __init__.py
208
+ │ ├── trainer.py # Main training script
209
+ │ ├── utils.py # Training utilities and loops
210
+ │ └── test.py # Testing utilities
211
+ ├── notebooks/
212
+ │ ├── chexpert_mae.ipynb # MAE experiments
213
+ │ └── chexpert_mae_mask_classifier.ipynb # Full pipeline experiments
214
+ ├── requirements.txt
215
+ └── README.md
216
+ ```
217
+
218
+ ## 📈 Model Architecture Details
219
+
220
+ ### MAE Encoder
221
+ - **Patch size**: 16×16
222
+ - **Embedding dim**: 768
223
+ - **Depth**: 12 transformer blocks
224
+ - **Heads**: 8 attention heads
225
+ - **MLP ratio**: 4×
226
+
227
+ ### DenseNet-169
228
+ - **Growth rate (k)**: 64
229
+ - **Layers**: [6, 12, 24, 16]
230
+ - **CBAM**: Channel + Spatial attention at each stage
231
+ - **Dropout**: Progressive (0.05 → 0.1 → 0.1 → 0.1)
232
+
233
+ ### Cross-Attention Fusion
234
+ - **12 bidirectional cross-attention layers**
235
+ - **Projection dim**: 512
236
+ - **Attention heads**: 8
237
+
238
+ ### FPN
239
+ - **Feature levels**: P1 (192×192), P2 (96×96), P3 (48×48), P4 (24×24)
240
+ - **Channel unification**: 256 channels per level
241
+
242
+ ## 🎓 CheXpert Labels
243
+
244
+ The model predicts 14 pathologies:
245
+
246
+ 1. No Finding
247
+ 2. Enlarged Cardiomediastinum
248
+ 3. Cardiomegaly
249
+ 4. Lung Opacity
250
+ 5. Lung Lesion
251
+ 6. Edema
252
+ 7. Consolidation
253
+ 8. Pneumonia
254
+ 9. Atelectasis
255
+ 10. Pneumothorax
256
+ 11. Pleural Effusion
257
+ 12. Pleural Other
258
+ 13. Fracture
259
+ 14. Support Devices
260
+
261
+ ## 🔬 Data Augmentation
262
+
263
+ Training augmentations (conservative for medical images):
264
+ - Horizontal flip (p=0.5)
265
+ - Random affine (translation, scale, rotation ±10°)
266
+ - Random brightness/contrast
267
+ - CLAHE histogram equalization
268
+ - Gaussian blur and noise
269
+
270
+ ## 💾 Checkpoints
271
+
272
+ The training automatically saves:
273
+ - **Best MAE checkpoint**: Based on validation reconstruction loss
274
+ - **Best classifier checkpoint**: Based on validation AUC (macro/micro)
275
+ - **Training history**: JSON file with all metrics
276
+ - **Per-epoch metrics plots**: Loss and AUC curves
277
+
278
+ ## 📊 Monitoring
279
+
280
+ Training logs are saved to:
281
+ - `training_log.txt`: Training progress with live metrics
282
+ - `val_log.txt`: Validation results
283
+ - `test_log.txt`: Test evaluation results
284
+ - `history.json`: All metrics across epochs
285
+ - `metrics.png`: Visualization plots
286
+
287
+ ## ⚡ Performance Tips
288
+
289
+ 1. **Memory Optimization**:
290
+ - Use gradient checkpointing (already enabled)
291
+ - Reduce batch size if OOM occurs
292
+ - Increase gradient accumulation steps
293
+
294
+ 2. **Speed Optimization**:
295
+ - Use persistent workers (already enabled)
296
+ - Enable cuDNN benchmark (already enabled)
297
+ - Use ZIP caching for faster data loading
298
+
299
+ 3. **Training Stability**:
300
+ - Gradient clipping at norm 1.0
301
+ - Mixed precision with dynamic loss scaling
302
+ - Warmup learning rate schedule
303
+
304
+ ## 🐛 Troubleshooting
305
+
306
+ **Q: Out of memory errors?**
307
+ - Reduce batch size in configs.py
308
+ - Increase gradient accumulation steps
309
+ - Enable gradient checkpointing
310
+
311
+ **Q: Slow training?**
312
+ - Check if ZIP caching is enabled
313
+ - Verify persistent workers are active
314
+ - Monitor GPU utilization
315
+
316
+ **Q: Poor convergence?**
317
+ - Ensure MAE is properly pre-trained first
318
+ - Check learning rate and warmup settings
319
+ - Verify class weights are computed correctly
320
+
321
+ ## 📚 Citation
322
+
323
+ If you use this code in your research, please cite:
324
+
325
+ ```bibtex
326
+ @misc{chexpert-mae-densenet-fpn,
327
+ author = {adel elsayed},
328
+ title = {CheXpert Classification with MAE-DenseNet-FPN},
329
+ year = {2025},
330
+ publisher = {GitHub},
331
+ url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn}
332
+ }
333
+ ```
334
+
335
+ ## 🙏 Acknowledgments
336
+
337
+ - **CheXpert Dataset**: Stanford ML Group
338
+ - **Masked Autoencoders**: Meta AI Research (He et al., 2021)
339
+ - **DenseNet**: Huang et al., 2017
340
+ - **CBAM**: Woo et al., 2018
341
+ - **Feature Pyramid Networks**: Lin et al., 2017
342
+
343
+ ## 📄 License
344
+
345
+ ## License
346
+ This project is licensed under the MIT License.
347
+
348
+
349
+ ## 📧 Contact
350
+
351
+ https://www.linkedin.com/in/adel-elsayed-a5260246/
352
+
353
+ **Note**: This is a research project. For clinical use, please ensure proper validation and regulatory approval.
configs/__init__.py ADDED
File without changes
configs/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (152 Bytes). View file
 
configs/__pycache__/configs.cpython-313.pyc ADDED
Binary file (4.44 kB). View file
 
configs/configs.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ root = "/content/drive/MyDrive"
4
+ mae_config={
5
+ "lr":1e-4,
6
+ "warmup":5,
7
+ "weight_decay":5e-4,
8
+ "num_epochs":200,
9
+ "num_classes":14,
10
+ "zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"),
11
+ "resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"),
12
+ "logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs"),
13
+ "checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),
14
+ "datadir":root,
15
+ "lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"),
16
+ "csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"),
17
+ "batch_size":96,
18
+ "device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
19
+ "accumulation":1,
20
+ "dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),os.path.join(root,"CheXpert-v1.0-small","maelogs")],
21
+ "train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"),
22
+ "val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"),
23
+ "test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
24
+ ,"channels":1,"mask_ratio":0.75,"dropout":0.25,"img_size":384,"encoder_dim":768,
25
+ "mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8,
26
+ "decoder_head":8,"patch_size":16
27
+ }
28
+ config={
29
+ "lr":1e-4,
30
+ "warmup":10,
31
+ "weight_decay":5e-4,
32
+ "num_epochs":200,
33
+ "num_classes":14,
34
+ "zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"),
35
+ "backbone":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"),
36
+ "densebackbone":os.path.join(root,"CheXpert-v1.0-small","checkpoints","No Eca with masking best_dense.pth"),
37
+ "resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","fpn","best_mae_classifier.pth"),
38
+ "logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs","fpn","classifier"),
39
+ "checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),
40
+ "datadir":root,
41
+ "lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"),
42
+ "csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"),
43
+ "batch_size":36,
44
+ "device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
45
+ "accumulation":8,
46
+ "maskdir":os.path.join(root,"CheXpert-v1.0-small","fpn","mask"),
47
+ "dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","fpn"),os.path.join(root,"CheXpert-v1.0-small","maelogs","fpn","classifier"),os.path.join(root,"CheXpert-v1.0-small","fpn","mask")],
48
+ "train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"),
49
+ "val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"),
50
+ "test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
51
+ ,"channels":1,"mask_ratio":0,"dropout":0.25,"img_size":384,"encoder_dim":768,
52
+ "mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8,
53
+ "decoder_head":8,"patch_size":16
54
+
55
+ }
data/__init__.py ADDED
File without changes
data/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (149 Bytes). View file
 
data/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (151 Bytes). View file
 
data/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (21.6 kB). View file
 
data/__pycache__/dataset.cpython-314.pyc ADDED
Binary file (22.3 kB). View file
 
data/dataset.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library
2
+ import os
3
+ import io
4
+ import zipfile
5
+ import pickle
6
+ from pathlib import Path
7
+
8
+ # Data handling
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ # PyTorch
13
+ import torch
14
+ from torch.utils.data import Dataset
15
+
16
+ # Image processing
17
+ from PIL import Image
18
+ import cv2
19
+
20
+ # Augmentations
21
+ import albumentations as A
22
+ from albumentations.pytorch import ToTensorV2
23
+
24
+ # Progress bar (for precompute_all_masks)
25
+ from tqdm import tqdm
26
+
27
+ class OptimizedZipReader:
28
+ """
29
+ Fast ZIP file reader with LRU caching
30
+ """
31
+ def __init__(self, zip_path, cache_size=1000):
32
+ """
33
+ Args:
34
+ zip_path: Path to ZIP file
35
+ cache_size: Number of images to cache in RAM
36
+ """
37
+ self.zip_path = zip_path
38
+ self.cache_size = cache_size
39
+ self._zip_file = None # Will be lazily initialized
40
+ self._name_to_info = None
41
+
42
+ # Cache
43
+ self._cache = {}
44
+ self._cache_order = []
45
+ self._hits = 0
46
+ self._misses = 0
47
+
48
+ @property
49
+ def zip_file(self):
50
+ """Lazy initialization of ZIP file handle"""
51
+ if self._zip_file is None:
52
+ print(f"Opening ZIP file: {self.zip_path}")
53
+ self._zip_file = zipfile.ZipFile(self.zip_path, 'r', allowZip64=True)
54
+
55
+ # Build index on first access
56
+ print("Building ZIP index...")
57
+ self._name_to_info = {
58
+ info.filename: info
59
+ for info in self._zip_file.infolist()
60
+ }
61
+ print(f"✓ Indexed {len(self._name_to_info)} files")
62
+
63
+ return self._zip_file
64
+
65
+ def read_image(self, path):
66
+ """
67
+ Read image data with automatic caching
68
+
69
+ Returns: bytes (image file data)
70
+ """
71
+ # Check cache first
72
+ if path in self._cache:
73
+ self._hits += 1
74
+ return self._cache[path]
75
+
76
+ # Cache miss - read from ZIP (this triggers lazy initialization)
77
+ self._misses += 1
78
+ img_data = self.zip_file.read(path) # Uses property getter
79
+
80
+ # Add to cache with LRU eviction
81
+ if len(self._cache) >= self.cache_size:
82
+ oldest = self._cache_order.pop(0)
83
+ del self._cache[oldest]
84
+
85
+ self._cache[path] = img_data
86
+ self._cache_order.append(path)
87
+
88
+ return img_data
89
+
90
+ def get_cache_stats(self):
91
+ """Return cache hit rate statistics"""
92
+ total = self._hits + self._misses
93
+ hit_rate = self._hits / total * 100 if total > 0 else 0
94
+ return {
95
+ 'hits': self._hits,
96
+ 'misses': self._misses,
97
+ 'hit_rate': f"{hit_rate:.2f}%",
98
+ 'cache_size': len(self._cache)
99
+ }
100
+
101
+ def close(self):
102
+ """Close ZIP file and clear cache"""
103
+ if self._zip_file is not None:
104
+ self._zip_file.close()
105
+ self._zip_file = None
106
+ self._cache.clear()
107
+ self._cache_order.clear()
108
+ self._name_to_info = None
109
+
110
+ class CheXpertDataset(Dataset):
111
+ """
112
+ CheXpert Dataset class
113
+
114
+ NEW: Returns 3-channel images: (img, img*mask, mask)
115
+ - Channel 0: Original grayscale image
116
+ - Channel 1: Masked image (lung region only)
117
+ - Channel 2: Binary lung mask
118
+
119
+ Args:
120
+ csv_path (str): Path to the CSV file (train.csv or valid.csv)
121
+ root_dir (str): Root directory of the CheXpert dataset
122
+ image_size (int): Target image size (default: 384)
123
+ augment (bool): Whether to apply augmentations (default: False)
124
+ use_frontal_only (bool): If True, only use frontal view images (default: True)
125
+ fill_uncertain (str): How to handle uncertain labels: 'zeros', 'ones', 'ignore' (default: 'zeros')
126
+ """
127
+
128
+ # 14 pathology classes in CheXpert
129
+ PATHOLOGIES = [
130
+ 'No Finding',
131
+ 'Enlarged Cardiomediastinum',
132
+ 'Cardiomegaly',
133
+ 'Lung Opacity',
134
+ 'Lung Lesion',
135
+ 'Edema',
136
+ 'Consolidation',
137
+ 'Pneumonia',
138
+ 'Atelectasis',
139
+ 'Pneumothorax',
140
+ 'Pleural Effusion',
141
+ 'Pleural Other',
142
+ 'Fracture',
143
+ 'Support Devices'
144
+ ]
145
+
146
+ def __init__(
147
+ self,
148
+ csv_path,
149
+ root_dir,
150
+ image_size=384,
151
+ augment=False,
152
+ use_frontal_only=False,
153
+ fill_uncertain='ignore',
154
+ lmdb_path=None,
155
+ zip_path=None,
156
+ zip_cache_size=1000,
157
+ mask_dir=None, domask=False
158
+ ):
159
+ self.root_dir = root_dir
160
+ self.image_size = image_size
161
+ self.augment = augment
162
+ self.fill_uncertain = fill_uncertain
163
+ self.env =None #lmdb.open(lmdb_path, readonly=True, lock=False) if lmdb_path else None
164
+ self._zip_path = zip_path
165
+ self._zip_cache_size = zip_cache_size
166
+ self._zip_reader_instance = None
167
+
168
+
169
+ # Read CSV file
170
+ self.df = pd.read_csv(csv_path)
171
+ for pathology in self.PATHOLOGIES:
172
+ if pathology in self.df.columns:
173
+ self.df[pathology] = pd.to_numeric(self.df[pathology], errors='coerce')
174
+
175
+ # Filter for frontal views only if specified
176
+ if use_frontal_only:
177
+ self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
178
+
179
+ # Handle uncertain labels (-1 values)
180
+ self._process_uncertain_labels()
181
+
182
+ # Setup augmentations
183
+ self.train_transform = self._get_train_transforms()
184
+ self.val_transform = self._get_val_transforms()
185
+
186
+ print(f"Loaded {len(self.df)} images from {csv_path}")
187
+ print(f"Image size: {image_size}x{image_size}")
188
+ print(f"Augmentation: {augment}")
189
+ print(f"Uncertain labels filled with: {fill_uncertain}")
190
+
191
+ if mask_dir and domask:
192
+ self.precompute_all_masks(mask_dir)
193
+
194
+ # Run this ONCE before training
195
+ def precompute_all_masks(self, save_dir):
196
+ os.makedirs(save_dir, exist_ok=True)
197
+ for idx in tqdm(range(len(self))):
198
+ img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
199
+ part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
200
+ if self.zip_reader:
201
+ # Read image data from ZIP (no extraction!)
202
+ img_data = self.zip_reader.read_image(part_path)
203
+
204
+ # Open image from bytes in memory
205
+ image = Image.open(io.BytesIO(img_data)).convert('L')
206
+ else:
207
+ image = Image.open(img_path).convert('L')
208
+
209
+ image = np.array(image)
210
+
211
+ mask = chexpert_medsam_mask(image)
212
+ mask_path = os.path.join(save_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
213
+ os.makedirs(os.path.dirname(mask_path), exist_ok=True)
214
+ torch.save(mask, mask_path)
215
+ @property
216
+ def zip_reader(self):
217
+ """
218
+ Lazy property getter for ZIP reader
219
+
220
+ The ZIP file is only opened when first accessed, not during __init__.
221
+ This is useful when:
222
+ - Creating multiple dataset objects but only using some
223
+ - Saving memory during dataset setup
224
+ - Working with multiprocessing (each worker creates its own)
225
+ """
226
+ if self._zip_reader_instance is None and self._zip_path is not None:
227
+ self._zip_reader_instance = OptimizedZipReader(
228
+ self._zip_path,
229
+ cache_size=self._zip_cache_size
230
+ )
231
+ return self._zip_reader_instance
232
+
233
+ def _load_and_cache_image(self, img_path, idx):
234
+ """
235
+ Load image with automatic resizing and caching.
236
+ If resized version exists, load it. Otherwise, resize, save, and load.
237
+
238
+ Args:
239
+ img_path (str): Original image path from CSV
240
+ idx (int): Index for tracking
241
+
242
+ Returns:
243
+ np.ndarray: Loaded image (grayscale)
244
+ """
245
+ # Create cache directory structure
246
+ cache_dir = Path(self.root_dir) #/ f"cache_{self.image_size}"
247
+
248
+ # Preserve the relative path structure in cache
249
+ path_parts = list(Path(img_path).parts)
250
+ path_parts[-1]=f"{self.image_size}_{path_parts[-1]}"
251
+ relative_path = Path(*path_parts)
252
+ cached_path =relative_path.with_suffix('.jpg')
253
+
254
+ # Check if cached version exists
255
+ if cached_path.exists():
256
+ # Load cached image
257
+ image = Image.open(cached_path).convert('L')
258
+ image = np.array(image)
259
+
260
+ # Verify it's the correct size
261
+ if image.shape[0] == self.image_size and image.shape[1] == self.image_size:
262
+ return image
263
+
264
+ # Cache doesn't exist or wrong size - load original
265
+ original_path = img_path
266
+ image = Image.open(original_path).convert('L')
267
+
268
+ # Check if original is already target size
269
+ width, height = image.size
270
+
271
+ if width == self.image_size and height == self.image_size:
272
+ # Already correct size, just convert to array
273
+ return np.array(image)
274
+
275
+ # Resize image
276
+ image_resized = image.resize(
277
+ (self.image_size, self.image_size),
278
+ Image.LANCZOS
279
+ )
280
+
281
+ # Save to cache
282
+ cached_path.parent.mkdir(parents=True, exist_ok=True)
283
+ image_resized.save(cached_path, 'JPEG', quality=95, optimize=True)
284
+
285
+ return np.array(image_resized)
286
+
287
+ def _process_uncertain_labels(self):
288
+ """Process uncertain labels (-1) based on the chosen strategy."""
289
+ for pathology in self.PATHOLOGIES:
290
+ if pathology in self.df.columns:
291
+ if self.fill_uncertain == 'zeros':
292
+ # Map uncertain (-1) to negative (0)
293
+ self.df[pathology] = self.df[pathology].replace(-1, 0)
294
+ elif self.fill_uncertain == 'ones':
295
+ # Map uncertain (-1) to positive (1)
296
+ self.df[pathology] = self.df[pathology].replace(-1, 1)
297
+ elif self.fill_uncertain == 'ignore':
298
+ # Keep -1 as is (you'll need to handle this in loss function)
299
+ pass
300
+
301
+ # Fill NaN with 0 (negative)
302
+ self.df[pathology] = self.df[pathology].fillna(0)
303
+
304
+ def _get_train_transforms(self):
305
+ """Get training augmentations suitable for chest X-rays."""
306
+ import cv2
307
+ return A.Compose([
308
+ # Resize to target size
309
+ A.LongestMaxSize(max_size=self.image_size),
310
+ A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
311
+
312
+ # Geometric augmentations (conservative for medical images)
313
+ A.HorizontalFlip(p=0.5),
314
+ A.Affine(
315
+ translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
316
+ scale=(0.9, 1.1),
317
+ rotate=(-10, 10),
318
+ fit_output=False,
319
+ p=0.5
320
+ ),
321
+
322
+ # Intensity augmentations
323
+ A.OneOf([
324
+ A.RandomBrightnessContrast(
325
+ brightness_limit=0.2,
326
+ contrast_limit=0.2,
327
+ p=1.0
328
+ ),
329
+ A.RandomGamma(gamma_limit=(80, 120), p=1.0),
330
+ A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
331
+ ], p=0.5),
332
+
333
+ # Add slight blur to simulate different imaging conditions
334
+ A.OneOf([
335
+ A.GaussianBlur(blur_limit=(3, 5), p=1.0),
336
+ A.MedianBlur(blur_limit=3, p=1.0),
337
+ ], p=0.2),
338
+
339
+ # Add noise
340
+ A.GaussNoise(p=0.2),
341
+
342
+ # Normalize to [0, 1]
343
+ A.Normalize(
344
+ mean=[0.5],
345
+ std=[0.5],
346
+ max_pixel_value=255.0
347
+ ),
348
+
349
+ ToTensorV2()
350
+ ])
351
+
352
+ def _get_val_transforms(self):
353
+ """Get validation/test transforms (no augmentation)."""
354
+ return A.Compose([
355
+ A.LongestMaxSize(max_size=self.image_size),
356
+ A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
357
+ A.Normalize(
358
+ mean=[0.5],
359
+ std=[0.5],
360
+ max_pixel_value=255.0
361
+ ),
362
+ ToTensorV2()
363
+ ])
364
+
365
+ def __len__(self):
366
+ return len(self.df)
367
+
368
+ def __del__(self):
369
+ """Close ZIP when done"""
370
+ if hasattr(self, 'zip_reader'):
371
+ self.zip_reader.close()
372
+
373
+ def __getitem__(self, idx):
374
+ if self.env:
375
+ with self.env.begin() as txn:
376
+ # Retrieve serialized data
377
+ data = txn.get(str(idx).encode())
378
+ sample = pickle.loads(data)
379
+ return sample
380
+ else:
381
+ # Get image path
382
+ img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
383
+ #image = self._load_and_cache_image(img_path, idx)
384
+ # Load image
385
+ #image = Image.open(img_path).convert('L') # Convert to grayscale
386
+
387
+ part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
388
+ if self.zip_reader:
389
+ # Read image data from ZIP (no extraction!)
390
+ img_data = self.zip_reader.read_image(part_path)
391
+
392
+ # Open image from bytes in memory
393
+ image = Image.open(io.BytesIO(img_data)).convert('L')
394
+ else:
395
+ image = Image.open(img_path).convert('L')
396
+
397
+ image = np.array(image)
398
+
399
+
400
+ # Load pre-computed mask
401
+ #mask_path = os.path.join(self.mask_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
402
+ #masked_img = torch.load(mask_path)
403
+ # Apply transforms to BOTH image and mask together
404
+ if self.augment:
405
+ # Augmentation applies to both image and mask
406
+ transformed = self.train_transform(image=image)
407
+ image_transformed = transformed['image'] # (1, H, W) tensor, normalized
408
+ #masked_img=transformed['mask']
409
+ # (H, W) tensor
410
+ else:
411
+ transformed = self.val_transform(image=image)
412
+ image_transformed = transformed['image'] # (1, H, W) tensor, normalized
413
+ #masked_img=transformed['mask']
414
+
415
+ # Expand dimensions to match
416
+ image_1ch = image_transformed # (1, H, W)
417
+ masked_img = image_transformed
418
+
419
+ # Get labels for all pathologies
420
+ labels = []
421
+ for pathology in self.PATHOLOGIES:
422
+ if pathology in self.df.columns:
423
+ label = self.df.iloc[idx][pathology]
424
+ labels.append(float(label) if not pd.isna(label) else 0.0)
425
+ else:
426
+ labels.append(0.0)
427
+
428
+ labels = torch.tensor(labels, dtype=torch.float32)
429
+
430
+ # Get additional metadata
431
+ metadata = {
432
+ 'patient_id': self.df.iloc[idx]['Path'].split('/')[2], # Extract patient ID from path
433
+ 'study_id': self.df.iloc[idx]['Path'].split('/')[3], # Extract study ID from path
434
+ 'view': self.df.iloc[idx]['Frontal/Lateral'],
435
+ 'sex': self.df.iloc[idx]['Sex'] if 'Sex' in self.df.columns else 'Unknown',
436
+ 'age': self.df.iloc[idx]['Age'] if 'Age' in self.df.columns else -1,
437
+ 'path': self.df.iloc[idx]['Path']
438
+ }
439
+
440
+ return {
441
+ 'image': image_1ch,
442
+ 'labels': labels,
443
+ 'metadata': metadata
444
+ }
445
+
446
+ def get_label_names(self):
447
+ """Return list of pathology label names."""
448
+ return self.PATHOLOGIES
449
+
450
+ def get_label_distribution(self):
451
+ """Get distribution of positive labels for each pathology."""
452
+ distribution = {}
453
+ for pathology in self.PATHOLOGIES:
454
+ if pathology in self.df.columns:
455
+ positive_count = (self.df[pathology] == 1.0).sum()
456
+ distribution[pathology] = {
457
+ 'positive': int(positive_count),
458
+ 'percentage': round(positive_count / len(self.df) * 100, 2)
459
+ }
460
+ return distribution
461
+
462
+ def get_class_weights(self):
463
+ """
464
+ OPTIMIZED: Vectorized class weights calculation
465
+ """
466
+ weights = []
467
+ for pathology in self.PATHOLOGIES:
468
+ if pathology in self.df.columns:
469
+ # Vectorized counting (much faster than iterating)
470
+ values = self.df[pathology].values
471
+ pos = np.sum(values == 1.0)
472
+ neg = np.sum(values == 0.0)
473
+ weight = neg / pos if pos > 0 else 1.0
474
+ weights.append(weight)
475
+ return torch.tensor(weights, dtype=torch.float32)
476
+
477
+ def get_sample_weights(self):
478
+ """
479
+ OPTIMIZED: Vectorized sample weights calculation
480
+
481
+ Performance: ~1000x faster than original
482
+ Original: 15-30 seconds for 200k samples
483
+ This: 0.01-0.05 seconds for 200k samples
484
+ """
485
+ # Get class weights as numpy array
486
+ class_weights = self.get_class_weights().numpy()
487
+
488
+ # Get all labels as numpy array in ONE vectorized operation
489
+ labels_array = self.df[self.PATHOLOGIES].values.astype(np.float32)
490
+
491
+ # Create weighted labels matrix: where label=1, use class_weight, else -inf
492
+ # Shape: (n_samples, n_classes)
493
+ weighted_labels = np.where(
494
+ labels_array == 1.0,
495
+ class_weights,
496
+ -np.inf # Use -inf instead of 0 so max will only consider positive labels
497
+ )
498
+
499
+ # For each sample, find the maximum class weight of its positive labels
500
+ # If a sample has no positive labels, max will be -inf, which we'll replace with 1.0
501
+ sample_weights = np.max(weighted_labels, axis=1)
502
+ sample_weights = np.where(
503
+ np.isinf(sample_weights),
504
+ 1.0, # Samples with no positive labels get weight 1.0
505
+ sample_weights
506
+ )
507
+
508
+ return torch.tensor(sample_weights, dtype=torch.float32)
data/splitter.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # Data handling
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ # Machine learning
10
+ from sklearn.model_selection import train_test_split
11
+
12
+ class CheXpertDataSplitter:
13
+ """
14
+ Advanced stratified train-validation splitter for CheXpert dataset.
15
+ Handles:
16
+ - Patient-level splitting (prevents data leakage)
17
+ - Multi-label stratification
18
+ - Class imbalance awareness
19
+ - Study-level grouping
20
+ """
21
+
22
+ PATHOLOGIES = [
23
+ 'No Finding',
24
+ 'Enlarged Cardiomediastinum',
25
+ 'Cardiomegaly',
26
+ 'Lung Opacity',
27
+ 'Lung Lesion',
28
+ 'Edema',
29
+ 'Consolidation',
30
+ 'Pneumonia',
31
+ 'Atelectasis',
32
+ 'Pneumothorax',
33
+ 'Pleural Effusion',
34
+ 'Pleural Other',
35
+ 'Fracture',
36
+ 'Support Devices'
37
+ ]
38
+
39
+ def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42,
40
+ use_frontal_only=True, fill_uncertain='zeros',root=None):
41
+ """
42
+ Initialize the splitter.
43
+
44
+ Args:
45
+ csv_path: Path to train.csv from CheXpert-small
46
+ val_size: Validation set proportion (default: 0.15)
47
+ random_state: Random seed for reproducibility
48
+ use_frontal_only: Use only frontal view images
49
+ fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore')
50
+ """
51
+ self.csv_path = csv_path
52
+ self.val_size = val_size
53
+ self.test_size = test_size
54
+ self.random_state = random_state
55
+ self.use_frontal_only = use_frontal_only
56
+ self.fill_uncertain = fill_uncertain
57
+ self.root=root
58
+
59
+ print("=" * 80)
60
+ print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias")
61
+ print("=" * 80)
62
+
63
+ def load_and_preprocess(self):
64
+ """Load and preprocess the dataset."""
65
+ print("\n[1/5] Loading data...")
66
+ self.df = pd.read_csv(self.csv_path)
67
+ print(f" Loaded {len(self.df)} images")
68
+
69
+ #self.df=self.df[self.df["Path"].apply(os.path.exists)]
70
+
71
+ # Filter for frontal views only
72
+ if self.use_frontal_only:
73
+ initial_count = len(self.df)
74
+ self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
75
+ print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)")
76
+
77
+ # Extract patient and study IDs from path
78
+ print("\n[2/5] Extracting patient and study IDs...")
79
+ self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2])
80
+ self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3])
81
+
82
+ n_patients = self.df['patient_id'].nunique()
83
+ n_studies = self.df['study_id'].nunique()
84
+ print(f" Unique patients: {n_patients}")
85
+ print(f" Unique studies: {n_studies}")
86
+ print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}")
87
+
88
+ # Process uncertain labels
89
+ print("\n[3/5] Processing uncertain labels...")
90
+ self._process_uncertain_labels()
91
+
92
+ return self.df
93
+
94
+ def _process_uncertain_labels(self):
95
+ """Process uncertain labels (-1) based on the chosen strategy."""
96
+ for pathology in self.PATHOLOGIES:
97
+ if pathology in self.df.columns:
98
+ uncertain_count = (self.df[pathology] == -1).sum()
99
+
100
+ if self.fill_uncertain == 'zeros':
101
+ self.df[pathology] = self.df[pathology].replace(-1, 0)
102
+ elif self.fill_uncertain == 'ones':
103
+ self.df[pathology] = self.df[pathology].replace(-1, 1)
104
+ elif self.fill_uncertain == 'ignore':
105
+ pass # Keep -1 as is
106
+
107
+ # Fill NaN with 0
108
+ self.df[pathology] = self.df[pathology].fillna(0)
109
+
110
+ print(f" Uncertain labels strategy: {self.fill_uncertain}")
111
+
112
+ def create_stratification_groups(self):
113
+ """
114
+ Create stratification groups based on multi-label combinations.
115
+ Uses patient-level aggregation to prevent data leakage.
116
+ """
117
+ print("\n[4/5] Creating stratification groups (patient-level)...")
118
+
119
+ # Group by patient and aggregate labels
120
+ patient_groups = self.df.groupby('patient_id').agg({
121
+ **{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns},
122
+ 'study_id': 'first', # Keep one study_id for reference
123
+ 'Sex': 'first',
124
+ 'Age': 'first'
125
+ }).reset_index()
126
+
127
+ # Create label signature for each patient
128
+ # This is a binary string representing which conditions are present
129
+ def create_label_signature(row):
130
+ signature = []
131
+ for pathology in self.PATHOLOGIES:
132
+ if pathology in patient_groups.columns:
133
+ signature.append(str(int(row[pathology])))
134
+ return ''.join(signature)
135
+
136
+ patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1)
137
+
138
+ # For rare combinations, group them together
139
+ signature_counts = patient_groups['label_signature'].value_counts()
140
+ rare_threshold = max(5, int(len(patient_groups) * 0.001)) # At least 5 or 0.1%
141
+
142
+ def get_stratification_group(signature):
143
+ if signature_counts[signature] < rare_threshold:
144
+ return 'RARE_COMBINATION'
145
+ return signature
146
+
147
+ patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group)
148
+
149
+ # Print distribution statistics
150
+ print(f"\n Patient-level label distribution:")
151
+ for pathology in self.PATHOLOGIES:
152
+ if pathology in patient_groups.columns:
153
+ positive_count = (patient_groups[pathology] == 1).sum()
154
+ percentage = positive_count / len(patient_groups) * 100
155
+ print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)")
156
+
157
+ unique_groups = patient_groups['stratification_group'].nunique()
158
+ print(f"\n Unique stratification groups: {unique_groups}")
159
+ print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}")
160
+
161
+ return patient_groups
162
+
163
+ def perform_split(self, patient_groups):
164
+ """
165
+ Perform stratified train-validation-test split at patient level.
166
+ """
167
+ print("\n[5/5] Performing stratified patient-level split...")
168
+
169
+ stratification_labels = patient_groups['stratification_group'].values
170
+
171
+ # ---- train / (val+test) ----
172
+ train_patients, valtest_patients = train_test_split(
173
+ patient_groups['patient_id'].values,
174
+ test_size=self.val_size + self.test_size, # <-- new
175
+ stratify=stratification_labels,
176
+ random_state=self.random_state
177
+ )
178
+
179
+ # ---- val / test from the remaining pool ----
180
+ remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values
181
+ val_patients, test_patients = train_test_split(
182
+ valtest_patients,
183
+ test_size=self.test_size / (self.val_size + self.test_size), # <-- proportion of the val+test pool
184
+ stratify=remaining_labels,
185
+ random_state=self.random_state
186
+ )
187
+
188
+ print(f" Train patients: {len(train_patients)}")
189
+ print(f" Val patients: {len(val_patients)}")
190
+ print(f" Test patients: {len(test_patients)}")
191
+
192
+ # Split the full dataframe
193
+ train_df = self.df[self.df['patient_id'].isin(train_patients)].copy()
194
+ val_df = self.df[self.df['patient_id'].isin(val_patients)].copy()
195
+ test_df = self.df[self.df['patient_id'].isin(test_patients)].copy()
196
+
197
+ # ---- leakage check (train vs val vs test) ----
198
+ sets = [('train', train_df), ('val', val_df), ('test', test_df)]
199
+ for i, (name_i, df_i) in enumerate(sets):
200
+ for j, (name_j, df_j) in enumerate(sets[i+1:]):
201
+ overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id']))
202
+ if overlap:
203
+ raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap")
204
+ print("\n No patient overlap – leakage prevented!")
205
+
206
+ return train_df, val_df, test_df
207
+
208
+ def run(self, output_dir='.', save_test=True):
209
+ self.load_and_preprocess()
210
+ patient_groups = self.create_stratification_groups()
211
+ train_df, val_df, test_df = self.perform_split(patient_groups)
212
+
213
+ self.verify_split_quality(train_df, val_df)
214
+ # optional: also verify train vs test (same function works with two dfs)
215
+ print("\n--- Train vs Test distribution check ---")
216
+ self.verify_split_quality(train_df, test_df)
217
+
218
+ train_path, val_path = self.save_splits(train_df, val_df, output_dir)
219
+ if save_test:
220
+ test_path = self.save_test_split(test_df, output_dir)
221
+ else:
222
+ test_path = None
223
+
224
+ print("\n" + "="*80)
225
+ print("Split Complete! (train / val / test)")
226
+ print("="*80)
227
+ return train_path, val_path, test_path
228
+
229
+ def save_test_split(self, test_df, output_dir):
230
+ output_dir = Path(output_dir)
231
+ output_dir.mkdir(exist_ok=True)
232
+ test_path = output_dir / 'test_ready.csv'
233
+
234
+ cols_to_drop = ['patient_id', 'study_id']
235
+ test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns])
236
+ test_clean.to_csv(test_path, index=False)
237
+
238
+ print(f"Test set : {test_path} ({len(test_clean)} images)")
239
+ return test_path
240
+
241
+ def verify_split_quality(self, train_df, val_df):
242
+ """
243
+ Verify the quality of the split by comparing label distributions.
244
+ """
245
+ print("\n" + "=" * 80)
246
+ print("Split Quality Verification")
247
+ print("=" * 80)
248
+
249
+ print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}")
250
+ print("-" * 80)
251
+
252
+ max_diff = 0
253
+ for pathology in self.PATHOLOGIES:
254
+ if pathology in train_df.columns:
255
+ train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100
256
+ val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100
257
+ diff = abs(train_pos - val_pos)
258
+ max_diff = max(max_diff, diff)
259
+
260
+ print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%")
261
+
262
+ print("-" * 80)
263
+ print(f"Maximum distribution difference: {max_diff:.2f}%")
264
+
265
+ if max_diff < 2.0:
266
+ print("✓ Excellent stratification (< 2% difference)")
267
+ elif max_diff < 5.0:
268
+ print("✓ Good stratification (< 5% difference)")
269
+ else:
270
+ print("⚠ Warning: Large distribution differences detected")
271
+
272
+ # Check for class imbalance
273
+ print("\n" + "=" * 80)
274
+ print("Class Imbalance Analysis (Train Set)")
275
+ print("=" * 80)
276
+
277
+ imbalance_ratios = []
278
+ for pathology in self.PATHOLOGIES:
279
+ if pathology in train_df.columns:
280
+ pos = (train_df[pathology] == 1).sum()
281
+ neg = (train_df[pathology] == 0).sum()
282
+ if pos > 0:
283
+ ratio = neg / pos
284
+ imbalance_ratios.append(ratio)
285
+ severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High"
286
+ print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]")
287
+
288
+ avg_imbalance = np.mean(imbalance_ratios)
289
+ print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1")
290
+
291
+ def save_splits(self, train_df, val_df, output_dir='.'):
292
+ """Save train and validation splits to CSV files."""
293
+ output_dir = Path(output_dir)
294
+ output_dir.mkdir(exist_ok=True)
295
+
296
+ train_path = output_dir / 'train_ready.csv'
297
+ val_path = output_dir / 'val_ready.csv'
298
+
299
+ # Remove temporary columns used for splitting
300
+ columns_to_drop = ['patient_id', 'study_id']
301
+ train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns])
302
+ val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns])
303
+
304
+ train_df_clean.to_csv(train_path, index=False)
305
+ val_df_clean.to_csv(val_path, index=False)
306
+
307
+ print("\n" + "=" * 80)
308
+ print("Files Saved Successfully")
309
+ print("=" * 80)
310
+ print(f"Train set: {train_path} ({len(train_df_clean)} images)")
311
+ print(f"Val set: {val_path} ({len(val_df_clean)} images)")
312
+
313
+ return train_path, val_path
314
+
315
+ # Main execution
316
+ if __name__ == "__main__":
317
+ root = "/content/drive/MyDrive"
318
+ # Configuration
319
+ CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv") # Adjust path as needed
320
+ OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small")
321
+ VAL_SIZE = 0.15
322
+ RANDOM_STATE = 42
323
+ USE_FRONTAL_ONLY = True
324
+ FILL_UNCERTAIN = 'zeros' # Options: 'zeros', 'ones', 'ignore'
325
+
326
+ # Create splitter
327
+ splitter = CheXpertDataSplitter(
328
+ csv_path=CHEXPERT_CSV,
329
+ val_size=VAL_SIZE,test_size=VAL_SIZE,
330
+ random_state=RANDOM_STATE,
331
+ use_frontal_only=USE_FRONTAL_ONLY,
332
+ fill_uncertain=FILL_UNCERTAIN,
333
+ root=OUTPUT_DIR
334
+ )
335
+
336
+ # Run the split
337
+ if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")):
338
+ train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")
339
+ val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")
340
+ test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
341
+ else:
342
+ train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR)
343
+
344
+ print("\nYou can now use these files with your CheXpertDataset class:")
345
+ print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)")
346
+ print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)")
347
+ print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)")
gitignore.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ # Python
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ *.so
7
+ .Python
8
+ env/
9
+ venv/
10
+ ENV/
11
+ .venv
12
+
13
+ # Jupyter Notebook
14
+ .ipynb_checkpoints
15
+ *.ipynb_checkpoints/
16
+
17
+ # PyTorch
18
+ *.ckpt
19
+ *.pth
20
+ weights/
21
+ runs/
22
+ lightning_logs/
23
+
24
+ # Data files (usually too large for GitHub)
25
+ *.csv
26
+ *.h5
27
+ *.hdf5
28
+ *.npy
29
+ *.npz
30
+ *.pkl
31
+ *.pickle
32
+ *.dcm
33
+ *.nii
34
+ *.nii.gz
35
+
36
+ # Models (often too large)
37
+ *.h5
38
+ *.pb
39
+ *.onnx
40
+ saved_models/
41
+
42
+ # IDE
43
+ .vscode/
44
+ .idea/
45
+ *.swp
46
+ *.swo
47
+
48
+ # OS
49
+ .DS_Store
50
+ Thumbs.db
51
+
52
+ # Environment variables
53
+ .env
54
+ .env.local
55
+
56
+ # Logs
57
+ *.log
58
+ logs/
59
+
60
+ # Weights & Biases (if you use it)
61
+ wandb/
loss/__init__.py ADDED
File without changes
loss/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (149 Bytes). View file
 
loss/__pycache__/assymetric.cpython-313.pyc ADDED
Binary file (2.98 kB). View file
 
loss/assymetric.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class AsymmetricLoss(nn.Module):
5
+ def __init__(self, gamma_neg=2, gamma_pos=1, clip=0.05, eps=1e-8, class_weights=None):
6
+ super().__init__()
7
+ self.gamma_neg = gamma_neg
8
+ self.gamma_pos = gamma_pos
9
+ self.clip = clip
10
+ self.eps = eps
11
+ if class_weights is not None:
12
+ self.register_buffer('class_weights', class_weights)
13
+ else:
14
+ self.class_weights = None
15
+
16
+ def forward(self, predictions, targets):
17
+ """
18
+ FIXED VERSION with better numerical stability
19
+ predictions: (B, 14) - sigmoid outputs (already applied!)
20
+ targets: (B, 14) - binary labels
21
+ """
22
+ try:
23
+ # CRITICAL FIX: Better clamping range
24
+ predictions = torch.clamp(predictions, min=self.eps, max=1 - self.eps)
25
+
26
+ # ===== POSITIVE SAMPLES =====
27
+ predictions_pos = torch.clamp(predictions - self.clip, min=self.eps)
28
+ focal_weight_pos = (1 - predictions_pos) ** self.gamma_pos
29
+
30
+ # FIX: Add small epsilon to prevent log(0)
31
+ loss_pos = targets * focal_weight_pos * torch.log(predictions_pos + self.eps)
32
+
33
+ # ===== NEGATIVE SAMPLES =====
34
+ focal_weight_neg = predictions ** self.gamma_neg
35
+
36
+ # FIX: Add small epsilon to prevent log(0)
37
+ loss_neg = (1 - targets) * focal_weight_neg * torch.log(1 - predictions + self.eps)
38
+
39
+ # ===== COMBINE =====
40
+ loss = -(loss_pos + loss_neg)
41
+
42
+ # Apply per-class weights
43
+ if self.class_weights is not None:
44
+ loss = loss * self.class_weights
45
+
46
+ # Average across batch and classes
47
+ loss = torch.mean(loss)
48
+
49
+ # CRITICAL: Check for NaN and return safe value
50
+ if torch.isnan(loss) or torch.isinf(loss):
51
+ raise ValueError("Loss is NaN or Inf")
52
+ except ValueError as e:
53
+ print("⚠️ WARNING: NaN/Inf detected in loss, returning safe value")
54
+ print(e)
55
+ print("predictions:", predictions)
56
+ print("targets:", targets)
57
+ import traceback
58
+ traceback.print_exc()
59
+ return torch.tensor(0.0, device=loss.device, requires_grad=True)
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (151 Bytes). View file
 
models/__pycache__/classifier.cpython-313.pyc ADDED
Binary file (17.6 kB). View file
 
models/__pycache__/densenet.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
models/__pycache__/mae.cpython-313.pyc ADDED
Binary file (13.6 kB). View file
 
models/classifier.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from models.mae import MaskedAutoEncoder
7
+ from models.densenet import DenseNet
8
+
9
+ class AttentionPool(nn.Module):
10
+ def __init__(self, dim=768, embed_dim=2048, num_heads=8):
11
+ super().__init__()
12
+ self.query = nn.Parameter(torch.randn(1, 1, dim))
13
+ self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
14
+ self.proj = nn.Linear(dim, embed_dim)
15
+
16
+ def forward(self, x): # x: (B, 576, 768)
17
+ B = x.size(0)
18
+ q = self.query.expand(B, -1, -1) # (B, 1, 768)
19
+ attn_out, _ = self.attn(q, x, x) # (B, 1, 768)
20
+ return self.proj(attn_out.squeeze(1)) # (B, 2048)
21
+
22
+ class CrossAttentionBlock(nn.Module):
23
+ """
24
+ Cross-attention: Query tokens attend to Key/Value tokens from another modality.
25
+ """
26
+ def __init__(self, dim_q, dim_kv, num_heads=8, dropout=0.1, proj_dim=None):
27
+ super().__init__()
28
+ self.proj_dim = proj_dim or dim_q
29
+ self.num_heads = num_heads
30
+ self.head_dim = self.proj_dim // num_heads
31
+ self.scale = self.head_dim ** -0.5
32
+
33
+ self.q_proj = nn.Linear(dim_q, self.proj_dim)
34
+ self.k_proj = nn.Linear(dim_kv, self.proj_dim)
35
+ self.v_proj = nn.Linear(dim_kv, self.proj_dim)
36
+ self.out_proj = nn.Linear(self.proj_dim, dim_q)
37
+
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.norm_q = nn.LayerNorm(dim_q)
40
+ self.norm_kv = nn.LayerNorm(dim_kv)
41
+
42
+ def forward(self, query, key_value):
43
+ B, N_q, _ = query.shape
44
+ N_kv = key_value.shape[1]
45
+
46
+ q = self.norm_q(query)
47
+ kv = self.norm_kv(key_value)
48
+
49
+ Q = self.q_proj(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
50
+ K = self.k_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
51
+ V = self.v_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
52
+
53
+ attn = (Q @ K.transpose(-2, -1)) * self.scale
54
+ attn = F.softmax(attn, dim=-1)
55
+ attn = self.dropout(attn)
56
+
57
+ out = (attn @ V).transpose(1, 2).reshape(B, N_q, self.proj_dim)
58
+ out = self.out_proj(out)
59
+
60
+ return query + self.dropout(out)
61
+
62
+
63
+ class BidirectionalCrossAttention(nn.Module):
64
+ """
65
+ Bidirectional: MAE attends to DenseNet AND DenseNet attends to MAE.
66
+ """
67
+ def __init__(self, mae_dim=768, dense_dim=2048, num_heads=8, dropout=0.1, proj_dim=512):
68
+ super().__init__()
69
+
70
+ # MAE queries DenseNet
71
+ self.mae_cross = CrossAttentionBlock(mae_dim, dense_dim, num_heads, dropout, proj_dim)
72
+ # DenseNet queries MAE
73
+ self.dense_cross = CrossAttentionBlock(dense_dim, mae_dim, num_heads, dropout, proj_dim)
74
+
75
+ # FFN blocks
76
+ self.mae_ffn = nn.Sequential(
77
+ nn.LayerNorm(mae_dim),
78
+ nn.Linear(mae_dim, mae_dim * 4),
79
+ nn.GELU(),
80
+ nn.Dropout(dropout),
81
+ nn.Linear(mae_dim * 4, mae_dim),
82
+ nn.Dropout(dropout)
83
+ )
84
+ self.dense_ffn = nn.Sequential(
85
+ nn.LayerNorm(dense_dim),
86
+ nn.Linear(dense_dim, dense_dim * 2),
87
+ nn.GELU(),
88
+ nn.Dropout(dropout),
89
+ nn.Linear(dense_dim * 2, dense_dim),
90
+ nn.Dropout(dropout)
91
+ )
92
+
93
+ def forward(self, mae_tokens, dense_tokens):
94
+ # Cross attention
95
+ mae_out = self.mae_cross(mae_tokens, dense_tokens)
96
+ dense_out = self.dense_cross(dense_tokens, mae_tokens)
97
+
98
+ # FFN with residual
99
+ mae_out = mae_out + self.mae_ffn(mae_out)
100
+ dense_out = dense_out + self.dense_ffn(dense_out)
101
+
102
+ return mae_out, dense_out
103
+ class LearnedLogitEnsemble(nn.Module):
104
+ def __init__(self, num_heads=7, num_classes=14, temperature_init=1.0, use_gate=False):
105
+ super().__init__()
106
+ self.num_classes = num_classes
107
+ self.num_heads = num_heads
108
+
109
+ # 1. Per-head temperature (very important!)
110
+ self.log_temps = nn.Parameter(torch.ones(num_heads) * math.log(temperature_init))
111
+
112
+ # 2. Learned head weights via tiny gating network (best version)
113
+ # Input = concatenated logits (or probs) → predicts soft weights
114
+ gate_input_dim = num_classes * num_heads # concatenating raw logits works best
115
+ self.use_gate = use_gate
116
+
117
+ if use_gate:
118
+ self.gate = nn.Sequential(
119
+ nn.Linear(gate_input_dim, 256),
120
+ nn.GELU(),
121
+ nn.LayerNorm(256),
122
+ nn.Dropout(0.1),
123
+ nn.Linear(256, num_heads),
124
+ )
125
+ else:
126
+ # Simpler: just learn fixed weights + L2 regularization later
127
+ self.raw_weights = nn.Parameter(torch.ones(num_heads))
128
+
129
+ def forward(self, logits_list):
130
+ """
131
+ logits_list: list/tuple of 7 tensors, each (B, 14)
132
+ """
133
+ B = logits_list[0].size(0)
134
+ device = logits_list[0].device
135
+
136
+ # Step 1: Temperature scaling per head
137
+ scaled_logits = []
138
+ for i, logits in enumerate(logits_list):
139
+ T = torch.exp(self.log_temps[i]) # >0 guaranteed
140
+ scaled_logits.append(logits / (T + 1e-8))
141
+
142
+ # Stack → (B, num_heads, num_classes)
143
+ stacked = torch.stack(scaled_logits, dim=1) # (B, 7, 14)
144
+
145
+ if self.use_gate:
146
+ # Step 2: Dynamic gating (sample-wise & class-wise aware)
147
+ gate_in = stacked.flatten(1) # (B, 7*14)
148
+ raw_gate = self.gate(gate_in) # (B, 7)
149
+ weights = torch.softmax(raw_gate, dim=-1).unsqueeze(-1) # (B,7,1)
150
+ else:
151
+ # Step 2: Fixed learned weights (still strong!)
152
+ weights = torch.softmax(self.raw_weights, dim=0) # (7,)
153
+ weights = weights.view(1, self.num_heads, 1).to(device) # (1,7,1)
154
+
155
+ # Step 3: Weighted average in logit space
156
+ fused_logits = (stacked * weights).sum(dim=1) # (B, 14)
157
+
158
+ return fused_logits
159
+ class XRAYClassifier(nn.Module):
160
+ def __init__(self, num_classes=14, c=1, mask_ratio=0, dropout=0.25, img_size=384,
161
+ encoder_dim=768, mlp_dim=3072, decoder_dim=512, encoder_depth=12,
162
+ encoder_head=8, decoder_depth=8, decoder_head=8, patch_size=8):
163
+ super().__init__()
164
+
165
+ # ---- MAE branch (frozen) ----
166
+ self.mae = MaskedAutoEncoder(
167
+ c=c, mask_ratio=0, dropout=dropout, img_size=img_size,
168
+ encoder_dim=encoder_dim, mlp_dim=mlp_dim, decoder_dim=decoder_dim,
169
+ encoder_depth=encoder_depth, encoder_head=encoder_head,
170
+ decoder_depth=decoder_depth, decoder_head=decoder_head, patch_size=patch_size
171
+ )
172
+ for p in self.mae.parameters():
173
+ p.requires_grad = False
174
+
175
+ self.token_ln = nn.LayerNorm(encoder_dim)
176
+ self.attn_selfpool_mae=AttentionPool(encoder_dim,1024)
177
+
178
+ # ---- DenseNet branch (pretrained by you) ----
179
+ # If your DenseNet supports 1 channel, set c=1 and remove the input duplication at forward.
180
+ self.dense = DenseNet(c=2, k=64, num_classes=num_classes)
181
+
182
+ self.dn_feat_dim = 2048
183
+
184
+ # ---- Cross-Attention Fusion (NEW) ----
185
+ self.cross_attn_layers = nn.ModuleList([
186
+ BidirectionalCrossAttention(
187
+ mae_dim=encoder_dim, # 768
188
+ dense_dim=self.dn_feat_dim, # 2048
189
+ num_heads=8,
190
+ dropout=0.1,
191
+ proj_dim=512
192
+ )
193
+ for _ in range(12)
194
+ ])
195
+
196
+ self.attn_pool_mae=AttentionPool(encoder_dim,1024)
197
+
198
+ self.classifier_mae=nn.Sequential(
199
+ nn.Linear(1024, 512),
200
+ nn.GELU(),
201
+ nn.Dropout(0.1),
202
+ nn.Linear(512, num_classes),
203
+ )
204
+
205
+ self.attn_pool_dense=AttentionPool(self.dn_feat_dim,1024)
206
+
207
+ self.classifier_attn=nn.Sequential(
208
+ nn.Linear(2048, 1024),
209
+ nn.GELU(),
210
+ nn.Dropout(0.2),
211
+ nn.Linear(1024, 512),
212
+ nn.GELU(),
213
+ nn.Dropout(0.1),
214
+ nn.Linear(512, num_classes),
215
+ )
216
+ #FPN
217
+ self.lateral5 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat4: 2048 ✅
218
+ self.lateral4 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat3: 2048 (CHANGED)
219
+ self.lateral3 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) # feat2: 1024 ✅
220
+ self.lateral2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) # feat1: 512 (CHANGED)
221
+ self.output5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
222
+ self.output4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
223
+ self.output3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
224
+ self.output2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
225
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
226
+
227
+ self._classify_out5 = nn.Linear(256, num_classes)
228
+ self._classify_out4 = nn.Linear(256, num_classes)
229
+ self._classify_out3 = nn.Linear(256, num_classes)
230
+ self._classify_out2 = nn.Linear(256, num_classes)
231
+
232
+ self.learned_logit_ensemble = LearnedLogitEnsemble(num_classes=num_classes)
233
+
234
+ def forward(self, x):
235
+ mae_tokens, _, _, _ = self.mae.encoder(x)
236
+ mae_tokens = self.token_ln(mae_tokens)
237
+ #self.generate_kmeans_mask(self.kmeans,mae_tokens,5)
238
+ doublex=torch.cat([x,x],dim=1) # [B, 2, 384, 384]
239
+ # ---- DenseNet path - Extract multi-scale features ----
240
+ xdense = self.dense.initialconv(doublex) # [B, 128, 192, 192]
241
+
242
+ # Layer 1 + ECA (BEFORE transition)
243
+ feat1 = self.dense.layer1(xdense)
244
+ feat1 = self.dense.dropout1(feat1)
245
+ feat1 = self.dense.eca1(feat1) # [B, 512, 192, 192] ← Keep this!
246
+ xdense1 = self.dense.trans1(feat1) # [B, 256, 96, 96]
247
+
248
+ # Layer 2 + ECA (BEFORE transition)
249
+ feat2 = self.dense.layer2(xdense1)
250
+ feat2 = self.dense.dropout2(feat2)
251
+ feat2 = self.dense.eca2(feat2) # [B, 1024, 96, 96] ← Keep this!
252
+ xdense2 = self.dense.trans2(feat2) # [B, 512, 48, 48]
253
+
254
+ # Layer 3 + ECA (BEFORE transition)
255
+ feat3 = self.dense.layer3(xdense2)
256
+ feat3 = self.dense.dropout3(feat3)
257
+ feat3 = self.dense.eca3(feat3) # [B, 2048, 48, 48] ← Keep this!
258
+ xdense3 = self.dense.trans3(feat3) # [B, 1024, 24, 24]
259
+
260
+ # Layer 4 (no transition)
261
+ feat4 = self.dense.layer4(xdense3)
262
+ feat4 = self.dense.dropout4(feat4)
263
+ feat4 = self.dense.eca4(feat4) # [B, 2048, 24, 24]
264
+ xdense4 = feat4
265
+
266
+ # Global pooling for DenseNet classifier
267
+ xdense_pooled = self.dense.global_average_pool(xdense4)
268
+ xdense_pooled = xdense_pooled.view(xdense_pooled.size(0), -1)
269
+ xdense_pooled = self.dense.dropout(xdense_pooled)
270
+ classifier_xdense = self.dense.classifier(xdense_pooled)
271
+
272
+ # Dense tokens for cross-attention
273
+ dense_tokens = xdense4.flatten(2).transpose(1, 2) # [B, 576, 2048]
274
+
275
+ # ---- FPN with CORRECT multi-scale features ----
276
+ c4 = self.lateral5(feat4) # [B, 2048, 24, 24] → [B, 256, 24, 24]
277
+ c3 = self.lateral4(feat3) # [B, 2048, 48, 48] → [B, 256, 48, 48]
278
+ c2 = self.lateral3(feat2) # [B, 1024, 96, 96] → [B, 256, 96, 96]
279
+ c1 = self.lateral2(feat1) # [B, 512, 192, 192] → [B, 256, 192, 192]
280
+
281
+ # Top-down pathway
282
+ p4 = c4 # 24×24
283
+ p4 = self.output5(p4)
284
+
285
+ p3 = self.upsample(p4) + c3 # 48×48 + 48×48 ✅
286
+ p3 = self.output4(p3)
287
+
288
+ p2 = self.upsample(p3) + c2 # 96×96 + 96×96 ✅
289
+ p2 = self.output3(p2)
290
+
291
+ p1 = self.upsample(p2) + c1 # 192×192 + 192×192 ✅
292
+ p1 = self.output2(p1)
293
+
294
+ # Classification heads
295
+ out4 = self._classify_out5(p4.mean([2, 3]))
296
+ out3 = self._classify_out4(p3.mean([2, 3]))
297
+ out2 = self._classify_out3(p2.mean([2, 3]))
298
+ out1 = self._classify_out2(p1.mean([2, 3]))
299
+
300
+ # ---- MAE path ----
301
+
302
+
303
+ mae_tokens_pooled = self.attn_selfpool_mae(mae_tokens)
304
+ classifier_mae = self.classifier_mae(mae_tokens_pooled)
305
+
306
+ # ---- Cross attention ----
307
+ for cross_layer in self.cross_attn_layers:
308
+ mae_cross, dense_cross = cross_layer(mae_tokens, dense_tokens)
309
+
310
+ mae_cross = self.attn_pool_mae(mae_cross)
311
+ dense_cross = self.attn_pool_dense(dense_cross)
312
+ out = torch.cat([mae_cross, dense_cross], dim=1)
313
+ classifier_attn = self.classifier_attn(out)
314
+
315
+ # ---- Ensemble ----
316
+ merged_classifier = self.learned_logit_ensemble([
317
+ classifier_mae,
318
+ classifier_xdense,
319
+ classifier_attn,
320
+ out4, out3, out2, out1 # 7 heads
321
+ ])
322
+
323
+ return merged_classifier
models/densenet.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ class ChannelAttention(nn.Module):
6
+ def __init__(self,channels,reduction=16):
7
+ super().__init__()
8
+ self.conv1=nn.Conv2d(channels,channels//reduction,kernel_size=1,bias=False)
9
+ self.relu=nn.ReLU(inplace=True)
10
+ self.conv2=nn.Conv2d(channels//reduction,channels,kernel_size=1,bias=False)
11
+ self.sigmoid=nn.Sigmoid()
12
+ self.avgpool=nn.AdaptiveAvgPool2d((1,1))
13
+ self.maxpool=nn.AdaptiveMaxPool2d((1,1))
14
+ def forward(self,x):
15
+ identity=x
16
+ avgpool=self.avgpool(x)
17
+ maxpool=self.maxpool(x)
18
+ avgpool=self.relu(self.conv1(avgpool))
19
+ maxpool=self.relu(self.conv1(maxpool))
20
+ avgpool=self.conv2(avgpool)
21
+ maxpool=self.conv2(maxpool)
22
+ out=self.sigmoid(avgpool+maxpool)
23
+ return identity*out
24
+
25
+ class SpatialAttention(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
29
+
30
+ def forward(self, x):
31
+ max_pool = torch.max(x, dim=1, keepdim=True)[0]
32
+ avg_pool = torch.mean(x, dim=1, keepdim=True)
33
+ attention = torch.cat([max_pool, avg_pool], dim=1)
34
+ attention = torch.sigmoid(self.conv(attention))
35
+ return x * attention
36
+
37
+ class CBAM(nn.Module):
38
+ def __init__(self,channels):
39
+ super().__init__()
40
+ self.ca=ChannelAttention(channels)
41
+ self.sa=SpatialAttention()
42
+
43
+ def forward(self,x):
44
+ x=self.ca(x)
45
+ x=self.sa(x)
46
+ return x
47
+
48
+ class InitialConv(nn.Module):
49
+ def __init__(self,input_channel=1,k=64):
50
+ super().__init__()
51
+ self.conv=nn.Conv2d(in_channels=input_channel,out_channels=2*k,kernel_size=7,stride=1,padding=3) # from B,1,384,384 to #B,128,384,384
52
+ self.bn=nn.BatchNorm2d(num_features=2*k)
53
+ self.relu=nn.ReLU(inplace=True)
54
+ self.pool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1) #from 384 to 256 #output B,128,192,192
55
+ def forward(self,x):
56
+ return self.pool(self.relu(self.bn(self.conv(x))))
57
+
58
+ class DenseLayer(nn.Module):
59
+ def __init__(self,c,k=64):
60
+ super().__init__()
61
+ self.bn1=nn.BatchNorm2d(num_features=c)
62
+ self.relu1=nn.ReLU(inplace=True)
63
+ self.conv1x1=nn.Conv2d(c,4*k,kernel_size=1)
64
+ self.bn2=nn.BatchNorm2d(num_features=4*k)
65
+ self.relu2=nn.ReLU(inplace=True)
66
+ self.conv3x3=nn.Conv2d(4*k,k,kernel_size=3, padding=1)
67
+ def forward(self,x):
68
+ identity=x
69
+ x=self.conv1x1(self.relu1(self.bn1(x)))
70
+ x=self.conv3x3(self.relu2(self.bn2(x)))
71
+ return torch.cat([identity,x],dim=1)
72
+
73
+ class DenseBlock(nn.Module):
74
+ def __init__(self,c,k=64,layer_len=6):
75
+ super().__init__()
76
+ self.blks=nn.ModuleList()
77
+ current_c = c
78
+ for _ in range(layer_len):
79
+ self.blks.append(DenseLayer(current_c, k))
80
+ current_c += k
81
+ def forward(self,x):
82
+ for layer in self.blks:x=checkpoint(layer, x,use_reentrant=False)
83
+ return x
84
+
85
+ class Transition(nn.Module):
86
+ def __init__(self,inchannels,down_factor=0.5):
87
+ super().__init__()
88
+ self.bn=nn.BatchNorm2d(num_features=inchannels)
89
+ self.relu=nn.ReLU(inplace=True)
90
+ self.conv1x1=nn.Conv2d(in_channels=inchannels,out_channels=int(down_factor*inchannels),kernel_size=1)
91
+ self.avgpool=nn.AvgPool2d(kernel_size=2,stride=2)
92
+ def forward(self,x):
93
+ return self.avgpool(self.conv1x1(self.relu(self.bn(x))))
94
+ class DenseNet(nn.Module):
95
+ def __init__(self,c=2,k=64,num_classes=14):
96
+ super().__init__()
97
+ self.initialconv=InitialConv(input_channel=c,k=k) #output B,128,192,192
98
+ self.layer1=DenseBlock(c=128,k=k,layer_len=6) #output B,inchannels+(layer_len*k),192,192 i.e # B,512,192,192
99
+ self.dropout1 = nn.Dropout(p=0.05)
100
+ self.eca1=CBAM(512)
101
+ self.trans1=Transition(inchannels=512,down_factor=0.5) #output B,256,96,96
102
+ self.layer2=DenseBlock(c=256,k=k,layer_len=12) #output B,inchannels+(layer_len*k),96,96 i.e # B,1024,96,96
103
+ self.dropout2 = nn.Dropout(p=0.1)
104
+ self.eca2=CBAM(1024)
105
+ self.trans2=Transition(inchannels=1024,down_factor=0.5) #output B,512,48,48
106
+ self.layer3=DenseBlock(c=512,k=k,layer_len=24) #output B,inchannels+(layer_len*k),48,48 i.e # B,2048,48,48
107
+ self.dropout3 = nn.Dropout(p=0.1)
108
+ self.eca3=CBAM(2048)
109
+ self.trans3=Transition(inchannels=2048,down_factor=0.5) #output B,1024,24,24
110
+ self.layer4=DenseBlock(c=1024,k=k,layer_len=16) #output B,inchannels+(layer_len*k),24,24 i.e # B,2048,24,24
111
+ self.dropout4 = nn.Dropout(p=0.1)
112
+ self.eca4=CBAM(2048)
113
+ self.global_average_pool= nn.AdaptiveAvgPool2d((1,1)) #output B,2048,1,1
114
+ self.classifier = nn.Sequential(
115
+ nn.Linear(2048, 1024),
116
+ nn.BatchNorm1d(1024),
117
+ nn.ReLU(),
118
+ nn.Dropout(0.1),
119
+ nn.Linear(1024, 512),
120
+ nn.BatchNorm1d(512),
121
+ nn.ReLU(),
122
+ nn.Dropout(0.1),
123
+ nn.Linear(512, 256),
124
+ nn.BatchNorm1d(256),
125
+ nn.ReLU(),
126
+ nn.Dropout(0.1),
127
+ nn.Linear(256, num_classes)
128
+ )
129
+ self.dropout = nn.Dropout(p=0.2)
130
+ for lay in self.classifier:
131
+ if isinstance(lay, nn.Linear):
132
+ nn.init.xavier_uniform_(lay.weight, gain=1.0)
133
+ nn.init.constant_(lay.bias, 0.0)
134
+ for m in self.modules():
135
+ if isinstance(m, nn.Conv2d):
136
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
137
+
138
+ def forward(self,x):
139
+ x=self.initialconv(x)
140
+ x=self.trans1(self.eca1(self.dropout1(self.layer1(x))))
141
+ x=self.trans2(self.eca2(self.dropout2(self.layer2(x))))
142
+ x=self.trans3(self.eca3(self.dropout3(self.layer3(x))))
143
+ x=self.eca4(self.dropout4(self.layer4(x)))
144
+ #x1=self.attn(x)
145
+ x=self.global_average_pool(x)
146
+ x=x.view(x.size(0),-1)
147
+ #x=torch.cat([x1,x2],dim=1)
148
+ x=self.dropout(x)
149
+ x=self.classifier(x)
150
+ return x
151
+
152
+ @staticmethod
153
+ def testme():
154
+ model=DenseNet()
155
+ sample=torch.randn(2,2,384,384)
156
+ out=model(sample)
157
+ print(out.shape)
models/mae.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ def patchify(x,patch_size=8):
5
+ b,c,h,w=x.shape
6
+ th=h//patch_size
7
+ tw=w//patch_size
8
+ assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size"
9
+
10
+ out=x.reshape(b,c,th,patch_size,tw,patch_size)
11
+ out=out.permute(0,2,4,1,3,5).contiguous()
12
+ out=out.view(b,th*tw,c*(patch_size**2))
13
+ return out
14
+ def unpatchify(x,patch_size=8):
15
+ b,z,p=x.shape
16
+ c=p//(patch_size**2)
17
+ th=int(math.sqrt(z))
18
+ tw=th
19
+ h=th*patch_size
20
+ w=tw*patch_size
21
+ x=x.view(b,th,tw,c,patch_size,patch_size)
22
+ x=x.permute(0,3,1,4,2,5).contiguous()
23
+ out=x.view(b,c,h,w)
24
+ return out
25
+ def random_mask(x,mask_ratio=0.75):
26
+ b,n,p=x.shape
27
+ len_keep=int(n*(1-mask_ratio))
28
+ noise=torch.rand(b,n).to(x.device)
29
+ ids_shuffle=torch.argsort(noise,dim=1)
30
+ ids_restore=torch.argsort(ids_shuffle,dim=1)
31
+ ids_keep=ids_shuffle[:,:len_keep]
32
+ x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device)
33
+ mask=torch.ones(b,n).to(x.device)
34
+ mask[:,:len_keep]=0
35
+ mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device)
36
+ return x_masked,mask,ids_restore,ids_keep
37
+
38
+ def mae_loss(pred, target, mask):
39
+ # pred/target: (B, N, P), mask: (B, N) with 1=masked
40
+ B, N, P = pred.shape
41
+ mask = mask.unsqueeze(-1).float() # (B, N, 1)
42
+ loss = (pred - target) ** 2
43
+ loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
44
+ return loss
45
+
46
+ class PositionalEncoding(nn.Module):
47
+ def __init__(self,num_patches,hidden_dim=768):
48
+ super().__init__()
49
+ self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim))
50
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
51
+ def forward(self, x, visible_indices):
52
+ # x: (B, len_keep, D); visible_indices: (B, len_keep)
53
+ B, L, D = x.shape
54
+ # expand table to (B, N, D)
55
+ pos = self.pos_embed.expand(B, -1, -1) # (B, N, D)
56
+ # build gather index (B, L, D)
57
+ idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1))
58
+ visible_pos = torch.gather(pos, 1, idx) # (B, L, D)
59
+ return x + visible_pos
60
+
61
+ class TransformerBlock(nn.Module):
62
+ def __init__(self,hidden_dim,mlp_dim,num_heads,dropout):
63
+ super().__init__()
64
+ self.layernorm1=nn.LayerNorm(hidden_dim)
65
+ self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout)
66
+ self.layernorm2=nn.LayerNorm(hidden_dim)
67
+ self.mlp=nn.Sequential(
68
+ nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout)
69
+ )
70
+
71
+
72
+ def forward(self,x):
73
+ residual=x
74
+ x=self.layernorm1(x)
75
+ attn,_=self.multihead(x,x,x)
76
+ x=residual+attn
77
+ residual=x
78
+ x=self.layernorm2(x)
79
+ x=self.mlp(x)
80
+ x=residual+x
81
+ return x
82
+
83
+ class MAEEncoder(nn.Module):
84
+ """
85
+ patch_dim-> % non-masked * no. of patches
86
+ """
87
+ def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8):
88
+ super().__init__()
89
+ self.mask_ratio=mask_ratio
90
+ self.patch_size=patch_size
91
+ self.patch_embed=nn.Linear(patch_dim,hidden_dim)
92
+ self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim)
93
+ self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
94
+ for _ in range(depth)])
95
+
96
+ self._init_weights()
97
+ def _init_weights(self):
98
+ for m in self.modules():
99
+ if isinstance(m, nn.Linear):
100
+ nn.init.trunc_normal_(m.weight, std=0.02)
101
+ if m.bias is not None:
102
+ nn.init.constant_(m.bias, 0)
103
+
104
+ def forward(self,x_in):
105
+ x_p=patchify(x_in,self.patch_size)
106
+ x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio)
107
+ x= self.patch_embed(x_masked)
108
+ x=self.pos_embed(x,ids_keep)
109
+ for attn_layer in self.transformer:x=attn_layer(x)
110
+ return x,mask,ids_keep,ids_restore
111
+
112
+ class MAEDecoder(nn.Module):
113
+ def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout):
114
+ super().__init__()
115
+ self.num_patches=num_patches
116
+ self.encoder_dim=encoder_dim
117
+ self.decoder_dim=decoder_dim
118
+ self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim))
119
+ self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim)
120
+ self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim))
121
+ self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
122
+ for _ in range(decoder_depth)])
123
+ self.layernorm=nn.LayerNorm(decoder_dim)
124
+ self.pred=nn.Linear(decoder_dim,c*(patch_size**2))
125
+
126
+ self._init_weights()
127
+ def _init_weights(self):
128
+ for m in self.modules():
129
+ if isinstance(m, nn.Linear):
130
+ nn.init.trunc_normal_(m.weight, std=0.02)
131
+ if m.bias is not None:
132
+ nn.init.constant_(m.bias, 0)
133
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
134
+ nn.init.trunc_normal_(self.mask_token, std=0.02)
135
+ def forward(self,x,ids_keep,ids_restore):
136
+ b,n,p=x.shape
137
+ xdec=self.enc_to_dec(x)
138
+ len_keep=xdec.size(1)
139
+ num_patches=ids_restore.size(1)
140
+ num_mask=num_patches-len_keep
141
+
142
+ mask_token=self.mask_token.expand(b,num_mask,-1)
143
+ x_=torch.cat([xdec,mask_token],dim=1)
144
+ x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1)))
145
+ x_=x_+self.pos_embed
146
+ for block in self.transformer:x_=block(x_)
147
+ x_=self.layernorm(x_)
148
+ out=self.pred(x_)
149
+ return out
150
+
151
+ class MaskedAutoEncoder(nn.Module):
152
+ def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8):
153
+ super().__init__()
154
+ self.patch_size=patch_size
155
+ self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2
156
+ ,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head
157
+ ,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size)
158
+ self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size
159
+ ,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth
160
+ ,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout)
161
+
162
+ def forward(self,x):
163
+ b,c,h,w=x.shape
164
+ encoded,mask,ids_keep,ids_restore=self.encoder(x)
165
+ decoded=self.decoder(encoded,ids_keep,ids_restore)
166
+
167
+ xpatched=patchify(x,self.patch_size)
168
+ return xpatched,decoded,mask
169
+
170
+ @staticmethod
171
+ def testme():
172
+ img=torch.rand(1,1,384,384)
173
+ mae=MaskedAutoEncoder()
174
+ a,b,c=mae(img)
175
+ print(a.shape)
176
+ print(b.shape)
177
+ print(c.shape)
notebooks/chexpert_mae.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/chexpert_mae_mask_classifier.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Deep Learning
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+
5
+ # Data Processing
6
+ numpy>=1.24.0
7
+ pandas>=2.0.0
8
+ scikit-learn>=1.3.0
9
+
10
+ # Image Processing
11
+ Pillow>=10.0.0
12
+ opencv-python>=4.8.0
13
+ albumentations>=1.3.1
14
+
15
+ # Visualization
16
+ matplotlib>=3.7.0
17
+ seaborn>=0.12.0
18
+
19
+ # Utilities
20
+ tqdm>=4.65.0
21
+
22
+ # Jupyter (optional - for notebooks)
23
+ jupyter>=1.0.0
24
+ ipykernel>=6.25.0
25
+ ipywidgets>=8.1.0
26
+
27
+ # Additional utilities (if needed)
28
+ # lmdb>=1.4.0 # Uncomment if using LMDB for caching
29
+ # tensorboard>=2.13.0 # Uncomment if using TensorBoard logging
results/test-results.docx ADDED
Binary file (7.54 kB). View file
 
trainer/__init__.py ADDED
File without changes
trainer/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (152 Bytes). View file
 
trainer/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (154 Bytes). View file
 
trainer/__pycache__/trainer.cpython-313.pyc ADDED
Binary file (1.01 kB). View file
 
trainer/__pycache__/trainer.cpython-314.pyc ADDED
Binary file (713 Bytes). View file
 
trainer/__pycache__/utils.cpython-313.pyc ADDED
Binary file (48.4 kB). View file
 
trainer/test.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import Trainer
2
+ from configs.configs import root,config
3
+
4
+
5
+ def main():
6
+ print("Testing classifier")
7
+ try:
8
+ tester=Trainer(config)
9
+ tester.test(model_path=config["resume"])
10
+ except:
11
+ import traceback
12
+ traceback.print_exc()
13
+
14
+
15
+ if __name__=="__main__":main()
trainer/trainer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import *
2
+ from configs.configs import root,config,mae_config
3
+
4
+ def main():
5
+ try:
6
+ decision=input("train mae or classifier? ")
7
+ if decision=="mae":
8
+ print(f"Training mae")
9
+ trainer=MAETrainer(mae_config)
10
+ trainer.train()
11
+ if decision=="classifier":
12
+ print(f"Training classifier")
13
+ trainer=Trainer(config)
14
+ trainer.train()
15
+ except:
16
+ import traceback
17
+ traceback.print_exc()
18
+
19
+ if __name__=="__main__":main()
trainer/utils.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import CheXpertDataset
2
+ from loss.assymetric import AsymmetricLoss
3
+ from models.mae import *
4
+ from models.densenet import *
5
+ from models.classifier import *
6
+ from torch.utils.data import DataLoader
7
+ import json
8
+ import os
9
+ import io
10
+ import sys
11
+ from sklearn.metrics import roc_auc_score,confusion_matrix
12
+
13
+ class TeeFile:
14
+ """
15
+ File-like object that writes to multiple streams (e.g., stdout and a file)
16
+ Automatically handles string paths by opening them as files.
17
+
18
+ Usage:
19
+ # This now works with both file objects and paths
20
+ tee = TeeFile(sys.stdout, "/path/to/log.txt")
21
+ print("Hello", file=tee) # Writes to both stdout and the file
22
+ """
23
+ def __init__(self, *file_objects_or_paths):
24
+ """
25
+ Args:
26
+ *file_objects_or_paths: Mix of file objects (like sys.stdout)
27
+ or string paths to log files
28
+ """
29
+ self.files = []
30
+ self.opened_files = [] # Track files we opened so we can close them later
31
+
32
+ for item in file_objects_or_paths:
33
+ if isinstance(item, str):
34
+ # It's a path string - open it as a file
35
+ f = open(item, 'a', buffering=1) # Append mode, line buffered
36
+ self.files.append(f)
37
+ self.opened_files.append(f)
38
+ else:
39
+ # It's already a file-like object (e.g., sys.stdout)
40
+ self.files.append(item)
41
+
42
+ def write(self, data):
43
+ """Write data to all streams"""
44
+ for f in self.files:
45
+ try:
46
+ f.write(data)
47
+ f.flush()
48
+ except Exception as e:
49
+ # Handle closed file gracefully
50
+ print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
51
+
52
+ def flush(self):
53
+ """Flush all streams"""
54
+ for f in self.files:
55
+ try:
56
+ f.flush()
57
+ except:
58
+ pass
59
+
60
+ def isatty(self):
61
+ """Check if any stream is a terminal (for tqdm compatibility)"""
62
+ return any(getattr(f, "isatty", lambda: False)() for f in self.files)
63
+
64
+ def fileno(self):
65
+ """Get file descriptor from any real file-like stream"""
66
+ for f in self.files:
67
+ if hasattr(f, "fileno"):
68
+ try:
69
+ return f.fileno()
70
+ except Exception:
71
+ pass
72
+ raise io.UnsupportedOperation("No fileno available")
73
+
74
+ def close(self):
75
+ """Close any files we opened"""
76
+ for f in self.opened_files:
77
+ try:
78
+ f.close()
79
+ except:
80
+ pass
81
+ self.opened_files.clear()
82
+
83
+ def __del__(self):
84
+ """Cleanup on deletion"""
85
+ self.close()
86
+
87
+ def __enter__(self):
88
+ """Context manager support"""
89
+ return self
90
+
91
+ def __exit__(self, exc_type, exc_val, exc_tb):
92
+ """Context manager cleanup"""
93
+ self.close()
94
+ return False
95
+
96
+ class MAETrainer:
97
+ def __init__(self,configs={}):
98
+
99
+ self.configs=configs
100
+ os.makedirs(configs["logdir"],exist_ok=True)
101
+ log_path_train = os.path.join(configs["logdir"], "training_log.txt")
102
+ log_path_val = os.path.join(configs["logdir"], "val_log.txt")
103
+ log_path_test = os.path.join(configs["logdir"], "test_log.txt")
104
+ #self.log_file = open(log_path, 'w', buffering=1)
105
+ self.traintee = TeeFile(sys.stdout, log_path_train)
106
+ self.valtee = TeeFile(sys.stdout, log_path_val)
107
+ self.testtee = TeeFile(sys.stdout, log_path_test)
108
+
109
+ for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
110
+
111
+ self.model=MaskedAutoEncoder(
112
+ c=configs["channels"],
113
+ mask_ratio=configs["mask_ratio"],
114
+ dropout=configs["dropout"],
115
+ img_size=configs["img_size"],
116
+ encoder_dim=configs["encoder_dim"],
117
+ mlp_dim=configs["mlp_dim"],
118
+ decoder_dim=configs["decoder_dim"],
119
+ encoder_depth=configs["encoder_depth"],
120
+ encoder_head=configs["encoder_head"],
121
+ decoder_depth=configs["decoder_depth"],
122
+ decoder_head=configs["decoder_head"],
123
+ patch_size=configs["patch_size"]
124
+ ).to(configs["device"])
125
+
126
+ self.criterion=mae_loss
127
+
128
+ self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
129
+ self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
130
+ self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
131
+ self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
132
+ self.scaler=torch.amp.GradScaler()
133
+
134
+ self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
135
+ self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
136
+ self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
137
+ self.sample_Weights=self.train_dataset.get_sample_weights()
138
+ self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
139
+ self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
140
+ self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
141
+ self.history={"train_loss":[],"val_loss":[]}
142
+
143
+ self.current_epoch=0
144
+
145
+ if os.path.exists(self.configs["resume"]):
146
+ loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
147
+ self.model.load_state_dict(loadedpickle["model"],strict=False)
148
+ self.optimizer.load_state_dict(loadedpickle["optimizer"])
149
+ self.schedular.load_state_dict(loadedpickle["schedular"])
150
+ self.schedular1.load_state_dict(loadedpickle["schedular1"])
151
+ self.schedular2.load_state_dict(loadedpickle["schedular2"])
152
+ self.scaler.load_state_dict(loadedpickle["scaler"])
153
+ self.current_epoch=loadedpickle["epoch"]+1
154
+
155
+
156
+
157
+ self.test_dataset = None
158
+ self.testloader = None
159
+ if configs.get("test_csv"):
160
+ self.test_dataset = CheXpertDataset(
161
+ zip_path=configs["zip_path"],
162
+ csv_path=configs["test_csv"],
163
+ root_dir=configs["datadir"],
164
+ augment=False,
165
+ use_frontal_only=True
166
+ )
167
+ self.testloader = DataLoader(
168
+ self.test_dataset,
169
+ batch_size=configs["batch_size"],
170
+ shuffle=False,
171
+ num_workers=8,
172
+ pin_memory=True,
173
+ persistent_workers=True
174
+ )
175
+ print(f"Test loader ready – {len(self.test_dataset)} images")
176
+
177
+ torch.backends.cudnn.benchmark = True
178
+ torch.backends.cudnn.enabled = True
179
+
180
+ # FIX: Set memory allocator settings
181
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
182
+
183
+ # FIX: Enable gradient checkpointing if model supports it
184
+ if hasattr(self.model, 'enable_gradient_checkpointing'):
185
+ self.model.enable_gradient_checkpointing()
186
+ @staticmethod
187
+ def plot_training_metrics(metrics, epoch,figs_path):
188
+ import matplotlib.pyplot as plt
189
+ """
190
+ Plot loss and AUC curves from training metrics.
191
+
192
+ Args:
193
+ metrics (dict): Dictionary containing lists for each metric key:
194
+ {
195
+ "train_loss": [...],
196
+ "val_loss": [...]
197
+ }
198
+ epoch (int): Current epoch number (used for title or axis scaling)
199
+ """
200
+ epochs = list(range(1, epoch + 1))
201
+
202
+ #Compute the common length across all series
203
+ keys = ["train_loss","val_loss"]
204
+ lengths = [len(metrics[k]) for k in keys if k in metrics]
205
+ if not lengths:
206
+ return
207
+ n = min(lengths)
208
+
209
+ # Slice everything to the same length
210
+ m = {k: metrics[k][:n] for k in keys if k in metrics}
211
+ epochs = list(range(1, n + 1))
212
+
213
+ plt.figure(figsize=(14, 6))
214
+
215
+
216
+ # ---- Loss subplot ----
217
+ plt.subplot(1, 2, 1)
218
+ plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
219
+ plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
220
+ plt.xlabel("Epoch")
221
+ plt.ylabel("Loss")
222
+ plt.title("Training & Validation Loss")
223
+ plt.legend()
224
+ plt.grid(True, linestyle='--', alpha=0.6)
225
+
226
+
227
+ plt.tight_layout()
228
+ os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
229
+ plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
230
+ plt.show()
231
+
232
+ def train_epoch(self, epoch, looper):
233
+ self.model.train()
234
+ running_loss = 0.0
235
+ all_preds = []
236
+ all_targets = []
237
+ current_loss=0
238
+ total_batches = len(self.trainloader)
239
+
240
+ for batch_idx, data in looper:
241
+ image = data['image'].to(self.configs["device"], non_blocking=True)
242
+ target = data['labels'].to(self.configs["device"], non_blocking=True)
243
+
244
+ with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
245
+ img,preds,mask = self.model(image)
246
+ loss = self.criterion(img,preds,mask)
247
+
248
+ loss_back = loss / self.configs["accumulation"]
249
+ running_loss += loss.item()
250
+
251
+ if torch.isfinite(loss):
252
+ #loss_back.backward()
253
+ self.scaler.scale(loss_back).backward()
254
+ else:
255
+ self.optimizer.zero_grad(set_to_none=True)
256
+ continue
257
+
258
+ if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
259
+ self.scaler.unscale_(self.optimizer)
260
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
261
+ self.scaler.step(self.optimizer)
262
+ self.scaler.update()
263
+ #self.optimizer.step()
264
+ self.schedular.step()
265
+ self.optimizer.zero_grad(set_to_none=True)
266
+
267
+
268
+ # === LIVE METRICS (every batch) ===
269
+ current_loss = running_loss / (batch_idx + 1)
270
+ if (batch_idx + 1) % 10 == 0:
271
+ current_lr = self.optimizer.param_groups[0]['lr']
272
+ looper.set_postfix({
273
+ "lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
274
+ "epoch": f"{epoch}/{self.configs['num_epochs']}",
275
+ "loss": f"{current_loss:.3f}",
276
+ })
277
+
278
+ return current_loss
279
+ def validate(self, epoch, looper):
280
+ self.model.eval()
281
+ val_loss = 0.0
282
+ all_preds = []
283
+ all_targets = []
284
+ lenloader=len(self.valloader)
285
+ current_loss=0
286
+ with torch.no_grad():
287
+ for batch_idx, data in looper:
288
+ image = data["image"].to(self.configs["device"], non_blocking=True)
289
+ target = data["labels"].to(self.configs["device"], non_blocking=True)
290
+
291
+ with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
292
+ img,preds,mask = self.model(image)
293
+ loss = self.criterion(img,preds,mask)
294
+
295
+ val_loss += loss.item()
296
+
297
+ # === LIVE METRICS ===
298
+ current_loss = val_loss / (batch_idx + 1)
299
+ if (batch_idx + 1) % 10 == 0 :
300
+
301
+ looper.set_postfix({
302
+ "epoch": f"{epoch}/{self.configs['num_epochs']}",
303
+ "batch":f"{batch_idx}/{lenloader}",
304
+ "loss": f"{current_loss:.3f}",
305
+ })
306
+
307
+ return current_loss
308
+ def train(self):
309
+
310
+ for epoch in range(self.current_epoch,self.configs["num_epochs"]):
311
+ trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
312
+ vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
313
+
314
+
315
+ self.model.train()
316
+ self.optimizer.zero_grad(set_to_none=True)
317
+
318
+ running_loss=self.train_epoch(epoch,trainlooper)
319
+
320
+ torch.cuda.synchronize()
321
+ torch.cuda.empty_cache()
322
+
323
+ val_loss=self.validate(epoch,vallooper)
324
+
325
+ torch.cuda.synchronize()
326
+ torch.cuda.empty_cache()
327
+
328
+ gc.collect()
329
+
330
+ if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
331
+ checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
332
+ torch.save(checkpoint, self.configs["resume"])
333
+
334
+ print(f"train loss {running_loss} val loss {val_loss}")
335
+
336
+ self.history["train_loss"].append(float(running_loss))
337
+ self.history["val_loss"].append(float(val_loss))
338
+
339
+ if epoch%10==0:
340
+ historyfile=os.path.join(self.configs["logdir"],"history.json")
341
+ if os.path.exists(historyfile):
342
+ with open(historyfile,"r") as f:
343
+ history=json.load(f)
344
+ history["train_loss"]+=self.history["train_loss"]
345
+ history["val_loss"]+=self.history["val_loss"]
346
+ with open(historyfile,"w") as f:
347
+ json.dump(self.history,f)
348
+ f.close()
349
+ MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
350
+
351
+ self.current_epoch=epoch
352
+
353
+ class Trainer:
354
+ def __init__(self,configs={}):
355
+
356
+ self.configs=configs
357
+ os.makedirs(configs["logdir"],exist_ok=True)
358
+ log_path_train = os.path.join(configs["logdir"], "training_log.txt")
359
+ log_path_val = os.path.join(configs["logdir"], "val_log.txt")
360
+ log_path_test = os.path.join(configs["logdir"], "test_log.txt")
361
+ #self.log_file = open(log_path, 'w', buffering=1)
362
+ self.traintee = TeeFile(sys.stdout, log_path_train)
363
+ self.valtee = TeeFile(sys.stdout, log_path_val)
364
+ self.testtee = TeeFile(sys.stdout, log_path_test)
365
+
366
+ for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
367
+
368
+ self.model=XRAYClassifier(
369
+ c=configs["channels"],
370
+ num_classes=configs["num_classes"],
371
+ mask_ratio=configs["mask_ratio"],
372
+ dropout=configs["dropout"],
373
+ img_size=configs["img_size"],
374
+ encoder_dim=configs["encoder_dim"],
375
+ mlp_dim=configs["mlp_dim"],
376
+ decoder_dim=configs["decoder_dim"],
377
+ encoder_depth=configs["encoder_depth"],
378
+ encoder_head=configs["encoder_head"],
379
+ decoder_depth=configs["decoder_depth"],
380
+ decoder_head=configs["decoder_head"],
381
+ patch_size=configs["patch_size"]
382
+ ).to(configs["device"])
383
+
384
+
385
+
386
+ self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
387
+ self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
388
+ self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
389
+ self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
390
+ self.scaler=torch.amp.GradScaler()
391
+
392
+ self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True,mask_dir=configs["maskdir"])
393
+ self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True,mask_dir=configs["maskdir"] )
394
+
395
+ self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
396
+ self.sample_Weights=self.train_dataset.get_sample_weights()
397
+ self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
398
+ self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=0,pin_memory=True,persistent_workers=False)
399
+ self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=0,pin_memory=True,persistent_workers=False)
400
+ self.criterion=AsymmetricLoss(class_weights=self.class_Weights).to(self.configs["device"])
401
+ self.history={"train_loss":[],"val_loss":[],"train_macro_auc":[],"val_macro_auc":[],"train_micro_auc":[],"val_micro_auc":[]}
402
+ if os.path.exists(os.path.join(self.configs["logdir"],"history.json")):
403
+ with open(os.path.join(self.configs["logdir"],"history.json"),'r') as hf:
404
+ self.history=json.load(hf)
405
+ hf.close()
406
+ self.current_epoch=0
407
+
408
+ self.optimal_thresholds =[0.5]*14
409
+
410
+ if os.path.exists(self.configs["resume"]):
411
+ ckpt = torch.load(self.configs["resume"], map_location=self.configs["device"],weights_only=False)
412
+ self.model.load_state_dict(ckpt["model"], strict=False)
413
+ self.optimizer.load_state_dict(ckpt["optimizer"])
414
+ self.schedular.load_state_dict(ckpt["schedular"])
415
+ self.schedular1.load_state_dict(ckpt["schedular1"])
416
+ self.schedular2.load_state_dict(ckpt["schedular2"])
417
+ self.scaler.load_state_dict(ckpt["scaler"])
418
+ self.current_epoch = ckpt.get("epoch", -1) + 1
419
+ self.optimal_thresholds =ckpt.get("thresholds")
420
+ else:
421
+ # Load MAE backbone only (pretrained)
422
+ bb = torch.load(self.configs["backbone"], map_location=self.configs["device"],weights_only=False)
423
+
424
+ # Optional: strip 'module.' if present
425
+ state = bb["model"]
426
+ if any(k.startswith("module.") for k in state.keys()):
427
+ from collections import OrderedDict
428
+ state = OrderedDict((k.replace("module.", "", 1), v) for k, v in state.items())
429
+
430
+ missing, unexpected = self.model.mae.load_state_dict(state, strict=False)
431
+ print("loaded backbone")
432
+ if missing: print(f"Missing keys: {len(missing)} (showing first 5): {missing[:5]}")
433
+ if unexpected: print(f"Unexpected keys: {len(unexpected)} (first 5): {unexpected[:5]}")
434
+
435
+ # (Optional) freeze backbone for warmup
436
+ for p in self.model.mae.parameters():
437
+ p.requires_grad = False
438
+ if os.path.exists(self.configs["densebackbone"]):
439
+ densebb=torch.load(self.configs["densebackbone"], map_location=self.configs["device"])
440
+ densestate = densebb["model"]
441
+ if any(k.startswith("module.") for k in state.keys()):
442
+ from collections import OrderedDict
443
+ state = OrderedDict((k.replace("module.", "", 1), v) for k, v in densestate.items())
444
+ densemissing, denseunexpected = self.model.dense.load_state_dict(densestate, strict=False)
445
+ print("loaded dense backbone")
446
+ if densemissing: print(f"Missing keys: {len(densemissing)} (showing first 5): {densemissing[:5]}")
447
+ if denseunexpected: print(f"Unexpected keys: {len(denseunexpected)} (first 5): {denseunexpected[:5]}")
448
+
449
+ self.test_dataset = None
450
+ self.testloader = None
451
+ if configs.get("test_csv"):
452
+ self.test_dataset = CheXpertDataset(
453
+ zip_path=configs["zip_path"],
454
+ csv_path=configs["test_csv"],
455
+ root_dir=configs["datadir"],
456
+ augment=False,
457
+ use_frontal_only=True
458
+ )
459
+ self.testloader = DataLoader(
460
+ self.test_dataset,
461
+ batch_size=configs["batch_size"],
462
+ shuffle=False,
463
+ num_workers=0,
464
+ pin_memory=True,
465
+ persistent_workers=False
466
+ )
467
+ print(f"Test loader ready – {len(self.test_dataset)} images")
468
+
469
+ torch.backends.cudnn.benchmark = True
470
+ torch.backends.cudnn.enabled = True
471
+
472
+ # FIX: Set memory allocator settings
473
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
474
+
475
+ # FIX: Enable gradient checkpointing if model supports it
476
+ if hasattr(self.model, 'enable_gradient_checkpointing'):
477
+ self.model.enable_gradient_checkpointing()
478
+ @staticmethod
479
+ def plot_training_metrics(metrics, epoch,figs_path):
480
+ import matplotlib.pyplot as plt
481
+ """
482
+ Plot loss and AUC curves from training metrics.
483
+
484
+ Args:
485
+ metrics (dict): Dictionary containing lists for each metric key:
486
+ {
487
+ "train_loss": [...],
488
+ "val_loss": [...],
489
+ "train_macro_auc": [...],
490
+ "val_macro_auc": [...],
491
+ "train_micro_auc": [...],
492
+ "val_micro_auc": [...]
493
+ }
494
+ epoch (int): Current epoch number (used for title or axis scaling)
495
+ """
496
+ epochs = list(range(1, epoch + 1))
497
+
498
+ #Compute the common length across all series
499
+ keys = ["train_loss","val_loss","train_macro_auc","val_macro_auc","train_micro_auc","val_micro_auc"]
500
+ lengths = [len(metrics[k]) for k in keys if k in metrics]
501
+ if not lengths:
502
+ return
503
+ n = min(lengths)
504
+
505
+ # Slice everything to the same length
506
+ m = {k: metrics[k][:n] for k in keys if k in metrics}
507
+ epochs = list(range(1, n + 1))
508
+
509
+ plt.figure(figsize=(14, 6))
510
+
511
+
512
+ # ---- Loss subplot ----
513
+ plt.subplot(1, 2, 1)
514
+ plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
515
+ plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
516
+ plt.xlabel("Epoch")
517
+ plt.ylabel("Loss")
518
+ plt.title("Training & Validation Loss")
519
+ plt.legend()
520
+ plt.grid(True, linestyle='--', alpha=0.6)
521
+
522
+ # ---- AUC subplot ----
523
+ plt.subplot(1, 2, 2)
524
+ plt.plot(epochs, metrics["train_macro_auc"], label="Train Macro AUC", marker='o')
525
+ plt.plot(epochs, metrics["val_macro_auc"], label="Val Macro AUC", marker='s')
526
+ plt.plot(epochs, metrics["train_micro_auc"], label="Train Micro AUC", marker='^')
527
+ plt.plot(epochs, metrics["val_micro_auc"], label="Val Micro AUC", marker='v')
528
+ plt.xlabel("Epoch")
529
+ plt.ylabel("AUC")
530
+ plt.title("Training & Validation AUC (Macro/Micro)")
531
+ plt.legend()
532
+ plt.grid(True, linestyle='--', alpha=0.6)
533
+
534
+ plt.tight_layout()
535
+ os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
536
+ plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
537
+ plt.show()
538
+
539
+
540
+
541
+ def train_epoch(self, epoch, looper):
542
+ self.model.train()
543
+ running_loss = 0.0
544
+ all_preds = []
545
+ all_targets = []
546
+
547
+ total_batches = len(self.trainloader)
548
+
549
+ for batch_idx, data in looper:
550
+ image = data['image'].to(self.configs["device"], non_blocking=True)
551
+ target = data['labels'].to(self.configs["device"], non_blocking=True)
552
+
553
+ #with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
554
+ logits = self.model(image)
555
+ #with torch.autocast(device_type=self.configs["device"].type, enabled=False):
556
+
557
+ preds = torch.sigmoid(logits.float())
558
+ loss = self.criterion(preds, target)
559
+
560
+ loss_back = loss / self.configs["accumulation"]
561
+ running_loss += loss.item()
562
+
563
+ if torch.isfinite(loss):
564
+ loss_back.backward()
565
+ #self.scaler.scale(loss_back).backward()
566
+ else:
567
+ self.optimizer.zero_grad(set_to_none=True)
568
+ continue
569
+
570
+ if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
571
+ #self.scaler.unscale_(self.optimizer)
572
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
573
+ #self.scaler.step(self.optimizer)
574
+ #self.scaler.update()
575
+ self.optimizer.step()
576
+
577
+ self.optimizer.zero_grad(set_to_none=True)
578
+
579
+ # Store for AUC
580
+ all_preds.append(preds.detach().cpu())
581
+ all_targets.append(target.detach().cpu())
582
+
583
+ # === LIVE METRICS (every batch) ===
584
+ current_loss = running_loss / (batch_idx + 1)
585
+ if (batch_idx + 1) % 500 == 0 and len(all_preds) > 0:
586
+ preds_np = torch.cat(all_preds).numpy()
587
+ targets_np = torch.cat(all_targets).numpy()
588
+ macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
589
+ micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
590
+
591
+ current_lr = self.optimizer.param_groups[0]['lr']
592
+ looper.set_postfix({
593
+ "lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
594
+ "epoch": f"{epoch}/{self.configs['num_epochs']}",
595
+ "loss": f"{current_loss:.3f}",
596
+ "macro": f"{macro_auc:.3f}",
597
+ "micro": f"{micro_auc:.3f}"
598
+ })
599
+
600
+ # === FINAL FULL EPOCH METRICS ===
601
+ preds_full = torch.cat(all_preds).numpy()
602
+ targets_full = torch.cat(all_targets).numpy()
603
+ final_loss = running_loss / total_batches
604
+ final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
605
+ final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
606
+
607
+ del all_preds, all_targets, preds_full, targets_full
608
+
609
+ return final_loss, final_macro_auc, final_micro_auc
610
+
611
+ def validate(self, epoch, looper):
612
+ self.model.eval()
613
+ val_loss = 0.0
614
+ all_preds = []
615
+ all_targets = []
616
+ lenloader=len(self.valloader)
617
+
618
+ with torch.no_grad():
619
+ for batch_idx, data in looper:
620
+ image = data["image"].to(self.configs["device"], non_blocking=True)
621
+ target = data["labels"].to(self.configs["device"], non_blocking=True)
622
+
623
+ logits = self.model(image)
624
+
625
+ preds = torch.sigmoid(logits.float())
626
+ loss = self.criterion(preds, target)
627
+
628
+ val_loss += loss.item()
629
+
630
+ all_preds.append(preds.detach().cpu())
631
+ all_targets.append(target.detach().cpu())
632
+
633
+
634
+ # === LIVE METRICS ===
635
+ current_loss = val_loss / (batch_idx + 1)
636
+ if (batch_idx + 1) % 200 == 0 and len(all_preds) > 0:
637
+ preds_np = torch.cat(all_preds).numpy()
638
+ targets_np = torch.cat(all_targets).numpy()
639
+ macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
640
+ micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
641
+ looper.set_postfix({
642
+ "epoch": f"{epoch}/{self.configs['num_epochs']}",
643
+ "batch":f"{batch_idx}/{lenloader}",
644
+ "loss": f"{current_loss:.3f}",
645
+ "macro": f"{macro_auc:.3f}",
646
+ "micro": f"{micro_auc:.3f}"
647
+ })
648
+
649
+
650
+
651
+ # === FINAL FULL VALIDATION METRICS ===
652
+ preds_full = torch.cat(all_preds).numpy()
653
+ targets_full = torch.cat(all_targets).numpy()
654
+ num_classes = 14
655
+ new_thresholds = [0.5] * num_classes # default
656
+
657
+ for class_idx in range(num_classes):
658
+ if targets_full[:, class_idx].sum() == 0:
659
+ # no positive samples, keep default 0.5
660
+ continue
661
+
662
+ thresholds = np.arange(0.1, 0.9, 0.02)
663
+ best_score = -1
664
+ best_threshold = 0.5
665
+
666
+ for threshold in thresholds:
667
+ preds_bin = (preds_full[:, class_idx] >= threshold).astype(int)
668
+ tn, fp, fn, tp = confusion_matrix(
669
+ targets_full[:, class_idx].astype(int),
670
+ preds_bin
671
+ ).ravel()
672
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
673
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
674
+ score = sensitivity + specificity - 1
675
+
676
+ if score > best_score:
677
+ best_score = score
678
+ best_threshold = threshold
679
+
680
+ new_thresholds[class_idx] = best_threshold
681
+
682
+ # after loop:
683
+ self.optimal_thresholds = new_thresholds
684
+
685
+
686
+ final_loss = val_loss / lenloader
687
+ final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
688
+ final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
689
+
690
+ del all_preds, all_targets, preds_full, targets_full
691
+
692
+ return final_loss, final_macro_auc, final_micro_auc
693
+ def train(self):
694
+
695
+ for epoch in range(self.current_epoch,self.configs["num_epochs"]):
696
+ trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=True,file=self.traintee)
697
+ vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=True,file=self.valtee)
698
+
699
+
700
+ self.model.train()
701
+ self.schedular.step()
702
+ self.optimizer.zero_grad(set_to_none=True)
703
+
704
+ running_loss,macro_auc,micro_auc=self.train_epoch(epoch,trainlooper)
705
+
706
+ torch.cuda.synchronize()
707
+ torch.cuda.empty_cache()
708
+
709
+ val_loss,val_macro_auc,val_micro_auc=self.validate(epoch,vallooper)
710
+
711
+ torch.cuda.synchronize()
712
+ torch.cuda.empty_cache()
713
+
714
+ gc.collect()
715
+
716
+ if (self.history["val_macro_auc"] and (val_macro_auc>max(self.history["val_macro_auc"]))) or (self.history["val_micro_auc"] and val_micro_auc>max(self.history["val_micro_auc"])):
717
+ checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),
718
+ "schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch
719
+ ,"thresholds":self.optimal_thresholds }
720
+ torch.save(checkpoint, self.configs["resume"])
721
+
722
+ print(f"epoch {epoch} train loss {running_loss} val loss {val_loss} val_macro_auc {val_macro_auc} val_micro_auc {val_micro_auc} train_macro_auc {macro_auc} train_micro_auc {micro_auc}")
723
+
724
+ self.history["train_loss"].append(float(running_loss))
725
+ self.history["val_loss"].append(float(val_loss))
726
+ self.history["train_macro_auc"].append(float(macro_auc))
727
+ self.history["val_macro_auc"].append(float(val_macro_auc))
728
+ self.history["train_micro_auc"].append(float(micro_auc))
729
+ self.history["val_micro_auc"].append(float(val_micro_auc))
730
+
731
+
732
+ historyfile=os.path.join(self.configs["logdir"],"history.json")
733
+ if os.path.exists(historyfile):
734
+ with open(historyfile,"r") as f:
735
+ history=json.load(f)
736
+ history["train_loss"]+=self.history["train_loss"]
737
+ history["val_loss"]+=self.history["val_loss"]
738
+ history["train_macro_auc"]+=self.history["train_macro_auc"]
739
+ history["val_macro_auc"]+=self.history["val_macro_auc"]
740
+ with open(historyfile,"w") as f:
741
+ json.dump(self.history,f)
742
+ f.close()
743
+
744
+ if epoch%10==0:Trainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
745
+
746
+ self.current_epoch=epoch
747
+ def test(self, model_path=None, return_preds=False):
748
+ """
749
+ Run a complete test evaluation.
750
+ If `model_path` is given, load that checkpoint first.
751
+ Returns (macro_auc, micro_auc, per_class_auc_dict) or predictions if requested.
752
+ """
753
+ if model_path:
754
+ ckpt = torch.load(model_path, map_location=self.configs["device"])
755
+ self.model.load_state_dict(ckpt["model"])
756
+ print(f"Loaded checkpoint {model_path}")
757
+
758
+ if self.testloader is None:
759
+ raise RuntimeError("No test loader – provide `test_csv` in config")
760
+
761
+ self.model.eval()
762
+ all_preds, all_targets = [], []
763
+
764
+ test_loss = 0.0
765
+ looper = tqdm(enumerate(self.testloader), total=len(self.testloader),
766
+ desc="Testing ",file=self.testtee)
767
+
768
+ with torch.inference_mode():
769
+ for batch_idx, data in looper:
770
+ img = data['image'].to(self.configs["device"], non_blocking=True)
771
+ tgt = data['labels'].to(self.configs["device"], non_blocking=True)
772
+ #image_1ch=data['image_1ch'].to(self.configs["device"], non_blocking=True)
773
+
774
+ logits = self.model(img)
775
+ if self.optimal_thresholds:
776
+ # class-wise thresholds in probability-space, e.g. list/array length C
777
+ # self.optimal_thresholds[c] = tau_c
778
+ taus = torch.tensor(self.optimal_thresholds, device=logits.device).view(1, -1)
779
+
780
+ # convert thresholds from prob to logit
781
+ margins = torch.log(taus / (1 - taus)) # shape [1, C]
782
+
783
+ # shift logits by the margin
784
+ # now BCEWithLogitsLoss thinks the decision boundary is at logits == margins
785
+ # equivalently: decision boundary in original logits is at 'margins'
786
+ logits = logits - margins
787
+ probs = torch.sigmoid(logits)
788
+ loss = self.criterion(probs, tgt)
789
+ test_loss += loss.item()
790
+
791
+ all_preds.append(probs.cpu())
792
+ all_targets.append(tgt.cpu())
793
+
794
+ # live stats
795
+ cur_loss = test_loss / (batch_idx + 1)
796
+ if all_preds:
797
+ p = torch.cat(all_preds).numpy()
798
+ t = torch.cat(all_targets).numpy()
799
+ macro = roc_auc_score(t, p, average='macro')
800
+ micro = roc_auc_score(t, p, average='micro')
801
+ else:
802
+ macro = micro = 0.0
803
+ looper.set_postfix(loss=f"{cur_loss:.4f}",
804
+ macro=f"{macro:.4f}",
805
+ micro=f"{micro:.4f}")
806
+
807
+ # ---- final metrics ----
808
+ preds = torch.cat(all_preds).numpy()
809
+ targets = torch.cat(all_targets).numpy()
810
+ final_loss = test_loss / len(self.testloader)
811
+ macro_auc = roc_auc_score(targets, preds, average='macro')
812
+ micro_auc = roc_auc_score(targets, preds, average='micro')
813
+
814
+ # per-class AUC
815
+ per_class = {}
816
+ for i, name in enumerate(self.train_dataset.get_label_names()):
817
+ if targets[:, i].sum() > 0: # avoid division-by-zero
818
+ per_class[name] = roc_auc_score(targets[:, i], preds[:, i])
819
+ else:
820
+ per_class[name] = float('nan')
821
+
822
+ # ---- pretty table ----
823
+ print("\n" + "="*80)
824
+ print(f"TEST RESULTS (loss={final_loss:.4f})")
825
+ print("="*80)
826
+ print(f"{'Pathology':<30} {'AUC':>8}")
827
+ print("-"*40)
828
+ for name, auc in per_class.items():
829
+ print(f"{name:<30} {auc:>8.4f}" if not np.isnan(auc) else f"{name:<30} {'N/A':>8}")
830
+ print("-"*40)
831
+ print(f"{'Macro AUC':<30} {macro_auc:>8.4f}")
832
+ print(f"{'Micro AUC':<30} {micro_auc:>8.4f}")
833
+ print("="*80)
834
+
835
+ if return_preds:
836
+ return macro_auc, micro_auc, per_class, (preds, targets)
837
+ return macro_auc, micro_auc, per_class
training logs/classifier/1/metrics.png ADDED
training logs/classifier/11/metrics.png ADDED

Git LFS Details

  • SHA256: 8209bfa73deef4eb2ada87a881bd5aec71579834c0a0df9ffea883be3d3a7cd8
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
training logs/classifier/Events.docx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c683f158db1946053b66c0a5769962a8d51bad98058490cfa4aedba5582f45d
3
+ size 1680108
training logs/classifier/history.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train_loss": [2.451549026515934, 2.324592100605649, 2.2527450666496867, 2.2051324946319886, 2.159476125092837, 2.1111394616786736, 2.057536503908261, 1.9841906148749953, 1.9176961825764776, 1.8619107825900996, 1.7461218646035648, 1.6598046678827294, 0.14859689984745142, 0.14480384502754282, 0.1413286671667666, 0.13848335853005814], "val_loss": [1.4677110827817452, 1.414012653402026, 1.3714177432875805, 1.2308028351501579, 1.3173101589027862, 1.327703529975834, 1.3617598478416082, 1.3191542980376254, 1.2529189027799352, 1.5297267407960213, 1.5704542597879037, 1.768142464009117, 0.1340647471810548, 0.13415449232644355, 0.13759738845883238, 0.13678586165817191], "train_macro_auc": [0.5915813508746322, 0.6911077814868211, 0.7180677573787114, 0.7330628389066268, 0.7446141059156149, 0.7555600431027648, 0.7611192999386304, 0.768528268418801, 0.7723525292111644, 0.7753779303902529, 0.7868138491223837, 0.7944627804675177, 0.7901445091432617, 0.8022003143300758, 0.815119789621127, 0.8244893337730623], "val_macro_auc": [0.6697306681955048, 0.7073059079688858, 0.7307953052152676, 0.7387044904612824, 0.7454194006038561, 0.7500732498762482, 0.7486698915393023, 0.7534324811456612, 0.7528406138149186, 0.7499375597482462, 0.7467580969017149, 0.7441320142787182, 0.7530519050114038, 0.7533548440220946, 0.749456570265707, 0.7485324814519589], "train_micro_auc": [0.7215386383447403, 0.7813286014278719, 0.7966446631205264, 0.8059628819522804, 0.8134203450690528, 0.8203666769222709, 0.8240765555633897, 0.8301711223679948, 0.833068249934892, 0.8355527627001551, 0.8440327424520534, 0.8495436085077932, 0.8464713300561717, 0.8555554974362551, 0.8642513599137882, 0.8715251218243751], "val_micro_auc": [0.830096540126698, 0.8465147824166648, 0.85730031114022, 0.8612904284897909, 0.8638291274014742, 0.8621590439819964, 0.8642181574740424, 0.8699097719546053, 0.8705562840474825, 0.8696004009592124, 0.8691057071199912, 0.867407803196298, 0.869385951883187, 0.8707534276030608, 0.8618755289952147, 0.8671620953710033]}
training logs/classifier/test_log.txt ADDED
The diff for this file is too large to render. See raw diff
 
training logs/classifier/training_log.txt ADDED
The diff for this file is too large to render. See raw diff
 
training logs/classifier/val_log.txt ADDED
The diff for this file is too large to render. See raw diff
 
training logs/mae/1/metrics.png ADDED
training logs/mae/101/metrics.png ADDED
training logs/mae/11/metrics.png ADDED