| | import torch |
| | import numpy as np |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class SLMAdversarialLoss(torch.nn.Module): |
| | def __init__( |
| | self, |
| | model, |
| | wl, |
| | sampler, |
| | min_len, |
| | max_len, |
| | batch_percentage=0.5, |
| | skip_update=10, |
| | sig=1.5, |
| | ): |
| | super().__init__() |
| | self.model = model |
| | self.wl = wl |
| | self.sampler = sampler |
| |
|
| | self.min_len = min_len |
| | self.max_len = max_len |
| | self.batch_percentage = batch_percentage |
| |
|
| | self.sig = sig |
| | self.skip_update = skip_update |
| |
|
| | |
| | def forward( |
| | self, |
| | iters, |
| | y_rec_gt, |
| | y_rec_gt_pred, |
| | waves, |
| | mel_input_length, |
| | ref_text, |
| | ref_lengths, |
| | use_ind, |
| | s_trg, |
| | ref_s=None, |
| | ): |
| | |
| | seq_len = ref_text.size(1) |
| | text_mask = ( |
| | torch.arange(seq_len, device=ref_text.device) |
| | .unsqueeze(0) |
| | >= ref_lengths.unsqueeze(1) |
| | ) |
| |
|
| | bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) |
| | d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) |
| |
|
| | |
| | if use_ind and np.random.rand() < 0.5: |
| | s_preds = s_trg |
| | else: |
| | num_steps = np.random.randint(3, 5) |
| | noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device) |
| | sampler_kwargs = dict( |
| | noise=noise, |
| | embedding=bert_dur, |
| | embedding_scale=1, |
| | embedding_mask_proba=0.1, |
| | num_steps=num_steps, |
| | ) |
| | if ref_s is not None: |
| | sampler_kwargs["features"] = ref_s |
| | s_preds = self.sampler(**sampler_kwargs).squeeze(1) |
| |
|
| | s_dur, s = s_preds[:, 128:], s_preds[:, :128] |
| |
|
| | |
| | seq_len = ref_text.size(1) |
| | rand_align = torch.randn(ref_text.size(0), seq_len, 2, device=ref_text.device) |
| |
|
| | d, _ = self.model.predictor( |
| | d_en, s_dur, ref_lengths, |
| | rand_align, |
| | text_mask, |
| | ) |
| |
|
| | |
| | attn_preds, output_lengths = [], [] |
| | for _s2s_pred, _len in zip(d, ref_lengths): |
| | _s2s_pred_org = _s2s_pred[: _len] |
| | _s2s_pred_sig = torch.sigmoid(_s2s_pred_org) |
| | _dur_pred = _s2s_pred_sig.sum(dim=-1) |
| |
|
| | l = int(torch.round(_s2s_pred_sig.sum()).item()) |
| | t = torch.arange(l, device=ref_text.device).unsqueeze(0).expand(_len, l) |
| | loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 |
| | h = torch.exp(-0.5 * (t - (l - loc.unsqueeze(-1))) ** 2 / (self.sig**2)) |
| |
|
| | out = F.conv1d( |
| | _s2s_pred_org.unsqueeze(0), |
| | h.unsqueeze(1), |
| | padding=h.size(-1) - 1, |
| | groups=int(_len), |
| | )[..., :l] |
| | attn_preds.append(F.softmax(out.squeeze(), dim=0)) |
| | output_lengths.append(l) |
| |
|
| | max_len = max(output_lengths) |
| |
|
| | |
| | with torch.no_grad(): |
| | t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) |
| |
|
| | seq_len = ref_text.size(1) |
| | s2s_attn = torch.zeros( |
| | len(ref_lengths), seq_len, max_len, device=ref_text.device |
| | ) |
| | for bib, (attn, L) in enumerate(zip(attn_preds, output_lengths)): |
| | s2s_attn[bib, : ref_lengths[bib], :L] = attn |
| |
|
| | asr_pred = t_en @ s2s_attn |
| |
|
| | _, p_pred = self.model.predictor( |
| | d_en, s_dur, ref_lengths, s2s_attn, text_mask |
| | ) |
| |
|
| | |
| | mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) |
| | mel_len = min(mel_len, self.max_len // 2) |
| |
|
| | en, p_en, sp, wav = [], [], [], [] |
| | for bib, L_pred in enumerate(output_lengths): |
| | L_gt = int(mel_input_length[bib].item() / 2) |
| | if L_gt <= mel_len or L_pred <= mel_len: |
| | continue |
| |
|
| | sp.append(s_preds[bib]) |
| |
|
| | start = np.random.randint(0, L_pred - mel_len) |
| | en.append(asr_pred[bib, :, start : start + mel_len]) |
| | p_en.append(p_pred[bib, :, start : start + mel_len]) |
| |
|
| | start_gt = np.random.randint(0, L_gt - mel_len) |
| | y = waves[bib][(start_gt * 2) * 300 : ((start_gt + mel_len) * 2) * 300] |
| | wav.append(torch.from_numpy(y).to(ref_text.device)) |
| |
|
| | if len(wav) >= self.batch_percentage * len(waves): |
| | break |
| |
|
| | if len(sp) <= 1: |
| | return None |
| |
|
| | sp = torch.stack(sp) |
| | wav = torch.stack(wav).float() |
| | en = torch.stack(en) |
| | p_en = torch.stack(p_en) |
| |
|
| | F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:]) |
| | y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) |
| |
|
| | |
| | if (iters + 1) % self.skip_update == 0: |
| | d_loss = self.wl.discriminator(wav.squeeze(), y_pred.detach().squeeze()).mean() |
| | else: |
| | d_loss = 0 |
| |
|
| | gen_loss = self.wl.generator(y_pred.squeeze()).mean() |
| | return d_loss, gen_loss, y_pred.detach().cpu().numpy() |
| |
|
| |
|
| | |
| | def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: |
| | """Classic length mask: 1 → PAD, 0 → real token.""" |
| | max_len = lengths.max() |
| | mask = ( |
| | torch.arange(max_len, device=lengths.device) |
| | .unsqueeze(0) |
| | .expand(lengths.size(0), -1) |
| | ) |
| | return mask >= lengths.unsqueeze(1) |