using System; using System.Runtime.InteropServices; using System.Text; using System.Text.Json; internal static class Native { [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr ace_create_context(string config_json); [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern void ace_free_context(IntPtr ctx); [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern void ace_string_free(IntPtr ptr); [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr ace_last_error(); [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern int ace_prepare_step_inputs( IntPtr ctx, string state_json, float[] in_tensor_ptr, UIntPtr in_tensor_len, out IntPtr out_json ); [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern int ace_scheduler_step( IntPtr ctx, float[] xt_ptr, float[] vt_ptr, UIntPtr len, float dt, [Out] float[] out_xt_ptr ); [DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)] public static extern int ace_apply_lm_constraints( IntPtr ctx, float[] logits_ptr, UIntPtr vocab_size, [Out] float[] out_masked_logits_ptr ); } public class Program { private static string LastError() { IntPtr p = Native.ace_last_error(); if (p == IntPtr.Zero) { return "unknown"; } try { return Marshal.PtrToStringUTF8(p) ?? "unknown"; } finally { Native.ace_string_free(p); } } private static bool NearlyEqual(float a, float b, float eps = 1e-7f) { return Math.Abs(a - b) <= eps; } public static int Main(string[] args) { IntPtr ctx = Native.ace_create_context("{\"seed\":42,\"blocked_token_ids\":[1,3],\"forced_token_id\":2}"); if (ctx == IntPtr.Zero) { Console.WriteLine($"create_context failed: {LastError()}"); return 1; } try { float[] inTensor = { 1f, 2f, 3f, 4f }; int prepRc = Native.ace_prepare_step_inputs( ctx, "{\"shift\":3.0,\"inference_steps\":8,\"current_step\":0}", inTensor, (UIntPtr)inTensor.Length, out IntPtr outJson ); if (prepRc != 0) { Console.WriteLine($"ace_prepare_step_inputs failed: {LastError()}"); return 2; } string payload = Marshal.PtrToStringUTF8(outJson) ?? "{}"; Native.ace_string_free(outJson); using JsonDocument doc = JsonDocument.Parse(payload); float timestep = doc.RootElement.GetProperty("timestep").GetSingle(); float nextTimestep = doc.RootElement.GetProperty("next_timestep").GetSingle(); if (!NearlyEqual(timestep, 1.0f) || !NearlyEqual(nextTimestep, 0.669921875f, 1e-6f)) { Console.WriteLine($"prepare mismatch: t={timestep}, next={nextTimestep}"); return 3; } float[] xt = { 1f, 1f, 1f, 1f }; float[] vt = { 0.1f, 0.2f, 0.3f, 0.4f }; float[] outXt = new float[4]; int rc = Native.ace_scheduler_step(ctx, xt, vt, (UIntPtr)4, 0.5f, outXt); if (rc != 0) { Console.WriteLine($"ace_scheduler_step failed: {LastError()}"); return 4; } float[] expectedXt = { 0.95f, 0.9f, 0.85f, 0.8f }; for (int i = 0; i < expectedXt.Length; i++) { if (!NearlyEqual(outXt[i], expectedXt[i])) { Console.WriteLine($"scheduler mismatch at {i}: got={outXt[i]} expected={expectedXt[i]}"); return 5; } } float[] logits = { 0f, 1f, 2f, 3f, 4f }; float[] masked = new float[5]; int lmRc = Native.ace_apply_lm_constraints(ctx, logits, (UIntPtr)logits.Length, masked); if (lmRc != 0) { Console.WriteLine($"ace_apply_lm_constraints failed: {LastError()}"); return 6; } if (!NearlyEqual(masked[2], 2f)) { Console.WriteLine($"forced token mismatch: {masked[2]}"); return 7; } for (int i = 0; i < masked.Length; i++) { if (i == 2) { continue; } if (masked[i] > -1e29f) { Console.WriteLine($"mask mismatch at {i}: {masked[i]}"); return 8; } } Console.WriteLine("csharp ffi regression: PASS"); return 0; } finally { Native.ace_free_context(ctx); } } }