|
|
--- |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- medical-imaging |
|
|
- chest-xray |
|
|
- chexpert |
|
|
- multi-label-classification |
|
|
- mae |
|
|
- densenet |
|
|
- fpn |
|
|
- deep-learning |
|
|
- healthcare |
|
|
license: mit |
|
|
datasets: |
|
|
- chexpert |
|
|
library_name: pytorch |
|
|
pipeline_tag: image-classification |
|
|
model_name: CheXpert MAE–DenseNet–FPN |
|
|
model_type: hybrid transformer-cnn |
|
|
metrics: |
|
|
- roc_auc |
|
|
--- |
|
|
|
|
|
# CheXpert MAE-DenseNet-FPN |
|
|
|
|
|
A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining **Masked Autoencoders (MAE)**, **DenseNet** with CBAM attention, and **Feature Pyramid Networks (FPN)** with bidirectional cross-attention fusion. |
|
|
|
|
|
## 🏗️ Architecture Overview |
|
|
|
|
|
This project implements a novel multi-modal fusion architecture for medical image classification: |
|
|
|
|
|
- **MAE Encoder**: Vision Transformer-based masked autoencoder for self-supervised feature extraction |
|
|
- **DenseNet-169**: Dense convolutional network with Channel and Spatial Attention (CBAM) |
|
|
- **Feature Pyramid Network**: Multi-scale feature extraction at 4 different resolutions |
|
|
- **Bidirectional Cross-Attention**: Fusion mechanism allowing MAE and DenseNet features to attend to each other |
|
|
- **Learned Logit Ensemble**: Intelligent combination of 7 prediction heads with learnable temperature scaling |
|
|
|
|
|
### Key Components |
|
|
|
|
|
``` |
|
|
Input Image (384×384) |
|
|
│ |
|
|
├─────────────────────────────┐ |
|
|
│ │ |
|
|
▼ ▼ |
|
|
MAE Encoder DenseNet-169 |
|
|
(ViT-based) (with CBAM) |
|
|
│ │ |
|
|
│ ┌───────────────────┤ |
|
|
│ │ │ |
|
|
│ FPN Pyramid Dense Features |
|
|
│ (P1-P4) (Multi-scale) |
|
|
│ │ │ |
|
|
└─────────┴───────────────────┘ |
|
|
│ |
|
|
Bidirectional Cross-Attention |
|
|
│ |
|
|
┌─────────┴──────────┐ |
|
|
│ │ |
|
|
MAE Head Dense Head + 4 FPN Heads |
|
|
│ │ |
|
|
└────────┬───────────┘ |
|
|
│ |
|
|
Learned Ensemble (7 heads) |
|
|
│ |
|
|
▼ |
|
|
14-class Predictions |
|
|
``` |
|
|
|
|
|
## ✨ Features |
|
|
|
|
|
- **Hybrid Architecture**: Combines transformer-based and convolutional approaches |
|
|
- **Multi-scale Learning**: FPN extracts features at 4 different resolutions |
|
|
- **Advanced Fusion**: Bidirectional cross-attention between MAE and DenseNet features |
|
|
- **Optimized Training**: |
|
|
- Mixed precision training (FP16) |
|
|
- Gradient accumulation |
|
|
- Weighted sampling for class imbalance |
|
|
- Cosine annealing with linear warmup |
|
|
- Gradient checkpointing for memory efficiency |
|
|
- **Smart Data Loading**: |
|
|
- ZIP file reader with LRU caching |
|
|
- On-the-fly augmentation using Albumentations |
|
|
- Multi-worker data loading with persistent workers |
|
|
- **Comprehensive Evaluation**: |
|
|
- Per-class AUC metrics |
|
|
- Optimal threshold computation per class |
|
|
- Macro and Micro AUC tracking |
|
|
|
|
|
## 📋 Requirements |
|
|
|
|
|
- Python 3.8+ |
|
|
- CUDA-capable GPU (recommended: 16GB+ VRAM) |
|
|
- CheXpert dataset |
|
|
|
|
|
## 🚀 Installation |
|
|
|
|
|
1. **Clone the repository** |
|
|
```bash |
|
|
git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git |
|
|
cd chexpert-mae-densenet-fpn |
|
|
``` |
|
|
|
|
|
2. **Create a virtual environment** |
|
|
```bash |
|
|
python -m venv venv |
|
|
source venv/bin/activate # On Windows: venv\Scripts\activate |
|
|
``` |
|
|
|
|
|
3. **Install dependencies** |
|
|
```bash |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
|
|
|
## 📊 Dataset Setup |
|
|
|
|
|
1. **Download CheXpert Dataset** |
|
|
- Visit: https://stanfordmlgroup.github.io/competitions/chexpert/ |
|
|
- Download CheXpert-v1.0-small |
|
|
|
|
|
2. **Prepare the dataset** |
|
|
```bash |
|
|
# Extract the dataset |
|
|
unzip CheXpert-v1.0-small.zip |
|
|
|
|
|
# Optionally, create a ZIP archive for faster loading |
|
|
cd CheXpert-v1.0-small |
|
|
zip -r chexpert.zip train/ valid/ |
|
|
``` |
|
|
|
|
|
3. **Update configuration** |
|
|
- Edit `configs/configs.py` |
|
|
- Update `root` variable to point to your dataset location |
|
|
- Update all paths accordingly |
|
|
|
|
|
## 🔧 Configuration |
|
|
|
|
|
Edit `configs/configs.py` to customize: |
|
|
|
|
|
```python |
|
|
# Example: Update paths |
|
|
root = "/path/to/your/data" |
|
|
|
|
|
mae_config = { |
|
|
"lr": 1e-4, |
|
|
"num_epochs": 200, |
|
|
"batch_size": 96, |
|
|
# ... other parameters |
|
|
} |
|
|
|
|
|
config = { |
|
|
"lr": 1e-4, |
|
|
"num_epochs": 200, |
|
|
"batch_size": 36, |
|
|
# ... other parameters |
|
|
} |
|
|
``` |
|
|
|
|
|
## 🎯 Training |
|
|
|
|
|
### Phase 1: Pre-train MAE |
|
|
|
|
|
```bash |
|
|
python trainer/trainer.py |
|
|
# When prompted, type: mae |
|
|
``` |
|
|
|
|
|
The MAE pre-training learns robust feature representations through masked image reconstruction. |
|
|
|
|
|
### Phase 2: Train Classifier |
|
|
|
|
|
```bash |
|
|
python trainer/trainer.py |
|
|
# When prompted, type: classifier |
|
|
``` |
|
|
|
|
|
This loads the pre-trained MAE encoder and trains the full classification pipeline. |
|
|
|
|
|
### Training Configuration |
|
|
|
|
|
- **MAE Training**: |
|
|
- Batch size: 96 |
|
|
- Mask ratio: 0.75 (masks 75% of patches) |
|
|
- Reconstruction loss on masked patches |
|
|
|
|
|
- **Classifier Training**: |
|
|
- Batch size: 36 with gradient accumulation (8 steps) |
|
|
- Effective batch size: 288 |
|
|
- Asymmetric loss with class weights |
|
|
- Per-class threshold optimization |
|
|
|
|
|
## 🧪 Testing |
|
|
|
|
|
```python |
|
|
from trainer.utils import Trainer |
|
|
from configs.configs import config |
|
|
|
|
|
# Initialize trainer |
|
|
trainer = Trainer(config) |
|
|
|
|
|
# Run evaluation on test set |
|
|
macro_auc, micro_auc, per_class = trainer.test( |
|
|
model_path="path/to/checkpoint.pth" |
|
|
) |
|
|
|
|
|
print(f"Macro AUC: {macro_auc:.4f}") |
|
|
print(f"Micro AUC: {micro_auc:.4f}") |
|
|
``` |
|
|
|
|
|
## 📁 Project Structure |
|
|
|
|
|
``` |
|
|
chexpert-mae-densenet-fpn/ |
|
|
├── configs/ |
|
|
│ ├── __init__.py |
|
|
│ └── configs.py # Configuration parameters |
|
|
├── data/ |
|
|
│ ├── __init__.py |
|
|
│ ├── dataset.py # CheXpert dataset with ZIP caching |
|
|
│ └── splitter.py # Data splitting utilities |
|
|
├── loss/ |
|
|
│ ├── __init__.py |
|
|
│ └── assymetric.py # Asymmetric loss for imbalanced data |
|
|
├── models/ |
|
|
│ ├── __init__.py |
|
|
│ ├── mae.py # Masked Autoencoder implementation |
|
|
│ ├── densenet.py # DenseNet-169 with CBAM |
|
|
│ └── classifier.py # Full classification architecture |
|
|
├── trainer/ |
|
|
│ ├── __init__.py |
|
|
│ ├── trainer.py # Main training script |
|
|
│ ├── utils.py # Training utilities and loops |
|
|
│ └── test.py # Testing utilities |
|
|
├── notebooks/ |
|
|
│ ├── chexpert_mae.ipynb # MAE experiments |
|
|
│ └── chexpert_mae_mask_classifier.ipynb # Full pipeline experiments |
|
|
├── requirements.txt |
|
|
└── README.md |
|
|
``` |
|
|
|
|
|
## 📈 Model Architecture Details |
|
|
|
|
|
### MAE Encoder |
|
|
- **Patch size**: 16×16 |
|
|
- **Embedding dim**: 768 |
|
|
- **Depth**: 12 transformer blocks |
|
|
- **Heads**: 8 attention heads |
|
|
- **MLP ratio**: 4× |
|
|
|
|
|
### DenseNet-169 |
|
|
- **Growth rate (k)**: 64 |
|
|
- **Layers**: [6, 12, 24, 16] |
|
|
- **CBAM**: Channel + Spatial attention at each stage |
|
|
- **Dropout**: Progressive (0.05 → 0.1 → 0.1 → 0.1) |
|
|
|
|
|
### Cross-Attention Fusion |
|
|
- **12 bidirectional cross-attention layers** |
|
|
- **Projection dim**: 512 |
|
|
- **Attention heads**: 8 |
|
|
|
|
|
### FPN |
|
|
- **Feature levels**: P1 (192×192), P2 (96×96), P3 (48×48), P4 (24×24) |
|
|
- **Channel unification**: 256 channels per level |
|
|
|
|
|
## 🎓 CheXpert Labels |
|
|
|
|
|
The model predicts 14 pathologies: |
|
|
|
|
|
1. No Finding |
|
|
2. Enlarged Cardiomediastinum |
|
|
3. Cardiomegaly |
|
|
4. Lung Opacity |
|
|
5. Lung Lesion |
|
|
6. Edema |
|
|
7. Consolidation |
|
|
8. Pneumonia |
|
|
9. Atelectasis |
|
|
10. Pneumothorax |
|
|
11. Pleural Effusion |
|
|
12. Pleural Other |
|
|
13. Fracture |
|
|
14. Support Devices |
|
|
|
|
|
## 🔬 Data Augmentation |
|
|
|
|
|
Training augmentations (conservative for medical images): |
|
|
- Horizontal flip (p=0.5) |
|
|
- Random affine (translation, scale, rotation ±10°) |
|
|
- Random brightness/contrast |
|
|
- CLAHE histogram equalization |
|
|
- Gaussian blur and noise |
|
|
|
|
|
## 💾 Checkpoints |
|
|
|
|
|
The training automatically saves: |
|
|
- **Best MAE checkpoint**: Based on validation reconstruction loss |
|
|
- **Best classifier checkpoint**: Based on validation AUC (macro/micro) |
|
|
- **Training history**: JSON file with all metrics |
|
|
- **Per-epoch metrics plots**: Loss and AUC curves |
|
|
|
|
|
## 📊 Monitoring |
|
|
|
|
|
Training logs are saved to: |
|
|
- `training_log.txt`: Training progress with live metrics |
|
|
- `val_log.txt`: Validation results |
|
|
- `test_log.txt`: Test evaluation results |
|
|
- `history.json`: All metrics across epochs |
|
|
- `metrics.png`: Visualization plots |
|
|
|
|
|
## ⚡ Performance Tips |
|
|
|
|
|
1. **Memory Optimization**: |
|
|
- Use gradient checkpointing (already enabled) |
|
|
- Reduce batch size if OOM occurs |
|
|
- Increase gradient accumulation steps |
|
|
|
|
|
2. **Speed Optimization**: |
|
|
- Use persistent workers (already enabled) |
|
|
- Enable cuDNN benchmark (already enabled) |
|
|
- Use ZIP caching for faster data loading |
|
|
|
|
|
3. **Training Stability**: |
|
|
- Gradient clipping at norm 1.0 |
|
|
- Mixed precision with dynamic loss scaling |
|
|
- Warmup learning rate schedule |
|
|
|
|
|
## 🐛 Troubleshooting |
|
|
|
|
|
**Q: Out of memory errors?** |
|
|
- Reduce batch size in configs.py |
|
|
- Increase gradient accumulation steps |
|
|
- Enable gradient checkpointing |
|
|
|
|
|
**Q: Slow training?** |
|
|
- Check if ZIP caching is enabled |
|
|
- Verify persistent workers are active |
|
|
- Monitor GPU utilization |
|
|
|
|
|
**Q: Poor convergence?** |
|
|
- Ensure MAE is properly pre-trained first |
|
|
- Check learning rate and warmup settings |
|
|
- Verify class weights are computed correctly |
|
|
|
|
|
## 📚 Citation |
|
|
|
|
|
If you use this code in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{chexpert-mae-densenet-fpn, |
|
|
author = {adel elsayed}, |
|
|
title = {CheXpert Classification with MAE-DenseNet-FPN}, |
|
|
year = {2025}, |
|
|
publisher = {GitHub}, |
|
|
url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn} |
|
|
} |
|
|
``` |
|
|
|
|
|
## 🙏 Acknowledgments |
|
|
|
|
|
- **CheXpert Dataset**: Stanford ML Group |
|
|
- **Masked Autoencoders**: Meta AI Research (He et al., 2021) |
|
|
- **DenseNet**: Huang et al., 2017 |
|
|
- **CBAM**: Woo et al., 2018 |
|
|
- **Feature Pyramid Networks**: Lin et al., 2017 |
|
|
|
|
|
## 📄 License |
|
|
|
|
|
## License |
|
|
This project is licensed under the MIT License. |
|
|
|
|
|
|
|
|
## 📧 Contact |
|
|
|
|
|
https://www.linkedin.com/in/adel-elsayed-a5260246/ |
|
|
|
|
|
**Note**: This is a research project. For clinical use, please ensure proper validation and regulatory approval. |