Text-to-Speech
ONNX
GGUF
speech-translation
streaming-speech-translation
speech
audio
speech-recognition
automatic-speech-recognition
streaming-asr
ASR
NeMo
ONNX
cache-aware ASR
FastConformer
RNNT
Parakeet
neural-machine-translation
NMT
gemma3
llama-cpp
GGUF
conversational
TTS
xtts
xttsv2
voice-clone
gpt2
hifigan
multilingual
vq
perceiver-encoder
websocket
| #!/usr/bin/env python3 | |
| # License: CC-BY-NC-ND-4.0 | |
| # Created by: Patrick Lumbantobing, Vertox-AI | |
| # Copyright (c) 2026 Vertox-AI. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons | |
| # Attribution-NonCommercial-NoDerivatives 4.0 International License. | |
| # To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc-nd/4.0/ | |
| """ | |
| Cache-aware streaming audio and feature buffers for Nemotron ASR. | |
| Adapted from: https://github.com/NVIDIA-NeMo/NeMo/tree/main | |
| Implements: | |
| - :class:`CacheAwareStreamingAudioBuffer` for audio → feature chunks | |
| compatible with NeMo cache-aware encoders. | |
| - :class:`CacheAwareStreamingASR` for encoder/decoder state management, | |
| hypothesis accumulation, and timestamped text output. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from collections.abc import Iterable | |
| from typing import Generator, List, Optional | |
| import numpy as np | |
| import numpy.typing as npt | |
| from src.asr.cache_aware_modules_config import (CacheAwareStreamingConfig, | |
| TimestampedResult) | |
| from src.asr.utils import log_softmax | |
| LOG_ZERO_GUARD_VALUE = float(2**-24) | |
| class CacheAwareStreamingAudioBuffer: | |
| """ | |
| Streaming audio and feature buffer for cache-aware ASR. | |
| Handles: | |
| - Chunking raw audio into overlapping frames for the preprocessor. | |
| - Dropping padded STFT frames after the first chunk. | |
| - Maintaining a feature buffer with pre-encode cache appended. | |
| """ | |
| def __init__(self, preprocessor, streaming_cfg: CacheAwareStreamingConfig) -> None: | |
| """ | |
| Parameters | |
| ---------- | |
| preprocessor : | |
| Callable that maps ``(waveforms, lengths)`` to | |
| ``(features, feature_lengths)``. | |
| streaming_cfg : | |
| Cache-aware streaming configuration. | |
| """ | |
| self._preprocessor = preprocessor | |
| self._streaming_cfg = streaming_cfg | |
| self.audio_buffer: Optional[npt.NDArray[np.float32]] = None | |
| self.audio_step: int = 0 | |
| self.features_buffer: Optional[npt.NDArray[np.float32]] = None | |
| self._audio_chunks_lens = np.array( | |
| [self._streaming_cfg.audio_chunk_frames * self._streaming_cfg.audio_frame_size], | |
| dtype=np.int64, | |
| ) | |
| self._audio_frames_drops_lens = ( | |
| self._streaming_cfg.audio_chunk_frames_drop * self._streaming_cfg.audio_frame_size | |
| ) | |
| self._features_frames_takes_lens = self._streaming_cfg.audio_chunk_frames - 1 | |
| self._chunk_size = self._streaming_cfg.chunk_size[1] | |
| self._shift_size = self._streaming_cfg.shift_size[1] | |
| self._pre_encode_cache_size = self._streaming_cfg.pre_encode_cache_size[1] | |
| self._cache_chunk_size = self._pre_encode_cache_size + self._chunk_size | |
| self._features_chunk_lengths = np.array([self._cache_chunk_size], dtype=np.int64) | |
| self._current_text: str = "" | |
| self._first_cache_pre_encode = np.log( | |
| np.zeros( | |
| (1, self._streaming_cfg.input_features, self._pre_encode_cache_size), | |
| dtype=np.float32, | |
| ) | |
| + LOG_ZERO_GUARD_VALUE | |
| ) | |
| def len_audio_buffer(self) -> int: | |
| """Return current audio buffer length (samples).""" | |
| return int(self.audio_buffer.shape[-1]) if self.audio_buffer is not None else 0 | |
| def len_features_buffer(self) -> int: | |
| """Return current feature buffer length (frames).""" | |
| return int(self.features_buffer.shape[-1]) if self.features_buffer is not None else 0 | |
| def reset_buffers(self) -> None: | |
| """Reset both audio and feature buffers.""" | |
| self.reset_audio_buffer() | |
| self.reset_features_buffer() | |
| def reset_audio_buffer(self) -> None: | |
| """Reset audio buffer and step counter.""" | |
| self.audio_buffer = None | |
| self.audio_step = 0 | |
| def reset_features_buffer(self) -> None: | |
| """Reset feature buffer.""" | |
| self.features_buffer = None | |
| def append_audio_buffer(self, audio_signal: npt.NDArray[np.float32]) -> None: | |
| """Append new audio samples to the buffer.""" | |
| if self.audio_buffer is None: | |
| self.audio_buffer = audio_signal | |
| else: | |
| self.audio_buffer = np.concatenate((self.audio_buffer, audio_signal), axis=-1).astype(np.float32) | |
| def process_audio_buffer( | |
| self, | |
| last: bool = False, | |
| ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: | |
| """ | |
| Convert buffered audio into feature chunks. | |
| Yields | |
| ------ | |
| np.ndarray or None | |
| Feature chunks of shape ``(1, feats, frames)`` or ``None`` when | |
| no more chunks are available. | |
| """ | |
| if self.audio_buffer is None: | |
| if last: | |
| yield None | |
| return | |
| while self._audio_chunks_lens[0] <= self.audio_buffer.shape[-1]: | |
| audio_chunks = self.audio_buffer[:, : self._audio_chunks_lens[0]] | |
| audio_features, _ = self._preprocessor(audio_chunks, self._audio_chunks_lens) | |
| self.audio_buffer = self.audio_buffer[:, self._audio_frames_drops_lens :] | |
| if self.audio_step > 0: | |
| audio_features = audio_features[ | |
| :, | |
| :, | |
| self._streaming_cfg.audio_chunk_frames_drop : self._features_frames_takes_lens, | |
| ] | |
| else: | |
| audio_features = audio_features[:, :, : self._features_frames_takes_lens] | |
| self.audio_step += self._audio_frames_drops_lens | |
| yield audio_features | |
| if last and self.audio_buffer is not None and self.audio_buffer.shape[-1] > 0: | |
| n_pad = self._audio_chunks_lens[0] - self.audio_buffer.shape[-1] | |
| zeros_pad = np.zeros((1, n_pad), dtype=np.float32) | |
| self.audio_buffer = np.concatenate((self.audio_buffer, zeros_pad), axis=-1).astype(np.float32) | |
| audio_chunks = self.audio_buffer[:, : self._audio_chunks_lens[0]] | |
| audio_features, _ = self._preprocessor(audio_chunks, self._audio_chunks_lens) | |
| self.audio_buffer = self.audio_buffer[:, self._audio_chunks_lens[0] :] | |
| if self.audio_step > 0: | |
| yield audio_features[:, :, self._streaming_cfg.audio_chunk_frames_drop :] | |
| else: | |
| yield audio_features | |
| self.reset_audio_buffer() | |
| yield None | |
| def append_audio_buffer_to_process_for_features( | |
| self, | |
| audio_signal: npt.NDArray[np.float32], | |
| last: bool = False, | |
| ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: | |
| """Append audio and immediately yield any ready feature chunks.""" | |
| self.append_audio_buffer(audio_signal) | |
| return self.process_audio_buffer(last=last) | |
| def append_features_buffer(self, audio_features: npt.NDArray[np.float32]) -> None: | |
| """Append new feature frames, preprending initial pre-encode cache if needed.""" | |
| if self.features_buffer is None: | |
| self.features_buffer = np.concatenate((self._first_cache_pre_encode, audio_features), axis=-1).astype( | |
| np.float32 | |
| ) | |
| else: | |
| self.features_buffer = np.concatenate((self.features_buffer, audio_features), axis=-1).astype(np.float32) | |
| def process_features_buffer( | |
| self, | |
| last: bool = False, | |
| ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: | |
| """ | |
| Convert feature buffer into encoder-ready feature chunks. | |
| Yields | |
| ------ | |
| np.ndarray or None | |
| Feature chunks of shape ``(1, feats, cache_chunk_size)`` or | |
| ``None`` when no more chunks are available. | |
| """ | |
| if self.features_buffer is None: | |
| if last: | |
| yield None | |
| return | |
| while self._cache_chunk_size <= self.features_buffer.shape[-1]: | |
| features_chunk = self.features_buffer[:, :, : self._cache_chunk_size] | |
| self.features_buffer = self.features_buffer[:, :, self._shift_size :] | |
| yield features_chunk | |
| if last and self.features_buffer.shape[-1] > 0: | |
| n_pad = self._cache_chunk_size - self.features_buffer.shape[-1] | |
| zeros_pad = np.log( | |
| np.zeros( | |
| (1, self.features_buffer.shape[1], n_pad), | |
| dtype=np.float32, | |
| ) | |
| + LOG_ZERO_GUARD_VALUE | |
| ) | |
| features_chunk = np.concatenate((self.features_buffer, zeros_pad), axis=-1).astype(np.float32) | |
| self.features_buffer = self.features_buffer[:, :, self._cache_chunk_size :] | |
| yield features_chunk | |
| self.reset_features_buffer() | |
| yield None | |
| def append_features_buffer_to_process_for_features_chunk( | |
| self, | |
| audio_features: npt.NDArray[np.float32], | |
| last: bool = False, | |
| ) -> Generator[Optional[npt.NDArray[np.float32]], None, None]: | |
| """Append features and immediately yield any ready feature chunks.""" | |
| self.append_features_buffer(audio_features) | |
| return self.process_features_buffer(last=last) | |
| class CacheAwareStreamingASR: | |
| """ | |
| Cache-aware streaming ASR wrapper around encoder/decoder ONNX models. | |
| Maintains encoder caches, decoder recurrent state, and an evolving | |
| hypothesis (tokens, timestamps, logprobs), producing incremental | |
| :class:`TimestampedResult` objects from feature chunks. | |
| """ | |
| def __init__( | |
| self, | |
| asr_encoder, | |
| asr_decoder, | |
| vocab: List[int], | |
| blank_idx: int, | |
| streaming_cfg: CacheAwareStreamingConfig, | |
| ) -> None: | |
| """ | |
| Parameters | |
| ---------- | |
| asr_encoder : | |
| ONNX Runtime session for the cache-aware encoder. | |
| asr_decoder : | |
| ONNX Runtime session for the decoder/joint network. | |
| vocab : | |
| Mapping from token IDs to text pieces. | |
| blank_idx : | |
| Index of the blank label in the vocabulary. | |
| streaming_cfg : | |
| Cache-aware streaming configuration. | |
| """ | |
| self._asr_encoder = asr_encoder | |
| self._asr_decoder = asr_decoder | |
| self._vocab = vocab | |
| self._vocab_size = len(self._vocab) | |
| self._blank_idx = blank_idx | |
| self._streaming_cfg = streaming_cfg | |
| # encoder cache | |
| self._cache_last_channel: npt.NDArray[np.float32] | None = None | |
| self._cache_last_time: npt.NDArray[np.float32] | None = None | |
| self._cache_last_channel_len: npt.NDArray[np.int64] | None = None | |
| self.set_init_encoder_cache() | |
| # encoder lengths | |
| self._chunk_size = self._streaming_cfg.chunk_size[1] | |
| self._pre_encode_cache_size = self._streaming_cfg.pre_encode_cache_size[1] | |
| self._cache_chunk_size = self._pre_encode_cache_size + self._chunk_size | |
| self._features_chunk_lengths = np.array([self._cache_chunk_size], dtype=np.int64) | |
| self._encoder_out_lengths = np.array( | |
| [self._streaming_cfg.valid_encoder_out_len], | |
| dtype=np.int64, | |
| ) | |
| # decoder state | |
| self._prev_state: tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]] | None = None | |
| self._tokens: List[int] | None = None | |
| self._timestamps: List[int] | None = None | |
| self._logprobs: List[float] | None = None | |
| self._t_index: int | None = None | |
| self.set_init_decoder_state() | |
| self.set_init_decoder_vars() | |
| self._current_text: str = "" | |
| self._DECODE_SPACE_PATTERN = re.compile(r"\A\s|\s\B|(\s)\b") | |
| def set_init_encoder_cache(self) -> None: | |
| """Initialise encoder caches to zeros.""" | |
| self._cache_last_channel = np.zeros( | |
| ( | |
| self._streaming_cfg.len_layers, | |
| 1, | |
| self._streaming_cfg.last_channel_cache_size, | |
| self._streaming_cfg.d_model, | |
| ), | |
| dtype=np.float32, | |
| ).transpose(1, 0, 2, 3) | |
| self._cache_last_time = np.zeros( | |
| ( | |
| self._streaming_cfg.len_layers, | |
| 1, | |
| self._streaming_cfg.d_model, | |
| self._streaming_cfg.conv_context_size[0], | |
| ), | |
| dtype=np.float32, | |
| ).transpose(1, 0, 2, 3) | |
| self._cache_last_channel_len = np.zeros(1, dtype=np.int64) | |
| def set_init_decoder_state(self) -> None: | |
| """Initialise decoder hidden states to zeros based on input shapes.""" | |
| shapes = {x.name: x.shape for x in self._asr_decoder.get_inputs()} | |
| self._prev_state = ( | |
| np.zeros( | |
| shape=(shapes["input_states_1"][0], 1, shapes["input_states_1"][2]), | |
| dtype=np.float32, | |
| ), | |
| np.zeros( | |
| shape=(shapes["input_states_2"][0], 1, shapes["input_states_2"][2]), | |
| dtype=np.float32, | |
| ), | |
| ) | |
| def set_init_decoder_vars(self) -> None: | |
| """Reset token, timestamp, logprob lists and time index.""" | |
| self._tokens = [] | |
| self._timestamps = [] | |
| self._logprobs = [] | |
| self._t_index = 0 | |
| def reset_states(self) -> None: | |
| """Reset encoder cache, decoder state, and current text.""" | |
| self.set_init_encoder_cache() | |
| self.set_init_decoder_state() | |
| self.set_init_decoder_vars() | |
| self._current_text = "" | |
| def process_encoder_step( | |
| self, | |
| features_chunk: npt.NDArray[np.float32], | |
| ) -> npt.NDArray[np.float32]: | |
| """ | |
| Run one encoder step with cache-aware inputs. | |
| Returns | |
| ------- | |
| encoder_out: ``(batch, time, dimension)`` | |
| """ | |
| assert self._features_chunk_lengths[0] == features_chunk.shape[-1] | |
| ( | |
| encoder_out, | |
| encoder_out_lens, | |
| cache_last_channel_next, | |
| cache_last_time_next, | |
| cache_last_channel_next_len, | |
| ) = self._asr_encoder.run( | |
| [ | |
| "outputs", | |
| "encoded_lengths", | |
| "cache_last_channel_next", | |
| "cache_last_time_next", | |
| "cache_last_channel_next_len", | |
| ], | |
| { | |
| "audio_signal": features_chunk, | |
| "length": self._features_chunk_lengths, | |
| "cache_last_channel": self._cache_last_channel, | |
| "cache_last_time": self._cache_last_time, | |
| "cache_last_channel_len": self._cache_last_channel_len, | |
| }, | |
| ) | |
| self._cache_last_channel = cache_last_channel_next | |
| self._cache_last_time = cache_last_time_next | |
| self._cache_last_channel_len = cache_last_channel_next_len | |
| return encoder_out.transpose(0, 2, 1) | |
| def _decode_tokens( | |
| self, ids: Iterable[int], indices: Iterable[int] | None, logprobs: Iterable[float] | None | |
| ) -> TimestampedResult: | |
| """ | |
| Decode token ids including timestamps, running text, and text delta. | |
| Returns | |
| ------- | |
| TimestampedResult: | |
| contains running text, timestamps, all tokens, all logprobs, and text delta | |
| """ | |
| tokens = [self._vocab[i] for i in ids] | |
| text = re.sub(self._DECODE_SPACE_PATTERN, lambda x: " " if x.group(1) else "", "".join(tokens)) | |
| n_added_chars = len(text) - len(self._current_text) | |
| added_text = text[-n_added_chars:] if n_added_chars > 0 else "" | |
| timestamps = ( | |
| None | |
| if indices is None | |
| else ( | |
| self._streaming_cfg.window_step * self._streaming_cfg.subsampling_factor * np.asarray(indices) | |
| ).tolist() | |
| ) | |
| return TimestampedResult( | |
| text, timestamps, tokens, None if logprobs is None else np.asarray(logprobs).tolist(), added_text | |
| ) | |
| def process_decoder_step(self, encoder_out): | |
| """ | |
| Run decoder steps with chunked encoder output. | |
| Returns | |
| ------- | |
| text: string | |
| full transcript from the start | |
| added_text: string | |
| text delta | |
| """ | |
| encodings = encoder_out[0] | |
| encodings_len = self._encoder_out_lengths[0] | |
| assert encodings_len == encodings.shape[0] | |
| step = 0 | |
| emitted_tokens = 0 | |
| while step < encodings_len: | |
| outputs, state1, state2 = self._asr_decoder.run( | |
| ["outputs", "output_states_1", "output_states_2"], | |
| { | |
| "encoder_outputs": encodings[step : step + 1, :, None], | |
| "targets": [[self._tokens[-1] if self._tokens else self._blank_idx]], | |
| "target_length": [1], | |
| "input_states_1": self._prev_state[0], | |
| "input_states_2": self._prev_state[1], | |
| }, | |
| ) | |
| logits = outputs.squeeze() | |
| state = (state1, state2) | |
| assert logits.shape[-1] <= self._vocab_size | |
| token = logits.argmax() | |
| if token != self._blank_idx: | |
| self._prev_state = state | |
| self._tokens.append(int(token)) | |
| self._timestamps.append(self._t_index) | |
| emitted_tokens += 1 | |
| self._logprobs.append(log_softmax(logits)[token]) | |
| if token == self._blank_idx or emitted_tokens == self._streaming_cfg.max_tokens_per_step: | |
| self._t_index += 1 | |
| emitted_tokens = 0 | |
| step += 1 | |
| if len(self._tokens) > 0: | |
| res = self._decode_tokens(self._tokens, self._timestamps, self._logprobs) | |
| self._current_text = res.text | |
| return res.text, res.added_text | |
| else: | |
| return None, None | |