Text-to-Speech
ONNX
GGUF
speech-translation
streaming-speech-translation
speech
audio
speech-recognition
automatic-speech-recognition
streaming-asr
ASR
NeMo
ONNX
cache-aware ASR
FastConformer
RNNT
Parakeet
neural-machine-translation
NMT
gemma3
llama-cpp
GGUF
conversational
TTS
xtts
xttsv2
voice-clone
gpt2
hifigan
multilingual
vq
perceiver-encoder
websocket
| #!/usr/bin/env python3 | |
| # License: CC-BY-NC-ND-4.0 | |
| # Created by: Patrick Lumbantobing, Vertox-AI and upstream authors | |
| # Copyright (c) 2025-2026 Vertox-AI. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons | |
| # Attribution-NonCommercial-NoDerivatives 4.0 International License. | |
| # To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc-nd/4.0/ | |
| """ | |
| Preprocessor and ASR modules for Nemotron cache-aware streaming ASR. | |
| Adapted from: https://github.com/istupakov/onnx-asr/tree/main | |
| Provides: | |
| - :class:`Preprocessor` for waveform → feature extraction. | |
| - :class:`NemoConformerRnnt` ONNX wrapper for NeMo Conformer RNN-T. | |
| - :class:`Resampler` ONNX wrapper for waveform resampling. | |
| - :class:`ASRModelPackage` to bundle ASR + resampler into a single API. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from collections.abc import Sequence | |
| from concurrent.futures import ThreadPoolExecutor | |
| from pathlib import Path | |
| from typing import Any, Generic, Literal, TypeVar, get_args | |
| import numpy as np | |
| import numpy.typing as npt | |
| import onnxruntime as rt | |
| from src.asr.modules_config import (ASRConfig, ASRRuntimeConfig, | |
| PreprocessorRuntimeConfig) | |
| from src.asr.onnx_utils import (OnnxSessionOptions, TensorRtOptions, | |
| get_onnx_providers, update_onnx_providers) | |
| from src.asr.utils import (SampleRates, find_files, is_float32_array, | |
| is_int64_array) | |
| log = logging.getLogger(__name__) | |
| S = TypeVar("S") | |
| R = TypeVar("R") | |
| class Preprocessor: | |
| """ | |
| ONNX-based ASR preprocessor (waveform → model features). | |
| Wraps an ONNX preprocessor model (e.g., Kaldi-style feature extractor) | |
| and provides a call interface that can optionally parallelise over | |
| batch items using a thread pool. | |
| Parameters | |
| ---------- | |
| model_dir : | |
| Directory with preprocessor ONNX model(s). | |
| name : | |
| Preprocessor name (e.g., ``"kaldi"``). | |
| runtime_config : | |
| Preprocessor runtime configuration (providers, concurrency). | |
| """ | |
| def __init__( | |
| self, | |
| model_dir: str, | |
| name: str, | |
| runtime_config: PreprocessorRuntimeConfig, | |
| ) -> None: | |
| onnx_options = runtime_config.copy() | |
| self._max_concurrent_workers: int = onnx_options.pop("max_concurrent_workers", 1) | |
| log.debug(f"Preprocessor name={name}") | |
| log.debug(f"Preprocessor runtime_config={runtime_config}") | |
| if name == "identity": | |
| self._preprocessor: rt.InferenceSession | None = None | |
| else: | |
| providers = get_onnx_providers(onnx_options) | |
| if name == "kaldi" and providers and providers != ["CPUExecutionProvider"]: | |
| name = "kaldi_fast" | |
| dir_path = Path(model_dir) | |
| filename = str(Path(name).with_suffix(".onnx")) | |
| filepath = dir_path / filename | |
| log.debug(f"Preprocessor providers={providers}") | |
| log.debug(f"Preprocessor filename={filename}") | |
| log.debug(f"Preprocessor filepath={filepath}") | |
| log.debug(f"Preprocessor shapes_fn={self._preprocessor_shapes}") | |
| self._preprocessor = rt.InferenceSession( | |
| filepath.read_bytes(), | |
| **TensorRtOptions.add_profile(onnx_options, self._preprocessor_shapes), | |
| ) | |
| log.debug(f"Preprocessor loaded: {self._preprocessor}") | |
| def _get_excluded_providers() -> list[str]: | |
| """Providers to exclude for this preprocessor (e.g., GPU-only).""" | |
| return ["CUDAExecutionProvider"] | |
| def _preprocessor_shapes(self, waveform_len_ms: int, **kwargs: int) -> str: | |
| """ | |
| Shape function for TensorRT profiles. | |
| Returns a shape string like ``"waveforms:{batch}x{len},waveforms_lens:{batch}"``. | |
| """ | |
| return "waveforms:{batch}x{len},waveforms_lens:{batch}".format( | |
| len=waveform_len_ms * 16, | |
| **kwargs, | |
| ) | |
| def _preprocess( | |
| self, | |
| waveforms: npt.NDArray[np.float32], | |
| waveforms_lens: npt.NDArray[np.int64], | |
| ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: | |
| if not self._preprocessor: | |
| return waveforms, waveforms_lens | |
| features, features_lens = self._preprocessor.run( | |
| ["features", "features_lens"], | |
| {"waveforms": waveforms, "waveforms_lens": waveforms_lens}, | |
| ) | |
| assert is_float32_array(features) | |
| assert is_int64_array(features_lens) | |
| return features, features_lens | |
| def __call__( | |
| self, | |
| waveforms: npt.NDArray[np.float32], | |
| waveforms_lens: npt.NDArray[np.int64], | |
| ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: | |
| """ | |
| Convert batch waveforms to model features. | |
| For multi-item batches and when ``max_concurrent_workers > 1``, | |
| preprocessing is parallelised over batch items. | |
| """ | |
| if self._preprocessor is None or waveforms.shape[0] == 1 or self._max_concurrent_workers == 1: | |
| return self._preprocess(waveforms, waveforms_lens) | |
| with ThreadPoolExecutor(max_workers=self._max_concurrent_workers) as executor: | |
| features, features_lens = zip( | |
| *executor.map( | |
| self._preprocess, | |
| waveforms[:, None], | |
| waveforms_lens[:, None], | |
| strict=True, | |
| ) | |
| ) | |
| return np.concatenate(features, axis=0), np.concatenate(features_lens, axis=0) | |
| class BaseASR(ABC): | |
| """Base ONNX ASR wrapper.""" | |
| def __init__( | |
| self, | |
| model_dir: str, | |
| model_files: dict[str, Path], | |
| runtime_config: ASRRuntimeConfig, | |
| ) -> None: | |
| if "config" in model_files: | |
| with model_files["config"].open("rt", encoding="utf-8") as f: | |
| self.config: ASRConfig = json.load(f) | |
| else: | |
| self.config = {} | |
| self.runtime_config = runtime_config | |
| self._preprocessor = Preprocessor( | |
| model_dir, | |
| self._preprocessor_name, | |
| runtime_config.preprocessor_config, | |
| ) | |
| def _get_excluded_providers() -> list[str]: | |
| """Providers to exclude for ASR encoder/decoder (override per subclass).""" | |
| return [] | |
| def _get_sample_rate() -> Literal[8_000, 16_000]: | |
| """Default ASR model sample rate (Hz).""" | |
| return 16_000 | |
| def _preprocessor_name(self) -> str: | |
| """Return the name of the preprocessor ONNX model.""" | |
| ... | |
| class NemoConformerRnnt(BaseASR): | |
| """ONNX wrapper for a NeMo Conformer RNN-T ASR model.""" | |
| def __init__( | |
| self, | |
| model_dir: str, | |
| runtime_config: ASRRuntimeConfig, | |
| quantization: str | None = None, | |
| ) -> None: | |
| model_files = find_files(model_dir, self._get_model_files(quantization)) | |
| super().__init__(model_dir, model_files, runtime_config) | |
| if "vocab" in model_files: | |
| with Path(model_files["vocab"]).open("rt", encoding="utf-8") as f: | |
| self._vocab = { | |
| int(id_): token.replace("\u2581", " ") | |
| for token, id_ in (line.strip("\n").split(" ") for line in f) | |
| } | |
| self._vocab_size = len(self._vocab) | |
| blank_idx = next( | |
| (id_ for id_, token in self._vocab.items() if token == "<blk>"), | |
| None, | |
| ) | |
| if blank_idx is not None: | |
| self._blank_idx = blank_idx | |
| log.debug(f"NemoConformerRnnt model_files={model_files}") | |
| log.debug(f"NemoConformerRnnt runtime_config={runtime_config}") | |
| log.debug(f"NemoConformerRnnt runtime_config.onnx_options={runtime_config.onnx_options}") | |
| self._encoder = rt.InferenceSession( | |
| model_files["encoder"], | |
| **TensorRtOptions.add_profile( | |
| runtime_config.onnx_options, | |
| self._encoder_shapes, | |
| ), | |
| ) | |
| log.info("NemoConformerRnnt encoder loaded") | |
| self._decoder_joint = rt.InferenceSession( | |
| model_files["decoder_joint"], | |
| **runtime_config.onnx_options, | |
| ) | |
| log.info("NemoConformerRnnt decoder_joint loaded") | |
| def _encoder_shapes(self, waveform_len_ms: int, **kwargs: int) -> str: | |
| return "audio_signal:{batch}x{features_size}x{len},length:{batch}".format( | |
| len=waveform_len_ms // 10, features_size=self._features_size, **kwargs | |
| ) | |
| def _features_size(self) -> int: | |
| return self.config.get("features_size", 128) | |
| def _preprocessor_name(self) -> str: | |
| return "NemoPreprocessor128" | |
| def _subsampling_factor(self) -> int: | |
| return self.config.get("subsampling_factor", 8) | |
| def _get_model_files(quantization: str | None = None) -> dict[str, str]: | |
| suffix = "?" + quantization if quantization else "" | |
| return { | |
| "encoder": f"encoder-model{suffix}.onnx", | |
| "decoder_joint": f"decoder_joint-model{suffix}.onnx", | |
| "vocab": "vocab.txt", | |
| } | |
| class Resampler: | |
| """ | |
| Waveform resampler implemented with ONNX Runtime. | |
| Loads per-source-frequency ONNX resampler models and exposes | |
| a call interface that resamples waveforms to the target | |
| sample rate (8 or 16 kHz). | |
| """ | |
| def __init__( | |
| self, | |
| model_dir: str, | |
| sample_rate: Literal[8_000, 16_000], | |
| onnx_options: OnnxSessionOptions, | |
| ) -> None: | |
| self._target_sample_rate = sample_rate | |
| self._preprocessors: dict[SampleRates, rt.InferenceSession] = {} | |
| dir_path = Path(model_dir) | |
| for orig_freq in get_args(SampleRates): | |
| if orig_freq == sample_rate: | |
| continue | |
| self._preprocessors[orig_freq] = rt.InferenceSession( | |
| (dir_path / f"resample_{orig_freq // 1000}_{sample_rate // 1000}.onnx").read_bytes(), | |
| **TensorRtOptions.add_profile(onnx_options, self._preprocessor_shapes), | |
| ) | |
| def _get_excluded_providers() -> list[str]: | |
| """Providers to exclude for resampler sessions.""" | |
| return TensorRtOptions.get_provider_names() | |
| def _preprocessor_shapes(self, waveform_len_ms: int, **kwargs: int) -> str: | |
| """Shape function for resampler TensorRT profile.""" | |
| return "waveforms:{batch}x{len},waveforms_lens:{batch}".format( | |
| len=kwargs.get("resampler_waveform_len_ms", waveform_len_ms) * 48, | |
| **kwargs, | |
| ) | |
| def __call__( | |
| self, | |
| waveforms: npt.NDArray[np.float32], | |
| waveforms_lens: npt.NDArray[np.int64], | |
| sample_rate: SampleRates, | |
| ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: | |
| """ | |
| Resample waveforms to the target sample rate. | |
| If ``sample_rate`` already matches the target, the input is returned | |
| unchanged. | |
| """ | |
| if sample_rate == self._target_sample_rate: | |
| return waveforms, waveforms_lens | |
| resampled, resampled_lens = self._preprocessors[sample_rate].run( | |
| ["resampled", "resampled_lens"], | |
| {"waveforms": waveforms, "waveforms_lens": waveforms_lens}, | |
| ) | |
| assert is_float32_array(resampled) | |
| assert is_int64_array(resampled_lens) | |
| return resampled, resampled_lens | |
| class ASRModelPackage(ABC, Generic[R]): | |
| """ | |
| Bundle ASR model and resampler into a single package. | |
| Provides a static :meth:`load_model` helper to construct both from | |
| a directory of ONNX files and runtime options. | |
| """ | |
| asr: BaseASR | |
| resampler: Resampler | |
| def __init__(self, asr: BaseASR, resampler: Resampler) -> None: | |
| self.asr = asr | |
| self.resampler = resampler | |
| def load_model( # noqa: C901 | |
| path: str | Path | None = None, | |
| quantization: str | None = None, | |
| sess_options: rt.SessionOptions | None = None, | |
| providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, | |
| provider_options: Sequence[dict[Any, Any]] | None = None, | |
| asr_config: OnnxSessionOptions | None = None, | |
| preprocessor_config: PreprocessorRuntimeConfig | None = None, | |
| resampler_config: OnnxSessionOptions | None = None, | |
| ) -> "ASRModelPackage": | |
| """ | |
| Load ASR model and resampler from a model directory. | |
| Parameters | |
| ---------- | |
| path : | |
| Path to directory with model files. | |
| quantization : | |
| Model quantization (e.g., ``"int8"``) or ``None``. | |
| sess_options, providers, provider_options : | |
| Base ONNX Runtime session options. | |
| asr_config : | |
| ASR ONNX session config (overrides). | |
| preprocessor_config : | |
| Preprocessor ONNX and concurrency config. | |
| resampler_config : | |
| Resampler ONNX config. | |
| Returns | |
| ------- | |
| ASRModelPackage | |
| Initialised ASR model and resampler bundle. | |
| """ | |
| default_onnx_config: OnnxSessionOptions = { | |
| "sess_options": sess_options, | |
| "providers": providers or rt.get_available_providers(), | |
| "provider_options": provider_options, | |
| } | |
| if asr_config is None: | |
| asr_config = update_onnx_providers( | |
| default_onnx_config, | |
| excluded_providers=NemoConformerRnnt._get_excluded_providers(), | |
| ) | |
| if preprocessor_config is None: | |
| preprocessor_config = { | |
| **update_onnx_providers( | |
| default_onnx_config, | |
| new_options={ | |
| "TensorrtExecutionProvider": { | |
| "trt_fp16_enable": False, | |
| "trt_int8_enable": False, | |
| } | |
| }, | |
| excluded_providers=Preprocessor._get_excluded_providers(), | |
| ), | |
| "max_concurrent_workers": 1, | |
| } | |
| if resampler_config is None: | |
| resampler_config = update_onnx_providers( | |
| default_onnx_config, | |
| excluded_providers=Resampler._get_excluded_providers(), | |
| ) | |
| return ASRModelPackage( | |
| NemoConformerRnnt(path, ASRRuntimeConfig(asr_config, preprocessor_config), quantization), | |
| Resampler(path, NemoConformerRnnt._get_sample_rate(), resampler_config), | |
| ) | |