| import matplotlib.pyplot as plt | |
| import json | |
| import torch | |
| import torchaudio | |
| def configure_args(config, args): | |
| for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]: | |
| if getattr(args, key) != None: | |
| config["general"][key] = str(getattr(args, key)) | |
| for key in ["n_train", "n_val", "n_test"]: | |
| if getattr(args, key) != None: | |
| config["preprocess"][key] = getattr(args, key) | |
| for key in ["alpha", "beta", "learning_rate", "epoch"]: | |
| if getattr(args, key) != None: | |
| config["train"][key] = getattr(args, key) | |
| for key in ["load_pretrained", "early_stopping"]: | |
| config["train"][key] = getattr(args, key) | |
| if args.feature_loss_type != None: | |
| config["train"]["feature_loss"]["type"] = args.feature_loss_type | |
| for key in ["pretrained_path"]: | |
| if getattr(args, key) != None: | |
| config["train"][key] = str(getattr(args, key)) | |
| return config, args | |