File size: 9,431 Bytes
3f22414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
// Unified session loader for inference engines. One entry point, two arms.
//
//   loadSession(bytes, intent, opts) -> { session, realizedBackend }
//
//   intent          'gpu' | 'cpu'                          — what the user wants
//   realizedBackend 'web-webgpu' | 'web-wasm'              — what actually ran
//                 | 'native-cpu' | 'native-coreml/...'     — (display only)
//                 | 'native-dml' | 'native-cuda' | ...
//
// Callers (the engines) shouldn't need to know whether we're running native
// (Electron + onnxruntime-node) or web (browser + ort-web). They pass intent
// in; they get a session-shaped object out plus a string label of what it's
// running on. No magic options-bag properties, no synthesised exceptions to
// drive control flow, no monkey-patching of ORT-Web's surface.

import { dispatchBackendEvent, shortenReason } from 'lib/backend-events';

// ───── Mode detection ──────────────────────────────────────────────────────
// `__nativeOrt` is installed by desktop/preload.cjs only when the renderer is
// running inside Electron *and* AITOOLS_NATIVE isn't explicitly disabled. Its
// presence is the entire signal — no `enabled` flag to interrogate.

export function isNativeMode() {
  return !!globalThis.__nativeOrt;
}

// ───── Auto-disable on worker crash ────────────────────────────────────────
// If the native worker crashes mid-session, we trip a one-way switch so the
// rest of this page's loads fall through to the web path. Same semantics as
// the old inject.js `tripAutoDisable`, just lifted here.

let nativeAutoDisabled = false;

function isWorkerCrash(err) {
  const m = err && err.message;
  return !!m && /worker crashed|worker not available|native ort unavailable/i.test(m);
}

// ───── Wire tensor codec ───────────────────────────────────────────────────
// Mirrors WIRE_DTYPES in ort-worker.cjs.

const WIRE_DTYPES = {
  float32: Float32Array, float16: Uint16Array,
  int32: Int32Array, int64: BigInt64Array, uint8: Uint8Array,
};

// ───── Public entry point ──────────────────────────────────────────────────

/**
 * Load a model and return a session-shaped object plus the realized backend.
 *
 * @param {Uint8Array|ArrayBuffer} modelBytes
 * @param {'gpu'|'cpu'} intent
 * @param {{
 *   profile?: boolean,
 *   preferredOutputLocation?: 'gpu-buffer' | 'cpu',
 * }} [opts]
 *   `preferredOutputLocation` is forwarded only on the web-webgpu path; it
 *   enables zero-readback output tensors for the upscaler's GPU fast path.
 *   Ignored on web-wasm and on native.
 * @returns {Promise<{ session: object, realizedBackend: string }>}
 */
export async function loadSession(modelBytes, intent, opts = {}) {
  if (intent !== 'gpu' && intent !== 'cpu') {
    throw new Error(`loadSession: unknown intent ${JSON.stringify(intent)} (expected 'gpu' or 'cpu')`);
  }
  if (isNativeMode() && !nativeAutoDisabled) {
    try {
      return await loadNative(modelBytes, intent);
    } catch (e) {
      if (isWorkerCrash(e)) {
        nativeAutoDisabled = true;
        console.warn(`[backend] native ORT auto-disabled for this page session: ${e.message} Future loads use ORT-Web. Reload to retry native.`);
        // Fall through to web path so this load still has a chance.
      } else {
        throw e;
      }
    }
  }
  return loadWeb(modelBytes, intent, opts);
}

// ───── Native arm ──────────────────────────────────────────────────────────

let nativeSeq = 0;

async function loadNative(modelBytes, intent) {
  const transferable = toArrayBuffer(modelBytes);
  const key = `m${++nativeSeq}_${transferable.byteLength}`;
  // Host emits attempt/fallback/skipped events via the model-event channel
  // (forwarded to backend-events by desktop/inject.js). We don't synthesise
  // an attempt here — let the worker's actual rung outcomes speak.
  const meta = await globalThis.__nativeOrt.load(key, transferable, { intent });
  console.log(`[backend] native session ${key}: ${meta.inputNames.join(',')} -> ${meta.outputNames.join(',')} via ${meta.rung}`);
  return {
    session: makeNativeSession(key, meta),
    realizedBackend: `native-${meta.rung}`,
  };
}

function makeNativeSession(key, meta) {
  // ORT-Web sessions expose inputMetadata / outputMetadata as arrays of
  // {name, type, dimensions}; the engines self-correct from dims so a narrow
  // 'tensor(float)' placeholder is fine.
  const inputMetadata  = meta.inputNames.map(name => ({ name, type: 'tensor(float)', dimensions: [] }));
  const outputMetadata = meta.outputNames.map(name => ({ name, type: 'tensor(float)', dimensions: [] }));

  return {
    inputNames: meta.inputNames,
    outputNames: meta.outputNames,
    inputMetadata,
    outputMetadata,

    async run(feeds /*, runOptions */) {
      const wire = {};
      for (const [name, t] of Object.entries(feeds)) {
        const data = t.data;
        const ab = data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength);
        wire[name] = { type: t.type, dims: t.dims, data: ab };
      }
      let raw;
      try {
        raw = await globalThis.__nativeOrt.run(key, wire);
      } catch (e) {
        if (isWorkerCrash(e)) {
          // Native is now dead for this page. The caller's current run still
          // fails (we can't reload mid-session-run), but subsequent loadSession
          // calls will fall through to the web path via nativeAutoDisabled.
          nativeAutoDisabled = true;
          console.warn(`[backend] native ORT auto-disabled mid-run: ${e.message}`);
        }
        throw e;
      }
      const out = {};
      for (const [name, t] of Object.entries(raw)) {
        const Arr = WIRE_DTYPES[t.type];
        if (!Arr) throw new Error(`[backend] unsupported output tensor type: ${t.type}`);
        const ortGlobal = globalThis.ort;
        const tensor = new ortGlobal.Tensor(t.type, new Arr(t.data), t.dims);
        if (typeof tensor.dispose !== 'function') tensor.dispose = () => {};
        out[name] = tensor;
      }
      return out;
    },

    async release() {
      try { await globalThis.__nativeOrt.release(key); } catch {}
    },
    startProfiling() {},
    endProfiling() {},
  };
}

// ───── Web arm ─────────────────────────────────────────────────────────────

async function loadWeb(modelBytes, intent, { profile = false, preferredOutputLocation } = {}) {
  const ort = globalThis.ort;
  if (!ort) throw new Error('[backend] ort-web is not loaded — include vendor/onnxruntime-web/ort.all.min.js before using loadSession');

  ort.env.wasm.wasmPaths =
    globalThis.__ORT_WASM_PATHS__ ||
    new URL('vendor/onnxruntime-web/', document.baseURI).toString();
  ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;

  // ort.env.webgpu.profilingMode is global; clear when not profiling so a
  // prior run can't leave it stuck on.
  if (ort.env.webgpu) {
    ort.env.webgpu.profilingMode = (profile && intent === 'gpu') ? 'default' : 'off';
  }

  const sessionOpts = {
    graphOptimizationLevel: 'all',
    ...(profile && { enableProfiling: true }),
  };

  if (intent === 'gpu') {
    sessionOpts.executionProviders = [{ name: 'webgpu', preferredLayout: 'NCHW' }];
    if (preferredOutputLocation) sessionOpts.preferredOutputLocation = preferredOutputLocation;
    dispatchBackendEvent({ kind: 'attempt', backend: 'web-webgpu' });
    try {
      const session = await ort.InferenceSession.create(modelBytes, sessionOpts);
      dispatchBackendEvent({ kind: 'success', backend: 'web-webgpu' });
      return { session, realizedBackend: 'web-webgpu' };
    } catch (e) {
      console.warn(`[backend] WebGPU failed, falling back to WASM. Reason:`, e);
      dispatchBackendEvent({ kind: 'fallback', backend: 'web-webgpu', reason: shortenReason(e) });
      // Strip WebGPU-only opts before retrying on WASM.
      delete sessionOpts.preferredOutputLocation;
    }
  }

  sessionOpts.executionProviders = ['wasm'];
  dispatchBackendEvent({ kind: 'attempt', backend: 'web-wasm' });
  const session = await ort.InferenceSession.create(modelBytes, sessionOpts);
  dispatchBackendEvent({ kind: 'success', backend: 'web-wasm' });
  return { session, realizedBackend: 'web-wasm' };
}

// ───── helpers ─────────────────────────────────────────────────────────────

function toArrayBuffer(modelBytes) {
  if (modelBytes instanceof ArrayBuffer) return modelBytes;
  if (modelBytes instanceof Uint8Array) {
    return modelBytes.buffer.slice(modelBytes.byteOffset, modelBytes.byteOffset + modelBytes.byteLength);
  }
  throw new Error(`[backend] modelBytes must be ArrayBuffer or Uint8Array, got ${typeof modelBytes}`);
}