"""Custom inference handler for Hugging Face Inference Endpoints. This module exposes :class:`EndpointHandler`, the entrypoint used by the Hugging Face serving stack when ``--task custom`` is selected. The handler loads the exported Noesis decoder ONNX graph and accepts symbolic intent vectors (``psi``) along with an optional ``slow_state`` memory tensor. The outputs mirror the values produced by the training runtime: * ``z_out`` – semantic embedding projected back into symbolic space. * ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars. * ``slow_state`` – updated slow memory tensor suitable for recurrent usage. The handler is intentionally lightweight so it can run without the rest of the AletheiaEngine Python package being installed. """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any, Mapping, MutableMapping, Optional import numpy as np import onnxruntime as ort @dataclass(frozen=True) class _ModelIO: """Snapshot of ONNX input and output metadata.""" inputs: tuple[ort.NodeArg, ...] outputs: tuple[ort.NodeArg, ...] class EndpointHandler: """Callable endpoint used by Hugging Face to drive inference.""" def __init__(self, path: str | None = None) -> None: self.model_dir = Path(path or Path(__file__).parent) self.session = self._load_session() self.io = self._capture_io() self.primary_input = self.io.inputs[0].name self.slow_input = self._find_input("slow_state") self._defaults = { node.name: self._zeros_like(node) for node in self.io.inputs if node.name not in {self.primary_input, self.slow_input} } if self.slow_input is not None: self._slow_fallback = self._zeros_like(self._input_map[self.slow_input]) else: self._slow_fallback = None def _load_session(self) -> ort.InferenceSession: model_path = self.model_dir / "model.onnx" if not model_path.exists(): available = ", ".join(sorted(str(p.name) for p in self.model_dir.glob("*.onnx"))) or "" raise FileNotFoundError( "Could not locate 'model.onnx' in %s (available: %s)" % (self.model_dir, available) ) return ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) @property def _input_map(self) -> Mapping[str, ort.NodeArg]: return {node.name: node for node in self.io.inputs} def _capture_io(self) -> _ModelIO: return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs())) def _find_input(self, target: str) -> Optional[str]: target = target.lower() for node in self.io.inputs: if node.name.lower() == target: return node.name return None @staticmethod def _zeros_like(node: ort.NodeArg) -> np.ndarray: shape: list[int] = [] for dim in node.shape: if isinstance(dim, int) and dim > 0: shape.append(dim) else: shape.append(1) return np.zeros(shape, dtype=np.float32) @staticmethod def _coerce_array(value: Any, *, allow_empty: bool = False) -> np.ndarray: array = np.asarray(value, dtype=np.float32) if array.size == 0 and not allow_empty: raise ValueError("Received an empty array; provide at least one value.") if array.ndim == 1: array = np.expand_dims(array, axis=0) elif array.ndim > 2: raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,)) return array def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]: psi = payload.get("psi") if psi is None: psi = payload.get("vector") or payload.get("psi_s") or payload.get("inputs") if psi is None: raise KeyError("Payload must include a 'psi' field containing the symbolic vector.") inputs: MutableMapping[str, np.ndarray] = {self.primary_input: self._coerce_array(psi)} if self.slow_input is not None: slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state") if slow_value is None: inputs[self.slow_input] = self._slow_fallback.copy() else: inputs[self.slow_input] = self._coerce_array(slow_value, allow_empty=True) for name, default in self._defaults.items(): inputs[name] = default.copy() return inputs @staticmethod def _format_output(name: str, value: np.ndarray) -> Any: value = np.asarray(value, dtype=np.float32) value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0) squeezed = np.squeeze(value) if squeezed.ndim == 0: return float(squeezed) return squeezed.tolist() def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]: payload = data.get("inputs", data) if not isinstance(payload, Mapping): payload = {"psi": payload} feed = self._prepare_inputs(payload) outputs = self.session.run(None, feed) result = { node.name: self._format_output(node.name, value) for node, value in zip(self.io.outputs, outputs) } return result __all__ = ["EndpointHandler"]