File size: 9,376 Bytes
d921913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# 🎯 Classification Training Guide

Complete guide for fine-tuning the BYOL pre-trained model for multi-label classification.

## πŸ“‹ Overview

After BYOL pre-training completes, you can fine-tune the model for classification using the `train_classification.py` script. This approach:

1. **Loads the BYOL checkpoint** with learned representations
2. **Freezes the backbone** initially (optional) to prevent overwriting good features
3. **Fine-tunes the classification head** with a higher learning rate
4. **Gradually unfreezes** the backbone for end-to-end fine-tuning

## πŸ—‚οΈ Data Preparation

### CSV Format
Create train/validation CSV files with this format:

```csv
tile_path,mass,calcification,architectural_distortion,asymmetry,normal,benign,malignant,birads_2,birads_3,birads_4
patient1_tile_001.png,1,0,0,0,0,1,0,0,1,0
patient1_tile_002.png,0,1,0,0,0,0,1,0,0,1
patient2_tile_001.png,0,0,0,0,1,1,0,1,0,0
...
```

**Requirements:**
- `tile_path`: Relative path to tile image
- **Class columns**: Binary labels (0/1) for each class
- **Multi-label support**: Each image can have multiple classes = 1

### Directory Structure
```
your_project/
β”œβ”€β”€ tiles/                    # Directory containing tile images
β”‚   β”œβ”€β”€ patient1_tile_001.png
β”‚   β”œβ”€β”€ patient1_tile_002.png
β”‚   └── ...
β”œβ”€β”€ train_labels.csv         # Training labels
β”œβ”€β”€ val_labels.csv          # Validation labels
└── mammogram_byol_best.pth # BYOL checkpoint
```

## πŸš€ Quick Start

### 1. Basic Classification Training

```bash
python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification architectural_distortion asymmetry normal benign malignant birads_2 birads_3 birads_4 \
    --output_dir ./classification_results
```

### 2. With Custom Configuration

```bash
python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification normal \
    --config ./classification_config.json \
    --output_dir ./classification_results \
    --wandb_project my-mammogram-classification
```

### 3. Quick Testing (Limited Dataset)

```bash
python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification normal \
    --max_samples 1000 \
    --output_dir ./test_results
```

## βš™οΈ Configuration Options

### Key Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `batch_size` | 32 | Batch size for training |
| `epochs` | 50 | Number of training epochs |
| `lr_backbone` | 1e-5 | Learning rate for pre-trained backbone |
| `lr_head` | 1e-3 | Learning rate for classification head |
| `freeze_backbone_epochs` | 10 | Epochs to freeze backbone (0 = never freeze) |
| `label_smoothing` | 0.1 | Label smoothing for regularization |
| `gradient_clip` | 1.0 | Gradient clipping max norm |

### Custom Configuration File

Create `my_config.json`:
```json
{
  "batch_size": 64,
  "epochs": 100,
  "lr_backbone": 5e-6,
  "lr_head": 2e-3,
  "freeze_backbone_epochs": 20,
  "label_smoothing": 0.2,
  "weight_decay": 1e-3
}
```

## πŸ“Š Expected Training Process

### Phase 1: Backbone Frozen (Epochs 1-10)
```
🧊 Epoch 1: Backbone frozen (training only classification head)
Epoch   1/50:
  Train Loss: 0.6234
  Val Loss:   0.5891
  Mean AUC:   0.7123
  Mean AP:    0.6894
  Exact Match: 0.4512
  βœ… New best model saved (AUC: 0.7123)
```

### Phase 2: End-to-End Fine-tuning (Epochs 11-50)
```
Epoch  15/50:
  Train Loss: 0.3456
  Val Loss:   0.3891
  Mean AUC:   0.8567
  Mean AP:    0.8234
  Exact Match: 0.6789
  βœ… New best model saved (AUC: 0.8567)
```

## πŸ” Making Predictions

### Single Image Inference

```bash
python inference_classification.py \
    --model_path ./classification_results/best_classification_model.pth \
    --image_path ./test_image.png \
    --threshold 0.5
```

**Output:**
```
πŸ“Έ Image 1: test_image.png
πŸ† Top prediction: mass (0.847)
πŸ“Š All probabilities:
   βœ… mass              : 0.847
   ❌ calcification     : 0.234
   ❌ normal            : 0.123
   ❌ architectural_distortion: 0.089
```

### Batch Inference

```bash
python inference_classification.py \
    --model_path ./classification_results/best_classification_model.pth \
    --images_dir ./test_images \
    --output_json ./predictions.json \
    --batch_size 64
```

### Programmatic Usage

```python
import torch
from train_byol_mammo import MammogramBYOL
from inference_classification import load_classification_model, create_inference_transform

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, class_names, config = load_classification_model(
    "./classification_results/best_classification_model.pth", device
)

# Make prediction
transform = create_inference_transform()
image = Image.open("test.png").convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)

with torch.no_grad():
    logits = model.classify(input_tensor)
    probabilities = torch.sigmoid(logits).cpu().numpy()[0]

# Get results
for i, class_name in enumerate(class_names):
    print(f"{class_name}: {probabilities[i]:.3f}")
```

## πŸ“ˆ Monitoring Training

### Weights & Biases Integration

The script automatically logs to W&B:
- Training/validation loss curves
- Per-class AUC and Average Precision
- Learning rate schedules
- Model hyperparameters

### Metrics Explained

- **AUC (Area Under Curve)**: Measures ranking quality (0-1, higher better)
- **AP (Average Precision)**: Summarizes precision-recall curve (0-1, higher better)  
- **Exact Match Accuracy**: Percentage where ALL labels are predicted correctly
- **Per-Class Accuracy**: Binary accuracy for each individual class

## πŸ’Ύ Output Files

Training creates:
```
classification_results/
β”œβ”€β”€ best_classification_model.pth      # Best model by validation AUC
β”œβ”€β”€ final_classification_model.pth     # Final model after all epochs
β”œβ”€β”€ classification_epoch_10.pth        # Periodic checkpoints
β”œβ”€β”€ classification_epoch_20.pth
└── ...
```

Each checkpoint contains:
- Model state dict
- Optimizer state  
- Training configuration
- Class names
- Validation metrics

## πŸ› οΈ Advanced Usage

### Custom Loss Functions

For imbalanced datasets, modify the loss function:

```python
# Calculate positive weights for each class
pos_counts = df[class_names].sum()
neg_counts = len(df) - pos_counts
pos_weight = torch.tensor(neg_counts / pos_counts).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
```

### Transfer Learning Strategies

1. **Conservative**: Freeze backbone for many epochs, low backbone LR
   - `freeze_backbone_epochs = 20`
   - `lr_backbone = 1e-6`

2. **Aggressive**: Unfreeze early, higher backbone LR
   - `freeze_backbone_epochs = 5`  
   - `lr_backbone = 1e-4`

3. **Progressive**: Gradually unfreeze layers (requires code modification)

### Multi-GPU Training

For multiple GPUs, wrap the model:
```python
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
```

## ⚠️ Troubleshooting

### Common Issues

**Low Validation Performance:**
- Increase `freeze_backbone_epochs` to 15-20
- Reduce `lr_backbone` to 5e-6 or 1e-6
- Check for data leakage between train/val sets

**Overfitting:**
- Increase `label_smoothing` to 0.2-0.3
- Add more dropout (modify model architecture)
- Reduce learning rates
- Use early stopping

**Memory Issues:**
- Reduce `batch_size` to 16 or 8
- Reduce `num_workers` to 4
- Use gradient checkpointing (requires code modification)

**Class Imbalance:**
- Use `pos_weight` in loss function
- Focus on per-class AUC rather than accuracy
- Consider focal loss for extreme imbalance

## 🎯 Best Practices

1. **Start Conservative**: Use default settings first
2. **Monitor Per-Class Metrics**: Some classes may need special attention
3. **Validate Data**: Ensure no train/val overlap
4. **Checkpoint Often**: Training can be interrupted
5. **Use Multiple Runs**: Average results across random seeds
6. **Test Thoroughly**: Use held-out test set for final evaluation

## πŸ“š Complete Example

Here's a full workflow from BYOL training to classification:

```bash
# 1. Train BYOL (this takes 4-5 hours on A100)
python train_byol_mammo.py

# 2. Prepare classification data (create CSVs with labels)
# ... prepare train_labels.csv and val_labels.csv ...

# 3. Fine-tune for classification (1-2 hours)
python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification architectural_distortion asymmetry normal \
    --output_dir ./classification_results

# 4. Run inference on new images
python inference_classification.py \
    --model_path ./classification_results/best_classification_model.pth \
    --images_dir ./new_patient_tiles \
    --output_json ./patient_predictions.json
```

This gives you a complete pipeline from self-supervised pre-training to production-ready classification! πŸš€