File size: 18,548 Bytes
d921913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
#!/usr/bin/env python3
"""
train_classification.py

Fine-tune the BYOL pre-trained model for multi-label classification on mammogram tiles.
This script loads the BYOL checkpoint and trains only the classification head while
optionally fine-tuning the backbone with a lower learning rate.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import pandas as pd
import numpy as np
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
from tqdm import tqdm
import wandb
import argparse
from typing import Dict, List, Tuple
import json

# Import the BYOL model
from train_byol_mammo import MammogramBYOL

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TILE_SIZE = 512

# Default hyperparameters - can be overridden via command line
DEFAULT_CONFIG = {
    'batch_size': 32,
    'num_workers': 8,
    'epochs': 50,
    'lr_backbone': 1e-5,      # Lower LR for pre-trained backbone
    'lr_head': 1e-3,          # Higher LR for classification head  
    'weight_decay': 1e-4,
    'warmup_epochs': 5,
    'freeze_backbone_epochs': 10,  # Freeze backbone for first N epochs
    'label_smoothing': 0.1,
    'dropout_rate': 0.3,
    'gradient_clip': 1.0,
}


class MammogramClassificationDataset(Dataset):
    """Dataset for mammogram tile classification with multi-label support."""
    
    def __init__(self, csv_path: str, tiles_dir: str, class_names: List[str], 
                 transform=None, max_samples: int = None):
        """
        Args:
            csv_path: Path to CSV with columns ['tile_path', 'class1', 'class2', ...]
            tiles_dir: Directory containing tile images
            class_names: List of class names (e.g., ['mass', 'calcification', 'normal', etc.])
            transform: Image transformations
            max_samples: Limit dataset size for testing
        """
        self.tiles_dir = Path(tiles_dir)
        self.class_names = class_names
        self.num_classes = len(class_names)
        self.transform = transform
        
        # Load data
        self.df = pd.read_csv(csv_path)
        if max_samples:
            self.df = self.df.head(max_samples)
        
        print(f"πŸ“Š Loaded {len(self.df)} samples for classification training")
        print(f"🏷️  Classes: {class_names}")
        
        # Validate required columns
        required_cols = ['tile_path'] + class_names
        missing_cols = [col for col in required_cols if col not in self.df.columns]
        if missing_cols:
            raise ValueError(f"Missing columns in CSV: {missing_cols}")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img_path = self.tiles_dir / row['tile_path']
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Get multi-label targets
        labels = torch.tensor([row[class_name] for class_name in self.class_names], 
                             dtype=torch.float32)
        
        return image, labels


def create_classification_transforms(tile_size: int, is_training: bool = True):
    """Create transforms for classification training."""
    
    if is_training:
        # Training transforms - moderate augmentation
        transform = T.Compose([
            T.Resize((tile_size, tile_size)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.2),
            T.RandomRotation(degrees=10, fill=0),
            T.ColorJitter(brightness=0.2, contrast=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    else:
        # Validation transforms - no augmentation
        transform = T.Compose([
            T.Resize((tile_size, tile_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    
    return transform


class ClassificationModel(nn.Module):
    """Classification model that wraps BYOL backbone with classification head."""
    
    def __init__(self, byol_model: MammogramBYOL, num_classes: int, hidden_dim: int = 2048):
        super().__init__()
        self.byol_model = byol_model
        
        # Create classification head
        self.classification_head = nn.Sequential(
            nn.Linear(2048, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x):
        """Forward pass for classification."""
        features = self.byol_model.get_features(x)
        return self.classification_head(features)
    
    def get_features(self, x):
        """Get backbone features."""
        return self.byol_model.get_features(x)


def load_byol_model(checkpoint_path: str, num_classes: int, device: torch.device):
    """Load BYOL pre-trained model and prepare for classification."""
    
    print(f"πŸ“₯ Loading BYOL checkpoint: {checkpoint_path}")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Create BYOL model with same architecture as training
    from torchvision import models
    resnet = models.resnet50(weights=None)  # Don't load ImageNet weights
    backbone = nn.Sequential(*list(resnet.children())[:-1])
    
    byol_model = MammogramBYOL(
        backbone=backbone,
        input_dim=2048,
        hidden_dim=4096, 
        proj_dim=256
    ).to(device)
    
    # Load BYOL weights
    byol_model.load_state_dict(checkpoint['model_state_dict'])
    
    # Create classification model
    model = ClassificationModel(byol_model, num_classes).to(device)
    
    print(f"βœ… Loaded BYOL model from epoch {checkpoint.get('epoch', 'unknown')}")
    print(f"πŸ“Š BYOL training loss: {checkpoint.get('loss', 'unknown'):.4f}")
    print(f"🎯 Added classification head: 2048 β†’ {2048} β†’ {num_classes}")
    
    return model


def calculate_metrics(predictions: np.ndarray, targets: np.ndarray, 
                     class_names: List[str]) -> Dict[str, float]:
    """Calculate comprehensive metrics for multi-label classification."""
    
    metrics = {}
    
    # Convert probabilities to binary predictions
    pred_binary = (predictions > 0.5).astype(int)
    
    # Per-class metrics
    for i, class_name in enumerate(class_names):
        try:
            # AUC-ROC per class
            auc = roc_auc_score(targets[:, i], predictions[:, i])
            metrics[f'auc_{class_name}'] = auc
            
            # Average Precision per class
            ap = average_precision_score(targets[:, i], predictions[:, i])
            metrics[f'ap_{class_name}'] = ap
            
            # Accuracy per class
            acc = accuracy_score(targets[:, i], pred_binary[:, i])
            metrics[f'acc_{class_name}'] = acc
            
        except ValueError:
            # Handle case where all samples are negative for this class
            metrics[f'auc_{class_name}'] = 0.0
            metrics[f'ap_{class_name}'] = 0.0
            metrics[f'acc_{class_name}'] = accuracy_score(targets[:, i], pred_binary[:, i])
    
    # Overall metrics
    metrics['mean_auc'] = np.mean([metrics[f'auc_{class_name}'] for class_name in class_names])
    metrics['mean_ap'] = np.mean([metrics[f'ap_{class_name}'] for class_name in class_names])
    metrics['mean_acc'] = np.mean([metrics[f'acc_{class_name}'] for class_name in class_names])
    
    # Exact match accuracy (all labels correct)
    exact_match = np.all(pred_binary == targets, axis=1).mean()
    metrics['exact_match_acc'] = exact_match
    
    return metrics


def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
                optimizer: optim.Optimizer, scaler: GradScaler, epoch: int,
                config: dict, freeze_backbone: bool = False) -> Dict[str, float]:
    """Train for one epoch."""
    
    model.train()
    total_loss = 0.0
    num_batches = len(dataloader)
    
    # Freeze backbone if specified
    if freeze_backbone:
        for param in model.byol_model.backbone.parameters():
            param.requires_grad = False
        for param in model.byol_model.backbone_momentum.parameters():
            param.requires_grad = False
    else:
        for param in model.byol_model.backbone.parameters():
            param.requires_grad = True
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch:3d}/{config['epochs']} [Train]", 
                ncols=100, leave=False)
    
    for batch_idx, (images, labels) in enumerate(pbar):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        
        with autocast():
            # Forward pass through classification model
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Avg': f'{total_loss/(batch_idx+1):.4f}',
            'LR': f'{optimizer.param_groups[0]["lr"]:.2e}'
        })
    
    return {'train_loss': total_loss / num_batches}


def validate_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
                  class_names: List[str]) -> Dict[str, float]:
    """Validate for one epoch."""
    
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation", ncols=100, leave=False)
        
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            
            # Convert outputs to probabilities
            probs = torch.sigmoid(outputs)
            
            all_predictions.append(probs.cpu().numpy())
            all_targets.append(labels.cpu().numpy())
            
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
    
    # Concatenate all predictions and targets
    predictions = np.concatenate(all_predictions, axis=0)
    targets = np.concatenate(all_targets, axis=0)
    
    # Calculate metrics
    metrics = calculate_metrics(predictions, targets, class_names)
    metrics['val_loss'] = total_loss / len(dataloader)
    
    return metrics


def main():
    parser = argparse.ArgumentParser(description='Fine-tune BYOL model for classification')
    parser.add_argument('--byol_checkpoint', type=str, required=True,
                       help='Path to BYOL checkpoint (.pth file)')
    parser.add_argument('--train_csv', type=str, required=True,
                       help='Path to training CSV file')
    parser.add_argument('--val_csv', type=str, required=True,
                       help='Path to validation CSV file') 
    parser.add_argument('--tiles_dir', type=str, required=True,
                       help='Directory containing tile images')
    parser.add_argument('--class_names', type=str, nargs='+', required=True,
                       help='List of class names (e.g., mass calcification normal)')
    parser.add_argument('--output_dir', type=str, default='./classification_results',
                       help='Output directory for checkpoints and logs')
    parser.add_argument('--config', type=str, default=None,
                       help='JSON config file (overrides defaults)')
    parser.add_argument('--wandb_project', type=str, default='mammogram-classification',
                       help='Weights & Biases project name')
    parser.add_argument('--max_samples', type=int, default=None,
                       help='Limit dataset size for testing')
    
    args = parser.parse_args()
    
    # Load configuration
    config = DEFAULT_CONFIG.copy()
    if args.config:
        with open(args.config, 'r') as f:
            config.update(json.load(f))
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize wandb
    try:
        wandb.init(
            project=args.wandb_project,
            config=config,
            name=f"classification_fine_tune_{len(args.class_names)}classes"
        )
        wandb_enabled = True
    except Exception as e:
        print(f"⚠️  WandB not configured: {e}")
        wandb_enabled = False
    
    print("πŸ”¬ BYOL Classification Fine-Tuning")
    print("=" * 50)
    print(f"Device: {DEVICE}")
    print(f"Classes: {args.class_names}")
    print(f"Batch size: {config['batch_size']}")
    print(f"Epochs: {config['epochs']}")
    print(f"Output directory: {output_dir}")
    
    # Load model
    model = load_byol_model(args.byol_checkpoint, len(args.class_names), DEVICE)
    
    # Create datasets
    train_transform = create_classification_transforms(TILE_SIZE, is_training=True)
    val_transform = create_classification_transforms(TILE_SIZE, is_training=False)
    
    train_dataset = MammogramClassificationDataset(
        args.train_csv, args.tiles_dir, args.class_names, 
        train_transform, max_samples=args.max_samples
    )
    
    val_dataset = MammogramClassificationDataset(
        args.val_csv, args.tiles_dir, args.class_names,
        val_transform, max_samples=args.max_samples
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False, 
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    print(f"πŸ“Š Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}")
    
    # Setup loss and optimizer
    # Use BCEWithLogitsLoss for multi-label classification
    pos_weight = None  # Could be calculated from class distribution if needed
    criterion = nn.BCEWithLogitsLoss(
        pos_weight=pos_weight,
        label_smoothing=config['label_smoothing']
    )
    
    # Different learning rates for backbone and classification head
    backbone_params = list(model.byol_model.backbone.parameters())
    head_params = list(model.classification_head.parameters())
    
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': config['lr_backbone']},
        {'params': head_params, 'lr': config['lr_head']}
    ], weight_decay=config['weight_decay'])
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['epochs'], eta_min=1e-6
    )
    
    # Mixed precision scaler
    scaler = GradScaler()
    
    # Training loop
    best_metric = 0.0
    
    for epoch in range(1, config['epochs'] + 1):
        # Decide whether to freeze backbone
        freeze_backbone = epoch <= config['freeze_backbone_epochs']
        if freeze_backbone:
            print(f"🧊 Epoch {epoch}: Backbone frozen (training only classification head)")
        
        # Train
        train_metrics = train_epoch(
            model, train_loader, criterion, optimizer, scaler, 
            epoch, config, freeze_backbone
        )
        
        # Validate
        val_metrics = validate_epoch(model, val_loader, criterion, args.class_names)
        
        # Step scheduler
        scheduler.step()
        
        # Print metrics
        print(f"\nEpoch {epoch:3d}/{config['epochs']}:")
        print(f"  Train Loss: {train_metrics['train_loss']:.4f}")
        print(f"  Val Loss:   {val_metrics['val_loss']:.4f}")
        print(f"  Mean AUC:   {val_metrics['mean_auc']:.4f}")
        print(f"  Mean AP:    {val_metrics['mean_ap']:.4f}")
        print(f"  Exact Match: {val_metrics['exact_match_acc']:.4f}")
        
        # Log to wandb
        if wandb_enabled:
            log_dict = {**train_metrics, **val_metrics, 'epoch': epoch}
            wandb.log(log_dict)
        
        # Save best model
        current_metric = val_metrics['mean_auc']
        if current_metric > best_metric:
            best_metric = current_metric
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_metrics': val_metrics,
                'config': config,
                'class_names': args.class_names
            }
            torch.save(checkpoint, output_dir / 'best_classification_model.pth')
            print(f"  βœ… New best model saved (AUC: {best_metric:.4f})")
        
        # Save periodic checkpoints
        if epoch % 10 == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_metrics': val_metrics,
                'config': config,
                'class_names': args.class_names
            }
            torch.save(checkpoint, output_dir / f'classification_epoch_{epoch}.pth')
    
    # Save final model
    final_checkpoint = {
        'epoch': config['epochs'],
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_metrics': val_metrics,
        'config': config,
        'class_names': args.class_names
    }
    torch.save(final_checkpoint, output_dir / 'final_classification_model.pth')
    
    print(f"\nπŸŽ‰ Classification training completed!")
    print(f"πŸ“Š Best validation AUC: {best_metric:.4f}")
    print(f"πŸ’Ύ Models saved to: {output_dir}")
    
    if wandb_enabled:
        wandb.finish()


if __name__ == "__main__":
    main()