Yaz / yaz /semantic_router.py
TilelliLab's picture
Yaz v0.0.1 — safetensors + loader + model card + paper (editable/auditable tiny KB model)
b14638e verified
Raw
History Blame Contribute Delete
5.51 kB
"""Semantic router for Yaz fact-atoms (the keystone of Yaz + Engram).
Maps a prompt string -> a fact-atom id using FROZEN Engram (MiniLM, 384-d) embeddings,
so paraphrases of the same fact route to the same atom. Two routers, both over the same
frozen embeddings:
- FROZEN centroid: atom key = mean embedding of a country's TRAIN-template prompt
prefixes; route = nearest key by cosine. Zero learned params.
- LEARNED linear: a trainable Linear(384 -> n_country) trained (CE) on the same
TRAIN-template embeddings; route = argmax. Tests whether a learned projection over
frozen embeddings generalizes to held-out phrasings.
Country i (in `country_order`) owns fact-atom id i (matches train_gen's c2i ordering),
so a router that returns country index returns the atom id directly.
Leak-free: keys/head use ONLY train templates; held-out templates never enter fitting.
MiniLM is frozen, never fine-tuned on these facts.
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
_ST_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class _StEmbedder:
"""Self-contained frozen embedder over sentence-transformers all-MiniLM-L6-v2.
This is what the optional internal `engram` package wraps; using it directly means
anyone can reproduce Yaz's routing/abstention numbers with `pip install
sentence-transformers` and no local paths. Embeddings are L2-normalized, matching
the original engram Embedder exactly (same model, normalize_embeddings=True).
"""
def __init__(self, model_name: str = _ST_MODEL):
from sentence_transformers import SentenceTransformer # lazy import
self._st = SentenceTransformer(model_name)
try:
self.dim = int(self._st.get_embedding_dimension()) # sentence-transformers >=5
except AttributeError:
self.dim = int(self._st.get_sentence_embedding_dimension()) # older API
self.mode = "st"
def encode_one(self, text: str) -> np.ndarray:
return self._st.encode([text], convert_to_numpy=True,
normalize_embeddings=True).astype(np.float32)[0]
def _make_embedder(prefer="auto"):
"""Prefer the internal Engram embedder if YAZ_EMBEDDER_PATH points at one; otherwise
fall back to the bundled sentence-transformers embedder (identical 384-d MiniLM)."""
ep = os.environ.get("YAZ_EMBEDDER_PATH", "")
if ep:
sys.path.insert(0, ep)
try:
from engram import Embedder # noqa: E402
return Embedder(prefer=prefer)
except Exception:
pass # fall through to the public sentence-transformers embedder
return _StEmbedder()
class SemanticRouter:
def __init__(self, country_order, train_templates, prefer="auto"):
self.country_order = list(country_order) # atom id == index here
self.train_templates = list(train_templates)
self.emb = _make_embedder(prefer)
assert self.emb.mode == "st", f"need semantic embeddings, got mode={self.emb.mode}"
self.dim = int(self.emb.dim)
self._cache: dict[str, np.ndarray] = {}
self.centroids: np.ndarray | None = None # (n, dim) unit-norm
self.head: nn.Linear | None = None
# ---- embedding (cached, deterministic) ----
def embed(self, prompt: str) -> np.ndarray:
v = self._cache.get(prompt)
if v is None:
v = self.emb.encode_one(prompt).astype(np.float32) # L2-normed
self._cache[prompt] = v
return v
def _train_matrix(self):
"""Returns (X, y): X=(n*T, dim) train-prefix embeddings, y=(n*T,) country idx."""
X, y = [], []
for ci, c in enumerate(self.country_order):
for t in self.train_templates:
X.append(self.embed(t.format(C=c)))
y.append(ci)
return np.stack(X), np.array(y, dtype=np.int64)
# ---- frozen centroid router ----
def build_centroids(self):
n = len(self.country_order)
cent = np.zeros((n, self.dim), dtype=np.float32)
for ci, c in enumerate(self.country_order):
vs = np.stack([self.embed(t.format(C=c)) for t in self.train_templates])
m = vs.mean(0)
cent[ci] = m / (np.linalg.norm(m) + 1e-8)
self.centroids = cent
return self
def route_frozen(self, prompt: str) -> int:
v = self.embed(prompt) # unit-norm
return int((self.centroids @ v).argmax()) # nearest centroid by cosine
# ---- learned linear router ----
def train_head(self, steps=400, lr=1e-2, seed=2026):
torch.manual_seed(seed)
X, y = self._train_matrix()
Xt, yt = torch.from_numpy(X), torch.from_numpy(y)
head = nn.Linear(self.dim, len(self.country_order))
opt = torch.optim.Adam(head.parameters(), lr=lr, weight_decay=1e-4)
for _ in range(steps):
opt.zero_grad()
loss = F.cross_entropy(head(Xt), yt)
loss.backward()
opt.step()
head.eval()
self.head = head
with torch.no_grad():
tr_acc = float((head(Xt).argmax(1) == yt).float().mean())
return tr_acc
def route_learned(self, prompt: str) -> int:
v = torch.from_numpy(self.embed(prompt))[None]
with torch.no_grad():
return int(self.head(v).argmax(1).item())