web-ttt / web-ttt.js
LJTSG's picture
web-TTT: browser-native test-time-trainable memory toolkit
78bc0e4 verified
// web-ttt.js — a browser-native, test-time-TRAINABLE associative memory.
//
// MiniLM sentence embeddings + a learnable 384x384 projection matrix W.
// Recall: score(fact) = cosine( normalize(W · embed(query)) , fact_vec ) -> top-K.
// Train (test-time training): pull a query toward a target fact by cross-entropy
// gradient descent on W, ~25 steps, ENTIRELY IN THE BROWSER. W persists in localStorage.
// No server, no API, no backend. The memory learns from use, on the user's device.
//
// Usage:
// import { WebTTT } from "./web-ttt.js";
// const ttt = new WebTTT({ storageKey: "my_ttt" });
// await ttt.init(); // loads MiniLM (transformers.js)
// await ttt.load([{ key, text }, ...]); // embeds your corpus
// const hits = await ttt.recall("a query", 5);
// await ttt.teach("a query", hits[0].index); // 25-step W update, then persists
//
// Embeddings come from Xenova/all-MiniLM-L6-v2 via transformers.js. By default both
// load from CDN; pass localModelPath/transformersUrl to run fully offline.
const DIM = 384;
const TEMP = 0.1;
export class WebTTT {
constructor({ storageKey = "web_ttt" } = {}) {
this.facts = [];
this.W = null;
this.embed = null;
this.storageKey = storageKey;
}
async init({
transformersUrl = "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2",
model = "Xenova/all-MiniLM-L6-v2",
localModelPath = null,
wasmPaths = null,
} = {}) {
const mod = await import(transformersUrl);
const { pipeline, env } = mod;
if (localModelPath) { env.allowRemoteModels = false; env.localModelPath = localModelPath; }
if (wasmPaths) env.backends.onnx.wasm.wasmPaths = wasmPaths;
const ext = await pipeline("feature-extraction", model);
this.embed = async (t) =>
Array.from((await ext(t, { pooling: "mean", normalize: true })).data);
this.W = this._loadW() || this._eye();
return this;
}
// corpus: [{ key, text, vec? }]. If vec (384 floats) is provided it's used as-is.
async load(corpus) {
this.facts = [];
for (const f of corpus) {
const vec = Array.isArray(f.vec) && f.vec.length === DIM
? f.vec
: await this.embed(`${f.key}. ${f.text || ""}`.slice(0, 512));
this.facts.push({ key: f.key, text: f.text || "", vec });
}
return this.facts.length;
}
async recall(query, k = 5) {
const q = this._l2(this._matvec(this.W, await this.embed(query)));
return this.facts
.map((f, i) => ({ key: f.key, text: f.text, index: i, score: this._dot(f.vec, q) }))
.sort((a, b) => b.score - a.score)
.slice(0, k);
}
// Reinforce: make `query` retrieve facts[targetIndex] more strongly. Real gradient descent on W.
async teach(query, targetIndex, steps = 25, lr = 0.3) {
if (targetIndex == null || targetIndex < 0 || targetIndex >= this.facts.length) return;
const qe = await this.embed(query);
for (let s = 0; s < steps; s++) this._teachStep(qe, targetIndex, lr);
this._saveW();
}
_teachStep(qe, target, lr) {
const qq = this._l2(this._matvec(this.W, qe));
const scores = this.facts.map((f) => this._dot(f.vec, qq));
const p = this._softmax(scores);
const g = new Float32Array(DIM);
for (let j = 0; j < this.facts.length; j++) {
const c = (p[j] - (j === target ? 1 : 0)) / TEMP;
const Kj = this.facts[j].vec;
for (let d = 0; d < DIM; d++) g[d] += c * Kj[d];
}
for (let r = 0; r < DIM; r++) {
const gr = lr * g[r];
if (gr === 0) continue;
const Wr = this.W[r];
for (let c = 0; c < DIM; c++) Wr[c] -= gr * qe[c];
}
}
exportW() { return this.W.map((r) => Array.from(r)); } // the learned weights
importW(arr) { if (arr && arr.length === DIM) { this.W = arr.map((r) => Float32Array.from(r)); this._saveW(); } }
resetW() { this.W = this._eye(); this._saveW(); }
_eye() { const W = []; for (let r = 0; r < DIM; r++) { const row = new Float32Array(DIM); row[r] = 1; W.push(row); } return W; }
_loadW() { try { const s = localStorage.getItem(this.storageKey); if (s) { const a = JSON.parse(s); if (a.length === DIM) return a.map((r) => Float32Array.from(r)); } } catch (e) {} return null; }
_saveW() { try { localStorage.setItem(this.storageKey, JSON.stringify(this.W.map((r) => Array.from(r, (x) => Math.round(x * 1e4) / 1e4)))); } catch (e) {} }
_dot(a, b) { let s = 0; for (let i = 0; i < DIM; i++) s += a[i] * b[i]; return s; }
_matvec(W, v) { const o = new Float32Array(DIM); for (let r = 0; r < DIM; r++) o[r] = this._dot(W[r], v); return o; }
_l2(v) { let s = 0; for (let i = 0; i < DIM; i++) s += v[i] * v[i]; s = Math.sqrt(s) || 1; const o = new Float32Array(DIM); for (let i = 0; i < DIM; i++) o[i] = v[i] / s; return o; }
_softmax(s) { let m = -Infinity; for (const x of s) if (x > m) m = x; let z = 0; const p = s.map((x) => { const e = Math.exp((x - m) / TEMP); z += e; return e; }); return p.map((x) => x / z); }
}