from __future__ import annotations import os from pathlib import Path from typing import Dict, List import numpy as np from contextlib import contextmanager @contextmanager def _suppress_native_stdio(): if os.environ.get("ZIPVOICE_AXERA_VERBOSE"): yield return stdout_fd = os.dup(1) stderr_fd = os.dup(2) try: with open(os.devnull, "w") as devnull: os.dup2(devnull.fileno(), 1) os.dup2(devnull.fileno(), 2) yield finally: os.dup2(stdout_fd, 1) os.dup2(stderr_fd, 2) os.close(stdout_fd) os.close(stderr_fd) try: with _suppress_native_stdio(): import axengine HAS_AXENGINE = True except ImportError: HAS_AXENGINE = False class AxeSession: """Thin wrapper around axengine.InferenceSession for one axmodel.""" def __init__(self, model_path: str | Path): if not HAS_AXENGINE: raise RuntimeError("axengine not available; run this script on the AX650 board") self.path = Path(model_path) if not self.path.exists(): raise FileNotFoundError(f"Model not found: {self.path}") with _suppress_native_stdio(): self._session = axengine.InferenceSession(str(self.path)) self._inputs = self._session.get_inputs() self._outputs = self._session.get_outputs() @property def input_names(self) -> List[str]: return [item.name for item in self._inputs] @property def output_names(self) -> List[str]: return [item.name for item in self._outputs] def run(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: outputs = self._session.run(None, feed_dict) if isinstance(outputs, dict): return outputs if isinstance(outputs, (list, tuple)): return {name: value for name, value in zip(self.output_names, outputs)} if len(self.output_names) == 1: return {self.output_names[0]: outputs} raise RuntimeError(f"Unexpected AXEngine output type from {self.path.name}: {type(outputs)}")