| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
|
|
|
|
| class _NamedTensor: |
| def __init__( |
| self, |
| name: str, |
| shape: list[Any] | tuple[Any, ...] | None = None, |
| dtype: Any = None, |
| ) -> None: |
| self.name = name |
| self.shape = list(shape) if shape is not None else [] |
| self.type = str(dtype) if dtype is not None else None |
|
|
| def __repr__(self) -> str: |
| return f"_NamedTensor(name={self.name!r}, shape={self.shape!r}, type={self.type!r})" |
|
|
|
|
| def _as_named_tensor(value: Any, fallback_name: str | None = None) -> _NamedTensor: |
| if isinstance(value, dict): |
| name = value.get("name", fallback_name) |
| shape = value.get("shape") or value.get("dims") |
| dtype = value.get("dtype") or value.get("type") |
| else: |
| name = getattr(value, "name", fallback_name) |
| shape = getattr(value, "shape", None) |
| if shape is None: |
| shape = getattr(value, "dims", None) |
| dtype = getattr(value, "dtype", None) |
| if dtype is None: |
| dtype = getattr(value, "type", None) |
| if name is None: |
| raise ValueError(f"Cannot infer tensor name from {value!r}") |
| return _NamedTensor(str(name), shape, dtype) |
|
|
|
|
| class AxeSession: |
| def __init__( |
| self, |
| axmodel_path: str | Path, |
| input_names: list[str], |
| output_names: list[str], |
| ) -> None: |
| import axengine as axe |
|
|
| self._path = Path(axmodel_path) |
| self._session = axe.InferenceSession(str(self._path)) |
| try: |
| actual_inputs = self._session.get_inputs() |
| if actual_inputs: |
| self._input_infos = [ |
| _as_named_tensor(inp, input_names[i] if i < len(input_names) else None) |
| for i, inp in enumerate(actual_inputs) |
| ] |
| else: |
| self._input_infos = [_NamedTensor(n) for n in input_names] |
| except Exception: |
| self._input_infos = [_NamedTensor(n) for n in input_names] |
| try: |
| actual_outputs = self._session.get_outputs() |
| if actual_outputs: |
| self._output_infos = [ |
| _as_named_tensor(out, output_names[i] if i < len(output_names) else None) |
| for i, out in enumerate(actual_outputs) |
| ] |
| else: |
| self._output_infos = [_NamedTensor(n) for n in output_names] |
| except Exception: |
| self._output_infos = [_NamedTensor(n) for n in output_names] |
|
|
| def run( |
| self, |
| output_names: list[str] | None, |
| input_feed: dict[str, np.ndarray], |
| *args: Any, |
| **kwargs: Any, |
| ) -> list[np.ndarray]: |
| return self._session.run(None, input_feed) |
|
|
| def get_inputs(self) -> list[_NamedTensor]: |
| return self._input_infos |
|
|
| def get_outputs(self) -> list[_NamedTensor]: |
| return self._output_infos |
|
|
| def __repr__(self) -> str: |
| return f"AxeSession({self._path.name!r})" |
|
|