| | |
| | |
| |
|
| | """ |
| | 流式Mel特征处理器 |
| | |
| | 用于实时音频流的Mel频谱特征提取,支持chunk-based处理。 |
| | 支持配置CNN冗余以保证与离线处理的一致性。 |
| | """ |
| |
|
| | import logging |
| | from typing import Dict |
| | from typing import Optional |
| | from typing import Tuple |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from .processing_audio_minicpma import MiniCPMAAudioProcessor |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class StreamingMelProcessorExact: |
| | """ |
| | 严格离线等价的流式Mel处理器。 |
| | |
| | 思路: |
| | - 累积全部历史音频到缓冲;每次新增后用同一个 feature_extractor 计算整段 mel。 |
| | - 只输出"已稳定"的帧:帧中心不依赖未来(右侧)上下文,即 center + n_fft//2 <= 当前缓冲长度。 |
| | - 结束时(flush)再输出最后一批帧,确保与离线全量计算完全一致。 |
| | |
| | 代价:每次会对累积缓冲做一次特征提取(可按需优化为增量)。 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | feature_extractor: MiniCPMAAudioProcessor, |
| | chunk_ms: int = 100, |
| | first_chunk_ms: Optional[int] = None, |
| | sample_rate: int = 16000, |
| | n_fft: int = 400, |
| | hop_length: int = 160, |
| | n_mels: int = 80, |
| | verbose: bool = False, |
| | cnn_redundancy_ms: int = 10, |
| | |
| | enable_sliding_window: bool = False, |
| | slide_trigger_seconds: float = 30.0, |
| | slide_stride_seconds: float = 10.0, |
| | ): |
| | self.feature_extractor = feature_extractor |
| | self.chunk_ms = chunk_ms |
| | self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms |
| | self.sample_rate = sample_rate |
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.n_mels = n_mels |
| | self.verbose = verbose |
| |
|
| | self.chunk_samples = int(round(chunk_ms * sample_rate / 1000)) |
| | self.chunk_frames = self.chunk_samples // hop_length |
| | |
| | hop = self.hop_length |
| | raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000)) |
| | aligned_first = max(hop, (raw_first_samples // hop) * hop) |
| | self.first_chunk_samples = aligned_first |
| | self.half_window = n_fft // 2 |
| |
|
| | |
| | self.cnn_redundancy_ms = cnn_redundancy_ms |
| | self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000) |
| | self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length) |
| |
|
| | |
| | self.enable_sliding_window = enable_sliding_window |
| | self.trigger_seconds = slide_trigger_seconds |
| | self.slide_seconds = slide_stride_seconds |
| |
|
| | |
| | self.left_samples_dropped = 0 |
| | self.base_T = 0 |
| |
|
| | self.reset() |
| |
|
| | def reset(self): |
| | self.buffer = np.zeros(0, dtype=np.float32) |
| | self.last_emitted_T = 0 |
| | self.total_samples_processed = 0 |
| | self.chunk_count = 0 |
| | self.is_first = True |
| | self.left_samples_dropped = 0 |
| | self.base_T = 0 |
| |
|
| | def get_chunk_size(self) -> int: |
| | return self.first_chunk_samples if self.is_first else self.chunk_samples |
| |
|
| | def get_expected_output_frames(self) -> int: |
| | raise NotImplementedError("get_expected_output_frames is not implemented") |
| |
|
| | def _extract_full(self) -> torch.Tensor: |
| | |
| | |
| | if len(self.buffer) < self.n_fft: |
| | raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}") |
| | |
| | if len(self.buffer) < 5 * self.sample_rate: |
| | |
| | self.feature_extractor.set_spac_log_norm(log_floor_db=-10) |
| | |
| | else: |
| | self.feature_extractor.set_spac_log_norm(dynamic_range_db=8) |
| | feats = self.feature_extractor( |
| | self.buffer, |
| | sampling_rate=self.sample_rate, |
| | return_tensors="pt", |
| | padding=False, |
| | ) |
| | return feats.input_features |
| |
|
| | def _stable_frames_count(self) -> int: |
| | |
| | L = int(self.buffer.shape[0]) |
| | if L <= 0: |
| | return 0 |
| | if L < self.half_window: |
| | return 0 |
| | return max(0, (L - self.half_window) // self.hop_length + 1) |
| |
|
| | def _maybe_slide_buffer(self): |
| | """Trigger模式滑窗:当缓冲区达到触发阈值时,滑动固定长度的窗口。""" |
| | if not self.enable_sliding_window: |
| | return |
| |
|
| | sr = self.sample_rate |
| | hop = self.hop_length |
| | L = len(self.buffer) |
| |
|
| | |
| | trigger_samples = int(self.trigger_seconds * sr) |
| | stride_samples = int(self.slide_seconds * sr) |
| |
|
| | |
| | if L < trigger_samples: |
| | return |
| |
|
| | |
| | drop = stride_samples |
| |
|
| | |
| | |
| | |
| | last_emitted_local = self.last_emitted_T - self.base_T |
| |
|
| | |
| | min_keep_seconds = 1.0 |
| | min_keep_samples = int(min_keep_seconds * sr) |
| |
|
| | |
| | guard_samples = min(min_keep_samples, L - drop) |
| |
|
| | |
| | max_allowed_drop = max(0, L - guard_samples) |
| | drop = min(drop, max_allowed_drop) |
| | drop = (drop // hop) * hop |
| |
|
| | if drop <= 0: |
| | return |
| |
|
| | |
| | self.buffer = self.buffer[drop:] |
| | self.left_samples_dropped += drop |
| | self.base_T += drop // hop |
| |
|
| | if self.verbose: |
| | print( |
| | f"[Slide] Trigger模式: drop={drop/sr:.2f}s samples, base_T={self.base_T}, buffer_after={len(self.buffer)/sr:.2f}s" |
| | ) |
| |
|
| | def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]: |
| | self.chunk_count += 1 |
| | |
| | if len(self.buffer) == 0: |
| | self.buffer = audio_chunk.astype(np.float32, copy=True) |
| | else: |
| | self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)]) |
| |
|
| | |
| | self._maybe_slide_buffer() |
| |
|
| | |
| | mel_full = self._extract_full() |
| | T_full = mel_full.shape[-1] |
| | stable_T = min(T_full, self._stable_frames_count()) |
| | stable_T_global = self.base_T + stable_T |
| |
|
| | |
| | core_start_g = self.last_emitted_T |
| | core_end_g = core_start_g + self.chunk_frames |
| | required_stable_g = core_end_g + self.cnn_redundancy_frames |
| |
|
| | if self.verbose: |
| | print( |
| | f"[Exact] buffer_len={len(self.buffer)} samples, T_full(local)={T_full}, " |
| | f"stable_T(local)={stable_T}, base_T={self.base_T}, " |
| | f"stable_T(global)={stable_T_global}, last_emitted={self.last_emitted_T}" |
| | ) |
| |
|
| | if stable_T_global >= required_stable_g or is_last_chunk: |
| | emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames) |
| | emit_end_g = core_end_g + self.cnn_redundancy_frames |
| |
|
| | |
| | emit_start = max(0, emit_start_g - self.base_T) |
| | emit_end = emit_end_g - self.base_T |
| | emit_start = max(0, min(emit_start, T_full)) |
| | emit_end = max(emit_start, min(emit_end, T_full)) |
| |
|
| | mel_output = mel_full[:, :, emit_start:emit_end] |
| | self.last_emitted_T = core_end_g |
| | else: |
| | mel_output = mel_full[:, :, 0:0] |
| |
|
| | self.total_samples_processed += len(audio_chunk) |
| | self.is_first = False |
| |
|
| | info = { |
| | "type": "exact_chunk", |
| | "chunk_number": self.chunk_count, |
| | "emitted_frames": mel_output.shape[-1], |
| | "stable_T": stable_T, |
| | "T_full": T_full, |
| | "base_T": self.base_T, |
| | "stable_T_global": stable_T_global, |
| | "buffer_len_samples": int(self.buffer.shape[0]), |
| | "left_samples_dropped": self.left_samples_dropped, |
| | "core_start": core_start_g, |
| | "core_end": core_end_g, |
| | } |
| | return mel_output, info |
| |
|
| | def flush(self) -> torch.Tensor: |
| | """在流结束时调用,输出剩余未发出的帧,保证与离线一致(按全局坐标计算)。""" |
| | if len(self.buffer) == 0: |
| | return torch.zeros(1, 80, 0) |
| |
|
| | mel_full = self._extract_full() |
| | T_local = mel_full.shape[-1] |
| | T_global = self.base_T + T_local |
| |
|
| | if self.last_emitted_T < T_global: |
| | start_l = max(0, self.last_emitted_T - self.base_T) |
| | tail = mel_full[:, :, start_l:] |
| | self.last_emitted_T = T_global |
| | if self.verbose: |
| | print(f"[Exact] flush {tail.shape[-1]} frames (T_global={T_global})") |
| | return tail |
| | return mel_full[:, :, 0:0] |
| |
|
| | def get_config(self) -> Dict: |
| | return { |
| | "chunk_ms": self.chunk_ms, |
| | "first_chunk_ms": self.first_chunk_ms, |
| | "effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0, |
| | "sample_rate": self.sample_rate, |
| | "n_fft": self.n_fft, |
| | "hop_length": self.hop_length, |
| | "cnn_redundancy_ms": self.cnn_redundancy_ms, |
| | "cnn_redundancy_frames": self.cnn_redundancy_frames, |
| | "enable_sliding_window": self.enable_sliding_window, |
| | "trigger_seconds": self.trigger_seconds, |
| | "slide_seconds": self.slide_seconds, |
| | } |
| |
|
| | def get_state(self) -> Dict: |
| | return { |
| | "chunk_count": self.chunk_count, |
| | "last_emitted_T": self.last_emitted_T, |
| | "total_samples_processed": self.total_samples_processed, |
| | "buffer_len": int(self.buffer.shape[0]), |
| | "base_T": self.base_T, |
| | "left_samples_dropped": self.left_samples_dropped, |
| | } |
| |
|
| | def get_snapshot(self) -> Dict: |
| | """获取完整状态快照(包括 buffer),用于抢跑恢复 |
| | |
| | Returns: |
| | 包含完整状态的字典,可用于 restore_snapshot 恢复 |
| | """ |
| | buffer_copy = self.buffer.copy() |
| | snapshot = { |
| | "chunk_count": self.chunk_count, |
| | "last_emitted_T": self.last_emitted_T, |
| | "total_samples_processed": self.total_samples_processed, |
| | "buffer": buffer_copy, |
| | "base_T": self.base_T, |
| | "left_samples_dropped": self.left_samples_dropped, |
| | "is_first": self.is_first, |
| | |
| | "fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None), |
| | "fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None), |
| | "fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None), |
| | } |
| | logger.debug( |
| | "[MelProcessor] Created snapshot: chunk_count=%d, last_emitted_T=%d, " |
| | "buffer_len=%d, buffer_sum=%.6f, total_samples=%d", |
| | self.chunk_count, |
| | self.last_emitted_T, |
| | len(buffer_copy), |
| | float(buffer_copy.sum()) if len(buffer_copy) > 0 else 0.0, |
| | self.total_samples_processed, |
| | ) |
| | return snapshot |
| |
|
| | def restore_snapshot(self, snapshot: Dict) -> None: |
| | """从快照恢复状态 |
| | |
| | Args: |
| | snapshot: 由 get_snapshot 返回的快照字典 |
| | """ |
| | |
| | prev_state = { |
| | "chunk_count": self.chunk_count, |
| | "last_emitted_T": self.last_emitted_T, |
| | "buffer_len": len(self.buffer), |
| | } |
| |
|
| | |
| | self.chunk_count = snapshot["chunk_count"] |
| | self.last_emitted_T = snapshot["last_emitted_T"] |
| | self.total_samples_processed = snapshot["total_samples_processed"] |
| | self.buffer = snapshot["buffer"].copy() |
| | self.base_T = snapshot["base_T"] |
| | self.left_samples_dropped = snapshot["left_samples_dropped"] |
| | self.is_first = snapshot["is_first"] |
| |
|
| | |
| | if snapshot.get("fe_dynamic_log_norm") is not None: |
| | self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"] |
| | if snapshot.get("fe_dynamic_range_db") is not None: |
| | self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"] |
| | if snapshot.get("fe_log_floor_db") is not None: |
| | self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"] |
| |
|
| | logger.info( |
| | "[MelProcessor] Restored snapshot: chunk_count %d->%d, last_emitted_T %d->%d, " |
| | "buffer_len %d->%d, total_samples=%d", |
| | prev_state["chunk_count"], |
| | self.chunk_count, |
| | prev_state["last_emitted_T"], |
| | self.last_emitted_T, |
| | prev_state["buffer_len"], |
| | len(self.buffer), |
| | self.total_samples_processed, |
| | ) |
| |
|