Spaces:
Sleeping
Sleeping
File size: 12,713 Bytes
fc01079 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 | // A seekable, streaming byte source backed by HTTP Range requests.
//
// The streaming ONNX parser uses this to:
// 1. Read protobuf wire bytes sequentially.
// 2. Skip over uninteresting regions (notably `raw_data` of initializers)
// without buffering them in RAM and, when the skip is large enough,
// without even pulling them from the network β we cancel the current
// stream and issue a fresh Range request that jumps over the region.
//
// Design goals:
// - Always know the current absolute offset in the file.
// - Pull more data from the network only when needed.
// - When a skip is small (< SKIP_RANGE_THRESHOLD), discard from the
// in-flight stream; when large, cancel and re-issue Range from the
// post-skip offset (saves bandwidth).
// - Report total bytes actually transferred so the UI can show savings.
// - Be cancellable from the outside (AbortController).
export interface ByteSourceOpts {
/** Range threshold (bytes) above which we cancel + re-issue instead of stream-skipping. */
skipRangeThreshold?: number;
/** Optional AbortSignal to cancel the entire source. */
signal?: AbortSignal;
/** Called every time we receive more data. Useful for UI progress. */
onTransfer?: (deltaBytes: number, totalTransferred: number) => void;
}
// Skip-vs-stream break-even. Below this, streaming-skip is cheaper than
// paying a fresh Range request RTT (~30β200 ms on Hub CDN, throughput
// typically 10β30 MB/s).
//
// We pick the threshold adaptively from the total file size β see
// `pickAdaptiveThreshold()` below β because the right answer is very
// different between a 50 MB quantized GPT-2 (lots of ~350 KB int8 tensors,
// we *want* to Range-jump even small blobs) and a 5 GB fp16 Llama (we want
// to amortize fewer, bigger jumps).
//
// 1 MB is the safe ceiling: above that, paying a fresh RTT is always
// worthwhile because we'd otherwise download a megabyte of dead bytes.
const SKIP_RANGE_CEILING = 1024 * 1024; // 1 MB
const SKIP_RANGE_FLOOR = 32 * 1024; // 32 KB
/**
* Pick a per-file skip threshold from the total file size.
*
* Rules of thumb the formula encodes:
* - Tiny files (< 50 MB, typical quantized BERT-class): aggressive 32 KB
* floor. We accept dozens of RTTs because the total file is small enough
* that each saved KB matters and TLS reuse keeps each RTT cheap.
* - Medium (50 MB β 500 MB, quantized GPT-2 / DistilBERT / ViT class):
* interpolate up to ~128 KB. We still want to skip most quantized
* tensors which sit in the 200β800 KB range.
* - Large (β₯ 500 MB, fp16/fp32 LLaMA-class): full 1 MB ceiling. Tensors
* are big enough that RTT cost is dwarfed by the bandwidth save.
*/
export function pickAdaptiveThreshold(totalSize: number): number {
if (totalSize <= 0) return SKIP_RANGE_CEILING;
if (totalSize < 50 * 1024 * 1024) return SKIP_RANGE_FLOOR;
if (totalSize >= 500 * 1024 * 1024) return SKIP_RANGE_CEILING;
// Linear ramp 32 KB β 1 MB over the 50 MB β 500 MB band.
const t = (totalSize - 50 * 1024 * 1024) / (450 * 1024 * 1024);
const v = SKIP_RANGE_FLOOR + t * (SKIP_RANGE_CEILING - SKIP_RANGE_FLOOR);
return Math.round(v);
}
export class ByteSource {
readonly url: string;
readonly totalSize: number;
private opts: Required<Omit<ByteSourceOpts, "signal" | "onTransfer">> & {
signal?: AbortSignal;
onTransfer?: ByteSourceOpts["onTransfer"];
};
private reader: ReadableStreamDefaultReader<Uint8Array> | null = null;
/** Absolute offset where the current Range request is supposed to end (exclusive). */
private streamEnd = 0;
/** A small ring of buffered chunks pulled from the reader but not yet consumed. */
private buffer: Uint8Array[] = [];
private bufferStart = 0; // absolute offset of buffer[0][0]
/** Offset within buffer[0]. */
private bufferHead = 0;
/** Total bytes currently held across all buffer chunks. */
private bufferedBytes = 0;
/** Absolute offset of the next byte to be returned to the caller. */
pos = 0;
/** Total bytes received over the network so far. */
bytesTransferred = 0;
constructor(url: string, totalSize: number, opts: ByteSourceOpts = {}) {
this.url = url;
this.totalSize = totalSize;
this.opts = {
skipRangeThreshold:
opts.skipRangeThreshold ?? pickAdaptiveThreshold(totalSize),
signal: opts.signal,
onTransfer: opts.onTransfer,
};
}
/** Exposed for the UI: what threshold are we actually using? */
get skipRangeThreshold(): number {
return this.opts.skipRangeThreshold;
}
/** Bytes available to read without pulling more from the network. */
private get availableInBuffer(): number {
return this.bufferedBytes - this.bufferHead;
}
/** Open a Range stream starting at `from` (absolute offset), ending at end (exclusive, defaults to totalSize). */
private async openRangeStream(from: number, end?: number): Promise<void> {
if (this.reader) await this.cancelStream();
const stop = end ?? this.totalSize;
if (from >= stop) {
this.streamEnd = stop;
return;
}
const rangeHeader = `bytes=${from}-${stop - 1}`;
const res = await fetch(this.url, {
headers: { Range: rangeHeader },
signal: this.opts.signal,
});
if (!(res.status === 206 || (res.status === 200 && from === 0))) {
throw new Error(
`Range request failed (status ${res.status}). The CDN may not support Range; falling back is not implemented.`,
);
}
if (!res.body) throw new Error("Response has no body.");
this.reader = res.body.getReader();
this.streamEnd = stop;
}
private async cancelStream(): Promise<void> {
if (!this.reader) return;
try {
await this.reader.cancel();
} catch {
// ignored: cancel failures are non-fatal.
}
this.reader = null;
}
/** Make sure at least `n` bytes are available in the buffer, pulling from the stream as needed. */
private async ensureAvailable(n: number): Promise<void> {
if (this.pos + n > this.totalSize) {
throw new Error(
`Read past end of file: pos=${this.pos}, want=${n}, total=${this.totalSize}`,
);
}
if (!this.reader || this.pos < this.bufferStart + this.bufferHead) {
// Stream is closed or seek happened β (re)open at current pos.
await this.openRangeStream(this.pos);
this.resetBuffer(this.pos);
}
while (this.availableInBuffer < n) {
if (!this.reader) {
// Stream ended but we expected more.
if (this.pos + this.availableInBuffer >= this.streamEnd) {
if (this.streamEnd < this.totalSize) {
await this.openRangeStream(this.pos + this.availableInBuffer);
continue;
}
}
throw new Error("Stream ended unexpectedly before requested bytes arrived.");
}
const { done, value } = await this.reader.read();
if (done) {
this.reader = null;
if (this.pos + this.availableInBuffer >= this.streamEnd && this.streamEnd < this.totalSize) {
await this.openRangeStream(this.pos + this.availableInBuffer);
continue;
}
if (this.availableInBuffer < n) {
throw new Error("Stream EOF before requested bytes arrived.");
}
break;
}
if (!value || value.byteLength === 0) continue;
this.buffer.push(value);
this.bufferedBytes += value.byteLength;
this.bytesTransferred += value.byteLength;
this.opts.onTransfer?.(value.byteLength, this.bytesTransferred);
}
}
private resetBuffer(startPos: number): void {
this.buffer = [];
this.bufferedBytes = 0;
this.bufferHead = 0;
this.bufferStart = startPos;
}
/** Read exactly `n` bytes into a fresh Uint8Array (concatenating internal chunks if needed). */
async readBytes(n: number): Promise<Uint8Array> {
if (n <= 0) return new Uint8Array(0);
await this.ensureAvailable(n);
const out = new Uint8Array(n);
let written = 0;
while (written < n) {
const head = this.buffer[0];
const headRemaining = head.byteLength - this.bufferHead;
const take = Math.min(headRemaining, n - written);
out.set(head.subarray(this.bufferHead, this.bufferHead + take), written);
written += take;
this.bufferHead += take;
if (this.bufferHead >= head.byteLength) {
this.buffer.shift();
this.bufferedBytes -= head.byteLength;
this.bufferStart += head.byteLength;
this.bufferHead = 0;
}
}
this.pos += n;
return out;
}
/** Read a single byte (faster path for varint decoding). */
async readByte(): Promise<number> {
if (this.availableInBuffer === 0) await this.ensureAvailable(1);
const head = this.buffer[0];
const b = head[this.bufferHead++];
if (this.bufferHead >= head.byteLength) {
this.buffer.shift();
this.bufferedBytes -= head.byteLength;
this.bufferStart += head.byteLength;
this.bufferHead = 0;
}
this.pos += 1;
return b;
}
/**
* Advance the read position by `n` bytes without exposing them to the caller.
* - If n <= skipRangeThreshold, we still pull/discard those bytes from the
* current stream (cheap, no RTT).
* - If n > skipRangeThreshold, we cancel the current stream and jump
* forward via a new Range request. This is how we save real bandwidth
* on big initializers.
*/
async skipBytes(n: number): Promise<void> {
if (n <= 0) return;
if (this.pos + n > this.totalSize) {
throw new Error(
`Skip past end of file: pos=${this.pos}, skip=${n}, total=${this.totalSize}`,
);
}
// First, eat what we already have buffered.
if (this.availableInBuffer > 0) {
const eat = Math.min(this.availableInBuffer, n);
let remaining = eat;
while (remaining > 0) {
const head = this.buffer[0];
const headRemaining = head.byteLength - this.bufferHead;
const take = Math.min(headRemaining, remaining);
this.bufferHead += take;
remaining -= take;
if (this.bufferHead >= head.byteLength) {
this.buffer.shift();
this.bufferedBytes -= head.byteLength;
this.bufferStart += head.byteLength;
this.bufferHead = 0;
}
}
this.pos += eat;
n -= eat;
if (n === 0) return;
}
if (n > this.opts.skipRangeThreshold) {
// Big skip: cancel current stream, jump forward via new Range.
await this.cancelStream();
this.resetBuffer(this.pos + n);
this.pos += n;
return;
}
// Small skip: stream-discard.
while (n > 0) {
if (!this.reader) {
await this.openRangeStream(this.pos);
this.resetBuffer(this.pos);
continue;
}
const { done, value } = await this.reader.read();
if (done) {
this.reader = null;
if (this.streamEnd < this.totalSize) {
await this.openRangeStream(this.pos);
continue;
}
throw new Error("Stream EOF during skip.");
}
if (!value || value.byteLength === 0) continue;
this.bytesTransferred += value.byteLength;
this.opts.onTransfer?.(value.byteLength, this.bytesTransferred);
if (value.byteLength <= n) {
// Entire chunk discarded.
this.pos += value.byteLength;
n -= value.byteLength;
} else {
// Partial chunk: keep the unused tail in the buffer.
const tail = value.subarray(n);
this.pos += n;
this.bufferStart = this.pos;
this.bufferHead = 0;
this.buffer = [tail];
this.bufferedBytes = tail.byteLength;
n = 0;
}
}
}
/** Cancel any pending network and free buffers. */
async close(): Promise<void> {
await this.cancelStream();
this.buffer = [];
this.bufferedBytes = 0;
this.bufferHead = 0;
}
/** Pre-fetch the total length via HEAD if not known at construction. */
static async create(url: string, opts: ByteSourceOpts = {}): Promise<ByteSource> {
const head = await fetch(url, { method: "HEAD", signal: opts.signal });
if (!head.ok) throw new Error(`HEAD ${url} failed: ${head.status}`);
const len = Number(head.headers.get("content-length") ?? 0);
if (!Number.isFinite(len) || len <= 0) {
throw new Error(`Server did not return a usable Content-Length for ${url}.`);
}
const acceptRanges = head.headers.get("accept-ranges");
if (acceptRanges && acceptRanges.toLowerCase() === "none") {
throw new Error(`Server says it does not support Range requests for ${url}.`);
}
return new ByteSource(url, len, opts);
}
}
|