noesis-decoder / handler.py
gnai-creator's picture
Upload folder using huggingface_hub
b1fcbcb verified
raw
history blame
8.92 kB
"""Custom inference handler for Hugging Face Inference Endpoints.
This module exposes :class:`EndpointHandler`, the entrypoint used by the
Hugging Face serving stack when ``--task custom`` is selected. The handler
loads the exported Noesis decoder ONNX graph and accepts symbolic intent
vectors (``psi``) along with an optional ``slow_state`` memory tensor. The
outputs mirror the values produced by the training runtime:
* ``z_out`` – semantic embedding projected back into symbolic space.
* ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars.
* ``slow_state`` – updated slow memory tensor suitable for recurrent usage.
The handler is intentionally lightweight so it can run without the rest of the
AletheiaEngine Python package being installed.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import re
from typing import Any, Mapping, MutableMapping, Optional
import numpy as np
import onnxruntime as ort
_WORD_RE = re.compile(r"\w+", re.UNICODE)
class _TextEncoder:
"""Deterministic text → vector encoder.
The Hugging Face Inference Endpoints frequently pass user prompts as
strings via the ``inputs`` field. The Noesis decoder, however, expects a
symbolic vector (``psi``) as input. To provide a graceful fallback the
handler lazily converts short text prompts into a stable float32 vector by
hashing tokens onto a hypersphere. This mirrors the lightweight
``TextEncoder256`` implementation bundled with the full AletheiaEngine
package while avoiding a heavy import dependency inside the endpoint
container.
"""
def __init__(self, dim: int) -> None:
self.dim = dim
@staticmethod
def _tokens(text: str) -> list[str]:
return [tok.lower() for tok in _WORD_RE.findall(text)]
@staticmethod
def _seed(tok: str) -> int:
# FNV-1a hash for determinism across processes/platforms.
value = 2166136261
for byte in tok.encode("utf-8"):
value ^= byte
value = (value * 16777619) & 0xFFFFFFFF
return int(value)
def encode(self, text: str) -> np.ndarray:
tokens = self._tokens(text)
if not tokens:
return np.zeros((1, self.dim), dtype=np.float32)
vecs = []
for tok in tokens:
rs = np.random.RandomState(self._seed(tok))
embedding = rs.normal(0.0, 1.0, size=(self.dim,)).astype(np.float32)
norm = float(np.linalg.norm(embedding)) or 1.0
vecs.append(embedding / norm)
stacked = np.stack(vecs, axis=0)
pooled = stacked.mean(axis=0, dtype=np.float32, keepdims=True)
pooled_norm = float(np.linalg.norm(pooled)) or 1.0
return pooled / pooled_norm
@dataclass(frozen=True)
class _ModelIO:
"""Snapshot of ONNX input and output metadata."""
inputs: tuple[ort.NodeArg, ...]
outputs: tuple[ort.NodeArg, ...]
class EndpointHandler:
"""Callable endpoint used by Hugging Face to drive inference."""
def __init__(self, path: str | None = None) -> None:
self.model_dir = Path(path or Path(__file__).parent)
self.session = self._load_session()
self.io = self._capture_io()
self.primary_input = self.io.inputs[0].name
self.slow_input = self._find_input("slow_state")
self._primary_dim = self._infer_primary_dim()
self._text_encoder = _TextEncoder(self._primary_dim)
self._defaults = {
node.name: self._zeros_like(node)
for node in self.io.inputs
if node.name not in {self.primary_input, self.slow_input}
}
if self.slow_input is not None:
self._slow_fallback = self._zeros_like(self._input_map[self.slow_input])
else:
self._slow_fallback = None
def _load_session(self) -> ort.InferenceSession:
"""Load the ONNX session, tolerating alternate filenames."""
preferred_names = ("model.onnx", "model_infer.onnx")
for name in preferred_names:
candidate = self.model_dir / name
if candidate.exists():
return ort.InferenceSession(str(candidate), providers=["CPUExecutionProvider"])
available = sorted(str(p.name) for p in self.model_dir.glob("*.onnx"))
if len(available) == 1:
# Fall back to the lone ONNX artefact if it has a non-standard name.
return ort.InferenceSession(str(self.model_dir / available[0]), providers=["CPUExecutionProvider"])
choices = ", ".join(available) or "<none>"
raise FileNotFoundError(
"Could not locate any of %s in %s (available: %s)"
% (", ".join(preferred_names), self.model_dir, choices)
)
@property
def _input_map(self) -> Mapping[str, ort.NodeArg]:
return {node.name: node for node in self.io.inputs}
def _capture_io(self) -> _ModelIO:
return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs()))
def _find_input(self, target: str) -> Optional[str]:
target = target.lower()
for node in self.io.inputs:
if node.name.lower() == target:
return node.name
return None
def _infer_primary_dim(self) -> int:
node = self._input_map[self.primary_input]
for dim in reversed(node.shape):
if isinstance(dim, int) and dim > 0:
return dim
# Conservative default matching TextEncoder256.
return 256
@staticmethod
def _zeros_like(node: ort.NodeArg) -> np.ndarray:
shape: list[int] = []
for dim in node.shape:
if isinstance(dim, int) and dim > 0:
shape.append(dim)
else:
shape.append(1)
return np.zeros(shape, dtype=np.float32)
@staticmethod
def _coerce_array(value: Any, *, allow_empty: bool = False) -> np.ndarray:
array = np.asarray(value, dtype=np.float32)
if array.size == 0 and not allow_empty:
raise ValueError("Received an empty array; provide at least one value.")
if array.ndim == 1:
array = np.expand_dims(array, axis=0)
elif array.ndim > 2:
raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,))
return array
def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]:
psi = payload.get("psi")
if psi is None:
psi = (
payload.get("vector")
or payload.get("psi_s")
or payload.get("inputs")
or payload.get("prompt")
or payload.get("text")
)
if psi is None:
raise KeyError("Payload must include a 'psi' field containing the symbolic vector.")
inputs: MutableMapping[str, np.ndarray] = {
self.primary_input: self._vector_from_payload(psi)
}
if self.slow_input is not None:
slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state")
if slow_value is None:
inputs[self.slow_input] = self._slow_fallback.copy()
else:
inputs[self.slow_input] = self._coerce_array(slow_value, allow_empty=True)
for name, default in self._defaults.items():
inputs[name] = default.copy()
return inputs
def _vector_from_payload(self, value: Any) -> np.ndarray:
if isinstance(value, str):
return self._text_encoder.encode(value)
if isinstance(value, (list, tuple)) and value and all(isinstance(v, str) for v in value):
return self._text_encoder.encode(" ".join(value))
return self._coerce_array(value)
@staticmethod
def _format_output(name: str, value: np.ndarray) -> Any:
value = np.asarray(value, dtype=np.float32)
value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
squeezed = np.squeeze(value)
if squeezed.ndim == 0:
return float(squeezed)
return squeezed.tolist()
def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
payload = data.get("inputs", data)
if not isinstance(payload, Mapping):
payload = {"psi": payload}
feed = self._prepare_inputs(payload)
outputs = self.session.run(None, feed)
result = {
node.name: self._format_output(node.name, value)
for node, value in zip(self.io.outputs, outputs)
}
return result
__all__ = ["EndpointHandler"]