File size: 2,230 Bytes
96cc2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

import ctypes
import json
from pathlib import Path


ROOT = Path(__file__).resolve().parents[3]
DLL = ROOT / "runtime_rust_dll" / "target" / "release" / "acestep_runtime.dll"

lib = ctypes.CDLL(str(DLL))

lib.ace_create_context.argtypes = [ctypes.c_char_p]
lib.ace_create_context.restype = ctypes.c_void_p
lib.ace_free_context.argtypes = [ctypes.c_void_p]
lib.ace_string_free.argtypes = [ctypes.c_void_p]

lib.ace_prepare_step_inputs.argtypes = [
    ctypes.c_void_p,
    ctypes.c_char_p,
    ctypes.POINTER(ctypes.c_float),
    ctypes.c_size_t,
    ctypes.POINTER(ctypes.c_void_p),
]
lib.ace_prepare_step_inputs.restype = ctypes.c_int32

lib.ace_scheduler_step.argtypes = [
    ctypes.c_void_p,
    ctypes.POINTER(ctypes.c_float),
    ctypes.POINTER(ctypes.c_float),
    ctypes.c_size_t,
    ctypes.c_float,
    ctypes.POINTER(ctypes.c_float),
]
lib.ace_scheduler_step.restype = ctypes.c_int32


def cstr(ptr: int) -> str:
    return ctypes.cast(ptr, ctypes.c_char_p).value.decode("utf-8")


def main() -> int:
    cfg = {"seed": 42, "blocked_token_ids": [0]}
    ctx = lib.ace_create_context(json.dumps(cfg).encode("utf-8"))
    if not ctx:
        raise RuntimeError("failed to create context")

    try:
        state = {"shift": 3.0, "inference_steps": 8, "current_step": 0}
        in_buf = (ctypes.c_float * 4)(1.0, 2.0, 3.0, 4.0)
        out_json = ctypes.c_void_p()
        rc = lib.ace_prepare_step_inputs(
            ctx,
            json.dumps(state).encode("utf-8"),
            in_buf,
            4,
            ctypes.byref(out_json),
        )
        if rc != 0:
            raise RuntimeError("ace_prepare_step_inputs failed")
        print("prepare:", cstr(out_json.value))
        lib.ace_string_free(out_json)

        xt = (ctypes.c_float * 4)(1.0, 1.0, 1.0, 1.0)
        vt = (ctypes.c_float * 4)(0.1, 0.2, 0.3, 0.4)
        out = (ctypes.c_float * 4)()
        rc = lib.ace_scheduler_step(ctx, xt, vt, 4, ctypes.c_float(0.5), out)
        if rc != 0:
            raise RuntimeError("ace_scheduler_step failed")
        print("scheduler:", list(out))
    finally:
        lib.ace_free_context(ctx)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())