mason369 commited on
Commit
762eecb
·
verified ·
1 Parent(s): b6f9c90

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. models/__init__.py +7 -0
  2. models/rmvpe.py +439 -0
  3. models/synthesizer.py +853 -0
models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 模型定义模块
4
+ """
5
+ from .rmvpe import RMVPE
6
+
7
+ __all__ = ["RMVPE"]
models/rmvpe.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ RMVPE 模型 - 用于高质量 F0 提取
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import Optional
10
+
11
+
12
+ class BiGRU(nn.Module):
13
+ """双向 GRU 层"""
14
+
15
+ def __init__(self, input_features: int, hidden_features: int, num_layers: int):
16
+ super().__init__()
17
+ self.gru = nn.GRU(
18
+ input_features,
19
+ hidden_features,
20
+ num_layers=num_layers,
21
+ batch_first=True,
22
+ bidirectional=True
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.gru(x)[0]
27
+
28
+
29
+ class ConvBlockRes(nn.Module):
30
+ """残差卷积块"""
31
+
32
+ def __init__(self, in_channels: int, out_channels: int, momentum: float = 0.01,
33
+ force_shortcut: bool = False):
34
+ super().__init__()
35
+ self.conv = nn.Sequential(
36
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
37
+ nn.BatchNorm2d(out_channels, momentum=momentum),
38
+ nn.ReLU(),
39
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
40
+ nn.BatchNorm2d(out_channels, momentum=momentum),
41
+ nn.ReLU()
42
+ )
43
+
44
+ # 当通道数不同或强制使用时才创建 shortcut
45
+ if in_channels != out_channels or force_shortcut:
46
+ self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
47
+ self.has_shortcut = True
48
+ else:
49
+ self.has_shortcut = False
50
+
51
+ def forward(self, x):
52
+ if self.has_shortcut:
53
+ return self.conv(x) + self.shortcut(x)
54
+ else:
55
+ return self.conv(x) + x
56
+
57
+
58
+ class EncoderBlock(nn.Module):
59
+ """编码器块 - 包含多个 ConvBlockRes 和一个池化层"""
60
+
61
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
62
+ n_blocks: int, momentum: float = 0.01):
63
+ super().__init__()
64
+ self.conv = nn.ModuleList()
65
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
66
+ for _ in range(n_blocks - 1):
67
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
68
+ self.pool = nn.AvgPool2d(kernel_size)
69
+
70
+ def forward(self, x):
71
+ for block in self.conv:
72
+ x = block(x)
73
+ # 返回池化前的张量用于 skip connection
74
+ return self.pool(x), x
75
+
76
+
77
+ class Encoder(nn.Module):
78
+ """RMVPE 编码器"""
79
+
80
+ def __init__(self, in_channels: int, in_size: int, n_encoders: int,
81
+ kernel_size: int, n_blocks: int, out_channels: int = 16,
82
+ momentum: float = 0.01):
83
+ super().__init__()
84
+
85
+ self.n_encoders = n_encoders
86
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
87
+ self.layers = nn.ModuleList()
88
+ self.latent_channels = []
89
+
90
+ for i in range(n_encoders):
91
+ self.layers.append(
92
+ EncoderBlock(
93
+ in_channels if i == 0 else out_channels * (2 ** (i - 1)),
94
+ out_channels * (2 ** i),
95
+ kernel_size,
96
+ n_blocks,
97
+ momentum
98
+ )
99
+ )
100
+ self.latent_channels.append(out_channels * (2 ** i))
101
+
102
+ def forward(self, x):
103
+ x = self.bn(x)
104
+ concat_tensors = []
105
+ for layer in self.layers:
106
+ x, skip = layer(x)
107
+ concat_tensors.append(skip)
108
+ return x, concat_tensors
109
+
110
+
111
+ class Intermediate(nn.Module):
112
+ """中间层"""
113
+
114
+ def __init__(self, in_channels: int, out_channels: int, n_inters: int,
115
+ n_blocks: int, momentum: float = 0.01):
116
+ super().__init__()
117
+
118
+ self.layers = nn.ModuleList()
119
+ for i in range(n_inters):
120
+ if i == 0:
121
+ # 第一层: in_channels -> out_channels (256 -> 512)
122
+ self.layers.append(
123
+ IntermediateBlock(in_channels, out_channels, n_blocks, momentum, first_block_shortcut=True)
124
+ )
125
+ else:
126
+ # 后续层: out_channels -> out_channels (512 -> 512)
127
+ self.layers.append(
128
+ IntermediateBlock(out_channels, out_channels, n_blocks, momentum, first_block_shortcut=False)
129
+ )
130
+
131
+ def forward(self, x):
132
+ for layer in self.layers:
133
+ x = layer(x)
134
+ return x
135
+
136
+
137
+ class IntermediateBlock(nn.Module):
138
+ """中间层块"""
139
+
140
+ def __init__(self, in_channels: int, out_channels: int, n_blocks: int,
141
+ momentum: float = 0.01, first_block_shortcut: bool = False):
142
+ super().__init__()
143
+ self.conv = nn.ModuleList()
144
+ # 第一个块可能需要强制使用 shortcut
145
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum, force_shortcut=first_block_shortcut))
146
+ for _ in range(n_blocks - 1):
147
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
148
+
149
+ def forward(self, x):
150
+ for block in self.conv:
151
+ x = block(x)
152
+ return x
153
+
154
+
155
+ class DecoderBlock(nn.Module):
156
+ """解码器块"""
157
+
158
+ def __init__(self, in_channels: int, out_channels: int, stride: int,
159
+ n_blocks: int, momentum: float = 0.01):
160
+ super().__init__()
161
+ # conv1: 转置卷积 + BatchNorm (kernel_size=3, stride=stride, padding=1, output_padding=1)
162
+ self.conv1 = nn.Sequential(
163
+ nn.ConvTranspose2d(in_channels, out_channels, 3, stride, padding=1, output_padding=1, bias=False),
164
+ nn.BatchNorm2d(out_channels, momentum=momentum)
165
+ )
166
+ # conv2: ConvBlockRes 列表
167
+ # 第一个块: in_channels = out_channels * 2 (concat 后), out_channels = out_channels
168
+ # 后续块: in_channels = out_channels, out_channels = out_channels
169
+ self.conv2 = nn.ModuleList()
170
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
171
+ for _ in range(n_blocks - 1):
172
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
173
+
174
+ def forward(self, x, concat_tensor):
175
+ x = self.conv1(x)
176
+ # 处理尺寸不匹配:填充较小的张量使其匹配较大的
177
+ diff_h = concat_tensor.size(2) - x.size(2)
178
+ diff_w = concat_tensor.size(3) - x.size(3)
179
+ if diff_h != 0 or diff_w != 0:
180
+ # 填充 x 使其与 concat_tensor 尺寸匹配
181
+ x = F.pad(x, [0, diff_w, 0, diff_h])
182
+ x = torch.cat([x, concat_tensor], dim=1)
183
+ for block in self.conv2:
184
+ x = block(x)
185
+ return x
186
+
187
+
188
+ class Decoder(nn.Module):
189
+ """RMVPE 解码器"""
190
+
191
+ def __init__(self, in_channels: int, n_decoders: int, stride: int,
192
+ n_blocks: int, out_channels: int = 16, momentum: float = 0.01):
193
+ super().__init__()
194
+
195
+ self.layers = nn.ModuleList()
196
+ for i in range(n_decoders):
197
+ out_ch = out_channels * (2 ** (n_decoders - 1 - i))
198
+ in_ch = in_channels if i == 0 else out_channels * (2 ** (n_decoders - i))
199
+ self.layers.append(
200
+ DecoderBlock(in_ch, out_ch, stride, n_blocks, momentum)
201
+ )
202
+
203
+ def forward(self, x, concat_tensors):
204
+ for i, layer in enumerate(self.layers):
205
+ x = layer(x, concat_tensors[-1 - i])
206
+ return x
207
+
208
+
209
+ class DeepUnet(nn.Module):
210
+ """Deep U-Net 架构"""
211
+
212
+ def __init__(self, kernel_size: int, n_blocks: int, en_de_layers: int = 5,
213
+ inter_layers: int = 4, in_channels: int = 1, en_out_channels: int = 16):
214
+ super().__init__()
215
+
216
+ # Encoder 输出通道: en_out_channels * 2^(en_de_layers-1) = 16 * 16 = 256
217
+ encoder_out_channels = en_out_channels * (2 ** (en_de_layers - 1))
218
+ # Intermediate 输出通道: encoder_out_channels * 2 = 512
219
+ intermediate_out_channels = encoder_out_channels * 2
220
+
221
+ self.encoder = Encoder(
222
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
223
+ )
224
+ self.intermediate = Intermediate(
225
+ encoder_out_channels,
226
+ intermediate_out_channels,
227
+ inter_layers, n_blocks
228
+ )
229
+ self.decoder = Decoder(
230
+ intermediate_out_channels,
231
+ en_de_layers, kernel_size, n_blocks, en_out_channels
232
+ )
233
+
234
+ def forward(self, x):
235
+ x, concat_tensors = self.encoder(x)
236
+ x = self.intermediate(x)
237
+ x = self.decoder(x, concat_tensors)
238
+ return x
239
+
240
+
241
+ class E2E(nn.Module):
242
+ """端到端 RMVPE 模型"""
243
+
244
+ def __init__(self, n_blocks: int, n_gru: int, kernel_size: int,
245
+ en_de_layers: int = 5, inter_layers: int = 4,
246
+ in_channels: int = 1, en_out_channels: int = 16):
247
+ super().__init__()
248
+
249
+ self.unet = DeepUnet(
250
+ kernel_size, n_blocks, en_de_layers, inter_layers,
251
+ in_channels, en_out_channels
252
+ )
253
+ self.cnn = nn.Conv2d(en_out_channels, 3, 3, 1, 1)
254
+
255
+ if n_gru:
256
+ self.fc = nn.Sequential(
257
+ BiGRU(3 * 128, 256, n_gru),
258
+ nn.Linear(512, 360),
259
+ nn.Dropout(0.25),
260
+ nn.Sigmoid()
261
+ )
262
+ else:
263
+ self.fc = nn.Sequential(
264
+ nn.Linear(3 * 128, 360),
265
+ nn.Dropout(0.25),
266
+ nn.Sigmoid()
267
+ )
268
+
269
+ def forward(self, mel):
270
+ # 输入 mel: [B, 128, T] 或 [B, 1, 128, T]
271
+ # 官方实现期望 [B, 1, T, 128],即 time 在 height,mel bins 在 width
272
+ if mel.dim() == 3:
273
+ # [B, 128, T] -> [B, T, 128] -> [B, 1, T, 128]
274
+ mel = mel.transpose(-1, -2).unsqueeze(1)
275
+ elif mel.dim() == 4 and mel.shape[1] == 1:
276
+ # [B, 1, 128, T] -> [B, 1, T, 128]
277
+ mel = mel.transpose(-1, -2)
278
+
279
+ x = self.unet(mel)
280
+ x = self.cnn(x)
281
+ # x shape: (batch, 3, T, 128)
282
+ # 转换为 (batch, T, 384) 其中 384 = 3 * 128
283
+ x = x.transpose(1, 2).flatten(-2) # (batch, T, 384)
284
+ x = self.fc(x)
285
+ return x
286
+
287
+
288
+ class MelSpectrogram(nn.Module):
289
+ """Mel 频谱提取"""
290
+
291
+ def __init__(self, n_mel: int = 128, n_fft: int = 1024, win_size: int = 1024,
292
+ hop_length: int = 160, sample_rate: int = 16000,
293
+ fmin: int = 30, fmax: int = 8000):
294
+ super().__init__()
295
+
296
+ self.n_fft = n_fft
297
+ self.hop_length = hop_length
298
+ self.win_size = win_size
299
+ self.sample_rate = sample_rate
300
+ self.n_mel = n_mel
301
+
302
+ # 创建 Mel 滤波器组
303
+ mel_basis = self._mel_filterbank(sample_rate, n_fft, n_mel, fmin, fmax)
304
+ self.register_buffer("mel_basis", mel_basis)
305
+ self.register_buffer("window", torch.hann_window(win_size))
306
+
307
+ def _mel_filterbank(self, sr, n_fft, n_mels, fmin, fmax):
308
+ """创建 Mel 滤波器组"""
309
+ import librosa
310
+ # 必须使用 htk=True,与官方 RVC RMVPE 保持一致
311
+ mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=True)
312
+ return torch.from_numpy(mel).float()
313
+
314
+ def forward(self, audio):
315
+ # STFT
316
+ spec = torch.stft(
317
+ audio,
318
+ self.n_fft,
319
+ hop_length=self.hop_length,
320
+ win_length=self.win_size,
321
+ window=self.window,
322
+ center=True,
323
+ pad_mode="reflect",
324
+ normalized=False,
325
+ onesided=True,
326
+ return_complex=True
327
+ )
328
+ # 使用功率谱(幅度的平方),与官方 RMVPE 一致
329
+ spec = torch.abs(spec) ** 2
330
+
331
+ # Mel 变换
332
+ mel = torch.matmul(self.mel_basis, spec)
333
+ mel = torch.log(torch.clamp(mel, min=1e-5))
334
+
335
+ return mel
336
+
337
+
338
+ class RMVPE:
339
+ """RMVPE F0 提取器封装类"""
340
+
341
+ def __init__(self, model_path: str, device: str = "cuda"):
342
+ self.device = device
343
+
344
+ # 加载模型
345
+ self.model = E2E(n_blocks=4, n_gru=1, kernel_size=2)
346
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
347
+ self.model.load_state_dict(ckpt)
348
+ self.model = self.model.to(device).eval()
349
+
350
+ # Mel 频谱提取器
351
+ self.mel_extractor = MelSpectrogram().to(device)
352
+
353
+ # 频率映射
354
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
355
+ self.cents_mapping = np.pad(cents_mapping, (4, 4))
356
+
357
+ @torch.no_grad()
358
+ def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03) -> np.ndarray:
359
+ """
360
+ 从音频提取 F0
361
+
362
+ Args:
363
+ audio: 16kHz 音频数据
364
+ thred: 置信度阈值
365
+
366
+ Returns:
367
+ np.ndarray: F0 序列
368
+ """
369
+ # 转换为张量
370
+ audio = torch.from_numpy(audio).float().to(self.device)
371
+ if audio.dim() == 1:
372
+ audio = audio.unsqueeze(0)
373
+
374
+ # 提取 Mel 频谱: [B, 128, T]
375
+ mel = self.mel_extractor(audio)
376
+
377
+ # 记录原始帧数
378
+ n_frames = mel.shape[-1]
379
+
380
+ # 填充时间维度使其可被 32 整除(5 层池化,每层 /2)
381
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
382
+ if n_pad > 0:
383
+ mel = F.pad(mel, (0, n_pad), mode='constant', value=0)
384
+
385
+ # 模型推理 - E2E.forward 会处理 transpose
386
+ hidden = self.model(mel)
387
+
388
+ # 移除填充部分,只保留原始帧数
389
+ hidden = hidden[:, :n_frames, :]
390
+ hidden = hidden.squeeze(0).cpu().numpy()
391
+
392
+ # 解码 F0
393
+ f0 = self._decode(hidden, thred)
394
+
395
+ return f0
396
+
397
+ def _decode(self, hidden: np.ndarray, thred: float) -> np.ndarray:
398
+ """解码隐藏状态为 F0 - 使用官方 RVC 算法"""
399
+ # 使用官方的 to_local_average_cents 算法
400
+ cents = self._to_local_average_cents(hidden, thred)
401
+
402
+ # 转换 cents 到 Hz
403
+ f0 = 10 * (2 ** (cents / 1200))
404
+ f0[f0 == 10] = 0 # cents=0 时 f0=10,需要置零
405
+
406
+ return f0
407
+
408
+ def _to_local_average_cents(self, salience: np.ndarray, thred: float) -> np.ndarray:
409
+ """官方 RVC 的 to_local_average_cents 算法"""
410
+ # Step 1: 找到每帧的峰值 bin
411
+ center = np.argmax(salience, axis=1) # [T]
412
+
413
+ # Step 2: 对 salience 进行 padding
414
+ salience = np.pad(salience, ((0, 0), (4, 4))) # [T, 368]
415
+ center += 4 # 调整 center 索引
416
+
417
+ # Step 3: 提取峰值附近 9 个 bin 的窗口并计算加权平均
418
+ todo_salience = []
419
+ todo_cents_mapping = []
420
+ starts = center - 4
421
+ ends = center + 5
422
+
423
+ for idx in range(salience.shape[0]):
424
+ todo_salience.append(salience[idx, starts[idx]:ends[idx]])
425
+ todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
426
+
427
+ todo_salience = np.array(todo_salience) # [T, 9]
428
+ todo_cents_mapping = np.array(todo_cents_mapping) # [T, 9]
429
+
430
+ # Step 4: 加权平均
431
+ product_sum = np.sum(todo_salience * todo_cents_mapping, axis=1)
432
+ weight_sum = np.sum(todo_salience, axis=1) + 1e-9
433
+ cents = product_sum / weight_sum
434
+
435
+ # Step 5: 阈值过滤 - 使用原始 salience 的最大值
436
+ maxx = np.max(salience, axis=1)
437
+ cents[maxx <= thred] = 0
438
+
439
+ return cents
models/synthesizer.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ RVC v2 合成器模型定义
4
+ """
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Optional, Tuple
10
+ import numpy as np
11
+
12
+
13
+ class LayerNorm(nn.Module):
14
+ """Layer normalization for channels-first tensors"""
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-5):
17
+ super().__init__()
18
+ self.channels = channels
19
+ self.eps = eps
20
+ self.gamma = nn.Parameter(torch.ones(channels))
21
+ self.beta = nn.Parameter(torch.zeros(channels))
22
+
23
+ def forward(self, x):
24
+ # x: [B, C, T]
25
+ x = x.transpose(1, -1) # [B, T, C]
26
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
27
+ return x.transpose(1, -1) # [B, C, T]
28
+
29
+
30
+ class MultiHeadAttention(nn.Module):
31
+ """Multi-head attention module"""
32
+
33
+ def __init__(self, channels: int, out_channels: int, n_heads: int,
34
+ p_dropout: float = 0.0, window_size: Optional[int] = None,
35
+ heads_share: bool = True, block_length: Optional[int] = None,
36
+ proximal_bias: bool = False, proximal_init: bool = False):
37
+ super().__init__()
38
+ assert channels % n_heads == 0
39
+
40
+ self.channels = channels
41
+ self.out_channels = out_channels
42
+ self.n_heads = n_heads
43
+ self.p_dropout = p_dropout
44
+ self.window_size = window_size
45
+ self.heads_share = heads_share
46
+ self.block_length = block_length
47
+ self.proximal_bias = proximal_bias
48
+ self.proximal_init = proximal_init
49
+ self.attn = None
50
+
51
+ self.k_channels = channels // n_heads
52
+ self.conv_q = nn.Conv1d(channels, channels, 1)
53
+ self.conv_k = nn.Conv1d(channels, channels, 1)
54
+ self.conv_v = nn.Conv1d(channels, channels, 1)
55
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
56
+ self.drop = nn.Dropout(p_dropout)
57
+
58
+ if window_size is not None:
59
+ n_heads_rel = 1 if heads_share else n_heads
60
+ rel_stddev = self.k_channels ** -0.5
61
+ self.emb_rel_k = nn.Parameter(
62
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
63
+ )
64
+ self.emb_rel_v = nn.Parameter(
65
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
66
+ )
67
+
68
+ nn.init.xavier_uniform_(self.conv_q.weight)
69
+ nn.init.xavier_uniform_(self.conv_k.weight)
70
+ nn.init.xavier_uniform_(self.conv_v.weight)
71
+ if proximal_init:
72
+ with torch.no_grad():
73
+ self.conv_k.weight.copy_(self.conv_q.weight)
74
+ self.conv_k.bias.copy_(self.conv_q.bias)
75
+
76
+ def forward(self, x, c, attn_mask=None):
77
+ q = self.conv_q(x)
78
+ k = self.conv_k(c)
79
+ v = self.conv_v(c)
80
+
81
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
82
+
83
+ x = self.conv_o(x)
84
+ return x
85
+
86
+ def attention(self, query, key, value, mask=None):
87
+ # query, key, value: [B, C, T]
88
+ b, d, t_s = key.size()
89
+ t_t = query.size(2)
90
+
91
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
92
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
93
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
94
+
95
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
96
+
97
+ if self.window_size is not None:
98
+ assert t_s == t_t, "Relative attention only for self-attention"
99
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
100
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
101
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
102
+ scores = scores + scores_local
103
+
104
+ if self.proximal_bias:
105
+ assert t_s == t_t, "Proximal bias only for self-attention"
106
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
107
+
108
+ if mask is not None:
109
+ scores = scores.masked_fill(mask == 0, -1e4)
110
+ if self.block_length is not None:
111
+ assert t_s == t_t, "Block length only for self-attention"
112
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
113
+ scores = scores.masked_fill(block_mask == 0, -1e4)
114
+
115
+ p_attn = F.softmax(scores, dim=-1)
116
+ p_attn = self.drop(p_attn)
117
+ output = torch.matmul(p_attn, value)
118
+
119
+ if self.window_size is not None:
120
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
121
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
122
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
123
+
124
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
125
+ return output, p_attn
126
+
127
+ def _matmul_with_relative_values(self, x, y):
128
+ ret = torch.matmul(x, y.unsqueeze(0))
129
+ return ret
130
+
131
+ def _matmul_with_relative_keys(self, x, y):
132
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
133
+ return ret
134
+
135
+ def _get_relative_embeddings(self, relative_embeddings, length):
136
+ max_relative_position = 2 * self.window_size + 1
137
+ pad_length = max(length - (self.window_size + 1), 0)
138
+ slice_start_position = max((self.window_size + 1) - length, 0)
139
+ slice_end_position = slice_start_position + 2 * length - 1
140
+ if pad_length > 0:
141
+ padded_relative_embeddings = F.pad(
142
+ relative_embeddings,
143
+ (0, 0, pad_length, pad_length, 0, 0)
144
+ )
145
+ else:
146
+ padded_relative_embeddings = relative_embeddings
147
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
148
+ return used_relative_embeddings
149
+
150
+ def _relative_position_to_absolute_position(self, x):
151
+ batch, heads, length, _ = x.size()
152
+ x = F.pad(x, (0, 1, 0, 0, 0, 0, 0, 0))
153
+ x_flat = x.view(batch, heads, length * 2 * length)
154
+ x_flat = F.pad(x_flat, (0, length - 1, 0, 0, 0, 0))
155
+ x_final = x_flat.view(batch, heads, length + 1, 2 * length - 1)[:, :, :length, length - 1:]
156
+ return x_final
157
+
158
+ def _absolute_position_to_relative_position(self, x):
159
+ batch, heads, length, _ = x.size()
160
+ x = F.pad(x, (0, length - 1, 0, 0, 0, 0, 0, 0))
161
+ x_flat = x.view(batch, heads, length ** 2 + length * (length - 1))
162
+ x_flat = F.pad(x_flat, (length, 0, 0, 0, 0, 0))
163
+ x_final = x_flat.view(batch, heads, length, 2 * length)[:, :, :, 1:]
164
+ return x_final
165
+
166
+ def _attention_bias_proximal(self, length):
167
+ r = torch.arange(length, dtype=torch.float32)
168
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
169
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
170
+
171
+
172
+ class FFN(nn.Module):
173
+ """Feed-forward network with optional causal convolution"""
174
+
175
+ def __init__(self, in_channels: int, out_channels: int, filter_channels: int,
176
+ kernel_size: int, p_dropout: float = 0.0, activation: str = None,
177
+ causal: bool = False):
178
+ super().__init__()
179
+ self.in_channels = in_channels
180
+ self.out_channels = out_channels
181
+ self.filter_channels = filter_channels
182
+ self.kernel_size = kernel_size
183
+ self.p_dropout = p_dropout
184
+ self.activation = activation
185
+ self.causal = causal
186
+
187
+ if causal:
188
+ self.padding = self._causal_padding
189
+ else:
190
+ self.padding = self._same_padding
191
+
192
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
193
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
194
+ self.drop = nn.Dropout(p_dropout)
195
+
196
+ def forward(self, x, x_mask):
197
+ x = self.conv_1(self.padding(x))
198
+ if self.activation == "gelu":
199
+ x = x * torch.sigmoid(1.702 * x)
200
+ else:
201
+ x = torch.relu(x)
202
+ x = self.drop(x)
203
+ x = self.conv_2(self.padding(x))
204
+ return x * x_mask
205
+
206
+ def _causal_padding(self, x):
207
+ if self.kernel_size == 1:
208
+ return x
209
+ pad_l = self.kernel_size - 1
210
+ pad_r = 0
211
+ return F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
212
+
213
+ def _same_padding(self, x):
214
+ if self.kernel_size == 1:
215
+ return x
216
+ pad_l = (self.kernel_size - 1) // 2
217
+ pad_r = self.kernel_size // 2
218
+ return F.pad(x, (pad_l, pad_r, 0, 0, 0, 0))
219
+
220
+
221
+ class Encoder(nn.Module):
222
+ """Transformer encoder with multi-head attention"""
223
+
224
+ def __init__(self, hidden_channels: int, filter_channels: int, n_heads: int,
225
+ n_layers: int, kernel_size: int = 1, p_dropout: float = 0.0,
226
+ window_size: int = 10):
227
+ super().__init__()
228
+ self.hidden_channels = hidden_channels
229
+ self.filter_channels = filter_channels
230
+ self.n_heads = n_heads
231
+ self.n_layers = n_layers
232
+ self.kernel_size = kernel_size
233
+ self.p_dropout = p_dropout
234
+ self.window_size = window_size
235
+
236
+ self.drop = nn.Dropout(p_dropout)
237
+ self.attn_layers = nn.ModuleList()
238
+ self.norm_layers_1 = nn.ModuleList()
239
+ self.ffn_layers = nn.ModuleList()
240
+ self.norm_layers_2 = nn.ModuleList()
241
+
242
+ for _ in range(n_layers):
243
+ self.attn_layers.append(
244
+ MultiHeadAttention(
245
+ hidden_channels, hidden_channels, n_heads,
246
+ p_dropout=p_dropout, window_size=window_size
247
+ )
248
+ )
249
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
250
+ self.ffn_layers.append(
251
+ FFN(hidden_channels, hidden_channels, filter_channels,
252
+ kernel_size, p_dropout=p_dropout)
253
+ )
254
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
255
+
256
+ def forward(self, x, x_mask):
257
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
258
+ x = x * x_mask
259
+ for i in range(self.n_layers):
260
+ y = self.attn_layers[i](x, x, attn_mask)
261
+ y = self.drop(y)
262
+ x = self.norm_layers_1[i](x + y)
263
+
264
+ y = self.ffn_layers[i](x, x_mask)
265
+ y = self.drop(y)
266
+ x = self.norm_layers_2[i](x + y)
267
+ x = x * x_mask
268
+ return x
269
+
270
+
271
+ class TextEncoder(nn.Module):
272
+ """Text encoder for RVC - encodes phone and pitch embeddings"""
273
+
274
+ def __init__(self, out_channels: int, hidden_channels: int, filter_channels: int,
275
+ n_heads: int, n_layers: int, kernel_size: int, p_dropout: float,
276
+ f0: bool = True):
277
+ super().__init__()
278
+ self.out_channels = out_channels
279
+ self.hidden_channels = hidden_channels
280
+ self.filter_channels = filter_channels
281
+ self.n_heads = n_heads
282
+ self.n_layers = n_layers
283
+ self.kernel_size = kernel_size
284
+ self.p_dropout = p_dropout
285
+ self.f0 = f0
286
+
287
+ # Phone embedding: Linear projection from 768-dim HuBERT features
288
+ self.emb_phone = nn.Linear(768, hidden_channels)
289
+
290
+ # Pitch embedding (only if f0 is enabled)
291
+ if f0:
292
+ self.emb_pitch = nn.Embedding(256, hidden_channels)
293
+
294
+ # Transformer encoder
295
+ self.encoder = Encoder(
296
+ hidden_channels, filter_channels, n_heads, n_layers,
297
+ kernel_size, p_dropout
298
+ )
299
+
300
+ # Output projection to mean and log-variance
301
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
302
+
303
+ def forward(self, phone, pitch, lengths):
304
+ """
305
+ Args:
306
+ phone: [B, 768, T] phone features from HuBERT (channels first)
307
+ pitch: [B, T] pitch indices (0-255)
308
+ lengths: [B] sequence lengths
309
+
310
+ Returns:
311
+ m: [B, out_channels, T] mean
312
+ logs: [B, out_channels, T] log-variance
313
+ x_mask: [B, 1, T] mask
314
+ """
315
+ import logging
316
+ log = logging.getLogger(__name__)
317
+
318
+ log.debug(f"[TextEncoder] 输入 phone: shape={phone.shape}")
319
+ log.debug(f"[TextEncoder] 输入 pitch: shape={pitch.shape}, max={pitch.max().item()}, min={pitch.min().item()}")
320
+ log.debug(f"[TextEncoder] 输入 lengths: {lengths}")
321
+
322
+ # Transpose phone from [B, C, T] to [B, T, C] for linear layer
323
+ phone = phone.transpose(1, 2) # [B, T, 768]
324
+ log.debug(f"[TextEncoder] 转置后 phone: shape={phone.shape}")
325
+
326
+ # Create mask
327
+ x_mask = torch.unsqueeze(
328
+ self._sequence_mask(lengths, phone.size(1)), 1
329
+ ).to(phone.dtype)
330
+ log.debug(f"[TextEncoder] x_mask: shape={x_mask.shape}, sum={x_mask.sum().item()}")
331
+
332
+ # Phone embedding
333
+ x = self.emb_phone(phone) # [B, T, hidden_channels]
334
+ log.debug(f"[TextEncoder] emb_phone 输出: shape={x.shape}, max={x.abs().max().item():.4f}, mean={x.abs().mean().item():.4f}")
335
+
336
+ # Add pitch embedding if enabled
337
+ if self.f0 and pitch is not None:
338
+ # Clamp pitch to valid range
339
+ pitch_clamped = torch.clamp(pitch, 0, 255)
340
+ pitch_emb = self.emb_pitch(pitch_clamped)
341
+ log.debug(f"[TextEncoder] emb_pitch 输出: shape={pitch_emb.shape}, max={pitch_emb.abs().max().item():.4f}")
342
+ x = x + pitch_emb
343
+
344
+ # Transpose for conv layers: [B, hidden_channels, T]
345
+ x = x.transpose(1, 2)
346
+ log.debug(f"[TextEncoder] 转置后 x: shape={x.shape}")
347
+
348
+ # Apply mask
349
+ x = x * x_mask
350
+
351
+ # Transformer encoder
352
+ x = self.encoder(x, x_mask)
353
+ log.debug(f"[TextEncoder] Transformer 输出: shape={x.shape}, max={x.abs().max().item():.4f}, mean={x.abs().mean().item():.4f}")
354
+
355
+ # Project to mean and log-variance
356
+ stats = self.proj(x) * x_mask
357
+ m, logs = torch.split(stats, self.out_channels, dim=1)
358
+ log.debug(f"[TextEncoder] 最终输出 m: shape={m.shape}, max={m.abs().max().item():.4f}")
359
+ log.debug(f"[TextEncoder] 最终输出 logs: shape={logs.shape}, max={logs.max().item():.4f}, min={logs.min().item():.4f}")
360
+
361
+ return m, logs, x_mask
362
+
363
+ def _sequence_mask(self, length, max_length=None):
364
+ if max_length is None:
365
+ max_length = length.max()
366
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
367
+ return x.unsqueeze(0) < length.unsqueeze(1)
368
+
369
+
370
+ class ResidualCouplingBlock(nn.Module):
371
+ """残差耦合块"""
372
+
373
+ def __init__(self, channels: int, hidden_channels: int, kernel_size: int,
374
+ dilation_rate: int, n_layers: int, n_flows: int = 4,
375
+ gin_channels: int = 0):
376
+ super().__init__()
377
+ self.flows = nn.ModuleList()
378
+
379
+ for _ in range(n_flows):
380
+ self.flows.append(
381
+ ResidualCouplingLayer(
382
+ channels, hidden_channels, kernel_size,
383
+ dilation_rate, n_layers, gin_channels=gin_channels
384
+ )
385
+ )
386
+ self.flows.append(Flip())
387
+
388
+ def forward(self, x, x_mask, g=None, reverse=False):
389
+ if not reverse:
390
+ for flow in self.flows:
391
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
392
+ else:
393
+ for flow in reversed(self.flows):
394
+ x = flow(x, x_mask, g=g, reverse=reverse)
395
+ return x
396
+
397
+
398
+ class ResidualCouplingLayer(nn.Module):
399
+ """残差耦合层"""
400
+
401
+ def __init__(self, channels: int, hidden_channels: int, kernel_size: int,
402
+ dilation_rate: int, n_layers: int, mean_only: bool = True,
403
+ gin_channels: int = 0):
404
+ super().__init__()
405
+ self.half_channels = channels // 2
406
+ self.mean_only = mean_only
407
+
408
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
409
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels)
410
+ self.post = nn.Conv1d(hidden_channels, self.half_channels, 1)
411
+ self.post.weight.data.zero_()
412
+ self.post.bias.data.zero_()
413
+
414
+ def forward(self, x, x_mask, g=None, reverse=False):
415
+ x0, x1 = torch.split(x, [self.half_channels] * 2, dim=1)
416
+ h = self.pre(x0) * x_mask
417
+ h = self.enc(h, x_mask, g=g)
418
+ stats = self.post(h) * x_mask
419
+ m = stats
420
+
421
+ if not reverse:
422
+ x1 = m + x1 * x_mask
423
+ x = torch.cat([x0, x1], dim=1)
424
+ return x, None
425
+ else:
426
+ x1 = (x1 - m) * x_mask
427
+ x = torch.cat([x0, x1], dim=1)
428
+ return x
429
+
430
+
431
+ class Flip(nn.Module):
432
+ """翻转层"""
433
+
434
+ def forward(self, x, *args, reverse=False, **kwargs):
435
+ x = torch.flip(x, [1])
436
+ return x
437
+
438
+
439
+ class WN(nn.Module):
440
+ """WaveNet 风格网络 (带权重归一化)"""
441
+
442
+ def __init__(self, hidden_channels: int, kernel_size: int,
443
+ dilation_rate: int, n_layers: int, gin_channels: int = 0,
444
+ p_dropout: float = 0):
445
+ super().__init__()
446
+ self.n_layers = n_layers
447
+ self.hidden_channels = hidden_channels
448
+ self.gin_channels = gin_channels
449
+
450
+ self.in_layers = nn.ModuleList()
451
+ self.res_skip_layers = nn.ModuleList()
452
+ self.drop = nn.Dropout(p_dropout)
453
+
454
+ if gin_channels > 0:
455
+ self.cond_layer = nn.utils.weight_norm(
456
+ nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
457
+ )
458
+
459
+ for i in range(n_layers):
460
+ dilation = dilation_rate ** i
461
+ padding = (kernel_size * dilation - dilation) // 2
462
+ self.in_layers.append(
463
+ nn.utils.weight_norm(
464
+ nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
465
+ dilation=dilation, padding=padding)
466
+ )
467
+ )
468
+ # 前 n-1 层输出 2 * hidden_channels,最后一层输出 hidden_channels
469
+ if i < n_layers - 1:
470
+ res_skip_channels = 2 * hidden_channels
471
+ else:
472
+ res_skip_channels = hidden_channels
473
+ self.res_skip_layers.append(
474
+ nn.utils.weight_norm(
475
+ nn.Conv1d(hidden_channels, res_skip_channels, 1)
476
+ )
477
+ )
478
+
479
+ def forward(self, x, x_mask, g=None):
480
+ output = torch.zeros_like(x)
481
+
482
+ if g is not None and self.gin_channels > 0:
483
+ g = self.cond_layer(g)
484
+
485
+ for i in range(self.n_layers):
486
+ x_in = self.in_layers[i](x)
487
+ if g is not None:
488
+ cond_offset = i * 2 * self.hidden_channels
489
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
490
+ x_in = x_in + g_l
491
+
492
+ acts = torch.tanh(x_in[:, :self.hidden_channels]) * torch.sigmoid(x_in[:, self.hidden_channels:])
493
+ acts = self.drop(acts)
494
+ res_skip = self.res_skip_layers[i](acts)
495
+
496
+ if i < self.n_layers - 1:
497
+ # 前 n-1 层:residual + skip
498
+ x = (x + res_skip[:, :self.hidden_channels]) * x_mask
499
+ output = output + res_skip[:, self.hidden_channels:]
500
+ else:
501
+ # 最后一层:只有 residual,加到 output
502
+ x = (x + res_skip) * x_mask
503
+ output = output + res_skip
504
+
505
+ return output * x_mask
506
+
507
+
508
+ class PosteriorEncoder(nn.Module):
509
+ """后验编码器"""
510
+
511
+ def __init__(self, in_channels: int, out_channels: int, hidden_channels: int,
512
+ kernel_size: int, dilation_rate: int, n_layers: int,
513
+ gin_channels: int = 0):
514
+ super().__init__()
515
+ self.out_channels = out_channels
516
+
517
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
518
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels)
519
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
520
+
521
+ def forward(self, x, x_lengths, g=None):
522
+ x_mask = torch.unsqueeze(
523
+ self._sequence_mask(x_lengths, x.size(2)), 1
524
+ ).to(x.dtype)
525
+
526
+ x = self.pre(x) * x_mask
527
+ x = self.enc(x, x_mask, g=g)
528
+ stats = self.proj(x) * x_mask
529
+ m, logs = torch.split(stats, self.out_channels, dim=1)
530
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
531
+ return z, m, logs, x_mask
532
+
533
+ def _sequence_mask(self, length, max_length=None):
534
+ if max_length is None:
535
+ max_length = length.max()
536
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
537
+ return x.unsqueeze(0) < length.unsqueeze(1)
538
+
539
+
540
+ class Generator(nn.Module):
541
+ """NSF-HiFi-GAN 生成器 (带权重归一化)"""
542
+
543
+ def __init__(self, initial_channel: int, resblock_kernel_sizes: list,
544
+ resblock_dilation_sizes: list, upsample_rates: list,
545
+ upsample_initial_channel: int, upsample_kernel_sizes: list,
546
+ gin_channels: int = 0, sr: int = 40000, is_half: bool = False):
547
+ super().__init__()
548
+ self.num_kernels = len(resblock_kernel_sizes)
549
+ self.num_upsamples = len(upsample_rates)
550
+ self.sr = sr
551
+ self.is_half = is_half
552
+
553
+ # 计算上采样因子
554
+ self.upp = int(np.prod(upsample_rates))
555
+
556
+ self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, 3)
557
+
558
+ # NSF 源模块
559
+ self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
560
+
561
+ # 噪声卷积层
562
+ self.noise_convs = nn.ModuleList()
563
+
564
+ self.ups = nn.ModuleList()
565
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
566
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
567
+ self.ups.append(
568
+ nn.utils.weight_norm(
569
+ nn.ConvTranspose1d(
570
+ upsample_initial_channel // (2 ** i),
571
+ c_cur,
572
+ k, u, (k - u) // 2
573
+ )
574
+ )
575
+ )
576
+ # 噪声卷积
577
+ if i + 1 < len(upsample_rates):
578
+ stride_f0 = int(np.prod(upsample_rates[i + 1:]))
579
+ self.noise_convs.append(
580
+ nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)
581
+ )
582
+ else:
583
+ self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=1))
584
+
585
+ self.resblocks = nn.ModuleList()
586
+ for i in range(len(self.ups)):
587
+ ch = upsample_initial_channel // (2 ** (i + 1))
588
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
589
+ self.resblocks.append(ResBlock(ch, k, d))
590
+
591
+ self.conv_post = nn.Conv1d(ch, 1, 7, 1, 3, bias=False)
592
+
593
+ if gin_channels > 0:
594
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
595
+
596
+ def forward(self, x, f0, g=None):
597
+ import logging
598
+ log = logging.getLogger(__name__)
599
+
600
+ log.debug(f"[Generator] 输入 x: shape={x.shape}, max={x.abs().max().item():.4f}, mean={x.abs().mean().item():.4f}")
601
+ log.debug(f"[Generator] 输入 f0: shape={f0.shape}, max={f0.max().item():.1f}, min={f0.min().item():.1f}")
602
+ if g is not None:
603
+ log.debug(f"[Generator] 输入 g: shape={g.shape}, max={g.abs().max().item():.4f}")
604
+
605
+ # 生成 NSF 激励信号
606
+ har_source, _, _ = self.m_source(f0, self.upp)
607
+ har_source = har_source.transpose(1, 2) # [B, 1, T*upp]
608
+ log.debug(f"[Generator] NSF har_source: shape={har_source.shape}, max={har_source.abs().max().item():.4f}")
609
+
610
+ x = self.conv_pre(x)
611
+ log.debug(f"[Generator] conv_pre 输出: shape={x.shape}, max={x.abs().max().item():.4f}")
612
+
613
+ if g is not None:
614
+ x = x + self.cond(g)
615
+ log.debug(f"[Generator] 加入条件后: max={x.abs().max().item():.4f}")
616
+
617
+ for i in range(self.num_upsamples):
618
+ x = F.leaky_relu(x, 0.1)
619
+ x = self.ups[i](x)
620
+
621
+ # 融合噪声
622
+ x_source = self.noise_convs[i](har_source)
623
+ x = x + x_source
624
+
625
+ xs = None
626
+ for j in range(self.num_kernels):
627
+ if xs is None:
628
+ xs = self.resblocks[i * self.num_kernels + j](x)
629
+ else:
630
+ xs += self.resblocks[i * self.num_kernels + j](x)
631
+ x = xs / self.num_kernels
632
+ log.debug(f"[Generator] 上采样层 {i}: shape={x.shape}, max={x.abs().max().item():.4f}")
633
+
634
+ x = F.leaky_relu(x)
635
+ x = self.conv_post(x)
636
+ log.debug(f"[Generator] conv_post 输出: shape={x.shape}, max={x.abs().max().item():.4f}")
637
+ x = torch.tanh(x)
638
+ log.debug(f"[Generator] tanh 输出: shape={x.shape}, max={x.abs().max().item():.4f}")
639
+
640
+ return x
641
+
642
+ def remove_weight_norm(self):
643
+ for l in self.ups:
644
+ nn.utils.remove_weight_norm(l)
645
+ for l in self.resblocks:
646
+ l.remove_weight_norm()
647
+
648
+
649
+ class ResBlock(nn.Module):
650
+ """残差��� (带权重归一化)"""
651
+
652
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple = (1, 3, 5)):
653
+ super().__init__()
654
+ self.convs1 = nn.ModuleList([
655
+ nn.utils.weight_norm(
656
+ nn.Conv1d(channels, channels, kernel_size, 1,
657
+ (kernel_size * d - d) // 2, dilation=d)
658
+ )
659
+ for d in dilation
660
+ ])
661
+ self.convs2 = nn.ModuleList([
662
+ nn.utils.weight_norm(
663
+ nn.Conv1d(channels, channels, kernel_size, 1,
664
+ (kernel_size - 1) // 2)
665
+ )
666
+ for _ in dilation
667
+ ])
668
+
669
+ def forward(self, x):
670
+ for c1, c2 in zip(self.convs1, self.convs2):
671
+ xt = F.leaky_relu(x, 0.1)
672
+ xt = c1(xt)
673
+ xt = F.leaky_relu(xt, 0.1)
674
+ xt = c2(xt)
675
+ x = xt + x
676
+ return x
677
+
678
+ def remove_weight_norm(self):
679
+ for l in self.convs1:
680
+ nn.utils.remove_weight_norm(l)
681
+ for l in self.convs2:
682
+ nn.utils.remove_weight_norm(l)
683
+
684
+
685
+ class SineGenerator(nn.Module):
686
+ """正弦波生成器 - NSF 的核心组件"""
687
+
688
+ def __init__(self, sample_rate: int, harmonic_num: int = 0,
689
+ sine_amp: float = 0.1, noise_std: float = 0.003,
690
+ voiced_threshold: float = 10):
691
+ super().__init__()
692
+ self.sample_rate = sample_rate
693
+ self.harmonic_num = harmonic_num
694
+ self.sine_amp = sine_amp
695
+ self.noise_std = noise_std
696
+ self.voiced_threshold = voiced_threshold
697
+ self.dim = harmonic_num + 1
698
+
699
+ def forward(self, f0: torch.Tensor, upp: int):
700
+ """
701
+ 生成正弦波激励信号
702
+
703
+ Args:
704
+ f0: 基频张量 [B, T]
705
+ upp: 上采样因子
706
+
707
+ Returns:
708
+ 正弦波信号 [B, T*upp, 1]
709
+ """
710
+ with torch.no_grad():
711
+ # 上采样 F0
712
+ f0 = f0.unsqueeze(1) # [B, 1, T]
713
+ f0_up = F.interpolate(f0, scale_factor=upp, mode='nearest')
714
+ f0_up = f0_up.transpose(1, 2) # [B, T*upp, 1]
715
+
716
+ # 生成正弦波
717
+ rad = f0_up / self.sample_rate # 归一化频率
718
+ rad_acc = torch.cumsum(rad, dim=1) % 1 # 累积相位
719
+ sine_wave = torch.sin(2 * np.pi * rad_acc) * self.sine_amp
720
+
721
+ # 静音区域(F0=0)使用噪声
722
+ voiced_mask = (f0_up > self.voiced_threshold).float()
723
+ noise = torch.randn_like(sine_wave) * self.noise_std
724
+ sine_wave = sine_wave * voiced_mask + noise * (1 - voiced_mask)
725
+
726
+ return sine_wave
727
+
728
+
729
+ class SourceModuleHnNSF(nn.Module):
730
+ """谐波加噪声源模块"""
731
+
732
+ def __init__(self, sample_rate: int, harmonic_num: int = 0,
733
+ sine_amp: float = 0.1, noise_std: float = 0.003,
734
+ add_noise_std: float = 0.003):
735
+ super().__init__()
736
+ self.sine_generator = SineGenerator(
737
+ sample_rate, harmonic_num, sine_amp, noise_std
738
+ )
739
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
740
+ self.l_tanh = nn.Tanh()
741
+
742
+ def forward(self, f0: torch.Tensor, upp: int):
743
+ sine = self.sine_generator(f0, upp) # [B, T*upp, 1]
744
+ sine = self.l_tanh(self.l_linear(sine))
745
+ noise = torch.randn_like(sine) * 0.003
746
+ return sine, noise, None # 返回 3 个值以匹配接口
747
+
748
+
749
+ class SynthesizerTrnMs768NSFsid(nn.Module):
750
+ """RVC v2 合成器 (768 维 HuBERT + NSF + SID)"""
751
+
752
+ def __init__(self, spec_channels: int, segment_size: int,
753
+ inter_channels: int, hidden_channels: int, filter_channels: int,
754
+ n_heads: int, n_layers: int, kernel_size: int, p_dropout: float,
755
+ resblock: str, resblock_kernel_sizes: list,
756
+ resblock_dilation_sizes: list, upsample_rates: list,
757
+ upsample_initial_channel: int, upsample_kernel_sizes: list,
758
+ spk_embed_dim: int, gin_channels: int, sr: int):
759
+ super().__init__()
760
+
761
+ self.spec_channels = spec_channels
762
+ self.inter_channels = inter_channels
763
+ self.hidden_channels = hidden_channels
764
+ self.filter_channels = filter_channels
765
+ self.n_heads = n_heads
766
+ self.n_layers = n_layers
767
+ self.kernel_size = kernel_size
768
+ self.p_dropout = p_dropout
769
+ self.resblock = resblock
770
+ self.resblock_kernel_sizes = resblock_kernel_sizes
771
+ self.resblock_dilation_sizes = resblock_dilation_sizes
772
+ self.upsample_rates = upsample_rates
773
+ self.upsample_initial_channel = upsample_initial_channel
774
+ self.upsample_kernel_sizes = upsample_kernel_sizes
775
+ self.segment_size = segment_size
776
+ self.gin_channels = gin_channels
777
+ self.spk_embed_dim = spk_embed_dim
778
+ self.sr = sr
779
+
780
+ # 文本编码器 (使用 TextEncoder 替代 PosteriorEncoder)
781
+ self.enc_p = TextEncoder(
782
+ inter_channels, hidden_channels, filter_channels,
783
+ n_heads, n_layers, kernel_size, p_dropout, f0=True
784
+ )
785
+
786
+ # 解码器/生成器 (NSF-HiFiGAN,内部包含 m_source)
787
+ self.dec = Generator(
788
+ inter_channels, resblock_kernel_sizes, resblock_dilation_sizes,
789
+ upsample_rates, upsample_initial_channel, upsample_kernel_sizes,
790
+ gin_channels, sr=sr
791
+ )
792
+
793
+ # 流
794
+ self.flow = ResidualCouplingBlock(
795
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
796
+ )
797
+
798
+ # 说话人嵌入
799
+ self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)
800
+
801
+ def forward(self, phone, phone_lengths, pitch, nsff0, sid, skip_head=0, return_length=0):
802
+ """前向传播"""
803
+ g = self.emb_g(sid).unsqueeze(-1)
804
+
805
+ # TextEncoder 返回 mean 和 log-variance
806
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
807
+
808
+ # 在编码器外部采样
809
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
810
+
811
+ # 正向 flow
812
+ z = self.flow(z_p, x_mask, g=g)
813
+
814
+ # 生成音频 (传入 f0)
815
+ o = self.dec(z, nsff0, g=g)
816
+
817
+ return o
818
+
819
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=1.0):
820
+ """推理"""
821
+ import logging
822
+ log = logging.getLogger(__name__)
823
+
824
+ log.debug(f"[infer] 输入 phone: shape={phone.shape}, dtype={phone.dtype}")
825
+ log.debug(f"[infer] 输入 phone 统计: max={phone.abs().max().item():.4f}, mean={phone.abs().mean().item():.4f}")
826
+ log.debug(f"[infer] 输入 phone_lengths: {phone_lengths}")
827
+ log.debug(f"[infer] 输入 pitch: shape={pitch.shape}, max={pitch.max().item()}, min={pitch.min().item()}")
828
+ log.debug(f"[infer] 输入 nsff0: shape={nsff0.shape}, max={nsff0.max().item():.1f}, min={nsff0.min().item():.1f}")
829
+ log.debug(f"[infer] 输入 sid: {sid}")
830
+
831
+ g = self.emb_g(sid).unsqueeze(-1)
832
+ log.debug(f"[infer] 说话人嵌入 g: shape={g.shape}, max={g.abs().max().item():.4f}")
833
+
834
+ # TextEncoder 返回 mean 和 log-variance
835
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
836
+ log.debug(f"[infer] TextEncoder 输出:")
837
+ log.debug(f"[infer] m_p: shape={m_p.shape}, max={m_p.abs().max().item():.4f}, mean={m_p.abs().mean().item():.4f}")
838
+ log.debug(f"[infer] logs_p: shape={logs_p.shape}, max={logs_p.max().item():.4f}, min={logs_p.min().item():.4f}")
839
+ log.debug(f"[infer] x_mask: shape={x_mask.shape}, sum={x_mask.sum().item()}")
840
+
841
+ # 在编码器外部采样 (使用较小的噪声系数以获得更稳定的输出)
842
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
843
+ log.debug(f"[infer] 采样后 z_p: shape={z_p.shape}, max={z_p.abs().max().item():.4f}, mean={z_p.abs().mean().item():.4f}")
844
+
845
+ # 反向 flow
846
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
847
+ log.debug(f"[infer] Flow 输出 z: shape={z.shape}, max={z.abs().max().item():.4f}, mean={z.abs().mean().item():.4f}")
848
+
849
+ # 生成音频 (传入 f0,Generator 内部会生成 NSF 激励信号)
850
+ o = self.dec(z * x_mask, nsff0, g=g)
851
+ log.debug(f"[infer] Generator 输出 o: shape={o.shape}, max={o.abs().max().item():.4f}, mean={o.abs().mean().item():.4f}")
852
+
853
+ return o, x_mask