Spaces:
Running
Running
| /* LocalAgent — in-browser tool calling on onnxruntime-web (WebGPU + WASM fallback). | |
| * | |
| * The transformer forward pass runs as an ONNX graph emitting `logits` and `hidden`. The GENERABLE | |
| * dispatch (route head -> dense two-tower selector -> pointer-copy args) is ported here from the | |
| * Python pipeline. Bundle contract (see localagent.inference.export): | |
| * model.fp16.onnx inputs: input_ids[int64, 1xT] outputs: logits[1,T,256], hidden[1,T,d] | |
| * dispatch_heads.json { route_head:{weight:[5][d],bias:[5],routes:[5],stop_index}, | |
| * dense_selector:{q_proj_weight:[p][d],q_proj_bias:[p],proj:p, | |
| * tool_matrix:[N][p],tool_names:[N],normalize_query} } | |
| * heads.json { pointer_head:{arg_idx,arg_emb,start_W,end_W}, ... } (args copy) | |
| * meta.json { d_model, markers:{...}, tools:[{name,args,schema}] } (50 tools) | |
| * | |
| * Selection is NOT a fixed-N classifier: the dense selector scores every tool by its description | |
| * embedding, so adding/removing a tool is adding/removing a tool_matrix row. | |
| */ | |
| const MODEL_URL = "model.fp16.onnx"; | |
| let SESSION = null; | |
| let HEADS = null; | |
| let META = null; | |
| let DISPATCH = null; | |
| let BACKEND = "wasm"; | |
| // ---- bundle loading ------------------------------------------------------- | |
| async function loadBundle() { | |
| ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.1/dist/"; | |
| [HEADS, META, DISPATCH] = await Promise.all([ | |
| fetch("heads.json").then((r) => r.json()), | |
| fetch("meta.json").then((r) => r.json()), | |
| fetch("dispatch_heads.json").then((r) => r.json()), | |
| ]); | |
| try { | |
| SESSION = await ort.InferenceSession.create(MODEL_URL, { | |
| executionProviders: ["webgpu", "wasm"], | |
| }); | |
| BACKEND = "webgpu"; | |
| } catch (e) { | |
| console.warn("WebGPU unavailable, falling back to WASM:", e); | |
| SESSION = await ort.InferenceSession.create(MODEL_URL, { executionProviders: ["wasm"] }); | |
| BACKEND = "wasm"; | |
| } | |
| } | |
| // ---- byte tokenizer (vocab 256) ------------------------------------------ | |
| // Markers are literal strings encoded as UTF-8 bytes — identical to the Python byte tokenizer. | |
| const enc = new TextEncoder(); | |
| function bytesOf(s) { return Array.from(enc.encode(s)); } | |
| function mark(name) { return META.markers[name].text; } // markers carry { text, ids } | |
| // Render a user turn the way the model was trained. | |
| function renderContext(query, steps) { | |
| let s = mark("user") + query + mark("assistant"); | |
| for (const st of steps || []) { | |
| s += mark("tool_call_open") + st.tool + "(" + JSON.stringify(st.args) + ")" + mark("tool_call_close"); | |
| s += mark("tool") + mark("tool_response_open") + (st.response || "ok") + mark("tool_response_close"); | |
| s += mark("assistant"); | |
| } | |
| return bytesOf(s); | |
| } | |
| // ---- model forward -------------------------------------------------------- | |
| async function forward(ids) { | |
| const arr = BigInt64Array.from(ids.map((x) => BigInt(x))); | |
| const input = new ort.Tensor("int64", arr, [1, ids.length]); | |
| const out = await SESSION.run({ input_ids: input }); | |
| return out; // { logits, hidden } | |
| } | |
| // ---- generable dispatch: route head -> dense two-tower selector ------------ | |
| function lastHidden(hiddenTensor, T) { | |
| const d = META.d_model, H = hiddenTensor.data, off = (T - 1) * d; | |
| return H.subarray ? H.subarray(off, off + d) : Array.from(H).slice(off, off + d); | |
| } | |
| function linrow(W, b, x) { // W[o][d] · x[d] + b[o] -> [o] | |
| const o = W.length, out = new Float32Array(o); | |
| for (let i = 0; i < o; i++) { const Wi = W[i]; let a = b ? b[i] : 0; for (let k = 0; k < x.length; k++) a += Wi[k] * x[k]; out[i] = a; } | |
| return out; | |
| } | |
| function argmax(v) { let bi = 0; for (let i = 1; i < v.length; i++) if (v[i] > v[bi]) bi = i; return bi; } | |
| function softmaxAt(v, i) { let m = -Infinity; for (const x of v) m = Math.max(m, x); let z = 0; for (const x of v) z += Math.exp(x - m); return Math.exp(v[i] - m) / z; } | |
| function dispatchSelect(hiddenTensor, T) { | |
| const last = lastHidden(hiddenTensor, T); | |
| // 1. route head (5-way modality gate); the `text` route (stop_index) = abstain / direct answer. | |
| const R = DISPATCH.route_head; | |
| const rl = linrow(R.weight, R.bias, last); | |
| const ri = argmax(rl); | |
| if (ri === R.stop_index) return { isStop: true, route: R.routes[ri], conf: softmaxAt(rl, ri) }; | |
| // 2. dense selector: q = normalize(q_proj(last)); score_j = q · tool_matrix[j]; argmax. | |
| const S = DISPATCH.dense_selector; | |
| const q = linrow(S.q_proj_weight, S.q_proj_bias, last); | |
| if (S.normalize_query) { let n = 0; for (const x of q) n += x * x; n = Math.sqrt(n) || 1; for (let i = 0; i < q.length; i++) q[i] /= n; } | |
| let bi = 0, bs = -Infinity; | |
| for (let j = 0; j < S.tool_names.length; j++) { | |
| const Tj = S.tool_matrix[j]; let a = 0; for (let i = 0; i < S.proj; i++) a += Tj[i] * q[i]; | |
| if (a > bs) { bs = a; bi = j; } | |
| } | |
| return { name: S.tool_names[bi], route: R.routes[ri], conf: (bs + 1) / 2, isStop: false }; | |
| } | |
| // ---- argument grounding via the learned pointer head (port of pointer_head) ---- | |
| // q = arg_emb[arg_idx[arg]]; qs = start_W·q; qe = end_W·q | |
| // start = argmax_t hidden[t]·qs; end = argmax_{t>=start} hidden[t]·qe; value = bytes[start..end] | |
| function matvec(M, v) { | |
| const d = v.length, out = new Float32Array(d); | |
| for (let i = 0; i < d; i++) { const Mi = M[i]; let a = 0; for (let j = 0; j < d; j++) a += Mi[j] * v[j]; out[i] = a; } | |
| return out; | |
| } | |
| function dotAt(H, t, d, q) { const off = t * d; let a = 0; for (let k = 0; k < d; k++) a += H[off + k] * q[k]; return a; } | |
| function pointerSpan(arg, ids, H, T) { | |
| const ph = HEADS.pointer_head, d = META.d_model; | |
| const ai = ph.arg_idx[arg]; | |
| if (ai == null) return ""; | |
| const qs = matvec(ph.start_W, ph.arg_emb[ai]); | |
| const qe = matvec(ph.end_W, ph.arg_emb[ai]); | |
| let s = 0, sb = -Infinity; | |
| for (let t = 0; t < T; t++) { const v = dotAt(H, t, d, qs); if (v > sb) { sb = v; s = t; } } | |
| let e = s, eb = -Infinity; | |
| for (let t = s; t < T; t++) { const v = dotAt(H, t, d, qe); if (v > eb) { eb = v; e = t; } } | |
| try { return new TextDecoder().decode(new Uint8Array(ids.slice(s, e + 1))); } catch { return ""; } | |
| } | |
| function groundArgs(tool, ids, hiddenTensor, T) { | |
| const spec = (META.tools || []).find((t) => t.name === tool); | |
| const args = {}; | |
| if (!spec) return args; | |
| const H = hiddenTensor.data; | |
| for (const arg of spec.args || []) { | |
| args[arg] = HEADS.pointer_head.arg_idx[arg] != null ? pointerSpan(arg, ids, H, T) : ""; | |
| } | |
| return args; | |
| } | |
| // ---- single grounded call ------------------------------------------------- | |
| async function callOnce(query) { | |
| const ids = renderContext(query, []); | |
| const t0 = performance.now(); | |
| const out = await forward(ids); | |
| const sel = dispatchSelect(out.hidden, ids.length); | |
| const ms = performance.now() - t0; | |
| if (sel.isStop) return { abstain: true, route: sel.route, conf: sel.conf, ms }; | |
| return { tool: sel.name, route: sel.route, args: groundArgs(sel.name, ids, out.hidden, ids.length), conf: sel.conf, ms }; | |
| } | |
| // ---- planner rollout ------------------------------------------------------ | |
| async function planRollout(query, maxSteps = 4) { | |
| const steps = []; | |
| const t0 = performance.now(); | |
| for (let i = 0; i < maxSteps; i++) { | |
| const ids = renderContext(query, steps); | |
| const out = await forward(ids); | |
| const sel = dispatchSelect(out.hidden, ids.length); | |
| if (sel.isStop) break; | |
| const args = groundArgs(sel.name, ids, out.hidden, ids.length); | |
| steps.push({ tool: sel.name, route: sel.route, args, conf: sel.conf, response: simResponse(sel.name, args) }); | |
| } | |
| return { steps, ms: performance.now() - t0 }; | |
| } | |
| // A compact simulated tool response so downstream steps have context. | |
| function simResponse(tool, args) { | |
| if (/read_file|grep|list_dir|find/.test(tool)) return Object.values(args)[0] || "ok"; | |
| if (/search|news|http|open_url|define/.test(tool)) return "result: " + (Object.values(args)[0] || ""); | |
| return "ok"; | |
| } | |
| // ---- UI ------------------------------------------------------------------- | |
| const $ = (id) => document.getElementById(id); | |
| function setStatus(cls, text, backend) { | |
| const s = $("status"); | |
| s.className = "status " + cls; | |
| $("status-text").textContent = text; | |
| const b = $("backend-badge"); | |
| if (backend) { b.hidden = false; b.textContent = backend.toUpperCase(); } | |
| } | |
| function renderCall(step, idx) { | |
| const div = document.createElement("div"); | |
| div.className = "call" + (step.abstain ? " abstain" : ""); | |
| const conf = step.conf != null ? `<span class="conf">${(step.conf * 100).toFixed(0)}%</span>` : ""; | |
| const route = step.route ? `<span class="route">${step.route}</span>` : ""; | |
| if (step.abstain) { | |
| div.innerHTML = `${conf}${route}<span class="tool">— abstains (no tool needed)</span>`; | |
| } else { | |
| const ix = idx != null ? `<span class="step-index">${idx + 1}.</span>` : ""; | |
| div.innerHTML = `${conf}${route}${ix}<span class="tool">${step.tool}</span>` + | |
| `<pre>${JSON.stringify(step.args, null, 2)}</pre>`; | |
| } | |
| return div; | |
| } | |
| async function run() { | |
| const query = $("prompt").value.trim(); | |
| if (!query || !SESSION) return; | |
| $("run").disabled = true; | |
| const res = $("result"); | |
| res.hidden = false; | |
| res.innerHTML = '<div class="call"><span class="tool">…thinking</span></div>'; | |
| try { | |
| if ($("plan-mode").checked) { | |
| const { steps, ms } = await planRollout(query); | |
| res.innerHTML = ""; | |
| if (!steps.length) res.appendChild(renderCall({ abstain: true })); | |
| steps.forEach((s, i) => res.appendChild(renderCall(s, i))); | |
| const t = document.createElement("div"); | |
| t.className = "timing"; | |
| t.textContent = `${steps.length} step(s) · ${ms.toFixed(0)} ms · ${BACKEND}`; | |
| res.appendChild(t); | |
| } else { | |
| const out = await callOnce(query); | |
| res.innerHTML = ""; | |
| res.appendChild(renderCall(out)); | |
| const t = document.createElement("div"); | |
| t.className = "timing"; | |
| t.textContent = `${out.ms.toFixed(0)} ms · ${BACKEND}`; | |
| res.appendChild(t); | |
| } | |
| } catch (e) { | |
| res.innerHTML = `<div class="call abstain"><span class="tool">error</span><pre>${e}</pre></div>`; | |
| } finally { | |
| $("run").disabled = false; | |
| } | |
| } | |
| function wireUI() { | |
| $("run").addEventListener("click", run); | |
| $("prompt").addEventListener("keydown", (e) => { | |
| if ((e.metaKey || e.ctrlKey) && e.key === "Enter") run(); | |
| }); | |
| document.querySelectorAll(".chip").forEach((c) => { | |
| c.addEventListener("click", () => { | |
| $("prompt").value = c.textContent; | |
| $("plan-mode").checked = c.dataset.plan === "1"; | |
| run(); | |
| }); | |
| }); | |
| } | |
| (async function main() { | |
| wireUI(); | |
| try { | |
| setStatus("loading", "Loading model… (first load downloads & caches the weights)"); | |
| await loadBundle(); | |
| setStatus("ready", "Model ready — runs locally in your browser.", BACKEND); | |
| $("run").disabled = false; | |
| } catch (e) { | |
| console.error(e); | |
| setStatus("error", "Failed to load the model bundle: " + e.message); | |
| } | |
| })(); | |