allinone-DLL-ONNX / python /run_allinone_ffi.py
zukky's picture
Upload folder using huggingface_hub
3ec79eb verified
import argparse
import ctypes
import json
import os
from pathlib import Path
from typing import Dict, Tuple
import numpy as np
import onnxruntime as ort
def find_dll(path_arg: str | None) -> Path:
if path_arg:
return Path(path_arg)
env = os.environ.get("ALLINONE_DLL")
if env:
return Path(env)
candidates = [
Path("target") / "release" / "allinone_onnx_ffi.dll",
Path("target") / "debug" / "allinone_onnx_ffi.dll",
]
for cand in candidates:
if cand.is_file():
return cand
return candidates[-1]
def load_cfg(path: Path) -> Dict[str, object]:
return json.loads(path.read_text(encoding="utf-8"))
def load_dll(path: Path) -> ctypes.CDLL:
if not path.is_file():
raise SystemExit(f"DLL not found: {path}")
return ctypes.CDLL(str(path))
def last_error(dll: ctypes.CDLL) -> str:
dll.allinone_last_error.restype = ctypes.c_void_p
ptr = dll.allinone_last_error()
if not ptr:
return "unknown error"
msg = ctypes.string_at(ptr).decode("utf-8", errors="replace")
dll.allinone_free_string(ctypes.c_void_p(ptr))
return msg
def setup_ffi(dll: ctypes.CDLL) -> None:
dll.allinone_free_string.argtypes = [ctypes.c_void_p]
dll.allinone_free_f32.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_size_t]
dll.allinone_free_config.argtypes = [ctypes.c_void_p]
dll.allinone_load_config_json.argtypes = [ctypes.c_char_p]
dll.allinone_load_config_json.restype = ctypes.c_void_p
dll.allinone_read_wav_mono.argtypes = [
ctypes.c_char_p,
ctypes.POINTER(ctypes.POINTER(ctypes.c_float)),
ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_uint32),
]
dll.allinone_read_wav_mono.restype = ctypes.c_int
dll.allinone_extract_spectrogram_from_stems.argtypes = [
ctypes.c_char_p,
ctypes.c_void_p,
ctypes.POINTER(ctypes.POINTER(ctypes.c_float)),
ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_size_t),
]
dll.allinone_extract_spectrogram_from_stems.restype = ctypes.c_int
dll.allinone_postprocess_logits.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.c_size_t,
ctypes.c_size_t,
ctypes.POINTER(ctypes.c_void_p),
ctypes.POINTER(ctypes.c_size_t),
]
dll.allinone_postprocess_logits.restype = ctypes.c_int
def extract_spec(
dll: ctypes.CDLL,
stems_dir: Path,
cfg_handle: ctypes.c_void_p,
) -> Tuple[np.ndarray, Tuple[int, int, int]]:
out_ptr = ctypes.POINTER(ctypes.c_float)()
out_channels = ctypes.c_size_t()
out_frames = ctypes.c_size_t()
out_bins = ctypes.c_size_t()
rc = dll.allinone_extract_spectrogram_from_stems(
str(stems_dir).encode("utf-8"),
cfg_handle,
ctypes.byref(out_ptr),
ctypes.byref(out_channels),
ctypes.byref(out_frames),
ctypes.byref(out_bins),
)
if rc != 0:
raise SystemExit(last_error(dll))
channels = int(out_channels.value)
frames = int(out_frames.value)
bins = int(out_bins.value)
length = channels * frames * bins
buf = np.ctypeslib.as_array(out_ptr, shape=(length,))
spec = buf.copy().reshape(channels, frames, bins)
dll.allinone_free_f32(out_ptr, length)
return spec, (channels, frames, bins)
def run_onnx(model_path: Path, spec: np.ndarray, provider: str) -> Dict[str, np.ndarray]:
providers = []
if provider == "cpu":
providers = ["CPUExecutionProvider"]
elif provider == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
session = ort.InferenceSession(model_path.as_posix(), providers=providers)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: spec})
return {o.name: v for o, v in zip(session.get_outputs(), outputs)}
def squeeze_to_1d(arr: np.ndarray) -> np.ndarray:
if arr.ndim == 2 and arr.shape[0] == 1:
return arr[0]
if arr.ndim == 1:
return arr
if arr.ndim == 2 and arr.shape[1] == 1:
return arr[:, 0]
raise SystemExit(f"Unexpected shape for 1D logits: {arr.shape}")
def normalize_function_logits(arr: np.ndarray, frames: int, num_labels: int) -> np.ndarray:
if arr.ndim == 3 and arr.shape[0] == 1:
arr = arr[0]
if arr.ndim != 2:
raise SystemExit(f"Unexpected shape for function logits: {arr.shape}")
if arr.shape == (num_labels, frames):
return arr
if arr.shape == (frames, num_labels):
return arr.T
raise SystemExit(f"Unexpected function logits shape: {arr.shape}")
def postprocess(
dll: ctypes.CDLL,
cfg_handle: ctypes.c_void_p,
logits: Dict[str, np.ndarray],
frames: int,
num_labels: int,
) -> dict:
beat = squeeze_to_1d(logits["logits_beat"]).astype(np.float32, copy=False)
downbeat = squeeze_to_1d(logits["logits_downbeat"]).astype(np.float32, copy=False)
section = squeeze_to_1d(logits["logits_section"]).astype(np.float32, copy=False)
function = normalize_function_logits(
logits["logits_function"], frames, num_labels
).astype(np.float32, copy=False)
out_json_ptr = ctypes.c_void_p()
out_json_len = ctypes.c_size_t()
rc = dll.allinone_postprocess_logits(
cfg_handle,
beat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
downbeat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
section.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
function.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
frames,
num_labels,
ctypes.byref(out_json_ptr),
ctypes.byref(out_json_len),
)
if rc != 0:
raise SystemExit(last_error(dll))
raw = ctypes.string_at(out_json_ptr, out_json_len.value)
dll.allinone_free_string(out_json_ptr)
return json.loads(raw.decode("utf-8"))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run allinone pipeline via Rust DLL + ONNXRuntime (Python)."
)
parser.add_argument("--dll", type=str, default=None, help="Path to allinone_onnx_ffi.dll")
parser.add_argument("--stems-dir", type=Path, required=True, help="Stems directory")
parser.add_argument("--onnx-model", type=Path, required=True, help="ONNX model path")
parser.add_argument("--config-json", type=Path, required=True, help="Config JSON path")
parser.add_argument(
"--output-dir",
"-output-dir",
dest="output_dir",
type=Path,
default=Path("output_dll"),
)
parser.add_argument("--analysis-json", type=Path, default=None)
parser.add_argument("--provider", choices=["cpu", "cuda"], default="cpu")
parser.add_argument("--save-spec", action="store_true", help="Save extracted spec as .npy")
return parser.parse_args()
def main() -> None:
args = parse_args()
dll_path = find_dll(args.dll)
dll = load_dll(dll_path)
setup_ffi(dll)
cfg_data = load_cfg(args.config_json)
num_labels = int(cfg_data["data"]["num_labels"])
cfg_handle = dll.allinone_load_config_json(str(args.config_json).encode("utf-8"))
if not cfg_handle:
raise SystemExit(last_error(dll))
spec, shape = extract_spec(dll, args.stems_dir, cfg_handle)
if args.save_spec:
args.output_dir.mkdir(parents=True, exist_ok=True)
np.save(args.output_dir / "spec.npy", spec)
spec = np.expand_dims(spec, axis=0).astype(np.float32, copy=False)
logits = run_onnx(args.onnx_model, spec, args.provider)
frames = shape[1]
analysis = postprocess(dll, cfg_handle, logits, frames, num_labels)
dll.allinone_free_config(cfg_handle)
args.output_dir.mkdir(parents=True, exist_ok=True)
analysis_path = args.analysis_json or (args.output_dir / "analysis.json")
analysis_path.write_text(json.dumps(analysis, indent=2), encoding="utf-8")
print(f"Wrote analysis: {analysis_path}")
if __name__ == "__main__":
main()