File size: 8,231 Bytes
1cd928a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import json
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d

# 外部依赖导入
from .diffusion import GaussianDiffusion, CFGDiffusion
from .wavenet import WaveNet, ControlWaveNet
from .ddsp.vocoder import CombSubFastFacV5A

# ================= 辅助函数 =================

def infonce_loss(spk_embeddings, spk_ids, temperature=0.1, supervised=True):
    """计算音频特征的对比损失"""
    spk_embeddings = F.normalize(spk_embeddings, p=2, dim=1)
    similarity_matrix = torch.matmul(spk_embeddings, spk_embeddings.T) / temperature
    
    if supervised:
        mask = (spk_ids.unsqueeze(1) == spk_ids.t().unsqueeze(0)).float()
        pos_mask = mask - torch.diag(torch.ones(mask.shape[0], device=mask.device))
        neg_mask = 1 - mask
    else:
        pos_mask = torch.eye(spk_embeddings.size(0), device=spk_embeddings.device).bool()
        neg_mask = ~pos_mask
    
    pos_mask_add = neg_mask * (-1000)
    neg_mask_add = pos_mask * (-1000)
    log_infonce_per_example = (similarity_matrix * pos_mask + pos_mask_add).logsumexp(-1) - \
                              (similarity_matrix * neg_mask + neg_mask_add).logsumexp(-1)
    return -torch.mean(log_infonce_per_example)

def get_f0_loss(spk, f0_pred_mu, f0_pred_var, f0_gt):
    """计算 F0 预测的 L1 损失"""
    f0_gt = torch.log(1 + f0_gt / 700)
    f0_gt_mu = f0_gt.mean(dim=1)
    f0_gt_var = f0_gt.var(dim=1)
    loss = F.l1_loss(f0_pred_mu, f0_gt_mu) + F.l1_loss(f0_pred_var, f0_gt_var)
    return loss

def adjust_f0(src_f0, src_log_f0_mean, src_log_f0_var, tar_log_f0_mean, tar_log_f0_var):
    """根据目标分布调整 F0"""
    semitone_difference = 12 * (tar_log_f0_mean - src_log_f0_mean) / torch.log(torch.tensor(2.0))
    semitone_difference_rounded = torch.round(semitone_difference)
    adjustment_factor = torch.pow(2, semitone_difference_rounded / 12)
    adjusted_f0 = src_f0 * adjustment_factor
    return adjusted_f0, semitone_difference_rounded

# ================= 核心组件类 =================

class SpeakerClassifier(nn.Module):
    def __init__(self, input_dim=256, num_speakers=100):
        super(SpeakerClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_speakers)

    def forward(self, x):
        logits = self.fc(x)
        prob = F.softmax(logits, dim=-1)
        pred_label = torch.argmax(prob, dim=-1)
        return logits, pred_label

class F0Predictor(nn.Module):
    def __init__(self, input_dim=256, hidden_dim=512, output_dim=1):
        super(F0Predictor, self).__init__()
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
        )
        self.mean_branch = nn.Linear(hidden_dim, output_dim)
        self.var_branch = nn.Linear(hidden_dim, output_dim)

    def forward(self, spk):
        features = self.shared_layers(spk)
        f0_mean = self.mean_branch(features)
        f0_var = self.var_branch(features)
        return f0_mean, f0_var

# ================= 主模型类 =================

class HQ_SVC(torch.nn.Module):
    def __init__(self, hop_size, args):
        super(HQ_SVC, self).__init__()
        self.hop_size = 512
        in_channels = 256
        num_mels = 128
        self.sampling_rate = 44100
        n_unit = 256
        pcmer_norm = False
        out_dims = 128
        n_layers = 20
        n_chans = 512
        n_hidden = 256

        self.guidance_scale = args.guidance_scale
        self.drop_rate = args.drop_rate
        self.use_pitch_aug = False
        self.use_tfm = args.use_tfm
        self.mode = args.mode
        self.use_mi_loss = args.use_mi_loss
        self.use_style_loss = args.use_style_loss
        self.conv_1 = Conv1d(in_channels, num_mels, 3, 1, padding=1)

        if self.guidance_scale is not None and self.guidance_scale >= 0:
            self.ddsp_model = CombSubFastFacV5A(self.sampling_rate, hop_size, n_unit, self.use_pitch_aug, self.use_tfm, pcmer_norm=pcmer_norm, mode=self.mode)
            self.diff_model = CFGDiffusion(ControlWaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims, drop_rate=self.drop_rate)
        else:
            self.ddsp_model = CombSubFastFacV5A(self.sampling_rate, hop_size, n_unit, self.use_pitch_aug, self.use_tfm, pcmer_norm=pcmer_norm, mode=self.mode)
            self.diff_model = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden))

        self.speaker_classifier = SpeakerClassifier(input_dim=256, num_speakers=100)
        self.ce = nn.CrossEntropyLoss()
        if 'pred_f0' in self.mode:
            self.f0_predictor = F0Predictor(input_dim=256, hidden_dim=512, output_dim=1)

    def forward(self, x, f0, volume, spk, spk_id=None, src_spk=None, gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=100, use_tqdm=True, aug_shift=False, vocoder=None, return_wav=False, use_ssim_loss=False, return_timbre=False):
        device = x.device
        if 'pred_f0' in self.mode and src_spk is not None and infer:
            tar_log_f0_mean, tar_log_f0_var = self.f0_predictor(self.ddsp_model.unit2ctrl.timbre_extractor(spk))
            src_log_f0_mean, src_log_f0_var = self.f0_predictor(self.ddsp_model.unit2ctrl.timbre_extractor(src_spk))
            f0, shift_key = adjust_f0(f0, src_log_f0_mean, src_log_f0_var, tar_log_f0_mean, tar_log_f0_var)
            print(f'shift key: {shift_key}')
        
        outputs = self.ddsp_model(x, f0, volume, spk, aug_shift=aug_shift, infer=infer)
        if 'adaln_mlp' in self.mode:
            ddsp_wav, hidden, timbre_f0, timbre, style = outputs
        else:
            ddsp_wav, hidden, timbre = outputs

        if return_timbre:
            return timbre
        
        ddsp_mel = vocoder.extract(ddsp_wav) if vocoder is not None else None
        if gt_spec is not None:
            gt_spec = gt_spec.permute(0, 2, 1)

        if not infer:
            ddsp_loss = F.mse_loss(ddsp_mel, gt_spec)
            if self.guidance_scale is not None and self.guidance_scale >= 0:
                diff_loss = self.diff_model(hidden, timbre_f0, gt_spec=gt_spec, k_step=k_step, infer=False, guidance_scale=self.guidance_scale)
            else:
                diff_loss = self.diff_model(hidden, gt_spec=gt_spec, k_step=k_step, infer=False)
            
            spk_loss = infonce_loss(timbre, spk_id.to(device), 0.1) if (spk_id is not None and 'infonce' in self.mode) else torch.tensor(0.).to(device)
            f0_loss = get_f0_loss(spk, *self.f0_predictor(timbre), f0.unsqueeze(-1)) if 'pred_f0' in self.mode else torch.tensor(0.).to(device)
            mi_loss = torch.tensor(0.).to(device)
            style_loss = torch.tensor(0.).to(device)

            return ddsp_loss, diff_loss, spk_loss, mi_loss, style_loss, f0_loss
        else:
            if gt_spec is not None:
                b, t, d = ddsp_mel.shape
                ddsp_mel = gt_spec[:, :t, :d]
            
            if k_step > 0:
                if self.guidance_scale is not None and self.guidance_scale >= 0:
                    mel = self.diff_model(hidden, timbre_f0, gt_spec=ddsp_mel, infer=True, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm, guidance_scale=self.guidance_scale)
                else:
                    mel = self.diff_model(hidden, gt_spec=ddsp_mel, infer=True, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
            else:
                mel = ddsp_mel
            
            return vocoder.infer(mel, f0) if return_wav else mel

# ================= 入口加载函数 =================

def load_hq_svc(mode='train', model_path=None, device='cuda', hop_size=512, args=None):
    generator = HQ_SVC(hop_size, args).to(device)
    if mode in ['infer', 'finetune']:
        if model_path is None:
            raise ValueError('model_path must be provided in infer mode')
        cp_dict = torch.load(model_path, map_location=device)
        generator.load_state_dict(cp_dict, strict=False)
        if mode == 'infer':
            generator.eval()
    else:
        generator.train()
    return generator