File size: 6,949 Bytes
bce29b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.makeRequestOptions = makeRequestOptions;
exports.makeRequestOptionsFromResolvedModel = makeRequestOptionsFromResolvedModel;
const config_js_1 = require("../config.js");
const package_js_1 = require("../package.js");
const getInferenceProviderMapping_js_1 = require("./getInferenceProviderMapping.js");
const isUrl_js_1 = require("./isUrl.js");
const errors_js_1 = require("../errors.js");
/**
* Lazy-loaded from huggingface.co/api/tasks when needed
* Used to determine the default model to use when it's not user defined
*/
let tasks = null;
/**
* Helper that prepares request arguments.
* This async version handle the model ID resolution step.
*/
async function makeRequestOptions(args, providerHelper, options) {
const { model: maybeModel } = args;
const provider = providerHelper.provider;
const { task } = options ?? {};
// Validate inputs
if (args.endpointUrl && provider !== "hf-inference") {
throw new errors_js_1.InferenceClientInputError(`Cannot use endpointUrl with a third-party provider.`);
}
if (maybeModel && (0, isUrl_js_1.isUrl)(maybeModel)) {
throw new errors_js_1.InferenceClientInputError(`Model URLs are no longer supported. Use endpointUrl instead.`);
}
if (args.endpointUrl) {
// No need to have maybeModel, or to load default model for a task
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, undefined, options);
}
if (!maybeModel && !task) {
throw new errors_js_1.InferenceClientInputError("No model provided, and no task has been specified.");
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task));
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
throw new errors_js_1.InferenceClientInputError(`Provider ${provider} requires a model ID to be passed directly.`);
}
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
? {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
providerId: removeProviderPrefix(maybeModel, provider),
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
hfModelId: maybeModel,
status: "live",
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task,
}
: await (0, getInferenceProviderMapping_js_1.getInferenceProviderMapping)({
modelId: hfModel,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task,
provider,
accessToken: args.accessToken,
}, { fetch: options?.fetch });
if (!inferenceProviderMapping) {
throw new errors_js_1.InferenceClientInputError(`We have not been able to find inference provider information for model ${hfModel}.`);
}
// Use the sync version with the resolved model
return makeRequestOptionsFromResolvedModel(inferenceProviderMapping.providerId, providerHelper, args, inferenceProviderMapping, options);
}
/**
* Helper that prepares request arguments. - for internal use only
* This sync version skips the model ID resolution step
*/
function makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, mapping, options) {
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
void model;
void maybeProvider;
const provider = providerHelper.provider;
const { includeCredentials, task, signal, billTo } = options ?? {};
const authMethod = (() => {
if (providerHelper.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
if (accessToken && accessToken.startsWith("hf_")) {
throw new errors_js_1.InferenceClientInputError(`Provider ${provider} is closed-source and does not support HF tokens.`);
}
}
if (accessToken) {
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
}
if (includeCredentials === "include") {
// If accessToken is passed, it should take precedence over includeCredentials
return "credentials-include";
}
return "none";
})();
// Make URL
const modelId = endpointUrl ?? resolvedModel;
const url = providerHelper.makeUrl({
authMethod,
model: modelId,
task,
});
// Make headers
const headers = providerHelper.prepareHeaders({
accessToken,
authMethod,
}, "data" in args && !!args.data);
if (billTo) {
headers[config_js_1.HF_HEADER_X_BILL_TO] = billTo;
}
// Add user-agent to headers
// e.g. @huggingface/inference/3.1.3
const ownUserAgent = `${package_js_1.PACKAGE_NAME}/${package_js_1.PACKAGE_VERSION}`;
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
.filter((x) => x !== undefined)
.join(" ");
headers["User-Agent"] = userAgent;
// Make body
const body = providerHelper.makeBody({
args: remainingArgs,
model: resolvedModel,
task,
mapping,
});
/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
*/
let credentials;
if (typeof includeCredentials === "string") {
credentials = includeCredentials;
}
else if (includeCredentials === true) {
credentials = "include";
}
const info = {
headers,
method: "POST",
body: body,
...(credentials ? { credentials } : undefined),
signal,
};
return { url, info };
}
async function loadDefaultModel(task) {
if (!tasks) {
tasks = await loadTaskInfo();
}
const taskInfo = tasks[task];
if ((taskInfo?.models.length ?? 0) <= 0) {
throw new errors_js_1.InferenceClientInputError(`No default model defined for task ${task}, please define the model explicitly.`);
}
return taskInfo.models[0].id;
}
async function loadTaskInfo() {
const url = `${config_js_1.HF_HUB_URL}/api/tasks`;
const res = await fetch(url);
if (!res.ok) {
throw new errors_js_1.InferenceClientHubApiError("Failed to load tasks definitions from Hugging Face Hub.", { url, method: "GET" }, { requestId: res.headers.get("x-request-id") ?? "", status: res.status, body: await res.text() });
}
return await res.json();
}
function removeProviderPrefix(model, provider) {
if (!model.startsWith(`${provider}/`)) {
throw new errors_js_1.InferenceClientInputError(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
}
return model.slice(provider.length + 1);
}
|