hf-model-viewer / src /lib /protoReader.ts
tfrere's picture
tfrere HF Staff
Deploy hf-model-viewer 2026-05-22T16:59:58Z
fc01079 verified
// Async protobuf wire-format reader on top of a `ByteSource`.
//
// We implement just what we need to parse ONNX:
// - tag : varint of (field_number << 3) | wire_type
// - varint : LEB128, up to 10 bytes
// - I32/I64: fixed little-endian
// - LEN : length-delimited, varint length followed by N bytes
//
// We expose helpers to read primitives and a `skipField(wireType)` helper
// that can stream-skip large LEN fields via the ByteSource's optimized
// skipBytes. This is the foundation for the ONNX raw_data optimization.
import { ByteSource } from "./byteSource";
export const WIRE_VARINT = 0;
export const WIRE_I64 = 1;
export const WIRE_LEN = 2;
export const WIRE_SGROUP = 3; // deprecated
export const WIRE_EGROUP = 4; // deprecated
export const WIRE_I32 = 5;
export interface ProtoTag {
field: number;
wire: number;
}
/** Decoded varint up to 64 bits. Values that fit in a JS number are returned as number, otherwise bigint. */
export type Varint = number | bigint;
export class ProtoReader {
private src: ByteSource;
/** Optional absolute end position. Reads past this throw EOM. */
private endPos: number;
constructor(src: ByteSource, endPos?: number) {
this.src = src;
this.endPos = endPos ?? src.totalSize;
}
get pos(): number {
return this.src.pos;
}
get end(): number {
return this.endPos;
}
get hasMore(): boolean {
return this.src.pos < this.endPos;
}
/** Create a sub-reader that ends at `start + length`. */
subreader(length: number): ProtoReader {
return new ProtoReader(this.src, this.src.pos + length);
}
/** Read an unsigned varint. Returns number if it fits, bigint otherwise. */
async readVarint(): Promise<Varint> {
// Fast path: stay in JS number-land for varints whose data fits in
// ~28 bits. We read up to 4 bytes here (4 × 7 = 28 bits) which is
// safe because `<< 21` still produces a positive 32-bit signed int.
// The fifth byte at shift=28 risks truncating the upper data bits
// (JS bitwise ops are 32-bit), so we switch to BigInt from byte 5 on.
let lo = 0;
for (let i = 0; i < 4; i++) {
const b = await this.src.readByte();
lo |= (b & 0x7f) << (i * 7);
lo = lo >>> 0;
if ((b & 0x80) === 0) return lo;
}
// Slow path: switch to BigInt for byte 5 onwards. Bit positions
// 28..62 are exactly representable, and `int64 = -1` (10-byte
// varint with all bits set) is handled by `readVarintNumber` which
// re-interprets values above MAX_SAFE_INTEGER as signed int64.
let big = BigInt(lo);
let shift = 28n;
for (let i = 0; i < 6; i++) {
const b = await this.src.readByte();
big |= BigInt(b & 0x7f) << shift;
shift += 7n;
if ((b & 0x80) === 0) {
if (big <= BigInt(Number.MAX_SAFE_INTEGER)) return Number(big);
return big;
}
}
throw new Error("Varint exceeds 10 bytes.");
}
/**
* Read a varint into a number.
*
* Lengths and small unsigned counters fit easily in JS numbers. For int64
* fields, protobuf encodes signed values via sign-extended unsigned varints:
* for example, `int64 = -1` is the 10-byte varint `FF FF FF FF FF FF FF FF FF 01`
* which decodes as `2^64 - 1`. We handle that by reinterpreting any varint
* larger than Number.MAX_SAFE_INTEGER as a signed int64 (subtract 2^64) and
* returning it if the result fits in a JS number.
*
* This is correct for every ONNX field we read (dim_value, dims, ir_version,
* model_version, opset versions, lengths) because none of them ever store an
* unsigned magnitude above 2^53.
*/
async readVarintNumber(): Promise<number> {
const v = await this.readVarint();
if (typeof v === "number") return v;
const TWO_64 = BigInt("18446744073709551616");
const MAX_SAFE = BigInt(Number.MAX_SAFE_INTEGER);
const MIN_SAFE = BigInt(Number.MIN_SAFE_INTEGER);
let signed = v;
if (v > MAX_SAFE) signed = v - TWO_64;
if (signed >= MIN_SAFE && signed <= MAX_SAFE) return Number(signed);
throw new Error(`Varint out of safe integer range: ${v}`);
}
async readTag(): Promise<ProtoTag> {
const v = await this.readVarintNumber();
return { field: v >>> 3, wire: v & 0x07 };
}
async readLengthDelimited(): Promise<Uint8Array> {
const len = await this.readVarintNumber();
return this.src.readBytes(len);
}
async readString(): Promise<string> {
const bytes = await this.readLengthDelimited();
return new TextDecoder("utf-8").decode(bytes);
}
async readBytes(n: number): Promise<Uint8Array> {
return this.src.readBytes(n);
}
async readByte(): Promise<number> {
return this.src.readByte();
}
async readI32LE(): Promise<number> {
const b = await this.src.readBytes(4);
return new DataView(b.buffer, b.byteOffset, 4).getUint32(0, true);
}
async readI64LE(): Promise<bigint> {
const b = await this.src.readBytes(8);
return new DataView(b.buffer, b.byteOffset, 8).getBigUint64(0, true);
}
async readFloat32(): Promise<number> {
const b = await this.src.readBytes(4);
return new DataView(b.buffer, b.byteOffset, 4).getFloat32(0, true);
}
/** Skip a single field of the given wire type. */
async skipField(wire: number): Promise<void> {
switch (wire) {
case WIRE_VARINT: {
// Read until MSB cleared, max 10 bytes.
for (let i = 0; i < 10; i++) {
const b = await this.src.readByte();
if ((b & 0x80) === 0) return;
}
throw new Error("Varint runaway during skip.");
}
case WIRE_I64:
await this.src.skipBytes(8);
return;
case WIRE_I32:
await this.src.skipBytes(4);
return;
case WIRE_LEN: {
const len = await this.readVarintNumber();
await this.src.skipBytes(len);
return;
}
case WIRE_SGROUP:
case WIRE_EGROUP:
throw new Error("Group wire types are not supported (deprecated).");
default:
throw new Error(`Unknown wire type ${wire}.`);
}
}
/** Convenience: skip everything until `endPos`. */
async skipToEnd(): Promise<void> {
if (this.src.pos < this.endPos) {
await this.src.skipBytes(this.endPos - this.src.pos);
}
}
}