LJTSG commited on
Commit
78bc0e4
·
verified ·
1 Parent(s): 373a9bd

web-TTT: browser-native test-time-trainable memory toolkit

Browse files
Files changed (4) hide show
  1. README.md +71 -0
  2. demo/facts.json +10 -0
  3. demo/index.html +53 -0
  4. web-ttt.js +105 -0
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - web-ttt
5
+ - test-time-training
6
+ - in-browser
7
+ - webgpu
8
+ - associative-memory
9
+ - minilm
10
+ - transformers-js
11
+ library_name: transformers.js
12
+ ---
13
+
14
+ # web-TTT — a browser-native, test-time-TRAINABLE memory
15
+
16
+ A tiny method + toolkit for an associative memory that **learns from use, entirely in the browser** — no server, no API, no backend. To our knowledge, the first published browser-trainable **TTT (test-time training)** memory.
17
+
18
+ It is the memory layer behind a fully in-browser AI companion, extracted here as a clean, reusable method you can drop onto **any** corpus (it does **not** contain anyone's private data — just the process + a generic demo).
19
+
20
+ ## The idea
21
+
22
+ - **Embed** each fact with `Xenova/all-MiniLM-L6-v2` (via [transformers.js](https://github.com/huggingface/transformers.js)) → a 384-d vector.
23
+ - **Recall** a query as `score(fact) = cosine( normalize(W · embed(query)) , fact_vec )`, top-K — where **W** is a learnable 384×384 projection matrix (identity at first).
24
+ - **Train at test time:** clicking a result runs ~25 steps of cross-entropy gradient descent on **W**, pulling that query toward that fact. The matrix learns which queries should surface which memories — **on the user's device, in the browser.** `W` persists in `localStorage`, so the memory keeps what you taught it across sessions.
25
+
26
+ That's the whole thing: a frozen embedder + a small matrix you train by example, live, client-side. It runs on WebGPU/WASM and needs nothing but a browser.
27
+
28
+ ## Files
29
+
30
+ - `web-ttt.js` — the toolkit (ES module). `WebTTT` class: `init()`, `load(corpus)`, `recall(query, k)`, `teach(query, targetIndex)`, `exportW()/importW()/resetW()`.
31
+ - `demo/index.html` + `demo/facts.json` — a working demo on a generic corpus.
32
+
33
+ ## Quick start
34
+
35
+ ```js
36
+ import { WebTTT } from "./web-ttt.js";
37
+
38
+ const ttt = new WebTTT({ storageKey: "my_ttt" });
39
+ await ttt.init(); // loads MiniLM (CDN by default)
40
+ await ttt.load([
41
+ { key: "the sun", text: "The Sun is the star at the center of the Solar System." },
42
+ { key: "the moon", text: "The Moon is Earth's only natural satellite." },
43
+ ]);
44
+
45
+ const hits = await ttt.recall("what's at the center of the solar system?", 3);
46
+ // → [{ key: "the sun", score: 0.6x, index: 0, ... }, ...]
47
+
48
+ await ttt.teach("center of the solar system", 0); // reinforce: 25-step W update, persists
49
+ ```
50
+
51
+ Run the demo: serve the folder over `http://localhost` (WebGPU/transformers.js need a real origin, not `file://`) and open `demo/index.html` in Chrome/Edge/Brave.
52
+
53
+ ### Fully offline
54
+
55
+ Pass local asset paths to `init()`:
56
+
57
+ ```js
58
+ await ttt.init({
59
+ transformersUrl: "./vendor/transformers/transformers.min.js",
60
+ localModelPath: "./models/", // contains Xenova/all-MiniLM-L6-v2
61
+ wasmPaths: "./vendor/transformers/", // ort-wasm*.wasm
62
+ });
63
+ ```
64
+
65
+ ## Format / publishing your own TTT
66
+
67
+ A "trained TTT" is just: your corpus (`[{key, text}]`), optionally the precomputed 384-d `vec` per fact, and the learned `W` (384×384, from `exportW()`). Ship those three and any browser can load and keep training it. That's the publishable artifact — **the method, on your data, your choice.**
68
+
69
+ ## License
70
+
71
+ MIT. The embedder (`Xenova/all-MiniLM-L6-v2`) and transformers.js carry their own licenses.
demo/facts.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ { "key": "what is web-ttt", "text": "web-TTT is a browser-native, test-time-trainable associative memory: MiniLM embeddings plus a learnable projection matrix you train by example, in the browser, with no server." },
3
+ { "key": "how recall works", "text": "A query is embedded, projected through the learned matrix W, then matched by cosine similarity against each stored fact's embedding. The top matches are returned." },
4
+ { "key": "how training works", "text": "Click a fact to teach it: 25 steps of cross-entropy gradient descent run on W, pulling that query toward that fact. This is real test-time training, on the user's GPU/CPU." },
5
+ { "key": "where memory lives", "text": "Fact embeddings live in memory; the learned matrix W persists in the browser's localStorage. Nothing leaves the device." },
6
+ { "key": "the sun", "text": "The Sun is the star at the center of the Solar System, a near-perfect ball of hot plasma." },
7
+ { "key": "the moon", "text": "The Moon is Earth's only natural satellite and the brightest object in the night sky after the Sun." },
8
+ { "key": "photosynthesis", "text": "Photosynthesis is how plants convert light, water, and carbon dioxide into glucose and oxygen." },
9
+ { "key": "the ocean", "text": "Oceans cover about 71 percent of Earth's surface and hold the vast majority of its water." }
10
+ ]
demo/index.html ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
6
+ <title>web-TTT demo</title>
7
+ <style>
8
+ body { font: 15px/1.5 ui-sans-serif, system-ui, sans-serif; max-width: 760px; margin: 32px auto; padding: 0 16px; color: #1a1a1a; }
9
+ h1 { font-size: 20px; } code { background: #f0f0f0; padding: 1px 5px; border-radius: 4px; }
10
+ input { width: 70%; padding: 8px; } button { padding: 8px 12px; cursor: pointer; }
11
+ .hit { border: 1px solid #ddd; border-left: 4px solid #2f6f67; border-radius: 6px; padding: 8px 10px; margin: 8px 0; }
12
+ .hit b { display:block; } .hit .s { color:#888; font-size:12px; } .teach { color:#2f6f67; cursor:pointer; font-size:12px; }
13
+ #status { color:#888; }
14
+ </style>
15
+ </head>
16
+ <body>
17
+ <h1>web-TTT — browser-trainable memory</h1>
18
+ <p id="status">loading MiniLM…</p>
19
+ <div><input id="q" placeholder="ask something…" /> <button id="go" disabled>recall</button></div>
20
+ <div id="hits"></div>
21
+ <script type="module">
22
+ import { WebTTT } from "../web-ttt.js";
23
+ const status = document.getElementById("q.status") || document.getElementById("status");
24
+ const ttt = new WebTTT({ storageKey: "web_ttt_demo" });
25
+ let lastQuery = "";
26
+ (async () => {
27
+ await ttt.init(); // MiniLM from CDN
28
+ const facts = await (await fetch("./facts.json")).json();
29
+ await ttt.load(facts);
30
+ status.textContent = `ready — ${facts.length} facts embedded. Recall, then click "teach" on a result to reinforce it (trains W in your browser).`;
31
+ document.getElementById("go").disabled = false;
32
+ })();
33
+ async function run() {
34
+ lastQuery = document.getElementById("q").value.trim();
35
+ if (!lastQuery) return;
36
+ const hits = await ttt.recall(lastQuery, 5);
37
+ document.getElementById("hits").innerHTML = hits.map(h =>
38
+ `<div class="hit"><b>${h.key}</b><span class="s">score ${h.score.toFixed(3)}</span>
39
+ <p>${h.text}</p><span class="teach" data-i="${h.index}">teach ↩ (pull this query toward this memory)</span></div>`
40
+ ).join("");
41
+ }
42
+ document.getElementById("go").onclick = run;
43
+ document.getElementById("q").addEventListener("keydown", e => { if (e.key === "Enter") run(); });
44
+ document.getElementById("hits").addEventListener("click", async e => {
45
+ const t = e.target.closest(".teach"); if (!t) return;
46
+ status.textContent = "teaching…";
47
+ await ttt.teach(lastQuery, Number(t.dataset.i));
48
+ status.textContent = "learned — re-running recall.";
49
+ run();
50
+ });
51
+ </script>
52
+ </body>
53
+ </html>
web-ttt.js ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // web-ttt.js — a browser-native, test-time-TRAINABLE associative memory.
2
+ //
3
+ // MiniLM sentence embeddings + a learnable 384x384 projection matrix W.
4
+ // Recall: score(fact) = cosine( normalize(W · embed(query)) , fact_vec ) -> top-K.
5
+ // Train (test-time training): pull a query toward a target fact by cross-entropy
6
+ // gradient descent on W, ~25 steps, ENTIRELY IN THE BROWSER. W persists in localStorage.
7
+ // No server, no API, no backend. The memory learns from use, on the user's device.
8
+ //
9
+ // Usage:
10
+ // import { WebTTT } from "./web-ttt.js";
11
+ // const ttt = new WebTTT({ storageKey: "my_ttt" });
12
+ // await ttt.init(); // loads MiniLM (transformers.js)
13
+ // await ttt.load([{ key, text }, ...]); // embeds your corpus
14
+ // const hits = await ttt.recall("a query", 5);
15
+ // await ttt.teach("a query", hits[0].index); // 25-step W update, then persists
16
+ //
17
+ // Embeddings come from Xenova/all-MiniLM-L6-v2 via transformers.js. By default both
18
+ // load from CDN; pass localModelPath/transformersUrl to run fully offline.
19
+
20
+ const DIM = 384;
21
+ const TEMP = 0.1;
22
+
23
+ export class WebTTT {
24
+ constructor({ storageKey = "web_ttt" } = {}) {
25
+ this.facts = [];
26
+ this.W = null;
27
+ this.embed = null;
28
+ this.storageKey = storageKey;
29
+ }
30
+
31
+ async init({
32
+ transformersUrl = "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2",
33
+ model = "Xenova/all-MiniLM-L6-v2",
34
+ localModelPath = null,
35
+ wasmPaths = null,
36
+ } = {}) {
37
+ const mod = await import(transformersUrl);
38
+ const { pipeline, env } = mod;
39
+ if (localModelPath) { env.allowRemoteModels = false; env.localModelPath = localModelPath; }
40
+ if (wasmPaths) env.backends.onnx.wasm.wasmPaths = wasmPaths;
41
+ const ext = await pipeline("feature-extraction", model);
42
+ this.embed = async (t) =>
43
+ Array.from((await ext(t, { pooling: "mean", normalize: true })).data);
44
+ this.W = this._loadW() || this._eye();
45
+ return this;
46
+ }
47
+
48
+ // corpus: [{ key, text, vec? }]. If vec (384 floats) is provided it's used as-is.
49
+ async load(corpus) {
50
+ this.facts = [];
51
+ for (const f of corpus) {
52
+ const vec = Array.isArray(f.vec) && f.vec.length === DIM
53
+ ? f.vec
54
+ : await this.embed(`${f.key}. ${f.text || ""}`.slice(0, 512));
55
+ this.facts.push({ key: f.key, text: f.text || "", vec });
56
+ }
57
+ return this.facts.length;
58
+ }
59
+
60
+ async recall(query, k = 5) {
61
+ const q = this._l2(this._matvec(this.W, await this.embed(query)));
62
+ return this.facts
63
+ .map((f, i) => ({ key: f.key, text: f.text, index: i, score: this._dot(f.vec, q) }))
64
+ .sort((a, b) => b.score - a.score)
65
+ .slice(0, k);
66
+ }
67
+
68
+ // Reinforce: make `query` retrieve facts[targetIndex] more strongly. Real gradient descent on W.
69
+ async teach(query, targetIndex, steps = 25, lr = 0.3) {
70
+ if (targetIndex == null || targetIndex < 0 || targetIndex >= this.facts.length) return;
71
+ const qe = await this.embed(query);
72
+ for (let s = 0; s < steps; s++) this._teachStep(qe, targetIndex, lr);
73
+ this._saveW();
74
+ }
75
+
76
+ _teachStep(qe, target, lr) {
77
+ const qq = this._l2(this._matvec(this.W, qe));
78
+ const scores = this.facts.map((f) => this._dot(f.vec, qq));
79
+ const p = this._softmax(scores);
80
+ const g = new Float32Array(DIM);
81
+ for (let j = 0; j < this.facts.length; j++) {
82
+ const c = (p[j] - (j === target ? 1 : 0)) / TEMP;
83
+ const Kj = this.facts[j].vec;
84
+ for (let d = 0; d < DIM; d++) g[d] += c * Kj[d];
85
+ }
86
+ for (let r = 0; r < DIM; r++) {
87
+ const gr = lr * g[r];
88
+ if (gr === 0) continue;
89
+ const Wr = this.W[r];
90
+ for (let c = 0; c < DIM; c++) Wr[c] -= gr * qe[c];
91
+ }
92
+ }
93
+
94
+ exportW() { return this.W.map((r) => Array.from(r)); } // the learned weights
95
+ importW(arr) { if (arr && arr.length === DIM) { this.W = arr.map((r) => Float32Array.from(r)); this._saveW(); } }
96
+ resetW() { this.W = this._eye(); this._saveW(); }
97
+
98
+ _eye() { const W = []; for (let r = 0; r < DIM; r++) { const row = new Float32Array(DIM); row[r] = 1; W.push(row); } return W; }
99
+ _loadW() { try { const s = localStorage.getItem(this.storageKey); if (s) { const a = JSON.parse(s); if (a.length === DIM) return a.map((r) => Float32Array.from(r)); } } catch (e) {} return null; }
100
+ _saveW() { try { localStorage.setItem(this.storageKey, JSON.stringify(this.W.map((r) => Array.from(r, (x) => Math.round(x * 1e4) / 1e4)))); } catch (e) {} }
101
+ _dot(a, b) { let s = 0; for (let i = 0; i < DIM; i++) s += a[i] * b[i]; return s; }
102
+ _matvec(W, v) { const o = new Float32Array(DIM); for (let r = 0; r < DIM; r++) o[r] = this._dot(W[r], v); return o; }
103
+ _l2(v) { let s = 0; for (let i = 0; i < DIM; i++) s += v[i] * v[i]; s = Math.sqrt(s) || 1; const o = new Float32Array(DIM); for (let i = 0; i < DIM; i++) o[i] = v[i] / s; return o; }
104
+ _softmax(s) { let m = -Infinity; for (const x of s) if (x > m) m = x; let z = 0; const p = s.map((x) => { const e = Math.exp((x - m) / TEMP); z += e; return e; }); return p.map((x) => x / z); }
105
+ }