CNN_Benchmark / train.py
Shreshth2002's picture
Upload folder using huggingface_hub
c65e61c verified
import torch
import torch.nn as nn
import time
from collections import defaultdict
from wandb_utils import WandbLogger
import argparse
def train(model, train_loader, val_loader, optimizer, criterion, scheduler,
num_epochs, device='cuda', use_wandb=True, model_name="CustomCNN"):
"""
Production-grade training pipeline with optional W&B integration.
Args:
model: PyTorch model to train
train_loader: Training data loader
val_loader: Validation data loader
optimizer: Optimizer (AdamW/SGD)
criterion: Loss function (CrossEntropyLoss)
scheduler: Learning rate scheduler
num_epochs: Number of training epochs
device: Device to train on
use_wandb: Whether to use W&B logging
model_name: Name of the model for logging
Returns:
dict: Training history with losses and accuracies
"""
model.to(device)
best_val_acc = 0.0
history = defaultdict(list)
# Initialize W&B if enabled
logger = None
if use_wandb:
logger = WandbLogger()
config = {
'epochs': num_epochs,
'device': str(device),
'optimizer': optimizer.__class__.__name__,
'scheduler': scheduler.__class__.__name__,
'learning_rate': optimizer.param_groups[0]['lr']
}
logger.init_experiment(config, model, model_name)
print(f"🚀 Training {model_name} on {device} for {num_epochs} epochs...")
print("-" * 60)
for epoch in range(num_epochs):
epoch_start = time.time()
# Training phase
model.train()
train_loss, train_acc = _train_epoch(model, train_loader, optimizer, criterion, device)
# Validation phase
model.eval()
val_loss, val_acc, val_preds, val_targets = _validate_epoch(model, val_loader, criterion, device)
# Update scheduler
scheduler.step()
# Log metrics
metrics = {
'train_loss': train_loss,
'train_acc': train_acc,
'val_loss': val_loss,
'val_acc': val_acc,
'lr': optimizer.param_groups[0]['lr']
}
for key, value in metrics.items():
history[key].append(value)
# W&B logging
if logger:
logger.log_metrics({
'epoch': epoch,
'train_loss': train_loss,
'train_accuracy': train_acc,
'val_loss': val_loss,
'val_accuracy': val_acc,
'learning_rate': optimizer.param_groups[0]['lr']
}, step=epoch)
# Log confusion matrix every 20 epochs
if (epoch + 1) % 20 == 0:
logger.log_confusion_matrix(val_targets, val_preds, epoch)
# Save best model
is_best = val_acc > best_val_acc
if is_best:
best_val_acc = val_acc
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
'val_loss': val_loss
}
torch.save(checkpoint, f'best_model_{model_name.lower()}.pth')
# Only log best model checkpoint to W&B
if logger:
logger.log_model_checkpoint(model, optimizer, epoch,
{'val_accuracy': val_acc, 'val_loss': val_loss},
is_best=True)
# Print progress
epoch_time = time.time() - epoch_start
print(f"Epoch {epoch+1:3d}/{num_epochs} | "
f"Train: {train_loss:.4f}/{train_acc:.2f}% | "
f"Val: {val_loss:.4f}/{val_acc:.2f}% | "
f"LR: {optimizer.param_groups[0]['lr']:.2e} | "
f"Time: {epoch_time:.1f}s")
if logger:
logger.finish()
print(f"\n🎯 Best validation accuracy: {best_val_acc:.2f}%")
return dict(history)
def _train_epoch(model, train_loader, optimizer, criterion, device):
"""Single training epoch with metrics."""
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return running_loss / len(train_loader), 100.0 * correct / total
def _validate_epoch(model, val_loader, criterion, device):
"""Single validation epoch with predictions."""
running_loss = 0.0
correct = 0
total = 0
all_preds = []
all_targets = []
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
return (running_loss / len(val_loader), 100.0 * correct / total,
all_preds, all_targets)
def create_optimizer(model, opt_type='adamw', lr=0.001, weight_decay=1e-4):
"""Create optimizer with best practices."""
if opt_type.lower() == 'adamw':
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
elif opt_type.lower() == 'sgd':
return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
else:
raise ValueError(f"Unsupported optimizer: {opt_type}")
def create_scheduler(optimizer, scheduler_type='cosine', num_epochs=50): # Changed default from 100 to 50
"""Create learning rate scheduler."""
if scheduler_type.lower() == 'cosine':
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
elif scheduler_type.lower() == 'step':
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
else:
raise ValueError(f"Unsupported scheduler: {scheduler_type}")
# CLI interface for hyperparameter sweeps
def main():
parser = argparse.ArgumentParser(description='Train CIFAR-10 models with W&B')
parser.add_argument('--model', choices=['custom', 'resnet18'], default='custom')
parser.add_argument('--epochs', type=int, default=50) # Changed from 100 to 50
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--no-wandb', action='store_true', help='Disable W&B logging')
parser.add_argument('--sweep', action='store_true', help='Run hyperparameter sweep')
args = parser.parse_args()
# Import models
if args.model == 'custom':
from models.custom_cnn import create_custom_cnn
model = create_custom_cnn()
model_name = "CustomCNN"
else:
from models.resnet18 import load_resnet18
model = load_resnet18()
model_name = "ResNet18"
# Load data
from utils.data_loader import get_cifar10_loaders
train_loader, val_loader, test_loader = get_cifar10_loaders(batch_size=args.batch_size)
# Create optimizer and scheduler
optimizer = create_optimizer(model, lr=args.lr)
scheduler = create_scheduler(optimizer, num_epochs=args.epochs)
criterion = nn.CrossEntropyLoss()
# Train model
if args.sweep:
from wandb_utils import create_hyperparameter_sweep, run_hyperparameter_sweep
sweep_config = create_hyperparameter_sweep()
def train_fn():
# W&B will set hyperparameters
train(model, train_loader, val_loader, optimizer, criterion,
scheduler, args.epochs, model_name=model_name)
run_hyperparameter_sweep(train_fn, sweep_config)
else:
train(model, train_loader, val_loader, optimizer, criterion, scheduler,
args.epochs, use_wandb=not args.no_wandb, model_name=model_name)
if __name__ == "__main__":
main()