Upload folder using huggingface_hub
Browse files- .gitattributes +1 -35
- .gitignore +1 -0
- LICENSE +21 -0
- README.md +329 -0
- configs/__init__.py +0 -0
- configs/configs.py +27 -0
- data/__init__.py +0 -0
- data/dataset.py +508 -0
- data/splitter.py +347 -0
- loss/__init__.py +0 -0
- loss/mae_loss.py +10 -0
- models/__init__.py +0 -0
- models/mae.py +177 -0
- notebooks/chexpert_mae.ipynb +0 -0
- requirements.txt +29 -0
- trainer/__init__.py +0 -0
- trainer/trainer.py +14 -0
- trainer/utils.py +348 -0
- training logs/mae/1/metrics.png +0 -0
- training logs/mae/101/metrics.png +0 -0
- training logs/mae/11/metrics.png +0 -0
- training logs/mae/21/metrics.png +0 -0
- training logs/mae/31/metrics.png +0 -0
- training logs/mae/41/metrics.png +0 -0
- training logs/mae/51/metrics.png +0 -0
- training logs/mae/61/metrics.png +0 -0
- training logs/mae/71/metrics.png +0 -0
- training logs/mae/81/metrics.png +0 -0
- training logs/mae/91/metrics.png +0 -0
- training logs/mae/history.json +1 -0
- training logs/mae/test_log.txt +0 -0
- training logs/mae/training_log.txt +0 -0
- training logs/mae/val_log.txt +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
weights/*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pth
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Adel Elsayed
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Masked Autoencoder (MAE) for Medical Imaging
|
| 2 |
+
|
| 3 |
+
A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset.
|
| 4 |
+
|
| 5 |
+
## 📋 Overview
|
| 6 |
+
|
| 7 |
+
This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data.
|
| 8 |
+
|
| 9 |
+
### Key Features
|
| 10 |
+
|
| 11 |
+
- **Vision Transformer Architecture**: Encoder-decoder transformer architecture with positional encodings
|
| 12 |
+
- **Self-Supervised Learning**: Pre-training through masked image reconstruction
|
| 13 |
+
- **Optimized for Medical Imaging**: Designed specifically for chest X-ray analysis
|
| 14 |
+
- **Production-Ready Training Pipeline**:
|
| 15 |
+
- Mixed precision training (FP16) with gradient scaling
|
| 16 |
+
- Gradient accumulation support
|
| 17 |
+
- Learning rate warmup and cosine annealing
|
| 18 |
+
- Automatic checkpointing and resumption
|
| 19 |
+
- **Efficient Data Loading**:
|
| 20 |
+
- Optimized ZIP file reader with LRU caching
|
| 21 |
+
- Class-balanced sampling with weighted random sampler
|
| 22 |
+
- Multi-worker data loading with persistent workers
|
| 23 |
+
- **Comprehensive Logging**: Training/validation metrics tracking and visualization
|
| 24 |
+
|
| 25 |
+
## 🏗️ Architecture
|
| 26 |
+
|
| 27 |
+
### Masked Autoencoder Structure
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
Input Image (384×384)
|
| 31 |
+
↓
|
| 32 |
+
Patchify (16×16 patches → 576 patches)
|
| 33 |
+
↓
|
| 34 |
+
Random Masking (75% masked, 25% visible)
|
| 35 |
+
↓
|
| 36 |
+
┌─────────────────────────────────────┐
|
| 37 |
+
│ MAE ENCODER │
|
| 38 |
+
│ - Linear patch embedding │
|
| 39 |
+
│ - Positional encoding (visible) │
|
| 40 |
+
│ - 12 Transformer blocks │
|
| 41 |
+
│ - 8 attention heads, 768 hidden │
|
| 42 |
+
└─────────────────────────────────────┘
|
| 43 |
+
↓
|
| 44 |
+
┌─────────────────────────────────────┐
|
| 45 |
+
│ MAE DECODER │
|
| 46 |
+
│ - Learnable mask tokens │
|
| 47 |
+
│ - Positional encoding (all) │
|
| 48 |
+
│ - 8 Transformer blocks │
|
| 49 |
+
│ - 8 attention heads, 512 hidden │
|
| 50 |
+
│ - Pixel reconstruction head │
|
| 51 |
+
└─────────────────────────────────────┘
|
| 52 |
+
↓
|
| 53 |
+
Reconstructed Image
|
| 54 |
+
↓
|
| 55 |
+
MSE Loss (on masked patches only)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Model Configuration
|
| 59 |
+
|
| 60 |
+
| Parameter | Default Value | Description |
|
| 61 |
+
|-----------|---------------|-------------|
|
| 62 |
+
| Image Size | 384×384 | Input image resolution |
|
| 63 |
+
| Patch Size | 16×16 | Size of each patch |
|
| 64 |
+
| Mask Ratio | 0.75 | Fraction of patches to mask |
|
| 65 |
+
| Encoder Depth | 12 layers | Number of transformer blocks |
|
| 66 |
+
| Encoder Dim | 768 | Hidden dimension |
|
| 67 |
+
| Encoder Heads | 8 | Number of attention heads |
|
| 68 |
+
| Decoder Depth | 8 layers | Number of transformer blocks |
|
| 69 |
+
| Decoder Dim | 512 | Hidden dimension |
|
| 70 |
+
| Decoder Heads | 8 | Number of attention heads |
|
| 71 |
+
| MLP Ratio | 4× | MLP expansion ratio (3072) |
|
| 72 |
+
| Dropout | 0.25 | Dropout rate |
|
| 73 |
+
|
| 74 |
+
## 🚀 Getting Started
|
| 75 |
+
|
| 76 |
+
### Prerequisites
|
| 77 |
+
|
| 78 |
+
- Python >= 3.8
|
| 79 |
+
- CUDA-capable GPU (recommended)
|
| 80 |
+
- 16GB+ RAM
|
| 81 |
+
|
| 82 |
+
### Installation
|
| 83 |
+
|
| 84 |
+
1. Clone the repository:
|
| 85 |
+
```bash
|
| 86 |
+
git clone https://github.com/adelelsayed/mae.git
|
| 87 |
+
cd mae
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
2. Install dependencies:
|
| 91 |
+
```bash
|
| 92 |
+
pip install -r requirements.txt
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### Dataset Preparation
|
| 96 |
+
|
| 97 |
+
This project is configured for the **CheXpert dataset**. To use it:
|
| 98 |
+
|
| 99 |
+
1. Download CheXpert-v1.0-small from [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/)
|
| 100 |
+
2. Update paths in `configs/configs.py`:
|
| 101 |
+
- `root`: Base directory for your data
|
| 102 |
+
- `zip_path`: Path to zipped dataset (optional, for faster loading)
|
| 103 |
+
- `csv`: Path to training CSV
|
| 104 |
+
- `train_csv`, `val_csv`, `test_csv`: Split CSV files
|
| 105 |
+
|
| 106 |
+
## 📊 Usage
|
| 107 |
+
|
| 108 |
+
### Training
|
| 109 |
+
|
| 110 |
+
Start training from scratch:
|
| 111 |
+
```bash
|
| 112 |
+
python trainer/trainer.py
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
The trainer will:
|
| 116 |
+
- Automatically create checkpoint and log directories
|
| 117 |
+
- Resume from the last checkpoint if available
|
| 118 |
+
- Log training/validation metrics to text files
|
| 119 |
+
- Save plots every 10 epochs
|
| 120 |
+
- Save best model based on validation loss
|
| 121 |
+
|
| 122 |
+
### Training Configuration
|
| 123 |
+
|
| 124 |
+
Edit `configs/configs.py` to customize training:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
mae_config = {
|
| 128 |
+
# Training hyperparameters
|
| 129 |
+
"lr": 1e-4, # Learning rate
|
| 130 |
+
"warmup": 5, # Warmup epochs
|
| 131 |
+
"weight_decay": 5e-4, # AdamW weight decay
|
| 132 |
+
"num_epochs": 200, # Total training epochs
|
| 133 |
+
"batch_size": 96, # Batch size
|
| 134 |
+
"accumulation": 1, # Gradient accumulation steps
|
| 135 |
+
|
| 136 |
+
# Model architecture
|
| 137 |
+
"mask_ratio": 0.75, # Masking ratio
|
| 138 |
+
"encoder_depth": 12, # Encoder layers
|
| 139 |
+
"decoder_depth": 8, # Decoder layers
|
| 140 |
+
|
| 141 |
+
# Paths
|
| 142 |
+
"checkpoints": "/path/to/checkpoints",
|
| 143 |
+
"logdir": "/path/to/logs",
|
| 144 |
+
...
|
| 145 |
+
}
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Monitoring Training
|
| 149 |
+
|
| 150 |
+
Training logs are saved in three files:
|
| 151 |
+
- `training_log.txt`: Training metrics per epoch
|
| 152 |
+
- `val_log.txt`: Validation metrics per epoch
|
| 153 |
+
- `test_log.txt`: Test set evaluation results
|
| 154 |
+
|
| 155 |
+
Metrics plots are saved every 10 epochs in `{logdir}/{epoch}/metrics.png`
|
| 156 |
+
|
| 157 |
+
### Evaluation
|
| 158 |
+
|
| 159 |
+
The project includes a test method in the trainer. To evaluate:
|
| 160 |
+
```python
|
| 161 |
+
from trainer.utils import MAETrainer
|
| 162 |
+
from configs.configs import mae_config
|
| 163 |
+
|
| 164 |
+
trainer = MAETrainer(mae_config)
|
| 165 |
+
trainer.test()
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## 📁 Project Structure
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
mae/
|
| 172 |
+
├── configs/
|
| 173 |
+
│ ├── __init__.py
|
| 174 |
+
│ └── configs.py # Training configuration
|
| 175 |
+
├── data/
|
| 176 |
+
│ ├── __init__.py
|
| 177 |
+
│ ├── dataset.py # CheXpert dataset loader
|
| 178 |
+
│ └── splitter.py # Dataset splitting utilities
|
| 179 |
+
├── loss/
|
| 180 |
+
│ ├── __init__.py
|
| 181 |
+
│ └── mae_loss.py # MAE reconstruction loss
|
| 182 |
+
├── models/
|
| 183 |
+
│ ├── __init__.py
|
| 184 |
+
│ └── mae.py # MAE architecture
|
| 185 |
+
├── trainer/
|
| 186 |
+
│ ├── __init__.py
|
| 187 |
+
│ ├── trainer.py # Main training script
|
| 188 |
+
│ └── utils.py # Training utilities
|
| 189 |
+
├── notebooks/
|
| 190 |
+
│ └── chexpert_mae.ipynb # Jupyter notebook for experiments
|
| 191 |
+
├── training logs/ # Logged metrics and plots
|
| 192 |
+
├── weights/ # Model checkpoints
|
| 193 |
+
├── results/ # Evaluation results
|
| 194 |
+
├── requirements.txt # Python dependencies
|
| 195 |
+
├── LICENSE # Project license
|
| 196 |
+
└── README.md # This file
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
## 🔧 Components
|
| 200 |
+
|
| 201 |
+
### Dataset (`data/dataset.py`)
|
| 202 |
+
|
| 203 |
+
- **OptimizedZipReader**: Fast ZIP file reading with LRU caching
|
| 204 |
+
- **CheXpertDataset**: PyTorch dataset for CheXpert chest X-rays
|
| 205 |
+
- 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc.
|
| 206 |
+
- Albumentations-based augmentation pipeline
|
| 207 |
+
- Class-balanced sampling support
|
| 208 |
+
- Frontal/lateral view filtering
|
| 209 |
+
|
| 210 |
+
### Model (`models/mae.py`)
|
| 211 |
+
|
| 212 |
+
- **Patchify/Unpatchify**: Image-to-patch conversion utilities
|
| 213 |
+
- **Random Masking**: Stochastic patch masking with restore indices
|
| 214 |
+
- **PositionalEncoding**: Learnable position embeddings
|
| 215 |
+
- **TransformerBlock**: Multi-head self-attention + MLP
|
| 216 |
+
- **MAEEncoder**: Processes visible patches only
|
| 217 |
+
- **MAEDecoder**: Reconstructs full image with mask tokens
|
| 218 |
+
- **MaskedAutoEncoder**: Complete MAE model
|
| 219 |
+
|
| 220 |
+
### Loss (`loss/mae_loss.py`)
|
| 221 |
+
|
| 222 |
+
Mean Squared Error (MSE) computed only on masked patches:
|
| 223 |
+
```python
|
| 224 |
+
loss = ((pred - target) ** 2 * mask).sum() / mask.sum()
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
### Trainer (`trainer/utils.py`)
|
| 228 |
+
|
| 229 |
+
- **MAETrainer**: Complete training pipeline
|
| 230 |
+
- Mixed precision training (AMP)
|
| 231 |
+
- Gradient clipping and accumulation
|
| 232 |
+
- Learning rate scheduling (warmup → cosine)
|
| 233 |
+
- Automatic checkpointing
|
| 234 |
+
- Multi-file logging (train/val/test)
|
| 235 |
+
- Live metric monitoring with tqdm
|
| 236 |
+
- Periodic metric visualization
|
| 237 |
+
|
| 238 |
+
## 🎯 CheXpert Pathologies
|
| 239 |
+
|
| 240 |
+
The dataset includes 14 chest X-ray findings:
|
| 241 |
+
|
| 242 |
+
1. No Finding
|
| 243 |
+
2. Enlarged Cardiomediastinum
|
| 244 |
+
3. Cardiomegaly
|
| 245 |
+
4. Lung Opacity
|
| 246 |
+
5. Lung Lesion
|
| 247 |
+
6. Edema
|
| 248 |
+
7. Consolidation
|
| 249 |
+
8. Pneumonia
|
| 250 |
+
9. Atelectasis
|
| 251 |
+
10. Pneumothorax
|
| 252 |
+
11. Pleural Effusion
|
| 253 |
+
12. Pleural Other
|
| 254 |
+
13. Fracture
|
| 255 |
+
14. Support Devices
|
| 256 |
+
|
| 257 |
+
## 📈 Training Tips
|
| 258 |
+
|
| 259 |
+
1. **Learning Rate**: Start with 1e-4, use warmup for stability
|
| 260 |
+
2. **Batch Size**: Maximize based on GPU memory (96 works well on 40GB GPUs)
|
| 261 |
+
3. **Gradient Accumulation**: Use if batch size is limited by memory
|
| 262 |
+
4. **Mixed Precision**: Enabled by default for faster training
|
| 263 |
+
5. **Masking Ratio**: 75% is standard, higher ratios increase difficulty
|
| 264 |
+
6. **Resume Training**: Model automatically resumes from last checkpoint
|
| 265 |
+
|
| 266 |
+
## 🔬 Use Cases
|
| 267 |
+
|
| 268 |
+
### Pre-training for Downstream Tasks
|
| 269 |
+
Use the trained encoder as a feature extractor:
|
| 270 |
+
```python
|
| 271 |
+
from models.mae import MaskedAutoEncoder
|
| 272 |
+
|
| 273 |
+
# Load pre-trained model
|
| 274 |
+
mae = MaskedAutoEncoder()
|
| 275 |
+
mae.load_state_dict(torch.load("best_mae.pth")["model"])
|
| 276 |
+
|
| 277 |
+
# Use encoder for feature extraction
|
| 278 |
+
encoder = mae.encoder
|
| 279 |
+
features, _, _, _ = encoder(images)
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
### Fine-tuning on Classification
|
| 283 |
+
Add a classification head to the encoder for supervised tasks.
|
| 284 |
+
|
| 285 |
+
### Anomaly Detection
|
| 286 |
+
Reconstruction error can indicate abnormalities in medical images.
|
| 287 |
+
|
| 288 |
+
## 📊 Performance Optimization
|
| 289 |
+
|
| 290 |
+
This implementation includes several optimizations:
|
| 291 |
+
|
| 292 |
+
- **Efficient ZIP Reading**: Avoids extracting files to disk
|
| 293 |
+
- **LRU Cache**: Keeps frequently accessed images in memory
|
| 294 |
+
- **Persistent Workers**: Reduces data loading overhead
|
| 295 |
+
- **Mixed Precision**: 2× faster training with minimal quality loss
|
| 296 |
+
- **Gradient Checkpointing**: Reduces memory usage (if enabled)
|
| 297 |
+
- **CUDA Memory Management**: Proper cache clearing and synchronization
|
| 298 |
+
|
| 299 |
+
## 🤝 Contributing
|
| 300 |
+
|
| 301 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
| 302 |
+
|
| 303 |
+
## 📄 License
|
| 304 |
+
|
| 305 |
+
This project is licensed under the terms specified in the LICENSE file.
|
| 306 |
+
|
| 307 |
+
## 📚 References
|
| 308 |
+
|
| 309 |
+
1. **Masked Autoencoders Are Scalable Vision Learners**
|
| 310 |
+
He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. (2022)
|
| 311 |
+
[arXiv:2111.06377](https://arxiv.org/abs/2111.06377)
|
| 312 |
+
|
| 313 |
+
2. **CheXpert: A Large Chest Radiograph Dataset**
|
| 314 |
+
Irvin, J., et al. (2019)
|
| 315 |
+
[Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/)
|
| 316 |
+
|
| 317 |
+
## 🙏 Acknowledgments
|
| 318 |
+
|
| 319 |
+
- Original MAE paper by Meta AI Research
|
| 320 |
+
- CheXpert dataset by Stanford ML Group
|
| 321 |
+
- PyTorch and Albumentations communities
|
| 322 |
+
|
| 323 |
+
## 📧 Contact
|
| 324 |
+
|
| 325 |
+
For questions or issues, please open an issue on GitHub or contact the maintainer.
|
| 326 |
+
|
| 327 |
+
---
|
| 328 |
+
|
| 329 |
+
**Note**: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance.
|
configs/__init__.py
ADDED
|
File without changes
|
configs/configs.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
root = "/content/drive/MyDrive"
|
| 4 |
+
mae_config={
|
| 5 |
+
"lr":1e-4,
|
| 6 |
+
"warmup":5,
|
| 7 |
+
"weight_decay":5e-4,
|
| 8 |
+
"num_epochs":200,
|
| 9 |
+
"num_classes":14,
|
| 10 |
+
"zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"),
|
| 11 |
+
"resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"),
|
| 12 |
+
"logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs"),
|
| 13 |
+
"checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),
|
| 14 |
+
"datadir":root,
|
| 15 |
+
"lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"),
|
| 16 |
+
"csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"),
|
| 17 |
+
"batch_size":96,
|
| 18 |
+
"device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
| 19 |
+
"accumulation":1,
|
| 20 |
+
"dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),os.path.join(root,"CheXpert-v1.0-small","maelogs")],
|
| 21 |
+
"train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"),
|
| 22 |
+
"val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"),
|
| 23 |
+
"test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
|
| 24 |
+
,"channels":1,"mask_ratio":0.75,"dropout":0.25,"img_size":384,"encoder_dim":768,
|
| 25 |
+
"mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8,
|
| 26 |
+
"decoder_head":8,"patch_size":16
|
| 27 |
+
}
|
data/__init__.py
ADDED
|
File without changes
|
data/dataset.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standard library
|
| 2 |
+
import os
|
| 3 |
+
import io
|
| 4 |
+
import zipfile
|
| 5 |
+
import pickle
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Data handling
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
# PyTorch
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
|
| 16 |
+
# Image processing
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import cv2
|
| 19 |
+
|
| 20 |
+
# Augmentations
|
| 21 |
+
import albumentations as A
|
| 22 |
+
from albumentations.pytorch import ToTensorV2
|
| 23 |
+
|
| 24 |
+
# Progress bar (for precompute_all_masks)
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
class OptimizedZipReader:
|
| 28 |
+
"""
|
| 29 |
+
Fast ZIP file reader with LRU caching
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, zip_path, cache_size=1000):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
zip_path: Path to ZIP file
|
| 35 |
+
cache_size: Number of images to cache in RAM
|
| 36 |
+
"""
|
| 37 |
+
self.zip_path = zip_path
|
| 38 |
+
self.cache_size = cache_size
|
| 39 |
+
self._zip_file = None # Will be lazily initialized
|
| 40 |
+
self._name_to_info = None
|
| 41 |
+
|
| 42 |
+
# Cache
|
| 43 |
+
self._cache = {}
|
| 44 |
+
self._cache_order = []
|
| 45 |
+
self._hits = 0
|
| 46 |
+
self._misses = 0
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def zip_file(self):
|
| 50 |
+
"""Lazy initialization of ZIP file handle"""
|
| 51 |
+
if self._zip_file is None:
|
| 52 |
+
print(f"Opening ZIP file: {self.zip_path}")
|
| 53 |
+
self._zip_file = zipfile.ZipFile(self.zip_path, 'r', allowZip64=True)
|
| 54 |
+
|
| 55 |
+
# Build index on first access
|
| 56 |
+
print("Building ZIP index...")
|
| 57 |
+
self._name_to_info = {
|
| 58 |
+
info.filename: info
|
| 59 |
+
for info in self._zip_file.infolist()
|
| 60 |
+
}
|
| 61 |
+
print(f"✓ Indexed {len(self._name_to_info)} files")
|
| 62 |
+
|
| 63 |
+
return self._zip_file
|
| 64 |
+
|
| 65 |
+
def read_image(self, path):
|
| 66 |
+
"""
|
| 67 |
+
Read image data with automatic caching
|
| 68 |
+
|
| 69 |
+
Returns: bytes (image file data)
|
| 70 |
+
"""
|
| 71 |
+
# Check cache first
|
| 72 |
+
if path in self._cache:
|
| 73 |
+
self._hits += 1
|
| 74 |
+
return self._cache[path]
|
| 75 |
+
|
| 76 |
+
# Cache miss - read from ZIP (this triggers lazy initialization)
|
| 77 |
+
self._misses += 1
|
| 78 |
+
img_data = self.zip_file.read(path) # Uses property getter
|
| 79 |
+
|
| 80 |
+
# Add to cache with LRU eviction
|
| 81 |
+
if len(self._cache) >= self.cache_size:
|
| 82 |
+
oldest = self._cache_order.pop(0)
|
| 83 |
+
del self._cache[oldest]
|
| 84 |
+
|
| 85 |
+
self._cache[path] = img_data
|
| 86 |
+
self._cache_order.append(path)
|
| 87 |
+
|
| 88 |
+
return img_data
|
| 89 |
+
|
| 90 |
+
def get_cache_stats(self):
|
| 91 |
+
"""Return cache hit rate statistics"""
|
| 92 |
+
total = self._hits + self._misses
|
| 93 |
+
hit_rate = self._hits / total * 100 if total > 0 else 0
|
| 94 |
+
return {
|
| 95 |
+
'hits': self._hits,
|
| 96 |
+
'misses': self._misses,
|
| 97 |
+
'hit_rate': f"{hit_rate:.2f}%",
|
| 98 |
+
'cache_size': len(self._cache)
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def close(self):
|
| 102 |
+
"""Close ZIP file and clear cache"""
|
| 103 |
+
if self._zip_file is not None:
|
| 104 |
+
self._zip_file.close()
|
| 105 |
+
self._zip_file = None
|
| 106 |
+
self._cache.clear()
|
| 107 |
+
self._cache_order.clear()
|
| 108 |
+
self._name_to_info = None
|
| 109 |
+
|
| 110 |
+
class CheXpertDataset(Dataset):
|
| 111 |
+
"""
|
| 112 |
+
CheXpert Dataset class
|
| 113 |
+
|
| 114 |
+
NEW: Returns 3-channel images: (img, img*mask, mask)
|
| 115 |
+
- Channel 0: Original grayscale image
|
| 116 |
+
- Channel 1: Masked image (lung region only)
|
| 117 |
+
- Channel 2: Binary lung mask
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
csv_path (str): Path to the CSV file (train.csv or valid.csv)
|
| 121 |
+
root_dir (str): Root directory of the CheXpert dataset
|
| 122 |
+
image_size (int): Target image size (default: 384)
|
| 123 |
+
augment (bool): Whether to apply augmentations (default: False)
|
| 124 |
+
use_frontal_only (bool): If True, only use frontal view images (default: True)
|
| 125 |
+
fill_uncertain (str): How to handle uncertain labels: 'zeros', 'ones', 'ignore' (default: 'zeros')
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# 14 pathology classes in CheXpert
|
| 129 |
+
PATHOLOGIES = [
|
| 130 |
+
'No Finding',
|
| 131 |
+
'Enlarged Cardiomediastinum',
|
| 132 |
+
'Cardiomegaly',
|
| 133 |
+
'Lung Opacity',
|
| 134 |
+
'Lung Lesion',
|
| 135 |
+
'Edema',
|
| 136 |
+
'Consolidation',
|
| 137 |
+
'Pneumonia',
|
| 138 |
+
'Atelectasis',
|
| 139 |
+
'Pneumothorax',
|
| 140 |
+
'Pleural Effusion',
|
| 141 |
+
'Pleural Other',
|
| 142 |
+
'Fracture',
|
| 143 |
+
'Support Devices'
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
csv_path,
|
| 149 |
+
root_dir,
|
| 150 |
+
image_size=384,
|
| 151 |
+
augment=False,
|
| 152 |
+
use_frontal_only=False,
|
| 153 |
+
fill_uncertain='ignore',
|
| 154 |
+
lmdb_path=None,
|
| 155 |
+
zip_path=None,
|
| 156 |
+
zip_cache_size=1000,
|
| 157 |
+
mask_dir=None, domask=False
|
| 158 |
+
):
|
| 159 |
+
self.root_dir = root_dir
|
| 160 |
+
self.image_size = image_size
|
| 161 |
+
self.augment = augment
|
| 162 |
+
self.fill_uncertain = fill_uncertain
|
| 163 |
+
self.env =None #lmdb.open(lmdb_path, readonly=True, lock=False) if lmdb_path else None
|
| 164 |
+
self._zip_path = zip_path
|
| 165 |
+
self._zip_cache_size = zip_cache_size
|
| 166 |
+
self._zip_reader_instance = None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# Read CSV file
|
| 170 |
+
self.df = pd.read_csv(csv_path)
|
| 171 |
+
for pathology in self.PATHOLOGIES:
|
| 172 |
+
if pathology in self.df.columns:
|
| 173 |
+
self.df[pathology] = pd.to_numeric(self.df[pathology], errors='coerce')
|
| 174 |
+
|
| 175 |
+
# Filter for frontal views only if specified
|
| 176 |
+
if use_frontal_only:
|
| 177 |
+
self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
|
| 178 |
+
|
| 179 |
+
# Handle uncertain labels (-1 values)
|
| 180 |
+
self._process_uncertain_labels()
|
| 181 |
+
|
| 182 |
+
# Setup augmentations
|
| 183 |
+
self.train_transform = self._get_train_transforms()
|
| 184 |
+
self.val_transform = self._get_val_transforms()
|
| 185 |
+
|
| 186 |
+
print(f"Loaded {len(self.df)} images from {csv_path}")
|
| 187 |
+
print(f"Image size: {image_size}x{image_size}")
|
| 188 |
+
print(f"Augmentation: {augment}")
|
| 189 |
+
print(f"Uncertain labels filled with: {fill_uncertain}")
|
| 190 |
+
|
| 191 |
+
if mask_dir and domask:
|
| 192 |
+
self.precompute_all_masks(mask_dir)
|
| 193 |
+
|
| 194 |
+
# Run this ONCE before training
|
| 195 |
+
def precompute_all_masks(self, save_dir):
|
| 196 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 197 |
+
for idx in tqdm(range(len(self))):
|
| 198 |
+
img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
|
| 199 |
+
part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
|
| 200 |
+
if self.zip_reader:
|
| 201 |
+
# Read image data from ZIP (no extraction!)
|
| 202 |
+
img_data = self.zip_reader.read_image(part_path)
|
| 203 |
+
|
| 204 |
+
# Open image from bytes in memory
|
| 205 |
+
image = Image.open(io.BytesIO(img_data)).convert('L')
|
| 206 |
+
else:
|
| 207 |
+
image = Image.open(img_path).convert('L')
|
| 208 |
+
|
| 209 |
+
image = np.array(image)
|
| 210 |
+
|
| 211 |
+
mask = chexpert_medsam_mask(image)
|
| 212 |
+
mask_path = os.path.join(save_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
|
| 213 |
+
os.makedirs(os.path.dirname(mask_path), exist_ok=True)
|
| 214 |
+
torch.save(mask, mask_path)
|
| 215 |
+
@property
|
| 216 |
+
def zip_reader(self):
|
| 217 |
+
"""
|
| 218 |
+
Lazy property getter for ZIP reader
|
| 219 |
+
|
| 220 |
+
The ZIP file is only opened when first accessed, not during __init__.
|
| 221 |
+
This is useful when:
|
| 222 |
+
- Creating multiple dataset objects but only using some
|
| 223 |
+
- Saving memory during dataset setup
|
| 224 |
+
- Working with multiprocessing (each worker creates its own)
|
| 225 |
+
"""
|
| 226 |
+
if self._zip_reader_instance is None and self._zip_path is not None:
|
| 227 |
+
self._zip_reader_instance = OptimizedZipReader(
|
| 228 |
+
self._zip_path,
|
| 229 |
+
cache_size=self._zip_cache_size
|
| 230 |
+
)
|
| 231 |
+
return self._zip_reader_instance
|
| 232 |
+
|
| 233 |
+
def _load_and_cache_image(self, img_path, idx):
|
| 234 |
+
"""
|
| 235 |
+
Load image with automatic resizing and caching.
|
| 236 |
+
If resized version exists, load it. Otherwise, resize, save, and load.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
img_path (str): Original image path from CSV
|
| 240 |
+
idx (int): Index for tracking
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
np.ndarray: Loaded image (grayscale)
|
| 244 |
+
"""
|
| 245 |
+
# Create cache directory structure
|
| 246 |
+
cache_dir = Path(self.root_dir) #/ f"cache_{self.image_size}"
|
| 247 |
+
|
| 248 |
+
# Preserve the relative path structure in cache
|
| 249 |
+
path_parts = list(Path(img_path).parts)
|
| 250 |
+
path_parts[-1]=f"{self.image_size}_{path_parts[-1]}"
|
| 251 |
+
relative_path = Path(*path_parts)
|
| 252 |
+
cached_path =relative_path.with_suffix('.jpg')
|
| 253 |
+
|
| 254 |
+
# Check if cached version exists
|
| 255 |
+
if cached_path.exists():
|
| 256 |
+
# Load cached image
|
| 257 |
+
image = Image.open(cached_path).convert('L')
|
| 258 |
+
image = np.array(image)
|
| 259 |
+
|
| 260 |
+
# Verify it's the correct size
|
| 261 |
+
if image.shape[0] == self.image_size and image.shape[1] == self.image_size:
|
| 262 |
+
return image
|
| 263 |
+
|
| 264 |
+
# Cache doesn't exist or wrong size - load original
|
| 265 |
+
original_path = img_path
|
| 266 |
+
image = Image.open(original_path).convert('L')
|
| 267 |
+
|
| 268 |
+
# Check if original is already target size
|
| 269 |
+
width, height = image.size
|
| 270 |
+
|
| 271 |
+
if width == self.image_size and height == self.image_size:
|
| 272 |
+
# Already correct size, just convert to array
|
| 273 |
+
return np.array(image)
|
| 274 |
+
|
| 275 |
+
# Resize image
|
| 276 |
+
image_resized = image.resize(
|
| 277 |
+
(self.image_size, self.image_size),
|
| 278 |
+
Image.LANCZOS
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Save to cache
|
| 282 |
+
cached_path.parent.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
image_resized.save(cached_path, 'JPEG', quality=95, optimize=True)
|
| 284 |
+
|
| 285 |
+
return np.array(image_resized)
|
| 286 |
+
|
| 287 |
+
def _process_uncertain_labels(self):
|
| 288 |
+
"""Process uncertain labels (-1) based on the chosen strategy."""
|
| 289 |
+
for pathology in self.PATHOLOGIES:
|
| 290 |
+
if pathology in self.df.columns:
|
| 291 |
+
if self.fill_uncertain == 'zeros':
|
| 292 |
+
# Map uncertain (-1) to negative (0)
|
| 293 |
+
self.df[pathology] = self.df[pathology].replace(-1, 0)
|
| 294 |
+
elif self.fill_uncertain == 'ones':
|
| 295 |
+
# Map uncertain (-1) to positive (1)
|
| 296 |
+
self.df[pathology] = self.df[pathology].replace(-1, 1)
|
| 297 |
+
elif self.fill_uncertain == 'ignore':
|
| 298 |
+
# Keep -1 as is (you'll need to handle this in loss function)
|
| 299 |
+
pass
|
| 300 |
+
|
| 301 |
+
# Fill NaN with 0 (negative)
|
| 302 |
+
self.df[pathology] = self.df[pathology].fillna(0)
|
| 303 |
+
|
| 304 |
+
def _get_train_transforms(self):
|
| 305 |
+
"""Get training augmentations suitable for chest X-rays."""
|
| 306 |
+
import cv2
|
| 307 |
+
return A.Compose([
|
| 308 |
+
# Resize to target size
|
| 309 |
+
A.LongestMaxSize(max_size=self.image_size),
|
| 310 |
+
A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
|
| 311 |
+
|
| 312 |
+
# Geometric augmentations (conservative for medical images)
|
| 313 |
+
A.HorizontalFlip(p=0.5),
|
| 314 |
+
A.Affine(
|
| 315 |
+
translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
|
| 316 |
+
scale=(0.9, 1.1),
|
| 317 |
+
rotate=(-10, 10),
|
| 318 |
+
fit_output=False,
|
| 319 |
+
p=0.5
|
| 320 |
+
),
|
| 321 |
+
|
| 322 |
+
# Intensity augmentations
|
| 323 |
+
A.OneOf([
|
| 324 |
+
A.RandomBrightnessContrast(
|
| 325 |
+
brightness_limit=0.2,
|
| 326 |
+
contrast_limit=0.2,
|
| 327 |
+
p=1.0
|
| 328 |
+
),
|
| 329 |
+
A.RandomGamma(gamma_limit=(80, 120), p=1.0),
|
| 330 |
+
A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
|
| 331 |
+
], p=0.5),
|
| 332 |
+
|
| 333 |
+
# Add slight blur to simulate different imaging conditions
|
| 334 |
+
A.OneOf([
|
| 335 |
+
A.GaussianBlur(blur_limit=(3, 5), p=1.0),
|
| 336 |
+
A.MedianBlur(blur_limit=3, p=1.0),
|
| 337 |
+
], p=0.2),
|
| 338 |
+
|
| 339 |
+
# Add noise
|
| 340 |
+
A.GaussNoise(p=0.2),
|
| 341 |
+
|
| 342 |
+
# Normalize to [0, 1]
|
| 343 |
+
A.Normalize(
|
| 344 |
+
mean=[0.5],
|
| 345 |
+
std=[0.5],
|
| 346 |
+
max_pixel_value=255.0
|
| 347 |
+
),
|
| 348 |
+
|
| 349 |
+
ToTensorV2()
|
| 350 |
+
])
|
| 351 |
+
|
| 352 |
+
def _get_val_transforms(self):
|
| 353 |
+
"""Get validation/test transforms (no augmentation)."""
|
| 354 |
+
return A.Compose([
|
| 355 |
+
A.LongestMaxSize(max_size=self.image_size),
|
| 356 |
+
A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
|
| 357 |
+
A.Normalize(
|
| 358 |
+
mean=[0.5],
|
| 359 |
+
std=[0.5],
|
| 360 |
+
max_pixel_value=255.0
|
| 361 |
+
),
|
| 362 |
+
ToTensorV2()
|
| 363 |
+
])
|
| 364 |
+
|
| 365 |
+
def __len__(self):
|
| 366 |
+
return len(self.df)
|
| 367 |
+
|
| 368 |
+
def __del__(self):
|
| 369 |
+
"""Close ZIP when done"""
|
| 370 |
+
if hasattr(self, 'zip_reader'):
|
| 371 |
+
self.zip_reader.close()
|
| 372 |
+
|
| 373 |
+
def __getitem__(self, idx):
|
| 374 |
+
if self.env:
|
| 375 |
+
with self.env.begin() as txn:
|
| 376 |
+
# Retrieve serialized data
|
| 377 |
+
data = txn.get(str(idx).encode())
|
| 378 |
+
sample = pickle.loads(data)
|
| 379 |
+
return sample
|
| 380 |
+
else:
|
| 381 |
+
# Get image path
|
| 382 |
+
img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
|
| 383 |
+
#image = self._load_and_cache_image(img_path, idx)
|
| 384 |
+
# Load image
|
| 385 |
+
#image = Image.open(img_path).convert('L') # Convert to grayscale
|
| 386 |
+
|
| 387 |
+
part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
|
| 388 |
+
if self.zip_reader:
|
| 389 |
+
# Read image data from ZIP (no extraction!)
|
| 390 |
+
img_data = self.zip_reader.read_image(part_path)
|
| 391 |
+
|
| 392 |
+
# Open image from bytes in memory
|
| 393 |
+
image = Image.open(io.BytesIO(img_data)).convert('L')
|
| 394 |
+
else:
|
| 395 |
+
image = Image.open(img_path).convert('L')
|
| 396 |
+
|
| 397 |
+
image = np.array(image)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# Load pre-computed mask
|
| 401 |
+
#mask_path = os.path.join(self.mask_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
|
| 402 |
+
#masked_img = torch.load(mask_path)
|
| 403 |
+
# Apply transforms to BOTH image and mask together
|
| 404 |
+
if self.augment:
|
| 405 |
+
# Augmentation applies to both image and mask
|
| 406 |
+
transformed = self.train_transform(image=image)
|
| 407 |
+
image_transformed = transformed['image'] # (1, H, W) tensor, normalized
|
| 408 |
+
#masked_img=transformed['mask']
|
| 409 |
+
# (H, W) tensor
|
| 410 |
+
else:
|
| 411 |
+
transformed = self.val_transform(image=image)
|
| 412 |
+
image_transformed = transformed['image'] # (1, H, W) tensor, normalized
|
| 413 |
+
#masked_img=transformed['mask']
|
| 414 |
+
|
| 415 |
+
# Expand dimensions to match
|
| 416 |
+
image_1ch = image_transformed # (1, H, W)
|
| 417 |
+
masked_img = image_transformed
|
| 418 |
+
|
| 419 |
+
# Get labels for all pathologies
|
| 420 |
+
labels = []
|
| 421 |
+
for pathology in self.PATHOLOGIES:
|
| 422 |
+
if pathology in self.df.columns:
|
| 423 |
+
label = self.df.iloc[idx][pathology]
|
| 424 |
+
labels.append(float(label) if not pd.isna(label) else 0.0)
|
| 425 |
+
else:
|
| 426 |
+
labels.append(0.0)
|
| 427 |
+
|
| 428 |
+
labels = torch.tensor(labels, dtype=torch.float32)
|
| 429 |
+
|
| 430 |
+
# Get additional metadata
|
| 431 |
+
metadata = {
|
| 432 |
+
'patient_id': self.df.iloc[idx]['Path'].split('/')[2], # Extract patient ID from path
|
| 433 |
+
'study_id': self.df.iloc[idx]['Path'].split('/')[3], # Extract study ID from path
|
| 434 |
+
'view': self.df.iloc[idx]['Frontal/Lateral'],
|
| 435 |
+
'sex': self.df.iloc[idx]['Sex'] if 'Sex' in self.df.columns else 'Unknown',
|
| 436 |
+
'age': self.df.iloc[idx]['Age'] if 'Age' in self.df.columns else -1,
|
| 437 |
+
'path': self.df.iloc[idx]['Path']
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
return {
|
| 441 |
+
'image': image_1ch,
|
| 442 |
+
'labels': labels,
|
| 443 |
+
'metadata': metadata
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
def get_label_names(self):
|
| 447 |
+
"""Return list of pathology label names."""
|
| 448 |
+
return self.PATHOLOGIES
|
| 449 |
+
|
| 450 |
+
def get_label_distribution(self):
|
| 451 |
+
"""Get distribution of positive labels for each pathology."""
|
| 452 |
+
distribution = {}
|
| 453 |
+
for pathology in self.PATHOLOGIES:
|
| 454 |
+
if pathology in self.df.columns:
|
| 455 |
+
positive_count = (self.df[pathology] == 1.0).sum()
|
| 456 |
+
distribution[pathology] = {
|
| 457 |
+
'positive': int(positive_count),
|
| 458 |
+
'percentage': round(positive_count / len(self.df) * 100, 2)
|
| 459 |
+
}
|
| 460 |
+
return distribution
|
| 461 |
+
|
| 462 |
+
def get_class_weights(self):
|
| 463 |
+
"""
|
| 464 |
+
OPTIMIZED: Vectorized class weights calculation
|
| 465 |
+
"""
|
| 466 |
+
weights = []
|
| 467 |
+
for pathology in self.PATHOLOGIES:
|
| 468 |
+
if pathology in self.df.columns:
|
| 469 |
+
# Vectorized counting (much faster than iterating)
|
| 470 |
+
values = self.df[pathology].values
|
| 471 |
+
pos = np.sum(values == 1.0)
|
| 472 |
+
neg = np.sum(values == 0.0)
|
| 473 |
+
weight = neg / pos if pos > 0 else 1.0
|
| 474 |
+
weights.append(weight)
|
| 475 |
+
return torch.tensor(weights, dtype=torch.float32)
|
| 476 |
+
|
| 477 |
+
def get_sample_weights(self):
|
| 478 |
+
"""
|
| 479 |
+
OPTIMIZED: Vectorized sample weights calculation
|
| 480 |
+
|
| 481 |
+
Performance: ~1000x faster than original
|
| 482 |
+
Original: 15-30 seconds for 200k samples
|
| 483 |
+
This: 0.01-0.05 seconds for 200k samples
|
| 484 |
+
"""
|
| 485 |
+
# Get class weights as numpy array
|
| 486 |
+
class_weights = self.get_class_weights().numpy()
|
| 487 |
+
|
| 488 |
+
# Get all labels as numpy array in ONE vectorized operation
|
| 489 |
+
labels_array = self.df[self.PATHOLOGIES].values.astype(np.float32)
|
| 490 |
+
|
| 491 |
+
# Create weighted labels matrix: where label=1, use class_weight, else -inf
|
| 492 |
+
# Shape: (n_samples, n_classes)
|
| 493 |
+
weighted_labels = np.where(
|
| 494 |
+
labels_array == 1.0,
|
| 495 |
+
class_weights,
|
| 496 |
+
-np.inf # Use -inf instead of 0 so max will only consider positive labels
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# For each sample, find the maximum class weight of its positive labels
|
| 500 |
+
# If a sample has no positive labels, max will be -inf, which we'll replace with 1.0
|
| 501 |
+
sample_weights = np.max(weighted_labels, axis=1)
|
| 502 |
+
sample_weights = np.where(
|
| 503 |
+
np.isinf(sample_weights),
|
| 504 |
+
1.0, # Samples with no positive labels get weight 1.0
|
| 505 |
+
sample_weights
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
return torch.tensor(sample_weights, dtype=torch.float32)
|
data/splitter.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standard library
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Data handling
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# Machine learning
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
|
| 12 |
+
class CheXpertDataSplitter:
|
| 13 |
+
"""
|
| 14 |
+
Advanced stratified train-validation splitter for CheXpert dataset.
|
| 15 |
+
Handles:
|
| 16 |
+
- Patient-level splitting (prevents data leakage)
|
| 17 |
+
- Multi-label stratification
|
| 18 |
+
- Class imbalance awareness
|
| 19 |
+
- Study-level grouping
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
PATHOLOGIES = [
|
| 23 |
+
'No Finding',
|
| 24 |
+
'Enlarged Cardiomediastinum',
|
| 25 |
+
'Cardiomegaly',
|
| 26 |
+
'Lung Opacity',
|
| 27 |
+
'Lung Lesion',
|
| 28 |
+
'Edema',
|
| 29 |
+
'Consolidation',
|
| 30 |
+
'Pneumonia',
|
| 31 |
+
'Atelectasis',
|
| 32 |
+
'Pneumothorax',
|
| 33 |
+
'Pleural Effusion',
|
| 34 |
+
'Pleural Other',
|
| 35 |
+
'Fracture',
|
| 36 |
+
'Support Devices'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42,
|
| 40 |
+
use_frontal_only=True, fill_uncertain='zeros',root=None):
|
| 41 |
+
"""
|
| 42 |
+
Initialize the splitter.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
csv_path: Path to train.csv from CheXpert-small
|
| 46 |
+
val_size: Validation set proportion (default: 0.15)
|
| 47 |
+
random_state: Random seed for reproducibility
|
| 48 |
+
use_frontal_only: Use only frontal view images
|
| 49 |
+
fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore')
|
| 50 |
+
"""
|
| 51 |
+
self.csv_path = csv_path
|
| 52 |
+
self.val_size = val_size
|
| 53 |
+
self.test_size = test_size
|
| 54 |
+
self.random_state = random_state
|
| 55 |
+
self.use_frontal_only = use_frontal_only
|
| 56 |
+
self.fill_uncertain = fill_uncertain
|
| 57 |
+
self.root=root
|
| 58 |
+
|
| 59 |
+
print("=" * 80)
|
| 60 |
+
print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias")
|
| 61 |
+
print("=" * 80)
|
| 62 |
+
|
| 63 |
+
def load_and_preprocess(self):
|
| 64 |
+
"""Load and preprocess the dataset."""
|
| 65 |
+
print("\n[1/5] Loading data...")
|
| 66 |
+
self.df = pd.read_csv(self.csv_path)
|
| 67 |
+
print(f" Loaded {len(self.df)} images")
|
| 68 |
+
|
| 69 |
+
#self.df=self.df[self.df["Path"].apply(os.path.exists)]
|
| 70 |
+
|
| 71 |
+
# Filter for frontal views only
|
| 72 |
+
if self.use_frontal_only:
|
| 73 |
+
initial_count = len(self.df)
|
| 74 |
+
self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
|
| 75 |
+
print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)")
|
| 76 |
+
|
| 77 |
+
# Extract patient and study IDs from path
|
| 78 |
+
print("\n[2/5] Extracting patient and study IDs...")
|
| 79 |
+
self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2])
|
| 80 |
+
self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3])
|
| 81 |
+
|
| 82 |
+
n_patients = self.df['patient_id'].nunique()
|
| 83 |
+
n_studies = self.df['study_id'].nunique()
|
| 84 |
+
print(f" Unique patients: {n_patients}")
|
| 85 |
+
print(f" Unique studies: {n_studies}")
|
| 86 |
+
print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}")
|
| 87 |
+
|
| 88 |
+
# Process uncertain labels
|
| 89 |
+
print("\n[3/5] Processing uncertain labels...")
|
| 90 |
+
self._process_uncertain_labels()
|
| 91 |
+
|
| 92 |
+
return self.df
|
| 93 |
+
|
| 94 |
+
def _process_uncertain_labels(self):
|
| 95 |
+
"""Process uncertain labels (-1) based on the chosen strategy."""
|
| 96 |
+
for pathology in self.PATHOLOGIES:
|
| 97 |
+
if pathology in self.df.columns:
|
| 98 |
+
uncertain_count = (self.df[pathology] == -1).sum()
|
| 99 |
+
|
| 100 |
+
if self.fill_uncertain == 'zeros':
|
| 101 |
+
self.df[pathology] = self.df[pathology].replace(-1, 0)
|
| 102 |
+
elif self.fill_uncertain == 'ones':
|
| 103 |
+
self.df[pathology] = self.df[pathology].replace(-1, 1)
|
| 104 |
+
elif self.fill_uncertain == 'ignore':
|
| 105 |
+
pass # Keep -1 as is
|
| 106 |
+
|
| 107 |
+
# Fill NaN with 0
|
| 108 |
+
self.df[pathology] = self.df[pathology].fillna(0)
|
| 109 |
+
|
| 110 |
+
print(f" Uncertain labels strategy: {self.fill_uncertain}")
|
| 111 |
+
|
| 112 |
+
def create_stratification_groups(self):
|
| 113 |
+
"""
|
| 114 |
+
Create stratification groups based on multi-label combinations.
|
| 115 |
+
Uses patient-level aggregation to prevent data leakage.
|
| 116 |
+
"""
|
| 117 |
+
print("\n[4/5] Creating stratification groups (patient-level)...")
|
| 118 |
+
|
| 119 |
+
# Group by patient and aggregate labels
|
| 120 |
+
patient_groups = self.df.groupby('patient_id').agg({
|
| 121 |
+
**{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns},
|
| 122 |
+
'study_id': 'first', # Keep one study_id for reference
|
| 123 |
+
'Sex': 'first',
|
| 124 |
+
'Age': 'first'
|
| 125 |
+
}).reset_index()
|
| 126 |
+
|
| 127 |
+
# Create label signature for each patient
|
| 128 |
+
# This is a binary string representing which conditions are present
|
| 129 |
+
def create_label_signature(row):
|
| 130 |
+
signature = []
|
| 131 |
+
for pathology in self.PATHOLOGIES:
|
| 132 |
+
if pathology in patient_groups.columns:
|
| 133 |
+
signature.append(str(int(row[pathology])))
|
| 134 |
+
return ''.join(signature)
|
| 135 |
+
|
| 136 |
+
patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1)
|
| 137 |
+
|
| 138 |
+
# For rare combinations, group them together
|
| 139 |
+
signature_counts = patient_groups['label_signature'].value_counts()
|
| 140 |
+
rare_threshold = max(5, int(len(patient_groups) * 0.001)) # At least 5 or 0.1%
|
| 141 |
+
|
| 142 |
+
def get_stratification_group(signature):
|
| 143 |
+
if signature_counts[signature] < rare_threshold:
|
| 144 |
+
return 'RARE_COMBINATION'
|
| 145 |
+
return signature
|
| 146 |
+
|
| 147 |
+
patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group)
|
| 148 |
+
|
| 149 |
+
# Print distribution statistics
|
| 150 |
+
print(f"\n Patient-level label distribution:")
|
| 151 |
+
for pathology in self.PATHOLOGIES:
|
| 152 |
+
if pathology in patient_groups.columns:
|
| 153 |
+
positive_count = (patient_groups[pathology] == 1).sum()
|
| 154 |
+
percentage = positive_count / len(patient_groups) * 100
|
| 155 |
+
print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)")
|
| 156 |
+
|
| 157 |
+
unique_groups = patient_groups['stratification_group'].nunique()
|
| 158 |
+
print(f"\n Unique stratification groups: {unique_groups}")
|
| 159 |
+
print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}")
|
| 160 |
+
|
| 161 |
+
return patient_groups
|
| 162 |
+
|
| 163 |
+
def perform_split(self, patient_groups):
|
| 164 |
+
"""
|
| 165 |
+
Perform stratified train-validation-test split at patient level.
|
| 166 |
+
"""
|
| 167 |
+
print("\n[5/5] Performing stratified patient-level split...")
|
| 168 |
+
|
| 169 |
+
stratification_labels = patient_groups['stratification_group'].values
|
| 170 |
+
|
| 171 |
+
# ---- train / (val+test) ----
|
| 172 |
+
train_patients, valtest_patients = train_test_split(
|
| 173 |
+
patient_groups['patient_id'].values,
|
| 174 |
+
test_size=self.val_size + self.test_size, # <-- new
|
| 175 |
+
stratify=stratification_labels,
|
| 176 |
+
random_state=self.random_state
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# ---- val / test from the remaining pool ----
|
| 180 |
+
remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values
|
| 181 |
+
val_patients, test_patients = train_test_split(
|
| 182 |
+
valtest_patients,
|
| 183 |
+
test_size=self.test_size / (self.val_size + self.test_size), # <-- proportion of the val+test pool
|
| 184 |
+
stratify=remaining_labels,
|
| 185 |
+
random_state=self.random_state
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
print(f" Train patients: {len(train_patients)}")
|
| 189 |
+
print(f" Val patients: {len(val_patients)}")
|
| 190 |
+
print(f" Test patients: {len(test_patients)}")
|
| 191 |
+
|
| 192 |
+
# Split the full dataframe
|
| 193 |
+
train_df = self.df[self.df['patient_id'].isin(train_patients)].copy()
|
| 194 |
+
val_df = self.df[self.df['patient_id'].isin(val_patients)].copy()
|
| 195 |
+
test_df = self.df[self.df['patient_id'].isin(test_patients)].copy()
|
| 196 |
+
|
| 197 |
+
# ---- leakage check (train vs val vs test) ----
|
| 198 |
+
sets = [('train', train_df), ('val', val_df), ('test', test_df)]
|
| 199 |
+
for i, (name_i, df_i) in enumerate(sets):
|
| 200 |
+
for j, (name_j, df_j) in enumerate(sets[i+1:]):
|
| 201 |
+
overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id']))
|
| 202 |
+
if overlap:
|
| 203 |
+
raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap")
|
| 204 |
+
print("\n No patient overlap – leakage prevented!")
|
| 205 |
+
|
| 206 |
+
return train_df, val_df, test_df
|
| 207 |
+
|
| 208 |
+
def run(self, output_dir='.', save_test=True):
|
| 209 |
+
self.load_and_preprocess()
|
| 210 |
+
patient_groups = self.create_stratification_groups()
|
| 211 |
+
train_df, val_df, test_df = self.perform_split(patient_groups)
|
| 212 |
+
|
| 213 |
+
self.verify_split_quality(train_df, val_df)
|
| 214 |
+
# optional: also verify train vs test (same function works with two dfs)
|
| 215 |
+
print("\n--- Train vs Test distribution check ---")
|
| 216 |
+
self.verify_split_quality(train_df, test_df)
|
| 217 |
+
|
| 218 |
+
train_path, val_path = self.save_splits(train_df, val_df, output_dir)
|
| 219 |
+
if save_test:
|
| 220 |
+
test_path = self.save_test_split(test_df, output_dir)
|
| 221 |
+
else:
|
| 222 |
+
test_path = None
|
| 223 |
+
|
| 224 |
+
print("\n" + "="*80)
|
| 225 |
+
print("Split Complete! (train / val / test)")
|
| 226 |
+
print("="*80)
|
| 227 |
+
return train_path, val_path, test_path
|
| 228 |
+
|
| 229 |
+
def save_test_split(self, test_df, output_dir):
|
| 230 |
+
output_dir = Path(output_dir)
|
| 231 |
+
output_dir.mkdir(exist_ok=True)
|
| 232 |
+
test_path = output_dir / 'test_ready.csv'
|
| 233 |
+
|
| 234 |
+
cols_to_drop = ['patient_id', 'study_id']
|
| 235 |
+
test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns])
|
| 236 |
+
test_clean.to_csv(test_path, index=False)
|
| 237 |
+
|
| 238 |
+
print(f"Test set : {test_path} ({len(test_clean)} images)")
|
| 239 |
+
return test_path
|
| 240 |
+
|
| 241 |
+
def verify_split_quality(self, train_df, val_df):
|
| 242 |
+
"""
|
| 243 |
+
Verify the quality of the split by comparing label distributions.
|
| 244 |
+
"""
|
| 245 |
+
print("\n" + "=" * 80)
|
| 246 |
+
print("Split Quality Verification")
|
| 247 |
+
print("=" * 80)
|
| 248 |
+
|
| 249 |
+
print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}")
|
| 250 |
+
print("-" * 80)
|
| 251 |
+
|
| 252 |
+
max_diff = 0
|
| 253 |
+
for pathology in self.PATHOLOGIES:
|
| 254 |
+
if pathology in train_df.columns:
|
| 255 |
+
train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100
|
| 256 |
+
val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100
|
| 257 |
+
diff = abs(train_pos - val_pos)
|
| 258 |
+
max_diff = max(max_diff, diff)
|
| 259 |
+
|
| 260 |
+
print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%")
|
| 261 |
+
|
| 262 |
+
print("-" * 80)
|
| 263 |
+
print(f"Maximum distribution difference: {max_diff:.2f}%")
|
| 264 |
+
|
| 265 |
+
if max_diff < 2.0:
|
| 266 |
+
print("✓ Excellent stratification (< 2% difference)")
|
| 267 |
+
elif max_diff < 5.0:
|
| 268 |
+
print("✓ Good stratification (< 5% difference)")
|
| 269 |
+
else:
|
| 270 |
+
print("⚠ Warning: Large distribution differences detected")
|
| 271 |
+
|
| 272 |
+
# Check for class imbalance
|
| 273 |
+
print("\n" + "=" * 80)
|
| 274 |
+
print("Class Imbalance Analysis (Train Set)")
|
| 275 |
+
print("=" * 80)
|
| 276 |
+
|
| 277 |
+
imbalance_ratios = []
|
| 278 |
+
for pathology in self.PATHOLOGIES:
|
| 279 |
+
if pathology in train_df.columns:
|
| 280 |
+
pos = (train_df[pathology] == 1).sum()
|
| 281 |
+
neg = (train_df[pathology] == 0).sum()
|
| 282 |
+
if pos > 0:
|
| 283 |
+
ratio = neg / pos
|
| 284 |
+
imbalance_ratios.append(ratio)
|
| 285 |
+
severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High"
|
| 286 |
+
print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]")
|
| 287 |
+
|
| 288 |
+
avg_imbalance = np.mean(imbalance_ratios)
|
| 289 |
+
print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1")
|
| 290 |
+
|
| 291 |
+
def save_splits(self, train_df, val_df, output_dir='.'):
|
| 292 |
+
"""Save train and validation splits to CSV files."""
|
| 293 |
+
output_dir = Path(output_dir)
|
| 294 |
+
output_dir.mkdir(exist_ok=True)
|
| 295 |
+
|
| 296 |
+
train_path = output_dir / 'train_ready.csv'
|
| 297 |
+
val_path = output_dir / 'val_ready.csv'
|
| 298 |
+
|
| 299 |
+
# Remove temporary columns used for splitting
|
| 300 |
+
columns_to_drop = ['patient_id', 'study_id']
|
| 301 |
+
train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns])
|
| 302 |
+
val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns])
|
| 303 |
+
|
| 304 |
+
train_df_clean.to_csv(train_path, index=False)
|
| 305 |
+
val_df_clean.to_csv(val_path, index=False)
|
| 306 |
+
|
| 307 |
+
print("\n" + "=" * 80)
|
| 308 |
+
print("Files Saved Successfully")
|
| 309 |
+
print("=" * 80)
|
| 310 |
+
print(f"Train set: {train_path} ({len(train_df_clean)} images)")
|
| 311 |
+
print(f"Val set: {val_path} ({len(val_df_clean)} images)")
|
| 312 |
+
|
| 313 |
+
return train_path, val_path
|
| 314 |
+
|
| 315 |
+
# Main execution
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
root = "/content/drive/MyDrive"
|
| 318 |
+
# Configuration
|
| 319 |
+
CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv") # Adjust path as needed
|
| 320 |
+
OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small")
|
| 321 |
+
VAL_SIZE = 0.15
|
| 322 |
+
RANDOM_STATE = 42
|
| 323 |
+
USE_FRONTAL_ONLY = True
|
| 324 |
+
FILL_UNCERTAIN = 'zeros' # Options: 'zeros', 'ones', 'ignore'
|
| 325 |
+
|
| 326 |
+
# Create splitter
|
| 327 |
+
splitter = CheXpertDataSplitter(
|
| 328 |
+
csv_path=CHEXPERT_CSV,
|
| 329 |
+
val_size=VAL_SIZE,test_size=VAL_SIZE,
|
| 330 |
+
random_state=RANDOM_STATE,
|
| 331 |
+
use_frontal_only=USE_FRONTAL_ONLY,
|
| 332 |
+
fill_uncertain=FILL_UNCERTAIN,
|
| 333 |
+
root=OUTPUT_DIR
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Run the split
|
| 337 |
+
if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")):
|
| 338 |
+
train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")
|
| 339 |
+
val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")
|
| 340 |
+
test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
|
| 341 |
+
else:
|
| 342 |
+
train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR)
|
| 343 |
+
|
| 344 |
+
print("\nYou can now use these files with your CheXpertDataset class:")
|
| 345 |
+
print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)")
|
| 346 |
+
print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)")
|
| 347 |
+
print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)")
|
loss/__init__.py
ADDED
|
File without changes
|
loss/mae_loss.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
def mae_loss(pred, target, mask):
|
| 5 |
+
# pred/target: (B, N, P), mask: (B, N) with 1=masked
|
| 6 |
+
B, N, P = pred.shape
|
| 7 |
+
mask = mask.unsqueeze(-1).float() # (B, N, 1)
|
| 8 |
+
loss = (pred - target) ** 2
|
| 9 |
+
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
|
| 10 |
+
return loss
|
models/__init__.py
ADDED
|
File without changes
|
models/mae.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
def patchify(x,patch_size=8):
|
| 5 |
+
b,c,h,w=x.shape
|
| 6 |
+
th=h//patch_size
|
| 7 |
+
tw=w//patch_size
|
| 8 |
+
assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size"
|
| 9 |
+
|
| 10 |
+
out=x.reshape(b,c,th,patch_size,tw,patch_size)
|
| 11 |
+
out=out.permute(0,2,4,1,3,5).contiguous()
|
| 12 |
+
out=out.view(b,th*tw,c*(patch_size**2))
|
| 13 |
+
return out
|
| 14 |
+
def unpatchify(x,patch_size=8):
|
| 15 |
+
b,z,p=x.shape
|
| 16 |
+
c=p//(patch_size**2)
|
| 17 |
+
th=int(math.sqrt(z))
|
| 18 |
+
tw=th
|
| 19 |
+
h=th*patch_size
|
| 20 |
+
w=tw*patch_size
|
| 21 |
+
x=x.view(b,th,tw,c,patch_size,patch_size)
|
| 22 |
+
x=x.permute(0,3,1,4,2,5).contiguous()
|
| 23 |
+
out=x.view(b,c,h,w)
|
| 24 |
+
return out
|
| 25 |
+
def random_mask(x,mask_ratio=0.75):
|
| 26 |
+
b,n,p=x.shape
|
| 27 |
+
len_keep=int(n*(1-mask_ratio))
|
| 28 |
+
noise=torch.rand(b,n).to(x.device)
|
| 29 |
+
ids_shuffle=torch.argsort(noise,dim=1)
|
| 30 |
+
ids_restore=torch.argsort(ids_shuffle,dim=1)
|
| 31 |
+
ids_keep=ids_shuffle[:,:len_keep]
|
| 32 |
+
x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device)
|
| 33 |
+
mask=torch.ones(b,n).to(x.device)
|
| 34 |
+
mask[:,:len_keep]=0
|
| 35 |
+
mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device)
|
| 36 |
+
return x_masked,mask,ids_restore,ids_keep
|
| 37 |
+
|
| 38 |
+
def mae_loss(pred, target, mask):
|
| 39 |
+
# pred/target: (B, N, P), mask: (B, N) with 1=masked
|
| 40 |
+
B, N, P = pred.shape
|
| 41 |
+
mask = mask.unsqueeze(-1).float() # (B, N, 1)
|
| 42 |
+
loss = (pred - target) ** 2
|
| 43 |
+
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
|
| 44 |
+
return loss
|
| 45 |
+
|
| 46 |
+
class PositionalEncoding(nn.Module):
|
| 47 |
+
def __init__(self,num_patches,hidden_dim=768):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim))
|
| 50 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 51 |
+
def forward(self, x, visible_indices):
|
| 52 |
+
# x: (B, len_keep, D); visible_indices: (B, len_keep)
|
| 53 |
+
B, L, D = x.shape
|
| 54 |
+
# expand table to (B, N, D)
|
| 55 |
+
pos = self.pos_embed.expand(B, -1, -1) # (B, N, D)
|
| 56 |
+
# build gather index (B, L, D)
|
| 57 |
+
idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1))
|
| 58 |
+
visible_pos = torch.gather(pos, 1, idx) # (B, L, D)
|
| 59 |
+
return x + visible_pos
|
| 60 |
+
|
| 61 |
+
class TransformerBlock(nn.Module):
|
| 62 |
+
def __init__(self,hidden_dim,mlp_dim,num_heads,dropout):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.layernorm1=nn.LayerNorm(hidden_dim)
|
| 65 |
+
self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout)
|
| 66 |
+
self.layernorm2=nn.LayerNorm(hidden_dim)
|
| 67 |
+
self.mlp=nn.Sequential(
|
| 68 |
+
nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def forward(self,x):
|
| 73 |
+
residual=x
|
| 74 |
+
x=self.layernorm1(x)
|
| 75 |
+
attn,_=self.multihead(x,x,x)
|
| 76 |
+
x=residual+attn
|
| 77 |
+
residual=x
|
| 78 |
+
x=self.layernorm2(x)
|
| 79 |
+
x=self.mlp(x)
|
| 80 |
+
x=residual+x
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class MAEEncoder(nn.Module):
|
| 84 |
+
"""
|
| 85 |
+
patch_dim-> % non-masked * no. of patches
|
| 86 |
+
"""
|
| 87 |
+
def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.mask_ratio=mask_ratio
|
| 90 |
+
self.patch_size=patch_size
|
| 91 |
+
self.patch_embed=nn.Linear(patch_dim,hidden_dim)
|
| 92 |
+
self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim)
|
| 93 |
+
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
|
| 94 |
+
for _ in range(depth)])
|
| 95 |
+
|
| 96 |
+
self._init_weights()
|
| 97 |
+
def _init_weights(self):
|
| 98 |
+
for m in self.modules():
|
| 99 |
+
if isinstance(m, nn.Linear):
|
| 100 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 101 |
+
if m.bias is not None:
|
| 102 |
+
nn.init.constant_(m.bias, 0)
|
| 103 |
+
|
| 104 |
+
def forward(self,x_in):
|
| 105 |
+
x_p=patchify(x_in,self.patch_size)
|
| 106 |
+
x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio)
|
| 107 |
+
x= self.patch_embed(x_masked)
|
| 108 |
+
x=self.pos_embed(x,ids_keep)
|
| 109 |
+
for attn_layer in self.transformer:x=attn_layer(x)
|
| 110 |
+
return x,mask,ids_keep,ids_restore
|
| 111 |
+
|
| 112 |
+
class MAEDecoder(nn.Module):
|
| 113 |
+
def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.num_patches=num_patches
|
| 116 |
+
self.encoder_dim=encoder_dim
|
| 117 |
+
self.decoder_dim=decoder_dim
|
| 118 |
+
self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim))
|
| 119 |
+
self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim)
|
| 120 |
+
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim))
|
| 121 |
+
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
|
| 122 |
+
for _ in range(decoder_depth)])
|
| 123 |
+
self.layernorm=nn.LayerNorm(decoder_dim)
|
| 124 |
+
self.pred=nn.Linear(decoder_dim,c*(patch_size**2))
|
| 125 |
+
|
| 126 |
+
self._init_weights()
|
| 127 |
+
def _init_weights(self):
|
| 128 |
+
for m in self.modules():
|
| 129 |
+
if isinstance(m, nn.Linear):
|
| 130 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 131 |
+
if m.bias is not None:
|
| 132 |
+
nn.init.constant_(m.bias, 0)
|
| 133 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 134 |
+
nn.init.trunc_normal_(self.mask_token, std=0.02)
|
| 135 |
+
def forward(self,x,ids_keep,ids_restore):
|
| 136 |
+
b,n,p=x.shape
|
| 137 |
+
xdec=self.enc_to_dec(x)
|
| 138 |
+
len_keep=xdec.size(1)
|
| 139 |
+
num_patches=ids_restore.size(1)
|
| 140 |
+
num_mask=num_patches-len_keep
|
| 141 |
+
|
| 142 |
+
mask_token=self.mask_token.expand(b,num_mask,-1)
|
| 143 |
+
x_=torch.cat([xdec,mask_token],dim=1)
|
| 144 |
+
x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1)))
|
| 145 |
+
x_=x_+self.pos_embed
|
| 146 |
+
for block in self.transformer:x_=block(x_)
|
| 147 |
+
x_=self.layernorm(x_)
|
| 148 |
+
out=self.pred(x_)
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
class MaskedAutoEncoder(nn.Module):
|
| 152 |
+
def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.patch_size=patch_size
|
| 155 |
+
self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2
|
| 156 |
+
,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head
|
| 157 |
+
,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size)
|
| 158 |
+
self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size
|
| 159 |
+
,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth
|
| 160 |
+
,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout)
|
| 161 |
+
|
| 162 |
+
def forward(self,x):
|
| 163 |
+
b,c,h,w=x.shape
|
| 164 |
+
encoded,mask,ids_keep,ids_restore=self.encoder(x)
|
| 165 |
+
decoded=self.decoder(encoded,ids_keep,ids_restore)
|
| 166 |
+
|
| 167 |
+
xpatched=patchify(x,self.patch_size)
|
| 168 |
+
return xpatched,decoded,mask
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def testme():
|
| 172 |
+
img=torch.rand(1,1,384,384)
|
| 173 |
+
mae=MaskedAutoEncoder()
|
| 174 |
+
a,b,c=mae(img)
|
| 175 |
+
print(a.shape)
|
| 176 |
+
print(b.shape)
|
| 177 |
+
print(c.shape)
|
notebooks/chexpert_mae.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Deep Learning
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
|
| 5 |
+
# Data Processing
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
scikit-learn>=1.3.0
|
| 9 |
+
|
| 10 |
+
# Image Processing
|
| 11 |
+
Pillow>=10.0.0
|
| 12 |
+
opencv-python>=4.8.0
|
| 13 |
+
albumentations>=1.3.1
|
| 14 |
+
|
| 15 |
+
# Visualization
|
| 16 |
+
matplotlib>=3.7.0
|
| 17 |
+
seaborn>=0.12.0
|
| 18 |
+
|
| 19 |
+
# Utilities
|
| 20 |
+
tqdm>=4.65.0
|
| 21 |
+
|
| 22 |
+
# Jupyter (optional - for notebooks)
|
| 23 |
+
jupyter>=1.0.0
|
| 24 |
+
ipykernel>=6.25.0
|
| 25 |
+
ipywidgets>=8.1.0
|
| 26 |
+
|
| 27 |
+
# Additional utilities (if needed)
|
| 28 |
+
# lmdb>=1.4.0 # Uncomment if using LMDB for caching
|
| 29 |
+
# tensorboard>=2.13.0 # Uncomment if using TensorBoard logging
|
trainer/__init__.py
ADDED
|
File without changes
|
trainer/trainer.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from trainer.utils import *
|
| 2 |
+
from configs.configs import root,mae_config
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
try:
|
| 6 |
+
print(f"Training mae")
|
| 7 |
+
trainer=MAETrainer(mae_config)
|
| 8 |
+
trainer.test()
|
| 9 |
+
|
| 10 |
+
except:
|
| 11 |
+
import traceback
|
| 12 |
+
traceback.print_exc()
|
| 13 |
+
|
| 14 |
+
if __name__=="__main__":main()
|
trainer/utils.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data.dataset import CheXpertDataset
|
| 2 |
+
from loss.mae_loss import mae_loss
|
| 3 |
+
from models.mae import *
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import io
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
class TeeFile:
|
| 11 |
+
"""
|
| 12 |
+
File-like object that writes to multiple streams (e.g., stdout and a file)
|
| 13 |
+
Automatically handles string paths by opening them as files.
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
# This now works with both file objects and paths
|
| 17 |
+
tee = TeeFile(sys.stdout, "/path/to/log.txt")
|
| 18 |
+
print("Hello", file=tee) # Writes to both stdout and the file
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, *file_objects_or_paths):
|
| 21 |
+
"""
|
| 22 |
+
Args:
|
| 23 |
+
*file_objects_or_paths: Mix of file objects (like sys.stdout)
|
| 24 |
+
or string paths to log files
|
| 25 |
+
"""
|
| 26 |
+
self.files = []
|
| 27 |
+
self.opened_files = [] # Track files we opened so we can close them later
|
| 28 |
+
|
| 29 |
+
for item in file_objects_or_paths:
|
| 30 |
+
if isinstance(item, str):
|
| 31 |
+
# It's a path string - open it as a file
|
| 32 |
+
f = open(item, 'a', buffering=1) # Append mode, line buffered
|
| 33 |
+
self.files.append(f)
|
| 34 |
+
self.opened_files.append(f)
|
| 35 |
+
else:
|
| 36 |
+
# It's already a file-like object (e.g., sys.stdout)
|
| 37 |
+
self.files.append(item)
|
| 38 |
+
|
| 39 |
+
def write(self, data):
|
| 40 |
+
"""Write data to all streams"""
|
| 41 |
+
for f in self.files:
|
| 42 |
+
try:
|
| 43 |
+
f.write(data)
|
| 44 |
+
f.flush()
|
| 45 |
+
except Exception as e:
|
| 46 |
+
# Handle closed file gracefully
|
| 47 |
+
print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
|
| 48 |
+
|
| 49 |
+
def flush(self):
|
| 50 |
+
"""Flush all streams"""
|
| 51 |
+
for f in self.files:
|
| 52 |
+
try:
|
| 53 |
+
f.flush()
|
| 54 |
+
except:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
def isatty(self):
|
| 58 |
+
"""Check if any stream is a terminal (for tqdm compatibility)"""
|
| 59 |
+
return any(getattr(f, "isatty", lambda: False)() for f in self.files)
|
| 60 |
+
|
| 61 |
+
def fileno(self):
|
| 62 |
+
"""Get file descriptor from any real file-like stream"""
|
| 63 |
+
for f in self.files:
|
| 64 |
+
if hasattr(f, "fileno"):
|
| 65 |
+
try:
|
| 66 |
+
return f.fileno()
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
raise io.UnsupportedOperation("No fileno available")
|
| 70 |
+
|
| 71 |
+
def close(self):
|
| 72 |
+
"""Close any files we opened"""
|
| 73 |
+
for f in self.opened_files:
|
| 74 |
+
try:
|
| 75 |
+
f.close()
|
| 76 |
+
except:
|
| 77 |
+
pass
|
| 78 |
+
self.opened_files.clear()
|
| 79 |
+
|
| 80 |
+
def __del__(self):
|
| 81 |
+
"""Cleanup on deletion"""
|
| 82 |
+
self.close()
|
| 83 |
+
|
| 84 |
+
def __enter__(self):
|
| 85 |
+
"""Context manager support"""
|
| 86 |
+
return self
|
| 87 |
+
|
| 88 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 89 |
+
"""Context manager cleanup"""
|
| 90 |
+
self.close()
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
class MAETrainer:
|
| 94 |
+
def __init__(self,configs={}):
|
| 95 |
+
|
| 96 |
+
self.configs=configs
|
| 97 |
+
os.makedirs(configs["logdir"],exist_ok=True)
|
| 98 |
+
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
|
| 99 |
+
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
|
| 100 |
+
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
|
| 101 |
+
#self.log_file = open(log_path, 'w', buffering=1)
|
| 102 |
+
self.traintee = TeeFile(sys.stdout, log_path_train)
|
| 103 |
+
self.valtee = TeeFile(sys.stdout, log_path_val)
|
| 104 |
+
self.testtee = TeeFile(sys.stdout, log_path_test)
|
| 105 |
+
|
| 106 |
+
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
|
| 107 |
+
|
| 108 |
+
self.model=MaskedAutoEncoder(
|
| 109 |
+
c=configs["channels"],
|
| 110 |
+
mask_ratio=configs["mask_ratio"],
|
| 111 |
+
dropout=configs["dropout"],
|
| 112 |
+
img_size=configs["img_size"],
|
| 113 |
+
encoder_dim=configs["encoder_dim"],
|
| 114 |
+
mlp_dim=configs["mlp_dim"],
|
| 115 |
+
decoder_dim=configs["decoder_dim"],
|
| 116 |
+
encoder_depth=configs["encoder_depth"],
|
| 117 |
+
encoder_head=configs["encoder_head"],
|
| 118 |
+
decoder_depth=configs["decoder_depth"],
|
| 119 |
+
decoder_head=configs["decoder_head"],
|
| 120 |
+
patch_size=configs["patch_size"]
|
| 121 |
+
).to(configs["device"])
|
| 122 |
+
|
| 123 |
+
self.criterion=mae_loss
|
| 124 |
+
|
| 125 |
+
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
|
| 126 |
+
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
|
| 127 |
+
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
|
| 128 |
+
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
|
| 129 |
+
self.scaler=torch.amp.GradScaler()
|
| 130 |
+
|
| 131 |
+
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
|
| 132 |
+
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
|
| 133 |
+
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
|
| 134 |
+
self.sample_Weights=self.train_dataset.get_sample_weights()
|
| 135 |
+
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
|
| 136 |
+
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
|
| 137 |
+
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
|
| 138 |
+
self.history={"train_loss":[],"val_loss":[]}
|
| 139 |
+
|
| 140 |
+
self.current_epoch=0
|
| 141 |
+
|
| 142 |
+
if os.path.exists(self.configs["resume"]):
|
| 143 |
+
loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
|
| 144 |
+
self.model.load_state_dict(loadedpickle["model"],strict=False)
|
| 145 |
+
self.optimizer.load_state_dict(loadedpickle["optimizer"])
|
| 146 |
+
self.schedular.load_state_dict(loadedpickle["schedular"])
|
| 147 |
+
self.schedular1.load_state_dict(loadedpickle["schedular1"])
|
| 148 |
+
self.schedular2.load_state_dict(loadedpickle["schedular2"])
|
| 149 |
+
self.scaler.load_state_dict(loadedpickle["scaler"])
|
| 150 |
+
self.current_epoch=loadedpickle["epoch"]+1
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
self.test_dataset = None
|
| 155 |
+
self.testloader = None
|
| 156 |
+
if configs.get("test_csv"):
|
| 157 |
+
self.test_dataset = CheXpertDataset(
|
| 158 |
+
zip_path=configs["zip_path"],
|
| 159 |
+
csv_path=configs["test_csv"],
|
| 160 |
+
root_dir=configs["datadir"],
|
| 161 |
+
augment=False,
|
| 162 |
+
use_frontal_only=True
|
| 163 |
+
)
|
| 164 |
+
self.testloader = DataLoader(
|
| 165 |
+
self.test_dataset,
|
| 166 |
+
batch_size=configs["batch_size"],
|
| 167 |
+
shuffle=False,
|
| 168 |
+
num_workers=8,
|
| 169 |
+
pin_memory=True,
|
| 170 |
+
persistent_workers=True
|
| 171 |
+
)
|
| 172 |
+
print(f"Test loader ready – {len(self.test_dataset)} images")
|
| 173 |
+
|
| 174 |
+
torch.backends.cudnn.benchmark = True
|
| 175 |
+
torch.backends.cudnn.enabled = True
|
| 176 |
+
|
| 177 |
+
# FIX: Set memory allocator settings
|
| 178 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 179 |
+
|
| 180 |
+
# FIX: Enable gradient checkpointing if model supports it
|
| 181 |
+
if hasattr(self.model, 'enable_gradient_checkpointing'):
|
| 182 |
+
self.model.enable_gradient_checkpointing()
|
| 183 |
+
@staticmethod
|
| 184 |
+
def plot_training_metrics(metrics, epoch,figs_path):
|
| 185 |
+
import matplotlib.pyplot as plt
|
| 186 |
+
"""
|
| 187 |
+
Plot loss and AUC curves from training metrics.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
metrics (dict): Dictionary containing lists for each metric key:
|
| 191 |
+
{
|
| 192 |
+
"train_loss": [...],
|
| 193 |
+
"val_loss": [...]
|
| 194 |
+
}
|
| 195 |
+
epoch (int): Current epoch number (used for title or axis scaling)
|
| 196 |
+
"""
|
| 197 |
+
epochs = list(range(1, epoch + 1))
|
| 198 |
+
|
| 199 |
+
#Compute the common length across all series
|
| 200 |
+
keys = ["train_loss","val_loss"]
|
| 201 |
+
lengths = [len(metrics[k]) for k in keys if k in metrics]
|
| 202 |
+
if not lengths:
|
| 203 |
+
return
|
| 204 |
+
n = min(lengths)
|
| 205 |
+
|
| 206 |
+
# Slice everything to the same length
|
| 207 |
+
m = {k: metrics[k][:n] for k in keys if k in metrics}
|
| 208 |
+
epochs = list(range(1, n + 1))
|
| 209 |
+
|
| 210 |
+
plt.figure(figsize=(14, 6))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ---- Loss subplot ----
|
| 214 |
+
plt.subplot(1, 2, 1)
|
| 215 |
+
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
|
| 216 |
+
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
|
| 217 |
+
plt.xlabel("Epoch")
|
| 218 |
+
plt.ylabel("Loss")
|
| 219 |
+
plt.title("Training & Validation Loss")
|
| 220 |
+
plt.legend()
|
| 221 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
plt.tight_layout()
|
| 225 |
+
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
|
| 226 |
+
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
|
| 227 |
+
plt.show()
|
| 228 |
+
|
| 229 |
+
def train_epoch(self, epoch, looper):
|
| 230 |
+
self.model.train()
|
| 231 |
+
running_loss = 0.0
|
| 232 |
+
all_preds = []
|
| 233 |
+
all_targets = []
|
| 234 |
+
current_loss=0
|
| 235 |
+
total_batches = len(self.trainloader)
|
| 236 |
+
|
| 237 |
+
for batch_idx, data in looper:
|
| 238 |
+
image = data['image'].to(self.configs["device"], non_blocking=True)
|
| 239 |
+
target = data['labels'].to(self.configs["device"], non_blocking=True)
|
| 240 |
+
|
| 241 |
+
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
| 242 |
+
img,preds,mask = self.model(image)
|
| 243 |
+
loss = self.criterion(img,preds,mask)
|
| 244 |
+
|
| 245 |
+
loss_back = loss / self.configs["accumulation"]
|
| 246 |
+
running_loss += loss.item()
|
| 247 |
+
|
| 248 |
+
if torch.isfinite(loss):
|
| 249 |
+
#loss_back.backward()
|
| 250 |
+
self.scaler.scale(loss_back).backward()
|
| 251 |
+
else:
|
| 252 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
|
| 256 |
+
self.scaler.unscale_(self.optimizer)
|
| 257 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 258 |
+
self.scaler.step(self.optimizer)
|
| 259 |
+
self.scaler.update()
|
| 260 |
+
#self.optimizer.step()
|
| 261 |
+
self.schedular.step()
|
| 262 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# === LIVE METRICS (every batch) ===
|
| 266 |
+
current_loss = running_loss / (batch_idx + 1)
|
| 267 |
+
if (batch_idx + 1) % 10 == 0:
|
| 268 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 269 |
+
looper.set_postfix({
|
| 270 |
+
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
|
| 271 |
+
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
| 272 |
+
"loss": f"{current_loss:.3f}",
|
| 273 |
+
})
|
| 274 |
+
|
| 275 |
+
return current_loss
|
| 276 |
+
def validate(self, epoch, looper):
|
| 277 |
+
self.model.eval()
|
| 278 |
+
val_loss = 0.0
|
| 279 |
+
all_preds = []
|
| 280 |
+
all_targets = []
|
| 281 |
+
lenloader=len(self.valloader)
|
| 282 |
+
current_loss=0
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
for batch_idx, data in looper:
|
| 285 |
+
image = data["image"].to(self.configs["device"], non_blocking=True)
|
| 286 |
+
target = data["labels"].to(self.configs["device"], non_blocking=True)
|
| 287 |
+
|
| 288 |
+
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
| 289 |
+
img,preds,mask = self.model(image)
|
| 290 |
+
loss = self.criterion(img,preds,mask)
|
| 291 |
+
|
| 292 |
+
val_loss += loss.item()
|
| 293 |
+
|
| 294 |
+
# === LIVE METRICS ===
|
| 295 |
+
current_loss = val_loss / (batch_idx + 1)
|
| 296 |
+
if (batch_idx + 1) % 10 == 0 :
|
| 297 |
+
|
| 298 |
+
looper.set_postfix({
|
| 299 |
+
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
| 300 |
+
"batch":f"{batch_idx}/{lenloader}",
|
| 301 |
+
"loss": f"{current_loss:.3f}",
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
return current_loss
|
| 305 |
+
def train(self):
|
| 306 |
+
|
| 307 |
+
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
|
| 308 |
+
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
|
| 309 |
+
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
self.model.train()
|
| 313 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 314 |
+
|
| 315 |
+
running_loss=self.train_epoch(epoch,trainlooper)
|
| 316 |
+
|
| 317 |
+
torch.cuda.synchronize()
|
| 318 |
+
torch.cuda.empty_cache()
|
| 319 |
+
|
| 320 |
+
val_loss=self.validate(epoch,vallooper)
|
| 321 |
+
|
| 322 |
+
torch.cuda.synchronize()
|
| 323 |
+
torch.cuda.empty_cache()
|
| 324 |
+
|
| 325 |
+
gc.collect()
|
| 326 |
+
|
| 327 |
+
if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
|
| 328 |
+
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
|
| 329 |
+
torch.save(checkpoint, self.configs["resume"])
|
| 330 |
+
|
| 331 |
+
print(f"train loss {running_loss} val loss {val_loss}")
|
| 332 |
+
|
| 333 |
+
self.history["train_loss"].append(float(running_loss))
|
| 334 |
+
self.history["val_loss"].append(float(val_loss))
|
| 335 |
+
|
| 336 |
+
if epoch%10==0:
|
| 337 |
+
historyfile=os.path.join(self.configs["logdir"],"history.json")
|
| 338 |
+
if os.path.exists(historyfile):
|
| 339 |
+
with open(historyfile,"r") as f:
|
| 340 |
+
history=json.load(f)
|
| 341 |
+
history["train_loss"]+=self.history["train_loss"]
|
| 342 |
+
history["val_loss"]+=self.history["val_loss"]
|
| 343 |
+
with open(historyfile,"w") as f:
|
| 344 |
+
json.dump(self.history,f)
|
| 345 |
+
f.close()
|
| 346 |
+
MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
|
| 347 |
+
|
| 348 |
+
self.current_epoch=epoch
|
training logs/mae/1/metrics.png
ADDED
|
training logs/mae/101/metrics.png
ADDED
|
training logs/mae/11/metrics.png
ADDED
|
training logs/mae/21/metrics.png
ADDED
|
training logs/mae/31/metrics.png
ADDED
|
training logs/mae/41/metrics.png
ADDED
|
training logs/mae/51/metrics.png
ADDED
|
training logs/mae/61/metrics.png
ADDED
|
training logs/mae/71/metrics.png
ADDED
|
training logs/mae/81/metrics.png
ADDED
|
training logs/mae/91/metrics.png
ADDED
|
training logs/mae/history.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"train_loss": [61.882596965502664, 44.471386281221996, 34.87869473939301, 27.54925524051899, 24.282669028415476, 22.451558465855097, 21.147514946913635, 20.99175854891432, 20.213269732075354, 20.001737736117455, 19.91450946014842, 19.47891389648547, 19.18175616794162, 18.748555672895098, 19.04294045150921, 18.45393469068739, 18.838688672315264, 18.155403604131447, 18.37196905074581, 17.72255533279911, 18.3023255009805, 17.781992453742625, 18.199740923262837, 17.737673626834773, 18.033451236875255, 17.423879869522587, 17.789273495144315, 17.71034824754175, 17.673153098253366, 17.505879045301867, 17.242074561717263, 17.15336024376654, 16.90521138372387, 16.933767681053464, 17.133027586714768, 17.32885531353694, 17.239427022352867, 17.082445266340795, 17.198882104333585, 16.797237842655523, 16.80550524462081, 16.855286646060193, 16.820361117564648, 16.804262694492135, 16.60529948887432, 16.617685430109713, 16.49680861811484, 16.534883691247646, 16.66711660952551, 16.68186878833292, 16.830572265365216, 16.551616942156173, 16.58498664637193, 16.517176605552756, 16.586636129673238, 16.610006309153786, 16.467580378269208, 16.335460134390008, 16.45028829335312, 16.415938852508436, 16.657625508052046, 16.622210603119225, 16.319772068146737, 16.374398973253037, 16.252494745015245, 16.311780229520625, 16.422519305102714, 16.033588693133392, 16.024791444757934, 16.087340885422137, 16.039379536208287, 16.26473860928662, 16.34161920308212, 15.996231590804234, 16.295430011817633, 16.445986707407087, 16.343918548775402, 16.409462545251333, 16.49581729998298, 16.137155871921117, 16.05842663481244, 16.16612617533694, 16.27624153595244, 16.29507503646249, 16.29731023747434, 16.399930175316378, 16.08117872668851, 16.119801326464582, 16.0585214939596, 15.990199300977919, 16.033912498036592, 16.225505850050183, 16.006062768095283, 16.016956458553192, 15.915514986318499, 16.111989719978798, 15.927976318414066, 16.02773256541153, 15.936725686186103, 15.84361021441798, 15.960153004089136], "val_loss": [31.570541223418278, 22.954840935741945, 14.442683501893104, 9.339503538568946, 8.24503136790076, 6.798558349229173, 7.171370674209341, 6.004779509927744, 6.1726217491682185, 5.98896356120062, 5.595671336912238, 5.563971098079238, 5.132711923795681, 5.617544600337843, 4.961312991044054, 5.048685019990534, 5.025379383682808, 4.725488581134631, 4.978387813631482, 4.562931300793771, 4.850497535692893, 4.521047054335129, 4.649226747081921, 4.5786320276038595, 4.333143504355041, 4.4852356395848165, 4.235697840535363, 4.414844192935779, 4.24225838635847, 4.289312726239429, 4.266240818555965, 4.077501735021902, 4.319235841301192, 4.025318107731715, 4.189943225676831, 4.086437865349145, 4.028652789980867, 4.119319666263669, 3.907484631205714, 4.0297732788859015, 3.884794211466843, 4.035012978651991, 3.96724297437953, 3.852650080804413, 3.9237460725727273, 3.7834923750538367, 3.9779289901454584, 3.7871727095885928, 3.838316381967741, 3.852536988020735, 3.7242261183222265, 3.889924947605577, 3.681634605920988, 3.851548166370075, 3.699327661349528, 3.693325303321661, 3.7589161586127804, 3.61661579759414, 3.7395904396855553, 3.5934903867220958, 3.6865477134223, 3.6255838419511863, 3.5961770504416024, 3.650220192152004, 3.5313091785012687, 3.658342431153966, 3.5214638036746915, 3.5825591047736896, 3.535893441830759, 3.4996724699026722, 3.600268395636169, 3.4633193554672292, 3.5713670332962493, 3.4522710076202188, 3.5278821134092007, 3.5031054748649217, 3.426014715650945, 3.5055409483735347, 3.403536310227606, 3.5005639573664364, 3.434894419983772, 3.439712012725019, 3.4542255861022544, 3.3616595402904523, 3.4835145457638457, 3.351115513481571, 3.4225067292337004, 3.3756711308742284, 3.367003402044607, 3.418841015064835, 3.3144125257219588, 3.4290886646093326, 3.2998156674280517, 3.368062291826521, 3.337226649851498, 3.2932771796799973, 3.3656438268300306, 3.266660438423537, 3.3932779888774074, 3.2590805264406426, 3.298969637119889]}
|
training logs/mae/test_log.txt
ADDED
|
File without changes
|
training logs/mae/training_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training logs/mae/val_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|