File size: 12,713 Bytes
fc01079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
// A seekable, streaming byte source backed by HTTP Range requests.
//
// The streaming ONNX parser uses this to:
//   1. Read protobuf wire bytes sequentially.
//   2. Skip over uninteresting regions (notably `raw_data` of initializers)
//      without buffering them in RAM and, when the skip is large enough,
//      without even pulling them from the network β€” we cancel the current
//      stream and issue a fresh Range request that jumps over the region.
//
// Design goals:
//   - Always know the current absolute offset in the file.
//   - Pull more data from the network only when needed.
//   - When a skip is small (< SKIP_RANGE_THRESHOLD), discard from the
//     in-flight stream; when large, cancel and re-issue Range from the
//     post-skip offset (saves bandwidth).
//   - Report total bytes actually transferred so the UI can show savings.
//   - Be cancellable from the outside (AbortController).

export interface ByteSourceOpts {
  /** Range threshold (bytes) above which we cancel + re-issue instead of stream-skipping. */
  skipRangeThreshold?: number;
  /** Optional AbortSignal to cancel the entire source. */
  signal?: AbortSignal;
  /** Called every time we receive more data. Useful for UI progress. */
  onTransfer?: (deltaBytes: number, totalTransferred: number) => void;
}

// Skip-vs-stream break-even. Below this, streaming-skip is cheaper than
// paying a fresh Range request RTT (~30–200 ms on Hub CDN, throughput
// typically 10–30 MB/s).
//
// We pick the threshold adaptively from the total file size β€” see
// `pickAdaptiveThreshold()` below β€” because the right answer is very
// different between a 50 MB quantized GPT-2 (lots of ~350 KB int8 tensors,
// we *want* to Range-jump even small blobs) and a 5 GB fp16 Llama (we want
// to amortize fewer, bigger jumps).
//
// 1 MB is the safe ceiling: above that, paying a fresh RTT is always
// worthwhile because we'd otherwise download a megabyte of dead bytes.
const SKIP_RANGE_CEILING = 1024 * 1024; // 1 MB
const SKIP_RANGE_FLOOR = 32 * 1024; // 32 KB

/**
 * Pick a per-file skip threshold from the total file size.
 *
 * Rules of thumb the formula encodes:
 *  - Tiny files (< 50 MB, typical quantized BERT-class): aggressive 32 KB
 *    floor. We accept dozens of RTTs because the total file is small enough
 *    that each saved KB matters and TLS reuse keeps each RTT cheap.
 *  - Medium (50 MB – 500 MB, quantized GPT-2 / DistilBERT / ViT class):
 *    interpolate up to ~128 KB. We still want to skip most quantized
 *    tensors which sit in the 200–800 KB range.
 *  - Large (β‰₯ 500 MB, fp16/fp32 LLaMA-class): full 1 MB ceiling. Tensors
 *    are big enough that RTT cost is dwarfed by the bandwidth save.
 */
export function pickAdaptiveThreshold(totalSize: number): number {
  if (totalSize <= 0) return SKIP_RANGE_CEILING;
  if (totalSize < 50 * 1024 * 1024) return SKIP_RANGE_FLOOR;
  if (totalSize >= 500 * 1024 * 1024) return SKIP_RANGE_CEILING;
  // Linear ramp 32 KB β†’ 1 MB over the 50 MB β†’ 500 MB band.
  const t = (totalSize - 50 * 1024 * 1024) / (450 * 1024 * 1024);
  const v = SKIP_RANGE_FLOOR + t * (SKIP_RANGE_CEILING - SKIP_RANGE_FLOOR);
  return Math.round(v);
}

export class ByteSource {
  readonly url: string;
  readonly totalSize: number;
  private opts: Required<Omit<ByteSourceOpts, "signal" | "onTransfer">> & {
    signal?: AbortSignal;
    onTransfer?: ByteSourceOpts["onTransfer"];
  };

  private reader: ReadableStreamDefaultReader<Uint8Array> | null = null;
  /** Absolute offset where the current Range request is supposed to end (exclusive). */
  private streamEnd = 0;

  /** A small ring of buffered chunks pulled from the reader but not yet consumed. */
  private buffer: Uint8Array[] = [];
  private bufferStart = 0; // absolute offset of buffer[0][0]
  /** Offset within buffer[0]. */
  private bufferHead = 0;
  /** Total bytes currently held across all buffer chunks. */
  private bufferedBytes = 0;

  /** Absolute offset of the next byte to be returned to the caller. */
  pos = 0;

  /** Total bytes received over the network so far. */
  bytesTransferred = 0;

  constructor(url: string, totalSize: number, opts: ByteSourceOpts = {}) {
    this.url = url;
    this.totalSize = totalSize;
    this.opts = {
      skipRangeThreshold:
        opts.skipRangeThreshold ?? pickAdaptiveThreshold(totalSize),
      signal: opts.signal,
      onTransfer: opts.onTransfer,
    };
  }

  /** Exposed for the UI: what threshold are we actually using? */
  get skipRangeThreshold(): number {
    return this.opts.skipRangeThreshold;
  }

  /** Bytes available to read without pulling more from the network. */
  private get availableInBuffer(): number {
    return this.bufferedBytes - this.bufferHead;
  }

  /** Open a Range stream starting at `from` (absolute offset), ending at end (exclusive, defaults to totalSize). */
  private async openRangeStream(from: number, end?: number): Promise<void> {
    if (this.reader) await this.cancelStream();
    const stop = end ?? this.totalSize;
    if (from >= stop) {
      this.streamEnd = stop;
      return;
    }
    const rangeHeader = `bytes=${from}-${stop - 1}`;
    const res = await fetch(this.url, {
      headers: { Range: rangeHeader },
      signal: this.opts.signal,
    });
    if (!(res.status === 206 || (res.status === 200 && from === 0))) {
      throw new Error(
        `Range request failed (status ${res.status}). The CDN may not support Range; falling back is not implemented.`,
      );
    }
    if (!res.body) throw new Error("Response has no body.");
    this.reader = res.body.getReader();
    this.streamEnd = stop;
  }

  private async cancelStream(): Promise<void> {
    if (!this.reader) return;
    try {
      await this.reader.cancel();
    } catch {
      // ignored: cancel failures are non-fatal.
    }
    this.reader = null;
  }

  /** Make sure at least `n` bytes are available in the buffer, pulling from the stream as needed. */
  private async ensureAvailable(n: number): Promise<void> {
    if (this.pos + n > this.totalSize) {
      throw new Error(
        `Read past end of file: pos=${this.pos}, want=${n}, total=${this.totalSize}`,
      );
    }
    if (!this.reader || this.pos < this.bufferStart + this.bufferHead) {
      // Stream is closed or seek happened β€” (re)open at current pos.
      await this.openRangeStream(this.pos);
      this.resetBuffer(this.pos);
    }
    while (this.availableInBuffer < n) {
      if (!this.reader) {
        // Stream ended but we expected more.
        if (this.pos + this.availableInBuffer >= this.streamEnd) {
          if (this.streamEnd < this.totalSize) {
            await this.openRangeStream(this.pos + this.availableInBuffer);
            continue;
          }
        }
        throw new Error("Stream ended unexpectedly before requested bytes arrived.");
      }
      const { done, value } = await this.reader.read();
      if (done) {
        this.reader = null;
        if (this.pos + this.availableInBuffer >= this.streamEnd && this.streamEnd < this.totalSize) {
          await this.openRangeStream(this.pos + this.availableInBuffer);
          continue;
        }
        if (this.availableInBuffer < n) {
          throw new Error("Stream EOF before requested bytes arrived.");
        }
        break;
      }
      if (!value || value.byteLength === 0) continue;
      this.buffer.push(value);
      this.bufferedBytes += value.byteLength;
      this.bytesTransferred += value.byteLength;
      this.opts.onTransfer?.(value.byteLength, this.bytesTransferred);
    }
  }

  private resetBuffer(startPos: number): void {
    this.buffer = [];
    this.bufferedBytes = 0;
    this.bufferHead = 0;
    this.bufferStart = startPos;
  }

  /** Read exactly `n` bytes into a fresh Uint8Array (concatenating internal chunks if needed). */
  async readBytes(n: number): Promise<Uint8Array> {
    if (n <= 0) return new Uint8Array(0);
    await this.ensureAvailable(n);
    const out = new Uint8Array(n);
    let written = 0;
    while (written < n) {
      const head = this.buffer[0];
      const headRemaining = head.byteLength - this.bufferHead;
      const take = Math.min(headRemaining, n - written);
      out.set(head.subarray(this.bufferHead, this.bufferHead + take), written);
      written += take;
      this.bufferHead += take;
      if (this.bufferHead >= head.byteLength) {
        this.buffer.shift();
        this.bufferedBytes -= head.byteLength;
        this.bufferStart += head.byteLength;
        this.bufferHead = 0;
      }
    }
    this.pos += n;
    return out;
  }

  /** Read a single byte (faster path for varint decoding). */
  async readByte(): Promise<number> {
    if (this.availableInBuffer === 0) await this.ensureAvailable(1);
    const head = this.buffer[0];
    const b = head[this.bufferHead++];
    if (this.bufferHead >= head.byteLength) {
      this.buffer.shift();
      this.bufferedBytes -= head.byteLength;
      this.bufferStart += head.byteLength;
      this.bufferHead = 0;
    }
    this.pos += 1;
    return b;
  }

  /**
   * Advance the read position by `n` bytes without exposing them to the caller.
   * - If n <= skipRangeThreshold, we still pull/discard those bytes from the
   *   current stream (cheap, no RTT).
   * - If n > skipRangeThreshold, we cancel the current stream and jump
   *   forward via a new Range request. This is how we save real bandwidth
   *   on big initializers.
   */
  async skipBytes(n: number): Promise<void> {
    if (n <= 0) return;
    if (this.pos + n > this.totalSize) {
      throw new Error(
        `Skip past end of file: pos=${this.pos}, skip=${n}, total=${this.totalSize}`,
      );
    }

    // First, eat what we already have buffered.
    if (this.availableInBuffer > 0) {
      const eat = Math.min(this.availableInBuffer, n);
      let remaining = eat;
      while (remaining > 0) {
        const head = this.buffer[0];
        const headRemaining = head.byteLength - this.bufferHead;
        const take = Math.min(headRemaining, remaining);
        this.bufferHead += take;
        remaining -= take;
        if (this.bufferHead >= head.byteLength) {
          this.buffer.shift();
          this.bufferedBytes -= head.byteLength;
          this.bufferStart += head.byteLength;
          this.bufferHead = 0;
        }
      }
      this.pos += eat;
      n -= eat;
      if (n === 0) return;
    }

    if (n > this.opts.skipRangeThreshold) {
      // Big skip: cancel current stream, jump forward via new Range.
      await this.cancelStream();
      this.resetBuffer(this.pos + n);
      this.pos += n;
      return;
    }

    // Small skip: stream-discard.
    while (n > 0) {
      if (!this.reader) {
        await this.openRangeStream(this.pos);
        this.resetBuffer(this.pos);
        continue;
      }
      const { done, value } = await this.reader.read();
      if (done) {
        this.reader = null;
        if (this.streamEnd < this.totalSize) {
          await this.openRangeStream(this.pos);
          continue;
        }
        throw new Error("Stream EOF during skip.");
      }
      if (!value || value.byteLength === 0) continue;
      this.bytesTransferred += value.byteLength;
      this.opts.onTransfer?.(value.byteLength, this.bytesTransferred);
      if (value.byteLength <= n) {
        // Entire chunk discarded.
        this.pos += value.byteLength;
        n -= value.byteLength;
      } else {
        // Partial chunk: keep the unused tail in the buffer.
        const tail = value.subarray(n);
        this.pos += n;
        this.bufferStart = this.pos;
        this.bufferHead = 0;
        this.buffer = [tail];
        this.bufferedBytes = tail.byteLength;
        n = 0;
      }
    }
  }

  /** Cancel any pending network and free buffers. */
  async close(): Promise<void> {
    await this.cancelStream();
    this.buffer = [];
    this.bufferedBytes = 0;
    this.bufferHead = 0;
  }

  /** Pre-fetch the total length via HEAD if not known at construction. */
  static async create(url: string, opts: ByteSourceOpts = {}): Promise<ByteSource> {
    const head = await fetch(url, { method: "HEAD", signal: opts.signal });
    if (!head.ok) throw new Error(`HEAD ${url} failed: ${head.status}`);
    const len = Number(head.headers.get("content-length") ?? 0);
    if (!Number.isFinite(len) || len <= 0) {
      throw new Error(`Server did not return a usable Content-Length for ${url}.`);
    }
    const acceptRanges = head.headers.get("accept-ranges");
    if (acceptRanges && acceptRanges.toLowerCase() === "none") {
      throw new Error(`Server says it does not support Range requests for ${url}.`);
    }
    return new ByteSource(url, len, opts);
  }
}