""" FlowMatchingTTS – adapts CosyVoice's Conditional Flow Matching pipeline to the semacs-tts dataset (VQ codes → mel spectrogram). Architecture mirrors MaskedDiffWithXvec from cosyvoice/flow/flow.py: codes → Embedding → causal Transformer → Linear → InterpolateRegulator ↘ ConditionalCFM (ConditionalDecoder UNet1D) ← speaker emb → mel spectrogram loss """ import math import random import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as ckpt from omegaconf import DictConfig from flow_matching.utils.cfm import ConditionalCFM from flow_matching.utils.decoder import ConditionalDecoder from flow_matching.utils.length_regulator import InterpolateRegulator from flow_matching.utils.mask import make_pad_mask class SinusoidalPE(nn.Module): def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 8192): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).float().unsqueeze(1) div = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(x + self.pe[:, :x.size(1)]) def _forward_layer(layer, x, attn_mask, pad_mask): return layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask) class CodeEncoder(nn.Module): """Causal Transformer encoder that operates on pre-embedded VQ codes.""" def __init__( self, hidden_dim: int, num_heads: int, num_layers: int, ffn_dim: int, dropout: float, causal: bool, grad_checkpoint: bool, ): super().__init__() self.causal = causal self.grad_checkpoint = grad_checkpoint self.pos_enc = SinusoidalPE(hidden_dim, dropout) self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=ffn_dim, dropout=dropout, batch_first=True, norm_first=True, # pre-LN for training stability ) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(hidden_dim) self._hidden_dim = hidden_dim def output_size(self) -> int: return self._hidden_dim def forward(self, x: torch.Tensor, lengths: torch.Tensor): """ x: (B, T, hidden_dim) – embedded codes lengths: (B,) – valid code lengths per sample Returns: (B, T, hidden_dim), lengths """ B, T, _ = x.shape x = self.pos_enc(x) pad_mask = make_pad_mask(lengths, T) # (B, T), True = padded attn_mask = None if self.causal: attn_mask = torch.triu( torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 ) for layer in self.layers: if self.grad_checkpoint and self.training: x = ckpt.checkpoint( _forward_layer, layer, x, attn_mask, pad_mask, use_reentrant=False, ) else: x = layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask) return self.norm(x), lengths class FlowMatchingTTS(nn.Module): """ Full TTS flow matching model. Training input: codes (B, num_q, T_codes) VQ code indices code_lens (B,) valid code lengths mel (B, T_mel, n_mels) target log-mel spectrogram mel_lens (B,) valid mel frame lengths spk_emb (B, 192) CAM++ speaker embedding (pre-extracted) Output: {'loss': scalar} conditional flow matching loss """ def __init__(self, cfg): super().__init__() m = cfg.model data = cfg.data cfm = cfg.cfm self.n_mels = data.n_mels # 100 hidden_dim = m.hidden_dim # 768 spk_emb_dim = m.spk_emb_dim # 192 # ── code embedding: sum across quantizers ──────────────────────── self.code_embedding = nn.Embedding(m.codebook_size, hidden_dim) # ── causal transformer encoder ─────────────────────────────────── self.encoder = CodeEncoder( hidden_dim=hidden_dim, num_heads=m.num_heads, num_layers=m.num_layers, ffn_dim=m.ffn_dim, dropout=m.dropout, causal=m.causal, grad_checkpoint=m.grad_checkpoint, ) # ── project encoder output to mel dimension ────────────────────── self.encoder_proj = nn.Linear(hidden_dim, self.n_mels) # ── speaker embedding: 192 → n_mels ───────────────────────────── self.spk_embed_affine = nn.Linear(spk_emb_dim, self.n_mels) # ── length regulator: upsample codes to mel frame rate ─────────── self.length_regulator = InterpolateRegulator( channels=self.n_mels, sampling_ratios=(), ) # ── conditional flow matching decoder ──────────────────────────── # ConditionalDecoder input = concat(x, mu, spks_t, cond) = 4 × n_mels cfm_params = DictConfig({ 'sigma_min': cfm.sigma_min, 'solver': 'euler', 't_scheduler': cfm.t_scheduler, 'training_cfg_rate': cfm.training_cfg_rate, 'inference_cfg_rate': cfm.inference_cfg_rate, 'reg_loss_type': 'l1', }) estimator = ConditionalDecoder( in_channels=4 * self.n_mels, # x + mu + spks_expanded + cond out_channels=self.n_mels, channels=(256, 256), dropout=0.05, attention_head_dim=64, n_blocks=4, num_mid_blocks=12, num_heads=8, act_fn='gelu', ) self.decoder = ConditionalCFM( in_channels=self.n_mels, cfm_params=cfm_params, n_spks=1, spk_emb_dim=self.n_mels, # already projected to n_mels estimator=estimator, ) # ── forward (training) ─────────────────────────────────────────────────── def forward(self, batch: dict, device) -> dict: """ Same interface as cosyvoice MaskedDiffWithXvec: model(batch, device). Batch keys (added by Executor before this call): codes (B, num_q, T_codes) code_lens (B,) mel (B, T_mel, n_mels) mel_lens (B,) embedding (B, 192) L2-normalised CAM++ speaker embedding """ codes = batch['codes'].to(device) code_lens = batch['code_lens'].to(device) mel = batch['mel'].to(device) mel_lens = batch['mel_lens'].to(device) embedding = batch['embedding'].to(device) # (B, 192) # Speaker projection spk = F.normalize(embedding, dim=-1) spk = self.spk_embed_affine(spk) # (B, n_mels) # Code embedding: sum over quantizer axis x = self.code_embedding(codes) # (B, num_q, T, hidden_dim) x = x.sum(dim=1) # (B, T_codes, hidden_dim) # Encode h, _ = self.encoder(x, code_lens) # (B, T_codes, hidden_dim) h = self.encoder_proj(h) # (B, T_codes, n_mels) # Upsample to mel frame rate h, _ = self.length_regulator(h, mel_lens) # (B, T_mel, n_mels) # Build conditioning: random-length mel prefix (50 % chance per sample) conds = torch.zeros_like(mel) for i, j in enumerate(mel_lens.tolist()): if random.random() < 0.5: continue idx = random.randint(0, int(0.8 * j)) conds[i, :idx] = mel[i, :idx] conds = conds.transpose(1, 2) # (B, n_mels, T_mel) # Transpose to (B, n_mels, T) for the CFM decoder mel_t = mel.transpose(1, 2).contiguous() # (B, n_mels, T_mel) h_t = h.transpose(1, 2).contiguous() # (B, n_mels, T_mel) # Safety alignment (no-op when length_regulator works correctly) if mel_t.shape[-1] != h_t.shape[-1]: mel_t = F.interpolate(mel_t, size=h_t.shape[-1], mode='nearest') conds = F.interpolate(conds, size=h_t.shape[-1], mode='nearest') mask = (~make_pad_mask(mel_lens)).to(h) # (B, T_mel) loss, _ = self.decoder.compute_loss( mel_t, mask.unsqueeze(1), h_t, spk, cond=conds, ) return {'loss': loss} # ── inference ──────────────────────────────────────────────────────────── @torch.inference_mode() def inference( self, codes: torch.Tensor, # (1, num_q, T_codes) code_lens: torch.Tensor, # (1,) prompt_mel: torch.Tensor, # (1, T_prompt, n_mels) or empty target_len: int, # desired output mel frames spk_emb: torch.Tensor, # (1, 192) n_timesteps: int = 10, temperature: float = 1.0, ) -> torch.Tensor: """Returns generated mel: (1, n_mels, T_target)""" spk = F.normalize(spk_emb, dim=-1) spk = self.spk_embed_affine(spk) # (1, n_mels) x = self.code_embedding(codes).sum(dim=1) # (1, T_codes, hidden_dim) h, _ = self.encoder(x, code_lens) h = self.encoder_proj(h) # (1, T_codes, n_mels) out_lens = torch.tensor([target_len], device=codes.device) h, _ = self.length_regulator(h, out_lens) # (1, target_len, n_mels) h_t = h.transpose(1, 2).contiguous() # (1, n_mels, target_len) conds = torch.zeros_like(h_t) if prompt_mel.shape[1] > 0: p = min(prompt_mel.shape[1], target_len) conds[:, :, :p] = prompt_mel[:, :p].transpose(1, 2) mask = torch.ones(1, 1, target_len, device=codes.device) mel_out = self.decoder( mu=h_t, mask=mask, spks=spk, cond=conds, n_timesteps=n_timesteps, temperature=temperature, ) return mel_out # (1, n_mels, target_len)