import torch import torch.nn as nn import torch.nn.functional as F class CondEncoder(nn.Module): def __init__(self, ppg_dim=1280, hubert_dim=256, f0_dim=1, spk_dim=256, cond_out_dim=1024): super().__init__() # Projections for each feature self.ppg_proj = nn.Linear(ppg_dim, cond_out_dim) self.hubert_proj = nn.Linear(hubert_dim, cond_out_dim) self.spk_proj = nn.Linear(spk_dim, cond_out_dim) # Simple f0 embedding (or continuous mapping) self.f0_proj = nn.Sequential( nn.Linear(f0_dim, 64), nn.GELU(), nn.Linear(64, cond_out_dim) ) # Gated fusion self.gate = nn.Linear(cond_out_dim * 4, cond_out_dim * 4) self.combine = nn.Linear(cond_out_dim * 4, cond_out_dim) self.cond_out_dim = cond_out_dim def forward(self, ppg, hubert, f0, spk, target_seq_len): """ ppg: (B, T_ppg, ppg_dim) - e.g. from Whisper ~50Hz hubert: (B, T_hubert, hubert_dim) - e.g. from Hubert ~50Hz f0: (B, T_f0, 1) - e.g. from Crepe ~100Hz spk: (B, spk_dim) - 1D Global embedding target_seq_len: int - e.g. from Codec ~86Hz Returns: c: (B, target_seq_len, cond_out_dim) """ # 1. Project inputs ppg_h = self.ppg_proj(ppg) # (B, T_ppg, D) hubert_h = self.hubert_proj(hubert) # (B, T_hubert, D) f0_h = self.f0_proj(f0) # (B, T_f0, D) # 2. Temporal Resampling (Linear interpolation to match target sequence length) # F.interpolate expects (B, C, T), so we transpose ppg_h = ppg_h.transpose(1, 2) # (B, D, T_ppg) hubert_h = hubert_h.transpose(1, 2) # (B, D, T_hubert) f0_h = f0_h.transpose(1, 2) # (B, D, T_f0) if ppg_h.shape[2] != target_seq_len: ppg_r = F.interpolate(ppg_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2) else: ppg_r = ppg_h.transpose(1, 2) if hubert_h.shape[2] != target_seq_len: hubert_r = F.interpolate(hubert_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2) else: hubert_r = hubert_h.transpose(1, 2) if f0_h.shape[2] != target_seq_len: f0_r = F.interpolate(f0_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2) else: f0_r = f0_h.transpose(1, 2) # 3. Speaker embedding broadcast spk_h = self.spk_proj(spk) # (B, D) spk_r = spk_h.unsqueeze(1).expand(-1, target_seq_len, -1) # (B, T, D) # 4. Learned Gated Fusion stacked = torch.cat([ppg_r, hubert_r, f0_r, spk_r], dim=-1) # (B, T, 4D) gate_weights = torch.sigmoid(self.gate(stacked)) gated = stacked * gate_weights c = self.combine(gated) return c