File size: 8,758 Bytes
b15e31b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# -*- coding: utf-8 -*-
"""

F0 (基频) 提取模块 - 支持多种提取方法

"""
import numpy as np
import torch
from typing import Optional, Literal

# F0 提取方法类型
F0Method = Literal["rmvpe", "pm", "harvest", "crepe", "hybrid"]


class F0Extractor:
    """F0 提取器基类"""

    def __init__(self, sample_rate: int = 16000, hop_length: int = 160):
        self.sample_rate = sample_rate
        self.hop_length = hop_length
        self.f0_min = 50
        self.f0_max = 1100

    def extract(self, audio: np.ndarray) -> np.ndarray:
        """提取 F0,子类需实现此方法"""
        raise NotImplementedError


class PMExtractor(F0Extractor):
    """Parselmouth (Praat) F0 提取器 - 速度快"""

    def extract(self, audio: np.ndarray) -> np.ndarray:
        import parselmouth

        time_step = self.hop_length / self.sample_rate
        sound = parselmouth.Sound(audio, self.sample_rate)

        pitch = sound.to_pitch_ac(
            time_step=time_step,
            voicing_threshold=0.6,
            pitch_floor=self.f0_min,
            pitch_ceiling=self.f0_max
        )

        f0 = pitch.selected_array["frequency"]
        f0[f0 == 0] = np.nan

        return f0


class HarvestExtractor(F0Extractor):
    """PyWorld Harvest F0 提取器 - 质量较好"""

    def extract(self, audio: np.ndarray) -> np.ndarray:
        import pyworld

        audio = audio.astype(np.float64)
        f0, _ = pyworld.harvest(
            audio,
            self.sample_rate,
            f0_floor=self.f0_min,
            f0_ceil=self.f0_max,
            frame_period=self.hop_length / self.sample_rate * 1000
        )

        return f0


class CrepeExtractor(F0Extractor):
    """TorchCrepe F0 提取器 - 深度学习方法"""

    def __init__(self, sample_rate: int = 16000, hop_length: int = 160,

                 device: str = "cuda"):
        super().__init__(sample_rate, hop_length)
        self.device = device

    def extract(self, audio: np.ndarray) -> np.ndarray:
        import torchcrepe

        audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
        audio_tensor = audio_tensor.to(self.device)

        f0, _ = torchcrepe.predict(
            audio_tensor,
            self.sample_rate,
            self.hop_length,
            self.f0_min,
            self.f0_max,
            model="full",
            batch_size=512,
            device=self.device,
            return_periodicity=True
        )

        f0 = f0.squeeze(0).cpu().numpy()
        return f0


class RMVPEExtractor(F0Extractor):
    """RMVPE F0 提取器 - 质量最高 (推荐)"""

    def __init__(self, model_path: str, sample_rate: int = 16000,

                 hop_length: int = 160, device: str = "cuda"):
        super().__init__(sample_rate, hop_length)
        self.device = device
        self.model = None
        self.model_path = model_path

    def load_model(self):
        """加载 RMVPE 模型"""
        if self.model is not None:
            return

        from models.rmvpe import RMVPE

        self.model = RMVPE(self.model_path, device=self.device)
        print(f"RMVPE 模型已加载: {self.device}")

    def extract(self, audio: np.ndarray) -> np.ndarray:
        self.load_model()

        # RMVPE 需要 16kHz 输入
        f0 = self.model.infer_from_audio(audio, thred=0.01)

        return f0


def get_f0_extractor(method: F0Method, device: str = "cuda",

                     rmvpe_path: str = None, crepe_threshold: float = 0.05) -> F0Extractor:
    """

    获取 F0 提取器实例



    Args:

        method: 提取方法 ("rmvpe", "pm", "harvest", "crepe", "hybrid")

        device: 计算设备

        rmvpe_path: RMVPE 模型路径 (rmvpe/hybrid 方法需要)

        crepe_threshold: CREPE置信度阈值 (仅hybrid方法使用)



    Returns:

        F0Extractor: 提取器实例

    """
    if method == "rmvpe":
        if rmvpe_path is None:
            raise ValueError("RMVPE 方法需要指定模型路径")
        return RMVPEExtractor(rmvpe_path, device=device)
    elif method == "hybrid":
        if rmvpe_path is None:
            raise ValueError("Hybrid 方法需要指定RMVPE模型路径")
        return HybridF0Extractor(rmvpe_path, device=device, crepe_threshold=crepe_threshold)
    elif method == "pm":
        return PMExtractor()
    elif method == "harvest":
        return HarvestExtractor()
    elif method == "crepe":
        return CrepeExtractor(device=device)
    else:
        raise ValueError(f"未知的 F0 提取方法: {method}")


class HybridF0Extractor(F0Extractor):
    """混合F0提取器 - RMVPE主导 + CREPE高精度补充"""

    def __init__(self, rmvpe_path: str, sample_rate: int = 16000,

                 hop_length: int = 160, device: str = "cuda",

                 crepe_threshold: float = 0.05):
        super().__init__(sample_rate, hop_length)
        self.device = device
        self.rmvpe = RMVPEExtractor(rmvpe_path, sample_rate, hop_length, device)
        self.crepe = None  # 延迟加载
        self.crepe_threshold = crepe_threshold

    def _load_crepe(self):
        """延迟加载CREPE模型"""
        if self.crepe is None:
            try:
                self.crepe = CrepeExtractor(self.sample_rate, self.hop_length, self.device)
            except ImportError:
                print("警告: torchcrepe未安装,混合F0将仅使用RMVPE")
                self.crepe = False

    def extract(self, audio: np.ndarray) -> np.ndarray:
        """

        混合提取F0:

        1. 使用RMVPE作为主要方法(快速、稳定)

        2. 在RMVPE不稳定的区域使用CREPE补充(高精度)

        """
        # 提取RMVPE F0
        f0_rmvpe = self.rmvpe.extract(audio)

        # 如果CREPE不可用,直接返回RMVPE结果
        self._load_crepe()
        if self.crepe is False:
            return f0_rmvpe

        # 提取CREPE F0和置信度
        import torchcrepe
        audio_tensor = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
        f0_crepe, confidence = torchcrepe.predict(
            audio_tensor,
            self.sample_rate,
            self.hop_length,
            self.f0_min,
            self.f0_max,
            model="full",
            batch_size=512,
            device=self.device,
            return_periodicity=True
        )
        f0_crepe = f0_crepe.squeeze(0).cpu().numpy()
        confidence = confidence.squeeze(0).cpu().numpy()

        # 对齐长度
        min_len = min(len(f0_rmvpe), len(f0_crepe), len(confidence))
        f0_rmvpe = f0_rmvpe[:min_len]
        f0_crepe = f0_crepe[:min_len]
        confidence = confidence[:min_len]

        # 检测RMVPE不稳定区域
        # 1. F0跳变过大(超过3个半音)
        f0_diff = np.abs(np.diff(f0_rmvpe, prepend=f0_rmvpe[0]))
        semitone_diff = np.abs(12 * np.log2((f0_rmvpe + 1e-6) / (np.roll(f0_rmvpe, 1) + 1e-6)))
        semitone_diff[0] = 0
        unstable_jump = semitone_diff > 3.0

        # 2. CREPE置信度高但RMVPE给出F0=0
        unstable_unvoiced = (f0_rmvpe < 1e-3) & (confidence > self.crepe_threshold)

        # 3. RMVPE和CREPE差异过大(超过2个半音)且CREPE置信度高
        f0_ratio = (f0_crepe + 1e-6) / (f0_rmvpe + 1e-6)
        semitone_gap = np.abs(12 * np.log2(f0_ratio))
        unstable_diverge = (semitone_gap > 2.0) & (confidence > self.crepe_threshold * 1.5)

        # 合并不稳定区域
        unstable_mask = unstable_jump | unstable_unvoiced | unstable_diverge

        # 扩展不稳定区域(前后各2帧)以平滑过渡
        kernel = np.ones(5, dtype=bool)
        unstable_mask = np.convolve(unstable_mask, kernel, mode='same')

        # 混合F0:不稳定区域使用CREPE,其他区域使用RMVPE
        f0_hybrid = f0_rmvpe.copy()
        f0_hybrid[unstable_mask] = f0_crepe[unstable_mask]

        # 平滑过渡边界
        for i in range(1, len(f0_hybrid) - 1):
            if unstable_mask[i] != unstable_mask[i-1]:
                # 边界处使用加权平均
                w = 0.5
                f0_hybrid[i] = w * f0_rmvpe[i] + (1-w) * f0_crepe[i]

        return f0_hybrid


def shift_f0(f0: np.ndarray, semitones: float) -> np.ndarray:
    """

    音调偏移



    Args:

        f0: 原始 F0

        semitones: 偏移半音数 (正数升调,负数降调)



    Returns:

        np.ndarray: 偏移后的 F0

    """
    factor = 2 ** (semitones / 12)
    f0_shifted = f0 * factor
    return f0_shifted