| # π― Classification Training Guide | |
| Complete guide for fine-tuning the BYOL pre-trained model for multi-label classification. | |
| ## π Overview | |
| After BYOL pre-training completes, you can fine-tune the model for classification using the `train_classification.py` script. This approach: | |
| 1. **Loads the BYOL checkpoint** with learned representations | |
| 2. **Freezes the backbone** initially (optional) to prevent overwriting good features | |
| 3. **Fine-tunes the classification head** with a higher learning rate | |
| 4. **Gradually unfreezes** the backbone for end-to-end fine-tuning | |
| ## ποΈ Data Preparation | |
| ### CSV Format | |
| Create train/validation CSV files with this format: | |
| ```csv | |
| tile_path,mass,calcification,architectural_distortion,asymmetry,normal,benign,malignant,birads_2,birads_3,birads_4 | |
| patient1_tile_001.png,1,0,0,0,0,1,0,0,1,0 | |
| patient1_tile_002.png,0,1,0,0,0,0,1,0,0,1 | |
| patient2_tile_001.png,0,0,0,0,1,1,0,1,0,0 | |
| ... | |
| ``` | |
| **Requirements:** | |
| - `tile_path`: Relative path to tile image | |
| - **Class columns**: Binary labels (0/1) for each class | |
| - **Multi-label support**: Each image can have multiple classes = 1 | |
| ### Directory Structure | |
| ``` | |
| your_project/ | |
| βββ tiles/ # Directory containing tile images | |
| β βββ patient1_tile_001.png | |
| β βββ patient1_tile_002.png | |
| β βββ ... | |
| βββ train_labels.csv # Training labels | |
| βββ val_labels.csv # Validation labels | |
| βββ mammogram_byol_best.pth # BYOL checkpoint | |
| ``` | |
| ## π Quick Start | |
| ### 1. Basic Classification Training | |
| ```bash | |
| python train_classification.py \ | |
| --byol_checkpoint ./mammogram_byol_best.pth \ | |
| --train_csv ./train_labels.csv \ | |
| --val_csv ./val_labels.csv \ | |
| --tiles_dir ./tiles \ | |
| --class_names mass calcification architectural_distortion asymmetry normal benign malignant birads_2 birads_3 birads_4 \ | |
| --output_dir ./classification_results | |
| ``` | |
| ### 2. With Custom Configuration | |
| ```bash | |
| python train_classification.py \ | |
| --byol_checkpoint ./mammogram_byol_best.pth \ | |
| --train_csv ./train_labels.csv \ | |
| --val_csv ./val_labels.csv \ | |
| --tiles_dir ./tiles \ | |
| --class_names mass calcification normal \ | |
| --config ./classification_config.json \ | |
| --output_dir ./classification_results \ | |
| --wandb_project my-mammogram-classification | |
| ``` | |
| ### 3. Quick Testing (Limited Dataset) | |
| ```bash | |
| python train_classification.py \ | |
| --byol_checkpoint ./mammogram_byol_best.pth \ | |
| --train_csv ./train_labels.csv \ | |
| --val_csv ./val_labels.csv \ | |
| --tiles_dir ./tiles \ | |
| --class_names mass calcification normal \ | |
| --max_samples 1000 \ | |
| --output_dir ./test_results | |
| ``` | |
| ## βοΈ Configuration Options | |
| ### Key Parameters | |
| | Parameter | Default | Description | | |
| |-----------|---------|-------------| | |
| | `batch_size` | 32 | Batch size for training | | |
| | `epochs` | 50 | Number of training epochs | | |
| | `lr_backbone` | 1e-5 | Learning rate for pre-trained backbone | | |
| | `lr_head` | 1e-3 | Learning rate for classification head | | |
| | `freeze_backbone_epochs` | 10 | Epochs to freeze backbone (0 = never freeze) | | |
| | `label_smoothing` | 0.1 | Label smoothing for regularization | | |
| | `gradient_clip` | 1.0 | Gradient clipping max norm | | |
| ### Custom Configuration File | |
| Create `my_config.json`: | |
| ```json | |
| { | |
| "batch_size": 64, | |
| "epochs": 100, | |
| "lr_backbone": 5e-6, | |
| "lr_head": 2e-3, | |
| "freeze_backbone_epochs": 20, | |
| "label_smoothing": 0.2, | |
| "weight_decay": 1e-3 | |
| } | |
| ``` | |
| ## π Expected Training Process | |
| ### Phase 1: Backbone Frozen (Epochs 1-10) | |
| ``` | |
| π§ Epoch 1: Backbone frozen (training only classification head) | |
| Epoch 1/50: | |
| Train Loss: 0.6234 | |
| Val Loss: 0.5891 | |
| Mean AUC: 0.7123 | |
| Mean AP: 0.6894 | |
| Exact Match: 0.4512 | |
| β New best model saved (AUC: 0.7123) | |
| ``` | |
| ### Phase 2: End-to-End Fine-tuning (Epochs 11-50) | |
| ``` | |
| Epoch 15/50: | |
| Train Loss: 0.3456 | |
| Val Loss: 0.3891 | |
| Mean AUC: 0.8567 | |
| Mean AP: 0.8234 | |
| Exact Match: 0.6789 | |
| β New best model saved (AUC: 0.8567) | |
| ``` | |
| ## π Making Predictions | |
| ### Single Image Inference | |
| ```bash | |
| python inference_classification.py \ | |
| --model_path ./classification_results/best_classification_model.pth \ | |
| --image_path ./test_image.png \ | |
| --threshold 0.5 | |
| ``` | |
| **Output:** | |
| ``` | |
| πΈ Image 1: test_image.png | |
| π Top prediction: mass (0.847) | |
| π All probabilities: | |
| β mass : 0.847 | |
| β calcification : 0.234 | |
| β normal : 0.123 | |
| β architectural_distortion: 0.089 | |
| ``` | |
| ### Batch Inference | |
| ```bash | |
| python inference_classification.py \ | |
| --model_path ./classification_results/best_classification_model.pth \ | |
| --images_dir ./test_images \ | |
| --output_json ./predictions.json \ | |
| --batch_size 64 | |
| ``` | |
| ### Programmatic Usage | |
| ```python | |
| import torch | |
| from train_byol_mammo import MammogramBYOL | |
| from inference_classification import load_classification_model, create_inference_transform | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model, class_names, config = load_classification_model( | |
| "./classification_results/best_classification_model.pth", device | |
| ) | |
| # Make prediction | |
| transform = create_inference_transform() | |
| image = Image.open("test.png").convert('RGB') | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model.classify(input_tensor) | |
| probabilities = torch.sigmoid(logits).cpu().numpy()[0] | |
| # Get results | |
| for i, class_name in enumerate(class_names): | |
| print(f"{class_name}: {probabilities[i]:.3f}") | |
| ``` | |
| ## π Monitoring Training | |
| ### Weights & Biases Integration | |
| The script automatically logs to W&B: | |
| - Training/validation loss curves | |
| - Per-class AUC and Average Precision | |
| - Learning rate schedules | |
| - Model hyperparameters | |
| ### Metrics Explained | |
| - **AUC (Area Under Curve)**: Measures ranking quality (0-1, higher better) | |
| - **AP (Average Precision)**: Summarizes precision-recall curve (0-1, higher better) | |
| - **Exact Match Accuracy**: Percentage where ALL labels are predicted correctly | |
| - **Per-Class Accuracy**: Binary accuracy for each individual class | |
| ## πΎ Output Files | |
| Training creates: | |
| ``` | |
| classification_results/ | |
| βββ best_classification_model.pth # Best model by validation AUC | |
| βββ final_classification_model.pth # Final model after all epochs | |
| βββ classification_epoch_10.pth # Periodic checkpoints | |
| βββ classification_epoch_20.pth | |
| βββ ... | |
| ``` | |
| Each checkpoint contains: | |
| - Model state dict | |
| - Optimizer state | |
| - Training configuration | |
| - Class names | |
| - Validation metrics | |
| ## π οΈ Advanced Usage | |
| ### Custom Loss Functions | |
| For imbalanced datasets, modify the loss function: | |
| ```python | |
| # Calculate positive weights for each class | |
| pos_counts = df[class_names].sum() | |
| neg_counts = len(df) - pos_counts | |
| pos_weight = torch.tensor(neg_counts / pos_counts).to(device) | |
| criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
| ``` | |
| ### Transfer Learning Strategies | |
| 1. **Conservative**: Freeze backbone for many epochs, low backbone LR | |
| - `freeze_backbone_epochs = 20` | |
| - `lr_backbone = 1e-6` | |
| 2. **Aggressive**: Unfreeze early, higher backbone LR | |
| - `freeze_backbone_epochs = 5` | |
| - `lr_backbone = 1e-4` | |
| 3. **Progressive**: Gradually unfreeze layers (requires code modification) | |
| ### Multi-GPU Training | |
| For multiple GPUs, wrap the model: | |
| ```python | |
| if torch.cuda.device_count() > 1: | |
| model = nn.DataParallel(model) | |
| ``` | |
| ## β οΈ Troubleshooting | |
| ### Common Issues | |
| **Low Validation Performance:** | |
| - Increase `freeze_backbone_epochs` to 15-20 | |
| - Reduce `lr_backbone` to 5e-6 or 1e-6 | |
| - Check for data leakage between train/val sets | |
| **Overfitting:** | |
| - Increase `label_smoothing` to 0.2-0.3 | |
| - Add more dropout (modify model architecture) | |
| - Reduce learning rates | |
| - Use early stopping | |
| **Memory Issues:** | |
| - Reduce `batch_size` to 16 or 8 | |
| - Reduce `num_workers` to 4 | |
| - Use gradient checkpointing (requires code modification) | |
| **Class Imbalance:** | |
| - Use `pos_weight` in loss function | |
| - Focus on per-class AUC rather than accuracy | |
| - Consider focal loss for extreme imbalance | |
| ## π― Best Practices | |
| 1. **Start Conservative**: Use default settings first | |
| 2. **Monitor Per-Class Metrics**: Some classes may need special attention | |
| 3. **Validate Data**: Ensure no train/val overlap | |
| 4. **Checkpoint Often**: Training can be interrupted | |
| 5. **Use Multiple Runs**: Average results across random seeds | |
| 6. **Test Thoroughly**: Use held-out test set for final evaluation | |
| ## π Complete Example | |
| Here's a full workflow from BYOL training to classification: | |
| ```bash | |
| # 1. Train BYOL (this takes 4-5 hours on A100) | |
| python train_byol_mammo.py | |
| # 2. Prepare classification data (create CSVs with labels) | |
| # ... prepare train_labels.csv and val_labels.csv ... | |
| # 3. Fine-tune for classification (1-2 hours) | |
| python train_classification.py \ | |
| --byol_checkpoint ./mammogram_byol_best.pth \ | |
| --train_csv ./train_labels.csv \ | |
| --val_csv ./val_labels.csv \ | |
| --tiles_dir ./tiles \ | |
| --class_names mass calcification architectural_distortion asymmetry normal \ | |
| --output_dir ./classification_results | |
| # 4. Run inference on new images | |
| python inference_classification.py \ | |
| --model_path ./classification_results/best_classification_model.pth \ | |
| --images_dir ./new_patient_tiles \ | |
| --output_json ./patient_predictions.json | |
| ``` | |
| This gives you a complete pipeline from self-supervised pre-training to production-ready classification! π |