AI-RVC / models /rmvpe.py
mason369's picture
Upload folder using huggingface_hub
762eecb verified
# -*- coding: utf-8 -*-
"""
RMVPE 模型 - 用于高质量 F0 提取
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional
class BiGRU(nn.Module):
"""双向 GRU 层"""
def __init__(self, input_features: int, hidden_features: int, num_layers: int):
super().__init__()
self.gru = nn.GRU(
input_features,
hidden_features,
num_layers=num_layers,
batch_first=True,
bidirectional=True
)
def forward(self, x):
return self.gru(x)[0]
class ConvBlockRes(nn.Module):
"""残差卷积块"""
def __init__(self, in_channels: int, out_channels: int, momentum: float = 0.01,
force_shortcut: bool = False):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum),
nn.ReLU()
)
# 当通道数不同或强制使用时才创建 shortcut
if in_channels != out_channels or force_shortcut:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
self.has_shortcut = True
else:
self.has_shortcut = False
def forward(self, x):
if self.has_shortcut:
return self.conv(x) + self.shortcut(x)
else:
return self.conv(x) + x
class EncoderBlock(nn.Module):
"""编码器块 - 包含多个 ConvBlockRes 和一个池化层"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
n_blocks: int, momentum: float = 0.01):
super().__init__()
self.conv = nn.ModuleList()
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
for _ in range(n_blocks - 1):
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
self.pool = nn.AvgPool2d(kernel_size)
def forward(self, x):
for block in self.conv:
x = block(x)
# 返回池化前的张量用于 skip connection
return self.pool(x), x
class Encoder(nn.Module):
"""RMVPE 编码器"""
def __init__(self, in_channels: int, in_size: int, n_encoders: int,
kernel_size: int, n_blocks: int, out_channels: int = 16,
momentum: float = 0.01):
super().__init__()
self.n_encoders = n_encoders
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
self.layers = nn.ModuleList()
self.latent_channels = []
for i in range(n_encoders):
self.layers.append(
EncoderBlock(
in_channels if i == 0 else out_channels * (2 ** (i - 1)),
out_channels * (2 ** i),
kernel_size,
n_blocks,
momentum
)
)
self.latent_channels.append(out_channels * (2 ** i))
def forward(self, x):
x = self.bn(x)
concat_tensors = []
for layer in self.layers:
x, skip = layer(x)
concat_tensors.append(skip)
return x, concat_tensors
class Intermediate(nn.Module):
"""中间层"""
def __init__(self, in_channels: int, out_channels: int, n_inters: int,
n_blocks: int, momentum: float = 0.01):
super().__init__()
self.layers = nn.ModuleList()
for i in range(n_inters):
if i == 0:
# 第一层: in_channels -> out_channels (256 -> 512)
self.layers.append(
IntermediateBlock(in_channels, out_channels, n_blocks, momentum, first_block_shortcut=True)
)
else:
# 后续层: out_channels -> out_channels (512 -> 512)
self.layers.append(
IntermediateBlock(out_channels, out_channels, n_blocks, momentum, first_block_shortcut=False)
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class IntermediateBlock(nn.Module):
"""中间层块"""
def __init__(self, in_channels: int, out_channels: int, n_blocks: int,
momentum: float = 0.01, first_block_shortcut: bool = False):
super().__init__()
self.conv = nn.ModuleList()
# 第一个块可能需要强制使用 shortcut
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum, force_shortcut=first_block_shortcut))
for _ in range(n_blocks - 1):
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
def forward(self, x):
for block in self.conv:
x = block(x)
return x
class DecoderBlock(nn.Module):
"""解码器块"""
def __init__(self, in_channels: int, out_channels: int, stride: int,
n_blocks: int, momentum: float = 0.01):
super().__init__()
# conv1: 转置卷积 + BatchNorm (kernel_size=3, stride=stride, padding=1, output_padding=1)
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 3, stride, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(out_channels, momentum=momentum)
)
# conv2: ConvBlockRes 列表
# 第一个块: in_channels = out_channels * 2 (concat 后), out_channels = out_channels
# 后续块: in_channels = out_channels, out_channels = out_channels
self.conv2 = nn.ModuleList()
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
for _ in range(n_blocks - 1):
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
def forward(self, x, concat_tensor):
x = self.conv1(x)
# 处理尺寸不匹配:填充较小的张量使其匹配较大的
diff_h = concat_tensor.size(2) - x.size(2)
diff_w = concat_tensor.size(3) - x.size(3)
if diff_h != 0 or diff_w != 0:
# 填充 x 使其与 concat_tensor 尺寸匹配
x = F.pad(x, [0, diff_w, 0, diff_h])
x = torch.cat([x, concat_tensor], dim=1)
for block in self.conv2:
x = block(x)
return x
class Decoder(nn.Module):
"""RMVPE 解码器"""
def __init__(self, in_channels: int, n_decoders: int, stride: int,
n_blocks: int, out_channels: int = 16, momentum: float = 0.01):
super().__init__()
self.layers = nn.ModuleList()
for i in range(n_decoders):
out_ch = out_channels * (2 ** (n_decoders - 1 - i))
in_ch = in_channels if i == 0 else out_channels * (2 ** (n_decoders - i))
self.layers.append(
DecoderBlock(in_ch, out_ch, stride, n_blocks, momentum)
)
def forward(self, x, concat_tensors):
for i, layer in enumerate(self.layers):
x = layer(x, concat_tensors[-1 - i])
return x
class DeepUnet(nn.Module):
"""Deep U-Net 架构"""
def __init__(self, kernel_size: int, n_blocks: int, en_de_layers: int = 5,
inter_layers: int = 4, in_channels: int = 1, en_out_channels: int = 16):
super().__init__()
# Encoder 输出通道: en_out_channels * 2^(en_de_layers-1) = 16 * 16 = 256
encoder_out_channels = en_out_channels * (2 ** (en_de_layers - 1))
# Intermediate 输出通道: encoder_out_channels * 2 = 512
intermediate_out_channels = encoder_out_channels * 2
self.encoder = Encoder(
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
)
self.intermediate = Intermediate(
encoder_out_channels,
intermediate_out_channels,
inter_layers, n_blocks
)
self.decoder = Decoder(
intermediate_out_channels,
en_de_layers, kernel_size, n_blocks, en_out_channels
)
def forward(self, x):
x, concat_tensors = self.encoder(x)
x = self.intermediate(x)
x = self.decoder(x, concat_tensors)
return x
class E2E(nn.Module):
"""端到端 RMVPE 模型"""
def __init__(self, n_blocks: int, n_gru: int, kernel_size: int,
en_de_layers: int = 5, inter_layers: int = 4,
in_channels: int = 1, en_out_channels: int = 16):
super().__init__()
self.unet = DeepUnet(
kernel_size, n_blocks, en_de_layers, inter_layers,
in_channels, en_out_channels
)
self.cnn = nn.Conv2d(en_out_channels, 3, 3, 1, 1)
if n_gru:
self.fc = nn.Sequential(
BiGRU(3 * 128, 256, n_gru),
nn.Linear(512, 360),
nn.Dropout(0.25),
nn.Sigmoid()
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * 128, 360),
nn.Dropout(0.25),
nn.Sigmoid()
)
def forward(self, mel):
# 输入 mel: [B, 128, T] 或 [B, 1, 128, T]
# 官方实现期望 [B, 1, T, 128],即 time 在 height,mel bins 在 width
if mel.dim() == 3:
# [B, 128, T] -> [B, T, 128] -> [B, 1, T, 128]
mel = mel.transpose(-1, -2).unsqueeze(1)
elif mel.dim() == 4 and mel.shape[1] == 1:
# [B, 1, 128, T] -> [B, 1, T, 128]
mel = mel.transpose(-1, -2)
x = self.unet(mel)
x = self.cnn(x)
# x shape: (batch, 3, T, 128)
# 转换为 (batch, T, 384) 其中 384 = 3 * 128
x = x.transpose(1, 2).flatten(-2) # (batch, T, 384)
x = self.fc(x)
return x
class MelSpectrogram(nn.Module):
"""Mel 频谱提取"""
def __init__(self, n_mel: int = 128, n_fft: int = 1024, win_size: int = 1024,
hop_length: int = 160, sample_rate: int = 16000,
fmin: int = 30, fmax: int = 8000):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_size = win_size
self.sample_rate = sample_rate
self.n_mel = n_mel
# 创建 Mel 滤波器组
mel_basis = self._mel_filterbank(sample_rate, n_fft, n_mel, fmin, fmax)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", torch.hann_window(win_size))
def _mel_filterbank(self, sr, n_fft, n_mels, fmin, fmax):
"""创建 Mel 滤波器组"""
import librosa
# 必须使用 htk=True,与官方 RVC RMVPE 保持一致
mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=True)
return torch.from_numpy(mel).float()
def forward(self, audio):
# STFT
spec = torch.stft(
audio,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_size,
window=self.window,
center=True,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True
)
# 使用功率谱(幅度的平方),与官方 RMVPE 一致
spec = torch.abs(spec) ** 2
# Mel 变换
mel = torch.matmul(self.mel_basis, spec)
mel = torch.log(torch.clamp(mel, min=1e-5))
return mel
class RMVPE:
"""RMVPE F0 提取器封装类"""
def __init__(self, model_path: str, device: str = "cuda"):
self.device = device
# 加载模型
self.model = E2E(n_blocks=4, n_gru=1, kernel_size=2)
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
self.model.load_state_dict(ckpt)
self.model = self.model.to(device).eval()
# Mel 频谱提取器
self.mel_extractor = MelSpectrogram().to(device)
# 频率映射
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4))
@torch.no_grad()
def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03) -> np.ndarray:
"""
从音频提取 F0
Args:
audio: 16kHz 音频数据
thred: 置信度阈值
Returns:
np.ndarray: F0 序列
"""
# 转换为张量
audio = torch.from_numpy(audio).float().to(self.device)
if audio.dim() == 1:
audio = audio.unsqueeze(0)
# 提取 Mel 频谱: [B, 128, T]
mel = self.mel_extractor(audio)
# 记录原始帧数
n_frames = mel.shape[-1]
# 填充时间维度使其可被 32 整除(5 层池化,每层 /2)
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
if n_pad > 0:
mel = F.pad(mel, (0, n_pad), mode='constant', value=0)
# 模型推理 - E2E.forward 会处理 transpose
hidden = self.model(mel)
# 移除填充部分,只保留原始帧数
hidden = hidden[:, :n_frames, :]
hidden = hidden.squeeze(0).cpu().numpy()
# 解码 F0
f0 = self._decode(hidden, thred)
return f0
def _decode(self, hidden: np.ndarray, thred: float) -> np.ndarray:
"""解码隐藏状态为 F0 - 使用官方 RVC 算法"""
# 使用官方的 to_local_average_cents 算法
cents = self._to_local_average_cents(hidden, thred)
# 转换 cents 到 Hz
f0 = 10 * (2 ** (cents / 1200))
f0[f0 == 10] = 0 # cents=0 时 f0=10,需要置零
return f0
def _to_local_average_cents(self, salience: np.ndarray, thred: float) -> np.ndarray:
"""官方 RVC 的 to_local_average_cents 算法"""
# Step 1: 找到每帧的峰值 bin
center = np.argmax(salience, axis=1) # [T]
# Step 2: 对 salience 进行 padding
salience = np.pad(salience, ((0, 0), (4, 4))) # [T, 368]
center += 4 # 调整 center 索引
# Step 3: 提取峰值附近 9 个 bin 的窗口并计算加权平均
todo_salience = []
todo_cents_mapping = []
starts = center - 4
ends = center + 5
for idx in range(salience.shape[0]):
todo_salience.append(salience[idx, starts[idx]:ends[idx]])
todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
todo_salience = np.array(todo_salience) # [T, 9]
todo_cents_mapping = np.array(todo_cents_mapping) # [T, 9]
# Step 4: 加权平均
product_sum = np.sum(todo_salience * todo_cents_mapping, axis=1)
weight_sum = np.sum(todo_salience, axis=1) + 1e-9
cents = product_sum / weight_sum
# Step 5: 阈值过滤 - 使用原始 salience 的最大值
maxx = np.max(salience, axis=1)
cents[maxx <= thred] = 0
return cents