|
|
import ctypes |
|
|
from ctypes import POINTER, Structure, c_bool, c_char_p, c_float, c_int32, c_int64, c_size_t, c_uint8, c_uint32 |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class LiteAsrError(RuntimeError): |
|
|
pass |
|
|
|
|
|
|
|
|
class _F32Buffer(Structure): |
|
|
_fields_ = [ |
|
|
("data", POINTER(c_float)), |
|
|
("len", c_size_t), |
|
|
("cap", c_size_t), |
|
|
("rows", c_size_t), |
|
|
("cols", c_size_t), |
|
|
] |
|
|
|
|
|
|
|
|
class _I64Buffer(Structure): |
|
|
_fields_ = [ |
|
|
("data", POINTER(c_int64)), |
|
|
("len", c_size_t), |
|
|
("cap", c_size_t), |
|
|
] |
|
|
|
|
|
|
|
|
class _Utf8Buffer(Structure): |
|
|
_fields_ = [ |
|
|
("data", POINTER(c_uint8)), |
|
|
("len", c_size_t), |
|
|
("cap", c_size_t), |
|
|
] |
|
|
|
|
|
|
|
|
def _default_dll_candidates() -> list[Path]: |
|
|
return [ |
|
|
Path("target/debug/liteasr_ffi.dll"), |
|
|
Path("target/release/liteasr_ffi.dll"), |
|
|
] |
|
|
|
|
|
|
|
|
class LiteAsrFfi: |
|
|
def __init__(self, dll_path: str | Path | None = None) -> None: |
|
|
if dll_path is None: |
|
|
for candidate in _default_dll_candidates(): |
|
|
if candidate.exists(): |
|
|
dll_path = candidate |
|
|
break |
|
|
if dll_path is None: |
|
|
raise LiteAsrError("liteasr_ffi.dll not found. Build with `cargo build`.") |
|
|
|
|
|
self._dll_path = Path(dll_path) |
|
|
if not self._dll_path.exists(): |
|
|
raise LiteAsrError(f"DLL not found: {self._dll_path}") |
|
|
|
|
|
self._lib = ctypes.CDLL(str(self._dll_path)) |
|
|
self._configure_signatures() |
|
|
|
|
|
@property |
|
|
def dll_path(self) -> Path: |
|
|
return self._dll_path |
|
|
|
|
|
def _configure_signatures(self) -> None: |
|
|
self._lib.liteasr_last_error_message.argtypes = [] |
|
|
self._lib.liteasr_last_error_message.restype = c_char_p |
|
|
|
|
|
self._lib.liteasr_preprocess_wav.argtypes = [ |
|
|
c_char_p, |
|
|
c_uint32, |
|
|
c_size_t, |
|
|
POINTER(_F32Buffer), |
|
|
] |
|
|
self._lib.liteasr_preprocess_wav.restype = c_int32 |
|
|
|
|
|
self._lib.liteasr_build_prompt_ids.argtypes = [ |
|
|
c_char_p, |
|
|
c_char_p, |
|
|
c_char_p, |
|
|
c_bool, |
|
|
c_bool, |
|
|
c_bool, |
|
|
POINTER(_I64Buffer), |
|
|
] |
|
|
self._lib.liteasr_build_prompt_ids.restype = c_int32 |
|
|
|
|
|
self._lib.liteasr_decode_tokens.argtypes = [ |
|
|
c_char_p, |
|
|
POINTER(c_int64), |
|
|
c_size_t, |
|
|
c_bool, |
|
|
POINTER(_Utf8Buffer), |
|
|
] |
|
|
self._lib.liteasr_decode_tokens.restype = c_int32 |
|
|
|
|
|
self._lib.liteasr_apply_suppression.argtypes = [ |
|
|
POINTER(c_float), |
|
|
c_size_t, |
|
|
POINTER(c_int64), |
|
|
c_size_t, |
|
|
POINTER(c_int64), |
|
|
c_size_t, |
|
|
c_size_t, |
|
|
] |
|
|
self._lib.liteasr_apply_suppression.restype = c_int32 |
|
|
|
|
|
self._lib.liteasr_free_f32_buffer.argtypes = [POINTER(_F32Buffer)] |
|
|
self._lib.liteasr_free_f32_buffer.restype = None |
|
|
self._lib.liteasr_free_i64_buffer.argtypes = [POINTER(_I64Buffer)] |
|
|
self._lib.liteasr_free_i64_buffer.restype = None |
|
|
self._lib.liteasr_free_utf8_buffer.argtypes = [POINTER(_Utf8Buffer)] |
|
|
self._lib.liteasr_free_utf8_buffer.restype = None |
|
|
|
|
|
def _raise_last_error(self, fallback: str) -> None: |
|
|
message = self._lib.liteasr_last_error_message() |
|
|
if message: |
|
|
text = message.decode("utf-8", errors="replace") |
|
|
raise LiteAsrError(text) |
|
|
raise LiteAsrError(fallback) |
|
|
|
|
|
def preprocess_wav(self, wav_path: str | Path, target_sr: int, n_mels: int) -> np.ndarray: |
|
|
out = _F32Buffer() |
|
|
rc = self._lib.liteasr_preprocess_wav( |
|
|
str(Path(wav_path)).encode("utf-8"), |
|
|
int(target_sr), |
|
|
int(n_mels), |
|
|
ctypes.byref(out), |
|
|
) |
|
|
if rc != 0: |
|
|
self._raise_last_error("liteasr_preprocess_wav failed") |
|
|
|
|
|
try: |
|
|
flat = np.ctypeslib.as_array(out.data, shape=(out.len,)) |
|
|
matrix = flat.reshape((out.rows, out.cols)).copy() |
|
|
return matrix |
|
|
finally: |
|
|
self._lib.liteasr_free_f32_buffer(ctypes.byref(out)) |
|
|
|
|
|
def build_prompt_ids( |
|
|
self, |
|
|
tokenizer_json_path: str | Path, |
|
|
language: str, |
|
|
task: str, |
|
|
with_timestamps: bool, |
|
|
omit_language_token: bool, |
|
|
omit_notimestamps_token: bool, |
|
|
) -> list[int]: |
|
|
out = _I64Buffer() |
|
|
rc = self._lib.liteasr_build_prompt_ids( |
|
|
str(Path(tokenizer_json_path)).encode("utf-8"), |
|
|
language.encode("utf-8"), |
|
|
task.encode("utf-8"), |
|
|
bool(with_timestamps), |
|
|
bool(omit_language_token), |
|
|
bool(omit_notimestamps_token), |
|
|
ctypes.byref(out), |
|
|
) |
|
|
if rc != 0: |
|
|
self._raise_last_error("liteasr_build_prompt_ids failed") |
|
|
|
|
|
try: |
|
|
arr = np.ctypeslib.as_array(out.data, shape=(out.len,)) |
|
|
return [int(v) for v in arr.tolist()] |
|
|
finally: |
|
|
self._lib.liteasr_free_i64_buffer(ctypes.byref(out)) |
|
|
|
|
|
def decode_tokens( |
|
|
self, |
|
|
tokenizer_json_path: str | Path, |
|
|
token_ids: list[int], |
|
|
skip_special_tokens: bool = True, |
|
|
) -> str: |
|
|
out = _Utf8Buffer() |
|
|
token_np = np.array(token_ids, dtype=np.int64) |
|
|
token_ptr = token_np.ctypes.data_as(POINTER(c_int64)) |
|
|
rc = self._lib.liteasr_decode_tokens( |
|
|
str(Path(tokenizer_json_path)).encode("utf-8"), |
|
|
token_ptr, |
|
|
int(token_np.shape[0]), |
|
|
bool(skip_special_tokens), |
|
|
ctypes.byref(out), |
|
|
) |
|
|
if rc != 0: |
|
|
self._raise_last_error("liteasr_decode_tokens failed") |
|
|
|
|
|
try: |
|
|
data = ctypes.string_at(out.data, out.len) |
|
|
return data.decode("utf-8", errors="replace") |
|
|
finally: |
|
|
self._lib.liteasr_free_utf8_buffer(ctypes.byref(out)) |
|
|
|
|
|
def apply_suppression( |
|
|
self, |
|
|
logits: np.ndarray, |
|
|
suppress_ids: list[int], |
|
|
begin_suppress_ids: list[int], |
|
|
step: int, |
|
|
) -> np.ndarray: |
|
|
if logits.dtype != np.float32 or not logits.flags["C_CONTIGUOUS"]: |
|
|
logits = np.ascontiguousarray(logits, dtype=np.float32) |
|
|
|
|
|
suppress_np = np.array(suppress_ids or [], dtype=np.int64) |
|
|
begin_np = np.array(begin_suppress_ids or [], dtype=np.int64) |
|
|
|
|
|
suppress_ptr = ( |
|
|
suppress_np.ctypes.data_as(POINTER(c_int64)) |
|
|
if suppress_np.size > 0 |
|
|
else ctypes.cast(0, POINTER(c_int64)) |
|
|
) |
|
|
begin_ptr = ( |
|
|
begin_np.ctypes.data_as(POINTER(c_int64)) |
|
|
if begin_np.size > 0 |
|
|
else ctypes.cast(0, POINTER(c_int64)) |
|
|
) |
|
|
|
|
|
rc = self._lib.liteasr_apply_suppression( |
|
|
logits.ctypes.data_as(POINTER(c_float)), |
|
|
int(logits.shape[0]), |
|
|
suppress_ptr, |
|
|
int(suppress_np.shape[0]), |
|
|
begin_ptr, |
|
|
int(begin_np.shape[0]), |
|
|
int(step), |
|
|
) |
|
|
if rc != 0: |
|
|
self._raise_last_error("liteasr_apply_suppression failed") |
|
|
return logits |
|
|
|