Reza2kn commited on
Commit
9b70c76
·
verified ·
1 Parent(s): 2531e27

INT4 vision (251MB) + chunked Range downloads (fix 1.7GB stall); verified end-to-end in Chrome WebGPU

Browse files
Files changed (1) hide show
  1. app.js +94 -19
app.js CHANGED
@@ -8,7 +8,9 @@ import { AutoTokenizer } from "https://cdn.jsdelivr.net/npm/@huggingface/transfo
8
 
9
  const REPO = "Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4";
10
  const BASE = `https://huggingface.co/${REPO}/resolve/main`;
11
- const VISION_URL = `${BASE}/onnx/vision_mlp.onnx`;
 
 
12
  const LANG_URL = `${BASE}/onnx/language_tail_kv_int4.onnx`;
13
  const LANG_DATA = "language_tail_kv_int4.onnx.data";
14
  const LANG_DATA_URL = `${BASE}/onnx/${LANG_DATA}`;
@@ -48,37 +50,110 @@ async function fetchBuf(url, label) {
48
  return new Uint8Array(await r.arrayBuffer());
49
  }
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  async function loadAll() {
52
  setBadge($("load"), "loading…", "warn");
53
  $("prog").style.display = "block";
54
  // webgpu first; wasm as fallback so unsupported ops/devices degrade instead of hard-failing.
55
  const sessOpts = { executionProviders: ["webgpu", "wasm"], graphOptimizationLevel: "all" };
 
56
 
57
- log("loading tokenizer…");
58
- tokenizer = await AutoTokenizer.from_pretrained(REPO);
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- log("loading embedding INT4 table…");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  embMeta = await (await fetch(EMB_META_URL)).json();
62
- embPacked = await fetchBuf(EMB_PACKED_URL, "embed packed"); // uint8 [vocab, hidden/2]
63
- const scalesBytes = await fetchBuf(EMB_SCALES_URL, "embed scales"); // fp16
64
  const sv = new DataView(scalesBytes.buffer);
65
  embScales = new Float32Array(scalesBytes.length / 2);
66
  for (let i = 0; i < embScales.length; i++) embScales[i] = f16to32(sv.getUint16(i * 2, true));
67
  log(`embedding: vocab=${embMeta.vocab} hidden=${embMeta.hidden} block=${embMeta.block_size}`);
68
- $("prog").value = 20;
69
 
70
- log("loading vision model (~1.7GB)…");
71
- visionSess = await ort.InferenceSession.create(VISION_URL, sessOpts);
72
- $("prog").value = 50;
73
-
74
- log("loading INT4 language model (~1.7GB + data)…");
75
- const langData = await fetchBuf(LANG_DATA_URL, "language data");
76
- $("prog").value = 85;
77
- langSess = await ort.InferenceSession.create(LANG_URL, {
78
- ...sessOpts,
79
- externalData: [{ path: LANG_DATA, data: langData }],
80
- });
81
- out_names_cache.lang = langSess.outputNames;
82
  $("prog").value = 100;
83
  $("prog").style.display = "none";
84
  setBadge($("load"), "model ready", "ok");
 
8
 
9
  const REPO = "Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4";
10
  const BASE = `https://huggingface.co/${REPO}/resolve/main`;
11
+ const VISION_URL = `${BASE}/onnx/vision_mlp_int4.onnx`; // INT4 (~250MB) — fp32 1.73GB stalls on download
12
+ const VISION_DATA = "vision_mlp_int4.onnx.data";
13
+ const VISION_DATA_URL = `${BASE}/onnx/${VISION_DATA}`;
14
  const LANG_URL = `${BASE}/onnx/language_tail_kv_int4.onnx`;
15
  const LANG_DATA = "language_tail_kv_int4.onnx.data";
16
  const LANG_DATA_URL = `${BASE}/onnx/${LANG_DATA}`;
 
50
  return new Uint8Array(await r.arrayBuffer());
51
  }
52
 
53
+ const sleep = (ms) => new Promise((r) => setTimeout(r, ms));
54
+
55
+ // fetch with a stall watchdog: aborts if no progress within `stallMs`
56
+ async function fetchAbortable(url, opts, stallMs = 30000) {
57
+ const ctrl = new AbortController();
58
+ const r = await fetch(url, { ...opts, signal: ctrl.signal });
59
+ if (!(r.status === 200 || r.status === 206)) throw new Error(`status ${r.status}`);
60
+ const reader = r.body.getReader();
61
+ const chunks = []; let got = 0;
62
+ let timer = setTimeout(() => ctrl.abort(), stallMs);
63
+ try {
64
+ for (;;) {
65
+ const { done, value } = await reader.read();
66
+ if (done) break;
67
+ clearTimeout(timer); timer = setTimeout(() => ctrl.abort(), stallMs);
68
+ chunks.push(value); got += value.length;
69
+ }
70
+ } finally { clearTimeout(timer); }
71
+ const buf = new Uint8Array(got); let o = 0;
72
+ for (const c of chunks) { buf.set(c, o); o += c.length; }
73
+ return { buf, headers: r.headers };
74
+ }
75
+
76
+ // Chunked Range download: small pieces (retried independently) so no single long-lived
77
+ // connection can stall the whole file. HF CDN supports range requests.
78
+ async function fetchBufProgress(url, label, chunk = 48 * 1024 * 1024) {
79
+ const t = performance.now();
80
+ // discover total size via the first range request's Content-Range
81
+ let total = 0, first;
82
+ for (let tr = 0; ; tr++) {
83
+ try {
84
+ first = await fetchAbortable(url, { headers: { Range: `bytes=0-${chunk - 1}` } });
85
+ const cr = first.headers.get("content-range");
86
+ total = cr ? +cr.split("/")[1] : (+first.headers.get("content-length") || first.buf.length);
87
+ break;
88
+ } catch (e) { if (tr >= 4) throw e; log(` ${label} init retry ${tr + 1}…`); await sleep(1200); }
89
+ }
90
+ if (!total || total <= first.buf.length) { // small file, already done
91
+ log(` ${label} downloaded ${(first.buf.length/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`);
92
+ return first.buf;
93
+ }
94
+ const buf = new Uint8Array(total);
95
+ buf.set(first.buf, 0);
96
+ let off = first.buf.length, lastPct = -1;
97
+ while (off < total) {
98
+ const end = Math.min(off + chunk, total) - 1;
99
+ let ok = false;
100
+ for (let tr = 0; tr < 5 && !ok; tr++) {
101
+ try {
102
+ const { buf: part } = await fetchAbortable(url, { headers: { Range: `bytes=${off}-${end}` } });
103
+ buf.set(part, off); off += part.length; ok = true;
104
+ } catch (e) { if (tr === 4) throw e; await sleep(1000); }
105
+ }
106
+ const pct = Math.floor((off / total) * 100);
107
+ if (pct >= lastPct + 10) { lastPct = pct; log(` ${label}: ${pct}% (${(off/1e6|0)}MB)`); }
108
+ }
109
+ log(` ${label} downloaded ${(total/1e6|0)}MB in ${((performance.now()-t)/1000).toFixed(1)}s`);
110
+ return buf;
111
+ }
112
+
113
  async function loadAll() {
114
  setBadge($("load"), "loading…", "warn");
115
  $("prog").style.display = "block";
116
  // webgpu first; wasm as fallback so unsupported ops/devices degrade instead of hard-failing.
117
  const sessOpts = { executionProviders: ["webgpu", "wasm"], graphOptimizationLevel: "all" };
118
+ let t;
119
 
120
+ // Create the ONNX sessions FIRST (before transformers.js), fetching buffers ourselves so we
121
+ // can see download vs. compile timing and avoid ort's internal URL fetch hanging on redirects.
122
+ log("downloading vision model (INT4, ~250MB)…");
123
+ const visGraph = await fetchBufProgress(VISION_URL, "vision graph");
124
+ const visData = await fetchBufProgress(VISION_DATA_URL, "vision data");
125
+ $("prog").value = 30;
126
+ log("compiling vision session…"); t = performance.now();
127
+ visionSess = await ort.InferenceSession.create(visGraph, {
128
+ ...sessOpts,
129
+ externalData: [{ path: VISION_DATA, data: visData }],
130
+ });
131
+ log(`vision session ready in ${((performance.now()-t)/1000).toFixed(1)}s`);
132
+ $("prog").value = 50;
133
 
134
+ log("downloading INT4 language model (~1.7GB)…");
135
+ const langData = await fetchBufProgress(LANG_DATA_URL, "language data");
136
+ const langGraph = await fetchBufProgress(LANG_URL, "language graph");
137
+ $("prog").value = 80;
138
+ log("compiling language session…"); t = performance.now();
139
+ langSess = await ort.InferenceSession.create(langGraph, {
140
+ ...sessOpts,
141
+ externalData: [{ path: LANG_DATA, data: langData }],
142
+ });
143
+ log(`language session ready in ${((performance.now()-t)/1000).toFixed(1)}s`);
144
+ out_names_cache.lang = langSess.outputNames;
145
+ $("prog").value = 90;
146
+
147
+ log("loading tokenizer + INT4 embedding table…");
148
+ tokenizer = await AutoTokenizer.from_pretrained(REPO);
149
  embMeta = await (await fetch(EMB_META_URL)).json();
150
+ embPacked = await fetchBufProgress(EMB_PACKED_URL, "embed packed"); // uint8 [vocab, hidden/2]
151
+ const scalesBytes = await fetchBufProgress(EMB_SCALES_URL, "embed scales"); // fp16
152
  const sv = new DataView(scalesBytes.buffer);
153
  embScales = new Float32Array(scalesBytes.length / 2);
154
  for (let i = 0; i < embScales.length; i++) embScales[i] = f16to32(sv.getUint16(i * 2, true));
155
  log(`embedding: vocab=${embMeta.vocab} hidden=${embMeta.hidden} block=${embMeta.block_size}`);
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  $("prog").value = 100;
158
  $("prog").style.display = "none";
159
  setBadge($("load"), "model ready", "ok");