File size: 4,429 Bytes
1cd928a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from .pcmer import PCmer

# ================= 辅助函数 =================
def split_to_dict(tensor, tensor_splits):
    """将张量分割为字典形式"""
    labels = []
    sizes = []
    for k, v in tensor_splits.items():
        labels.append(k)
        sizes.append(v)
    tensors = torch.split(tensor, sizes, dim=-1)
    return dict(zip(labels, tensors))

# ================= 核心组件 =================
class FiLM(nn.Module):
    """特征线性调制层"""
    def __init__(self, feature_dim):
        super().__init__()
        self.scale_fc = nn.Linear(feature_dim, feature_dim)
        self.bias_fc = nn.Linear(feature_dim, feature_dim)
    
    def forward(self, x, condition):
        scale = self.scale_fc(condition)
        bias = self.bias_fc(condition)
        return x * scale + bias

# ================= 主模型 =================
class Unit2ControlFacV5A(nn.Module):
    def __init__(
            self,
            input_channel,
            output_splits,
            use_pitch_aug=False,
            pcmer_norm=False):
        super().__init__()
        self.output_splits = output_splits
        
        # 1. 音色提取器 (film_mlp 使用简单的 MLP 结构)
        self.timbre_extractor = nn.Sequential(
            nn.Linear(256, 512),
            nn.SiLU(),
            nn.Linear(512, 256)
        )
            
        # 2. 特征 Embedding (使用 MLP 结构)
        self.f0_embed = self._make_mlp(1, 256)
        self.phase_embed = self._make_mlp(1, 256)
        self.volume_embed = self._make_mlp(1, 256)

        # 3. 融合模块 (4个特征拼接: timbre_f0, style, phase, volume)
        self.fuse_conv = nn.Conv1d(in_channels=256 * 4, out_channels=256, kernel_size=1)
        self.film = FiLM(256)
        
        # 4. 基础卷积栈
        self.stack = nn.Sequential(
                nn.Conv1d(input_channel, 256, 3, 1, 1),
                nn.GroupNorm(4, 256),
                nn.LeakyReLU(),
                nn.Conv1d(256, 256, 3, 1, 1)) 

        # 5. Transformer 解码器
        self.decoder = PCmer(
            num_layers=3,
            num_heads=8,
            dim_model=256,
            dim_keys=256,
            dim_values=256,
            residual_dropout=0.1,
            attention_dropout=0.1,
            pcmer_norm=pcmer_norm)
        
        self.norm = nn.LayerNorm(256)

        # 6. 输出层
        self.n_out = sum([v for k, v in output_splits.items()])
        self.dense_out = weight_norm(nn.Linear(256, self.n_out))

        if use_pitch_aug:
            self.aug_shift_embed = nn.Linear(1, 256, bias=False)
        else:
            self.aug_shift_embed = None

    def _make_mlp(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.SiLU(),
            nn.Linear(out_channels, out_channels)
        )

    def forward(self, units, f0, phase, volume, spk, spk_id=None, aug_shift=None, is_infer=False):
        # 1. 基础特征提取
        x = self.stack(units.transpose(1, 2)).transpose(1, 2)
        n_frame = f0.shape[1]
        x = x[:, :n_frame, :]
        batch_size = x.shape[0]

        # 2. 说话人与音色解耦
        # film_mlp 模式下默认 use_tfm 为 True 的逻辑
        timbre_embed_raw = self.timbre_extractor(spk).view(batch_size, 1, -1)
        style_embed_raw = (spk.view(batch_size, 1, -1) - timbre_embed_raw)

        # 3. 编码特征
        timbre_f0 = self.f0_embed((1 + f0 / 700).log()) + timbre_embed_raw
        phase_feat = self.phase_embed(phase / np.pi)
        volume_feat = self.volume_embed(volume)

        # 4. 构建 FiLM 调节条件 (拼接 4 个 256 维特征)
        # style_embed 需要在时间维度上扩展
        style_feat = style_embed_raw.expand(-1, n_frame, -1)
        
        condition_style = torch.cat([timbre_f0, style_feat, phase_feat, volume_feat], dim=-1)
        # 通过 1x1 卷积融合维度回 256
        condition_style = self.fuse_conv(condition_style.permute(0, 2, 1)).permute(0, 2, 1)
        
        # 5. 应用 FiLM 调制
        x = self.film(x, condition_style)

        # 6. 后处理与输出
        x = self.decoder(x)
        x = self.norm(x)
        e = self.dense_out(x)
        controls = split_to_dict(e, self.output_splits)
        
        return controls, x, timbre_embed_raw