| import argparse | |
| import ctypes | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import numpy as np | |
| import onnxruntime as ort | |
| def find_dll(path_arg: str | None) -> Path: | |
| if path_arg: | |
| return Path(path_arg) | |
| env = os.environ.get("ALLINONE_DLL") | |
| if env: | |
| return Path(env) | |
| candidates = [ | |
| Path("target") / "release" / "allinone_onnx_ffi.dll", | |
| Path("target") / "debug" / "allinone_onnx_ffi.dll", | |
| ] | |
| for cand in candidates: | |
| if cand.is_file(): | |
| return cand | |
| return candidates[-1] | |
| def load_cfg(path: Path) -> Dict[str, object]: | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def load_dll(path: Path) -> ctypes.CDLL: | |
| if not path.is_file(): | |
| raise SystemExit(f"DLL not found: {path}") | |
| return ctypes.CDLL(str(path)) | |
| def last_error(dll: ctypes.CDLL) -> str: | |
| dll.allinone_last_error.restype = ctypes.c_void_p | |
| ptr = dll.allinone_last_error() | |
| if not ptr: | |
| return "unknown error" | |
| msg = ctypes.string_at(ptr).decode("utf-8", errors="replace") | |
| dll.allinone_free_string(ctypes.c_void_p(ptr)) | |
| return msg | |
| def setup_ffi(dll: ctypes.CDLL) -> None: | |
| dll.allinone_free_string.argtypes = [ctypes.c_void_p] | |
| dll.allinone_free_f32.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_size_t] | |
| dll.allinone_free_config.argtypes = [ctypes.c_void_p] | |
| dll.allinone_load_config_json.argtypes = [ctypes.c_char_p] | |
| dll.allinone_load_config_json.restype = ctypes.c_void_p | |
| dll.allinone_read_wav_mono.argtypes = [ | |
| ctypes.c_char_p, | |
| ctypes.POINTER(ctypes.POINTER(ctypes.c_float)), | |
| ctypes.POINTER(ctypes.c_size_t), | |
| ctypes.POINTER(ctypes.c_uint32), | |
| ] | |
| dll.allinone_read_wav_mono.restype = ctypes.c_int | |
| dll.allinone_extract_spectrogram_from_stems.argtypes = [ | |
| ctypes.c_char_p, | |
| ctypes.c_void_p, | |
| ctypes.POINTER(ctypes.POINTER(ctypes.c_float)), | |
| ctypes.POINTER(ctypes.c_size_t), | |
| ctypes.POINTER(ctypes.c_size_t), | |
| ctypes.POINTER(ctypes.c_size_t), | |
| ] | |
| dll.allinone_extract_spectrogram_from_stems.restype = ctypes.c_int | |
| dll.allinone_postprocess_logits.argtypes = [ | |
| ctypes.c_void_p, | |
| ctypes.POINTER(ctypes.c_float), | |
| ctypes.POINTER(ctypes.c_float), | |
| ctypes.POINTER(ctypes.c_float), | |
| ctypes.POINTER(ctypes.c_float), | |
| ctypes.c_size_t, | |
| ctypes.c_size_t, | |
| ctypes.POINTER(ctypes.c_void_p), | |
| ctypes.POINTER(ctypes.c_size_t), | |
| ] | |
| dll.allinone_postprocess_logits.restype = ctypes.c_int | |
| def extract_spec( | |
| dll: ctypes.CDLL, | |
| stems_dir: Path, | |
| cfg_handle: ctypes.c_void_p, | |
| ) -> Tuple[np.ndarray, Tuple[int, int, int]]: | |
| out_ptr = ctypes.POINTER(ctypes.c_float)() | |
| out_channels = ctypes.c_size_t() | |
| out_frames = ctypes.c_size_t() | |
| out_bins = ctypes.c_size_t() | |
| rc = dll.allinone_extract_spectrogram_from_stems( | |
| str(stems_dir).encode("utf-8"), | |
| cfg_handle, | |
| ctypes.byref(out_ptr), | |
| ctypes.byref(out_channels), | |
| ctypes.byref(out_frames), | |
| ctypes.byref(out_bins), | |
| ) | |
| if rc != 0: | |
| raise SystemExit(last_error(dll)) | |
| channels = int(out_channels.value) | |
| frames = int(out_frames.value) | |
| bins = int(out_bins.value) | |
| length = channels * frames * bins | |
| buf = np.ctypeslib.as_array(out_ptr, shape=(length,)) | |
| spec = buf.copy().reshape(channels, frames, bins) | |
| dll.allinone_free_f32(out_ptr, length) | |
| return spec, (channels, frames, bins) | |
| def run_onnx(model_path: Path, spec: np.ndarray, provider: str) -> Dict[str, np.ndarray]: | |
| providers = [] | |
| if provider == "cpu": | |
| providers = ["CPUExecutionProvider"] | |
| elif provider == "cuda": | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| providers = ["CPUExecutionProvider"] | |
| session = ort.InferenceSession(model_path.as_posix(), providers=providers) | |
| input_name = session.get_inputs()[0].name | |
| outputs = session.run(None, {input_name: spec}) | |
| return {o.name: v for o, v in zip(session.get_outputs(), outputs)} | |
| def squeeze_to_1d(arr: np.ndarray) -> np.ndarray: | |
| if arr.ndim == 2 and arr.shape[0] == 1: | |
| return arr[0] | |
| if arr.ndim == 1: | |
| return arr | |
| if arr.ndim == 2 and arr.shape[1] == 1: | |
| return arr[:, 0] | |
| raise SystemExit(f"Unexpected shape for 1D logits: {arr.shape}") | |
| def normalize_function_logits(arr: np.ndarray, frames: int, num_labels: int) -> np.ndarray: | |
| if arr.ndim == 3 and arr.shape[0] == 1: | |
| arr = arr[0] | |
| if arr.ndim != 2: | |
| raise SystemExit(f"Unexpected shape for function logits: {arr.shape}") | |
| if arr.shape == (num_labels, frames): | |
| return arr | |
| if arr.shape == (frames, num_labels): | |
| return arr.T | |
| raise SystemExit(f"Unexpected function logits shape: {arr.shape}") | |
| def postprocess( | |
| dll: ctypes.CDLL, | |
| cfg_handle: ctypes.c_void_p, | |
| logits: Dict[str, np.ndarray], | |
| frames: int, | |
| num_labels: int, | |
| ) -> dict: | |
| beat = squeeze_to_1d(logits["logits_beat"]).astype(np.float32, copy=False) | |
| downbeat = squeeze_to_1d(logits["logits_downbeat"]).astype(np.float32, copy=False) | |
| section = squeeze_to_1d(logits["logits_section"]).astype(np.float32, copy=False) | |
| function = normalize_function_logits( | |
| logits["logits_function"], frames, num_labels | |
| ).astype(np.float32, copy=False) | |
| out_json_ptr = ctypes.c_void_p() | |
| out_json_len = ctypes.c_size_t() | |
| rc = dll.allinone_postprocess_logits( | |
| cfg_handle, | |
| beat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), | |
| downbeat.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), | |
| section.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), | |
| function.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), | |
| frames, | |
| num_labels, | |
| ctypes.byref(out_json_ptr), | |
| ctypes.byref(out_json_len), | |
| ) | |
| if rc != 0: | |
| raise SystemExit(last_error(dll)) | |
| raw = ctypes.string_at(out_json_ptr, out_json_len.value) | |
| dll.allinone_free_string(out_json_ptr) | |
| return json.loads(raw.decode("utf-8")) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Run allinone pipeline via Rust DLL + ONNXRuntime (Python)." | |
| ) | |
| parser.add_argument("--dll", type=str, default=None, help="Path to allinone_onnx_ffi.dll") | |
| parser.add_argument("--stems-dir", type=Path, required=True, help="Stems directory") | |
| parser.add_argument("--onnx-model", type=Path, required=True, help="ONNX model path") | |
| parser.add_argument("--config-json", type=Path, required=True, help="Config JSON path") | |
| parser.add_argument( | |
| "--output-dir", | |
| "-output-dir", | |
| dest="output_dir", | |
| type=Path, | |
| default=Path("output_dll"), | |
| ) | |
| parser.add_argument("--analysis-json", type=Path, default=None) | |
| parser.add_argument("--provider", choices=["cpu", "cuda"], default="cpu") | |
| parser.add_argument("--save-spec", action="store_true", help="Save extracted spec as .npy") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| dll_path = find_dll(args.dll) | |
| dll = load_dll(dll_path) | |
| setup_ffi(dll) | |
| cfg_data = load_cfg(args.config_json) | |
| num_labels = int(cfg_data["data"]["num_labels"]) | |
| cfg_handle = dll.allinone_load_config_json(str(args.config_json).encode("utf-8")) | |
| if not cfg_handle: | |
| raise SystemExit(last_error(dll)) | |
| spec, shape = extract_spec(dll, args.stems_dir, cfg_handle) | |
| if args.save_spec: | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| np.save(args.output_dir / "spec.npy", spec) | |
| spec = np.expand_dims(spec, axis=0).astype(np.float32, copy=False) | |
| logits = run_onnx(args.onnx_model, spec, args.provider) | |
| frames = shape[1] | |
| analysis = postprocess(dll, cfg_handle, logits, frames, num_labels) | |
| dll.allinone_free_config(cfg_handle) | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| analysis_path = args.analysis_json or (args.output_dir / "analysis.json") | |
| analysis_path.write_text(json.dumps(analysis, indent=2), encoding="utf-8") | |
| print(f"Wrote analysis: {analysis_path}") | |
| if __name__ == "__main__": | |
| main() | |