| # Masked Autoencoder (MAE) for Medical Imaging | |
| A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset. | |
| ## π Overview | |
| This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data. | |
| ### Key Features | |
| - **Vision Transformer Architecture**: Encoder-decoder transformer architecture with positional encodings | |
| - **Self-Supervised Learning**: Pre-training through masked image reconstruction | |
| - **Optimized for Medical Imaging**: Designed specifically for chest X-ray analysis | |
| - **Production-Ready Training Pipeline**: | |
| - Mixed precision training (FP16) with gradient scaling | |
| - Gradient accumulation support | |
| - Learning rate warmup and cosine annealing | |
| - Automatic checkpointing and resumption | |
| - **Efficient Data Loading**: | |
| - Optimized ZIP file reader with LRU caching | |
| - Class-balanced sampling with weighted random sampler | |
| - Multi-worker data loading with persistent workers | |
| - **Comprehensive Logging**: Training/validation metrics tracking and visualization | |
| ## ποΈ Architecture | |
| ### Masked Autoencoder Structure | |
| ``` | |
| Input Image (384Γ384) | |
| β | |
| Patchify (16Γ16 patches β 576 patches) | |
| β | |
| Random Masking (75% masked, 25% visible) | |
| β | |
| βββββββββββββββββββββββββββββββββββββββ | |
| β MAE ENCODER β | |
| β - Linear patch embedding β | |
| β - Positional encoding (visible) β | |
| β - 12 Transformer blocks β | |
| β - 8 attention heads, 768 hidden β | |
| βββββββββββββββββββββββββββββββββββββββ | |
| β | |
| βββββββββββββββββββββββββββββββββββββββ | |
| β MAE DECODER β | |
| β - Learnable mask tokens β | |
| β - Positional encoding (all) β | |
| β - 8 Transformer blocks β | |
| β - 8 attention heads, 512 hidden β | |
| β - Pixel reconstruction head β | |
| βββββββββββββββββββββββββββββββββββββββ | |
| β | |
| Reconstructed Image | |
| β | |
| MSE Loss (on masked patches only) | |
| ``` | |
| ### Model Configuration | |
| | Parameter | Default Value | Description | | |
| |-----------|---------------|-------------| | |
| | Image Size | 384Γ384 | Input image resolution | | |
| | Patch Size | 16Γ16 | Size of each patch | | |
| | Mask Ratio | 0.75 | Fraction of patches to mask | | |
| | Encoder Depth | 12 layers | Number of transformer blocks | | |
| | Encoder Dim | 768 | Hidden dimension | | |
| | Encoder Heads | 8 | Number of attention heads | | |
| | Decoder Depth | 8 layers | Number of transformer blocks | | |
| | Decoder Dim | 512 | Hidden dimension | | |
| | Decoder Heads | 8 | Number of attention heads | | |
| | MLP Ratio | 4Γ | MLP expansion ratio (3072) | | |
| | Dropout | 0.25 | Dropout rate | | |
| ## π Getting Started | |
| ### Prerequisites | |
| - Python >= 3.8 | |
| - CUDA-capable GPU (recommended) | |
| - 16GB+ RAM | |
| ### Installation | |
| 1. Clone the repository: | |
| ```bash | |
| git clone https://github.com/adelelsayed/mae.git | |
| cd mae | |
| ``` | |
| 2. Install dependencies: | |
| ```bash | |
| pip install -r requirements.txt | |
| ``` | |
| ### Dataset Preparation | |
| This project is configured for the **CheXpert dataset**. To use it: | |
| 1. Download CheXpert-v1.0-small from [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/) | |
| 2. Update paths in `configs/configs.py`: | |
| - `root`: Base directory for your data | |
| - `zip_path`: Path to zipped dataset (optional, for faster loading) | |
| - `csv`: Path to training CSV | |
| - `train_csv`, `val_csv`, `test_csv`: Split CSV files | |
| ## π Usage | |
| ### Training | |
| Start training from scratch: | |
| ```bash | |
| python trainer/trainer.py | |
| ``` | |
| The trainer will: | |
| - Automatically create checkpoint and log directories | |
| - Resume from the last checkpoint if available | |
| - Log training/validation metrics to text files | |
| - Save plots every 10 epochs | |
| - Save best model based on validation loss | |
| ### Training Configuration | |
| Edit `configs/configs.py` to customize training: | |
| ```python | |
| mae_config = { | |
| # Training hyperparameters | |
| "lr": 1e-4, # Learning rate | |
| "warmup": 5, # Warmup epochs | |
| "weight_decay": 5e-4, # AdamW weight decay | |
| "num_epochs": 200, # Total training epochs | |
| "batch_size": 96, # Batch size | |
| "accumulation": 1, # Gradient accumulation steps | |
| # Model architecture | |
| "mask_ratio": 0.75, # Masking ratio | |
| "encoder_depth": 12, # Encoder layers | |
| "decoder_depth": 8, # Decoder layers | |
| # Paths | |
| "checkpoints": "/path/to/checkpoints", | |
| "logdir": "/path/to/logs", | |
| ... | |
| } | |
| ``` | |
| ### Monitoring Training | |
| Training logs are saved in three files: | |
| - `training_log.txt`: Training metrics per epoch | |
| - `val_log.txt`: Validation metrics per epoch | |
| - `test_log.txt`: Test set evaluation results | |
| Metrics plots are saved every 10 epochs in `{logdir}/{epoch}/metrics.png` | |
| ### Evaluation | |
| The project includes a test method in the trainer. To evaluate: | |
| ```python | |
| from trainer.utils import MAETrainer | |
| from configs.configs import mae_config | |
| trainer = MAETrainer(mae_config) | |
| trainer.test() | |
| ``` | |
| ## π Project Structure | |
| ``` | |
| mae/ | |
| βββ configs/ | |
| β βββ __init__.py | |
| β βββ configs.py # Training configuration | |
| βββ data/ | |
| β βββ __init__.py | |
| β βββ dataset.py # CheXpert dataset loader | |
| β βββ splitter.py # Dataset splitting utilities | |
| βββ loss/ | |
| β βββ __init__.py | |
| β βββ mae_loss.py # MAE reconstruction loss | |
| βββ models/ | |
| β βββ __init__.py | |
| β βββ mae.py # MAE architecture | |
| βββ trainer/ | |
| β βββ __init__.py | |
| β βββ trainer.py # Main training script | |
| β βββ utils.py # Training utilities | |
| βββ notebooks/ | |
| β βββ chexpert_mae.ipynb # Jupyter notebook for experiments | |
| βββ training logs/ # Logged metrics and plots | |
| βββ weights/ # Model checkpoints | |
| βββ results/ # Evaluation results | |
| βββ requirements.txt # Python dependencies | |
| βββ LICENSE # Project license | |
| βββ README.md # This file | |
| ``` | |
| ## π§ Components | |
| ### Dataset (`data/dataset.py`) | |
| - **OptimizedZipReader**: Fast ZIP file reading with LRU caching | |
| - **CheXpertDataset**: PyTorch dataset for CheXpert chest X-rays | |
| - 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc. | |
| - Albumentations-based augmentation pipeline | |
| - Class-balanced sampling support | |
| - Frontal/lateral view filtering | |
| ### Model (`models/mae.py`) | |
| - **Patchify/Unpatchify**: Image-to-patch conversion utilities | |
| - **Random Masking**: Stochastic patch masking with restore indices | |
| - **PositionalEncoding**: Learnable position embeddings | |
| - **TransformerBlock**: Multi-head self-attention + MLP | |
| - **MAEEncoder**: Processes visible patches only | |
| - **MAEDecoder**: Reconstructs full image with mask tokens | |
| - **MaskedAutoEncoder**: Complete MAE model | |
| ### Loss (`loss/mae_loss.py`) | |
| Mean Squared Error (MSE) computed only on masked patches: | |
| ```python | |
| loss = ((pred - target) ** 2 * mask).sum() / mask.sum() | |
| ``` | |
| ### Trainer (`trainer/utils.py`) | |
| - **MAETrainer**: Complete training pipeline | |
| - Mixed precision training (AMP) | |
| - Gradient clipping and accumulation | |
| - Learning rate scheduling (warmup β cosine) | |
| - Automatic checkpointing | |
| - Multi-file logging (train/val/test) | |
| - Live metric monitoring with tqdm | |
| - Periodic metric visualization | |
| ## π― CheXpert Pathologies | |
| The dataset includes 14 chest X-ray findings: | |
| 1. No Finding | |
| 2. Enlarged Cardiomediastinum | |
| 3. Cardiomegaly | |
| 4. Lung Opacity | |
| 5. Lung Lesion | |
| 6. Edema | |
| 7. Consolidation | |
| 8. Pneumonia | |
| 9. Atelectasis | |
| 10. Pneumothorax | |
| 11. Pleural Effusion | |
| 12. Pleural Other | |
| 13. Fracture | |
| 14. Support Devices | |
| ## π Training Tips | |
| 1. **Learning Rate**: Start with 1e-4, use warmup for stability | |
| 2. **Batch Size**: Maximize based on GPU memory (96 works well on 40GB GPUs) | |
| 3. **Gradient Accumulation**: Use if batch size is limited by memory | |
| 4. **Mixed Precision**: Enabled by default for faster training | |
| 5. **Masking Ratio**: 75% is standard, higher ratios increase difficulty | |
| 6. **Resume Training**: Model automatically resumes from last checkpoint | |
| ## π¬ Use Cases | |
| ### Pre-training for Downstream Tasks | |
| Use the trained encoder as a feature extractor: | |
| ```python | |
| from models.mae import MaskedAutoEncoder | |
| # Load pre-trained model | |
| mae = MaskedAutoEncoder() | |
| mae.load_state_dict(torch.load("best_mae.pth")["model"]) | |
| # Use encoder for feature extraction | |
| encoder = mae.encoder | |
| features, _, _, _ = encoder(images) | |
| ``` | |
| ### Fine-tuning on Classification | |
| Add a classification head to the encoder for supervised tasks. | |
| ### Anomaly Detection | |
| Reconstruction error can indicate abnormalities in medical images. | |
| ## π Performance Optimization | |
| This implementation includes several optimizations: | |
| - **Efficient ZIP Reading**: Avoids extracting files to disk | |
| - **LRU Cache**: Keeps frequently accessed images in memory | |
| - **Persistent Workers**: Reduces data loading overhead | |
| - **Mixed Precision**: 2Γ faster training with minimal quality loss | |
| - **Gradient Checkpointing**: Reduces memory usage (if enabled) | |
| - **CUDA Memory Management**: Proper cache clearing and synchronization | |
| ## π€ Contributing | |
| Contributions are welcome! Please feel free to submit a Pull Request. | |
| ## π License | |
| This project is licensed under the terms specified in the LICENSE file. | |
| ## π References | |
| 1. **Masked Autoencoders Are Scalable Vision Learners** | |
| He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022) | |
| [arXiv:2111.06377](https://arxiv.org/abs/2111.06377) | |
| 2. **CheXpert: A Large Chest Radiograph Dataset** | |
| Irvin, J., et al. (2019) | |
| [Stanford ML Group](https://stanfordmlgroup.github.io/competitions/chexpert/) | |
| ## π Acknowledgments | |
| - Original MAE paper by Meta AI Research | |
| - CheXpert dataset by Stanford ML Group | |
| - PyTorch and Albumentations communities | |
| ## π§ Contact | |
| For questions or issues, please open an issue on GitHub or contact the maintainer. | |
| --- | |
| **Note**: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance. | |