--- language: - en tags: - medical-imaging - chest-xray - chexpert - multi-label-classification - mae - densenet - fpn - deep-learning - healthcare license: mit datasets: - chexpert library_name: pytorch pipeline_tag: image-classification model_name: CheXpert MAE–DenseNet–FPN model_type: hybrid transformer-cnn metrics: - roc_auc --- # CheXpert MAE-DenseNet-FPN A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining **Masked Autoencoders (MAE)**, **DenseNet** with CBAM attention, and **Feature Pyramid Networks (FPN)** with bidirectional cross-attention fusion. ## πŸ—οΈ Architecture Overview This project implements a novel multi-modal fusion architecture for medical image classification: - **MAE Encoder**: Vision Transformer-based masked autoencoder for self-supervised feature extraction - **DenseNet-169**: Dense convolutional network with Channel and Spatial Attention (CBAM) - **Feature Pyramid Network**: Multi-scale feature extraction at 4 different resolutions - **Bidirectional Cross-Attention**: Fusion mechanism allowing MAE and DenseNet features to attend to each other - **Learned Logit Ensemble**: Intelligent combination of 7 prediction heads with learnable temperature scaling ### Key Components ``` Input Image (384Γ—384) β”‚ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ β–Ό β–Ό MAE Encoder DenseNet-169 (ViT-based) (with CBAM) β”‚ β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”‚ β”‚ β”‚ β”‚ FPN Pyramid Dense Features β”‚ (P1-P4) (Multi-scale) β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ Bidirectional Cross-Attention β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ MAE Head Dense Head + 4 FPN Heads β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ Learned Ensemble (7 heads) β”‚ β–Ό 14-class Predictions ``` ## ✨ Features - **Hybrid Architecture**: Combines transformer-based and convolutional approaches - **Multi-scale Learning**: FPN extracts features at 4 different resolutions - **Advanced Fusion**: Bidirectional cross-attention between MAE and DenseNet features - **Optimized Training**: - Mixed precision training (FP16) - Gradient accumulation - Weighted sampling for class imbalance - Cosine annealing with linear warmup - Gradient checkpointing for memory efficiency - **Smart Data Loading**: - ZIP file reader with LRU caching - On-the-fly augmentation using Albumentations - Multi-worker data loading with persistent workers - **Comprehensive Evaluation**: - Per-class AUC metrics - Optimal threshold computation per class - Macro and Micro AUC tracking ## πŸ“‹ Requirements - Python 3.8+ - CUDA-capable GPU (recommended: 16GB+ VRAM) - CheXpert dataset ## πŸš€ Installation 1. **Clone the repository** ```bash git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git cd chexpert-mae-densenet-fpn ``` 2. **Create a virtual environment** ```bash python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate ``` 3. **Install dependencies** ```bash pip install -r requirements.txt ``` ## πŸ“Š Dataset Setup 1. **Download CheXpert Dataset** - Visit: https://stanfordmlgroup.github.io/competitions/chexpert/ - Download CheXpert-v1.0-small 2. **Prepare the dataset** ```bash # Extract the dataset unzip CheXpert-v1.0-small.zip # Optionally, create a ZIP archive for faster loading cd CheXpert-v1.0-small zip -r chexpert.zip train/ valid/ ``` 3. **Update configuration** - Edit `configs/configs.py` - Update `root` variable to point to your dataset location - Update all paths accordingly ## πŸ”§ Configuration Edit `configs/configs.py` to customize: ```python # Example: Update paths root = "/path/to/your/data" mae_config = { "lr": 1e-4, "num_epochs": 200, "batch_size": 96, # ... other parameters } config = { "lr": 1e-4, "num_epochs": 200, "batch_size": 36, # ... other parameters } ``` ## 🎯 Training ### Phase 1: Pre-train MAE ```bash python trainer/trainer.py # When prompted, type: mae ``` The MAE pre-training learns robust feature representations through masked image reconstruction. ### Phase 2: Train Classifier ```bash python trainer/trainer.py # When prompted, type: classifier ``` This loads the pre-trained MAE encoder and trains the full classification pipeline. ### Training Configuration - **MAE Training**: - Batch size: 96 - Mask ratio: 0.75 (masks 75% of patches) - Reconstruction loss on masked patches - **Classifier Training**: - Batch size: 36 with gradient accumulation (8 steps) - Effective batch size: 288 - Asymmetric loss with class weights - Per-class threshold optimization ## πŸ§ͺ Testing ```python from trainer.utils import Trainer from configs.configs import config # Initialize trainer trainer = Trainer(config) # Run evaluation on test set macro_auc, micro_auc, per_class = trainer.test( model_path="path/to/checkpoint.pth" ) print(f"Macro AUC: {macro_auc:.4f}") print(f"Micro AUC: {micro_auc:.4f}") ``` ## πŸ“ Project Structure ``` chexpert-mae-densenet-fpn/ β”œβ”€β”€ configs/ β”‚ β”œβ”€β”€ __init__.py β”‚ └── configs.py # Configuration parameters β”œβ”€β”€ data/ β”‚ β”œβ”€β”€ __init__.py β”‚ β”œβ”€β”€ dataset.py # CheXpert dataset with ZIP caching β”‚ └── splitter.py # Data splitting utilities β”œβ”€β”€ loss/ β”‚ β”œβ”€β”€ __init__.py β”‚ └── assymetric.py # Asymmetric loss for imbalanced data β”œβ”€β”€ models/ β”‚ β”œβ”€β”€ __init__.py β”‚ β”œβ”€β”€ mae.py # Masked Autoencoder implementation β”‚ β”œβ”€β”€ densenet.py # DenseNet-169 with CBAM β”‚ └── classifier.py # Full classification architecture β”œβ”€β”€ trainer/ β”‚ β”œβ”€β”€ __init__.py β”‚ β”œβ”€β”€ trainer.py # Main training script β”‚ β”œβ”€β”€ utils.py # Training utilities and loops β”‚ └── test.py # Testing utilities β”œβ”€β”€ notebooks/ β”‚ β”œβ”€β”€ chexpert_mae.ipynb # MAE experiments β”‚ └── chexpert_mae_mask_classifier.ipynb # Full pipeline experiments β”œβ”€β”€ requirements.txt └── README.md ``` ## πŸ“ˆ Model Architecture Details ### MAE Encoder - **Patch size**: 16Γ—16 - **Embedding dim**: 768 - **Depth**: 12 transformer blocks - **Heads**: 8 attention heads - **MLP ratio**: 4Γ— ### DenseNet-169 - **Growth rate (k)**: 64 - **Layers**: [6, 12, 24, 16] - **CBAM**: Channel + Spatial attention at each stage - **Dropout**: Progressive (0.05 β†’ 0.1 β†’ 0.1 β†’ 0.1) ### Cross-Attention Fusion - **12 bidirectional cross-attention layers** - **Projection dim**: 512 - **Attention heads**: 8 ### FPN - **Feature levels**: P1 (192Γ—192), P2 (96Γ—96), P3 (48Γ—48), P4 (24Γ—24) - **Channel unification**: 256 channels per level ## πŸŽ“ CheXpert Labels The model predicts 14 pathologies: 1. No Finding 2. Enlarged Cardiomediastinum 3. Cardiomegaly 4. Lung Opacity 5. Lung Lesion 6. Edema 7. Consolidation 8. Pneumonia 9. Atelectasis 10. Pneumothorax 11. Pleural Effusion 12. Pleural Other 13. Fracture 14. Support Devices ## πŸ”¬ Data Augmentation Training augmentations (conservative for medical images): - Horizontal flip (p=0.5) - Random affine (translation, scale, rotation Β±10Β°) - Random brightness/contrast - CLAHE histogram equalization - Gaussian blur and noise ## πŸ’Ύ Checkpoints The training automatically saves: - **Best MAE checkpoint**: Based on validation reconstruction loss - **Best classifier checkpoint**: Based on validation AUC (macro/micro) - **Training history**: JSON file with all metrics - **Per-epoch metrics plots**: Loss and AUC curves ## πŸ“Š Monitoring Training logs are saved to: - `training_log.txt`: Training progress with live metrics - `val_log.txt`: Validation results - `test_log.txt`: Test evaluation results - `history.json`: All metrics across epochs - `metrics.png`: Visualization plots ## ⚑ Performance Tips 1. **Memory Optimization**: - Use gradient checkpointing (already enabled) - Reduce batch size if OOM occurs - Increase gradient accumulation steps 2. **Speed Optimization**: - Use persistent workers (already enabled) - Enable cuDNN benchmark (already enabled) - Use ZIP caching for faster data loading 3. **Training Stability**: - Gradient clipping at norm 1.0 - Mixed precision with dynamic loss scaling - Warmup learning rate schedule ## πŸ› Troubleshooting **Q: Out of memory errors?** - Reduce batch size in configs.py - Increase gradient accumulation steps - Enable gradient checkpointing **Q: Slow training?** - Check if ZIP caching is enabled - Verify persistent workers are active - Monitor GPU utilization **Q: Poor convergence?** - Ensure MAE is properly pre-trained first - Check learning rate and warmup settings - Verify class weights are computed correctly ## πŸ“š Citation If you use this code in your research, please cite: ```bibtex @misc{chexpert-mae-densenet-fpn, author = {adel elsayed}, title = {CheXpert Classification with MAE-DenseNet-FPN}, year = {2025}, publisher = {GitHub}, url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn} } ``` ## πŸ™ Acknowledgments - **CheXpert Dataset**: Stanford ML Group - **Masked Autoencoders**: Meta AI Research (He et al., 2021) - **DenseNet**: Huang et al., 2017 - **CBAM**: Woo et al., 2018 - **Feature Pyramid Networks**: Lin et al., 2017 ## πŸ“„ License ## License This project is licensed under the MIT License. ## πŸ“§ Contact https://www.linkedin.com/in/adel-elsayed-a5260246/ **Note**: This is a research project. For clinical use, please ensure proper validation and regulatory approval.