/** * Thin runtime adapter around onnxruntime-{web,node}. * * Why an adapter * ────────────── * ``onnxruntime-web`` and ``onnxruntime-node`` ship with subtly * different ``InferenceSession.create`` signatures and execution * provider names (``'wasm'``/``'webgpu'`` vs ``'cpu'``/``'cuda'``). * Userland code shouldn't have to know which one is loaded — we pick * at runtime based on whether ``window`` exists, then re-export a * unified surface. * * Both packages are *peer dependencies* (``peerDependenciesMeta`` * marks them optional) so users only install the one they need. */ import type * as OrtWeb from 'onnxruntime-web'; import type * as OrtNode from 'onnxruntime-node'; type OrtModule = typeof OrtWeb | typeof OrtNode; type ExecutionProvider = 'webgpu' | 'wasm' | 'cpu' | 'cuda'; /** Re-exported ORT session interface, narrowed to what we use. */ export interface OrtSession { inputNames: readonly string[]; outputNames: readonly string[]; run( feeds: Record, ): Promise>; } /** ORT tensor — matches both ``onnxruntime-web`` and * ``onnxruntime-node`` tensors at the structural level. */ export interface OrtTensor { type: string; data: | Float32Array | Int32Array | BigInt64Array | Uint8Array | Uint16Array; dims: readonly number[]; } let _ort: OrtModule | null = null; /** * Load whichever ORT package is installed. Browser prefers * ``onnxruntime-web``; Node prefers ``onnxruntime-node`` but will * accept ``-web`` as a fallback (slower but works). */ export async function loadOrt(): Promise { if (_ort) return _ort; const isBrowser = typeof window !== 'undefined' && typeof document !== 'undefined'; // Vite/webpack treat ``import('onnxruntime-web')`` as a static spec // and bundle the module if it's a real dep; with peer-deps it stays // dynamic and only loads when present. ``catch`` swallows the // missing-peer-dep case and falls back to the alternate runtime. if (isBrowser) { try { _ort = (await import('onnxruntime-web')) as unknown as OrtModule; return _ort; } catch { throw new Error( '@cp500/infon-coref: onnxruntime-web is required in the ' + 'browser. Install with: npm install onnxruntime-web', ); } } // Node. try { _ort = (await import('onnxruntime-node')) as unknown as OrtModule; return _ort; } catch { try { _ort = (await import('onnxruntime-web')) as unknown as OrtModule; return _ort; } catch { throw new Error( '@cp500/infon-coref: onnxruntime-node (preferred) or ' + 'onnxruntime-web is required. Install with: ' + 'npm install onnxruntime-node', ); } } } /** Resolve ``'auto'`` to a concrete EP for the current runtime. */ async function resolveProvider( device: 'auto' | ExecutionProvider, ort: OrtModule, ): Promise { const isBrowser = typeof window !== 'undefined' && typeof document !== 'undefined'; if (device !== 'auto') return device; if (isBrowser) { // Probe WebGPU. Even when the API exists, the device may be // unreachable (older Macbooks, Firefox without flags, etc.). if (typeof (navigator as { gpu?: unknown }).gpu !== 'undefined') { try { const adapter = await ( navigator as { gpu?: { requestAdapter(): Promise } } ).gpu!.requestAdapter(); if (adapter) return 'webgpu'; } catch { /* fall through */ } } return 'wasm'; } // Node: default CPU (CUDA needs explicit opt-in via ``device: 'cuda'``). return 'cpu'; } /** Helper: create an ONNX inference session with sane defaults. */ export async function createSession( modelPath: string | ArrayBuffer | Uint8Array, device: 'auto' | ExecutionProvider = 'auto', ): Promise { const ort = await loadOrt(); const provider = await resolveProvider(device, ort); const session = await ort.InferenceSession.create(modelPath as never, { executionProviders: [provider], graphOptimizationLevel: 'all', }); return session as unknown as OrtSession; } /** Helper: build an ORT tensor from a typed array. * * The web/node packages share a constructor signature but TypeScript * doesn't see it because we don't statically import them. */ export async function makeTensor( type: 'float32' | 'int64' | 'float16', data: | Float32Array | BigInt64Array | Uint16Array, dims: readonly number[], ): Promise { const ort = await loadOrt(); // ``Tensor`` constructor: new ort.Tensor(type, data, dims). return new (ort as unknown as { Tensor: new (...args: unknown[]) => OrtTensor }) .Tensor(type, data, dims as number[]); }