| """ |
| FlowMatchingTTS β adapts CosyVoice's Conditional Flow Matching pipeline |
| to the semacs-tts dataset (VQ codes β mel spectrogram). |
| |
| Architecture mirrors MaskedDiffWithXvec from cosyvoice/flow/flow.py: |
| codes β Embedding β causal Transformer β Linear β InterpolateRegulator |
| β ConditionalCFM (ConditionalDecoder UNet1D) β speaker emb |
| β mel spectrogram loss |
| """ |
| import math |
| import random |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as ckpt |
| from omegaconf import DictConfig |
|
|
| from flow_matching.utils.cfm import ConditionalCFM |
| from flow_matching.utils.decoder import ConditionalDecoder |
| from flow_matching.utils.length_regulator import InterpolateRegulator |
| from flow_matching.utils.mask import make_pad_mask |
|
|
|
|
| class SinusoidalPE(nn.Module): |
| def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 8192): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(0, max_len).float().unsqueeze(1) |
| div = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(x + self.pe[:, :x.size(1)]) |
|
|
|
|
| def _forward_layer(layer, x, attn_mask, pad_mask): |
| return layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask) |
|
|
|
|
| class CodeEncoder(nn.Module): |
| """Causal Transformer encoder that operates on pre-embedded VQ codes.""" |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| num_heads: int, |
| num_layers: int, |
| ffn_dim: int, |
| dropout: float, |
| causal: bool, |
| grad_checkpoint: bool, |
| ): |
| super().__init__() |
| self.causal = causal |
| self.grad_checkpoint = grad_checkpoint |
| self.pos_enc = SinusoidalPE(hidden_dim, dropout) |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=num_heads, |
| dim_feedforward=ffn_dim, |
| dropout=dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| for _ in range(num_layers) |
| ]) |
| self.norm = nn.LayerNorm(hidden_dim) |
| self._hidden_dim = hidden_dim |
|
|
| def output_size(self) -> int: |
| return self._hidden_dim |
|
|
| def forward(self, x: torch.Tensor, lengths: torch.Tensor): |
| """ |
| x: (B, T, hidden_dim) β embedded codes |
| lengths: (B,) β valid code lengths per sample |
| Returns: (B, T, hidden_dim), lengths |
| """ |
| B, T, _ = x.shape |
| x = self.pos_enc(x) |
|
|
| pad_mask = make_pad_mask(lengths, T) |
| attn_mask = None |
| if self.causal: |
| attn_mask = torch.triu( |
| torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 |
| ) |
|
|
| for layer in self.layers: |
| if self.grad_checkpoint and self.training: |
| x = ckpt.checkpoint( |
| _forward_layer, layer, x, attn_mask, pad_mask, |
| use_reentrant=False, |
| ) |
| else: |
| x = layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask) |
|
|
| return self.norm(x), lengths |
|
|
|
|
| class FlowMatchingTTS(nn.Module): |
| """ |
| Full TTS flow matching model. |
| |
| Training input: |
| codes (B, num_q, T_codes) VQ code indices |
| code_lens (B,) valid code lengths |
| mel (B, T_mel, n_mels) target log-mel spectrogram |
| mel_lens (B,) valid mel frame lengths |
| spk_emb (B, 192) CAM++ speaker embedding (pre-extracted) |
| |
| Output: |
| {'loss': scalar} conditional flow matching loss |
| """ |
|
|
| def __init__(self, cfg): |
| super().__init__() |
| m = cfg.model |
| data = cfg.data |
| cfm = cfg.cfm |
|
|
| self.n_mels = data.n_mels |
| hidden_dim = m.hidden_dim |
| spk_emb_dim = m.spk_emb_dim |
|
|
| |
| self.code_embedding = nn.Embedding(m.codebook_size, hidden_dim) |
|
|
| |
| self.encoder = CodeEncoder( |
| hidden_dim=hidden_dim, |
| num_heads=m.num_heads, |
| num_layers=m.num_layers, |
| ffn_dim=m.ffn_dim, |
| dropout=m.dropout, |
| causal=m.causal, |
| grad_checkpoint=m.grad_checkpoint, |
| ) |
|
|
| |
| self.encoder_proj = nn.Linear(hidden_dim, self.n_mels) |
|
|
| |
| self.spk_embed_affine = nn.Linear(spk_emb_dim, self.n_mels) |
|
|
| |
| self.length_regulator = InterpolateRegulator( |
| channels=self.n_mels, |
| sampling_ratios=(), |
| ) |
|
|
| |
| |
| cfm_params = DictConfig({ |
| 'sigma_min': cfm.sigma_min, |
| 'solver': 'euler', |
| 't_scheduler': cfm.t_scheduler, |
| 'training_cfg_rate': cfm.training_cfg_rate, |
| 'inference_cfg_rate': cfm.inference_cfg_rate, |
| 'reg_loss_type': 'l1', |
| }) |
| estimator = ConditionalDecoder( |
| in_channels=4 * self.n_mels, |
| out_channels=self.n_mels, |
| channels=(256, 256), |
| dropout=0.05, |
| attention_head_dim=64, |
| n_blocks=4, |
| num_mid_blocks=12, |
| num_heads=8, |
| act_fn='gelu', |
| ) |
| self.decoder = ConditionalCFM( |
| in_channels=self.n_mels, |
| cfm_params=cfm_params, |
| n_spks=1, |
| spk_emb_dim=self.n_mels, |
| estimator=estimator, |
| ) |
|
|
| |
|
|
| def forward(self, batch: dict, device) -> dict: |
| """ |
| Same interface as cosyvoice MaskedDiffWithXvec: model(batch, device). |
| |
| Batch keys (added by Executor before this call): |
| codes (B, num_q, T_codes) |
| code_lens (B,) |
| mel (B, T_mel, n_mels) |
| mel_lens (B,) |
| embedding (B, 192) L2-normalised CAM++ speaker embedding |
| """ |
| codes = batch['codes'].to(device) |
| code_lens = batch['code_lens'].to(device) |
| mel = batch['mel'].to(device) |
| mel_lens = batch['mel_lens'].to(device) |
| embedding = batch['embedding'].to(device) |
|
|
| |
| spk = F.normalize(embedding, dim=-1) |
| spk = self.spk_embed_affine(spk) |
|
|
| |
| x = self.code_embedding(codes) |
| x = x.sum(dim=1) |
|
|
| |
| h, _ = self.encoder(x, code_lens) |
| h = self.encoder_proj(h) |
|
|
| |
| h, _ = self.length_regulator(h, mel_lens) |
|
|
| |
| conds = torch.zeros_like(mel) |
| for i, j in enumerate(mel_lens.tolist()): |
| if random.random() < 0.5: |
| continue |
| idx = random.randint(0, int(0.8 * j)) |
| conds[i, :idx] = mel[i, :idx] |
| conds = conds.transpose(1, 2) |
|
|
| |
| mel_t = mel.transpose(1, 2).contiguous() |
| h_t = h.transpose(1, 2).contiguous() |
|
|
| |
| if mel_t.shape[-1] != h_t.shape[-1]: |
| mel_t = F.interpolate(mel_t, size=h_t.shape[-1], mode='nearest') |
| conds = F.interpolate(conds, size=h_t.shape[-1], mode='nearest') |
|
|
| mask = (~make_pad_mask(mel_lens)).to(h) |
|
|
| loss, _ = self.decoder.compute_loss( |
| mel_t, |
| mask.unsqueeze(1), |
| h_t, |
| spk, |
| cond=conds, |
| ) |
| return {'loss': loss} |
|
|
| |
|
|
| @torch.inference_mode() |
| def inference( |
| self, |
| codes: torch.Tensor, |
| code_lens: torch.Tensor, |
| prompt_mel: torch.Tensor, |
| target_len: int, |
| spk_emb: torch.Tensor, |
| n_timesteps: int = 10, |
| temperature: float = 1.0, |
| ) -> torch.Tensor: |
| """Returns generated mel: (1, n_mels, T_target)""" |
| spk = F.normalize(spk_emb, dim=-1) |
| spk = self.spk_embed_affine(spk) |
|
|
| x = self.code_embedding(codes).sum(dim=1) |
| h, _ = self.encoder(x, code_lens) |
| h = self.encoder_proj(h) |
|
|
| out_lens = torch.tensor([target_len], device=codes.device) |
| h, _ = self.length_regulator(h, out_lens) |
| h_t = h.transpose(1, 2).contiguous() |
|
|
| conds = torch.zeros_like(h_t) |
| if prompt_mel.shape[1] > 0: |
| p = min(prompt_mel.shape[1], target_len) |
| conds[:, :, :p] = prompt_mel[:, :p].transpose(1, 2) |
|
|
| mask = torch.ones(1, 1, target_len, device=codes.device) |
|
|
| mel_out = self.decoder( |
| mu=h_t, |
| mask=mask, |
| spks=spk, |
| cond=conds, |
| n_timesteps=n_timesteps, |
| temperature=temperature, |
| ) |
| return mel_out |
|
|