File size: 2,111 Bytes
ea47387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

import os
from pathlib import Path
from typing import Dict, List

import numpy as np

from contextlib import contextmanager


@contextmanager
def _suppress_native_stdio():
    if os.environ.get("ZIPVOICE_AXERA_VERBOSE"):
        yield
        return

    stdout_fd = os.dup(1)
    stderr_fd = os.dup(2)
    try:
        with open(os.devnull, "w") as devnull:
            os.dup2(devnull.fileno(), 1)
            os.dup2(devnull.fileno(), 2)
            yield
    finally:
        os.dup2(stdout_fd, 1)
        os.dup2(stderr_fd, 2)
        os.close(stdout_fd)
        os.close(stderr_fd)

try:
    with _suppress_native_stdio():
        import axengine
    HAS_AXENGINE = True
except ImportError:
    HAS_AXENGINE = False


class AxeSession:
    """Thin wrapper around axengine.InferenceSession for one axmodel."""

    def __init__(self, model_path: str | Path):
        if not HAS_AXENGINE:
            raise RuntimeError("axengine not available; run this script on the AX650 board")
        self.path = Path(model_path)
        if not self.path.exists():
            raise FileNotFoundError(f"Model not found: {self.path}")
        with _suppress_native_stdio():
            self._session = axengine.InferenceSession(str(self.path))
        self._inputs = self._session.get_inputs()
        self._outputs = self._session.get_outputs()

    @property
    def input_names(self) -> List[str]:
        return [item.name for item in self._inputs]

    @property
    def output_names(self) -> List[str]:
        return [item.name for item in self._outputs]

    def run(self, feed_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        outputs = self._session.run(None, feed_dict)
        if isinstance(outputs, dict):
            return outputs
        if isinstance(outputs, (list, tuple)):
            return {name: value for name, value in zip(self.output_names, outputs)}
        if len(self.output_names) == 1:
            return {self.output_names[0]: outputs}
        raise RuntimeError(f"Unexpected AXEngine output type from {self.path.name}: {type(outputs)}")