File size: 5,293 Bytes
96cc2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;

internal static class Native
{
    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern IntPtr ace_create_context(string config_json);

    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern void ace_free_context(IntPtr ctx);

    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern void ace_string_free(IntPtr ptr);

    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern IntPtr ace_last_error();

    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern int ace_prepare_step_inputs(
        IntPtr ctx,
        string state_json,
        float[] in_tensor_ptr,
        UIntPtr in_tensor_len,
        out IntPtr out_json
    );

    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern int ace_scheduler_step(
        IntPtr ctx,
        float[] xt_ptr,
        float[] vt_ptr,
        UIntPtr len,
        float dt,
        [Out] float[] out_xt_ptr
    );

    [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
    public static extern int ace_apply_lm_constraints(
        IntPtr ctx,
        float[] logits_ptr,
        UIntPtr vocab_size,
        [Out] float[] out_masked_logits_ptr
    );
}

public class Program
{
    private static string LastError()
    {
        IntPtr p = Native.ace_last_error();
        if (p == IntPtr.Zero)
        {
            return "unknown";
        }
        try
        {
            return Marshal.PtrToStringUTF8(p) ?? "unknown";
        }
        finally
        {
            Native.ace_string_free(p);
        }
    }

    private static bool NearlyEqual(float a, float b, float eps = 1e-7f)
    {
        return Math.Abs(a - b) <= eps;
    }

    public static int Main(string[] args)
    {
        IntPtr ctx = Native.ace_create_context("{\"seed\":42,\"blocked_token_ids\":[1,3],\"forced_token_id\":2}");
        if (ctx == IntPtr.Zero)
        {
            Console.WriteLine($"create_context failed: {LastError()}");
            return 1;
        }
        try
        {
            float[] inTensor = { 1f, 2f, 3f, 4f };
            int prepRc = Native.ace_prepare_step_inputs(
                ctx,
                "{\"shift\":3.0,\"inference_steps\":8,\"current_step\":0}",
                inTensor,
                (UIntPtr)inTensor.Length,
                out IntPtr outJson
            );
            if (prepRc != 0)
            {
                Console.WriteLine($"ace_prepare_step_inputs failed: {LastError()}");
                return 2;
            }
            string payload = Marshal.PtrToStringUTF8(outJson) ?? "{}";
            Native.ace_string_free(outJson);
            using JsonDocument doc = JsonDocument.Parse(payload);
            float timestep = doc.RootElement.GetProperty("timestep").GetSingle();
            float nextTimestep = doc.RootElement.GetProperty("next_timestep").GetSingle();
            if (!NearlyEqual(timestep, 1.0f) || !NearlyEqual(nextTimestep, 0.669921875f, 1e-6f))
            {
                Console.WriteLine($"prepare mismatch: t={timestep}, next={nextTimestep}");
                return 3;
            }

            float[] xt = { 1f, 1f, 1f, 1f };
            float[] vt = { 0.1f, 0.2f, 0.3f, 0.4f };
            float[] outXt = new float[4];
            int rc = Native.ace_scheduler_step(ctx, xt, vt, (UIntPtr)4, 0.5f, outXt);
            if (rc != 0)
            {
                Console.WriteLine($"ace_scheduler_step failed: {LastError()}");
                return 4;
            }
            float[] expectedXt = { 0.95f, 0.9f, 0.85f, 0.8f };
            for (int i = 0; i < expectedXt.Length; i++)
            {
                if (!NearlyEqual(outXt[i], expectedXt[i]))
                {
                    Console.WriteLine($"scheduler mismatch at {i}: got={outXt[i]} expected={expectedXt[i]}");
                    return 5;
                }
            }

            float[] logits = { 0f, 1f, 2f, 3f, 4f };
            float[] masked = new float[5];
            int lmRc = Native.ace_apply_lm_constraints(ctx, logits, (UIntPtr)logits.Length, masked);
            if (lmRc != 0)
            {
                Console.WriteLine($"ace_apply_lm_constraints failed: {LastError()}");
                return 6;
            }
            if (!NearlyEqual(masked[2], 2f))
            {
                Console.WriteLine($"forced token mismatch: {masked[2]}");
                return 7;
            }
            for (int i = 0; i < masked.Length; i++)
            {
                if (i == 2)
                {
                    continue;
                }
                if (masked[i] > -1e29f)
                {
                    Console.WriteLine($"mask mismatch at {i}: {masked[i]}");
                    return 8;
                }
            }
            Console.WriteLine("csharp ffi regression: PASS");
            return 0;
        }
        finally
        {
            Native.ace_free_context(ctx);
        }
    }
}