|
|
""" |
|
|
The main training script for training on synthetic data |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.utils.data |
|
|
import torch.nn as nn |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import multiprocessing |
|
|
import time |
|
|
|
|
|
import numpy as np |
|
|
import src.utils as utils |
|
|
from src.training.tain_val import train_epoch, test_epoch |
|
|
import shutil |
|
|
import sys |
|
|
|
|
|
import wandb |
|
|
|
|
|
VAL_SEED = 0 |
|
|
CURRENT_EPOCH = 0 |
|
|
|
|
|
def seed_from_epoch(seed): |
|
|
global CURRENT_EPOCH |
|
|
|
|
|
utils.seed_all(seed + CURRENT_EPOCH) |
|
|
|
|
|
def print_metrics(metrics: list): |
|
|
input_sisdr = np.array([x['input_si_sdr'] for x in metrics]) |
|
|
sisdr = np.array([x['si_sdr'] for x in metrics]) |
|
|
|
|
|
print("Average Input SI-SDR: {:03f}, Average Output SI-SDR: {:03f}, Average SI-SDRi: {:03f}".format(np.mean(input_sisdr), np.mean(sisdr), np.mean(sisdr - input_sisdr))) |
|
|
|
|
|
|
|
|
def train(args: argparse.Namespace): |
|
|
""" |
|
|
Resolve the network to be trained |
|
|
""" |
|
|
|
|
|
utils.seed_all(args.seed) |
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
if args.use_nondeterministic_cudnn: |
|
|
torch.backends.cudnn.deterministic = False |
|
|
else: |
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
|
|
|
with open(args.config, 'rb') as f: |
|
|
params = json.load(f) |
|
|
|
|
|
|
|
|
data_train = utils.import_attr(params['train_dataset'])(**params['train_data_args'], split='train') |
|
|
data_val = utils.import_attr(params['val_dataset'])(**params['val_data_args'], split='val') |
|
|
|
|
|
|
|
|
use_cuda = True |
|
|
device = torch.device('cuda' if use_cuda else 'cpu') |
|
|
print("Using device {}".format('cuda' if use_cuda else 'cpu')) |
|
|
|
|
|
|
|
|
num_workers = min(multiprocessing.cpu_count(), params['num_workers']) |
|
|
kwargs = { |
|
|
'num_workers': num_workers, |
|
|
'worker_init_fn': lambda x: seed_from_epoch(args.seed), |
|
|
'pin_memory': False |
|
|
} if use_cuda else {} |
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(data_train, |
|
|
batch_size=params['batch_size'], |
|
|
shuffle=True, |
|
|
**kwargs) |
|
|
|
|
|
kwargs['worker_init_fn'] = lambda x: utils.seed_all(VAL_SEED) |
|
|
test_loader = torch.utils.data.DataLoader(data_val, |
|
|
batch_size=params['eval_batch_size'], |
|
|
**kwargs) |
|
|
|
|
|
|
|
|
hl_module = utils.import_attr(params['pl_module'])(**params['pl_module_args']) |
|
|
hl_module.model.to(device) |
|
|
|
|
|
|
|
|
run_name = os.path.basename(args.run_dir.rstrip('/')) |
|
|
checkpoints_dir = os.path.join(args.run_dir, 'checkpoints') |
|
|
|
|
|
|
|
|
if not os.path.exists(checkpoints_dir): |
|
|
os.makedirs(checkpoints_dir) |
|
|
|
|
|
|
|
|
shutil.copyfile(args.config, os.path.join(args.run_dir, 'config.json')) |
|
|
|
|
|
|
|
|
best_path = os.path.join(checkpoints_dir, 'best.pt') |
|
|
state_path = os.path.join(checkpoints_dir, 'last.pt') |
|
|
if args.best and os.path.exists(best_path): |
|
|
print("load best state path .....") |
|
|
hl_module.load_state(best_path) |
|
|
|
|
|
elif os.path.exists(state_path): |
|
|
print("load state path .....") |
|
|
hl_module.load_state(state_path) |
|
|
|
|
|
start_epoch = hl_module.epoch |
|
|
|
|
|
if "project_name" in params.keys(): |
|
|
project_name = params["project_name"] |
|
|
else: |
|
|
project_name = "AcousticBubble" |
|
|
|
|
|
|
|
|
wandb_run = wandb.init( |
|
|
project=project_name, |
|
|
name=run_name, |
|
|
notes='Example of a note', |
|
|
tags=['speech', 'audio', 'embedded-systems'] |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
for epoch in range(start_epoch, params['epochs']): |
|
|
global CURRENT_EPOCH, VAL_SEED |
|
|
CURRENT_EPOCH = epoch |
|
|
seed_from_epoch(args.seed) |
|
|
|
|
|
hl_module.on_epoch_start() |
|
|
|
|
|
current_lr = hl_module.get_current_lr() |
|
|
print("CURRENT learning rate: {:0.08f}".format(current_lr)) |
|
|
|
|
|
print("[TRAINING]") |
|
|
|
|
|
|
|
|
|
|
|
t1 = time.time() |
|
|
train_loss = train_epoch(hl_module, train_loader, device) |
|
|
t2 = time.time() |
|
|
print(f"Train epoch time: {t2 - t1:02f}s") |
|
|
|
|
|
print("\nTrain set: Average Loss: {:.4f}\n".format(train_loss)) |
|
|
|
|
|
print() |
|
|
if np.isnan(train_loss): |
|
|
raise ValueError("Got NAN in training") |
|
|
utils.seed_all(VAL_SEED) |
|
|
|
|
|
|
|
|
|
|
|
print("[TESTING]") |
|
|
|
|
|
test_loss = test_epoch(hl_module, test_loader, device) |
|
|
|
|
|
print("\nTest set: Average Loss: {:.4f}\n".format(test_loss)) |
|
|
|
|
|
hl_module.on_epoch_end(best_path, wandb_run) |
|
|
hl_module.dump_state(state_path) |
|
|
|
|
|
print() |
|
|
print("=" * 25, "FINISHED EPOCH", epoch, "=" * 25) |
|
|
print() |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("Interrupted") |
|
|
except Exception as _: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--config', type=str, |
|
|
help='Path to experiment config') |
|
|
|
|
|
parser.add_argument('--run_dir', type=str, |
|
|
help='Path to experiment directory') |
|
|
|
|
|
parser.add_argument('--best', action='store_true', |
|
|
help="load from best checkpoint instead of last checkpoint") |
|
|
|
|
|
|
|
|
parser.add_argument('--seed', type=int, default=10, |
|
|
help='Random seed for reproducibility') |
|
|
parser.add_argument('--use_nondeterministic_cudnn', |
|
|
action='store_true', |
|
|
help="If using cuda, chooses whether or not to use \ |
|
|
non-deterministic cudDNN algorithms. Training will be\ |
|
|
faster, but the final results may differ slighty.") |
|
|
|
|
|
|
|
|
parser.add_argument('--project_name', |
|
|
type=str, |
|
|
default='AcousticBubble', |
|
|
help='Project name that shows up on wandb') |
|
|
train(parser.parse_args()) |
|
|
|