zukky's picture
Upload folder using huggingface_hub
96cc2fd verified
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
internal static class Native
{
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ace_create_context(string config_json);
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern void ace_free_context(IntPtr ctx);
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern void ace_string_free(IntPtr ptr);
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ace_last_error();
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern int ace_prepare_step_inputs(
IntPtr ctx,
string state_json,
float[] in_tensor_ptr,
UIntPtr in_tensor_len,
out IntPtr out_json
);
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern int ace_scheduler_step(
IntPtr ctx,
float[] xt_ptr,
float[] vt_ptr,
UIntPtr len,
float dt,
[Out] float[] out_xt_ptr
);
[DllImport("acestep_runtime.dll", CallingConvention = CallingConvention.Cdecl)]
public static extern int ace_apply_lm_constraints(
IntPtr ctx,
float[] logits_ptr,
UIntPtr vocab_size,
[Out] float[] out_masked_logits_ptr
);
}
public class Program
{
private static string LastError()
{
IntPtr p = Native.ace_last_error();
if (p == IntPtr.Zero)
{
return "unknown";
}
try
{
return Marshal.PtrToStringUTF8(p) ?? "unknown";
}
finally
{
Native.ace_string_free(p);
}
}
private static bool NearlyEqual(float a, float b, float eps = 1e-7f)
{
return Math.Abs(a - b) <= eps;
}
public static int Main(string[] args)
{
IntPtr ctx = Native.ace_create_context("{\"seed\":42,\"blocked_token_ids\":[1,3],\"forced_token_id\":2}");
if (ctx == IntPtr.Zero)
{
Console.WriteLine($"create_context failed: {LastError()}");
return 1;
}
try
{
float[] inTensor = { 1f, 2f, 3f, 4f };
int prepRc = Native.ace_prepare_step_inputs(
ctx,
"{\"shift\":3.0,\"inference_steps\":8,\"current_step\":0}",
inTensor,
(UIntPtr)inTensor.Length,
out IntPtr outJson
);
if (prepRc != 0)
{
Console.WriteLine($"ace_prepare_step_inputs failed: {LastError()}");
return 2;
}
string payload = Marshal.PtrToStringUTF8(outJson) ?? "{}";
Native.ace_string_free(outJson);
using JsonDocument doc = JsonDocument.Parse(payload);
float timestep = doc.RootElement.GetProperty("timestep").GetSingle();
float nextTimestep = doc.RootElement.GetProperty("next_timestep").GetSingle();
if (!NearlyEqual(timestep, 1.0f) || !NearlyEqual(nextTimestep, 0.669921875f, 1e-6f))
{
Console.WriteLine($"prepare mismatch: t={timestep}, next={nextTimestep}");
return 3;
}
float[] xt = { 1f, 1f, 1f, 1f };
float[] vt = { 0.1f, 0.2f, 0.3f, 0.4f };
float[] outXt = new float[4];
int rc = Native.ace_scheduler_step(ctx, xt, vt, (UIntPtr)4, 0.5f, outXt);
if (rc != 0)
{
Console.WriteLine($"ace_scheduler_step failed: {LastError()}");
return 4;
}
float[] expectedXt = { 0.95f, 0.9f, 0.85f, 0.8f };
for (int i = 0; i < expectedXt.Length; i++)
{
if (!NearlyEqual(outXt[i], expectedXt[i]))
{
Console.WriteLine($"scheduler mismatch at {i}: got={outXt[i]} expected={expectedXt[i]}");
return 5;
}
}
float[] logits = { 0f, 1f, 2f, 3f, 4f };
float[] masked = new float[5];
int lmRc = Native.ace_apply_lm_constraints(ctx, logits, (UIntPtr)logits.Length, masked);
if (lmRc != 0)
{
Console.WriteLine($"ace_apply_lm_constraints failed: {LastError()}");
return 6;
}
if (!NearlyEqual(masked[2], 2f))
{
Console.WriteLine($"forced token mismatch: {masked[2]}");
return 7;
}
for (int i = 0; i < masked.Length; i++)
{
if (i == 2)
{
continue;
}
if (masked[i] > -1e29f)
{
Console.WriteLine($"mask mismatch at {i}: {masked[i]}");
return 8;
}
}
Console.WriteLine("csharp ffi regression: PASS");
return 0;
}
finally
{
Native.ace_free_context(ctx);
}
}
}