bluemellophone's picture
Fix image
cd605d0 unverified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
The lecture materials for Lecture 1: Dataset Prototyping and Visualization
"""
import click
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import trange
from cv4e_lecture13 import dataset, model, utils
log = None
def inference(cfg, dataloader, net, optimizer, criterion, update):
'''
Our actual training function.
'''
device = cfg.get('device')
torch.set_grad_enabled(update)
net.train() if update else net.eval()
type_str = 'Train' if update else 'Val'
loss, accuracy = 0.0, 0.0
total = len(dataloader)
prog = trange(total)
for index, (data, labels) in enumerate(dataloader):
data, labels = data.to(device), labels.to(device)
prediction = net(data)
gradient = criterion(prediction, labels)
if update:
optimizer.zero_grad()
gradient.backward()
optimizer.step()
# log statistics
loss += gradient.item()
label_ = torch.argmax(prediction, dim=1)
accuracy += torch.mean((label_ == labels).float()).item()
prog.set_description(
'[{:s}] Loss: {:.2f}; Acc: {:.2f}%'.format(
type_str, loss / (index + 1), 100.0 * accuracy / (index + 1)
)
)
prog.update(1)
prog.close()
loss /= total
accuracy /= total
return loss, accuracy
@click.command()
@click.option(
'--config', help='Path to config file', default='configs/mnist_resnet18.yaml'
)
def lecture(config):
"""
Main function for Lecture 1: Dataset Prototyping and Visualization
"""
global log
log = utils.init_logging()
cfg = utils.init_config(config, log)
# init random number generator seed (set at the start)
utils.init_seed(cfg.get('seed', None))
################################################################################
# Load MNIST
train, test = dataset.load(cfg)
net, epoch, best_loss = model.load(cfg)
optimizer = Adam(
net.parameters(),
lr=cfg.get('learning_rate'),
weight_decay=cfg.get('weight_decay'),
)
criterion = nn.CrossEntropyLoss()
epochs = cfg.get('max_epochs')
while epoch < epochs:
log.info(f'Epoch {epoch}/{epochs}')
loss_train, accuracy_train = inference(
cfg, train, net, optimizer, criterion, update=True
)
loss_test, accuracy_test = inference(
cfg, test, net, optimizer, criterion, update=False
)
# combine stats and save
stats = {
'loss_train': loss_train,
'loss_val': loss_test,
'accuracy_train': accuracy_train,
'accuracy_test': accuracy_test,
}
best = loss_test < best_loss
net.save(cfg, epoch, stats, best=best)
if not best:
log.warning('Stopping early')
break
best_loss = loss_test
epoch += 1
if __name__ == '__main__':
# Common boiler-plating needed to run the code from the command line as `python lecture.py` or `./lecture.py`
# This if condition will be False if the file is imported
lecture()