File size: 2,970 Bytes
df93d13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F

class CondEncoder(nn.Module):
    def __init__(self, ppg_dim=1280, hubert_dim=256, f0_dim=1, spk_dim=256, cond_out_dim=1024):
        super().__init__()
        # Projections for each feature
        self.ppg_proj = nn.Linear(ppg_dim, cond_out_dim)
        self.hubert_proj = nn.Linear(hubert_dim, cond_out_dim)
        self.spk_proj = nn.Linear(spk_dim, cond_out_dim)
        
        # Simple f0 embedding (or continuous mapping)
        self.f0_proj = nn.Sequential(
            nn.Linear(f0_dim, 64),
            nn.GELU(),
            nn.Linear(64, cond_out_dim)
        )
        
        # Gated fusion
        self.gate = nn.Linear(cond_out_dim * 4, cond_out_dim * 4) 
        self.combine = nn.Linear(cond_out_dim * 4, cond_out_dim)
        
        self.cond_out_dim = cond_out_dim

    def forward(self, ppg, hubert, f0, spk, target_seq_len):
        """
        ppg: (B, T_ppg, ppg_dim)  - e.g. from Whisper ~50Hz
        hubert: (B, T_hubert, hubert_dim) - e.g. from Hubert ~50Hz
        f0: (B, T_f0, 1)          - e.g. from Crepe ~100Hz
        spk: (B, spk_dim)         - 1D Global embedding
        target_seq_len: int       - e.g. from Codec ~86Hz
        
        Returns:
            c: (B, target_seq_len, cond_out_dim)
        """
        # 1. Project inputs
        ppg_h = self.ppg_proj(ppg)       # (B, T_ppg, D)
        hubert_h = self.hubert_proj(hubert) # (B, T_hubert, D)
        f0_h = self.f0_proj(f0)          # (B, T_f0, D)
        
        # 2. Temporal Resampling (Linear interpolation to match target sequence length)
        # F.interpolate expects (B, C, T), so we transpose
        ppg_h = ppg_h.transpose(1, 2)    # (B, D, T_ppg)
        hubert_h = hubert_h.transpose(1, 2) # (B, D, T_hubert)
        f0_h = f0_h.transpose(1, 2)      # (B, D, T_f0)
        
        if ppg_h.shape[2] != target_seq_len:
            ppg_r = F.interpolate(ppg_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2)
        else: ppg_r = ppg_h.transpose(1, 2)
            
        if hubert_h.shape[2] != target_seq_len:
            hubert_r = F.interpolate(hubert_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2)
        else: hubert_r = hubert_h.transpose(1, 2)
            
        if f0_h.shape[2] != target_seq_len:
            f0_r = F.interpolate(f0_h, size=target_seq_len, mode='linear', align_corners=False).transpose(1, 2)
        else: f0_r = f0_h.transpose(1, 2)
        
        # 3. Speaker embedding broadcast
        spk_h = self.spk_proj(spk) # (B, D)
        spk_r = spk_h.unsqueeze(1).expand(-1, target_seq_len, -1) # (B, T, D)
        
        # 4. Learned Gated Fusion
        stacked = torch.cat([ppg_r, hubert_r, f0_r, spk_r], dim=-1) # (B, T, 4D)
        
        gate_weights = torch.sigmoid(self.gate(stacked))
        gated = stacked * gate_weights
        
        c = self.combine(gated)
        return c