BYOL_Mammogram / CLASSIFICATION_GUIDE.md
PranayPalem's picture
πŸ₯ Add BYOL Mammogram Classification Model
d921913
# 🎯 Classification Training Guide
Complete guide for fine-tuning the BYOL pre-trained model for multi-label classification.
## πŸ“‹ Overview
After BYOL pre-training completes, you can fine-tune the model for classification using the `train_classification.py` script. This approach:
1. **Loads the BYOL checkpoint** with learned representations
2. **Freezes the backbone** initially (optional) to prevent overwriting good features
3. **Fine-tunes the classification head** with a higher learning rate
4. **Gradually unfreezes** the backbone for end-to-end fine-tuning
## πŸ—‚οΈ Data Preparation
### CSV Format
Create train/validation CSV files with this format:
```csv
tile_path,mass,calcification,architectural_distortion,asymmetry,normal,benign,malignant,birads_2,birads_3,birads_4
patient1_tile_001.png,1,0,0,0,0,1,0,0,1,0
patient1_tile_002.png,0,1,0,0,0,0,1,0,0,1
patient2_tile_001.png,0,0,0,0,1,1,0,1,0,0
...
```
**Requirements:**
- `tile_path`: Relative path to tile image
- **Class columns**: Binary labels (0/1) for each class
- **Multi-label support**: Each image can have multiple classes = 1
### Directory Structure
```
your_project/
β”œβ”€β”€ tiles/ # Directory containing tile images
β”‚ β”œβ”€β”€ patient1_tile_001.png
β”‚ β”œβ”€β”€ patient1_tile_002.png
β”‚ └── ...
β”œβ”€β”€ train_labels.csv # Training labels
β”œβ”€β”€ val_labels.csv # Validation labels
└── mammogram_byol_best.pth # BYOL checkpoint
```
## πŸš€ Quick Start
### 1. Basic Classification Training
```bash
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification architectural_distortion asymmetry normal benign malignant birads_2 birads_3 birads_4 \
--output_dir ./classification_results
```
### 2. With Custom Configuration
```bash
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification normal \
--config ./classification_config.json \
--output_dir ./classification_results \
--wandb_project my-mammogram-classification
```
### 3. Quick Testing (Limited Dataset)
```bash
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification normal \
--max_samples 1000 \
--output_dir ./test_results
```
## βš™οΈ Configuration Options
### Key Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `batch_size` | 32 | Batch size for training |
| `epochs` | 50 | Number of training epochs |
| `lr_backbone` | 1e-5 | Learning rate for pre-trained backbone |
| `lr_head` | 1e-3 | Learning rate for classification head |
| `freeze_backbone_epochs` | 10 | Epochs to freeze backbone (0 = never freeze) |
| `label_smoothing` | 0.1 | Label smoothing for regularization |
| `gradient_clip` | 1.0 | Gradient clipping max norm |
### Custom Configuration File
Create `my_config.json`:
```json
{
"batch_size": 64,
"epochs": 100,
"lr_backbone": 5e-6,
"lr_head": 2e-3,
"freeze_backbone_epochs": 20,
"label_smoothing": 0.2,
"weight_decay": 1e-3
}
```
## πŸ“Š Expected Training Process
### Phase 1: Backbone Frozen (Epochs 1-10)
```
🧊 Epoch 1: Backbone frozen (training only classification head)
Epoch 1/50:
Train Loss: 0.6234
Val Loss: 0.5891
Mean AUC: 0.7123
Mean AP: 0.6894
Exact Match: 0.4512
βœ… New best model saved (AUC: 0.7123)
```
### Phase 2: End-to-End Fine-tuning (Epochs 11-50)
```
Epoch 15/50:
Train Loss: 0.3456
Val Loss: 0.3891
Mean AUC: 0.8567
Mean AP: 0.8234
Exact Match: 0.6789
βœ… New best model saved (AUC: 0.8567)
```
## πŸ” Making Predictions
### Single Image Inference
```bash
python inference_classification.py \
--model_path ./classification_results/best_classification_model.pth \
--image_path ./test_image.png \
--threshold 0.5
```
**Output:**
```
πŸ“Έ Image 1: test_image.png
πŸ† Top prediction: mass (0.847)
πŸ“Š All probabilities:
βœ… mass : 0.847
❌ calcification : 0.234
❌ normal : 0.123
❌ architectural_distortion: 0.089
```
### Batch Inference
```bash
python inference_classification.py \
--model_path ./classification_results/best_classification_model.pth \
--images_dir ./test_images \
--output_json ./predictions.json \
--batch_size 64
```
### Programmatic Usage
```python
import torch
from train_byol_mammo import MammogramBYOL
from inference_classification import load_classification_model, create_inference_transform
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, class_names, config = load_classification_model(
"./classification_results/best_classification_model.pth", device
)
# Make prediction
transform = create_inference_transform()
image = Image.open("test.png").convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model.classify(input_tensor)
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
# Get results
for i, class_name in enumerate(class_names):
print(f"{class_name}: {probabilities[i]:.3f}")
```
## πŸ“ˆ Monitoring Training
### Weights & Biases Integration
The script automatically logs to W&B:
- Training/validation loss curves
- Per-class AUC and Average Precision
- Learning rate schedules
- Model hyperparameters
### Metrics Explained
- **AUC (Area Under Curve)**: Measures ranking quality (0-1, higher better)
- **AP (Average Precision)**: Summarizes precision-recall curve (0-1, higher better)
- **Exact Match Accuracy**: Percentage where ALL labels are predicted correctly
- **Per-Class Accuracy**: Binary accuracy for each individual class
## πŸ’Ύ Output Files
Training creates:
```
classification_results/
β”œβ”€β”€ best_classification_model.pth # Best model by validation AUC
β”œβ”€β”€ final_classification_model.pth # Final model after all epochs
β”œβ”€β”€ classification_epoch_10.pth # Periodic checkpoints
β”œβ”€β”€ classification_epoch_20.pth
└── ...
```
Each checkpoint contains:
- Model state dict
- Optimizer state
- Training configuration
- Class names
- Validation metrics
## πŸ› οΈ Advanced Usage
### Custom Loss Functions
For imbalanced datasets, modify the loss function:
```python
# Calculate positive weights for each class
pos_counts = df[class_names].sum()
neg_counts = len(df) - pos_counts
pos_weight = torch.tensor(neg_counts / pos_counts).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
```
### Transfer Learning Strategies
1. **Conservative**: Freeze backbone for many epochs, low backbone LR
- `freeze_backbone_epochs = 20`
- `lr_backbone = 1e-6`
2. **Aggressive**: Unfreeze early, higher backbone LR
- `freeze_backbone_epochs = 5`
- `lr_backbone = 1e-4`
3. **Progressive**: Gradually unfreeze layers (requires code modification)
### Multi-GPU Training
For multiple GPUs, wrap the model:
```python
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
```
## ⚠️ Troubleshooting
### Common Issues
**Low Validation Performance:**
- Increase `freeze_backbone_epochs` to 15-20
- Reduce `lr_backbone` to 5e-6 or 1e-6
- Check for data leakage between train/val sets
**Overfitting:**
- Increase `label_smoothing` to 0.2-0.3
- Add more dropout (modify model architecture)
- Reduce learning rates
- Use early stopping
**Memory Issues:**
- Reduce `batch_size` to 16 or 8
- Reduce `num_workers` to 4
- Use gradient checkpointing (requires code modification)
**Class Imbalance:**
- Use `pos_weight` in loss function
- Focus on per-class AUC rather than accuracy
- Consider focal loss for extreme imbalance
## 🎯 Best Practices
1. **Start Conservative**: Use default settings first
2. **Monitor Per-Class Metrics**: Some classes may need special attention
3. **Validate Data**: Ensure no train/val overlap
4. **Checkpoint Often**: Training can be interrupted
5. **Use Multiple Runs**: Average results across random seeds
6. **Test Thoroughly**: Use held-out test set for final evaluation
## πŸ“š Complete Example
Here's a full workflow from BYOL training to classification:
```bash
# 1. Train BYOL (this takes 4-5 hours on A100)
python train_byol_mammo.py
# 2. Prepare classification data (create CSVs with labels)
# ... prepare train_labels.csv and val_labels.csv ...
# 3. Fine-tune for classification (1-2 hours)
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification architectural_distortion asymmetry normal \
--output_dir ./classification_results
# 4. Run inference on new images
python inference_classification.py \
--model_path ./classification_results/best_classification_model.pth \
--images_dir ./new_patient_tiles \
--output_json ./patient_predictions.json
```
This gives you a complete pipeline from self-supervised pre-training to production-ready classification! πŸš€