MiniCPM-o-4.5-nvidia-FlagOS / processing_streaming_mel.py
YummyYum's picture
Upload folder using huggingface_hub
be99bcf verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
流式Mel特征处理器
用于实时音频流的Mel频谱特征提取,支持chunk-based处理。
支持配置CNN冗余以保证与离线处理的一致性。
"""
import logging
from typing import Dict
from typing import Optional
from typing import Tuple
import numpy as np
import torch
from .processing_audio_minicpma import MiniCPMAAudioProcessor
logger = logging.getLogger(__name__)
class StreamingMelProcessorExact:
"""
严格离线等价的流式Mel处理器。
思路:
- 累积全部历史音频到缓冲;每次新增后用同一个 feature_extractor 计算整段 mel。
- 只输出"已稳定"的帧:帧中心不依赖未来(右侧)上下文,即 center + n_fft//2 <= 当前缓冲长度。
- 结束时(flush)再输出最后一批帧,确保与离线全量计算完全一致。
代价:每次会对累积缓冲做一次特征提取(可按需优化为增量)。
"""
def __init__(
self,
feature_extractor: MiniCPMAAudioProcessor,
chunk_ms: int = 100,
first_chunk_ms: Optional[int] = None,
sample_rate: int = 16000,
n_fft: int = 400,
hop_length: int = 160,
n_mels: int = 80,
verbose: bool = False,
cnn_redundancy_ms: int = 10, # (以ms给定,通常10ms=1帧)
# --- 滑窗参数(Trigger模式) ---
enable_sliding_window: bool = False, # 是否启用滑窗
slide_trigger_seconds: float = 30.0, # 触发滑窗的缓冲区秒数阈值
slide_stride_seconds: float = 10.0, # 每次滑窗移动的秒数
):
self.feature_extractor = feature_extractor
self.chunk_ms = chunk_ms
self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
self.verbose = verbose
self.chunk_samples = int(round(chunk_ms * sample_rate / 1000))
self.chunk_frames = self.chunk_samples // hop_length
# 对齐到 hop_length 的整数倍,避免帧边界不齐
hop = self.hop_length
raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000))
aligned_first = max(hop, (raw_first_samples // hop) * hop)
self.first_chunk_samples = aligned_first
self.half_window = n_fft // 2 # 需要的右侧上下文
# 冗余帧数(以帧为单位),<=1帧:10ms → 1帧
self.cnn_redundancy_ms = cnn_redundancy_ms
self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000)
self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length)
# --- 滑窗配置(Trigger模式) ---
self.enable_sliding_window = enable_sliding_window
self.trigger_seconds = slide_trigger_seconds
self.slide_seconds = slide_stride_seconds
# --- 位移/基准(全局帧坐标) ---
self.left_samples_dropped = 0 # 已从左侧丢弃的样本数
self.base_T = 0 # 当前 mel_full[:, :, 0] 对应的"全局帧"下标
self.reset()
def reset(self):
self.buffer = np.zeros(0, dtype=np.float32)
self.last_emitted_T = 0
self.total_samples_processed = 0
self.chunk_count = 0
self.is_first = True
self.left_samples_dropped = 0
self.base_T = 0
def get_chunk_size(self) -> int:
return self.first_chunk_samples if self.is_first else self.chunk_samples
def get_expected_output_frames(self) -> int:
raise NotImplementedError("get_expected_output_frames is not implemented")
def _extract_full(self) -> torch.Tensor:
# 当缓冲长度小于 n_fft 时,Whisper 的内部 STFT 在 center=True 且 pad 模式下会报错
# (pad 大于输入长度)。此时本来也没有稳定帧可输出,所以直接返回空特征。
if len(self.buffer) < self.n_fft:
raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}")
# 如果 buffer 长度 小于 5s 的话,用 set_spac_log_norm(log_floor_db=-10) 或者 上一次缓存的结果
if len(self.buffer) < 5 * self.sample_rate:
# TODO: 这里最好的还是 做一些 实验选择 一个 最好的,现在这个 是通过 经验 选择的, 可以看 MiniCPMAAudioProcessor 的 main 实现
self.feature_extractor.set_spac_log_norm(log_floor_db=-10)
# 如果 buffer 长度 大于 5s 的话,用 set_spac_log_norm(dynamic_range_db=8)
else:
self.feature_extractor.set_spac_log_norm(dynamic_range_db=8)
feats = self.feature_extractor(
self.buffer,
sampling_rate=self.sample_rate,
return_tensors="pt",
padding=False,
)
return feats.input_features # [1, 80, T]
def _stable_frames_count(self) -> int:
# 已稳定帧数 = floor((len(buffer) - half_window) / hop) + 1,最小为0
L = int(self.buffer.shape[0])
if L <= 0:
return 0
if L < self.half_window:
return 0
return max(0, (L - self.half_window) // self.hop_length + 1)
def _maybe_slide_buffer(self):
"""Trigger模式滑窗:当缓冲区达到触发阈值时,滑动固定长度的窗口。"""
if not self.enable_sliding_window:
return
sr = self.sample_rate
hop = self.hop_length
L = len(self.buffer)
# 将秒数转换为样本数
trigger_samples = int(self.trigger_seconds * sr)
stride_samples = int(self.slide_seconds * sr)
# 检查是否达到触发阈值
if L < trigger_samples:
return
# 计算需要丢弃的样本数(固定滑动 stride_samples)
drop = stride_samples
# 不能丢掉后续发射还需要的左侧上下文
# 在trigger模式下,我们只需要保护最小必要的数据
# 即:确保不丢弃未来可能需要的帧
last_emitted_local = self.last_emitted_T - self.base_T
# 只保护必要的上下文(例如,最近的1秒数据)
min_keep_seconds = 1.0 # 保留至少1秒的数据以确保处理的连续性
min_keep_samples = int(min_keep_seconds * sr)
# guard_samples 是我们必须保留的最小样本数
guard_samples = min(min_keep_samples, L - drop)
# 限制:不得越过安全边界;并对齐 hop
max_allowed_drop = max(0, L - guard_samples)
drop = min(drop, max_allowed_drop)
drop = (drop // hop) * hop
if drop <= 0:
return
# 真正丢弃 & 更新基准
self.buffer = self.buffer[drop:]
self.left_samples_dropped += drop
self.base_T += drop // hop
if self.verbose:
print(
f"[Slide] Trigger模式: drop={drop/sr:.2f}s samples, base_T={self.base_T}, buffer_after={len(self.buffer)/sr:.2f}s"
)
def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]:
self.chunk_count += 1
# 追加到缓冲
if len(self.buffer) == 0:
self.buffer = audio_chunk.astype(np.float32, copy=True)
else:
self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)])
# --- 滑窗处理 ---
self._maybe_slide_buffer()
# 全量提取(针对当前窗口)
mel_full = self._extract_full()
T_full = mel_full.shape[-1] # 当前窗口的局部帧数
stable_T = min(T_full, self._stable_frames_count()) # 局部可稳定帧
stable_T_global = self.base_T + stable_T # 映射到全局帧坐标
# 计划本次发射的核心帧(全局坐标)
core_start_g = self.last_emitted_T
core_end_g = core_start_g + self.chunk_frames
required_stable_g = core_end_g + self.cnn_redundancy_frames
if self.verbose:
print(
f"[Exact] buffer_len={len(self.buffer)} samples, T_full(local)={T_full}, "
f"stable_T(local)={stable_T}, base_T={self.base_T}, "
f"stable_T(global)={stable_T_global}, last_emitted={self.last_emitted_T}"
)
if stable_T_global >= required_stable_g or is_last_chunk:
emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames)
emit_end_g = core_end_g + self.cnn_redundancy_frames
# 全局 -> 局部索引
emit_start = max(0, emit_start_g - self.base_T)
emit_end = emit_end_g - self.base_T
emit_start = max(0, min(emit_start, T_full))
emit_end = max(emit_start, min(emit_end, T_full))
mel_output = mel_full[:, :, emit_start:emit_end]
self.last_emitted_T = core_end_g # 仅推进核心帧指针(全局)
else:
mel_output = mel_full[:, :, 0:0]
self.total_samples_processed += len(audio_chunk)
self.is_first = False
info = {
"type": "exact_chunk",
"chunk_number": self.chunk_count,
"emitted_frames": mel_output.shape[-1],
"stable_T": stable_T,
"T_full": T_full,
"base_T": self.base_T,
"stable_T_global": stable_T_global,
"buffer_len_samples": int(self.buffer.shape[0]),
"left_samples_dropped": self.left_samples_dropped,
"core_start": core_start_g, # 如果保留原字段名,这里用全局值
"core_end": core_end_g, # 同上
}
return mel_output, info
def flush(self) -> torch.Tensor:
"""在流结束时调用,输出剩余未发出的帧,保证与离线一致(按全局坐标计算)。"""
if len(self.buffer) == 0:
return torch.zeros(1, 80, 0)
mel_full = self._extract_full()
T_local = mel_full.shape[-1]
T_global = self.base_T + T_local
if self.last_emitted_T < T_global:
start_l = max(0, self.last_emitted_T - self.base_T)
tail = mel_full[:, :, start_l:]
self.last_emitted_T = T_global
if self.verbose:
print(f"[Exact] flush {tail.shape[-1]} frames (T_global={T_global})")
return tail
return mel_full[:, :, 0:0]
def get_config(self) -> Dict:
return {
"chunk_ms": self.chunk_ms,
"first_chunk_ms": self.first_chunk_ms,
"effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0,
"sample_rate": self.sample_rate,
"n_fft": self.n_fft,
"hop_length": self.hop_length,
"cnn_redundancy_ms": self.cnn_redundancy_ms,
"cnn_redundancy_frames": self.cnn_redundancy_frames,
"enable_sliding_window": self.enable_sliding_window,
"trigger_seconds": self.trigger_seconds,
"slide_seconds": self.slide_seconds,
}
def get_state(self) -> Dict:
return {
"chunk_count": self.chunk_count,
"last_emitted_T": self.last_emitted_T,
"total_samples_processed": self.total_samples_processed,
"buffer_len": int(self.buffer.shape[0]),
"base_T": self.base_T,
"left_samples_dropped": self.left_samples_dropped,
}
def get_snapshot(self) -> Dict:
"""获取完整状态快照(包括 buffer),用于抢跑恢复
Returns:
包含完整状态的字典,可用于 restore_snapshot 恢复
"""
buffer_copy = self.buffer.copy()
snapshot = {
"chunk_count": self.chunk_count,
"last_emitted_T": self.last_emitted_T,
"total_samples_processed": self.total_samples_processed,
"buffer": buffer_copy,
"base_T": self.base_T,
"left_samples_dropped": self.left_samples_dropped,
"is_first": self.is_first,
# 保存 feature_extractor 的状态(关键:确保 mel 特征提取的确定性)
"fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None),
"fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None),
"fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None),
}
logger.debug(
"[MelProcessor] Created snapshot: chunk_count=%d, last_emitted_T=%d, "
"buffer_len=%d, buffer_sum=%.6f, total_samples=%d",
self.chunk_count,
self.last_emitted_T,
len(buffer_copy),
float(buffer_copy.sum()) if len(buffer_copy) > 0 else 0.0,
self.total_samples_processed,
)
return snapshot
def restore_snapshot(self, snapshot: Dict) -> None:
"""从快照恢复状态
Args:
snapshot: 由 get_snapshot 返回的快照字典
"""
# 记录恢复前的状态
prev_state = {
"chunk_count": self.chunk_count,
"last_emitted_T": self.last_emitted_T,
"buffer_len": len(self.buffer),
}
# 恢复状态
self.chunk_count = snapshot["chunk_count"]
self.last_emitted_T = snapshot["last_emitted_T"]
self.total_samples_processed = snapshot["total_samples_processed"]
self.buffer = snapshot["buffer"].copy() # 复制 buffer
self.base_T = snapshot["base_T"]
self.left_samples_dropped = snapshot["left_samples_dropped"]
self.is_first = snapshot["is_first"]
# 恢复 feature_extractor 的状态(关键:确保 mel 特征提取的确定性)
if snapshot.get("fe_dynamic_log_norm") is not None:
self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"]
if snapshot.get("fe_dynamic_range_db") is not None:
self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"]
if snapshot.get("fe_log_floor_db") is not None:
self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"]
logger.info(
"[MelProcessor] Restored snapshot: chunk_count %d->%d, last_emitted_T %d->%d, "
"buffer_len %d->%d, total_samples=%d",
prev_state["chunk_count"],
self.chunk_count,
prev_state["last_emitted_T"],
self.last_emitted_T,
prev_state["buffer_len"],
len(self.buffer),
self.total_samples_processed,
)