#!/usr/bin/env python3 # License: CC-BY-NC-ND-4.0 # Created by: Patrick Lumbantobing, Vertox-AI and upstream authors # Copyright (c) 2025-2026 Vertox-AI. All rights reserved. # # This work is licensed under the Creative Commons # Attribution-NonCommercial-NoDerivatives 4.0 International License. # To view a copy of this license, visit # http://creativecommons.org/licenses/by-nc-nd/4.0/ """ Preprocessor and ASR modules for Nemotron cache-aware streaming ASR. Adapted from: https://github.com/istupakov/onnx-asr/tree/main Provides: - :class:`Preprocessor` for waveform → feature extraction. - :class:`NemoConformerRnnt` ONNX wrapper for NeMo Conformer RNN-T. - :class:`Resampler` ONNX wrapper for waveform resampling. - :class:`ASRModelPackage` to bundle ASR + resampler into a single API. """ from __future__ import annotations import json import logging from abc import ABC, abstractmethod from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Generic, Literal, TypeVar, get_args import numpy as np import numpy.typing as npt import onnxruntime as rt from src.asr.modules_config import (ASRConfig, ASRRuntimeConfig, PreprocessorRuntimeConfig) from src.asr.onnx_utils import (OnnxSessionOptions, TensorRtOptions, get_onnx_providers, update_onnx_providers) from src.asr.utils import (SampleRates, find_files, is_float32_array, is_int64_array) log = logging.getLogger(__name__) S = TypeVar("S") R = TypeVar("R") class Preprocessor: """ ONNX-based ASR preprocessor (waveform → model features). Wraps an ONNX preprocessor model (e.g., Kaldi-style feature extractor) and provides a call interface that can optionally parallelise over batch items using a thread pool. Parameters ---------- model_dir : Directory with preprocessor ONNX model(s). name : Preprocessor name (e.g., ``"kaldi"``). runtime_config : Preprocessor runtime configuration (providers, concurrency). """ def __init__( self, model_dir: str, name: str, runtime_config: PreprocessorRuntimeConfig, ) -> None: onnx_options = runtime_config.copy() self._max_concurrent_workers: int = onnx_options.pop("max_concurrent_workers", 1) log.debug(f"Preprocessor name={name}") log.debug(f"Preprocessor runtime_config={runtime_config}") if name == "identity": self._preprocessor: rt.InferenceSession | None = None else: providers = get_onnx_providers(onnx_options) if name == "kaldi" and providers and providers != ["CPUExecutionProvider"]: name = "kaldi_fast" dir_path = Path(model_dir) filename = str(Path(name).with_suffix(".onnx")) filepath = dir_path / filename log.debug(f"Preprocessor providers={providers}") log.debug(f"Preprocessor filename={filename}") log.debug(f"Preprocessor filepath={filepath}") log.debug(f"Preprocessor shapes_fn={self._preprocessor_shapes}") self._preprocessor = rt.InferenceSession( filepath.read_bytes(), **TensorRtOptions.add_profile(onnx_options, self._preprocessor_shapes), ) log.debug(f"Preprocessor loaded: {self._preprocessor}") @staticmethod def _get_excluded_providers() -> list[str]: """Providers to exclude for this preprocessor (e.g., GPU-only).""" return ["CUDAExecutionProvider"] def _preprocessor_shapes(self, waveform_len_ms: int, **kwargs: int) -> str: """ Shape function for TensorRT profiles. Returns a shape string like ``"waveforms:{batch}x{len},waveforms_lens:{batch}"``. """ return "waveforms:{batch}x{len},waveforms_lens:{batch}".format( len=waveform_len_ms * 16, **kwargs, ) def _preprocess( self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: if not self._preprocessor: return waveforms, waveforms_lens features, features_lens = self._preprocessor.run( ["features", "features_lens"], {"waveforms": waveforms, "waveforms_lens": waveforms_lens}, ) assert is_float32_array(features) assert is_int64_array(features_lens) return features, features_lens def __call__( self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: """ Convert batch waveforms to model features. For multi-item batches and when ``max_concurrent_workers > 1``, preprocessing is parallelised over batch items. """ if self._preprocessor is None or waveforms.shape[0] == 1 or self._max_concurrent_workers == 1: return self._preprocess(waveforms, waveforms_lens) with ThreadPoolExecutor(max_workers=self._max_concurrent_workers) as executor: features, features_lens = zip( *executor.map( self._preprocess, waveforms[:, None], waveforms_lens[:, None], strict=True, ) ) return np.concatenate(features, axis=0), np.concatenate(features_lens, axis=0) class BaseASR(ABC): """Base ONNX ASR wrapper.""" def __init__( self, model_dir: str, model_files: dict[str, Path], runtime_config: ASRRuntimeConfig, ) -> None: if "config" in model_files: with model_files["config"].open("rt", encoding="utf-8") as f: self.config: ASRConfig = json.load(f) else: self.config = {} self.runtime_config = runtime_config self._preprocessor = Preprocessor( model_dir, self._preprocessor_name, runtime_config.preprocessor_config, ) @staticmethod def _get_excluded_providers() -> list[str]: """Providers to exclude for ASR encoder/decoder (override per subclass).""" return [] @staticmethod def _get_sample_rate() -> Literal[8_000, 16_000]: """Default ASR model sample rate (Hz).""" return 16_000 @property @abstractmethod def _preprocessor_name(self) -> str: """Return the name of the preprocessor ONNX model.""" ... class NemoConformerRnnt(BaseASR): """ONNX wrapper for a NeMo Conformer RNN-T ASR model.""" def __init__( self, model_dir: str, runtime_config: ASRRuntimeConfig, quantization: str | None = None, ) -> None: model_files = find_files(model_dir, self._get_model_files(quantization)) super().__init__(model_dir, model_files, runtime_config) if "vocab" in model_files: with Path(model_files["vocab"]).open("rt", encoding="utf-8") as f: self._vocab = { int(id_): token.replace("\u2581", " ") for token, id_ in (line.strip("\n").split(" ") for line in f) } self._vocab_size = len(self._vocab) blank_idx = next( (id_ for id_, token in self._vocab.items() if token == ""), None, ) if blank_idx is not None: self._blank_idx = blank_idx log.debug(f"NemoConformerRnnt model_files={model_files}") log.debug(f"NemoConformerRnnt runtime_config={runtime_config}") log.debug(f"NemoConformerRnnt runtime_config.onnx_options={runtime_config.onnx_options}") self._encoder = rt.InferenceSession( model_files["encoder"], **TensorRtOptions.add_profile( runtime_config.onnx_options, self._encoder_shapes, ), ) log.info("NemoConformerRnnt encoder loaded") self._decoder_joint = rt.InferenceSession( model_files["decoder_joint"], **runtime_config.onnx_options, ) log.info("NemoConformerRnnt decoder_joint loaded") def _encoder_shapes(self, waveform_len_ms: int, **kwargs: int) -> str: return "audio_signal:{batch}x{features_size}x{len},length:{batch}".format( len=waveform_len_ms // 10, features_size=self._features_size, **kwargs ) @property def _features_size(self) -> int: return self.config.get("features_size", 128) @property def _preprocessor_name(self) -> str: return "NemoPreprocessor128" @property def _subsampling_factor(self) -> int: return self.config.get("subsampling_factor", 8) @staticmethod def _get_model_files(quantization: str | None = None) -> dict[str, str]: suffix = "?" + quantization if quantization else "" return { "encoder": f"encoder-model{suffix}.onnx", "decoder_joint": f"decoder_joint-model{suffix}.onnx", "vocab": "vocab.txt", } class Resampler: """ Waveform resampler implemented with ONNX Runtime. Loads per-source-frequency ONNX resampler models and exposes a call interface that resamples waveforms to the target sample rate (8 or 16 kHz). """ def __init__( self, model_dir: str, sample_rate: Literal[8_000, 16_000], onnx_options: OnnxSessionOptions, ) -> None: self._target_sample_rate = sample_rate self._preprocessors: dict[SampleRates, rt.InferenceSession] = {} dir_path = Path(model_dir) for orig_freq in get_args(SampleRates): if orig_freq == sample_rate: continue self._preprocessors[orig_freq] = rt.InferenceSession( (dir_path / f"resample_{orig_freq // 1000}_{sample_rate // 1000}.onnx").read_bytes(), **TensorRtOptions.add_profile(onnx_options, self._preprocessor_shapes), ) @staticmethod def _get_excluded_providers() -> list[str]: """Providers to exclude for resampler sessions.""" return TensorRtOptions.get_provider_names() def _preprocessor_shapes(self, waveform_len_ms: int, **kwargs: int) -> str: """Shape function for resampler TensorRT profile.""" return "waveforms:{batch}x{len},waveforms_lens:{batch}".format( len=kwargs.get("resampler_waveform_len_ms", waveform_len_ms) * 48, **kwargs, ) def __call__( self, waveforms: npt.NDArray[np.float32], waveforms_lens: npt.NDArray[np.int64], sample_rate: SampleRates, ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: """ Resample waveforms to the target sample rate. If ``sample_rate`` already matches the target, the input is returned unchanged. """ if sample_rate == self._target_sample_rate: return waveforms, waveforms_lens resampled, resampled_lens = self._preprocessors[sample_rate].run( ["resampled", "resampled_lens"], {"waveforms": waveforms, "waveforms_lens": waveforms_lens}, ) assert is_float32_array(resampled) assert is_int64_array(resampled_lens) return resampled, resampled_lens class ASRModelPackage(ABC, Generic[R]): """ Bundle ASR model and resampler into a single package. Provides a static :meth:`load_model` helper to construct both from a directory of ONNX files and runtime options. """ asr: BaseASR resampler: Resampler def __init__(self, asr: BaseASR, resampler: Resampler) -> None: self.asr = asr self.resampler = resampler @staticmethod def load_model( # noqa: C901 path: str | Path | None = None, quantization: str | None = None, sess_options: rt.SessionOptions | None = None, providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, provider_options: Sequence[dict[Any, Any]] | None = None, asr_config: OnnxSessionOptions | None = None, preprocessor_config: PreprocessorRuntimeConfig | None = None, resampler_config: OnnxSessionOptions | None = None, ) -> "ASRModelPackage": """ Load ASR model and resampler from a model directory. Parameters ---------- path : Path to directory with model files. quantization : Model quantization (e.g., ``"int8"``) or ``None``. sess_options, providers, provider_options : Base ONNX Runtime session options. asr_config : ASR ONNX session config (overrides). preprocessor_config : Preprocessor ONNX and concurrency config. resampler_config : Resampler ONNX config. Returns ------- ASRModelPackage Initialised ASR model and resampler bundle. """ default_onnx_config: OnnxSessionOptions = { "sess_options": sess_options, "providers": providers or rt.get_available_providers(), "provider_options": provider_options, } if asr_config is None: asr_config = update_onnx_providers( default_onnx_config, excluded_providers=NemoConformerRnnt._get_excluded_providers(), ) if preprocessor_config is None: preprocessor_config = { **update_onnx_providers( default_onnx_config, new_options={ "TensorrtExecutionProvider": { "trt_fp16_enable": False, "trt_int8_enable": False, } }, excluded_providers=Preprocessor._get_excluded_providers(), ), "max_concurrent_workers": 1, } if resampler_config is None: resampler_config = update_onnx_providers( default_onnx_config, excluded_providers=Resampler._get_excluded_providers(), ) return ASRModelPackage( NemoConformerRnnt(path, ASRRuntimeConfig(asr_config, preprocessor_config), quantization), Resampler(path, NemoConformerRnnt._get_sample_rate(), resampler_config), )