sunf / flow_matching /model.py
anhtunguyen98's picture
Upload folder using huggingface_hub
4698bfc verified
"""
FlowMatchingTTS – adapts CosyVoice's Conditional Flow Matching pipeline
to the semacs-tts dataset (VQ codes β†’ mel spectrogram).
Architecture mirrors MaskedDiffWithXvec from cosyvoice/flow/flow.py:
codes β†’ Embedding β†’ causal Transformer β†’ Linear β†’ InterpolateRegulator
β†˜ ConditionalCFM (ConditionalDecoder UNet1D) ← speaker emb
β†’ mel spectrogram loss
"""
import math
import random
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as ckpt
from omegaconf import DictConfig
from flow_matching.utils.cfm import ConditionalCFM
from flow_matching.utils.decoder import ConditionalDecoder
from flow_matching.utils.length_regulator import InterpolateRegulator
from flow_matching.utils.mask import make_pad_mask
class SinusoidalPE(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 8192):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).float().unsqueeze(1)
div = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(x + self.pe[:, :x.size(1)])
def _forward_layer(layer, x, attn_mask, pad_mask):
return layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask)
class CodeEncoder(nn.Module):
"""Causal Transformer encoder that operates on pre-embedded VQ codes."""
def __init__(
self,
hidden_dim: int,
num_heads: int,
num_layers: int,
ffn_dim: int,
dropout: float,
causal: bool,
grad_checkpoint: bool,
):
super().__init__()
self.causal = causal
self.grad_checkpoint = grad_checkpoint
self.pos_enc = SinusoidalPE(hidden_dim, dropout)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=ffn_dim,
dropout=dropout,
batch_first=True,
norm_first=True, # pre-LN for training stability
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_dim)
self._hidden_dim = hidden_dim
def output_size(self) -> int:
return self._hidden_dim
def forward(self, x: torch.Tensor, lengths: torch.Tensor):
"""
x: (B, T, hidden_dim) – embedded codes
lengths: (B,) – valid code lengths per sample
Returns: (B, T, hidden_dim), lengths
"""
B, T, _ = x.shape
x = self.pos_enc(x)
pad_mask = make_pad_mask(lengths, T) # (B, T), True = padded
attn_mask = None
if self.causal:
attn_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
)
for layer in self.layers:
if self.grad_checkpoint and self.training:
x = ckpt.checkpoint(
_forward_layer, layer, x, attn_mask, pad_mask,
use_reentrant=False,
)
else:
x = layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask)
return self.norm(x), lengths
class FlowMatchingTTS(nn.Module):
"""
Full TTS flow matching model.
Training input:
codes (B, num_q, T_codes) VQ code indices
code_lens (B,) valid code lengths
mel (B, T_mel, n_mels) target log-mel spectrogram
mel_lens (B,) valid mel frame lengths
spk_emb (B, 192) CAM++ speaker embedding (pre-extracted)
Output:
{'loss': scalar} conditional flow matching loss
"""
def __init__(self, cfg):
super().__init__()
m = cfg.model
data = cfg.data
cfm = cfg.cfm
self.n_mels = data.n_mels # 100
hidden_dim = m.hidden_dim # 768
spk_emb_dim = m.spk_emb_dim # 192
# ── code embedding: sum across quantizers ────────────────────────
self.code_embedding = nn.Embedding(m.codebook_size, hidden_dim)
# ── causal transformer encoder ───────────────────────────────────
self.encoder = CodeEncoder(
hidden_dim=hidden_dim,
num_heads=m.num_heads,
num_layers=m.num_layers,
ffn_dim=m.ffn_dim,
dropout=m.dropout,
causal=m.causal,
grad_checkpoint=m.grad_checkpoint,
)
# ── project encoder output to mel dimension ──────────────────────
self.encoder_proj = nn.Linear(hidden_dim, self.n_mels)
# ── speaker embedding: 192 β†’ n_mels ─────────────────────────────
self.spk_embed_affine = nn.Linear(spk_emb_dim, self.n_mels)
# ── length regulator: upsample codes to mel frame rate ───────────
self.length_regulator = InterpolateRegulator(
channels=self.n_mels,
sampling_ratios=(),
)
# ── conditional flow matching decoder ────────────────────────────
# ConditionalDecoder input = concat(x, mu, spks_t, cond) = 4 Γ— n_mels
cfm_params = DictConfig({
'sigma_min': cfm.sigma_min,
'solver': 'euler',
't_scheduler': cfm.t_scheduler,
'training_cfg_rate': cfm.training_cfg_rate,
'inference_cfg_rate': cfm.inference_cfg_rate,
'reg_loss_type': 'l1',
})
estimator = ConditionalDecoder(
in_channels=4 * self.n_mels, # x + mu + spks_expanded + cond
out_channels=self.n_mels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=4,
num_mid_blocks=12,
num_heads=8,
act_fn='gelu',
)
self.decoder = ConditionalCFM(
in_channels=self.n_mels,
cfm_params=cfm_params,
n_spks=1,
spk_emb_dim=self.n_mels, # already projected to n_mels
estimator=estimator,
)
# ── forward (training) ───────────────────────────────────────────────────
def forward(self, batch: dict, device) -> dict:
"""
Same interface as cosyvoice MaskedDiffWithXvec: model(batch, device).
Batch keys (added by Executor before this call):
codes (B, num_q, T_codes)
code_lens (B,)
mel (B, T_mel, n_mels)
mel_lens (B,)
embedding (B, 192) L2-normalised CAM++ speaker embedding
"""
codes = batch['codes'].to(device)
code_lens = batch['code_lens'].to(device)
mel = batch['mel'].to(device)
mel_lens = batch['mel_lens'].to(device)
embedding = batch['embedding'].to(device) # (B, 192)
# Speaker projection
spk = F.normalize(embedding, dim=-1)
spk = self.spk_embed_affine(spk) # (B, n_mels)
# Code embedding: sum over quantizer axis
x = self.code_embedding(codes) # (B, num_q, T, hidden_dim)
x = x.sum(dim=1) # (B, T_codes, hidden_dim)
# Encode
h, _ = self.encoder(x, code_lens) # (B, T_codes, hidden_dim)
h = self.encoder_proj(h) # (B, T_codes, n_mels)
# Upsample to mel frame rate
h, _ = self.length_regulator(h, mel_lens) # (B, T_mel, n_mels)
# Build conditioning: random-length mel prefix (50 % chance per sample)
conds = torch.zeros_like(mel)
for i, j in enumerate(mel_lens.tolist()):
if random.random() < 0.5:
continue
idx = random.randint(0, int(0.8 * j))
conds[i, :idx] = mel[i, :idx]
conds = conds.transpose(1, 2) # (B, n_mels, T_mel)
# Transpose to (B, n_mels, T) for the CFM decoder
mel_t = mel.transpose(1, 2).contiguous() # (B, n_mels, T_mel)
h_t = h.transpose(1, 2).contiguous() # (B, n_mels, T_mel)
# Safety alignment (no-op when length_regulator works correctly)
if mel_t.shape[-1] != h_t.shape[-1]:
mel_t = F.interpolate(mel_t, size=h_t.shape[-1], mode='nearest')
conds = F.interpolate(conds, size=h_t.shape[-1], mode='nearest')
mask = (~make_pad_mask(mel_lens)).to(h) # (B, T_mel)
loss, _ = self.decoder.compute_loss(
mel_t,
mask.unsqueeze(1),
h_t,
spk,
cond=conds,
)
return {'loss': loss}
# ── inference ────────────────────────────────────────────────────────────
@torch.inference_mode()
def inference(
self,
codes: torch.Tensor, # (1, num_q, T_codes)
code_lens: torch.Tensor, # (1,)
prompt_mel: torch.Tensor, # (1, T_prompt, n_mels) or empty
target_len: int, # desired output mel frames
spk_emb: torch.Tensor, # (1, 192)
n_timesteps: int = 10,
temperature: float = 1.0,
) -> torch.Tensor:
"""Returns generated mel: (1, n_mels, T_target)"""
spk = F.normalize(spk_emb, dim=-1)
spk = self.spk_embed_affine(spk) # (1, n_mels)
x = self.code_embedding(codes).sum(dim=1) # (1, T_codes, hidden_dim)
h, _ = self.encoder(x, code_lens)
h = self.encoder_proj(h) # (1, T_codes, n_mels)
out_lens = torch.tensor([target_len], device=codes.device)
h, _ = self.length_regulator(h, out_lens) # (1, target_len, n_mels)
h_t = h.transpose(1, 2).contiguous() # (1, n_mels, target_len)
conds = torch.zeros_like(h_t)
if prompt_mel.shape[1] > 0:
p = min(prompt_mel.shape[1], target_len)
conds[:, :, :p] = prompt_mel[:, :p].transpose(1, 2)
mask = torch.ones(1, 1, target_len, device=codes.device)
mel_out = self.decoder(
mu=h_t,
mask=mask,
spks=spk,
cond=conds,
n_timesteps=n_timesteps,
temperature=temperature,
)
return mel_out # (1, n_mels, target_len)