File size: 14,710 Bytes
be99bcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
流式Mel特征处理器
用于实时音频流的Mel频谱特征提取,支持chunk-based处理。
支持配置CNN冗余以保证与离线处理的一致性。
"""
import logging
from typing import Dict
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from .processing_audio_minicpma import MiniCPMAAudioProcessor
logger = logging.getLogger(__name__)
class StreamingMelProcessorExact:
"""
严格离线等价的流式Mel处理器。
思路:
- 累积全部历史音频到缓冲;每次新增后用同一个 feature_extractor 计算整段 mel。
- 只输出"已稳定"的帧:帧中心不依赖未来(右侧)上下文,即 center + n_fft//2 <= 当前缓冲长度。
- 结束时(flush)再输出最后一批帧,确保与离线全量计算完全一致。
代价:每次会对累积缓冲做一次特征提取(可按需优化为增量)。
"""
def __init__(
self,
feature_extractor: MiniCPMAAudioProcessor,
chunk_ms: int = 100,
first_chunk_ms: Optional[int] = None,
sample_rate: int = 16000,
n_fft: int = 400,
hop_length: int = 160,
n_mels: int = 80,
verbose: bool = False,
cnn_redundancy_ms: int = 10, # (以ms给定,通常10ms=1帧)
# --- 滑窗参数(Trigger模式) ---
enable_sliding_window: bool = False, # 是否启用滑窗
slide_trigger_seconds: float = 30.0, # 触发滑窗的缓冲区秒数阈值
slide_stride_seconds: float = 10.0, # 每次滑窗移动的秒数
):
self.feature_extractor = feature_extractor
self.chunk_ms = chunk_ms
self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
self.verbose = verbose
self.chunk_samples = int(round(chunk_ms * sample_rate / 1000))
self.chunk_frames = self.chunk_samples // hop_length
# 对齐到 hop_length 的整数倍,避免帧边界不齐
hop = self.hop_length
raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000))
aligned_first = max(hop, (raw_first_samples // hop) * hop)
self.first_chunk_samples = aligned_first
self.half_window = n_fft // 2 # 需要的右侧上下文
# 冗余帧数(以帧为单位),<=1帧:10ms → 1帧
self.cnn_redundancy_ms = cnn_redundancy_ms
self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000)
self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length)
# --- 滑窗配置(Trigger模式) ---
self.enable_sliding_window = enable_sliding_window
self.trigger_seconds = slide_trigger_seconds
self.slide_seconds = slide_stride_seconds
# --- 位移/基准(全局帧坐标) ---
self.left_samples_dropped = 0 # 已从左侧丢弃的样本数
self.base_T = 0 # 当前 mel_full[:, :, 0] 对应的"全局帧"下标
self.reset()
def reset(self):
self.buffer = np.zeros(0, dtype=np.float32)
self.last_emitted_T = 0
self.total_samples_processed = 0
self.chunk_count = 0
self.is_first = True
self.left_samples_dropped = 0
self.base_T = 0
def get_chunk_size(self) -> int:
return self.first_chunk_samples if self.is_first else self.chunk_samples
def get_expected_output_frames(self) -> int:
raise NotImplementedError("get_expected_output_frames is not implemented")
def _extract_full(self) -> torch.Tensor:
# 当缓冲长度小于 n_fft 时,Whisper 的内部 STFT 在 center=True 且 pad 模式下会报错
# (pad 大于输入长度)。此时本来也没有稳定帧可输出,所以直接返回空特征。
if len(self.buffer) < self.n_fft:
raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}")
# 如果 buffer 长度 小于 5s 的话,用 set_spac_log_norm(log_floor_db=-10) 或者 上一次缓存的结果
if len(self.buffer) < 5 * self.sample_rate:
# TODO: 这里最好的还是 做一些 实验选择 一个 最好的,现在这个 是通过 经验 选择的, 可以看 MiniCPMAAudioProcessor 的 main 实现
self.feature_extractor.set_spac_log_norm(log_floor_db=-10)
# 如果 buffer 长度 大于 5s 的话,用 set_spac_log_norm(dynamic_range_db=8)
else:
self.feature_extractor.set_spac_log_norm(dynamic_range_db=8)
feats = self.feature_extractor(
self.buffer,
sampling_rate=self.sample_rate,
return_tensors="pt",
padding=False,
)
return feats.input_features # [1, 80, T]
def _stable_frames_count(self) -> int:
# 已稳定帧数 = floor((len(buffer) - half_window) / hop) + 1,最小为0
L = int(self.buffer.shape[0])
if L <= 0:
return 0
if L < self.half_window:
return 0
return max(0, (L - self.half_window) // self.hop_length + 1)
def _maybe_slide_buffer(self):
"""Trigger模式滑窗:当缓冲区达到触发阈值时,滑动固定长度的窗口。"""
if not self.enable_sliding_window:
return
sr = self.sample_rate
hop = self.hop_length
L = len(self.buffer)
# 将秒数转换为样本数
trigger_samples = int(self.trigger_seconds * sr)
stride_samples = int(self.slide_seconds * sr)
# 检查是否达到触发阈值
if L < trigger_samples:
return
# 计算需要丢弃的样本数(固定滑动 stride_samples)
drop = stride_samples
# 不能丢掉后续发射还需要的左侧上下文
# 在trigger模式下,我们只需要保护最小必要的数据
# 即:确保不丢弃未来可能需要的帧
last_emitted_local = self.last_emitted_T - self.base_T
# 只保护必要的上下文(例如,最近的1秒数据)
min_keep_seconds = 1.0 # 保留至少1秒的数据以确保处理的连续性
min_keep_samples = int(min_keep_seconds * sr)
# guard_samples 是我们必须保留的最小样本数
guard_samples = min(min_keep_samples, L - drop)
# 限制:不得越过安全边界;并对齐 hop
max_allowed_drop = max(0, L - guard_samples)
drop = min(drop, max_allowed_drop)
drop = (drop // hop) * hop
if drop <= 0:
return
# 真正丢弃 & 更新基准
self.buffer = self.buffer[drop:]
self.left_samples_dropped += drop
self.base_T += drop // hop
if self.verbose:
print(
f"[Slide] Trigger模式: drop={drop/sr:.2f}s samples, base_T={self.base_T}, buffer_after={len(self.buffer)/sr:.2f}s"
)
def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]:
self.chunk_count += 1
# 追加到缓冲
if len(self.buffer) == 0:
self.buffer = audio_chunk.astype(np.float32, copy=True)
else:
self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)])
# --- 滑窗处理 ---
self._maybe_slide_buffer()
# 全量提取(针对当前窗口)
mel_full = self._extract_full()
T_full = mel_full.shape[-1] # 当前窗口的局部帧数
stable_T = min(T_full, self._stable_frames_count()) # 局部可稳定帧
stable_T_global = self.base_T + stable_T # 映射到全局帧坐标
# 计划本次发射的核心帧(全局坐标)
core_start_g = self.last_emitted_T
core_end_g = core_start_g + self.chunk_frames
required_stable_g = core_end_g + self.cnn_redundancy_frames
if self.verbose:
print(
f"[Exact] buffer_len={len(self.buffer)} samples, T_full(local)={T_full}, "
f"stable_T(local)={stable_T}, base_T={self.base_T}, "
f"stable_T(global)={stable_T_global}, last_emitted={self.last_emitted_T}"
)
if stable_T_global >= required_stable_g or is_last_chunk:
emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames)
emit_end_g = core_end_g + self.cnn_redundancy_frames
# 全局 -> 局部索引
emit_start = max(0, emit_start_g - self.base_T)
emit_end = emit_end_g - self.base_T
emit_start = max(0, min(emit_start, T_full))
emit_end = max(emit_start, min(emit_end, T_full))
mel_output = mel_full[:, :, emit_start:emit_end]
self.last_emitted_T = core_end_g # 仅推进核心帧指针(全局)
else:
mel_output = mel_full[:, :, 0:0]
self.total_samples_processed += len(audio_chunk)
self.is_first = False
info = {
"type": "exact_chunk",
"chunk_number": self.chunk_count,
"emitted_frames": mel_output.shape[-1],
"stable_T": stable_T,
"T_full": T_full,
"base_T": self.base_T,
"stable_T_global": stable_T_global,
"buffer_len_samples": int(self.buffer.shape[0]),
"left_samples_dropped": self.left_samples_dropped,
"core_start": core_start_g, # 如果保留原字段名,这里用全局值
"core_end": core_end_g, # 同上
}
return mel_output, info
def flush(self) -> torch.Tensor:
"""在流结束时调用,输出剩余未发出的帧,保证与离线一致(按全局坐标计算)。"""
if len(self.buffer) == 0:
return torch.zeros(1, 80, 0)
mel_full = self._extract_full()
T_local = mel_full.shape[-1]
T_global = self.base_T + T_local
if self.last_emitted_T < T_global:
start_l = max(0, self.last_emitted_T - self.base_T)
tail = mel_full[:, :, start_l:]
self.last_emitted_T = T_global
if self.verbose:
print(f"[Exact] flush {tail.shape[-1]} frames (T_global={T_global})")
return tail
return mel_full[:, :, 0:0]
def get_config(self) -> Dict:
return {
"chunk_ms": self.chunk_ms,
"first_chunk_ms": self.first_chunk_ms,
"effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0,
"sample_rate": self.sample_rate,
"n_fft": self.n_fft,
"hop_length": self.hop_length,
"cnn_redundancy_ms": self.cnn_redundancy_ms,
"cnn_redundancy_frames": self.cnn_redundancy_frames,
"enable_sliding_window": self.enable_sliding_window,
"trigger_seconds": self.trigger_seconds,
"slide_seconds": self.slide_seconds,
}
def get_state(self) -> Dict:
return {
"chunk_count": self.chunk_count,
"last_emitted_T": self.last_emitted_T,
"total_samples_processed": self.total_samples_processed,
"buffer_len": int(self.buffer.shape[0]),
"base_T": self.base_T,
"left_samples_dropped": self.left_samples_dropped,
}
def get_snapshot(self) -> Dict:
"""获取完整状态快照(包括 buffer),用于抢跑恢复
Returns:
包含完整状态的字典,可用于 restore_snapshot 恢复
"""
buffer_copy = self.buffer.copy()
snapshot = {
"chunk_count": self.chunk_count,
"last_emitted_T": self.last_emitted_T,
"total_samples_processed": self.total_samples_processed,
"buffer": buffer_copy,
"base_T": self.base_T,
"left_samples_dropped": self.left_samples_dropped,
"is_first": self.is_first,
# 保存 feature_extractor 的状态(关键:确保 mel 特征提取的确定性)
"fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None),
"fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None),
"fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None),
}
logger.debug(
"[MelProcessor] Created snapshot: chunk_count=%d, last_emitted_T=%d, "
"buffer_len=%d, buffer_sum=%.6f, total_samples=%d",
self.chunk_count,
self.last_emitted_T,
len(buffer_copy),
float(buffer_copy.sum()) if len(buffer_copy) > 0 else 0.0,
self.total_samples_processed,
)
return snapshot
def restore_snapshot(self, snapshot: Dict) -> None:
"""从快照恢复状态
Args:
snapshot: 由 get_snapshot 返回的快照字典
"""
# 记录恢复前的状态
prev_state = {
"chunk_count": self.chunk_count,
"last_emitted_T": self.last_emitted_T,
"buffer_len": len(self.buffer),
}
# 恢复状态
self.chunk_count = snapshot["chunk_count"]
self.last_emitted_T = snapshot["last_emitted_T"]
self.total_samples_processed = snapshot["total_samples_processed"]
self.buffer = snapshot["buffer"].copy() # 复制 buffer
self.base_T = snapshot["base_T"]
self.left_samples_dropped = snapshot["left_samples_dropped"]
self.is_first = snapshot["is_first"]
# 恢复 feature_extractor 的状态(关键:确保 mel 特征提取的确定性)
if snapshot.get("fe_dynamic_log_norm") is not None:
self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"]
if snapshot.get("fe_dynamic_range_db") is not None:
self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"]
if snapshot.get("fe_log_floor_db") is not None:
self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"]
logger.info(
"[MelProcessor] Restored snapshot: chunk_count %d->%d, last_emitted_T %d->%d, "
"buffer_len %d->%d, total_samples=%d",
prev_state["chunk_count"],
self.chunk_count,
prev_state["last_emitted_T"],
self.last_emitted_T,
prev_state["buffer_len"],
len(self.buffer),
self.total_samples_processed,
)
|