npv2k1's picture
feat: update
0e63e05 verified
raw
history blame contribute delete
715 Bytes
from src.configs.model_config import ModelConfig
import torch
def train(dataloader, model, loss_fn, optimizer):
config = ModelConfig().get_config()
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(config.device), y.to(config.device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % config.log_interval == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# return loss
return loss