cfm_svc / models /svc_cond_adapter.py
Hector Li
Initial commit for Hugging Face
df93d13
"""
SVCCondAdapter: replaces F5-TTS's text conditioning pathway with SVC features.
F5-TTS text path: char_tokens (B, T) → embed + ConvNeXt → (B, T_mel, text_dim)
SVC replacement: PPG/HuBERT/F0 (B, T_feat, D) → project → (B, T_mel, text_dim)
The output shape matches F5-TTS's text_dim so the DiT sees no change.
Default text_dim=512 for F5-TTS Base (model_dim=1024).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class SVCCondAdapter(nn.Module):
def __init__(
self,
ppg_dim: int = 1280,
hubert_dim: int = 256,
f0_dim: int = 1,
spk_dim: int = 256,
out_dim: int = 512, # must match F5-TTS text_dim
feat_sr: float = 50.0, # Hz — PPG/HuBERT input frame rate
mel_sr: float = 93.75, # Hz — F5-TTS mel frame rate (24000/256)
):
super().__init__()
self.feat_sr = feat_sr
self.mel_sr = mel_sr
feat_in = ppg_dim + hubert_dim + f0_dim
self.content_proj = nn.Sequential(
nn.Linear(feat_in, out_dim * 2),
nn.SiLU(),
nn.Linear(out_dim * 2, out_dim),
)
# Speaker embedding broadcast-added to every frame
self.spk_proj = nn.Linear(spk_dim, out_dim)
# Small-scale init on output layers so the adapter starts near-zero
# (DiT initially sees near-zero conditioning, preserving pretrained state)
# but gradients can still flow. Pure zero-init kills gradients entirely
# because ∂output/∂W = 0 when W = 0.
nn.init.normal_(self.content_proj[-1].weight, std=0.01)
nn.init.zeros_(self.content_proj[-1].bias)
nn.init.normal_(self.spk_proj.weight, std=0.01)
nn.init.zeros_(self.spk_proj.bias)
def forward(
self,
ppg: torch.Tensor, # (B, T_feat, ppg_dim)
hubert: torch.Tensor, # (B, T_feat, hubert_dim)
f0: torch.Tensor, # (B, T_feat, 1)
spk: torch.Tensor, # (B, spk_dim)
target_len: int, # number of mel frames to produce
) -> torch.Tensor: # (B, target_len, out_dim)
feat = torch.cat([ppg, hubert, f0], dim=-1) # (B, T_feat, feat_in)
# Resample from feature frame rate to mel frame rate
feat = feat.transpose(1, 2) # (B, feat_in, T_feat)
feat = F.interpolate(feat, size=target_len, mode="linear", align_corners=False)
feat = feat.transpose(1, 2) # (B, target_len, feat_in)
out = self.content_proj(feat) # (B, target_len, out_dim)
out = out + self.spk_proj(spk).unsqueeze(1) # add speaker (broadcast)
return out