Approximetal's picture
Upload folder using huggingface_hub
f36e46d verified
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from random import random
import random as _random
from typing import Callable, Dict, OrderedDict
import math
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from lemas_tts.model.modules import MelSpec
from lemas_tts.model.modules import MIEsitmator, AccentClassifier, grad_reverse
from lemas_tts.model.backbones.ecapa_tdnn import ECAPA_TDNN
from lemas_tts.model.backbones.prosody_encoder import ProsodyEncoder, extract_fbank_16k
from lemas_tts.model.utils import (
default,
exists,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
mask_from_frac_lengths,
)
def clip_and_shuffle(mel, mel_len, sample_rate=24000, hop_length=256, ratio=None):
"""
Randomly clip a mel-spectrogram segment and shuffle 1-second chunks to
create an accent-invariant conditioning segment.
This is a inference-time utility used by the accent GRL path.
Args:
mel: [n_mels, T]
mel_len: int, original mel length (T)
"""
frames_per_second = int(sample_rate / hop_length) # ≈ 94 frames / second
# ---- 1. Randomly crop 25%~75% of the original length (or ratio * length) ----
total_len = mel_len
if not ratio:
seg_len = _random.randint(int(0.25 * total_len), int(0.75 * total_len))
else:
seg_len = int(total_len * ratio)
start = _random.randint(0, max(0, total_len - seg_len))
mel_seg = mel[:, start : start + seg_len]
# ---- 2. Split into ~1-second chunks ----
n_chunks = (mel_seg.size(1) + frames_per_second - 1) // frames_per_second
chunks = []
for i in range(n_chunks):
chunk = mel_seg[:, i * frames_per_second : (i + 1) * frames_per_second]
chunks.append(chunk)
# ---- 3. Shuffle chunk order ----
_random.shuffle(chunks)
shuffled_mel = torch.cat(chunks, dim=1)
# ---- 4. Repeat random chunks until reaching original length ----
if shuffled_mel.size(1) < total_len:
repeat_chunks = []
while sum(c.size(1) for c in repeat_chunks) < total_len:
repeat_chunks.append(_random.choice(chunks))
shuffled_mel = torch.cat([shuffled_mel] + repeat_chunks, dim=1)
# ---- 5. Trim to exactly mel_len ----
shuffled_mel = shuffled_mel[:, :total_len]
assert shuffled_mel.shape == mel.shape, f"shuffled_mel.shape != mel.shape: {shuffled_mel.shape} != {mel.shape}"
return shuffled_mel
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma=0.0,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method="euler" # 'midpoint'
),
audio_drop_prob=0.3,
text_drop_prob=0.1,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str:int] | None = None,
use_ctc_loss: bool = False,
use_spk_enc: bool = False,
use_prosody_encoder: bool = False,
prosody_cfg_path: str | None = None,
prosody_ckpt_path: str | None = None,
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.text_drop_prob = text_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# conditional flow related
self.sigma = sigma
# sampling related
self.odeint_kwargs = odeint_kwargs
# vocab map for tokenization
self.vocab_char_map = vocab_char_map
# Prosody encoder (Pretssel ECAPA-TDNN)
self.use_prosody_encoder = (
use_prosody_encoder and prosody_cfg_path is not None and prosody_ckpt_path is not None
)
if self.use_prosody_encoder:
cfg_path = Path(prosody_cfg_path)
ckpt_path = Path(prosody_ckpt_path)
self.prosody_encoder = ProsodyEncoder(cfg_path, ckpt_path, freeze=True)
# 512-d prosody -> mel channel dimension
self.prosody_to_mel = nn.Linear(512, self.num_channels)
self.prosody_dropout = nn.Dropout(p=0.2)
else:
self.prosody_encoder = None
# Speaker encoder
self.use_spk_enc = use_spk_enc
if use_spk_enc:
self.speaker_encoder = ECAPA_TDNN(
self.num_channels,
self.dim,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=4,
se_channels=128,
global_context=True,
batch_norm=True,
)
# self.load_partial_weights(self.speaker_encoder, "/cto_labs/vistring/zhaozhiyuan/outputs/F5-TTS/pretrain/speaker.bin", device="cpu")
self.use_ctc_loss = use_ctc_loss
if use_ctc_loss:
# print("vocab_char_map:", len(vocab_char_map)+1, "dim:", dim, "mel_spec_kwargs:",mel_spec_kwargs)
self.ctc = MIEsitmator(len(self.vocab_char_map), self.num_channels, self.dim, dropout=self.text_drop_prob)
self.accent_classifier = AccentClassifier(input_dim=self.num_channels, hidden_dim=self.dim, num_accents=12)
self.accent_criterion = nn.CrossEntropyLoss()
def load_partial_weights(self, model: nn.Module,
ckpt_path: str,
device="cpu",
verbose=True) -> int:
"""
仅加载形状匹配的参数,其余跳过。
返回成功加载的参数数量。
"""
state_dict = torch.load(ckpt_path, map_location=device)
model_dict = model.state_dict()
ok_count = 0
new_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
for k, v in state_dict.items():
if k in model_dict and v.shape == model_dict[k].shape:
new_dict[k] = v
ok_count += 1
else:
if verbose:
print(f"[SKIP] {k} ckpt:{v.shape} model:{model_dict[k].shape if k in model_dict else 'N/A'}")
model_dict.update(new_dict)
model.load_state_dict(model_dict)
if verbose:
print(f"=> 成功加载 {ok_count}/{len(state_dict)} 个参数")
return ok_count
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"], # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
duration: int | int["b"], # noqa: F821
*,
lens: int["b"] | None = None, # noqa: F821
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
use_acc_grl = True,
use_prosody_encoder = True,
ref_ratio = 1,
):
self.eval()
# raw wave -> mel, keep a copy for prosody encoder if available
raw_audio = None
if cond.ndim == 2:
raw_audio = cond.clone() # (B, nw)
cond = self.mel_spec(cond)
cond = cond.permute(0, 2, 1)
assert cond.shape[-1] == self.num_channels
cond = cond.to(next(self.parameters()).dtype)
cond_mean = cond.mean(dim=1, keepdim=True)
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# optional global prosody conditioning at inference (one embedding per sample)
prosody_mel_cond = None
prosody_text_cond = None
prosody_embeds = None
if self.prosody_encoder is not None and raw_audio is not None and use_prosody_encoder:
embeds = []
for b in range(batch):
audio_b = raw_audio[b].unsqueeze(0) # (1, nw)
src_sr = self.mel_spec.target_sample_rate
if src_sr != 16_000:
audio_16k = torchaudio.functional.resample(
audio_b, src_sr, 16_000
).squeeze(0)
else:
audio_16k = audio_b.squeeze(0)
fbank = extract_fbank_16k(audio_16k)
fbank = fbank.unsqueeze(0).to(device=device, dtype=cond.dtype)
emb = self.prosody_encoder(fbank, padding_mask=None)[0] # (512,)
embeds.append(emb)
prosody_embeds = torch.stack(embeds, dim=0) # (B, 512)
# broadcast along mel and text
prosody_mel_cond = prosody_embeds[:, None, :].expand(-1, cond_seq_len, -1)
if use_acc_grl:
# rand_mel = clip_and_shuffle(cond.permute(0, 2, 1).squeeze(0), cond.shape[1])
# rand_mel = rand_mel.unsqueeze(0).permute(0, 2, 1)
# assert rand_mel.shape == cond.shape, f"Shape diff: rand_mel.shape: {rand_mel.shape}, cond.shape: {cond.shape}"
# cond_grl = grad_reverse(rand_mel, lambda_=1.0)
if ref_ratio < 1:
rand_mel = clip_and_shuffle(cond.permute(0, 2, 1).squeeze(0), cond.shape[1], ratio=ref_ratio)
rand_mel = rand_mel.unsqueeze(0).permute(0, 2, 1)
assert rand_mel.shape == cond.shape, f"Shape diff: rand_mel.shape: {rand_mel.shape}, cond.shape: {cond.shape}"
cond_grl = grad_reverse(rand_mel, lambda_=1.0)
else:
cond_grl = grad_reverse(cond, lambda_=1.0)
# print("cond:", cond.shape, cond.mean(), cond.max(), cond.min(), "rand_mel:", rand_mel.mean(), rand_mel.max(), rand_mel.min(), "cond_grl:", cond_grl.mean(), cond_grl.max(), cond_grl.min())
# text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# duration
cond_mask = lens_to_mask(lens)
if edit_mask is not None:
cond_mask = cond_mask & edit_mask
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
) # duration at least text/audio prompt length plus one token, so something is generated
# clamp and convert max_duration to python int for padding ops
duration = duration.clamp(max=max_duration)
max_duration = int(duration.amax().item())
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
if prosody_mel_cond is not None:
prosody_mel_cond = F.pad(
prosody_mel_cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0
)
prosody_mel_proj = self.prosody_to_mel(prosody_mel_cond)
cond = cond + prosody_mel_proj
if no_ref_audio:
random_cond = torch.randn_like(cond) * 0.1 + cond_mean
random_cond = random_cond / random_cond.mean(dim=1, keepdim=True) * cond_mean
print("cond:", cond.mean(), cond.max(), cond.min(), "random_cond:", random_cond.mean(), random_cond.max(), random_cond.min(), "mean_cond:", cond_mean.shape)
cond = random_cond
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
cond_mask = cond_mask.unsqueeze(-1)
if use_acc_grl:
cond_grl = F.pad(cond_grl, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
if batch > 1:
mask = lens_to_mask(duration)
else: # save memory and speed up, as single inference need no mask currently
mask = None
# neural ode
def compute_sway_max(steps: int,
t_start: float = 0.0,
dtype=torch.float32,
min_ratio: float | None = None,
safety_factor: float = 0.5) -> float:
"""
Compute a safe upper bound for sway_sampling_coef given steps and t_start.
- steps: number of ODE steps
- t_start: start time in [0,1)
- dtype: torch dtype (for machine eps)
- min_ratio: smallest distinguishable dt^p (if None, use conservative default)
- safety_factor: scale down the theoretical maximum to be safe
"""
assert 0.0 <= t_start < 1.0
dt = (1.0 - t_start) / max(1, steps)
eps = torch.finfo(dtype).eps
if min_ratio is None:
# conservative default: ~100 * eps (float32 -> ~1e-5)
min_ratio = max(1e-9, 1e2 * float(eps))
if dt >= 0.9:
p_max = 1.0 + 10.0
else:
# solve dt^p >= min_ratio => p <= log(min_ratio)/log(dt)
p_max = math.log(min_ratio) / math.log(dt)
sway_max = max(0.0, p_max - 1.0)
sway_max = sway_max * float(safety_factor)
return torch.tensor(sway_max, device=device, dtype=dtype)
# prepare text-side prosody conditioning if embeddings available
if prosody_embeds is not None:
text_len = text.shape[1]
prosody_text_cond = prosody_embeds[:, None, :].expand(-1, text_len, -1)
else:
prosody_text_cond = None
def fn(t, x):
# at each step, conditioning is fixed
# if use_spk_enc:
# mix_cond = t * cond + (1-t) * spk_emb
# step_cond = torch.where(cond_mask, mix_cond, torch.zeros_like(mix_cond))
if use_acc_grl:
step_cond = torch.where(cond_mask, cond_grl, torch.zeros_like(cond_grl))
else:
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
prosody_text=prosody_text_cond,
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=True,
drop_text=True,
cache=True,
prosody_text=prosody_text_cond,
)
# cfg_t = cfg_strength * torch.cos(0.5 * torch.pi * t)
# cfg_t = cfg_strength * (1 - t)
cfg_t = cfg_strength * ((1 - t) ** 2)
# print("t:", t, "cfg_t:", cfg_t)
res = pred + (pred - null_pred) * cfg_t
# print("t:", t.item(), "\tres:", res.shape, res.mean().item(), res.max().item(), res.min().item(), "\tpred:", pred.mean().item(), pred.max().item(), pred.min().item(), "\tnull_pred:", null_pred.mean().item(), null_pred.max().item(), null_pred.min().item(), "\tcfg_t:", cfg_t.item())
res = res.clamp(-20, 20)
return res
# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
# still some difference maybe due to convolutional layers
y0 = []
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
t_start = 0
# duplicate test corner for inner time step oberservation
if duplicate_test:
t_start = t_inter
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, int(steps + 1), device=self.device, dtype=step_cond.dtype)
sway_max = compute_sway_max(steps, t_start=t_start, dtype=step_cond.dtype, min_ratio=1e-9, safety_factor=0.7)
if sway_sampling_coef is not None:
sway_sampling_coef = min(sway_max, sway_sampling_coef)
# t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
t = t ** (1 + sway_sampling_coef)
else:
t = t ** (1 + sway_max)
# print("t:",t, "sway_max:", sway_max, "sway_sampling_coef:", sway_sampling_coef)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
# out生成的部分,或者说pad补0的部分,单独计算mean, 然后和cond的mean做对齐(乘以系数,两个的均值要差不多)
if no_ref_audio:
out_mean = out[:,cond_seq_len:,:].mean(dim=1, keepdim=True)
out[:,cond_seq_len:,:] = out[:,cond_seq_len:,:] - (out_mean - cond_mean)
# print("out_mean:", out_mean.shape, out_mean.mean(), "cond_mean:", cond_mean.shape, cond_mean.mean(), "out:", out[:,cond_seq_len:,:].shape, out[:,cond_seq_len:,:].mean().item(), out[:,cond_seq_len:,:].max().item(), out[:,cond_seq_len:,:].min().item())
if exists(vocoder):
out = out.permute(0, 2, 1)
out = vocoder(out)
# print("out:", out.shape, "trajectory:", trajectory.shape)
return out, trajectory
def info_nce_speaker(self,
e_gt: torch.Tensor,
e_pred: torch.Tensor,
temperature: float = 0.1):
"""
InfoNCE loss for speaker encoder training.
同一条样本的 e_gt 与 e_pred 互为正例,其余均为负例。
Args:
temperature: 温度缩放 τ
Returns:
loss: 标量 tensor,可 backward
"""
B = e_gt.size(0)
# 2. L2 归一化
e_gt = F.normalize(e_gt, dim=1)
e_pred = F.normalize(e_pred, dim=1)
# 3. 计算 B×B 相似度矩阵(pred 对 gt)
logits = torch.einsum('bd,cd->bc', e_pred, e_gt) / temperature # [B, B]
# 4. 正例标签正好是对角线
labels = torch.arange(B, device=logits.device)
# 5. InfoNCE = cross-entropy over in-batch negatives
loss = F.cross_entropy(logits, labels)
return loss
def forward_old(
self,
batchs: Dict[str, torch.Tensor],
# inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
# text: int["b nt"] | list[str], # noqa: F722
*,
# lens: int["b"] | None = None, # noqa: F821
noise_scheduler: str | None = None,
):
inp = batchs["mel"].permute(0, 2, 1)
lens = batchs["mel_lengths"]
rand_mel = batchs["rand_mel"].permute(0, 2, 1)
text = batchs["text"]
target_text_lengths = torch.tensor([len(x) for x in text], device=inp.device)
langs = batchs["langs"]
# print("inp:", inp.shape, "rand_mel:", rand_mel.shape, "lens:", lens, "target_text_lengths:", target_text_lengths, "langs:", langs)
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = inp.permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
# print("inp_shape:", inp.shape, inp.max(), inp.min(), "dtype:", dtype, "device:", device, "σ1:", _σ1)
# handle text as string
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype=dtype, device=self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
t = time.unsqueeze(-1).unsqueeze(-1)
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# cond = torch.where(rand_span_mask[..., None], torch.zeros_like(rand_mel), rand_mel)
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# print("seq_len:", seq_len, "lens:", lens)
if self.use_spk_enc: # 50%的概率使用spk_emb
spk_emb = self.speaker_encoder(rand_mel, lens)
# global_emb: [batch, 1, dim] -> 重复扩展到 [batch, seq_len, dim]
spk_emb = spk_emb.unsqueeze(1).expand_as(x1)
# print("spk_emb_shape:", spk_emb.shape)
# 应用mask操作
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(spk_emb), spk_emb)
# assert cond.shape[0] == batch, "speaker encoder batch size mismatch"
# print("x1.shape:", x1.shape, "cond_shape:", cond.shape)
# 给一个随机数,把spk_emb * 随机数,再加上原来的cond *(1-随机数)
rand_num = torch.rand((batch, 1, 1), dtype=dtype, device=self.device)
cond = cond * rand_num + spk_emb * (1 - rand_num)
cond_grl = grad_reverse(cond, lambda_=1.0)
# print("inp_shape:", inp.shape, "rand_span_mask:", rand_span_mask.shape)
# # # transformer and cfg training with a drop rate
# drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
# drop_text_cond = random() < self.text_drop_prob # p_drop in voicebox paper
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.text_drop_prob: # p_uncond in voicebox paper
drop_audio_cond = True
drop_text_cond = True
else:
drop_text_cond = False
# print("drop_audio_cond:", drop_audio_cond, "drop_text_cond:", drop_text_cond)
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(x=φ, cond=cond_grl, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text_cond)
# flow matching loss
pred_clamp = pred.float().clamp(-20, 20)
loss = F.mse_loss(pred_clamp, flow, reduction="none")
loss = loss[rand_span_mask] # [N]
# # # 1. 全局截断:>2 或 NaN → 0(全局)
# print("mse loss shape:", loss.shape, "loss max:", loss.max(), "loss min:", loss.min(), target_text_lengths[0])
# # 2. 统计非NaN值的百分比
# valid_mask = ~torch.isnan(loss)
# total_count = loss.numel() # 总元素数量(所有维度)
# valid_count = valid_mask.sum().item() # 非NaN元素数量
# valid_percentage = (valid_count / total_count) * 100
# print(f"mse loss: total_count: {total_count}", f"valid_count: {valid_count}", f"valid_percentage: {valid_percentage:.2f}%")
# valid_loss = loss[~torch.isnan(loss)]
loss = torch.where(torch.isnan(loss) | (loss > 300.0), 300.0, loss)
loss = loss.mean()
# loss = torch.tanh(torch.log1p(loss.mean())) # 对数缩放
# if len(valid_loss) > 0:
# clipped_loss = torch.clamp(valid_loss, max=150)
# loss = torch.tanh(torch.log1p(clipped_loss.mean())) # 对数缩放
# else:
# loss = torch.tensor(0.0, device=pred.device)
accent_logits = self.accent_classifier(cond_grl)
accent_logits_mean = accent_logits.mean(dim=1)
lang_labels = langs.to(accent_logits.device).long()
# print("langs:", lang_labels, "accent_logits:", accent_logits.shape, "accent_logits_mean:", accent_logits_mean.shape)
accent_loss = self.accent_criterion(accent_logits_mean, lang_labels)
# guard against NaN / Inf in accent_loss
if not torch.isfinite(accent_loss):
accent_loss = torch.zeros_like(accent_loss, device=accent_loss.device)
# accent_loss = torch.zeros_like(loss, device=loss.device, requires_grad=True)
loss += 0.1 * accent_loss
valid_indices = torch.where(time > 0.5)[0]
# print("torch.where(time > 0.5):", valid_indices, torch.where(time > 0.5))
if valid_indices.size(0) > 2:
# 动态选择符合条件的sample
selected_gt = inp[valid_indices]
selected_pred = pred[valid_indices]
selected_text = text[valid_indices]
selected_lens = lens[valid_indices]
selected_target_lengths = target_text_lengths[valid_indices]
# print("pred:", selected_pred.shape, "valid_indices:", valid_indices, "lens:", selected_lens, "target_lengths:", selected_target_lengths)
if self.use_spk_enc and valid_indices.size(0) > 2:
# speaker encoder loss
e_gt = self.speaker_encoder(selected_gt, selected_lens)
e_pred = self.speaker_encoder(selected_pred, selected_lens)
spk_loss = self.info_nce_speaker(e_gt, e_pred)
if not torch.isnan(spk_loss).any(): # and spk_loss.item() > 1e-6:
loss = loss + spk_loss * 10.0
else:
spk_loss = torch.zeros_like(loss, device=loss.device, requires_grad=False)
else:
spk_loss = torch.zeros_like(loss, device=loss.device, requires_grad=False)
# print("spk_loss:", spk_loss)
# ctc loss
if self.use_ctc_loss and valid_indices.size(0) > 2:
# 如果t大于0.5 则计算ctc loss
ctc_loss = self.ctc(
decoder_outputs=selected_pred,
target_phones=selected_text,
decoder_lengths=selected_lens,
target_lengths=selected_target_lengths,
)
# print("loss:", loss, "ctc_loss:", ctc_loss, "time: ", time.shape, time[valid_indices].mean())
# 如果ctc loss没有nan,才加上ctc loss
if not torch.isnan(ctc_loss).any() and ctc_loss.item() > 1e-6:
# ctc_scaled = torch.tanh(torch.log1p(ctc_loss))
ctc_scaled = ctc_loss
loss = loss + 0.1 * ctc_scaled
else:
ctc_scaled = torch.zeros_like(loss, device=loss.device, requires_grad=False)
# print("loss:", loss, "ctc_scaled:", ctc_scaled)
else:
ctc_scaled = torch.zeros_like(loss, device=loss.device, requires_grad=False)
# 在计算完 total loss 之前
total_loss = loss # base flow loss + others you added
# note: we intentionally do NOT add 0.0 * pred.sum() etc. here, to avoid
# propagating NaNs from intermediate tensors into the loss scalar.
return total_loss, ctc_scaled, accent_loss, len(valid_indices), cond, pred # accent_loss,
def forward(self, batchs: Dict[str, torch.Tensor], *, noise_scheduler: str | None = None):
"""
Simplified forward version for accent-invariant flow matching.
Removes speaker encoder and CTC parts, keeps accent GRL.
"""
inp = batchs["mel"].permute(0, 2, 1) # [B, T_mel, D]
lens = batchs["mel_lengths"]
text = batchs["text"]
langs = batchs["langs"]
audio_16k_list = batchs.get("audio_16k", None)
prosody_idx_list = batchs.get("prosody_idx", None)
# # ---- 4. 随机截取并打乱 segment ----
# rand_mel = [clip_and_shuffle(spec, spec.shape[-1]) for spec in batchs["mel"]]
# padded_rand_mel = []
# for spec in rand_mel:
# padding = (0, batchs["mel"].shape[-1] - spec.size(-1))
# padded_spec = F.pad(spec, padding, value=0)
# padded_rand_mel.append(padded_spec)
# rand_mel = torch.stack(padded_rand_mel).permute(0, 2, 1)
# assert rand_mel.shape == inp.shape, f"shape diff: rand_mel.shape: {rand_mel.shape}, inp.shape: {inp.shape}"
if inp.ndim == 2:
inp = self.mel_spec(inp).permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, self.device
# --- handle text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# print("text:", batchs["text"][0], text.shape, text[0], batchs["text_lengths"][0])
# --- prosody conditioning (compute embeddings per sub-utterance)
prosody_mel_cond = None
prosody_text_cond = None
if (
self.prosody_encoder is not None
and audio_16k_list is not None
and prosody_idx_list is not None
):
# prepare zero tensors for each sample
T_mel = seq_len
T_text = text.shape[1]
prosody_mel_cond = torch.zeros(batch, T_mel, 512, device=device, dtype=dtype)
prosody_text_cond = torch.zeros(batch, T_text, 512, device=device, dtype=dtype)
# collect all segments, run encoder per segment
seg_embeds: list[Tensor] = []
seg_meta: list[tuple[int, int, int, int, int, int]] = []
for b in range(batch):
audio_b = audio_16k_list[b]
idx_list = prosody_idx_list[b]
if audio_b is None or idx_list is None:
continue
audio_b = audio_b.to(device=device, dtype=dtype)
for seg in idx_list:
text_start, text_end, mel_start, mel_end, audio_start, audio_end = seg
# clamp audio indices
audio_start = max(0, min(audio_start, audio_b.shape[0] - 1))
audio_end = max(audio_start + 1, min(audio_end, audio_b.shape[0]))
audio_seg = audio_b[audio_start:audio_end]
if audio_seg.numel() == 0:
continue
fbank = extract_fbank_16k(audio_seg) # (T_fbank, 80)
fbank = fbank.unsqueeze(0).to(device=device, dtype=dtype) # (1, T_fbank, 80)
with torch.no_grad():
emb = self.prosody_encoder(fbank, padding_mask=None)[0] # (512,)
seg_embeds.append(emb)
seg_meta.append(
(b, text_start, text_end, mel_start, mel_end)
)
if seg_embeds:
seg_embeds_tensor = torch.stack(seg_embeds, dim=0) # (N_seg, 512)
# scatter embeddings back to per-sample tensors
for emb, meta in zip(seg_embeds_tensor, seg_meta):
b, ts, te, ms, me = meta
emb_exp = emb.to(device=device, dtype=dtype)
prosody_mel_cond[b, ms:me, :] = emb_exp
prosody_text_cond[b, ts:te, :] = emb_exp
# dropout on prosody conditioning
prosody_mel_cond = self.prosody_dropout(prosody_mel_cond)
prosody_text_cond = self.prosody_dropout(prosody_text_cond)
# --- mask & random span
mask = lens_to_mask(lens, length=seq_len)
frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*self.frac_lengths_mask)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# --- flow setup
x1 = inp
x0 = torch.randn_like(x1)
time = torch.rand((batch,), dtype=dtype, device=device)
t = time[:, None, None]
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# --- conditional input (masked mel) + optional prosody
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) # x1 # rand_mel
if prosody_mel_cond is not None:
prosody_mel_proj = self.prosody_to_mel(prosody_mel_cond) # (B, T_mel, num_channels)
# if needed, pad/crop to seq_len
if prosody_mel_proj.size(1) < seq_len:
pad_len = seq_len - prosody_mel_proj.size(1)
prosody_mel_proj = F.pad(prosody_mel_proj, (0, 0, 0, pad_len))
elif prosody_mel_proj.size(1) > seq_len:
prosody_mel_proj = prosody_mel_proj[:, :seq_len, :]
cond = cond + prosody_mel_proj
# --- Gradient reversal: encourage accent-invariant cond
cond_grl = grad_reverse(cond, lambda_=1.0)
# # --- random drop condition for CFG-like robustness
# drop_audio_cond = random() < self.audio_drop_prob
# drop_text_cond = random() < self.text_drop_prob if not drop_audio_cond else True
# safe per-batch random (tensor)
rand_for_drop = torch.rand(1, device=device)
drop_audio_cond = (rand_for_drop.item() < self.audio_drop_prob)
rand_for_text = torch.rand(1, device=device)
drop_text_cond = (rand_for_text.item() < self.text_drop_prob)
# --- main prediction
pred = self.transformer(
x=φ,
cond=cond_grl,
text=text,
time=time,
drop_audio_cond=drop_audio_cond,
drop_text=drop_text_cond,
prosody_text=prosody_text_cond,
)
# === FLOW LOSS (robust mask-weighted) ===
pred_clamp = pred.float().clamp(-20, 20)
per_elem_loss = F.mse_loss(pred_clamp, flow, reduction="none") # [B, T, D]
mask_exp = rand_span_mask.unsqueeze(-1).to(dtype=per_elem_loss.dtype) # [B, T, 1]
masked_loss = per_elem_loss * mask_exp # zeros where mask False
# total selected scalar (frames * dim)
n_selected = mask_exp.sum() * per_elem_loss.size(-1) # scalar
denom = torch.clamp(n_selected, min=1.0)
loss_sum = masked_loss.sum()
loss = loss_sum / denom
# numeric safety
loss = torch.where(torch.isnan(loss) | (loss > 300.0), torch.tensor(300.0, device=loss.device, dtype=loss.dtype), loss)
# === ACCENT LOSS ===
accent_logits = self.accent_classifier(cond_grl)
# pool across time -> [B, C]
accent_logits_mean = accent_logits.mean(dim=1)
lang_labels = langs.to(accent_logits_mean.device).long()
accent_loss = self.accent_criterion(accent_logits_mean, lang_labels)
# guard against NaN / Inf in accent_loss
if not torch.isfinite(accent_loss):
accent_loss = torch.zeros_like(accent_loss, device=accent_loss.device)
base_loss = loss + 0.1 * accent_loss
# === OPTIONAL CTC LOSS (robust, only on valid samples) ===
ctc_scaled = torch.tensor(0.0, device=device, dtype=dtype)
if getattr(self, "use_ctc_loss", False) and getattr(self, "ctc", None) is not None:
# select samples with larger t for CTC supervision (similar to forward_old)
valid_indices = torch.where(time > 0.5)[0]
if valid_indices.size(0) > 2:
selected_pred = pred[valid_indices]
selected_text = text[valid_indices]
selected_lens = lens[valid_indices]
# text was tokenized from list_str_to_idx, where padding is -1
selected_target_lengths = (selected_text != -1).sum(dim=-1)
ctc_loss = self.ctc(
decoder_outputs=selected_pred,
target_phones=selected_text,
decoder_lengths=selected_lens,
target_lengths=selected_target_lengths,
)
if torch.isfinite(ctc_loss) and ctc_loss.item() > 1e-6:
ctc_scaled = ctc_loss
base_loss = base_loss + 0.1 * ctc_scaled
total_loss = base_loss
# note: we intentionally do NOT add 0.0 * pred.sum() etc. here, to avoid
# propagating NaNs from intermediate tensors into the loss scalar.
return total_loss, accent_loss, ctc_scaled, cond, pred