Ace-Step1.5-ONNX-DLL / examples /ffi /python /onnxruntime_dll_demo.py
zukky's picture
Upload folder using huggingface_hub
96cc2fd verified
from __future__ import annotations
import argparse
import ctypes
import json
from pathlib import Path
from typing import Dict, List
import numpy as np
import onnxruntime as ort
import soundfile as sf
from runtime_py_ort.case_schema import CaseSpec
from runtime_py_ort.scheduler import resolve_timesteps
def _session(path: Path, provider: str) -> ort.InferenceSession:
providers = ["CPUExecutionProvider"]
if provider == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
try:
return ort.InferenceSession(str(path), providers=providers)
except Exception:
if provider != "cuda":
raise
return ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
def _load_contract(onnx_dir: Path) -> Dict[str, Dict[str, List[str]]]:
contract_path = onnx_dir / "io_contract_core.json"
if not contract_path.exists():
raise FileNotFoundError(f"missing contract: {contract_path}")
return json.loads(contract_path.read_text(encoding="utf-8"))
def _crop_vt(vt: np.ndarray, orig_len: int) -> np.ndarray:
if vt.shape[1] == orig_len:
return vt.astype(np.float32)
return vt[:, :orig_len, :].astype(np.float32)
def _dll_bind(lib: ctypes.CDLL) -> None:
lib.ace_create_context.argtypes = [ctypes.c_char_p]
lib.ace_create_context.restype = ctypes.c_void_p
lib.ace_free_context.argtypes = [ctypes.c_void_p]
lib.ace_free_context.restype = None
lib.ace_string_free.argtypes = [ctypes.c_void_p]
lib.ace_string_free.restype = None
lib.ace_last_error.argtypes = []
lib.ace_last_error.restype = ctypes.c_void_p
lib.ace_prepare_step_inputs.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_void_p),
]
lib.ace_prepare_step_inputs.restype = ctypes.c_int32
lib.ace_scheduler_step.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.c_size_t,
ctypes.c_float,
ctypes.POINTER(ctypes.c_float),
]
lib.ace_scheduler_step.restype = ctypes.c_int32
lib.ace_apply_lm_constraints.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_float),
]
lib.ace_apply_lm_constraints.restype = ctypes.c_int32
lib.ace_finalize_metadata.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_int64),
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_void_p),
]
lib.ace_finalize_metadata.restype = ctypes.c_int32
def _dll_last_error(lib: ctypes.CDLL) -> str:
ptr = lib.ace_last_error()
if not ptr:
return "unknown"
try:
raw = ctypes.cast(ptr, ctypes.c_char_p).value
return (raw or b"unknown").decode("utf-8", errors="replace")
finally:
lib.ace_string_free(ptr)
def _dll_prepare_step(lib: ctypes.CDLL, ctx: int, shift: float, steps: int, step: int, xt: np.ndarray) -> None:
payload = {"shift": float(shift), "inference_steps": int(steps), "current_step": int(step)}
out_json = ctypes.c_void_p()
buf = np.ascontiguousarray(xt.reshape(-1), dtype=np.float32)
rc = lib.ace_prepare_step_inputs(
ctx,
json.dumps(payload).encode("utf-8"),
buf.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.c_size_t(buf.size),
ctypes.byref(out_json),
)
if rc != 0:
raise RuntimeError(f"ace_prepare_step_inputs failed: {_dll_last_error(lib)}")
lib.ace_string_free(out_json)
def _dll_scheduler_step(lib: ctypes.CDLL, ctx: int, xt: np.ndarray, vt: np.ndarray, dt: float) -> np.ndarray:
x = np.ascontiguousarray(xt.reshape(-1), dtype=np.float32)
v = np.ascontiguousarray(vt.reshape(-1), dtype=np.float32)
out = np.empty_like(x)
rc = lib.ace_scheduler_step(
ctx,
x.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
v.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.c_size_t(x.size),
ctypes.c_float(dt),
out.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
)
if rc != 0:
raise RuntimeError(f"ace_scheduler_step failed: {_dll_last_error(lib)}")
return out.reshape(xt.shape)
def _run_condition(
sess: ort.InferenceSession,
arr: Dict[str, np.ndarray],
contract_inputs: List[str],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if all(k in arr for k in ("encoder_hidden_states", "encoder_attention_mask", "context_latents")):
return (
arr["encoder_hidden_states"].astype(np.float32),
arr["encoder_attention_mask"].astype(np.float32),
arr["context_latents"].astype(np.float32),
)
feeds: Dict[str, np.ndarray] = {}
for name in contract_inputs:
if name not in arr:
raise ValueError(f"inputs npz missing condition input: {name}")
value = arr[name]
if name == "refer_audio_order_mask":
feeds[name] = value.astype(np.int64, copy=False)
elif name == "is_covers":
feeds[name] = value.astype(bool, copy=False)
else:
feeds[name] = value.astype(np.float32, copy=False)
out_names = [x.name for x in sess.get_outputs()]
out = sess.run(out_names, feeds)
out_map = {k: v for k, v in zip(out_names, out)}
return (
out_map["encoder_hidden_states"].astype(np.float32),
out_map["encoder_attention_mask"].astype(np.float32),
out_map["context_latents"].astype(np.float32),
)
def main() -> int:
parser = argparse.ArgumentParser(description="Python ONNXRuntime + Rust DLL demo")
parser.add_argument("--case", type=Path, required=True)
parser.add_argument("--inputs-npz", type=Path)
parser.add_argument("--onnx-dir", type=Path, default=Path("artifacts/onnx_runtime"))
parser.add_argument("--dll", type=Path, default=Path("runtime_rust_dll/target/release/acestep_runtime.dll"))
parser.add_argument("--provider", choices=["cpu", "cuda"], default="cpu")
parser.add_argument("--out-wav", type=Path, default=Path("reports/listening/demo/python_ort_dll.wav"))
parser.add_argument("--out-npz", type=Path, default=Path("reports/listening/demo/python_ort_dll.npz"))
args = parser.parse_args()
case = CaseSpec.from_path(args.case)
inputs_npz = args.inputs_npz or Path(f"fixtures/tensors/{case.case_id}.npz")
npz = np.load(str(inputs_npz), allow_pickle=False)
arr = {k: npz[k] for k in npz.files}
onnx_dir = args.onnx_dir
contract = _load_contract(onnx_dir)
lib = ctypes.CDLL(str(args.dll.resolve()))
_dll_bind(lib)
ctx = lib.ace_create_context(json.dumps({"seed": int(case.seed)}).encode("utf-8"))
if not ctx:
raise RuntimeError(f"ace_create_context failed: {_dll_last_error(lib)}")
try:
condition_sess = _session(onnx_dir / "condition_encoder.onnx", args.provider)
encoder_hidden_states, encoder_attention_mask, context_latents = _run_condition(
condition_sess,
arr,
contract.get("inputs", {}).get("condition_encoder", []),
)
src_latents = arr["src_latents"].astype(np.float32)
xt = arr["xt_steps"][0].astype(np.float32).copy() if "xt_steps" in arr else src_latents.copy()
latent_masks = arr.get("latent_masks")
attention_mask = latent_masks.astype(np.float32) if latent_masks is not None else np.ones((xt.shape[0], xt.shape[1]), dtype=np.float32)
orig_len = int(xt.shape[1])
pad_len = (-orig_len) % 2
context_latents_padded = (
np.pad(context_latents, ((0, 0), (0, pad_len), (0, 0)), mode="constant") if pad_len else context_latents
)
attention_mask_padded = (
np.pad(attention_mask, ((0, 0), (0, pad_len)), mode="constant") if pad_len else attention_mask
)
timesteps = resolve_timesteps(case.shift, None, max_steps=max(1, int(case.inference_steps)))
has_kv = (onnx_dir / "dit_prefill_kv.onnx").exists() and (onnx_dir / "dit_decode_kv.onnx").exists()
if has_kv:
dit_prefill = _session(onnx_dir / "dit_prefill_kv.onnx", args.provider)
dit_decode = _session(onnx_dir / "dit_decode_kv.onnx", args.provider)
prefill_inputs = [x.name for x in dit_prefill.get_inputs()]
prefill_outputs = [x.name for x in dit_prefill.get_outputs()]
decode_inputs = [x.name for x in dit_decode.get_inputs()]
decode_outputs = [x.name for x in dit_decode.get_outputs()]
cache_map: Dict[str, np.ndarray] = {}
else:
dit = _session(onnx_dir / "dit_decoder.onnx", args.provider)
dit_inputs = {x.name for x in dit.get_inputs()}
xt_steps: List[np.ndarray] = []
vt_steps: List[np.ndarray] = []
for idx, t in enumerate(timesteps):
_dll_prepare_step(lib, ctx, case.shift, len(timesteps), idx, xt)
t_vec = np.full((xt.shape[0],), t, dtype=np.float32)
xt_in = np.pad(xt, ((0, 0), (0, pad_len), (0, 0)), mode="constant") if pad_len else xt
if has_kv:
base = {
"hidden_states": xt_in,
"timestep": t_vec,
"timestep_r": t_vec,
"attention_mask": attention_mask_padded,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"context_latents": context_latents_padded,
}
if idx == 0:
feeds = {k: base[k] for k in prefill_inputs if k in base}
out = dit_prefill.run(prefill_outputs, feeds)
out_map = {k: v for k, v in zip(prefill_outputs, out)}
else:
feeds = {}
for name in decode_inputs:
if name.startswith("past_"):
present_name = "present_" + name[len("past_") :]
if present_name not in cache_map:
raise ValueError(f"missing cache for {present_name}")
feeds[name] = cache_map[present_name]
elif name in base:
feeds[name] = base[name]
out = dit_decode.run(decode_outputs, feeds)
out_map = {k: v for k, v in zip(decode_outputs, out)}
vt = _crop_vt(out_map["vt"], orig_len)
cache_map = {k: v.astype(np.float32) for k, v in out_map.items() if k.startswith("present_")}
else:
feeds = {
"hidden_states": xt_in,
"timestep": t_vec,
"timestep_r": t_vec,
"encoder_hidden_states": encoder_hidden_states,
"context_latents": context_latents_padded,
}
if "attention_mask" in dit_inputs:
feeds["attention_mask"] = attention_mask_padded
if "encoder_attention_mask" in dit_inputs:
feeds["encoder_attention_mask"] = encoder_attention_mask
vt = _crop_vt(dit.run(["vt"], feeds)[0], orig_len)
xt_steps.append(xt.copy())
vt_steps.append(vt.copy())
dt = float(t) if idx == len(timesteps) - 1 else float(t - timesteps[idx + 1])
xt = _dll_scheduler_step(lib, ctx, xt, vt, dt)
pred_latents = xt
vae = _session(onnx_dir / "vae_decoder.onnx", args.provider)
latents = np.transpose(pred_latents, (0, 2, 1)).astype(np.float32)
audio = vae.run(["audio"], {"latents": latents})[0].astype(np.float32)
audio_0 = audio[0]
args.out_wav.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(args.out_wav), audio_0.T, 48_000, subtype="FLOAT")
args.out_npz.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
str(args.out_npz),
pred_latents=pred_latents.astype(np.float32),
xt_steps=np.asarray(xt_steps, dtype=np.float32),
vt_steps=np.asarray(vt_steps, dtype=np.float32),
encoder_hidden_states=encoder_hidden_states.astype(np.float32),
encoder_attention_mask=encoder_attention_mask.astype(np.float32),
context_latents=context_latents.astype(np.float32),
)
# Demonstrate LM constraint / metadata APIs for simple_mode use.
logits = np.array([0.0, 1.0, 2.0, 3.0], dtype=np.float32)
masked = np.empty_like(logits)
_ = lib.ace_apply_lm_constraints(
ctx,
logits.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.c_size_t(logits.size),
masked.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
)
token_ids = np.array([1, 2, 3], dtype=np.int64)
out_json = ctypes.c_void_p()
rc = lib.ace_finalize_metadata(
ctx,
token_ids.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
ctypes.c_size_t(token_ids.size),
ctypes.byref(out_json),
)
if rc == 0:
lib.ace_string_free(out_json)
print(f"Wrote wav: {args.out_wav}")
print(f"Wrote npz: {args.out_npz}")
return 0
finally:
lib.ace_free_context(ctx)
if __name__ == "__main__":
raise SystemExit(main())