import ctypes from ctypes import POINTER, Structure, c_bool, c_char_p, c_float, c_int32, c_int64, c_size_t, c_uint8, c_uint32 from pathlib import Path import numpy as np class LiteAsrError(RuntimeError): pass class _F32Buffer(Structure): _fields_ = [ ("data", POINTER(c_float)), ("len", c_size_t), ("cap", c_size_t), ("rows", c_size_t), ("cols", c_size_t), ] class _I64Buffer(Structure): _fields_ = [ ("data", POINTER(c_int64)), ("len", c_size_t), ("cap", c_size_t), ] class _Utf8Buffer(Structure): _fields_ = [ ("data", POINTER(c_uint8)), ("len", c_size_t), ("cap", c_size_t), ] def _default_dll_candidates() -> list[Path]: return [ Path("target/debug/liteasr_ffi.dll"), Path("target/release/liteasr_ffi.dll"), ] class LiteAsrFfi: def __init__(self, dll_path: str | Path | None = None) -> None: if dll_path is None: for candidate in _default_dll_candidates(): if candidate.exists(): dll_path = candidate break if dll_path is None: raise LiteAsrError("liteasr_ffi.dll not found. Build with `cargo build`.") self._dll_path = Path(dll_path) if not self._dll_path.exists(): raise LiteAsrError(f"DLL not found: {self._dll_path}") self._lib = ctypes.CDLL(str(self._dll_path)) self._configure_signatures() @property def dll_path(self) -> Path: return self._dll_path def _configure_signatures(self) -> None: self._lib.liteasr_last_error_message.argtypes = [] self._lib.liteasr_last_error_message.restype = c_char_p self._lib.liteasr_preprocess_wav.argtypes = [ c_char_p, c_uint32, c_size_t, POINTER(_F32Buffer), ] self._lib.liteasr_preprocess_wav.restype = c_int32 self._lib.liteasr_build_prompt_ids.argtypes = [ c_char_p, c_char_p, c_char_p, c_bool, c_bool, c_bool, POINTER(_I64Buffer), ] self._lib.liteasr_build_prompt_ids.restype = c_int32 self._lib.liteasr_decode_tokens.argtypes = [ c_char_p, POINTER(c_int64), c_size_t, c_bool, POINTER(_Utf8Buffer), ] self._lib.liteasr_decode_tokens.restype = c_int32 self._lib.liteasr_apply_suppression.argtypes = [ POINTER(c_float), c_size_t, POINTER(c_int64), c_size_t, POINTER(c_int64), c_size_t, c_size_t, ] self._lib.liteasr_apply_suppression.restype = c_int32 self._lib.liteasr_free_f32_buffer.argtypes = [POINTER(_F32Buffer)] self._lib.liteasr_free_f32_buffer.restype = None self._lib.liteasr_free_i64_buffer.argtypes = [POINTER(_I64Buffer)] self._lib.liteasr_free_i64_buffer.restype = None self._lib.liteasr_free_utf8_buffer.argtypes = [POINTER(_Utf8Buffer)] self._lib.liteasr_free_utf8_buffer.restype = None def _raise_last_error(self, fallback: str) -> None: message = self._lib.liteasr_last_error_message() if message: text = message.decode("utf-8", errors="replace") raise LiteAsrError(text) raise LiteAsrError(fallback) def preprocess_wav(self, wav_path: str | Path, target_sr: int, n_mels: int) -> np.ndarray: out = _F32Buffer() rc = self._lib.liteasr_preprocess_wav( str(Path(wav_path)).encode("utf-8"), int(target_sr), int(n_mels), ctypes.byref(out), ) if rc != 0: self._raise_last_error("liteasr_preprocess_wav failed") try: flat = np.ctypeslib.as_array(out.data, shape=(out.len,)) matrix = flat.reshape((out.rows, out.cols)).copy() return matrix finally: self._lib.liteasr_free_f32_buffer(ctypes.byref(out)) def build_prompt_ids( self, tokenizer_json_path: str | Path, language: str, task: str, with_timestamps: bool, omit_language_token: bool, omit_notimestamps_token: bool, ) -> list[int]: out = _I64Buffer() rc = self._lib.liteasr_build_prompt_ids( str(Path(tokenizer_json_path)).encode("utf-8"), language.encode("utf-8"), task.encode("utf-8"), bool(with_timestamps), bool(omit_language_token), bool(omit_notimestamps_token), ctypes.byref(out), ) if rc != 0: self._raise_last_error("liteasr_build_prompt_ids failed") try: arr = np.ctypeslib.as_array(out.data, shape=(out.len,)) return [int(v) for v in arr.tolist()] finally: self._lib.liteasr_free_i64_buffer(ctypes.byref(out)) def decode_tokens( self, tokenizer_json_path: str | Path, token_ids: list[int], skip_special_tokens: bool = True, ) -> str: out = _Utf8Buffer() token_np = np.array(token_ids, dtype=np.int64) token_ptr = token_np.ctypes.data_as(POINTER(c_int64)) rc = self._lib.liteasr_decode_tokens( str(Path(tokenizer_json_path)).encode("utf-8"), token_ptr, int(token_np.shape[0]), bool(skip_special_tokens), ctypes.byref(out), ) if rc != 0: self._raise_last_error("liteasr_decode_tokens failed") try: data = ctypes.string_at(out.data, out.len) return data.decode("utf-8", errors="replace") finally: self._lib.liteasr_free_utf8_buffer(ctypes.byref(out)) def apply_suppression( self, logits: np.ndarray, suppress_ids: list[int], begin_suppress_ids: list[int], step: int, ) -> np.ndarray: if logits.dtype != np.float32 or not logits.flags["C_CONTIGUOUS"]: logits = np.ascontiguousarray(logits, dtype=np.float32) suppress_np = np.array(suppress_ids or [], dtype=np.int64) begin_np = np.array(begin_suppress_ids or [], dtype=np.int64) suppress_ptr = ( suppress_np.ctypes.data_as(POINTER(c_int64)) if suppress_np.size > 0 else ctypes.cast(0, POINTER(c_int64)) ) begin_ptr = ( begin_np.ctypes.data_as(POINTER(c_int64)) if begin_np.size > 0 else ctypes.cast(0, POINTER(c_int64)) ) rc = self._lib.liteasr_apply_suppression( logits.ctypes.data_as(POINTER(c_float)), int(logits.shape[0]), suppress_ptr, int(suppress_np.shape[0]), begin_ptr, int(begin_np.shape[0]), int(step), ) if rc != 0: self._raise_last_error("liteasr_apply_suppression failed") return logits