#!/usr/bin/env python3 # License: CC-BY-NC-ND-4.0 # Created by: Patrick Lumbantobing, Vertox-AI # Copyright (c) 2026 Vertox-AI. All rights reserved. # # This work is licensed under the Creative Commons # Attribution-NonCommercial-NoDerivatives 4.0 International License. # To view a copy of this license, visit # http://creativecommons.org/licenses/by-nc-nd/4.0/ """ Utility functions for ASR preprocessing and model I/O. Adapted from: https://github.com/istupakov/onnx-asr/tree/main Provides: - Sample-rate validation helpers. - Typed NumPy array guards for common dtypes. - WAV reading utilities (mono-mixing, multi-width support). - Batch padding to a common length. - Log-softmax implementation. - Model file discovery helpers. """ from __future__ import annotations import wave from pathlib import Path from typing import Literal, Optional, TypeGuard, cast, get_args import numpy as np import numpy.typing as npt SampleRates = Literal[8_000, 11_025, 16_000, 22_050, 24_000, 32_000, 44_100, 48_000] def is_supported_sample_rate(sample_rate: int) -> TypeGuard[SampleRates]: """Return True if ``sample_rate`` is one of the supported ASR rates.""" return sample_rate in get_args(SampleRates) def is_float16_array(x: object) -> TypeGuard[npt.NDArray[np.float16]]: """Return True if ``x`` is a NumPy array with dtype float16.""" return isinstance(x, np.ndarray) and x.dtype == np.float16 def is_float32_array(x: object) -> TypeGuard[npt.NDArray[np.float32]]: """Return True if ``x`` is a NumPy array with dtype float32.""" return isinstance(x, np.ndarray) and x.dtype == np.float32 def is_int32_array(x: object) -> TypeGuard[npt.NDArray[np.int32]]: """Return True if ``x`` is a NumPy array with dtype int32.""" return isinstance(x, np.ndarray) and x.dtype == np.int32 def is_int64_array(x: object) -> TypeGuard[npt.NDArray[np.int64]]: """Return True if ``x`` is a NumPy array with dtype int64.""" return isinstance(x, np.ndarray) and x.dtype == np.int64 class ModelPathNotDirectoryError(NotADirectoryError): """Raised when a given model path is not a directory.""" def __init__(self, path: str | Path) -> None: super().__init__(f"The path '{path}' is not a directory.") class ModelFileNotFoundError(FileNotFoundError): """Raised when a required model file cannot be found in a directory.""" def __init__(self, filename: str | Path, path: str | Path) -> None: super().__init__(f"File '{filename}' not found in path '{path}'.") class MoreThanOneModelFileFoundError(Exception): """Raised when multiple candidate model files match a given pattern.""" def __init__(self, filename: str | Path, path: str | Path) -> None: super().__init__(f"Found more than one file '{filename}' in path '{path}'.") class SupportedOnlyMonoAudioError(ValueError): """Raised when a multi-channel waveform is provided where mono is required.""" def __init__(self) -> None: super().__init__("Supported only mono audio.") class WrongSampleRateError(ValueError): """Raised when a waveform sample rate is not supported.""" def __init__(self) -> None: super().__init__(f"Supported only {get_args(SampleRates)} sample rates.") class DifferentSampleRatesError(ValueError): """Raised when waveforms in a batch have different sample rates.""" def __init__(self) -> None: super().__init__("All sample rates in a batch must be the same.") def read_wav(filename: str) -> tuple[npt.NDArray[np.float32], int]: """ Read a PCM WAV file into a mono float32 NumPy array. The waveform is normalised to the range [-1, 1] (approximately) and multi-channel input is averaged down to mono. Parameters ---------- filename : Path to the WAV file. Returns ------- (np.ndarray, int) Tuple of ``(audio, sample_rate)`` where ``audio`` has shape ``(T,)``. """ with wave.open(filename, mode="rb") as f: data = f.readframes(f.getnframes()) zero_value = 0 if f.getsampwidth() == 1: # 8-bit unsigned PCM. buffer = np.frombuffer(data, dtype="u1") zero_value = 1 elif f.getsampwidth() == 3: # 24-bit PCM via 32-bit view. buffer = np.zeros((len(data) // 3, 4), dtype="V1") buffer[:, -3:] = np.frombuffer(data, dtype="V1").reshape(-1, f.getsampwidth()) buffer = buffer.view(dtype=" tuple[npt.NDArray[np.float32], npt.NDArray[np.int64], SampleRates]: """ Convert a list of waveforms or filenames into a padded batch array. Parameters ---------- waveforms : List of either mono float32 arrays (shape ``(T,)``) or filenames. numpy_sample_rate : Sample rate to associate with NumPy waveforms (ignored for filenames). Returns ------- (np.ndarray, np.ndarray, SampleRates) ``(batch_waveforms, lengths, sample_rate)`` where ``batch_waveforms`` has shape ``(B, T_max)``. """ results: list[npt.NDArray[np.float32]] = [] sample_rates: list[int | SampleRates | None] = [] for x in waveforms: if isinstance(x, str): waveform, sample_rate = read_wav(x) results.append(waveform) sample_rates.append(sample_rate) else: if x.ndim != 1: raise SupportedOnlyMonoAudioError results.append(x) sample_rates.append(numpy_sample_rate) if len(set(sample_rates)) > 1: raise DifferentSampleRatesError sr = sample_rates[0] if not isinstance(sr, int): # If everything came from NumPy arrays, sr is already a SampleRates. sr_int = int(sr) if sr is not None else None else: sr_int = sr if sr_int is not None and is_supported_sample_rate(sr_int): batch, lengths = pad_list(results) return batch, lengths, sr_int # type: ignore[return-value] raise WrongSampleRateError def pad_list( arrays: list[npt.NDArray[np.float32]], ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]: """ Pad a list of 1-D NumPy arrays to a common length. Parameters ---------- arrays : List of waveforms with shape ``(T_i,)``. Returns ------- (np.ndarray, np.ndarray) ``(batch, lengths)`` where ``batch`` has shape ``(B, T_max)`` and ``lengths`` holds original lengths. """ lens = np.array([array.shape[0] for array in arrays], dtype=np.int64) result = np.zeros((len(arrays), lens.max()), dtype=np.float32) for i, x in enumerate(arrays): result[i, : x.shape[0]] = x[: min(x.shape[0], result.shape[1])] return result, lens def log_softmax( logits: npt.NDArray[np.float32], axis: int | None = None, ) -> npt.NDArray[np.float32]: """ Compute the log-softmax of an array along the given axis. Parameters ---------- logits : Input array of unnormalised log-probabilities. axis : Axis to normalise over (default: last axis). Returns ------- np.ndarray Log-softmax of ``logits`` with the same shape and dtype float32. """ if axis is None: axis = -1 tmp = logits - np.max(logits, axis=axis, keepdims=True) tmp -= np.log(np.sum(np.exp(tmp), axis=axis, keepdims=True)) return cast(npt.NDArray[np.float32], tmp) def find_files(path: str | Path, files: dict[str, str]) -> dict[str, Path]: """ Resolve model-related filenames within a directory. Parameters ---------- path : Directory containing model files. files : Mapping from logical name (e.g., ``"encoder"``) to glob pattern (e.g., ``"encoder*.onnx"``). Returns ------- dict[str, Path] Mapping from logical name to resolved :class:`Path`. Raises ------ ModelPathNotDirectoryError If ``path`` is not a directory. ModelFileNotFoundError If no file matches a given pattern. MoreThanOneModelFileFoundError If multiple files match a given pattern. """ if not Path(path).is_dir(): raise ModelPathNotDirectoryError(path) # Optional config.json convenience. if Path(path, "config.json").exists(): files |= {"config": "config.json"} def find(filename: str) -> Path: matches = list(Path(path).glob(filename)) if len(matches) == 0: raise ModelFileNotFoundError(filename, path) if len(matches) > 1: raise MoreThanOneModelFileFoundError(filename, path) return matches[0] return {key: find(pattern) for key, pattern in files.items()}