pltobing's picture
Formatting black, isort, flake8
0c397a9
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI and upstream authors
# Copyright (c) 2025-2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Preprocessor and ASR modules for Nemotron cache-aware streaming ASR.
Adapted from: https://github.com/istupakov/onnx-asr/tree/main
Provides:
- :class:`Preprocessor` for waveform → feature extraction.
- :class:`NemoConformerRnnt` ONNX wrapper for NeMo Conformer RNN-T.
- :class:`Resampler` ONNX wrapper for waveform resampling.
- :class:`ASRModelPackage` to bundle ASR + resampler into a single API.
"""
from __future__ import annotations
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar, get_args
import numpy as np
import numpy.typing as npt
import onnxruntime as rt
from src.asr.modules_config import (ASRConfig, ASRRuntimeConfig,
PreprocessorRuntimeConfig)
from src.asr.onnx_utils import (OnnxSessionOptions, TensorRtOptions,
get_onnx_providers, update_onnx_providers)
from src.asr.utils import (SampleRates, find_files, is_float32_array,
is_int64_array)
log = logging.getLogger(__name__)
S = TypeVar("S")
R = TypeVar("R")
class Preprocessor:
"""
ONNX-based ASR preprocessor (waveform → model features).
Wraps an ONNX preprocessor model (e.g., Kaldi-style feature extractor)
and provides a call interface that can optionally parallelise over
batch items using a thread pool.
Parameters
----------
model_dir :
Directory with preprocessor ONNX model(s).
name :
Preprocessor name (e.g., ``"kaldi"``).
runtime_config :
Preprocessor runtime configuration (providers, concurrency).
"""
def __init__(
self,
model_dir: str,
name: str,
runtime_config: PreprocessorRuntimeConfig,
) -> None:
onnx_options = runtime_config.copy()
self._max_concurrent_workers: int = onnx_options.pop("max_concurrent_workers", 1)
log.debug(f"Preprocessor name={name}")
log.debug(f"Preprocessor runtime_config={runtime_config}")
if name == "identity":
self._preprocessor: rt.InferenceSession | None = None
else:
providers = get_onnx_providers(onnx_options)
if name == "kaldi" and providers and providers != ["CPUExecutionProvider"]:
name = "kaldi_fast"
dir_path = Path(model_dir)
filename = str(Path(name).with_suffix(".onnx"))
filepath = dir_path / filename
log.debug(f"Preprocessor providers={providers}")
log.debug(f"Preprocessor filename={filename}")
log.debug(f"Preprocessor filepath={filepath}")
log.debug(f"Preprocessor shapes_fn={self._preprocessor_shapes}")
self._preprocessor = rt.InferenceSession(
filepath.read_bytes(),
**TensorRtOptions.add_profile(onnx_options, self._preprocessor_shapes),
)
log.debug(f"Preprocessor loaded: {self._preprocessor}")
@staticmethod
def _get_excluded_providers() -> list[str]:
"""Providers to exclude for this preprocessor (e.g., GPU-only)."""
return ["CUDAExecutionProvider"]
def _preprocessor_shapes(self, waveform_len_ms: int, **kwargs: int) -> str:
"""
Shape function for TensorRT profiles.
Returns a shape string like ``"waveforms:{batch}x{len},waveforms_lens:{batch}"``.
"""
return "waveforms:{batch}x{len},waveforms_lens:{batch}".format(
len=waveform_len_ms * 16,
**kwargs,
)
def _preprocess(
self,
waveforms: npt.NDArray[np.float32],
waveforms_lens: npt.NDArray[np.int64],
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
if not self._preprocessor:
return waveforms, waveforms_lens
features, features_lens = self._preprocessor.run(
["features", "features_lens"],
{"waveforms": waveforms, "waveforms_lens": waveforms_lens},
)
assert is_float32_array(features)
assert is_int64_array(features_lens)
return features, features_lens
def __call__(
self,
waveforms: npt.NDArray[np.float32],
waveforms_lens: npt.NDArray[np.int64],
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
"""
Convert batch waveforms to model features.
For multi-item batches and when ``max_concurrent_workers > 1``,
preprocessing is parallelised over batch items.
"""
if self._preprocessor is None or waveforms.shape[0] == 1 or self._max_concurrent_workers == 1:
return self._preprocess(waveforms, waveforms_lens)
with ThreadPoolExecutor(max_workers=self._max_concurrent_workers) as executor:
features, features_lens = zip(
*executor.map(
self._preprocess,
waveforms[:, None],
waveforms_lens[:, None],
strict=True,
)
)
return np.concatenate(features, axis=0), np.concatenate(features_lens, axis=0)
class BaseASR(ABC):
"""Base ONNX ASR wrapper."""
def __init__(
self,
model_dir: str,
model_files: dict[str, Path],
runtime_config: ASRRuntimeConfig,
) -> None:
if "config" in model_files:
with model_files["config"].open("rt", encoding="utf-8") as f:
self.config: ASRConfig = json.load(f)
else:
self.config = {}
self.runtime_config = runtime_config
self._preprocessor = Preprocessor(
model_dir,
self._preprocessor_name,
runtime_config.preprocessor_config,
)
@staticmethod
def _get_excluded_providers() -> list[str]:
"""Providers to exclude for ASR encoder/decoder (override per subclass)."""
return []
@staticmethod
def _get_sample_rate() -> Literal[8_000, 16_000]:
"""Default ASR model sample rate (Hz)."""
return 16_000
@property
@abstractmethod
def _preprocessor_name(self) -> str:
"""Return the name of the preprocessor ONNX model."""
...
class NemoConformerRnnt(BaseASR):
"""ONNX wrapper for a NeMo Conformer RNN-T ASR model."""
def __init__(
self,
model_dir: str,
runtime_config: ASRRuntimeConfig,
quantization: str | None = None,
) -> None:
model_files = find_files(model_dir, self._get_model_files(quantization))
super().__init__(model_dir, model_files, runtime_config)
if "vocab" in model_files:
with Path(model_files["vocab"]).open("rt", encoding="utf-8") as f:
self._vocab = {
int(id_): token.replace("\u2581", " ")
for token, id_ in (line.strip("\n").split(" ") for line in f)
}
self._vocab_size = len(self._vocab)
blank_idx = next(
(id_ for id_, token in self._vocab.items() if token == "<blk>"),
None,
)
if blank_idx is not None:
self._blank_idx = blank_idx
log.debug(f"NemoConformerRnnt model_files={model_files}")
log.debug(f"NemoConformerRnnt runtime_config={runtime_config}")
log.debug(f"NemoConformerRnnt runtime_config.onnx_options={runtime_config.onnx_options}")
self._encoder = rt.InferenceSession(
model_files["encoder"],
**TensorRtOptions.add_profile(
runtime_config.onnx_options,
self._encoder_shapes,
),
)
log.info("NemoConformerRnnt encoder loaded")
self._decoder_joint = rt.InferenceSession(
model_files["decoder_joint"],
**runtime_config.onnx_options,
)
log.info("NemoConformerRnnt decoder_joint loaded")
def _encoder_shapes(self, waveform_len_ms: int, **kwargs: int) -> str:
return "audio_signal:{batch}x{features_size}x{len},length:{batch}".format(
len=waveform_len_ms // 10, features_size=self._features_size, **kwargs
)
@property
def _features_size(self) -> int:
return self.config.get("features_size", 128)
@property
def _preprocessor_name(self) -> str:
return "NemoPreprocessor128"
@property
def _subsampling_factor(self) -> int:
return self.config.get("subsampling_factor", 8)
@staticmethod
def _get_model_files(quantization: str | None = None) -> dict[str, str]:
suffix = "?" + quantization if quantization else ""
return {
"encoder": f"encoder-model{suffix}.onnx",
"decoder_joint": f"decoder_joint-model{suffix}.onnx",
"vocab": "vocab.txt",
}
class Resampler:
"""
Waveform resampler implemented with ONNX Runtime.
Loads per-source-frequency ONNX resampler models and exposes
a call interface that resamples waveforms to the target
sample rate (8 or 16 kHz).
"""
def __init__(
self,
model_dir: str,
sample_rate: Literal[8_000, 16_000],
onnx_options: OnnxSessionOptions,
) -> None:
self._target_sample_rate = sample_rate
self._preprocessors: dict[SampleRates, rt.InferenceSession] = {}
dir_path = Path(model_dir)
for orig_freq in get_args(SampleRates):
if orig_freq == sample_rate:
continue
self._preprocessors[orig_freq] = rt.InferenceSession(
(dir_path / f"resample_{orig_freq // 1000}_{sample_rate // 1000}.onnx").read_bytes(),
**TensorRtOptions.add_profile(onnx_options, self._preprocessor_shapes),
)
@staticmethod
def _get_excluded_providers() -> list[str]:
"""Providers to exclude for resampler sessions."""
return TensorRtOptions.get_provider_names()
def _preprocessor_shapes(self, waveform_len_ms: int, **kwargs: int) -> str:
"""Shape function for resampler TensorRT profile."""
return "waveforms:{batch}x{len},waveforms_lens:{batch}".format(
len=kwargs.get("resampler_waveform_len_ms", waveform_len_ms) * 48,
**kwargs,
)
def __call__(
self,
waveforms: npt.NDArray[np.float32],
waveforms_lens: npt.NDArray[np.int64],
sample_rate: SampleRates,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
"""
Resample waveforms to the target sample rate.
If ``sample_rate`` already matches the target, the input is returned
unchanged.
"""
if sample_rate == self._target_sample_rate:
return waveforms, waveforms_lens
resampled, resampled_lens = self._preprocessors[sample_rate].run(
["resampled", "resampled_lens"],
{"waveforms": waveforms, "waveforms_lens": waveforms_lens},
)
assert is_float32_array(resampled)
assert is_int64_array(resampled_lens)
return resampled, resampled_lens
class ASRModelPackage(ABC, Generic[R]):
"""
Bundle ASR model and resampler into a single package.
Provides a static :meth:`load_model` helper to construct both from
a directory of ONNX files and runtime options.
"""
asr: BaseASR
resampler: Resampler
def __init__(self, asr: BaseASR, resampler: Resampler) -> None:
self.asr = asr
self.resampler = resampler
@staticmethod
def load_model( # noqa: C901
path: str | Path | None = None,
quantization: str | None = None,
sess_options: rt.SessionOptions | None = None,
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
provider_options: Sequence[dict[Any, Any]] | None = None,
asr_config: OnnxSessionOptions | None = None,
preprocessor_config: PreprocessorRuntimeConfig | None = None,
resampler_config: OnnxSessionOptions | None = None,
) -> "ASRModelPackage":
"""
Load ASR model and resampler from a model directory.
Parameters
----------
path :
Path to directory with model files.
quantization :
Model quantization (e.g., ``"int8"``) or ``None``.
sess_options, providers, provider_options :
Base ONNX Runtime session options.
asr_config :
ASR ONNX session config (overrides).
preprocessor_config :
Preprocessor ONNX and concurrency config.
resampler_config :
Resampler ONNX config.
Returns
-------
ASRModelPackage
Initialised ASR model and resampler bundle.
"""
default_onnx_config: OnnxSessionOptions = {
"sess_options": sess_options,
"providers": providers or rt.get_available_providers(),
"provider_options": provider_options,
}
if asr_config is None:
asr_config = update_onnx_providers(
default_onnx_config,
excluded_providers=NemoConformerRnnt._get_excluded_providers(),
)
if preprocessor_config is None:
preprocessor_config = {
**update_onnx_providers(
default_onnx_config,
new_options={
"TensorrtExecutionProvider": {
"trt_fp16_enable": False,
"trt_int8_enable": False,
}
},
excluded_providers=Preprocessor._get_excluded_providers(),
),
"max_concurrent_workers": 1,
}
if resampler_config is None:
resampler_config = update_onnx_providers(
default_onnx_config,
excluded_providers=Resampler._get_excluded_providers(),
)
return ASRModelPackage(
NemoConformerRnnt(path, ASRRuntimeConfig(asr_config, preprocessor_config), quantization),
Resampler(path, NemoConformerRnnt._get_sample_rate(), resampler_config),
)