| | |
| | """ |
| | RMVPE 模型 - 用于高质量 F0 提取 |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from typing import Optional |
| |
|
| |
|
| | class BiGRU(nn.Module): |
| | """双向 GRU 层""" |
| |
|
| | def __init__(self, input_features: int, hidden_features: int, num_layers: int): |
| | super().__init__() |
| | self.gru = nn.GRU( |
| | input_features, |
| | hidden_features, |
| | num_layers=num_layers, |
| | batch_first=True, |
| | bidirectional=True |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.gru(x)[0] |
| |
|
| |
|
| | class ConvBlockRes(nn.Module): |
| | """残差卷积块""" |
| |
|
| | def __init__(self, in_channels: int, out_channels: int, momentum: float = 0.01, |
| | force_shortcut: bool = False): |
| | super().__init__() |
| | self.conv = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), |
| | nn.BatchNorm2d(out_channels, momentum=momentum), |
| | nn.ReLU(), |
| | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), |
| | nn.BatchNorm2d(out_channels, momentum=momentum), |
| | nn.ReLU() |
| | ) |
| |
|
| | |
| | if in_channels != out_channels or force_shortcut: |
| | self.shortcut = nn.Conv2d(in_channels, out_channels, 1) |
| | self.has_shortcut = True |
| | else: |
| | self.has_shortcut = False |
| |
|
| | def forward(self, x): |
| | if self.has_shortcut: |
| | return self.conv(x) + self.shortcut(x) |
| | else: |
| | return self.conv(x) + x |
| |
|
| |
|
| | class EncoderBlock(nn.Module): |
| | """编码器块 - 包含多个 ConvBlockRes 和一个池化层""" |
| |
|
| | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, |
| | n_blocks: int, momentum: float = 0.01): |
| | super().__init__() |
| | self.conv = nn.ModuleList() |
| | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) |
| | for _ in range(n_blocks - 1): |
| | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) |
| | self.pool = nn.AvgPool2d(kernel_size) |
| |
|
| | def forward(self, x): |
| | for block in self.conv: |
| | x = block(x) |
| | |
| | return self.pool(x), x |
| |
|
| |
|
| | class Encoder(nn.Module): |
| | """RMVPE 编码器""" |
| |
|
| | def __init__(self, in_channels: int, in_size: int, n_encoders: int, |
| | kernel_size: int, n_blocks: int, out_channels: int = 16, |
| | momentum: float = 0.01): |
| | super().__init__() |
| |
|
| | self.n_encoders = n_encoders |
| | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) |
| | self.layers = nn.ModuleList() |
| | self.latent_channels = [] |
| |
|
| | for i in range(n_encoders): |
| | self.layers.append( |
| | EncoderBlock( |
| | in_channels if i == 0 else out_channels * (2 ** (i - 1)), |
| | out_channels * (2 ** i), |
| | kernel_size, |
| | n_blocks, |
| | momentum |
| | ) |
| | ) |
| | self.latent_channels.append(out_channels * (2 ** i)) |
| |
|
| | def forward(self, x): |
| | x = self.bn(x) |
| | concat_tensors = [] |
| | for layer in self.layers: |
| | x, skip = layer(x) |
| | concat_tensors.append(skip) |
| | return x, concat_tensors |
| |
|
| |
|
| | class Intermediate(nn.Module): |
| | """中间层""" |
| |
|
| | def __init__(self, in_channels: int, out_channels: int, n_inters: int, |
| | n_blocks: int, momentum: float = 0.01): |
| | super().__init__() |
| |
|
| | self.layers = nn.ModuleList() |
| | for i in range(n_inters): |
| | if i == 0: |
| | |
| | self.layers.append( |
| | IntermediateBlock(in_channels, out_channels, n_blocks, momentum, first_block_shortcut=True) |
| | ) |
| | else: |
| | |
| | self.layers.append( |
| | IntermediateBlock(out_channels, out_channels, n_blocks, momentum, first_block_shortcut=False) |
| | ) |
| |
|
| | def forward(self, x): |
| | for layer in self.layers: |
| | x = layer(x) |
| | return x |
| |
|
| |
|
| | class IntermediateBlock(nn.Module): |
| | """中间层块""" |
| |
|
| | def __init__(self, in_channels: int, out_channels: int, n_blocks: int, |
| | momentum: float = 0.01, first_block_shortcut: bool = False): |
| | super().__init__() |
| | self.conv = nn.ModuleList() |
| | |
| | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum, force_shortcut=first_block_shortcut)) |
| | for _ in range(n_blocks - 1): |
| | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) |
| |
|
| | def forward(self, x): |
| | for block in self.conv: |
| | x = block(x) |
| | return x |
| |
|
| |
|
| | class DecoderBlock(nn.Module): |
| | """解码器块""" |
| |
|
| | def __init__(self, in_channels: int, out_channels: int, stride: int, |
| | n_blocks: int, momentum: float = 0.01): |
| | super().__init__() |
| | |
| | self.conv1 = nn.Sequential( |
| | nn.ConvTranspose2d(in_channels, out_channels, 3, stride, padding=1, output_padding=1, bias=False), |
| | nn.BatchNorm2d(out_channels, momentum=momentum) |
| | ) |
| | |
| | |
| | |
| | self.conv2 = nn.ModuleList() |
| | self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) |
| | for _ in range(n_blocks - 1): |
| | self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) |
| |
|
| | def forward(self, x, concat_tensor): |
| | x = self.conv1(x) |
| | |
| | diff_h = concat_tensor.size(2) - x.size(2) |
| | diff_w = concat_tensor.size(3) - x.size(3) |
| | if diff_h != 0 or diff_w != 0: |
| | |
| | x = F.pad(x, [0, diff_w, 0, diff_h]) |
| | x = torch.cat([x, concat_tensor], dim=1) |
| | for block in self.conv2: |
| | x = block(x) |
| | return x |
| |
|
| |
|
| | class Decoder(nn.Module): |
| | """RMVPE 解码器""" |
| |
|
| | def __init__(self, in_channels: int, n_decoders: int, stride: int, |
| | n_blocks: int, out_channels: int = 16, momentum: float = 0.01): |
| | super().__init__() |
| |
|
| | self.layers = nn.ModuleList() |
| | for i in range(n_decoders): |
| | out_ch = out_channels * (2 ** (n_decoders - 1 - i)) |
| | in_ch = in_channels if i == 0 else out_channels * (2 ** (n_decoders - i)) |
| | self.layers.append( |
| | DecoderBlock(in_ch, out_ch, stride, n_blocks, momentum) |
| | ) |
| |
|
| | def forward(self, x, concat_tensors): |
| | for i, layer in enumerate(self.layers): |
| | x = layer(x, concat_tensors[-1 - i]) |
| | return x |
| |
|
| |
|
| | class DeepUnet(nn.Module): |
| | """Deep U-Net 架构""" |
| |
|
| | def __init__(self, kernel_size: int, n_blocks: int, en_de_layers: int = 5, |
| | inter_layers: int = 4, in_channels: int = 1, en_out_channels: int = 16): |
| | super().__init__() |
| |
|
| | |
| | encoder_out_channels = en_out_channels * (2 ** (en_de_layers - 1)) |
| | |
| | intermediate_out_channels = encoder_out_channels * 2 |
| |
|
| | self.encoder = Encoder( |
| | in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels |
| | ) |
| | self.intermediate = Intermediate( |
| | encoder_out_channels, |
| | intermediate_out_channels, |
| | inter_layers, n_blocks |
| | ) |
| | self.decoder = Decoder( |
| | intermediate_out_channels, |
| | en_de_layers, kernel_size, n_blocks, en_out_channels |
| | ) |
| |
|
| | def forward(self, x): |
| | x, concat_tensors = self.encoder(x) |
| | x = self.intermediate(x) |
| | x = self.decoder(x, concat_tensors) |
| | return x |
| |
|
| |
|
| | class E2E(nn.Module): |
| | """端到端 RMVPE 模型""" |
| |
|
| | def __init__(self, n_blocks: int, n_gru: int, kernel_size: int, |
| | en_de_layers: int = 5, inter_layers: int = 4, |
| | in_channels: int = 1, en_out_channels: int = 16): |
| | super().__init__() |
| |
|
| | self.unet = DeepUnet( |
| | kernel_size, n_blocks, en_de_layers, inter_layers, |
| | in_channels, en_out_channels |
| | ) |
| | self.cnn = nn.Conv2d(en_out_channels, 3, 3, 1, 1) |
| |
|
| | if n_gru: |
| | self.fc = nn.Sequential( |
| | BiGRU(3 * 128, 256, n_gru), |
| | nn.Linear(512, 360), |
| | nn.Dropout(0.25), |
| | nn.Sigmoid() |
| | ) |
| | else: |
| | self.fc = nn.Sequential( |
| | nn.Linear(3 * 128, 360), |
| | nn.Dropout(0.25), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, mel): |
| | |
| | |
| | if mel.dim() == 3: |
| | |
| | mel = mel.transpose(-1, -2).unsqueeze(1) |
| | elif mel.dim() == 4 and mel.shape[1] == 1: |
| | |
| | mel = mel.transpose(-1, -2) |
| |
|
| | x = self.unet(mel) |
| | x = self.cnn(x) |
| | |
| | |
| | x = x.transpose(1, 2).flatten(-2) |
| | x = self.fc(x) |
| | return x |
| |
|
| |
|
| | class MelSpectrogram(nn.Module): |
| | """Mel 频谱提取""" |
| |
|
| | def __init__(self, n_mel: int = 128, n_fft: int = 1024, win_size: int = 1024, |
| | hop_length: int = 160, sample_rate: int = 16000, |
| | fmin: int = 30, fmax: int = 8000): |
| | super().__init__() |
| |
|
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.win_size = win_size |
| | self.sample_rate = sample_rate |
| | self.n_mel = n_mel |
| |
|
| | |
| | mel_basis = self._mel_filterbank(sample_rate, n_fft, n_mel, fmin, fmax) |
| | self.register_buffer("mel_basis", mel_basis) |
| | self.register_buffer("window", torch.hann_window(win_size)) |
| |
|
| | def _mel_filterbank(self, sr, n_fft, n_mels, fmin, fmax): |
| | """创建 Mel 滤波器组""" |
| | import librosa |
| | |
| | mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=True) |
| | return torch.from_numpy(mel).float() |
| |
|
| | def forward(self, audio): |
| | |
| | spec = torch.stft( |
| | audio, |
| | self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_size, |
| | window=self.window, |
| | center=True, |
| | pad_mode="reflect", |
| | normalized=False, |
| | onesided=True, |
| | return_complex=True |
| | ) |
| | |
| | spec = torch.abs(spec) ** 2 |
| |
|
| | |
| | mel = torch.matmul(self.mel_basis, spec) |
| | mel = torch.log(torch.clamp(mel, min=1e-5)) |
| |
|
| | return mel |
| |
|
| |
|
| | class RMVPE: |
| | """RMVPE F0 提取器封装类""" |
| |
|
| | def __init__(self, model_path: str, device: str = "cuda"): |
| | self.device = device |
| |
|
| | |
| | self.model = E2E(n_blocks=4, n_gru=1, kernel_size=2) |
| | ckpt = torch.load(model_path, map_location="cpu", weights_only=False) |
| | self.model.load_state_dict(ckpt) |
| | self.model = self.model.to(device).eval() |
| |
|
| | |
| | self.mel_extractor = MelSpectrogram().to(device) |
| |
|
| | |
| | cents_mapping = 20 * np.arange(360) + 1997.3794084376191 |
| | self.cents_mapping = np.pad(cents_mapping, (4, 4)) |
| |
|
| | @torch.no_grad() |
| | def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03) -> np.ndarray: |
| | """ |
| | 从音频提取 F0 |
| | |
| | Args: |
| | audio: 16kHz 音频数据 |
| | thred: 置信度阈值 |
| | |
| | Returns: |
| | np.ndarray: F0 序列 |
| | """ |
| | |
| | audio = torch.from_numpy(audio).float().to(self.device) |
| | if audio.dim() == 1: |
| | audio = audio.unsqueeze(0) |
| |
|
| | |
| | mel = self.mel_extractor(audio) |
| |
|
| | |
| | n_frames = mel.shape[-1] |
| |
|
| | |
| | n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames |
| | if n_pad > 0: |
| | mel = F.pad(mel, (0, n_pad), mode='constant', value=0) |
| |
|
| | |
| | hidden = self.model(mel) |
| |
|
| | |
| | hidden = hidden[:, :n_frames, :] |
| | hidden = hidden.squeeze(0).cpu().numpy() |
| |
|
| | |
| | f0 = self._decode(hidden, thred) |
| |
|
| | return f0 |
| |
|
| | def _decode(self, hidden: np.ndarray, thred: float) -> np.ndarray: |
| | """解码隐藏状态为 F0 - 使用官方 RVC 算法""" |
| | |
| | cents = self._to_local_average_cents(hidden, thred) |
| |
|
| | |
| | f0 = 10 * (2 ** (cents / 1200)) |
| | f0[f0 == 10] = 0 |
| |
|
| | return f0 |
| |
|
| | def _to_local_average_cents(self, salience: np.ndarray, thred: float) -> np.ndarray: |
| | """官方 RVC 的 to_local_average_cents 算法""" |
| | |
| | center = np.argmax(salience, axis=1) |
| |
|
| | |
| | salience = np.pad(salience, ((0, 0), (4, 4))) |
| | center += 4 |
| |
|
| | |
| | todo_salience = [] |
| | todo_cents_mapping = [] |
| | starts = center - 4 |
| | ends = center + 5 |
| |
|
| | for idx in range(salience.shape[0]): |
| | todo_salience.append(salience[idx, starts[idx]:ends[idx]]) |
| | todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]]) |
| |
|
| | todo_salience = np.array(todo_salience) |
| | todo_cents_mapping = np.array(todo_cents_mapping) |
| |
|
| | |
| | product_sum = np.sum(todo_salience * todo_cents_mapping, axis=1) |
| | weight_sum = np.sum(todo_salience, axis=1) + 1e-9 |
| | cents = product_sum / weight_sum |
| |
|
| | |
| | maxx = np.max(salience, axis=1) |
| | cents[maxx <= thred] = 0 |
| |
|
| | return cents |
| |
|