File size: 5,992 Bytes
07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 07b5cfc bf65828 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | import torch
import numpy as np
import torch.nn.functional as F
class SLMAdversarialLoss(torch.nn.Module):
def __init__(
self,
model,
wl,
sampler,
min_len,
max_len,
batch_percentage=0.5,
skip_update=10,
sig=1.5,
):
super().__init__()
self.model = model
self.wl = wl
self.sampler = sampler
self.min_len = min_len
self.max_len = max_len
self.batch_percentage = batch_percentage
self.sig = sig
self.skip_update = skip_update
# ------------------------------------------------------------------ #
def forward(
self,
iters,
y_rec_gt,
y_rec_gt_pred,
waves,
mel_input_length,
ref_text,
ref_lengths,
use_ind,
s_trg,
ref_s=None,
):
# ---- full-width mask (matches ref_text.size(1)) ----------------
seq_len = ref_text.size(1)
text_mask = (
torch.arange(seq_len, device=ref_text.device)
.unsqueeze(0)
>= ref_lengths.unsqueeze(1)
) # shape [B, seq_len]
bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
# ----- style / prosody sampling ---------------------------------
if use_ind and np.random.rand() < 0.5:
s_preds = s_trg
else:
num_steps = np.random.randint(3, 5)
noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device)
sampler_kwargs = dict(
noise=noise,
embedding=bert_dur,
embedding_scale=1,
embedding_mask_proba=0.1,
num_steps=num_steps,
)
if ref_s is not None:
sampler_kwargs["features"] = ref_s
s_preds = self.sampler(**sampler_kwargs).squeeze(1)
s_dur, s = s_preds[:, 128:], s_preds[:, :128]
# random alignment placeholder must match the *padded* token width
seq_len = ref_text.size(1)
rand_align = torch.randn(ref_text.size(0), seq_len, 2, device=ref_text.device)
d, _ = self.model.predictor(
d_en, s_dur, ref_lengths,
rand_align,
text_mask,
)
# ----- differentiable duration modelling -----------------------
attn_preds, output_lengths = [], []
for _s2s_pred, _len in zip(d, ref_lengths):
_s2s_pred_org = _s2s_pred[: _len]
_s2s_pred_sig = torch.sigmoid(_s2s_pred_org)
_dur_pred = _s2s_pred_sig.sum(dim=-1)
l = int(torch.round(_s2s_pred_sig.sum()).item())
t = torch.arange(l, device=ref_text.device).unsqueeze(0).expand(_len, l)
loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
h = torch.exp(-0.5 * (t - (l - loc.unsqueeze(-1))) ** 2 / (self.sig**2))
out = F.conv1d(
_s2s_pred_org.unsqueeze(0),
h.unsqueeze(1),
padding=h.size(-1) - 1,
groups=int(_len),
)[..., :l]
attn_preds.append(F.softmax(out.squeeze(), dim=0))
output_lengths.append(l)
max_len = max(output_lengths)
# ----- build full-width alignment matrix -----------------------
with torch.no_grad():
t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
seq_len = ref_text.size(1)
s2s_attn = torch.zeros(
len(ref_lengths), seq_len, max_len, device=ref_text.device
)
for bib, (attn, L) in enumerate(zip(attn_preds, output_lengths)):
s2s_attn[bib, : ref_lengths[bib], :L] = attn
asr_pred = t_en @ s2s_attn
_, p_pred = self.model.predictor(
d_en, s_dur, ref_lengths, s2s_attn, text_mask
)
# ----- clip extraction -----------------------------------------
mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
mel_len = min(mel_len, self.max_len // 2)
en, p_en, sp, wav = [], [], [], []
for bib, L_pred in enumerate(output_lengths):
L_gt = int(mel_input_length[bib].item() / 2)
if L_gt <= mel_len or L_pred <= mel_len:
continue
sp.append(s_preds[bib])
start = np.random.randint(0, L_pred - mel_len)
en.append(asr_pred[bib, :, start : start + mel_len])
p_en.append(p_pred[bib, :, start : start + mel_len])
start_gt = np.random.randint(0, L_gt - mel_len)
y = waves[bib][(start_gt * 2) * 300 : ((start_gt + mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(ref_text.device))
if len(wav) >= self.batch_percentage * len(waves):
break
if len(sp) <= 1:
return None
sp = torch.stack(sp)
wav = torch.stack(wav).float()
en = torch.stack(en)
p_en = torch.stack(p_en)
F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
# -------------- adversarial losses -----------------------------
if (iters + 1) % self.skip_update == 0:
d_loss = self.wl.discriminator(wav.squeeze(), y_pred.detach().squeeze()).mean()
else:
d_loss = 0
gen_loss = self.wl.generator(y_pred.squeeze()).mean()
return d_loss, gen_loss, y_pred.detach().cpu().numpy()
# ------------------------------------------------------------------ #
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Classic length mask: 1 → PAD, 0 → real token."""
max_len = lengths.max()
mask = (
torch.arange(max_len, device=lengths.device)
.unsqueeze(0)
.expand(lengths.size(0), -1)
)
return mask >= lengths.unsqueeze(1) |