| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import numpy as np |
|
|
| from contextlib import contextmanager |
|
|
|
|
| @contextmanager |
| def _suppress_native_stdio(): |
| if os.environ.get("ZIPVOICE_AXERA_VERBOSE"): |
| yield |
| return |
|
|
| stdout_fd = os.dup(1) |
| stderr_fd = os.dup(2) |
| try: |
| with open(os.devnull, "w") as devnull: |
| os.dup2(devnull.fileno(), 1) |
| os.dup2(devnull.fileno(), 2) |
| yield |
| finally: |
| os.dup2(stdout_fd, 1) |
| os.dup2(stderr_fd, 2) |
| os.close(stdout_fd) |
| os.close(stderr_fd) |
|
|
| try: |
| with _suppress_native_stdio(): |
| import axengine |
| HAS_AXENGINE = True |
| except ImportError: |
| HAS_AXENGINE = False |
|
|
|
|
| class AxeSession: |
| """Thin wrapper around axengine.InferenceSession for one axmodel.""" |
|
|
| def __init__(self, model_path: str | Path): |
| if not HAS_AXENGINE: |
| raise RuntimeError("axengine not available; run this script on the AX650 board") |
| self.path = Path(model_path) |
| if not self.path.exists(): |
| raise FileNotFoundError(f"Model not found: {self.path}") |
| with _suppress_native_stdio(): |
| self._session = axengine.InferenceSession(str(self.path)) |
| self._inputs = self._session.get_inputs() |
| self._outputs = self._session.get_outputs() |
|
|
| @property |
| def input_names(self) -> List[str]: |
| return [item.name for item in self._inputs] |
|
|
| @property |
| def output_names(self) -> List[str]: |
| return [item.name for item in self._outputs] |
|
|
| def run(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: |
| outputs = self._session.run(None, feed_dict) |
| if isinstance(outputs, dict): |
| return outputs |
| if isinstance(outputs, (list, tuple)): |
| return {name: value for name, value in zip(self.output_names, outputs)} |
| if len(self.output_names) == 1: |
| return {self.output_names[0]: outputs} |
| raise RuntimeError(f"Unexpected AXEngine output type from {self.path.name}: {type(outputs)}") |
|
|