File size: 8,916 Bytes
2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 b1fcbcb 2a9a820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
"""Custom inference handler for Hugging Face Inference Endpoints.
This module exposes :class:`EndpointHandler`, the entrypoint used by the
Hugging Face serving stack when ``--task custom`` is selected. The handler
loads the exported Noesis decoder ONNX graph and accepts symbolic intent
vectors (``psi``) along with an optional ``slow_state`` memory tensor. The
outputs mirror the values produced by the training runtime:
* ``z_out`` – semantic embedding projected back into symbolic space.
* ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars.
* ``slow_state`` – updated slow memory tensor suitable for recurrent usage.
The handler is intentionally lightweight so it can run without the rest of the
AletheiaEngine Python package being installed.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import re
from typing import Any, Mapping, MutableMapping, Optional
import numpy as np
import onnxruntime as ort
_WORD_RE = re.compile(r"\w+", re.UNICODE)
class _TextEncoder:
"""Deterministic text → vector encoder.
The Hugging Face Inference Endpoints frequently pass user prompts as
strings via the ``inputs`` field. The Noesis decoder, however, expects a
symbolic vector (``psi``) as input. To provide a graceful fallback the
handler lazily converts short text prompts into a stable float32 vector by
hashing tokens onto a hypersphere. This mirrors the lightweight
``TextEncoder256`` implementation bundled with the full AletheiaEngine
package while avoiding a heavy import dependency inside the endpoint
container.
"""
def __init__(self, dim: int) -> None:
self.dim = dim
@staticmethod
def _tokens(text: str) -> list[str]:
return [tok.lower() for tok in _WORD_RE.findall(text)]
@staticmethod
def _seed(tok: str) -> int:
# FNV-1a hash for determinism across processes/platforms.
value = 2166136261
for byte in tok.encode("utf-8"):
value ^= byte
value = (value * 16777619) & 0xFFFFFFFF
return int(value)
def encode(self, text: str) -> np.ndarray:
tokens = self._tokens(text)
if not tokens:
return np.zeros((1, self.dim), dtype=np.float32)
vecs = []
for tok in tokens:
rs = np.random.RandomState(self._seed(tok))
embedding = rs.normal(0.0, 1.0, size=(self.dim,)).astype(np.float32)
norm = float(np.linalg.norm(embedding)) or 1.0
vecs.append(embedding / norm)
stacked = np.stack(vecs, axis=0)
pooled = stacked.mean(axis=0, dtype=np.float32, keepdims=True)
pooled_norm = float(np.linalg.norm(pooled)) or 1.0
return pooled / pooled_norm
@dataclass(frozen=True)
class _ModelIO:
"""Snapshot of ONNX input and output metadata."""
inputs: tuple[ort.NodeArg, ...]
outputs: tuple[ort.NodeArg, ...]
class EndpointHandler:
"""Callable endpoint used by Hugging Face to drive inference."""
def __init__(self, path: str | None = None) -> None:
self.model_dir = Path(path or Path(__file__).parent)
self.session = self._load_session()
self.io = self._capture_io()
self.primary_input = self.io.inputs[0].name
self.slow_input = self._find_input("slow_state")
self._primary_dim = self._infer_primary_dim()
self._text_encoder = _TextEncoder(self._primary_dim)
self._defaults = {
node.name: self._zeros_like(node)
for node in self.io.inputs
if node.name not in {self.primary_input, self.slow_input}
}
if self.slow_input is not None:
self._slow_fallback = self._zeros_like(self._input_map[self.slow_input])
else:
self._slow_fallback = None
def _load_session(self) -> ort.InferenceSession:
"""Load the ONNX session, tolerating alternate filenames."""
preferred_names = ("model.onnx", "model_infer.onnx")
for name in preferred_names:
candidate = self.model_dir / name
if candidate.exists():
return ort.InferenceSession(str(candidate), providers=["CPUExecutionProvider"])
available = sorted(str(p.name) for p in self.model_dir.glob("*.onnx"))
if len(available) == 1:
# Fall back to the lone ONNX artefact if it has a non-standard name.
return ort.InferenceSession(str(self.model_dir / available[0]), providers=["CPUExecutionProvider"])
choices = ", ".join(available) or "<none>"
raise FileNotFoundError(
"Could not locate any of %s in %s (available: %s)"
% (", ".join(preferred_names), self.model_dir, choices)
)
@property
def _input_map(self) -> Mapping[str, ort.NodeArg]:
return {node.name: node for node in self.io.inputs}
def _capture_io(self) -> _ModelIO:
return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs()))
def _find_input(self, target: str) -> Optional[str]:
target = target.lower()
for node in self.io.inputs:
if node.name.lower() == target:
return node.name
return None
def _infer_primary_dim(self) -> int:
node = self._input_map[self.primary_input]
for dim in reversed(node.shape):
if isinstance(dim, int) and dim > 0:
return dim
# Conservative default matching TextEncoder256.
return 256
@staticmethod
def _zeros_like(node: ort.NodeArg) -> np.ndarray:
shape: list[int] = []
for dim in node.shape:
if isinstance(dim, int) and dim > 0:
shape.append(dim)
else:
shape.append(1)
return np.zeros(shape, dtype=np.float32)
@staticmethod
def _coerce_array(value: Any, *, allow_empty: bool = False) -> np.ndarray:
array = np.asarray(value, dtype=np.float32)
if array.size == 0 and not allow_empty:
raise ValueError("Received an empty array; provide at least one value.")
if array.ndim == 1:
array = np.expand_dims(array, axis=0)
elif array.ndim > 2:
raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,))
return array
def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]:
psi = payload.get("psi")
if psi is None:
psi = (
payload.get("vector")
or payload.get("psi_s")
or payload.get("inputs")
or payload.get("prompt")
or payload.get("text")
)
if psi is None:
raise KeyError("Payload must include a 'psi' field containing the symbolic vector.")
inputs: MutableMapping[str, np.ndarray] = {
self.primary_input: self._vector_from_payload(psi)
}
if self.slow_input is not None:
slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state")
if slow_value is None:
inputs[self.slow_input] = self._slow_fallback.copy()
else:
inputs[self.slow_input] = self._coerce_array(slow_value, allow_empty=True)
for name, default in self._defaults.items():
inputs[name] = default.copy()
return inputs
def _vector_from_payload(self, value: Any) -> np.ndarray:
if isinstance(value, str):
return self._text_encoder.encode(value)
if isinstance(value, (list, tuple)) and value and all(isinstance(v, str) for v in value):
return self._text_encoder.encode(" ".join(value))
return self._coerce_array(value)
@staticmethod
def _format_output(name: str, value: np.ndarray) -> Any:
value = np.asarray(value, dtype=np.float32)
value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
squeezed = np.squeeze(value)
if squeezed.ndim == 0:
return float(squeezed)
return squeezed.tolist()
def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
payload = data.get("inputs", data)
if not isinstance(payload, Mapping):
payload = {"psi": payload}
feed = self._prepare_inputs(payload)
outputs = self.session.run(None, feed)
result = {
node.name: self._format_output(node.name, value)
for node, value in zip(self.io.outputs, outputs)
}
return result
__all__ = ["EndpointHandler"]
|