File size: 8,758 Bytes
b15e31b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | # -*- coding: utf-8 -*-
"""
F0 (基频) 提取模块 - 支持多种提取方法
"""
import numpy as np
import torch
from typing import Optional, Literal
# F0 提取方法类型
F0Method = Literal["rmvpe", "pm", "harvest", "crepe", "hybrid"]
class F0Extractor:
"""F0 提取器基类"""
def __init__(self, sample_rate: int = 16000, hop_length: int = 160):
self.sample_rate = sample_rate
self.hop_length = hop_length
self.f0_min = 50
self.f0_max = 1100
def extract(self, audio: np.ndarray) -> np.ndarray:
"""提取 F0,子类需实现此方法"""
raise NotImplementedError
class PMExtractor(F0Extractor):
"""Parselmouth (Praat) F0 提取器 - 速度快"""
def extract(self, audio: np.ndarray) -> np.ndarray:
import parselmouth
time_step = self.hop_length / self.sample_rate
sound = parselmouth.Sound(audio, self.sample_rate)
pitch = sound.to_pitch_ac(
time_step=time_step,
voicing_threshold=0.6,
pitch_floor=self.f0_min,
pitch_ceiling=self.f0_max
)
f0 = pitch.selected_array["frequency"]
f0[f0 == 0] = np.nan
return f0
class HarvestExtractor(F0Extractor):
"""PyWorld Harvest F0 提取器 - 质量较好"""
def extract(self, audio: np.ndarray) -> np.ndarray:
import pyworld
audio = audio.astype(np.float64)
f0, _ = pyworld.harvest(
audio,
self.sample_rate,
f0_floor=self.f0_min,
f0_ceil=self.f0_max,
frame_period=self.hop_length / self.sample_rate * 1000
)
return f0
class CrepeExtractor(F0Extractor):
"""TorchCrepe F0 提取器 - 深度学习方法"""
def __init__(self, sample_rate: int = 16000, hop_length: int = 160,
device: str = "cuda"):
super().__init__(sample_rate, hop_length)
self.device = device
def extract(self, audio: np.ndarray) -> np.ndarray:
import torchcrepe
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
audio_tensor = audio_tensor.to(self.device)
f0, _ = torchcrepe.predict(
audio_tensor,
self.sample_rate,
self.hop_length,
self.f0_min,
self.f0_max,
model="full",
batch_size=512,
device=self.device,
return_periodicity=True
)
f0 = f0.squeeze(0).cpu().numpy()
return f0
class RMVPEExtractor(F0Extractor):
"""RMVPE F0 提取器 - 质量最高 (推荐)"""
def __init__(self, model_path: str, sample_rate: int = 16000,
hop_length: int = 160, device: str = "cuda"):
super().__init__(sample_rate, hop_length)
self.device = device
self.model = None
self.model_path = model_path
def load_model(self):
"""加载 RMVPE 模型"""
if self.model is not None:
return
from models.rmvpe import RMVPE
self.model = RMVPE(self.model_path, device=self.device)
print(f"RMVPE 模型已加载: {self.device}")
def extract(self, audio: np.ndarray) -> np.ndarray:
self.load_model()
# RMVPE 需要 16kHz 输入
f0 = self.model.infer_from_audio(audio, thred=0.01)
return f0
def get_f0_extractor(method: F0Method, device: str = "cuda",
rmvpe_path: str = None, crepe_threshold: float = 0.05) -> F0Extractor:
"""
获取 F0 提取器实例
Args:
method: 提取方法 ("rmvpe", "pm", "harvest", "crepe", "hybrid")
device: 计算设备
rmvpe_path: RMVPE 模型路径 (rmvpe/hybrid 方法需要)
crepe_threshold: CREPE置信度阈值 (仅hybrid方法使用)
Returns:
F0Extractor: 提取器实例
"""
if method == "rmvpe":
if rmvpe_path is None:
raise ValueError("RMVPE 方法需要指定模型路径")
return RMVPEExtractor(rmvpe_path, device=device)
elif method == "hybrid":
if rmvpe_path is None:
raise ValueError("Hybrid 方法需要指定RMVPE模型路径")
return HybridF0Extractor(rmvpe_path, device=device, crepe_threshold=crepe_threshold)
elif method == "pm":
return PMExtractor()
elif method == "harvest":
return HarvestExtractor()
elif method == "crepe":
return CrepeExtractor(device=device)
else:
raise ValueError(f"未知的 F0 提取方法: {method}")
class HybridF0Extractor(F0Extractor):
"""混合F0提取器 - RMVPE主导 + CREPE高精度补充"""
def __init__(self, rmvpe_path: str, sample_rate: int = 16000,
hop_length: int = 160, device: str = "cuda",
crepe_threshold: float = 0.05):
super().__init__(sample_rate, hop_length)
self.device = device
self.rmvpe = RMVPEExtractor(rmvpe_path, sample_rate, hop_length, device)
self.crepe = None # 延迟加载
self.crepe_threshold = crepe_threshold
def _load_crepe(self):
"""延迟加载CREPE模型"""
if self.crepe is None:
try:
self.crepe = CrepeExtractor(self.sample_rate, self.hop_length, self.device)
except ImportError:
print("警告: torchcrepe未安装,混合F0将仅使用RMVPE")
self.crepe = False
def extract(self, audio: np.ndarray) -> np.ndarray:
"""
混合提取F0:
1. 使用RMVPE作为主要方法(快速、稳定)
2. 在RMVPE不稳定的区域使用CREPE补充(高精度)
"""
# 提取RMVPE F0
f0_rmvpe = self.rmvpe.extract(audio)
# 如果CREPE不可用,直接返回RMVPE结果
self._load_crepe()
if self.crepe is False:
return f0_rmvpe
# 提取CREPE F0和置信度
import torchcrepe
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
f0_crepe, confidence = torchcrepe.predict(
audio_tensor,
self.sample_rate,
self.hop_length,
self.f0_min,
self.f0_max,
model="full",
batch_size=512,
device=self.device,
return_periodicity=True
)
f0_crepe = f0_crepe.squeeze(0).cpu().numpy()
confidence = confidence.squeeze(0).cpu().numpy()
# 对齐长度
min_len = min(len(f0_rmvpe), len(f0_crepe), len(confidence))
f0_rmvpe = f0_rmvpe[:min_len]
f0_crepe = f0_crepe[:min_len]
confidence = confidence[:min_len]
# 检测RMVPE不稳定区域
# 1. F0跳变过大(超过3个半音)
f0_diff = np.abs(np.diff(f0_rmvpe, prepend=f0_rmvpe[0]))
semitone_diff = np.abs(12 * np.log2((f0_rmvpe + 1e-6) / (np.roll(f0_rmvpe, 1) + 1e-6)))
semitone_diff[0] = 0
unstable_jump = semitone_diff > 3.0
# 2. CREPE置信度高但RMVPE给出F0=0
unstable_unvoiced = (f0_rmvpe < 1e-3) & (confidence > self.crepe_threshold)
# 3. RMVPE和CREPE差异过大(超过2个半音)且CREPE置信度高
f0_ratio = (f0_crepe + 1e-6) / (f0_rmvpe + 1e-6)
semitone_gap = np.abs(12 * np.log2(f0_ratio))
unstable_diverge = (semitone_gap > 2.0) & (confidence > self.crepe_threshold * 1.5)
# 合并不稳定区域
unstable_mask = unstable_jump | unstable_unvoiced | unstable_diverge
# 扩展不稳定区域(前后各2帧)以平滑过渡
kernel = np.ones(5, dtype=bool)
unstable_mask = np.convolve(unstable_mask, kernel, mode='same')
# 混合F0:不稳定区域使用CREPE,其他区域使用RMVPE
f0_hybrid = f0_rmvpe.copy()
f0_hybrid[unstable_mask] = f0_crepe[unstable_mask]
# 平滑过渡边界
for i in range(1, len(f0_hybrid) - 1):
if unstable_mask[i] != unstable_mask[i-1]:
# 边界处使用加权平均
w = 0.5
f0_hybrid[i] = w * f0_rmvpe[i] + (1-w) * f0_crepe[i]
return f0_hybrid
def shift_f0(f0: np.ndarray, semitones: float) -> np.ndarray:
"""
音调偏移
Args:
f0: 原始 F0
semitones: 偏移半音数 (正数升调,负数降调)
Returns:
np.ndarray: 偏移后的 F0
"""
factor = 2 ** (semitones / 12)
f0_shifted = f0 * factor
return f0_shifted
|