Reza2kn's picture
INT4 vision (251MB) + chunked Range downloads (fix 1.7GB stall); verified end-to-end in Chrome WebGPU
9b70c76 verified
// LocateAnything-3B — fully in-browser WebGPU detection.
// Pipeline: preprocess image -> vision ONNX -> visual_features; build prompt input_ids;
// INT4 embedding gather + visual splice -> inputs_embeds; KV-cache autoregressive decode
// with the INT4 language graph; parse <ref>/<box> -> draw. No server inference.
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/ort.webgpu.min.mjs";
import { AutoTokenizer } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.1";
const REPO = "Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4";
const BASE = `https://huggingface.co/${REPO}/resolve/main`;
const VISION_URL = `${BASE}/onnx/vision_mlp_int4.onnx`; // INT4 (~250MB) — fp32 1.73GB stalls on download
const VISION_DATA = "vision_mlp_int4.onnx.data";
const VISION_DATA_URL = `${BASE}/onnx/${VISION_DATA}`;
const LANG_URL = `${BASE}/onnx/language_tail_kv_int4.onnx`;
const LANG_DATA = "language_tail_kv_int4.onnx.data";
const LANG_DATA_URL = `${BASE}/onnx/${LANG_DATA}`;
const EMB_PACKED_URL = `${BASE}/onnx/embed_tokens_int4_packed.bin`;
const EMB_SCALES_URL = `${BASE}/onnx/embed_tokens_int4_scales.bin`;
const EMB_META_URL = `${BASE}/onnx/embed_tokens_int4_meta.json`;
// token ids / model constants (from config + generate_utils)
const IMG_CONTEXT = 151665; // image placeholder token (image_token_index)
const IM_END = 151645; // stop token
const N_LAYERS = 36, KV_HEADS = 2, HEAD_DIM = 128;
const PATCH = 14, MERGE = 2, IN_TOKEN_LIMIT = 256;
const MEAN = 0.5, STD = 0.5;
const $ = (id) => document.getElementById(id);
const logEl = $("log");
function log(m) { logEl.textContent += m + "\n"; logEl.scrollTop = logEl.scrollHeight; console.log(m); }
function setBadge(el, text, cls) { el.textContent = text; el.className = "badge" + (cls ? " " + cls : ""); }
let tokenizer = null, visionSess = null, langSess = null;
let embPacked = null, embScales = null, embMeta = null;
let curImage = null; // HTMLImageElement
const out_names_cache = {};
// ---------- fp16 -> fp32 ----------
function f16to32(h) {
const s = (h & 0x8000) >> 15, e = (h & 0x7c00) >> 10, f = h & 0x03ff;
if (e === 0) return (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024);
if (e === 0x1f) return f ? NaN : (s ? -Infinity : Infinity);
return (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024);
}
// ---------- model loading ----------
async function fetchBuf(url, label) {
const r = await fetch(url);
if (!r.ok) throw new Error(`fetch ${label} failed: ${r.status}`);
return new Uint8Array(await r.arrayBuffer());
}
const sleep = (ms) => new Promise((r) => setTimeout(r, ms));
// fetch with a stall watchdog: aborts if no progress within `stallMs`
async function fetchAbortable(url, opts, stallMs = 30000) {
const ctrl = new AbortController();
const r = await fetch(url, { ...opts, signal: ctrl.signal });
if (!(r.status === 200 || r.status === 206)) throw new Error(`status ${r.status}`);
const reader = r.body.getReader();
const chunks = []; let got = 0;
let timer = setTimeout(() => ctrl.abort(), stallMs);
try {
for (;;) {
const { done, value } = await reader.read();
if (done) break;
clearTimeout(timer); timer = setTimeout(() => ctrl.abort(), stallMs);
chunks.push(value); got += value.length;
}
} finally { clearTimeout(timer); }
const buf = new Uint8Array(got); let o = 0;
for (const c of chunks) { buf.set(c, o); o += c.length; }
return { buf, headers: r.headers };
}
// Chunked Range download: small pieces (retried independently) so no single long-lived
// connection can stall the whole file. HF CDN supports range requests.
async function fetchBufProgress(url, label, chunk = 48 * 1024 * 1024) {
const t = performance.now();
// discover total size via the first range request's Content-Range
let total = 0, first;
for (let tr = 0; ; tr++) {
try {
first = await fetchAbortable(url, { headers: { Range: `bytes=0-${chunk - 1}` } });
const cr = first.headers.get("content-range");
total = cr ? +cr.split("/")[1] : (+first.headers.get("content-length") || first.buf.length);
break;
} catch (e) { if (tr >= 4) throw e; log(` ${label} init retry ${tr + 1}…`); await sleep(1200); }
}
if (!total || total <= first.buf.length) { // small file, already done
log(` ${label} downloaded ${(first.buf.length/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`);
return first.buf;
}
const buf = new Uint8Array(total);
buf.set(first.buf, 0);
let off = first.buf.length, lastPct = -1;
while (off < total) {
const end = Math.min(off + chunk, total) - 1;
let ok = false;
for (let tr = 0; tr < 5 && !ok; tr++) {
try {
const { buf: part } = await fetchAbortable(url, { headers: { Range: `bytes=${off}-${end}` } });
buf.set(part, off); off += part.length; ok = true;
} catch (e) { if (tr === 4) throw e; await sleep(1000); }
}
const pct = Math.floor((off / total) * 100);
if (pct >= lastPct + 10) { lastPct = pct; log(` ${label}: ${pct}% (${(off/1e6|0)}MB)`); }
}
log(` ${label} downloaded ${(total/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`);
return buf;
}
async function loadAll() {
setBadge($("load"), "loading…", "warn");
$("prog").style.display = "block";
// webgpu first; wasm as fallback so unsupported ops/devices degrade instead of hard-failing.
const sessOpts = { executionProviders: ["webgpu", "wasm"], graphOptimizationLevel: "all" };
let t;
// Create the ONNX sessions FIRST (before transformers.js), fetching buffers ourselves so we
// can see download vs. compile timing and avoid ort's internal URL fetch hanging on redirects.
log("downloading vision model (INT4, ~250MB)…");
const visGraph = await fetchBufProgress(VISION_URL, "vision graph");
const visData = await fetchBufProgress(VISION_DATA_URL, "vision data");
$("prog").value = 30;
log("compiling vision session…"); t = performance.now();
visionSess = await ort.InferenceSession.create(visGraph, {
...sessOpts,
externalData: [{ path: VISION_DATA, data: visData }],
});
log(`vision session ready in ${((performance.now()-t)/1000).toFixed(1)}s`);
$("prog").value = 50;
log("downloading INT4 language model (~1.7GB)…");
const langData = await fetchBufProgress(LANG_DATA_URL, "language data");
const langGraph = await fetchBufProgress(LANG_URL, "language graph");
$("prog").value = 80;
log("compiling language session…"); t = performance.now();
langSess = await ort.InferenceSession.create(langGraph, {
...sessOpts,
externalData: [{ path: LANG_DATA, data: langData }],
});
log(`language session ready in ${((performance.now()-t)/1000).toFixed(1)}s`);
out_names_cache.lang = langSess.outputNames;
$("prog").value = 90;
log("loading tokenizer + INT4 embedding table…");
tokenizer = await AutoTokenizer.from_pretrained(REPO);
embMeta = await (await fetch(EMB_META_URL)).json();
embPacked = await fetchBufProgress(EMB_PACKED_URL, "embed packed"); // uint8 [vocab, hidden/2]
const scalesBytes = await fetchBufProgress(EMB_SCALES_URL, "embed scales"); // fp16
const sv = new DataView(scalesBytes.buffer);
embScales = new Float32Array(scalesBytes.length / 2);
for (let i = 0; i < embScales.length; i++) embScales[i] = f16to32(sv.getUint16(i * 2, true));
log(`embedding: vocab=${embMeta.vocab} hidden=${embMeta.hidden} block=${embMeta.block_size}`);
$("prog").value = 100;
$("prog").style.display = "none";
setBadge($("load"), "model ready", "ok");
$("run").disabled = false;
log("ready.");
}
// ---------- image preprocessing (mirrors LocateAnythingImageProcessor) ----------
function preprocess(img) {
let w = img.naturalWidth, h = img.naturalHeight;
// rescale to <= in_token_limit patches
if ((Math.floor(w / PATCH)) * (Math.floor(h / PATCH)) > IN_TOKEN_LIMIT) {
const scale = Math.sqrt(IN_TOKEN_LIMIT / ((Math.floor(w / PATCH)) * (Math.floor(h / PATCH))));
w = Math.floor(w * scale); h = Math.floor(h * scale);
}
// pad to multiple of MERGE*PATCH
const pad = MERGE * PATCH;
const tw = Math.ceil(w / pad) * pad, th = Math.ceil(h / pad) * pad;
const c = document.createElement("canvas"); c.width = tw; c.height = th;
const cx = c.getContext("2d");
cx.imageSmoothingEnabled = true; cx.imageSmoothingQuality = "high";
cx.drawImage(img, 0, 0, tw, th);
const data = cx.getImageData(0, 0, tw, th).data; // RGBA, row-major
const gh = th / PATCH, gw = tw / PATCH;
const nPatches = gh * gw;
// pixel_values [nPatches, 3, 14, 14], patch order row-major (gh outer, gw inner)
const pv = new Float32Array(nPatches * 3 * PATCH * PATCH);
let p = 0;
for (let py = 0; py < gh; py++) {
for (let px = 0; px < gw; px++) {
for (let ch = 0; ch < 3; ch++) {
for (let yy = 0; yy < PATCH; yy++) {
for (let xx = 0; xx < PATCH; xx++) {
const sx = px * PATCH + xx, sy = py * PATCH + yy;
const v = data[(sy * tw + sx) * 4 + ch] / 255;
pv[p++] = (v - MEAN) / STD; // -> 2v-1
}
}
}
}
}
return { pv, gh, gw, nPatches, tw, th };
}
// ---------- prompt build ----------
async function buildInputIds(cat, gh, gw) {
const N = Math.floor((gh * gw) / (MERGE * MERGE));
const str =
`<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\n<image 1><img>` +
"<IMG_CONTEXT>".repeat(N) +
`</img>Locate all the instances that matches the following description: ${cat}.<|im_end|>\n<|im_start|>assistant\n`;
const enc = await tokenizer(str, { add_special_tokens: false });
const ids = Array.from(enc.input_ids.data, (x) => Number(x));
return { ids, N };
}
// ---------- INT4 embedding gather for one token -> Float32Array(hidden) ----------
function gatherEmbed(tokenId, dst, off) {
const H = embMeta.hidden, B = embMeta.block_size, NG = embMeta.n_groups, ZP = embMeta.zero_point;
const packedRow = tokenId * (H / 2);
const scaleRow = tokenId * NG;
for (let j = 0; j < H; j += 2) {
const byte = embPacked[packedRow + (j >> 1)];
const lo = byte & 0x0f, hi = (byte >> 4) & 0x0f;
const g0 = (j / B) | 0, g1 = ((j + 1) / B) | 0;
dst[off + j] = (lo - ZP) * embScales[scaleRow + g0];
dst[off + j + 1] = (hi - ZP) * embScales[scaleRow + g1];
}
}
// ---------- main detection ----------
async function detect() {
if (!curImage) { log("pick or upload an image first."); return; }
$("run").disabled = true;
const cat = $("cat").value.trim() || "object";
const maxNew = parseInt($("mnt").value, 10);
const t0 = performance.now();
// 1) preprocess + vision
const { pv, gh, gw, nPatches } = preprocess(curImage);
log(`image grid ${gh}x${gw} (${nPatches} patches)`);
const pvT = new ort.Tensor("float32", pv, [nPatches, 3, PATCH, PATCH]);
const ghT = new ort.Tensor("int64", BigInt64Array.from([BigInt(gh), BigInt(gw)]), [1, 2]);
const vOut = await visionSess.run({ pixel_values: pvT, image_grid_hws: ghT });
const visual = vOut[visionSess.outputNames[0]]; // [N,2048] float32
const H = embMeta.hidden;
log(`vision -> visual_features ${visual.dims.join("x")} (${((performance.now() - t0) / 1000).toFixed(1)}s)`);
// 2) prompt
const { ids, N } = await buildInputIds(cat, gh, gw);
log(`prompt tokens: ${ids.length} (image tokens N=${N})`);
// 3) prompt inputs_embeds = INT4 gather + visual splice at IMG_CONTEXT positions
const L = ids.length;
const embeds = new Float32Array(L * H);
let visIdx = 0;
const vdata = visual.data;
for (let i = 0; i < L; i++) {
if (ids[i] === IMG_CONTEXT) {
embeds.set(vdata.subarray(visIdx * H, (visIdx + 1) * H), i * H);
visIdx++;
} else {
gatherEmbed(ids[i], embeds, i * H);
}
}
// 4) KV-cache prefill
const idsBig = BigInt64Array.from(ids.map((x) => BigInt(x)));
const mkEmptyPast = () => {
const f = {};
for (let i = 0; i < N_LAYERS; i++) {
f[`past_key_${i}`] = new ort.Tensor("float32", new Float32Array(0), [1, KV_HEADS, 0, HEAD_DIM]);
f[`past_value_${i}`] = new ort.Tensor("float32", new Float32Array(0), [1, KV_HEADS, 0, HEAD_DIM]);
}
return f;
};
let feeds = {
input_ids: new ort.Tensor("int64", idsBig, [1, L]),
inputs_embeds: new ort.Tensor("float32", embeds, [1, L, H]),
attention_mask: new ort.Tensor("int64", BigInt64Array.from(new Array(L).fill(1n)), [1, L]),
position_ids: new ort.Tensor("int64", BigInt64Array.from(ids.map((_, i) => BigInt(i))), [1, L]),
...mkEmptyPast(),
};
let res = await langSess.run(feeds);
let present = res;
const outNames = langSess.outputNames;
const logits = res["logits"]; // [1,1,V]
const V = logits.dims[2];
let next = argmaxLast(logits.data, V);
const gen = [next];
log(`prefill done (${((performance.now() - t0) / 1000).toFixed(1)}s), decoding…`);
// 5) decode loop with KV cache
let pastLen = L;
const tDec = performance.now();
for (let step = 0; step < maxNew - 1; step++) {
if (next === IM_END) break;
const emb1 = new Float32Array(H);
gatherEmbed(next, emb1, 0);
const f = {
input_ids: new ort.Tensor("int64", BigInt64Array.from([BigInt(next)]), [1, 1]),
inputs_embeds: new ort.Tensor("float32", emb1, [1, 1, H]),
attention_mask: new ort.Tensor("int64", BigInt64Array.from(new Array(pastLen + 1).fill(1n)), [1, pastLen + 1]),
position_ids: new ort.Tensor("int64", BigInt64Array.from([BigInt(pastLen)]), [1, 1]),
};
for (let i = 0; i < N_LAYERS; i++) {
f[`past_key_${i}`] = present[`present_key_${i}`];
f[`past_value_${i}`] = present[`present_value_${i}`];
}
res = await langSess.run(f);
present = res;
next = argmaxLast(res["logits"].data, V);
gen.push(next);
pastLen += 1;
if (step % 8 === 0) { $("raw").textContent = tokenizer.decode(gen, { skip_special_tokens: false }); await new Promise(r => setTimeout(r)); }
}
const decS = (performance.now() - tDec) / 1000;
const text = tokenizer.decode(gen, { skip_special_tokens: false });
$("raw").textContent = text;
log(`decoded ${gen.length} tokens in ${decS.toFixed(1)}s (${(gen.length / decS).toFixed(2)} tok/s)`);
// 6) parse + draw
const dets = parseBoxes(text);
log(`detections: ${dets.length}`);
drawResult(dets);
$("run").disabled = false;
}
function argmaxLast(arr, V) {
// logits are [1, T, V]; take last row
const base = arr.length - V;
let best = 0, bv = -Infinity;
for (let i = 0; i < V; i++) { const v = arr[base + i]; if (v > bv) { bv = v; best = i; } }
return best;
}
function parseBoxes(text) {
const out = [];
const re = /<ref>(.*?)<\/ref>((?:\s*<box>.*?<\/box>)+)/gis;
let m;
while ((m = re.exec(text)) !== null) {
const label = m[1].trim();
const boxRe = /<box>(.*?)<\/box>/gis;
let b;
while ((b = boxRe.exec(m[2])) !== null) {
const nums = (b[1].match(/-?\d+\.?\d*/g) || []).map(Number);
if (nums.length === 4) out.push({ label, box: nums });
}
}
// bare boxes without a preceding ref
if (out.length === 0) {
const boxRe = /<box>(.*?)<\/box>/gis; let b;
while ((b = boxRe.exec(text)) !== null) {
const nums = (b[1].match(/-?\d+\.?\d*/g) || []).map(Number);
if (nums.length === 4) out.push({ label: "", box: nums });
}
}
return out;
}
const COLORS = ["#0891b2", "#dc2626", "#16a34a", "#2563eb", "#d97706", "#9333ea"];
function drawResult(dets) {
const cv = $("cv"), ctx = cv.getContext("2d");
const W = curImage.naturalWidth, Ht = curImage.naturalHeight;
cv.width = W; cv.height = Ht;
ctx.drawImage(curImage, 0, 0, W, Ht);
ctx.lineWidth = Math.max(2, Math.round(W / 320));
ctx.font = `${Math.max(12, Math.round(W / 45))}px sans-serif`;
dets.forEach((d, i) => {
const col = COLORS[i % COLORS.length];
const [x1, y1, x2, y2] = d.box;
const rx1 = (x1 * W) / 1000, ry1 = (y1 * Ht) / 1000, rx2 = (x2 * W) / 1000, ry2 = (y2 * Ht) / 1000;
ctx.strokeStyle = col; ctx.fillStyle = col;
ctx.strokeRect(rx1, ry1, rx2 - rx1, ry2 - ry1);
if (d.label) {
ctx.fillText(d.label, rx1 + 3, Math.max(ry1 - 4, 12));
}
});
}
// ---------- wiring ----------
function setImageFromSrc(src, el) {
const img = new Image(); img.crossOrigin = "anonymous";
img.onload = () => { curImage = img; const cv = $("cv"), ctx = cv.getContext("2d");
cv.width = img.naturalWidth; cv.height = img.naturalHeight; ctx.drawImage(img, 0, 0); };
img.src = src;
document.querySelectorAll(".samples img").forEach(s => s.classList.remove("sel"));
if (el) el.classList.add("sel");
}
function initSamples() {
const names = ["person", "book", "sweet", "ocr"];
const wrap = $("samples");
names.forEach(n => {
const im = document.createElement("img");
im.src = `./assets/${n}.jpg`;
im.onclick = () => setImageFromSrc(im.src, im);
wrap.appendChild(im);
});
}
async function main() {
initSamples();
$("mnt").oninput = () => $("mntv").textContent = $("mnt").value;
$("file").onchange = (e) => { const f = e.target.files[0]; if (f) setImageFromSrc(URL.createObjectURL(f)); };
$("run").onclick = () => detect().catch(err => { log("ERROR: " + err.message); console.error(err); $("run").disabled = false; });
if (!navigator.gpu) {
setBadge($("gpu"), "WebGPU not available", "err");
log("WebGPU is not available in this browser. Use Chrome/Edge 121+ or Safari 18+ with WebGPU enabled.");
} else {
setBadge($("gpu"), "WebGPU ✓", "ok");
}
try {
await loadAll();
// preselect first sample
setImageFromSrc("./assets/person.jpg", document.querySelector(".samples img"));
} catch (e) {
setBadge($("load"), "load failed", "err");
log("ERROR loading models: " + e.message);
console.error(e);
}
}
main();