File size: 5,651 Bytes
c0fd9c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Custom inference handler for Hugging Face Inference Endpoints.



This module exposes :class:`EndpointHandler`, the entrypoint used by the

Hugging Face serving stack when ``--task custom`` is selected.  The handler

loads the exported Noesis decoder ONNX graph and accepts symbolic intent

vectors (``psi``) along with an optional ``slow_state`` memory tensor.  The

outputs mirror the values produced by the training runtime:



* ``z_out`` – semantic embedding projected back into symbolic space.

* ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars.

* ``slow_state`` – updated slow memory tensor suitable for recurrent usage.



The handler is intentionally lightweight so it can run without the rest of the

AletheiaEngine Python package being installed.

"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Mapping, MutableMapping, Optional

import numpy as np
import onnxruntime as ort


@dataclass(frozen=True)
class _ModelIO:
    """Snapshot of ONNX input and output metadata."""

    inputs: tuple[ort.NodeArg, ...]
    outputs: tuple[ort.NodeArg, ...]


class EndpointHandler:
    """Callable endpoint used by Hugging Face to drive inference."""

    def __init__(self, path: str | None = None) -> None:
        self.model_dir = Path(path or Path(__file__).parent)
        self.session = self._load_session()
        self.io = self._capture_io()

        self.primary_input = self.io.inputs[0].name
        self.slow_input = self._find_input("slow_state")
        self._defaults = {
            node.name: self._zeros_like(node)
            for node in self.io.inputs
            if node.name not in {self.primary_input, self.slow_input}
        }
        if self.slow_input is not None:
            self._slow_fallback = self._zeros_like(self._input_map[self.slow_input])
        else:
            self._slow_fallback = None

    def _load_session(self) -> ort.InferenceSession:
        model_path = self.model_dir / "model.onnx"
        if not model_path.exists():
            available = ", ".join(sorted(str(p.name) for p in self.model_dir.glob("*.onnx"))) or "<none>"
            raise FileNotFoundError(
                "Could not locate 'model.onnx' in %s (available: %s)" % (self.model_dir, available)
            )
        return ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])

    @property
    def _input_map(self) -> Mapping[str, ort.NodeArg]:
        return {node.name: node for node in self.io.inputs}

    def _capture_io(self) -> _ModelIO:
        return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs()))

    def _find_input(self, target: str) -> Optional[str]:
        target = target.lower()
        for node in self.io.inputs:
            if node.name.lower() == target:
                return node.name
        return None

    @staticmethod
    def _zeros_like(node: ort.NodeArg) -> np.ndarray:
        shape: list[int] = []
        for dim in node.shape:
            if isinstance(dim, int) and dim > 0:
                shape.append(dim)
            else:
                shape.append(1)
        return np.zeros(shape, dtype=np.float32)

    @staticmethod
    def _coerce_array(value: Any, *, allow_empty: bool = False) -> np.ndarray:
        array = np.asarray(value, dtype=np.float32)
        if array.size == 0 and not allow_empty:
            raise ValueError("Received an empty array; provide at least one value.")
        if array.ndim == 1:
            array = np.expand_dims(array, axis=0)
        elif array.ndim > 2:
            raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,))
        return array

    def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]:
        psi = payload.get("psi")
        if psi is None:
            psi = payload.get("vector") or payload.get("psi_s") or payload.get("inputs")
        if psi is None:
            raise KeyError("Payload must include a 'psi' field containing the symbolic vector.")

        inputs: MutableMapping[str, np.ndarray] = {self.primary_input: self._coerce_array(psi)}

        if self.slow_input is not None:
            slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state")
            if slow_value is None:
                inputs[self.slow_input] = self._slow_fallback.copy()
            else:
                inputs[self.slow_input] = self._coerce_array(slow_value, allow_empty=True)

        for name, default in self._defaults.items():
            inputs[name] = default.copy()

        return inputs

    @staticmethod
    def _format_output(name: str, value: np.ndarray) -> Any:
        value = np.asarray(value, dtype=np.float32)
        value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
        squeezed = np.squeeze(value)
        if squeezed.ndim == 0:
            return float(squeezed)
        return squeezed.tolist()

    def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
        payload = data.get("inputs", data)
        if not isinstance(payload, Mapping):
            payload = {"psi": payload}

        feed = self._prepare_inputs(payload)
        outputs = self.session.run(None, feed)

        result = {
            node.name: self._format_output(node.name, value)
            for node, value in zip(self.io.outputs, outputs)
        }
        return result


__all__ = ["EndpointHandler"]