cfm_svc / models /cond_encoder.py
Hector Li
Initial commit for Hugging Face
df93d13
import torch
import torch.nn as nn
import torch.nn.functional as F
class CondEncoder(nn.Module):
def __init__(self, ppg_dim=1280, hubert_dim=256, f0_dim=1, spk_dim=256, cond_out_dim=1024):
super().__init__()
# Projections for each feature
self.ppg_proj = nn.Linear(ppg_dim, cond_out_dim)
self.hubert_proj = nn.Linear(hubert_dim, cond_out_dim)
self.spk_proj = nn.Linear(spk_dim, cond_out_dim)
# Simple f0 embedding (or continuous mapping)
self.f0_proj = nn.Sequential(
nn.Linear(f0_dim, 64),
nn.GELU(),
nn.Linear(64, cond_out_dim)
)
# Gated fusion
self.gate = nn.Linear(cond_out_dim * 4, cond_out_dim * 4)
self.combine = nn.Linear(cond_out_dim * 4, cond_out_dim)
self.cond_out_dim = cond_out_dim
def forward(self, ppg, hubert, f0, spk, target_seq_len):
"""
ppg: (B, T_ppg, ppg_dim) - e.g. from Whisper ~50Hz
hubert: (B, T_hubert, hubert_dim) - e.g. from Hubert ~50Hz
f0: (B, T_f0, 1) - e.g. from Crepe ~100Hz
spk: (B, spk_dim) - 1D Global embedding
target_seq_len: int - e.g. from Codec ~86Hz
Returns:
c: (B, target_seq_len, cond_out_dim)
"""
# 1. Project inputs
ppg_h = self.ppg_proj(ppg) # (B, T_ppg, D)
hubert_h = self.hubert_proj(hubert) # (B, T_hubert, D)
f0_h = self.f0_proj(f0) # (B, T_f0, D)
# 2. Temporal Resampling (Linear interpolation to match target sequence length)
# F.interpolate expects (B, C, T), so we transpose
ppg_h = ppg_h.transpose(1, 2) # (B, D, T_ppg)
hubert_h = hubert_h.transpose(1, 2) # (B, D, T_hubert)
f0_h = f0_h.transpose(1, 2) # (B, D, T_f0)
if ppg_h.shape[2] != target_seq_len:
ppg_r = F.interpolate(ppg_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2)
else: ppg_r = ppg_h.transpose(1, 2)
if hubert_h.shape[2] != target_seq_len:
hubert_r = F.interpolate(hubert_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2)
else: hubert_r = hubert_h.transpose(1, 2)
if f0_h.shape[2] != target_seq_len:
f0_r = F.interpolate(f0_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2)
else: f0_r = f0_h.transpose(1, 2)
# 3. Speaker embedding broadcast
spk_h = self.spk_proj(spk) # (B, D)
spk_r = spk_h.unsqueeze(1).expand(-1, target_seq_len, -1) # (B, T, D)
# 4. Learned Gated Fusion
stacked = torch.cat([ppg_r, hubert_r, f0_r, spk_r], dim=-1) # (B, T, 4D)
gate_weights = torch.sigmoid(self.gate(stacked))
gated = stacked * gate_weights
c = self.combine(gated)
return c