Spaces:
Sleeping
Sleeping
| // Benchmark the streaming ONNX parser against a few real-world Hugging Face | |
| // models, with no browser involved. | |
| // | |
| // For each model we: | |
| // 1. Probe the Hub for the best ONNX export (external-data preferred). | |
| // 2. Run the streaming parser. Measure bytes actually transferred over | |
| // the wire vs the full file size. | |
| // 3. Print a summary of the parsed graph (ops, initializers, IO). | |
| // | |
| // Usage: | |
| // npx tsx scripts/test-stream.ts | |
| // npx tsx scripts/test-stream.ts Xenova/distilbert-base-uncased | |
| // | |
| // Note: this script imports the same lib code that runs in the browser. It | |
| // only relies on fetch + ReadableStream which are both native in Node β₯18. | |
| import { resolveOnnxFile, bytesPretty } from "../src/lib/hub"; | |
| import { parseOnnxStream } from "../src/lib/onnxStream"; | |
| interface Target { | |
| label: string; | |
| resolve: () => Promise<{ | |
| repoId: string; | |
| filePath: string; | |
| url: string; | |
| sizeBytes: number; | |
| hasExternalData: boolean; | |
| }>; | |
| } | |
| function asTarget(arg: string): Target { | |
| if (arg.startsWith("http://") || arg.startsWith("https://")) { | |
| const url = arg; | |
| return { | |
| label: url, | |
| resolve: async () => { | |
| const head = await fetch(url, { method: "HEAD" }); | |
| if (!head.ok) throw new Error(`HEAD ${url} failed: ${head.status}`); | |
| const len = Number(head.headers.get("content-length") ?? 0); | |
| const linked = Number(head.headers.get("x-linked-size") ?? 0); | |
| const size = linked || len; | |
| return { | |
| repoId: url, | |
| filePath: url.split("/").slice(-1)[0], | |
| url, | |
| sizeBytes: size, | |
| hasExternalData: false, | |
| }; | |
| }, | |
| }; | |
| } | |
| return { | |
| label: arg, | |
| resolve: async () => resolveOnnxFile(arg), | |
| }; | |
| } | |
| const DEFAULT_MODELS = [ | |
| "Xenova/all-MiniLM-L6-v2", | |
| "Xenova/distilbert-base-uncased", | |
| "https://huggingface.co/Xenova/distilbert-base-uncased/resolve/main/onnx/model.onnx", | |
| "onnx-community/Llama-3.2-1B", | |
| ]; | |
| function pct(part: number, total: number): string { | |
| if (!total) return "n/a"; | |
| return `${((part / total) * 100).toFixed(1)}%`; | |
| } | |
| function topByCount<T>(items: T[], key: (t: T) => string, n: number): [string, number][] { | |
| const counts = new Map<string, number>(); | |
| for (const it of items) { | |
| const k = key(it); | |
| counts.set(k, (counts.get(k) ?? 0) + 1); | |
| } | |
| return [...counts.entries()].sort((a, b) => b[1] - a[1]).slice(0, n); | |
| } | |
| async function benchOne(input: string): Promise<void> { | |
| console.log(`\nββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ`); | |
| console.log(`βΆ ${input}`); | |
| console.log(`ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ`); | |
| const target = asTarget(input); | |
| const t0 = performance.now(); | |
| const resolved = await target.resolve(); | |
| const t1 = performance.now(); | |
| console.log(` resolved in ${(t1 - t0).toFixed(0)} ms`); | |
| console.log(` path: ${resolved.filePath}`); | |
| console.log(` .onnx size: ${bytesPretty(resolved.sizeBytes)}`); | |
| console.log(` external data: ${resolved.hasExternalData ? "yes (sidecar .onnx_data ignored)" : "no (weights inline)"}`); | |
| console.log(` url: ${resolved.url}`); | |
| let lastReported = 0; | |
| const onTransfer = (_delta: number, total: number) => { | |
| if (total - lastReported >= 1024 * 1024) { | |
| lastReported = total; | |
| process.stdout.write( | |
| `\r streaming⦠${bytesPretty(total)} / ${bytesPretty(resolved.sizeBytes)} ` + | |
| `(${pct(total, resolved.sizeBytes)}) `, | |
| ); | |
| } | |
| }; | |
| const t2 = performance.now(); | |
| const { model, stats } = await parseOnnxStream(resolved.url, { onTransfer }); | |
| const t3 = performance.now(); | |
| process.stdout.write("\r" + " ".repeat(80) + "\r"); | |
| const g = model.graph; | |
| if (!g) { | |
| console.error(" β no graph in model"); | |
| return; | |
| } | |
| console.log(` parsed in ${(t3 - t2).toFixed(0)} ms`); | |
| console.log(` bytes transferred: ${bytesPretty(stats.bytesTransferred)} / ${bytesPretty(stats.totalFileSize)}`); | |
| const saved = stats.totalFileSize - stats.bytesTransferred; | |
| console.log( | |
| ` bandwidth saved: ${bytesPretty(Math.max(0, saved))} (${pct(saved, stats.totalFileSize)} of file)`, | |
| ); | |
| console.log(` ir_version: ${model.ir_version ?? "?"}`); | |
| console.log(` producer: ${model.producer_name ?? "?"}`); | |
| console.log(` opsets: ${(model.opset_import ?? []).map((o) => `${o.domain || "ai.onnx"}@${o.version}`).join(", ")}`); | |
| console.log(` graph name: ${g.name ?? "(unnamed)"}`); | |
| console.log(` inputs: ${(g.input ?? []).length}`); | |
| console.log(` outputs: ${(g.output ?? []).length}`); | |
| console.log(` initializers: ${(g.initializer ?? []).length}`); | |
| console.log(` value_info nodes: ${(g.value_info ?? []).length}`); | |
| console.log(` nodes (ops): ${(g.node ?? []).length}`); | |
| const opCounts = topByCount(g.node ?? [], (n) => n.op_type ?? "?", 8); | |
| console.log(` top op_types:`); | |
| for (const [op, c] of opCounts) console.log(` ${op.padEnd(20)} ${c}`); | |
| // A few sample initializers (showing we got metadata without bytes). | |
| const inits = g.initializer ?? []; | |
| if (inits.length) { | |
| console.log(` sample initializers (first 3):`); | |
| for (const t of inits.slice(0, 3)) { | |
| const shape = (t.dims ?? []).join("Γ") || "scalar"; | |
| console.log(` ${(t.name ?? "?").padEnd(50)} [${shape}] dtype=${t.data_type ?? "?"}`); | |
| } | |
| } | |
| // A few sample nodes with full scope path (to confirm we kept the names). | |
| const samples = (g.node ?? []).filter((n) => n.name && n.name.includes("/")).slice(0, 3); | |
| if (samples.length) { | |
| console.log(` sample node names:`); | |
| for (const n of samples) console.log(` ${n.op_type ?? "?"} :: ${n.name}`); | |
| } | |
| } | |
| async function main(): Promise<void> { | |
| const argv = process.argv.slice(2); | |
| const targets = argv.length ? argv : DEFAULT_MODELS; | |
| for (const t of targets) { | |
| try { | |
| await benchOne(t); | |
| } catch (err) { | |
| console.error(`\nβ ${t} failed:`, err instanceof Error ? err.message : err); | |
| } | |
| } | |
| } | |
| main().catch((err) => { | |
| console.error(err); | |
| process.exit(1); | |
| }); | |