ZipVoice.AXERA / scripts /zipvoice_runtime.py
HY-2012's picture
First commit
ea47387 verified
Raw
History Blame Contribute Delete
2.11 kB
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List
import numpy as np
from contextlib import contextmanager
@contextmanager
def _suppress_native_stdio():
if os.environ.get("ZIPVOICE_AXERA_VERBOSE"):
yield
return
stdout_fd = os.dup(1)
stderr_fd = os.dup(2)
try:
with open(os.devnull, "w") as devnull:
os.dup2(devnull.fileno(), 1)
os.dup2(devnull.fileno(), 2)
yield
finally:
os.dup2(stdout_fd, 1)
os.dup2(stderr_fd, 2)
os.close(stdout_fd)
os.close(stderr_fd)
try:
with _suppress_native_stdio():
import axengine
HAS_AXENGINE = True
except ImportError:
HAS_AXENGINE = False
class AxeSession:
"""Thin wrapper around axengine.InferenceSession for one axmodel."""
def __init__(self, model_path: str | Path):
if not HAS_AXENGINE:
raise RuntimeError("axengine not available; run this script on the AX650 board")
self.path = Path(model_path)
if not self.path.exists():
raise FileNotFoundError(f"Model not found: {self.path}")
with _suppress_native_stdio():
self._session = axengine.InferenceSession(str(self.path))
self._inputs = self._session.get_inputs()
self._outputs = self._session.get_outputs()
@property
def input_names(self) -> List[str]:
return [item.name for item in self._inputs]
@property
def output_names(self) -> List[str]:
return [item.name for item in self._outputs]
def run(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
outputs = self._session.run(None, feed_dict)
if isinstance(outputs, dict):
return outputs
if isinstance(outputs, (list, tuple)):
return {name: value for name, value in zip(self.output_names, outputs)}
if len(self.output_names) == 1:
return {self.output_names[0]: outputs}
raise RuntimeError(f"Unexpected AXEngine output type from {self.path.name}: {type(outputs)}")