|
|
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config.js"; |
|
|
import { PACKAGE_NAME, PACKAGE_VERSION } from "../package.js"; |
|
|
import type { InferenceTask, Options, RequestArgs } from "../types.js"; |
|
|
import type { InferenceProviderModelMapping } from "./getInferenceProviderMapping.js"; |
|
|
import { getInferenceProviderMapping } from "./getInferenceProviderMapping.js"; |
|
|
import type { getProviderHelper } from "./getProviderHelper.js"; |
|
|
import { isUrl } from "./isUrl.js"; |
|
|
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js"; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let tasks: Record<string, { models: { id: string }[] }> | null = null; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function makeRequestOptions( |
|
|
args: RequestArgs & { |
|
|
data?: Blob | ArrayBuffer; |
|
|
stream?: boolean; |
|
|
}, |
|
|
providerHelper: ReturnType<typeof getProviderHelper>, |
|
|
options?: Options & { |
|
|
/** In most cases (unless we pass a endpointUrl) we know the task */ |
|
|
task?: InferenceTask; |
|
|
} |
|
|
): Promise<{ url: string; info: RequestInit }> { |
|
|
const { model: maybeModel } = args; |
|
|
const provider = providerHelper.provider; |
|
|
const { task } = options ?? {}; |
|
|
|
|
|
|
|
|
if (args.endpointUrl && provider !== "hf-inference") { |
|
|
throw new InferenceClientInputError(`Cannot use endpointUrl with a third-party provider.`); |
|
|
} |
|
|
if (maybeModel && isUrl(maybeModel)) { |
|
|
throw new InferenceClientInputError(`Model URLs are no longer supported. Use endpointUrl instead.`); |
|
|
} |
|
|
|
|
|
if (args.endpointUrl) { |
|
|
|
|
|
return makeRequestOptionsFromResolvedModel( |
|
|
maybeModel ?? args.endpointUrl, |
|
|
providerHelper, |
|
|
args, |
|
|
undefined, |
|
|
options |
|
|
); |
|
|
} |
|
|
|
|
|
if (!maybeModel && !task) { |
|
|
throw new InferenceClientInputError("No model provided, and no task has been specified."); |
|
|
} |
|
|
|
|
|
|
|
|
const hfModel = maybeModel ?? (await loadDefaultModel(task!)); |
|
|
|
|
|
if (providerHelper.clientSideRoutingOnly && !maybeModel) { |
|
|
throw new InferenceClientInputError(`Provider ${provider} requires a model ID to be passed directly.`); |
|
|
} |
|
|
|
|
|
const inferenceProviderMapping = providerHelper.clientSideRoutingOnly |
|
|
? ({ |
|
|
|
|
|
providerId: removeProviderPrefix(maybeModel!, provider), |
|
|
|
|
|
hfModelId: maybeModel!, |
|
|
status: "live", |
|
|
|
|
|
task: task!, |
|
|
} satisfies InferenceProviderModelMapping) |
|
|
: await getInferenceProviderMapping( |
|
|
{ |
|
|
modelId: hfModel, |
|
|
|
|
|
task: task!, |
|
|
provider, |
|
|
accessToken: args.accessToken, |
|
|
}, |
|
|
{ fetch: options?.fetch } |
|
|
); |
|
|
if (!inferenceProviderMapping) { |
|
|
throw new InferenceClientInputError( |
|
|
`We have not been able to find inference provider information for model ${hfModel}.` |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
return makeRequestOptionsFromResolvedModel( |
|
|
inferenceProviderMapping.providerId, |
|
|
providerHelper, |
|
|
args, |
|
|
inferenceProviderMapping, |
|
|
options |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export function makeRequestOptionsFromResolvedModel( |
|
|
resolvedModel: string, |
|
|
providerHelper: ReturnType<typeof getProviderHelper>, |
|
|
args: RequestArgs & { |
|
|
data?: Blob | ArrayBuffer; |
|
|
stream?: boolean; |
|
|
}, |
|
|
mapping: InferenceProviderModelMapping | undefined, |
|
|
options?: Options & { |
|
|
task?: InferenceTask; |
|
|
} |
|
|
): { url: string; info: RequestInit } { |
|
|
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args; |
|
|
void model; |
|
|
void maybeProvider; |
|
|
|
|
|
const provider = providerHelper.provider; |
|
|
|
|
|
const { includeCredentials, task, signal, billTo } = options ?? {}; |
|
|
const authMethod = (() => { |
|
|
if (providerHelper.clientSideRoutingOnly) { |
|
|
|
|
|
if (accessToken && accessToken.startsWith("hf_")) { |
|
|
throw new InferenceClientInputError(`Provider ${provider} is closed-source and does not support HF tokens.`); |
|
|
} |
|
|
} |
|
|
if (accessToken) { |
|
|
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key"; |
|
|
} |
|
|
if (includeCredentials === "include") { |
|
|
|
|
|
return "credentials-include"; |
|
|
} |
|
|
return "none"; |
|
|
})(); |
|
|
|
|
|
|
|
|
|
|
|
const modelId = endpointUrl ?? resolvedModel; |
|
|
const url = providerHelper.makeUrl({ |
|
|
authMethod, |
|
|
model: modelId, |
|
|
task, |
|
|
}); |
|
|
|
|
|
const headers = providerHelper.prepareHeaders( |
|
|
{ |
|
|
accessToken, |
|
|
authMethod, |
|
|
}, |
|
|
"data" in args && !!args.data |
|
|
); |
|
|
if (billTo) { |
|
|
headers[HF_HEADER_X_BILL_TO] = billTo; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
const ownUserAgent = `${PACKAGE_NAME}/${PACKAGE_VERSION}`; |
|
|
const userAgent = [ownUserAgent, typeof navigator !== "undefined" ? navigator.userAgent : undefined] |
|
|
.filter((x) => x !== undefined) |
|
|
.join(" "); |
|
|
headers["User-Agent"] = userAgent; |
|
|
|
|
|
|
|
|
const body = providerHelper.makeBody({ |
|
|
args: remainingArgs as Record<string, unknown>, |
|
|
model: resolvedModel, |
|
|
task, |
|
|
mapping, |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
let credentials: RequestCredentials | undefined; |
|
|
if (typeof includeCredentials === "string") { |
|
|
credentials = includeCredentials as RequestCredentials; |
|
|
} else if (includeCredentials === true) { |
|
|
credentials = "include"; |
|
|
} |
|
|
|
|
|
const info: RequestInit = { |
|
|
headers, |
|
|
method: "POST", |
|
|
body: body, |
|
|
...(credentials ? { credentials } : undefined), |
|
|
signal, |
|
|
}; |
|
|
return { url, info }; |
|
|
} |
|
|
|
|
|
async function loadDefaultModel(task: InferenceTask): Promise<string> { |
|
|
if (!tasks) { |
|
|
tasks = await loadTaskInfo(); |
|
|
} |
|
|
const taskInfo = tasks[task]; |
|
|
if ((taskInfo?.models.length ?? 0) <= 0) { |
|
|
throw new InferenceClientInputError( |
|
|
`No default model defined for task ${task}, please define the model explicitly.` |
|
|
); |
|
|
} |
|
|
return taskInfo.models[0].id; |
|
|
} |
|
|
|
|
|
async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> { |
|
|
const url = `${HF_HUB_URL}/api/tasks`; |
|
|
const res = await fetch(url); |
|
|
|
|
|
if (!res.ok) { |
|
|
throw new InferenceClientHubApiError( |
|
|
"Failed to load tasks definitions from Hugging Face Hub.", |
|
|
{ url, method: "GET" }, |
|
|
{ requestId: res.headers.get("x-request-id") ?? "", status: res.status, body: await res.text() } |
|
|
); |
|
|
} |
|
|
return await res.json(); |
|
|
} |
|
|
|
|
|
function removeProviderPrefix(model: string, provider: string): string { |
|
|
if (!model.startsWith(`${provider}/`)) { |
|
|
throw new InferenceClientInputError(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`); |
|
|
} |
|
|
return model.slice(provider.length + 1); |
|
|
} |
|
|
|