Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import pickle | |
| import datetime | |
| import itertools | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from onmt_modules.misc import sequence_mask | |
| from model_autopst import Generator_1 as Predictor | |
| class Solver(object): | |
| def __init__(self, data_loader, config, hparams): | |
| """Initialize configurations.""" | |
| self.data_loader = data_loader | |
| self.hparams = hparams | |
| self.gate_threshold = hparams.gate_threshold | |
| self.use_cuda = torch.cuda.is_available() | |
| self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu') | |
| self.num_iters = config.num_iters | |
| self.log_step = config.log_step | |
| # Build the model | |
| self.build_model() | |
| def build_model(self): | |
| self.P = Predictor(self.hparams) | |
| self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999]) | |
| self.P.to(self.device) | |
| self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device) | |
| def train(self): | |
| # Set data loader | |
| data_loader = self.data_loader | |
| data_iter = iter(data_loader) | |
| # Print logs in specified order | |
| keys = ['P/loss_tx2sp', 'P/loss_stop_sp'] | |
| # Start training. | |
| print('Start training...') | |
| start_time = time.time() | |
| for i in range(self.num_iters): | |
| try: | |
| sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter) | |
| except: | |
| data_iter = iter(data_loader) | |
| sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter) | |
| sp_real = sp_real.to(self.device) | |
| cep_real = cep_real.to(self.device) | |
| cd_real = cd_real.to(self.device) | |
| len_real = len_real.to(self.device) | |
| spk_emb = spk_emb.to(self.device) | |
| num_rep_sync = num_rep_sync.to(self.device) | |
| len_short_sync = len_short_sync.to(self.device) | |
| # real spect masks | |
| mask_sp_real = ~sequence_mask(len_real, sp_real.size(1)) | |
| mask_long = (~mask_sp_real).float() | |
| len_real_mask = torch.min(len_real + 10, | |
| torch.full_like(len_real, sp_real.size(1))) | |
| loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1) | |
| # text input masks | |
| codes_mask = sequence_mask(len_short_sync, num_rep_sync.size(1)).float() | |
| # =================================================================================== # | |
| # 2. Train # | |
| # =================================================================================== # | |
| self.P = self.P.train() | |
| sp_real_sft = torch.zeros_like(sp_real) | |
| sp_real_sft[:, 1:, :] = sp_real[:, :-1, :] | |
| spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1), | |
| mask_long, | |
| codes_mask, | |
| num_rep_sync, | |
| len_short_sync+1, | |
| sp_real_sft.transpose(1,0), | |
| len_real+1, | |
| spk_emb) | |
| loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none') | |
| * loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum() | |
| loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float()) | |
| loss_total = loss_tx2sp + loss_stop_sp | |
| # Backward and optimize | |
| self.optimizer.zero_grad() | |
| loss_total.backward() | |
| self.optimizer.step() | |
| # Logging | |
| loss = {} | |
| loss['P/loss_tx2sp'] = loss_tx2sp.item() | |
| loss['P/loss_stop_sp'] = loss_stop_sp.item() | |
| # =================================================================================== # | |
| # 4. Miscellaneous # | |
| # =================================================================================== # | |
| # Print out training information | |
| if (i+1) % self.log_step == 0: | |
| et = time.time() - start_time | |
| et = str(datetime.timedelta(seconds=et))[:-7] | |
| log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) | |
| for tag in keys: | |
| log += ", {}: {:.8f}".format(tag, loss[tag]) | |
| print(log) | |
| # Save model checkpoints. | |
| if (i+1) % 10000 == 0: | |
| torch.save({'model': self.P.state_dict(), | |
| 'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-A.ckpt') | |
| print('Saved model checkpoints into assets ...') |