Yaz v0.0.1 — safetensors + loader + model card + paper (editable/auditable tiny KB model)
b14638e verified | """Semantic router for Yaz fact-atoms (the keystone of Yaz + Engram). | |
| Maps a prompt string -> a fact-atom id using FROZEN Engram (MiniLM, 384-d) embeddings, | |
| so paraphrases of the same fact route to the same atom. Two routers, both over the same | |
| frozen embeddings: | |
| - FROZEN centroid: atom key = mean embedding of a country's TRAIN-template prompt | |
| prefixes; route = nearest key by cosine. Zero learned params. | |
| - LEARNED linear: a trainable Linear(384 -> n_country) trained (CE) on the same | |
| TRAIN-template embeddings; route = argmax. Tests whether a learned projection over | |
| frozen embeddings generalizes to held-out phrasings. | |
| Country i (in `country_order`) owns fact-atom id i (matches train_gen's c2i ordering), | |
| so a router that returns country index returns the atom id directly. | |
| Leak-free: keys/head use ONLY train templates; held-out templates never enter fitting. | |
| MiniLM is frozen, never fine-tuned on these facts. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| _ST_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| class _StEmbedder: | |
| """Self-contained frozen embedder over sentence-transformers all-MiniLM-L6-v2. | |
| This is what the optional internal `engram` package wraps; using it directly means | |
| anyone can reproduce Yaz's routing/abstention numbers with `pip install | |
| sentence-transformers` and no local paths. Embeddings are L2-normalized, matching | |
| the original engram Embedder exactly (same model, normalize_embeddings=True). | |
| """ | |
| def __init__(self, model_name: str = _ST_MODEL): | |
| from sentence_transformers import SentenceTransformer # lazy import | |
| self._st = SentenceTransformer(model_name) | |
| try: | |
| self.dim = int(self._st.get_embedding_dimension()) # sentence-transformers >=5 | |
| except AttributeError: | |
| self.dim = int(self._st.get_sentence_embedding_dimension()) # older API | |
| self.mode = "st" | |
| def encode_one(self, text: str) -> np.ndarray: | |
| return self._st.encode([text], convert_to_numpy=True, | |
| normalize_embeddings=True).astype(np.float32)[0] | |
| def _make_embedder(prefer="auto"): | |
| """Prefer the internal Engram embedder if YAZ_EMBEDDER_PATH points at one; otherwise | |
| fall back to the bundled sentence-transformers embedder (identical 384-d MiniLM).""" | |
| ep = os.environ.get("YAZ_EMBEDDER_PATH", "") | |
| if ep: | |
| sys.path.insert(0, ep) | |
| try: | |
| from engram import Embedder # noqa: E402 | |
| return Embedder(prefer=prefer) | |
| except Exception: | |
| pass # fall through to the public sentence-transformers embedder | |
| return _StEmbedder() | |
| class SemanticRouter: | |
| def __init__(self, country_order, train_templates, prefer="auto"): | |
| self.country_order = list(country_order) # atom id == index here | |
| self.train_templates = list(train_templates) | |
| self.emb = _make_embedder(prefer) | |
| assert self.emb.mode == "st", f"need semantic embeddings, got mode={self.emb.mode}" | |
| self.dim = int(self.emb.dim) | |
| self._cache: dict[str, np.ndarray] = {} | |
| self.centroids: np.ndarray | None = None # (n, dim) unit-norm | |
| self.head: nn.Linear | None = None | |
| # ---- embedding (cached, deterministic) ---- | |
| def embed(self, prompt: str) -> np.ndarray: | |
| v = self._cache.get(prompt) | |
| if v is None: | |
| v = self.emb.encode_one(prompt).astype(np.float32) # L2-normed | |
| self._cache[prompt] = v | |
| return v | |
| def _train_matrix(self): | |
| """Returns (X, y): X=(n*T, dim) train-prefix embeddings, y=(n*T,) country idx.""" | |
| X, y = [], [] | |
| for ci, c in enumerate(self.country_order): | |
| for t in self.train_templates: | |
| X.append(self.embed(t.format(C=c))) | |
| y.append(ci) | |
| return np.stack(X), np.array(y, dtype=np.int64) | |
| # ---- frozen centroid router ---- | |
| def build_centroids(self): | |
| n = len(self.country_order) | |
| cent = np.zeros((n, self.dim), dtype=np.float32) | |
| for ci, c in enumerate(self.country_order): | |
| vs = np.stack([self.embed(t.format(C=c)) for t in self.train_templates]) | |
| m = vs.mean(0) | |
| cent[ci] = m / (np.linalg.norm(m) + 1e-8) | |
| self.centroids = cent | |
| return self | |
| def route_frozen(self, prompt: str) -> int: | |
| v = self.embed(prompt) # unit-norm | |
| return int((self.centroids @ v).argmax()) # nearest centroid by cosine | |
| # ---- learned linear router ---- | |
| def train_head(self, steps=400, lr=1e-2, seed=2026): | |
| torch.manual_seed(seed) | |
| X, y = self._train_matrix() | |
| Xt, yt = torch.from_numpy(X), torch.from_numpy(y) | |
| head = nn.Linear(self.dim, len(self.country_order)) | |
| opt = torch.optim.Adam(head.parameters(), lr=lr, weight_decay=1e-4) | |
| for _ in range(steps): | |
| opt.zero_grad() | |
| loss = F.cross_entropy(head(Xt), yt) | |
| loss.backward() | |
| opt.step() | |
| head.eval() | |
| self.head = head | |
| with torch.no_grad(): | |
| tr_acc = float((head(Xt).argmax(1) == yt).float().mean()) | |
| return tr_acc | |
| def route_learned(self, prompt: str) -> int: | |
| v = torch.from_numpy(self.embed(prompt))[None] | |
| with torch.no_grad(): | |
| return int(self.head(v).argmax(1).item()) | |