Spaces:
Running
Running
File size: 9,431 Bytes
3f22414 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | // Unified session loader for inference engines. One entry point, two arms.
//
// loadSession(bytes, intent, opts) -> { session, realizedBackend }
//
// intent 'gpu' | 'cpu' — what the user wants
// realizedBackend 'web-webgpu' | 'web-wasm' — what actually ran
// | 'native-cpu' | 'native-coreml/...' — (display only)
// | 'native-dml' | 'native-cuda' | ...
//
// Callers (the engines) shouldn't need to know whether we're running native
// (Electron + onnxruntime-node) or web (browser + ort-web). They pass intent
// in; they get a session-shaped object out plus a string label of what it's
// running on. No magic options-bag properties, no synthesised exceptions to
// drive control flow, no monkey-patching of ORT-Web's surface.
import { dispatchBackendEvent, shortenReason } from 'lib/backend-events';
// ───── Mode detection ──────────────────────────────────────────────────────
// `__nativeOrt` is installed by desktop/preload.cjs only when the renderer is
// running inside Electron *and* AITOOLS_NATIVE isn't explicitly disabled. Its
// presence is the entire signal — no `enabled` flag to interrogate.
export function isNativeMode() {
return !!globalThis.__nativeOrt;
}
// ───── Auto-disable on worker crash ────────────────────────────────────────
// If the native worker crashes mid-session, we trip a one-way switch so the
// rest of this page's loads fall through to the web path. Same semantics as
// the old inject.js `tripAutoDisable`, just lifted here.
let nativeAutoDisabled = false;
function isWorkerCrash(err) {
const m = err && err.message;
return !!m && /worker crashed|worker not available|native ort unavailable/i.test(m);
}
// ───── Wire tensor codec ───────────────────────────────────────────────────
// Mirrors WIRE_DTYPES in ort-worker.cjs.
const WIRE_DTYPES = {
float32: Float32Array, float16: Uint16Array,
int32: Int32Array, int64: BigInt64Array, uint8: Uint8Array,
};
// ───── Public entry point ──────────────────────────────────────────────────
/**
* Load a model and return a session-shaped object plus the realized backend.
*
* @param {Uint8Array|ArrayBuffer} modelBytes
* @param {'gpu'|'cpu'} intent
* @param {{
* profile?: boolean,
* preferredOutputLocation?: 'gpu-buffer' | 'cpu',
* }} [opts]
* `preferredOutputLocation` is forwarded only on the web-webgpu path; it
* enables zero-readback output tensors for the upscaler's GPU fast path.
* Ignored on web-wasm and on native.
* @returns {Promise<{ session: object, realizedBackend: string }>}
*/
export async function loadSession(modelBytes, intent, opts = {}) {
if (intent !== 'gpu' && intent !== 'cpu') {
throw new Error(`loadSession: unknown intent ${JSON.stringify(intent)} (expected 'gpu' or 'cpu')`);
}
if (isNativeMode() && !nativeAutoDisabled) {
try {
return await loadNative(modelBytes, intent);
} catch (e) {
if (isWorkerCrash(e)) {
nativeAutoDisabled = true;
console.warn(`[backend] native ORT auto-disabled for this page session: ${e.message} Future loads use ORT-Web. Reload to retry native.`);
// Fall through to web path so this load still has a chance.
} else {
throw e;
}
}
}
return loadWeb(modelBytes, intent, opts);
}
// ───── Native arm ──────────────────────────────────────────────────────────
let nativeSeq = 0;
async function loadNative(modelBytes, intent) {
const transferable = toArrayBuffer(modelBytes);
const key = `m${++nativeSeq}_${transferable.byteLength}`;
// Host emits attempt/fallback/skipped events via the model-event channel
// (forwarded to backend-events by desktop/inject.js). We don't synthesise
// an attempt here — let the worker's actual rung outcomes speak.
const meta = await globalThis.__nativeOrt.load(key, transferable, { intent });
console.log(`[backend] native session ${key}: ${meta.inputNames.join(',')} -> ${meta.outputNames.join(',')} via ${meta.rung}`);
return {
session: makeNativeSession(key, meta),
realizedBackend: `native-${meta.rung}`,
};
}
function makeNativeSession(key, meta) {
// ORT-Web sessions expose inputMetadata / outputMetadata as arrays of
// {name, type, dimensions}; the engines self-correct from dims so a narrow
// 'tensor(float)' placeholder is fine.
const inputMetadata = meta.inputNames.map(name => ({ name, type: 'tensor(float)', dimensions: [] }));
const outputMetadata = meta.outputNames.map(name => ({ name, type: 'tensor(float)', dimensions: [] }));
return {
inputNames: meta.inputNames,
outputNames: meta.outputNames,
inputMetadata,
outputMetadata,
async run(feeds /*, runOptions */) {
const wire = {};
for (const [name, t] of Object.entries(feeds)) {
const data = t.data;
const ab = data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength);
wire[name] = { type: t.type, dims: t.dims, data: ab };
}
let raw;
try {
raw = await globalThis.__nativeOrt.run(key, wire);
} catch (e) {
if (isWorkerCrash(e)) {
// Native is now dead for this page. The caller's current run still
// fails (we can't reload mid-session-run), but subsequent loadSession
// calls will fall through to the web path via nativeAutoDisabled.
nativeAutoDisabled = true;
console.warn(`[backend] native ORT auto-disabled mid-run: ${e.message}`);
}
throw e;
}
const out = {};
for (const [name, t] of Object.entries(raw)) {
const Arr = WIRE_DTYPES[t.type];
if (!Arr) throw new Error(`[backend] unsupported output tensor type: ${t.type}`);
const ortGlobal = globalThis.ort;
const tensor = new ortGlobal.Tensor(t.type, new Arr(t.data), t.dims);
if (typeof tensor.dispose !== 'function') tensor.dispose = () => {};
out[name] = tensor;
}
return out;
},
async release() {
try { await globalThis.__nativeOrt.release(key); } catch {}
},
startProfiling() {},
endProfiling() {},
};
}
// ───── Web arm ─────────────────────────────────────────────────────────────
async function loadWeb(modelBytes, intent, { profile = false, preferredOutputLocation } = {}) {
const ort = globalThis.ort;
if (!ort) throw new Error('[backend] ort-web is not loaded — include vendor/onnxruntime-web/ort.all.min.js before using loadSession');
ort.env.wasm.wasmPaths =
globalThis.__ORT_WASM_PATHS__ ||
new URL('vendor/onnxruntime-web/', document.baseURI).toString();
ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
// ort.env.webgpu.profilingMode is global; clear when not profiling so a
// prior run can't leave it stuck on.
if (ort.env.webgpu) {
ort.env.webgpu.profilingMode = (profile && intent === 'gpu') ? 'default' : 'off';
}
const sessionOpts = {
graphOptimizationLevel: 'all',
...(profile && { enableProfiling: true }),
};
if (intent === 'gpu') {
sessionOpts.executionProviders = [{ name: 'webgpu', preferredLayout: 'NCHW' }];
if (preferredOutputLocation) sessionOpts.preferredOutputLocation = preferredOutputLocation;
dispatchBackendEvent({ kind: 'attempt', backend: 'web-webgpu' });
try {
const session = await ort.InferenceSession.create(modelBytes, sessionOpts);
dispatchBackendEvent({ kind: 'success', backend: 'web-webgpu' });
return { session, realizedBackend: 'web-webgpu' };
} catch (e) {
console.warn(`[backend] WebGPU failed, falling back to WASM. Reason:`, e);
dispatchBackendEvent({ kind: 'fallback', backend: 'web-webgpu', reason: shortenReason(e) });
// Strip WebGPU-only opts before retrying on WASM.
delete sessionOpts.preferredOutputLocation;
}
}
sessionOpts.executionProviders = ['wasm'];
dispatchBackendEvent({ kind: 'attempt', backend: 'web-wasm' });
const session = await ort.InferenceSession.create(modelBytes, sessionOpts);
dispatchBackendEvent({ kind: 'success', backend: 'web-wasm' });
return { session, realizedBackend: 'web-wasm' };
}
// ───── helpers ─────────────────────────────────────────────────────────────
function toArrayBuffer(modelBytes) {
if (modelBytes instanceof ArrayBuffer) return modelBytes;
if (modelBytes instanceof Uint8Array) {
return modelBytes.buffer.slice(modelBytes.byteOffset, modelBytes.byteOffset + modelBytes.byteLength);
}
throw new Error(`[backend] modelBytes must be ArrayBuffer or Uint8Array, got ${typeof modelBytes}`);
}
|