File size: 15,043 Bytes
762eecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
# -*- coding: utf-8 -*-
"""
RMVPE 模型 - 用于高质量 F0 提取
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional


class BiGRU(nn.Module):
    """双向 GRU 层"""

    def __init__(self, input_features: int, hidden_features: int, num_layers: int):
        super().__init__()
        self.gru = nn.GRU(
            input_features,
            hidden_features,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

    def forward(self, x):
        return self.gru(x)[0]


class ConvBlockRes(nn.Module):
    """残差卷积块"""

    def __init__(self, in_channels: int, out_channels: int, momentum: float = 0.01,
                 force_shortcut: bool = False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels, momentum=momentum),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels, momentum=momentum),
            nn.ReLU()
        )

        # 当通道数不同或强制使用时才创建 shortcut
        if in_channels != out_channels or force_shortcut:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
            self.has_shortcut = True
        else:
            self.has_shortcut = False

    def forward(self, x):
        if self.has_shortcut:
            return self.conv(x) + self.shortcut(x)
        else:
            return self.conv(x) + x


class EncoderBlock(nn.Module):
    """编码器块 - 包含多个 ConvBlockRes 和一个池化层"""

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
                 n_blocks: int, momentum: float = 0.01):
        super().__init__()
        self.conv = nn.ModuleList()
        self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
        for _ in range(n_blocks - 1):
            self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
        self.pool = nn.AvgPool2d(kernel_size)

    def forward(self, x):
        for block in self.conv:
            x = block(x)
        # 返回池化前的张量用于 skip connection
        return self.pool(x), x


class Encoder(nn.Module):
    """RMVPE 编码器"""

    def __init__(self, in_channels: int, in_size: int, n_encoders: int,
                 kernel_size: int, n_blocks: int, out_channels: int = 16,
                 momentum: float = 0.01):
        super().__init__()

        self.n_encoders = n_encoders
        self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.layers = nn.ModuleList()
        self.latent_channels = []

        for i in range(n_encoders):
            self.layers.append(
                EncoderBlock(
                    in_channels if i == 0 else out_channels * (2 ** (i - 1)),
                    out_channels * (2 ** i),
                    kernel_size,
                    n_blocks,
                    momentum
                )
            )
            self.latent_channels.append(out_channels * (2 ** i))

    def forward(self, x):
        x = self.bn(x)
        concat_tensors = []
        for layer in self.layers:
            x, skip = layer(x)
            concat_tensors.append(skip)
        return x, concat_tensors


class Intermediate(nn.Module):
    """中间层"""

    def __init__(self, in_channels: int, out_channels: int, n_inters: int,
                 n_blocks: int, momentum: float = 0.01):
        super().__init__()

        self.layers = nn.ModuleList()
        for i in range(n_inters):
            if i == 0:
                # 第一层: in_channels -> out_channels (256 -> 512)
                self.layers.append(
                    IntermediateBlock(in_channels, out_channels, n_blocks, momentum, first_block_shortcut=True)
                )
            else:
                # 后续层: out_channels -> out_channels (512 -> 512)
                self.layers.append(
                    IntermediateBlock(out_channels, out_channels, n_blocks, momentum, first_block_shortcut=False)
                )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class IntermediateBlock(nn.Module):
    """中间层块"""

    def __init__(self, in_channels: int, out_channels: int, n_blocks: int,
                 momentum: float = 0.01, first_block_shortcut: bool = False):
        super().__init__()
        self.conv = nn.ModuleList()
        # 第一个块可能需要强制使用 shortcut
        self.conv.append(ConvBlockRes(in_channels, out_channels, momentum, force_shortcut=first_block_shortcut))
        for _ in range(n_blocks - 1):
            self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))

    def forward(self, x):
        for block in self.conv:
            x = block(x)
        return x


class DecoderBlock(nn.Module):
    """解码器块"""

    def __init__(self, in_channels: int, out_channels: int, stride: int,
                 n_blocks: int, momentum: float = 0.01):
        super().__init__()
        # conv1: 转置卷积 + BatchNorm (kernel_size=3, stride=stride, padding=1, output_padding=1)
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, stride, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(out_channels, momentum=momentum)
        )
        # conv2: ConvBlockRes 列表
        # 第一个块: in_channels = out_channels * 2 (concat 后), out_channels = out_channels
        # 后续块: in_channels = out_channels, out_channels = out_channels
        self.conv2 = nn.ModuleList()
        self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
        for _ in range(n_blocks - 1):
            self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))

    def forward(self, x, concat_tensor):
        x = self.conv1(x)
        # 处理尺寸不匹配:填充较小的张量使其匹配较大的
        diff_h = concat_tensor.size(2) - x.size(2)
        diff_w = concat_tensor.size(3) - x.size(3)
        if diff_h != 0 or diff_w != 0:
            # 填充 x 使其与 concat_tensor 尺寸匹配
            x = F.pad(x, [0, diff_w, 0, diff_h])
        x = torch.cat([x, concat_tensor], dim=1)
        for block in self.conv2:
            x = block(x)
        return x


class Decoder(nn.Module):
    """RMVPE 解码器"""

    def __init__(self, in_channels: int, n_decoders: int, stride: int,
                 n_blocks: int, out_channels: int = 16, momentum: float = 0.01):
        super().__init__()

        self.layers = nn.ModuleList()
        for i in range(n_decoders):
            out_ch = out_channels * (2 ** (n_decoders - 1 - i))
            in_ch = in_channels if i == 0 else out_channels * (2 ** (n_decoders - i))
            self.layers.append(
                DecoderBlock(in_ch, out_ch, stride, n_blocks, momentum)
            )

    def forward(self, x, concat_tensors):
        for i, layer in enumerate(self.layers):
            x = layer(x, concat_tensors[-1 - i])
        return x


class DeepUnet(nn.Module):
    """Deep U-Net 架构"""

    def __init__(self, kernel_size: int, n_blocks: int, en_de_layers: int = 5,
                 inter_layers: int = 4, in_channels: int = 1, en_out_channels: int = 16):
        super().__init__()

        # Encoder 输出通道: en_out_channels * 2^(en_de_layers-1) = 16 * 16 = 256
        encoder_out_channels = en_out_channels * (2 ** (en_de_layers - 1))
        # Intermediate 输出通道: encoder_out_channels * 2 = 512
        intermediate_out_channels = encoder_out_channels * 2

        self.encoder = Encoder(
            in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
        )
        self.intermediate = Intermediate(
            encoder_out_channels,
            intermediate_out_channels,
            inter_layers, n_blocks
        )
        self.decoder = Decoder(
            intermediate_out_channels,
            en_de_layers, kernel_size, n_blocks, en_out_channels
        )

    def forward(self, x):
        x, concat_tensors = self.encoder(x)
        x = self.intermediate(x)
        x = self.decoder(x, concat_tensors)
        return x


class E2E(nn.Module):
    """端到端 RMVPE 模型"""

    def __init__(self, n_blocks: int, n_gru: int, kernel_size: int,
                 en_de_layers: int = 5, inter_layers: int = 4,
                 in_channels: int = 1, en_out_channels: int = 16):
        super().__init__()

        self.unet = DeepUnet(
            kernel_size, n_blocks, en_de_layers, inter_layers,
            in_channels, en_out_channels
        )
        self.cnn = nn.Conv2d(en_out_channels, 3, 3, 1, 1)

        if n_gru:
            self.fc = nn.Sequential(
                BiGRU(3 * 128, 256, n_gru),
                nn.Linear(512, 360),
                nn.Dropout(0.25),
                nn.Sigmoid()
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(3 * 128, 360),
                nn.Dropout(0.25),
                nn.Sigmoid()
            )

    def forward(self, mel):
        # 输入 mel: [B, 128, T] 或 [B, 1, 128, T]
        # 官方实现期望 [B, 1, T, 128],即 time 在 height,mel bins 在 width
        if mel.dim() == 3:
            # [B, 128, T] -> [B, T, 128] -> [B, 1, T, 128]
            mel = mel.transpose(-1, -2).unsqueeze(1)
        elif mel.dim() == 4 and mel.shape[1] == 1:
            # [B, 1, 128, T] -> [B, 1, T, 128]
            mel = mel.transpose(-1, -2)

        x = self.unet(mel)
        x = self.cnn(x)
        # x shape: (batch, 3, T, 128)
        # 转换为 (batch, T, 384) 其中 384 = 3 * 128
        x = x.transpose(1, 2).flatten(-2)  # (batch, T, 384)
        x = self.fc(x)
        return x


class MelSpectrogram(nn.Module):
    """Mel 频谱提取"""

    def __init__(self, n_mel: int = 128, n_fft: int = 1024, win_size: int = 1024,
                 hop_length: int = 160, sample_rate: int = 16000,
                 fmin: int = 30, fmax: int = 8000):
        super().__init__()

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_size = win_size
        self.sample_rate = sample_rate
        self.n_mel = n_mel

        # 创建 Mel 滤波器组
        mel_basis = self._mel_filterbank(sample_rate, n_fft, n_mel, fmin, fmax)
        self.register_buffer("mel_basis", mel_basis)
        self.register_buffer("window", torch.hann_window(win_size))

    def _mel_filterbank(self, sr, n_fft, n_mels, fmin, fmax):
        """创建 Mel 滤波器组"""
        import librosa
        # 必须使用 htk=True,与官方 RVC RMVPE 保持一致
        mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=True)
        return torch.from_numpy(mel).float()

    def forward(self, audio):
        # STFT
        spec = torch.stft(
            audio,
            self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_size,
            window=self.window,
            center=True,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True
        )
        # 使用功率谱(幅度的平方),与官方 RMVPE 一致
        spec = torch.abs(spec) ** 2

        # Mel 变换
        mel = torch.matmul(self.mel_basis, spec)
        mel = torch.log(torch.clamp(mel, min=1e-5))

        return mel


class RMVPE:
    """RMVPE F0 提取器封装类"""

    def __init__(self, model_path: str, device: str = "cuda"):
        self.device = device

        # 加载模型
        self.model = E2E(n_blocks=4, n_gru=1, kernel_size=2)
        ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
        self.model.load_state_dict(ckpt)
        self.model = self.model.to(device).eval()

        # Mel 频谱提取器
        self.mel_extractor = MelSpectrogram().to(device)

        # 频率映射
        cents_mapping = 20 * np.arange(360) + 1997.3794084376191
        self.cents_mapping = np.pad(cents_mapping, (4, 4))

    @torch.no_grad()
    def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03) -> np.ndarray:
        """
        从音频提取 F0

        Args:
            audio: 16kHz 音频数据
            thred: 置信度阈值

        Returns:
            np.ndarray: F0 序列
        """
        # 转换为张量
        audio = torch.from_numpy(audio).float().to(self.device)
        if audio.dim() == 1:
            audio = audio.unsqueeze(0)

        # 提取 Mel 频谱: [B, 128, T]
        mel = self.mel_extractor(audio)

        # 记录原始帧数
        n_frames = mel.shape[-1]

        # 填充时间维度使其可被 32 整除(5 层池化,每层 /2)
        n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
        if n_pad > 0:
            mel = F.pad(mel, (0, n_pad), mode='constant', value=0)

        # 模型推理 - E2E.forward 会处理 transpose
        hidden = self.model(mel)

        # 移除填充部分,只保留原始帧数
        hidden = hidden[:, :n_frames, :]
        hidden = hidden.squeeze(0).cpu().numpy()

        # 解码 F0
        f0 = self._decode(hidden, thred)

        return f0

    def _decode(self, hidden: np.ndarray, thred: float) -> np.ndarray:
        """解码隐藏状态为 F0 - 使用官方 RVC 算法"""
        # 使用官方的 to_local_average_cents 算法
        cents = self._to_local_average_cents(hidden, thred)

        # 转换 cents 到 Hz
        f0 = 10 * (2 ** (cents / 1200))
        f0[f0 == 10] = 0  # cents=0 时 f0=10,需要置零

        return f0

    def _to_local_average_cents(self, salience: np.ndarray, thred: float) -> np.ndarray:
        """官方 RVC 的 to_local_average_cents 算法"""
        # Step 1: 找到每帧的峰值 bin
        center = np.argmax(salience, axis=1)  # [T]

        # Step 2: 对 salience 进行 padding
        salience = np.pad(salience, ((0, 0), (4, 4)))  # [T, 368]
        center += 4  # 调整 center 索引

        # Step 3: 提取峰值附近 9 个 bin 的窗口并计算加权平均
        todo_salience = []
        todo_cents_mapping = []
        starts = center - 4
        ends = center + 5

        for idx in range(salience.shape[0]):
            todo_salience.append(salience[idx, starts[idx]:ends[idx]])
            todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])

        todo_salience = np.array(todo_salience)  # [T, 9]
        todo_cents_mapping = np.array(todo_cents_mapping)  # [T, 9]

        # Step 4: 加权平均
        product_sum = np.sum(todo_salience * todo_cents_mapping, axis=1)
        weight_sum = np.sum(todo_salience, axis=1) + 1e-9
        cents = product_sum / weight_sum

        # Step 5: 阈值过滤 - 使用原始 salience 的最大值
        maxx = np.max(salience, axis=1)
        cents[maxx <= thred] = 0

        return cents