File size: 6,364 Bytes
fc01079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
// Benchmark the streaming ONNX parser against a few real-world Hugging Face
// models, with no browser involved.
//
// For each model we:
//   1. Probe the Hub for the best ONNX export (external-data preferred).
//   2. Run the streaming parser. Measure bytes actually transferred over
//      the wire vs the full file size.
//   3. Print a summary of the parsed graph (ops, initializers, IO).
//
// Usage:
//   npx tsx scripts/test-stream.ts
//   npx tsx scripts/test-stream.ts Xenova/distilbert-base-uncased
//
// Note: this script imports the same lib code that runs in the browser. It
// only relies on fetch + ReadableStream which are both native in Node β‰₯18.

import { resolveOnnxFile, bytesPretty } from "../src/lib/hub";
import { parseOnnxStream } from "../src/lib/onnxStream";

interface Target {
  label: string;
  resolve: () => Promise<{
    repoId: string;
    filePath: string;
    url: string;
    sizeBytes: number;
    hasExternalData: boolean;
  }>;
}

function asTarget(arg: string): Target {
  if (arg.startsWith("http://") || arg.startsWith("https://")) {
    const url = arg;
    return {
      label: url,
      resolve: async () => {
        const head = await fetch(url, { method: "HEAD" });
        if (!head.ok) throw new Error(`HEAD ${url} failed: ${head.status}`);
        const len = Number(head.headers.get("content-length") ?? 0);
        const linked = Number(head.headers.get("x-linked-size") ?? 0);
        const size = linked || len;
        return {
          repoId: url,
          filePath: url.split("/").slice(-1)[0],
          url,
          sizeBytes: size,
          hasExternalData: false,
        };
      },
    };
  }
  return {
    label: arg,
    resolve: async () => resolveOnnxFile(arg),
  };
}

const DEFAULT_MODELS = [
  "Xenova/all-MiniLM-L6-v2",
  "Xenova/distilbert-base-uncased",
  "https://huggingface.co/Xenova/distilbert-base-uncased/resolve/main/onnx/model.onnx",
  "onnx-community/Llama-3.2-1B",
];

function pct(part: number, total: number): string {
  if (!total) return "n/a";
  return `${((part / total) * 100).toFixed(1)}%`;
}

function topByCount<T>(items: T[], key: (t: T) => string, n: number): [string, number][] {
  const counts = new Map<string, number>();
  for (const it of items) {
    const k = key(it);
    counts.set(k, (counts.get(k) ?? 0) + 1);
  }
  return [...counts.entries()].sort((a, b) => b[1] - a[1]).slice(0, n);
}

async function benchOne(input: string): Promise<void> {
  console.log(`\n────────────────────────────────────────────────────────────`);
  console.log(`β–Ά ${input}`);
  console.log(`────────────────────────────────────────────────────────────`);

  const target = asTarget(input);
  const t0 = performance.now();
  const resolved = await target.resolve();
  const t1 = performance.now();
  console.log(`  resolved in ${(t1 - t0).toFixed(0)} ms`);
  console.log(`  path:            ${resolved.filePath}`);
  console.log(`  .onnx size:      ${bytesPretty(resolved.sizeBytes)}`);
  console.log(`  external data:   ${resolved.hasExternalData ? "yes (sidecar .onnx_data ignored)" : "no (weights inline)"}`);
  console.log(`  url:             ${resolved.url}`);

  let lastReported = 0;
  const onTransfer = (_delta: number, total: number) => {
    if (total - lastReported >= 1024 * 1024) {
      lastReported = total;
      process.stdout.write(
        `\r  streaming…      ${bytesPretty(total)} / ${bytesPretty(resolved.sizeBytes)} ` +
          `(${pct(total, resolved.sizeBytes)})    `,
      );
    }
  };

  const t2 = performance.now();
  const { model, stats } = await parseOnnxStream(resolved.url, { onTransfer });
  const t3 = performance.now();
  process.stdout.write("\r" + " ".repeat(80) + "\r");

  const g = model.graph;
  if (!g) {
    console.error("  βœ— no graph in model");
    return;
  }

  console.log(`  parsed in ${(t3 - t2).toFixed(0)} ms`);
  console.log(`  bytes transferred: ${bytesPretty(stats.bytesTransferred)} / ${bytesPretty(stats.totalFileSize)}`);
  const saved = stats.totalFileSize - stats.bytesTransferred;
  console.log(
    `  bandwidth saved:   ${bytesPretty(Math.max(0, saved))} (${pct(saved, stats.totalFileSize)} of file)`,
  );

  console.log(`  ir_version:       ${model.ir_version ?? "?"}`);
  console.log(`  producer:         ${model.producer_name ?? "?"}`);
  console.log(`  opsets:           ${(model.opset_import ?? []).map((o) => `${o.domain || "ai.onnx"}@${o.version}`).join(", ")}`);
  console.log(`  graph name:       ${g.name ?? "(unnamed)"}`);
  console.log(`  inputs:           ${(g.input ?? []).length}`);
  console.log(`  outputs:          ${(g.output ?? []).length}`);
  console.log(`  initializers:     ${(g.initializer ?? []).length}`);
  console.log(`  value_info nodes: ${(g.value_info ?? []).length}`);
  console.log(`  nodes (ops):      ${(g.node ?? []).length}`);

  const opCounts = topByCount(g.node ?? [], (n) => n.op_type ?? "?", 8);
  console.log(`  top op_types:`);
  for (const [op, c] of opCounts) console.log(`    ${op.padEnd(20)} ${c}`);

  // A few sample initializers (showing we got metadata without bytes).
  const inits = g.initializer ?? [];
  if (inits.length) {
    console.log(`  sample initializers (first 3):`);
    for (const t of inits.slice(0, 3)) {
      const shape = (t.dims ?? []).join("Γ—") || "scalar";
      console.log(`    ${(t.name ?? "?").padEnd(50)} [${shape}]  dtype=${t.data_type ?? "?"}`);
    }
  }

  // A few sample nodes with full scope path (to confirm we kept the names).
  const samples = (g.node ?? []).filter((n) => n.name && n.name.includes("/")).slice(0, 3);
  if (samples.length) {
    console.log(`  sample node names:`);
    for (const n of samples) console.log(`    ${n.op_type ?? "?"} :: ${n.name}`);
  }
}

async function main(): Promise<void> {
  const argv = process.argv.slice(2);
  const targets = argv.length ? argv : DEFAULT_MODELS;
  for (const t of targets) {
    try {
      await benchOne(t);
    } catch (err) {
      console.error(`\nβœ— ${t} failed:`, err instanceof Error ? err.message : err);
    }
  }
}

main().catch((err) => {
  console.error(err);
  process.exit(1);
});