Masked Autoencoder (MAE) for Medical Imaging
A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset.
π Overview
This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data.
Key Features
- Vision Transformer Architecture: Encoder-decoder transformer architecture with positional encodings
- Self-Supervised Learning: Pre-training through masked image reconstruction
- Optimized for Medical Imaging: Designed specifically for chest X-ray analysis
- Production-Ready Training Pipeline:
- Mixed precision training (FP16) with gradient scaling
- Gradient accumulation support
- Learning rate warmup and cosine annealing
- Automatic checkpointing and resumption
- Efficient Data Loading:
- Optimized ZIP file reader with LRU caching
- Class-balanced sampling with weighted random sampler
- Multi-worker data loading with persistent workers
- Comprehensive Logging: Training/validation metrics tracking and visualization
ποΈ Architecture
Masked Autoencoder Structure
Input Image (384Γ384)
β
Patchify (16Γ16 patches β 576 patches)
β
Random Masking (75% masked, 25% visible)
β
βββββββββββββββββββββββββββββββββββββββ
β MAE ENCODER β
β - Linear patch embedding β
β - Positional encoding (visible) β
β - 12 Transformer blocks β
β - 8 attention heads, 768 hidden β
βββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββ
β MAE DECODER β
β - Learnable mask tokens β
β - Positional encoding (all) β
β - 8 Transformer blocks β
β - 8 attention heads, 512 hidden β
β - Pixel reconstruction head β
βββββββββββββββββββββββββββββββββββββββ
β
Reconstructed Image
β
MSE Loss (on masked patches only)
Model Configuration
| Parameter | Default Value | Description |
|---|---|---|
| Image Size | 384Γ384 | Input image resolution |
| Patch Size | 16Γ16 | Size of each patch |
| Mask Ratio | 0.75 | Fraction of patches to mask |
| Encoder Depth | 12 layers | Number of transformer blocks |
| Encoder Dim | 768 | Hidden dimension |
| Encoder Heads | 8 | Number of attention heads |
| Decoder Depth | 8 layers | Number of transformer blocks |
| Decoder Dim | 512 | Hidden dimension |
| Decoder Heads | 8 | Number of attention heads |
| MLP Ratio | 4Γ | MLP expansion ratio (3072) |
| Dropout | 0.25 | Dropout rate |
π Getting Started
Prerequisites
- Python >= 3.8
- CUDA-capable GPU (recommended)
- 16GB+ RAM
Installation
- Clone the repository:
git clone https://github.com/adelelsayed/mae.git
cd mae
- Install dependencies:
pip install -r requirements.txt
Dataset Preparation
This project is configured for the CheXpert dataset. To use it:
- Download CheXpert-v1.0-small from Stanford ML Group
- Update paths in
configs/configs.py:root: Base directory for your datazip_path: Path to zipped dataset (optional, for faster loading)csv: Path to training CSVtrain_csv,val_csv,test_csv: Split CSV files
π Usage
Training
Start training from scratch:
python trainer/trainer.py
The trainer will:
- Automatically create checkpoint and log directories
- Resume from the last checkpoint if available
- Log training/validation metrics to text files
- Save plots every 10 epochs
- Save best model based on validation loss
Training Configuration
Edit configs/configs.py to customize training:
mae_config = {
# Training hyperparameters
"lr": 1e-4, # Learning rate
"warmup": 5, # Warmup epochs
"weight_decay": 5e-4, # AdamW weight decay
"num_epochs": 200, # Total training epochs
"batch_size": 96, # Batch size
"accumulation": 1, # Gradient accumulation steps
# Model architecture
"mask_ratio": 0.75, # Masking ratio
"encoder_depth": 12, # Encoder layers
"decoder_depth": 8, # Decoder layers
# Paths
"checkpoints": "/path/to/checkpoints",
"logdir": "/path/to/logs",
...
}
Monitoring Training
Training logs are saved in three files:
training_log.txt: Training metrics per epochval_log.txt: Validation metrics per epochtest_log.txt: Test set evaluation results
Metrics plots are saved every 10 epochs in {logdir}/{epoch}/metrics.png
Evaluation
The project includes a test method in the trainer. To evaluate:
from trainer.utils import MAETrainer
from configs.configs import mae_config
trainer = MAETrainer(mae_config)
trainer.test()
π Project Structure
mae/
βββ configs/
β βββ __init__.py
β βββ configs.py # Training configuration
βββ data/
β βββ __init__.py
β βββ dataset.py # CheXpert dataset loader
β βββ splitter.py # Dataset splitting utilities
βββ loss/
β βββ __init__.py
β βββ mae_loss.py # MAE reconstruction loss
βββ models/
β βββ __init__.py
β βββ mae.py # MAE architecture
βββ trainer/
β βββ __init__.py
β βββ trainer.py # Main training script
β βββ utils.py # Training utilities
βββ notebooks/
β βββ chexpert_mae.ipynb # Jupyter notebook for experiments
βββ training logs/ # Logged metrics and plots
βββ weights/ # Model checkpoints
βββ results/ # Evaluation results
βββ requirements.txt # Python dependencies
βββ LICENSE # Project license
βββ README.md # This file
π§ Components
Dataset (data/dataset.py)
- OptimizedZipReader: Fast ZIP file reading with LRU caching
- CheXpertDataset: PyTorch dataset for CheXpert chest X-rays
- 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc.
- Albumentations-based augmentation pipeline
- Class-balanced sampling support
- Frontal/lateral view filtering
Model (models/mae.py)
- Patchify/Unpatchify: Image-to-patch conversion utilities
- Random Masking: Stochastic patch masking with restore indices
- PositionalEncoding: Learnable position embeddings
- TransformerBlock: Multi-head self-attention + MLP
- MAEEncoder: Processes visible patches only
- MAEDecoder: Reconstructs full image with mask tokens
- MaskedAutoEncoder: Complete MAE model
Loss (loss/mae_loss.py)
Mean Squared Error (MSE) computed only on masked patches:
loss = ((pred - target) ** 2 * mask).sum() / mask.sum()
Trainer (trainer/utils.py)
- MAETrainer: Complete training pipeline
- Mixed precision training (AMP)
- Gradient clipping and accumulation
- Learning rate scheduling (warmup β cosine)
- Automatic checkpointing
- Multi-file logging (train/val/test)
- Live metric monitoring with tqdm
- Periodic metric visualization
π― CheXpert Pathologies
The dataset includes 14 chest X-ray findings:
- No Finding
- Enlarged Cardiomediastinum
- Cardiomegaly
- Lung Opacity
- Lung Lesion
- Edema
- Consolidation
- Pneumonia
- Atelectasis
- Pneumothorax
- Pleural Effusion
- Pleural Other
- Fracture
- Support Devices
π Training Tips
- Learning Rate: Start with 1e-4, use warmup for stability
- Batch Size: Maximize based on GPU memory (96 works well on 40GB GPUs)
- Gradient Accumulation: Use if batch size is limited by memory
- Mixed Precision: Enabled by default for faster training
- Masking Ratio: 75% is standard, higher ratios increase difficulty
- Resume Training: Model automatically resumes from last checkpoint
π¬ Use Cases
Pre-training for Downstream Tasks
Use the trained encoder as a feature extractor:
from models.mae import MaskedAutoEncoder
# Load pre-trained model
mae = MaskedAutoEncoder()
mae.load_state_dict(torch.load("best_mae.pth")["model"])
# Use encoder for feature extraction
encoder = mae.encoder
features, _, _, _ = encoder(images)
Fine-tuning on Classification
Add a classification head to the encoder for supervised tasks.
Anomaly Detection
Reconstruction error can indicate abnormalities in medical images.
π Performance Optimization
This implementation includes several optimizations:
- Efficient ZIP Reading: Avoids extracting files to disk
- LRU Cache: Keeps frequently accessed images in memory
- Persistent Workers: Reduces data loading overhead
- Mixed Precision: 2Γ faster training with minimal quality loss
- Gradient Checkpointing: Reduces memory usage (if enabled)
- CUDA Memory Management: Proper cache clearing and synchronization
π€ Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
π License
This project is licensed under the terms specified in the LICENSE file.
π References
Masked Autoencoders Are Scalable Vision Learners
He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022)
arXiv:2111.06377CheXpert: A Large Chest Radiograph Dataset
Irvin, J., et al. (2019)
Stanford ML Group
π Acknowledgments
- Original MAE paper by Meta AI Research
- CheXpert dataset by Stanford ML Group
- PyTorch and Albumentations communities
π§ Contact
For questions or issues, please open an issue on GitHub or contact the maintainer.
Note: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance.