from __future__ import annotations import ctypes import json import math import os from pathlib import Path ROOT = Path(__file__).resolve().parents[3] DEFAULT_DLL = ROOT / "runtime_rust_dll" / "target" / "release" / "acestep_runtime.dll" DLL_PATH = Path(os.environ.get("ACESTEP_RUNTIME_DLL", str(DEFAULT_DLL))) def _expect(cond: bool, msg: str) -> None: if not cond: raise RuntimeError(msg) def _decode_json_ptr(lib: ctypes.CDLL, ptr: int) -> dict: try: raw = ctypes.cast(ptr, ctypes.c_char_p).value _expect(raw is not None, "null json pointer") return json.loads(raw.decode("utf-8")) finally: lib.ace_string_free(ptr) def _last_error(lib: ctypes.CDLL) -> str: ptr = lib.ace_last_error() if not ptr: return "unknown" try: raw = ctypes.cast(ptr, ctypes.c_char_p).value return (raw or b"unknown").decode("utf-8", errors="replace") finally: lib.ace_string_free(ptr) def main() -> int: lib = ctypes.CDLL(str(DLL_PATH)) lib.ace_create_context.argtypes = [ctypes.c_char_p] lib.ace_create_context.restype = ctypes.c_void_p lib.ace_free_context.argtypes = [ctypes.c_void_p] lib.ace_free_context.restype = None lib.ace_string_free.argtypes = [ctypes.c_void_p] lib.ace_string_free.restype = None lib.ace_last_error.argtypes = [] lib.ace_last_error.restype = ctypes.c_void_p lib.ace_prepare_step_inputs.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.POINTER(ctypes.c_void_p), ] lib.ace_prepare_step_inputs.restype = ctypes.c_int32 lib.ace_scheduler_step.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.c_float, ctypes.POINTER(ctypes.c_float), ] lib.ace_scheduler_step.restype = ctypes.c_int32 lib.ace_apply_lm_constraints.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.POINTER(ctypes.c_float), ] lib.ace_apply_lm_constraints.restype = ctypes.c_int32 cfg = {"seed": 42, "blocked_token_ids": [1, 3], "forced_token_id": 2} ctx = lib.ace_create_context(json.dumps(cfg).encode("utf-8")) _expect(bool(ctx), "ace_create_context failed") try: state = {"shift": 3.0, "inference_steps": 8, "current_step": 0} in_buf = (ctypes.c_float * 4)(1.0, 2.0, 3.0, 4.0) out_json = ctypes.c_void_p() rc = lib.ace_prepare_step_inputs( ctx, json.dumps(state).encode("utf-8"), in_buf, 4, ctypes.byref(out_json), ) _expect(rc == 0, f"ace_prepare_step_inputs failed: {_last_error(lib)}") payload = _decode_json_ptr(lib, out_json.value) _expect(payload["seed"] == 42, "seed mismatch") _expect(payload["inference_steps"] == 8, "inference_steps mismatch") _expect(abs(payload["timestep"] - 1.0) < 1e-7, "timestep mismatch") _expect(abs(payload["next_timestep"] - (0.875 ** 3)) < 1e-7, "next_timestep mismatch") xt = (ctypes.c_float * 4)(1.0, 1.0, 1.0, 1.0) vt = (ctypes.c_float * 4)(0.1, 0.2, 0.3, 0.4) out = (ctypes.c_float * 4)() rc = lib.ace_scheduler_step(ctx, xt, vt, 4, ctypes.c_float(0.5), out) _expect(rc == 0, f"ace_scheduler_step failed: {_last_error(lib)}") expected = [0.95, 0.9, 0.85, 0.8] for got, exp in zip(list(out), expected): _expect(math.isclose(got, exp, rel_tol=0.0, abs_tol=1e-7), f"scheduler mismatch: got={got}, exp={exp}") logits = (ctypes.c_float * 5)(0.0, 1.0, 2.0, 3.0, 4.0) masked = (ctypes.c_float * 5)() rc = lib.ace_apply_lm_constraints(ctx, logits, 5, masked) _expect(rc == 0, f"ace_apply_lm_constraints failed: {_last_error(lib)}") _expect(abs(masked[2] - 2.0) < 1e-7, "forced token mismatch") for i, value in enumerate(masked): if i != 2: _expect(value < -1e29, f"token {i} should be masked, got={value}") finally: lib.ace_free_context(ctx) print("python ffi regression: PASS") return 0 if __name__ == "__main__": raise SystemExit(main())