File size: 2,111 Bytes
ea47387 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List
import numpy as np
from contextlib import contextmanager
@contextmanager
def _suppress_native_stdio():
if os.environ.get("ZIPVOICE_AXERA_VERBOSE"):
yield
return
stdout_fd = os.dup(1)
stderr_fd = os.dup(2)
try:
with open(os.devnull, "w") as devnull:
os.dup2(devnull.fileno(), 1)
os.dup2(devnull.fileno(), 2)
yield
finally:
os.dup2(stdout_fd, 1)
os.dup2(stderr_fd, 2)
os.close(stdout_fd)
os.close(stderr_fd)
try:
with _suppress_native_stdio():
import axengine
HAS_AXENGINE = True
except ImportError:
HAS_AXENGINE = False
class AxeSession:
"""Thin wrapper around axengine.InferenceSession for one axmodel."""
def __init__(self, model_path: str | Path):
if not HAS_AXENGINE:
raise RuntimeError("axengine not available; run this script on the AX650 board")
self.path = Path(model_path)
if not self.path.exists():
raise FileNotFoundError(f"Model not found: {self.path}")
with _suppress_native_stdio():
self._session = axengine.InferenceSession(str(self.path))
self._inputs = self._session.get_inputs()
self._outputs = self._session.get_outputs()
@property
def input_names(self) -> List[str]:
return [item.name for item in self._inputs]
@property
def output_names(self) -> List[str]:
return [item.name for item in self._outputs]
def run(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
outputs = self._session.run(None, feed_dict)
if isinstance(outputs, dict):
return outputs
if isinstance(outputs, (list, tuple)):
return {name: value for name, value in zip(self.output_names, outputs)}
if len(self.output_names) == 1:
return {self.output_names[0]: outputs}
raise RuntimeError(f"Unexpected AXEngine output type from {self.path.name}: {type(outputs)}")
|