#!/usr/bin/env python3 # License: CC-BY-NC-ND-4.0 # Created by: Patrick Lumbantobing, Vertox-AI # Copyright (c) 2026 Vertox-AI. All rights reserved. # # This work is licensed under the Creative Commons # Attribution-NonCommercial-NoDerivatives 4.0 International License. # To view a copy of this license, visit # http://creativecommons.org/licenses/by-nc-nd/4.0/ """ Cache-aware streaming audio and feature buffers for Nemotron ASR. Adapted from: https://github.com/NVIDIA-NeMo/NeMo/tree/main Implements: - :class:`CacheAwareStreamingAudioBuffer` for audio → feature chunks compatible with NeMo cache-aware encoders. - :class:`CacheAwareStreamingASR` for encoder/decoder state management, hypothesis accumulation, and timestamped text output. """ from __future__ import annotations import re from collections.abc import Iterable from typing import Generator, List, Optional import numpy as np import numpy.typing as npt from src.asr.cache_aware_modules_config import (CacheAwareStreamingConfig, TimestampedResult) from src.asr.utils import log_softmax LOG_ZERO_GUARD_VALUE = float(2**-24) class CacheAwareStreamingAudioBuffer: """ Streaming audio and feature buffer for cache-aware ASR. Handles: - Chunking raw audio into overlapping frames for the preprocessor. - Dropping padded STFT frames after the first chunk. - Maintaining a feature buffer with pre-encode cache appended. """ def __init__(self, preprocessor, streaming_cfg: CacheAwareStreamingConfig) -> None: """ Parameters ---------- preprocessor : Callable that maps ``(waveforms, lengths)`` to ``(features, feature_lengths)``. streaming_cfg : Cache-aware streaming configuration. """ self._preprocessor = preprocessor self._streaming_cfg = streaming_cfg self.audio_buffer: Optional[npt.NDArray[np.float32]] = None self.audio_step: int = 0 self.features_buffer: Optional[npt.NDArray[np.float32]] = None self._audio_chunks_lens = np.array( [self._streaming_cfg.audio_chunk_frames * self._streaming_cfg.audio_frame_size], dtype=np.int64, ) self._audio_frames_drops_lens = ( self._streaming_cfg.audio_chunk_frames_drop * self._streaming_cfg.audio_frame_size ) self._features_frames_takes_lens = self._streaming_cfg.audio_chunk_frames - 1 self._chunk_size = self._streaming_cfg.chunk_size[1] self._shift_size = self._streaming_cfg.shift_size[1] self._pre_encode_cache_size = self._streaming_cfg.pre_encode_cache_size[1] self._cache_chunk_size = self._pre_encode_cache_size + self._chunk_size self._features_chunk_lengths = np.array([self._cache_chunk_size], dtype=np.int64) self._current_text: str = "" self._first_cache_pre_encode = np.log( np.zeros( (1, self._streaming_cfg.input_features, self._pre_encode_cache_size), dtype=np.float32, ) + LOG_ZERO_GUARD_VALUE ) def len_audio_buffer(self) -> int: """Return current audio buffer length (samples).""" return int(self.audio_buffer.shape[-1]) if self.audio_buffer is not None else 0 def len_features_buffer(self) -> int: """Return current feature buffer length (frames).""" return int(self.features_buffer.shape[-1]) if self.features_buffer is not None else 0 def reset_buffers(self) -> None: """Reset both audio and feature buffers.""" self.reset_audio_buffer() self.reset_features_buffer() def reset_audio_buffer(self) -> None: """Reset audio buffer and step counter.""" self.audio_buffer = None self.audio_step = 0 def reset_features_buffer(self) -> None: """Reset feature buffer.""" self.features_buffer = None def append_audio_buffer(self, audio_signal: npt.NDArray[np.float32]) -> None: """Append new audio samples to the buffer.""" if self.audio_buffer is None: self.audio_buffer = audio_signal else: self.audio_buffer = np.concatenate((self.audio_buffer, audio_signal), axis=-1).astype(np.float32) def process_audio_buffer( self, last: bool = False, ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: """ Convert buffered audio into feature chunks. Yields ------ np.ndarray or None Feature chunks of shape ``(1, feats, frames)`` or ``None`` when no more chunks are available. """ if self.audio_buffer is None: if last: yield None return while self._audio_chunks_lens[0] <= self.audio_buffer.shape[-1]: audio_chunks = self.audio_buffer[:, : self._audio_chunks_lens[0]] audio_features, _ = self._preprocessor(audio_chunks, self._audio_chunks_lens) self.audio_buffer = self.audio_buffer[:, self._audio_frames_drops_lens :] if self.audio_step > 0: audio_features = audio_features[ :, :, self._streaming_cfg.audio_chunk_frames_drop : self._features_frames_takes_lens, ] else: audio_features = audio_features[:, :, : self._features_frames_takes_lens] self.audio_step += self._audio_frames_drops_lens yield audio_features if last and self.audio_buffer is not None and self.audio_buffer.shape[-1] > 0: n_pad = self._audio_chunks_lens[0] - self.audio_buffer.shape[-1] zeros_pad = np.zeros((1, n_pad), dtype=np.float32) self.audio_buffer = np.concatenate((self.audio_buffer, zeros_pad), axis=-1).astype(np.float32) audio_chunks = self.audio_buffer[:, : self._audio_chunks_lens[0]] audio_features, _ = self._preprocessor(audio_chunks, self._audio_chunks_lens) self.audio_buffer = self.audio_buffer[:, self._audio_chunks_lens[0] :] if self.audio_step > 0: yield audio_features[:, :, self._streaming_cfg.audio_chunk_frames_drop :] else: yield audio_features self.reset_audio_buffer() yield None def append_audio_buffer_to_process_for_features( self, audio_signal: npt.NDArray[np.float32], last: bool = False, ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: """Append audio and immediately yield any ready feature chunks.""" self.append_audio_buffer(audio_signal) return self.process_audio_buffer(last=last) def append_features_buffer(self, audio_features: npt.NDArray[np.float32]) -> None: """Append new feature frames, preprending initial pre-encode cache if needed.""" if self.features_buffer is None: self.features_buffer = np.concatenate((self._first_cache_pre_encode, audio_features), axis=-1).astype( np.float32 ) else: self.features_buffer = np.concatenate((self.features_buffer, audio_features), axis=-1).astype(np.float32) def process_features_buffer( self, last: bool = False, ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: """ Convert feature buffer into encoder-ready feature chunks. Yields ------ np.ndarray or None Feature chunks of shape ``(1, feats, cache_chunk_size)`` or ``None`` when no more chunks are available. """ if self.features_buffer is None: if last: yield None return while self._cache_chunk_size <= self.features_buffer.shape[-1]: features_chunk = self.features_buffer[:, :, : self._cache_chunk_size] self.features_buffer = self.features_buffer[:, :, self._shift_size :] yield features_chunk if last and self.features_buffer.shape[-1] > 0: n_pad = self._cache_chunk_size - self.features_buffer.shape[-1] zeros_pad = np.log( np.zeros( (1, self.features_buffer.shape[1], n_pad), dtype=np.float32, ) + LOG_ZERO_GUARD_VALUE ) features_chunk = np.concatenate((self.features_buffer, zeros_pad), axis=-1).astype(np.float32) self.features_buffer = self.features_buffer[:, :, self._cache_chunk_size :] yield features_chunk self.reset_features_buffer() yield None def append_features_buffer_to_process_for_features_chunk( self, audio_features: npt.NDArray[np.float32], last: bool = False, ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: """Append features and immediately yield any ready feature chunks.""" self.append_features_buffer(audio_features) return self.process_features_buffer(last=last) class CacheAwareStreamingASR: """ Cache-aware streaming ASR wrapper around encoder/decoder ONNX models. Maintains encoder caches, decoder recurrent state, and an evolving hypothesis (tokens, timestamps, logprobs), producing incremental :class:`TimestampedResult` objects from feature chunks. """ def __init__( self, asr_encoder, asr_decoder, vocab: List[int], blank_idx: int, streaming_cfg: CacheAwareStreamingConfig, ) -> None: """ Parameters ---------- asr_encoder : ONNX Runtime session for the cache-aware encoder. asr_decoder : ONNX Runtime session for the decoder/joint network. vocab : Mapping from token IDs to text pieces. blank_idx : Index of the blank label in the vocabulary. streaming_cfg : Cache-aware streaming configuration. """ self._asr_encoder = asr_encoder self._asr_decoder = asr_decoder self._vocab = vocab self._vocab_size = len(self._vocab) self._blank_idx = blank_idx self._streaming_cfg = streaming_cfg # encoder cache self._cache_last_channel: npt.NDArray[np.float32] | None = None self._cache_last_time: npt.NDArray[np.float32] | None = None self._cache_last_channel_len: npt.NDArray[np.int64] | None = None self.set_init_encoder_cache() # encoder lengths self._chunk_size = self._streaming_cfg.chunk_size[1] self._pre_encode_cache_size = self._streaming_cfg.pre_encode_cache_size[1] self._cache_chunk_size = self._pre_encode_cache_size + self._chunk_size self._features_chunk_lengths = np.array([self._cache_chunk_size], dtype=np.int64) self._encoder_out_lengths = np.array( [self._streaming_cfg.valid_encoder_out_len], dtype=np.int64, ) # decoder state self._prev_state: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]] | None = None self._tokens: List[int] | None = None self._timestamps: List[int] | None = None self._logprobs: List[float] | None = None self._t_index: int | None = None self.set_init_decoder_state() self.set_init_decoder_vars() self._current_text: str = "" self._DECODE_SPACE_PATTERN = re.compile(r"\A\s|\s\B|(\s)\b") def set_init_encoder_cache(self) -> None: """Initialise encoder caches to zeros.""" self._cache_last_channel = np.zeros( ( self._streaming_cfg.len_layers, 1, self._streaming_cfg.last_channel_cache_size, self._streaming_cfg.d_model, ), dtype=np.float32, ).transpose(1, 0, 2, 3) self._cache_last_time = np.zeros( ( self._streaming_cfg.len_layers, 1, self._streaming_cfg.d_model, self._streaming_cfg.conv_context_size[0], ), dtype=np.float32, ).transpose(1, 0, 2, 3) self._cache_last_channel_len = np.zeros(1, dtype=np.int64) def set_init_decoder_state(self) -> None: """Initialise decoder hidden states to zeros based on input shapes.""" shapes = {x.name: x.shape for x in self._asr_decoder.get_inputs()} self._prev_state = ( np.zeros( shape=(shapes["input_states_1"][0], 1, shapes["input_states_1"][2]), dtype=np.float32, ), np.zeros( shape=(shapes["input_states_2"][0], 1, shapes["input_states_2"][2]), dtype=np.float32, ), ) def set_init_decoder_vars(self) -> None: """Reset token, timestamp, logprob lists and time index.""" self._tokens = [] self._timestamps = [] self._logprobs = [] self._t_index = 0 def reset_states(self) -> None: """Reset encoder cache, decoder state, and current text.""" self.set_init_encoder_cache() self.set_init_decoder_state() self.set_init_decoder_vars() self._current_text = "" def process_encoder_step( self, features_chunk: npt.NDArray[np.float32], ) -> npt.NDArray[np.float32]: """ Run one encoder step with cache-aware inputs. Returns ------- encoder_out: ``(batch, time, dimension)`` """ assert self._features_chunk_lengths[0] == features_chunk.shape[-1] ( encoder_out, encoder_out_lens, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len, ) = self._asr_encoder.run( [ "outputs", "encoded_lengths", "cache_last_channel_next", "cache_last_time_next", "cache_last_channel_next_len", ], { "audio_signal": features_chunk, "length": self._features_chunk_lengths, "cache_last_channel": self._cache_last_channel, "cache_last_time": self._cache_last_time, "cache_last_channel_len": self._cache_last_channel_len, }, ) self._cache_last_channel = cache_last_channel_next self._cache_last_time = cache_last_time_next self._cache_last_channel_len = cache_last_channel_next_len return encoder_out.transpose(0, 2, 1) def _decode_tokens( self, ids: Iterable[int], indices: Iterable[int] | None, logprobs: Iterable[float] | None ) -> TimestampedResult: """ Decode token ids including timestamps, running text, and text delta. Returns ------- TimestampedResult: contains running text, timestamps, all tokens, all logprobs, and text delta """ tokens = [self._vocab[i] for i in ids] text = re.sub(self._DECODE_SPACE_PATTERN, lambda x: " " if x.group(1) else "", "".join(tokens)) n_added_chars = len(text) - len(self._current_text) added_text = text[-n_added_chars:] if n_added_chars > 0 else "" timestamps = ( None if indices is None else ( self._streaming_cfg.window_step * self._streaming_cfg.subsampling_factor * np.asarray(indices) ).tolist() ) return TimestampedResult( text, timestamps, tokens, None if logprobs is None else np.asarray(logprobs).tolist(), added_text ) def process_decoder_step(self, encoder_out): """ Run decoder steps with chunked encoder output. Returns ------- text: string full transcript from the start added_text: string text delta """ encodings = encoder_out[0] encodings_len = self._encoder_out_lengths[0] assert encodings_len == encodings.shape[0] step = 0 emitted_tokens = 0 while step < encodings_len: outputs, state1, state2 = self._asr_decoder.run( ["outputs", "output_states_1", "output_states_2"], { "encoder_outputs": encodings[step : step + 1, :, None], "targets": [[self._tokens[-1] if self._tokens else self._blank_idx]], "target_length": [1], "input_states_1": self._prev_state[0], "input_states_2": self._prev_state[1], }, ) logits = outputs.squeeze() state = (state1, state2) assert logits.shape[-1] <= self._vocab_size token = logits.argmax() if token != self._blank_idx: self._prev_state = state self._tokens.append(int(token)) self._timestamps.append(self._t_index) emitted_tokens += 1 self._logprobs.append(log_softmax(logits)[token]) if token == self._blank_idx or emitted_tokens == self._streaming_cfg.max_tokens_per_step: self._t_index += 1 emitted_tokens = 0 step += 1 if len(self._tokens) > 0: res = self._decode_tokens(self._tokens, self._timestamps, self._logprobs) self._current_text = res.text return res.text, res.added_text else: return None, None