gnai-creator commited on
Commit
5f1dd80
·
verified ·
1 Parent(s): 93fb87b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +274 -233
handler.py CHANGED
@@ -1,233 +1,274 @@
1
- """Custom inference handler for Hugging Face Inference Endpoints.
2
-
3
- This module exposes :class:`EndpointHandler`, the entrypoint used by the
4
- Hugging Face serving stack when ``--task custom`` is selected. The handler
5
- loads the exported Noesis decoder ONNX graph and accepts symbolic intent
6
- vectors (``psi``) along with an optional ``slow_state`` memory tensor. The
7
- outputs mirror the values produced by the training runtime:
8
-
9
- * ``z_out`` – semantic embedding projected back into symbolic space.
10
- * ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars.
11
- * ``slow_state`` – updated slow memory tensor suitable for recurrent usage.
12
-
13
- The handler is intentionally lightweight so it can run without the rest of the
14
- AletheiaEngine Python package being installed.
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- from dataclasses import dataclass
20
- from pathlib import Path
21
- import re
22
- from typing import Any, Mapping, MutableMapping, Optional
23
-
24
- import numpy as np
25
- import onnxruntime as ort
26
-
27
-
28
- _WORD_RE = re.compile(r"\w+", re.UNICODE)
29
-
30
-
31
- class _TextEncoder:
32
- """Deterministic text → vector encoder.
33
-
34
- The Hugging Face Inference Endpoints frequently pass user prompts as
35
- strings via the ``inputs`` field. The Noesis decoder, however, expects a
36
- symbolic vector (``psi``) as input. To provide a graceful fallback the
37
- handler lazily converts short text prompts into a stable float32 vector by
38
- hashing tokens onto a hypersphere. This mirrors the lightweight
39
- ``TextEncoder256`` implementation bundled with the full AletheiaEngine
40
- package while avoiding a heavy import dependency inside the endpoint
41
- container.
42
- """
43
-
44
- def __init__(self, dim: int) -> None:
45
- self.dim = dim
46
-
47
- @staticmethod
48
- def _tokens(text: str) -> list[str]:
49
- return [tok.lower() for tok in _WORD_RE.findall(text)]
50
-
51
- @staticmethod
52
- def _seed(tok: str) -> int:
53
- # FNV-1a hash for determinism across processes/platforms.
54
- value = 2166136261
55
- for byte in tok.encode("utf-8"):
56
- value ^= byte
57
- value = (value * 16777619) & 0xFFFFFFFF
58
- return int(value)
59
-
60
- def encode(self, text: str) -> np.ndarray:
61
- tokens = self._tokens(text)
62
- if not tokens:
63
- return np.zeros((1, self.dim), dtype=np.float32)
64
-
65
- vecs = []
66
- for tok in tokens:
67
- rs = np.random.RandomState(self._seed(tok))
68
- embedding = rs.normal(0.0, 1.0, size=(self.dim,)).astype(np.float32)
69
- norm = float(np.linalg.norm(embedding)) or 1.0
70
- vecs.append(embedding / norm)
71
-
72
- stacked = np.stack(vecs, axis=0)
73
- pooled = stacked.mean(axis=0, dtype=np.float32, keepdims=True)
74
- pooled_norm = float(np.linalg.norm(pooled)) or 1.0
75
- return pooled / pooled_norm
76
-
77
-
78
- @dataclass(frozen=True)
79
- class _ModelIO:
80
- """Snapshot of ONNX input and output metadata."""
81
-
82
- inputs: tuple[ort.NodeArg, ...]
83
- outputs: tuple[ort.NodeArg, ...]
84
-
85
-
86
- class EndpointHandler:
87
- """Callable endpoint used by Hugging Face to drive inference."""
88
-
89
- def __init__(self, path: str | None = None) -> None:
90
- self.model_dir = Path(path or Path(__file__).parent)
91
- self.session = self._load_session()
92
- self.io = self._capture_io()
93
-
94
- self.primary_input = self.io.inputs[0].name
95
- self.slow_input = self._find_input("slow_state")
96
- self._primary_dim = self._infer_primary_dim()
97
- self._text_encoder = _TextEncoder(self._primary_dim)
98
- self._defaults = {
99
- node.name: self._zeros_like(node)
100
- for node in self.io.inputs
101
- if node.name not in {self.primary_input, self.slow_input}
102
- }
103
- if self.slow_input is not None:
104
- self._slow_fallback = self._zeros_like(self._input_map[self.slow_input])
105
- else:
106
- self._slow_fallback = None
107
-
108
- def _load_session(self) -> ort.InferenceSession:
109
- """Load the ONNX session, tolerating alternate filenames."""
110
-
111
- preferred_names = ("model.onnx", "model_infer.onnx")
112
- for name in preferred_names:
113
- candidate = self.model_dir / name
114
- if candidate.exists():
115
- return ort.InferenceSession(str(candidate), providers=["CPUExecutionProvider"])
116
-
117
- available = sorted(str(p.name) for p in self.model_dir.glob("*.onnx"))
118
- if len(available) == 1:
119
- # Fall back to the lone ONNX artefact if it has a non-standard name.
120
- return ort.InferenceSession(str(self.model_dir / available[0]), providers=["CPUExecutionProvider"])
121
-
122
- choices = ", ".join(available) or "<none>"
123
- raise FileNotFoundError(
124
- "Could not locate any of %s in %s (available: %s)"
125
- % (", ".join(preferred_names), self.model_dir, choices)
126
- )
127
-
128
- @property
129
- def _input_map(self) -> Mapping[str, ort.NodeArg]:
130
- return {node.name: node for node in self.io.inputs}
131
-
132
- def _capture_io(self) -> _ModelIO:
133
- return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs()))
134
-
135
- def _find_input(self, target: str) -> Optional[str]:
136
- target = target.lower()
137
- for node in self.io.inputs:
138
- if node.name.lower() == target:
139
- return node.name
140
- return None
141
-
142
- def _infer_primary_dim(self) -> int:
143
- node = self._input_map[self.primary_input]
144
- for dim in reversed(node.shape):
145
- if isinstance(dim, int) and dim > 0:
146
- return dim
147
- # Conservative default matching TextEncoder256.
148
- return 256
149
-
150
- @staticmethod
151
- def _zeros_like(node: ort.NodeArg) -> np.ndarray:
152
- shape: list[int] = []
153
- for dim in node.shape:
154
- if isinstance(dim, int) and dim > 0:
155
- shape.append(dim)
156
- else:
157
- shape.append(1)
158
- return np.zeros(shape, dtype=np.float32)
159
-
160
- @staticmethod
161
- def _coerce_array(value: Any, *, allow_empty: bool = False) -> np.ndarray:
162
- array = np.asarray(value, dtype=np.float32)
163
- if array.size == 0 and not allow_empty:
164
- raise ValueError("Received an empty array; provide at least one value.")
165
- if array.ndim == 1:
166
- array = np.expand_dims(array, axis=0)
167
- elif array.ndim > 2:
168
- raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,))
169
- return array
170
-
171
- def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]:
172
- psi = payload.get("psi")
173
- if psi is None:
174
- psi = (
175
- payload.get("vector")
176
- or payload.get("psi_s")
177
- or payload.get("inputs")
178
- or payload.get("prompt")
179
- or payload.get("text")
180
- )
181
- if psi is None:
182
- raise KeyError("Payload must include a 'psi' field containing the symbolic vector.")
183
-
184
- inputs: MutableMapping[str, np.ndarray] = {
185
- self.primary_input: self._vector_from_payload(psi)
186
- }
187
-
188
- if self.slow_input is not None:
189
- slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state")
190
- if slow_value is None:
191
- inputs[self.slow_input] = self._slow_fallback.copy()
192
- else:
193
- inputs[self.slow_input] = self._coerce_array(slow_value, allow_empty=True)
194
-
195
- for name, default in self._defaults.items():
196
- inputs[name] = default.copy()
197
-
198
- return inputs
199
-
200
- def _vector_from_payload(self, value: Any) -> np.ndarray:
201
- if isinstance(value, str):
202
- return self._text_encoder.encode(value)
203
-
204
- if isinstance(value, (list, tuple)) and value and all(isinstance(v, str) for v in value):
205
- return self._text_encoder.encode(" ".join(value))
206
-
207
- return self._coerce_array(value)
208
-
209
- @staticmethod
210
- def _format_output(name: str, value: np.ndarray) -> Any:
211
- value = np.asarray(value, dtype=np.float32)
212
- value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
213
- squeezed = np.squeeze(value)
214
- if squeezed.ndim == 0:
215
- return float(squeezed)
216
- return squeezed.tolist()
217
-
218
- def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
219
- payload = data.get("inputs", data)
220
- if not isinstance(payload, Mapping):
221
- payload = {"psi": payload}
222
-
223
- feed = self._prepare_inputs(payload)
224
- outputs = self.session.run(None, feed)
225
-
226
- result = {
227
- node.name: self._format_output(node.name, value)
228
- for node, value in zip(self.io.outputs, outputs)
229
- }
230
- return result
231
-
232
-
233
- __all__ = ["EndpointHandler"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom inference handler for Hugging Face Inference Endpoints.
2
+
3
+ This module exposes :class:`EndpointHandler`, the entrypoint used by the
4
+ Hugging Face serving stack when ``--task custom`` is selected. The handler
5
+ loads the exported Noesis decoder ONNX graph and accepts symbolic intent
6
+ vectors (``psi``) along with an optional ``slow_state`` memory tensor. The
7
+ outputs mirror the values produced by the training runtime:
8
+
9
+ * ``z_out`` – semantic embedding projected back into symbolic space.
10
+ * ``choice``, ``pain``, ``memory`` and ``quality`` – diagnostic scalars.
11
+ * ``slow_state`` – updated slow memory tensor suitable for recurrent usage.
12
+
13
+ The handler is intentionally lightweight so it can run without the rest of the
14
+ AletheiaEngine Python package being installed.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import importlib
20
+ import importlib.util
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ import re
24
+ from typing import Any, Mapping, MutableMapping, Optional
25
+
26
+ import numpy as np
27
+
28
+
29
+ _WORD_RE = re.compile(r"\w+", re.UNICODE)
30
+
31
+
32
+ class _TextEncoder:
33
+ """Deterministic text → vector encoder.
34
+
35
+ The Hugging Face Inference Endpoints frequently pass user prompts as
36
+ strings via the ``inputs`` field. The Noesis decoder, however, expects a
37
+ symbolic vector (``psi``) as input. To provide a graceful fallback the
38
+ handler lazily converts short text prompts into a stable float32 vector by
39
+ hashing tokens onto a hypersphere. This mirrors the lightweight
40
+ ``TextEncoder256`` implementation bundled with the full AletheiaEngine
41
+ package while avoiding a heavy import dependency inside the endpoint
42
+ container.
43
+ """
44
+
45
+ def __init__(self, dim: int) -> None:
46
+ self.dim = dim
47
+
48
+ @staticmethod
49
+ def _tokens(text: str) -> list[str]:
50
+ return [tok.lower() for tok in _WORD_RE.findall(text)]
51
+
52
+ @staticmethod
53
+ def _seed(tok: str) -> int:
54
+ # FNV-1a hash for determinism across processes/platforms.
55
+ value = 2166136261
56
+ for byte in tok.encode("utf-8"):
57
+ value ^= byte
58
+ value = (value * 16777619) & 0xFFFFFFFF
59
+ return int(value)
60
+
61
+ def encode(self, text: str) -> np.ndarray:
62
+ tokens = self._tokens(text)
63
+ if not tokens:
64
+ return np.zeros((1, self.dim), dtype=np.float32)
65
+
66
+ vecs = []
67
+ for tok in tokens:
68
+ rs = np.random.RandomState(self._seed(tok))
69
+ embedding = rs.normal(0.0, 1.0, size=(self.dim,)).astype(np.float32)
70
+ norm = float(np.linalg.norm(embedding)) or 1.0
71
+ vecs.append(embedding / norm)
72
+
73
+ stacked = np.stack(vecs, axis=0)
74
+ pooled = stacked.mean(axis=0, dtype=np.float32, keepdims=True)
75
+ pooled_norm = float(np.linalg.norm(pooled)) or 1.0
76
+ return pooled / pooled_norm
77
+
78
+
79
+ @dataclass(frozen=True)
80
+ class _ModelIO:
81
+ """Snapshot of ONNX input and output metadata."""
82
+
83
+ inputs: tuple[Any, ...]
84
+ outputs: tuple[Any, ...]
85
+
86
+
87
+ class EndpointHandler:
88
+ """Callable endpoint used by Hugging Face to drive inference."""
89
+
90
+ def __init__(self, path: str | None = None) -> None:
91
+ self.model_dir = Path(path or Path(__file__).parent)
92
+ self.session = self._load_session()
93
+ self.io = self._capture_io()
94
+
95
+ self.primary_input = self.io.inputs[0].name
96
+ self.slow_input = self._find_input("slow_state")
97
+ self._primary_dim = self._infer_primary_dim()
98
+ self._text_encoder = _TextEncoder(self._primary_dim)
99
+ self._defaults = {}
100
+ for node in self.io.inputs:
101
+ if node.name in {self.primary_input, self.slow_input}:
102
+ continue
103
+ self._defaults[node.name] = self._zeros_like(node)
104
+ if self.slow_input is not None:
105
+ self._slow_fallback = self._zeros_like(self._input_map[self.slow_input])
106
+ else:
107
+ self._slow_fallback = None
108
+
109
+ def _load_session(self):
110
+ """Load the ONNX session, tolerating alternate filenames."""
111
+
112
+ ort = self._import_onnxruntime()
113
+ preferred_names = ("model.onnx", "model_infer.onnx")
114
+ for name in preferred_names:
115
+ candidate = self.model_dir / name
116
+ if candidate.exists():
117
+ return ort.InferenceSession(str(candidate), providers=["CPUExecutionProvider"])
118
+
119
+ available = sorted(str(p.name) for p in self.model_dir.glob("*.onnx"))
120
+ if len(available) == 1:
121
+ # Fall back to the lone ONNX artefact if it has a non-standard name.
122
+ return ort.InferenceSession(str(self.model_dir / available[0]), providers=["CPUExecutionProvider"])
123
+
124
+ choices = ", ".join(available) or "<none>"
125
+ raise FileNotFoundError(
126
+ "Could not locate any of %s in %s (available: %s)"
127
+ % (", ".join(preferred_names), self.model_dir, choices)
128
+ )
129
+
130
+ @staticmethod
131
+ def _import_onnxruntime():
132
+ """Import :mod:`onnxruntime`, providing a helpful error if unavailable."""
133
+
134
+ spec = importlib.util.find_spec("onnxruntime")
135
+ if spec is None:
136
+ raise ModuleNotFoundError(
137
+ "onnxruntime is required to load Noesis decoder ONNX graphs. "
138
+ "Install it with 'pip install onnxruntime'."
139
+ )
140
+ return importlib.import_module("onnxruntime")
141
+
142
+ @property
143
+ def _input_map(self) -> Mapping[str, Any]:
144
+ return {node.name: node for node in self.io.inputs}
145
+
146
+ def _capture_io(self) -> _ModelIO:
147
+ return _ModelIO(inputs=tuple(self.session.get_inputs()), outputs=tuple(self.session.get_outputs()))
148
+
149
+ def _find_input(self, target: str) -> Optional[str]:
150
+ target = target.lower()
151
+ for node in self.io.inputs:
152
+ if node.name.lower() == target:
153
+ return node.name
154
+ return None
155
+
156
+ def _infer_primary_dim(self) -> int:
157
+ node = self._input_map[self.primary_input]
158
+ for dim in reversed(node.shape):
159
+ if isinstance(dim, int) and dim > 0:
160
+ return dim
161
+ # Conservative default matching TextEncoder256.
162
+ return 256
163
+
164
+ @staticmethod
165
+ def _onnx_type_to_numpy(type_str: str | None) -> np.dtype:
166
+ mapping = {
167
+ "tensor(float)": np.float32,
168
+ "tensor(float16)": np.float16,
169
+ "tensor(double)": np.float64,
170
+ "tensor(int64)": np.int64,
171
+ "tensor(int32)": np.int32,
172
+ "tensor(int16)": np.int16,
173
+ "tensor(int8)": np.int8,
174
+ "tensor(uint8)": np.uint8,
175
+ "tensor(bool)": np.bool_,
176
+ }
177
+ return mapping.get(type_str, np.float32)
178
+
179
+ def _dtype_for(self, node: Any) -> np.dtype:
180
+ return self._onnx_type_to_numpy(getattr(node, "type", None))
181
+
182
+ def _zeros_like(self, node: Any) -> np.ndarray:
183
+ shape: list[int] = []
184
+ for dim in node.shape:
185
+ if isinstance(dim, int) and dim > 0:
186
+ shape.append(dim)
187
+ else:
188
+ shape.append(1)
189
+ dtype = self._dtype_for(node)
190
+ return np.zeros(shape, dtype=dtype)
191
+
192
+ def _coerce_array(self, value: Any, *, node: Any, allow_empty: bool = False) -> np.ndarray:
193
+ dtype = self._dtype_for(node)
194
+ array = np.asarray(value, dtype=dtype)
195
+ if array.size == 0 and not allow_empty:
196
+ raise ValueError("Received an empty array; provide at least one value.")
197
+ if array.ndim == 1:
198
+ array = np.expand_dims(array, axis=0)
199
+ elif array.ndim > 2:
200
+ raise ValueError("Expected a 1D or batched 2D array; received shape %s" % (array.shape,))
201
+ if array.dtype != dtype:
202
+ array = array.astype(dtype, copy=False)
203
+ return array
204
+
205
+ def _prepare_inputs(self, payload: Mapping[str, Any]) -> MutableMapping[str, np.ndarray]:
206
+ psi = payload.get("psi")
207
+ if psi is None:
208
+ psi = (
209
+ payload.get("vector")
210
+ or payload.get("psi_s")
211
+ or payload.get("inputs")
212
+ or payload.get("prompt")
213
+ or payload.get("text")
214
+ )
215
+ if psi is None:
216
+ raise KeyError("Payload must include a 'psi' field containing the symbolic vector.")
217
+
218
+ primary_node = self._input_map[self.primary_input]
219
+ inputs: MutableMapping[str, np.ndarray] = {
220
+ self.primary_input: self._vector_from_payload(psi, node=primary_node)
221
+ }
222
+
223
+ if self.slow_input is not None:
224
+ slow_value = payload.get("slow_state") or payload.get("slow") or payload.get("state")
225
+ if slow_value is None:
226
+ inputs[self.slow_input] = self._slow_fallback.copy()
227
+ else:
228
+ inputs[self.slow_input] = self._coerce_array(
229
+ slow_value,
230
+ node=self._input_map[self.slow_input],
231
+ allow_empty=True,
232
+ )
233
+
234
+ for name, default in self._defaults.items():
235
+ inputs[name] = default.copy()
236
+
237
+ return inputs
238
+
239
+ def _vector_from_payload(self, value: Any, *, node: Any) -> np.ndarray:
240
+ if isinstance(value, str):
241
+ encoded = self._text_encoder.encode(value)
242
+ return self._coerce_array(encoded, node=node)
243
+
244
+ if isinstance(value, (list, tuple)) and value and all(isinstance(v, str) for v in value):
245
+ encoded = self._text_encoder.encode(" ".join(value))
246
+ return self._coerce_array(encoded, node=node)
247
+
248
+ return self._coerce_array(value, node=node)
249
+
250
+ @staticmethod
251
+ def _format_output(name: str, value: np.ndarray) -> Any:
252
+ value = np.asarray(value, dtype=np.float32)
253
+ value = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
254
+ squeezed = np.squeeze(value)
255
+ if squeezed.ndim == 0:
256
+ return float(squeezed)
257
+ return squeezed.tolist()
258
+
259
+ def __call__(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
260
+ payload = data.get("inputs", data)
261
+ if not isinstance(payload, Mapping):
262
+ payload = {"psi": payload}
263
+
264
+ feed = self._prepare_inputs(payload)
265
+ outputs = self.session.run(None, feed)
266
+
267
+ result = {
268
+ node.name: self._format_output(node.name, value)
269
+ for node, value in zip(self.io.outputs, outputs)
270
+ }
271
+ return result
272
+
273
+
274
+ __all__ = ["EndpointHandler"]