| |
| |
| |
| |
|
|
| import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/ort.webgpu.min.mjs"; |
| import { AutoTokenizer } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.1"; |
|
|
| const REPO = "Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4"; |
| const BASE = `https://huggingface.co/${REPO}/resolve/main`; |
| const VISION_URL = `${BASE}/onnx/vision_mlp_int4.onnx`; |
| const VISION_DATA = "vision_mlp_int4.onnx.data"; |
| const VISION_DATA_URL = `${BASE}/onnx/${VISION_DATA}`; |
| const LANG_URL = `${BASE}/onnx/language_tail_kv_int4.onnx`; |
| const LANG_DATA = "language_tail_kv_int4.onnx.data"; |
| const LANG_DATA_URL = `${BASE}/onnx/${LANG_DATA}`; |
| const EMB_PACKED_URL = `${BASE}/onnx/embed_tokens_int4_packed.bin`; |
| const EMB_SCALES_URL = `${BASE}/onnx/embed_tokens_int4_scales.bin`; |
| const EMB_META_URL = `${BASE}/onnx/embed_tokens_int4_meta.json`; |
|
|
| |
| const IMG_CONTEXT = 151665; |
| const IM_END = 151645; |
| const N_LAYERS = 36, KV_HEADS = 2, HEAD_DIM = 128; |
| const PATCH = 14, MERGE = 2, IN_TOKEN_LIMIT = 256; |
| const MEAN = 0.5, STD = 0.5; |
|
|
| const $ = (id) => document.getElementById(id); |
| const logEl = $("log"); |
| function log(m) { logEl.textContent += m + "\n"; logEl.scrollTop = logEl.scrollHeight; console.log(m); } |
| function setBadge(el, text, cls) { el.textContent = text; el.className = "badge" + (cls ? " " + cls : ""); } |
|
|
| let tokenizer = null, visionSess = null, langSess = null; |
| let embPacked = null, embScales = null, embMeta = null; |
| let curImage = null; |
| const out_names_cache = {}; |
|
|
| |
| function f16to32(h) { |
| const s = (h & 0x8000) >> 15, e = (h & 0x7c00) >> 10, f = h & 0x03ff; |
| if (e === 0) return (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024); |
| if (e === 0x1f) return f ? NaN : (s ? -Infinity : Infinity); |
| return (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024); |
| } |
|
|
| |
| async function fetchBuf(url, label) { |
| const r = await fetch(url); |
| if (!r.ok) throw new Error(`fetch ${label} failed: ${r.status}`); |
| return new Uint8Array(await r.arrayBuffer()); |
| } |
|
|
| const sleep = (ms) => new Promise((r) => setTimeout(r, ms)); |
|
|
| |
| async function fetchAbortable(url, opts, stallMs = 30000) { |
| const ctrl = new AbortController(); |
| const r = await fetch(url, { ...opts, signal: ctrl.signal }); |
| if (!(r.status === 200 || r.status === 206)) throw new Error(`status ${r.status}`); |
| const reader = r.body.getReader(); |
| const chunks = []; let got = 0; |
| let timer = setTimeout(() => ctrl.abort(), stallMs); |
| try { |
| for (;;) { |
| const { done, value } = await reader.read(); |
| if (done) break; |
| clearTimeout(timer); timer = setTimeout(() => ctrl.abort(), stallMs); |
| chunks.push(value); got += value.length; |
| } |
| } finally { clearTimeout(timer); } |
| const buf = new Uint8Array(got); let o = 0; |
| for (const c of chunks) { buf.set(c, o); o += c.length; } |
| return { buf, headers: r.headers }; |
| } |
|
|
| |
| |
| async function fetchBufProgress(url, label, chunk = 48 * 1024 * 1024) { |
| const t = performance.now(); |
| |
| let total = 0, first; |
| for (let tr = 0; ; tr++) { |
| try { |
| first = await fetchAbortable(url, { headers: { Range: `bytes=0-${chunk - 1}` } }); |
| const cr = first.headers.get("content-range"); |
| total = cr ? +cr.split("/")[1] : (+first.headers.get("content-length") || first.buf.length); |
| break; |
| } catch (e) { if (tr >= 4) throw e; log(` ${label} init retry ${tr + 1}…`); await sleep(1200); } |
| } |
| if (!total || total <= first.buf.length) { |
| log(` ${label} downloaded ${(first.buf.length/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`); |
| return first.buf; |
| } |
| const buf = new Uint8Array(total); |
| buf.set(first.buf, 0); |
| let off = first.buf.length, lastPct = -1; |
| while (off < total) { |
| const end = Math.min(off + chunk, total) - 1; |
| let ok = false; |
| for (let tr = 0; tr < 5 && !ok; tr++) { |
| try { |
| const { buf: part } = await fetchAbortable(url, { headers: { Range: `bytes=${off}-${end}` } }); |
| buf.set(part, off); off += part.length; ok = true; |
| } catch (e) { if (tr === 4) throw e; await sleep(1000); } |
| } |
| const pct = Math.floor((off / total) * 100); |
| if (pct >= lastPct + 10) { lastPct = pct; log(` ${label}: ${pct}% (${(off/1e6|0)}MB)`); } |
| } |
| log(` ${label} downloaded ${(total/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`); |
| return buf; |
| } |
|
|
| async function loadAll() { |
| setBadge($("load"), "loading…", "warn"); |
| $("prog").style.display = "block"; |
| |
| const sessOpts = { executionProviders: ["webgpu", "wasm"], graphOptimizationLevel: "all" }; |
| let t; |
|
|
| |
| |
| log("downloading vision model (INT4, ~250MB)…"); |
| const visGraph = await fetchBufProgress(VISION_URL, "vision graph"); |
| const visData = await fetchBufProgress(VISION_DATA_URL, "vision data"); |
| $("prog").value = 30; |
| log("compiling vision session…"); t = performance.now(); |
| visionSess = await ort.InferenceSession.create(visGraph, { |
| ...sessOpts, |
| externalData: [{ path: VISION_DATA, data: visData }], |
| }); |
| log(`vision session ready in ${((performance.now()-t)/1000).toFixed(1)}s`); |
| $("prog").value = 50; |
|
|
| log("downloading INT4 language model (~1.7GB)…"); |
| const langData = await fetchBufProgress(LANG_DATA_URL, "language data"); |
| const langGraph = await fetchBufProgress(LANG_URL, "language graph"); |
| $("prog").value = 80; |
| log("compiling language session…"); t = performance.now(); |
| langSess = await ort.InferenceSession.create(langGraph, { |
| ...sessOpts, |
| externalData: [{ path: LANG_DATA, data: langData }], |
| }); |
| log(`language session ready in ${((performance.now()-t)/1000).toFixed(1)}s`); |
| out_names_cache.lang = langSess.outputNames; |
| $("prog").value = 90; |
|
|
| log("loading tokenizer + INT4 embedding table…"); |
| tokenizer = await AutoTokenizer.from_pretrained(REPO); |
| embMeta = await (await fetch(EMB_META_URL)).json(); |
| embPacked = await fetchBufProgress(EMB_PACKED_URL, "embed packed"); |
| const scalesBytes = await fetchBufProgress(EMB_SCALES_URL, "embed scales"); |
| const sv = new DataView(scalesBytes.buffer); |
| embScales = new Float32Array(scalesBytes.length / 2); |
| for (let i = 0; i < embScales.length; i++) embScales[i] = f16to32(sv.getUint16(i * 2, true)); |
| log(`embedding: vocab=${embMeta.vocab} hidden=${embMeta.hidden} block=${embMeta.block_size}`); |
|
|
| $("prog").value = 100; |
| $("prog").style.display = "none"; |
| setBadge($("load"), "model ready", "ok"); |
| $("run").disabled = false; |
| log("ready."); |
| } |
|
|
| |
| function preprocess(img) { |
| let w = img.naturalWidth, h = img.naturalHeight; |
| |
| if ((Math.floor(w / PATCH)) * (Math.floor(h / PATCH)) > IN_TOKEN_LIMIT) { |
| const scale = Math.sqrt(IN_TOKEN_LIMIT / ((Math.floor(w / PATCH)) * (Math.floor(h / PATCH)))); |
| w = Math.floor(w * scale); h = Math.floor(h * scale); |
| } |
| |
| const pad = MERGE * PATCH; |
| const tw = Math.ceil(w / pad) * pad, th = Math.ceil(h / pad) * pad; |
| const c = document.createElement("canvas"); c.width = tw; c.height = th; |
| const cx = c.getContext("2d"); |
| cx.imageSmoothingEnabled = true; cx.imageSmoothingQuality = "high"; |
| cx.drawImage(img, 0, 0, tw, th); |
| const data = cx.getImageData(0, 0, tw, th).data; |
| const gh = th / PATCH, gw = tw / PATCH; |
| const nPatches = gh * gw; |
| |
| const pv = new Float32Array(nPatches * 3 * PATCH * PATCH); |
| let p = 0; |
| for (let py = 0; py < gh; py++) { |
| for (let px = 0; px < gw; px++) { |
| for (let ch = 0; ch < 3; ch++) { |
| for (let yy = 0; yy < PATCH; yy++) { |
| for (let xx = 0; xx < PATCH; xx++) { |
| const sx = px * PATCH + xx, sy = py * PATCH + yy; |
| const v = data[(sy * tw + sx) * 4 + ch] / 255; |
| pv[p++] = (v - MEAN) / STD; |
| } |
| } |
| } |
| } |
| } |
| return { pv, gh, gw, nPatches, tw, th }; |
| } |
|
|
| |
| async function buildInputIds(cat, gh, gw) { |
| const N = Math.floor((gh * gw) / (MERGE * MERGE)); |
| const str = |
| `<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\n<image 1><img>` + |
| "<IMG_CONTEXT>".repeat(N) + |
| `</img>Locate all the instances that matches the following description: ${cat}.<|im_end|>\n<|im_start|>assistant\n`; |
| const enc = await tokenizer(str, { add_special_tokens: false }); |
| const ids = Array.from(enc.input_ids.data, (x) => Number(x)); |
| return { ids, N }; |
| } |
|
|
| |
| function gatherEmbed(tokenId, dst, off) { |
| const H = embMeta.hidden, B = embMeta.block_size, NG = embMeta.n_groups, ZP = embMeta.zero_point; |
| const packedRow = tokenId * (H / 2); |
| const scaleRow = tokenId * NG; |
| for (let j = 0; j < H; j += 2) { |
| const byte = embPacked[packedRow + (j >> 1)]; |
| const lo = byte & 0x0f, hi = (byte >> 4) & 0x0f; |
| const g0 = (j / B) | 0, g1 = ((j + 1) / B) | 0; |
| dst[off + j] = (lo - ZP) * embScales[scaleRow + g0]; |
| dst[off + j + 1] = (hi - ZP) * embScales[scaleRow + g1]; |
| } |
| } |
|
|
| |
| async function detect() { |
| if (!curImage) { log("pick or upload an image first."); return; } |
| $("run").disabled = true; |
| const cat = $("cat").value.trim() || "object"; |
| const maxNew = parseInt($("mnt").value, 10); |
| const t0 = performance.now(); |
|
|
| |
| const { pv, gh, gw, nPatches } = preprocess(curImage); |
| log(`image grid ${gh}x${gw} (${nPatches} patches)`); |
| const pvT = new ort.Tensor("float32", pv, [nPatches, 3, PATCH, PATCH]); |
| const ghT = new ort.Tensor("int64", BigInt64Array.from([BigInt(gh), BigInt(gw)]), [1, 2]); |
| const vOut = await visionSess.run({ pixel_values: pvT, image_grid_hws: ghT }); |
| const visual = vOut[visionSess.outputNames[0]]; |
| const H = embMeta.hidden; |
| log(`vision -> visual_features ${visual.dims.join("x")} (${((performance.now() - t0) / 1000).toFixed(1)}s)`); |
|
|
| |
| const { ids, N } = await buildInputIds(cat, gh, gw); |
| log(`prompt tokens: ${ids.length} (image tokens N=${N})`); |
|
|
| |
| const L = ids.length; |
| const embeds = new Float32Array(L * H); |
| let visIdx = 0; |
| const vdata = visual.data; |
| for (let i = 0; i < L; i++) { |
| if (ids[i] === IMG_CONTEXT) { |
| embeds.set(vdata.subarray(visIdx * H, (visIdx + 1) * H), i * H); |
| visIdx++; |
| } else { |
| gatherEmbed(ids[i], embeds, i * H); |
| } |
| } |
|
|
| |
| const idsBig = BigInt64Array.from(ids.map((x) => BigInt(x))); |
| const mkEmptyPast = () => { |
| const f = {}; |
| for (let i = 0; i < N_LAYERS; i++) { |
| f[`past_key_${i}`] = new ort.Tensor("float32", new Float32Array(0), [1, KV_HEADS, 0, HEAD_DIM]); |
| f[`past_value_${i}`] = new ort.Tensor("float32", new Float32Array(0), [1, KV_HEADS, 0, HEAD_DIM]); |
| } |
| return f; |
| }; |
| let feeds = { |
| input_ids: new ort.Tensor("int64", idsBig, [1, L]), |
| inputs_embeds: new ort.Tensor("float32", embeds, [1, L, H]), |
| attention_mask: new ort.Tensor("int64", BigInt64Array.from(new Array(L).fill(1n)), [1, L]), |
| position_ids: new ort.Tensor("int64", BigInt64Array.from(ids.map((_, i) => BigInt(i))), [1, L]), |
| ...mkEmptyPast(), |
| }; |
| let res = await langSess.run(feeds); |
| let present = res; |
| const outNames = langSess.outputNames; |
| const logits = res["logits"]; |
| const V = logits.dims[2]; |
| let next = argmaxLast(logits.data, V); |
| const gen = [next]; |
| log(`prefill done (${((performance.now() - t0) / 1000).toFixed(1)}s), decoding…`); |
|
|
| |
| let pastLen = L; |
| const tDec = performance.now(); |
| for (let step = 0; step < maxNew - 1; step++) { |
| if (next === IM_END) break; |
| const emb1 = new Float32Array(H); |
| gatherEmbed(next, emb1, 0); |
| const f = { |
| input_ids: new ort.Tensor("int64", BigInt64Array.from([BigInt(next)]), [1, 1]), |
| inputs_embeds: new ort.Tensor("float32", emb1, [1, 1, H]), |
| attention_mask: new ort.Tensor("int64", BigInt64Array.from(new Array(pastLen + 1).fill(1n)), [1, pastLen + 1]), |
| position_ids: new ort.Tensor("int64", BigInt64Array.from([BigInt(pastLen)]), [1, 1]), |
| }; |
| for (let i = 0; i < N_LAYERS; i++) { |
| f[`past_key_${i}`] = present[`present_key_${i}`]; |
| f[`past_value_${i}`] = present[`present_value_${i}`]; |
| } |
| res = await langSess.run(f); |
| present = res; |
| next = argmaxLast(res["logits"].data, V); |
| gen.push(next); |
| pastLen += 1; |
| if (step % 8 === 0) { $("raw").textContent = tokenizer.decode(gen, { skip_special_tokens: false }); await new Promise(r => setTimeout(r)); } |
| } |
| const decS = (performance.now() - tDec) / 1000; |
| const text = tokenizer.decode(gen, { skip_special_tokens: false }); |
| $("raw").textContent = text; |
| log(`decoded ${gen.length} tokens in ${decS.toFixed(1)}s (${(gen.length / decS).toFixed(2)} tok/s)`); |
|
|
| |
| const dets = parseBoxes(text); |
| log(`detections: ${dets.length}`); |
| drawResult(dets); |
| $("run").disabled = false; |
| } |
|
|
| function argmaxLast(arr, V) { |
| |
| const base = arr.length - V; |
| let best = 0, bv = -Infinity; |
| for (let i = 0; i < V; i++) { const v = arr[base + i]; if (v > bv) { bv = v; best = i; } } |
| return best; |
| } |
|
|
| function parseBoxes(text) { |
| const out = []; |
| const re = /<ref>(.*?)<\/ref>((?:\s*<box>.*?<\/box>)+)/gis; |
| let m; |
| while ((m = re.exec(text)) !== null) { |
| const label = m[1].trim(); |
| const boxRe = /<box>(.*?)<\/box>/gis; |
| let b; |
| while ((b = boxRe.exec(m[2])) !== null) { |
| const nums = (b[1].match(/-?\d+\.?\d*/g) || []).map(Number); |
| if (nums.length === 4) out.push({ label, box: nums }); |
| } |
| } |
| |
| if (out.length === 0) { |
| const boxRe = /<box>(.*?)<\/box>/gis; let b; |
| while ((b = boxRe.exec(text)) !== null) { |
| const nums = (b[1].match(/-?\d+\.?\d*/g) || []).map(Number); |
| if (nums.length === 4) out.push({ label: "", box: nums }); |
| } |
| } |
| return out; |
| } |
|
|
| const COLORS = ["#0891b2", "#dc2626", "#16a34a", "#2563eb", "#d97706", "#9333ea"]; |
| function drawResult(dets) { |
| const cv = $("cv"), ctx = cv.getContext("2d"); |
| const W = curImage.naturalWidth, Ht = curImage.naturalHeight; |
| cv.width = W; cv.height = Ht; |
| ctx.drawImage(curImage, 0, 0, W, Ht); |
| ctx.lineWidth = Math.max(2, Math.round(W / 320)); |
| ctx.font = `${Math.max(12, Math.round(W / 45))}px sans-serif`; |
| dets.forEach((d, i) => { |
| const col = COLORS[i % COLORS.length]; |
| const [x1, y1, x2, y2] = d.box; |
| const rx1 = (x1 * W) / 1000, ry1 = (y1 * Ht) / 1000, rx2 = (x2 * W) / 1000, ry2 = (y2 * Ht) / 1000; |
| ctx.strokeStyle = col; ctx.fillStyle = col; |
| ctx.strokeRect(rx1, ry1, rx2 - rx1, ry2 - ry1); |
| if (d.label) { |
| ctx.fillText(d.label, rx1 + 3, Math.max(ry1 - 4, 12)); |
| } |
| }); |
| } |
|
|
| |
| function setImageFromSrc(src, el) { |
| const img = new Image(); img.crossOrigin = "anonymous"; |
| img.onload = () => { curImage = img; const cv = $("cv"), ctx = cv.getContext("2d"); |
| cv.width = img.naturalWidth; cv.height = img.naturalHeight; ctx.drawImage(img, 0, 0); }; |
| img.src = src; |
| document.querySelectorAll(".samples img").forEach(s => s.classList.remove("sel")); |
| if (el) el.classList.add("sel"); |
| } |
|
|
| function initSamples() { |
| const names = ["person", "book", "sweet", "ocr"]; |
| const wrap = $("samples"); |
| names.forEach(n => { |
| const im = document.createElement("img"); |
| im.src = `./assets/${n}.jpg`; |
| im.onclick = () => setImageFromSrc(im.src, im); |
| wrap.appendChild(im); |
| }); |
| } |
|
|
| async function main() { |
| initSamples(); |
| $("mnt").oninput = () => $("mntv").textContent = $("mnt").value; |
| $("file").onchange = (e) => { const f = e.target.files[0]; if (f) setImageFromSrc(URL.createObjectURL(f)); }; |
| $("run").onclick = () => detect().catch(err => { log("ERROR: " + err.message); console.error(err); $("run").disabled = false; }); |
|
|
| if (!navigator.gpu) { |
| setBadge($("gpu"), "WebGPU not available", "err"); |
| log("WebGPU is not available in this browser. Use Chrome/Edge 121+ or Safari 18+ with WebGPU enabled."); |
| } else { |
| setBadge($("gpu"), "WebGPU ✓", "ok"); |
| } |
| try { |
| await loadAll(); |
| |
| setImageFromSrc("./assets/person.jpg", document.querySelector(".samples img")); |
| } catch (e) { |
| setBadge($("load"), "load failed", "err"); |
| log("ERROR loading models: " + e.message); |
| console.error(e); |
| } |
| } |
| main(); |
|
|