Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ein notation: | |
| b - batch | |
| n - sequence | |
| nt - text sequence | |
| nw - raw wave length | |
| d - dimension | |
| """ | |
| from __future__ import annotations | |
| from random import random | |
| import random as _random | |
| from typing import Callable, Dict, OrderedDict | |
| import math | |
| from pathlib import Path | |
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from torch import nn | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torchdiffeq import odeint | |
| from lemas_tts.model.modules import MelSpec | |
| from lemas_tts.model.modules import MIEsitmator, AccentClassifier, grad_reverse | |
| from lemas_tts.model.backbones.ecapa_tdnn import ECAPA_TDNN | |
| from lemas_tts.model.backbones.prosody_encoder import ProsodyEncoder, extract_fbank_16k | |
| from lemas_tts.model.utils import ( | |
| default, | |
| exists, | |
| lens_to_mask, | |
| list_str_to_idx, | |
| list_str_to_tensor, | |
| mask_from_frac_lengths, | |
| ) | |
| def clip_and_shuffle(mel, mel_len, sample_rate=24000, hop_length=256, ratio=None): | |
| """ | |
| Randomly clip a mel-spectrogram segment and shuffle 1-second chunks to | |
| create an accent-invariant conditioning segment. | |
| This is a inference-time utility used by the accent GRL path. | |
| Args: | |
| mel: [n_mels, T] | |
| mel_len: int, original mel length (T) | |
| """ | |
| frames_per_second = int(sample_rate / hop_length) # ≈ 94 frames / second | |
| # ---- 1. Randomly crop 25%~75% of the original length (or ratio * length) ---- | |
| total_len = mel_len | |
| if not ratio: | |
| seg_len = _random.randint(int(0.25 * total_len), int(0.75 * total_len)) | |
| else: | |
| seg_len = int(total_len * ratio) | |
| start = _random.randint(0, max(0, total_len - seg_len)) | |
| mel_seg = mel[:, start : start + seg_len] | |
| # ---- 2. Split into ~1-second chunks ---- | |
| n_chunks = (mel_seg.size(1) + frames_per_second - 1) // frames_per_second | |
| chunks = [] | |
| for i in range(n_chunks): | |
| chunk = mel_seg[:, i * frames_per_second : (i + 1) * frames_per_second] | |
| chunks.append(chunk) | |
| # ---- 3. Shuffle chunk order ---- | |
| _random.shuffle(chunks) | |
| shuffled_mel = torch.cat(chunks, dim=1) | |
| # ---- 4. Repeat random chunks until reaching original length ---- | |
| if shuffled_mel.size(1) < total_len: | |
| repeat_chunks = [] | |
| while sum(c.size(1) for c in repeat_chunks) < total_len: | |
| repeat_chunks.append(_random.choice(chunks)) | |
| shuffled_mel = torch.cat([shuffled_mel] + repeat_chunks, dim=1) | |
| # ---- 5. Trim to exactly mel_len ---- | |
| shuffled_mel = shuffled_mel[:, :total_len] | |
| assert shuffled_mel.shape == mel.shape, f"shuffled_mel.shape != mel.shape: {shuffled_mel.shape} != {mel.shape}" | |
| return shuffled_mel | |
| class CFM(nn.Module): | |
| def __init__( | |
| self, | |
| transformer: nn.Module, | |
| sigma=0.0, | |
| odeint_kwargs: dict = dict( | |
| # atol = 1e-5, | |
| # rtol = 1e-5, | |
| method="euler" # 'midpoint' | |
| ), | |
| audio_drop_prob=0.3, | |
| text_drop_prob=0.1, | |
| num_channels=None, | |
| mel_spec_module: nn.Module | None = None, | |
| mel_spec_kwargs: dict = dict(), | |
| frac_lengths_mask: tuple[float, float] = (0.7, 1.0), | |
| vocab_char_map: dict[str:int] | None = None, | |
| use_ctc_loss: bool = False, | |
| use_spk_enc: bool = False, | |
| use_prosody_encoder: bool = False, | |
| prosody_cfg_path: str | None = None, | |
| prosody_ckpt_path: str | None = None, | |
| ): | |
| super().__init__() | |
| self.frac_lengths_mask = frac_lengths_mask | |
| # mel spec | |
| self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) | |
| num_channels = default(num_channels, self.mel_spec.n_mel_channels) | |
| self.num_channels = num_channels | |
| # classifier-free guidance | |
| self.audio_drop_prob = audio_drop_prob | |
| self.text_drop_prob = text_drop_prob | |
| # transformer | |
| self.transformer = transformer | |
| dim = transformer.dim | |
| self.dim = dim | |
| # conditional flow related | |
| self.sigma = sigma | |
| # sampling related | |
| self.odeint_kwargs = odeint_kwargs | |
| # vocab map for tokenization | |
| self.vocab_char_map = vocab_char_map | |
| # Prosody encoder (Pretssel ECAPA-TDNN) | |
| self.use_prosody_encoder = ( | |
| use_prosody_encoder and prosody_cfg_path is not None and prosody_ckpt_path is not None | |
| ) | |
| if self.use_prosody_encoder: | |
| cfg_path = Path(prosody_cfg_path) | |
| ckpt_path = Path(prosody_ckpt_path) | |
| self.prosody_encoder = ProsodyEncoder(cfg_path, ckpt_path, freeze=True) | |
| # 512-d prosody -> mel channel dimension | |
| self.prosody_to_mel = nn.Linear(512, self.num_channels) | |
| self.prosody_dropout = nn.Dropout(p=0.2) | |
| else: | |
| self.prosody_encoder = None | |
| # Speaker encoder | |
| self.use_spk_enc = use_spk_enc | |
| if use_spk_enc: | |
| self.speaker_encoder = ECAPA_TDNN( | |
| self.num_channels, | |
| self.dim, | |
| channels=[512, 512, 512, 512, 1536], | |
| kernel_sizes=[5, 3, 3, 3, 1], | |
| dilations=[1, 2, 3, 4, 1], | |
| attention_channels=128, | |
| res2net_scale=4, | |
| se_channels=128, | |
| global_context=True, | |
| batch_norm=True, | |
| ) | |
| # self.load_partial_weights(self.speaker_encoder, "/cto_labs/vistring/zhaozhiyuan/outputs/F5-TTS/pretrain/speaker.bin", device="cpu") | |
| self.use_ctc_loss = use_ctc_loss | |
| if use_ctc_loss: | |
| # print("vocab_char_map:", len(vocab_char_map)+1, "dim:", dim, "mel_spec_kwargs:",mel_spec_kwargs) | |
| self.ctc = MIEsitmator(len(self.vocab_char_map), self.num_channels, self.dim, dropout=self.text_drop_prob) | |
| self.accent_classifier = AccentClassifier(input_dim=self.num_channels, hidden_dim=self.dim, num_accents=12) | |
| self.accent_criterion = nn.CrossEntropyLoss() | |
| def load_partial_weights(self, model: nn.Module, | |
| ckpt_path: str, | |
| device="cpu", | |
| verbose=True) -> int: | |
| """ | |
| 仅加载形状匹配的参数,其余跳过。 | |
| 返回成功加载的参数数量。 | |
| """ | |
| state_dict = torch.load(ckpt_path, map_location=device) | |
| model_dict = model.state_dict() | |
| ok_count = 0 | |
| new_dict: OrderedDict[str, torch.Tensor] = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if k in model_dict and v.shape == model_dict[k].shape: | |
| new_dict[k] = v | |
| ok_count += 1 | |
| else: | |
| if verbose: | |
| print(f"[SKIP] {k} ckpt:{v.shape} model:{model_dict[k].shape if k in model_dict else 'N/A'}") | |
| model_dict.update(new_dict) | |
| model.load_state_dict(model_dict) | |
| if verbose: | |
| print(f"=> 成功加载 {ok_count}/{len(state_dict)} 个参数") | |
| return ok_count | |
| def device(self): | |
| return next(self.parameters()).device | |
| def sample( | |
| self, | |
| cond: float["b n d"] | float["b nw"], # noqa: F722 | |
| text: int["b nt"] | list[str], # noqa: F722 | |
| duration: int | int["b"], # noqa: F821 | |
| *, | |
| lens: int["b"] | None = None, # noqa: F821 | |
| steps=32, | |
| cfg_strength=1.0, | |
| sway_sampling_coef=None, | |
| seed: int | None = None, | |
| max_duration=4096, | |
| vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 | |
| no_ref_audio=False, | |
| duplicate_test=False, | |
| t_inter=0.1, | |
| edit_mask=None, | |
| use_acc_grl = True, | |
| use_prosody_encoder = True, | |
| ref_ratio = 1, | |
| ): | |
| self.eval() | |
| # raw wave -> mel, keep a copy for prosody encoder if available | |
| raw_audio = None | |
| if cond.ndim == 2: | |
| raw_audio = cond.clone() # (B, nw) | |
| cond = self.mel_spec(cond) | |
| cond = cond.permute(0, 2, 1) | |
| assert cond.shape[-1] == self.num_channels | |
| cond = cond.to(next(self.parameters()).dtype) | |
| cond_mean = cond.mean(dim=1, keepdim=True) | |
| batch, cond_seq_len, device = *cond.shape[:2], cond.device | |
| if not exists(lens): | |
| lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) | |
| # optional global prosody conditioning at inference (one embedding per sample) | |
| prosody_mel_cond = None | |
| prosody_text_cond = None | |
| prosody_embeds = None | |
| if self.prosody_encoder is not None and raw_audio is not None and use_prosody_encoder: | |
| embeds = [] | |
| for b in range(batch): | |
| audio_b = raw_audio[b].unsqueeze(0) # (1, nw) | |
| src_sr = self.mel_spec.target_sample_rate | |
| if src_sr != 16_000: | |
| audio_16k = torchaudio.functional.resample( | |
| audio_b, src_sr, 16_000 | |
| ).squeeze(0) | |
| else: | |
| audio_16k = audio_b.squeeze(0) | |
| fbank = extract_fbank_16k(audio_16k) | |
| fbank = fbank.unsqueeze(0).to(device=device, dtype=cond.dtype) | |
| emb = self.prosody_encoder(fbank, padding_mask=None)[0] # (512,) | |
| embeds.append(emb) | |
| prosody_embeds = torch.stack(embeds, dim=0) # (B, 512) | |
| # broadcast along mel and text | |
| prosody_mel_cond = prosody_embeds[:, None, :].expand(-1, cond_seq_len, -1) | |
| if use_acc_grl: | |
| # rand_mel = clip_and_shuffle(cond.permute(0, 2, 1).squeeze(0), cond.shape[1]) | |
| # rand_mel = rand_mel.unsqueeze(0).permute(0, 2, 1) | |
| # assert rand_mel.shape == cond.shape, f"Shape diff: rand_mel.shape: {rand_mel.shape}, cond.shape: {cond.shape}" | |
| # cond_grl = grad_reverse(rand_mel, lambda_=1.0) | |
| if ref_ratio < 1: | |
| rand_mel = clip_and_shuffle(cond.permute(0, 2, 1).squeeze(0), cond.shape[1], ratio=ref_ratio) | |
| rand_mel = rand_mel.unsqueeze(0).permute(0, 2, 1) | |
| assert rand_mel.shape == cond.shape, f"Shape diff: rand_mel.shape: {rand_mel.shape}, cond.shape: {cond.shape}" | |
| cond_grl = grad_reverse(rand_mel, lambda_=1.0) | |
| else: | |
| cond_grl = grad_reverse(cond, lambda_=1.0) | |
| # print("cond:", cond.shape, cond.mean(), cond.max(), cond.min(), "rand_mel:", rand_mel.mean(), rand_mel.max(), rand_mel.min(), "cond_grl:", cond_grl.mean(), cond_grl.max(), cond_grl.min()) | |
| # text | |
| if isinstance(text, list): | |
| if exists(self.vocab_char_map): | |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) | |
| else: | |
| text = list_str_to_tensor(text).to(device) | |
| assert text.shape[0] == batch | |
| # duration | |
| cond_mask = lens_to_mask(lens) | |
| if edit_mask is not None: | |
| cond_mask = cond_mask & edit_mask | |
| if isinstance(duration, int): | |
| duration = torch.full((batch,), duration, device=device, dtype=torch.long) | |
| duration = torch.maximum( | |
| torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration | |
| ) # duration at least text/audio prompt length plus one token, so something is generated | |
| # clamp and convert max_duration to python int for padding ops | |
| duration = duration.clamp(max=max_duration) | |
| max_duration = int(duration.amax().item()) | |
| # duplicate test corner for inner time step oberservation | |
| if duplicate_test: | |
| test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0) | |
| cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) | |
| if prosody_mel_cond is not None: | |
| prosody_mel_cond = F.pad( | |
| prosody_mel_cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0 | |
| ) | |
| prosody_mel_proj = self.prosody_to_mel(prosody_mel_cond) | |
| cond = cond + prosody_mel_proj | |
| if no_ref_audio: | |
| random_cond = torch.randn_like(cond) * 0.1 + cond_mean | |
| random_cond = random_cond / random_cond.mean(dim=1, keepdim=True) * cond_mean | |
| print("cond:", cond.mean(), cond.max(), cond.min(), "random_cond:", random_cond.mean(), random_cond.max(), random_cond.min(), "mean_cond:", cond_mean.shape) | |
| cond = random_cond | |
| cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) | |
| cond_mask = cond_mask.unsqueeze(-1) | |
| if use_acc_grl: | |
| cond_grl = F.pad(cond_grl, (0, 0, 0, max_duration - cond_seq_len), value=0.0) | |
| step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in | |
| if batch > 1: | |
| mask = lens_to_mask(duration) | |
| else: # save memory and speed up, as single inference need no mask currently | |
| mask = None | |
| # neural ode | |
| def compute_sway_max(steps: int, | |
| t_start: float = 0.0, | |
| dtype=torch.float32, | |
| min_ratio: float | None = None, | |
| safety_factor: float = 0.5) -> float: | |
| """ | |
| Compute a safe upper bound for sway_sampling_coef given steps and t_start. | |
| - steps: number of ODE steps | |
| - t_start: start time in [0,1) | |
| - dtype: torch dtype (for machine eps) | |
| - min_ratio: smallest distinguishable dt^p (if None, use conservative default) | |
| - safety_factor: scale down the theoretical maximum to be safe | |
| """ | |
| assert 0.0 <= t_start < 1.0 | |
| dt = (1.0 - t_start) / max(1, steps) | |
| eps = torch.finfo(dtype).eps | |
| if min_ratio is None: | |
| # conservative default: ~100 * eps (float32 -> ~1e-5) | |
| min_ratio = max(1e-9, 1e2 * float(eps)) | |
| if dt >= 0.9: | |
| p_max = 1.0 + 10.0 | |
| else: | |
| # solve dt^p >= min_ratio => p <= log(min_ratio)/log(dt) | |
| p_max = math.log(min_ratio) / math.log(dt) | |
| sway_max = max(0.0, p_max - 1.0) | |
| sway_max = sway_max * float(safety_factor) | |
| return torch.tensor(sway_max, device=device, dtype=dtype) | |
| # prepare text-side prosody conditioning if embeddings available | |
| if prosody_embeds is not None: | |
| text_len = text.shape[1] | |
| prosody_text_cond = prosody_embeds[:, None, :].expand(-1, text_len, -1) | |
| else: | |
| prosody_text_cond = None | |
| def fn(t, x): | |
| # at each step, conditioning is fixed | |
| # if use_spk_enc: | |
| # mix_cond = t * cond + (1-t) * spk_emb | |
| # step_cond = torch.where(cond_mask, mix_cond, torch.zeros_like(mix_cond)) | |
| if use_acc_grl: | |
| step_cond = torch.where(cond_mask, cond_grl, torch.zeros_like(cond_grl)) | |
| else: | |
| step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) | |
| # predict flow | |
| pred = self.transformer( | |
| x=x, | |
| cond=step_cond, | |
| text=text, | |
| time=t, | |
| mask=mask, | |
| drop_audio_cond=False, | |
| drop_text=False, | |
| cache=True, | |
| prosody_text=prosody_text_cond, | |
| ) | |
| if cfg_strength < 1e-5: | |
| return pred | |
| null_pred = self.transformer( | |
| x=x, | |
| cond=step_cond, | |
| text=text, | |
| time=t, | |
| mask=mask, | |
| drop_audio_cond=True, | |
| drop_text=True, | |
| cache=True, | |
| prosody_text=prosody_text_cond, | |
| ) | |
| # cfg_t = cfg_strength * torch.cos(0.5 * torch.pi * t) | |
| # cfg_t = cfg_strength * (1 - t) | |
| cfg_t = cfg_strength * ((1 - t) ** 2) | |
| # print("t:", t, "cfg_t:", cfg_t) | |
| res = pred + (pred - null_pred) * cfg_t | |
| # print("t:", t.item(), "\tres:", res.shape, res.mean().item(), res.max().item(), res.min().item(), "\tpred:", pred.mean().item(), pred.max().item(), pred.min().item(), "\tnull_pred:", null_pred.mean().item(), null_pred.max().item(), null_pred.min().item(), "\tcfg_t:", cfg_t.item()) | |
| res = res.clamp(-20, 20) | |
| return res | |
| # noise input | |
| # to make sure batch inference result is same with different batch size, and for sure single inference | |
| # still some difference maybe due to convolutional layers | |
| y0 = [] | |
| for dur in duration: | |
| if exists(seed): | |
| torch.manual_seed(seed) | |
| y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype)) | |
| y0 = pad_sequence(y0, padding_value=0, batch_first=True) | |
| t_start = 0 | |
| # duplicate test corner for inner time step oberservation | |
| if duplicate_test: | |
| t_start = t_inter | |
| y0 = (1 - t_start) * y0 + t_start * test_cond | |
| steps = int(steps * (1 - t_start)) | |
| t = torch.linspace(t_start, 1, int(steps + 1), device=self.device, dtype=step_cond.dtype) | |
| sway_max = compute_sway_max(steps, t_start=t_start, dtype=step_cond.dtype, min_ratio=1e-9, safety_factor=0.7) | |
| if sway_sampling_coef is not None: | |
| sway_sampling_coef = min(sway_max, sway_sampling_coef) | |
| # t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
| t = t ** (1 + sway_sampling_coef) | |
| else: | |
| t = t ** (1 + sway_max) | |
| # print("t:",t, "sway_max:", sway_max, "sway_sampling_coef:", sway_sampling_coef) | |
| trajectory = odeint(fn, y0, t, **self.odeint_kwargs) | |
| self.transformer.clear_cache() | |
| sampled = trajectory[-1] | |
| out = sampled | |
| out = torch.where(cond_mask, cond, out) | |
| # out生成的部分,或者说pad补0的部分,单独计算mean, 然后和cond的mean做对齐(乘以系数,两个的均值要差不多) | |
| if no_ref_audio: | |
| out_mean = out[:,cond_seq_len:,:].mean(dim=1, keepdim=True) | |
| out[:,cond_seq_len:,:] = out[:,cond_seq_len:,:] - (out_mean - cond_mean) | |
| # print("out_mean:", out_mean.shape, out_mean.mean(), "cond_mean:", cond_mean.shape, cond_mean.mean(), "out:", out[:,cond_seq_len:,:].shape, out[:,cond_seq_len:,:].mean().item(), out[:,cond_seq_len:,:].max().item(), out[:,cond_seq_len:,:].min().item()) | |
| if exists(vocoder): | |
| out = out.permute(0, 2, 1) | |
| out = vocoder(out) | |
| # print("out:", out.shape, "trajectory:", trajectory.shape) | |
| return out, trajectory | |
| def info_nce_speaker(self, | |
| e_gt: torch.Tensor, | |
| e_pred: torch.Tensor, | |
| temperature: float = 0.1): | |
| """ | |
| InfoNCE loss for speaker encoder training. | |
| 同一条样本的 e_gt 与 e_pred 互为正例,其余均为负例。 | |
| Args: | |
| temperature: 温度缩放 τ | |
| Returns: | |
| loss: 标量 tensor,可 backward | |
| """ | |
| B = e_gt.size(0) | |
| # 2. L2 归一化 | |
| e_gt = F.normalize(e_gt, dim=1) | |
| e_pred = F.normalize(e_pred, dim=1) | |
| # 3. 计算 B×B 相似度矩阵(pred 对 gt) | |
| logits = torch.einsum('bd,cd->bc', e_pred, e_gt) / temperature # [B, B] | |
| # 4. 正例标签正好是对角线 | |
| labels = torch.arange(B, device=logits.device) | |
| # 5. InfoNCE = cross-entropy over in-batch negatives | |
| loss = F.cross_entropy(logits, labels) | |
| return loss | |
| def forward_old( | |
| self, | |
| batchs: Dict[str, torch.Tensor], | |
| # inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 | |
| # text: int["b nt"] | list[str], # noqa: F722 | |
| *, | |
| # lens: int["b"] | None = None, # noqa: F821 | |
| noise_scheduler: str | None = None, | |
| ): | |
| inp = batchs["mel"].permute(0, 2, 1) | |
| lens = batchs["mel_lengths"] | |
| rand_mel = batchs["rand_mel"].permute(0, 2, 1) | |
| text = batchs["text"] | |
| target_text_lengths = torch.tensor([len(x) for x in text], device=inp.device) | |
| langs = batchs["langs"] | |
| # print("inp:", inp.shape, "rand_mel:", rand_mel.shape, "lens:", lens, "target_text_lengths:", target_text_lengths, "langs:", langs) | |
| # handle raw wave | |
| if inp.ndim == 2: | |
| inp = self.mel_spec(inp) | |
| inp = inp.permute(0, 2, 1) | |
| assert inp.shape[-1] == self.num_channels | |
| batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma | |
| # print("inp_shape:", inp.shape, inp.max(), inp.min(), "dtype:", dtype, "device:", device, "σ1:", _σ1) | |
| # handle text as string | |
| if isinstance(text, list): | |
| if exists(self.vocab_char_map): | |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) | |
| else: | |
| text = list_str_to_tensor(text).to(device) | |
| assert text.shape[0] == batch | |
| # lens and mask | |
| if not exists(lens): | |
| lens = torch.full((batch,), seq_len, device=device) | |
| mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch | |
| # get a random span to mask out for training conditionally | |
| frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) | |
| rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) | |
| if exists(mask): | |
| rand_span_mask &= mask | |
| # mel is x1 | |
| x1 = inp | |
| # x0 is gaussian noise | |
| x0 = torch.randn_like(x1) | |
| # time step | |
| time = torch.rand((batch,), dtype=dtype, device=self.device) | |
| # TODO. noise_scheduler | |
| # sample xt (φ_t(x) in the paper) | |
| t = time.unsqueeze(-1).unsqueeze(-1) | |
| φ = (1 - t) * x0 + t * x1 | |
| flow = x1 - x0 | |
| # cond = torch.where(rand_span_mask[..., None], torch.zeros_like(rand_mel), rand_mel) | |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) | |
| # print("seq_len:", seq_len, "lens:", lens) | |
| if self.use_spk_enc: # 50%的概率使用spk_emb | |
| spk_emb = self.speaker_encoder(rand_mel, lens) | |
| # global_emb: [batch, 1, dim] -> 重复扩展到 [batch, seq_len, dim] | |
| spk_emb = spk_emb.unsqueeze(1).expand_as(x1) | |
| # print("spk_emb_shape:", spk_emb.shape) | |
| # 应用mask操作 | |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(spk_emb), spk_emb) | |
| # assert cond.shape[0] == batch, "speaker encoder batch size mismatch" | |
| # print("x1.shape:", x1.shape, "cond_shape:", cond.shape) | |
| # 给一个随机数,把spk_emb * 随机数,再加上原来的cond *(1-随机数) | |
| rand_num = torch.rand((batch, 1, 1), dtype=dtype, device=self.device) | |
| cond = cond * rand_num + spk_emb * (1 - rand_num) | |
| cond_grl = grad_reverse(cond, lambda_=1.0) | |
| # print("inp_shape:", inp.shape, "rand_span_mask:", rand_span_mask.shape) | |
| # # # transformer and cfg training with a drop rate | |
| # drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper | |
| # drop_text_cond = random() < self.text_drop_prob # p_drop in voicebox paper | |
| drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper | |
| if random() < self.text_drop_prob: # p_uncond in voicebox paper | |
| drop_audio_cond = True | |
| drop_text_cond = True | |
| else: | |
| drop_text_cond = False | |
| # print("drop_audio_cond:", drop_audio_cond, "drop_text_cond:", drop_text_cond) | |
| # if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here | |
| # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences | |
| pred = self.transformer(x=φ, cond=cond_grl, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text_cond) | |
| # flow matching loss | |
| pred_clamp = pred.float().clamp(-20, 20) | |
| loss = F.mse_loss(pred_clamp, flow, reduction="none") | |
| loss = loss[rand_span_mask] # [N] | |
| # # # 1. 全局截断:>2 或 NaN → 0(全局) | |
| # print("mse loss shape:", loss.shape, "loss max:", loss.max(), "loss min:", loss.min(), target_text_lengths[0]) | |
| # # 2. 统计非NaN值的百分比 | |
| # valid_mask = ~torch.isnan(loss) | |
| # total_count = loss.numel() # 总元素数量(所有维度) | |
| # valid_count = valid_mask.sum().item() # 非NaN元素数量 | |
| # valid_percentage = (valid_count / total_count) * 100 | |
| # print(f"mse loss: total_count: {total_count}", f"valid_count: {valid_count}", f"valid_percentage: {valid_percentage:.2f}%") | |
| # valid_loss = loss[~torch.isnan(loss)] | |
| loss = torch.where(torch.isnan(loss) | (loss > 300.0), 300.0, loss) | |
| loss = loss.mean() | |
| # loss = torch.tanh(torch.log1p(loss.mean())) # 对数缩放 | |
| # if len(valid_loss) > 0: | |
| # clipped_loss = torch.clamp(valid_loss, max=150) | |
| # loss = torch.tanh(torch.log1p(clipped_loss.mean())) # 对数缩放 | |
| # else: | |
| # loss = torch.tensor(0.0, device=pred.device) | |
| accent_logits = self.accent_classifier(cond_grl) | |
| accent_logits_mean = accent_logits.mean(dim=1) | |
| lang_labels = langs.to(accent_logits.device).long() | |
| # print("langs:", lang_labels, "accent_logits:", accent_logits.shape, "accent_logits_mean:", accent_logits_mean.shape) | |
| accent_loss = self.accent_criterion(accent_logits_mean, lang_labels) | |
| # guard against NaN / Inf in accent_loss | |
| if not torch.isfinite(accent_loss): | |
| accent_loss = torch.zeros_like(accent_loss, device=accent_loss.device) | |
| # accent_loss = torch.zeros_like(loss, device=loss.device, requires_grad=True) | |
| loss += 0.1 * accent_loss | |
| valid_indices = torch.where(time > 0.5)[0] | |
| # print("torch.where(time > 0.5):", valid_indices, torch.where(time > 0.5)) | |
| if valid_indices.size(0) > 2: | |
| # 动态选择符合条件的sample | |
| selected_gt = inp[valid_indices] | |
| selected_pred = pred[valid_indices] | |
| selected_text = text[valid_indices] | |
| selected_lens = lens[valid_indices] | |
| selected_target_lengths = target_text_lengths[valid_indices] | |
| # print("pred:", selected_pred.shape, "valid_indices:", valid_indices, "lens:", selected_lens, "target_lengths:", selected_target_lengths) | |
| if self.use_spk_enc and valid_indices.size(0) > 2: | |
| # speaker encoder loss | |
| e_gt = self.speaker_encoder(selected_gt, selected_lens) | |
| e_pred = self.speaker_encoder(selected_pred, selected_lens) | |
| spk_loss = self.info_nce_speaker(e_gt, e_pred) | |
| if not torch.isnan(spk_loss).any(): # and spk_loss.item() > 1e-6: | |
| loss = loss + spk_loss * 10.0 | |
| else: | |
| spk_loss = torch.zeros_like(loss, device=loss.device, requires_grad=False) | |
| else: | |
| spk_loss = torch.zeros_like(loss, device=loss.device, requires_grad=False) | |
| # print("spk_loss:", spk_loss) | |
| # ctc loss | |
| if self.use_ctc_loss and valid_indices.size(0) > 2: | |
| # 如果t大于0.5 则计算ctc loss | |
| ctc_loss = self.ctc( | |
| decoder_outputs=selected_pred, | |
| target_phones=selected_text, | |
| decoder_lengths=selected_lens, | |
| target_lengths=selected_target_lengths, | |
| ) | |
| # print("loss:", loss, "ctc_loss:", ctc_loss, "time: ", time.shape, time[valid_indices].mean()) | |
| # 如果ctc loss没有nan,才加上ctc loss | |
| if not torch.isnan(ctc_loss).any() and ctc_loss.item() > 1e-6: | |
| # ctc_scaled = torch.tanh(torch.log1p(ctc_loss)) | |
| ctc_scaled = ctc_loss | |
| loss = loss + 0.1 * ctc_scaled | |
| else: | |
| ctc_scaled = torch.zeros_like(loss, device=loss.device, requires_grad=False) | |
| # print("loss:", loss, "ctc_scaled:", ctc_scaled) | |
| else: | |
| ctc_scaled = torch.zeros_like(loss, device=loss.device, requires_grad=False) | |
| # 在计算完 total loss 之前 | |
| total_loss = loss # base flow loss + others you added | |
| # note: we intentionally do NOT add 0.0 * pred.sum() etc. here, to avoid | |
| # propagating NaNs from intermediate tensors into the loss scalar. | |
| return total_loss, ctc_scaled, accent_loss, len(valid_indices), cond, pred # accent_loss, | |
| def forward(self, batchs: Dict[str, torch.Tensor], *, noise_scheduler: str | None = None): | |
| """ | |
| Simplified forward version for accent-invariant flow matching. | |
| Removes speaker encoder and CTC parts, keeps accent GRL. | |
| """ | |
| inp = batchs["mel"].permute(0, 2, 1) # [B, T_mel, D] | |
| lens = batchs["mel_lengths"] | |
| text = batchs["text"] | |
| langs = batchs["langs"] | |
| audio_16k_list = batchs.get("audio_16k", None) | |
| prosody_idx_list = batchs.get("prosody_idx", None) | |
| # # ---- 4. 随机截取并打乱 segment ---- | |
| # rand_mel = [clip_and_shuffle(spec, spec.shape[-1]) for spec in batchs["mel"]] | |
| # padded_rand_mel = [] | |
| # for spec in rand_mel: | |
| # padding = (0, batchs["mel"].shape[-1] - spec.size(-1)) | |
| # padded_spec = F.pad(spec, padding, value=0) | |
| # padded_rand_mel.append(padded_spec) | |
| # rand_mel = torch.stack(padded_rand_mel).permute(0, 2, 1) | |
| # assert rand_mel.shape == inp.shape, f"shape diff: rand_mel.shape: {rand_mel.shape}, inp.shape: {inp.shape}" | |
| if inp.ndim == 2: | |
| inp = self.mel_spec(inp).permute(0, 2, 1) | |
| assert inp.shape[-1] == self.num_channels | |
| batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device | |
| # --- handle text | |
| if isinstance(text, list): | |
| if exists(self.vocab_char_map): | |
| text = list_str_to_idx(text, self.vocab_char_map).to(device) | |
| else: | |
| text = list_str_to_tensor(text).to(device) | |
| assert text.shape[0] == batch | |
| # print("text:", batchs["text"][0], text.shape, text[0], batchs["text_lengths"][0]) | |
| # --- prosody conditioning (compute embeddings per sub-utterance) | |
| prosody_mel_cond = None | |
| prosody_text_cond = None | |
| if ( | |
| self.prosody_encoder is not None | |
| and audio_16k_list is not None | |
| and prosody_idx_list is not None | |
| ): | |
| # prepare zero tensors for each sample | |
| T_mel = seq_len | |
| T_text = text.shape[1] | |
| prosody_mel_cond = torch.zeros(batch, T_mel, 512, device=device, dtype=dtype) | |
| prosody_text_cond = torch.zeros(batch, T_text, 512, device=device, dtype=dtype) | |
| # collect all segments, run encoder per segment | |
| seg_embeds: list[Tensor] = [] | |
| seg_meta: list[tuple[int, int, int, int, int, int]] = [] | |
| for b in range(batch): | |
| audio_b = audio_16k_list[b] | |
| idx_list = prosody_idx_list[b] | |
| if audio_b is None or idx_list is None: | |
| continue | |
| audio_b = audio_b.to(device=device, dtype=dtype) | |
| for seg in idx_list: | |
| text_start, text_end, mel_start, mel_end, audio_start, audio_end = seg | |
| # clamp audio indices | |
| audio_start = max(0, min(audio_start, audio_b.shape[0] - 1)) | |
| audio_end = max(audio_start + 1, min(audio_end, audio_b.shape[0])) | |
| audio_seg = audio_b[audio_start:audio_end] | |
| if audio_seg.numel() == 0: | |
| continue | |
| fbank = extract_fbank_16k(audio_seg) # (T_fbank, 80) | |
| fbank = fbank.unsqueeze(0).to(device=device, dtype=dtype) # (1, T_fbank, 80) | |
| with torch.no_grad(): | |
| emb = self.prosody_encoder(fbank, padding_mask=None)[0] # (512,) | |
| seg_embeds.append(emb) | |
| seg_meta.append( | |
| (b, text_start, text_end, mel_start, mel_end) | |
| ) | |
| if seg_embeds: | |
| seg_embeds_tensor = torch.stack(seg_embeds, dim=0) # (N_seg, 512) | |
| # scatter embeddings back to per-sample tensors | |
| for emb, meta in zip(seg_embeds_tensor, seg_meta): | |
| b, ts, te, ms, me = meta | |
| emb_exp = emb.to(device=device, dtype=dtype) | |
| prosody_mel_cond[b, ms:me, :] = emb_exp | |
| prosody_text_cond[b, ts:te, :] = emb_exp | |
| # dropout on prosody conditioning | |
| prosody_mel_cond = self.prosody_dropout(prosody_mel_cond) | |
| prosody_text_cond = self.prosody_dropout(prosody_text_cond) | |
| # --- mask & random span | |
| mask = lens_to_mask(lens, length=seq_len) | |
| frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*self.frac_lengths_mask) | |
| rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) | |
| if exists(mask): | |
| rand_span_mask &= mask | |
| # --- flow setup | |
| x1 = inp | |
| x0 = torch.randn_like(x1) | |
| time = torch.rand((batch,), dtype=dtype, device=device) | |
| t = time[:, None, None] | |
| φ = (1 - t) * x0 + t * x1 | |
| flow = x1 - x0 | |
| # --- conditional input (masked mel) + optional prosody | |
| cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) # x1 # rand_mel | |
| if prosody_mel_cond is not None: | |
| prosody_mel_proj = self.prosody_to_mel(prosody_mel_cond) # (B, T_mel, num_channels) | |
| # if needed, pad/crop to seq_len | |
| if prosody_mel_proj.size(1) < seq_len: | |
| pad_len = seq_len - prosody_mel_proj.size(1) | |
| prosody_mel_proj = F.pad(prosody_mel_proj, (0, 0, 0, pad_len)) | |
| elif prosody_mel_proj.size(1) > seq_len: | |
| prosody_mel_proj = prosody_mel_proj[:, :seq_len, :] | |
| cond = cond + prosody_mel_proj | |
| # --- Gradient reversal: encourage accent-invariant cond | |
| cond_grl = grad_reverse(cond, lambda_=1.0) | |
| # # --- random drop condition for CFG-like robustness | |
| # drop_audio_cond = random() < self.audio_drop_prob | |
| # drop_text_cond = random() < self.text_drop_prob if not drop_audio_cond else True | |
| # safe per-batch random (tensor) | |
| rand_for_drop = torch.rand(1, device=device) | |
| drop_audio_cond = (rand_for_drop.item() < self.audio_drop_prob) | |
| rand_for_text = torch.rand(1, device=device) | |
| drop_text_cond = (rand_for_text.item() < self.text_drop_prob) | |
| # --- main prediction | |
| pred = self.transformer( | |
| x=φ, | |
| cond=cond_grl, | |
| text=text, | |
| time=time, | |
| drop_audio_cond=drop_audio_cond, | |
| drop_text=drop_text_cond, | |
| prosody_text=prosody_text_cond, | |
| ) | |
| # === FLOW LOSS (robust mask-weighted) === | |
| pred_clamp = pred.float().clamp(-20, 20) | |
| per_elem_loss = F.mse_loss(pred_clamp, flow, reduction="none") # [B, T, D] | |
| mask_exp = rand_span_mask.unsqueeze(-1).to(dtype=per_elem_loss.dtype) # [B, T, 1] | |
| masked_loss = per_elem_loss * mask_exp # zeros where mask False | |
| # total selected scalar (frames * dim) | |
| n_selected = mask_exp.sum() * per_elem_loss.size(-1) # scalar | |
| denom = torch.clamp(n_selected, min=1.0) | |
| loss_sum = masked_loss.sum() | |
| loss = loss_sum / denom | |
| # numeric safety | |
| loss = torch.where(torch.isnan(loss) | (loss > 300.0), torch.tensor(300.0, device=loss.device, dtype=loss.dtype), loss) | |
| # === ACCENT LOSS === | |
| accent_logits = self.accent_classifier(cond_grl) | |
| # pool across time -> [B, C] | |
| accent_logits_mean = accent_logits.mean(dim=1) | |
| lang_labels = langs.to(accent_logits_mean.device).long() | |
| accent_loss = self.accent_criterion(accent_logits_mean, lang_labels) | |
| # guard against NaN / Inf in accent_loss | |
| if not torch.isfinite(accent_loss): | |
| accent_loss = torch.zeros_like(accent_loss, device=accent_loss.device) | |
| base_loss = loss + 0.1 * accent_loss | |
| # === OPTIONAL CTC LOSS (robust, only on valid samples) === | |
| ctc_scaled = torch.tensor(0.0, device=device, dtype=dtype) | |
| if getattr(self, "use_ctc_loss", False) and getattr(self, "ctc", None) is not None: | |
| # select samples with larger t for CTC supervision (similar to forward_old) | |
| valid_indices = torch.where(time > 0.5)[0] | |
| if valid_indices.size(0) > 2: | |
| selected_pred = pred[valid_indices] | |
| selected_text = text[valid_indices] | |
| selected_lens = lens[valid_indices] | |
| # text was tokenized from list_str_to_idx, where padding is -1 | |
| selected_target_lengths = (selected_text != -1).sum(dim=-1) | |
| ctc_loss = self.ctc( | |
| decoder_outputs=selected_pred, | |
| target_phones=selected_text, | |
| decoder_lengths=selected_lens, | |
| target_lengths=selected_target_lengths, | |
| ) | |
| if torch.isfinite(ctc_loss) and ctc_loss.item() > 1e-6: | |
| ctc_scaled = ctc_loss | |
| base_loss = base_loss + 0.1 * ctc_scaled | |
| total_loss = base_loss | |
| # note: we intentionally do NOT add 0.0 * pred.sum() etc. here, to avoid | |
| # propagating NaNs from intermediate tensors into the loss scalar. | |
| return total_loss, accent_loss, ctc_scaled, cond, pred | |