|
|
--- |
|
|
license: mit |
|
|
pipeline_tag: image-feature-extraction |
|
|
--- |
|
|
# Masked Autoencoder (MAE) for Medical Imaging |
|
|
|
|
|
A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset. |
|
|
|
|
|
## π Overview |
|
|
|
|
|
This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data. |
|
|
|
|
|
### Key Features |
|
|
|
|
|
- **Vision Transformer Architecture**: Encoder-decoder transformer architecture with positional encodings |
|
|
- **Self-Supervised Learning**: Pre-training through masked image reconstruction |
|
|
- **Optimized for Medical Imaging**: Designed specifically for chest X-ray analysis |
|
|
- **Production-Ready Training Pipeline**: |
|
|
- Mixed precision training (FP16) with gradient scaling |
|
|
- Gradient accumulation support |
|
|
- Learning rate warmup and cosine annealing |
|
|
- Automatic checkpointing and resumption |
|
|
- **Efficient Data Loading**: |
|
|
- Optimized ZIP file reader with LRU caching |
|
|
- Class-balanced sampling with weighted random sampler |
|
|
- Multi-worker data loading with persistent workers |
|
|
- **Comprehensive Logging**: Training/validation metrics tracking and visualization |
|
|
|
|
|
## ποΈ Architecture |
|
|
|
|
|
### Masked Autoencoder Structure |
|
|
|
|
|
``` |
|
|
Input Image (384Γ384) |
|
|
β |
|
|
Patchify (16Γ16 patches β 576 patches) |
|
|
β |
|
|
Random Masking (75% masked, 25% visible) |
|
|
β |
|
|
βββββββββββββββββββββββββββββββββββββββ |
|
|
β MAE ENCODER β |
|
|
β - Linear patch embedding β |
|
|
β - Positional encoding (visible) β |
|
|
β - 12 Transformer blocks β |
|
|
β - 8 attention heads, 768 hidden β |
|
|
βββββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βββββββββββββββββββββββββββββββββββββββ |
|
|
β MAE DECODER β |
|
|
β - Learnable mask tokens β |
|
|
β - Positional encoding (all) β |
|
|
β - 8 Transformer blocks β |
|
|
β - 8 attention heads, 512 hidden β |
|
|
β - Pixel reconstruction head β |
|
|
βββββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
Reconstructed Image |
|
|
β |
|
|
MSE Loss (on masked patches only) |
|
|
``` |
|
|
|
|
|
### Model Configuration |
|
|
|
|
|
| Parameter | Default Value | Description | |
|
|
|-----------|---------------|-------------| |
|
|
| Image Size | 384Γ384 | Input image resolution | |
|
|
| Patch Size | 16Γ16 | Size of each patch | |
|
|
| Mask Ratio | 0.75 | Fraction of patches to mask | |
|
|
| Encoder Depth | 12 layers | Number of transformer blocks | |
|
|
| Encoder Dim | 768 | Hidden dimension | |
|
|
| Encoder Heads | 8 | Number of attention heads | |
|
|
| Decoder Depth | 8 layers | Number of transformer blocks | |
|
|
| Decoder Dim | 512 | Hidden dimension | |
|
|
| Decoder Heads | 8 | Number of attention heads | |
|
|
| MLP Ratio | 4Γ | MLP expansion ratio (3072) | |
|
|
| Dropout | 0.25 | Dropout rate | |
|
|
|
|
|
## π Getting Started |
|
|
|
|
|
### Prerequisites |
|
|
|
|
|
- Python >= 3.8 |
|
|
- CUDA-capable GPU (recommended) |
|
|
- 16GB+ RAM |
|
|
|
|
|
### Installation |
|
|
|
|
|
1. Clone the repository: |
|
|
```bash |
|
|
git clone https://github.com/adelelsayed/mae.git |
|
|
cd mae |
|
|
``` |
|
|
|
|
|
2. Install dependencies: |
|
|
```bash |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
|
|
|
### Dataset Preparation |
|
|
|
|
|
This project is configured for the **CheXpert dataset**. To use it: |
|
|
|
|
|
1. Download CheXpert-v1.0-small from [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/) |
|
|
2. Update paths in `configs/configs.py`: |
|
|
- `root`: Base directory for your data |
|
|
- `zip_path`: Path to zipped dataset (optional, for faster loading) |
|
|
- `csv`: Path to training CSV |
|
|
- `train_csv`, `val_csv`, `test_csv`: Split CSV files |
|
|
|
|
|
## π Usage |
|
|
|
|
|
### Training |
|
|
|
|
|
Start training from scratch: |
|
|
```bash |
|
|
python trainer/trainer.py |
|
|
``` |
|
|
|
|
|
The trainer will: |
|
|
- Automatically create checkpoint and log directories |
|
|
- Resume from the last checkpoint if available |
|
|
- Log training/validation metrics to text files |
|
|
- Save plots every 10 epochs |
|
|
- Save best model based on validation loss |
|
|
|
|
|
### Training Configuration |
|
|
|
|
|
Edit `configs/configs.py` to customize training: |
|
|
|
|
|
```python |
|
|
mae_config = { |
|
|
# Training hyperparameters |
|
|
"lr": 1e-4, # Learning rate |
|
|
"warmup": 5, # Warmup epochs |
|
|
"weight_decay": 5e-4, # AdamW weight decay |
|
|
"num_epochs": 200, # Total training epochs |
|
|
"batch_size": 96, # Batch size |
|
|
"accumulation": 1, # Gradient accumulation steps |
|
|
|
|
|
# Model architecture |
|
|
"mask_ratio": 0.75, # Masking ratio |
|
|
"encoder_depth": 12, # Encoder layers |
|
|
"decoder_depth": 8, # Decoder layers |
|
|
|
|
|
# Paths |
|
|
"checkpoints": "/path/to/checkpoints", |
|
|
"logdir": "/path/to/logs", |
|
|
... |
|
|
} |
|
|
``` |
|
|
|
|
|
### Monitoring Training |
|
|
|
|
|
Training logs are saved in three files: |
|
|
- `training_log.txt`: Training metrics per epoch |
|
|
- `val_log.txt`: Validation metrics per epoch |
|
|
- `test_log.txt`: Test set evaluation results |
|
|
|
|
|
Metrics plots are saved every 10 epochs in `{logdir}/{epoch}/metrics.png` |
|
|
|
|
|
### Evaluation |
|
|
|
|
|
The project includes a test method in the trainer. To evaluate: |
|
|
```python |
|
|
from trainer.utils import MAETrainer |
|
|
from configs.configs import mae_config |
|
|
|
|
|
trainer = MAETrainer(mae_config) |
|
|
trainer.test() |
|
|
``` |
|
|
|
|
|
## π Project Structure |
|
|
|
|
|
``` |
|
|
mae/ |
|
|
βββ configs/ |
|
|
β βββ __init__.py |
|
|
β βββ configs.py # Training configuration |
|
|
βββ data/ |
|
|
β βββ __init__.py |
|
|
β βββ dataset.py # CheXpert dataset loader |
|
|
β βββ splitter.py # Dataset splitting utilities |
|
|
βββ loss/ |
|
|
β βββ __init__.py |
|
|
β βββ mae_loss.py # MAE reconstruction loss |
|
|
βββ models/ |
|
|
β βββ __init__.py |
|
|
β βββ mae.py # MAE architecture |
|
|
βββ trainer/ |
|
|
β βββ __init__.py |
|
|
β βββ trainer.py # Main training script |
|
|
β βββ utils.py # Training utilities |
|
|
βββ notebooks/ |
|
|
β βββ chexpert_mae.ipynb # Jupyter notebook for experiments |
|
|
βββ training logs/ # Logged metrics and plots |
|
|
βββ weights/ # Model checkpoints |
|
|
βββ results/ # Evaluation results |
|
|
βββ requirements.txt # Python dependencies |
|
|
βββ LICENSE # Project license |
|
|
βββ README.md # This file |
|
|
``` |
|
|
|
|
|
## π§ Components |
|
|
|
|
|
### Dataset (`data/dataset.py`) |
|
|
|
|
|
- **OptimizedZipReader**: Fast ZIP file reading with LRU caching |
|
|
- **CheXpertDataset**: PyTorch dataset for CheXpert chest X-rays |
|
|
- 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc. |
|
|
- Albumentations-based augmentation pipeline |
|
|
- Class-balanced sampling support |
|
|
- Frontal/lateral view filtering |
|
|
|
|
|
### Model (`models/mae.py`) |
|
|
|
|
|
- **Patchify/Unpatchify**: Image-to-patch conversion utilities |
|
|
- **Random Masking**: Stochastic patch masking with restore indices |
|
|
- **PositionalEncoding**: Learnable position embeddings |
|
|
- **TransformerBlock**: Multi-head self-attention + MLP |
|
|
- **MAEEncoder**: Processes visible patches only |
|
|
- **MAEDecoder**: Reconstructs full image with mask tokens |
|
|
- **MaskedAutoEncoder**: Complete MAE model |
|
|
|
|
|
### Loss (`loss/mae_loss.py`) |
|
|
|
|
|
Mean Squared Error (MSE) computed only on masked patches: |
|
|
```python |
|
|
loss = ((pred - target) ** 2 * mask).sum() / mask.sum() |
|
|
``` |
|
|
|
|
|
### Trainer (`trainer/utils.py`) |
|
|
|
|
|
- **MAETrainer**: Complete training pipeline |
|
|
- Mixed precision training (AMP) |
|
|
- Gradient clipping and accumulation |
|
|
- Learning rate scheduling (warmup β cosine) |
|
|
- Automatic checkpointing |
|
|
- Multi-file logging (train/val/test) |
|
|
- Live metric monitoring with tqdm |
|
|
- Periodic metric visualization |
|
|
|
|
|
## π― CheXpert Pathologies |
|
|
|
|
|
The dataset includes 14 chest X-ray findings: |
|
|
|
|
|
1. No Finding |
|
|
2. Enlarged Cardiomediastinum |
|
|
3. Cardiomegaly |
|
|
4. Lung Opacity |
|
|
5. Lung Lesion |
|
|
6. Edema |
|
|
7. Consolidation |
|
|
8. Pneumonia |
|
|
9. Atelectasis |
|
|
10. Pneumothorax |
|
|
11. Pleural Effusion |
|
|
12. Pleural Other |
|
|
13. Fracture |
|
|
14. Support Devices |
|
|
|
|
|
## π Training Tips |
|
|
|
|
|
1. **Learning Rate**: Start with 1e-4, use warmup for stability |
|
|
2. **Batch Size**: Maximize based on GPU memory (96 works well on 40GB GPUs) |
|
|
3. **Gradient Accumulation**: Use if batch size is limited by memory |
|
|
4. **Mixed Precision**: Enabled by default for faster training |
|
|
5. **Masking Ratio**: 75% is standard, higher ratios increase difficulty |
|
|
6. **Resume Training**: Model automatically resumes from last checkpoint |
|
|
|
|
|
## π¬ Use Cases |
|
|
|
|
|
### Pre-training for Downstream Tasks |
|
|
Use the trained encoder as a feature extractor: |
|
|
```python |
|
|
from models.mae import MaskedAutoEncoder |
|
|
|
|
|
# Load pre-trained model |
|
|
mae = MaskedAutoEncoder() |
|
|
mae.load_state_dict(torch.load("best_mae.pth")["model"]) |
|
|
|
|
|
# Use encoder for feature extraction |
|
|
encoder = mae.encoder |
|
|
features, _, _, _ = encoder(images) |
|
|
``` |
|
|
|
|
|
### Fine-tuning on Classification |
|
|
Add a classification head to the encoder for supervised tasks. |
|
|
|
|
|
### Anomaly Detection |
|
|
Reconstruction error can indicate abnormalities in medical images. |
|
|
|
|
|
## π Performance Optimization |
|
|
|
|
|
This implementation includes several optimizations: |
|
|
|
|
|
- **Efficient ZIP Reading**: Avoids extracting files to disk |
|
|
- **LRU Cache**: Keeps frequently accessed images in memory |
|
|
- **Persistent Workers**: Reduces data loading overhead |
|
|
- **Mixed Precision**: 2Γ faster training with minimal quality loss |
|
|
- **Gradient Checkpointing**: Reduces memory usage (if enabled) |
|
|
- **CUDA Memory Management**: Proper cache clearing and synchronization |
|
|
|
|
|
## π€ Contributing |
|
|
|
|
|
Contributions are welcome! Please feel free to submit a Pull Request. |
|
|
|
|
|
## π License |
|
|
|
|
|
This project is licensed under the terms specified in the LICENSE file. |
|
|
|
|
|
## π References |
|
|
|
|
|
1. **Masked Autoencoders Are Scalable Vision Learners** |
|
|
He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022) |
|
|
[arXiv:2111.06377](https://arxiv.org/abs/2111.06377) |
|
|
|
|
|
2. **CheXpert: A Large Chest Radiograph Dataset** |
|
|
Irvin, J., et al. (2019) |
|
|
[Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/) |
|
|
|
|
|
## π Acknowledgments |
|
|
|
|
|
- Original MAE paper by Meta AI Research |
|
|
- CheXpert dataset by Stanford ML Group |
|
|
- PyTorch and Albumentations communities |
|
|
|
|
|
## π§ Contact |
|
|
|
|
|
For questions or issues, please open an issue on GitHub or contact the maintainer. |
|
|
|
|
|
--- |
|
|
|
|
|
**Note**: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance. |