from time import time import torch import torchmetrics from torch.utils.data import DataLoader import torch.nn as nn from typing import Optional, Callable, Dict, Any def evaluate(model, metric, loss_fn, data_loader, device): """Evaluate model on a validation/test set.""" metric.reset() model.eval() with torch.no_grad(): total_loss = 0.0 for X_batch, y_batch in data_loader: X_batch, y_batch = X_batch.to(device), y_batch.to(device) yhat = model(X_batch) total_loss += loss_fn(yhat, y_batch).item() metric.update(yhat, y_batch) return (total_loss / len(data_loader) , metric.compute().item()) def train( model: nn.Module, optimizer: torch.optim.Optimizer, loss_fn: Callable, metric: torchmetrics.Metric, n_epochs: int, device: torch.device, train_loader: DataLoader, val_loader: DataLoader, scheduler: Optional[Any] = None, scheduler_monitor: str = "val_loss", epoch_callback: Optional[Callable] = None ) -> Dict[str, list]: """Train a model with logging, validation, and optional scheduler.""" train_logs = {"train_loss":[] , "train_metric":[], "val_loss":[] , "val_metric":[], "lr":[]} for epoch in range(n_epochs): # ON EPOCH START start_time = time() model.train() total_loss = 0.0 metric.reset() if epoch_callback is not None: epoch_callback(model, epoch) # INNER LOOP for idx, (X_batch, y_batch) in enumerate(train_loader): X_batch, y_batch = X_batch.to(device), y_batch.to(device) yhat = model(X_batch) loss = loss_fn(yhat, y_batch) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() metric.update(yhat, y_batch) train_metric = metric.compute().item() print(f"\r Epoch {epoch + 1}/{n_epochs}", end="") print(f", Step {idx+1}/{len(train_loader)}", end="") print(f", train_loss: {total_loss / (idx+1):.4f}", end="") print(f", train_metric : {train_metric:.4f}", end="") # LOGGING train_logs["train_loss"].append(total_loss / len(train_loader)) train_logs["train_metric"].append(train_metric) eval_loss, eval_metric = evaluate(model, metric, loss_fn, val_loader, device) train_logs["val_loss"].append(eval_loss) train_logs["val_metric"].append(eval_metric) train_logs["lr"].append(optimizer.param_groups[0]['lr']) # LR SCHEDULE if scheduler is not None: scheduler.step(eval_loss if scheduler_monitor=="val_loss" else eval_metric) print(f"\r Epoch {epoch + 1}/{n_epochs}", end="") print(f", train_loss: {train_logs["train_loss"][-1]:.4f}", end="") print(f", train_metric : {train_logs["train_metric"][-1]:.4f}", end="") print(f', val_loss: {train_logs["val_loss"][-1]:.4f}', end="") print(f', val_metric: {train_logs["val_metric"][-1]:.4f}', end="") if scheduler is not None: print(f", lr: {train_logs["lr"][-1]}", end="") print(f', epoch_time: {time() - start_time:.2f}s') return train_logs