| import os |
| import argparse |
| import torch |
|
|
| from logger import utils |
| from data_loaders import get_data_loaders |
| from solver import train |
| from ddsp.vocoder import Sins, CombSub, CombSubFast |
| from ddsp.loss import RSSLoss |
|
|
|
|
| def parse_args(args=None, namespace=None): |
| """Parse command-line arguments.""" |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-c", |
| "--config", |
| type=str, |
| required=True, |
| help="path to the config file") |
| return parser.parse_args(args=args, namespace=namespace) |
|
|
|
|
| if __name__ == '__main__': |
| |
| cmd = parse_args() |
| |
| |
| args = utils.load_config(cmd.config) |
| print(' > config:', cmd.config) |
| print(' > exp:', args.env.expdir) |
|
|
| |
| model = None |
| |
| if args.model.type == 'Sins': |
| model = Sins( |
| sampling_rate=args.data.sampling_rate, |
| block_size=args.data.block_size, |
| n_harmonics=args.model.n_harmonics, |
| n_mag_allpass=args.model.n_mag_allpass, |
| n_mag_noise=args.model.n_mag_noise, |
| n_unit=args.data.encoder_out_channels, |
| n_spk=args.model.n_spk) |
| |
| elif args.model.type == 'CombSub': |
| model = CombSub( |
| sampling_rate=args.data.sampling_rate, |
| block_size=args.data.block_size, |
| n_mag_allpass=args.model.n_mag_allpass, |
| n_mag_harmonic=args.model.n_mag_harmonic, |
| n_mag_noise=args.model.n_mag_noise, |
| n_unit=args.data.encoder_out_channels, |
| n_spk=args.model.n_spk) |
| |
| elif args.model.type == 'CombSubFast': |
| model = CombSubFast( |
| sampling_rate=args.data.sampling_rate, |
| block_size=args.data.block_size, |
| n_unit=args.data.encoder_out_channels, |
| n_spk=args.model.n_spk) |
| |
| else: |
| raise ValueError(f" [x] Unknown Model: {args.model.type}") |
| |
| |
| optimizer = torch.optim.AdamW(model.parameters()) |
| initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = args.train.lr |
| param_group['weight_decay'] = args.train.weight_decay |
| |
| |
| loss_func = RSSLoss(args.loss.fft_min, args.loss.fft_max, args.loss.n_scale, device = args.device) |
|
|
| |
| if args.device == 'cuda': |
| torch.cuda.set_device(args.env.gpu_id) |
| model.to(args.device) |
| |
| for state in optimizer.state.values(): |
| for k, v in state.items(): |
| if torch.is_tensor(v): |
| state[k] = v.to(args.device) |
| |
| loss_func.to(args.device) |
|
|
| |
| loader_train, loader_valid = get_data_loaders(args, whole_audio=False) |
| |
| |
| train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_valid) |
| |
|
|