Spaces:
Build error
Build error
| from src.configs.model_config import ModelConfig | |
| import torch | |
| def test(dataloader, model, loss_fn): | |
| config = ModelConfig().get_config() | |
| size = len(dataloader.dataset) | |
| num_batches = len(dataloader) | |
| model.eval() | |
| test_loss, correct = 0, 0 | |
| with torch.no_grad(): | |
| for X, y in dataloader: | |
| X, y = X.to(config.device), y.to(config.device) | |
| pred = model(X) | |
| test_loss += loss_fn(pred, y).item() | |
| correct += (pred.argmax(1) == y).type(torch.float).sum().item() | |
| test_loss /= num_batches | |
| correct /= size | |
| print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") |