Commit
Β·
d921913
1
Parent(s):
c4b099c
π₯ Add BYOL Mammogram Classification Model
Browse files- Self-supervised BYOL pre-training for mammogram analysis
- ResNet50 backbone with medical-optimized augmentations
- Aggressive background rejection and intelligent tissue segmentation
- A100 GPU optimized training with mixed precision
- Complete model checkpoints: best and final weights
- Classification fine-tuning pipeline with inference script
- Comprehensive model card and usage documentation
Key Features:
β
Medical-grade tile extraction (512x512px)
β
Multi-level background filtering
β
BYOL self-supervised learning
β
Ready for downstream classification tasks
β
Clinical-safe augmentation strategy
Model weights: 528MB total (best + final checkpoints)
Training: 100 epochs on high-quality breast tissue tiles
- CLASSIFICATION_GUIDE.md +330 -0
- README.md +210 -0
- classification_config.json +13 -0
- inference_classification.py +288 -0
- mammogram_byol_best.pth +3 -0
- mammogram_byol_final.pth +3 -0
- requirements.txt +12 -0
- train_byol_mammo.py +785 -0
- train_classification.py +517 -0
CLASSIFICATION_GUIDE.md
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# π― Classification Training Guide
|
| 2 |
+
|
| 3 |
+
Complete guide for fine-tuning the BYOL pre-trained model for multi-label classification.
|
| 4 |
+
|
| 5 |
+
## π Overview
|
| 6 |
+
|
| 7 |
+
After BYOL pre-training completes, you can fine-tune the model for classification using the `train_classification.py` script. This approach:
|
| 8 |
+
|
| 9 |
+
1. **Loads the BYOL checkpoint** with learned representations
|
| 10 |
+
2. **Freezes the backbone** initially (optional) to prevent overwriting good features
|
| 11 |
+
3. **Fine-tunes the classification head** with a higher learning rate
|
| 12 |
+
4. **Gradually unfreezes** the backbone for end-to-end fine-tuning
|
| 13 |
+
|
| 14 |
+
## ποΈ Data Preparation
|
| 15 |
+
|
| 16 |
+
### CSV Format
|
| 17 |
+
Create train/validation CSV files with this format:
|
| 18 |
+
|
| 19 |
+
```csv
|
| 20 |
+
tile_path,mass,calcification,architectural_distortion,asymmetry,normal,benign,malignant,birads_2,birads_3,birads_4
|
| 21 |
+
patient1_tile_001.png,1,0,0,0,0,1,0,0,1,0
|
| 22 |
+
patient1_tile_002.png,0,1,0,0,0,0,1,0,0,1
|
| 23 |
+
patient2_tile_001.png,0,0,0,0,1,1,0,1,0,0
|
| 24 |
+
...
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
**Requirements:**
|
| 28 |
+
- `tile_path`: Relative path to tile image
|
| 29 |
+
- **Class columns**: Binary labels (0/1) for each class
|
| 30 |
+
- **Multi-label support**: Each image can have multiple classes = 1
|
| 31 |
+
|
| 32 |
+
### Directory Structure
|
| 33 |
+
```
|
| 34 |
+
your_project/
|
| 35 |
+
βββ tiles/ # Directory containing tile images
|
| 36 |
+
β βββ patient1_tile_001.png
|
| 37 |
+
β βββ patient1_tile_002.png
|
| 38 |
+
β βββ ...
|
| 39 |
+
βββ train_labels.csv # Training labels
|
| 40 |
+
βββ val_labels.csv # Validation labels
|
| 41 |
+
βββ mammogram_byol_best.pth # BYOL checkpoint
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## π Quick Start
|
| 45 |
+
|
| 46 |
+
### 1. Basic Classification Training
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python train_classification.py \
|
| 50 |
+
--byol_checkpoint ./mammogram_byol_best.pth \
|
| 51 |
+
--train_csv ./train_labels.csv \
|
| 52 |
+
--val_csv ./val_labels.csv \
|
| 53 |
+
--tiles_dir ./tiles \
|
| 54 |
+
--class_names mass calcification architectural_distortion asymmetry normal benign malignant birads_2 birads_3 birads_4 \
|
| 55 |
+
--output_dir ./classification_results
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### 2. With Custom Configuration
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python train_classification.py \
|
| 62 |
+
--byol_checkpoint ./mammogram_byol_best.pth \
|
| 63 |
+
--train_csv ./train_labels.csv \
|
| 64 |
+
--val_csv ./val_labels.csv \
|
| 65 |
+
--tiles_dir ./tiles \
|
| 66 |
+
--class_names mass calcification normal \
|
| 67 |
+
--config ./classification_config.json \
|
| 68 |
+
--output_dir ./classification_results \
|
| 69 |
+
--wandb_project my-mammogram-classification
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 3. Quick Testing (Limited Dataset)
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
python train_classification.py \
|
| 76 |
+
--byol_checkpoint ./mammogram_byol_best.pth \
|
| 77 |
+
--train_csv ./train_labels.csv \
|
| 78 |
+
--val_csv ./val_labels.csv \
|
| 79 |
+
--tiles_dir ./tiles \
|
| 80 |
+
--class_names mass calcification normal \
|
| 81 |
+
--max_samples 1000 \
|
| 82 |
+
--output_dir ./test_results
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## βοΈ Configuration Options
|
| 86 |
+
|
| 87 |
+
### Key Parameters
|
| 88 |
+
|
| 89 |
+
| Parameter | Default | Description |
|
| 90 |
+
|-----------|---------|-------------|
|
| 91 |
+
| `batch_size` | 32 | Batch size for training |
|
| 92 |
+
| `epochs` | 50 | Number of training epochs |
|
| 93 |
+
| `lr_backbone` | 1e-5 | Learning rate for pre-trained backbone |
|
| 94 |
+
| `lr_head` | 1e-3 | Learning rate for classification head |
|
| 95 |
+
| `freeze_backbone_epochs` | 10 | Epochs to freeze backbone (0 = never freeze) |
|
| 96 |
+
| `label_smoothing` | 0.1 | Label smoothing for regularization |
|
| 97 |
+
| `gradient_clip` | 1.0 | Gradient clipping max norm |
|
| 98 |
+
|
| 99 |
+
### Custom Configuration File
|
| 100 |
+
|
| 101 |
+
Create `my_config.json`:
|
| 102 |
+
```json
|
| 103 |
+
{
|
| 104 |
+
"batch_size": 64,
|
| 105 |
+
"epochs": 100,
|
| 106 |
+
"lr_backbone": 5e-6,
|
| 107 |
+
"lr_head": 2e-3,
|
| 108 |
+
"freeze_backbone_epochs": 20,
|
| 109 |
+
"label_smoothing": 0.2,
|
| 110 |
+
"weight_decay": 1e-3
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## π Expected Training Process
|
| 115 |
+
|
| 116 |
+
### Phase 1: Backbone Frozen (Epochs 1-10)
|
| 117 |
+
```
|
| 118 |
+
π§ Epoch 1: Backbone frozen (training only classification head)
|
| 119 |
+
Epoch 1/50:
|
| 120 |
+
Train Loss: 0.6234
|
| 121 |
+
Val Loss: 0.5891
|
| 122 |
+
Mean AUC: 0.7123
|
| 123 |
+
Mean AP: 0.6894
|
| 124 |
+
Exact Match: 0.4512
|
| 125 |
+
β
New best model saved (AUC: 0.7123)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Phase 2: End-to-End Fine-tuning (Epochs 11-50)
|
| 129 |
+
```
|
| 130 |
+
Epoch 15/50:
|
| 131 |
+
Train Loss: 0.3456
|
| 132 |
+
Val Loss: 0.3891
|
| 133 |
+
Mean AUC: 0.8567
|
| 134 |
+
Mean AP: 0.8234
|
| 135 |
+
Exact Match: 0.6789
|
| 136 |
+
β
New best model saved (AUC: 0.8567)
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## π Making Predictions
|
| 140 |
+
|
| 141 |
+
### Single Image Inference
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
python inference_classification.py \
|
| 145 |
+
--model_path ./classification_results/best_classification_model.pth \
|
| 146 |
+
--image_path ./test_image.png \
|
| 147 |
+
--threshold 0.5
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
**Output:**
|
| 151 |
+
```
|
| 152 |
+
πΈ Image 1: test_image.png
|
| 153 |
+
π Top prediction: mass (0.847)
|
| 154 |
+
π All probabilities:
|
| 155 |
+
β
mass : 0.847
|
| 156 |
+
β calcification : 0.234
|
| 157 |
+
β normal : 0.123
|
| 158 |
+
β architectural_distortion: 0.089
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Batch Inference
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
python inference_classification.py \
|
| 165 |
+
--model_path ./classification_results/best_classification_model.pth \
|
| 166 |
+
--images_dir ./test_images \
|
| 167 |
+
--output_json ./predictions.json \
|
| 168 |
+
--batch_size 64
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### Programmatic Usage
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
import torch
|
| 175 |
+
from train_byol_mammo import MammogramBYOL
|
| 176 |
+
from inference_classification import load_classification_model, create_inference_transform
|
| 177 |
+
|
| 178 |
+
# Load model
|
| 179 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 180 |
+
model, class_names, config = load_classification_model(
|
| 181 |
+
"./classification_results/best_classification_model.pth", device
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Make prediction
|
| 185 |
+
transform = create_inference_transform()
|
| 186 |
+
image = Image.open("test.png").convert('RGB')
|
| 187 |
+
input_tensor = transform(image).unsqueeze(0).to(device)
|
| 188 |
+
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
logits = model.classify(input_tensor)
|
| 191 |
+
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
|
| 192 |
+
|
| 193 |
+
# Get results
|
| 194 |
+
for i, class_name in enumerate(class_names):
|
| 195 |
+
print(f"{class_name}: {probabilities[i]:.3f}")
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## π Monitoring Training
|
| 199 |
+
|
| 200 |
+
### Weights & Biases Integration
|
| 201 |
+
|
| 202 |
+
The script automatically logs to W&B:
|
| 203 |
+
- Training/validation loss curves
|
| 204 |
+
- Per-class AUC and Average Precision
|
| 205 |
+
- Learning rate schedules
|
| 206 |
+
- Model hyperparameters
|
| 207 |
+
|
| 208 |
+
### Metrics Explained
|
| 209 |
+
|
| 210 |
+
- **AUC (Area Under Curve)**: Measures ranking quality (0-1, higher better)
|
| 211 |
+
- **AP (Average Precision)**: Summarizes precision-recall curve (0-1, higher better)
|
| 212 |
+
- **Exact Match Accuracy**: Percentage where ALL labels are predicted correctly
|
| 213 |
+
- **Per-Class Accuracy**: Binary accuracy for each individual class
|
| 214 |
+
|
| 215 |
+
## πΎ Output Files
|
| 216 |
+
|
| 217 |
+
Training creates:
|
| 218 |
+
```
|
| 219 |
+
classification_results/
|
| 220 |
+
βββ best_classification_model.pth # Best model by validation AUC
|
| 221 |
+
βββ final_classification_model.pth # Final model after all epochs
|
| 222 |
+
βββ classification_epoch_10.pth # Periodic checkpoints
|
| 223 |
+
βββ classification_epoch_20.pth
|
| 224 |
+
βββ ...
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
Each checkpoint contains:
|
| 228 |
+
- Model state dict
|
| 229 |
+
- Optimizer state
|
| 230 |
+
- Training configuration
|
| 231 |
+
- Class names
|
| 232 |
+
- Validation metrics
|
| 233 |
+
|
| 234 |
+
## π οΈ Advanced Usage
|
| 235 |
+
|
| 236 |
+
### Custom Loss Functions
|
| 237 |
+
|
| 238 |
+
For imbalanced datasets, modify the loss function:
|
| 239 |
+
|
| 240 |
+
```python
|
| 241 |
+
# Calculate positive weights for each class
|
| 242 |
+
pos_counts = df[class_names].sum()
|
| 243 |
+
neg_counts = len(df) - pos_counts
|
| 244 |
+
pos_weight = torch.tensor(neg_counts / pos_counts).to(device)
|
| 245 |
+
|
| 246 |
+
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
### Transfer Learning Strategies
|
| 250 |
+
|
| 251 |
+
1. **Conservative**: Freeze backbone for many epochs, low backbone LR
|
| 252 |
+
- `freeze_backbone_epochs = 20`
|
| 253 |
+
- `lr_backbone = 1e-6`
|
| 254 |
+
|
| 255 |
+
2. **Aggressive**: Unfreeze early, higher backbone LR
|
| 256 |
+
- `freeze_backbone_epochs = 5`
|
| 257 |
+
- `lr_backbone = 1e-4`
|
| 258 |
+
|
| 259 |
+
3. **Progressive**: Gradually unfreeze layers (requires code modification)
|
| 260 |
+
|
| 261 |
+
### Multi-GPU Training
|
| 262 |
+
|
| 263 |
+
For multiple GPUs, wrap the model:
|
| 264 |
+
```python
|
| 265 |
+
if torch.cuda.device_count() > 1:
|
| 266 |
+
model = nn.DataParallel(model)
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
## β οΈ Troubleshooting
|
| 270 |
+
|
| 271 |
+
### Common Issues
|
| 272 |
+
|
| 273 |
+
**Low Validation Performance:**
|
| 274 |
+
- Increase `freeze_backbone_epochs` to 15-20
|
| 275 |
+
- Reduce `lr_backbone` to 5e-6 or 1e-6
|
| 276 |
+
- Check for data leakage between train/val sets
|
| 277 |
+
|
| 278 |
+
**Overfitting:**
|
| 279 |
+
- Increase `label_smoothing` to 0.2-0.3
|
| 280 |
+
- Add more dropout (modify model architecture)
|
| 281 |
+
- Reduce learning rates
|
| 282 |
+
- Use early stopping
|
| 283 |
+
|
| 284 |
+
**Memory Issues:**
|
| 285 |
+
- Reduce `batch_size` to 16 or 8
|
| 286 |
+
- Reduce `num_workers` to 4
|
| 287 |
+
- Use gradient checkpointing (requires code modification)
|
| 288 |
+
|
| 289 |
+
**Class Imbalance:**
|
| 290 |
+
- Use `pos_weight` in loss function
|
| 291 |
+
- Focus on per-class AUC rather than accuracy
|
| 292 |
+
- Consider focal loss for extreme imbalance
|
| 293 |
+
|
| 294 |
+
## π― Best Practices
|
| 295 |
+
|
| 296 |
+
1. **Start Conservative**: Use default settings first
|
| 297 |
+
2. **Monitor Per-Class Metrics**: Some classes may need special attention
|
| 298 |
+
3. **Validate Data**: Ensure no train/val overlap
|
| 299 |
+
4. **Checkpoint Often**: Training can be interrupted
|
| 300 |
+
5. **Use Multiple Runs**: Average results across random seeds
|
| 301 |
+
6. **Test Thoroughly**: Use held-out test set for final evaluation
|
| 302 |
+
|
| 303 |
+
## π Complete Example
|
| 304 |
+
|
| 305 |
+
Here's a full workflow from BYOL training to classification:
|
| 306 |
+
|
| 307 |
+
```bash
|
| 308 |
+
# 1. Train BYOL (this takes 4-5 hours on A100)
|
| 309 |
+
python train_byol_mammo.py
|
| 310 |
+
|
| 311 |
+
# 2. Prepare classification data (create CSVs with labels)
|
| 312 |
+
# ... prepare train_labels.csv and val_labels.csv ...
|
| 313 |
+
|
| 314 |
+
# 3. Fine-tune for classification (1-2 hours)
|
| 315 |
+
python train_classification.py \
|
| 316 |
+
--byol_checkpoint ./mammogram_byol_best.pth \
|
| 317 |
+
--train_csv ./train_labels.csv \
|
| 318 |
+
--val_csv ./val_labels.csv \
|
| 319 |
+
--tiles_dir ./tiles \
|
| 320 |
+
--class_names mass calcification architectural_distortion asymmetry normal \
|
| 321 |
+
--output_dir ./classification_results
|
| 322 |
+
|
| 323 |
+
# 4. Run inference on new images
|
| 324 |
+
python inference_classification.py \
|
| 325 |
+
--model_path ./classification_results/best_classification_model.pth \
|
| 326 |
+
--images_dir ./new_patient_tiles \
|
| 327 |
+
--output_json ./patient_predictions.json
|
| 328 |
+
```
|
| 329 |
+
|
| 330 |
+
This gives you a complete pipeline from self-supervised pre-training to production-ready classification! π
|
README.md
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- medical-imaging
|
| 8 |
+
- mammography
|
| 9 |
+
- self-supervised-learning
|
| 10 |
+
- byol
|
| 11 |
+
- breast-cancer
|
| 12 |
+
- computer-vision
|
| 13 |
+
- resnet50
|
| 14 |
+
pipeline_tag: image-classification
|
| 15 |
+
datasets:
|
| 16 |
+
- mammogram-breast-tissue-tiles
|
| 17 |
+
metrics:
|
| 18 |
+
- accuracy
|
| 19 |
+
- precision
|
| 20 |
+
- recall
|
| 21 |
+
- f1
|
| 22 |
+
base_model:
|
| 23 |
+
- microsoft/resnet-50
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
# BYOL Mammogram Classification Model
|
| 27 |
+
|
| 28 |
+
A self-supervised learning model for mammogram analysis using Bootstrap Your Own Latent (BYOL) pre-training with ResNet50 backbone.
|
| 29 |
+
|
| 30 |
+
## Model Description
|
| 31 |
+
|
| 32 |
+
This model implements BYOL (Bootstrap Your Own Latent) self-supervised pre-training on mammogram breast tissue tiles, followed by fine-tuning for classification tasks. The model is designed specifically for medical imaging applications with aggressive background rejection and intelligent tissue segmentation.
|
| 33 |
+
|
| 34 |
+
### Key Features
|
| 35 |
+
|
| 36 |
+
- **Self-supervised pre-training**: Uses BYOL to learn meaningful representations from unlabeled mammogram data
|
| 37 |
+
- **Aggressive background rejection**: Multi-level filtering eliminates empty space and background tiles
|
| 38 |
+
- **Medical-optimized augmentations**: Preserves anatomical details while providing effective augmentation
|
| 39 |
+
- **High-quality tile extraction**: Intelligent breast tissue segmentation with frequency-based selection
|
| 40 |
+
- **A100 GPU optimized**: Mixed precision training with advanced optimizations
|
| 41 |
+
|
| 42 |
+
## Model Architecture
|
| 43 |
+
|
| 44 |
+
- **Backbone**: ResNet50 (ImageNet pre-trained β BYOL fine-tuned)
|
| 45 |
+
- **Input dimension**: 2048 (ResNet50 features)
|
| 46 |
+
- **Hidden dimension**: 4096
|
| 47 |
+
- **Projection dimension**: 256
|
| 48 |
+
- **Tile size**: 512x512 pixels
|
| 49 |
+
- **Input format**: RGB (grayscale mammograms converted to RGB)
|
| 50 |
+
|
| 51 |
+
## Training Details
|
| 52 |
+
|
| 53 |
+
### BYOL Pre-training
|
| 54 |
+
- **Epochs**: 100
|
| 55 |
+
- **Batch size**: 32 (A100 optimized)
|
| 56 |
+
- **Learning rate**: 2e-3 with warmup
|
| 57 |
+
- **Optimizer**: AdamW with cosine annealing
|
| 58 |
+
- **Mixed precision**: Enabled for A100 optimization
|
| 59 |
+
- **Momentum updates**: Per-step momentum scheduling (0.996 β 1.0)
|
| 60 |
+
|
| 61 |
+
### Data Processing
|
| 62 |
+
- **Tile extraction**: 512x512 pixels with 50% overlap
|
| 63 |
+
- **Background rejection**: Multiple criteria including intensity, frequency energy, and tissue ratio
|
| 64 |
+
- **Minimum breast ratio**: 15% (increased from typical 30%)
|
| 65 |
+
- **Minimum frequency energy**: 0.03 (aggressive threshold)
|
| 66 |
+
- **Augmentations**: Medical-safe rotations, flips, color jittering, perspective transforms
|
| 67 |
+
|
| 68 |
+
## Usage
|
| 69 |
+
|
| 70 |
+
### Loading the Model
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
import torch
|
| 74 |
+
from train_byol_mammo import MammogramBYOL
|
| 75 |
+
from torchvision import models
|
| 76 |
+
import torch.nn as nn
|
| 77 |
+
|
| 78 |
+
# Load the pre-trained BYOL model
|
| 79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
|
| 81 |
+
# Create ResNet50 backbone
|
| 82 |
+
resnet = models.resnet50(weights=None)
|
| 83 |
+
backbone = nn.Sequential(*list(resnet.children())[:-1])
|
| 84 |
+
|
| 85 |
+
# Initialize BYOL model
|
| 86 |
+
model = MammogramBYOL(
|
| 87 |
+
backbone=backbone,
|
| 88 |
+
input_dim=2048,
|
| 89 |
+
hidden_dim=4096,
|
| 90 |
+
proj_dim=256
|
| 91 |
+
).to(device)
|
| 92 |
+
|
| 93 |
+
# Load pre-trained weights
|
| 94 |
+
checkpoint = torch.load('mammogram_byol_best.pth', map_location=device)
|
| 95 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 96 |
+
model.eval()
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### Feature Extraction
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
# Extract features from mammogram tiles
|
| 103 |
+
def extract_features(image_tensor):
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
features = model.get_features(image_tensor)
|
| 106 |
+
return features
|
| 107 |
+
|
| 108 |
+
# Example usage
|
| 109 |
+
image = torch.randn(1, 3, 512, 512).to(device) # Example input
|
| 110 |
+
features = extract_features(image) # Returns 2048-dim features
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### Classification Fine-tuning
|
| 114 |
+
|
| 115 |
+
Use the provided `train_classification.py` script for downstream classification tasks:
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
python train_classification.py \
|
| 119 |
+
--byol_checkpoint ./mammogram_byol_best.pth \
|
| 120 |
+
--train_csv ./train_labels.csv \
|
| 121 |
+
--val_csv ./val_labels.csv \
|
| 122 |
+
--tiles_dir ./tiles/ \
|
| 123 |
+
--output_dir ./classification_results/
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
## File Structure
|
| 127 |
+
|
| 128 |
+
```
|
| 129 |
+
BYOL_Mammogram/
|
| 130 |
+
βββ mammogram_byol_best.pth # Best BYOL checkpoint
|
| 131 |
+
βββ mammogram_byol_final.pth # Final BYOL checkpoint
|
| 132 |
+
βββ train_byol_mammo.py # BYOL pre-training script
|
| 133 |
+
βββ train_classification.py # Classification fine-tuning
|
| 134 |
+
βββ inference_classification.py # Inference script
|
| 135 |
+
βββ classification_config.json # Classification configuration
|
| 136 |
+
βββ CLASSIFICATION_GUIDE.md # Detailed training guide
|
| 137 |
+
βββ requirements.txt # Dependencies
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## Performance
|
| 141 |
+
|
| 142 |
+
### Pre-training Results
|
| 143 |
+
- **Dataset**: High-quality breast tissue tiles with aggressive background rejection
|
| 144 |
+
- **Efficiency**: ~15-20% tile selection rate (quality over quantity)
|
| 145 |
+
- **Background contamination**: 0% (eliminated during extraction)
|
| 146 |
+
- **Training time**: ~100 epochs on A100 GPU
|
| 147 |
+
|
| 148 |
+
### Key Metrics
|
| 149 |
+
- **Average breast tissue per tile**: >15%
|
| 150 |
+
- **Average frequency energy**: >0.03
|
| 151 |
+
- **Tile quality**: Medical-grade with preserved anatomical details
|
| 152 |
+
|
| 153 |
+
## Technical Specifications
|
| 154 |
+
|
| 155 |
+
### Hardware Requirements
|
| 156 |
+
- **GPU**: A100 (40GB/80GB) recommended
|
| 157 |
+
- **Memory**: 35-40GB GPU memory for training
|
| 158 |
+
- **CPU**: 16+ cores for data loading
|
| 159 |
+
|
| 160 |
+
### Dependencies
|
| 161 |
+
```
|
| 162 |
+
torch>=2.0.0
|
| 163 |
+
torchvision>=0.15.0
|
| 164 |
+
lightly>=1.4.0
|
| 165 |
+
opencv-python>=4.8.0
|
| 166 |
+
scipy>=1.10.0
|
| 167 |
+
numpy>=1.24.0
|
| 168 |
+
Pillow>=9.5.0
|
| 169 |
+
tqdm>=4.65.0
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
## Medical Imaging Considerations
|
| 173 |
+
|
| 174 |
+
### Data Safety
|
| 175 |
+
- **Augmentation strategy**: Preserves medical accuracy while providing diversity
|
| 176 |
+
- **Background rejection**: Prevents training on non-diagnostic regions
|
| 177 |
+
- **Tissue focus**: Emphasizes clinically relevant breast tissue areas
|
| 178 |
+
|
| 179 |
+
### Clinical Applications
|
| 180 |
+
- **Screening support**: Potential for computer-aided detection
|
| 181 |
+
- **Research tool**: Feature extraction for medical AI research
|
| 182 |
+
- **Educational**: Understanding mammogram image analysis
|
| 183 |
+
|
| 184 |
+
## Limitations
|
| 185 |
+
|
| 186 |
+
- **Domain specific**: Trained specifically on mammogram data
|
| 187 |
+
- **Preprocessing required**: Requires proper tissue segmentation
|
| 188 |
+
- **Computational intensive**: Large model requiring substantial GPU resources
|
| 189 |
+
- **Medical supervision**: Requires clinical validation for any medical application
|
| 190 |
+
|
| 191 |
+
## Citation
|
| 192 |
+
|
| 193 |
+
If you use this model in your research, please cite:
|
| 194 |
+
|
| 195 |
+
```bibtex
|
| 196 |
+
@model{byol_mammogram_2024,
|
| 197 |
+
title={BYOL Mammogram Classification Model},
|
| 198 |
+
author={PranayPalem},
|
| 199 |
+
year={2024},
|
| 200 |
+
url={https://huggingface.co/PranayPalem/BYOL_Mammogram}
|
| 201 |
+
}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
## License
|
| 205 |
+
|
| 206 |
+
MIT License - See LICENSE file for details.
|
| 207 |
+
|
| 208 |
+
## Disclaimer
|
| 209 |
+
|
| 210 |
+
This model is for research purposes only and should not be used for clinical diagnosis without proper validation and medical supervision. Always consult healthcare professionals for medical decisions.
|
classification_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"batch_size": 32,
|
| 3 |
+
"num_workers": 8,
|
| 4 |
+
"epochs": 50,
|
| 5 |
+
"lr_backbone": 1e-5,
|
| 6 |
+
"lr_head": 1e-3,
|
| 7 |
+
"weight_decay": 1e-4,
|
| 8 |
+
"warmup_epochs": 5,
|
| 9 |
+
"freeze_backbone_epochs": 10,
|
| 10 |
+
"label_smoothing": 0.1,
|
| 11 |
+
"dropout_rate": 0.3,
|
| 12 |
+
"gradient_clip": 1.0
|
| 13 |
+
}
|
inference_classification.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
inference_classification.py
|
| 4 |
+
|
| 5 |
+
Inference script for the fine-tuned BYOL classification model.
|
| 6 |
+
Demonstrates how to load the trained model and make predictions on new images.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import torchvision.transforms as T
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import argparse
|
| 16 |
+
from typing import List, Dict
|
| 17 |
+
import json
|
| 18 |
+
|
| 19 |
+
from train_byol_mammo import MammogramBYOL
|
| 20 |
+
from train_classification import ClassificationModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_classification_model(checkpoint_path: str, device: torch.device):
|
| 24 |
+
"""Load the fine-tuned classification model."""
|
| 25 |
+
|
| 26 |
+
print(f"π₯ Loading classification model: {checkpoint_path}")
|
| 27 |
+
|
| 28 |
+
# Load checkpoint
|
| 29 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 30 |
+
|
| 31 |
+
# Get configuration
|
| 32 |
+
config = checkpoint.get('config', {})
|
| 33 |
+
class_names = checkpoint['class_names']
|
| 34 |
+
num_classes = len(class_names)
|
| 35 |
+
|
| 36 |
+
# Create BYOL model
|
| 37 |
+
from torchvision import models
|
| 38 |
+
resnet = models.resnet50(weights=None)
|
| 39 |
+
backbone = nn.Sequential(*list(resnet.children())[:-1])
|
| 40 |
+
|
| 41 |
+
byol_model = MammogramBYOL(
|
| 42 |
+
backbone=backbone,
|
| 43 |
+
input_dim=2048,
|
| 44 |
+
hidden_dim=4096,
|
| 45 |
+
proj_dim=256
|
| 46 |
+
).to(device)
|
| 47 |
+
|
| 48 |
+
# Create classification model
|
| 49 |
+
model = ClassificationModel(byol_model, num_classes).to(device)
|
| 50 |
+
|
| 51 |
+
# Load weights
|
| 52 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
# Get metrics from checkpoint
|
| 56 |
+
val_metrics = checkpoint.get('val_metrics', {})
|
| 57 |
+
epoch = checkpoint.get('epoch', 'unknown')
|
| 58 |
+
|
| 59 |
+
print(f"β
Loaded model from epoch {epoch}")
|
| 60 |
+
print(f"π Classes: {class_names}")
|
| 61 |
+
if 'mean_auc' in val_metrics:
|
| 62 |
+
print(f"π― Validation AUC: {val_metrics['mean_auc']:.4f}")
|
| 63 |
+
|
| 64 |
+
return model, class_names, config
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_inference_transform(tile_size: int = 512):
|
| 68 |
+
"""Create transforms for inference (no augmentation)."""
|
| 69 |
+
return T.Compose([
|
| 70 |
+
T.Resize((tile_size, tile_size)),
|
| 71 |
+
T.ToTensor(),
|
| 72 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 73 |
+
])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def predict_single_image(model: nn.Module, image_path: str, transform,
|
| 77 |
+
class_names: List[str], device: torch.device,
|
| 78 |
+
threshold: float = 0.5) -> Dict:
|
| 79 |
+
"""Make prediction on a single image."""
|
| 80 |
+
|
| 81 |
+
# Load and preprocess image
|
| 82 |
+
image = Image.open(image_path).convert('RGB')
|
| 83 |
+
input_tensor = transform(image).unsqueeze(0).to(device)
|
| 84 |
+
|
| 85 |
+
# Make prediction
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
logits = model(input_tensor)
|
| 88 |
+
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
|
| 89 |
+
|
| 90 |
+
# Create results
|
| 91 |
+
results = {
|
| 92 |
+
'image_path': str(image_path),
|
| 93 |
+
'predictions': {},
|
| 94 |
+
'binary_predictions': {},
|
| 95 |
+
'max_class': None,
|
| 96 |
+
'max_probability': 0.0
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
max_prob = 0.0
|
| 100 |
+
max_class = None
|
| 101 |
+
|
| 102 |
+
for i, class_name in enumerate(class_names):
|
| 103 |
+
prob = float(probabilities[i])
|
| 104 |
+
binary_pred = prob > threshold
|
| 105 |
+
|
| 106 |
+
results['predictions'][class_name] = prob
|
| 107 |
+
results['binary_predictions'][class_name] = binary_pred
|
| 108 |
+
|
| 109 |
+
if prob > max_prob:
|
| 110 |
+
max_prob = prob
|
| 111 |
+
max_class = class_name
|
| 112 |
+
|
| 113 |
+
results['max_class'] = max_class
|
| 114 |
+
results['max_probability'] = max_prob
|
| 115 |
+
|
| 116 |
+
return results
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def predict_batch(model: nn.Module, image_paths: List[str], transform,
|
| 120 |
+
class_names: List[str], device: torch.device,
|
| 121 |
+
batch_size: int = 32, threshold: float = 0.5) -> List[Dict]:
|
| 122 |
+
"""Make predictions on a batch of images efficiently."""
|
| 123 |
+
|
| 124 |
+
results = []
|
| 125 |
+
|
| 126 |
+
for i in range(0, len(image_paths), batch_size):
|
| 127 |
+
batch_paths = image_paths[i:i + batch_size]
|
| 128 |
+
|
| 129 |
+
# Load and preprocess batch
|
| 130 |
+
batch_tensors = []
|
| 131 |
+
for path in batch_paths:
|
| 132 |
+
image = Image.open(path).convert('RGB')
|
| 133 |
+
tensor = transform(image)
|
| 134 |
+
batch_tensors.append(tensor)
|
| 135 |
+
|
| 136 |
+
batch_input = torch.stack(batch_tensors).to(device)
|
| 137 |
+
|
| 138 |
+
# Make predictions
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
logits = model(batch_input)
|
| 141 |
+
probabilities = torch.sigmoid(logits).cpu().numpy()
|
| 142 |
+
|
| 143 |
+
# Process results
|
| 144 |
+
for j, path in enumerate(batch_paths):
|
| 145 |
+
probs = probabilities[j]
|
| 146 |
+
|
| 147 |
+
result = {
|
| 148 |
+
'image_path': str(path),
|
| 149 |
+
'predictions': {},
|
| 150 |
+
'binary_predictions': {},
|
| 151 |
+
'max_class': None,
|
| 152 |
+
'max_probability': 0.0
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
max_prob = 0.0
|
| 156 |
+
max_class = None
|
| 157 |
+
|
| 158 |
+
for k, class_name in enumerate(class_names):
|
| 159 |
+
prob = float(probs[k])
|
| 160 |
+
binary_pred = prob > threshold
|
| 161 |
+
|
| 162 |
+
result['predictions'][class_name] = prob
|
| 163 |
+
result['binary_predictions'][class_name] = binary_pred
|
| 164 |
+
|
| 165 |
+
if prob > max_prob:
|
| 166 |
+
max_prob = prob
|
| 167 |
+
max_class = class_name
|
| 168 |
+
|
| 169 |
+
result['max_class'] = max_class
|
| 170 |
+
result['max_probability'] = max_prob
|
| 171 |
+
|
| 172 |
+
results.append(result)
|
| 173 |
+
|
| 174 |
+
return results
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def print_prediction_results(results: List[Dict], top_k: int = 5):
|
| 178 |
+
"""Print prediction results in a nice format."""
|
| 179 |
+
|
| 180 |
+
for i, result in enumerate(results[:top_k]):
|
| 181 |
+
print(f"\nπΈ Image {i+1}: {Path(result['image_path']).name}")
|
| 182 |
+
print(f"π Top prediction: {result['max_class']} ({result['max_probability']:.3f})")
|
| 183 |
+
|
| 184 |
+
print("π All probabilities:")
|
| 185 |
+
# Sort by probability
|
| 186 |
+
sorted_preds = sorted(result['predictions'].items(),
|
| 187 |
+
key=lambda x: x[1], reverse=True)
|
| 188 |
+
|
| 189 |
+
for class_name, prob in sorted_preds:
|
| 190 |
+
binary = "β
" if result['binary_predictions'][class_name] else "β"
|
| 191 |
+
print(f" {binary} {class_name:15s}: {prob:.3f}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main():
|
| 195 |
+
parser = argparse.ArgumentParser(description='Inference with fine-tuned BYOL classification model')
|
| 196 |
+
parser.add_argument('--model_path', type=str, required=True,
|
| 197 |
+
help='Path to fine-tuned classification model (.pth file)')
|
| 198 |
+
parser.add_argument('--image_path', type=str, default=None,
|
| 199 |
+
help='Path to single image for inference')
|
| 200 |
+
parser.add_argument('--images_dir', type=str, default=None,
|
| 201 |
+
help='Directory containing images for batch inference')
|
| 202 |
+
parser.add_argument('--output_json', type=str, default=None,
|
| 203 |
+
help='Save results to JSON file')
|
| 204 |
+
parser.add_argument('--threshold', type=float, default=0.5,
|
| 205 |
+
help='Classification threshold (default: 0.5)')
|
| 206 |
+
parser.add_argument('--batch_size', type=int, default=32,
|
| 207 |
+
help='Batch size for inference')
|
| 208 |
+
parser.add_argument('--tile_size', type=int, default=512,
|
| 209 |
+
help='Input tile size')
|
| 210 |
+
|
| 211 |
+
args = parser.parse_args()
|
| 212 |
+
|
| 213 |
+
# Validate arguments
|
| 214 |
+
if not args.image_path and not args.images_dir:
|
| 215 |
+
parser.error("Must specify either --image_path or --images_dir")
|
| 216 |
+
|
| 217 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 218 |
+
|
| 219 |
+
print("π BYOL Classification Inference")
|
| 220 |
+
print("=" * 40)
|
| 221 |
+
print(f"Device: {device}")
|
| 222 |
+
print(f"Threshold: {args.threshold}")
|
| 223 |
+
|
| 224 |
+
# Load model
|
| 225 |
+
model, class_names, config = load_classification_model(args.model_path, device)
|
| 226 |
+
|
| 227 |
+
# Create transform
|
| 228 |
+
transform = create_inference_transform(args.tile_size)
|
| 229 |
+
|
| 230 |
+
# Get image paths
|
| 231 |
+
if args.image_path:
|
| 232 |
+
image_paths = [args.image_path]
|
| 233 |
+
print(f"πΈ Single image inference: {args.image_path}")
|
| 234 |
+
else:
|
| 235 |
+
images_dir = Path(args.images_dir)
|
| 236 |
+
image_paths = list(images_dir.glob("*.png")) + list(images_dir.glob("*.jpg"))
|
| 237 |
+
print(f"π Batch inference: {len(image_paths)} images from {images_dir}")
|
| 238 |
+
|
| 239 |
+
if len(image_paths) == 0:
|
| 240 |
+
print("β No images found!")
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
# Make predictions
|
| 244 |
+
if len(image_paths) == 1:
|
| 245 |
+
# Single image
|
| 246 |
+
result = predict_single_image(
|
| 247 |
+
model, image_paths[0], transform, class_names, device, args.threshold
|
| 248 |
+
)
|
| 249 |
+
results = [result]
|
| 250 |
+
else:
|
| 251 |
+
# Batch processing
|
| 252 |
+
print(f"π Processing {len(image_paths)} images in batches of {args.batch_size}...")
|
| 253 |
+
results = predict_batch(
|
| 254 |
+
model, image_paths, transform, class_names, device,
|
| 255 |
+
args.batch_size, args.threshold
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Print results
|
| 259 |
+
print(f"\nπ― INFERENCE RESULTS")
|
| 260 |
+
print("=" * 40)
|
| 261 |
+
print_prediction_results(results)
|
| 262 |
+
|
| 263 |
+
# Save to JSON if requested
|
| 264 |
+
if args.output_json:
|
| 265 |
+
with open(args.output_json, 'w') as f:
|
| 266 |
+
json.dump(results, f, indent=2)
|
| 267 |
+
print(f"\nπΎ Results saved to: {args.output_json}")
|
| 268 |
+
|
| 269 |
+
# Summary statistics
|
| 270 |
+
print(f"\nπ SUMMARY")
|
| 271 |
+
print("=" * 40)
|
| 272 |
+
print(f"Total images processed: {len(results)}")
|
| 273 |
+
|
| 274 |
+
# Count predictions per class
|
| 275 |
+
class_counts = {class_name: 0 for class_name in class_names}
|
| 276 |
+
for result in results:
|
| 277 |
+
for class_name, binary_pred in result['binary_predictions'].items():
|
| 278 |
+
if binary_pred:
|
| 279 |
+
class_counts[class_name] += 1
|
| 280 |
+
|
| 281 |
+
print("Class distribution (above threshold):")
|
| 282 |
+
for class_name, count in class_counts.items():
|
| 283 |
+
percentage = (count / len(results)) * 100
|
| 284 |
+
print(f" {class_name:15s}: {count:4d} ({percentage:5.1f}%)")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|
mammogram_byol_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbe86fd9ff38440181296b00e2af9bb5db0c5c64793ef339ca5ed39fc1f37986
|
| 3 |
+
size 553443460
|
mammogram_byol_final.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8bdbdd8226a100239f8b27b77a3d1e34a120d2f53f5a8bc0d467450db4a97a8
|
| 3 |
+
size 553451289
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
lightly>=1.4.0
|
| 4 |
+
wandb>=0.15.0
|
| 5 |
+
opencv-python>=4.8.0
|
| 6 |
+
scipy>=1.10.0
|
| 7 |
+
numpy>=1.24.0
|
| 8 |
+
Pillow>=9.5.0
|
| 9 |
+
tqdm>=4.65.0
|
| 10 |
+
matplotlib>=3.7.0
|
| 11 |
+
pandas>=2.3.1
|
| 12 |
+
|
train_byol_mammo.py
ADDED
|
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
train_byol_mammo.py
|
| 4 |
+
|
| 5 |
+
Selfβsupervised BYOL preβtraining with a ResNet50 backbone on
|
| 6 |
+
BREAST TISSUE TILES from mammogram images with intelligent segmentation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import time
|
| 12 |
+
from typing import List, Tuple
|
| 13 |
+
import pickle
|
| 14 |
+
import hashlib
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, optim
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader
|
| 19 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from torchvision import models
|
| 22 |
+
import numpy as np
|
| 23 |
+
import cv2
|
| 24 |
+
from scipy import ndimage
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
import wandb
|
| 27 |
+
|
| 28 |
+
# Lightly imports for BYOL
|
| 29 |
+
from lightly.transforms.byol_transform import (
|
| 30 |
+
BYOLTransform,
|
| 31 |
+
BYOLView1Transform,
|
| 32 |
+
BYOLView2Transform,
|
| 33 |
+
)
|
| 34 |
+
from lightly.loss import NegativeCosineSimilarity
|
| 35 |
+
from lightly.models.modules import BYOLProjectionHead, BYOLPredictionHead
|
| 36 |
+
from lightly.models.utils import deactivate_requires_grad, update_momentum
|
| 37 |
+
from lightly.utils.scheduler import cosine_schedule
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# 1) Configuration - A100 GPU Optimized
|
| 41 |
+
#
|
| 42 |
+
# A100 GPU Memory Configurations:
|
| 43 |
+
# ================================
|
| 44 |
+
# A100-40GB: BATCH_SIZE=32, LR=1e-3, NUM_WORKERS=16
|
| 45 |
+
# A100-80GB: BATCH_SIZE=64, LR=2e-3, NUM_WORKERS=20 (uncomment below for 80GB)
|
| 46 |
+
#
|
| 47 |
+
# For A100-80GB, uncomment these lines:
|
| 48 |
+
# BATCH_SIZE = 64; LR = 2e-3; NUM_WORKERS = 20
|
| 49 |
+
|
| 50 |
+
DATA_DIR = "./split_images/training"
|
| 51 |
+
BATCH_SIZE = 32 # A100 memory optimized (reduced from 64)
|
| 52 |
+
NUM_WORKERS = 16 # A100 CPU core utilization (system recommended max)
|
| 53 |
+
EPOCHS = 100
|
| 54 |
+
LR = 2e-3 # Batch-size scaled: 3e-4 * (BATCH_SIZE/8)
|
| 55 |
+
WARMUP_EPOCHS = 10 # LR warmup for large batch stability
|
| 56 |
+
MOMENTUM_BASE = 0.996
|
| 57 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
WANDB_PROJECT = "mammogram-byol"
|
| 59 |
+
|
| 60 |
+
# Tile settings - preserve full resolution with AGGRESSIVE background rejection
|
| 61 |
+
TILE_SIZE = 512 # px - increased for fewer, higher quality tiles
|
| 62 |
+
TILE_STRIDE = 256 # px (50% overlap)
|
| 63 |
+
MIN_BREAST_RATIO = 0.15 # INCREASED: More strict breast tissue requirement
|
| 64 |
+
MIN_FREQ_ENERGY = 0.03 # INCREASED: Much higher threshold to avoid background noise
|
| 65 |
+
MIN_BREAST_FOR_FREQ = 0.12 # INCREASED: Even more breast tissue required for frequency selection
|
| 66 |
+
MIN_TILE_INTENSITY = 40 # NEW: Minimum average intensity to avoid background
|
| 67 |
+
MIN_NON_ZERO_PIXELS = 0.7 # NEW: At least 70% of pixels must be non-background
|
| 68 |
+
|
| 69 |
+
# Model settings for BYOL pre-training
|
| 70 |
+
HIDDEN_DIM = 4096
|
| 71 |
+
PROJ_DIM = 256
|
| 72 |
+
INPUT_DIM = 2048
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def is_background_tile(image_patch: np.ndarray) -> bool:
|
| 76 |
+
"""
|
| 77 |
+
Comprehensive background detection to reject empty/dark tiles.
|
| 78 |
+
"""
|
| 79 |
+
if len(image_patch.shape) == 3:
|
| 80 |
+
gray = cv2.cvtColor(image_patch, cv2.COLOR_RGB2GRAY)
|
| 81 |
+
else:
|
| 82 |
+
gray = image_patch.copy()
|
| 83 |
+
|
| 84 |
+
# Multiple background rejection criteria
|
| 85 |
+
mean_intensity = np.mean(gray)
|
| 86 |
+
std_intensity = np.std(gray)
|
| 87 |
+
non_zero_pixels = np.sum(gray > 15)
|
| 88 |
+
total_pixels = gray.size
|
| 89 |
+
|
| 90 |
+
# Criteria for background tiles:
|
| 91 |
+
# 1. Too dark overall
|
| 92 |
+
if mean_intensity < MIN_TILE_INTENSITY:
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
# 2. Too many near-zero pixels (empty space)
|
| 96 |
+
if non_zero_pixels / total_pixels < MIN_NON_ZERO_PIXELS:
|
| 97 |
+
return True
|
| 98 |
+
|
| 99 |
+
# 3. Very low variation (uniform background)
|
| 100 |
+
if std_intensity < 10:
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
# 4. Check intensity distribution - reject if too skewed toward zero
|
| 104 |
+
histogram, _ = np.histogram(gray, bins=50, range=(0, 255))
|
| 105 |
+
if histogram[0] > total_pixels * 0.3: # More than 30% pixels near zero
|
| 106 |
+
return True
|
| 107 |
+
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def compute_frequency_energy(image_patch: np.ndarray) -> float:
|
| 112 |
+
"""
|
| 113 |
+
Compute high-frequency energy with AGGRESSIVE background rejection.
|
| 114 |
+
"""
|
| 115 |
+
if len(image_patch.shape) == 3:
|
| 116 |
+
gray = cv2.cvtColor(image_patch, cv2.COLOR_RGB2GRAY)
|
| 117 |
+
else:
|
| 118 |
+
gray = image_patch.copy()
|
| 119 |
+
|
| 120 |
+
# AGGRESSIVE background rejection
|
| 121 |
+
mean_intensity = np.mean(gray)
|
| 122 |
+
if mean_intensity < MIN_TILE_INTENSITY: # Much stricter intensity threshold
|
| 123 |
+
return 0.0
|
| 124 |
+
|
| 125 |
+
# Check for sufficient non-background pixels
|
| 126 |
+
non_zero_ratio = np.sum(gray > 15) / gray.size
|
| 127 |
+
if non_zero_ratio < MIN_NON_ZERO_PIXELS: # Too much background
|
| 128 |
+
return 0.0
|
| 129 |
+
|
| 130 |
+
# Apply Laplacian of Gaussian for high-frequency detection
|
| 131 |
+
blurred = cv2.GaussianBlur(gray.astype(np.float32), (3, 3), 1.0)
|
| 132 |
+
laplacian = cv2.Laplacian(blurred, cv2.CV_32F, ksize=3)
|
| 133 |
+
|
| 134 |
+
# Focus only on positive responses (bright spots)
|
| 135 |
+
positive_laplacian = np.maximum(laplacian, 0)
|
| 136 |
+
|
| 137 |
+
# Only analyze pixels with meaningful intensity
|
| 138 |
+
mask = gray > max(30, mean_intensity * 0.4) # Much stricter tissue mask
|
| 139 |
+
if np.sum(mask) < (gray.size * 0.2): # Need substantial tissue content
|
| 140 |
+
return 0.0
|
| 141 |
+
|
| 142 |
+
masked_laplacian = positive_laplacian[mask]
|
| 143 |
+
energy = np.var(masked_laplacian) / (mean_intensity + 1e-8)
|
| 144 |
+
|
| 145 |
+
return float(energy)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def segment_breast_tissue(image_array: np.ndarray) -> np.ndarray:
|
| 149 |
+
"""
|
| 150 |
+
Enhanced breast tissue segmentation with aggressive background removal
|
| 151 |
+
"""
|
| 152 |
+
if len(image_array.shape) == 3:
|
| 153 |
+
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
|
| 154 |
+
else:
|
| 155 |
+
gray = image_array.copy()
|
| 156 |
+
|
| 157 |
+
# More aggressive pre-filtering of background
|
| 158 |
+
filtered_gray = np.where(gray > 20, gray, 0) # Stricter background cutoff
|
| 159 |
+
|
| 160 |
+
# Otsu thresholding
|
| 161 |
+
_, binary = cv2.threshold(filtered_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 162 |
+
|
| 163 |
+
# Additional background removal based on intensity
|
| 164 |
+
binary = np.where(gray > 25, binary, 0).astype(np.uint8)
|
| 165 |
+
|
| 166 |
+
# More aggressive morphological operations
|
| 167 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) # Larger kernel
|
| 168 |
+
opened = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
|
| 169 |
+
|
| 170 |
+
# Fill holes
|
| 171 |
+
filled = ndimage.binary_fill_holes(opened).astype(np.uint8) * 255
|
| 172 |
+
|
| 173 |
+
# Keep largest connected component
|
| 174 |
+
num_labels, labels = cv2.connectedComponents(filled)
|
| 175 |
+
if num_labels > 1:
|
| 176 |
+
largest_label = 1 + np.argmax([np.sum(labels == i) for i in range(1, num_labels)])
|
| 177 |
+
mask = (labels == largest_label).astype(np.uint8) * 255
|
| 178 |
+
else:
|
| 179 |
+
mask = filled
|
| 180 |
+
|
| 181 |
+
# Closing with larger kernel for smoother boundaries
|
| 182 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 183 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 184 |
+
|
| 185 |
+
return mask > 0
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class BreastTileMammoDataset(Dataset):
|
| 189 |
+
"""Produces breast tissue tiles from mammograms with AGGRESSIVE background rejection."""
|
| 190 |
+
|
| 191 |
+
def __init__(self, root: str, tile_size: int, stride: int, min_breast_ratio: float = 0.15, min_freq_energy: float = 0.03, min_breast_for_freq: float = 0.12, transform=None):
|
| 192 |
+
self.transform = transform
|
| 193 |
+
self.tile_size = tile_size
|
| 194 |
+
self.stride = stride
|
| 195 |
+
self.min_breast_ratio = min_breast_ratio
|
| 196 |
+
self.min_freq_energy = min_freq_energy
|
| 197 |
+
self.min_breast_for_freq = min_breast_for_freq
|
| 198 |
+
self.tiles = [] # (path, x, y, breast_ratio, freq_energy)
|
| 199 |
+
|
| 200 |
+
# Generate cache filename based on parameters
|
| 201 |
+
cache_key = self._generate_cache_key(root, tile_size, stride, min_breast_ratio, min_freq_energy, min_breast_for_freq)
|
| 202 |
+
cache_file = Path(f"tile_cache_{cache_key}.pkl")
|
| 203 |
+
|
| 204 |
+
# Try to load from cache first
|
| 205 |
+
if cache_file.exists():
|
| 206 |
+
print(f"[Dataset] Found cached tiles: {cache_file}")
|
| 207 |
+
print(f"[Dataset] Loading tiles from cache (avoiding ~57min extraction)...")
|
| 208 |
+
with open(cache_file, 'rb') as f:
|
| 209 |
+
cache_data = pickle.load(f)
|
| 210 |
+
self.tiles = cache_data['tiles']
|
| 211 |
+
stats = cache_data['stats']
|
| 212 |
+
|
| 213 |
+
print(f"[Dataset] β
Loaded {len(self.tiles):,} cached tiles!")
|
| 214 |
+
print(f" β’ Generated {stats['breast_tiles']:,} tiles from {stats['total_tiles']:,} possible ({stats['efficiency']:.1f}% efficiency)")
|
| 215 |
+
print(f" β’ Breast tissue method tiles: {stats['breast_tiles'] - stats['freq_tiles']:,}")
|
| 216 |
+
print(f" β’ Frequency energy method tiles: {stats['freq_tiles']:,}")
|
| 217 |
+
print(f" β’ Average breast tissue per tile: {stats['avg_breast_ratio']:.1%}")
|
| 218 |
+
print(f" β’ Average frequency energy per tile: {stats['avg_freq_energy']:.4f}")
|
| 219 |
+
print(f" β
Cache hit: Skipping tile extraction")
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
# Cache miss - extract tiles from scratch
|
| 223 |
+
img_paths = list(Path(root).glob("*.png"))
|
| 224 |
+
if len(img_paths) == 0:
|
| 225 |
+
raise RuntimeError(f"No .png files found in {root!r}")
|
| 226 |
+
|
| 227 |
+
print(f"[Dataset] Cache miss: Extracting tiles from {len(img_paths)} mammogram images...")
|
| 228 |
+
print(f"[Dataset] This will take ~57 minutes but will be cached for future runs...")
|
| 229 |
+
|
| 230 |
+
total_tiles = 0
|
| 231 |
+
breast_tiles = 0
|
| 232 |
+
freq_tiles = 0
|
| 233 |
+
total_rejected_bg = 0
|
| 234 |
+
total_rejected_intensity = 0
|
| 235 |
+
|
| 236 |
+
for img_path in tqdm(img_paths, desc="Extracting breast tiles with AGGRESSIVE background rejection",
|
| 237 |
+
ncols=100, leave=False):
|
| 238 |
+
with Image.open(img_path) as img:
|
| 239 |
+
img_array = np.array(img)
|
| 240 |
+
|
| 241 |
+
# Segment breast tissue with enhanced method
|
| 242 |
+
breast_mask = segment_breast_tissue(img_array)
|
| 243 |
+
|
| 244 |
+
# Extract tiles from breast regions (no per-image logging to reduce clutter)
|
| 245 |
+
tiles = self._extract_breast_tiles(img_array, breast_mask, img_path)
|
| 246 |
+
self.tiles.extend(tiles)
|
| 247 |
+
|
| 248 |
+
# Count selection methods
|
| 249 |
+
image_breast_tiles = sum(1 for t in tiles if len(t) > 4 and
|
| 250 |
+
(len(t) <= 5 or t[4] >= self.min_breast_ratio))
|
| 251 |
+
image_freq_tiles = len(tiles) - image_breast_tiles
|
| 252 |
+
|
| 253 |
+
total_tiles += len(self._get_all_possible_tiles(img_array.shape))
|
| 254 |
+
breast_tiles += len(tiles)
|
| 255 |
+
freq_tiles += image_freq_tiles
|
| 256 |
+
|
| 257 |
+
# Enhanced summary statistics matching notebook
|
| 258 |
+
efficiency = (breast_tiles / total_tiles) * 100 if total_tiles > 0 else 0
|
| 259 |
+
avg_breast_ratio = np.mean([t[3] for t in self.tiles])
|
| 260 |
+
avg_freq_energy = np.mean([t[4] for t in self.tiles])
|
| 261 |
+
|
| 262 |
+
print(f"\n[Dataset] AGGRESSIVE Background Rejection Results:")
|
| 263 |
+
print(f" β’ Generated {breast_tiles:,} tiles from {total_tiles:,} possible ({efficiency:.1f}% efficiency)")
|
| 264 |
+
print(f" β’ Breast tissue method tiles: {breast_tiles - freq_tiles:,}")
|
| 265 |
+
print(f" β’ Frequency energy method tiles: {freq_tiles:,}")
|
| 266 |
+
print(f" β’ Average breast tissue per tile: {avg_breast_ratio:.1%}")
|
| 267 |
+
print(f" β’ Average frequency energy per tile: {avg_freq_energy:.4f}")
|
| 268 |
+
print(f" β’ Background contamination check: SKIPPED (pre-filtered during extraction)")
|
| 269 |
+
print(f" β
All tiles passed AGGRESSIVE background rejection during extraction")
|
| 270 |
+
print(f" β
Quality assured: Multi-level filtering eliminated empty space tiles")
|
| 271 |
+
|
| 272 |
+
# Save to cache for future runs
|
| 273 |
+
print(f"[Dataset] πΎ Saving tiles to cache: {cache_file}")
|
| 274 |
+
cache_data = {
|
| 275 |
+
'tiles': self.tiles,
|
| 276 |
+
'stats': {
|
| 277 |
+
'total_tiles': total_tiles,
|
| 278 |
+
'breast_tiles': breast_tiles,
|
| 279 |
+
'freq_tiles': freq_tiles,
|
| 280 |
+
'efficiency': efficiency,
|
| 281 |
+
'avg_breast_ratio': avg_breast_ratio,
|
| 282 |
+
'avg_freq_energy': avg_freq_energy
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
with open(cache_file, 'wb') as f:
|
| 286 |
+
pickle.dump(cache_data, f)
|
| 287 |
+
print(f" β
Cache saved! Future runs will load instantly.")
|
| 288 |
+
|
| 289 |
+
def _generate_cache_key(self, root: str, tile_size: int, stride: int, min_breast_ratio: float, min_freq_energy: float, min_breast_for_freq: float) -> str:
|
| 290 |
+
"""Generate a unique cache key based on dataset parameters."""
|
| 291 |
+
# Include modification times of image files to detect changes
|
| 292 |
+
img_paths = sorted(Path(root).glob("*.png"))
|
| 293 |
+
file_info = [(str(p), p.stat().st_mtime) for p in img_paths[:10]] # Sample first 10 files
|
| 294 |
+
|
| 295 |
+
key_data = {
|
| 296 |
+
'root': root,
|
| 297 |
+
'tile_size': tile_size,
|
| 298 |
+
'stride': stride,
|
| 299 |
+
'min_breast_ratio': min_breast_ratio,
|
| 300 |
+
'min_freq_energy': min_freq_energy,
|
| 301 |
+
'min_breast_for_freq': min_breast_for_freq,
|
| 302 |
+
'num_images': len(img_paths),
|
| 303 |
+
'file_sample': file_info,
|
| 304 |
+
'version': '1.0' # Increment this if extraction logic changes
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
key_str = str(key_data)
|
| 308 |
+
return hashlib.md5(key_str.encode()).hexdigest()[:12]
|
| 309 |
+
|
| 310 |
+
def _get_all_possible_tiles(self, shape: Tuple) -> List:
|
| 311 |
+
"""Get all possible tile positions for efficiency calculation."""
|
| 312 |
+
height, width = shape[:2]
|
| 313 |
+
positions = []
|
| 314 |
+
|
| 315 |
+
y_positions = list(range(0, max(1, height - self.tile_size + 1), self.stride))
|
| 316 |
+
x_positions = list(range(0, max(1, width - self.tile_size + 1), self.stride))
|
| 317 |
+
|
| 318 |
+
if y_positions[-1] + self.tile_size < height:
|
| 319 |
+
y_positions.append(height - self.tile_size)
|
| 320 |
+
if x_positions[-1] + self.tile_size < width:
|
| 321 |
+
x_positions.append(width - self.tile_size)
|
| 322 |
+
|
| 323 |
+
for y in y_positions:
|
| 324 |
+
for x in x_positions:
|
| 325 |
+
positions.append((x, y))
|
| 326 |
+
|
| 327 |
+
return positions
|
| 328 |
+
|
| 329 |
+
def _extract_breast_tiles(self, image_array: np.ndarray, breast_mask: np.ndarray, img_path: Path) -> List:
|
| 330 |
+
"""Extract tiles with AGGRESSIVE background rejection - NO empty space tiles allowed."""
|
| 331 |
+
tiles = []
|
| 332 |
+
rejected_background = 0
|
| 333 |
+
rejected_intensity = 0
|
| 334 |
+
rejected_breast_ratio = 0
|
| 335 |
+
rejected_freq_energy = 0
|
| 336 |
+
|
| 337 |
+
height, width = image_array.shape[:2]
|
| 338 |
+
|
| 339 |
+
# Generate all possible tile positions
|
| 340 |
+
y_positions = list(range(0, max(1, height - self.tile_size + 1), self.stride))
|
| 341 |
+
x_positions = list(range(0, max(1, width - self.tile_size + 1), self.stride))
|
| 342 |
+
|
| 343 |
+
# Add edge positions if needed
|
| 344 |
+
if y_positions[-1] + self.tile_size < height:
|
| 345 |
+
y_positions.append(height - self.tile_size)
|
| 346 |
+
if x_positions[-1] + self.tile_size < width:
|
| 347 |
+
x_positions.append(width - self.tile_size)
|
| 348 |
+
|
| 349 |
+
for y in y_positions:
|
| 350 |
+
for x in x_positions:
|
| 351 |
+
# Extract image tile
|
| 352 |
+
tile_image = image_array[y:y+self.tile_size, x:x+self.tile_size]
|
| 353 |
+
|
| 354 |
+
# STEP 1: Comprehensive background rejection
|
| 355 |
+
if is_background_tile(tile_image):
|
| 356 |
+
rejected_background += 1
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
# STEP 2: Intensity-based rejection
|
| 360 |
+
mean_intensity = np.mean(tile_image)
|
| 361 |
+
if mean_intensity < MIN_TILE_INTENSITY:
|
| 362 |
+
rejected_intensity += 1
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
# STEP 3: Breast tissue ratio check
|
| 366 |
+
tile_mask = breast_mask[y:y+self.tile_size, x:x+self.tile_size]
|
| 367 |
+
breast_ratio = np.sum(tile_mask) / (self.tile_size * self.tile_size)
|
| 368 |
+
|
| 369 |
+
# STEP 4: Enhanced selection logic with multiple criteria
|
| 370 |
+
freq_energy = compute_frequency_energy(tile_image)
|
| 371 |
+
|
| 372 |
+
# Main selection criteria
|
| 373 |
+
selected = False
|
| 374 |
+
selection_reason = ""
|
| 375 |
+
|
| 376 |
+
if breast_ratio >= self.min_breast_ratio:
|
| 377 |
+
selected = True
|
| 378 |
+
selection_reason = "breast_tissue"
|
| 379 |
+
elif (freq_energy >= self.min_freq_energy and
|
| 380 |
+
breast_ratio >= self.min_breast_for_freq and
|
| 381 |
+
mean_intensity >= MIN_TILE_INTENSITY + 10): # Even stricter for freq tiles
|
| 382 |
+
selected = True
|
| 383 |
+
selection_reason = "frequency_energy"
|
| 384 |
+
|
| 385 |
+
if selected:
|
| 386 |
+
tiles.append((img_path, x, y, breast_ratio, freq_energy))
|
| 387 |
+
else:
|
| 388 |
+
if freq_energy < self.min_freq_energy:
|
| 389 |
+
rejected_freq_energy += 1
|
| 390 |
+
else:
|
| 391 |
+
rejected_breast_ratio += 1
|
| 392 |
+
|
| 393 |
+
# Accumulate rejection stats (no per-image logging to reduce clutter)
|
| 394 |
+
|
| 395 |
+
return tiles
|
| 396 |
+
|
| 397 |
+
def __len__(self):
|
| 398 |
+
return len(self.tiles)
|
| 399 |
+
|
| 400 |
+
def __getitem__(self, idx):
|
| 401 |
+
img_path, x, y, breast_ratio, freq_energy = self.tiles[idx]
|
| 402 |
+
|
| 403 |
+
with Image.open(img_path) as img:
|
| 404 |
+
# Extract tile while preserving full resolution
|
| 405 |
+
crop = img.crop((x, y, x + self.tile_size, y + self.tile_size))
|
| 406 |
+
|
| 407 |
+
# Keep as grayscale for medical imaging, convert to RGB by replicating channel
|
| 408 |
+
if crop.mode != 'L':
|
| 409 |
+
crop = crop.convert('L')
|
| 410 |
+
# Convert to RGB by replicating the grayscale channel
|
| 411 |
+
crop = crop.convert('RGB')
|
| 412 |
+
|
| 413 |
+
# Apply BYOL transformations
|
| 414 |
+
views = self.transform(crop)
|
| 415 |
+
|
| 416 |
+
return views, breast_ratio # Return breast ratio for monitoring
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class MammogramBYOL(nn.Module):
|
| 420 |
+
"""BYOL model for self-supervised pre-training on mammogram tiles."""
|
| 421 |
+
|
| 422 |
+
def __init__(self, backbone, input_dim=2048, hidden_dim=4096, proj_dim=256):
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.backbone = backbone
|
| 425 |
+
self.projection_head = BYOLProjectionHead(input_dim, hidden_dim, proj_dim)
|
| 426 |
+
self.prediction_head = BYOLPredictionHead(proj_dim, hidden_dim, proj_dim)
|
| 427 |
+
|
| 428 |
+
# Momentum (target) networks
|
| 429 |
+
self.backbone_momentum = copy.deepcopy(backbone)
|
| 430 |
+
self.projection_head_momentum = copy.deepcopy(self.projection_head)
|
| 431 |
+
deactivate_requires_grad(self.backbone_momentum)
|
| 432 |
+
deactivate_requires_grad(self.projection_head_momentum)
|
| 433 |
+
|
| 434 |
+
def forward(self, x):
|
| 435 |
+
"""Forward pass for BYOL training."""
|
| 436 |
+
h = self.backbone(x).flatten(start_dim=1)
|
| 437 |
+
z = self.projection_head(h)
|
| 438 |
+
return self.prediction_head(z)
|
| 439 |
+
|
| 440 |
+
def forward_momentum(self, x):
|
| 441 |
+
"""Forward pass through momentum network."""
|
| 442 |
+
h = self.backbone_momentum(x).flatten(start_dim=1)
|
| 443 |
+
z = self.projection_head_momentum(h)
|
| 444 |
+
return z.detach()
|
| 445 |
+
|
| 446 |
+
def get_features(self, x):
|
| 447 |
+
"""Extract backbone features (for downstream tasks)."""
|
| 448 |
+
with torch.no_grad():
|
| 449 |
+
return self.backbone(x).flatten(start_dim=1)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def create_medical_transforms(input_size: int):
|
| 453 |
+
"""Create BYOL transforms with stronger augmentations for effective self-supervised learning."""
|
| 454 |
+
import torchvision.transforms as T
|
| 455 |
+
|
| 456 |
+
# View 1: Moderate augmentations for medical safety
|
| 457 |
+
view1_transform = T.Compose([
|
| 458 |
+
T.ToTensor(),
|
| 459 |
+
T.RandomHorizontalFlip(p=0.5),
|
| 460 |
+
T.RandomVerticalFlip(p=0.2), # Added vertical flip for more diversity
|
| 461 |
+
T.RandomRotation(degrees=15, fill=0), # Increased rotation range
|
| 462 |
+
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0, hue=0), # Stronger brightness/contrast
|
| 463 |
+
T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.85, 1.15), fill=0), # More translation/scaling
|
| 464 |
+
T.Resize(input_size, antialias=True),
|
| 465 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 466 |
+
])
|
| 467 |
+
|
| 468 |
+
# View 2: Stronger augmentations for BYOL effectiveness
|
| 469 |
+
view2_transform = T.Compose([
|
| 470 |
+
T.ToTensor(),
|
| 471 |
+
T.RandomHorizontalFlip(p=0.5),
|
| 472 |
+
T.RandomVerticalFlip(p=0.3), # Higher chance for more diversity
|
| 473 |
+
T.RandomRotation(degrees=25, fill=0), # Wider rotation range
|
| 474 |
+
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0, hue=0), # Standard BYOL intensity
|
| 475 |
+
T.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), fill=0), # More aggressive transforms
|
| 476 |
+
T.RandomPerspective(distortion_scale=0.1, p=0.3, fill=0), # Add perspective distortion
|
| 477 |
+
T.GaussianBlur(kernel_size=5, sigma=(0.1, 1.5)), # Stronger blur range
|
| 478 |
+
T.RandomGrayscale(p=0.2), # Convert to grayscale occasionally for more diversity
|
| 479 |
+
T.Resize(input_size, antialias=True),
|
| 480 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 481 |
+
])
|
| 482 |
+
|
| 483 |
+
return BYOLTransform(
|
| 484 |
+
view_1_transform=view1_transform,
|
| 485 |
+
view_2_transform=view2_transform,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def estimate_memory_usage(batch_size: int, tile_size: int = 256) -> float:
|
| 490 |
+
"""Estimate GPU memory usage in GB for the given configuration."""
|
| 491 |
+
# Model parameters (ResNet50 + BYOL heads + momentum networks)
|
| 492 |
+
model_memory = 6.5 # GB - ResNet50 + BYOL + momentum networks
|
| 493 |
+
|
| 494 |
+
# Batch memory (RGB tiles + gradients + optimizer states)
|
| 495 |
+
tile_memory_mb = (tile_size * tile_size * 3 * 4) / (1024 * 1024) # 4 bytes per float32
|
| 496 |
+
batch_memory = batch_size * tile_memory_mb * 4 / 1024 # x4 for forward/backward + optimizer states
|
| 497 |
+
|
| 498 |
+
total_memory = model_memory + batch_memory
|
| 499 |
+
return total_memory
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def main():
|
| 503 |
+
# Memory usage estimation
|
| 504 |
+
estimated_memory = estimate_memory_usage(BATCH_SIZE, TILE_SIZE)
|
| 505 |
+
print(f"π Estimated GPU Memory Usage: {estimated_memory:.1f} GB")
|
| 506 |
+
if estimated_memory > 40:
|
| 507 |
+
print(f"β οΈ Warning: May exceed A100-40GB capacity. Consider batch size {int(BATCH_SIZE * 35 / estimated_memory)}")
|
| 508 |
+
elif estimated_memory < 25:
|
| 509 |
+
print(f"π‘ Tip: GPU underutilized. Consider increasing batch size to {int(BATCH_SIZE * 35 / estimated_memory)} for A100-40GB")
|
| 510 |
+
print()
|
| 511 |
+
|
| 512 |
+
# Initialize wandb (offline mode if no API key)
|
| 513 |
+
try:
|
| 514 |
+
wandb.init(
|
| 515 |
+
project=WANDB_PROJECT,
|
| 516 |
+
config={
|
| 517 |
+
# A100 Optimization Settings
|
| 518 |
+
"gpu_type": "A100",
|
| 519 |
+
"batch_size": BATCH_SIZE,
|
| 520 |
+
"num_workers": NUM_WORKERS,
|
| 521 |
+
"learning_rate": LR,
|
| 522 |
+
"warmup_epochs": WARMUP_EPOCHS,
|
| 523 |
+
"estimated_memory_gb": estimate_memory_usage(BATCH_SIZE, TILE_SIZE),
|
| 524 |
+
|
| 525 |
+
# Model Architecture
|
| 526 |
+
"backbone": "resnet50",
|
| 527 |
+
"pretrained_weights": "IMAGENET1K_V2",
|
| 528 |
+
"tile_size": TILE_SIZE,
|
| 529 |
+
"epochs": EPOCHS,
|
| 530 |
+
"momentum_base": MOMENTUM_BASE,
|
| 531 |
+
"hidden_dim": HIDDEN_DIM,
|
| 532 |
+
"proj_dim": PROJ_DIM,
|
| 533 |
+
|
| 534 |
+
# Medical Pipeline Settings
|
| 535 |
+
"min_breast_ratio": MIN_BREAST_RATIO,
|
| 536 |
+
"min_freq_energy": MIN_FREQ_ENERGY,
|
| 537 |
+
"min_breast_for_freq": MIN_BREAST_FOR_FREQ,
|
| 538 |
+
"min_tile_intensity": MIN_TILE_INTENSITY,
|
| 539 |
+
"min_non_zero_pixels": MIN_NON_ZERO_PIXELS,
|
| 540 |
+
|
| 541 |
+
# Optimization Features
|
| 542 |
+
"mixed_precision": True,
|
| 543 |
+
"pytorch_compile": hasattr(torch, 'compile'),
|
| 544 |
+
"gradient_clipping": True,
|
| 545 |
+
"lr_scheduler": "warmup_cosine",
|
| 546 |
+
}
|
| 547 |
+
)
|
| 548 |
+
wandb_enabled = True
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f"β οΈ WandB not configured, running offline. To enable: wandb login")
|
| 551 |
+
wandb_enabled = False
|
| 552 |
+
|
| 553 |
+
print("π¬ Mammogram BYOL Training with AGGRESSIVE Background Rejection")
|
| 554 |
+
print("=" * 60)
|
| 555 |
+
print(f"Device: {DEVICE}")
|
| 556 |
+
print(f"Tile size: {TILE_SIZE}x{TILE_SIZE} (increased for fewer, higher quality tiles)")
|
| 557 |
+
print(f"Tile stride: {TILE_STRIDE} pixels ({TILE_STRIDE/TILE_SIZE*100:.0f}% overlap)")
|
| 558 |
+
print(f"\nπ AGGRESSIVE Background Rejection Parameters:")
|
| 559 |
+
print(f" π‘οΈ MIN_BREAST_RATIO: {MIN_BREAST_RATIO:.1%} (increased from 0.3)")
|
| 560 |
+
print(f" π‘οΈ MIN_FREQ_ENERGY: {MIN_FREQ_ENERGY:.3f} (much higher threshold)")
|
| 561 |
+
print(f" π‘οΈ MIN_BREAST_FOR_FREQ: {MIN_BREAST_FOR_FREQ:.1%} (stricter for frequency tiles)")
|
| 562 |
+
print(f" π‘οΈ MIN_TILE_INTENSITY: {MIN_TILE_INTENSITY} (reject dark background)")
|
| 563 |
+
print(f" π‘οΈ MIN_NON_ZERO_PIXELS: {MIN_NON_ZERO_PIXELS:.1%} (reject empty space)")
|
| 564 |
+
print(f"\nποΈ Enhanced BYOL Augmentations for Effective Self-Supervised Learning:")
|
| 565 |
+
print(f" β
View 1: Moderate (brightness/contrast 0.3/0.3, Β±15Β° rotation, scale 0.85-1.15)")
|
| 566 |
+
print(f" β
View 2: Strong (brightness/contrast 0.4/0.4, Β±25Β° rotation, perspective, blur)")
|
| 567 |
+
print(f" β
Added: Vertical flips, random perspective, random grayscale for diversity")
|
| 568 |
+
print(f" β
Balanced: Strong enough for BYOL while preserving medical details")
|
| 569 |
+
print(f"\nMulti-level filtering eliminates ALL empty space tiles\n")
|
| 570 |
+
|
| 571 |
+
# Medical-optimized BYOL transforms
|
| 572 |
+
transform = create_medical_transforms(TILE_SIZE)
|
| 573 |
+
|
| 574 |
+
# Dataset with AGGRESSIVE background rejection and micro-calcification detection
|
| 575 |
+
dataset = BreastTileMammoDataset(
|
| 576 |
+
DATA_DIR, TILE_SIZE, TILE_STRIDE, MIN_BREAST_RATIO, MIN_FREQ_ENERGY, MIN_BREAST_FOR_FREQ, transform
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# A100-optimized DataLoader settings
|
| 580 |
+
loader = DataLoader(
|
| 581 |
+
dataset,
|
| 582 |
+
batch_size=BATCH_SIZE,
|
| 583 |
+
shuffle=True,
|
| 584 |
+
drop_last=True,
|
| 585 |
+
num_workers=NUM_WORKERS,
|
| 586 |
+
pin_memory=True,
|
| 587 |
+
persistent_workers=True,
|
| 588 |
+
prefetch_factor=4, # A100 optimization: prefetch more batches
|
| 589 |
+
multiprocessing_context='spawn', # Better for CUDA
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
print(f"π Dataset: {len(dataset):,} breast tissue tiles β {len(loader):,} batches")
|
| 593 |
+
|
| 594 |
+
# Model with classification readiness - ImageNet pretrained for better convergence
|
| 595 |
+
# ImageNet pretraining helps even for medical images by providing:
|
| 596 |
+
# 1. Better edge/texture detectors in early layers
|
| 597 |
+
# 2. Faster convergence and more stable training
|
| 598 |
+
# 3. Better generalization to medical domain features
|
| 599 |
+
resnet = models.resnet50(weights='IMAGENET1K_V2') # Latest ImageNet weights for better medical transfer
|
| 600 |
+
backbone = nn.Sequential(*list(resnet.children())[:-1])
|
| 601 |
+
model = MammogramBYOL(backbone, INPUT_DIM, HIDDEN_DIM, PROJ_DIM).to(DEVICE)
|
| 602 |
+
|
| 603 |
+
print(f"β
Using ImageNet-pretrained ResNet50 backbone for better medical domain transfer")
|
| 604 |
+
|
| 605 |
+
# A100 Performance Boost: PyTorch 2.0 Compile (if available)
|
| 606 |
+
if hasattr(torch, 'compile') and torch.cuda.is_available():
|
| 607 |
+
print("π Enabling PyTorch 2.0 compile optimization for A100...")
|
| 608 |
+
model = torch.compile(model, mode='max-autotune') # Maximum A100 optimization
|
| 609 |
+
print(" β
Model compiled for maximum A100 performance")
|
| 610 |
+
else:
|
| 611 |
+
print(" β οΈ PyTorch 2.0 compile not available - using standard optimization")
|
| 612 |
+
|
| 613 |
+
criterion = NegativeCosineSimilarity()
|
| 614 |
+
|
| 615 |
+
# Optimized for large batch training on A100
|
| 616 |
+
optimizer = optim.AdamW(
|
| 617 |
+
model.parameters(),
|
| 618 |
+
lr=LR,
|
| 619 |
+
weight_decay=1e-4,
|
| 620 |
+
betas=(0.9, 0.999), # Standard for large batch
|
| 621 |
+
eps=1e-8
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# LR warmup + cosine annealing for large batch stability
|
| 625 |
+
warmup_scheduler = optim.lr_scheduler.LinearLR(
|
| 626 |
+
optimizer,
|
| 627 |
+
start_factor=0.1,
|
| 628 |
+
end_factor=1.0,
|
| 629 |
+
total_iters=WARMUP_EPOCHS
|
| 630 |
+
)
|
| 631 |
+
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 632 |
+
optimizer,
|
| 633 |
+
T_max=EPOCHS - WARMUP_EPOCHS, # After warmup
|
| 634 |
+
eta_min=LR * 0.01 # 1% of peak LR
|
| 635 |
+
)
|
| 636 |
+
scheduler = optim.lr_scheduler.SequentialLR(
|
| 637 |
+
optimizer,
|
| 638 |
+
schedulers=[warmup_scheduler, cosine_scheduler],
|
| 639 |
+
milestones=[WARMUP_EPOCHS]
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
scaler = GradScaler() # Mixed precision training for A100 optimization
|
| 643 |
+
|
| 644 |
+
print(f"π§ Model: ResNet50 backbone with {sum(p.numel() for p in model.parameters()):,} parameters")
|
| 645 |
+
print(f"π― Ready for downstream tasks with {INPUT_DIM}D backbone features")
|
| 646 |
+
print(f"\nβ‘ A100 GPU MAXIMUM PERFORMANCE OPTIMIZATIONS:")
|
| 647 |
+
print(f" π Large batch training: BATCH_SIZE={BATCH_SIZE} (4x increase)")
|
| 648 |
+
print(f" π Scaled learning rate: LR={LR} with {WARMUP_EPOCHS}-epoch warmup")
|
| 649 |
+
print(f" π Mixed precision training: autocast + GradScaler")
|
| 650 |
+
print(f" π PyTorch 2.0 compile: max-autotune mode (if available)")
|
| 651 |
+
print(f" π Enhanced DataLoader: {NUM_WORKERS} workers, prefetch_factor=4")
|
| 652 |
+
print(f" π Per-step momentum updates: optimal BYOL convergence")
|
| 653 |
+
print(f" π Sequential LR scheduler: warmup β cosine annealing")
|
| 654 |
+
print(f" π Gradient clipping: max_norm=1.0 for stability")
|
| 655 |
+
print(f" πΎ Memory optimized: pin_memory + non_blocking transfers\n")
|
| 656 |
+
|
| 657 |
+
# Training loop with progress tracking
|
| 658 |
+
start_time = time.time()
|
| 659 |
+
best_loss = float('inf')
|
| 660 |
+
global_step = 0
|
| 661 |
+
total_steps = EPOCHS * len(loader)
|
| 662 |
+
|
| 663 |
+
for epoch in range(1, EPOCHS + 1):
|
| 664 |
+
model.train()
|
| 665 |
+
epoch_loss = 0.0
|
| 666 |
+
breast_ratios = []
|
| 667 |
+
|
| 668 |
+
# Clean progress bar for epoch
|
| 669 |
+
pbar = tqdm(loader, desc=f"Epoch {epoch:3d}/{EPOCHS}",
|
| 670 |
+
ncols=80, leave=False, disable=False)
|
| 671 |
+
|
| 672 |
+
for batch_idx, (views, batch_breast_ratios) in enumerate(pbar):
|
| 673 |
+
x0, x1 = views
|
| 674 |
+
x0, x1 = x0.to(DEVICE, non_blocking=True), x1.to(DEVICE, non_blocking=True)
|
| 675 |
+
|
| 676 |
+
# Per-step momentum update schedule (BYOL best practice)
|
| 677 |
+
momentum = cosine_schedule(global_step, total_steps, MOMENTUM_BASE, 1.0)
|
| 678 |
+
|
| 679 |
+
# Update momentum networks
|
| 680 |
+
update_momentum(model.backbone, model.backbone_momentum, momentum)
|
| 681 |
+
update_momentum(model.projection_head, model.projection_head_momentum, momentum)
|
| 682 |
+
|
| 683 |
+
global_step += 1
|
| 684 |
+
|
| 685 |
+
# Mixed precision forward passes
|
| 686 |
+
with autocast():
|
| 687 |
+
# BYOL forward passes
|
| 688 |
+
p0 = model(x0)
|
| 689 |
+
z1 = model.forward_momentum(x1)
|
| 690 |
+
p1 = model(x1)
|
| 691 |
+
z0 = model.forward_momentum(x0)
|
| 692 |
+
|
| 693 |
+
# BYOL loss
|
| 694 |
+
loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
|
| 695 |
+
|
| 696 |
+
# Mixed precision optimization step
|
| 697 |
+
optimizer.zero_grad()
|
| 698 |
+
scaler.scale(loss).backward()
|
| 699 |
+
scaler.unscale_(optimizer) # Unscale before gradient clipping
|
| 700 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 701 |
+
scaler.step(optimizer)
|
| 702 |
+
scaler.update()
|
| 703 |
+
|
| 704 |
+
# Metrics
|
| 705 |
+
epoch_loss += loss.item()
|
| 706 |
+
breast_ratios.extend(batch_breast_ratios.numpy())
|
| 707 |
+
|
| 708 |
+
# Update progress bar every 50 batches to reduce clutter
|
| 709 |
+
if batch_idx % 50 == 0 or batch_idx == len(loader) - 1:
|
| 710 |
+
pbar.set_postfix({
|
| 711 |
+
'Loss': f'{loss.item():.4f}',
|
| 712 |
+
'LR': f'{scheduler.get_last_lr()[0]:.1e}'
|
| 713 |
+
})
|
| 714 |
+
|
| 715 |
+
scheduler.step()
|
| 716 |
+
|
| 717 |
+
# Epoch metrics
|
| 718 |
+
avg_loss = epoch_loss / len(loader)
|
| 719 |
+
avg_breast_ratio = np.mean(breast_ratios)
|
| 720 |
+
elapsed = time.time() - start_time
|
| 721 |
+
|
| 722 |
+
# Log to wandb if enabled
|
| 723 |
+
if wandb_enabled:
|
| 724 |
+
wandb.log({
|
| 725 |
+
"epoch": epoch,
|
| 726 |
+
"loss": avg_loss,
|
| 727 |
+
"momentum": momentum,
|
| 728 |
+
"learning_rate": scheduler.get_last_lr()[0],
|
| 729 |
+
"avg_breast_ratio": avg_breast_ratio,
|
| 730 |
+
"elapsed_hours": elapsed / 3600,
|
| 731 |
+
})
|
| 732 |
+
|
| 733 |
+
# Concise epoch summary
|
| 734 |
+
print(f"Epoch {epoch:3d}/{EPOCHS} β Loss: {avg_loss:.4f} β Breast: {avg_breast_ratio:.1%} β {elapsed/60:.1f}min")
|
| 735 |
+
|
| 736 |
+
# Save best model and periodic checkpoints
|
| 737 |
+
if avg_loss < best_loss:
|
| 738 |
+
best_loss = avg_loss
|
| 739 |
+
torch.save({
|
| 740 |
+
'epoch': epoch,
|
| 741 |
+
'model_state_dict': model.state_dict(),
|
| 742 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 743 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 744 |
+
'loss': avg_loss,
|
| 745 |
+
}, 'mammogram_byol_best.pth')
|
| 746 |
+
|
| 747 |
+
# Save checkpoints every 10 epochs (less verbose logging)
|
| 748 |
+
if epoch % 10 == 0:
|
| 749 |
+
checkpoint_path = f'mammogram_byol_epoch{epoch}.pth'
|
| 750 |
+
torch.save({
|
| 751 |
+
'epoch': epoch,
|
| 752 |
+
'model_state_dict': model.state_dict(),
|
| 753 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 754 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 755 |
+
'loss': avg_loss,
|
| 756 |
+
}, checkpoint_path)
|
| 757 |
+
|
| 758 |
+
# Final save
|
| 759 |
+
final_path = 'mammogram_byol_final.pth'
|
| 760 |
+
torch.save({
|
| 761 |
+
'epoch': EPOCHS,
|
| 762 |
+
'model_state_dict': model.state_dict(),
|
| 763 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 764 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 765 |
+
'loss': avg_loss,
|
| 766 |
+
'config': wandb.config,
|
| 767 |
+
}, final_path)
|
| 768 |
+
|
| 769 |
+
total_time = time.time() - start_time
|
| 770 |
+
print(f"\nπ₯ === MEDICAL-OPTIMIZED BYOL TRAINING COMPLETE ===")
|
| 771 |
+
print(f"β±οΈ Total training time: {total_time/3600:.1f} hours")
|
| 772 |
+
print(f"πΎ Final model saved: {final_path}")
|
| 773 |
+
print(f"π Dataset: {len(dataset):,} high-quality breast tissue tiles")
|
| 774 |
+
print(f"π‘οΈ AGGRESSIVE background rejection: Zero empty space contamination")
|
| 775 |
+
print(f"ποΈ Medical-safe augmentations: Preserves anatomical details")
|
| 776 |
+
print(f"β‘ A100 optimized: Mixed precision + per-step momentum updates")
|
| 777 |
+
print(f"π― Ready for downstream fine-tuning and classification tasks")
|
| 778 |
+
print(f"π Ready for downstream fine-tuning!")
|
| 779 |
+
|
| 780 |
+
if wandb_enabled:
|
| 781 |
+
wandb.finish()
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
if __name__ == "__main__":
|
| 785 |
+
main()
|
train_classification.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
train_classification.py
|
| 4 |
+
|
| 5 |
+
Fine-tune the BYOL pre-trained model for multi-label classification on mammogram tiles.
|
| 6 |
+
This script loads the BYOL checkpoint and trains only the classification head while
|
| 7 |
+
optionally fine-tuning the backbone with a lower learning rate.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
from torch.utils.data import Dataset, DataLoader
|
| 14 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import numpy as np
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torchvision.transforms as T
|
| 20 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import wandb
|
| 23 |
+
import argparse
|
| 24 |
+
from typing import Dict, List, Tuple
|
| 25 |
+
import json
|
| 26 |
+
|
| 27 |
+
# Import the BYOL model
|
| 28 |
+
from train_byol_mammo import MammogramBYOL
|
| 29 |
+
|
| 30 |
+
# Configuration
|
| 31 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
TILE_SIZE = 512
|
| 33 |
+
|
| 34 |
+
# Default hyperparameters - can be overridden via command line
|
| 35 |
+
DEFAULT_CONFIG = {
|
| 36 |
+
'batch_size': 32,
|
| 37 |
+
'num_workers': 8,
|
| 38 |
+
'epochs': 50,
|
| 39 |
+
'lr_backbone': 1e-5, # Lower LR for pre-trained backbone
|
| 40 |
+
'lr_head': 1e-3, # Higher LR for classification head
|
| 41 |
+
'weight_decay': 1e-4,
|
| 42 |
+
'warmup_epochs': 5,
|
| 43 |
+
'freeze_backbone_epochs': 10, # Freeze backbone for first N epochs
|
| 44 |
+
'label_smoothing': 0.1,
|
| 45 |
+
'dropout_rate': 0.3,
|
| 46 |
+
'gradient_clip': 1.0,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MammogramClassificationDataset(Dataset):
|
| 51 |
+
"""Dataset for mammogram tile classification with multi-label support."""
|
| 52 |
+
|
| 53 |
+
def __init__(self, csv_path: str, tiles_dir: str, class_names: List[str],
|
| 54 |
+
transform=None, max_samples: int = None):
|
| 55 |
+
"""
|
| 56 |
+
Args:
|
| 57 |
+
csv_path: Path to CSV with columns ['tile_path', 'class1', 'class2', ...]
|
| 58 |
+
tiles_dir: Directory containing tile images
|
| 59 |
+
class_names: List of class names (e.g., ['mass', 'calcification', 'normal', etc.])
|
| 60 |
+
transform: Image transformations
|
| 61 |
+
max_samples: Limit dataset size for testing
|
| 62 |
+
"""
|
| 63 |
+
self.tiles_dir = Path(tiles_dir)
|
| 64 |
+
self.class_names = class_names
|
| 65 |
+
self.num_classes = len(class_names)
|
| 66 |
+
self.transform = transform
|
| 67 |
+
|
| 68 |
+
# Load data
|
| 69 |
+
self.df = pd.read_csv(csv_path)
|
| 70 |
+
if max_samples:
|
| 71 |
+
self.df = self.df.head(max_samples)
|
| 72 |
+
|
| 73 |
+
print(f"π Loaded {len(self.df)} samples for classification training")
|
| 74 |
+
print(f"π·οΈ Classes: {class_names}")
|
| 75 |
+
|
| 76 |
+
# Validate required columns
|
| 77 |
+
required_cols = ['tile_path'] + class_names
|
| 78 |
+
missing_cols = [col for col in required_cols if col not in self.df.columns]
|
| 79 |
+
if missing_cols:
|
| 80 |
+
raise ValueError(f"Missing columns in CSV: {missing_cols}")
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.df)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx):
|
| 86 |
+
row = self.df.iloc[idx]
|
| 87 |
+
|
| 88 |
+
# Load image
|
| 89 |
+
img_path = self.tiles_dir / row['tile_path']
|
| 90 |
+
image = Image.open(img_path).convert('RGB')
|
| 91 |
+
|
| 92 |
+
if self.transform:
|
| 93 |
+
image = self.transform(image)
|
| 94 |
+
|
| 95 |
+
# Get multi-label targets
|
| 96 |
+
labels = torch.tensor([row[class_name] for class_name in self.class_names],
|
| 97 |
+
dtype=torch.float32)
|
| 98 |
+
|
| 99 |
+
return image, labels
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def create_classification_transforms(tile_size: int, is_training: bool = True):
|
| 103 |
+
"""Create transforms for classification training."""
|
| 104 |
+
|
| 105 |
+
if is_training:
|
| 106 |
+
# Training transforms - moderate augmentation
|
| 107 |
+
transform = T.Compose([
|
| 108 |
+
T.Resize((tile_size, tile_size)),
|
| 109 |
+
T.RandomHorizontalFlip(p=0.5),
|
| 110 |
+
T.RandomVerticalFlip(p=0.2),
|
| 111 |
+
T.RandomRotation(degrees=10, fill=0),
|
| 112 |
+
T.ColorJitter(brightness=0.2, contrast=0.2),
|
| 113 |
+
T.ToTensor(),
|
| 114 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 115 |
+
])
|
| 116 |
+
else:
|
| 117 |
+
# Validation transforms - no augmentation
|
| 118 |
+
transform = T.Compose([
|
| 119 |
+
T.Resize((tile_size, tile_size)),
|
| 120 |
+
T.ToTensor(),
|
| 121 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 122 |
+
])
|
| 123 |
+
|
| 124 |
+
return transform
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ClassificationModel(nn.Module):
|
| 128 |
+
"""Classification model that wraps BYOL backbone with classification head."""
|
| 129 |
+
|
| 130 |
+
def __init__(self, byol_model: MammogramBYOL, num_classes: int, hidden_dim: int = 2048):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.byol_model = byol_model
|
| 133 |
+
|
| 134 |
+
# Create classification head
|
| 135 |
+
self.classification_head = nn.Sequential(
|
| 136 |
+
nn.Linear(2048, hidden_dim),
|
| 137 |
+
nn.ReLU(),
|
| 138 |
+
nn.Dropout(0.3),
|
| 139 |
+
nn.Linear(hidden_dim, num_classes)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
"""Forward pass for classification."""
|
| 144 |
+
features = self.byol_model.get_features(x)
|
| 145 |
+
return self.classification_head(features)
|
| 146 |
+
|
| 147 |
+
def get_features(self, x):
|
| 148 |
+
"""Get backbone features."""
|
| 149 |
+
return self.byol_model.get_features(x)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_byol_model(checkpoint_path: str, num_classes: int, device: torch.device):
|
| 153 |
+
"""Load BYOL pre-trained model and prepare for classification."""
|
| 154 |
+
|
| 155 |
+
print(f"π₯ Loading BYOL checkpoint: {checkpoint_path}")
|
| 156 |
+
|
| 157 |
+
# Load checkpoint
|
| 158 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 159 |
+
|
| 160 |
+
# Create BYOL model with same architecture as training
|
| 161 |
+
from torchvision import models
|
| 162 |
+
resnet = models.resnet50(weights=None) # Don't load ImageNet weights
|
| 163 |
+
backbone = nn.Sequential(*list(resnet.children())[:-1])
|
| 164 |
+
|
| 165 |
+
byol_model = MammogramBYOL(
|
| 166 |
+
backbone=backbone,
|
| 167 |
+
input_dim=2048,
|
| 168 |
+
hidden_dim=4096,
|
| 169 |
+
proj_dim=256
|
| 170 |
+
).to(device)
|
| 171 |
+
|
| 172 |
+
# Load BYOL weights
|
| 173 |
+
byol_model.load_state_dict(checkpoint['model_state_dict'])
|
| 174 |
+
|
| 175 |
+
# Create classification model
|
| 176 |
+
model = ClassificationModel(byol_model, num_classes).to(device)
|
| 177 |
+
|
| 178 |
+
print(f"β
Loaded BYOL model from epoch {checkpoint.get('epoch', 'unknown')}")
|
| 179 |
+
print(f"π BYOL training loss: {checkpoint.get('loss', 'unknown'):.4f}")
|
| 180 |
+
print(f"π― Added classification head: 2048 β {2048} β {num_classes}")
|
| 181 |
+
|
| 182 |
+
return model
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def calculate_metrics(predictions: np.ndarray, targets: np.ndarray,
|
| 186 |
+
class_names: List[str]) -> Dict[str, float]:
|
| 187 |
+
"""Calculate comprehensive metrics for multi-label classification."""
|
| 188 |
+
|
| 189 |
+
metrics = {}
|
| 190 |
+
|
| 191 |
+
# Convert probabilities to binary predictions
|
| 192 |
+
pred_binary = (predictions > 0.5).astype(int)
|
| 193 |
+
|
| 194 |
+
# Per-class metrics
|
| 195 |
+
for i, class_name in enumerate(class_names):
|
| 196 |
+
try:
|
| 197 |
+
# AUC-ROC per class
|
| 198 |
+
auc = roc_auc_score(targets[:, i], predictions[:, i])
|
| 199 |
+
metrics[f'auc_{class_name}'] = auc
|
| 200 |
+
|
| 201 |
+
# Average Precision per class
|
| 202 |
+
ap = average_precision_score(targets[:, i], predictions[:, i])
|
| 203 |
+
metrics[f'ap_{class_name}'] = ap
|
| 204 |
+
|
| 205 |
+
# Accuracy per class
|
| 206 |
+
acc = accuracy_score(targets[:, i], pred_binary[:, i])
|
| 207 |
+
metrics[f'acc_{class_name}'] = acc
|
| 208 |
+
|
| 209 |
+
except ValueError:
|
| 210 |
+
# Handle case where all samples are negative for this class
|
| 211 |
+
metrics[f'auc_{class_name}'] = 0.0
|
| 212 |
+
metrics[f'ap_{class_name}'] = 0.0
|
| 213 |
+
metrics[f'acc_{class_name}'] = accuracy_score(targets[:, i], pred_binary[:, i])
|
| 214 |
+
|
| 215 |
+
# Overall metrics
|
| 216 |
+
metrics['mean_auc'] = np.mean([metrics[f'auc_{class_name}'] for class_name in class_names])
|
| 217 |
+
metrics['mean_ap'] = np.mean([metrics[f'ap_{class_name}'] for class_name in class_names])
|
| 218 |
+
metrics['mean_acc'] = np.mean([metrics[f'acc_{class_name}'] for class_name in class_names])
|
| 219 |
+
|
| 220 |
+
# Exact match accuracy (all labels correct)
|
| 221 |
+
exact_match = np.all(pred_binary == targets, axis=1).mean()
|
| 222 |
+
metrics['exact_match_acc'] = exact_match
|
| 223 |
+
|
| 224 |
+
return metrics
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
|
| 228 |
+
optimizer: optim.Optimizer, scaler: GradScaler, epoch: int,
|
| 229 |
+
config: dict, freeze_backbone: bool = False) -> Dict[str, float]:
|
| 230 |
+
"""Train for one epoch."""
|
| 231 |
+
|
| 232 |
+
model.train()
|
| 233 |
+
total_loss = 0.0
|
| 234 |
+
num_batches = len(dataloader)
|
| 235 |
+
|
| 236 |
+
# Freeze backbone if specified
|
| 237 |
+
if freeze_backbone:
|
| 238 |
+
for param in model.byol_model.backbone.parameters():
|
| 239 |
+
param.requires_grad = False
|
| 240 |
+
for param in model.byol_model.backbone_momentum.parameters():
|
| 241 |
+
param.requires_grad = False
|
| 242 |
+
else:
|
| 243 |
+
for param in model.byol_model.backbone.parameters():
|
| 244 |
+
param.requires_grad = True
|
| 245 |
+
|
| 246 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch:3d}/{config['epochs']} [Train]",
|
| 247 |
+
ncols=100, leave=False)
|
| 248 |
+
|
| 249 |
+
for batch_idx, (images, labels) in enumerate(pbar):
|
| 250 |
+
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
| 251 |
+
|
| 252 |
+
optimizer.zero_grad()
|
| 253 |
+
|
| 254 |
+
with autocast():
|
| 255 |
+
# Forward pass through classification model
|
| 256 |
+
outputs = model(images)
|
| 257 |
+
loss = criterion(outputs, labels)
|
| 258 |
+
|
| 259 |
+
# Backward pass
|
| 260 |
+
scaler.scale(loss).backward()
|
| 261 |
+
scaler.unscale_(optimizer)
|
| 262 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
|
| 263 |
+
scaler.step(optimizer)
|
| 264 |
+
scaler.update()
|
| 265 |
+
|
| 266 |
+
total_loss += loss.item()
|
| 267 |
+
|
| 268 |
+
# Update progress bar
|
| 269 |
+
pbar.set_postfix({
|
| 270 |
+
'Loss': f'{loss.item():.4f}',
|
| 271 |
+
'Avg': f'{total_loss/(batch_idx+1):.4f}',
|
| 272 |
+
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}'
|
| 273 |
+
})
|
| 274 |
+
|
| 275 |
+
return {'train_loss': total_loss / num_batches}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def validate_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
|
| 279 |
+
class_names: List[str]) -> Dict[str, float]:
|
| 280 |
+
"""Validate for one epoch."""
|
| 281 |
+
|
| 282 |
+
model.eval()
|
| 283 |
+
total_loss = 0.0
|
| 284 |
+
all_predictions = []
|
| 285 |
+
all_targets = []
|
| 286 |
+
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
pbar = tqdm(dataloader, desc="Validation", ncols=100, leave=False)
|
| 289 |
+
|
| 290 |
+
for images, labels in pbar:
|
| 291 |
+
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
| 292 |
+
|
| 293 |
+
with autocast():
|
| 294 |
+
outputs = model(images)
|
| 295 |
+
loss = criterion(outputs, labels)
|
| 296 |
+
|
| 297 |
+
total_loss += loss.item()
|
| 298 |
+
|
| 299 |
+
# Convert outputs to probabilities
|
| 300 |
+
probs = torch.sigmoid(outputs)
|
| 301 |
+
|
| 302 |
+
all_predictions.append(probs.cpu().numpy())
|
| 303 |
+
all_targets.append(labels.cpu().numpy())
|
| 304 |
+
|
| 305 |
+
pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
|
| 306 |
+
|
| 307 |
+
# Concatenate all predictions and targets
|
| 308 |
+
predictions = np.concatenate(all_predictions, axis=0)
|
| 309 |
+
targets = np.concatenate(all_targets, axis=0)
|
| 310 |
+
|
| 311 |
+
# Calculate metrics
|
| 312 |
+
metrics = calculate_metrics(predictions, targets, class_names)
|
| 313 |
+
metrics['val_loss'] = total_loss / len(dataloader)
|
| 314 |
+
|
| 315 |
+
return metrics
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def main():
|
| 319 |
+
parser = argparse.ArgumentParser(description='Fine-tune BYOL model for classification')
|
| 320 |
+
parser.add_argument('--byol_checkpoint', type=str, required=True,
|
| 321 |
+
help='Path to BYOL checkpoint (.pth file)')
|
| 322 |
+
parser.add_argument('--train_csv', type=str, required=True,
|
| 323 |
+
help='Path to training CSV file')
|
| 324 |
+
parser.add_argument('--val_csv', type=str, required=True,
|
| 325 |
+
help='Path to validation CSV file')
|
| 326 |
+
parser.add_argument('--tiles_dir', type=str, required=True,
|
| 327 |
+
help='Directory containing tile images')
|
| 328 |
+
parser.add_argument('--class_names', type=str, nargs='+', required=True,
|
| 329 |
+
help='List of class names (e.g., mass calcification normal)')
|
| 330 |
+
parser.add_argument('--output_dir', type=str, default='./classification_results',
|
| 331 |
+
help='Output directory for checkpoints and logs')
|
| 332 |
+
parser.add_argument('--config', type=str, default=None,
|
| 333 |
+
help='JSON config file (overrides defaults)')
|
| 334 |
+
parser.add_argument('--wandb_project', type=str, default='mammogram-classification',
|
| 335 |
+
help='Weights & Biases project name')
|
| 336 |
+
parser.add_argument('--max_samples', type=int, default=None,
|
| 337 |
+
help='Limit dataset size for testing')
|
| 338 |
+
|
| 339 |
+
args = parser.parse_args()
|
| 340 |
+
|
| 341 |
+
# Load configuration
|
| 342 |
+
config = DEFAULT_CONFIG.copy()
|
| 343 |
+
if args.config:
|
| 344 |
+
with open(args.config, 'r') as f:
|
| 345 |
+
config.update(json.load(f))
|
| 346 |
+
|
| 347 |
+
# Create output directory
|
| 348 |
+
output_dir = Path(args.output_dir)
|
| 349 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 350 |
+
|
| 351 |
+
# Initialize wandb
|
| 352 |
+
try:
|
| 353 |
+
wandb.init(
|
| 354 |
+
project=args.wandb_project,
|
| 355 |
+
config=config,
|
| 356 |
+
name=f"classification_fine_tune_{len(args.class_names)}classes"
|
| 357 |
+
)
|
| 358 |
+
wandb_enabled = True
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(f"β οΈ WandB not configured: {e}")
|
| 361 |
+
wandb_enabled = False
|
| 362 |
+
|
| 363 |
+
print("π¬ BYOL Classification Fine-Tuning")
|
| 364 |
+
print("=" * 50)
|
| 365 |
+
print(f"Device: {DEVICE}")
|
| 366 |
+
print(f"Classes: {args.class_names}")
|
| 367 |
+
print(f"Batch size: {config['batch_size']}")
|
| 368 |
+
print(f"Epochs: {config['epochs']}")
|
| 369 |
+
print(f"Output directory: {output_dir}")
|
| 370 |
+
|
| 371 |
+
# Load model
|
| 372 |
+
model = load_byol_model(args.byol_checkpoint, len(args.class_names), DEVICE)
|
| 373 |
+
|
| 374 |
+
# Create datasets
|
| 375 |
+
train_transform = create_classification_transforms(TILE_SIZE, is_training=True)
|
| 376 |
+
val_transform = create_classification_transforms(TILE_SIZE, is_training=False)
|
| 377 |
+
|
| 378 |
+
train_dataset = MammogramClassificationDataset(
|
| 379 |
+
args.train_csv, args.tiles_dir, args.class_names,
|
| 380 |
+
train_transform, max_samples=args.max_samples
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
val_dataset = MammogramClassificationDataset(
|
| 384 |
+
args.val_csv, args.tiles_dir, args.class_names,
|
| 385 |
+
val_transform, max_samples=args.max_samples
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Create data loaders
|
| 389 |
+
train_loader = DataLoader(
|
| 390 |
+
train_dataset,
|
| 391 |
+
batch_size=config['batch_size'],
|
| 392 |
+
shuffle=True,
|
| 393 |
+
num_workers=config['num_workers'],
|
| 394 |
+
pin_memory=True,
|
| 395 |
+
drop_last=True
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
val_loader = DataLoader(
|
| 399 |
+
val_dataset,
|
| 400 |
+
batch_size=config['batch_size'],
|
| 401 |
+
shuffle=False,
|
| 402 |
+
num_workers=config['num_workers'],
|
| 403 |
+
pin_memory=True
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
print(f"π Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}")
|
| 407 |
+
|
| 408 |
+
# Setup loss and optimizer
|
| 409 |
+
# Use BCEWithLogitsLoss for multi-label classification
|
| 410 |
+
pos_weight = None # Could be calculated from class distribution if needed
|
| 411 |
+
criterion = nn.BCEWithLogitsLoss(
|
| 412 |
+
pos_weight=pos_weight,
|
| 413 |
+
label_smoothing=config['label_smoothing']
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Different learning rates for backbone and classification head
|
| 417 |
+
backbone_params = list(model.byol_model.backbone.parameters())
|
| 418 |
+
head_params = list(model.classification_head.parameters())
|
| 419 |
+
|
| 420 |
+
optimizer = optim.AdamW([
|
| 421 |
+
{'params': backbone_params, 'lr': config['lr_backbone']},
|
| 422 |
+
{'params': head_params, 'lr': config['lr_head']}
|
| 423 |
+
], weight_decay=config['weight_decay'])
|
| 424 |
+
|
| 425 |
+
# Learning rate scheduler
|
| 426 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 427 |
+
optimizer, T_max=config['epochs'], eta_min=1e-6
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Mixed precision scaler
|
| 431 |
+
scaler = GradScaler()
|
| 432 |
+
|
| 433 |
+
# Training loop
|
| 434 |
+
best_metric = 0.0
|
| 435 |
+
|
| 436 |
+
for epoch in range(1, config['epochs'] + 1):
|
| 437 |
+
# Decide whether to freeze backbone
|
| 438 |
+
freeze_backbone = epoch <= config['freeze_backbone_epochs']
|
| 439 |
+
if freeze_backbone:
|
| 440 |
+
print(f"π§ Epoch {epoch}: Backbone frozen (training only classification head)")
|
| 441 |
+
|
| 442 |
+
# Train
|
| 443 |
+
train_metrics = train_epoch(
|
| 444 |
+
model, train_loader, criterion, optimizer, scaler,
|
| 445 |
+
epoch, config, freeze_backbone
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Validate
|
| 449 |
+
val_metrics = validate_epoch(model, val_loader, criterion, args.class_names)
|
| 450 |
+
|
| 451 |
+
# Step scheduler
|
| 452 |
+
scheduler.step()
|
| 453 |
+
|
| 454 |
+
# Print metrics
|
| 455 |
+
print(f"\nEpoch {epoch:3d}/{config['epochs']}:")
|
| 456 |
+
print(f" Train Loss: {train_metrics['train_loss']:.4f}")
|
| 457 |
+
print(f" Val Loss: {val_metrics['val_loss']:.4f}")
|
| 458 |
+
print(f" Mean AUC: {val_metrics['mean_auc']:.4f}")
|
| 459 |
+
print(f" Mean AP: {val_metrics['mean_ap']:.4f}")
|
| 460 |
+
print(f" Exact Match: {val_metrics['exact_match_acc']:.4f}")
|
| 461 |
+
|
| 462 |
+
# Log to wandb
|
| 463 |
+
if wandb_enabled:
|
| 464 |
+
log_dict = {**train_metrics, **val_metrics, 'epoch': epoch}
|
| 465 |
+
wandb.log(log_dict)
|
| 466 |
+
|
| 467 |
+
# Save best model
|
| 468 |
+
current_metric = val_metrics['mean_auc']
|
| 469 |
+
if current_metric > best_metric:
|
| 470 |
+
best_metric = current_metric
|
| 471 |
+
checkpoint = {
|
| 472 |
+
'epoch': epoch,
|
| 473 |
+
'model_state_dict': model.state_dict(),
|
| 474 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 475 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 476 |
+
'val_metrics': val_metrics,
|
| 477 |
+
'config': config,
|
| 478 |
+
'class_names': args.class_names
|
| 479 |
+
}
|
| 480 |
+
torch.save(checkpoint, output_dir / 'best_classification_model.pth')
|
| 481 |
+
print(f" β
New best model saved (AUC: {best_metric:.4f})")
|
| 482 |
+
|
| 483 |
+
# Save periodic checkpoints
|
| 484 |
+
if epoch % 10 == 0:
|
| 485 |
+
checkpoint = {
|
| 486 |
+
'epoch': epoch,
|
| 487 |
+
'model_state_dict': model.state_dict(),
|
| 488 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 489 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 490 |
+
'val_metrics': val_metrics,
|
| 491 |
+
'config': config,
|
| 492 |
+
'class_names': args.class_names
|
| 493 |
+
}
|
| 494 |
+
torch.save(checkpoint, output_dir / f'classification_epoch_{epoch}.pth')
|
| 495 |
+
|
| 496 |
+
# Save final model
|
| 497 |
+
final_checkpoint = {
|
| 498 |
+
'epoch': config['epochs'],
|
| 499 |
+
'model_state_dict': model.state_dict(),
|
| 500 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 501 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 502 |
+
'val_metrics': val_metrics,
|
| 503 |
+
'config': config,
|
| 504 |
+
'class_names': args.class_names
|
| 505 |
+
}
|
| 506 |
+
torch.save(final_checkpoint, output_dir / 'final_classification_model.pth')
|
| 507 |
+
|
| 508 |
+
print(f"\nπ Classification training completed!")
|
| 509 |
+
print(f"π Best validation AUC: {best_metric:.4f}")
|
| 510 |
+
print(f"πΎ Models saved to: {output_dir}")
|
| 511 |
+
|
| 512 |
+
if wandb_enabled:
|
| 513 |
+
wandb.finish()
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if __name__ == "__main__":
|
| 517 |
+
main()
|