// Unified session loader for inference engines. One entry point, two arms. // // loadSession(bytes, intent, opts) -> { session, realizedBackend } // // intent 'gpu' | 'cpu' — what the user wants // realizedBackend 'web-webgpu' | 'web-wasm' — what actually ran // | 'native-cpu' | 'native-coreml/...' — (display only) // | 'native-dml' | 'native-cuda' | ... // // Callers (the engines) shouldn't need to know whether we're running native // (Electron + onnxruntime-node) or web (browser + ort-web). They pass intent // in; they get a session-shaped object out plus a string label of what it's // running on. No magic options-bag properties, no synthesised exceptions to // drive control flow, no monkey-patching of ORT-Web's surface. import { dispatchBackendEvent, shortenReason } from 'lib/backend-events'; // ───── Mode detection ────────────────────────────────────────────────────── // `__nativeOrt` is installed by desktop/preload.cjs only when the renderer is // running inside Electron *and* AITOOLS_NATIVE isn't explicitly disabled. Its // presence is the entire signal — no `enabled` flag to interrogate. export function isNativeMode() { return !!globalThis.__nativeOrt; } // ───── Auto-disable on worker crash ──────────────────────────────────────── // If the native worker crashes mid-session, we trip a one-way switch so the // rest of this page's loads fall through to the web path. Same semantics as // the old inject.js `tripAutoDisable`, just lifted here. let nativeAutoDisabled = false; function isWorkerCrash(err) { const m = err && err.message; return !!m && /worker crashed|worker not available|native ort unavailable/i.test(m); } // ───── Wire tensor codec ─────────────────────────────────────────────────── // Mirrors WIRE_DTYPES in ort-worker.cjs. const WIRE_DTYPES = { float32: Float32Array, float16: Uint16Array, int32: Int32Array, int64: BigInt64Array, uint8: Uint8Array, }; // ───── Public entry point ────────────────────────────────────────────────── /** * Load a model and return a session-shaped object plus the realized backend. * * @param {Uint8Array|ArrayBuffer} modelBytes * @param {'gpu'|'cpu'} intent * @param {{ * profile?: boolean, * preferredOutputLocation?: 'gpu-buffer' | 'cpu', * }} [opts] * `preferredOutputLocation` is forwarded only on the web-webgpu path; it * enables zero-readback output tensors for the upscaler's GPU fast path. * Ignored on web-wasm and on native. * @returns {Promise<{ session: object, realizedBackend: string }>} */ export async function loadSession(modelBytes, intent, opts = {}) { if (intent !== 'gpu' && intent !== 'cpu') { throw new Error(`loadSession: unknown intent ${JSON.stringify(intent)} (expected 'gpu' or 'cpu')`); } if (isNativeMode() && !nativeAutoDisabled) { try { return await loadNative(modelBytes, intent); } catch (e) { if (isWorkerCrash(e)) { nativeAutoDisabled = true; console.warn(`[backend] native ORT auto-disabled for this page session: ${e.message} Future loads use ORT-Web. Reload to retry native.`); // Fall through to web path so this load still has a chance. } else { throw e; } } } return loadWeb(modelBytes, intent, opts); } // ───── Native arm ────────────────────────────────────────────────────────── let nativeSeq = 0; async function loadNative(modelBytes, intent) { const transferable = toArrayBuffer(modelBytes); const key = `m${++nativeSeq}_${transferable.byteLength}`; // Host emits attempt/fallback/skipped events via the model-event channel // (forwarded to backend-events by desktop/inject.js). We don't synthesise // an attempt here — let the worker's actual rung outcomes speak. const meta = await globalThis.__nativeOrt.load(key, transferable, { intent }); console.log(`[backend] native session ${key}: ${meta.inputNames.join(',')} -> ${meta.outputNames.join(',')} via ${meta.rung}`); return { session: makeNativeSession(key, meta), realizedBackend: `native-${meta.rung}`, }; } function makeNativeSession(key, meta) { // ORT-Web sessions expose inputMetadata / outputMetadata as arrays of // {name, type, dimensions}; the engines self-correct from dims so a narrow // 'tensor(float)' placeholder is fine. const inputMetadata = meta.inputNames.map(name => ({ name, type: 'tensor(float)', dimensions: [] })); const outputMetadata = meta.outputNames.map(name => ({ name, type: 'tensor(float)', dimensions: [] })); return { inputNames: meta.inputNames, outputNames: meta.outputNames, inputMetadata, outputMetadata, async run(feeds /*, runOptions */) { const wire = {}; for (const [name, t] of Object.entries(feeds)) { const data = t.data; const ab = data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength); wire[name] = { type: t.type, dims: t.dims, data: ab }; } let raw; try { raw = await globalThis.__nativeOrt.run(key, wire); } catch (e) { if (isWorkerCrash(e)) { // Native is now dead for this page. The caller's current run still // fails (we can't reload mid-session-run), but subsequent loadSession // calls will fall through to the web path via nativeAutoDisabled. nativeAutoDisabled = true; console.warn(`[backend] native ORT auto-disabled mid-run: ${e.message}`); } throw e; } const out = {}; for (const [name, t] of Object.entries(raw)) { const Arr = WIRE_DTYPES[t.type]; if (!Arr) throw new Error(`[backend] unsupported output tensor type: ${t.type}`); const ortGlobal = globalThis.ort; const tensor = new ortGlobal.Tensor(t.type, new Arr(t.data), t.dims); if (typeof tensor.dispose !== 'function') tensor.dispose = () => {}; out[name] = tensor; } return out; }, async release() { try { await globalThis.__nativeOrt.release(key); } catch {} }, startProfiling() {}, endProfiling() {}, }; } // ───── Web arm ───────────────────────────────────────────────────────────── async function loadWeb(modelBytes, intent, { profile = false, preferredOutputLocation } = {}) { const ort = globalThis.ort; if (!ort) throw new Error('[backend] ort-web is not loaded — include vendor/onnxruntime-web/ort.all.min.js before using loadSession'); ort.env.wasm.wasmPaths = globalThis.__ORT_WASM_PATHS__ || new URL('vendor/onnxruntime-web/', document.baseURI).toString(); ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4; // ort.env.webgpu.profilingMode is global; clear when not profiling so a // prior run can't leave it stuck on. if (ort.env.webgpu) { ort.env.webgpu.profilingMode = (profile && intent === 'gpu') ? 'default' : 'off'; } const sessionOpts = { graphOptimizationLevel: 'all', ...(profile && { enableProfiling: true }), }; if (intent === 'gpu') { sessionOpts.executionProviders = [{ name: 'webgpu', preferredLayout: 'NCHW' }]; if (preferredOutputLocation) sessionOpts.preferredOutputLocation = preferredOutputLocation; dispatchBackendEvent({ kind: 'attempt', backend: 'web-webgpu' }); try { const session = await ort.InferenceSession.create(modelBytes, sessionOpts); dispatchBackendEvent({ kind: 'success', backend: 'web-webgpu' }); return { session, realizedBackend: 'web-webgpu' }; } catch (e) { console.warn(`[backend] WebGPU failed, falling back to WASM. Reason:`, e); dispatchBackendEvent({ kind: 'fallback', backend: 'web-webgpu', reason: shortenReason(e) }); // Strip WebGPU-only opts before retrying on WASM. delete sessionOpts.preferredOutputLocation; } } sessionOpts.executionProviders = ['wasm']; dispatchBackendEvent({ kind: 'attempt', backend: 'web-wasm' }); const session = await ort.InferenceSession.create(modelBytes, sessionOpts); dispatchBackendEvent({ kind: 'success', backend: 'web-wasm' }); return { session, realizedBackend: 'web-wasm' }; } // ───── helpers ───────────────────────────────────────────────────────────── function toArrayBuffer(modelBytes) { if (modelBytes instanceof ArrayBuffer) return modelBytes; if (modelBytes instanceof Uint8Array) { return modelBytes.buffer.slice(modelBytes.byteOffset, modelBytes.byteOffset + modelBytes.byteLength); } throw new Error(`[backend] modelBytes must be ArrayBuffer or Uint8Array, got ${typeof modelBytes}`); }