MOSS-TTS-Nano.AXERA / scripts /axe_session.py
HY-2012's picture
First commit
b3a7ca2 verified
Raw
History Blame Contribute Delete
3.05 kB
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
class _NamedTensor:
def __init__(
self,
name: str,
shape: list[Any] | tuple[Any, ...] | None = None,
dtype: Any = None,
) -> None:
self.name = name
self.shape = list(shape) if shape is not None else []
self.type = str(dtype) if dtype is not None else None
def __repr__(self) -> str:
return f"_NamedTensor(name={self.name!r}, shape={self.shape!r}, type={self.type!r})"
def _as_named_tensor(value: Any, fallback_name: str | None = None) -> _NamedTensor:
if isinstance(value, dict):
name = value.get("name", fallback_name)
shape = value.get("shape") or value.get("dims")
dtype = value.get("dtype") or value.get("type")
else:
name = getattr(value, "name", fallback_name)
shape = getattr(value, "shape", None)
if shape is None:
shape = getattr(value, "dims", None)
dtype = getattr(value, "dtype", None)
if dtype is None:
dtype = getattr(value, "type", None)
if name is None:
raise ValueError(f"Cannot infer tensor name from {value!r}")
return _NamedTensor(str(name), shape, dtype)
class AxeSession:
def __init__(
self,
axmodel_path: str | Path,
input_names: list[str],
output_names: list[str],
) -> None:
import axengine as axe
self._path = Path(axmodel_path)
self._session = axe.InferenceSession(str(self._path))
try:
actual_inputs = self._session.get_inputs()
if actual_inputs:
self._input_infos = [
_as_named_tensor(inp, input_names[i] if i < len(input_names) else None)
for i, inp in enumerate(actual_inputs)
]
else:
self._input_infos = [_NamedTensor(n) for n in input_names]
except Exception:
self._input_infos = [_NamedTensor(n) for n in input_names]
try:
actual_outputs = self._session.get_outputs()
if actual_outputs:
self._output_infos = [
_as_named_tensor(out, output_names[i] if i < len(output_names) else None)
for i, out in enumerate(actual_outputs)
]
else:
self._output_infos = [_NamedTensor(n) for n in output_names]
except Exception:
self._output_infos = [_NamedTensor(n) for n in output_names]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, np.ndarray],
*args: Any,
**kwargs: Any,
) -> list[np.ndarray]:
return self._session.run(None, input_feed)
def get_inputs(self) -> list[_NamedTensor]:
return self._input_infos
def get_outputs(self) -> list[_NamedTensor]:
return self._output_infos
def __repr__(self) -> str:
return f"AxeSession({self._path.name!r})"