Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- LICENSE +21 -0
- README.md +353 -3
- configs/__init__.py +0 -0
- configs/__pycache__/__init__.cpython-313.pyc +0 -0
- configs/__pycache__/configs.cpython-313.pyc +0 -0
- configs/configs.py +55 -0
- data/__init__.py +0 -0
- data/__pycache__/__init__.cpython-313.pyc +0 -0
- data/__pycache__/__init__.cpython-314.pyc +0 -0
- data/__pycache__/dataset.cpython-313.pyc +0 -0
- data/__pycache__/dataset.cpython-314.pyc +0 -0
- data/dataset.py +508 -0
- data/splitter.py +347 -0
- gitignore.txt +61 -0
- loss/__init__.py +0 -0
- loss/__pycache__/__init__.cpython-313.pyc +0 -0
- loss/__pycache__/assymetric.cpython-313.pyc +0 -0
- loss/assymetric.py +59 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-313.pyc +0 -0
- models/__pycache__/classifier.cpython-313.pyc +0 -0
- models/__pycache__/densenet.cpython-313.pyc +0 -0
- models/__pycache__/mae.cpython-313.pyc +0 -0
- models/classifier.py +323 -0
- models/densenet.py +157 -0
- models/mae.py +177 -0
- notebooks/chexpert_mae.ipynb +0 -0
- notebooks/chexpert_mae_mask_classifier.ipynb +0 -0
- requirements.txt +29 -0
- results/test-results.docx +0 -0
- trainer/__init__.py +0 -0
- trainer/__pycache__/__init__.cpython-313.pyc +0 -0
- trainer/__pycache__/__init__.cpython-314.pyc +0 -0
- trainer/__pycache__/trainer.cpython-313.pyc +0 -0
- trainer/__pycache__/trainer.cpython-314.pyc +0 -0
- trainer/__pycache__/utils.cpython-313.pyc +0 -0
- trainer/test.py +15 -0
- trainer/trainer.py +19 -0
- trainer/utils.py +837 -0
- training logs/classifier/1/metrics.png +0 -0
- training logs/classifier/11/metrics.png +3 -0
- training logs/classifier/Events.docx +3 -0
- training logs/classifier/history.json +1 -0
- training logs/classifier/test_log.txt +0 -0
- training logs/classifier/training_log.txt +0 -0
- training logs/classifier/val_log.txt +0 -0
- training logs/mae/1/metrics.png +0 -0
- training logs/mae/101/metrics.png +0 -0
- training logs/mae/11/metrics.png +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
training[[:space:]]logs/classifier/11/metrics.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
training[[:space:]]logs/classifier/Events.docx filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Adel Elsayed
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,353 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CheXpert MAE-DenseNet-FPN
|
| 2 |
+
|
| 3 |
+
A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining **Masked Autoencoders (MAE)**, **DenseNet** with CBAM attention, and **Feature Pyramid Networks (FPN)** with bidirectional cross-attention fusion.
|
| 4 |
+
|
| 5 |
+
## 🏗️ Architecture Overview
|
| 6 |
+
|
| 7 |
+
This project implements a novel multi-modal fusion architecture for medical image classification:
|
| 8 |
+
|
| 9 |
+
- **MAE Encoder**: Vision Transformer-based masked autoencoder for self-supervised feature extraction
|
| 10 |
+
- **DenseNet-169**: Dense convolutional network with Channel and Spatial Attention (CBAM)
|
| 11 |
+
- **Feature Pyramid Network**: Multi-scale feature extraction at 4 different resolutions
|
| 12 |
+
- **Bidirectional Cross-Attention**: Fusion mechanism allowing MAE and DenseNet features to attend to each other
|
| 13 |
+
- **Learned Logit Ensemble**: Intelligent combination of 7 prediction heads with learnable temperature scaling
|
| 14 |
+
|
| 15 |
+
### Key Components
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
Input Image (384×384)
|
| 19 |
+
│
|
| 20 |
+
├─────────────────────────────┐
|
| 21 |
+
│ │
|
| 22 |
+
▼ ▼
|
| 23 |
+
MAE Encoder DenseNet-169
|
| 24 |
+
(ViT-based) (with CBAM)
|
| 25 |
+
│ │
|
| 26 |
+
│ ┌───────────────────┤
|
| 27 |
+
│ │ │
|
| 28 |
+
│ FPN Pyramid Dense Features
|
| 29 |
+
│ (P1-P4) (Multi-scale)
|
| 30 |
+
│ │ │
|
| 31 |
+
└─────────┴───────────────────┘
|
| 32 |
+
│
|
| 33 |
+
Bidirectional Cross-Attention
|
| 34 |
+
│
|
| 35 |
+
┌─────────┴──────────┐
|
| 36 |
+
│ │
|
| 37 |
+
MAE Head Dense Head + 4 FPN Heads
|
| 38 |
+
│ │
|
| 39 |
+
└────────┬───────────┘
|
| 40 |
+
│
|
| 41 |
+
Learned Ensemble (7 heads)
|
| 42 |
+
│
|
| 43 |
+
▼
|
| 44 |
+
14-class Predictions
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## ✨ Features
|
| 48 |
+
|
| 49 |
+
- **Hybrid Architecture**: Combines transformer-based and convolutional approaches
|
| 50 |
+
- **Multi-scale Learning**: FPN extracts features at 4 different resolutions
|
| 51 |
+
- **Advanced Fusion**: Bidirectional cross-attention between MAE and DenseNet features
|
| 52 |
+
- **Optimized Training**:
|
| 53 |
+
- Mixed precision training (FP16)
|
| 54 |
+
- Gradient accumulation
|
| 55 |
+
- Weighted sampling for class imbalance
|
| 56 |
+
- Cosine annealing with linear warmup
|
| 57 |
+
- Gradient checkpointing for memory efficiency
|
| 58 |
+
- **Smart Data Loading**:
|
| 59 |
+
- ZIP file reader with LRU caching
|
| 60 |
+
- On-the-fly augmentation using Albumentations
|
| 61 |
+
- Multi-worker data loading with persistent workers
|
| 62 |
+
- **Comprehensive Evaluation**:
|
| 63 |
+
- Per-class AUC metrics
|
| 64 |
+
- Optimal threshold computation per class
|
| 65 |
+
- Macro and Micro AUC tracking
|
| 66 |
+
|
| 67 |
+
## 📋 Requirements
|
| 68 |
+
|
| 69 |
+
- Python 3.8+
|
| 70 |
+
- CUDA-capable GPU (recommended: 16GB+ VRAM)
|
| 71 |
+
- CheXpert dataset
|
| 72 |
+
|
| 73 |
+
## 🚀 Installation
|
| 74 |
+
|
| 75 |
+
1. **Clone the repository**
|
| 76 |
+
```bash
|
| 77 |
+
git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git
|
| 78 |
+
cd chexpert-mae-densenet-fpn
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
2. **Create a virtual environment**
|
| 82 |
+
```bash
|
| 83 |
+
python -m venv venv
|
| 84 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
3. **Install dependencies**
|
| 88 |
+
```bash
|
| 89 |
+
pip install -r requirements.txt
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## 📊 Dataset Setup
|
| 93 |
+
|
| 94 |
+
1. **Download CheXpert Dataset**
|
| 95 |
+
- Visit: https://stanfordmlgroup.github.io/competitions/chexpert/
|
| 96 |
+
- Download CheXpert-v1.0-small
|
| 97 |
+
|
| 98 |
+
2. **Prepare the dataset**
|
| 99 |
+
```bash
|
| 100 |
+
# Extract the dataset
|
| 101 |
+
unzip CheXpert-v1.0-small.zip
|
| 102 |
+
|
| 103 |
+
# Optionally, create a ZIP archive for faster loading
|
| 104 |
+
cd CheXpert-v1.0-small
|
| 105 |
+
zip -r chexpert.zip train/ valid/
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
3. **Update configuration**
|
| 109 |
+
- Edit `configs/configs.py`
|
| 110 |
+
- Update `root` variable to point to your dataset location
|
| 111 |
+
- Update all paths accordingly
|
| 112 |
+
|
| 113 |
+
## 🔧 Configuration
|
| 114 |
+
|
| 115 |
+
Edit `configs/configs.py` to customize:
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
# Example: Update paths
|
| 119 |
+
root = "/path/to/your/data"
|
| 120 |
+
|
| 121 |
+
mae_config = {
|
| 122 |
+
"lr": 1e-4,
|
| 123 |
+
"num_epochs": 200,
|
| 124 |
+
"batch_size": 96,
|
| 125 |
+
# ... other parameters
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
config = {
|
| 129 |
+
"lr": 1e-4,
|
| 130 |
+
"num_epochs": 200,
|
| 131 |
+
"batch_size": 36,
|
| 132 |
+
# ... other parameters
|
| 133 |
+
}
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## 🎯 Training
|
| 137 |
+
|
| 138 |
+
### Phase 1: Pre-train MAE
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
python trainer/trainer.py
|
| 142 |
+
# When prompted, type: mae
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
The MAE pre-training learns robust feature representations through masked image reconstruction.
|
| 146 |
+
|
| 147 |
+
### Phase 2: Train Classifier
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
python trainer/trainer.py
|
| 151 |
+
# When prompted, type: classifier
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
This loads the pre-trained MAE encoder and trains the full classification pipeline.
|
| 155 |
+
|
| 156 |
+
### Training Configuration
|
| 157 |
+
|
| 158 |
+
- **MAE Training**:
|
| 159 |
+
- Batch size: 96
|
| 160 |
+
- Mask ratio: 0.75 (masks 75% of patches)
|
| 161 |
+
- Reconstruction loss on masked patches
|
| 162 |
+
|
| 163 |
+
- **Classifier Training**:
|
| 164 |
+
- Batch size: 36 with gradient accumulation (8 steps)
|
| 165 |
+
- Effective batch size: 288
|
| 166 |
+
- Asymmetric loss with class weights
|
| 167 |
+
- Per-class threshold optimization
|
| 168 |
+
|
| 169 |
+
## 🧪 Testing
|
| 170 |
+
|
| 171 |
+
```python
|
| 172 |
+
from trainer.utils import Trainer
|
| 173 |
+
from configs.configs import config
|
| 174 |
+
|
| 175 |
+
# Initialize trainer
|
| 176 |
+
trainer = Trainer(config)
|
| 177 |
+
|
| 178 |
+
# Run evaluation on test set
|
| 179 |
+
macro_auc, micro_auc, per_class = trainer.test(
|
| 180 |
+
model_path="path/to/checkpoint.pth"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
print(f"Macro AUC: {macro_auc:.4f}")
|
| 184 |
+
print(f"Micro AUC: {micro_auc:.4f}")
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
## 📁 Project Structure
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
chexpert-mae-densenet-fpn/
|
| 191 |
+
├── configs/
|
| 192 |
+
│ ├── __init__.py
|
| 193 |
+
│ └── configs.py # Configuration parameters
|
| 194 |
+
├── data/
|
| 195 |
+
│ ├── __init__.py
|
| 196 |
+
│ ├── dataset.py # CheXpert dataset with ZIP caching
|
| 197 |
+
│ └── splitter.py # Data splitting utilities
|
| 198 |
+
├── loss/
|
| 199 |
+
│ ├── __init__.py
|
| 200 |
+
│ └── assymetric.py # Asymmetric loss for imbalanced data
|
| 201 |
+
├── models/
|
| 202 |
+
│ ├── __init__.py
|
| 203 |
+
│ ├── mae.py # Masked Autoencoder implementation
|
| 204 |
+
│ ├── densenet.py # DenseNet-169 with CBAM
|
| 205 |
+
│ └── classifier.py # Full classification architecture
|
| 206 |
+
├── trainer/
|
| 207 |
+
│ ├── __init__.py
|
| 208 |
+
│ ├── trainer.py # Main training script
|
| 209 |
+
│ ├── utils.py # Training utilities and loops
|
| 210 |
+
│ └── test.py # Testing utilities
|
| 211 |
+
├── notebooks/
|
| 212 |
+
│ ├── chexpert_mae.ipynb # MAE experiments
|
| 213 |
+
│ └── chexpert_mae_mask_classifier.ipynb # Full pipeline experiments
|
| 214 |
+
├── requirements.txt
|
| 215 |
+
└── README.md
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
## 📈 Model Architecture Details
|
| 219 |
+
|
| 220 |
+
### MAE Encoder
|
| 221 |
+
- **Patch size**: 16×16
|
| 222 |
+
- **Embedding dim**: 768
|
| 223 |
+
- **Depth**: 12 transformer blocks
|
| 224 |
+
- **Heads**: 8 attention heads
|
| 225 |
+
- **MLP ratio**: 4×
|
| 226 |
+
|
| 227 |
+
### DenseNet-169
|
| 228 |
+
- **Growth rate (k)**: 64
|
| 229 |
+
- **Layers**: [6, 12, 24, 16]
|
| 230 |
+
- **CBAM**: Channel + Spatial attention at each stage
|
| 231 |
+
- **Dropout**: Progressive (0.05 → 0.1 → 0.1 → 0.1)
|
| 232 |
+
|
| 233 |
+
### Cross-Attention Fusion
|
| 234 |
+
- **12 bidirectional cross-attention layers**
|
| 235 |
+
- **Projection dim**: 512
|
| 236 |
+
- **Attention heads**: 8
|
| 237 |
+
|
| 238 |
+
### FPN
|
| 239 |
+
- **Feature levels**: P1 (192×192), P2 (96×96), P3 (48×48), P4 (24×24)
|
| 240 |
+
- **Channel unification**: 256 channels per level
|
| 241 |
+
|
| 242 |
+
## 🎓 CheXpert Labels
|
| 243 |
+
|
| 244 |
+
The model predicts 14 pathologies:
|
| 245 |
+
|
| 246 |
+
1. No Finding
|
| 247 |
+
2. Enlarged Cardiomediastinum
|
| 248 |
+
3. Cardiomegaly
|
| 249 |
+
4. Lung Opacity
|
| 250 |
+
5. Lung Lesion
|
| 251 |
+
6. Edema
|
| 252 |
+
7. Consolidation
|
| 253 |
+
8. Pneumonia
|
| 254 |
+
9. Atelectasis
|
| 255 |
+
10. Pneumothorax
|
| 256 |
+
11. Pleural Effusion
|
| 257 |
+
12. Pleural Other
|
| 258 |
+
13. Fracture
|
| 259 |
+
14. Support Devices
|
| 260 |
+
|
| 261 |
+
## 🔬 Data Augmentation
|
| 262 |
+
|
| 263 |
+
Training augmentations (conservative for medical images):
|
| 264 |
+
- Horizontal flip (p=0.5)
|
| 265 |
+
- Random affine (translation, scale, rotation ±10°)
|
| 266 |
+
- Random brightness/contrast
|
| 267 |
+
- CLAHE histogram equalization
|
| 268 |
+
- Gaussian blur and noise
|
| 269 |
+
|
| 270 |
+
## 💾 Checkpoints
|
| 271 |
+
|
| 272 |
+
The training automatically saves:
|
| 273 |
+
- **Best MAE checkpoint**: Based on validation reconstruction loss
|
| 274 |
+
- **Best classifier checkpoint**: Based on validation AUC (macro/micro)
|
| 275 |
+
- **Training history**: JSON file with all metrics
|
| 276 |
+
- **Per-epoch metrics plots**: Loss and AUC curves
|
| 277 |
+
|
| 278 |
+
## 📊 Monitoring
|
| 279 |
+
|
| 280 |
+
Training logs are saved to:
|
| 281 |
+
- `training_log.txt`: Training progress with live metrics
|
| 282 |
+
- `val_log.txt`: Validation results
|
| 283 |
+
- `test_log.txt`: Test evaluation results
|
| 284 |
+
- `history.json`: All metrics across epochs
|
| 285 |
+
- `metrics.png`: Visualization plots
|
| 286 |
+
|
| 287 |
+
## ⚡ Performance Tips
|
| 288 |
+
|
| 289 |
+
1. **Memory Optimization**:
|
| 290 |
+
- Use gradient checkpointing (already enabled)
|
| 291 |
+
- Reduce batch size if OOM occurs
|
| 292 |
+
- Increase gradient accumulation steps
|
| 293 |
+
|
| 294 |
+
2. **Speed Optimization**:
|
| 295 |
+
- Use persistent workers (already enabled)
|
| 296 |
+
- Enable cuDNN benchmark (already enabled)
|
| 297 |
+
- Use ZIP caching for faster data loading
|
| 298 |
+
|
| 299 |
+
3. **Training Stability**:
|
| 300 |
+
- Gradient clipping at norm 1.0
|
| 301 |
+
- Mixed precision with dynamic loss scaling
|
| 302 |
+
- Warmup learning rate schedule
|
| 303 |
+
|
| 304 |
+
## 🐛 Troubleshooting
|
| 305 |
+
|
| 306 |
+
**Q: Out of memory errors?**
|
| 307 |
+
- Reduce batch size in configs.py
|
| 308 |
+
- Increase gradient accumulation steps
|
| 309 |
+
- Enable gradient checkpointing
|
| 310 |
+
|
| 311 |
+
**Q: Slow training?**
|
| 312 |
+
- Check if ZIP caching is enabled
|
| 313 |
+
- Verify persistent workers are active
|
| 314 |
+
- Monitor GPU utilization
|
| 315 |
+
|
| 316 |
+
**Q: Poor convergence?**
|
| 317 |
+
- Ensure MAE is properly pre-trained first
|
| 318 |
+
- Check learning rate and warmup settings
|
| 319 |
+
- Verify class weights are computed correctly
|
| 320 |
+
|
| 321 |
+
## 📚 Citation
|
| 322 |
+
|
| 323 |
+
If you use this code in your research, please cite:
|
| 324 |
+
|
| 325 |
+
```bibtex
|
| 326 |
+
@misc{chexpert-mae-densenet-fpn,
|
| 327 |
+
author = {adel elsayed},
|
| 328 |
+
title = {CheXpert Classification with MAE-DenseNet-FPN},
|
| 329 |
+
year = {2025},
|
| 330 |
+
publisher = {GitHub},
|
| 331 |
+
url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn}
|
| 332 |
+
}
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
## 🙏 Acknowledgments
|
| 336 |
+
|
| 337 |
+
- **CheXpert Dataset**: Stanford ML Group
|
| 338 |
+
- **Masked Autoencoders**: Meta AI Research (He et al., 2021)
|
| 339 |
+
- **DenseNet**: Huang et al., 2017
|
| 340 |
+
- **CBAM**: Woo et al., 2018
|
| 341 |
+
- **Feature Pyramid Networks**: Lin et al., 2017
|
| 342 |
+
|
| 343 |
+
## 📄 License
|
| 344 |
+
|
| 345 |
+
## License
|
| 346 |
+
This project is licensed under the MIT License.
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
## 📧 Contact
|
| 350 |
+
|
| 351 |
+
https://www.linkedin.com/in/adel-elsayed-a5260246/
|
| 352 |
+
|
| 353 |
+
**Note**: This is a research project. For clinical use, please ensure proper validation and regulatory approval.
|
configs/__init__.py
ADDED
|
File without changes
|
configs/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
configs/__pycache__/configs.cpython-313.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
configs/configs.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
root = "/content/drive/MyDrive"
|
| 4 |
+
mae_config={
|
| 5 |
+
"lr":1e-4,
|
| 6 |
+
"warmup":5,
|
| 7 |
+
"weight_decay":5e-4,
|
| 8 |
+
"num_epochs":200,
|
| 9 |
+
"num_classes":14,
|
| 10 |
+
"zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"),
|
| 11 |
+
"resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"),
|
| 12 |
+
"logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs"),
|
| 13 |
+
"checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),
|
| 14 |
+
"datadir":root,
|
| 15 |
+
"lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"),
|
| 16 |
+
"csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"),
|
| 17 |
+
"batch_size":96,
|
| 18 |
+
"device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
| 19 |
+
"accumulation":1,
|
| 20 |
+
"dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),os.path.join(root,"CheXpert-v1.0-small","maelogs")],
|
| 21 |
+
"train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"),
|
| 22 |
+
"val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"),
|
| 23 |
+
"test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
|
| 24 |
+
,"channels":1,"mask_ratio":0.75,"dropout":0.25,"img_size":384,"encoder_dim":768,
|
| 25 |
+
"mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8,
|
| 26 |
+
"decoder_head":8,"patch_size":16
|
| 27 |
+
}
|
| 28 |
+
config={
|
| 29 |
+
"lr":1e-4,
|
| 30 |
+
"warmup":10,
|
| 31 |
+
"weight_decay":5e-4,
|
| 32 |
+
"num_epochs":200,
|
| 33 |
+
"num_classes":14,
|
| 34 |
+
"zip_path":os.path.join(root,"CheXpert-v1.0-small","chexpert.zip"),
|
| 35 |
+
"backbone":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","best_mae.pth"),
|
| 36 |
+
"densebackbone":os.path.join(root,"CheXpert-v1.0-small","checkpoints","No Eca with masking best_dense.pth"),
|
| 37 |
+
"resume":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","fpn","best_mae_classifier.pth"),
|
| 38 |
+
"logdir":os.path.join(root,"CheXpert-v1.0-small","maelogs","fpn","classifier"),
|
| 39 |
+
"checkpoints":os.path.join(root,"CheXpert-v1.0-small","maecheckpoints"),
|
| 40 |
+
"datadir":root,
|
| 41 |
+
"lmdb":os.path.join(root,"CheXpert-v1.0-small","lmdb"),
|
| 42 |
+
"csv":os.path.join(root,"CheXpert-v1.0-small","train.csv"),
|
| 43 |
+
"batch_size":36,
|
| 44 |
+
"device":torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
|
| 45 |
+
"accumulation":8,
|
| 46 |
+
"maskdir":os.path.join(root,"CheXpert-v1.0-small","fpn","mask"),
|
| 47 |
+
"dirsToMake":[os.path.join(root,"CheXpert-v1.0-small","maecheckpoints","fpn"),os.path.join(root,"CheXpert-v1.0-small","maelogs","fpn","classifier"),os.path.join(root,"CheXpert-v1.0-small","fpn","mask")],
|
| 48 |
+
"train_csv":os.path.join(root,"CheXpert-v1.0-small","train_ready.csv"),
|
| 49 |
+
"val_csv":os.path.join(root,"CheXpert-v1.0-small","val_ready.csv"),
|
| 50 |
+
"test_csv":os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
|
| 51 |
+
,"channels":1,"mask_ratio":0,"dropout":0.25,"img_size":384,"encoder_dim":768,
|
| 52 |
+
"mlp_dim":3072,"decoder_dim":512,"encoder_depth":12,"encoder_head":8,"decoder_depth":8,
|
| 53 |
+
"decoder_head":8,"patch_size":16
|
| 54 |
+
|
| 55 |
+
}
|
data/__init__.py
ADDED
|
File without changes
|
data/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
data/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
data/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (21.6 kB). View file
|
|
|
data/__pycache__/dataset.cpython-314.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
data/dataset.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standard library
|
| 2 |
+
import os
|
| 3 |
+
import io
|
| 4 |
+
import zipfile
|
| 5 |
+
import pickle
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Data handling
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
# PyTorch
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
|
| 16 |
+
# Image processing
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import cv2
|
| 19 |
+
|
| 20 |
+
# Augmentations
|
| 21 |
+
import albumentations as A
|
| 22 |
+
from albumentations.pytorch import ToTensorV2
|
| 23 |
+
|
| 24 |
+
# Progress bar (for precompute_all_masks)
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
class OptimizedZipReader:
|
| 28 |
+
"""
|
| 29 |
+
Fast ZIP file reader with LRU caching
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, zip_path, cache_size=1000):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
zip_path: Path to ZIP file
|
| 35 |
+
cache_size: Number of images to cache in RAM
|
| 36 |
+
"""
|
| 37 |
+
self.zip_path = zip_path
|
| 38 |
+
self.cache_size = cache_size
|
| 39 |
+
self._zip_file = None # Will be lazily initialized
|
| 40 |
+
self._name_to_info = None
|
| 41 |
+
|
| 42 |
+
# Cache
|
| 43 |
+
self._cache = {}
|
| 44 |
+
self._cache_order = []
|
| 45 |
+
self._hits = 0
|
| 46 |
+
self._misses = 0
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def zip_file(self):
|
| 50 |
+
"""Lazy initialization of ZIP file handle"""
|
| 51 |
+
if self._zip_file is None:
|
| 52 |
+
print(f"Opening ZIP file: {self.zip_path}")
|
| 53 |
+
self._zip_file = zipfile.ZipFile(self.zip_path, 'r', allowZip64=True)
|
| 54 |
+
|
| 55 |
+
# Build index on first access
|
| 56 |
+
print("Building ZIP index...")
|
| 57 |
+
self._name_to_info = {
|
| 58 |
+
info.filename: info
|
| 59 |
+
for info in self._zip_file.infolist()
|
| 60 |
+
}
|
| 61 |
+
print(f"✓ Indexed {len(self._name_to_info)} files")
|
| 62 |
+
|
| 63 |
+
return self._zip_file
|
| 64 |
+
|
| 65 |
+
def read_image(self, path):
|
| 66 |
+
"""
|
| 67 |
+
Read image data with automatic caching
|
| 68 |
+
|
| 69 |
+
Returns: bytes (image file data)
|
| 70 |
+
"""
|
| 71 |
+
# Check cache first
|
| 72 |
+
if path in self._cache:
|
| 73 |
+
self._hits += 1
|
| 74 |
+
return self._cache[path]
|
| 75 |
+
|
| 76 |
+
# Cache miss - read from ZIP (this triggers lazy initialization)
|
| 77 |
+
self._misses += 1
|
| 78 |
+
img_data = self.zip_file.read(path) # Uses property getter
|
| 79 |
+
|
| 80 |
+
# Add to cache with LRU eviction
|
| 81 |
+
if len(self._cache) >= self.cache_size:
|
| 82 |
+
oldest = self._cache_order.pop(0)
|
| 83 |
+
del self._cache[oldest]
|
| 84 |
+
|
| 85 |
+
self._cache[path] = img_data
|
| 86 |
+
self._cache_order.append(path)
|
| 87 |
+
|
| 88 |
+
return img_data
|
| 89 |
+
|
| 90 |
+
def get_cache_stats(self):
|
| 91 |
+
"""Return cache hit rate statistics"""
|
| 92 |
+
total = self._hits + self._misses
|
| 93 |
+
hit_rate = self._hits / total * 100 if total > 0 else 0
|
| 94 |
+
return {
|
| 95 |
+
'hits': self._hits,
|
| 96 |
+
'misses': self._misses,
|
| 97 |
+
'hit_rate': f"{hit_rate:.2f}%",
|
| 98 |
+
'cache_size': len(self._cache)
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def close(self):
|
| 102 |
+
"""Close ZIP file and clear cache"""
|
| 103 |
+
if self._zip_file is not None:
|
| 104 |
+
self._zip_file.close()
|
| 105 |
+
self._zip_file = None
|
| 106 |
+
self._cache.clear()
|
| 107 |
+
self._cache_order.clear()
|
| 108 |
+
self._name_to_info = None
|
| 109 |
+
|
| 110 |
+
class CheXpertDataset(Dataset):
|
| 111 |
+
"""
|
| 112 |
+
CheXpert Dataset class
|
| 113 |
+
|
| 114 |
+
NEW: Returns 3-channel images: (img, img*mask, mask)
|
| 115 |
+
- Channel 0: Original grayscale image
|
| 116 |
+
- Channel 1: Masked image (lung region only)
|
| 117 |
+
- Channel 2: Binary lung mask
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
csv_path (str): Path to the CSV file (train.csv or valid.csv)
|
| 121 |
+
root_dir (str): Root directory of the CheXpert dataset
|
| 122 |
+
image_size (int): Target image size (default: 384)
|
| 123 |
+
augment (bool): Whether to apply augmentations (default: False)
|
| 124 |
+
use_frontal_only (bool): If True, only use frontal view images (default: True)
|
| 125 |
+
fill_uncertain (str): How to handle uncertain labels: 'zeros', 'ones', 'ignore' (default: 'zeros')
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# 14 pathology classes in CheXpert
|
| 129 |
+
PATHOLOGIES = [
|
| 130 |
+
'No Finding',
|
| 131 |
+
'Enlarged Cardiomediastinum',
|
| 132 |
+
'Cardiomegaly',
|
| 133 |
+
'Lung Opacity',
|
| 134 |
+
'Lung Lesion',
|
| 135 |
+
'Edema',
|
| 136 |
+
'Consolidation',
|
| 137 |
+
'Pneumonia',
|
| 138 |
+
'Atelectasis',
|
| 139 |
+
'Pneumothorax',
|
| 140 |
+
'Pleural Effusion',
|
| 141 |
+
'Pleural Other',
|
| 142 |
+
'Fracture',
|
| 143 |
+
'Support Devices'
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
csv_path,
|
| 149 |
+
root_dir,
|
| 150 |
+
image_size=384,
|
| 151 |
+
augment=False,
|
| 152 |
+
use_frontal_only=False,
|
| 153 |
+
fill_uncertain='ignore',
|
| 154 |
+
lmdb_path=None,
|
| 155 |
+
zip_path=None,
|
| 156 |
+
zip_cache_size=1000,
|
| 157 |
+
mask_dir=None, domask=False
|
| 158 |
+
):
|
| 159 |
+
self.root_dir = root_dir
|
| 160 |
+
self.image_size = image_size
|
| 161 |
+
self.augment = augment
|
| 162 |
+
self.fill_uncertain = fill_uncertain
|
| 163 |
+
self.env =None #lmdb.open(lmdb_path, readonly=True, lock=False) if lmdb_path else None
|
| 164 |
+
self._zip_path = zip_path
|
| 165 |
+
self._zip_cache_size = zip_cache_size
|
| 166 |
+
self._zip_reader_instance = None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# Read CSV file
|
| 170 |
+
self.df = pd.read_csv(csv_path)
|
| 171 |
+
for pathology in self.PATHOLOGIES:
|
| 172 |
+
if pathology in self.df.columns:
|
| 173 |
+
self.df[pathology] = pd.to_numeric(self.df[pathology], errors='coerce')
|
| 174 |
+
|
| 175 |
+
# Filter for frontal views only if specified
|
| 176 |
+
if use_frontal_only:
|
| 177 |
+
self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
|
| 178 |
+
|
| 179 |
+
# Handle uncertain labels (-1 values)
|
| 180 |
+
self._process_uncertain_labels()
|
| 181 |
+
|
| 182 |
+
# Setup augmentations
|
| 183 |
+
self.train_transform = self._get_train_transforms()
|
| 184 |
+
self.val_transform = self._get_val_transforms()
|
| 185 |
+
|
| 186 |
+
print(f"Loaded {len(self.df)} images from {csv_path}")
|
| 187 |
+
print(f"Image size: {image_size}x{image_size}")
|
| 188 |
+
print(f"Augmentation: {augment}")
|
| 189 |
+
print(f"Uncertain labels filled with: {fill_uncertain}")
|
| 190 |
+
|
| 191 |
+
if mask_dir and domask:
|
| 192 |
+
self.precompute_all_masks(mask_dir)
|
| 193 |
+
|
| 194 |
+
# Run this ONCE before training
|
| 195 |
+
def precompute_all_masks(self, save_dir):
|
| 196 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 197 |
+
for idx in tqdm(range(len(self))):
|
| 198 |
+
img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
|
| 199 |
+
part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
|
| 200 |
+
if self.zip_reader:
|
| 201 |
+
# Read image data from ZIP (no extraction!)
|
| 202 |
+
img_data = self.zip_reader.read_image(part_path)
|
| 203 |
+
|
| 204 |
+
# Open image from bytes in memory
|
| 205 |
+
image = Image.open(io.BytesIO(img_data)).convert('L')
|
| 206 |
+
else:
|
| 207 |
+
image = Image.open(img_path).convert('L')
|
| 208 |
+
|
| 209 |
+
image = np.array(image)
|
| 210 |
+
|
| 211 |
+
mask = chexpert_medsam_mask(image)
|
| 212 |
+
mask_path = os.path.join(save_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
|
| 213 |
+
os.makedirs(os.path.dirname(mask_path), exist_ok=True)
|
| 214 |
+
torch.save(mask, mask_path)
|
| 215 |
+
@property
|
| 216 |
+
def zip_reader(self):
|
| 217 |
+
"""
|
| 218 |
+
Lazy property getter for ZIP reader
|
| 219 |
+
|
| 220 |
+
The ZIP file is only opened when first accessed, not during __init__.
|
| 221 |
+
This is useful when:
|
| 222 |
+
- Creating multiple dataset objects but only using some
|
| 223 |
+
- Saving memory during dataset setup
|
| 224 |
+
- Working with multiprocessing (each worker creates its own)
|
| 225 |
+
"""
|
| 226 |
+
if self._zip_reader_instance is None and self._zip_path is not None:
|
| 227 |
+
self._zip_reader_instance = OptimizedZipReader(
|
| 228 |
+
self._zip_path,
|
| 229 |
+
cache_size=self._zip_cache_size
|
| 230 |
+
)
|
| 231 |
+
return self._zip_reader_instance
|
| 232 |
+
|
| 233 |
+
def _load_and_cache_image(self, img_path, idx):
|
| 234 |
+
"""
|
| 235 |
+
Load image with automatic resizing and caching.
|
| 236 |
+
If resized version exists, load it. Otherwise, resize, save, and load.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
img_path (str): Original image path from CSV
|
| 240 |
+
idx (int): Index for tracking
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
np.ndarray: Loaded image (grayscale)
|
| 244 |
+
"""
|
| 245 |
+
# Create cache directory structure
|
| 246 |
+
cache_dir = Path(self.root_dir) #/ f"cache_{self.image_size}"
|
| 247 |
+
|
| 248 |
+
# Preserve the relative path structure in cache
|
| 249 |
+
path_parts = list(Path(img_path).parts)
|
| 250 |
+
path_parts[-1]=f"{self.image_size}_{path_parts[-1]}"
|
| 251 |
+
relative_path = Path(*path_parts)
|
| 252 |
+
cached_path =relative_path.with_suffix('.jpg')
|
| 253 |
+
|
| 254 |
+
# Check if cached version exists
|
| 255 |
+
if cached_path.exists():
|
| 256 |
+
# Load cached image
|
| 257 |
+
image = Image.open(cached_path).convert('L')
|
| 258 |
+
image = np.array(image)
|
| 259 |
+
|
| 260 |
+
# Verify it's the correct size
|
| 261 |
+
if image.shape[0] == self.image_size and image.shape[1] == self.image_size:
|
| 262 |
+
return image
|
| 263 |
+
|
| 264 |
+
# Cache doesn't exist or wrong size - load original
|
| 265 |
+
original_path = img_path
|
| 266 |
+
image = Image.open(original_path).convert('L')
|
| 267 |
+
|
| 268 |
+
# Check if original is already target size
|
| 269 |
+
width, height = image.size
|
| 270 |
+
|
| 271 |
+
if width == self.image_size and height == self.image_size:
|
| 272 |
+
# Already correct size, just convert to array
|
| 273 |
+
return np.array(image)
|
| 274 |
+
|
| 275 |
+
# Resize image
|
| 276 |
+
image_resized = image.resize(
|
| 277 |
+
(self.image_size, self.image_size),
|
| 278 |
+
Image.LANCZOS
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Save to cache
|
| 282 |
+
cached_path.parent.mkdir(parents=True, exist_ok=True)
|
| 283 |
+
image_resized.save(cached_path, 'JPEG', quality=95, optimize=True)
|
| 284 |
+
|
| 285 |
+
return np.array(image_resized)
|
| 286 |
+
|
| 287 |
+
def _process_uncertain_labels(self):
|
| 288 |
+
"""Process uncertain labels (-1) based on the chosen strategy."""
|
| 289 |
+
for pathology in self.PATHOLOGIES:
|
| 290 |
+
if pathology in self.df.columns:
|
| 291 |
+
if self.fill_uncertain == 'zeros':
|
| 292 |
+
# Map uncertain (-1) to negative (0)
|
| 293 |
+
self.df[pathology] = self.df[pathology].replace(-1, 0)
|
| 294 |
+
elif self.fill_uncertain == 'ones':
|
| 295 |
+
# Map uncertain (-1) to positive (1)
|
| 296 |
+
self.df[pathology] = self.df[pathology].replace(-1, 1)
|
| 297 |
+
elif self.fill_uncertain == 'ignore':
|
| 298 |
+
# Keep -1 as is (you'll need to handle this in loss function)
|
| 299 |
+
pass
|
| 300 |
+
|
| 301 |
+
# Fill NaN with 0 (negative)
|
| 302 |
+
self.df[pathology] = self.df[pathology].fillna(0)
|
| 303 |
+
|
| 304 |
+
def _get_train_transforms(self):
|
| 305 |
+
"""Get training augmentations suitable for chest X-rays."""
|
| 306 |
+
import cv2
|
| 307 |
+
return A.Compose([
|
| 308 |
+
# Resize to target size
|
| 309 |
+
A.LongestMaxSize(max_size=self.image_size),
|
| 310 |
+
A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
|
| 311 |
+
|
| 312 |
+
# Geometric augmentations (conservative for medical images)
|
| 313 |
+
A.HorizontalFlip(p=0.5),
|
| 314 |
+
A.Affine(
|
| 315 |
+
translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
|
| 316 |
+
scale=(0.9, 1.1),
|
| 317 |
+
rotate=(-10, 10),
|
| 318 |
+
fit_output=False,
|
| 319 |
+
p=0.5
|
| 320 |
+
),
|
| 321 |
+
|
| 322 |
+
# Intensity augmentations
|
| 323 |
+
A.OneOf([
|
| 324 |
+
A.RandomBrightnessContrast(
|
| 325 |
+
brightness_limit=0.2,
|
| 326 |
+
contrast_limit=0.2,
|
| 327 |
+
p=1.0
|
| 328 |
+
),
|
| 329 |
+
A.RandomGamma(gamma_limit=(80, 120), p=1.0),
|
| 330 |
+
A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
|
| 331 |
+
], p=0.5),
|
| 332 |
+
|
| 333 |
+
# Add slight blur to simulate different imaging conditions
|
| 334 |
+
A.OneOf([
|
| 335 |
+
A.GaussianBlur(blur_limit=(3, 5), p=1.0),
|
| 336 |
+
A.MedianBlur(blur_limit=3, p=1.0),
|
| 337 |
+
], p=0.2),
|
| 338 |
+
|
| 339 |
+
# Add noise
|
| 340 |
+
A.GaussNoise(p=0.2),
|
| 341 |
+
|
| 342 |
+
# Normalize to [0, 1]
|
| 343 |
+
A.Normalize(
|
| 344 |
+
mean=[0.5],
|
| 345 |
+
std=[0.5],
|
| 346 |
+
max_pixel_value=255.0
|
| 347 |
+
),
|
| 348 |
+
|
| 349 |
+
ToTensorV2()
|
| 350 |
+
])
|
| 351 |
+
|
| 352 |
+
def _get_val_transforms(self):
|
| 353 |
+
"""Get validation/test transforms (no augmentation)."""
|
| 354 |
+
return A.Compose([
|
| 355 |
+
A.LongestMaxSize(max_size=self.image_size),
|
| 356 |
+
A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'),
|
| 357 |
+
A.Normalize(
|
| 358 |
+
mean=[0.5],
|
| 359 |
+
std=[0.5],
|
| 360 |
+
max_pixel_value=255.0
|
| 361 |
+
),
|
| 362 |
+
ToTensorV2()
|
| 363 |
+
])
|
| 364 |
+
|
| 365 |
+
def __len__(self):
|
| 366 |
+
return len(self.df)
|
| 367 |
+
|
| 368 |
+
def __del__(self):
|
| 369 |
+
"""Close ZIP when done"""
|
| 370 |
+
if hasattr(self, 'zip_reader'):
|
| 371 |
+
self.zip_reader.close()
|
| 372 |
+
|
| 373 |
+
def __getitem__(self, idx):
|
| 374 |
+
if self.env:
|
| 375 |
+
with self.env.begin() as txn:
|
| 376 |
+
# Retrieve serialized data
|
| 377 |
+
data = txn.get(str(idx).encode())
|
| 378 |
+
sample = pickle.loads(data)
|
| 379 |
+
return sample
|
| 380 |
+
else:
|
| 381 |
+
# Get image path
|
| 382 |
+
img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path'])
|
| 383 |
+
#image = self._load_and_cache_image(img_path, idx)
|
| 384 |
+
# Load image
|
| 385 |
+
#image = Image.open(img_path).convert('L') # Convert to grayscale
|
| 386 |
+
|
| 387 |
+
part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:])
|
| 388 |
+
if self.zip_reader:
|
| 389 |
+
# Read image data from ZIP (no extraction!)
|
| 390 |
+
img_data = self.zip_reader.read_image(part_path)
|
| 391 |
+
|
| 392 |
+
# Open image from bytes in memory
|
| 393 |
+
image = Image.open(io.BytesIO(img_data)).convert('L')
|
| 394 |
+
else:
|
| 395 |
+
image = Image.open(img_path).convert('L')
|
| 396 |
+
|
| 397 |
+
image = np.array(image)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# Load pre-computed mask
|
| 401 |
+
#mask_path = os.path.join(self.mask_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt'))
|
| 402 |
+
#masked_img = torch.load(mask_path)
|
| 403 |
+
# Apply transforms to BOTH image and mask together
|
| 404 |
+
if self.augment:
|
| 405 |
+
# Augmentation applies to both image and mask
|
| 406 |
+
transformed = self.train_transform(image=image)
|
| 407 |
+
image_transformed = transformed['image'] # (1, H, W) tensor, normalized
|
| 408 |
+
#masked_img=transformed['mask']
|
| 409 |
+
# (H, W) tensor
|
| 410 |
+
else:
|
| 411 |
+
transformed = self.val_transform(image=image)
|
| 412 |
+
image_transformed = transformed['image'] # (1, H, W) tensor, normalized
|
| 413 |
+
#masked_img=transformed['mask']
|
| 414 |
+
|
| 415 |
+
# Expand dimensions to match
|
| 416 |
+
image_1ch = image_transformed # (1, H, W)
|
| 417 |
+
masked_img = image_transformed
|
| 418 |
+
|
| 419 |
+
# Get labels for all pathologies
|
| 420 |
+
labels = []
|
| 421 |
+
for pathology in self.PATHOLOGIES:
|
| 422 |
+
if pathology in self.df.columns:
|
| 423 |
+
label = self.df.iloc[idx][pathology]
|
| 424 |
+
labels.append(float(label) if not pd.isna(label) else 0.0)
|
| 425 |
+
else:
|
| 426 |
+
labels.append(0.0)
|
| 427 |
+
|
| 428 |
+
labels = torch.tensor(labels, dtype=torch.float32)
|
| 429 |
+
|
| 430 |
+
# Get additional metadata
|
| 431 |
+
metadata = {
|
| 432 |
+
'patient_id': self.df.iloc[idx]['Path'].split('/')[2], # Extract patient ID from path
|
| 433 |
+
'study_id': self.df.iloc[idx]['Path'].split('/')[3], # Extract study ID from path
|
| 434 |
+
'view': self.df.iloc[idx]['Frontal/Lateral'],
|
| 435 |
+
'sex': self.df.iloc[idx]['Sex'] if 'Sex' in self.df.columns else 'Unknown',
|
| 436 |
+
'age': self.df.iloc[idx]['Age'] if 'Age' in self.df.columns else -1,
|
| 437 |
+
'path': self.df.iloc[idx]['Path']
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
return {
|
| 441 |
+
'image': image_1ch,
|
| 442 |
+
'labels': labels,
|
| 443 |
+
'metadata': metadata
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
def get_label_names(self):
|
| 447 |
+
"""Return list of pathology label names."""
|
| 448 |
+
return self.PATHOLOGIES
|
| 449 |
+
|
| 450 |
+
def get_label_distribution(self):
|
| 451 |
+
"""Get distribution of positive labels for each pathology."""
|
| 452 |
+
distribution = {}
|
| 453 |
+
for pathology in self.PATHOLOGIES:
|
| 454 |
+
if pathology in self.df.columns:
|
| 455 |
+
positive_count = (self.df[pathology] == 1.0).sum()
|
| 456 |
+
distribution[pathology] = {
|
| 457 |
+
'positive': int(positive_count),
|
| 458 |
+
'percentage': round(positive_count / len(self.df) * 100, 2)
|
| 459 |
+
}
|
| 460 |
+
return distribution
|
| 461 |
+
|
| 462 |
+
def get_class_weights(self):
|
| 463 |
+
"""
|
| 464 |
+
OPTIMIZED: Vectorized class weights calculation
|
| 465 |
+
"""
|
| 466 |
+
weights = []
|
| 467 |
+
for pathology in self.PATHOLOGIES:
|
| 468 |
+
if pathology in self.df.columns:
|
| 469 |
+
# Vectorized counting (much faster than iterating)
|
| 470 |
+
values = self.df[pathology].values
|
| 471 |
+
pos = np.sum(values == 1.0)
|
| 472 |
+
neg = np.sum(values == 0.0)
|
| 473 |
+
weight = neg / pos if pos > 0 else 1.0
|
| 474 |
+
weights.append(weight)
|
| 475 |
+
return torch.tensor(weights, dtype=torch.float32)
|
| 476 |
+
|
| 477 |
+
def get_sample_weights(self):
|
| 478 |
+
"""
|
| 479 |
+
OPTIMIZED: Vectorized sample weights calculation
|
| 480 |
+
|
| 481 |
+
Performance: ~1000x faster than original
|
| 482 |
+
Original: 15-30 seconds for 200k samples
|
| 483 |
+
This: 0.01-0.05 seconds for 200k samples
|
| 484 |
+
"""
|
| 485 |
+
# Get class weights as numpy array
|
| 486 |
+
class_weights = self.get_class_weights().numpy()
|
| 487 |
+
|
| 488 |
+
# Get all labels as numpy array in ONE vectorized operation
|
| 489 |
+
labels_array = self.df[self.PATHOLOGIES].values.astype(np.float32)
|
| 490 |
+
|
| 491 |
+
# Create weighted labels matrix: where label=1, use class_weight, else -inf
|
| 492 |
+
# Shape: (n_samples, n_classes)
|
| 493 |
+
weighted_labels = np.where(
|
| 494 |
+
labels_array == 1.0,
|
| 495 |
+
class_weights,
|
| 496 |
+
-np.inf # Use -inf instead of 0 so max will only consider positive labels
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# For each sample, find the maximum class weight of its positive labels
|
| 500 |
+
# If a sample has no positive labels, max will be -inf, which we'll replace with 1.0
|
| 501 |
+
sample_weights = np.max(weighted_labels, axis=1)
|
| 502 |
+
sample_weights = np.where(
|
| 503 |
+
np.isinf(sample_weights),
|
| 504 |
+
1.0, # Samples with no positive labels get weight 1.0
|
| 505 |
+
sample_weights
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
return torch.tensor(sample_weights, dtype=torch.float32)
|
data/splitter.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standard library
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Data handling
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# Machine learning
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
|
| 12 |
+
class CheXpertDataSplitter:
|
| 13 |
+
"""
|
| 14 |
+
Advanced stratified train-validation splitter for CheXpert dataset.
|
| 15 |
+
Handles:
|
| 16 |
+
- Patient-level splitting (prevents data leakage)
|
| 17 |
+
- Multi-label stratification
|
| 18 |
+
- Class imbalance awareness
|
| 19 |
+
- Study-level grouping
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
PATHOLOGIES = [
|
| 23 |
+
'No Finding',
|
| 24 |
+
'Enlarged Cardiomediastinum',
|
| 25 |
+
'Cardiomegaly',
|
| 26 |
+
'Lung Opacity',
|
| 27 |
+
'Lung Lesion',
|
| 28 |
+
'Edema',
|
| 29 |
+
'Consolidation',
|
| 30 |
+
'Pneumonia',
|
| 31 |
+
'Atelectasis',
|
| 32 |
+
'Pneumothorax',
|
| 33 |
+
'Pleural Effusion',
|
| 34 |
+
'Pleural Other',
|
| 35 |
+
'Fracture',
|
| 36 |
+
'Support Devices'
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
def __init__(self, csv_path, val_size=0.15,test_size=0.15, random_state=42,
|
| 40 |
+
use_frontal_only=True, fill_uncertain='zeros',root=None):
|
| 41 |
+
"""
|
| 42 |
+
Initialize the splitter.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
csv_path: Path to train.csv from CheXpert-small
|
| 46 |
+
val_size: Validation set proportion (default: 0.15)
|
| 47 |
+
random_state: Random seed for reproducibility
|
| 48 |
+
use_frontal_only: Use only frontal view images
|
| 49 |
+
fill_uncertain: How to handle uncertain labels ('zeros', 'ones', 'ignore')
|
| 50 |
+
"""
|
| 51 |
+
self.csv_path = csv_path
|
| 52 |
+
self.val_size = val_size
|
| 53 |
+
self.test_size = test_size
|
| 54 |
+
self.random_state = random_state
|
| 55 |
+
self.use_frontal_only = use_frontal_only
|
| 56 |
+
self.fill_uncertain = fill_uncertain
|
| 57 |
+
self.root=root
|
| 58 |
+
|
| 59 |
+
print("=" * 80)
|
| 60 |
+
print("CheXpert Data Splitter - Preventing Data Leakage & Class Bias")
|
| 61 |
+
print("=" * 80)
|
| 62 |
+
|
| 63 |
+
def load_and_preprocess(self):
|
| 64 |
+
"""Load and preprocess the dataset."""
|
| 65 |
+
print("\n[1/5] Loading data...")
|
| 66 |
+
self.df = pd.read_csv(self.csv_path)
|
| 67 |
+
print(f" Loaded {len(self.df)} images")
|
| 68 |
+
|
| 69 |
+
#self.df=self.df[self.df["Path"].apply(os.path.exists)]
|
| 70 |
+
|
| 71 |
+
# Filter for frontal views only
|
| 72 |
+
if self.use_frontal_only:
|
| 73 |
+
initial_count = len(self.df)
|
| 74 |
+
self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True)
|
| 75 |
+
print(f" Filtered to frontal views: {len(self.df)} images ({initial_count - len(self.df)} removed)")
|
| 76 |
+
|
| 77 |
+
# Extract patient and study IDs from path
|
| 78 |
+
print("\n[2/5] Extracting patient and study IDs...")
|
| 79 |
+
self.df['patient_id'] = self.df['Path'].apply(lambda x: x.split('/')[2])
|
| 80 |
+
self.df['study_id'] = self.df['Path'].apply(lambda x: x.split('/')[3])
|
| 81 |
+
|
| 82 |
+
n_patients = self.df['patient_id'].nunique()
|
| 83 |
+
n_studies = self.df['study_id'].nunique()
|
| 84 |
+
print(f" Unique patients: {n_patients}")
|
| 85 |
+
print(f" Unique studies: {n_studies}")
|
| 86 |
+
print(f" Images per patient (avg): {len(self.df) / n_patients:.2f}")
|
| 87 |
+
|
| 88 |
+
# Process uncertain labels
|
| 89 |
+
print("\n[3/5] Processing uncertain labels...")
|
| 90 |
+
self._process_uncertain_labels()
|
| 91 |
+
|
| 92 |
+
return self.df
|
| 93 |
+
|
| 94 |
+
def _process_uncertain_labels(self):
|
| 95 |
+
"""Process uncertain labels (-1) based on the chosen strategy."""
|
| 96 |
+
for pathology in self.PATHOLOGIES:
|
| 97 |
+
if pathology in self.df.columns:
|
| 98 |
+
uncertain_count = (self.df[pathology] == -1).sum()
|
| 99 |
+
|
| 100 |
+
if self.fill_uncertain == 'zeros':
|
| 101 |
+
self.df[pathology] = self.df[pathology].replace(-1, 0)
|
| 102 |
+
elif self.fill_uncertain == 'ones':
|
| 103 |
+
self.df[pathology] = self.df[pathology].replace(-1, 1)
|
| 104 |
+
elif self.fill_uncertain == 'ignore':
|
| 105 |
+
pass # Keep -1 as is
|
| 106 |
+
|
| 107 |
+
# Fill NaN with 0
|
| 108 |
+
self.df[pathology] = self.df[pathology].fillna(0)
|
| 109 |
+
|
| 110 |
+
print(f" Uncertain labels strategy: {self.fill_uncertain}")
|
| 111 |
+
|
| 112 |
+
def create_stratification_groups(self):
|
| 113 |
+
"""
|
| 114 |
+
Create stratification groups based on multi-label combinations.
|
| 115 |
+
Uses patient-level aggregation to prevent data leakage.
|
| 116 |
+
"""
|
| 117 |
+
print("\n[4/5] Creating stratification groups (patient-level)...")
|
| 118 |
+
|
| 119 |
+
# Group by patient and aggregate labels
|
| 120 |
+
patient_groups = self.df.groupby('patient_id').agg({
|
| 121 |
+
**{pathology: 'max' for pathology in self.PATHOLOGIES if pathology in self.df.columns},
|
| 122 |
+
'study_id': 'first', # Keep one study_id for reference
|
| 123 |
+
'Sex': 'first',
|
| 124 |
+
'Age': 'first'
|
| 125 |
+
}).reset_index()
|
| 126 |
+
|
| 127 |
+
# Create label signature for each patient
|
| 128 |
+
# This is a binary string representing which conditions are present
|
| 129 |
+
def create_label_signature(row):
|
| 130 |
+
signature = []
|
| 131 |
+
for pathology in self.PATHOLOGIES:
|
| 132 |
+
if pathology in patient_groups.columns:
|
| 133 |
+
signature.append(str(int(row[pathology])))
|
| 134 |
+
return ''.join(signature)
|
| 135 |
+
|
| 136 |
+
patient_groups['label_signature'] = patient_groups.apply(create_label_signature, axis=1)
|
| 137 |
+
|
| 138 |
+
# For rare combinations, group them together
|
| 139 |
+
signature_counts = patient_groups['label_signature'].value_counts()
|
| 140 |
+
rare_threshold = max(5, int(len(patient_groups) * 0.001)) # At least 5 or 0.1%
|
| 141 |
+
|
| 142 |
+
def get_stratification_group(signature):
|
| 143 |
+
if signature_counts[signature] < rare_threshold:
|
| 144 |
+
return 'RARE_COMBINATION'
|
| 145 |
+
return signature
|
| 146 |
+
|
| 147 |
+
patient_groups['stratification_group'] = patient_groups['label_signature'].apply(get_stratification_group)
|
| 148 |
+
|
| 149 |
+
# Print distribution statistics
|
| 150 |
+
print(f"\n Patient-level label distribution:")
|
| 151 |
+
for pathology in self.PATHOLOGIES:
|
| 152 |
+
if pathology in patient_groups.columns:
|
| 153 |
+
positive_count = (patient_groups[pathology] == 1).sum()
|
| 154 |
+
percentage = positive_count / len(patient_groups) * 100
|
| 155 |
+
print(f" {pathology:30s}: {positive_count:5d} ({percentage:5.2f}%)")
|
| 156 |
+
|
| 157 |
+
unique_groups = patient_groups['stratification_group'].nunique()
|
| 158 |
+
print(f"\n Unique stratification groups: {unique_groups}")
|
| 159 |
+
print(f" Rare combinations grouped: {(patient_groups['stratification_group'] == 'RARE_COMBINATION').sum()}")
|
| 160 |
+
|
| 161 |
+
return patient_groups
|
| 162 |
+
|
| 163 |
+
def perform_split(self, patient_groups):
|
| 164 |
+
"""
|
| 165 |
+
Perform stratified train-validation-test split at patient level.
|
| 166 |
+
"""
|
| 167 |
+
print("\n[5/5] Performing stratified patient-level split...")
|
| 168 |
+
|
| 169 |
+
stratification_labels = patient_groups['stratification_group'].values
|
| 170 |
+
|
| 171 |
+
# ---- train / (val+test) ----
|
| 172 |
+
train_patients, valtest_patients = train_test_split(
|
| 173 |
+
patient_groups['patient_id'].values,
|
| 174 |
+
test_size=self.val_size + self.test_size, # <-- new
|
| 175 |
+
stratify=stratification_labels,
|
| 176 |
+
random_state=self.random_state
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# ---- val / test from the remaining pool ----
|
| 180 |
+
remaining_labels = patient_groups.set_index('patient_id').loc[valtest_patients]['stratification_group'].values
|
| 181 |
+
val_patients, test_patients = train_test_split(
|
| 182 |
+
valtest_patients,
|
| 183 |
+
test_size=self.test_size / (self.val_size + self.test_size), # <-- proportion of the val+test pool
|
| 184 |
+
stratify=remaining_labels,
|
| 185 |
+
random_state=self.random_state
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
print(f" Train patients: {len(train_patients)}")
|
| 189 |
+
print(f" Val patients: {len(val_patients)}")
|
| 190 |
+
print(f" Test patients: {len(test_patients)}")
|
| 191 |
+
|
| 192 |
+
# Split the full dataframe
|
| 193 |
+
train_df = self.df[self.df['patient_id'].isin(train_patients)].copy()
|
| 194 |
+
val_df = self.df[self.df['patient_id'].isin(val_patients)].copy()
|
| 195 |
+
test_df = self.df[self.df['patient_id'].isin(test_patients)].copy()
|
| 196 |
+
|
| 197 |
+
# ---- leakage check (train vs val vs test) ----
|
| 198 |
+
sets = [('train', train_df), ('val', val_df), ('test', test_df)]
|
| 199 |
+
for i, (name_i, df_i) in enumerate(sets):
|
| 200 |
+
for j, (name_j, df_j) in enumerate(sets[i+1:]):
|
| 201 |
+
overlap = set(df_i['patient_id']).intersection(set(df_j['patient_id']))
|
| 202 |
+
if overlap:
|
| 203 |
+
raise ValueError(f"Data leakage between {name_i} and {name_j}: {len(overlap)} patients overlap")
|
| 204 |
+
print("\n No patient overlap – leakage prevented!")
|
| 205 |
+
|
| 206 |
+
return train_df, val_df, test_df
|
| 207 |
+
|
| 208 |
+
def run(self, output_dir='.', save_test=True):
|
| 209 |
+
self.load_and_preprocess()
|
| 210 |
+
patient_groups = self.create_stratification_groups()
|
| 211 |
+
train_df, val_df, test_df = self.perform_split(patient_groups)
|
| 212 |
+
|
| 213 |
+
self.verify_split_quality(train_df, val_df)
|
| 214 |
+
# optional: also verify train vs test (same function works with two dfs)
|
| 215 |
+
print("\n--- Train vs Test distribution check ---")
|
| 216 |
+
self.verify_split_quality(train_df, test_df)
|
| 217 |
+
|
| 218 |
+
train_path, val_path = self.save_splits(train_df, val_df, output_dir)
|
| 219 |
+
if save_test:
|
| 220 |
+
test_path = self.save_test_split(test_df, output_dir)
|
| 221 |
+
else:
|
| 222 |
+
test_path = None
|
| 223 |
+
|
| 224 |
+
print("\n" + "="*80)
|
| 225 |
+
print("Split Complete! (train / val / test)")
|
| 226 |
+
print("="*80)
|
| 227 |
+
return train_path, val_path, test_path
|
| 228 |
+
|
| 229 |
+
def save_test_split(self, test_df, output_dir):
|
| 230 |
+
output_dir = Path(output_dir)
|
| 231 |
+
output_dir.mkdir(exist_ok=True)
|
| 232 |
+
test_path = output_dir / 'test_ready.csv'
|
| 233 |
+
|
| 234 |
+
cols_to_drop = ['patient_id', 'study_id']
|
| 235 |
+
test_clean = test_df.drop(columns=[c for c in cols_to_drop if c in test_df.columns])
|
| 236 |
+
test_clean.to_csv(test_path, index=False)
|
| 237 |
+
|
| 238 |
+
print(f"Test set : {test_path} ({len(test_clean)} images)")
|
| 239 |
+
return test_path
|
| 240 |
+
|
| 241 |
+
def verify_split_quality(self, train_df, val_df):
|
| 242 |
+
"""
|
| 243 |
+
Verify the quality of the split by comparing label distributions.
|
| 244 |
+
"""
|
| 245 |
+
print("\n" + "=" * 80)
|
| 246 |
+
print("Split Quality Verification")
|
| 247 |
+
print("=" * 80)
|
| 248 |
+
|
| 249 |
+
print(f"\n{'Pathology':<30s} {'Train %':>10s} {'Val %':>10s} {'Difference':>12s}")
|
| 250 |
+
print("-" * 80)
|
| 251 |
+
|
| 252 |
+
max_diff = 0
|
| 253 |
+
for pathology in self.PATHOLOGIES:
|
| 254 |
+
if pathology in train_df.columns:
|
| 255 |
+
train_pos = (train_df[pathology] == 1).sum() / len(train_df) * 100
|
| 256 |
+
val_pos = (val_df[pathology] == 1).sum() / len(val_df) * 100
|
| 257 |
+
diff = abs(train_pos - val_pos)
|
| 258 |
+
max_diff = max(max_diff, diff)
|
| 259 |
+
|
| 260 |
+
print(f"{pathology:<30s} {train_pos:>9.2f}% {val_pos:>9.2f}% {diff:>11.2f}%")
|
| 261 |
+
|
| 262 |
+
print("-" * 80)
|
| 263 |
+
print(f"Maximum distribution difference: {max_diff:.2f}%")
|
| 264 |
+
|
| 265 |
+
if max_diff < 2.0:
|
| 266 |
+
print("✓ Excellent stratification (< 2% difference)")
|
| 267 |
+
elif max_diff < 5.0:
|
| 268 |
+
print("✓ Good stratification (< 5% difference)")
|
| 269 |
+
else:
|
| 270 |
+
print("⚠ Warning: Large distribution differences detected")
|
| 271 |
+
|
| 272 |
+
# Check for class imbalance
|
| 273 |
+
print("\n" + "=" * 80)
|
| 274 |
+
print("Class Imbalance Analysis (Train Set)")
|
| 275 |
+
print("=" * 80)
|
| 276 |
+
|
| 277 |
+
imbalance_ratios = []
|
| 278 |
+
for pathology in self.PATHOLOGIES:
|
| 279 |
+
if pathology in train_df.columns:
|
| 280 |
+
pos = (train_df[pathology] == 1).sum()
|
| 281 |
+
neg = (train_df[pathology] == 0).sum()
|
| 282 |
+
if pos > 0:
|
| 283 |
+
ratio = neg / pos
|
| 284 |
+
imbalance_ratios.append(ratio)
|
| 285 |
+
severity = "Low" if ratio < 5 else "Medium" if ratio < 20 else "High"
|
| 286 |
+
print(f"{pathology:<30s} Ratio: {ratio:>6.2f}:1 [{severity:>6s} imbalance]")
|
| 287 |
+
|
| 288 |
+
avg_imbalance = np.mean(imbalance_ratios)
|
| 289 |
+
print(f"\nAverage imbalance ratio: {avg_imbalance:.2f}:1")
|
| 290 |
+
|
| 291 |
+
def save_splits(self, train_df, val_df, output_dir='.'):
|
| 292 |
+
"""Save train and validation splits to CSV files."""
|
| 293 |
+
output_dir = Path(output_dir)
|
| 294 |
+
output_dir.mkdir(exist_ok=True)
|
| 295 |
+
|
| 296 |
+
train_path = output_dir / 'train_ready.csv'
|
| 297 |
+
val_path = output_dir / 'val_ready.csv'
|
| 298 |
+
|
| 299 |
+
# Remove temporary columns used for splitting
|
| 300 |
+
columns_to_drop = ['patient_id', 'study_id']
|
| 301 |
+
train_df_clean = train_df.drop(columns=[col for col in columns_to_drop if col in train_df.columns])
|
| 302 |
+
val_df_clean = val_df.drop(columns=[col for col in columns_to_drop if col in val_df.columns])
|
| 303 |
+
|
| 304 |
+
train_df_clean.to_csv(train_path, index=False)
|
| 305 |
+
val_df_clean.to_csv(val_path, index=False)
|
| 306 |
+
|
| 307 |
+
print("\n" + "=" * 80)
|
| 308 |
+
print("Files Saved Successfully")
|
| 309 |
+
print("=" * 80)
|
| 310 |
+
print(f"Train set: {train_path} ({len(train_df_clean)} images)")
|
| 311 |
+
print(f"Val set: {val_path} ({len(val_df_clean)} images)")
|
| 312 |
+
|
| 313 |
+
return train_path, val_path
|
| 314 |
+
|
| 315 |
+
# Main execution
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
root = "/content/drive/MyDrive"
|
| 318 |
+
# Configuration
|
| 319 |
+
CHEXPERT_CSV = os.path.join(root,"CheXpert-v1.0-small","train.csv") # Adjust path as needed
|
| 320 |
+
OUTPUT_DIR = os.path.join(root,"CheXpert-v1.0-small")
|
| 321 |
+
VAL_SIZE = 0.15
|
| 322 |
+
RANDOM_STATE = 42
|
| 323 |
+
USE_FRONTAL_ONLY = True
|
| 324 |
+
FILL_UNCERTAIN = 'zeros' # Options: 'zeros', 'ones', 'ignore'
|
| 325 |
+
|
| 326 |
+
# Create splitter
|
| 327 |
+
splitter = CheXpertDataSplitter(
|
| 328 |
+
csv_path=CHEXPERT_CSV,
|
| 329 |
+
val_size=VAL_SIZE,test_size=VAL_SIZE,
|
| 330 |
+
random_state=RANDOM_STATE,
|
| 331 |
+
use_frontal_only=USE_FRONTAL_ONLY,
|
| 332 |
+
fill_uncertain=FILL_UNCERTAIN,
|
| 333 |
+
root=OUTPUT_DIR
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Run the split
|
| 337 |
+
if os.path.exists(os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")) and os.path.exists(os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")):
|
| 338 |
+
train_path=os.path.join(root,"CheXpert-v1.0-small","train_ready.csv")
|
| 339 |
+
val_path=os.path.join(root,"CheXpert-v1.0-small","val_ready.csv")
|
| 340 |
+
test_path=os.path.join(root,"CheXpert-v1.0-small","test_ready.csv")
|
| 341 |
+
else:
|
| 342 |
+
train_path, val_path,test_path = splitter.run(output_dir=OUTPUT_DIR)
|
| 343 |
+
|
| 344 |
+
print("\nYou can now use these files with your CheXpertDataset class:")
|
| 345 |
+
print(f" train_dataset = CheXpertDataset('{train_path}', root_dir='...', augment=True)")
|
| 346 |
+
print(f" val_dataset = CheXpertDataset('{val_path}', root_dir='...', augment=False)")
|
| 347 |
+
print(f" test_dataset = CheXpertDataset('{test_path}', root_dir='...', augment=False)")
|
gitignore.txt
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
```
|
| 2 |
+
# Python
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*$py.class
|
| 6 |
+
*.so
|
| 7 |
+
.Python
|
| 8 |
+
env/
|
| 9 |
+
venv/
|
| 10 |
+
ENV/
|
| 11 |
+
.venv
|
| 12 |
+
|
| 13 |
+
# Jupyter Notebook
|
| 14 |
+
.ipynb_checkpoints
|
| 15 |
+
*.ipynb_checkpoints/
|
| 16 |
+
|
| 17 |
+
# PyTorch
|
| 18 |
+
*.ckpt
|
| 19 |
+
*.pth
|
| 20 |
+
weights/
|
| 21 |
+
runs/
|
| 22 |
+
lightning_logs/
|
| 23 |
+
|
| 24 |
+
# Data files (usually too large for GitHub)
|
| 25 |
+
*.csv
|
| 26 |
+
*.h5
|
| 27 |
+
*.hdf5
|
| 28 |
+
*.npy
|
| 29 |
+
*.npz
|
| 30 |
+
*.pkl
|
| 31 |
+
*.pickle
|
| 32 |
+
*.dcm
|
| 33 |
+
*.nii
|
| 34 |
+
*.nii.gz
|
| 35 |
+
|
| 36 |
+
# Models (often too large)
|
| 37 |
+
*.h5
|
| 38 |
+
*.pb
|
| 39 |
+
*.onnx
|
| 40 |
+
saved_models/
|
| 41 |
+
|
| 42 |
+
# IDE
|
| 43 |
+
.vscode/
|
| 44 |
+
.idea/
|
| 45 |
+
*.swp
|
| 46 |
+
*.swo
|
| 47 |
+
|
| 48 |
+
# OS
|
| 49 |
+
.DS_Store
|
| 50 |
+
Thumbs.db
|
| 51 |
+
|
| 52 |
+
# Environment variables
|
| 53 |
+
.env
|
| 54 |
+
.env.local
|
| 55 |
+
|
| 56 |
+
# Logs
|
| 57 |
+
*.log
|
| 58 |
+
logs/
|
| 59 |
+
|
| 60 |
+
# Weights & Biases (if you use it)
|
| 61 |
+
wandb/
|
loss/__init__.py
ADDED
|
File without changes
|
loss/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
loss/__pycache__/assymetric.cpython-313.pyc
ADDED
|
Binary file (2.98 kB). View file
|
|
|
loss/assymetric.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class AsymmetricLoss(nn.Module):
|
| 5 |
+
def __init__(self, gamma_neg=2, gamma_pos=1, clip=0.05, eps=1e-8, class_weights=None):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.gamma_neg = gamma_neg
|
| 8 |
+
self.gamma_pos = gamma_pos
|
| 9 |
+
self.clip = clip
|
| 10 |
+
self.eps = eps
|
| 11 |
+
if class_weights is not None:
|
| 12 |
+
self.register_buffer('class_weights', class_weights)
|
| 13 |
+
else:
|
| 14 |
+
self.class_weights = None
|
| 15 |
+
|
| 16 |
+
def forward(self, predictions, targets):
|
| 17 |
+
"""
|
| 18 |
+
FIXED VERSION with better numerical stability
|
| 19 |
+
predictions: (B, 14) - sigmoid outputs (already applied!)
|
| 20 |
+
targets: (B, 14) - binary labels
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
# CRITICAL FIX: Better clamping range
|
| 24 |
+
predictions = torch.clamp(predictions, min=self.eps, max=1 - self.eps)
|
| 25 |
+
|
| 26 |
+
# ===== POSITIVE SAMPLES =====
|
| 27 |
+
predictions_pos = torch.clamp(predictions - self.clip, min=self.eps)
|
| 28 |
+
focal_weight_pos = (1 - predictions_pos) ** self.gamma_pos
|
| 29 |
+
|
| 30 |
+
# FIX: Add small epsilon to prevent log(0)
|
| 31 |
+
loss_pos = targets * focal_weight_pos * torch.log(predictions_pos + self.eps)
|
| 32 |
+
|
| 33 |
+
# ===== NEGATIVE SAMPLES =====
|
| 34 |
+
focal_weight_neg = predictions ** self.gamma_neg
|
| 35 |
+
|
| 36 |
+
# FIX: Add small epsilon to prevent log(0)
|
| 37 |
+
loss_neg = (1 - targets) * focal_weight_neg * torch.log(1 - predictions + self.eps)
|
| 38 |
+
|
| 39 |
+
# ===== COMBINE =====
|
| 40 |
+
loss = -(loss_pos + loss_neg)
|
| 41 |
+
|
| 42 |
+
# Apply per-class weights
|
| 43 |
+
if self.class_weights is not None:
|
| 44 |
+
loss = loss * self.class_weights
|
| 45 |
+
|
| 46 |
+
# Average across batch and classes
|
| 47 |
+
loss = torch.mean(loss)
|
| 48 |
+
|
| 49 |
+
# CRITICAL: Check for NaN and return safe value
|
| 50 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 51 |
+
raise ValueError("Loss is NaN or Inf")
|
| 52 |
+
except ValueError as e:
|
| 53 |
+
print("⚠️ WARNING: NaN/Inf detected in loss, returning safe value")
|
| 54 |
+
print(e)
|
| 55 |
+
print("predictions:", predictions)
|
| 56 |
+
print("targets:", targets)
|
| 57 |
+
import traceback
|
| 58 |
+
traceback.print_exc()
|
| 59 |
+
return torch.tensor(0.0, device=loss.device, requires_grad=True)
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
models/__pycache__/classifier.cpython-313.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
models/__pycache__/densenet.cpython-313.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
models/__pycache__/mae.cpython-313.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
models/classifier.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from models.mae import MaskedAutoEncoder
|
| 7 |
+
from models.densenet import DenseNet
|
| 8 |
+
|
| 9 |
+
class AttentionPool(nn.Module):
|
| 10 |
+
def __init__(self, dim=768, embed_dim=2048, num_heads=8):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.query = nn.Parameter(torch.randn(1, 1, dim))
|
| 13 |
+
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
|
| 14 |
+
self.proj = nn.Linear(dim, embed_dim)
|
| 15 |
+
|
| 16 |
+
def forward(self, x): # x: (B, 576, 768)
|
| 17 |
+
B = x.size(0)
|
| 18 |
+
q = self.query.expand(B, -1, -1) # (B, 1, 768)
|
| 19 |
+
attn_out, _ = self.attn(q, x, x) # (B, 1, 768)
|
| 20 |
+
return self.proj(attn_out.squeeze(1)) # (B, 2048)
|
| 21 |
+
|
| 22 |
+
class CrossAttentionBlock(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Cross-attention: Query tokens attend to Key/Value tokens from another modality.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, dim_q, dim_kv, num_heads=8, dropout=0.1, proj_dim=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.proj_dim = proj_dim or dim_q
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
self.head_dim = self.proj_dim // num_heads
|
| 31 |
+
self.scale = self.head_dim ** -0.5
|
| 32 |
+
|
| 33 |
+
self.q_proj = nn.Linear(dim_q, self.proj_dim)
|
| 34 |
+
self.k_proj = nn.Linear(dim_kv, self.proj_dim)
|
| 35 |
+
self.v_proj = nn.Linear(dim_kv, self.proj_dim)
|
| 36 |
+
self.out_proj = nn.Linear(self.proj_dim, dim_q)
|
| 37 |
+
|
| 38 |
+
self.dropout = nn.Dropout(dropout)
|
| 39 |
+
self.norm_q = nn.LayerNorm(dim_q)
|
| 40 |
+
self.norm_kv = nn.LayerNorm(dim_kv)
|
| 41 |
+
|
| 42 |
+
def forward(self, query, key_value):
|
| 43 |
+
B, N_q, _ = query.shape
|
| 44 |
+
N_kv = key_value.shape[1]
|
| 45 |
+
|
| 46 |
+
q = self.norm_q(query)
|
| 47 |
+
kv = self.norm_kv(key_value)
|
| 48 |
+
|
| 49 |
+
Q = self.q_proj(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
|
| 50 |
+
K = self.k_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
|
| 51 |
+
V = self.v_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
|
| 52 |
+
|
| 53 |
+
attn = (Q @ K.transpose(-2, -1)) * self.scale
|
| 54 |
+
attn = F.softmax(attn, dim=-1)
|
| 55 |
+
attn = self.dropout(attn)
|
| 56 |
+
|
| 57 |
+
out = (attn @ V).transpose(1, 2).reshape(B, N_q, self.proj_dim)
|
| 58 |
+
out = self.out_proj(out)
|
| 59 |
+
|
| 60 |
+
return query + self.dropout(out)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class BidirectionalCrossAttention(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Bidirectional: MAE attends to DenseNet AND DenseNet attends to MAE.
|
| 66 |
+
"""
|
| 67 |
+
def __init__(self, mae_dim=768, dense_dim=2048, num_heads=8, dropout=0.1, proj_dim=512):
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
# MAE queries DenseNet
|
| 71 |
+
self.mae_cross = CrossAttentionBlock(mae_dim, dense_dim, num_heads, dropout, proj_dim)
|
| 72 |
+
# DenseNet queries MAE
|
| 73 |
+
self.dense_cross = CrossAttentionBlock(dense_dim, mae_dim, num_heads, dropout, proj_dim)
|
| 74 |
+
|
| 75 |
+
# FFN blocks
|
| 76 |
+
self.mae_ffn = nn.Sequential(
|
| 77 |
+
nn.LayerNorm(mae_dim),
|
| 78 |
+
nn.Linear(mae_dim, mae_dim * 4),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
nn.Dropout(dropout),
|
| 81 |
+
nn.Linear(mae_dim * 4, mae_dim),
|
| 82 |
+
nn.Dropout(dropout)
|
| 83 |
+
)
|
| 84 |
+
self.dense_ffn = nn.Sequential(
|
| 85 |
+
nn.LayerNorm(dense_dim),
|
| 86 |
+
nn.Linear(dense_dim, dense_dim * 2),
|
| 87 |
+
nn.GELU(),
|
| 88 |
+
nn.Dropout(dropout),
|
| 89 |
+
nn.Linear(dense_dim * 2, dense_dim),
|
| 90 |
+
nn.Dropout(dropout)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def forward(self, mae_tokens, dense_tokens):
|
| 94 |
+
# Cross attention
|
| 95 |
+
mae_out = self.mae_cross(mae_tokens, dense_tokens)
|
| 96 |
+
dense_out = self.dense_cross(dense_tokens, mae_tokens)
|
| 97 |
+
|
| 98 |
+
# FFN with residual
|
| 99 |
+
mae_out = mae_out + self.mae_ffn(mae_out)
|
| 100 |
+
dense_out = dense_out + self.dense_ffn(dense_out)
|
| 101 |
+
|
| 102 |
+
return mae_out, dense_out
|
| 103 |
+
class LearnedLogitEnsemble(nn.Module):
|
| 104 |
+
def __init__(self, num_heads=7, num_classes=14, temperature_init=1.0, use_gate=False):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.num_classes = num_classes
|
| 107 |
+
self.num_heads = num_heads
|
| 108 |
+
|
| 109 |
+
# 1. Per-head temperature (very important!)
|
| 110 |
+
self.log_temps = nn.Parameter(torch.ones(num_heads) * math.log(temperature_init))
|
| 111 |
+
|
| 112 |
+
# 2. Learned head weights via tiny gating network (best version)
|
| 113 |
+
# Input = concatenated logits (or probs) → predicts soft weights
|
| 114 |
+
gate_input_dim = num_classes * num_heads # concatenating raw logits works best
|
| 115 |
+
self.use_gate = use_gate
|
| 116 |
+
|
| 117 |
+
if use_gate:
|
| 118 |
+
self.gate = nn.Sequential(
|
| 119 |
+
nn.Linear(gate_input_dim, 256),
|
| 120 |
+
nn.GELU(),
|
| 121 |
+
nn.LayerNorm(256),
|
| 122 |
+
nn.Dropout(0.1),
|
| 123 |
+
nn.Linear(256, num_heads),
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
# Simpler: just learn fixed weights + L2 regularization later
|
| 127 |
+
self.raw_weights = nn.Parameter(torch.ones(num_heads))
|
| 128 |
+
|
| 129 |
+
def forward(self, logits_list):
|
| 130 |
+
"""
|
| 131 |
+
logits_list: list/tuple of 7 tensors, each (B, 14)
|
| 132 |
+
"""
|
| 133 |
+
B = logits_list[0].size(0)
|
| 134 |
+
device = logits_list[0].device
|
| 135 |
+
|
| 136 |
+
# Step 1: Temperature scaling per head
|
| 137 |
+
scaled_logits = []
|
| 138 |
+
for i, logits in enumerate(logits_list):
|
| 139 |
+
T = torch.exp(self.log_temps[i]) # >0 guaranteed
|
| 140 |
+
scaled_logits.append(logits / (T + 1e-8))
|
| 141 |
+
|
| 142 |
+
# Stack → (B, num_heads, num_classes)
|
| 143 |
+
stacked = torch.stack(scaled_logits, dim=1) # (B, 7, 14)
|
| 144 |
+
|
| 145 |
+
if self.use_gate:
|
| 146 |
+
# Step 2: Dynamic gating (sample-wise & class-wise aware)
|
| 147 |
+
gate_in = stacked.flatten(1) # (B, 7*14)
|
| 148 |
+
raw_gate = self.gate(gate_in) # (B, 7)
|
| 149 |
+
weights = torch.softmax(raw_gate, dim=-1).unsqueeze(-1) # (B,7,1)
|
| 150 |
+
else:
|
| 151 |
+
# Step 2: Fixed learned weights (still strong!)
|
| 152 |
+
weights = torch.softmax(self.raw_weights, dim=0) # (7,)
|
| 153 |
+
weights = weights.view(1, self.num_heads, 1).to(device) # (1,7,1)
|
| 154 |
+
|
| 155 |
+
# Step 3: Weighted average in logit space
|
| 156 |
+
fused_logits = (stacked * weights).sum(dim=1) # (B, 14)
|
| 157 |
+
|
| 158 |
+
return fused_logits
|
| 159 |
+
class XRAYClassifier(nn.Module):
|
| 160 |
+
def __init__(self, num_classes=14, c=1, mask_ratio=0, dropout=0.25, img_size=384,
|
| 161 |
+
encoder_dim=768, mlp_dim=3072, decoder_dim=512, encoder_depth=12,
|
| 162 |
+
encoder_head=8, decoder_depth=8, decoder_head=8, patch_size=8):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
# ---- MAE branch (frozen) ----
|
| 166 |
+
self.mae = MaskedAutoEncoder(
|
| 167 |
+
c=c, mask_ratio=0, dropout=dropout, img_size=img_size,
|
| 168 |
+
encoder_dim=encoder_dim, mlp_dim=mlp_dim, decoder_dim=decoder_dim,
|
| 169 |
+
encoder_depth=encoder_depth, encoder_head=encoder_head,
|
| 170 |
+
decoder_depth=decoder_depth, decoder_head=decoder_head, patch_size=patch_size
|
| 171 |
+
)
|
| 172 |
+
for p in self.mae.parameters():
|
| 173 |
+
p.requires_grad = False
|
| 174 |
+
|
| 175 |
+
self.token_ln = nn.LayerNorm(encoder_dim)
|
| 176 |
+
self.attn_selfpool_mae=AttentionPool(encoder_dim,1024)
|
| 177 |
+
|
| 178 |
+
# ---- DenseNet branch (pretrained by you) ----
|
| 179 |
+
# If your DenseNet supports 1 channel, set c=1 and remove the input duplication at forward.
|
| 180 |
+
self.dense = DenseNet(c=2, k=64, num_classes=num_classes)
|
| 181 |
+
|
| 182 |
+
self.dn_feat_dim = 2048
|
| 183 |
+
|
| 184 |
+
# ---- Cross-Attention Fusion (NEW) ----
|
| 185 |
+
self.cross_attn_layers = nn.ModuleList([
|
| 186 |
+
BidirectionalCrossAttention(
|
| 187 |
+
mae_dim=encoder_dim, # 768
|
| 188 |
+
dense_dim=self.dn_feat_dim, # 2048
|
| 189 |
+
num_heads=8,
|
| 190 |
+
dropout=0.1,
|
| 191 |
+
proj_dim=512
|
| 192 |
+
)
|
| 193 |
+
for _ in range(12)
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
self.attn_pool_mae=AttentionPool(encoder_dim,1024)
|
| 197 |
+
|
| 198 |
+
self.classifier_mae=nn.Sequential(
|
| 199 |
+
nn.Linear(1024, 512),
|
| 200 |
+
nn.GELU(),
|
| 201 |
+
nn.Dropout(0.1),
|
| 202 |
+
nn.Linear(512, num_classes),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self.attn_pool_dense=AttentionPool(self.dn_feat_dim,1024)
|
| 206 |
+
|
| 207 |
+
self.classifier_attn=nn.Sequential(
|
| 208 |
+
nn.Linear(2048, 1024),
|
| 209 |
+
nn.GELU(),
|
| 210 |
+
nn.Dropout(0.2),
|
| 211 |
+
nn.Linear(1024, 512),
|
| 212 |
+
nn.GELU(),
|
| 213 |
+
nn.Dropout(0.1),
|
| 214 |
+
nn.Linear(512, num_classes),
|
| 215 |
+
)
|
| 216 |
+
#FPN
|
| 217 |
+
self.lateral5 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat4: 2048 ✅
|
| 218 |
+
self.lateral4 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat3: 2048 (CHANGED)
|
| 219 |
+
self.lateral3 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) # feat2: 1024 ✅
|
| 220 |
+
self.lateral2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) # feat1: 512 (CHANGED)
|
| 221 |
+
self.output5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 222 |
+
self.output4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 223 |
+
self.output3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 224 |
+
self.output2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 225 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
| 226 |
+
|
| 227 |
+
self._classify_out5 = nn.Linear(256, num_classes)
|
| 228 |
+
self._classify_out4 = nn.Linear(256, num_classes)
|
| 229 |
+
self._classify_out3 = nn.Linear(256, num_classes)
|
| 230 |
+
self._classify_out2 = nn.Linear(256, num_classes)
|
| 231 |
+
|
| 232 |
+
self.learned_logit_ensemble = LearnedLogitEnsemble(num_classes=num_classes)
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
mae_tokens, _, _, _ = self.mae.encoder(x)
|
| 236 |
+
mae_tokens = self.token_ln(mae_tokens)
|
| 237 |
+
#self.generate_kmeans_mask(self.kmeans,mae_tokens,5)
|
| 238 |
+
doublex=torch.cat([x,x],dim=1) # [B, 2, 384, 384]
|
| 239 |
+
# ---- DenseNet path - Extract multi-scale features ----
|
| 240 |
+
xdense = self.dense.initialconv(doublex) # [B, 128, 192, 192]
|
| 241 |
+
|
| 242 |
+
# Layer 1 + ECA (BEFORE transition)
|
| 243 |
+
feat1 = self.dense.layer1(xdense)
|
| 244 |
+
feat1 = self.dense.dropout1(feat1)
|
| 245 |
+
feat1 = self.dense.eca1(feat1) # [B, 512, 192, 192] ← Keep this!
|
| 246 |
+
xdense1 = self.dense.trans1(feat1) # [B, 256, 96, 96]
|
| 247 |
+
|
| 248 |
+
# Layer 2 + ECA (BEFORE transition)
|
| 249 |
+
feat2 = self.dense.layer2(xdense1)
|
| 250 |
+
feat2 = self.dense.dropout2(feat2)
|
| 251 |
+
feat2 = self.dense.eca2(feat2) # [B, 1024, 96, 96] ← Keep this!
|
| 252 |
+
xdense2 = self.dense.trans2(feat2) # [B, 512, 48, 48]
|
| 253 |
+
|
| 254 |
+
# Layer 3 + ECA (BEFORE transition)
|
| 255 |
+
feat3 = self.dense.layer3(xdense2)
|
| 256 |
+
feat3 = self.dense.dropout3(feat3)
|
| 257 |
+
feat3 = self.dense.eca3(feat3) # [B, 2048, 48, 48] ← Keep this!
|
| 258 |
+
xdense3 = self.dense.trans3(feat3) # [B, 1024, 24, 24]
|
| 259 |
+
|
| 260 |
+
# Layer 4 (no transition)
|
| 261 |
+
feat4 = self.dense.layer4(xdense3)
|
| 262 |
+
feat4 = self.dense.dropout4(feat4)
|
| 263 |
+
feat4 = self.dense.eca4(feat4) # [B, 2048, 24, 24]
|
| 264 |
+
xdense4 = feat4
|
| 265 |
+
|
| 266 |
+
# Global pooling for DenseNet classifier
|
| 267 |
+
xdense_pooled = self.dense.global_average_pool(xdense4)
|
| 268 |
+
xdense_pooled = xdense_pooled.view(xdense_pooled.size(0), -1)
|
| 269 |
+
xdense_pooled = self.dense.dropout(xdense_pooled)
|
| 270 |
+
classifier_xdense = self.dense.classifier(xdense_pooled)
|
| 271 |
+
|
| 272 |
+
# Dense tokens for cross-attention
|
| 273 |
+
dense_tokens = xdense4.flatten(2).transpose(1, 2) # [B, 576, 2048]
|
| 274 |
+
|
| 275 |
+
# ---- FPN with CORRECT multi-scale features ----
|
| 276 |
+
c4 = self.lateral5(feat4) # [B, 2048, 24, 24] → [B, 256, 24, 24]
|
| 277 |
+
c3 = self.lateral4(feat3) # [B, 2048, 48, 48] → [B, 256, 48, 48]
|
| 278 |
+
c2 = self.lateral3(feat2) # [B, 1024, 96, 96] → [B, 256, 96, 96]
|
| 279 |
+
c1 = self.lateral2(feat1) # [B, 512, 192, 192] → [B, 256, 192, 192]
|
| 280 |
+
|
| 281 |
+
# Top-down pathway
|
| 282 |
+
p4 = c4 # 24×24
|
| 283 |
+
p4 = self.output5(p4)
|
| 284 |
+
|
| 285 |
+
p3 = self.upsample(p4) + c3 # 48×48 + 48×48 ✅
|
| 286 |
+
p3 = self.output4(p3)
|
| 287 |
+
|
| 288 |
+
p2 = self.upsample(p3) + c2 # 96×96 + 96×96 ✅
|
| 289 |
+
p2 = self.output3(p2)
|
| 290 |
+
|
| 291 |
+
p1 = self.upsample(p2) + c1 # 192×192 + 192×192 ✅
|
| 292 |
+
p1 = self.output2(p1)
|
| 293 |
+
|
| 294 |
+
# Classification heads
|
| 295 |
+
out4 = self._classify_out5(p4.mean([2, 3]))
|
| 296 |
+
out3 = self._classify_out4(p3.mean([2, 3]))
|
| 297 |
+
out2 = self._classify_out3(p2.mean([2, 3]))
|
| 298 |
+
out1 = self._classify_out2(p1.mean([2, 3]))
|
| 299 |
+
|
| 300 |
+
# ---- MAE path ----
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
mae_tokens_pooled = self.attn_selfpool_mae(mae_tokens)
|
| 304 |
+
classifier_mae = self.classifier_mae(mae_tokens_pooled)
|
| 305 |
+
|
| 306 |
+
# ---- Cross attention ----
|
| 307 |
+
for cross_layer in self.cross_attn_layers:
|
| 308 |
+
mae_cross, dense_cross = cross_layer(mae_tokens, dense_tokens)
|
| 309 |
+
|
| 310 |
+
mae_cross = self.attn_pool_mae(mae_cross)
|
| 311 |
+
dense_cross = self.attn_pool_dense(dense_cross)
|
| 312 |
+
out = torch.cat([mae_cross, dense_cross], dim=1)
|
| 313 |
+
classifier_attn = self.classifier_attn(out)
|
| 314 |
+
|
| 315 |
+
# ---- Ensemble ----
|
| 316 |
+
merged_classifier = self.learned_logit_ensemble([
|
| 317 |
+
classifier_mae,
|
| 318 |
+
classifier_xdense,
|
| 319 |
+
classifier_attn,
|
| 320 |
+
out4, out3, out2, out1 # 7 heads
|
| 321 |
+
])
|
| 322 |
+
|
| 323 |
+
return merged_classifier
|
models/densenet.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.checkpoint import checkpoint
|
| 4 |
+
|
| 5 |
+
class ChannelAttention(nn.Module):
|
| 6 |
+
def __init__(self,channels,reduction=16):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.conv1=nn.Conv2d(channels,channels//reduction,kernel_size=1,bias=False)
|
| 9 |
+
self.relu=nn.ReLU(inplace=True)
|
| 10 |
+
self.conv2=nn.Conv2d(channels//reduction,channels,kernel_size=1,bias=False)
|
| 11 |
+
self.sigmoid=nn.Sigmoid()
|
| 12 |
+
self.avgpool=nn.AdaptiveAvgPool2d((1,1))
|
| 13 |
+
self.maxpool=nn.AdaptiveMaxPool2d((1,1))
|
| 14 |
+
def forward(self,x):
|
| 15 |
+
identity=x
|
| 16 |
+
avgpool=self.avgpool(x)
|
| 17 |
+
maxpool=self.maxpool(x)
|
| 18 |
+
avgpool=self.relu(self.conv1(avgpool))
|
| 19 |
+
maxpool=self.relu(self.conv1(maxpool))
|
| 20 |
+
avgpool=self.conv2(avgpool)
|
| 21 |
+
maxpool=self.conv2(maxpool)
|
| 22 |
+
out=self.sigmoid(avgpool+maxpool)
|
| 23 |
+
return identity*out
|
| 24 |
+
|
| 25 |
+
class SpatialAttention(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
max_pool = torch.max(x, dim=1, keepdim=True)[0]
|
| 32 |
+
avg_pool = torch.mean(x, dim=1, keepdim=True)
|
| 33 |
+
attention = torch.cat([max_pool, avg_pool], dim=1)
|
| 34 |
+
attention = torch.sigmoid(self.conv(attention))
|
| 35 |
+
return x * attention
|
| 36 |
+
|
| 37 |
+
class CBAM(nn.Module):
|
| 38 |
+
def __init__(self,channels):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.ca=ChannelAttention(channels)
|
| 41 |
+
self.sa=SpatialAttention()
|
| 42 |
+
|
| 43 |
+
def forward(self,x):
|
| 44 |
+
x=self.ca(x)
|
| 45 |
+
x=self.sa(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
class InitialConv(nn.Module):
|
| 49 |
+
def __init__(self,input_channel=1,k=64):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.conv=nn.Conv2d(in_channels=input_channel,out_channels=2*k,kernel_size=7,stride=1,padding=3) # from B,1,384,384 to #B,128,384,384
|
| 52 |
+
self.bn=nn.BatchNorm2d(num_features=2*k)
|
| 53 |
+
self.relu=nn.ReLU(inplace=True)
|
| 54 |
+
self.pool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1) #from 384 to 256 #output B,128,192,192
|
| 55 |
+
def forward(self,x):
|
| 56 |
+
return self.pool(self.relu(self.bn(self.conv(x))))
|
| 57 |
+
|
| 58 |
+
class DenseLayer(nn.Module):
|
| 59 |
+
def __init__(self,c,k=64):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.bn1=nn.BatchNorm2d(num_features=c)
|
| 62 |
+
self.relu1=nn.ReLU(inplace=True)
|
| 63 |
+
self.conv1x1=nn.Conv2d(c,4*k,kernel_size=1)
|
| 64 |
+
self.bn2=nn.BatchNorm2d(num_features=4*k)
|
| 65 |
+
self.relu2=nn.ReLU(inplace=True)
|
| 66 |
+
self.conv3x3=nn.Conv2d(4*k,k,kernel_size=3, padding=1)
|
| 67 |
+
def forward(self,x):
|
| 68 |
+
identity=x
|
| 69 |
+
x=self.conv1x1(self.relu1(self.bn1(x)))
|
| 70 |
+
x=self.conv3x3(self.relu2(self.bn2(x)))
|
| 71 |
+
return torch.cat([identity,x],dim=1)
|
| 72 |
+
|
| 73 |
+
class DenseBlock(nn.Module):
|
| 74 |
+
def __init__(self,c,k=64,layer_len=6):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.blks=nn.ModuleList()
|
| 77 |
+
current_c = c
|
| 78 |
+
for _ in range(layer_len):
|
| 79 |
+
self.blks.append(DenseLayer(current_c, k))
|
| 80 |
+
current_c += k
|
| 81 |
+
def forward(self,x):
|
| 82 |
+
for layer in self.blks:x=checkpoint(layer, x,use_reentrant=False)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
class Transition(nn.Module):
|
| 86 |
+
def __init__(self,inchannels,down_factor=0.5):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.bn=nn.BatchNorm2d(num_features=inchannels)
|
| 89 |
+
self.relu=nn.ReLU(inplace=True)
|
| 90 |
+
self.conv1x1=nn.Conv2d(in_channels=inchannels,out_channels=int(down_factor*inchannels),kernel_size=1)
|
| 91 |
+
self.avgpool=nn.AvgPool2d(kernel_size=2,stride=2)
|
| 92 |
+
def forward(self,x):
|
| 93 |
+
return self.avgpool(self.conv1x1(self.relu(self.bn(x))))
|
| 94 |
+
class DenseNet(nn.Module):
|
| 95 |
+
def __init__(self,c=2,k=64,num_classes=14):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.initialconv=InitialConv(input_channel=c,k=k) #output B,128,192,192
|
| 98 |
+
self.layer1=DenseBlock(c=128,k=k,layer_len=6) #output B,inchannels+(layer_len*k),192,192 i.e # B,512,192,192
|
| 99 |
+
self.dropout1 = nn.Dropout(p=0.05)
|
| 100 |
+
self.eca1=CBAM(512)
|
| 101 |
+
self.trans1=Transition(inchannels=512,down_factor=0.5) #output B,256,96,96
|
| 102 |
+
self.layer2=DenseBlock(c=256,k=k,layer_len=12) #output B,inchannels+(layer_len*k),96,96 i.e # B,1024,96,96
|
| 103 |
+
self.dropout2 = nn.Dropout(p=0.1)
|
| 104 |
+
self.eca2=CBAM(1024)
|
| 105 |
+
self.trans2=Transition(inchannels=1024,down_factor=0.5) #output B,512,48,48
|
| 106 |
+
self.layer3=DenseBlock(c=512,k=k,layer_len=24) #output B,inchannels+(layer_len*k),48,48 i.e # B,2048,48,48
|
| 107 |
+
self.dropout3 = nn.Dropout(p=0.1)
|
| 108 |
+
self.eca3=CBAM(2048)
|
| 109 |
+
self.trans3=Transition(inchannels=2048,down_factor=0.5) #output B,1024,24,24
|
| 110 |
+
self.layer4=DenseBlock(c=1024,k=k,layer_len=16) #output B,inchannels+(layer_len*k),24,24 i.e # B,2048,24,24
|
| 111 |
+
self.dropout4 = nn.Dropout(p=0.1)
|
| 112 |
+
self.eca4=CBAM(2048)
|
| 113 |
+
self.global_average_pool= nn.AdaptiveAvgPool2d((1,1)) #output B,2048,1,1
|
| 114 |
+
self.classifier = nn.Sequential(
|
| 115 |
+
nn.Linear(2048, 1024),
|
| 116 |
+
nn.BatchNorm1d(1024),
|
| 117 |
+
nn.ReLU(),
|
| 118 |
+
nn.Dropout(0.1),
|
| 119 |
+
nn.Linear(1024, 512),
|
| 120 |
+
nn.BatchNorm1d(512),
|
| 121 |
+
nn.ReLU(),
|
| 122 |
+
nn.Dropout(0.1),
|
| 123 |
+
nn.Linear(512, 256),
|
| 124 |
+
nn.BatchNorm1d(256),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.Dropout(0.1),
|
| 127 |
+
nn.Linear(256, num_classes)
|
| 128 |
+
)
|
| 129 |
+
self.dropout = nn.Dropout(p=0.2)
|
| 130 |
+
for lay in self.classifier:
|
| 131 |
+
if isinstance(lay, nn.Linear):
|
| 132 |
+
nn.init.xavier_uniform_(lay.weight, gain=1.0)
|
| 133 |
+
nn.init.constant_(lay.bias, 0.0)
|
| 134 |
+
for m in self.modules():
|
| 135 |
+
if isinstance(m, nn.Conv2d):
|
| 136 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 137 |
+
|
| 138 |
+
def forward(self,x):
|
| 139 |
+
x=self.initialconv(x)
|
| 140 |
+
x=self.trans1(self.eca1(self.dropout1(self.layer1(x))))
|
| 141 |
+
x=self.trans2(self.eca2(self.dropout2(self.layer2(x))))
|
| 142 |
+
x=self.trans3(self.eca3(self.dropout3(self.layer3(x))))
|
| 143 |
+
x=self.eca4(self.dropout4(self.layer4(x)))
|
| 144 |
+
#x1=self.attn(x)
|
| 145 |
+
x=self.global_average_pool(x)
|
| 146 |
+
x=x.view(x.size(0),-1)
|
| 147 |
+
#x=torch.cat([x1,x2],dim=1)
|
| 148 |
+
x=self.dropout(x)
|
| 149 |
+
x=self.classifier(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def testme():
|
| 154 |
+
model=DenseNet()
|
| 155 |
+
sample=torch.randn(2,2,384,384)
|
| 156 |
+
out=model(sample)
|
| 157 |
+
print(out.shape)
|
models/mae.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
def patchify(x,patch_size=8):
|
| 5 |
+
b,c,h,w=x.shape
|
| 6 |
+
th=h//patch_size
|
| 7 |
+
tw=w//patch_size
|
| 8 |
+
assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size"
|
| 9 |
+
|
| 10 |
+
out=x.reshape(b,c,th,patch_size,tw,patch_size)
|
| 11 |
+
out=out.permute(0,2,4,1,3,5).contiguous()
|
| 12 |
+
out=out.view(b,th*tw,c*(patch_size**2))
|
| 13 |
+
return out
|
| 14 |
+
def unpatchify(x,patch_size=8):
|
| 15 |
+
b,z,p=x.shape
|
| 16 |
+
c=p//(patch_size**2)
|
| 17 |
+
th=int(math.sqrt(z))
|
| 18 |
+
tw=th
|
| 19 |
+
h=th*patch_size
|
| 20 |
+
w=tw*patch_size
|
| 21 |
+
x=x.view(b,th,tw,c,patch_size,patch_size)
|
| 22 |
+
x=x.permute(0,3,1,4,2,5).contiguous()
|
| 23 |
+
out=x.view(b,c,h,w)
|
| 24 |
+
return out
|
| 25 |
+
def random_mask(x,mask_ratio=0.75):
|
| 26 |
+
b,n,p=x.shape
|
| 27 |
+
len_keep=int(n*(1-mask_ratio))
|
| 28 |
+
noise=torch.rand(b,n).to(x.device)
|
| 29 |
+
ids_shuffle=torch.argsort(noise,dim=1)
|
| 30 |
+
ids_restore=torch.argsort(ids_shuffle,dim=1)
|
| 31 |
+
ids_keep=ids_shuffle[:,:len_keep]
|
| 32 |
+
x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device)
|
| 33 |
+
mask=torch.ones(b,n).to(x.device)
|
| 34 |
+
mask[:,:len_keep]=0
|
| 35 |
+
mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device)
|
| 36 |
+
return x_masked,mask,ids_restore,ids_keep
|
| 37 |
+
|
| 38 |
+
def mae_loss(pred, target, mask):
|
| 39 |
+
# pred/target: (B, N, P), mask: (B, N) with 1=masked
|
| 40 |
+
B, N, P = pred.shape
|
| 41 |
+
mask = mask.unsqueeze(-1).float() # (B, N, 1)
|
| 42 |
+
loss = (pred - target) ** 2
|
| 43 |
+
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
|
| 44 |
+
return loss
|
| 45 |
+
|
| 46 |
+
class PositionalEncoding(nn.Module):
|
| 47 |
+
def __init__(self,num_patches,hidden_dim=768):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim))
|
| 50 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 51 |
+
def forward(self, x, visible_indices):
|
| 52 |
+
# x: (B, len_keep, D); visible_indices: (B, len_keep)
|
| 53 |
+
B, L, D = x.shape
|
| 54 |
+
# expand table to (B, N, D)
|
| 55 |
+
pos = self.pos_embed.expand(B, -1, -1) # (B, N, D)
|
| 56 |
+
# build gather index (B, L, D)
|
| 57 |
+
idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1))
|
| 58 |
+
visible_pos = torch.gather(pos, 1, idx) # (B, L, D)
|
| 59 |
+
return x + visible_pos
|
| 60 |
+
|
| 61 |
+
class TransformerBlock(nn.Module):
|
| 62 |
+
def __init__(self,hidden_dim,mlp_dim,num_heads,dropout):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.layernorm1=nn.LayerNorm(hidden_dim)
|
| 65 |
+
self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout)
|
| 66 |
+
self.layernorm2=nn.LayerNorm(hidden_dim)
|
| 67 |
+
self.mlp=nn.Sequential(
|
| 68 |
+
nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def forward(self,x):
|
| 73 |
+
residual=x
|
| 74 |
+
x=self.layernorm1(x)
|
| 75 |
+
attn,_=self.multihead(x,x,x)
|
| 76 |
+
x=residual+attn
|
| 77 |
+
residual=x
|
| 78 |
+
x=self.layernorm2(x)
|
| 79 |
+
x=self.mlp(x)
|
| 80 |
+
x=residual+x
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class MAEEncoder(nn.Module):
|
| 84 |
+
"""
|
| 85 |
+
patch_dim-> % non-masked * no. of patches
|
| 86 |
+
"""
|
| 87 |
+
def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.mask_ratio=mask_ratio
|
| 90 |
+
self.patch_size=patch_size
|
| 91 |
+
self.patch_embed=nn.Linear(patch_dim,hidden_dim)
|
| 92 |
+
self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim)
|
| 93 |
+
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
|
| 94 |
+
for _ in range(depth)])
|
| 95 |
+
|
| 96 |
+
self._init_weights()
|
| 97 |
+
def _init_weights(self):
|
| 98 |
+
for m in self.modules():
|
| 99 |
+
if isinstance(m, nn.Linear):
|
| 100 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 101 |
+
if m.bias is not None:
|
| 102 |
+
nn.init.constant_(m.bias, 0)
|
| 103 |
+
|
| 104 |
+
def forward(self,x_in):
|
| 105 |
+
x_p=patchify(x_in,self.patch_size)
|
| 106 |
+
x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio)
|
| 107 |
+
x= self.patch_embed(x_masked)
|
| 108 |
+
x=self.pos_embed(x,ids_keep)
|
| 109 |
+
for attn_layer in self.transformer:x=attn_layer(x)
|
| 110 |
+
return x,mask,ids_keep,ids_restore
|
| 111 |
+
|
| 112 |
+
class MAEDecoder(nn.Module):
|
| 113 |
+
def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.num_patches=num_patches
|
| 116 |
+
self.encoder_dim=encoder_dim
|
| 117 |
+
self.decoder_dim=decoder_dim
|
| 118 |
+
self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim))
|
| 119 |
+
self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim)
|
| 120 |
+
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim))
|
| 121 |
+
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
|
| 122 |
+
for _ in range(decoder_depth)])
|
| 123 |
+
self.layernorm=nn.LayerNorm(decoder_dim)
|
| 124 |
+
self.pred=nn.Linear(decoder_dim,c*(patch_size**2))
|
| 125 |
+
|
| 126 |
+
self._init_weights()
|
| 127 |
+
def _init_weights(self):
|
| 128 |
+
for m in self.modules():
|
| 129 |
+
if isinstance(m, nn.Linear):
|
| 130 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 131 |
+
if m.bias is not None:
|
| 132 |
+
nn.init.constant_(m.bias, 0)
|
| 133 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 134 |
+
nn.init.trunc_normal_(self.mask_token, std=0.02)
|
| 135 |
+
def forward(self,x,ids_keep,ids_restore):
|
| 136 |
+
b,n,p=x.shape
|
| 137 |
+
xdec=self.enc_to_dec(x)
|
| 138 |
+
len_keep=xdec.size(1)
|
| 139 |
+
num_patches=ids_restore.size(1)
|
| 140 |
+
num_mask=num_patches-len_keep
|
| 141 |
+
|
| 142 |
+
mask_token=self.mask_token.expand(b,num_mask,-1)
|
| 143 |
+
x_=torch.cat([xdec,mask_token],dim=1)
|
| 144 |
+
x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1)))
|
| 145 |
+
x_=x_+self.pos_embed
|
| 146 |
+
for block in self.transformer:x_=block(x_)
|
| 147 |
+
x_=self.layernorm(x_)
|
| 148 |
+
out=self.pred(x_)
|
| 149 |
+
return out
|
| 150 |
+
|
| 151 |
+
class MaskedAutoEncoder(nn.Module):
|
| 152 |
+
def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.patch_size=patch_size
|
| 155 |
+
self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2
|
| 156 |
+
,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head
|
| 157 |
+
,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size)
|
| 158 |
+
self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size
|
| 159 |
+
,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth
|
| 160 |
+
,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout)
|
| 161 |
+
|
| 162 |
+
def forward(self,x):
|
| 163 |
+
b,c,h,w=x.shape
|
| 164 |
+
encoded,mask,ids_keep,ids_restore=self.encoder(x)
|
| 165 |
+
decoded=self.decoder(encoded,ids_keep,ids_restore)
|
| 166 |
+
|
| 167 |
+
xpatched=patchify(x,self.patch_size)
|
| 168 |
+
return xpatched,decoded,mask
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def testme():
|
| 172 |
+
img=torch.rand(1,1,384,384)
|
| 173 |
+
mae=MaskedAutoEncoder()
|
| 174 |
+
a,b,c=mae(img)
|
| 175 |
+
print(a.shape)
|
| 176 |
+
print(b.shape)
|
| 177 |
+
print(c.shape)
|
notebooks/chexpert_mae.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/chexpert_mae_mask_classifier.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Deep Learning
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
|
| 5 |
+
# Data Processing
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
scikit-learn>=1.3.0
|
| 9 |
+
|
| 10 |
+
# Image Processing
|
| 11 |
+
Pillow>=10.0.0
|
| 12 |
+
opencv-python>=4.8.0
|
| 13 |
+
albumentations>=1.3.1
|
| 14 |
+
|
| 15 |
+
# Visualization
|
| 16 |
+
matplotlib>=3.7.0
|
| 17 |
+
seaborn>=0.12.0
|
| 18 |
+
|
| 19 |
+
# Utilities
|
| 20 |
+
tqdm>=4.65.0
|
| 21 |
+
|
| 22 |
+
# Jupyter (optional - for notebooks)
|
| 23 |
+
jupyter>=1.0.0
|
| 24 |
+
ipykernel>=6.25.0
|
| 25 |
+
ipywidgets>=8.1.0
|
| 26 |
+
|
| 27 |
+
# Additional utilities (if needed)
|
| 28 |
+
# lmdb>=1.4.0 # Uncomment if using LMDB for caching
|
| 29 |
+
# tensorboard>=2.13.0 # Uncomment if using TensorBoard logging
|
results/test-results.docx
ADDED
|
Binary file (7.54 kB). View file
|
|
|
trainer/__init__.py
ADDED
|
File without changes
|
trainer/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
trainer/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
trainer/__pycache__/trainer.cpython-313.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
trainer/__pycache__/trainer.cpython-314.pyc
ADDED
|
Binary file (713 Bytes). View file
|
|
|
trainer/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (48.4 kB). View file
|
|
|
trainer/test.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import Trainer
|
| 2 |
+
from configs.configs import root,config
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
print("Testing classifier")
|
| 7 |
+
try:
|
| 8 |
+
tester=Trainer(config)
|
| 9 |
+
tester.test(model_path=config["resume"])
|
| 10 |
+
except:
|
| 11 |
+
import traceback
|
| 12 |
+
traceback.print_exc()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
if __name__=="__main__":main()
|
trainer/trainer.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import *
|
| 2 |
+
from configs.configs import root,config,mae_config
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
try:
|
| 6 |
+
decision=input("train mae or classifier? ")
|
| 7 |
+
if decision=="mae":
|
| 8 |
+
print(f"Training mae")
|
| 9 |
+
trainer=MAETrainer(mae_config)
|
| 10 |
+
trainer.train()
|
| 11 |
+
if decision=="classifier":
|
| 12 |
+
print(f"Training classifier")
|
| 13 |
+
trainer=Trainer(config)
|
| 14 |
+
trainer.train()
|
| 15 |
+
except:
|
| 16 |
+
import traceback
|
| 17 |
+
traceback.print_exc()
|
| 18 |
+
|
| 19 |
+
if __name__=="__main__":main()
|
trainer/utils.py
ADDED
|
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data.dataset import CheXpertDataset
|
| 2 |
+
from loss.assymetric import AsymmetricLoss
|
| 3 |
+
from models.mae import *
|
| 4 |
+
from models.densenet import *
|
| 5 |
+
from models.classifier import *
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import io
|
| 10 |
+
import sys
|
| 11 |
+
from sklearn.metrics import roc_auc_score,confusion_matrix
|
| 12 |
+
|
| 13 |
+
class TeeFile:
|
| 14 |
+
"""
|
| 15 |
+
File-like object that writes to multiple streams (e.g., stdout and a file)
|
| 16 |
+
Automatically handles string paths by opening them as files.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
# This now works with both file objects and paths
|
| 20 |
+
tee = TeeFile(sys.stdout, "/path/to/log.txt")
|
| 21 |
+
print("Hello", file=tee) # Writes to both stdout and the file
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, *file_objects_or_paths):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
*file_objects_or_paths: Mix of file objects (like sys.stdout)
|
| 27 |
+
or string paths to log files
|
| 28 |
+
"""
|
| 29 |
+
self.files = []
|
| 30 |
+
self.opened_files = [] # Track files we opened so we can close them later
|
| 31 |
+
|
| 32 |
+
for item in file_objects_or_paths:
|
| 33 |
+
if isinstance(item, str):
|
| 34 |
+
# It's a path string - open it as a file
|
| 35 |
+
f = open(item, 'a', buffering=1) # Append mode, line buffered
|
| 36 |
+
self.files.append(f)
|
| 37 |
+
self.opened_files.append(f)
|
| 38 |
+
else:
|
| 39 |
+
# It's already a file-like object (e.g., sys.stdout)
|
| 40 |
+
self.files.append(item)
|
| 41 |
+
|
| 42 |
+
def write(self, data):
|
| 43 |
+
"""Write data to all streams"""
|
| 44 |
+
for f in self.files:
|
| 45 |
+
try:
|
| 46 |
+
f.write(data)
|
| 47 |
+
f.flush()
|
| 48 |
+
except Exception as e:
|
| 49 |
+
# Handle closed file gracefully
|
| 50 |
+
print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
|
| 51 |
+
|
| 52 |
+
def flush(self):
|
| 53 |
+
"""Flush all streams"""
|
| 54 |
+
for f in self.files:
|
| 55 |
+
try:
|
| 56 |
+
f.flush()
|
| 57 |
+
except:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def isatty(self):
|
| 61 |
+
"""Check if any stream is a terminal (for tqdm compatibility)"""
|
| 62 |
+
return any(getattr(f, "isatty", lambda: False)() for f in self.files)
|
| 63 |
+
|
| 64 |
+
def fileno(self):
|
| 65 |
+
"""Get file descriptor from any real file-like stream"""
|
| 66 |
+
for f in self.files:
|
| 67 |
+
if hasattr(f, "fileno"):
|
| 68 |
+
try:
|
| 69 |
+
return f.fileno()
|
| 70 |
+
except Exception:
|
| 71 |
+
pass
|
| 72 |
+
raise io.UnsupportedOperation("No fileno available")
|
| 73 |
+
|
| 74 |
+
def close(self):
|
| 75 |
+
"""Close any files we opened"""
|
| 76 |
+
for f in self.opened_files:
|
| 77 |
+
try:
|
| 78 |
+
f.close()
|
| 79 |
+
except:
|
| 80 |
+
pass
|
| 81 |
+
self.opened_files.clear()
|
| 82 |
+
|
| 83 |
+
def __del__(self):
|
| 84 |
+
"""Cleanup on deletion"""
|
| 85 |
+
self.close()
|
| 86 |
+
|
| 87 |
+
def __enter__(self):
|
| 88 |
+
"""Context manager support"""
|
| 89 |
+
return self
|
| 90 |
+
|
| 91 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 92 |
+
"""Context manager cleanup"""
|
| 93 |
+
self.close()
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
class MAETrainer:
|
| 97 |
+
def __init__(self,configs={}):
|
| 98 |
+
|
| 99 |
+
self.configs=configs
|
| 100 |
+
os.makedirs(configs["logdir"],exist_ok=True)
|
| 101 |
+
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
|
| 102 |
+
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
|
| 103 |
+
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
|
| 104 |
+
#self.log_file = open(log_path, 'w', buffering=1)
|
| 105 |
+
self.traintee = TeeFile(sys.stdout, log_path_train)
|
| 106 |
+
self.valtee = TeeFile(sys.stdout, log_path_val)
|
| 107 |
+
self.testtee = TeeFile(sys.stdout, log_path_test)
|
| 108 |
+
|
| 109 |
+
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
|
| 110 |
+
|
| 111 |
+
self.model=MaskedAutoEncoder(
|
| 112 |
+
c=configs["channels"],
|
| 113 |
+
mask_ratio=configs["mask_ratio"],
|
| 114 |
+
dropout=configs["dropout"],
|
| 115 |
+
img_size=configs["img_size"],
|
| 116 |
+
encoder_dim=configs["encoder_dim"],
|
| 117 |
+
mlp_dim=configs["mlp_dim"],
|
| 118 |
+
decoder_dim=configs["decoder_dim"],
|
| 119 |
+
encoder_depth=configs["encoder_depth"],
|
| 120 |
+
encoder_head=configs["encoder_head"],
|
| 121 |
+
decoder_depth=configs["decoder_depth"],
|
| 122 |
+
decoder_head=configs["decoder_head"],
|
| 123 |
+
patch_size=configs["patch_size"]
|
| 124 |
+
).to(configs["device"])
|
| 125 |
+
|
| 126 |
+
self.criterion=mae_loss
|
| 127 |
+
|
| 128 |
+
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
|
| 129 |
+
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
|
| 130 |
+
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
|
| 131 |
+
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
|
| 132 |
+
self.scaler=torch.amp.GradScaler()
|
| 133 |
+
|
| 134 |
+
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
|
| 135 |
+
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
|
| 136 |
+
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
|
| 137 |
+
self.sample_Weights=self.train_dataset.get_sample_weights()
|
| 138 |
+
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
|
| 139 |
+
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
|
| 140 |
+
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
|
| 141 |
+
self.history={"train_loss":[],"val_loss":[]}
|
| 142 |
+
|
| 143 |
+
self.current_epoch=0
|
| 144 |
+
|
| 145 |
+
if os.path.exists(self.configs["resume"]):
|
| 146 |
+
loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
|
| 147 |
+
self.model.load_state_dict(loadedpickle["model"],strict=False)
|
| 148 |
+
self.optimizer.load_state_dict(loadedpickle["optimizer"])
|
| 149 |
+
self.schedular.load_state_dict(loadedpickle["schedular"])
|
| 150 |
+
self.schedular1.load_state_dict(loadedpickle["schedular1"])
|
| 151 |
+
self.schedular2.load_state_dict(loadedpickle["schedular2"])
|
| 152 |
+
self.scaler.load_state_dict(loadedpickle["scaler"])
|
| 153 |
+
self.current_epoch=loadedpickle["epoch"]+1
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
self.test_dataset = None
|
| 158 |
+
self.testloader = None
|
| 159 |
+
if configs.get("test_csv"):
|
| 160 |
+
self.test_dataset = CheXpertDataset(
|
| 161 |
+
zip_path=configs["zip_path"],
|
| 162 |
+
csv_path=configs["test_csv"],
|
| 163 |
+
root_dir=configs["datadir"],
|
| 164 |
+
augment=False,
|
| 165 |
+
use_frontal_only=True
|
| 166 |
+
)
|
| 167 |
+
self.testloader = DataLoader(
|
| 168 |
+
self.test_dataset,
|
| 169 |
+
batch_size=configs["batch_size"],
|
| 170 |
+
shuffle=False,
|
| 171 |
+
num_workers=8,
|
| 172 |
+
pin_memory=True,
|
| 173 |
+
persistent_workers=True
|
| 174 |
+
)
|
| 175 |
+
print(f"Test loader ready – {len(self.test_dataset)} images")
|
| 176 |
+
|
| 177 |
+
torch.backends.cudnn.benchmark = True
|
| 178 |
+
torch.backends.cudnn.enabled = True
|
| 179 |
+
|
| 180 |
+
# FIX: Set memory allocator settings
|
| 181 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 182 |
+
|
| 183 |
+
# FIX: Enable gradient checkpointing if model supports it
|
| 184 |
+
if hasattr(self.model, 'enable_gradient_checkpointing'):
|
| 185 |
+
self.model.enable_gradient_checkpointing()
|
| 186 |
+
@staticmethod
|
| 187 |
+
def plot_training_metrics(metrics, epoch,figs_path):
|
| 188 |
+
import matplotlib.pyplot as plt
|
| 189 |
+
"""
|
| 190 |
+
Plot loss and AUC curves from training metrics.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
metrics (dict): Dictionary containing lists for each metric key:
|
| 194 |
+
{
|
| 195 |
+
"train_loss": [...],
|
| 196 |
+
"val_loss": [...]
|
| 197 |
+
}
|
| 198 |
+
epoch (int): Current epoch number (used for title or axis scaling)
|
| 199 |
+
"""
|
| 200 |
+
epochs = list(range(1, epoch + 1))
|
| 201 |
+
|
| 202 |
+
#Compute the common length across all series
|
| 203 |
+
keys = ["train_loss","val_loss"]
|
| 204 |
+
lengths = [len(metrics[k]) for k in keys if k in metrics]
|
| 205 |
+
if not lengths:
|
| 206 |
+
return
|
| 207 |
+
n = min(lengths)
|
| 208 |
+
|
| 209 |
+
# Slice everything to the same length
|
| 210 |
+
m = {k: metrics[k][:n] for k in keys if k in metrics}
|
| 211 |
+
epochs = list(range(1, n + 1))
|
| 212 |
+
|
| 213 |
+
plt.figure(figsize=(14, 6))
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ---- Loss subplot ----
|
| 217 |
+
plt.subplot(1, 2, 1)
|
| 218 |
+
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
|
| 219 |
+
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
|
| 220 |
+
plt.xlabel("Epoch")
|
| 221 |
+
plt.ylabel("Loss")
|
| 222 |
+
plt.title("Training & Validation Loss")
|
| 223 |
+
plt.legend()
|
| 224 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
plt.tight_layout()
|
| 228 |
+
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
|
| 229 |
+
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
|
| 230 |
+
plt.show()
|
| 231 |
+
|
| 232 |
+
def train_epoch(self, epoch, looper):
|
| 233 |
+
self.model.train()
|
| 234 |
+
running_loss = 0.0
|
| 235 |
+
all_preds = []
|
| 236 |
+
all_targets = []
|
| 237 |
+
current_loss=0
|
| 238 |
+
total_batches = len(self.trainloader)
|
| 239 |
+
|
| 240 |
+
for batch_idx, data in looper:
|
| 241 |
+
image = data['image'].to(self.configs["device"], non_blocking=True)
|
| 242 |
+
target = data['labels'].to(self.configs["device"], non_blocking=True)
|
| 243 |
+
|
| 244 |
+
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
| 245 |
+
img,preds,mask = self.model(image)
|
| 246 |
+
loss = self.criterion(img,preds,mask)
|
| 247 |
+
|
| 248 |
+
loss_back = loss / self.configs["accumulation"]
|
| 249 |
+
running_loss += loss.item()
|
| 250 |
+
|
| 251 |
+
if torch.isfinite(loss):
|
| 252 |
+
#loss_back.backward()
|
| 253 |
+
self.scaler.scale(loss_back).backward()
|
| 254 |
+
else:
|
| 255 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
|
| 259 |
+
self.scaler.unscale_(self.optimizer)
|
| 260 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 261 |
+
self.scaler.step(self.optimizer)
|
| 262 |
+
self.scaler.update()
|
| 263 |
+
#self.optimizer.step()
|
| 264 |
+
self.schedular.step()
|
| 265 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# === LIVE METRICS (every batch) ===
|
| 269 |
+
current_loss = running_loss / (batch_idx + 1)
|
| 270 |
+
if (batch_idx + 1) % 10 == 0:
|
| 271 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 272 |
+
looper.set_postfix({
|
| 273 |
+
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
|
| 274 |
+
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
| 275 |
+
"loss": f"{current_loss:.3f}",
|
| 276 |
+
})
|
| 277 |
+
|
| 278 |
+
return current_loss
|
| 279 |
+
def validate(self, epoch, looper):
|
| 280 |
+
self.model.eval()
|
| 281 |
+
val_loss = 0.0
|
| 282 |
+
all_preds = []
|
| 283 |
+
all_targets = []
|
| 284 |
+
lenloader=len(self.valloader)
|
| 285 |
+
current_loss=0
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
for batch_idx, data in looper:
|
| 288 |
+
image = data["image"].to(self.configs["device"], non_blocking=True)
|
| 289 |
+
target = data["labels"].to(self.configs["device"], non_blocking=True)
|
| 290 |
+
|
| 291 |
+
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
| 292 |
+
img,preds,mask = self.model(image)
|
| 293 |
+
loss = self.criterion(img,preds,mask)
|
| 294 |
+
|
| 295 |
+
val_loss += loss.item()
|
| 296 |
+
|
| 297 |
+
# === LIVE METRICS ===
|
| 298 |
+
current_loss = val_loss / (batch_idx + 1)
|
| 299 |
+
if (batch_idx + 1) % 10 == 0 :
|
| 300 |
+
|
| 301 |
+
looper.set_postfix({
|
| 302 |
+
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
| 303 |
+
"batch":f"{batch_idx}/{lenloader}",
|
| 304 |
+
"loss": f"{current_loss:.3f}",
|
| 305 |
+
})
|
| 306 |
+
|
| 307 |
+
return current_loss
|
| 308 |
+
def train(self):
|
| 309 |
+
|
| 310 |
+
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
|
| 311 |
+
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
|
| 312 |
+
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
self.model.train()
|
| 316 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 317 |
+
|
| 318 |
+
running_loss=self.train_epoch(epoch,trainlooper)
|
| 319 |
+
|
| 320 |
+
torch.cuda.synchronize()
|
| 321 |
+
torch.cuda.empty_cache()
|
| 322 |
+
|
| 323 |
+
val_loss=self.validate(epoch,vallooper)
|
| 324 |
+
|
| 325 |
+
torch.cuda.synchronize()
|
| 326 |
+
torch.cuda.empty_cache()
|
| 327 |
+
|
| 328 |
+
gc.collect()
|
| 329 |
+
|
| 330 |
+
if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
|
| 331 |
+
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
|
| 332 |
+
torch.save(checkpoint, self.configs["resume"])
|
| 333 |
+
|
| 334 |
+
print(f"train loss {running_loss} val loss {val_loss}")
|
| 335 |
+
|
| 336 |
+
self.history["train_loss"].append(float(running_loss))
|
| 337 |
+
self.history["val_loss"].append(float(val_loss))
|
| 338 |
+
|
| 339 |
+
if epoch%10==0:
|
| 340 |
+
historyfile=os.path.join(self.configs["logdir"],"history.json")
|
| 341 |
+
if os.path.exists(historyfile):
|
| 342 |
+
with open(historyfile,"r") as f:
|
| 343 |
+
history=json.load(f)
|
| 344 |
+
history["train_loss"]+=self.history["train_loss"]
|
| 345 |
+
history["val_loss"]+=self.history["val_loss"]
|
| 346 |
+
with open(historyfile,"w") as f:
|
| 347 |
+
json.dump(self.history,f)
|
| 348 |
+
f.close()
|
| 349 |
+
MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
|
| 350 |
+
|
| 351 |
+
self.current_epoch=epoch
|
| 352 |
+
|
| 353 |
+
class Trainer:
|
| 354 |
+
def __init__(self,configs={}):
|
| 355 |
+
|
| 356 |
+
self.configs=configs
|
| 357 |
+
os.makedirs(configs["logdir"],exist_ok=True)
|
| 358 |
+
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
|
| 359 |
+
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
|
| 360 |
+
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
|
| 361 |
+
#self.log_file = open(log_path, 'w', buffering=1)
|
| 362 |
+
self.traintee = TeeFile(sys.stdout, log_path_train)
|
| 363 |
+
self.valtee = TeeFile(sys.stdout, log_path_val)
|
| 364 |
+
self.testtee = TeeFile(sys.stdout, log_path_test)
|
| 365 |
+
|
| 366 |
+
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
|
| 367 |
+
|
| 368 |
+
self.model=XRAYClassifier(
|
| 369 |
+
c=configs["channels"],
|
| 370 |
+
num_classes=configs["num_classes"],
|
| 371 |
+
mask_ratio=configs["mask_ratio"],
|
| 372 |
+
dropout=configs["dropout"],
|
| 373 |
+
img_size=configs["img_size"],
|
| 374 |
+
encoder_dim=configs["encoder_dim"],
|
| 375 |
+
mlp_dim=configs["mlp_dim"],
|
| 376 |
+
decoder_dim=configs["decoder_dim"],
|
| 377 |
+
encoder_depth=configs["encoder_depth"],
|
| 378 |
+
encoder_head=configs["encoder_head"],
|
| 379 |
+
decoder_depth=configs["decoder_depth"],
|
| 380 |
+
decoder_head=configs["decoder_head"],
|
| 381 |
+
patch_size=configs["patch_size"]
|
| 382 |
+
).to(configs["device"])
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
|
| 387 |
+
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
|
| 388 |
+
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
|
| 389 |
+
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
|
| 390 |
+
self.scaler=torch.amp.GradScaler()
|
| 391 |
+
|
| 392 |
+
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True,mask_dir=configs["maskdir"])
|
| 393 |
+
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True,mask_dir=configs["maskdir"] )
|
| 394 |
+
|
| 395 |
+
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
|
| 396 |
+
self.sample_Weights=self.train_dataset.get_sample_weights()
|
| 397 |
+
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
|
| 398 |
+
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=0,pin_memory=True,persistent_workers=False)
|
| 399 |
+
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=0,pin_memory=True,persistent_workers=False)
|
| 400 |
+
self.criterion=AsymmetricLoss(class_weights=self.class_Weights).to(self.configs["device"])
|
| 401 |
+
self.history={"train_loss":[],"val_loss":[],"train_macro_auc":[],"val_macro_auc":[],"train_micro_auc":[],"val_micro_auc":[]}
|
| 402 |
+
if os.path.exists(os.path.join(self.configs["logdir"],"history.json")):
|
| 403 |
+
with open(os.path.join(self.configs["logdir"],"history.json"),'r') as hf:
|
| 404 |
+
self.history=json.load(hf)
|
| 405 |
+
hf.close()
|
| 406 |
+
self.current_epoch=0
|
| 407 |
+
|
| 408 |
+
self.optimal_thresholds =[0.5]*14
|
| 409 |
+
|
| 410 |
+
if os.path.exists(self.configs["resume"]):
|
| 411 |
+
ckpt = torch.load(self.configs["resume"], map_location=self.configs["device"],weights_only=False)
|
| 412 |
+
self.model.load_state_dict(ckpt["model"], strict=False)
|
| 413 |
+
self.optimizer.load_state_dict(ckpt["optimizer"])
|
| 414 |
+
self.schedular.load_state_dict(ckpt["schedular"])
|
| 415 |
+
self.schedular1.load_state_dict(ckpt["schedular1"])
|
| 416 |
+
self.schedular2.load_state_dict(ckpt["schedular2"])
|
| 417 |
+
self.scaler.load_state_dict(ckpt["scaler"])
|
| 418 |
+
self.current_epoch = ckpt.get("epoch", -1) + 1
|
| 419 |
+
self.optimal_thresholds =ckpt.get("thresholds")
|
| 420 |
+
else:
|
| 421 |
+
# Load MAE backbone only (pretrained)
|
| 422 |
+
bb = torch.load(self.configs["backbone"], map_location=self.configs["device"],weights_only=False)
|
| 423 |
+
|
| 424 |
+
# Optional: strip 'module.' if present
|
| 425 |
+
state = bb["model"]
|
| 426 |
+
if any(k.startswith("module.") for k in state.keys()):
|
| 427 |
+
from collections import OrderedDict
|
| 428 |
+
state = OrderedDict((k.replace("module.", "", 1), v) for k, v in state.items())
|
| 429 |
+
|
| 430 |
+
missing, unexpected = self.model.mae.load_state_dict(state, strict=False)
|
| 431 |
+
print("loaded backbone")
|
| 432 |
+
if missing: print(f"Missing keys: {len(missing)} (showing first 5): {missing[:5]}")
|
| 433 |
+
if unexpected: print(f"Unexpected keys: {len(unexpected)} (first 5): {unexpected[:5]}")
|
| 434 |
+
|
| 435 |
+
# (Optional) freeze backbone for warmup
|
| 436 |
+
for p in self.model.mae.parameters():
|
| 437 |
+
p.requires_grad = False
|
| 438 |
+
if os.path.exists(self.configs["densebackbone"]):
|
| 439 |
+
densebb=torch.load(self.configs["densebackbone"], map_location=self.configs["device"])
|
| 440 |
+
densestate = densebb["model"]
|
| 441 |
+
if any(k.startswith("module.") for k in state.keys()):
|
| 442 |
+
from collections import OrderedDict
|
| 443 |
+
state = OrderedDict((k.replace("module.", "", 1), v) for k, v in densestate.items())
|
| 444 |
+
densemissing, denseunexpected = self.model.dense.load_state_dict(densestate, strict=False)
|
| 445 |
+
print("loaded dense backbone")
|
| 446 |
+
if densemissing: print(f"Missing keys: {len(densemissing)} (showing first 5): {densemissing[:5]}")
|
| 447 |
+
if denseunexpected: print(f"Unexpected keys: {len(denseunexpected)} (first 5): {denseunexpected[:5]}")
|
| 448 |
+
|
| 449 |
+
self.test_dataset = None
|
| 450 |
+
self.testloader = None
|
| 451 |
+
if configs.get("test_csv"):
|
| 452 |
+
self.test_dataset = CheXpertDataset(
|
| 453 |
+
zip_path=configs["zip_path"],
|
| 454 |
+
csv_path=configs["test_csv"],
|
| 455 |
+
root_dir=configs["datadir"],
|
| 456 |
+
augment=False,
|
| 457 |
+
use_frontal_only=True
|
| 458 |
+
)
|
| 459 |
+
self.testloader = DataLoader(
|
| 460 |
+
self.test_dataset,
|
| 461 |
+
batch_size=configs["batch_size"],
|
| 462 |
+
shuffle=False,
|
| 463 |
+
num_workers=0,
|
| 464 |
+
pin_memory=True,
|
| 465 |
+
persistent_workers=False
|
| 466 |
+
)
|
| 467 |
+
print(f"Test loader ready – {len(self.test_dataset)} images")
|
| 468 |
+
|
| 469 |
+
torch.backends.cudnn.benchmark = True
|
| 470 |
+
torch.backends.cudnn.enabled = True
|
| 471 |
+
|
| 472 |
+
# FIX: Set memory allocator settings
|
| 473 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 474 |
+
|
| 475 |
+
# FIX: Enable gradient checkpointing if model supports it
|
| 476 |
+
if hasattr(self.model, 'enable_gradient_checkpointing'):
|
| 477 |
+
self.model.enable_gradient_checkpointing()
|
| 478 |
+
@staticmethod
|
| 479 |
+
def plot_training_metrics(metrics, epoch,figs_path):
|
| 480 |
+
import matplotlib.pyplot as plt
|
| 481 |
+
"""
|
| 482 |
+
Plot loss and AUC curves from training metrics.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
metrics (dict): Dictionary containing lists for each metric key:
|
| 486 |
+
{
|
| 487 |
+
"train_loss": [...],
|
| 488 |
+
"val_loss": [...],
|
| 489 |
+
"train_macro_auc": [...],
|
| 490 |
+
"val_macro_auc": [...],
|
| 491 |
+
"train_micro_auc": [...],
|
| 492 |
+
"val_micro_auc": [...]
|
| 493 |
+
}
|
| 494 |
+
epoch (int): Current epoch number (used for title or axis scaling)
|
| 495 |
+
"""
|
| 496 |
+
epochs = list(range(1, epoch + 1))
|
| 497 |
+
|
| 498 |
+
#Compute the common length across all series
|
| 499 |
+
keys = ["train_loss","val_loss","train_macro_auc","val_macro_auc","train_micro_auc","val_micro_auc"]
|
| 500 |
+
lengths = [len(metrics[k]) for k in keys if k in metrics]
|
| 501 |
+
if not lengths:
|
| 502 |
+
return
|
| 503 |
+
n = min(lengths)
|
| 504 |
+
|
| 505 |
+
# Slice everything to the same length
|
| 506 |
+
m = {k: metrics[k][:n] for k in keys if k in metrics}
|
| 507 |
+
epochs = list(range(1, n + 1))
|
| 508 |
+
|
| 509 |
+
plt.figure(figsize=(14, 6))
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
# ---- Loss subplot ----
|
| 513 |
+
plt.subplot(1, 2, 1)
|
| 514 |
+
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
|
| 515 |
+
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
|
| 516 |
+
plt.xlabel("Epoch")
|
| 517 |
+
plt.ylabel("Loss")
|
| 518 |
+
plt.title("Training & Validation Loss")
|
| 519 |
+
plt.legend()
|
| 520 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 521 |
+
|
| 522 |
+
# ---- AUC subplot ----
|
| 523 |
+
plt.subplot(1, 2, 2)
|
| 524 |
+
plt.plot(epochs, metrics["train_macro_auc"], label="Train Macro AUC", marker='o')
|
| 525 |
+
plt.plot(epochs, metrics["val_macro_auc"], label="Val Macro AUC", marker='s')
|
| 526 |
+
plt.plot(epochs, metrics["train_micro_auc"], label="Train Micro AUC", marker='^')
|
| 527 |
+
plt.plot(epochs, metrics["val_micro_auc"], label="Val Micro AUC", marker='v')
|
| 528 |
+
plt.xlabel("Epoch")
|
| 529 |
+
plt.ylabel("AUC")
|
| 530 |
+
plt.title("Training & Validation AUC (Macro/Micro)")
|
| 531 |
+
plt.legend()
|
| 532 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 533 |
+
|
| 534 |
+
plt.tight_layout()
|
| 535 |
+
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
|
| 536 |
+
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
|
| 537 |
+
plt.show()
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def train_epoch(self, epoch, looper):
|
| 542 |
+
self.model.train()
|
| 543 |
+
running_loss = 0.0
|
| 544 |
+
all_preds = []
|
| 545 |
+
all_targets = []
|
| 546 |
+
|
| 547 |
+
total_batches = len(self.trainloader)
|
| 548 |
+
|
| 549 |
+
for batch_idx, data in looper:
|
| 550 |
+
image = data['image'].to(self.configs["device"], non_blocking=True)
|
| 551 |
+
target = data['labels'].to(self.configs["device"], non_blocking=True)
|
| 552 |
+
|
| 553 |
+
#with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
| 554 |
+
logits = self.model(image)
|
| 555 |
+
#with torch.autocast(device_type=self.configs["device"].type, enabled=False):
|
| 556 |
+
|
| 557 |
+
preds = torch.sigmoid(logits.float())
|
| 558 |
+
loss = self.criterion(preds, target)
|
| 559 |
+
|
| 560 |
+
loss_back = loss / self.configs["accumulation"]
|
| 561 |
+
running_loss += loss.item()
|
| 562 |
+
|
| 563 |
+
if torch.isfinite(loss):
|
| 564 |
+
loss_back.backward()
|
| 565 |
+
#self.scaler.scale(loss_back).backward()
|
| 566 |
+
else:
|
| 567 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 568 |
+
continue
|
| 569 |
+
|
| 570 |
+
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
|
| 571 |
+
#self.scaler.unscale_(self.optimizer)
|
| 572 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 573 |
+
#self.scaler.step(self.optimizer)
|
| 574 |
+
#self.scaler.update()
|
| 575 |
+
self.optimizer.step()
|
| 576 |
+
|
| 577 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 578 |
+
|
| 579 |
+
# Store for AUC
|
| 580 |
+
all_preds.append(preds.detach().cpu())
|
| 581 |
+
all_targets.append(target.detach().cpu())
|
| 582 |
+
|
| 583 |
+
# === LIVE METRICS (every batch) ===
|
| 584 |
+
current_loss = running_loss / (batch_idx + 1)
|
| 585 |
+
if (batch_idx + 1) % 500 == 0 and len(all_preds) > 0:
|
| 586 |
+
preds_np = torch.cat(all_preds).numpy()
|
| 587 |
+
targets_np = torch.cat(all_targets).numpy()
|
| 588 |
+
macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
|
| 589 |
+
micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
|
| 590 |
+
|
| 591 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 592 |
+
looper.set_postfix({
|
| 593 |
+
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
|
| 594 |
+
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
| 595 |
+
"loss": f"{current_loss:.3f}",
|
| 596 |
+
"macro": f"{macro_auc:.3f}",
|
| 597 |
+
"micro": f"{micro_auc:.3f}"
|
| 598 |
+
})
|
| 599 |
+
|
| 600 |
+
# === FINAL FULL EPOCH METRICS ===
|
| 601 |
+
preds_full = torch.cat(all_preds).numpy()
|
| 602 |
+
targets_full = torch.cat(all_targets).numpy()
|
| 603 |
+
final_loss = running_loss / total_batches
|
| 604 |
+
final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
|
| 605 |
+
final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
|
| 606 |
+
|
| 607 |
+
del all_preds, all_targets, preds_full, targets_full
|
| 608 |
+
|
| 609 |
+
return final_loss, final_macro_auc, final_micro_auc
|
| 610 |
+
|
| 611 |
+
def validate(self, epoch, looper):
|
| 612 |
+
self.model.eval()
|
| 613 |
+
val_loss = 0.0
|
| 614 |
+
all_preds = []
|
| 615 |
+
all_targets = []
|
| 616 |
+
lenloader=len(self.valloader)
|
| 617 |
+
|
| 618 |
+
with torch.no_grad():
|
| 619 |
+
for batch_idx, data in looper:
|
| 620 |
+
image = data["image"].to(self.configs["device"], non_blocking=True)
|
| 621 |
+
target = data["labels"].to(self.configs["device"], non_blocking=True)
|
| 622 |
+
|
| 623 |
+
logits = self.model(image)
|
| 624 |
+
|
| 625 |
+
preds = torch.sigmoid(logits.float())
|
| 626 |
+
loss = self.criterion(preds, target)
|
| 627 |
+
|
| 628 |
+
val_loss += loss.item()
|
| 629 |
+
|
| 630 |
+
all_preds.append(preds.detach().cpu())
|
| 631 |
+
all_targets.append(target.detach().cpu())
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# === LIVE METRICS ===
|
| 635 |
+
current_loss = val_loss / (batch_idx + 1)
|
| 636 |
+
if (batch_idx + 1) % 200 == 0 and len(all_preds) > 0:
|
| 637 |
+
preds_np = torch.cat(all_preds).numpy()
|
| 638 |
+
targets_np = torch.cat(all_targets).numpy()
|
| 639 |
+
macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
|
| 640 |
+
micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
|
| 641 |
+
looper.set_postfix({
|
| 642 |
+
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
| 643 |
+
"batch":f"{batch_idx}/{lenloader}",
|
| 644 |
+
"loss": f"{current_loss:.3f}",
|
| 645 |
+
"macro": f"{macro_auc:.3f}",
|
| 646 |
+
"micro": f"{micro_auc:.3f}"
|
| 647 |
+
})
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
# === FINAL FULL VALIDATION METRICS ===
|
| 652 |
+
preds_full = torch.cat(all_preds).numpy()
|
| 653 |
+
targets_full = torch.cat(all_targets).numpy()
|
| 654 |
+
num_classes = 14
|
| 655 |
+
new_thresholds = [0.5] * num_classes # default
|
| 656 |
+
|
| 657 |
+
for class_idx in range(num_classes):
|
| 658 |
+
if targets_full[:, class_idx].sum() == 0:
|
| 659 |
+
# no positive samples, keep default 0.5
|
| 660 |
+
continue
|
| 661 |
+
|
| 662 |
+
thresholds = np.arange(0.1, 0.9, 0.02)
|
| 663 |
+
best_score = -1
|
| 664 |
+
best_threshold = 0.5
|
| 665 |
+
|
| 666 |
+
for threshold in thresholds:
|
| 667 |
+
preds_bin = (preds_full[:, class_idx] >= threshold).astype(int)
|
| 668 |
+
tn, fp, fn, tp = confusion_matrix(
|
| 669 |
+
targets_full[:, class_idx].astype(int),
|
| 670 |
+
preds_bin
|
| 671 |
+
).ravel()
|
| 672 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 673 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 674 |
+
score = sensitivity + specificity - 1
|
| 675 |
+
|
| 676 |
+
if score > best_score:
|
| 677 |
+
best_score = score
|
| 678 |
+
best_threshold = threshold
|
| 679 |
+
|
| 680 |
+
new_thresholds[class_idx] = best_threshold
|
| 681 |
+
|
| 682 |
+
# after loop:
|
| 683 |
+
self.optimal_thresholds = new_thresholds
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
final_loss = val_loss / lenloader
|
| 687 |
+
final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
|
| 688 |
+
final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
|
| 689 |
+
|
| 690 |
+
del all_preds, all_targets, preds_full, targets_full
|
| 691 |
+
|
| 692 |
+
return final_loss, final_macro_auc, final_micro_auc
|
| 693 |
+
def train(self):
|
| 694 |
+
|
| 695 |
+
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
|
| 696 |
+
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=True,file=self.traintee)
|
| 697 |
+
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=True,file=self.valtee)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
self.model.train()
|
| 701 |
+
self.schedular.step()
|
| 702 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 703 |
+
|
| 704 |
+
running_loss,macro_auc,micro_auc=self.train_epoch(epoch,trainlooper)
|
| 705 |
+
|
| 706 |
+
torch.cuda.synchronize()
|
| 707 |
+
torch.cuda.empty_cache()
|
| 708 |
+
|
| 709 |
+
val_loss,val_macro_auc,val_micro_auc=self.validate(epoch,vallooper)
|
| 710 |
+
|
| 711 |
+
torch.cuda.synchronize()
|
| 712 |
+
torch.cuda.empty_cache()
|
| 713 |
+
|
| 714 |
+
gc.collect()
|
| 715 |
+
|
| 716 |
+
if (self.history["val_macro_auc"] and (val_macro_auc>max(self.history["val_macro_auc"]))) or (self.history["val_micro_auc"] and val_micro_auc>max(self.history["val_micro_auc"])):
|
| 717 |
+
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),
|
| 718 |
+
"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch
|
| 719 |
+
,"thresholds":self.optimal_thresholds }
|
| 720 |
+
torch.save(checkpoint, self.configs["resume"])
|
| 721 |
+
|
| 722 |
+
print(f"epoch {epoch} train loss {running_loss} val loss {val_loss} val_macro_auc {val_macro_auc} val_micro_auc {val_micro_auc} train_macro_auc {macro_auc} train_micro_auc {micro_auc}")
|
| 723 |
+
|
| 724 |
+
self.history["train_loss"].append(float(running_loss))
|
| 725 |
+
self.history["val_loss"].append(float(val_loss))
|
| 726 |
+
self.history["train_macro_auc"].append(float(macro_auc))
|
| 727 |
+
self.history["val_macro_auc"].append(float(val_macro_auc))
|
| 728 |
+
self.history["train_micro_auc"].append(float(micro_auc))
|
| 729 |
+
self.history["val_micro_auc"].append(float(val_micro_auc))
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
historyfile=os.path.join(self.configs["logdir"],"history.json")
|
| 733 |
+
if os.path.exists(historyfile):
|
| 734 |
+
with open(historyfile,"r") as f:
|
| 735 |
+
history=json.load(f)
|
| 736 |
+
history["train_loss"]+=self.history["train_loss"]
|
| 737 |
+
history["val_loss"]+=self.history["val_loss"]
|
| 738 |
+
history["train_macro_auc"]+=self.history["train_macro_auc"]
|
| 739 |
+
history["val_macro_auc"]+=self.history["val_macro_auc"]
|
| 740 |
+
with open(historyfile,"w") as f:
|
| 741 |
+
json.dump(self.history,f)
|
| 742 |
+
f.close()
|
| 743 |
+
|
| 744 |
+
if epoch%10==0:Trainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
|
| 745 |
+
|
| 746 |
+
self.current_epoch=epoch
|
| 747 |
+
def test(self, model_path=None, return_preds=False):
|
| 748 |
+
"""
|
| 749 |
+
Run a complete test evaluation.
|
| 750 |
+
If `model_path` is given, load that checkpoint first.
|
| 751 |
+
Returns (macro_auc, micro_auc, per_class_auc_dict) or predictions if requested.
|
| 752 |
+
"""
|
| 753 |
+
if model_path:
|
| 754 |
+
ckpt = torch.load(model_path, map_location=self.configs["device"])
|
| 755 |
+
self.model.load_state_dict(ckpt["model"])
|
| 756 |
+
print(f"Loaded checkpoint {model_path}")
|
| 757 |
+
|
| 758 |
+
if self.testloader is None:
|
| 759 |
+
raise RuntimeError("No test loader – provide `test_csv` in config")
|
| 760 |
+
|
| 761 |
+
self.model.eval()
|
| 762 |
+
all_preds, all_targets = [], []
|
| 763 |
+
|
| 764 |
+
test_loss = 0.0
|
| 765 |
+
looper = tqdm(enumerate(self.testloader), total=len(self.testloader),
|
| 766 |
+
desc="Testing ",file=self.testtee)
|
| 767 |
+
|
| 768 |
+
with torch.inference_mode():
|
| 769 |
+
for batch_idx, data in looper:
|
| 770 |
+
img = data['image'].to(self.configs["device"], non_blocking=True)
|
| 771 |
+
tgt = data['labels'].to(self.configs["device"], non_blocking=True)
|
| 772 |
+
#image_1ch=data['image_1ch'].to(self.configs["device"], non_blocking=True)
|
| 773 |
+
|
| 774 |
+
logits = self.model(img)
|
| 775 |
+
if self.optimal_thresholds:
|
| 776 |
+
# class-wise thresholds in probability-space, e.g. list/array length C
|
| 777 |
+
# self.optimal_thresholds[c] = tau_c
|
| 778 |
+
taus = torch.tensor(self.optimal_thresholds, device=logits.device).view(1, -1)
|
| 779 |
+
|
| 780 |
+
# convert thresholds from prob to logit
|
| 781 |
+
margins = torch.log(taus / (1 - taus)) # shape [1, C]
|
| 782 |
+
|
| 783 |
+
# shift logits by the margin
|
| 784 |
+
# now BCEWithLogitsLoss thinks the decision boundary is at logits == margins
|
| 785 |
+
# equivalently: decision boundary in original logits is at 'margins'
|
| 786 |
+
logits = logits - margins
|
| 787 |
+
probs = torch.sigmoid(logits)
|
| 788 |
+
loss = self.criterion(probs, tgt)
|
| 789 |
+
test_loss += loss.item()
|
| 790 |
+
|
| 791 |
+
all_preds.append(probs.cpu())
|
| 792 |
+
all_targets.append(tgt.cpu())
|
| 793 |
+
|
| 794 |
+
# live stats
|
| 795 |
+
cur_loss = test_loss / (batch_idx + 1)
|
| 796 |
+
if all_preds:
|
| 797 |
+
p = torch.cat(all_preds).numpy()
|
| 798 |
+
t = torch.cat(all_targets).numpy()
|
| 799 |
+
macro = roc_auc_score(t, p, average='macro')
|
| 800 |
+
micro = roc_auc_score(t, p, average='micro')
|
| 801 |
+
else:
|
| 802 |
+
macro = micro = 0.0
|
| 803 |
+
looper.set_postfix(loss=f"{cur_loss:.4f}",
|
| 804 |
+
macro=f"{macro:.4f}",
|
| 805 |
+
micro=f"{micro:.4f}")
|
| 806 |
+
|
| 807 |
+
# ---- final metrics ----
|
| 808 |
+
preds = torch.cat(all_preds).numpy()
|
| 809 |
+
targets = torch.cat(all_targets).numpy()
|
| 810 |
+
final_loss = test_loss / len(self.testloader)
|
| 811 |
+
macro_auc = roc_auc_score(targets, preds, average='macro')
|
| 812 |
+
micro_auc = roc_auc_score(targets, preds, average='micro')
|
| 813 |
+
|
| 814 |
+
# per-class AUC
|
| 815 |
+
per_class = {}
|
| 816 |
+
for i, name in enumerate(self.train_dataset.get_label_names()):
|
| 817 |
+
if targets[:, i].sum() > 0: # avoid division-by-zero
|
| 818 |
+
per_class[name] = roc_auc_score(targets[:, i], preds[:, i])
|
| 819 |
+
else:
|
| 820 |
+
per_class[name] = float('nan')
|
| 821 |
+
|
| 822 |
+
# ---- pretty table ----
|
| 823 |
+
print("\n" + "="*80)
|
| 824 |
+
print(f"TEST RESULTS (loss={final_loss:.4f})")
|
| 825 |
+
print("="*80)
|
| 826 |
+
print(f"{'Pathology':<30} {'AUC':>8}")
|
| 827 |
+
print("-"*40)
|
| 828 |
+
for name, auc in per_class.items():
|
| 829 |
+
print(f"{name:<30} {auc:>8.4f}" if not np.isnan(auc) else f"{name:<30} {'N/A':>8}")
|
| 830 |
+
print("-"*40)
|
| 831 |
+
print(f"{'Macro AUC':<30} {macro_auc:>8.4f}")
|
| 832 |
+
print(f"{'Micro AUC':<30} {micro_auc:>8.4f}")
|
| 833 |
+
print("="*80)
|
| 834 |
+
|
| 835 |
+
if return_preds:
|
| 836 |
+
return macro_auc, micro_auc, per_class, (preds, targets)
|
| 837 |
+
return macro_auc, micro_auc, per_class
|
training logs/classifier/1/metrics.png
ADDED
|
training logs/classifier/11/metrics.png
ADDED
|
Git LFS Details
|
training logs/classifier/Events.docx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c683f158db1946053b66c0a5769962a8d51bad98058490cfa4aedba5582f45d
|
| 3 |
+
size 1680108
|
training logs/classifier/history.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"train_loss": [2.451549026515934, 2.324592100605649, 2.2527450666496867, 2.2051324946319886, 2.159476125092837, 2.1111394616786736, 2.057536503908261, 1.9841906148749953, 1.9176961825764776, 1.8619107825900996, 1.7461218646035648, 1.6598046678827294, 0.14859689984745142, 0.14480384502754282, 0.1413286671667666, 0.13848335853005814], "val_loss": [1.4677110827817452, 1.414012653402026, 1.3714177432875805, 1.2308028351501579, 1.3173101589027862, 1.327703529975834, 1.3617598478416082, 1.3191542980376254, 1.2529189027799352, 1.5297267407960213, 1.5704542597879037, 1.768142464009117, 0.1340647471810548, 0.13415449232644355, 0.13759738845883238, 0.13678586165817191], "train_macro_auc": [0.5915813508746322, 0.6911077814868211, 0.7180677573787114, 0.7330628389066268, 0.7446141059156149, 0.7555600431027648, 0.7611192999386304, 0.768528268418801, 0.7723525292111644, 0.7753779303902529, 0.7868138491223837, 0.7944627804675177, 0.7901445091432617, 0.8022003143300758, 0.815119789621127, 0.8244893337730623], "val_macro_auc": [0.6697306681955048, 0.7073059079688858, 0.7307953052152676, 0.7387044904612824, 0.7454194006038561, 0.7500732498762482, 0.7486698915393023, 0.7534324811456612, 0.7528406138149186, 0.7499375597482462, 0.7467580969017149, 0.7441320142787182, 0.7530519050114038, 0.7533548440220946, 0.749456570265707, 0.7485324814519589], "train_micro_auc": [0.7215386383447403, 0.7813286014278719, 0.7966446631205264, 0.8059628819522804, 0.8134203450690528, 0.8203666769222709, 0.8240765555633897, 0.8301711223679948, 0.833068249934892, 0.8355527627001551, 0.8440327424520534, 0.8495436085077932, 0.8464713300561717, 0.8555554974362551, 0.8642513599137882, 0.8715251218243751], "val_micro_auc": [0.830096540126698, 0.8465147824166648, 0.85730031114022, 0.8612904284897909, 0.8638291274014742, 0.8621590439819964, 0.8642181574740424, 0.8699097719546053, 0.8705562840474825, 0.8696004009592124, 0.8691057071199912, 0.867407803196298, 0.869385951883187, 0.8707534276030608, 0.8618755289952147, 0.8671620953710033]}
|
training logs/classifier/test_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training logs/classifier/training_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training logs/classifier/val_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training logs/mae/1/metrics.png
ADDED
|
training logs/mae/101/metrics.png
ADDED
|
training logs/mae/11/metrics.png
ADDED
|