| | """ |
| | The main training script for training on synthetic data |
| | """ |
| |
|
| | import torch |
| | import torch.utils.data |
| | import torch.nn as nn |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import multiprocessing |
| | import time |
| |
|
| | import numpy as np |
| | import src.utils as utils |
| | from src.training.tain_val import train_epoch, test_epoch |
| | import shutil |
| | import sys |
| |
|
| | import wandb |
| |
|
| | VAL_SEED = 0 |
| | CURRENT_EPOCH = 0 |
| |
|
| | def seed_from_epoch(seed): |
| | global CURRENT_EPOCH |
| |
|
| | utils.seed_all(seed + CURRENT_EPOCH) |
| |
|
| | def print_metrics(metrics: list): |
| | input_sisdr = np.array([x['input_si_sdr'] for x in metrics]) |
| | sisdr = np.array([x['si_sdr'] for x in metrics]) |
| |
|
| | print("Average Input SI-SDR: {:03f}, Average Output SI-SDR: {:03f}, Average SI-SDRi: {:03f}".format(np.mean(input_sisdr), np.mean(sisdr), np.mean(sisdr - input_sisdr))) |
| |
|
| |
|
| | def train(args: argparse.Namespace): |
| | """ |
| | Resolve the network to be trained |
| | """ |
| | |
| | utils.seed_all(args.seed) |
| | os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | if args.use_nondeterministic_cudnn: |
| | torch.backends.cudnn.deterministic = False |
| | else: |
| | torch.backends.cudnn.deterministic = True |
| |
|
| | |
| | with open(args.config, 'rb') as f: |
| | params = json.load(f) |
| |
|
| | |
| | data_train = utils.import_attr(params['train_dataset'])(**params['train_data_args'], split='train') |
| | data_val = utils.import_attr(params['val_dataset'])(**params['val_data_args'], split='val') |
| |
|
| | |
| | use_cuda = True |
| | device = torch.device('cuda' if use_cuda else 'cpu') |
| | print("Using device {}".format('cuda' if use_cuda else 'cpu')) |
| |
|
| | |
| | num_workers = min(multiprocessing.cpu_count(), params['num_workers']) |
| | kwargs = { |
| | 'num_workers': num_workers, |
| | 'worker_init_fn': lambda x: seed_from_epoch(args.seed), |
| | 'pin_memory': False |
| | } if use_cuda else {} |
| |
|
| | |
| | train_loader = torch.utils.data.DataLoader(data_train, |
| | batch_size=params['batch_size'], |
| | shuffle=True, |
| | **kwargs) |
| | |
| | kwargs['worker_init_fn'] = lambda x: utils.seed_all(VAL_SEED) |
| | test_loader = torch.utils.data.DataLoader(data_val, |
| | batch_size=params['eval_batch_size'], |
| | **kwargs) |
| |
|
| | |
| | hl_module = utils.import_attr(params['pl_module'])(**params['pl_module_args']) |
| | hl_module.model.to(device) |
| | |
| | |
| | run_name = os.path.basename(args.run_dir.rstrip('/')) |
| | checkpoints_dir = os.path.join(args.run_dir, 'checkpoints') |
| |
|
| | |
| | if not os.path.exists(checkpoints_dir): |
| | os.makedirs(checkpoints_dir) |
| |
|
| | |
| | shutil.copyfile(args.config, os.path.join(args.run_dir, 'config.json')) |
| |
|
| | |
| | best_path = os.path.join(checkpoints_dir, 'best.pt') |
| | state_path = os.path.join(checkpoints_dir, 'last.pt') |
| | if args.best and os.path.exists(best_path): |
| | print("load best state path .....") |
| | hl_module.load_state(best_path) |
| | |
| | elif os.path.exists(state_path): |
| | print("load state path .....") |
| | hl_module.load_state(state_path) |
| |
|
| | start_epoch = hl_module.epoch |
| | |
| | if "project_name" in params.keys(): |
| | project_name = params["project_name"] |
| | else: |
| | project_name = "AcousticBubble" |
| | |
| | |
| | wandb_run = wandb.init( |
| | project=project_name, |
| | name=run_name, |
| | notes='Example of a note', |
| | tags=['speech', 'audio', 'embedded-systems'] |
| | ) |
| |
|
| | |
| | try: |
| | |
| | for epoch in range(start_epoch, params['epochs']): |
| | global CURRENT_EPOCH, VAL_SEED |
| | CURRENT_EPOCH = epoch |
| | seed_from_epoch(args.seed) |
| |
|
| | hl_module.on_epoch_start() |
| |
|
| | current_lr = hl_module.get_current_lr() |
| | print("CURRENT learning rate: {:0.08f}".format(current_lr)) |
| |
|
| | print("[TRAINING]") |
| | |
| | |
| | |
| | t1 = time.time() |
| | train_loss = train_epoch(hl_module, train_loader, device) |
| | t2 = time.time() |
| | print(f"Train epoch time: {t2 - t1:02f}s") |
| |
|
| | print("\nTrain set: Average Loss: {:.4f}\n".format(train_loss)) |
| |
|
| | print() |
| | if np.isnan(train_loss): |
| | raise ValueError("Got NAN in training") |
| | utils.seed_all(VAL_SEED) |
| |
|
| | |
| |
|
| | print("[TESTING]") |
| | |
| | test_loss = test_epoch(hl_module, test_loader, device) |
| | |
| | print("\nTest set: Average Loss: {:.4f}\n".format(test_loss)) |
| | |
| | hl_module.on_epoch_end(best_path, wandb_run) |
| | hl_module.dump_state(state_path) |
| |
|
| | print() |
| | print("=" * 25, "FINISHED EPOCH", epoch, "=" * 25) |
| | print() |
| |
|
| | except KeyboardInterrupt: |
| | print("Interrupted") |
| | except Exception as _: |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument('--config', type=str, |
| | help='Path to experiment config') |
| |
|
| | parser.add_argument('--run_dir', type=str, |
| | help='Path to experiment directory') |
| | |
| | parser.add_argument('--best', action='store_true', |
| | help="load from best checkpoint instead of last checkpoint") |
| |
|
| | |
| | parser.add_argument('--seed', type=int, default=10, |
| | help='Random seed for reproducibility') |
| | parser.add_argument('--use_nondeterministic_cudnn', |
| | action='store_true', |
| | help="If using cuda, chooses whether or not to use \ |
| | non-deterministic cudDNN algorithms. Training will be\ |
| | faster, but the final results may differ slighty.") |
| | |
| | |
| | parser.add_argument('--project_name', |
| | type=str, |
| | default='AcousticBubble', |
| | help='Project name that shows up on wandb') |
| | train(parser.parse_args()) |
| |
|