File size: 7,140 Bytes
bce29b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config.js";
import { PACKAGE_NAME, PACKAGE_VERSION } from "../package.js";
import type { InferenceTask, Options, RequestArgs } from "../types.js";
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping.js";
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js";
import type { getProviderHelper } from "./getProviderHelper.js";
import { isUrl } from "./isUrl.js";
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";
/**
* Lazy-loaded from huggingface.co/api/tasks when needed
* Used to determine the default model to use when it's not user defined
*/
let tasks: Record<string, { models: { id: string }[] }> | null = null;
/**
* Helper that prepares request arguments.
* This async version handle the model ID resolution step.
*/
export async function makeRequestOptions(
args: RequestArgs & {
data?: Blob | ArrayBuffer;
stream?: boolean;
},
providerHelper: ReturnType<typeof getProviderHelper>,
options?: Options & {
/** In most cases (unless we pass a endpointUrl) we know the task */
task?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
const { model: maybeModel } = args;
const provider = providerHelper.provider;
const { task } = options ?? {};
// Validate inputs
if (args.endpointUrl && provider !== "hf-inference") {
throw new InferenceClientInputError(`Cannot use endpointUrl with a third-party provider.`);
}
if (maybeModel && isUrl(maybeModel)) {
throw new InferenceClientInputError(`Model URLs are no longer supported. Use endpointUrl instead.`);
}
if (args.endpointUrl) {
// No need to have maybeModel, or to load default model for a task
return makeRequestOptionsFromResolvedModel(
maybeModel ?? args.endpointUrl,
providerHelper,
args,
undefined,
options
);
}
if (!maybeModel && !task) {
throw new InferenceClientInputError("No model provided, and no task has been specified.");
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
if (providerHelper.clientSideRoutingOnly && !maybeModel) {
throw new InferenceClientInputError(`Provider ${provider} requires a model ID to be passed directly.`);
}
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly
? ({
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
providerId: removeProviderPrefix(maybeModel!, provider),
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
hfModelId: maybeModel!,
status: "live",
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task!,
} satisfies InferenceProviderModelMapping)
: await getInferenceProviderMapping(
{
modelId: hfModel,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
task: task!,
provider,
accessToken: args.accessToken,
},
{ fetch: options?.fetch }
);
if (!inferenceProviderMapping) {
throw new InferenceClientInputError(
`We have not been able to find inference provider information for model ${hfModel}.`
);
}
// Use the sync version with the resolved model
return makeRequestOptionsFromResolvedModel(
inferenceProviderMapping.providerId,
providerHelper,
args,
inferenceProviderMapping,
options
);
}
/**
* Helper that prepares request arguments. - for internal use only
* This sync version skips the model ID resolution step
*/
export function makeRequestOptionsFromResolvedModel(
resolvedModel: string,
providerHelper: ReturnType<typeof getProviderHelper>,
args: RequestArgs & {
data?: Blob | ArrayBuffer;
stream?: boolean;
},
mapping: InferenceProviderModelMapping | undefined,
options?: Options & {
task?: InferenceTask;
}
): { url: string; info: RequestInit } {
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
void model;
void maybeProvider;
const provider = providerHelper.provider;
const { includeCredentials, task, signal, billTo } = options ?? {};
const authMethod = (() => {
if (providerHelper.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
if (accessToken && accessToken.startsWith("hf_")) {
throw new InferenceClientInputError(`Provider ${provider} is closed-source and does not support HF tokens.`);
}
}
if (accessToken) {
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
}
if (includeCredentials === "include") {
// If accessToken is passed, it should take precedence over includeCredentials
return "credentials-include";
}
return "none";
})();
// Make URL
const modelId = endpointUrl ?? resolvedModel;
const url = providerHelper.makeUrl({
authMethod,
model: modelId,
task,
});
// Make headers
const headers = providerHelper.prepareHeaders(
{
accessToken,
authMethod,
},
"data" in args && !!args.data
);
if (billTo) {
headers[HF_HEADER_X_BILL_TO] = billTo;
}
// Add user-agent to headers
// e.g. @huggingface/inference/3.1.3
const ownUserAgent = `${PACKAGE_NAME}/${PACKAGE_VERSION}`;
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined]
.filter((x) => x !== undefined)
.join(" ");
headers["User-Agent"] = userAgent;
// Make body
const body = providerHelper.makeBody({
args: remainingArgs as Record<string, unknown>,
model: resolvedModel,
task,
mapping,
});
/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
*/
let credentials: RequestCredentials | undefined;
if (typeof includeCredentials === "string") {
credentials = includeCredentials as RequestCredentials;
} else if (includeCredentials === true) {
credentials = "include";
}
const info: RequestInit = {
headers,
method: "POST",
body: body,
...(credentials ? { credentials } : undefined),
signal,
};
return { url, info };
}
async function loadDefaultModel(task: InferenceTask): Promise<string> {
if (!tasks) {
tasks = await loadTaskInfo();
}
const taskInfo = tasks[task];
if ((taskInfo?.models.length ?? 0) <= 0) {
throw new InferenceClientInputError(
`No default model defined for task ${task}, please define the model explicitly.`
);
}
return taskInfo.models[0].id;
}
async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> {
const url = `${HF_HUB_URL}/api/tasks`;
const res = await fetch(url);
if (!res.ok) {
throw new InferenceClientHubApiError(
"Failed to load tasks definitions from Hugging Face Hub.",
{ url, method: "GET" },
{ requestId: res.headers.get("x-request-id") ?? "", status: res.status, body: await res.text() }
);
}
return await res.json();
}
function removeProviderPrefix(model: string, provider: string): string {
if (!model.startsWith(`${provider}/`)) {
throw new InferenceClientInputError(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
}
return model.slice(provider.length + 1);
}
|