from __future__ import annotations from pathlib import Path from typing import Any import numpy as np class _NamedTensor: def __init__( self, name: str, shape: list[Any] | tuple[Any, ...] | None = None, dtype: Any = None, ) -> None: self.name = name self.shape = list(shape) if shape is not None else [] self.type = str(dtype) if dtype is not None else None def __repr__(self) -> str: return f"_NamedTensor(name={self.name!r}, shape={self.shape!r}, type={self.type!r})" def _as_named_tensor(value: Any, fallback_name: str | None = None) -> _NamedTensor: if isinstance(value, dict): name = value.get("name", fallback_name) shape = value.get("shape") or value.get("dims") dtype = value.get("dtype") or value.get("type") else: name = getattr(value, "name", fallback_name) shape = getattr(value, "shape", None) if shape is None: shape = getattr(value, "dims", None) dtype = getattr(value, "dtype", None) if dtype is None: dtype = getattr(value, "type", None) if name is None: raise ValueError(f"Cannot infer tensor name from {value!r}") return _NamedTensor(str(name), shape, dtype) class AxeSession: def __init__( self, axmodel_path: str | Path, input_names: list[str], output_names: list[str], ) -> None: import axengine as axe self._path = Path(axmodel_path) self._session = axe.InferenceSession(str(self._path)) try: actual_inputs = self._session.get_inputs() if actual_inputs: self._input_infos = [ _as_named_tensor(inp, input_names[i] if i < len(input_names) else None) for i, inp in enumerate(actual_inputs) ] else: self._input_infos = [_NamedTensor(n) for n in input_names] except Exception: self._input_infos = [_NamedTensor(n) for n in input_names] try: actual_outputs = self._session.get_outputs() if actual_outputs: self._output_infos = [ _as_named_tensor(out, output_names[i] if i < len(output_names) else None) for i, out in enumerate(actual_outputs) ] else: self._output_infos = [_NamedTensor(n) for n in output_names] except Exception: self._output_infos = [_NamedTensor(n) for n in output_names] def run( self, output_names: list[str] | None, input_feed: dict[str, np.ndarray], *args: Any, **kwargs: Any, ) -> list[np.ndarray]: return self._session.run(None, input_feed) def get_inputs(self) -> list[_NamedTensor]: return self._input_infos def get_outputs(self) -> list[_NamedTensor]: return self._output_infos def __repr__(self) -> str: return f"AxeSession({self._path.name!r})"