cp500's picture
Upload js/src/ort.ts with huggingface_hub
b0225a3 verified
Raw
History Blame Contribute Delete
4.81 kB
/**
* Thin runtime adapter around onnxruntime-{web,node}.
*
* Why an adapter
* ──────────────
* ``onnxruntime-web`` and ``onnxruntime-node`` ship with subtly
* different ``InferenceSession.create`` signatures and execution
* provider names (``'wasm'``/``'webgpu'`` vs ``'cpu'``/``'cuda'``).
* Userland code shouldn't have to know which one is loaded — we pick
* at runtime based on whether ``window`` exists, then re-export a
* unified surface.
*
* Both packages are *peer dependencies* (``peerDependenciesMeta``
* marks them optional) so users only install the one they need.
*/
import type * as OrtWeb from 'onnxruntime-web';
import type * as OrtNode from 'onnxruntime-node';
type OrtModule = typeof OrtWeb | typeof OrtNode;
type ExecutionProvider = 'webgpu' | 'wasm' | 'cpu' | 'cuda';
/** Re-exported ORT session interface, narrowed to what we use. */
export interface OrtSession {
inputNames: readonly string[];
outputNames: readonly string[];
run(
feeds: Record<string, OrtTensor>,
): Promise<Record<string, OrtTensor>>;
}
/** ORT tensor — matches both ``onnxruntime-web`` and
* ``onnxruntime-node`` tensors at the structural level. */
export interface OrtTensor {
type: string;
data:
| Float32Array
| Int32Array
| BigInt64Array
| Uint8Array
| Uint16Array;
dims: readonly number[];
}
let _ort: OrtModule | null = null;
/**
* Load whichever ORT package is installed. Browser prefers
* ``onnxruntime-web``; Node prefers ``onnxruntime-node`` but will
* accept ``-web`` as a fallback (slower but works).
*/
export async function loadOrt(): Promise<OrtModule> {
if (_ort) return _ort;
const isBrowser =
typeof window !== 'undefined' && typeof document !== 'undefined';
// Vite/webpack treat ``import('onnxruntime-web')`` as a static spec
// and bundle the module if it's a real dep; with peer-deps it stays
// dynamic and only loads when present. ``catch`` swallows the
// missing-peer-dep case and falls back to the alternate runtime.
if (isBrowser) {
try {
_ort = (await import('onnxruntime-web')) as unknown as OrtModule;
return _ort;
} catch {
throw new Error(
'@cp500/infon-coref: onnxruntime-web is required in the ' +
'browser. Install with: npm install onnxruntime-web',
);
}
}
// Node.
try {
_ort = (await import('onnxruntime-node')) as unknown as OrtModule;
return _ort;
} catch {
try {
_ort = (await import('onnxruntime-web')) as unknown as OrtModule;
return _ort;
} catch {
throw new Error(
'@cp500/infon-coref: onnxruntime-node (preferred) or ' +
'onnxruntime-web is required. Install with: ' +
'npm install onnxruntime-node',
);
}
}
}
/** Resolve ``'auto'`` to a concrete EP for the current runtime. */
async function resolveProvider(
device: 'auto' | ExecutionProvider,
ort: OrtModule,
): Promise<ExecutionProvider> {
const isBrowser =
typeof window !== 'undefined' && typeof document !== 'undefined';
if (device !== 'auto') return device;
if (isBrowser) {
// Probe WebGPU. Even when the API exists, the device may be
// unreachable (older Macbooks, Firefox without flags, etc.).
if (typeof (navigator as { gpu?: unknown }).gpu !== 'undefined') {
try {
const adapter = await (
navigator as { gpu?: { requestAdapter(): Promise<unknown> } }
).gpu!.requestAdapter();
if (adapter) return 'webgpu';
} catch {
/* fall through */
}
}
return 'wasm';
}
// Node: default CPU (CUDA needs explicit opt-in via ``device: 'cuda'``).
return 'cpu';
}
/** Helper: create an ONNX inference session with sane defaults. */
export async function createSession(
modelPath: string | ArrayBuffer | Uint8Array,
device: 'auto' | ExecutionProvider = 'auto',
): Promise<OrtSession> {
const ort = await loadOrt();
const provider = await resolveProvider(device, ort);
const session = await ort.InferenceSession.create(modelPath as never, {
executionProviders: [provider],
graphOptimizationLevel: 'all',
});
return session as unknown as OrtSession;
}
/** Helper: build an ORT tensor from a typed array.
*
* The web/node packages share a constructor signature but TypeScript
* doesn't see it because we don't statically import them. */
export async function makeTensor(
type: 'float32' | 'int64' | 'float16',
data:
| Float32Array
| BigInt64Array
| Uint16Array,
dims: readonly number[],
): Promise<OrtTensor> {
const ort = await loadOrt();
// ``Tensor`` constructor: new ort.Tensor(type, data, dims).
return new (ort as unknown as { Tensor: new (...args: unknown[]) => OrtTensor })
.Tensor(type, data, dims as number[]);
}