from __future__ import annotations import argparse import ctypes import json from pathlib import Path from typing import Dict, List import numpy as np import onnxruntime as ort import soundfile as sf from runtime_py_ort.case_schema import CaseSpec from runtime_py_ort.scheduler import resolve_timesteps def _session(path: Path, provider: str) -> ort.InferenceSession: providers = ["CPUExecutionProvider"] if provider == "cuda": providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] try: return ort.InferenceSession(str(path), providers=providers) except Exception: if provider != "cuda": raise return ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) def _load_contract(onnx_dir: Path) -> Dict[str, Dict[str, List[str]]]: contract_path = onnx_dir / "io_contract_core.json" if not contract_path.exists(): raise FileNotFoundError(f"missing contract: {contract_path}") return json.loads(contract_path.read_text(encoding="utf-8")) def _crop_vt(vt: np.ndarray, orig_len: int) -> np.ndarray: if vt.shape[1] == orig_len: return vt.astype(np.float32) return vt[:, :orig_len, :].astype(np.float32) def _dll_bind(lib: ctypes.CDLL) -> None: lib.ace_create_context.argtypes = [ctypes.c_char_p] lib.ace_create_context.restype = ctypes.c_void_p lib.ace_free_context.argtypes = [ctypes.c_void_p] lib.ace_free_context.restype = None lib.ace_string_free.argtypes = [ctypes.c_void_p] lib.ace_string_free.restype = None lib.ace_last_error.argtypes = [] lib.ace_last_error.restype = ctypes.c_void_p lib.ace_prepare_step_inputs.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.POINTER(ctypes.c_void_p), ] lib.ace_prepare_step_inputs.restype = ctypes.c_int32 lib.ace_scheduler_step.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.c_float, ctypes.POINTER(ctypes.c_float), ] lib.ace_scheduler_step.restype = ctypes.c_int32 lib.ace_apply_lm_constraints.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.POINTER(ctypes.c_float), ] lib.ace_apply_lm_constraints.restype = ctypes.c_int32 lib.ace_finalize_metadata.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t, ctypes.POINTER(ctypes.c_void_p), ] lib.ace_finalize_metadata.restype = ctypes.c_int32 def _dll_last_error(lib: ctypes.CDLL) -> str: ptr = lib.ace_last_error() if not ptr: return "unknown" try: raw = ctypes.cast(ptr, ctypes.c_char_p).value return (raw or b"unknown").decode("utf-8", errors="replace") finally: lib.ace_string_free(ptr) def _dll_prepare_step(lib: ctypes.CDLL, ctx: int, shift: float, steps: int, step: int, xt: np.ndarray) -> None: payload = {"shift": float(shift), "inference_steps": int(steps), "current_step": int(step)} out_json = ctypes.c_void_p() buf = np.ascontiguousarray(xt.reshape(-1), dtype=np.float32) rc = lib.ace_prepare_step_inputs( ctx, json.dumps(payload).encode("utf-8"), buf.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.c_size_t(buf.size), ctypes.byref(out_json), ) if rc != 0: raise RuntimeError(f"ace_prepare_step_inputs failed: {_dll_last_error(lib)}") lib.ace_string_free(out_json) def _dll_scheduler_step(lib: ctypes.CDLL, ctx: int, xt: np.ndarray, vt: np.ndarray, dt: float) -> np.ndarray: x = np.ascontiguousarray(xt.reshape(-1), dtype=np.float32) v = np.ascontiguousarray(vt.reshape(-1), dtype=np.float32) out = np.empty_like(x) rc = lib.ace_scheduler_step( ctx, x.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), v.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.c_size_t(x.size), ctypes.c_float(dt), out.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ) if rc != 0: raise RuntimeError(f"ace_scheduler_step failed: {_dll_last_error(lib)}") return out.reshape(xt.shape) def _run_condition( sess: ort.InferenceSession, arr: Dict[str, np.ndarray], contract_inputs: List[str], ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: if all(k in arr for k in ("encoder_hidden_states", "encoder_attention_mask", "context_latents")): return ( arr["encoder_hidden_states"].astype(np.float32), arr["encoder_attention_mask"].astype(np.float32), arr["context_latents"].astype(np.float32), ) feeds: Dict[str, np.ndarray] = {} for name in contract_inputs: if name not in arr: raise ValueError(f"inputs npz missing condition input: {name}") value = arr[name] if name == "refer_audio_order_mask": feeds[name] = value.astype(np.int64, copy=False) elif name == "is_covers": feeds[name] = value.astype(bool, copy=False) else: feeds[name] = value.astype(np.float32, copy=False) out_names = [x.name for x in sess.get_outputs()] out = sess.run(out_names, feeds) out_map = {k: v for k, v in zip(out_names, out)} return ( out_map["encoder_hidden_states"].astype(np.float32), out_map["encoder_attention_mask"].astype(np.float32), out_map["context_latents"].astype(np.float32), ) def main() -> int: parser = argparse.ArgumentParser(description="Python ONNXRuntime + Rust DLL demo") parser.add_argument("--case", type=Path, required=True) parser.add_argument("--inputs-npz", type=Path) parser.add_argument("--onnx-dir", type=Path, default=Path("artifacts/onnx_runtime")) parser.add_argument("--dll", type=Path, default=Path("runtime_rust_dll/target/release/acestep_runtime.dll")) parser.add_argument("--provider", choices=["cpu", "cuda"], default="cpu") parser.add_argument("--out-wav", type=Path, default=Path("reports/listening/demo/python_ort_dll.wav")) parser.add_argument("--out-npz", type=Path, default=Path("reports/listening/demo/python_ort_dll.npz")) args = parser.parse_args() case = CaseSpec.from_path(args.case) inputs_npz = args.inputs_npz or Path(f"fixtures/tensors/{case.case_id}.npz") npz = np.load(str(inputs_npz), allow_pickle=False) arr = {k: npz[k] for k in npz.files} onnx_dir = args.onnx_dir contract = _load_contract(onnx_dir) lib = ctypes.CDLL(str(args.dll.resolve())) _dll_bind(lib) ctx = lib.ace_create_context(json.dumps({"seed": int(case.seed)}).encode("utf-8")) if not ctx: raise RuntimeError(f"ace_create_context failed: {_dll_last_error(lib)}") try: condition_sess = _session(onnx_dir / "condition_encoder.onnx", args.provider) encoder_hidden_states, encoder_attention_mask, context_latents = _run_condition( condition_sess, arr, contract.get("inputs", {}).get("condition_encoder", []), ) src_latents = arr["src_latents"].astype(np.float32) xt = arr["xt_steps"][0].astype(np.float32).copy() if "xt_steps" in arr else src_latents.copy() latent_masks = arr.get("latent_masks") attention_mask = latent_masks.astype(np.float32) if latent_masks is not None else np.ones((xt.shape[0], xt.shape[1]), dtype=np.float32) orig_len = int(xt.shape[1]) pad_len = (-orig_len) % 2 context_latents_padded = ( np.pad(context_latents, ((0, 0), (0, pad_len), (0, 0)), mode="constant") if pad_len else context_latents ) attention_mask_padded = ( np.pad(attention_mask, ((0, 0), (0, pad_len)), mode="constant") if pad_len else attention_mask ) timesteps = resolve_timesteps(case.shift, None, max_steps=max(1, int(case.inference_steps))) has_kv = (onnx_dir / "dit_prefill_kv.onnx").exists() and (onnx_dir / "dit_decode_kv.onnx").exists() if has_kv: dit_prefill = _session(onnx_dir / "dit_prefill_kv.onnx", args.provider) dit_decode = _session(onnx_dir / "dit_decode_kv.onnx", args.provider) prefill_inputs = [x.name for x in dit_prefill.get_inputs()] prefill_outputs = [x.name for x in dit_prefill.get_outputs()] decode_inputs = [x.name for x in dit_decode.get_inputs()] decode_outputs = [x.name for x in dit_decode.get_outputs()] cache_map: Dict[str, np.ndarray] = {} else: dit = _session(onnx_dir / "dit_decoder.onnx", args.provider) dit_inputs = {x.name for x in dit.get_inputs()} xt_steps: List[np.ndarray] = [] vt_steps: List[np.ndarray] = [] for idx, t in enumerate(timesteps): _dll_prepare_step(lib, ctx, case.shift, len(timesteps), idx, xt) t_vec = np.full((xt.shape[0],), t, dtype=np.float32) xt_in = np.pad(xt, ((0, 0), (0, pad_len), (0, 0)), mode="constant") if pad_len else xt if has_kv: base = { "hidden_states": xt_in, "timestep": t_vec, "timestep_r": t_vec, "attention_mask": attention_mask_padded, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "context_latents": context_latents_padded, } if idx == 0: feeds = {k: base[k] for k in prefill_inputs if k in base} out = dit_prefill.run(prefill_outputs, feeds) out_map = {k: v for k, v in zip(prefill_outputs, out)} else: feeds = {} for name in decode_inputs: if name.startswith("past_"): present_name = "present_" + name[len("past_") :] if present_name not in cache_map: raise ValueError(f"missing cache for {present_name}") feeds[name] = cache_map[present_name] elif name in base: feeds[name] = base[name] out = dit_decode.run(decode_outputs, feeds) out_map = {k: v for k, v in zip(decode_outputs, out)} vt = _crop_vt(out_map["vt"], orig_len) cache_map = {k: v.astype(np.float32) for k, v in out_map.items() if k.startswith("present_")} else: feeds = { "hidden_states": xt_in, "timestep": t_vec, "timestep_r": t_vec, "encoder_hidden_states": encoder_hidden_states, "context_latents": context_latents_padded, } if "attention_mask" in dit_inputs: feeds["attention_mask"] = attention_mask_padded if "encoder_attention_mask" in dit_inputs: feeds["encoder_attention_mask"] = encoder_attention_mask vt = _crop_vt(dit.run(["vt"], feeds)[0], orig_len) xt_steps.append(xt.copy()) vt_steps.append(vt.copy()) dt = float(t) if idx == len(timesteps) - 1 else float(t - timesteps[idx + 1]) xt = _dll_scheduler_step(lib, ctx, xt, vt, dt) pred_latents = xt vae = _session(onnx_dir / "vae_decoder.onnx", args.provider) latents = np.transpose(pred_latents, (0, 2, 1)).astype(np.float32) audio = vae.run(["audio"], {"latents": latents})[0].astype(np.float32) audio_0 = audio[0] args.out_wav.parent.mkdir(parents=True, exist_ok=True) sf.write(str(args.out_wav), audio_0.T, 48_000, subtype="FLOAT") args.out_npz.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed( str(args.out_npz), pred_latents=pred_latents.astype(np.float32), xt_steps=np.asarray(xt_steps, dtype=np.float32), vt_steps=np.asarray(vt_steps, dtype=np.float32), encoder_hidden_states=encoder_hidden_states.astype(np.float32), encoder_attention_mask=encoder_attention_mask.astype(np.float32), context_latents=context_latents.astype(np.float32), ) # Demonstrate LM constraint / metadata APIs for simple_mode use. logits = np.array([0.0, 1.0, 2.0, 3.0], dtype=np.float32) masked = np.empty_like(logits) _ = lib.ace_apply_lm_constraints( ctx, logits.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.c_size_t(logits.size), masked.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ) token_ids = np.array([1, 2, 3], dtype=np.int64) out_json = ctypes.c_void_p() rc = lib.ace_finalize_metadata( ctx, token_ids.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), ctypes.c_size_t(token_ids.size), ctypes.byref(out_json), ) if rc == 0: lib.ace_string_free(out_json) print(f"Wrote wav: {args.out_wav}") print(f"Wrote npz: {args.out_npz}") return 0 finally: lib.ace_free_context(ctx) if __name__ == "__main__": raise SystemExit(main())