import argparse import ctypes import json import os from pathlib import Path from typing import Dict, Tuple import numpy as np import onnxruntime as ort def find_dll(path_arg: str | None) -> Path: if path_arg: return Path(path_arg) env = os.environ.get("ALLINONE_DLL") if env: return Path(env) candidates = [ Path("target") / "release" / "allinone_onnx_ffi.dll", Path("target") / "debug" / "allinone_onnx_ffi.dll", ] for cand in candidates: if cand.is_file(): return cand return candidates[-1] def load_cfg(path: Path) -> Dict[str, object]: return json.loads(path.read_text(encoding="utf-8")) def load_dll(path: Path) -> ctypes.CDLL: if not path.is_file(): raise SystemExit(f"DLL not found: {path}") return ctypes.CDLL(str(path)) def last_error(dll: ctypes.CDLL) -> str: dll.allinone_last_error.restype = ctypes.c_void_p ptr = dll.allinone_last_error() if not ptr: return "unknown error" msg = ctypes.string_at(ptr).decode("utf-8", errors="replace") dll.allinone_free_string(ctypes.c_void_p(ptr)) return msg def setup_ffi(dll: ctypes.CDLL) -> None: dll.allinone_free_string.argtypes = [ctypes.c_void_p] dll.allinone_free_f32.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_size_t] dll.allinone_free_config.argtypes = [ctypes.c_void_p] dll.allinone_load_config_json.argtypes = [ctypes.c_char_p] dll.allinone_load_config_json.restype = ctypes.c_void_p dll.allinone_read_wav_mono.argtypes = [ ctypes.c_char_p, ctypes.POINTER(ctypes.POINTER(ctypes.c_float)), ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_uint32), ] dll.allinone_read_wav_mono.restype = ctypes.c_int dll.allinone_extract_spectrogram_from_stems.argtypes = [ ctypes.c_char_p, ctypes.c_void_p, ctypes.POINTER(ctypes.POINTER(ctypes.c_float)), ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t), ] dll.allinone_extract_spectrogram_from_stems.restype = ctypes.c_int dll.allinone_postprocess_logits.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.c_size_t, ctypes.c_size_t, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_size_t), ] dll.allinone_postprocess_logits.restype = ctypes.c_int def extract_spec( dll: ctypes.CDLL, stems_dir: Path, cfg_handle: ctypes.c_void_p, ) -> Tuple[np.ndarray, Tuple[int, int, int]]: out_ptr = ctypes.POINTER(ctypes.c_float)() out_channels = ctypes.c_size_t() out_frames = ctypes.c_size_t() out_bins = ctypes.c_size_t() rc = dll.allinone_extract_spectrogram_from_stems( str(stems_dir).encode("utf-8"), cfg_handle, ctypes.byref(out_ptr), ctypes.byref(out_channels), ctypes.byref(out_frames), ctypes.byref(out_bins), ) if rc != 0: raise SystemExit(last_error(dll)) channels = int(out_channels.value) frames = int(out_frames.value) bins = int(out_bins.value) length = channels * frames * bins buf = np.ctypeslib.as_array(out_ptr, shape=(length,)) spec = buf.copy().reshape(channels, frames, bins) dll.allinone_free_f32(out_ptr, length) return spec, (channels, frames, bins) def run_onnx(model_path: Path, spec: np.ndarray, provider: str) -> Dict[str, np.ndarray]: providers = [] if provider == "cpu": providers = ["CPUExecutionProvider"] elif provider == "cuda": providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] session = ort.InferenceSession(model_path.as_posix(), providers=providers) input_name = session.get_inputs()[0].name outputs = session.run(None, {input_name: spec}) return {o.name: v for o, v in zip(session.get_outputs(), outputs)} def squeeze_to_1d(arr: np.ndarray) -> np.ndarray: if arr.ndim == 2 and arr.shape[0] == 1: return arr[0] if arr.ndim == 1: return arr if arr.ndim == 2 and arr.shape[1] == 1: return arr[:, 0] raise SystemExit(f"Unexpected shape for 1D logits: {arr.shape}") def normalize_function_logits(arr: np.ndarray, frames: int, num_labels: int) -> np.ndarray: if arr.ndim == 3 and arr.shape[0] == 1: arr = arr[0] if arr.ndim != 2: raise SystemExit(f"Unexpected shape for function logits: {arr.shape}") if arr.shape == (num_labels, frames): return arr if arr.shape == (frames, num_labels): return arr.T raise SystemExit(f"Unexpected function logits shape: {arr.shape}") def postprocess( dll: ctypes.CDLL, cfg_handle: ctypes.c_void_p, logits: Dict[str, np.ndarray], frames: int, num_labels: int, ) -> dict: beat = squeeze_to_1d(logits["logits_beat"]).astype(np.float32, copy=False) downbeat = squeeze_to_1d(logits["logits_downbeat"]).astype(np.float32, copy=False) section = squeeze_to_1d(logits["logits_section"]).astype(np.float32, copy=False) function = normalize_function_logits( logits["logits_function"], frames, num_labels ).astype(np.float32, copy=False) out_json_ptr = ctypes.c_void_p() out_json_len = ctypes.c_size_t() rc = dll.allinone_postprocess_logits( cfg_handle, beat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), downbeat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), section.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), function.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), frames, num_labels, ctypes.byref(out_json_ptr), ctypes.byref(out_json_len), ) if rc != 0: raise SystemExit(last_error(dll)) raw = ctypes.string_at(out_json_ptr, out_json_len.value) dll.allinone_free_string(out_json_ptr) return json.loads(raw.decode("utf-8")) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run allinone pipeline via Rust DLL + ONNXRuntime (Python)." ) parser.add_argument("--dll", type=str, default=None, help="Path to allinone_onnx_ffi.dll") parser.add_argument("--stems-dir", type=Path, required=True, help="Stems directory") parser.add_argument("--onnx-model", type=Path, required=True, help="ONNX model path") parser.add_argument("--config-json", type=Path, required=True, help="Config JSON path") parser.add_argument( "--output-dir", "-output-dir", dest="output_dir", type=Path, default=Path("output_dll"), ) parser.add_argument("--analysis-json", type=Path, default=None) parser.add_argument("--provider", choices=["cpu", "cuda"], default="cpu") parser.add_argument("--save-spec", action="store_true", help="Save extracted spec as .npy") return parser.parse_args() def main() -> None: args = parse_args() dll_path = find_dll(args.dll) dll = load_dll(dll_path) setup_ffi(dll) cfg_data = load_cfg(args.config_json) num_labels = int(cfg_data["data"]["num_labels"]) cfg_handle = dll.allinone_load_config_json(str(args.config_json).encode("utf-8")) if not cfg_handle: raise SystemExit(last_error(dll)) spec, shape = extract_spec(dll, args.stems_dir, cfg_handle) if args.save_spec: args.output_dir.mkdir(parents=True, exist_ok=True) np.save(args.output_dir / "spec.npy", spec) spec = np.expand_dims(spec, axis=0).astype(np.float32, copy=False) logits = run_onnx(args.onnx_model, spec, args.provider) frames = shape[1] analysis = postprocess(dll, cfg_handle, logits, frames, num_labels) dll.allinone_free_config(cfg_handle) args.output_dir.mkdir(parents=True, exist_ok=True) analysis_path = args.analysis_json or (args.output_dir / "analysis.json") analysis_path.write_text(json.dumps(analysis, indent=2), encoding="utf-8") print(f"Wrote analysis: {analysis_path}") if __name__ == "__main__": main()