from src.configs.model_config import ModelConfig import torch def train(dataloader, model, loss_fn, optimizer): config = ModelConfig().get_config() size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): X, y = X.to(config.device), y.to(config.device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation loss.backward() optimizer.step() optimizer.zero_grad() if batch % config.log_interval == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") # return loss return loss