// Async protobuf wire-format reader on top of a `ByteSource`. // // We implement just what we need to parse ONNX: // - tag : varint of (field_number << 3) | wire_type // - varint : LEB128, up to 10 bytes // - I32/I64: fixed little-endian // - LEN : length-delimited, varint length followed by N bytes // // We expose helpers to read primitives and a `skipField(wireType)` helper // that can stream-skip large LEN fields via the ByteSource's optimized // skipBytes. This is the foundation for the ONNX raw_data optimization. import { ByteSource } from "./byteSource"; export const WIRE_VARINT = 0; export const WIRE_I64 = 1; export const WIRE_LEN = 2; export const WIRE_SGROUP = 3; // deprecated export const WIRE_EGROUP = 4; // deprecated export const WIRE_I32 = 5; export interface ProtoTag { field: number; wire: number; } /** Decoded varint up to 64 bits. Values that fit in a JS number are returned as number, otherwise bigint. */ export type Varint = number | bigint; export class ProtoReader { private src: ByteSource; /** Optional absolute end position. Reads past this throw EOM. */ private endPos: number; constructor(src: ByteSource, endPos?: number) { this.src = src; this.endPos = endPos ?? src.totalSize; } get pos(): number { return this.src.pos; } get end(): number { return this.endPos; } get hasMore(): boolean { return this.src.pos < this.endPos; } /** Create a sub-reader that ends at `start + length`. */ subreader(length: number): ProtoReader { return new ProtoReader(this.src, this.src.pos + length); } /** Read an unsigned varint. Returns number if it fits, bigint otherwise. */ async readVarint(): Promise { // Fast path: stay in JS number-land for varints whose data fits in // ~28 bits. We read up to 4 bytes here (4 × 7 = 28 bits) which is // safe because `<< 21` still produces a positive 32-bit signed int. // The fifth byte at shift=28 risks truncating the upper data bits // (JS bitwise ops are 32-bit), so we switch to BigInt from byte 5 on. let lo = 0; for (let i = 0; i < 4; i++) { const b = await this.src.readByte(); lo |= (b & 0x7f) << (i * 7); lo = lo >>> 0; if ((b & 0x80) === 0) return lo; } // Slow path: switch to BigInt for byte 5 onwards. Bit positions // 28..62 are exactly representable, and `int64 = -1` (10-byte // varint with all bits set) is handled by `readVarintNumber` which // re-interprets values above MAX_SAFE_INTEGER as signed int64. let big = BigInt(lo); let shift = 28n; for (let i = 0; i < 6; i++) { const b = await this.src.readByte(); big |= BigInt(b & 0x7f) << shift; shift += 7n; if ((b & 0x80) === 0) { if (big <= BigInt(Number.MAX_SAFE_INTEGER)) return Number(big); return big; } } throw new Error("Varint exceeds 10 bytes."); } /** * Read a varint into a number. * * Lengths and small unsigned counters fit easily in JS numbers. For int64 * fields, protobuf encodes signed values via sign-extended unsigned varints: * for example, `int64 = -1` is the 10-byte varint `FF FF FF FF FF FF FF FF FF 01` * which decodes as `2^64 - 1`. We handle that by reinterpreting any varint * larger than Number.MAX_SAFE_INTEGER as a signed int64 (subtract 2^64) and * returning it if the result fits in a JS number. * * This is correct for every ONNX field we read (dim_value, dims, ir_version, * model_version, opset versions, lengths) because none of them ever store an * unsigned magnitude above 2^53. */ async readVarintNumber(): Promise { const v = await this.readVarint(); if (typeof v === "number") return v; const TWO_64 = BigInt("18446744073709551616"); const MAX_SAFE = BigInt(Number.MAX_SAFE_INTEGER); const MIN_SAFE = BigInt(Number.MIN_SAFE_INTEGER); let signed = v; if (v > MAX_SAFE) signed = v - TWO_64; if (signed >= MIN_SAFE && signed <= MAX_SAFE) return Number(signed); throw new Error(`Varint out of safe integer range: ${v}`); } async readTag(): Promise { const v = await this.readVarintNumber(); return { field: v >>> 3, wire: v & 0x07 }; } async readLengthDelimited(): Promise { const len = await this.readVarintNumber(); return this.src.readBytes(len); } async readString(): Promise { const bytes = await this.readLengthDelimited(); return new TextDecoder("utf-8").decode(bytes); } async readBytes(n: number): Promise { return this.src.readBytes(n); } async readByte(): Promise { return this.src.readByte(); } async readI32LE(): Promise { const b = await this.src.readBytes(4); return new DataView(b.buffer, b.byteOffset, 4).getUint32(0, true); } async readI64LE(): Promise { const b = await this.src.readBytes(8); return new DataView(b.buffer, b.byteOffset, 8).getBigUint64(0, true); } async readFloat32(): Promise { const b = await this.src.readBytes(4); return new DataView(b.buffer, b.byteOffset, 4).getFloat32(0, true); } /** Skip a single field of the given wire type. */ async skipField(wire: number): Promise { switch (wire) { case WIRE_VARINT: { // Read until MSB cleared, max 10 bytes. for (let i = 0; i < 10; i++) { const b = await this.src.readByte(); if ((b & 0x80) === 0) return; } throw new Error("Varint runaway during skip."); } case WIRE_I64: await this.src.skipBytes(8); return; case WIRE_I32: await this.src.skipBytes(4); return; case WIRE_LEN: { const len = await this.readVarintNumber(); await this.src.skipBytes(len); return; } case WIRE_SGROUP: case WIRE_EGROUP: throw new Error("Group wire types are not supported (deprecated)."); default: throw new Error(`Unknown wire type ${wire}.`); } } /** Convenience: skip everything until `endPos`. */ async skipToEnd(): Promise { if (this.src.pos < this.endPos) { await this.src.skipBytes(this.endPos - this.src.pos); } } }