// LocateAnything-3B — fully in-browser WebGPU detection. // Pipeline: preprocess image -> vision ONNX -> visual_features; build prompt input_ids; // INT4 embedding gather + visual splice -> inputs_embeds; KV-cache autoregressive decode // with the INT4 language graph; parse / -> draw. No server inference. import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/ort.webgpu.min.mjs"; import { AutoTokenizer } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.1"; const REPO = "Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4"; const BASE = `https://huggingface.co/${REPO}/resolve/main`; const VISION_URL = `${BASE}/onnx/vision_mlp_int4.onnx`; // INT4 (~250MB) — fp32 1.73GB stalls on download const VISION_DATA = "vision_mlp_int4.onnx.data"; const VISION_DATA_URL = `${BASE}/onnx/${VISION_DATA}`; const LANG_URL = `${BASE}/onnx/language_tail_kv_int4.onnx`; const LANG_DATA = "language_tail_kv_int4.onnx.data"; const LANG_DATA_URL = `${BASE}/onnx/${LANG_DATA}`; const EMB_PACKED_URL = `${BASE}/onnx/embed_tokens_int4_packed.bin`; const EMB_SCALES_URL = `${BASE}/onnx/embed_tokens_int4_scales.bin`; const EMB_META_URL = `${BASE}/onnx/embed_tokens_int4_meta.json`; // token ids / model constants (from config + generate_utils) const IMG_CONTEXT = 151665; // image placeholder token (image_token_index) const IM_END = 151645; // stop token const N_LAYERS = 36, KV_HEADS = 2, HEAD_DIM = 128; const PATCH = 14, MERGE = 2, IN_TOKEN_LIMIT = 256; const MEAN = 0.5, STD = 0.5; const $ = (id) => document.getElementById(id); const logEl = $("log"); function log(m) { logEl.textContent += m + "\n"; logEl.scrollTop = logEl.scrollHeight; console.log(m); } function setBadge(el, text, cls) { el.textContent = text; el.className = "badge" + (cls ? " " + cls : ""); } let tokenizer = null, visionSess = null, langSess = null; let embPacked = null, embScales = null, embMeta = null; let curImage = null; // HTMLImageElement const out_names_cache = {}; // ---------- fp16 -> fp32 ---------- function f16to32(h) { const s = (h & 0x8000) >> 15, e = (h & 0x7c00) >> 10, f = h & 0x03ff; if (e === 0) return (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024); if (e === 0x1f) return f ? NaN : (s ? -Infinity : Infinity); return (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024); } // ---------- model loading ---------- async function fetchBuf(url, label) { const r = await fetch(url); if (!r.ok) throw new Error(`fetch ${label} failed: ${r.status}`); return new Uint8Array(await r.arrayBuffer()); } const sleep = (ms) => new Promise((r) => setTimeout(r, ms)); // fetch with a stall watchdog: aborts if no progress within `stallMs` async function fetchAbortable(url, opts, stallMs = 30000) { const ctrl = new AbortController(); const r = await fetch(url, { ...opts, signal: ctrl.signal }); if (!(r.status === 200 || r.status === 206)) throw new Error(`status ${r.status}`); const reader = r.body.getReader(); const chunks = []; let got = 0; let timer = setTimeout(() => ctrl.abort(), stallMs); try { for (;;) { const { done, value } = await reader.read(); if (done) break; clearTimeout(timer); timer = setTimeout(() => ctrl.abort(), stallMs); chunks.push(value); got += value.length; } } finally { clearTimeout(timer); } const buf = new Uint8Array(got); let o = 0; for (const c of chunks) { buf.set(c, o); o += c.length; } return { buf, headers: r.headers }; } // Chunked Range download: small pieces (retried independently) so no single long-lived // connection can stall the whole file. HF CDN supports range requests. async function fetchBufProgress(url, label, chunk = 48 * 1024 * 1024) { const t = performance.now(); // discover total size via the first range request's Content-Range let total = 0, first; for (let tr = 0; ; tr++) { try { first = await fetchAbortable(url, { headers: { Range: `bytes=0-${chunk - 1}` } }); const cr = first.headers.get("content-range"); total = cr ? +cr.split("/")[1] : (+first.headers.get("content-length") || first.buf.length); break; } catch (e) { if (tr >= 4) throw e; log(` ${label} init retry ${tr + 1}…`); await sleep(1200); } } if (!total || total <= first.buf.length) { // small file, already done log(` ${label} downloaded ${(first.buf.length/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`); return first.buf; } const buf = new Uint8Array(total); buf.set(first.buf, 0); let off = first.buf.length, lastPct = -1; while (off < total) { const end = Math.min(off + chunk, total) - 1; let ok = false; for (let tr = 0; tr < 5 && !ok; tr++) { try { const { buf: part } = await fetchAbortable(url, { headers: { Range: `bytes=${off}-${end}` } }); buf.set(part, off); off += part.length; ok = true; } catch (e) { if (tr === 4) throw e; await sleep(1000); } } const pct = Math.floor((off / total) * 100); if (pct >= lastPct + 10) { lastPct = pct; log(` ${label}: ${pct}% (${(off/1e6|0)}MB)`); } } log(` ${label} downloaded ${(total/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`); return buf; } async function loadAll() { setBadge($("load"), "loading…", "warn"); $("prog").style.display = "block"; // webgpu first; wasm as fallback so unsupported ops/devices degrade instead of hard-failing. const sessOpts = { executionProviders: ["webgpu", "wasm"], graphOptimizationLevel: "all" }; let t; // Create the ONNX sessions FIRST (before transformers.js), fetching buffers ourselves so we // can see download vs. compile timing and avoid ort's internal URL fetch hanging on redirects. log("downloading vision model (INT4, ~250MB)…"); const visGraph = await fetchBufProgress(VISION_URL, "vision graph"); const visData = await fetchBufProgress(VISION_DATA_URL, "vision data"); $("prog").value = 30; log("compiling vision session…"); t = performance.now(); visionSess = await ort.InferenceSession.create(visGraph, { ...sessOpts, externalData: [{ path: VISION_DATA, data: visData }], }); log(`vision session ready in ${((performance.now()-t)/1000).toFixed(1)}s`); $("prog").value = 50; log("downloading INT4 language model (~1.7GB)…"); const langData = await fetchBufProgress(LANG_DATA_URL, "language data"); const langGraph = await fetchBufProgress(LANG_URL, "language graph"); $("prog").value = 80; log("compiling language session…"); t = performance.now(); langSess = await ort.InferenceSession.create(langGraph, { ...sessOpts, externalData: [{ path: LANG_DATA, data: langData }], }); log(`language session ready in ${((performance.now()-t)/1000).toFixed(1)}s`); out_names_cache.lang = langSess.outputNames; $("prog").value = 90; log("loading tokenizer + INT4 embedding table…"); tokenizer = await AutoTokenizer.from_pretrained(REPO); embMeta = await (await fetch(EMB_META_URL)).json(); embPacked = await fetchBufProgress(EMB_PACKED_URL, "embed packed"); // uint8 [vocab, hidden/2] const scalesBytes = await fetchBufProgress(EMB_SCALES_URL, "embed scales"); // fp16 const sv = new DataView(scalesBytes.buffer); embScales = new Float32Array(scalesBytes.length / 2); for (let i = 0; i < embScales.length; i++) embScales[i] = f16to32(sv.getUint16(i * 2, true)); log(`embedding: vocab=${embMeta.vocab} hidden=${embMeta.hidden} block=${embMeta.block_size}`); $("prog").value = 100; $("prog").style.display = "none"; setBadge($("load"), "model ready", "ok"); $("run").disabled = false; log("ready."); } // ---------- image preprocessing (mirrors LocateAnythingImageProcessor) ---------- function preprocess(img) { let w = img.naturalWidth, h = img.naturalHeight; // rescale to <= in_token_limit patches if ((Math.floor(w / PATCH)) * (Math.floor(h / PATCH)) > IN_TOKEN_LIMIT) { const scale = Math.sqrt(IN_TOKEN_LIMIT / ((Math.floor(w / PATCH)) * (Math.floor(h / PATCH)))); w = Math.floor(w * scale); h = Math.floor(h * scale); } // pad to multiple of MERGE*PATCH const pad = MERGE * PATCH; const tw = Math.ceil(w / pad) * pad, th = Math.ceil(h / pad) * pad; const c = document.createElement("canvas"); c.width = tw; c.height = th; const cx = c.getContext("2d"); cx.imageSmoothingEnabled = true; cx.imageSmoothingQuality = "high"; cx.drawImage(img, 0, 0, tw, th); const data = cx.getImageData(0, 0, tw, th).data; // RGBA, row-major const gh = th / PATCH, gw = tw / PATCH; const nPatches = gh * gw; // pixel_values [nPatches, 3, 14, 14], patch order row-major (gh outer, gw inner) const pv = new Float32Array(nPatches * 3 * PATCH * PATCH); let p = 0; for (let py = 0; py < gh; py++) { for (let px = 0; px < gw; px++) { for (let ch = 0; ch < 3; ch++) { for (let yy = 0; yy < PATCH; yy++) { for (let xx = 0; xx < PATCH; xx++) { const sx = px * PATCH + xx, sy = py * PATCH + yy; const v = data[(sy * tw + sx) * 4 + ch] / 255; pv[p++] = (v - MEAN) / STD; // -> 2v-1 } } } } } return { pv, gh, gw, nPatches, tw, th }; } // ---------- prompt build ---------- async function buildInputIds(cat, gh, gw) { const N = Math.floor((gh * gw) / (MERGE * MERGE)); const str = `<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\n` + "".repeat(N) + `Locate all the instances that matches the following description: ${cat}.<|im_end|>\n<|im_start|>assistant\n`; const enc = await tokenizer(str, { add_special_tokens: false }); const ids = Array.from(enc.input_ids.data, (x) => Number(x)); return { ids, N }; } // ---------- INT4 embedding gather for one token -> Float32Array(hidden) ---------- function gatherEmbed(tokenId, dst, off) { const H = embMeta.hidden, B = embMeta.block_size, NG = embMeta.n_groups, ZP = embMeta.zero_point; const packedRow = tokenId * (H / 2); const scaleRow = tokenId * NG; for (let j = 0; j < H; j += 2) { const byte = embPacked[packedRow + (j >> 1)]; const lo = byte & 0x0f, hi = (byte >> 4) & 0x0f; const g0 = (j / B) | 0, g1 = ((j + 1) / B) | 0; dst[off + j] = (lo - ZP) * embScales[scaleRow + g0]; dst[off + j + 1] = (hi - ZP) * embScales[scaleRow + g1]; } } // ---------- main detection ---------- async function detect() { if (!curImage) { log("pick or upload an image first."); return; } $("run").disabled = true; const cat = $("cat").value.trim() || "object"; const maxNew = parseInt($("mnt").value, 10); const t0 = performance.now(); // 1) preprocess + vision const { pv, gh, gw, nPatches } = preprocess(curImage); log(`image grid ${gh}x${gw} (${nPatches} patches)`); const pvT = new ort.Tensor("float32", pv, [nPatches, 3, PATCH, PATCH]); const ghT = new ort.Tensor("int64", BigInt64Array.from([BigInt(gh), BigInt(gw)]), [1, 2]); const vOut = await visionSess.run({ pixel_values: pvT, image_grid_hws: ghT }); const visual = vOut[visionSess.outputNames[0]]; // [N,2048] float32 const H = embMeta.hidden; log(`vision -> visual_features ${visual.dims.join("x")} (${((performance.now() - t0) / 1000).toFixed(1)}s)`); // 2) prompt const { ids, N } = await buildInputIds(cat, gh, gw); log(`prompt tokens: ${ids.length} (image tokens N=${N})`); // 3) prompt inputs_embeds = INT4 gather + visual splice at IMG_CONTEXT positions const L = ids.length; const embeds = new Float32Array(L * H); let visIdx = 0; const vdata = visual.data; for (let i = 0; i < L; i++) { if (ids[i] === IMG_CONTEXT) { embeds.set(vdata.subarray(visIdx * H, (visIdx + 1) * H), i * H); visIdx++; } else { gatherEmbed(ids[i], embeds, i * H); } } // 4) KV-cache prefill const idsBig = BigInt64Array.from(ids.map((x) => BigInt(x))); const mkEmptyPast = () => { const f = {}; for (let i = 0; i < N_LAYERS; i++) { f[`past_key_${i}`] = new ort.Tensor("float32", new Float32Array(0), [1, KV_HEADS, 0, HEAD_DIM]); f[`past_value_${i}`] = new ort.Tensor("float32", new Float32Array(0), [1, KV_HEADS, 0, HEAD_DIM]); } return f; }; let feeds = { input_ids: new ort.Tensor("int64", idsBig, [1, L]), inputs_embeds: new ort.Tensor("float32", embeds, [1, L, H]), attention_mask: new ort.Tensor("int64", BigInt64Array.from(new Array(L).fill(1n)), [1, L]), position_ids: new ort.Tensor("int64", BigInt64Array.from(ids.map((_, i) => BigInt(i))), [1, L]), ...mkEmptyPast(), }; let res = await langSess.run(feeds); let present = res; const outNames = langSess.outputNames; const logits = res["logits"]; // [1,1,V] const V = logits.dims[2]; let next = argmaxLast(logits.data, V); const gen = [next]; log(`prefill done (${((performance.now() - t0) / 1000).toFixed(1)}s), decoding…`); // 5) decode loop with KV cache let pastLen = L; const tDec = performance.now(); for (let step = 0; step < maxNew - 1; step++) { if (next === IM_END) break; const emb1 = new Float32Array(H); gatherEmbed(next, emb1, 0); const f = { input_ids: new ort.Tensor("int64", BigInt64Array.from([BigInt(next)]), [1, 1]), inputs_embeds: new ort.Tensor("float32", emb1, [1, 1, H]), attention_mask: new ort.Tensor("int64", BigInt64Array.from(new Array(pastLen + 1).fill(1n)), [1, pastLen + 1]), position_ids: new ort.Tensor("int64", BigInt64Array.from([BigInt(pastLen)]), [1, 1]), }; for (let i = 0; i < N_LAYERS; i++) { f[`past_key_${i}`] = present[`present_key_${i}`]; f[`past_value_${i}`] = present[`present_value_${i}`]; } res = await langSess.run(f); present = res; next = argmaxLast(res["logits"].data, V); gen.push(next); pastLen += 1; if (step % 8 === 0) { $("raw").textContent = tokenizer.decode(gen, { skip_special_tokens: false }); await new Promise(r => setTimeout(r)); } } const decS = (performance.now() - tDec) / 1000; const text = tokenizer.decode(gen, { skip_special_tokens: false }); $("raw").textContent = text; log(`decoded ${gen.length} tokens in ${decS.toFixed(1)}s (${(gen.length / decS).toFixed(2)} tok/s)`); // 6) parse + draw const dets = parseBoxes(text); log(`detections: ${dets.length}`); drawResult(dets); $("run").disabled = false; } function argmaxLast(arr, V) { // logits are [1, T, V]; take last row const base = arr.length - V; let best = 0, bv = -Infinity; for (let i = 0; i < V; i++) { const v = arr[base + i]; if (v > bv) { bv = v; best = i; } } return best; } function parseBoxes(text) { const out = []; const re = /(.*?)<\/ref>((?:\s*.*?<\/box>)+)/gis; let m; while ((m = re.exec(text)) !== null) { const label = m[1].trim(); const boxRe = /(.*?)<\/box>/gis; let b; while ((b = boxRe.exec(m[2])) !== null) { const nums = (b[1].match(/-?\d+\.?\d*/g) || []).map(Number); if (nums.length === 4) out.push({ label, box: nums }); } } // bare boxes without a preceding ref if (out.length === 0) { const boxRe = /(.*?)<\/box>/gis; let b; while ((b = boxRe.exec(text)) !== null) { const nums = (b[1].match(/-?\d+\.?\d*/g) || []).map(Number); if (nums.length === 4) out.push({ label: "", box: nums }); } } return out; } const COLORS = ["#0891b2", "#dc2626", "#16a34a", "#2563eb", "#d97706", "#9333ea"]; function drawResult(dets) { const cv = $("cv"), ctx = cv.getContext("2d"); const W = curImage.naturalWidth, Ht = curImage.naturalHeight; cv.width = W; cv.height = Ht; ctx.drawImage(curImage, 0, 0, W, Ht); ctx.lineWidth = Math.max(2, Math.round(W / 320)); ctx.font = `${Math.max(12, Math.round(W / 45))}px sans-serif`; dets.forEach((d, i) => { const col = COLORS[i % COLORS.length]; const [x1, y1, x2, y2] = d.box; const rx1 = (x1 * W) / 1000, ry1 = (y1 * Ht) / 1000, rx2 = (x2 * W) / 1000, ry2 = (y2 * Ht) / 1000; ctx.strokeStyle = col; ctx.fillStyle = col; ctx.strokeRect(rx1, ry1, rx2 - rx1, ry2 - ry1); if (d.label) { ctx.fillText(d.label, rx1 + 3, Math.max(ry1 - 4, 12)); } }); } // ---------- wiring ---------- function setImageFromSrc(src, el) { const img = new Image(); img.crossOrigin = "anonymous"; img.onload = () => { curImage = img; const cv = $("cv"), ctx = cv.getContext("2d"); cv.width = img.naturalWidth; cv.height = img.naturalHeight; ctx.drawImage(img, 0, 0); }; img.src = src; document.querySelectorAll(".samples img").forEach(s => s.classList.remove("sel")); if (el) el.classList.add("sel"); } function initSamples() { const names = ["person", "book", "sweet", "ocr"]; const wrap = $("samples"); names.forEach(n => { const im = document.createElement("img"); im.src = `./assets/${n}.jpg`; im.onclick = () => setImageFromSrc(im.src, im); wrap.appendChild(im); }); } async function main() { initSamples(); $("mnt").oninput = () => $("mntv").textContent = $("mnt").value; $("file").onchange = (e) => { const f = e.target.files[0]; if (f) setImageFromSrc(URL.createObjectURL(f)); }; $("run").onclick = () => detect().catch(err => { log("ERROR: " + err.message); console.error(err); $("run").disabled = false; }); if (!navigator.gpu) { setBadge($("gpu"), "WebGPU not available", "err"); log("WebGPU is not available in this browser. Use Chrome/Edge 121+ or Safari 18+ with WebGPU enabled."); } else { setBadge($("gpu"), "WebGPU ✓", "ok"); } try { await loadAll(); // preselect first sample setImageFromSrc("./assets/person.jpg", document.querySelector(".samples img")); } catch (e) { setBadge($("load"), "load failed", "err"); log("ERROR loading models: " + e.message); console.error(e); } } main();