| """ |
| The main training script for training on synthetic data |
| """ |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import os |
| import tqdm |
|
|
|
|
| def to_device(batch, device): |
| if type(batch) == torch.Tensor: |
| return batch.to(device) |
| elif type(batch) == dict: |
| for k in batch: |
| batch[k] = to_device(batch[k], device) |
| return batch |
| elif type(batch) in [list, tuple]: |
| batch = [to_device(x, device) for x in batch] |
| return batch |
| else: |
| return batch |
|
|
| def test_epoch(hl_module, test_loader, device) -> float: |
| """ |
| Evaluate the network. |
| """ |
| hl_module.eval() |
| |
| test_loss = 0 |
| num_elements = 0 |
|
|
| num_batches = len(test_loader) |
| pbar = tqdm.tqdm(total=num_batches) |
|
|
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(test_loader): |
| batch = to_device(batch, device) |
| |
| loss, B = hl_module.validation_step(batch, batch_idx) |
| |
| test_loss += (loss.item() * B) |
| num_elements += B |
|
|
| pbar.set_postfix(loss='%.05f'%(loss.item()) ) |
| pbar.update() |
|
|
| return test_loss / num_elements |
|
|
| def train_epoch(hl_module, train_loader, device) -> float: |
| """ |
| Train a single epoch. |
| """ |
| |
| hl_module.train() |
| |
| |
| train_loss = 0 |
| num_elements = 0 |
|
|
| num_batches = len(train_loader) |
| pbar = tqdm.tqdm(total=num_batches) |
| |
| for batch_idx, batch in enumerate(train_loader): |
| batch = to_device(batch, device) |
|
|
| |
| hl_module.reset_grad() |
| |
| |
| loss, B = hl_module.training_step(batch, batch_idx) |
|
|
| |
| loss.backward(retain_graph=False) |
| hl_module.backprop() |
|
|
| |
| loss = loss.detach() |
| train_loss += (loss.item() * B) |
| num_elements += B |
| |
| |
| |
| pbar.set_postfix(loss='%.05f'%(loss.item()) ) |
| pbar.update() |
|
|
| return train_loss / num_elements |