Spaces:
Sleeping
Sleeping
| // Async protobuf wire-format reader on top of a `ByteSource`. | |
| // | |
| // We implement just what we need to parse ONNX: | |
| // - tag : varint of (field_number << 3) | wire_type | |
| // - varint : LEB128, up to 10 bytes | |
| // - I32/I64: fixed little-endian | |
| // - LEN : length-delimited, varint length followed by N bytes | |
| // | |
| // We expose helpers to read primitives and a `skipField(wireType)` helper | |
| // that can stream-skip large LEN fields via the ByteSource's optimized | |
| // skipBytes. This is the foundation for the ONNX raw_data optimization. | |
| import { ByteSource } from "./byteSource"; | |
| export const WIRE_VARINT = 0; | |
| export const WIRE_I64 = 1; | |
| export const WIRE_LEN = 2; | |
| export const WIRE_SGROUP = 3; // deprecated | |
| export const WIRE_EGROUP = 4; // deprecated | |
| export const WIRE_I32 = 5; | |
| export interface ProtoTag { | |
| field: number; | |
| wire: number; | |
| } | |
| /** Decoded varint up to 64 bits. Values that fit in a JS number are returned as number, otherwise bigint. */ | |
| export type Varint = number | bigint; | |
| export class ProtoReader { | |
| private src: ByteSource; | |
| /** Optional absolute end position. Reads past this throw EOM. */ | |
| private endPos: number; | |
| constructor(src: ByteSource, endPos?: number) { | |
| this.src = src; | |
| this.endPos = endPos ?? src.totalSize; | |
| } | |
| get pos(): number { | |
| return this.src.pos; | |
| } | |
| get end(): number { | |
| return this.endPos; | |
| } | |
| get hasMore(): boolean { | |
| return this.src.pos < this.endPos; | |
| } | |
| /** Create a sub-reader that ends at `start + length`. */ | |
| subreader(length: number): ProtoReader { | |
| return new ProtoReader(this.src, this.src.pos + length); | |
| } | |
| /** Read an unsigned varint. Returns number if it fits, bigint otherwise. */ | |
| async readVarint(): Promise<Varint> { | |
| // Fast path: stay in JS number-land for varints whose data fits in | |
| // ~28 bits. We read up to 4 bytes here (4 × 7 = 28 bits) which is | |
| // safe because `<< 21` still produces a positive 32-bit signed int. | |
| // The fifth byte at shift=28 risks truncating the upper data bits | |
| // (JS bitwise ops are 32-bit), so we switch to BigInt from byte 5 on. | |
| let lo = 0; | |
| for (let i = 0; i < 4; i++) { | |
| const b = await this.src.readByte(); | |
| lo |= (b & 0x7f) << (i * 7); | |
| lo = lo >>> 0; | |
| if ((b & 0x80) === 0) return lo; | |
| } | |
| // Slow path: switch to BigInt for byte 5 onwards. Bit positions | |
| // 28..62 are exactly representable, and `int64 = -1` (10-byte | |
| // varint with all bits set) is handled by `readVarintNumber` which | |
| // re-interprets values above MAX_SAFE_INTEGER as signed int64. | |
| let big = BigInt(lo); | |
| let shift = 28n; | |
| for (let i = 0; i < 6; i++) { | |
| const b = await this.src.readByte(); | |
| big |= BigInt(b & 0x7f) << shift; | |
| shift += 7n; | |
| if ((b & 0x80) === 0) { | |
| if (big <= BigInt(Number.MAX_SAFE_INTEGER)) return Number(big); | |
| return big; | |
| } | |
| } | |
| throw new Error("Varint exceeds 10 bytes."); | |
| } | |
| /** | |
| * Read a varint into a number. | |
| * | |
| * Lengths and small unsigned counters fit easily in JS numbers. For int64 | |
| * fields, protobuf encodes signed values via sign-extended unsigned varints: | |
| * for example, `int64 = -1` is the 10-byte varint `FF FF FF FF FF FF FF FF FF 01` | |
| * which decodes as `2^64 - 1`. We handle that by reinterpreting any varint | |
| * larger than Number.MAX_SAFE_INTEGER as a signed int64 (subtract 2^64) and | |
| * returning it if the result fits in a JS number. | |
| * | |
| * This is correct for every ONNX field we read (dim_value, dims, ir_version, | |
| * model_version, opset versions, lengths) because none of them ever store an | |
| * unsigned magnitude above 2^53. | |
| */ | |
| async readVarintNumber(): Promise<number> { | |
| const v = await this.readVarint(); | |
| if (typeof v === "number") return v; | |
| const TWO_64 = BigInt("18446744073709551616"); | |
| const MAX_SAFE = BigInt(Number.MAX_SAFE_INTEGER); | |
| const MIN_SAFE = BigInt(Number.MIN_SAFE_INTEGER); | |
| let signed = v; | |
| if (v > MAX_SAFE) signed = v - TWO_64; | |
| if (signed >= MIN_SAFE && signed <= MAX_SAFE) return Number(signed); | |
| throw new Error(`Varint out of safe integer range: ${v}`); | |
| } | |
| async readTag(): Promise<ProtoTag> { | |
| const v = await this.readVarintNumber(); | |
| return { field: v >>> 3, wire: v & 0x07 }; | |
| } | |
| async readLengthDelimited(): Promise<Uint8Array> { | |
| const len = await this.readVarintNumber(); | |
| return this.src.readBytes(len); | |
| } | |
| async readString(): Promise<string> { | |
| const bytes = await this.readLengthDelimited(); | |
| return new TextDecoder("utf-8").decode(bytes); | |
| } | |
| async readBytes(n: number): Promise<Uint8Array> { | |
| return this.src.readBytes(n); | |
| } | |
| async readByte(): Promise<number> { | |
| return this.src.readByte(); | |
| } | |
| async readI32LE(): Promise<number> { | |
| const b = await this.src.readBytes(4); | |
| return new DataView(b.buffer, b.byteOffset, 4).getUint32(0, true); | |
| } | |
| async readI64LE(): Promise<bigint> { | |
| const b = await this.src.readBytes(8); | |
| return new DataView(b.buffer, b.byteOffset, 8).getBigUint64(0, true); | |
| } | |
| async readFloat32(): Promise<number> { | |
| const b = await this.src.readBytes(4); | |
| return new DataView(b.buffer, b.byteOffset, 4).getFloat32(0, true); | |
| } | |
| /** Skip a single field of the given wire type. */ | |
| async skipField(wire: number): Promise<void> { | |
| switch (wire) { | |
| case WIRE_VARINT: { | |
| // Read until MSB cleared, max 10 bytes. | |
| for (let i = 0; i < 10; i++) { | |
| const b = await this.src.readByte(); | |
| if ((b & 0x80) === 0) return; | |
| } | |
| throw new Error("Varint runaway during skip."); | |
| } | |
| case WIRE_I64: | |
| await this.src.skipBytes(8); | |
| return; | |
| case WIRE_I32: | |
| await this.src.skipBytes(4); | |
| return; | |
| case WIRE_LEN: { | |
| const len = await this.readVarintNumber(); | |
| await this.src.skipBytes(len); | |
| return; | |
| } | |
| case WIRE_SGROUP: | |
| case WIRE_EGROUP: | |
| throw new Error("Group wire types are not supported (deprecated)."); | |
| default: | |
| throw new Error(`Unknown wire type ${wire}.`); | |
| } | |
| } | |
| /** Convenience: skip everything until `endPos`. */ | |
| async skipToEnd(): Promise<void> { | |
| if (this.src.pos < this.endPos) { | |
| await this.src.skipBytes(this.endPos - this.src.pos); | |
| } | |
| } | |
| } | |