zukky's picture
Upload folder using huggingface_hub
96cc2fd verified
from __future__ import annotations
import ctypes
import json
from pathlib import Path
ROOT = Path(__file__).resolve().parents[3]
DLL = ROOT / "runtime_rust_dll" / "target" / "release" / "acestep_runtime.dll"
lib = ctypes.CDLL(str(DLL))
lib.ace_create_context.argtypes = [ctypes.c_char_p]
lib.ace_create_context.restype = ctypes.c_void_p
lib.ace_free_context.argtypes = [ctypes.c_void_p]
lib.ace_string_free.argtypes = [ctypes.c_void_p]
lib.ace_prepare_step_inputs.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_void_p),
]
lib.ace_prepare_step_inputs.restype = ctypes.c_int32
lib.ace_scheduler_step.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.c_size_t,
ctypes.c_float,
ctypes.POINTER(ctypes.c_float),
]
lib.ace_scheduler_step.restype = ctypes.c_int32
def cstr(ptr: int) -> str:
return ctypes.cast(ptr, ctypes.c_char_p).value.decode("utf-8")
def main() -> int:
cfg = {"seed": 42, "blocked_token_ids": [0]}
ctx = lib.ace_create_context(json.dumps(cfg).encode("utf-8"))
if not ctx:
raise RuntimeError("failed to create context")
try:
state = {"shift": 3.0, "inference_steps": 8, "current_step": 0}
in_buf = (ctypes.c_float * 4)(1.0, 2.0, 3.0, 4.0)
out_json = ctypes.c_void_p()
rc = lib.ace_prepare_step_inputs(
ctx,
json.dumps(state).encode("utf-8"),
in_buf,
4,
ctypes.byref(out_json),
)
if rc != 0:
raise RuntimeError("ace_prepare_step_inputs failed")
print("prepare:", cstr(out_json.value))
lib.ace_string_free(out_json)
xt = (ctypes.c_float * 4)(1.0, 1.0, 1.0, 1.0)
vt = (ctypes.c_float * 4)(0.1, 0.2, 0.3, 0.4)
out = (ctypes.c_float * 4)()
rc = lib.ace_scheduler_step(ctx, xt, vt, 4, ctypes.c_float(0.5), out)
if rc != 0:
raise RuntimeError("ace_scheduler_step failed")
print("scheduler:", list(out))
finally:
lib.ace_free_context(ctx)
return 0
if __name__ == "__main__":
raise SystemExit(main())