"""Custom inference handler for Hugging Face Inference Endpoints. This module exposes :class:`EndpointHandler`, the entrypoint used by the Hugging Face serving stack when ``--task custom`` is selected. The handler loads the exported Noesis decoder ONNX graph and accepts symbolic intent vectors (``psi``) along with an optional ``slow_state`` memory tensor. The outputs mirror the values produced by the training runtime: * ``z_out`` – semantic embedding projected back into symbolic space. * ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars. * ``slow_state`` – updated slow memory tensor suitable for recurrent usage. The handler is intentionally lightweight so it can run without the rest of the AletheiaEngine Python package being installed. """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path import re from typing import Any, Mapping, MutableMapping, Optional import numpy as np import onnxruntime as ort _WORD_RE = re.compile(r"\w+", re.UNICODE) class _TextEncoder: """Deterministic text → vector encoder. The Hugging Face Inference Endpoints frequently pass user prompts as strings via the ``inputs`` field. The Noesis decoder, however, expects a symbolic vector (``psi``) as input. To provide a graceful fallback the handler lazily converts short text prompts into a stable float32 vector by hashing tokens onto a hypersphere. This mirrors the lightweight ``TextEncoder256`` implementation bundled with the full AletheiaEngine package while avoiding a heavy import dependency inside the endpoint container. """ def __init__(self, dim: int) -> None: self.dim = dim @staticmethod def _tokens(text: str) -> list[str]: return [tok.lower() for tok in _WORD_RE.findall(text)] @staticmethod def _seed(tok: str) -> int: # FNV-1a hash for determinism across processes/platforms. value = 2166136261 for byte in tok.encode("utf-8"): value ^= byte value = (value * 16777619) & 0xFFFFFFFF return int(value) def encode(self, text: str) -> np.ndarray: tokens = self._tokens(text) if not tokens: return np.zeros((1, self.dim), dtype=np.float32) vecs = [] for tok in tokens: rs = np.random.RandomState(self._seed(tok)) embedding = rs.normal(0.0, 1.0, size=(self.dim,)).astype(np.float32) norm = float(np.linalg.norm(embedding)) or 1.0 vecs.append(embedding / norm) stacked = np.stack(vecs, axis=0) pooled = stacked.mean(axis=0, dtype=np.float32, keepdims=True) pooled_norm = float(np.linalg.norm(pooled)) or 1.0 return pooled / pooled_norm @dataclass(frozen=True) class _ModelIO: """Snapshot of ONNX input and output metadata.""" inputs: tuple[ort.NodeArg, ...] outputs: tuple[ort.NodeArg, ...] class EndpointHandler: """Callable endpoint used by Hugging Face to drive inference.""" def __init__(self, path: str | None = None) -> None: self.model_dir = Path(path or Path(__file__).parent) self.session = self._load_session() self.io = self._capture_io() self.primary_input = self.io.inputs[0].name self.slow_input = self._find_input("slow_state") self._primary_dim = self._infer_primary_dim() self._text_encoder = _TextEncoder(self._primary_dim) self._defaults = { node.name: self._zeros_like(node) for node in self.io.inputs if node.name not in {self.primary_input, self.slow_input} } if self.slow_input is not None: self._slow_fallback = self._zeros_like(self._input_map[self.slow_input]) else: self._slow_fallback = None def _load_session(self) -> ort.InferenceSession: """Load the ONNX session, tolerating alternate filenames.""" preferred_names = ("model.onnx", "model_infer.onnx") for name in preferred_names: candidate = self.model_dir / name if candidate.exists(): return ort.InferenceSession(str(candidate), providers=["CPUExecutionProvider"]) available = sorted(str(p.name) for p in self.model_dir.glob("*.onnx")) if len(available) == 1: # Fall back to the lone ONNX artefact if it has a non-standard name. return ort.InferenceSession(str(self.model_dir / available[0]), providers=["CPUExecutionProvider"]) choices = ", ".join(available) or "" raise FileNotFoundError( "Could not locate any of %s in %s (available: %s)" % (", ".join(preferred_names), self.model_dir, choices) ) @property def _input_map(self) -> Mapping[str, ort.NodeArg]: return {node.name: node for node in self.io.inputs} def _capture_io(self) -> _ModelIO: return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs())) def _find_input(self, target: str) -> Optional[str]: target = target.lower() for node in self.io.inputs: if node.name.lower() == target: return node.name return None def _infer_primary_dim(self) -> int: node = self._input_map[self.primary_input] for dim in reversed(node.shape): if isinstance(dim, int) and dim > 0: return dim # Conservative default matching TextEncoder256. return 256 @staticmethod def _zeros_like(node: ort.NodeArg) -> np.ndarray: shape: list[int] = [] for dim in node.shape: if isinstance(dim, int) and dim > 0: shape.append(dim) else: shape.append(1) return np.zeros(shape, dtype=np.float32) @staticmethod def _coerce_array(value: Any, *, allow_empty: bool = False) -> np.ndarray: array = np.asarray(value, dtype=np.float32) if array.size == 0 and not allow_empty: raise ValueError("Received an empty array; provide at least one value.") if array.ndim == 1: array = np.expand_dims(array, axis=0) elif array.ndim > 2: raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,)) return array def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]: psi = payload.get("psi") if psi is None: psi = ( payload.get("vector") or payload.get("psi_s") or payload.get("inputs") or payload.get("prompt") or payload.get("text") ) if psi is None: raise KeyError("Payload must include a 'psi' field containing the symbolic vector.") inputs: MutableMapping[str, np.ndarray] = { self.primary_input: self._vector_from_payload(psi) } if self.slow_input is not None: slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state") if slow_value is None: inputs[self.slow_input] = self._slow_fallback.copy() else: inputs[self.slow_input] = self._coerce_array(slow_value, allow_empty=True) for name, default in self._defaults.items(): inputs[name] = default.copy() return inputs def _vector_from_payload(self, value: Any) -> np.ndarray: if isinstance(value, str): return self._text_encoder.encode(value) if isinstance(value, (list, tuple)) and value and all(isinstance(v, str) for v in value): return self._text_encoder.encode(" ".join(value)) return self._coerce_array(value) @staticmethod def _format_output(name: str, value: np.ndarray) -> Any: value = np.asarray(value, dtype=np.float32) value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0) squeezed = np.squeeze(value) if squeezed.ndim == 0: return float(squeezed) return squeezed.tolist() def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]: payload = data.get("inputs", data) if not isinstance(payload, Mapping): payload = {"psi": payload} feed = self._prepare_inputs(payload) outputs = self.session.run(None, feed) result = { node.name: self._format_output(node.name, value) for node, value in zip(self.io.outputs, outputs) } return result __all__ = ["EndpointHandler"]