Spaces:
Paused
Paused
| # %% | |
| import argparse | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from models.tacotron2.tacotron2_ms import Tacotron2MS | |
| from utils import get_config | |
| from utils.data import ArabDataset, text_mel_collate_fn | |
| from utils.logging import TBLogger | |
| from utils.training import batch_to_device, save_states_gan as save_states | |
| from models.common.loss import PatchDiscriminator, extract_chunks, calc_feature_match_loss | |
| from models.tacotron2.loss import Tacotron2Loss | |
| # %% | |
| try: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, | |
| default="configs/nawar_tc2_adv.yaml", help="Path to yaml config file") | |
| args = parser.parse_args() | |
| config_path = args.config | |
| except: | |
| config_path = './configs/nawar_tc2_adv.yaml' | |
| # %% | |
| config = get_config(config_path) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # set random seed | |
| if config.random_seed != False: | |
| torch.manual_seed(config.random_seed) | |
| torch.cuda.manual_seed_all(config.random_seed) | |
| import numpy as np | |
| np.random.seed(config.random_seed) | |
| # make checkpoint folder if nonexistent | |
| if not os.path.isdir(config.checkpoint_dir): | |
| os.makedirs(os.path.abspath(config.checkpoint_dir)) | |
| print(f"Created checkpoint_dir folder: {config.checkpoint_dir}") | |
| # datasets | |
| train_dataset = ArabDataset(txtpath=config.train_labels, | |
| wavpath=config.train_wavs_path, | |
| label_pattern=config.label_pattern) | |
| # test_dataset = ArabDataset(config.test_labels, config.test_wavs_path) | |
| # optional: balanced sampling | |
| sampler, shuffle, drop_last = None, True, True | |
| if config.balanced_sampling: | |
| weights = torch.load(config.sampler_weights_file) | |
| sampler = torch.utils.data.WeightedRandomSampler( | |
| weights, len(weights), replacement=False) | |
| shuffle, drop_last = False, False | |
| # dataloaders | |
| train_loader = DataLoader(train_dataset, | |
| batch_size=config.batch_size, | |
| collate_fn=text_mel_collate_fn, | |
| shuffle=shuffle, drop_last=drop_last, | |
| sampler=sampler) | |
| # test_loader = DataLoader(test_dataset, | |
| # batch_size=config.batch_size, drop_last=False, | |
| # shuffle=False, collate_fn=text_mel_collate_fn) | |
| # %% Generator | |
| model = Tacotron2MS(n_symbol=40, num_speakers=40) | |
| model = model.to(device) | |
| model.decoder.decoder_max_step = config.decoder_max_step | |
| optimizer = torch.optim.AdamW(model.parameters(), | |
| lr=config.g_lr, | |
| betas=(config.g_beta1, config.g_beta2), | |
| weight_decay=config.weight_decay) | |
| criterion = Tacotron2Loss(mel_loss_scale=1.0) | |
| # %% Discriminator | |
| critic = PatchDiscriminator(1, 32).to(device) | |
| optimizer_d = torch.optim.AdamW(critic.parameters(), | |
| lr=config.d_lr, | |
| betas=(config.d_beta1, config.d_beta2), | |
| weight_decay=config.weight_decay) | |
| tar_len = 128 | |
| # %% | |
| # resume from existing checkpoint | |
| n_epoch, n_iter = 0, 0 | |
| if config.restore_model != '': | |
| state_dicts = torch.load(config.restore_model) | |
| model.load_state_dict(state_dicts['model']) | |
| if 'model_d' in state_dicts: | |
| critic.load_state_dict(state_dicts['model_d'], strict=False) | |
| if 'optim' in state_dicts: | |
| optimizer.load_state_dict(state_dicts['optim']) | |
| if 'optim_d' in state_dicts: | |
| optimizer_d.load_state_dict(state_dicts['optim_d']) | |
| if 'epoch' in state_dicts: | |
| n_epoch = state_dicts['epoch'] | |
| if 'iter' in state_dicts: | |
| n_iter = state_dicts['iter'] | |
| # %% | |
| # tensorboard writer | |
| writer = TBLogger(config.log_dir) | |
| # %% | |
| def trunc_batch(batch, N): | |
| return (batch[0][:N], batch[1][:N], batch[2][:N], | |
| batch[3][:N], batch[4][:N]) | |
| # %% TRAINING LOOP | |
| model.train() | |
| for epoch in range(n_epoch, config.epochs): | |
| print(f"Epoch: {epoch}") | |
| for batch in train_loader: | |
| if batch[-1][0] > 2000: | |
| batch = trunc_batch(batch, 6) | |
| text_padded, input_lengths, mel_padded, gate_padded, \ | |
| output_lengths = batch_to_device(batch, device) | |
| y_pred = model(text_padded, input_lengths, | |
| mel_padded, output_lengths, | |
| torch.zeros_like(output_lengths)) | |
| mel_out, mel_out_postnet, gate_out, alignments = y_pred | |
| # extract chunks for critic | |
| Nchunks = mel_out.size(0) | |
| tar_len_ = min(output_lengths.min().item(), tar_len) | |
| mel_ids = torch.randint(0, mel_out.size(0), (Nchunks,)).cuda(non_blocking=True) | |
| ofx_perc = torch.rand(Nchunks).cuda(non_blocking=True) | |
| out_lens = output_lengths[mel_ids] | |
| ofx = (ofx_perc * (out_lens + tar_len_) - tar_len_/2) \ | |
| .clamp(out_lens*0, out_lens - tar_len_).long() | |
| chunks_org = extract_chunks( | |
| mel_padded, ofx, mel_ids, tar_len_) # mel_padded: B F T | |
| chunks_gen = extract_chunks( | |
| mel_out_postnet, ofx, mel_ids, tar_len_) # mel_out_postnet: B F T | |
| chunks_org_ = (chunks_org.unsqueeze(1) + 4.5) / 2.5 | |
| chunks_gen_ = (chunks_gen.unsqueeze(1) + 4.5) / 2.5 | |
| # DISCRIMINATOR | |
| d_org, fmaps_org = critic(chunks_org_.requires_grad_(True)) | |
| d_gen, _ = critic(chunks_gen_.detach()) | |
| loss_d = 0.5*(d_org - 1).square().mean() + 0.5*d_gen.square().mean() | |
| critic.zero_grad() | |
| loss_d.backward() | |
| optimizer_d.step() | |
| # GENERATOR | |
| loss, meta = criterion(mel_out, mel_out_postnet, mel_padded, | |
| gate_out, gate_padded) | |
| d_gen2, fmaps_gen = critic(chunks_gen_) | |
| loss_score = (d_gen2 - 1).square().mean() | |
| loss_fmatch = calc_feature_match_loss(fmaps_gen, fmaps_org) | |
| loss += config.gan_loss_weight * loss_score | |
| loss += config.feat_loss_weight * loss_fmatch | |
| optimizer.zero_grad() | |
| loss.backward() | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), config.grad_clip_thresh) | |
| optimizer.step() | |
| # LOGGING | |
| meta['score'] = loss_score.clone().detach() | |
| meta['fmatch'] = loss_fmatch.clone().detach() | |
| meta['loss'] = loss.clone().detach() | |
| print(f"loss: {loss.item()}, grad_norm: {grad_norm.item()}") | |
| writer.add_training_data(meta, grad_norm.item(), | |
| config.learning_rate, n_iter) | |
| if n_iter % config.n_save_states_iter == 0: | |
| save_states(f'states.pth', model, critic, | |
| optimizer, optimizer_d, n_iter, | |
| epoch, None, config) | |
| if n_iter % config.n_save_backup_iter == 0 and n_iter > 0: | |
| save_states(f'states_{n_iter}.pth', model, critic, | |
| optimizer, optimizer_d, n_iter, | |
| epoch, None, config) | |
| n_iter += 1 | |
| # VALIDATE | |
| # val_loss = validate(model, test_loader, writer, device, n_iter) | |
| # print(f"Validation loss: {val_loss}") | |
| save_states(f'states.pth', model, critic, | |
| optimizer, optimizer_d, n_iter, | |
| epoch, None, config) | |
| # %% | |