PneumoniaAPI / src /train.py
GitHub Actions
Auto-deploy from GitHub: 495db78a06be79166200269bb14d9e9b1e8906d6
af59988
"""
Training pipeline for Pneumonia classification.
"""
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from pathlib import Path
from typing import Dict, Optional, Tuple
import time
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from .config import (
STAGE1_EPOCHS, STAGE1_LR, STAGE2_EPOCHS, STAGE2_LR,
WEIGHT_DECAY, SCHEDULER_PATIENCE, SCHEDULER_FACTOR,
EARLY_STOP_PATIENCE, CHECKPOINT_PATH, MODEL_DIR
)
from .model import PneumoniaClassifier, get_device
class EarlyStopping:
"""Early stopping to prevent overfitting."""
def __init__(self, patience: int = 7, min_delta: float = 0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
self.should_stop = False
def __call__(self, val_loss: float) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
def train_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device
) -> Tuple[float, float]:
"""Train for one epoch."""
model.train()
total_loss = 0
all_preds, all_labels = [], []
for images, labels in loader:
images = images.to(device)
labels = labels.float().unsqueeze(1).to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
preds = (torch.sigmoid(outputs) > 0.5).int()
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(loader.dataset)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy
def validate(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
device: torch.device
) -> Dict[str, float]:
"""Validate the model."""
model.eval()
total_loss = 0
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.float().unsqueeze(1).to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
preds = (torch.sigmoid(outputs) > 0.5).int()
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(loader.dataset)
return {
'loss': avg_loss,
'accuracy': accuracy_score(all_labels, all_preds),
'precision': precision_score(all_labels, all_preds, zero_division=0),
'recall': recall_score(all_labels, all_preds, zero_division=0),
'f1': f1_score(all_labels, all_preds, zero_division=0)
}
def train(
model: PneumoniaClassifier,
train_loader: DataLoader,
val_loader: DataLoader,
pos_weight: torch.Tensor,
epochs: int,
lr: float,
device: torch.device,
stage: str = "stage1",
use_wandb: bool = True,
wandb_run = None
) -> Dict[str, list]:
"""Training loop with validation."""
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
optimizer = AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr,
weight_decay=WEIGHT_DECAY
)
scheduler = ReduceLROnPlateau(
optimizer, mode='min',
patience=SCHEDULER_PATIENCE,
factor=SCHEDULER_FACTOR
)
early_stopping = EarlyStopping(patience=EARLY_STOP_PATIENCE)
history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'val_f1': [], 'lr': []}
best_val_loss = float('inf')
MODEL_DIR.mkdir(parents=True, exist_ok=True)
for epoch in range(epochs):
start = time.time()
# Train
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
# Validate
val_metrics = validate(model, val_loader, criterion, device)
# Get current LR
current_lr = optimizer.param_groups[0]['lr']
# Update scheduler
scheduler.step(val_metrics['loss'])
# Log
elapsed = time.time() - start
print(f"[{stage}] Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"Val Acc: {val_metrics['accuracy']:.3f} | "
f"Val F1: {val_metrics['f1']:.3f} | "
f"LR: {current_lr:.2e}")
# W&B logging
if use_wandb and wandb_run:
wandb_run.log({
f"{stage}/train_loss": train_loss,
f"{stage}/train_acc": train_acc,
f"{stage}/val_loss": val_metrics['loss'],
f"{stage}/val_acc": val_metrics['accuracy'],
f"{stage}/val_precision": val_metrics['precision'],
f"{stage}/val_recall": val_metrics['recall'],
f"{stage}/val_f1": val_metrics['f1'],
f"{stage}/lr": current_lr,
"epoch": epoch + 1
})
# Save history
history['train_loss'].append(train_loss)
history['val_loss'].append(val_metrics['loss'])
history['val_acc'].append(val_metrics['accuracy'])
history['val_f1'].append(val_metrics['f1'])
history['lr'].append(current_lr)
# Save best model
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': best_val_loss,
'val_metrics': val_metrics
}, CHECKPOINT_PATH)
print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
# Early stopping
if early_stopping(val_metrics['loss']):
print(f"Early stopping triggered at epoch {epoch+1}")
break
return history
def train_two_stage(
model: PneumoniaClassifier,
train_loader: DataLoader,
val_loader: DataLoader,
pos_weight: torch.Tensor,
device: torch.device,
use_wandb: bool = True,
wandb_run = None
) -> Dict[str, list]:
"""Two-stage training: frozen backbone then fine-tuning."""
# Stage 1: Train classifier only
print("\n" + "=" * 60)
print("STAGE 1: Training classifier (backbone frozen)")
print("=" * 60)
model.freeze_backbone()
trainable, total = model.get_param_counts()
print(f"Trainable params: {trainable:,} / {total:,}")
history1 = train(
model, train_loader, val_loader, pos_weight,
epochs=STAGE1_EPOCHS, lr=STAGE1_LR, device=device,
stage="stage1", use_wandb=use_wandb, wandb_run=wandb_run
)
# Stage 2: Fine-tune entire network
print("\n" + "=" * 60)
print("STAGE 2: Fine-tuning entire network")
print("=" * 60)
model.unfreeze_backbone()
trainable, total = model.get_param_counts()
print(f"Trainable params: {trainable:,} / {total:,}")
history2 = train(
model, train_loader, val_loader, pos_weight,
epochs=STAGE2_EPOCHS, lr=STAGE2_LR, device=device,
stage="stage2", use_wandb=use_wandb, wandb_run=wandb_run
)
# Combine histories
history = {k: history1[k] + history2[k] for k in history1}
return history