import torch import torch.nn as nn import time from collections import defaultdict from wandb_utils import WandbLogger import argparse def train(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs, device='cuda', use_wandb=True, model_name="CustomCNN"): """ Production-grade training pipeline with optional W&B integration. Args: model: PyTorch model to train train_loader: Training data loader val_loader: Validation data loader optimizer: Optimizer (AdamW/SGD) criterion: Loss function (CrossEntropyLoss) scheduler: Learning rate scheduler num_epochs: Number of training epochs device: Device to train on use_wandb: Whether to use W&B logging model_name: Name of the model for logging Returns: dict: Training history with losses and accuracies """ model.to(device) best_val_acc = 0.0 history = defaultdict(list) # Initialize W&B if enabled logger = None if use_wandb: logger = WandbLogger() config = { 'epochs': num_epochs, 'device': str(device), 'optimizer': optimizer.__class__.__name__, 'scheduler': scheduler.__class__.__name__, 'learning_rate': optimizer.param_groups[0]['lr'] } logger.init_experiment(config, model, model_name) print(f"šŸš€ Training {model_name} on {device} for {num_epochs} epochs...") print("-" * 60) for epoch in range(num_epochs): epoch_start = time.time() # Training phase model.train() train_loss, train_acc = _train_epoch(model, train_loader, optimizer, criterion, device) # Validation phase model.eval() val_loss, val_acc, val_preds, val_targets = _validate_epoch(model, val_loader, criterion, device) # Update scheduler scheduler.step() # Log metrics metrics = { 'train_loss': train_loss, 'train_acc': train_acc, 'val_loss': val_loss, 'val_acc': val_acc, 'lr': optimizer.param_groups[0]['lr'] } for key, value in metrics.items(): history[key].append(value) # W&B logging if logger: logger.log_metrics({ 'epoch': epoch, 'train_loss': train_loss, 'train_accuracy': train_acc, 'val_loss': val_loss, 'val_accuracy': val_acc, 'learning_rate': optimizer.param_groups[0]['lr'] }, step=epoch) # Log confusion matrix every 20 epochs if (epoch + 1) % 20 == 0: logger.log_confusion_matrix(val_targets, val_preds, epoch) # Save best model is_best = val_acc > best_val_acc if is_best: best_val_acc = val_acc checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc, 'val_loss': val_loss } torch.save(checkpoint, f'best_model_{model_name.lower()}.pth') # Only log best model checkpoint to W&B if logger: logger.log_model_checkpoint(model, optimizer, epoch, {'val_accuracy': val_acc, 'val_loss': val_loss}, is_best=True) # Print progress epoch_time = time.time() - epoch_start print(f"Epoch {epoch+1:3d}/{num_epochs} | " f"Train: {train_loss:.4f}/{train_acc:.2f}% | " f"Val: {val_loss:.4f}/{val_acc:.2f}% | " f"LR: {optimizer.param_groups[0]['lr']:.2e} | " f"Time: {epoch_time:.1f}s") if logger: logger.finish() print(f"\nšŸŽÆ Best validation accuracy: {best_val_acc:.2f}%") return dict(history) def _train_epoch(model, train_loader, optimizer, criterion, device): """Single training epoch with metrics.""" running_loss = 0.0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return running_loss / len(train_loader), 100.0 * correct / total def _validate_epoch(model, val_loader, criterion, device): """Single validation epoch with predictions.""" running_loss = 0.0 correct = 0 total = 0 all_preds = [] all_targets = [] with torch.no_grad(): for inputs, targets in val_loader: inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) running_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() all_preds.extend(predicted.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) return (running_loss / len(val_loader), 100.0 * correct / total, all_preds, all_targets) def create_optimizer(model, opt_type='adamw', lr=0.001, weight_decay=1e-4): """Create optimizer with best practices.""" if opt_type.lower() == 'adamw': return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) elif opt_type.lower() == 'sgd': return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) else: raise ValueError(f"Unsupported optimizer: {opt_type}") def create_scheduler(optimizer, scheduler_type='cosine', num_epochs=50): # Changed default from 100 to 50 """Create learning rate scheduler.""" if scheduler_type.lower() == 'cosine': return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) elif scheduler_type.lower() == 'step': return torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) else: raise ValueError(f"Unsupported scheduler: {scheduler_type}") # CLI interface for hyperparameter sweeps def main(): parser = argparse.ArgumentParser(description='Train CIFAR-10 models with W&B') parser.add_argument('--model', choices=['custom', 'resnet18'], default='custom') parser.add_argument('--epochs', type=int, default=50) # Changed from 100 to 50 parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--no-wandb', action='store_true', help='Disable W&B logging') parser.add_argument('--sweep', action='store_true', help='Run hyperparameter sweep') args = parser.parse_args() # Import models if args.model == 'custom': from models.custom_cnn import create_custom_cnn model = create_custom_cnn() model_name = "CustomCNN" else: from models.resnet18 import load_resnet18 model = load_resnet18() model_name = "ResNet18" # Load data from utils.data_loader import get_cifar10_loaders train_loader, val_loader, test_loader = get_cifar10_loaders(batch_size=args.batch_size) # Create optimizer and scheduler optimizer = create_optimizer(model, lr=args.lr) scheduler = create_scheduler(optimizer, num_epochs=args.epochs) criterion = nn.CrossEntropyLoss() # Train model if args.sweep: from wandb_utils import create_hyperparameter_sweep, run_hyperparameter_sweep sweep_config = create_hyperparameter_sweep() def train_fn(): # W&B will set hyperparameters train(model, train_loader, val_loader, optimizer, criterion, scheduler, args.epochs, model_name=model_name) run_hyperparameter_sweep(train_fn, sweep_config) else: train(model, train_loader, val_loader, optimizer, criterion, scheduler, args.epochs, use_wandb=not args.no_wandb, model_name=model_name) if __name__ == "__main__": main()