LiteASR-ONNX-DLL / ffi_python /liteasr_ffi.py
zukky's picture
Upload folder using huggingface_hub
27a58dc verified
import ctypes
from ctypes import POINTER, Structure, c_bool, c_char_p, c_float, c_int32, c_int64, c_size_t, c_uint8, c_uint32
from pathlib import Path
import numpy as np
class LiteAsrError(RuntimeError):
pass
class _F32Buffer(Structure):
_fields_ = [
("data", POINTER(c_float)),
("len", c_size_t),
("cap", c_size_t),
("rows", c_size_t),
("cols", c_size_t),
]
class _I64Buffer(Structure):
_fields_ = [
("data", POINTER(c_int64)),
("len", c_size_t),
("cap", c_size_t),
]
class _Utf8Buffer(Structure):
_fields_ = [
("data", POINTER(c_uint8)),
("len", c_size_t),
("cap", c_size_t),
]
def _default_dll_candidates() -> list[Path]:
return [
Path("target/debug/liteasr_ffi.dll"),
Path("target/release/liteasr_ffi.dll"),
]
class LiteAsrFfi:
def __init__(self, dll_path: str | Path | None = None) -> None:
if dll_path is None:
for candidate in _default_dll_candidates():
if candidate.exists():
dll_path = candidate
break
if dll_path is None:
raise LiteAsrError("liteasr_ffi.dll not found. Build with `cargo build`.")
self._dll_path = Path(dll_path)
if not self._dll_path.exists():
raise LiteAsrError(f"DLL not found: {self._dll_path}")
self._lib = ctypes.CDLL(str(self._dll_path))
self._configure_signatures()
@property
def dll_path(self) -> Path:
return self._dll_path
def _configure_signatures(self) -> None:
self._lib.liteasr_last_error_message.argtypes = []
self._lib.liteasr_last_error_message.restype = c_char_p
self._lib.liteasr_preprocess_wav.argtypes = [
c_char_p,
c_uint32,
c_size_t,
POINTER(_F32Buffer),
]
self._lib.liteasr_preprocess_wav.restype = c_int32
self._lib.liteasr_build_prompt_ids.argtypes = [
c_char_p,
c_char_p,
c_char_p,
c_bool,
c_bool,
c_bool,
POINTER(_I64Buffer),
]
self._lib.liteasr_build_prompt_ids.restype = c_int32
self._lib.liteasr_decode_tokens.argtypes = [
c_char_p,
POINTER(c_int64),
c_size_t,
c_bool,
POINTER(_Utf8Buffer),
]
self._lib.liteasr_decode_tokens.restype = c_int32
self._lib.liteasr_apply_suppression.argtypes = [
POINTER(c_float),
c_size_t,
POINTER(c_int64),
c_size_t,
POINTER(c_int64),
c_size_t,
c_size_t,
]
self._lib.liteasr_apply_suppression.restype = c_int32
self._lib.liteasr_free_f32_buffer.argtypes = [POINTER(_F32Buffer)]
self._lib.liteasr_free_f32_buffer.restype = None
self._lib.liteasr_free_i64_buffer.argtypes = [POINTER(_I64Buffer)]
self._lib.liteasr_free_i64_buffer.restype = None
self._lib.liteasr_free_utf8_buffer.argtypes = [POINTER(_Utf8Buffer)]
self._lib.liteasr_free_utf8_buffer.restype = None
def _raise_last_error(self, fallback: str) -> None:
message = self._lib.liteasr_last_error_message()
if message:
text = message.decode("utf-8", errors="replace")
raise LiteAsrError(text)
raise LiteAsrError(fallback)
def preprocess_wav(self, wav_path: str | Path, target_sr: int, n_mels: int) -> np.ndarray:
out = _F32Buffer()
rc = self._lib.liteasr_preprocess_wav(
str(Path(wav_path)).encode("utf-8"),
int(target_sr),
int(n_mels),
ctypes.byref(out),
)
if rc != 0:
self._raise_last_error("liteasr_preprocess_wav failed")
try:
flat = np.ctypeslib.as_array(out.data, shape=(out.len,))
matrix = flat.reshape((out.rows, out.cols)).copy()
return matrix
finally:
self._lib.liteasr_free_f32_buffer(ctypes.byref(out))
def build_prompt_ids(
self,
tokenizer_json_path: str | Path,
language: str,
task: str,
with_timestamps: bool,
omit_language_token: bool,
omit_notimestamps_token: bool,
) -> list[int]:
out = _I64Buffer()
rc = self._lib.liteasr_build_prompt_ids(
str(Path(tokenizer_json_path)).encode("utf-8"),
language.encode("utf-8"),
task.encode("utf-8"),
bool(with_timestamps),
bool(omit_language_token),
bool(omit_notimestamps_token),
ctypes.byref(out),
)
if rc != 0:
self._raise_last_error("liteasr_build_prompt_ids failed")
try:
arr = np.ctypeslib.as_array(out.data, shape=(out.len,))
return [int(v) for v in arr.tolist()]
finally:
self._lib.liteasr_free_i64_buffer(ctypes.byref(out))
def decode_tokens(
self,
tokenizer_json_path: str | Path,
token_ids: list[int],
skip_special_tokens: bool = True,
) -> str:
out = _Utf8Buffer()
token_np = np.array(token_ids, dtype=np.int64)
token_ptr = token_np.ctypes.data_as(POINTER(c_int64))
rc = self._lib.liteasr_decode_tokens(
str(Path(tokenizer_json_path)).encode("utf-8"),
token_ptr,
int(token_np.shape[0]),
bool(skip_special_tokens),
ctypes.byref(out),
)
if rc != 0:
self._raise_last_error("liteasr_decode_tokens failed")
try:
data = ctypes.string_at(out.data, out.len)
return data.decode("utf-8", errors="replace")
finally:
self._lib.liteasr_free_utf8_buffer(ctypes.byref(out))
def apply_suppression(
self,
logits: np.ndarray,
suppress_ids: list[int],
begin_suppress_ids: list[int],
step: int,
) -> np.ndarray:
if logits.dtype != np.float32 or not logits.flags["C_CONTIGUOUS"]:
logits = np.ascontiguousarray(logits, dtype=np.float32)
suppress_np = np.array(suppress_ids or [], dtype=np.int64)
begin_np = np.array(begin_suppress_ids or [], dtype=np.int64)
suppress_ptr = (
suppress_np.ctypes.data_as(POINTER(c_int64))
if suppress_np.size > 0
else ctypes.cast(0, POINTER(c_int64))
)
begin_ptr = (
begin_np.ctypes.data_as(POINTER(c_int64))
if begin_np.size > 0
else ctypes.cast(0, POINTER(c_int64))
)
rc = self._lib.liteasr_apply_suppression(
logits.ctypes.data_as(POINTER(c_float)),
int(logits.shape[0]),
suppress_ptr,
int(suppress_np.shape[0]),
begin_ptr,
int(begin_np.shape[0]),
int(step),
)
if rc != 0:
self._raise_last_error("liteasr_apply_suppression failed")
return logits