Masked Autoencoder (MAE) for Medical Imaging

A PyTorch implementation of Masked Autoencoder (MAE) for self-supervised learning on chest X-ray images, specifically designed for the CheXpert dataset.

πŸ“‹ Overview

This project implements a Vision Transformer-based Masked Autoencoder that learns representations from chest X-ray images through self-supervised reconstruction. The model randomly masks 75% of image patches and learns to reconstruct the original image, enabling it to learn powerful visual representations without requiring labeled data.

Key Features

  • Vision Transformer Architecture: Encoder-decoder transformer architecture with positional encodings
  • Self-Supervised Learning: Pre-training through masked image reconstruction
  • Optimized for Medical Imaging: Designed specifically for chest X-ray analysis
  • Production-Ready Training Pipeline:
    • Mixed precision training (FP16) with gradient scaling
    • Gradient accumulation support
    • Learning rate warmup and cosine annealing
    • Automatic checkpointing and resumption
  • Efficient Data Loading:
    • Optimized ZIP file reader with LRU caching
    • Class-balanced sampling with weighted random sampler
    • Multi-worker data loading with persistent workers
  • Comprehensive Logging: Training/validation metrics tracking and visualization

πŸ—οΈ Architecture

Masked Autoencoder Structure

Input Image (384Γ—384) 
    ↓
Patchify (16Γ—16 patches β†’ 576 patches)
    ↓
Random Masking (75% masked, 25% visible)
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           MAE ENCODER               β”‚
β”‚  - Linear patch embedding           β”‚
β”‚  - Positional encoding (visible)    β”‚
β”‚  - 12 Transformer blocks            β”‚
β”‚  - 8 attention heads, 768 hidden    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           MAE DECODER               β”‚
β”‚  - Learnable mask tokens            β”‚
β”‚  - Positional encoding (all)        β”‚
β”‚  - 8 Transformer blocks             β”‚
β”‚  - 8 attention heads, 512 hidden    β”‚
β”‚  - Pixel reconstruction head        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    ↓
Reconstructed Image
    ↓
MSE Loss (on masked patches only)

Model Configuration

Parameter Default Value Description
Image Size 384Γ—384 Input image resolution
Patch Size 16Γ—16 Size of each patch
Mask Ratio 0.75 Fraction of patches to mask
Encoder Depth 12 layers Number of transformer blocks
Encoder Dim 768 Hidden dimension
Encoder Heads 8 Number of attention heads
Decoder Depth 8 layers Number of transformer blocks
Decoder Dim 512 Hidden dimension
Decoder Heads 8 Number of attention heads
MLP Ratio 4Γ— MLP expansion ratio (3072)
Dropout 0.25 Dropout rate

πŸš€ Getting Started

Prerequisites

  • Python >= 3.8
  • CUDA-capable GPU (recommended)
  • 16GB+ RAM

Installation

  1. Clone the repository:
git clone https://github.com/adelelsayed/mae.git
cd mae
  1. Install dependencies:
pip install -r requirements.txt

Dataset Preparation

This project is configured for the CheXpert dataset. To use it:

  1. Download CheXpert-v1.0-small from Stanford ML Group
  2. Update paths in configs/configs.py:
    • root: Base directory for your data
    • zip_path: Path to zipped dataset (optional, for faster loading)
    • csv: Path to training CSV
    • train_csv, val_csv, test_csv: Split CSV files

πŸ“Š Usage

Training

Start training from scratch:

python trainer/trainer.py

The trainer will:

  • Automatically create checkpoint and log directories
  • Resume from the last checkpoint if available
  • Log training/validation metrics to text files
  • Save plots every 10 epochs
  • Save best model based on validation loss

Training Configuration

Edit configs/configs.py to customize training:

mae_config = {
    # Training hyperparameters
    "lr": 1e-4,                    # Learning rate
    "warmup": 5,                   # Warmup epochs
    "weight_decay": 5e-4,          # AdamW weight decay
    "num_epochs": 200,             # Total training epochs
    "batch_size": 96,              # Batch size
    "accumulation": 1,             # Gradient accumulation steps
    
    # Model architecture
    "mask_ratio": 0.75,            # Masking ratio
    "encoder_depth": 12,           # Encoder layers
    "decoder_depth": 8,            # Decoder layers
    
    # Paths
    "checkpoints": "/path/to/checkpoints",
    "logdir": "/path/to/logs",
    ...
}

Monitoring Training

Training logs are saved in three files:

  • training_log.txt: Training metrics per epoch
  • val_log.txt: Validation metrics per epoch
  • test_log.txt: Test set evaluation results

Metrics plots are saved every 10 epochs in {logdir}/{epoch}/metrics.png

Evaluation

The project includes a test method in the trainer. To evaluate:

from trainer.utils import MAETrainer
from configs.configs import mae_config

trainer = MAETrainer(mae_config)
trainer.test()

πŸ“ Project Structure

mae/
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ __init__.py
β”‚   └── configs.py              # Training configuration
β”œβ”€β”€ data/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ dataset.py              # CheXpert dataset loader
β”‚   └── splitter.py             # Dataset splitting utilities
β”œβ”€β”€ loss/
β”‚   β”œβ”€β”€ __init__.py
β”‚   └── mae_loss.py             # MAE reconstruction loss
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ __init__.py
β”‚   └── mae.py                  # MAE architecture
β”œβ”€β”€ trainer/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ trainer.py              # Main training script
β”‚   └── utils.py                # Training utilities
β”œβ”€β”€ notebooks/
β”‚   └── chexpert_mae.ipynb      # Jupyter notebook for experiments
β”œβ”€β”€ training logs/              # Logged metrics and plots
β”œβ”€β”€ weights/                    # Model checkpoints
β”œβ”€β”€ results/                    # Evaluation results
β”œβ”€β”€ requirements.txt            # Python dependencies
β”œβ”€β”€ LICENSE                     # Project license
└── README.md                   # This file

πŸ”§ Components

Dataset (data/dataset.py)

  • OptimizedZipReader: Fast ZIP file reading with LRU caching
  • CheXpertDataset: PyTorch dataset for CheXpert chest X-rays
    • 14 pathology labels: No Finding, Cardiomegaly, Edema, Consolidation, etc.
    • Albumentations-based augmentation pipeline
    • Class-balanced sampling support
    • Frontal/lateral view filtering

Model (models/mae.py)

  • Patchify/Unpatchify: Image-to-patch conversion utilities
  • Random Masking: Stochastic patch masking with restore indices
  • PositionalEncoding: Learnable position embeddings
  • TransformerBlock: Multi-head self-attention + MLP
  • MAEEncoder: Processes visible patches only
  • MAEDecoder: Reconstructs full image with mask tokens
  • MaskedAutoEncoder: Complete MAE model

Loss (loss/mae_loss.py)

Mean Squared Error (MSE) computed only on masked patches:

loss = ((pred - target) ** 2 * mask).sum() / mask.sum()

Trainer (trainer/utils.py)

  • MAETrainer: Complete training pipeline
    • Mixed precision training (AMP)
    • Gradient clipping and accumulation
    • Learning rate scheduling (warmup β†’ cosine)
    • Automatic checkpointing
    • Multi-file logging (train/val/test)
    • Live metric monitoring with tqdm
    • Periodic metric visualization

🎯 CheXpert Pathologies

The dataset includes 14 chest X-ray findings:

  1. No Finding
  2. Enlarged Cardiomediastinum
  3. Cardiomegaly
  4. Lung Opacity
  5. Lung Lesion
  6. Edema
  7. Consolidation
  8. Pneumonia
  9. Atelectasis
  10. Pneumothorax
  11. Pleural Effusion
  12. Pleural Other
  13. Fracture
  14. Support Devices

πŸ“ˆ Training Tips

  1. Learning Rate: Start with 1e-4, use warmup for stability
  2. Batch Size: Maximize based on GPU memory (96 works well on 40GB GPUs)
  3. Gradient Accumulation: Use if batch size is limited by memory
  4. Mixed Precision: Enabled by default for faster training
  5. Masking Ratio: 75% is standard, higher ratios increase difficulty
  6. Resume Training: Model automatically resumes from last checkpoint

πŸ”¬ Use Cases

Pre-training for Downstream Tasks

Use the trained encoder as a feature extractor:

from models.mae import MaskedAutoEncoder

# Load pre-trained model
mae = MaskedAutoEncoder()
mae.load_state_dict(torch.load("best_mae.pth")["model"])

# Use encoder for feature extraction
encoder = mae.encoder
features, _, _, _ = encoder(images)

Fine-tuning on Classification

Add a classification head to the encoder for supervised tasks.

Anomaly Detection

Reconstruction error can indicate abnormalities in medical images.

πŸ“Š Performance Optimization

This implementation includes several optimizations:

  • Efficient ZIP Reading: Avoids extracting files to disk
  • LRU Cache: Keeps frequently accessed images in memory
  • Persistent Workers: Reduces data loading overhead
  • Mixed Precision: 2Γ— faster training with minimal quality loss
  • Gradient Checkpointing: Reduces memory usage (if enabled)
  • CUDA Memory Management: Proper cache clearing and synchronization

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

πŸ“„ License

This project is licensed under the terms specified in the LICENSE file.

πŸ“š References

  1. Masked Autoencoders Are Scalable Vision Learners
    He, K., Chen, X., Xie, S., Li, Y., DollΓ‘r, P., & Girshick, R. (2022)
    arXiv:2111.06377

  2. CheXpert: A Large Chest Radiograph Dataset
    Irvin, J., et al. (2019)
    Stanford ML Group

πŸ™ Acknowledgments

  • Original MAE paper by Meta AI Research
  • CheXpert dataset by Stanford ML Group
  • PyTorch and Albumentations communities

πŸ“§ Contact

For questions or issues, please open an issue on GitHub or contact the maintainer.


Note: This is a research/educational implementation. For clinical applications, please ensure proper validation and regulatory compliance.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support