Spaces:
Sleeping
Sleeping
| """ | |
| Training pipeline for Pneumonia classification. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.utils.data import DataLoader | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple | |
| import time | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | |
| from .config import ( | |
| STAGE1_EPOCHS, STAGE1_LR, STAGE2_EPOCHS, STAGE2_LR, | |
| WEIGHT_DECAY, SCHEDULER_PATIENCE, SCHEDULER_FACTOR, | |
| EARLY_STOP_PATIENCE, CHECKPOINT_PATH, MODEL_DIR | |
| ) | |
| from .model import PneumoniaClassifier, get_device | |
| class EarlyStopping: | |
| """Early stopping to prevent overfitting.""" | |
| def __init__(self, patience: int = 7, min_delta: float = 0.001): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.counter = 0 | |
| self.best_loss = float('inf') | |
| self.should_stop = False | |
| def __call__(self, val_loss: float) -> bool: | |
| if val_loss < self.best_loss - self.min_delta: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.should_stop = True | |
| return self.should_stop | |
| def train_epoch( | |
| model: nn.Module, | |
| loader: DataLoader, | |
| criterion: nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| device: torch.device | |
| ) -> Tuple[float, float]: | |
| """Train for one epoch.""" | |
| model.train() | |
| total_loss = 0 | |
| all_preds, all_labels = [], [] | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.float().unsqueeze(1).to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * images.size(0) | |
| preds = (torch.sigmoid(outputs) > 0.5).int() | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| avg_loss = total_loss / len(loader.dataset) | |
| accuracy = accuracy_score(all_labels, all_preds) | |
| return avg_loss, accuracy | |
| def validate( | |
| model: nn.Module, | |
| loader: DataLoader, | |
| criterion: nn.Module, | |
| device: torch.device | |
| ) -> Dict[str, float]: | |
| """Validate the model.""" | |
| model.eval() | |
| total_loss = 0 | |
| all_preds, all_labels = [], [] | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.float().unsqueeze(1).to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * images.size(0) | |
| preds = (torch.sigmoid(outputs) > 0.5).int() | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| avg_loss = total_loss / len(loader.dataset) | |
| return { | |
| 'loss': avg_loss, | |
| 'accuracy': accuracy_score(all_labels, all_preds), | |
| 'precision': precision_score(all_labels, all_preds, zero_division=0), | |
| 'recall': recall_score(all_labels, all_preds, zero_division=0), | |
| 'f1': f1_score(all_labels, all_preds, zero_division=0) | |
| } | |
| def train( | |
| model: PneumoniaClassifier, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| pos_weight: torch.Tensor, | |
| epochs: int, | |
| lr: float, | |
| device: torch.device, | |
| stage: str = "stage1", | |
| use_wandb: bool = True, | |
| wandb_run = None | |
| ) -> Dict[str, list]: | |
| """Training loop with validation.""" | |
| criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device)) | |
| optimizer = AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=lr, | |
| weight_decay=WEIGHT_DECAY | |
| ) | |
| scheduler = ReduceLROnPlateau( | |
| optimizer, mode='min', | |
| patience=SCHEDULER_PATIENCE, | |
| factor=SCHEDULER_FACTOR | |
| ) | |
| early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE) | |
| history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'val_f1': [], 'lr': []} | |
| best_val_loss = float('inf') | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| for epoch in range(epochs): | |
| start = time.time() | |
| # Train | |
| train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) | |
| # Validate | |
| val_metrics = validate(model, val_loader, criterion, device) | |
| # Get current LR | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| # Update scheduler | |
| scheduler.step(val_metrics['loss']) | |
| # Log | |
| elapsed = time.time() - start | |
| print(f"[{stage}] Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) | " | |
| f"Train Loss: {train_loss:.4f} | " | |
| f"Val Loss: {val_metrics['loss']:.4f} | " | |
| f"Val Acc: {val_metrics['accuracy']:.3f} | " | |
| f"Val F1: {val_metrics['f1']:.3f} | " | |
| f"LR: {current_lr:.2e}") | |
| # W&B logging | |
| if use_wandb and wandb_run: | |
| wandb_run.log({ | |
| f"{stage}/train_loss": train_loss, | |
| f"{stage}/train_acc": train_acc, | |
| f"{stage}/val_loss": val_metrics['loss'], | |
| f"{stage}/val_acc": val_metrics['accuracy'], | |
| f"{stage}/val_precision": val_metrics['precision'], | |
| f"{stage}/val_recall": val_metrics['recall'], | |
| f"{stage}/val_f1": val_metrics['f1'], | |
| f"{stage}/lr": current_lr, | |
| "epoch": epoch + 1 | |
| }) | |
| # Save history | |
| history['train_loss'].append(train_loss) | |
| history['val_loss'].append(val_metrics['loss']) | |
| history['val_acc'].append(val_metrics['accuracy']) | |
| history['val_f1'].append(val_metrics['f1']) | |
| history['lr'].append(current_lr) | |
| # Save best model | |
| if val_metrics['loss'] < best_val_loss: | |
| best_val_loss = val_metrics['loss'] | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'val_loss': best_val_loss, | |
| 'val_metrics': val_metrics | |
| }, CHECKPOINT_PATH) | |
| print(f" -> Saved best model (val_loss: {best_val_loss:.4f})") | |
| # Early stopping | |
| if early_stopping(val_metrics['loss']): | |
| print(f"Early stopping triggered at epoch {epoch+1}") | |
| break | |
| return history | |
| def train_two_stage( | |
| model: PneumoniaClassifier, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| pos_weight: torch.Tensor, | |
| device: torch.device, | |
| use_wandb: bool = True, | |
| wandb_run = None | |
| ) -> Dict[str, list]: | |
| """Two-stage training: frozen backbone then fine-tuning.""" | |
| # Stage 1: Train classifier only | |
| print("\n" + "=" * 60) | |
| print("STAGE 1: Training classifier (backbone frozen)") | |
| print("=" * 60) | |
| model.freeze_backbone() | |
| trainable, total = model.get_param_counts() | |
| print(f"Trainable params: {trainable:,} / {total:,}") | |
| history1 = train( | |
| model, train_loader, val_loader, pos_weight, | |
| epochs=STAGE1_EPOCHS, lr=STAGE1_LR, device=device, | |
| stage="stage1", use_wandb=use_wandb, wandb_run=wandb_run | |
| ) | |
| # Stage 2: Fine-tune entire network | |
| print("\n" + "=" * 60) | |
| print("STAGE 2: Fine-tuning entire network") | |
| print("=" * 60) | |
| model.unfreeze_backbone() | |
| trainable, total = model.get_param_counts() | |
| print(f"Trainable params: {trainable:,} / {total:,}") | |
| history2 = train( | |
| model, train_loader, val_loader, pos_weight, | |
| epochs=STAGE2_EPOCHS, lr=STAGE2_LR, device=device, | |
| stage="stage2", use_wandb=use_wandb, wandb_run=wandb_run | |
| ) | |
| # Combine histories | |
| history = {k: history1[k] + history2[k] for k in history1} | |
| return history | |