| """ |
| ein notation: |
| b - batch |
| n - sequence |
| nt - text sequence |
| nw - raw wave length |
| d - dimension |
| """ |
|
|
| from __future__ import annotations |
|
|
| from random import random |
| import random as _random |
| from typing import Callable, Dict, OrderedDict |
| import math |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import torchaudio |
| from torch import nn |
| from torch.nn.utils.rnn import pad_sequence |
| from torchdiffeq import odeint |
|
|
| from lemas_tts.model.modules import MelSpec |
| from lemas_tts.model.modules import MIEsitmator, AccentClassifier, grad_reverse |
| from lemas_tts.model.backbones.ecapa_tdnn import ECAPA_TDNN |
| from lemas_tts.model.backbones.prosody_encoder import ProsodyEncoder, extract_fbank_16k |
| from lemas_tts.model.utils import ( |
| default, |
| exists, |
| lens_to_mask, |
| list_str_to_idx, |
| list_str_to_tensor, |
| mask_from_frac_lengths, |
| ) |
|
|
|
|
| def clip_and_shuffle(mel, mel_len, sample_rate=24000, hop_length=256, ratio=None): |
| """ |
| Randomly clip a mel-spectrogram segment and shuffle 1-second chunks to |
| create an accent-invariant conditioning segment. |
| |
| This is a inference-time utility used by the accent GRL path. |
| |
| Args: |
| mel: [n_mels, T] |
| mel_len: int, original mel length (T) |
| """ |
| frames_per_second = int(sample_rate / hop_length) |
|
|
| |
| total_len = mel_len |
| if not ratio: |
| seg_len = _random.randint(int(0.25 * total_len), int(0.75 * total_len)) |
| else: |
| seg_len = int(total_len * ratio) |
| start = _random.randint(0, max(0, total_len - seg_len)) |
| mel_seg = mel[:, start : start + seg_len] |
|
|
| |
| n_chunks = (mel_seg.size(1) + frames_per_second - 1) // frames_per_second |
| chunks = [] |
| for i in range(n_chunks): |
| chunk = mel_seg[:, i * frames_per_second : (i + 1) * frames_per_second] |
| chunks.append(chunk) |
|
|
| |
| _random.shuffle(chunks) |
| shuffled_mel = torch.cat(chunks, dim=1) |
|
|
| |
| if shuffled_mel.size(1) < total_len: |
| repeat_chunks = [] |
| while sum(c.size(1) for c in repeat_chunks) < total_len: |
| repeat_chunks.append(_random.choice(chunks)) |
| shuffled_mel = torch.cat([shuffled_mel] + repeat_chunks, dim=1) |
|
|
| |
| shuffled_mel = shuffled_mel[:, :total_len] |
| assert shuffled_mel.shape == mel.shape, f"shuffled_mel.shape != mel.shape: {shuffled_mel.shape} != {mel.shape}" |
|
|
| return shuffled_mel |
|
|
| class CFM(nn.Module): |
| def __init__( |
| self, |
| transformer: nn.Module, |
| sigma=0.0, |
| odeint_kwargs: dict = dict( |
| |
| |
| method="euler" |
| ), |
| audio_drop_prob=0.3, |
| text_drop_prob=0.1, |
| num_channels=None, |
| mel_spec_module: nn.Module | None = None, |
| mel_spec_kwargs: dict = dict(), |
| frac_lengths_mask: tuple[float, float] = (0.7, 1.0), |
| vocab_char_map: dict[str:int] | None = None, |
| use_ctc_loss: bool = False, |
| use_spk_enc: bool = False, |
| use_prosody_encoder: bool = False, |
| prosody_cfg_path: str | None = None, |
| prosody_ckpt_path: str | None = None, |
| ): |
| super().__init__() |
|
|
| self.frac_lengths_mask = frac_lengths_mask |
|
|
| |
| self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) |
| num_channels = default(num_channels, self.mel_spec.n_mel_channels) |
| self.num_channels = num_channels |
|
|
| |
| self.audio_drop_prob = audio_drop_prob |
| self.text_drop_prob = text_drop_prob |
|
|
| |
| self.transformer = transformer |
| dim = transformer.dim |
| self.dim = dim |
|
|
| |
| self.sigma = sigma |
|
|
| |
| self.odeint_kwargs = odeint_kwargs |
|
|
| |
| self.vocab_char_map = vocab_char_map |
|
|
| |
| self.use_prosody_encoder = ( |
| use_prosody_encoder and prosody_cfg_path is not None and prosody_ckpt_path is not None |
| ) |
| if self.use_prosody_encoder: |
| cfg_path = Path(prosody_cfg_path) |
| ckpt_path = Path(prosody_ckpt_path) |
| self.prosody_encoder = ProsodyEncoder(cfg_path, ckpt_path, freeze=True) |
| |
| self.prosody_to_mel = nn.Linear(512, self.num_channels) |
| self.prosody_dropout = nn.Dropout(p=0.2) |
| else: |
| self.prosody_encoder = None |
| |
| |
| self.use_spk_enc = use_spk_enc |
| if use_spk_enc: |
| self.speaker_encoder = ECAPA_TDNN( |
| self.num_channels, |
| self.dim, |
| channels=[512, 512, 512, 512, 1536], |
| kernel_sizes=[5, 3, 3, 3, 1], |
| dilations=[1, 2, 3, 4, 1], |
| attention_channels=128, |
| res2net_scale=4, |
| se_channels=128, |
| global_context=True, |
| batch_norm=True, |
| ) |
| |
|
|
| self.use_ctc_loss = use_ctc_loss |
| if use_ctc_loss: |
| |
| self.ctc = MIEsitmator(len(self.vocab_char_map), self.num_channels, self.dim, dropout=self.text_drop_prob) |
|
|
| self.accent_classifier = AccentClassifier(input_dim=self.num_channels, hidden_dim=self.dim, num_accents=12) |
| self.accent_criterion = nn.CrossEntropyLoss() |
|
|
| def load_partial_weights(self, model: nn.Module, |
| ckpt_path: str, |
| device="cpu", |
| verbose=True) -> int: |
| """ |
| 仅加载形状匹配的参数,其余跳过。 |
| 返回成功加载的参数数量。 |
| """ |
| state_dict = torch.load(ckpt_path, map_location=device) |
| model_dict = model.state_dict() |
|
|
| ok_count = 0 |
| new_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
|
|
| for k, v in state_dict.items(): |
| if k in model_dict and v.shape == model_dict[k].shape: |
| new_dict[k] = v |
| ok_count += 1 |
| else: |
| if verbose: |
| print(f"[SKIP] {k} ckpt:{v.shape} model:{model_dict[k].shape if k in model_dict else 'N/A'}") |
|
|
| model_dict.update(new_dict) |
| model.load_state_dict(model_dict) |
| if verbose: |
| print(f"=> 成功加载 {ok_count}/{len(state_dict)} 个参数") |
| return ok_count |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| @torch.no_grad() |
| def sample( |
| self, |
| cond: float["b n d"] | float["b nw"], |
| text: int["b nt"] | list[str], |
| duration: int | int["b"], |
| *, |
| lens: int["b"] | None = None, |
| steps=32, |
| cfg_strength=1.0, |
| sway_sampling_coef=None, |
| seed: int | None = None, |
| max_duration=4096, |
| vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, |
| no_ref_audio=False, |
| duplicate_test=False, |
| t_inter=0.1, |
| edit_mask=None, |
| use_acc_grl = True, |
| use_prosody_encoder = True, |
| ref_ratio = 1, |
| ): |
| self.eval() |
|
|
| |
| raw_audio = None |
| if cond.ndim == 2: |
| raw_audio = cond.clone() |
| cond = self.mel_spec(cond) |
| cond = cond.permute(0, 2, 1) |
| assert cond.shape[-1] == self.num_channels |
|
|
| cond = cond.to(next(self.parameters()).dtype) |
| cond_mean = cond.mean(dim=1, keepdim=True) |
| batch, cond_seq_len, device = *cond.shape[:2], cond.device |
| if not exists(lens): |
| lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) |
|
|
| |
| prosody_mel_cond = None |
| prosody_text_cond = None |
| prosody_embeds = None |
| if self.prosody_encoder is not None and raw_audio is not None and use_prosody_encoder: |
| embeds = [] |
| for b in range(batch): |
| audio_b = raw_audio[b].unsqueeze(0) |
| src_sr = self.mel_spec.target_sample_rate |
| if src_sr != 16_000: |
| audio_16k = torchaudio.functional.resample( |
| audio_b, src_sr, 16_000 |
| ).squeeze(0) |
| else: |
| audio_16k = audio_b.squeeze(0) |
| fbank = extract_fbank_16k(audio_16k) |
| fbank = fbank.unsqueeze(0).to(device=device, dtype=cond.dtype) |
| emb = self.prosody_encoder(fbank, padding_mask=None)[0] |
| embeds.append(emb) |
| prosody_embeds = torch.stack(embeds, dim=0) |
| |
| prosody_mel_cond = prosody_embeds[:, None, :].expand(-1, cond_seq_len, -1) |
|
|
| if use_acc_grl: |
| |
| |
| |
| |
|
|
| if ref_ratio < 1: |
| rand_mel = clip_and_shuffle(cond.permute(0, 2, 1).squeeze(0), cond.shape[1], ratio=ref_ratio) |
| rand_mel = rand_mel.unsqueeze(0).permute(0, 2, 1) |
| assert rand_mel.shape == cond.shape, f"Shape diff: rand_mel.shape: {rand_mel.shape}, cond.shape: {cond.shape}" |
| cond_grl = grad_reverse(rand_mel, lambda_=1.0) |
| else: |
| cond_grl = grad_reverse(cond, lambda_=1.0) |
| |
|
|
| |
|
|
| if isinstance(text, list): |
| if exists(self.vocab_char_map): |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) |
| else: |
| text = list_str_to_tensor(text).to(device) |
| assert text.shape[0] == batch |
|
|
| |
|
|
| cond_mask = lens_to_mask(lens) |
| if edit_mask is not None: |
| cond_mask = cond_mask & edit_mask |
|
|
| if isinstance(duration, int): |
| duration = torch.full((batch,), duration, device=device, dtype=torch.long) |
|
|
| duration = torch.maximum( |
| torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration |
| ) |
| |
| duration = duration.clamp(max=max_duration) |
| max_duration = int(duration.amax().item()) |
|
|
| |
| if duplicate_test: |
| test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) |
|
|
| cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) |
|
|
| if prosody_mel_cond is not None: |
| prosody_mel_cond = F.pad( |
| prosody_mel_cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0 |
| ) |
| prosody_mel_proj = self.prosody_to_mel(prosody_mel_cond) |
| cond = cond + prosody_mel_proj |
|
|
| if no_ref_audio: |
| random_cond = torch.randn_like(cond) * 0.1 + cond_mean |
| random_cond = random_cond / random_cond.mean(dim=1, keepdim=True) * cond_mean |
| print("cond:", cond.mean(), cond.max(), cond.min(), "random_cond:", random_cond.mean(), random_cond.max(), random_cond.min(), "mean_cond:", cond_mean.shape) |
| cond = random_cond |
| |
| cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) |
| cond_mask = cond_mask.unsqueeze(-1) |
|
|
| if use_acc_grl: |
| cond_grl = F.pad(cond_grl, (0, 0, 0, max_duration - cond_seq_len), value=0.0) |
|
|
|
|
| step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) |
| |
|
|
| if batch > 1: |
| mask = lens_to_mask(duration) |
| else: |
| mask = None |
|
|
| |
|
|
| def compute_sway_max(steps: int, |
| t_start: float = 0.0, |
| dtype=torch.float32, |
| min_ratio: float | None = None, |
| safety_factor: float = 0.5) -> float: |
| """ |
| Compute a safe upper bound for sway_sampling_coef given steps and t_start. |
| |
| - steps: number of ODE steps |
| - t_start: start time in [0,1) |
| - dtype: torch dtype (for machine eps) |
| - min_ratio: smallest distinguishable dt^p (if None, use conservative default) |
| - safety_factor: scale down the theoretical maximum to be safe |
| """ |
| assert 0.0 <= t_start < 1.0 |
| dt = (1.0 - t_start) / max(1, steps) |
| eps = torch.finfo(dtype).eps |
|
|
| if min_ratio is None: |
| |
| min_ratio = max(1e-9, 1e2 * float(eps)) |
|
|
| if dt >= 0.9: |
| p_max = 1.0 + 10.0 |
| else: |
| |
| p_max = math.log(min_ratio) / math.log(dt) |
|
|
| sway_max = max(0.0, p_max - 1.0) |
| sway_max = sway_max * float(safety_factor) |
| return torch.tensor(sway_max, device=device, dtype=dtype) |
|
|
| |
| if prosody_embeds is not None: |
| text_len = text.shape[1] |
| prosody_text_cond = prosody_embeds[:, None, :].expand(-1, text_len, -1) |
| else: |
| prosody_text_cond = None |
|
|
| def fn(t, x): |
| |
| |
| |
| |
| if use_acc_grl: |
| step_cond = torch.where(cond_mask, cond_grl, torch.zeros_like(cond_grl)) |
| else: |
| step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) |
| |
| |
| pred = self.transformer( |
| x=x, |
| cond=step_cond, |
| text=text, |
| time=t, |
| mask=mask, |
| drop_audio_cond=False, |
| drop_text=False, |
| cache=True, |
| prosody_text=prosody_text_cond, |
| ) |
| if cfg_strength < 1e-5: |
| return pred |
|
|
| null_pred = self.transformer( |
| x=x, |
| cond=step_cond, |
| text=text, |
| time=t, |
| mask=mask, |
| drop_audio_cond=True, |
| drop_text=True, |
| cache=True, |
| prosody_text=prosody_text_cond, |
| ) |
| |
| |
| cfg_t = cfg_strength * ((1 - t) ** 2) |
| |
| res = pred + (pred - null_pred) * cfg_t |
| |
| res = res.clamp(-20, 20) |
| return res |
|
|
| |
| |
| |
| y0 = [] |
| for dur in duration: |
| if exists(seed): |
| torch.manual_seed(seed) |
| y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) |
| y0 = pad_sequence(y0, padding_value=0, batch_first=True) |
|
|
| t_start = 0 |
|
|
| |
| if duplicate_test: |
| t_start = t_inter |
| y0 = (1 - t_start) * y0 + t_start * test_cond |
| steps = int(steps * (1 - t_start)) |
|
|
| t = torch.linspace(t_start, 1, int(steps + 1), device=self.device, dtype=step_cond.dtype) |
|
|
| sway_max = compute_sway_max(steps, t_start=t_start, dtype=step_cond.dtype, min_ratio=1e-9, safety_factor=0.7) |
| if sway_sampling_coef is not None: |
| sway_sampling_coef = min(sway_max, sway_sampling_coef) |
| |
| t = t ** (1 + sway_sampling_coef) |
| else: |
| t = t ** (1 + sway_max) |
| |
| |
| trajectory = odeint(fn, y0, t, **self.odeint_kwargs) |
| self.transformer.clear_cache() |
|
|
| sampled = trajectory[-1] |
| out = sampled |
| out = torch.where(cond_mask, cond, out) |
|
|
| |
| if no_ref_audio: |
| out_mean = out[:,cond_seq_len:,:].mean(dim=1, keepdim=True) |
| out[:,cond_seq_len:,:] = out[:,cond_seq_len:,:] - (out_mean - cond_mean) |
| |
|
|
| if exists(vocoder): |
| out = out.permute(0, 2, 1) |
| out = vocoder(out) |
| |
| return out, trajectory |
|
|
|
|
| def info_nce_speaker(self, |
| e_gt: torch.Tensor, |
| e_pred: torch.Tensor, |
| temperature: float = 0.1): |
| """ |
| InfoNCE loss for speaker encoder training. |
| 同一条样本的 e_gt 与 e_pred 互为正例,其余均为负例。 |
| |
| Args: |
| temperature: 温度缩放 τ |
| |
| Returns: |
| loss: 标量 tensor,可 backward |
| """ |
| B = e_gt.size(0) |
| |
| e_gt = F.normalize(e_gt, dim=1) |
| e_pred = F.normalize(e_pred, dim=1) |
|
|
| |
| logits = torch.einsum('bd,cd->bc', e_pred, e_gt) / temperature |
|
|
| |
| labels = torch.arange(B, device=logits.device) |
|
|
| |
| loss = F.cross_entropy(logits, labels) |
| return loss |
|
|
|
|
| def forward_old( |
| self, |
| batchs: Dict[str, torch.Tensor], |
| |
| |
| *, |
| |
| noise_scheduler: str | None = None, |
| ): |
|
|
| inp = batchs["mel"].permute(0, 2, 1) |
| lens = batchs["mel_lengths"] |
|
|
| rand_mel = batchs["rand_mel"].permute(0, 2, 1) |
|
|
| text = batchs["text"] |
| target_text_lengths = torch.tensor([len(x) for x in text], device=inp.device) |
|
|
| langs = batchs["langs"] |
|
|
| |
|
|
| |
| if inp.ndim == 2: |
| inp = self.mel_spec(inp) |
| inp = inp.permute(0, 2, 1) |
| assert inp.shape[-1] == self.num_channels |
|
|
| batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma |
| |
|
|
| |
| if isinstance(text, list): |
| if exists(self.vocab_char_map): |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) |
| else: |
| text = list_str_to_tensor(text).to(device) |
| assert text.shape[0] == batch |
|
|
| |
| if not exists(lens): |
| lens = torch.full((batch,), seq_len, device=device) |
|
|
| mask = lens_to_mask(lens, length=seq_len) |
|
|
| |
| frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) |
| rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) |
|
|
| if exists(mask): |
| rand_span_mask &= mask |
|
|
| |
| x1 = inp |
|
|
| |
| x0 = torch.randn_like(x1) |
|
|
| |
| time = torch.rand((batch,), dtype=dtype, device=self.device) |
| |
|
|
| |
| t = time.unsqueeze(-1).unsqueeze(-1) |
| φ = (1 - t) * x0 + t * x1 |
| flow = x1 - x0 |
|
|
| |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) |
|
|
| |
| if self.use_spk_enc: |
| |
| spk_emb = self.speaker_encoder(rand_mel, lens) |
| |
| spk_emb = spk_emb.unsqueeze(1).expand_as(x1) |
| |
| |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(spk_emb), spk_emb) |
| |
| |
| |
| |
| rand_num = torch.rand((batch, 1, 1), dtype=dtype, device=self.device) |
| cond = cond * rand_num + spk_emb * (1 - rand_num) |
| |
| cond_grl = grad_reverse(cond, lambda_=1.0) |
| |
| |
|
|
| |
| |
| |
| drop_audio_cond = random() < self.audio_drop_prob |
| if random() < self.text_drop_prob: |
| drop_audio_cond = True |
| drop_text_cond = True |
| else: |
| drop_text_cond = False |
|
|
| |
| |
| |
| pred = self.transformer(x=φ, cond=cond_grl, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text_cond) |
|
|
| |
| pred_clamp = pred.float().clamp(-20, 20) |
| loss = F.mse_loss(pred_clamp, flow, reduction="none") |
| loss = loss[rand_span_mask] |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| loss = torch.where(torch.isnan(loss) | (loss > 300.0), 300.0, loss) |
| loss = loss.mean() |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| accent_logits = self.accent_classifier(cond_grl) |
| accent_logits_mean = accent_logits.mean(dim=1) |
| lang_labels = langs.to(accent_logits.device).long() |
| |
| accent_loss = self.accent_criterion(accent_logits_mean, lang_labels) |
| |
| if not torch.isfinite(accent_loss): |
| accent_loss = torch.zeros_like(accent_loss, device=accent_loss.device) |
| |
| loss += 0.1 * accent_loss |
|
|
| valid_indices = torch.where(time > 0.5)[0] |
| |
| if valid_indices.size(0) > 2: |
| |
| selected_gt = inp[valid_indices] |
| selected_pred = pred[valid_indices] |
| selected_text = text[valid_indices] |
| selected_lens = lens[valid_indices] |
| selected_target_lengths = target_text_lengths[valid_indices] |
| |
|
|
| if self.use_spk_enc and valid_indices.size(0) > 2: |
| |
| e_gt = self.speaker_encoder(selected_gt, selected_lens) |
| e_pred = self.speaker_encoder(selected_pred, selected_lens) |
| spk_loss = self.info_nce_speaker(e_gt, e_pred) |
| if not torch.isnan(spk_loss).any(): |
| loss = loss + spk_loss * 10.0 |
| else: |
| spk_loss = torch.zeros_like(loss, device=loss.device, requires_grad=False) |
| else: |
| spk_loss = torch.zeros_like(loss, device=loss.device, requires_grad=False) |
| |
|
|
| |
| if self.use_ctc_loss and valid_indices.size(0) > 2: |
| |
| ctc_loss = self.ctc( |
| decoder_outputs=selected_pred, |
| target_phones=selected_text, |
| decoder_lengths=selected_lens, |
| target_lengths=selected_target_lengths, |
| ) |
| |
| |
| if not torch.isnan(ctc_loss).any() and ctc_loss.item() > 1e-6: |
| |
| ctc_scaled = ctc_loss |
| loss = loss + 0.1 * ctc_scaled |
| else: |
| ctc_scaled = torch.zeros_like(loss, device=loss.device, requires_grad=False) |
| |
| else: |
| ctc_scaled = torch.zeros_like(loss, device=loss.device, requires_grad=False) |
|
|
|
|
| |
| total_loss = loss |
| |
| |
|
|
| return total_loss, ctc_scaled, accent_loss, len(valid_indices), cond, pred |
|
|
|
|
| def forward(self, batchs: Dict[str, torch.Tensor], *, noise_scheduler: str | None = None): |
| """ |
| Simplified forward version for accent-invariant flow matching. |
| Removes speaker encoder and CTC parts, keeps accent GRL. |
| """ |
| inp = batchs["mel"].permute(0, 2, 1) |
| lens = batchs["mel_lengths"] |
| text = batchs["text"] |
| langs = batchs["langs"] |
| audio_16k_list = batchs.get("audio_16k", None) |
| prosody_idx_list = batchs.get("prosody_idx", None) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if inp.ndim == 2: |
| inp = self.mel_spec(inp).permute(0, 2, 1) |
| assert inp.shape[-1] == self.num_channels |
|
|
| batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device |
|
|
| |
| if isinstance(text, list): |
| if exists(self.vocab_char_map): |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) |
| else: |
| text = list_str_to_tensor(text).to(device) |
| assert text.shape[0] == batch |
| |
| |
| prosody_mel_cond = None |
| prosody_text_cond = None |
| if ( |
| self.prosody_encoder is not None |
| and audio_16k_list is not None |
| and prosody_idx_list is not None |
| ): |
| |
| T_mel = seq_len |
| T_text = text.shape[1] |
| prosody_mel_cond = torch.zeros(batch, T_mel, 512, device=device, dtype=dtype) |
| prosody_text_cond = torch.zeros(batch, T_text, 512, device=device, dtype=dtype) |
|
|
| |
| seg_embeds: list[Tensor] = [] |
| seg_meta: list[tuple[int, int, int, int, int, int]] = [] |
| for b in range(batch): |
| audio_b = audio_16k_list[b] |
| idx_list = prosody_idx_list[b] |
| if audio_b is None or idx_list is None: |
| continue |
| audio_b = audio_b.to(device=device, dtype=dtype) |
| for seg in idx_list: |
| text_start, text_end, mel_start, mel_end, audio_start, audio_end = seg |
| |
| audio_start = max(0, min(audio_start, audio_b.shape[0] - 1)) |
| audio_end = max(audio_start + 1, min(audio_end, audio_b.shape[0])) |
| audio_seg = audio_b[audio_start:audio_end] |
| if audio_seg.numel() == 0: |
| continue |
| fbank = extract_fbank_16k(audio_seg) |
| fbank = fbank.unsqueeze(0).to(device=device, dtype=dtype) |
| with torch.no_grad(): |
| emb = self.prosody_encoder(fbank, padding_mask=None)[0] |
| seg_embeds.append(emb) |
| seg_meta.append( |
| (b, text_start, text_end, mel_start, mel_end) |
| ) |
|
|
| if seg_embeds: |
| seg_embeds_tensor = torch.stack(seg_embeds, dim=0) |
| |
| for emb, meta in zip(seg_embeds_tensor, seg_meta): |
| b, ts, te, ms, me = meta |
| emb_exp = emb.to(device=device, dtype=dtype) |
| prosody_mel_cond[b, ms:me, :] = emb_exp |
| prosody_text_cond[b, ts:te, :] = emb_exp |
|
|
| |
| prosody_mel_cond = self.prosody_dropout(prosody_mel_cond) |
| prosody_text_cond = self.prosody_dropout(prosody_text_cond) |
|
|
| |
| mask = lens_to_mask(lens, length=seq_len) |
| frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*self.frac_lengths_mask) |
| rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) |
| if exists(mask): |
| rand_span_mask &= mask |
|
|
| |
| x1 = inp |
| x0 = torch.randn_like(x1) |
| time = torch.rand((batch,), dtype=dtype, device=device) |
| t = time[:, None, None] |
| φ = (1 - t) * x0 + t * x1 |
| flow = x1 - x0 |
|
|
| |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) |
| if prosody_mel_cond is not None: |
| prosody_mel_proj = self.prosody_to_mel(prosody_mel_cond) |
| |
| if prosody_mel_proj.size(1) < seq_len: |
| pad_len = seq_len - prosody_mel_proj.size(1) |
| prosody_mel_proj = F.pad(prosody_mel_proj, (0, 0, 0, pad_len)) |
| elif prosody_mel_proj.size(1) > seq_len: |
| prosody_mel_proj = prosody_mel_proj[:, :seq_len, :] |
| cond = cond + prosody_mel_proj |
| |
| |
| cond_grl = grad_reverse(cond, lambda_=1.0) |
|
|
| |
| |
| |
|
|
| |
| rand_for_drop = torch.rand(1, device=device) |
| drop_audio_cond = (rand_for_drop.item() < self.audio_drop_prob) |
| rand_for_text = torch.rand(1, device=device) |
| drop_text_cond = (rand_for_text.item() < self.text_drop_prob) |
|
|
| |
| pred = self.transformer( |
| x=φ, |
| cond=cond_grl, |
| text=text, |
| time=time, |
| drop_audio_cond=drop_audio_cond, |
| drop_text=drop_text_cond, |
| prosody_text=prosody_text_cond, |
| ) |
|
|
| |
| pred_clamp = pred.float().clamp(-20, 20) |
| per_elem_loss = F.mse_loss(pred_clamp, flow, reduction="none") |
|
|
| mask_exp = rand_span_mask.unsqueeze(-1).to(dtype=per_elem_loss.dtype) |
| masked_loss = per_elem_loss * mask_exp |
|
|
| |
| n_selected = mask_exp.sum() * per_elem_loss.size(-1) |
| denom = torch.clamp(n_selected, min=1.0) |
|
|
| loss_sum = masked_loss.sum() |
| loss = loss_sum / denom |
| |
| loss = torch.where(torch.isnan(loss) | (loss > 300.0), torch.tensor(300.0, device=loss.device, dtype=loss.dtype), loss) |
|
|
| |
| accent_logits = self.accent_classifier(cond_grl) |
| |
| accent_logits_mean = accent_logits.mean(dim=1) |
| lang_labels = langs.to(accent_logits_mean.device).long() |
| accent_loss = self.accent_criterion(accent_logits_mean, lang_labels) |
| |
| if not torch.isfinite(accent_loss): |
| accent_loss = torch.zeros_like(accent_loss, device=accent_loss.device) |
|
|
| base_loss = loss + 0.1 * accent_loss |
|
|
| |
| ctc_scaled = torch.tensor(0.0, device=device, dtype=dtype) |
| if getattr(self, "use_ctc_loss", False) and getattr(self, "ctc", None) is not None: |
| |
| valid_indices = torch.where(time > 0.5)[0] |
| if valid_indices.size(0) > 2: |
| selected_pred = pred[valid_indices] |
| selected_text = text[valid_indices] |
| selected_lens = lens[valid_indices] |
| |
| selected_target_lengths = (selected_text != -1).sum(dim=-1) |
|
|
| ctc_loss = self.ctc( |
| decoder_outputs=selected_pred, |
| target_phones=selected_text, |
| decoder_lengths=selected_lens, |
| target_lengths=selected_target_lengths, |
| ) |
| if torch.isfinite(ctc_loss) and ctc_loss.item() > 1e-6: |
| ctc_scaled = ctc_loss |
| base_loss = base_loss + 0.1 * ctc_scaled |
|
|
| total_loss = base_loss |
|
|
| |
| |
|
|
| return total_loss, accent_loss, ctc_scaled, cond, pred |
|
|