Spaces:
Build error
Build error
File size: 715 Bytes
0e63e05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from src.configs.model_config import ModelConfig
import torch
def train(dataloader, model, loss_fn, optimizer):
config = ModelConfig().get_config()
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(config.device), y.to(config.device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % config.log_interval == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# return loss
return loss |