Spaces:
Running
Running
File size: 9,150 Bytes
ca97aa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
/**
* @file Handler file for choosing the correct version of ONNX Runtime, based on the environment.
* Ideally, we could import the `onnxruntime-web` and `onnxruntime-node` packages only when needed,
* but dynamic imports don't seem to work with the current webpack version and/or configuration.
* This is possibly due to the experimental nature of top-level await statements.
* So, we just import both packages, and use the appropriate one based on the environment:
* - When running in node, we use `onnxruntime-node`.
* - When running in the browser, we use `onnxruntime-web` (`onnxruntime-node` is not bundled).
*
* This module is not directly exported, but can be accessed through the environment variables:
* ```javascript
* import { env } from '@huggingface/transformers';
* console.log(env.backends.onnx);
* ```
*
* @module backends/onnx
*/
import { env, apis } from '../env.js';
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';
export { Tensor } from 'onnxruntime-common';
/**
* @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
*/
/** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
auto: null, // Auto-detect based on device and environment
gpu: null, // Auto-detect GPU
cpu: 'cpu', // CPU
wasm: 'wasm', // WebAssembly
webgpu: 'webgpu', // WebGPU
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
'webnn-gpu': { name: 'webnn', deviceType: 'gpu' }, // WebNN GPU
'webnn-cpu': { name: 'webnn', deviceType: 'cpu' }, // WebNN CPU
});
/**
* The list of supported devices, sorted by priority/performance.
* @type {import("../utils/devices.js").DeviceType[]}
*/
const supportedDevices = [];
/** @type {ONNXExecutionProviders[]} */
let defaultDevices;
let ONNX;
const ORT_SYMBOL = Symbol.for('onnxruntime');
if (ORT_SYMBOL in globalThis) {
// If the JS runtime exposes their own ONNX runtime, use it
ONNX = globalThis[ORT_SYMBOL];
} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
// Updated as of ONNX Runtime 1.20.1
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
// | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
// | CPU | βοΈ | βοΈ | βοΈ | βοΈ | βοΈ | βοΈ |
// | DirectML | βοΈ | βοΈ | β | β | β | β |
// | CUDA | β | β | βοΈ (CUDA v11.8) | β | β | β |
switch (process.platform) {
case 'win32': // Windows x64 and Windows arm64
supportedDevices.push('dml');
break;
case 'linux': // Linux x64 and Linux arm64
if (process.arch === 'x64') {
supportedDevices.push('cuda');
}
break;
case 'darwin': // MacOS x64 and MacOS arm64
break;
}
supportedDevices.push('cpu');
defaultDevices = ['cpu'];
} else {
ONNX = ONNX_WEB;
if (apis.IS_WEBNN_AVAILABLE) {
// TODO: Only push supported providers (depending on available hardware)
supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn');
}
if (apis.IS_WEBGPU_AVAILABLE) {
supportedDevices.push('webgpu');
}
supportedDevices.push('wasm');
defaultDevices = ['wasm'];
}
// @ts-ignore
const InferenceSession = ONNX.InferenceSession;
/**
* Map a device to the execution providers to use for the given device.
* @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
* @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
*/
export function deviceToExecutionProviders(device = null) {
// Use the default execution providers if the user hasn't specified anything
if (!device) return defaultDevices;
// Handle overloaded cases
switch (device) {
case "auto":
return supportedDevices;
case "gpu":
return supportedDevices.filter(x =>
["webgpu", "cuda", "dml", "webnn-gpu"].includes(x),
);
}
if (supportedDevices.includes(device)) {
return [DEVICE_TO_EXECUTION_PROVIDER_MAPPING[device] ?? device];
}
throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedDevices.join(', ')}.`)
}
/**
* To prevent multiple calls to `initWasm()`, we store the first call in a Promise
* that is resolved when the first InferenceSession is created. Subsequent calls
* will wait for this Promise to resolve before creating their own InferenceSession.
* @type {Promise<any>|null}
*/
let wasmInitPromise = null;
/**
* Create an ONNX inference session.
* @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path.
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
* @param {Object} session_config ONNX inference session configuration.
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
*/
export async function createInferenceSession(buffer_or_path, session_options, session_config) {
if (wasmInitPromise) {
// A previous session has already initialized the WASM runtime
// so we wait for it to resolve before creating this new session.
await wasmInitPromise;
}
const sessionPromise = InferenceSession.create(buffer_or_path, session_options);
wasmInitPromise ??= sessionPromise;
const session = await sessionPromise;
session.config = session_config;
return session;
}
/**
* Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
* For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
* @type {Promise<any>}
*/
let webInferenceChain = Promise.resolve();
const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV;
/**
* Run an inference session.
* @param {import('onnxruntime-common').InferenceSession} session The ONNX inference session.
* @param {Record<string, import('onnxruntime-common').Tensor>} ortFeed The input tensors.
* @returns {Promise<Record<string, import('onnxruntime-common').Tensor>>} The output tensors.
*/
export async function runInferenceSession(session, ortFeed) {
const run = () => session.run(ortFeed);
const output = await (IS_WEB_ENV
? (webInferenceChain = webInferenceChain.then(run))
: run()
);
return output;
}
/**
* Check if an object is an ONNX tensor.
* @param {any} x The object to check
* @returns {boolean} Whether the object is an ONNX tensor.
*/
export function isONNXTensor(x) {
return x instanceof ONNX.Tensor;
}
/** @type {import('onnxruntime-common').Env} */
// @ts-ignore
const ONNX_ENV = ONNX?.env;
if (ONNX_ENV?.wasm) {
// Initialize wasm backend with suitable default settings.
// (Optional) Set path to wasm files. This will override the default path search behavior of onnxruntime-web.
// By default, we only do this if we are not in a service worker and the wasmPaths are not already set.
if (
// @ts-ignore Cannot find name 'ServiceWorkerGlobalScope'.ts(2304)
!(typeof ServiceWorkerGlobalScope !== 'undefined' && self instanceof ServiceWorkerGlobalScope)
&& !ONNX_ENV.wasm.wasmPaths
) {
ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`;
}
// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
// https://github.com/microsoft/onnxruntime/pull/21534
// Users may wish to proxy the WASM backend to prevent the UI from freezing,
// However, this is not necessary when using WebGPU, so we default to false.
ONNX_ENV.wasm.proxy = false;
}
if (ONNX_ENV?.webgpu) {
ONNX_ENV.webgpu.powerPreference = 'high-performance';
}
/**
* Check if ONNX's WASM backend is being proxied.
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
*/
export function isONNXProxy() {
// TODO: Update this when allowing non-WASM backends.
return ONNX_ENV?.wasm?.proxy;
}
// Expose ONNX environment variables to `env.backends.onnx`
env.backends.onnx = ONNX_ENV;
|