Spaces:
Build error
Build error
| from src.configs.model_config import ModelConfig | |
| import torch | |
| def train(dataloader, model, loss_fn, optimizer): | |
| config = ModelConfig().get_config() | |
| size = len(dataloader.dataset) | |
| model.train() | |
| for batch, (X, y) in enumerate(dataloader): | |
| X, y = X.to(config.device), y.to(config.device) | |
| # Compute prediction error | |
| pred = model(X) | |
| loss = loss_fn(pred, y) | |
| # Backpropagation | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if batch % config.log_interval == 0: | |
| loss, current = loss.item(), (batch + 1) * len(X) | |
| print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") | |
| # return loss | |
| return loss |