| import time | |
| import numpy as np | |
| import torch | |
| import tqdm | |
| def train(model, device, train_loader, optimizer, epoch, log_interval): | |
| model.train() | |
| losses = [] | |
| accuracy = 0 | |
| for batch_idx, (x, y) in enumerate(tqdm.tqdm(train_loader)): | |
| x, y = x.to(device), y.to(device) | |
| optimizer.zero_grad() | |
| out = model(x) | |
| loss = model.loss(out, y) | |
| with torch.no_grad(): | |
| pred = torch.argmax(out, dim=1) | |
| accuracy += torch.sum((pred == y)) | |
| losses.append(loss.item()) | |
| loss.backward() | |
| optimizer.step() | |
| if batch_idx % log_interval == 0: | |
| print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
| time.ctime(time.time()), | |
| epoch, batch_idx * len(x), len(train_loader.dataset), | |
| 100. * batch_idx / len(train_loader), loss.item())) | |
| accuracy_mean = (100. * accuracy) / len(train_loader.dataset) | |
| return np.mean(losses), accuracy_mean.item() | |
| def test(model, device, test_loader, log_interval=None): | |
| model.eval() | |
| losses = [] | |
| accuracy = 0 | |
| with torch.no_grad(): | |
| for batch_idx, (x, y) in enumerate(tqdm.tqdm(test_loader)): | |
| x, y = x.to(device), y.to(device) | |
| out = model(x) | |
| test_loss_on = model.loss(out, y).item() | |
| losses.append(test_loss_on) | |
| pred = torch.argmax(out, dim=1) | |
| accuracy += torch.sum((pred == y)) | |
| if log_interval is not None and batch_idx % log_interval == 0: | |
| print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
| time.ctime(time.time()), | |
| batch_idx * len(x), len(test_loader.dataset), | |
| 100. * batch_idx / len(test_loader), test_loss_on)) | |
| test_loss = np.mean(losses) | |
| accuracy_mean = (100. * accuracy) / len(test_loader.dataset) | |
| print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} , ({:.4f})%\n'.format( | |
| test_loss, accuracy, len(test_loader.dataset), accuracy_mean)) | |
| return test_loss, accuracy_mean.item() | |