File size: 14,710 Bytes
be99bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
流式Mel特征处理器

用于实时音频流的Mel频谱特征提取,支持chunk-based处理。
支持配置CNN冗余以保证与离线处理的一致性。
"""

import logging
from typing import Dict
from typing import Optional
from typing import Tuple

import numpy as np
import torch

from .processing_audio_minicpma import MiniCPMAAudioProcessor

logger = logging.getLogger(__name__)


class StreamingMelProcessorExact:
    """
    严格离线等价的流式Mel处理器。

    思路:
    - 累积全部历史音频到缓冲;每次新增后用同一个 feature_extractor 计算整段 mel。
    - 只输出"已稳定"的帧:帧中心不依赖未来(右侧)上下文,即 center + n_fft//2 <= 当前缓冲长度。
    - 结束时(flush)再输出最后一批帧,确保与离线全量计算完全一致。

    代价:每次会对累积缓冲做一次特征提取(可按需优化为增量)。
    """

    def __init__(
        self,
        feature_extractor: MiniCPMAAudioProcessor,
        chunk_ms: int = 100,
        first_chunk_ms: Optional[int] = None,
        sample_rate: int = 16000,
        n_fft: int = 400,
        hop_length: int = 160,
        n_mels: int = 80,
        verbose: bool = False,
        cnn_redundancy_ms: int = 10,  # (以ms给定,通常10ms=1帧)
        # --- 滑窗参数(Trigger模式) ---
        enable_sliding_window: bool = False,  # 是否启用滑窗
        slide_trigger_seconds: float = 30.0,  # 触发滑窗的缓冲区秒数阈值
        slide_stride_seconds: float = 10.0,  # 每次滑窗移动的秒数
    ):
        self.feature_extractor = feature_extractor
        self.chunk_ms = chunk_ms
        self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels
        self.verbose = verbose

        self.chunk_samples = int(round(chunk_ms * sample_rate / 1000))
        self.chunk_frames = self.chunk_samples // hop_length
        # 对齐到 hop_length 的整数倍,避免帧边界不齐
        hop = self.hop_length
        raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000))
        aligned_first = max(hop, (raw_first_samples // hop) * hop)
        self.first_chunk_samples = aligned_first
        self.half_window = n_fft // 2  # 需要的右侧上下文

        # 冗余帧数(以帧为单位),<=1帧:10ms → 1帧
        self.cnn_redundancy_ms = cnn_redundancy_ms
        self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000)
        self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length)

        # --- 滑窗配置(Trigger模式) ---
        self.enable_sliding_window = enable_sliding_window
        self.trigger_seconds = slide_trigger_seconds
        self.slide_seconds = slide_stride_seconds

        # --- 位移/基准(全局帧坐标) ---
        self.left_samples_dropped = 0  # 已从左侧丢弃的样本数
        self.base_T = 0  # 当前 mel_full[:, :, 0] 对应的"全局帧"下标

        self.reset()

    def reset(self):
        self.buffer = np.zeros(0, dtype=np.float32)
        self.last_emitted_T = 0
        self.total_samples_processed = 0
        self.chunk_count = 0
        self.is_first = True
        self.left_samples_dropped = 0
        self.base_T = 0

    def get_chunk_size(self) -> int:
        return self.first_chunk_samples if self.is_first else self.chunk_samples

    def get_expected_output_frames(self) -> int:
        raise NotImplementedError("get_expected_output_frames is not implemented")

    def _extract_full(self) -> torch.Tensor:
        # 当缓冲长度小于 n_fft 时,Whisper 的内部 STFT 在 center=True 且 pad 模式下会报错
        # (pad 大于输入长度)。此时本来也没有稳定帧可输出,所以直接返回空特征。
        if len(self.buffer) < self.n_fft:
            raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}")
        # 如果 buffer 长度 小于 5s 的话,用 set_spac_log_norm(log_floor_db=-10) 或者 上一次缓存的结果
        if len(self.buffer) < 5 * self.sample_rate:
            # TODO: 这里最好的还是 做一些 实验选择 一个 最好的,现在这个 是通过 经验 选择的, 可以看 MiniCPMAAudioProcessor 的 main 实现
            self.feature_extractor.set_spac_log_norm(log_floor_db=-10)
        # 如果 buffer 长度 大于 5s 的话,用 set_spac_log_norm(dynamic_range_db=8)
        else:
            self.feature_extractor.set_spac_log_norm(dynamic_range_db=8)
        feats = self.feature_extractor(
            self.buffer,
            sampling_rate=self.sample_rate,
            return_tensors="pt",
            padding=False,
        )
        return feats.input_features  # [1, 80, T]

    def _stable_frames_count(self) -> int:
        # 已稳定帧数 = floor((len(buffer) - half_window) / hop) + 1,最小为0
        L = int(self.buffer.shape[0])
        if L <= 0:
            return 0
        if L < self.half_window:
            return 0
        return max(0, (L - self.half_window) // self.hop_length + 1)

    def _maybe_slide_buffer(self):
        """Trigger模式滑窗:当缓冲区达到触发阈值时,滑动固定长度的窗口。"""
        if not self.enable_sliding_window:
            return

        sr = self.sample_rate
        hop = self.hop_length
        L = len(self.buffer)

        # 将秒数转换为样本数
        trigger_samples = int(self.trigger_seconds * sr)
        stride_samples = int(self.slide_seconds * sr)

        # 检查是否达到触发阈值
        if L < trigger_samples:
            return

        # 计算需要丢弃的样本数(固定滑动 stride_samples)
        drop = stride_samples

        # 不能丢掉后续发射还需要的左侧上下文
        # 在trigger模式下,我们只需要保护最小必要的数据
        # 即:确保不丢弃未来可能需要的帧
        last_emitted_local = self.last_emitted_T - self.base_T

        # 只保护必要的上下文(例如,最近的1秒数据)
        min_keep_seconds = 1.0  # 保留至少1秒的数据以确保处理的连续性
        min_keep_samples = int(min_keep_seconds * sr)

        # guard_samples 是我们必须保留的最小样本数
        guard_samples = min(min_keep_samples, L - drop)

        # 限制:不得越过安全边界;并对齐 hop
        max_allowed_drop = max(0, L - guard_samples)
        drop = min(drop, max_allowed_drop)
        drop = (drop // hop) * hop

        if drop <= 0:
            return

        # 真正丢弃 & 更新基准
        self.buffer = self.buffer[drop:]
        self.left_samples_dropped += drop
        self.base_T += drop // hop

        if self.verbose:
            print(
                f"[Slide] Trigger模式: drop={drop/sr:.2f}s samples, base_T={self.base_T}, buffer_after={len(self.buffer)/sr:.2f}s"
            )

    def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]:
        self.chunk_count += 1
        # 追加到缓冲
        if len(self.buffer) == 0:
            self.buffer = audio_chunk.astype(np.float32, copy=True)
        else:
            self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)])

        # --- 滑窗处理 ---
        self._maybe_slide_buffer()

        # 全量提取(针对当前窗口)
        mel_full = self._extract_full()
        T_full = mel_full.shape[-1]  # 当前窗口的局部帧数
        stable_T = min(T_full, self._stable_frames_count())  # 局部可稳定帧
        stable_T_global = self.base_T + stable_T  # 映射到全局帧坐标

        # 计划本次发射的核心帧(全局坐标)
        core_start_g = self.last_emitted_T
        core_end_g = core_start_g + self.chunk_frames
        required_stable_g = core_end_g + self.cnn_redundancy_frames

        if self.verbose:
            print(
                f"[Exact] buffer_len={len(self.buffer)} samples, T_full(local)={T_full}, "
                f"stable_T(local)={stable_T}, base_T={self.base_T}, "
                f"stable_T(global)={stable_T_global}, last_emitted={self.last_emitted_T}"
            )

        if stable_T_global >= required_stable_g or is_last_chunk:
            emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames)
            emit_end_g = core_end_g + self.cnn_redundancy_frames

            # 全局 -> 局部索引
            emit_start = max(0, emit_start_g - self.base_T)
            emit_end = emit_end_g - self.base_T
            emit_start = max(0, min(emit_start, T_full))
            emit_end = max(emit_start, min(emit_end, T_full))

            mel_output = mel_full[:, :, emit_start:emit_end]
            self.last_emitted_T = core_end_g  # 仅推进核心帧指针(全局)
        else:
            mel_output = mel_full[:, :, 0:0]

        self.total_samples_processed += len(audio_chunk)
        self.is_first = False

        info = {
            "type": "exact_chunk",
            "chunk_number": self.chunk_count,
            "emitted_frames": mel_output.shape[-1],
            "stable_T": stable_T,
            "T_full": T_full,
            "base_T": self.base_T,
            "stable_T_global": stable_T_global,
            "buffer_len_samples": int(self.buffer.shape[0]),
            "left_samples_dropped": self.left_samples_dropped,
            "core_start": core_start_g,  # 如果保留原字段名,这里用全局值
            "core_end": core_end_g,  # 同上
        }
        return mel_output, info

    def flush(self) -> torch.Tensor:
        """在流结束时调用,输出剩余未发出的帧,保证与离线一致(按全局坐标计算)。"""
        if len(self.buffer) == 0:
            return torch.zeros(1, 80, 0)

        mel_full = self._extract_full()
        T_local = mel_full.shape[-1]
        T_global = self.base_T + T_local

        if self.last_emitted_T < T_global:
            start_l = max(0, self.last_emitted_T - self.base_T)
            tail = mel_full[:, :, start_l:]
            self.last_emitted_T = T_global
            if self.verbose:
                print(f"[Exact] flush {tail.shape[-1]} frames (T_global={T_global})")
            return tail
        return mel_full[:, :, 0:0]

    def get_config(self) -> Dict:
        return {
            "chunk_ms": self.chunk_ms,
            "first_chunk_ms": self.first_chunk_ms,
            "effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0,
            "sample_rate": self.sample_rate,
            "n_fft": self.n_fft,
            "hop_length": self.hop_length,
            "cnn_redundancy_ms": self.cnn_redundancy_ms,
            "cnn_redundancy_frames": self.cnn_redundancy_frames,
            "enable_sliding_window": self.enable_sliding_window,
            "trigger_seconds": self.trigger_seconds,
            "slide_seconds": self.slide_seconds,
        }

    def get_state(self) -> Dict:
        return {
            "chunk_count": self.chunk_count,
            "last_emitted_T": self.last_emitted_T,
            "total_samples_processed": self.total_samples_processed,
            "buffer_len": int(self.buffer.shape[0]),
            "base_T": self.base_T,
            "left_samples_dropped": self.left_samples_dropped,
        }

    def get_snapshot(self) -> Dict:
        """获取完整状态快照(包括 buffer),用于抢跑恢复

        Returns:
            包含完整状态的字典,可用于 restore_snapshot 恢复
        """
        buffer_copy = self.buffer.copy()
        snapshot = {
            "chunk_count": self.chunk_count,
            "last_emitted_T": self.last_emitted_T,
            "total_samples_processed": self.total_samples_processed,
            "buffer": buffer_copy,
            "base_T": self.base_T,
            "left_samples_dropped": self.left_samples_dropped,
            "is_first": self.is_first,
            # 保存 feature_extractor 的状态(关键:确保 mel 特征提取的确定性)
            "fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None),
            "fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None),
            "fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None),
        }
        logger.debug(
            "[MelProcessor] Created snapshot: chunk_count=%d, last_emitted_T=%d, "
            "buffer_len=%d, buffer_sum=%.6f, total_samples=%d",
            self.chunk_count,
            self.last_emitted_T,
            len(buffer_copy),
            float(buffer_copy.sum()) if len(buffer_copy) > 0 else 0.0,
            self.total_samples_processed,
        )
        return snapshot

    def restore_snapshot(self, snapshot: Dict) -> None:
        """从快照恢复状态

        Args:
            snapshot: 由 get_snapshot 返回的快照字典
        """
        # 记录恢复前的状态
        prev_state = {
            "chunk_count": self.chunk_count,
            "last_emitted_T": self.last_emitted_T,
            "buffer_len": len(self.buffer),
        }

        # 恢复状态
        self.chunk_count = snapshot["chunk_count"]
        self.last_emitted_T = snapshot["last_emitted_T"]
        self.total_samples_processed = snapshot["total_samples_processed"]
        self.buffer = snapshot["buffer"].copy()  # 复制 buffer
        self.base_T = snapshot["base_T"]
        self.left_samples_dropped = snapshot["left_samples_dropped"]
        self.is_first = snapshot["is_first"]

        # 恢复 feature_extractor 的状态(关键:确保 mel 特征提取的确定性)
        if snapshot.get("fe_dynamic_log_norm") is not None:
            self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"]
        if snapshot.get("fe_dynamic_range_db") is not None:
            self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"]
        if snapshot.get("fe_log_floor_db") is not None:
            self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"]

        logger.info(
            "[MelProcessor] Restored snapshot: chunk_count %d->%d, last_emitted_T %d->%d, "
            "buffer_len %d->%d, total_samples=%d",
            prev_state["chunk_count"],
            self.chunk_count,
            prev_state["last_emitted_T"],
            self.last_emitted_T,
            prev_state["buffer_len"],
            len(self.buffer),
            self.total_samples_processed,
        )