CharRNN / src /trainer.py
hoom4n's picture
Upload 18 files
b6447fa verified
from time import time
import torch
import torchmetrics
from torch.utils.data import DataLoader
import torch.nn as nn
from typing import Optional, Callable, Dict, Any
def evaluate(model, metric, loss_fn, data_loader, device):
"""Evaluate model on a validation/test set."""
metric.reset()
model.eval()
with torch.no_grad():
total_loss = 0.0
for X_batch, y_batch in data_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
yhat = model(X_batch)
total_loss += loss_fn(yhat, y_batch).item()
metric.update(yhat, y_batch)
return (total_loss / len(data_loader) , metric.compute().item())
def train(
model: nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Callable,
metric: torchmetrics.Metric,
n_epochs: int,
device: torch.device,
train_loader: DataLoader,
val_loader: DataLoader,
scheduler: Optional[Any] = None,
scheduler_monitor: str = "val_loss",
epoch_callback: Optional[Callable] = None
) -> Dict[str, list]:
"""Train a model with logging, validation, and optional scheduler."""
train_logs = {"train_loss":[] , "train_metric":[], "val_loss":[] , "val_metric":[], "lr":[]}
for epoch in range(n_epochs):
# ON EPOCH START
start_time = time()
model.train()
total_loss = 0.0
metric.reset()
if epoch_callback is not None:
epoch_callback(model, epoch)
# INNER LOOP
for idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
yhat = model(X_batch)
loss = loss_fn(yhat, y_batch)
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
metric.update(yhat, y_batch)
train_metric = metric.compute().item()
print(f"\r Epoch {epoch + 1}/{n_epochs}", end="")
print(f", Step {idx+1}/{len(train_loader)}", end="")
print(f", train_loss: {total_loss / (idx+1):.4f}", end="")
print(f", train_metric : {train_metric:.4f}", end="")
# LOGGING
train_logs["train_loss"].append(total_loss / len(train_loader))
train_logs["train_metric"].append(train_metric)
eval_loss, eval_metric = evaluate(model, metric, loss_fn, val_loader, device)
train_logs["val_loss"].append(eval_loss)
train_logs["val_metric"].append(eval_metric)
train_logs["lr"].append(optimizer.param_groups[0]['lr'])
# LR SCHEDULE
if scheduler is not None:
scheduler.step(eval_loss if scheduler_monitor=="val_loss" else eval_metric)
print(f"\r Epoch {epoch + 1}/{n_epochs}", end="")
print(f", train_loss: {train_logs["train_loss"][-1]:.4f}", end="")
print(f", train_metric : {train_logs["train_metric"][-1]:.4f}", end="")
print(f', val_loss: {train_logs["val_loss"][-1]:.4f}', end="")
print(f', val_metric: {train_logs["val_metric"][-1]:.4f}', end="")
if scheduler is not None:
print(f", lr: {train_logs["lr"][-1]}", end="")
print(f', epoch_time: {time() - start_time:.2f}s')
return train_logs