pltobing's picture
Formatting black, isort, flake8
0c397a9
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Utility functions for ASR preprocessing and model I/O.
Adapted from: https://github.com/istupakov/onnx-asr/tree/main
Provides:
- Sample-rate validation helpers.
- Typed NumPy array guards for common dtypes.
- WAV reading utilities (mono-mixing, multi-width support).
- Batch padding to a common length.
- Log-softmax implementation.
- Model file discovery helpers.
"""
from __future__ import annotations
import wave
from pathlib import Path
from typing import Literal, Optional, TypeGuard, cast, get_args
import numpy as np
import numpy.typing as npt
SampleRates = Literal[8_000, 11_025, 16_000, 22_050, 24_000, 32_000, 44_100, 48_000]
def is_supported_sample_rate(sample_rate: int) -> TypeGuard[SampleRates]:
"""Return True if ``sample_rate`` is one of the supported ASR rates."""
return sample_rate in get_args(SampleRates)
def is_float16_array(x: object) -> TypeGuard[npt.NDArray[np.float16]]:
"""Return True if ``x`` is a NumPy array with dtype float16."""
return isinstance(x, np.ndarray) and x.dtype == np.float16
def is_float32_array(x: object) -> TypeGuard[npt.NDArray[np.float32]]:
"""Return True if ``x`` is a NumPy array with dtype float32."""
return isinstance(x, np.ndarray) and x.dtype == np.float32
def is_int32_array(x: object) -> TypeGuard[npt.NDArray[np.int32]]:
"""Return True if ``x`` is a NumPy array with dtype int32."""
return isinstance(x, np.ndarray) and x.dtype == np.int32
def is_int64_array(x: object) -> TypeGuard[npt.NDArray[np.int64]]:
"""Return True if ``x`` is a NumPy array with dtype int64."""
return isinstance(x, np.ndarray) and x.dtype == np.int64
class ModelPathNotDirectoryError(NotADirectoryError):
"""Raised when a given model path is not a directory."""
def __init__(self, path: str | Path) -> None:
super().__init__(f"The path '{path}' is not a directory.")
class ModelFileNotFoundError(FileNotFoundError):
"""Raised when a required model file cannot be found in a directory."""
def __init__(self, filename: str | Path, path: str | Path) -> None:
super().__init__(f"File '{filename}' not found in path '{path}'.")
class MoreThanOneModelFileFoundError(Exception):
"""Raised when multiple candidate model files match a given pattern."""
def __init__(self, filename: str | Path, path: str | Path) -> None:
super().__init__(f"Found more than one file '{filename}' in path '{path}'.")
class SupportedOnlyMonoAudioError(ValueError):
"""Raised when a multi-channel waveform is provided where mono is required."""
def __init__(self) -> None:
super().__init__("Supported only mono audio.")
class WrongSampleRateError(ValueError):
"""Raised when a waveform sample rate is not supported."""
def __init__(self) -> None:
super().__init__(f"Supported only {get_args(SampleRates)} sample rates.")
class DifferentSampleRatesError(ValueError):
"""Raised when waveforms in a batch have different sample rates."""
def __init__(self) -> None:
super().__init__("All sample rates in a batch must be the same.")
def read_wav(filename: str) -> tuple[npt.NDArray[np.float32], int]:
"""
Read a PCM WAV file into a mono float32 NumPy array.
The waveform is normalised to the range [-1, 1] (approximately) and
multi-channel input is averaged down to mono.
Parameters
----------
filename :
Path to the WAV file.
Returns
-------
(np.ndarray, int)
Tuple of ``(audio, sample_rate)`` where ``audio`` has shape ``(T,)``.
"""
with wave.open(filename, mode="rb") as f:
data = f.readframes(f.getnframes())
zero_value = 0
if f.getsampwidth() == 1:
# 8-bit unsigned PCM.
buffer = np.frombuffer(data, dtype="u1")
zero_value = 1
elif f.getsampwidth() == 3:
# 24-bit PCM via 32-bit view.
buffer = np.zeros((len(data) // 3, 4), dtype="V1")
buffer[:, -3:] = np.frombuffer(data, dtype="V1").reshape(-1, f.getsampwidth())
buffer = buffer.view(dtype="<i4")
else:
# 16-bit or 32-bit PCM.
buffer = np.frombuffer(data, dtype=f"<i{f.getsampwidth()}")
max_value = 2 ** (8 * buffer.itemsize - 1)
sample_rate = f.getframerate()
audio = buffer.reshape(f.getnframes(), f.getnchannels()).astype(np.float32) / max_value - zero_value
if audio.shape[-1] == 1:
return audio[:, 0], sample_rate
# Multi-channel: simple average to mono.
audio_ch_sum = audio[:, 0]
for ch_idx in range(1, audio.shape[-1]):
audio_ch_sum = audio_ch_sum + audio[:, ch_idx]
return audio_ch_sum / audio.shape[-1], sample_rate
def read_wav_files(
waveforms: list[npt.NDArray[np.float32] | str],
numpy_sample_rate: Optional[SampleRates] = None,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64], SampleRates]:
"""
Convert a list of waveforms or filenames into a padded batch array.
Parameters
----------
waveforms :
List of either mono float32 arrays (shape ``(T,)``) or filenames.
numpy_sample_rate :
Sample rate to associate with NumPy waveforms (ignored for filenames).
Returns
-------
(np.ndarray, np.ndarray, SampleRates)
``(batch_waveforms, lengths, sample_rate)`` where
``batch_waveforms`` has shape ``(B, T_max)``.
"""
results: list[npt.NDArray[np.float32]] = []
sample_rates: list[int | SampleRates | None] = []
for x in waveforms:
if isinstance(x, str):
waveform, sample_rate = read_wav(x)
results.append(waveform)
sample_rates.append(sample_rate)
else:
if x.ndim != 1:
raise SupportedOnlyMonoAudioError
results.append(x)
sample_rates.append(numpy_sample_rate)
if len(set(sample_rates)) > 1:
raise DifferentSampleRatesError
sr = sample_rates[0]
if not isinstance(sr, int):
# If everything came from NumPy arrays, sr is already a SampleRates.
sr_int = int(sr) if sr is not None else None
else:
sr_int = sr
if sr_int is not None and is_supported_sample_rate(sr_int):
batch, lengths = pad_list(results)
return batch, lengths, sr_int # type: ignore[return-value]
raise WrongSampleRateError
def pad_list(
arrays: list[npt.NDArray[np.float32]],
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int64]]:
"""
Pad a list of 1-D NumPy arrays to a common length.
Parameters
----------
arrays :
List of waveforms with shape ``(T_i,)``.
Returns
-------
(np.ndarray, np.ndarray)
``(batch, lengths)`` where ``batch`` has shape ``(B, T_max)`` and
``lengths`` holds original lengths.
"""
lens = np.array([array.shape[0] for array in arrays], dtype=np.int64)
result = np.zeros((len(arrays), lens.max()), dtype=np.float32)
for i, x in enumerate(arrays):
result[i, : x.shape[0]] = x[: min(x.shape[0], result.shape[1])]
return result, lens
def log_softmax(
logits: npt.NDArray[np.float32],
axis: int | None = None,
) -> npt.NDArray[np.float32]:
"""
Compute the log-softmax of an array along the given axis.
Parameters
----------
logits :
Input array of unnormalised log-probabilities.
axis :
Axis to normalise over (default: last axis).
Returns
-------
np.ndarray
Log-softmax of ``logits`` with the same shape and dtype float32.
"""
if axis is None:
axis = -1
tmp = logits - np.max(logits, axis=axis, keepdims=True)
tmp -= np.log(np.sum(np.exp(tmp), axis=axis, keepdims=True))
return cast(npt.NDArray[np.float32], tmp)
def find_files(path: str | Path, files: dict[str, str]) -> dict[str, Path]:
"""
Resolve model-related filenames within a directory.
Parameters
----------
path :
Directory containing model files.
files :
Mapping from logical name (e.g., ``"encoder"``) to glob pattern
(e.g., ``"encoder*.onnx"``).
Returns
-------
dict[str, Path]
Mapping from logical name to resolved :class:`Path`.
Raises
------
ModelPathNotDirectoryError
If ``path`` is not a directory.
ModelFileNotFoundError
If no file matches a given pattern.
MoreThanOneModelFileFoundError
If multiple files match a given pattern.
"""
if not Path(path).is_dir():
raise ModelPathNotDirectoryError(path)
# Optional config.json convenience.
if Path(path, "config.json").exists():
files |= {"config": "config.json"}
def find(filename: str) -> Path:
matches = list(Path(path).glob(filename))
if len(matches) == 0:
raise ModelFileNotFoundError(filename, path)
if len(matches) > 1:
raise MoreThanOneModelFileFoundError(filename, path)
return matches[0]
return {key: find(pattern) for key, pattern in files.items()}