| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class CondEncoder(nn.Module): |
| def __init__(self, ppg_dim=1280, hubert_dim=256, f0_dim=1, spk_dim=256, cond_out_dim=1024): |
| super().__init__() |
| |
| self.ppg_proj = nn.Linear(ppg_dim, cond_out_dim) |
| self.hubert_proj = nn.Linear(hubert_dim, cond_out_dim) |
| self.spk_proj = nn.Linear(spk_dim, cond_out_dim) |
| |
| |
| self.f0_proj = nn.Sequential( |
| nn.Linear(f0_dim, 64), |
| nn.GELU(), |
| nn.Linear(64, cond_out_dim) |
| ) |
| |
| |
| self.gate = nn.Linear(cond_out_dim * 4, cond_out_dim * 4) |
| self.combine = nn.Linear(cond_out_dim * 4, cond_out_dim) |
| |
| self.cond_out_dim = cond_out_dim |
|
|
| def forward(self, ppg, hubert, f0, spk, target_seq_len): |
| """ |
| ppg: (B, T_ppg, ppg_dim) - e.g. from Whisper ~50Hz |
| hubert: (B, T_hubert, hubert_dim) - e.g. from Hubert ~50Hz |
| f0: (B, T_f0, 1) - e.g. from Crepe ~100Hz |
| spk: (B, spk_dim) - 1D Global embedding |
| target_seq_len: int - e.g. from Codec ~86Hz |
| |
| Returns: |
| c: (B, target_seq_len, cond_out_dim) |
| """ |
| |
| ppg_h = self.ppg_proj(ppg) |
| hubert_h = self.hubert_proj(hubert) |
| f0_h = self.f0_proj(f0) |
| |
| |
| |
| ppg_h = ppg_h.transpose(1, 2) |
| hubert_h = hubert_h.transpose(1, 2) |
| f0_h = f0_h.transpose(1, 2) |
| |
| if ppg_h.shape[2] != target_seq_len: |
| ppg_r = F.interpolate(ppg_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2) |
| else: ppg_r = ppg_h.transpose(1, 2) |
| |
| if hubert_h.shape[2] != target_seq_len: |
| hubert_r = F.interpolate(hubert_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2) |
| else: hubert_r = hubert_h.transpose(1, 2) |
| |
| if f0_h.shape[2] != target_seq_len: |
| f0_r = F.interpolate(f0_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2) |
| else: f0_r = f0_h.transpose(1, 2) |
| |
| |
| spk_h = self.spk_proj(spk) |
| spk_r = spk_h.unsqueeze(1).expand(-1, target_seq_len, -1) |
| |
| |
| stacked = torch.cat([ppg_r, hubert_r, f0_r, spk_r], dim=-1) |
| |
| gate_weights = torch.sigmoid(self.gate(stacked)) |
| gated = stacked * gate_weights |
| |
| c = self.combine(gated) |
| return c |
|
|