from __future__ import annotations import ctypes import json from pathlib import Path ROOT = Path(__file__).resolve().parents[3] DLL = ROOT / "runtime_rust_dll" / "target" / "release" / "acestep_runtime.dll" lib = ctypes.CDLL(str(DLL)) lib.ace_create_context.argtypes = [ctypes.c_char_p] lib.ace_create_context.restype = ctypes.c_void_p lib.ace_free_context.argtypes = [ctypes.c_void_p] lib.ace_string_free.argtypes = [ctypes.c_void_p] lib.ace_prepare_step_inputs.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.POINTER(ctypes.c_void_p), ] lib.ace_prepare_step_inputs.restype = ctypes.c_int32 lib.ace_scheduler_step.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.c_float, ctypes.POINTER(ctypes.c_float), ] lib.ace_scheduler_step.restype = ctypes.c_int32 def cstr(ptr: int) -> str: return ctypes.cast(ptr, ctypes.c_char_p).value.decode("utf-8") def main() -> int: cfg = {"seed": 42, "blocked_token_ids": [0]} ctx = lib.ace_create_context(json.dumps(cfg).encode("utf-8")) if not ctx: raise RuntimeError("failed to create context") try: state = {"shift": 3.0, "inference_steps": 8, "current_step": 0} in_buf = (ctypes.c_float * 4)(1.0, 2.0, 3.0, 4.0) out_json = ctypes.c_void_p() rc = lib.ace_prepare_step_inputs( ctx, json.dumps(state).encode("utf-8"), in_buf, 4, ctypes.byref(out_json), ) if rc != 0: raise RuntimeError("ace_prepare_step_inputs failed") print("prepare:", cstr(out_json.value)) lib.ace_string_free(out_json) xt = (ctypes.c_float * 4)(1.0, 1.0, 1.0, 1.0) vt = (ctypes.c_float * 4)(0.1, 0.2, 0.3, 0.4) out = (ctypes.c_float * 4)() rc = lib.ace_scheduler_step(ctx, xt, vt, 4, ctypes.c_float(0.5), out) if rc != 0: raise RuntimeError("ace_scheduler_step failed") print("scheduler:", list(out)) finally: lib.ace_free_context(ctx) return 0 if __name__ == "__main__": raise SystemExit(main())