Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| The lecture materials for Lecture 1: Dataset Prototyping and Visualization | |
| """ | |
| import click | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from tqdm import trange | |
| from cv4e_lecture13 import dataset, model, utils | |
| log = None | |
| def inference(cfg, dataloader, net, optimizer, criterion, update): | |
| ''' | |
| Our actual training function. | |
| ''' | |
| device = cfg.get('device') | |
| torch.set_grad_enabled(update) | |
| net.train() if update else net.eval() | |
| type_str = 'Train' if update else 'Val' | |
| loss, accuracy = 0.0, 0.0 | |
| total = len(dataloader) | |
| prog = trange(total) | |
| for index, (data, labels) in enumerate(dataloader): | |
| data, labels = data.to(device), labels.to(device) | |
| prediction = net(data) | |
| gradient = criterion(prediction, labels) | |
| if update: | |
| optimizer.zero_grad() | |
| gradient.backward() | |
| optimizer.step() | |
| # log statistics | |
| loss += gradient.item() | |
| label_ = torch.argmax(prediction, dim=1) | |
| accuracy += torch.mean((label_ == labels).float()).item() | |
| prog.set_description( | |
| '[{:s}] Loss: {:.2f}; Acc: {:.2f}%'.format( | |
| type_str, loss / (index + 1), 100.0 * accuracy / (index + 1) | |
| ) | |
| ) | |
| prog.update(1) | |
| prog.close() | |
| loss /= total | |
| accuracy /= total | |
| return loss, accuracy | |
| def lecture(config): | |
| """ | |
| Main function for Lecture 1: Dataset Prototyping and Visualization | |
| """ | |
| global log | |
| log = utils.init_logging() | |
| cfg = utils.init_config(config, log) | |
| # init random number generator seed (set at the start) | |
| utils.init_seed(cfg.get('seed', None)) | |
| ################################################################################ | |
| # Load MNIST | |
| train, test = dataset.load(cfg) | |
| net, epoch, best_loss = model.load(cfg) | |
| optimizer = Adam( | |
| net.parameters(), | |
| lr=cfg.get('learning_rate'), | |
| weight_decay=cfg.get('weight_decay'), | |
| ) | |
| criterion = nn.CrossEntropyLoss() | |
| epochs = cfg.get('max_epochs') | |
| while epoch < epochs: | |
| log.info(f'Epoch {epoch}/{epochs}') | |
| loss_train, accuracy_train = inference( | |
| cfg, train, net, optimizer, criterion, update=True | |
| ) | |
| loss_test, accuracy_test = inference( | |
| cfg, test, net, optimizer, criterion, update=False | |
| ) | |
| # combine stats and save | |
| stats = { | |
| 'loss_train': loss_train, | |
| 'loss_val': loss_test, | |
| 'accuracy_train': accuracy_train, | |
| 'accuracy_test': accuracy_test, | |
| } | |
| best = loss_test < best_loss | |
| net.save(cfg, epoch, stats, best=best) | |
| if not best: | |
| log.warning('Stopping early') | |
| break | |
| best_loss = loss_test | |
| epoch += 1 | |
| if __name__ == '__main__': | |
| # Common boiler-plating needed to run the code from the command line as `python lecture.py` or `./lecture.py` | |
| # This if condition will be False if the file is imported | |
| lecture() | |