| from time import time |
| import torch |
| import torchmetrics |
| from torch.utils.data import DataLoader |
| import torch.nn as nn |
| from typing import Optional, Callable, Dict, Any |
|
|
| def evaluate(model, metric, loss_fn, data_loader, device): |
| """Evaluate model on a validation/test set.""" |
| metric.reset() |
| model.eval() |
| with torch.no_grad(): |
| total_loss = 0.0 |
| for X_batch, y_batch in data_loader: |
| X_batch, y_batch = X_batch.to(device), y_batch.to(device) |
| yhat = model(X_batch) |
| total_loss += loss_fn(yhat, y_batch).item() |
| metric.update(yhat, y_batch) |
| return (total_loss / len(data_loader) , metric.compute().item()) |
| |
| def train( |
| model: nn.Module, |
| optimizer: torch.optim.Optimizer, |
| loss_fn: Callable, |
| metric: torchmetrics.Metric, |
| n_epochs: int, |
| device: torch.device, |
| train_loader: DataLoader, |
| val_loader: DataLoader, |
| scheduler: Optional[Any] = None, |
| scheduler_monitor: str = "val_loss", |
| epoch_callback: Optional[Callable] = None |
| ) -> Dict[str, list]: |
| """Train a model with logging, validation, and optional scheduler.""" |
| train_logs = {"train_loss":[] , "train_metric":[], "val_loss":[] , "val_metric":[], "lr":[]} |
|
|
| for epoch in range(n_epochs): |
| |
| |
| start_time = time() |
| model.train() |
| total_loss = 0.0 |
| metric.reset() |
| if epoch_callback is not None: |
| epoch_callback(model, epoch) |
|
|
| |
| for idx, (X_batch, y_batch) in enumerate(train_loader): |
| X_batch, y_batch = X_batch.to(device), y_batch.to(device) |
| yhat = model(X_batch) |
| loss = loss_fn(yhat, y_batch) |
| total_loss += loss.item() |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| metric.update(yhat, y_batch) |
| train_metric = metric.compute().item() |
|
|
| print(f"\r Epoch {epoch + 1}/{n_epochs}", end="") |
| print(f", Step {idx+1}/{len(train_loader)}", end="") |
| print(f", train_loss: {total_loss / (idx+1):.4f}", end="") |
| print(f", train_metric : {train_metric:.4f}", end="") |
|
|
| |
| train_logs["train_loss"].append(total_loss / len(train_loader)) |
| train_logs["train_metric"].append(train_metric) |
| eval_loss, eval_metric = evaluate(model, metric, loss_fn, val_loader, device) |
| train_logs["val_loss"].append(eval_loss) |
| train_logs["val_metric"].append(eval_metric) |
| train_logs["lr"].append(optimizer.param_groups[0]['lr']) |
| |
| |
| if scheduler is not None: |
| scheduler.step(eval_loss if scheduler_monitor=="val_loss" else eval_metric) |
|
|
| print(f"\r Epoch {epoch + 1}/{n_epochs}", end="") |
| print(f", train_loss: {train_logs["train_loss"][-1]:.4f}", end="") |
| print(f", train_metric : {train_logs["train_metric"][-1]:.4f}", end="") |
| print(f', val_loss: {train_logs["val_loss"][-1]:.4f}', end="") |
| print(f', val_metric: {train_logs["val_metric"][-1]:.4f}', end="") |
| if scheduler is not None: |
| print(f", lr: {train_logs["lr"][-1]}", end="") |
| print(f', epoch_time: {time() - start_time:.2f}s') |
| |
| return train_logs |