adelelsayed1991 commited on
Commit
5ffe2e2
·
verified ·
1 Parent(s): 3655e02

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ weights/*.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Adel Elsayed
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Masked Autoencoder (MAE) for Medical Imaging
2
+
3
+ A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset.
4
+
5
+ ## 📋 Overview
6
+
7
+ This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data.
8
+
9
+ ### Key Features
10
+
11
+ - **Vision Transformer Architecture**: Encoder-decoder transformer architecture with positional encodings
12
+ - **Self-Supervised Learning**: Pre-training through masked image reconstruction
13
+ - **Optimized for Medical Imaging**: Designed specifically for chest X-ray analysis
14
+ - **Production-Ready Training Pipeline**:
15
+ - Mixed precision training (FP16) with gradient scaling
16
+ - Gradient accumulation support
17
+ - Learning rate warmup and cosine annealing
18
+ - Automatic checkpointing and resumption
19
+ - **Efficient Data Loading**:
20
+ - Optimized ZIP file reader with LRU caching
21
+ - Class-balanced sampling with weighted random sampler
22
+ - Multi-worker data loading with persistent workers
23
+ - **Comprehensive Logging**: Training/validation metrics tracking and visualization
24
+
25
+ ## 🏗️ Architecture
26
+
27
+ ### Masked Autoencoder Structure
28
+
29
+ ```
30
+ Input Image (384×384)
31
+
32
+ Patchify (16×16 patches → 576 patches)
33
+
34
+ Random Masking (75% masked, 25% visible)
35
+
36
+ ┌─────────────────────────────────────┐
37
+ │ MAE ENCODER │
38
+ │ - Linear patch embedding │
39
+ │ - Positional encoding (visible) │
40
+ │ - 12 Transformer blocks │
41
+ │ - 8 attention heads, 768 hidden │
42
+ └─────────────────────────────────────┘
43
+
44
+ ┌─────────────────────────────────────┐
45
+ │ MAE DECODER │
46
+ │ - Learnable mask tokens │
47
+ │ - Positional encoding (all) │
48
+ │ - 8 Transformer blocks │
49
+ │ - 8 attention heads, 512 hidden │
50
+ │ - Pixel reconstruction head │
51
+ └─────────────────────────────────────┘
52
+
53
+ Reconstructed Image
54
+
55
+ MSE Loss (on masked patches only)
56
+ ```
57
+
58
+ ### Model Configuration
59
+
60
+ | Parameter | Default Value | Description |
61
+ |-----------|---------------|-------------|
62
+ | Image Size | 384×384 | Input image resolution |
63
+ | Patch Size | 16×16 | Size of each patch |
64
+ | Mask Ratio | 0.75 | Fraction of patches to mask |
65
+ | Encoder Depth | 12 layers | Number of transformer blocks |
66
+ | Encoder Dim | 768 | Hidden dimension |
67
+ | Encoder Heads | 8 | Number of attention heads |
68
+ | Decoder Depth | 8 layers | Number of transformer blocks |
69
+ | Decoder Dim | 512 | Hidden dimension |
70
+ | Decoder Heads | 8 | Number of attention heads |
71
+ | MLP Ratio | 4× | MLP expansion ratio (3072) |
72
+ | Dropout | 0.25 | Dropout rate |
73
+
74
+ ## 🚀 Getting Started
75
+
76
+ ### Prerequisites
77
+
78
+ - Python >= 3.8
79
+ - CUDA-capable GPU (recommended)
80
+ - 16GB+ RAM
81
+
82
+ ### Installation
83
+
84
+ 1. Clone the repository:
85
+ ```bash
86
+ git clone https://github.com/adelelsayed/mae.git
87
+ cd mae
88
+ ```
89
+
90
+ 2. Install dependencies:
91
+ ```bash
92
+ pip install -r requirements.txt
93
+ ```
94
+
95
+ ### Dataset Preparation
96
+
97
+ This project is configured for the **CheXpert dataset**. To use it:
98
+
99
+ 1. Download CheXpert-v1.0-small from [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/)
100
+ 2. Update paths in `configs/configs.py`:
101
+ - `root`: Base directory for your data
102
+ - `zip_path`: Path to zipped dataset (optional, for faster loading)
103
+ - `csv`: Path to training CSV
104
+ - `train_csv`, `val_csv`, `test_csv`: Split CSV files
105
+
106
+ ## 📊 Usage
107
+
108
+ ### Training
109
+
110
+ Start training from scratch:
111
+ ```bash
112
+ python trainer/trainer.py
113
+ ```
114
+
115
+ The trainer will:
116
+ - Automatically create checkpoint and log directories
117
+ - Resume from the last checkpoint if available
118
+ - Log training/validation metrics to text files
119
+ - Save plots every 10 epochs
120
+ - Save best model based on validation loss
121
+
122
+ ### Training Configuration
123
+
124
+ Edit `configs/configs.py` to customize training:
125
+
126
+ ```python
127
+ mae_config = {
128
+ # Training hyperparameters
129
+ "lr": 1e-4, # Learning rate
130
+ "warmup": 5, # Warmup epochs
131
+ "weight_decay": 5e-4, # AdamW weight decay
132
+ "num_epochs": 200, # Total training epochs
133
+ "batch_size": 96, # Batch size
134
+ "accumulation": 1, # Gradient accumulation steps
135
+
136
+ # Model architecture
137
+ "mask_ratio": 0.75, # Masking ratio
138
+ "encoder_depth": 12, # Encoder layers
139
+ "decoder_depth": 8, # Decoder layers
140
+
141
+ # Paths
142
+ "checkpoints": "/path/to/checkpoints",
143
+ "logdir": "/path/to/logs",
144
+ ...
145
+ }
146
+ ```
147
+
148
+ ### Monitoring Training
149
+
150
+ Training logs are saved in three files:
151
+ - `training_log.txt`: Training metrics per epoch
152
+ - `val_log.txt`: Validation metrics per epoch
153
+ - `test_log.txt`: Test set evaluation results
154
+
155
+ Metrics plots are saved every 10 epochs in `{logdir}/{epoch}/metrics.png`
156
+
157
+ ### Evaluation
158
+
159
+ The project includes a test method in the trainer. To evaluate:
160
+ ```python
161
+ from trainer.utils import MAETrainer
162
+ from configs.configs import mae_config
163
+
164
+ trainer = MAETrainer(mae_config)
165
+ trainer.test()
166
+ ```
167
+
168
+ ## 📁 Project Structure
169
+
170
+ ```
171
+ mae/
172
+ ├── configs/
173
+ │ ├── __init__.py
174
+ │ └── configs.py # Training configuration
175
+ ├── data/
176
+ │ ├── __init__.py
177
+ │ ├── dataset.py # CheXpert dataset loader
178
+ │ └── splitter.py # Dataset splitting utilities
179
+ ├── loss/
180
+ │ ├── __init__.py
181
+ │ └── mae_loss.py # MAE reconstruction loss
182
+ ├── models/
183
+ │ ├── __init__.py
184
+ │ └── mae.py # MAE architecture
185
+ ├── trainer/
186
+ │ ├── __init__.py
187
+ │ ├── trainer.py # Main training script
188
+ │ └── utils.py # Training utilities
189
+ ├── notebooks/
190
+ │ └── chexpert_mae.ipynb # Jupyter notebook for experiments
191
+ ├── training logs/ # Logged metrics and plots
192
+ ├── weights/ # Model checkpoints
193
+ ├── results/ # Evaluation results
194
+ ├── requirements.txt # Python dependencies
195
+ ├── LICENSE # Project license
196
+ └── README.md # This file
197
+ ```
198
+
199
+ ## 🔧 Components
200
+
201
+ ### Dataset (`data/dataset.py`)
202
+
203
+ - **OptimizedZipReader**: Fast ZIP file reading with LRU caching
204
+ - **CheXpertDataset**: PyTorch dataset for CheXpert chest X-rays
205
+ - 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc.
206
+ - Albumentations-based augmentation pipeline
207
+ - Class-balanced sampling support
208
+ - Frontal/lateral view filtering
209
+
210
+ ### Model (`models/mae.py`)
211
+
212
+ - **Patchify/Unpatchify**: Image-to-patch conversion utilities
213
+ - **Random Masking**: Stochastic patch masking with restore indices
214
+ - **PositionalEncoding**: Learnable position embeddings
215
+ - **TransformerBlock**: Multi-head self-attention + MLP
216
+ - **MAEEncoder**: Processes visible patches only
217
+ - **MAEDecoder**: Reconstructs full image with mask tokens
218
+ - **MaskedAutoEncoder**: Complete MAE model
219
+
220
+ ### Loss (`loss/mae_loss.py`)
221
+
222
+ Mean Squared Error (MSE) computed only on masked patches:
223
+ ```python
224
+ loss = ((pred - target) ** 2 * mask).sum() / mask.sum()
225
+ ```
226
+
227
+ ### Trainer (`trainer/utils.py`)
228
+
229
+ - **MAETrainer**: Complete training pipeline
230
+ - Mixed precision training (AMP)
231
+ - Gradient clipping and accumulation
232
+ - Learning rate scheduling (warmup → cosine)
233
+ - Automatic checkpointing
234
+ - Multi-file logging (train/val/test)
235
+ - Live metric monitoring with tqdm
236
+ - Periodic metric visualization
237
+
238
+ ## 🎯 CheXpert Pathologies
239
+
240
+ The dataset includes 14 chest X-ray findings:
241
+
242
+ 1. No Finding
243
+ 2. Enlarged Cardiomediastinum
244
+ 3. Cardiomegaly
245
+ 4. Lung Opacity
246
+ 5. Lung Lesion
247
+ 6. Edema
248
+ 7. Consolidation
249
+ 8. Pneumonia
250
+ 9. Atelectasis
251
+ 10. Pneumothorax
252
+ 11. Pleural Effusion
253
+ 12. Pleural Other
254
+ 13. Fracture
255
+ 14. Support Devices
256
+
257
+ ## 📈 Training Tips
258
+
259
+ 1. **Learning Rate**: Start with 1e-4, use warmup for stability
260
+ 2. **Batch Size**: Maximize based on GPU memory (96 works well on 40GB GPUs)
261
+ 3. **Gradient Accumulation**: Use if batch size is limited by memory
262
+ 4. **Mixed Precision**: Enabled by default for faster training
263
+ 5. **Masking Ratio**: 75% is standard, higher ratios increase difficulty
264
+ 6. **Resume Training**: Model automatically resumes from last checkpoint
265
+
266
+ ## 🔬 Use Cases
267
+
268
+ ### Pre-training for Downstream Tasks
269
+ Use the trained encoder as a feature extractor:
270
+ ```python
271
+ from models.mae import MaskedAutoEncoder
272
+
273
+ # Load pre-trained model
274
+ mae = MaskedAutoEncoder()
275
+ mae.load_state_dict(torch.load("best_mae.pth")["model"])
276
+
277
+ # Use encoder for feature extraction
278
+ encoder = mae.encoder
279
+ features, _, _, _ = encoder(images)
280
+ ```
281
+
282
+ ### Fine-tuning on Classification
283
+ Add a classification head to the encoder for supervised tasks.
284
+
285
+ ### Anomaly Detection
286
+ Reconstruction error can indicate abnormalities in medical images.
287
+
288
+ ## 📊 Performance Optimization
289
+
290
+ This implementation includes several optimizations:
291
+
292
+ - **Efficient ZIP Reading**: Avoids extracting files to disk
293
+ - **LRU Cache**: Keeps frequently accessed images in memory
294
+ - **Persistent Workers**: Reduces data loading overhead
295
+ - **Mixed Precision**: 2× faster training with minimal quality loss
296
+ - **Gradient Checkpointing**: Reduces memory usage (if enabled)
297
+ - **CUDA Memory Management**: Proper cache clearing and synchronization
298
+
299
+ ## 🤝 Contributing
300
+
301
+ Contributions are welcome! Please feel free to submit a Pull Request.
302
+
303
+ ## 📄 License
304
+
305
+ This project is licensed under the terms specified in the LICENSE file.
306
+
307
+ ## 📚 References
308
+
309
+ 1. **Masked Autoencoders Are Scalable Vision Learners**
310
+ He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022)
311
+ [arXiv:2111.06377](https://arxiv.org/abs/2111.06377)
312
+
313
+ 2. **CheXpert: A Large Chest Radiograph Dataset**
314
+ Irvin, J., et al. (2019)
315
+ [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/)
316
+
317
+ ## 🙏 Acknowledgments
318
+
319
+ - Original MAE paper by Meta AI Research
320
+ - CheXpert dataset by Stanford ML Group
321
+ - PyTorch and Albumentations communities
322
+
323
+ ## 📧 Contact
324
+
325
+ For questions or issues, please open an issue on GitHub or contact the maintainer.
326
+
327
+ ---
328
+
329
+ **Note**: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance.
configs/__init__.py ADDED
File without changes
configs/configs.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ root = "/content/drive/MyDrive"
4
+ mae_config={
5
+ "lr":1e-4,
6
+ "warmup":5,
7
+ "weight_decay":5e-4,
8
+ "num_epochs":200,
9
+ "num_classes":14,
10
+ "zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"),
11
+ "resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"),
12
+ "logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs"),
13
+ "checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),
14
+ "datadir":root,
15
+ "lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"),
16
+ "csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"),
17
+ "batch_size":96,
18
+ "device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
19
+ "accumulation":1,
20
+ "dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),os.path.join(root,"CheXpert-v1.0-small","maelogs")],
21
+ "train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"),
22
+ "val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"),
23
+ "test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
24
+ ,"channels":1,"mask_ratio":0.75,"dropout":0.25,"img_size":384,"encoder_dim":768,
25
+ "mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8,
26
+ "decoder_head":8,"patch_size":16
27
+ }
data/__init__.py ADDED
File without changes
data/dataset.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library
2
+ import os
3
+ import io
4
+ import zipfile
5
+ import pickle
6
+ from pathlib import Path
7
+
8
+ # Data handling
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ # PyTorch
13
+ import torch
14
+ from torch.utils.data import Dataset
15
+
16
+ # Image processing
17
+ from PIL import Image
18
+ import cv2
19
+
20
+ # Augmentations
21
+ import albumentations as A
22
+ from albumentations.pytorch import ToTensorV2
23
+
24
+ # Progress bar (for precompute_all_masks)
25
+ from tqdm import tqdm
26
+
27
+ class OptimizedZipReader:
28
+ """
29
+ Fast ZIP file reader with LRU caching
30
+ """
31
+ def __init__(self, zip_path, cache_size=1000):
32
+ """
33
+ Args:
34
+ zip_path: Path to ZIP file
35
+ cache_size: Number of images to cache in RAM
36
+ """
37
+ self.zip_path = zip_path
38
+ self.cache_size = cache_size
39
+ self._zip_file = None # Will be lazily initialized
40
+ self._name_to_info = None
41
+
42
+ # Cache
43
+ self._cache = {}
44
+ self._cache_order = []
45
+ self._hits = 0
46
+ self._misses = 0
47
+
48
+ @property
49
+ def zip_file(self):
50
+ """Lazy initialization of ZIP file handle"""
51
+ if self._zip_file is None:
52
+ print(f"Opening ZIP file: {self.zip_path}")
53
+ self._zip_file = zipfile.ZipFile(self.zip_path, 'r', allowZip64=True)
54
+
55
+ # Build index on first access
56
+ print("Building ZIP index...")
57
+ self._name_to_info = {
58
+ info.filename: info
59
+ for info in self._zip_file.infolist()
60
+ }
61
+ print(f"✓ Indexed {len(self._name_to_info)} files")
62
+
63
+ return self._zip_file
64
+
65
+ def read_image(self, path):
66
+ """
67
+ Read image data with automatic caching
68
+
69
+ Returns: bytes (image file data)
70
+ """
71
+ # Check cache first
72
+ if path in self._cache:
73
+ self._hits += 1
74
+ return self._cache[path]
75
+
76
+ # Cache miss - read from ZIP (this triggers lazy initialization)
77
+ self._misses += 1
78
+ img_data = self.zip_file.read(path) # Uses property getter
79
+
80
+ # Add to cache with LRU eviction
81
+ if len(self._cache) >= self.cache_size:
82
+ oldest = self._cache_order.pop(0)
83
+ del self._cache[oldest]
84
+
85
+ self._cache[path] = img_data
86
+ self._cache_order.append(path)
87
+
88
+ return img_data
89
+
90
+ def get_cache_stats(self):
91
+ """Return cache hit rate statistics"""
92
+ total = self._hits + self._misses
93
+ hit_rate = self._hits / total * 100 if total > 0 else 0
94
+ return {
95
+ 'hits': self._hits,
96
+ 'misses': self._misses,
97
+ 'hit_rate': f"{hit_rate:.2f}%",
98
+ 'cache_size': len(self._cache)
99
+ }
100
+
101
+ def close(self):
102
+ """Close ZIP file and clear cache"""
103
+ if self._zip_file is not None:
104
+ self._zip_file.close()
105
+ self._zip_file = None
106
+ self._cache.clear()
107
+ self._cache_order.clear()
108
+ self._name_to_info = None
109
+
110
+ class CheXpertDataset(Dataset):
111
+ """
112
+ CheXpert Dataset class
113
+
114
+ NEW: Returns 3-channel images: (img, img*mask, mask)
115
+ - Channel 0: Original grayscale image
116
+ - Channel 1: Masked image (lung region only)
117
+ - Channel 2: Binary lung mask
118
+
119
+ Args:
120
+ csv_path (str): Path to the CSV file (train.csv or valid.csv)
121
+ root_dir (str): Root directory of the CheXpert dataset
122
+ image_size (int): Target image size (default: 384)
123
+ augment (bool): Whether to apply augmentations (default: False)
124
+ use_frontal_only (bool): If True, only use frontal view images (default: True)
125
+ fill_uncertain (str): How to handle uncertain labels: 'zeros', 'ones', 'ignore' (default: 'zeros')
126
+ """
127
+
128
+ # 14 pathology classes in CheXpert
129
+ PATHOLOGIES = [
130
+ 'No Finding',
131
+ 'Enlarged Cardiomediastinum',
132
+ 'Cardiomegaly',
133
+ 'Lung Opacity',
134
+ 'Lung Lesion',
135
+ 'Edema',
136
+ 'Consolidation',
137
+ 'Pneumonia',
138
+ 'Atelectasis',
139
+ 'Pneumothorax',
140
+ 'Pleural Effusion',
141
+ 'Pleural Other',
142
+ 'Fracture',
143
+ 'Support Devices'
144
+ ]
145
+
146
+ def __init__(
147
+ self,
148
+ csv_path,
149
+ root_dir,
150
+ image_size=384,
151
+ augment=False,
152
+ use_frontal_only=False,
153
+ fill_uncertain='ignore',
154
+ lmdb_path=None,
155
+ zip_path=None,
156
+ zip_cache_size=1000,
157
+ mask_dir=None, domask=False
158
+ ):
159
+ self.root_dir = root_dir
160
+ self.image_size = image_size
161
+ self.augment = augment
162
+ self.fill_uncertain = fill_uncertain
163
+ self.env =None #lmdb.open(lmdb_path, readonly=True, lock=False) if lmdb_path else None
164
+ self._zip_path = zip_path
165
+ self._zip_cache_size = zip_cache_size
166
+ self._zip_reader_instance = None
167
+
168
+
169
+ # Read CSV file
170
+ self.df = pd.read_csv(csv_path)
171
+ for pathology in self.PATHOLOGIES:
172
+ if pathology in self.df.columns:
173
+ self.df[pathology] = pd.to_numeric(self.df[pathology], errors='coerce')
174
+
175
+ # Filter for frontal views only if specified
176
+ if use_frontal_only:
177
+ self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
178
+
179
+ # Handle uncertain labels (-1 values)
180
+ self._process_uncertain_labels()
181
+
182
+ # Setup augmentations
183
+ self.train_transform = self._get_train_transforms()
184
+ self.val_transform = self._get_val_transforms()
185
+
186
+ print(f"Loaded {len(self.df)} images from {csv_path}")
187
+ print(f"Image size: {image_size}x{image_size}")
188
+ print(f"Augmentation: {augment}")
189
+ print(f"Uncertain labels filled with: {fill_uncertain}")
190
+
191
+ if mask_dir and domask:
192
+ self.precompute_all_masks(mask_dir)
193
+
194
+ # Run this ONCE before training
195
+ def precompute_all_masks(self, save_dir):
196
+ os.makedirs(save_dir, exist_ok=True)
197
+ for idx in tqdm(range(len(self))):
198
+ img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
199
+ part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
200
+ if self.zip_reader:
201
+ # Read image data from ZIP (no extraction!)
202
+ img_data = self.zip_reader.read_image(part_path)
203
+
204
+ # Open image from bytes in memory
205
+ image = Image.open(io.BytesIO(img_data)).convert('L')
206
+ else:
207
+ image = Image.open(img_path).convert('L')
208
+
209
+ image = np.array(image)
210
+
211
+ mask = chexpert_medsam_mask(image)
212
+ mask_path = os.path.join(save_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
213
+ os.makedirs(os.path.dirname(mask_path), exist_ok=True)
214
+ torch.save(mask, mask_path)
215
+ @property
216
+ def zip_reader(self):
217
+ """
218
+ Lazy property getter for ZIP reader
219
+
220
+ The ZIP file is only opened when first accessed, not during __init__.
221
+ This is useful when:
222
+ - Creating multiple dataset objects but only using some
223
+ - Saving memory during dataset setup
224
+ - Working with multiprocessing (each worker creates its own)
225
+ """
226
+ if self._zip_reader_instance is None and self._zip_path is not None:
227
+ self._zip_reader_instance = OptimizedZipReader(
228
+ self._zip_path,
229
+ cache_size=self._zip_cache_size
230
+ )
231
+ return self._zip_reader_instance
232
+
233
+ def _load_and_cache_image(self, img_path, idx):
234
+ """
235
+ Load image with automatic resizing and caching.
236
+ If resized version exists, load it. Otherwise, resize, save, and load.
237
+
238
+ Args:
239
+ img_path (str): Original image path from CSV
240
+ idx (int): Index for tracking
241
+
242
+ Returns:
243
+ np.ndarray: Loaded image (grayscale)
244
+ """
245
+ # Create cache directory structure
246
+ cache_dir = Path(self.root_dir) #/ f"cache_{self.image_size}"
247
+
248
+ # Preserve the relative path structure in cache
249
+ path_parts = list(Path(img_path).parts)
250
+ path_parts[-1]=f"{self.image_size}_{path_parts[-1]}"
251
+ relative_path = Path(*path_parts)
252
+ cached_path =relative_path.with_suffix('.jpg')
253
+
254
+ # Check if cached version exists
255
+ if cached_path.exists():
256
+ # Load cached image
257
+ image = Image.open(cached_path).convert('L')
258
+ image = np.array(image)
259
+
260
+ # Verify it's the correct size
261
+ if image.shape[0] == self.image_size and image.shape[1] == self.image_size:
262
+ return image
263
+
264
+ # Cache doesn't exist or wrong size - load original
265
+ original_path = img_path
266
+ image = Image.open(original_path).convert('L')
267
+
268
+ # Check if original is already target size
269
+ width, height = image.size
270
+
271
+ if width == self.image_size and height == self.image_size:
272
+ # Already correct size, just convert to array
273
+ return np.array(image)
274
+
275
+ # Resize image
276
+ image_resized = image.resize(
277
+ (self.image_size, self.image_size),
278
+ Image.LANCZOS
279
+ )
280
+
281
+ # Save to cache
282
+ cached_path.parent.mkdir(parents=True, exist_ok=True)
283
+ image_resized.save(cached_path, 'JPEG', quality=95, optimize=True)
284
+
285
+ return np.array(image_resized)
286
+
287
+ def _process_uncertain_labels(self):
288
+ """Process uncertain labels (-1) based on the chosen strategy."""
289
+ for pathology in self.PATHOLOGIES:
290
+ if pathology in self.df.columns:
291
+ if self.fill_uncertain == 'zeros':
292
+ # Map uncertain (-1) to negative (0)
293
+ self.df[pathology] = self.df[pathology].replace(-1, 0)
294
+ elif self.fill_uncertain == 'ones':
295
+ # Map uncertain (-1) to positive (1)
296
+ self.df[pathology] = self.df[pathology].replace(-1, 1)
297
+ elif self.fill_uncertain == 'ignore':
298
+ # Keep -1 as is (you'll need to handle this in loss function)
299
+ pass
300
+
301
+ # Fill NaN with 0 (negative)
302
+ self.df[pathology] = self.df[pathology].fillna(0)
303
+
304
+ def _get_train_transforms(self):
305
+ """Get training augmentations suitable for chest X-rays."""
306
+ import cv2
307
+ return A.Compose([
308
+ # Resize to target size
309
+ A.LongestMaxSize(max_size=self.image_size),
310
+ A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
311
+
312
+ # Geometric augmentations (conservative for medical images)
313
+ A.HorizontalFlip(p=0.5),
314
+ A.Affine(
315
+ translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
316
+ scale=(0.9, 1.1),
317
+ rotate=(-10, 10),
318
+ fit_output=False,
319
+ p=0.5
320
+ ),
321
+
322
+ # Intensity augmentations
323
+ A.OneOf([
324
+ A.RandomBrightnessContrast(
325
+ brightness_limit=0.2,
326
+ contrast_limit=0.2,
327
+ p=1.0
328
+ ),
329
+ A.RandomGamma(gamma_limit=(80, 120), p=1.0),
330
+ A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
331
+ ], p=0.5),
332
+
333
+ # Add slight blur to simulate different imaging conditions
334
+ A.OneOf([
335
+ A.GaussianBlur(blur_limit=(3, 5), p=1.0),
336
+ A.MedianBlur(blur_limit=3, p=1.0),
337
+ ], p=0.2),
338
+
339
+ # Add noise
340
+ A.GaussNoise(p=0.2),
341
+
342
+ # Normalize to [0, 1]
343
+ A.Normalize(
344
+ mean=[0.5],
345
+ std=[0.5],
346
+ max_pixel_value=255.0
347
+ ),
348
+
349
+ ToTensorV2()
350
+ ])
351
+
352
+ def _get_val_transforms(self):
353
+ """Get validation/test transforms (no augmentation)."""
354
+ return A.Compose([
355
+ A.LongestMaxSize(max_size=self.image_size),
356
+ A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
357
+ A.Normalize(
358
+ mean=[0.5],
359
+ std=[0.5],
360
+ max_pixel_value=255.0
361
+ ),
362
+ ToTensorV2()
363
+ ])
364
+
365
+ def __len__(self):
366
+ return len(self.df)
367
+
368
+ def __del__(self):
369
+ """Close ZIP when done"""
370
+ if hasattr(self, 'zip_reader'):
371
+ self.zip_reader.close()
372
+
373
+ def __getitem__(self, idx):
374
+ if self.env:
375
+ with self.env.begin() as txn:
376
+ # Retrieve serialized data
377
+ data = txn.get(str(idx).encode())
378
+ sample = pickle.loads(data)
379
+ return sample
380
+ else:
381
+ # Get image path
382
+ img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
383
+ #image = self._load_and_cache_image(img_path, idx)
384
+ # Load image
385
+ #image = Image.open(img_path).convert('L') # Convert to grayscale
386
+
387
+ part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
388
+ if self.zip_reader:
389
+ # Read image data from ZIP (no extraction!)
390
+ img_data = self.zip_reader.read_image(part_path)
391
+
392
+ # Open image from bytes in memory
393
+ image = Image.open(io.BytesIO(img_data)).convert('L')
394
+ else:
395
+ image = Image.open(img_path).convert('L')
396
+
397
+ image = np.array(image)
398
+
399
+
400
+ # Load pre-computed mask
401
+ #mask_path = os.path.join(self.mask_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
402
+ #masked_img = torch.load(mask_path)
403
+ # Apply transforms to BOTH image and mask together
404
+ if self.augment:
405
+ # Augmentation applies to both image and mask
406
+ transformed = self.train_transform(image=image)
407
+ image_transformed = transformed['image'] # (1, H, W) tensor, normalized
408
+ #masked_img=transformed['mask']
409
+ # (H, W) tensor
410
+ else:
411
+ transformed = self.val_transform(image=image)
412
+ image_transformed = transformed['image'] # (1, H, W) tensor, normalized
413
+ #masked_img=transformed['mask']
414
+
415
+ # Expand dimensions to match
416
+ image_1ch = image_transformed # (1, H, W)
417
+ masked_img = image_transformed
418
+
419
+ # Get labels for all pathologies
420
+ labels = []
421
+ for pathology in self.PATHOLOGIES:
422
+ if pathology in self.df.columns:
423
+ label = self.df.iloc[idx][pathology]
424
+ labels.append(float(label) if not pd.isna(label) else 0.0)
425
+ else:
426
+ labels.append(0.0)
427
+
428
+ labels = torch.tensor(labels, dtype=torch.float32)
429
+
430
+ # Get additional metadata
431
+ metadata = {
432
+ 'patient_id': self.df.iloc[idx]['Path'].split('/')[2], # Extract patient ID from path
433
+ 'study_id': self.df.iloc[idx]['Path'].split('/')[3], # Extract study ID from path
434
+ 'view': self.df.iloc[idx]['Frontal/Lateral'],
435
+ 'sex': self.df.iloc[idx]['Sex'] if 'Sex' in self.df.columns else 'Unknown',
436
+ 'age': self.df.iloc[idx]['Age'] if 'Age' in self.df.columns else -1,
437
+ 'path': self.df.iloc[idx]['Path']
438
+ }
439
+
440
+ return {
441
+ 'image': image_1ch,
442
+ 'labels': labels,
443
+ 'metadata': metadata
444
+ }
445
+
446
+ def get_label_names(self):
447
+ """Return list of pathology label names."""
448
+ return self.PATHOLOGIES
449
+
450
+ def get_label_distribution(self):
451
+ """Get distribution of positive labels for each pathology."""
452
+ distribution = {}
453
+ for pathology in self.PATHOLOGIES:
454
+ if pathology in self.df.columns:
455
+ positive_count = (self.df[pathology] == 1.0).sum()
456
+ distribution[pathology] = {
457
+ 'positive': int(positive_count),
458
+ 'percentage': round(positive_count / len(self.df) * 100, 2)
459
+ }
460
+ return distribution
461
+
462
+ def get_class_weights(self):
463
+ """
464
+ OPTIMIZED: Vectorized class weights calculation
465
+ """
466
+ weights = []
467
+ for pathology in self.PATHOLOGIES:
468
+ if pathology in self.df.columns:
469
+ # Vectorized counting (much faster than iterating)
470
+ values = self.df[pathology].values
471
+ pos = np.sum(values == 1.0)
472
+ neg = np.sum(values == 0.0)
473
+ weight = neg / pos if pos > 0 else 1.0
474
+ weights.append(weight)
475
+ return torch.tensor(weights, dtype=torch.float32)
476
+
477
+ def get_sample_weights(self):
478
+ """
479
+ OPTIMIZED: Vectorized sample weights calculation
480
+
481
+ Performance: ~1000x faster than original
482
+ Original: 15-30 seconds for 200k samples
483
+ This: 0.01-0.05 seconds for 200k samples
484
+ """
485
+ # Get class weights as numpy array
486
+ class_weights = self.get_class_weights().numpy()
487
+
488
+ # Get all labels as numpy array in ONE vectorized operation
489
+ labels_array = self.df[self.PATHOLOGIES].values.astype(np.float32)
490
+
491
+ # Create weighted labels matrix: where label=1, use class_weight, else -inf
492
+ # Shape: (n_samples, n_classes)
493
+ weighted_labels = np.where(
494
+ labels_array == 1.0,
495
+ class_weights,
496
+ -np.inf # Use -inf instead of 0 so max will only consider positive labels
497
+ )
498
+
499
+ # For each sample, find the maximum class weight of its positive labels
500
+ # If a sample has no positive labels, max will be -inf, which we'll replace with 1.0
501
+ sample_weights = np.max(weighted_labels, axis=1)
502
+ sample_weights = np.where(
503
+ np.isinf(sample_weights),
504
+ 1.0, # Samples with no positive labels get weight 1.0
505
+ sample_weights
506
+ )
507
+
508
+ return torch.tensor(sample_weights, dtype=torch.float32)
data/splitter.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # Data handling
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ # Machine learning
10
+ from sklearn.model_selection import train_test_split
11
+
12
+ class CheXpertDataSplitter:
13
+ """
14
+ Advanced stratified train-validation splitter for CheXpert dataset.
15
+ Handles:
16
+ - Patient-level splitting (prevents data leakage)
17
+ - Multi-label stratification
18
+ - Class imbalance awareness
19
+ - Study-level grouping
20
+ """
21
+
22
+ PATHOLOGIES = [
23
+ 'No Finding',
24
+ 'Enlarged Cardiomediastinum',
25
+ 'Cardiomegaly',
26
+ 'Lung Opacity',
27
+ 'Lung Lesion',
28
+ 'Edema',
29
+ 'Consolidation',
30
+ 'Pneumonia',
31
+ 'Atelectasis',
32
+ 'Pneumothorax',
33
+ 'Pleural Effusion',
34
+ 'Pleural Other',
35
+ 'Fracture',
36
+ 'Support Devices'
37
+ ]
38
+
39
+ def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42,
40
+ use_frontal_only=True, fill_uncertain='zeros',root=None):
41
+ """
42
+ Initialize the splitter.
43
+
44
+ Args:
45
+ csv_path: Path to train.csv from CheXpert-small
46
+ val_size: Validation set proportion (default: 0.15)
47
+ random_state: Random seed for reproducibility
48
+ use_frontal_only: Use only frontal view images
49
+ fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore')
50
+ """
51
+ self.csv_path = csv_path
52
+ self.val_size = val_size
53
+ self.test_size = test_size
54
+ self.random_state = random_state
55
+ self.use_frontal_only = use_frontal_only
56
+ self.fill_uncertain = fill_uncertain
57
+ self.root=root
58
+
59
+ print("=" * 80)
60
+ print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias")
61
+ print("=" * 80)
62
+
63
+ def load_and_preprocess(self):
64
+ """Load and preprocess the dataset."""
65
+ print("\n[1/5] Loading data...")
66
+ self.df = pd.read_csv(self.csv_path)
67
+ print(f" Loaded {len(self.df)} images")
68
+
69
+ #self.df=self.df[self.df["Path"].apply(os.path.exists)]
70
+
71
+ # Filter for frontal views only
72
+ if self.use_frontal_only:
73
+ initial_count = len(self.df)
74
+ self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
75
+ print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)")
76
+
77
+ # Extract patient and study IDs from path
78
+ print("\n[2/5] Extracting patient and study IDs...")
79
+ self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2])
80
+ self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3])
81
+
82
+ n_patients = self.df['patient_id'].nunique()
83
+ n_studies = self.df['study_id'].nunique()
84
+ print(f" Unique patients: {n_patients}")
85
+ print(f" Unique studies: {n_studies}")
86
+ print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}")
87
+
88
+ # Process uncertain labels
89
+ print("\n[3/5] Processing uncertain labels...")
90
+ self._process_uncertain_labels()
91
+
92
+ return self.df
93
+
94
+ def _process_uncertain_labels(self):
95
+ """Process uncertain labels (-1) based on the chosen strategy."""
96
+ for pathology in self.PATHOLOGIES:
97
+ if pathology in self.df.columns:
98
+ uncertain_count = (self.df[pathology] == -1).sum()
99
+
100
+ if self.fill_uncertain == 'zeros':
101
+ self.df[pathology] = self.df[pathology].replace(-1, 0)
102
+ elif self.fill_uncertain == 'ones':
103
+ self.df[pathology] = self.df[pathology].replace(-1, 1)
104
+ elif self.fill_uncertain == 'ignore':
105
+ pass # Keep -1 as is
106
+
107
+ # Fill NaN with 0
108
+ self.df[pathology] = self.df[pathology].fillna(0)
109
+
110
+ print(f" Uncertain labels strategy: {self.fill_uncertain}")
111
+
112
+ def create_stratification_groups(self):
113
+ """
114
+ Create stratification groups based on multi-label combinations.
115
+ Uses patient-level aggregation to prevent data leakage.
116
+ """
117
+ print("\n[4/5] Creating stratification groups (patient-level)...")
118
+
119
+ # Group by patient and aggregate labels
120
+ patient_groups = self.df.groupby('patient_id').agg({
121
+ **{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns},
122
+ 'study_id': 'first', # Keep one study_id for reference
123
+ 'Sex': 'first',
124
+ 'Age': 'first'
125
+ }).reset_index()
126
+
127
+ # Create label signature for each patient
128
+ # This is a binary string representing which conditions are present
129
+ def create_label_signature(row):
130
+ signature = []
131
+ for pathology in self.PATHOLOGIES:
132
+ if pathology in patient_groups.columns:
133
+ signature.append(str(int(row[pathology])))
134
+ return ''.join(signature)
135
+
136
+ patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1)
137
+
138
+ # For rare combinations, group them together
139
+ signature_counts = patient_groups['label_signature'].value_counts()
140
+ rare_threshold = max(5, int(len(patient_groups) * 0.001)) # At least 5 or 0.1%
141
+
142
+ def get_stratification_group(signature):
143
+ if signature_counts[signature] < rare_threshold:
144
+ return 'RARE_COMBINATION'
145
+ return signature
146
+
147
+ patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group)
148
+
149
+ # Print distribution statistics
150
+ print(f"\n Patient-level label distribution:")
151
+ for pathology in self.PATHOLOGIES:
152
+ if pathology in patient_groups.columns:
153
+ positive_count = (patient_groups[pathology] == 1).sum()
154
+ percentage = positive_count / len(patient_groups) * 100
155
+ print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)")
156
+
157
+ unique_groups = patient_groups['stratification_group'].nunique()
158
+ print(f"\n Unique stratification groups: {unique_groups}")
159
+ print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}")
160
+
161
+ return patient_groups
162
+
163
+ def perform_split(self, patient_groups):
164
+ """
165
+ Perform stratified train-validation-test split at patient level.
166
+ """
167
+ print("\n[5/5] Performing stratified patient-level split...")
168
+
169
+ stratification_labels = patient_groups['stratification_group'].values
170
+
171
+ # ---- train / (val+test) ----
172
+ train_patients, valtest_patients = train_test_split(
173
+ patient_groups['patient_id'].values,
174
+ test_size=self.val_size + self.test_size, # <-- new
175
+ stratify=stratification_labels,
176
+ random_state=self.random_state
177
+ )
178
+
179
+ # ---- val / test from the remaining pool ----
180
+ remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values
181
+ val_patients, test_patients = train_test_split(
182
+ valtest_patients,
183
+ test_size=self.test_size / (self.val_size + self.test_size), # <-- proportion of the val+test pool
184
+ stratify=remaining_labels,
185
+ random_state=self.random_state
186
+ )
187
+
188
+ print(f" Train patients: {len(train_patients)}")
189
+ print(f" Val patients: {len(val_patients)}")
190
+ print(f" Test patients: {len(test_patients)}")
191
+
192
+ # Split the full dataframe
193
+ train_df = self.df[self.df['patient_id'].isin(train_patients)].copy()
194
+ val_df = self.df[self.df['patient_id'].isin(val_patients)].copy()
195
+ test_df = self.df[self.df['patient_id'].isin(test_patients)].copy()
196
+
197
+ # ---- leakage check (train vs val vs test) ----
198
+ sets = [('train', train_df), ('val', val_df), ('test', test_df)]
199
+ for i, (name_i, df_i) in enumerate(sets):
200
+ for j, (name_j, df_j) in enumerate(sets[i+1:]):
201
+ overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id']))
202
+ if overlap:
203
+ raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap")
204
+ print("\n No patient overlap – leakage prevented!")
205
+
206
+ return train_df, val_df, test_df
207
+
208
+ def run(self, output_dir='.', save_test=True):
209
+ self.load_and_preprocess()
210
+ patient_groups = self.create_stratification_groups()
211
+ train_df, val_df, test_df = self.perform_split(patient_groups)
212
+
213
+ self.verify_split_quality(train_df, val_df)
214
+ # optional: also verify train vs test (same function works with two dfs)
215
+ print("\n--- Train vs Test distribution check ---")
216
+ self.verify_split_quality(train_df, test_df)
217
+
218
+ train_path, val_path = self.save_splits(train_df, val_df, output_dir)
219
+ if save_test:
220
+ test_path = self.save_test_split(test_df, output_dir)
221
+ else:
222
+ test_path = None
223
+
224
+ print("\n" + "="*80)
225
+ print("Split Complete! (train / val / test)")
226
+ print("="*80)
227
+ return train_path, val_path, test_path
228
+
229
+ def save_test_split(self, test_df, output_dir):
230
+ output_dir = Path(output_dir)
231
+ output_dir.mkdir(exist_ok=True)
232
+ test_path = output_dir / 'test_ready.csv'
233
+
234
+ cols_to_drop = ['patient_id', 'study_id']
235
+ test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns])
236
+ test_clean.to_csv(test_path, index=False)
237
+
238
+ print(f"Test set : {test_path} ({len(test_clean)} images)")
239
+ return test_path
240
+
241
+ def verify_split_quality(self, train_df, val_df):
242
+ """
243
+ Verify the quality of the split by comparing label distributions.
244
+ """
245
+ print("\n" + "=" * 80)
246
+ print("Split Quality Verification")
247
+ print("=" * 80)
248
+
249
+ print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}")
250
+ print("-" * 80)
251
+
252
+ max_diff = 0
253
+ for pathology in self.PATHOLOGIES:
254
+ if pathology in train_df.columns:
255
+ train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100
256
+ val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100
257
+ diff = abs(train_pos - val_pos)
258
+ max_diff = max(max_diff, diff)
259
+
260
+ print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%")
261
+
262
+ print("-" * 80)
263
+ print(f"Maximum distribution difference: {max_diff:.2f}%")
264
+
265
+ if max_diff < 2.0:
266
+ print("✓ Excellent stratification (< 2% difference)")
267
+ elif max_diff < 5.0:
268
+ print("✓ Good stratification (< 5% difference)")
269
+ else:
270
+ print("⚠ Warning: Large distribution differences detected")
271
+
272
+ # Check for class imbalance
273
+ print("\n" + "=" * 80)
274
+ print("Class Imbalance Analysis (Train Set)")
275
+ print("=" * 80)
276
+
277
+ imbalance_ratios = []
278
+ for pathology in self.PATHOLOGIES:
279
+ if pathology in train_df.columns:
280
+ pos = (train_df[pathology] == 1).sum()
281
+ neg = (train_df[pathology] == 0).sum()
282
+ if pos > 0:
283
+ ratio = neg / pos
284
+ imbalance_ratios.append(ratio)
285
+ severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High"
286
+ print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]")
287
+
288
+ avg_imbalance = np.mean(imbalance_ratios)
289
+ print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1")
290
+
291
+ def save_splits(self, train_df, val_df, output_dir='.'):
292
+ """Save train and validation splits to CSV files."""
293
+ output_dir = Path(output_dir)
294
+ output_dir.mkdir(exist_ok=True)
295
+
296
+ train_path = output_dir / 'train_ready.csv'
297
+ val_path = output_dir / 'val_ready.csv'
298
+
299
+ # Remove temporary columns used for splitting
300
+ columns_to_drop = ['patient_id', 'study_id']
301
+ train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns])
302
+ val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns])
303
+
304
+ train_df_clean.to_csv(train_path, index=False)
305
+ val_df_clean.to_csv(val_path, index=False)
306
+
307
+ print("\n" + "=" * 80)
308
+ print("Files Saved Successfully")
309
+ print("=" * 80)
310
+ print(f"Train set: {train_path} ({len(train_df_clean)} images)")
311
+ print(f"Val set: {val_path} ({len(val_df_clean)} images)")
312
+
313
+ return train_path, val_path
314
+
315
+ # Main execution
316
+ if __name__ == "__main__":
317
+ root = "/content/drive/MyDrive"
318
+ # Configuration
319
+ CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv") # Adjust path as needed
320
+ OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small")
321
+ VAL_SIZE = 0.15
322
+ RANDOM_STATE = 42
323
+ USE_FRONTAL_ONLY = True
324
+ FILL_UNCERTAIN = 'zeros' # Options: 'zeros', 'ones', 'ignore'
325
+
326
+ # Create splitter
327
+ splitter = CheXpertDataSplitter(
328
+ csv_path=CHEXPERT_CSV,
329
+ val_size=VAL_SIZE,test_size=VAL_SIZE,
330
+ random_state=RANDOM_STATE,
331
+ use_frontal_only=USE_FRONTAL_ONLY,
332
+ fill_uncertain=FILL_UNCERTAIN,
333
+ root=OUTPUT_DIR
334
+ )
335
+
336
+ # Run the split
337
+ if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")):
338
+ train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")
339
+ val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")
340
+ test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
341
+ else:
342
+ train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR)
343
+
344
+ print("\nYou can now use these files with your CheXpertDataset class:")
345
+ print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)")
346
+ print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)")
347
+ print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)")
loss/__init__.py ADDED
File without changes
loss/mae_loss.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ def mae_loss(pred, target, mask):
5
+ # pred/target: (B, N, P), mask: (B, N) with 1=masked
6
+ B, N, P = pred.shape
7
+ mask = mask.unsqueeze(-1).float() # (B, N, 1)
8
+ loss = (pred - target) ** 2
9
+ loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
10
+ return loss
models/__init__.py ADDED
File without changes
models/mae.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ def patchify(x,patch_size=8):
5
+ b,c,h,w=x.shape
6
+ th=h//patch_size
7
+ tw=w//patch_size
8
+ assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size"
9
+
10
+ out=x.reshape(b,c,th,patch_size,tw,patch_size)
11
+ out=out.permute(0,2,4,1,3,5).contiguous()
12
+ out=out.view(b,th*tw,c*(patch_size**2))
13
+ return out
14
+ def unpatchify(x,patch_size=8):
15
+ b,z,p=x.shape
16
+ c=p//(patch_size**2)
17
+ th=int(math.sqrt(z))
18
+ tw=th
19
+ h=th*patch_size
20
+ w=tw*patch_size
21
+ x=x.view(b,th,tw,c,patch_size,patch_size)
22
+ x=x.permute(0,3,1,4,2,5).contiguous()
23
+ out=x.view(b,c,h,w)
24
+ return out
25
+ def random_mask(x,mask_ratio=0.75):
26
+ b,n,p=x.shape
27
+ len_keep=int(n*(1-mask_ratio))
28
+ noise=torch.rand(b,n).to(x.device)
29
+ ids_shuffle=torch.argsort(noise,dim=1)
30
+ ids_restore=torch.argsort(ids_shuffle,dim=1)
31
+ ids_keep=ids_shuffle[:,:len_keep]
32
+ x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device)
33
+ mask=torch.ones(b,n).to(x.device)
34
+ mask[:,:len_keep]=0
35
+ mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device)
36
+ return x_masked,mask,ids_restore,ids_keep
37
+
38
+ def mae_loss(pred, target, mask):
39
+ # pred/target: (B, N, P), mask: (B, N) with 1=masked
40
+ B, N, P = pred.shape
41
+ mask = mask.unsqueeze(-1).float() # (B, N, 1)
42
+ loss = (pred - target) ** 2
43
+ loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
44
+ return loss
45
+
46
+ class PositionalEncoding(nn.Module):
47
+ def __init__(self,num_patches,hidden_dim=768):
48
+ super().__init__()
49
+ self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim))
50
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
51
+ def forward(self, x, visible_indices):
52
+ # x: (B, len_keep, D); visible_indices: (B, len_keep)
53
+ B, L, D = x.shape
54
+ # expand table to (B, N, D)
55
+ pos = self.pos_embed.expand(B, -1, -1) # (B, N, D)
56
+ # build gather index (B, L, D)
57
+ idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1))
58
+ visible_pos = torch.gather(pos, 1, idx) # (B, L, D)
59
+ return x + visible_pos
60
+
61
+ class TransformerBlock(nn.Module):
62
+ def __init__(self,hidden_dim,mlp_dim,num_heads,dropout):
63
+ super().__init__()
64
+ self.layernorm1=nn.LayerNorm(hidden_dim)
65
+ self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout)
66
+ self.layernorm2=nn.LayerNorm(hidden_dim)
67
+ self.mlp=nn.Sequential(
68
+ nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout)
69
+ )
70
+
71
+
72
+ def forward(self,x):
73
+ residual=x
74
+ x=self.layernorm1(x)
75
+ attn,_=self.multihead(x,x,x)
76
+ x=residual+attn
77
+ residual=x
78
+ x=self.layernorm2(x)
79
+ x=self.mlp(x)
80
+ x=residual+x
81
+ return x
82
+
83
+ class MAEEncoder(nn.Module):
84
+ """
85
+ patch_dim-> % non-masked * no. of patches
86
+ """
87
+ def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8):
88
+ super().__init__()
89
+ self.mask_ratio=mask_ratio
90
+ self.patch_size=patch_size
91
+ self.patch_embed=nn.Linear(patch_dim,hidden_dim)
92
+ self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim)
93
+ self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
94
+ for _ in range(depth)])
95
+
96
+ self._init_weights()
97
+ def _init_weights(self):
98
+ for m in self.modules():
99
+ if isinstance(m, nn.Linear):
100
+ nn.init.trunc_normal_(m.weight, std=0.02)
101
+ if m.bias is not None:
102
+ nn.init.constant_(m.bias, 0)
103
+
104
+ def forward(self,x_in):
105
+ x_p=patchify(x_in,self.patch_size)
106
+ x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio)
107
+ x= self.patch_embed(x_masked)
108
+ x=self.pos_embed(x,ids_keep)
109
+ for attn_layer in self.transformer:x=attn_layer(x)
110
+ return x,mask,ids_keep,ids_restore
111
+
112
+ class MAEDecoder(nn.Module):
113
+ def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout):
114
+ super().__init__()
115
+ self.num_patches=num_patches
116
+ self.encoder_dim=encoder_dim
117
+ self.decoder_dim=decoder_dim
118
+ self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim))
119
+ self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim)
120
+ self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim))
121
+ self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
122
+ for _ in range(decoder_depth)])
123
+ self.layernorm=nn.LayerNorm(decoder_dim)
124
+ self.pred=nn.Linear(decoder_dim,c*(patch_size**2))
125
+
126
+ self._init_weights()
127
+ def _init_weights(self):
128
+ for m in self.modules():
129
+ if isinstance(m, nn.Linear):
130
+ nn.init.trunc_normal_(m.weight, std=0.02)
131
+ if m.bias is not None:
132
+ nn.init.constant_(m.bias, 0)
133
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
134
+ nn.init.trunc_normal_(self.mask_token, std=0.02)
135
+ def forward(self,x,ids_keep,ids_restore):
136
+ b,n,p=x.shape
137
+ xdec=self.enc_to_dec(x)
138
+ len_keep=xdec.size(1)
139
+ num_patches=ids_restore.size(1)
140
+ num_mask=num_patches-len_keep
141
+
142
+ mask_token=self.mask_token.expand(b,num_mask,-1)
143
+ x_=torch.cat([xdec,mask_token],dim=1)
144
+ x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1)))
145
+ x_=x_+self.pos_embed
146
+ for block in self.transformer:x_=block(x_)
147
+ x_=self.layernorm(x_)
148
+ out=self.pred(x_)
149
+ return out
150
+
151
+ class MaskedAutoEncoder(nn.Module):
152
+ def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8):
153
+ super().__init__()
154
+ self.patch_size=patch_size
155
+ self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2
156
+ ,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head
157
+ ,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size)
158
+ self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size
159
+ ,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth
160
+ ,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout)
161
+
162
+ def forward(self,x):
163
+ b,c,h,w=x.shape
164
+ encoded,mask,ids_keep,ids_restore=self.encoder(x)
165
+ decoded=self.decoder(encoded,ids_keep,ids_restore)
166
+
167
+ xpatched=patchify(x,self.patch_size)
168
+ return xpatched,decoded,mask
169
+
170
+ @staticmethod
171
+ def testme():
172
+ img=torch.rand(1,1,384,384)
173
+ mae=MaskedAutoEncoder()
174
+ a,b,c=mae(img)
175
+ print(a.shape)
176
+ print(b.shape)
177
+ print(c.shape)
notebooks/chexpert_mae.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Deep Learning
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+
5
+ # Data Processing
6
+ numpy>=1.24.0
7
+ pandas>=2.0.0
8
+ scikit-learn>=1.3.0
9
+
10
+ # Image Processing
11
+ Pillow>=10.0.0
12
+ opencv-python>=4.8.0
13
+ albumentations>=1.3.1
14
+
15
+ # Visualization
16
+ matplotlib>=3.7.0
17
+ seaborn>=0.12.0
18
+
19
+ # Utilities
20
+ tqdm>=4.65.0
21
+
22
+ # Jupyter (optional - for notebooks)
23
+ jupyter>=1.0.0
24
+ ipykernel>=6.25.0
25
+ ipywidgets>=8.1.0
26
+
27
+ # Additional utilities (if needed)
28
+ # lmdb>=1.4.0 # Uncomment if using LMDB for caching
29
+ # tensorboard>=2.13.0 # Uncomment if using TensorBoard logging
trainer/__init__.py ADDED
File without changes
trainer/trainer.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trainer.utils import *
2
+ from configs.configs import root,mae_config
3
+
4
+ def main():
5
+ try:
6
+ print(f"Training mae")
7
+ trainer=MAETrainer(mae_config)
8
+ trainer.test()
9
+
10
+ except:
11
+ import traceback
12
+ traceback.print_exc()
13
+
14
+ if __name__=="__main__":main()
trainer/utils.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.dataset import CheXpertDataset
2
+ from loss.mae_loss import mae_loss
3
+ from models.mae import *
4
+ from torch.utils.data import DataLoader
5
+ import json
6
+ import os
7
+ import io
8
+ import sys
9
+
10
+ class TeeFile:
11
+ """
12
+ File-like object that writes to multiple streams (e.g., stdout and a file)
13
+ Automatically handles string paths by opening them as files.
14
+
15
+ Usage:
16
+ # This now works with both file objects and paths
17
+ tee = TeeFile(sys.stdout, "/path/to/log.txt")
18
+ print("Hello", file=tee) # Writes to both stdout and the file
19
+ """
20
+ def __init__(self, *file_objects_or_paths):
21
+ """
22
+ Args:
23
+ *file_objects_or_paths: Mix of file objects (like sys.stdout)
24
+ or string paths to log files
25
+ """
26
+ self.files = []
27
+ self.opened_files = [] # Track files we opened so we can close them later
28
+
29
+ for item in file_objects_or_paths:
30
+ if isinstance(item, str):
31
+ # It's a path string - open it as a file
32
+ f = open(item, 'a', buffering=1) # Append mode, line buffered
33
+ self.files.append(f)
34
+ self.opened_files.append(f)
35
+ else:
36
+ # It's already a file-like object (e.g., sys.stdout)
37
+ self.files.append(item)
38
+
39
+ def write(self, data):
40
+ """Write data to all streams"""
41
+ for f in self.files:
42
+ try:
43
+ f.write(data)
44
+ f.flush()
45
+ except Exception as e:
46
+ # Handle closed file gracefully
47
+ print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
48
+
49
+ def flush(self):
50
+ """Flush all streams"""
51
+ for f in self.files:
52
+ try:
53
+ f.flush()
54
+ except:
55
+ pass
56
+
57
+ def isatty(self):
58
+ """Check if any stream is a terminal (for tqdm compatibility)"""
59
+ return any(getattr(f, "isatty", lambda: False)() for f in self.files)
60
+
61
+ def fileno(self):
62
+ """Get file descriptor from any real file-like stream"""
63
+ for f in self.files:
64
+ if hasattr(f, "fileno"):
65
+ try:
66
+ return f.fileno()
67
+ except Exception:
68
+ pass
69
+ raise io.UnsupportedOperation("No fileno available")
70
+
71
+ def close(self):
72
+ """Close any files we opened"""
73
+ for f in self.opened_files:
74
+ try:
75
+ f.close()
76
+ except:
77
+ pass
78
+ self.opened_files.clear()
79
+
80
+ def __del__(self):
81
+ """Cleanup on deletion"""
82
+ self.close()
83
+
84
+ def __enter__(self):
85
+ """Context manager support"""
86
+ return self
87
+
88
+ def __exit__(self, exc_type, exc_val, exc_tb):
89
+ """Context manager cleanup"""
90
+ self.close()
91
+ return False
92
+
93
+ class MAETrainer:
94
+ def __init__(self,configs={}):
95
+
96
+ self.configs=configs
97
+ os.makedirs(configs["logdir"],exist_ok=True)
98
+ log_path_train = os.path.join(configs["logdir"], "training_log.txt")
99
+ log_path_val = os.path.join(configs["logdir"], "val_log.txt")
100
+ log_path_test = os.path.join(configs["logdir"], "test_log.txt")
101
+ #self.log_file = open(log_path, 'w', buffering=1)
102
+ self.traintee = TeeFile(sys.stdout, log_path_train)
103
+ self.valtee = TeeFile(sys.stdout, log_path_val)
104
+ self.testtee = TeeFile(sys.stdout, log_path_test)
105
+
106
+ for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
107
+
108
+ self.model=MaskedAutoEncoder(
109
+ c=configs["channels"],
110
+ mask_ratio=configs["mask_ratio"],
111
+ dropout=configs["dropout"],
112
+ img_size=configs["img_size"],
113
+ encoder_dim=configs["encoder_dim"],
114
+ mlp_dim=configs["mlp_dim"],
115
+ decoder_dim=configs["decoder_dim"],
116
+ encoder_depth=configs["encoder_depth"],
117
+ encoder_head=configs["encoder_head"],
118
+ decoder_depth=configs["decoder_depth"],
119
+ decoder_head=configs["decoder_head"],
120
+ patch_size=configs["patch_size"]
121
+ ).to(configs["device"])
122
+
123
+ self.criterion=mae_loss
124
+
125
+ self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
126
+ self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
127
+ self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
128
+ self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
129
+ self.scaler=torch.amp.GradScaler()
130
+
131
+ self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
132
+ self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
133
+ self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
134
+ self.sample_Weights=self.train_dataset.get_sample_weights()
135
+ self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
136
+ self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
137
+ self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
138
+ self.history={"train_loss":[],"val_loss":[]}
139
+
140
+ self.current_epoch=0
141
+
142
+ if os.path.exists(self.configs["resume"]):
143
+ loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
144
+ self.model.load_state_dict(loadedpickle["model"],strict=False)
145
+ self.optimizer.load_state_dict(loadedpickle["optimizer"])
146
+ self.schedular.load_state_dict(loadedpickle["schedular"])
147
+ self.schedular1.load_state_dict(loadedpickle["schedular1"])
148
+ self.schedular2.load_state_dict(loadedpickle["schedular2"])
149
+ self.scaler.load_state_dict(loadedpickle["scaler"])
150
+ self.current_epoch=loadedpickle["epoch"]+1
151
+
152
+
153
+
154
+ self.test_dataset = None
155
+ self.testloader = None
156
+ if configs.get("test_csv"):
157
+ self.test_dataset = CheXpertDataset(
158
+ zip_path=configs["zip_path"],
159
+ csv_path=configs["test_csv"],
160
+ root_dir=configs["datadir"],
161
+ augment=False,
162
+ use_frontal_only=True
163
+ )
164
+ self.testloader = DataLoader(
165
+ self.test_dataset,
166
+ batch_size=configs["batch_size"],
167
+ shuffle=False,
168
+ num_workers=8,
169
+ pin_memory=True,
170
+ persistent_workers=True
171
+ )
172
+ print(f"Test loader ready – {len(self.test_dataset)} images")
173
+
174
+ torch.backends.cudnn.benchmark = True
175
+ torch.backends.cudnn.enabled = True
176
+
177
+ # FIX: Set memory allocator settings
178
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
179
+
180
+ # FIX: Enable gradient checkpointing if model supports it
181
+ if hasattr(self.model, 'enable_gradient_checkpointing'):
182
+ self.model.enable_gradient_checkpointing()
183
+ @staticmethod
184
+ def plot_training_metrics(metrics, epoch,figs_path):
185
+ import matplotlib.pyplot as plt
186
+ """
187
+ Plot loss and AUC curves from training metrics.
188
+
189
+ Args:
190
+ metrics (dict): Dictionary containing lists for each metric key:
191
+ {
192
+ "train_loss": [...],
193
+ "val_loss": [...]
194
+ }
195
+ epoch (int): Current epoch number (used for title or axis scaling)
196
+ """
197
+ epochs = list(range(1, epoch + 1))
198
+
199
+ #Compute the common length across all series
200
+ keys = ["train_loss","val_loss"]
201
+ lengths = [len(metrics[k]) for k in keys if k in metrics]
202
+ if not lengths:
203
+ return
204
+ n = min(lengths)
205
+
206
+ # Slice everything to the same length
207
+ m = {k: metrics[k][:n] for k in keys if k in metrics}
208
+ epochs = list(range(1, n + 1))
209
+
210
+ plt.figure(figsize=(14, 6))
211
+
212
+
213
+ # ---- Loss subplot ----
214
+ plt.subplot(1, 2, 1)
215
+ plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
216
+ plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
217
+ plt.xlabel("Epoch")
218
+ plt.ylabel("Loss")
219
+ plt.title("Training & Validation Loss")
220
+ plt.legend()
221
+ plt.grid(True, linestyle='--', alpha=0.6)
222
+
223
+
224
+ plt.tight_layout()
225
+ os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
226
+ plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
227
+ plt.show()
228
+
229
+ def train_epoch(self, epoch, looper):
230
+ self.model.train()
231
+ running_loss = 0.0
232
+ all_preds = []
233
+ all_targets = []
234
+ current_loss=0
235
+ total_batches = len(self.trainloader)
236
+
237
+ for batch_idx, data in looper:
238
+ image = data['image'].to(self.configs["device"], non_blocking=True)
239
+ target = data['labels'].to(self.configs["device"], non_blocking=True)
240
+
241
+ with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
242
+ img,preds,mask = self.model(image)
243
+ loss = self.criterion(img,preds,mask)
244
+
245
+ loss_back = loss / self.configs["accumulation"]
246
+ running_loss += loss.item()
247
+
248
+ if torch.isfinite(loss):
249
+ #loss_back.backward()
250
+ self.scaler.scale(loss_back).backward()
251
+ else:
252
+ self.optimizer.zero_grad(set_to_none=True)
253
+ continue
254
+
255
+ if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
256
+ self.scaler.unscale_(self.optimizer)
257
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
258
+ self.scaler.step(self.optimizer)
259
+ self.scaler.update()
260
+ #self.optimizer.step()
261
+ self.schedular.step()
262
+ self.optimizer.zero_grad(set_to_none=True)
263
+
264
+
265
+ # === LIVE METRICS (every batch) ===
266
+ current_loss = running_loss / (batch_idx + 1)
267
+ if (batch_idx + 1) % 10 == 0:
268
+ current_lr = self.optimizer.param_groups[0]['lr']
269
+ looper.set_postfix({
270
+ "lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
271
+ "epoch": f"{epoch}/{self.configs['num_epochs']}",
272
+ "loss": f"{current_loss:.3f}",
273
+ })
274
+
275
+ return current_loss
276
+ def validate(self, epoch, looper):
277
+ self.model.eval()
278
+ val_loss = 0.0
279
+ all_preds = []
280
+ all_targets = []
281
+ lenloader=len(self.valloader)
282
+ current_loss=0
283
+ with torch.no_grad():
284
+ for batch_idx, data in looper:
285
+ image = data["image"].to(self.configs["device"], non_blocking=True)
286
+ target = data["labels"].to(self.configs["device"], non_blocking=True)
287
+
288
+ with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
289
+ img,preds,mask = self.model(image)
290
+ loss = self.criterion(img,preds,mask)
291
+
292
+ val_loss += loss.item()
293
+
294
+ # === LIVE METRICS ===
295
+ current_loss = val_loss / (batch_idx + 1)
296
+ if (batch_idx + 1) % 10 == 0 :
297
+
298
+ looper.set_postfix({
299
+ "epoch": f"{epoch}/{self.configs['num_epochs']}",
300
+ "batch":f"{batch_idx}/{lenloader}",
301
+ "loss": f"{current_loss:.3f}",
302
+ })
303
+
304
+ return current_loss
305
+ def train(self):
306
+
307
+ for epoch in range(self.current_epoch,self.configs["num_epochs"]):
308
+ trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
309
+ vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
310
+
311
+
312
+ self.model.train()
313
+ self.optimizer.zero_grad(set_to_none=True)
314
+
315
+ running_loss=self.train_epoch(epoch,trainlooper)
316
+
317
+ torch.cuda.synchronize()
318
+ torch.cuda.empty_cache()
319
+
320
+ val_loss=self.validate(epoch,vallooper)
321
+
322
+ torch.cuda.synchronize()
323
+ torch.cuda.empty_cache()
324
+
325
+ gc.collect()
326
+
327
+ if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
328
+ checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
329
+ torch.save(checkpoint, self.configs["resume"])
330
+
331
+ print(f"train loss {running_loss} val loss {val_loss}")
332
+
333
+ self.history["train_loss"].append(float(running_loss))
334
+ self.history["val_loss"].append(float(val_loss))
335
+
336
+ if epoch%10==0:
337
+ historyfile=os.path.join(self.configs["logdir"],"history.json")
338
+ if os.path.exists(historyfile):
339
+ with open(historyfile,"r") as f:
340
+ history=json.load(f)
341
+ history["train_loss"]+=self.history["train_loss"]
342
+ history["val_loss"]+=self.history["val_loss"]
343
+ with open(historyfile,"w") as f:
344
+ json.dump(self.history,f)
345
+ f.close()
346
+ MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
347
+
348
+ self.current_epoch=epoch
training logs/mae/1/metrics.png ADDED
training logs/mae/101/metrics.png ADDED
training logs/mae/11/metrics.png ADDED
training logs/mae/21/metrics.png ADDED
training logs/mae/31/metrics.png ADDED
training logs/mae/41/metrics.png ADDED
training logs/mae/51/metrics.png ADDED
training logs/mae/61/metrics.png ADDED
training logs/mae/71/metrics.png ADDED
training logs/mae/81/metrics.png ADDED
training logs/mae/91/metrics.png ADDED
training logs/mae/history.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train_loss": [61.882596965502664, 44.471386281221996, 34.87869473939301, 27.54925524051899, 24.282669028415476, 22.451558465855097, 21.147514946913635, 20.99175854891432, 20.213269732075354, 20.001737736117455, 19.91450946014842, 19.47891389648547, 19.18175616794162, 18.748555672895098, 19.04294045150921, 18.45393469068739, 18.838688672315264, 18.155403604131447, 18.37196905074581, 17.72255533279911, 18.3023255009805, 17.781992453742625, 18.199740923262837, 17.737673626834773, 18.033451236875255, 17.423879869522587, 17.789273495144315, 17.71034824754175, 17.673153098253366, 17.505879045301867, 17.242074561717263, 17.15336024376654, 16.90521138372387, 16.933767681053464, 17.133027586714768, 17.32885531353694, 17.239427022352867, 17.082445266340795, 17.198882104333585, 16.797237842655523, 16.80550524462081, 16.855286646060193, 16.820361117564648, 16.804262694492135, 16.60529948887432, 16.617685430109713, 16.49680861811484, 16.534883691247646, 16.66711660952551, 16.68186878833292, 16.830572265365216, 16.551616942156173, 16.58498664637193, 16.517176605552756, 16.586636129673238, 16.610006309153786, 16.467580378269208, 16.335460134390008, 16.45028829335312, 16.415938852508436, 16.657625508052046, 16.622210603119225, 16.319772068146737, 16.374398973253037, 16.252494745015245, 16.311780229520625, 16.422519305102714, 16.033588693133392, 16.024791444757934, 16.087340885422137, 16.039379536208287, 16.26473860928662, 16.34161920308212, 15.996231590804234, 16.295430011817633, 16.445986707407087, 16.343918548775402, 16.409462545251333, 16.49581729998298, 16.137155871921117, 16.05842663481244, 16.16612617533694, 16.27624153595244, 16.29507503646249, 16.29731023747434, 16.399930175316378, 16.08117872668851, 16.119801326464582, 16.0585214939596, 15.990199300977919, 16.033912498036592, 16.225505850050183, 16.006062768095283, 16.016956458553192, 15.915514986318499, 16.111989719978798, 15.927976318414066, 16.02773256541153, 15.936725686186103, 15.84361021441798, 15.960153004089136], "val_loss": [31.570541223418278, 22.954840935741945, 14.442683501893104, 9.339503538568946, 8.24503136790076, 6.798558349229173, 7.171370674209341, 6.004779509927744, 6.1726217491682185, 5.98896356120062, 5.595671336912238, 5.563971098079238, 5.132711923795681, 5.617544600337843, 4.961312991044054, 5.048685019990534, 5.025379383682808, 4.725488581134631, 4.978387813631482, 4.562931300793771, 4.850497535692893, 4.521047054335129, 4.649226747081921, 4.5786320276038595, 4.333143504355041, 4.4852356395848165, 4.235697840535363, 4.414844192935779, 4.24225838635847, 4.289312726239429, 4.266240818555965, 4.077501735021902, 4.319235841301192, 4.025318107731715, 4.189943225676831, 4.086437865349145, 4.028652789980867, 4.119319666263669, 3.907484631205714, 4.0297732788859015, 3.884794211466843, 4.035012978651991, 3.96724297437953, 3.852650080804413, 3.9237460725727273, 3.7834923750538367, 3.9779289901454584, 3.7871727095885928, 3.838316381967741, 3.852536988020735, 3.7242261183222265, 3.889924947605577, 3.681634605920988, 3.851548166370075, 3.699327661349528, 3.693325303321661, 3.7589161586127804, 3.61661579759414, 3.7395904396855553, 3.5934903867220958, 3.6865477134223, 3.6255838419511863, 3.5961770504416024, 3.650220192152004, 3.5313091785012687, 3.658342431153966, 3.5214638036746915, 3.5825591047736896, 3.535893441830759, 3.4996724699026722, 3.600268395636169, 3.4633193554672292, 3.5713670332962493, 3.4522710076202188, 3.5278821134092007, 3.5031054748649217, 3.426014715650945, 3.5055409483735347, 3.403536310227606, 3.5005639573664364, 3.434894419983772, 3.439712012725019, 3.4542255861022544, 3.3616595402904523, 3.4835145457638457, 3.351115513481571, 3.4225067292337004, 3.3756711308742284, 3.367003402044607, 3.418841015064835, 3.3144125257219588, 3.4290886646093326, 3.2998156674280517, 3.368062291826521, 3.337226649851498, 3.2932771796799973, 3.3656438268300306, 3.266660438423537, 3.3932779888774074, 3.2590805264406426, 3.298969637119889]}
training logs/mae/test_log.txt ADDED
File without changes
training logs/mae/training_log.txt ADDED
The diff for this file is too large to render. See raw diff
 
training logs/mae/val_log.txt ADDED
The diff for this file is too large to render. See raw diff