| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| class SLMAdversarialLoss(torch.nn.Module): | |
| def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5): | |
| super(SLMAdversarialLoss, self).__init__() | |
| self.model = model | |
| self.wl = wl | |
| self.sampler = sampler | |
| self.min_len = min_len | |
| self.max_len = max_len | |
| self.batch_percentage = batch_percentage | |
| self.sig = sig | |
| self.skip_update = skip_update | |
| def forward(self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None): | |
| text_mask = length_to_mask(ref_lengths).to(ref_text.device) | |
| bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) | |
| d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) | |
| if use_ind and np.random.rand() < 0.5: | |
| s_preds = s_trg | |
| else: | |
| num_steps = np.random.randint(3, 5) | |
| if ref_s is not None: | |
| s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), | |
| embedding=bert_dur, | |
| embedding_scale=1, | |
| features=ref_s, # reference from the same speaker as the embedding | |
| embedding_mask_proba=0.1, | |
| num_steps=num_steps).squeeze(1) | |
| else: | |
| s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), | |
| embedding=bert_dur, | |
| embedding_scale=1, | |
| embedding_mask_proba=0.1, | |
| num_steps=num_steps).squeeze(1) | |
| s_dur = s_preds[:, 128:] | |
| s = s_preds[:, :128] | |
| d, _ = self.model.predictor(d_en, s_dur, | |
| ref_lengths, | |
| torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device), | |
| text_mask) | |
| bib = 0 | |
| output_lengths = [] | |
| attn_preds = [] | |
| # differentiable duration modeling | |
| for _s2s_pred, _text_length in zip(d, ref_lengths): | |
| _s2s_pred_org = _s2s_pred[:_text_length, :] | |
| _s2s_pred = torch.sigmoid(_s2s_pred_org) | |
| _dur_pred = _s2s_pred.sum(axis=-1) | |
| l = int(torch.round(_s2s_pred.sum()).item()) | |
| t = torch.arange(0, l).expand(l) | |
| t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device) | |
| loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 | |
| h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2) | |
| out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0), | |
| h.unsqueeze(1), | |
| padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l] | |
| attn_preds.append(F.softmax(out.squeeze(), dim=0)) | |
| output_lengths.append(l) | |
| max_len = max(output_lengths) | |
| with torch.no_grad(): | |
| t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) | |
| s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device) | |
| for bib in range(len(output_lengths)): | |
| s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib] | |
| asr_pred = t_en @ s2s_attn | |
| _, p_pred = self.model.predictor(d_en, s_dur, | |
| ref_lengths, | |
| s2s_attn, | |
| text_mask) | |
| mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) | |
| mel_len = min(mel_len, self.max_len // 2) | |
| # get clips | |
| en = [] | |
| p_en = [] | |
| sp = [] | |
| F0_fakes = [] | |
| N_fakes = [] | |
| wav = [] | |
| for bib in range(len(output_lengths)): | |
| mel_length_pred = output_lengths[bib] | |
| mel_length_gt = int(mel_input_length[bib].item() / 2) | |
| if mel_length_gt <= mel_len or mel_length_pred <= mel_len: | |
| continue | |
| sp.append(s_preds[bib]) | |
| random_start = np.random.randint(0, mel_length_pred - mel_len) | |
| en.append(asr_pred[bib, :, random_start:random_start+mel_len]) | |
| p_en.append(p_pred[bib, :, random_start:random_start+mel_len]) | |
| # get ground truth clips | |
| random_start = np.random.randint(0, mel_length_gt - mel_len) | |
| y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300] | |
| wav.append(torch.from_numpy(y).to(ref_text.device)) | |
| if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths | |
| break | |
| if len(sp) <= 1: | |
| return None | |
| sp = torch.stack(sp) | |
| wav = torch.stack(wav).float() | |
| en = torch.stack(en) | |
| p_en = torch.stack(p_en) | |
| F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:]) | |
| y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) | |
| # discriminator loss | |
| if (iters + 1) % self.skip_update == 0: | |
| if np.random.randint(0, 2) == 0: | |
| wav = y_rec_gt_pred | |
| use_rec = True | |
| else: | |
| use_rec = False | |
| crop_size = min(wav.size(-1), y_pred.size(-1)) | |
| if use_rec: # use reconstructed (shorter lengths), do length invariant regularization | |
| if wav.size(-1) > y_pred.size(-1): | |
| real_GP = wav[:, : , :crop_size] | |
| out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) | |
| out_org = self.wl.discriminator_forward(wav.detach().squeeze()) | |
| loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)]) | |
| if np.random.randint(0, 2) == 0: | |
| d_loss = self.wl.discriminator(real_GP.detach().squeeze(), y_pred.detach().squeeze()).mean() | |
| else: | |
| d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean() | |
| else: | |
| real_GP = y_pred[:, : , :crop_size] | |
| out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) | |
| out_org = self.wl.discriminator_forward(y_pred.detach().squeeze()) | |
| loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)]) | |
| if np.random.randint(0, 2) == 0: | |
| d_loss = self.wl.discriminator(wav.detach().squeeze(), real_GP.detach().squeeze()).mean() | |
| else: | |
| d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean() | |
| # regularization (ignore length variation) | |
| d_loss += loss_reg | |
| out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze()) | |
| out_rec = self.wl.discriminator_forward(y_rec_gt_pred.detach().squeeze()) | |
| # regularization (ignore reconstruction artifacts) | |
| d_loss += F.l1_loss(out_gt, out_rec) | |
| else: | |
| d_loss = self.wl.discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean() | |
| else: | |
| d_loss = 0 | |
| # generator loss | |
| gen_loss = self.wl.generator(y_pred.squeeze()) | |
| gen_loss = gen_loss.mean() | |
| return d_loss, gen_loss, y_pred.detach().cpu().numpy() | |
| def length_to_mask(lengths): | |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
| mask = torch.gt(mask+1, lengths.unsqueeze(1)) | |
| return mask | |
| # import torch | |
| # import numpy as np | |
| # import torch.nn.functional as F | |
| # from accelerate import Accelerator, DistributedDataParallelKwargs | |
| # from accelerate.utils import tqdm, ProjectConfiguration | |
| # class SLMAdversarialLoss(torch.nn.Module): | |
| # def __init__(self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5): | |
| # super(SLMAdversarialLoss, self).__init__() | |
| # self.model = model | |
| # self.wl = wl | |
| # self.sampler = sampler | |
| # self.min_len = min_len | |
| # self.max_len = max_len | |
| # self.batch_percentage = batch_percentage | |
| # self.sig = sig | |
| # self.skip_update = skip_update | |
| # def forward(self, iters, accelerator, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None): | |
| # text_mask = length_to_mask(ref_lengths).to(ref_text.device) | |
| # bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) | |
| # d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) | |
| # if use_ind and np.random.rand() < 0.5: | |
| # s_preds = s_trg | |
| # else: | |
| # num_steps = np.random.randint(3, 5) | |
| # if ref_s is not None: | |
| # s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), | |
| # embedding=bert_dur, | |
| # embedding_scale=1, | |
| # features=ref_s, # reference from the same speaker as the embedding | |
| # embedding_mask_proba=0.1, | |
| # num_steps=num_steps).squeeze(1) | |
| # else: | |
| # s_preds = self.sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), | |
| # embedding=bert_dur, | |
| # embedding_scale=1, | |
| # embedding_mask_proba=0.1, | |
| # num_steps=num_steps).squeeze(1) | |
| # s_dur = s_preds[:, 128:] | |
| # s = s_preds[:, :128] | |
| # d, _ = self.model.predictor(d_en, s_dur, | |
| # ref_lengths, | |
| # torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device), | |
| # text_mask) | |
| # bib = 0 | |
| # output_lengths = [] | |
| # attn_preds = [] | |
| # # differentiable duration modeling | |
| # for _s2s_pred, _text_length in zip(d, ref_lengths): | |
| # _s2s_pred_org = _s2s_pred[:_text_length, :] | |
| # _s2s_pred = torch.sigmoid(_s2s_pred_org) | |
| # _dur_pred = _s2s_pred.sum(axis=-1) | |
| # l = int(torch.round(_s2s_pred.sum()).item()) | |
| # t = torch.arange(0, l).expand(l) | |
| # t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to(ref_text.device) | |
| # loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 | |
| # h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig)**2) | |
| # out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0), | |
| # h.unsqueeze(1), | |
| # padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l] | |
| # attn_preds.append(F.softmax(out.squeeze(), dim=0)) | |
| # output_lengths.append(l) | |
| # max_len = max(output_lengths) | |
| # with torch.no_grad(): | |
| # t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) | |
| # s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(ref_text.device) | |
| # for bib in range(len(output_lengths)): | |
| # s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib] | |
| # asr_pred = t_en @ s2s_attn | |
| # _, p_pred = self.model.predictor(d_en, s_dur, | |
| # ref_lengths, | |
| # s2s_attn, | |
| # text_mask) | |
| # mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) | |
| # mel_len = min(mel_len, self.max_len // 2) | |
| # # get clips | |
| # en = [] | |
| # p_en = [] | |
| # sp = [] | |
| # F0_fakes = [] | |
| # N_fakes = [] | |
| # wav = [] | |
| # for bib in range(len(output_lengths)): | |
| # mel_length_pred = output_lengths[bib] | |
| # mel_length_gt = int(mel_input_length[bib].item() / 2) | |
| # if mel_length_gt <= mel_len or mel_length_pred <= mel_len: | |
| # continue | |
| # sp.append(s_preds[bib]) | |
| # random_start = np.random.randint(0, mel_length_pred - mel_len) | |
| # en.append(asr_pred[bib, :, random_start:random_start+mel_len]) | |
| # p_en.append(p_pred[bib, :, random_start:random_start+mel_len]) | |
| # # get ground truth clips | |
| # random_start = np.random.randint(0, mel_length_gt - mel_len) | |
| # y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300] | |
| # wav.append(torch.from_numpy(y).to(ref_text.device)) | |
| # if len(wav) >= self.batch_percentage * len(waves): # prevent OOM due to longer lengths | |
| # break | |
| # # global_min_batch = accelerator.gather(torch.tensor([len(wav)], device=ref_text.device)).min().item() | |
| # # if global_min_batch <= 1: | |
| # # raise ValueError("skip slmadv") | |
| # if len(sp) <= 1: | |
| # return None | |
| # sp = torch.stack(sp) | |
| # wav = torch.stack(wav).float() | |
| # en = torch.stack(en) | |
| # p_en = torch.stack(p_en) | |
| # F0_fake, N_fake = self.model.predictor(texts=p_en, style=sp[:, 128:], f0=True) | |
| # y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) | |
| # # discriminator loss | |
| # if (iters + 1) % self.skip_update == 0: | |
| # if np.random.randint(0, 2) == 0: | |
| # wav = y_rec_gt_pred | |
| # use_rec = True | |
| # else: | |
| # use_rec = False | |
| # crop_size = min(wav.size(-1), y_pred.size(-1)) | |
| # if use_rec: # use reconstructed (shorter lengths), do length invariant regularization | |
| # if wav.size(-1) > y_pred.size(-1): | |
| # real_GP = wav[:, : , :crop_size] | |
| # out_crop = self.wl(wav = real_GP.detach().squeeze(),y_rec=None, discriminator_forward=True) | |
| # out_org = self.wl(wav = wav.detach().squeeze(),y_rec=None, discriminator_forward=True) | |
| # loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)]) | |
| # if np.random.randint(0, 2) == 0: | |
| # d_loss = self.wl(wav = real_GP.detach().squeeze(),y_rec= y_pred.detach().squeeze(), discriminator=True).mean() | |
| # else: | |
| # d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = y_pred.detach().squeeze(), discriminator=True).mean() | |
| # else: | |
| # real_GP = y_pred[:, : , :crop_size] | |
| # out_crop = self.wl(wav = real_GP.detach().squeeze(), y_rec=None, discriminator_forward=True) | |
| # out_org = self.wl(wav = y_pred.detach().squeeze(),y_rec=None, discriminator_forward=True) | |
| # loss_reg = F.l1_loss(out_crop, out_org[..., :out_crop.size(-1)]) | |
| # if np.random.randint(0, 2) == 0: | |
| # d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = real_GP.detach().squeeze(), discriminator=True ).mean() | |
| # else: | |
| # d_loss = self.wl(wav = wav.detach().squeeze(), y_rec = y_pred.detach().squeeze(), discriminator=True).mean() | |
| # # regularization (ignore length variation) | |
| # d_loss += loss_reg | |
| # out_gt = self.wl(wav = y_rec_gt.detach().squeeze(),y_rec=None, discriminator_forward=True) | |
| # out_rec = self.wl(wav = y_rec_gt_pred.detach().squeeze(), y_rec=None, discriminator_forward=True) | |
| # # regularization (ignore reconstruction artifacts) | |
| # d_loss += F.l1_loss(out_gt, out_rec) | |
| # else: | |
| # d_loss = self.wl(wav = wav.detach().squeeze(),y_rec= y_pred.detach().squeeze(), discriminator=True).mean() | |
| # else: | |
| # d_loss = 0 | |
| # # generator loss | |
| # gen_loss = self.wl(wav = None, y_rec = y_pred.squeeze(), generator=True) | |
| # gen_loss = gen_loss.mean() | |
| # return d_loss, gen_loss, y_pred.detach().cpu().numpy() | |
| # def length_to_mask(lengths): | |
| # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
| # mask = torch.gt(mask+1, lengths.unsqueeze(1)) | |
| # return mask | |