File size: 5,407 Bytes
bce29b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.inferenceProviderMappingCache = void 0;
exports.fetchInferenceProviderMappingForModel = fetchInferenceProviderMappingForModel;
exports.getInferenceProviderMapping = getInferenceProviderMapping;
exports.resolveProvider = resolveProvider;
const config_js_1 = require("../config.js");
const consts_js_1 = require("../providers/consts.js");
const hf_inference_js_1 = require("../providers/hf-inference.js");
const typedInclude_js_1 = require("../utils/typedInclude.js");
const errors_js_1 = require("../errors.js");
exports.inferenceProviderMappingCache = new Map();
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
let inferenceProviderMapping;
if (exports.inferenceProviderMappingCache.has(modelId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = exports.inferenceProviderMappingCache.get(modelId);
}
else {
const url = `${config_js_1.HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
const resp = await (options?.fetch ?? fetch)(url, {
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
});
if (!resp.ok) {
if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
const error = await resp.json();
if ("error" in error && typeof error.error === "string") {
throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error });
}
}
else {
throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() });
}
}
let payload = null;
try {
payload = await resp.json();
}
catch {
throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() });
}
if (!payload?.inferenceProviderMapping) {
throw new errors_js_1.InferenceClientHubApiError(`We have not been able to find inference provider information for model ${modelId}.`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() });
}
inferenceProviderMapping = payload.inferenceProviderMapping;
}
return inferenceProviderMapping;
}
async function getInferenceProviderMapping(params, options) {
if (consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
const equivalentTasks = params.provider === "hf-inference" && (0, typedInclude_js_1.typedInclude)(hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
? hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
: [params.task];
if (!(0, typedInclude_js_1.typedInclude)(equivalentTasks, providerMapping.task)) {
throw new errors_js_1.InferenceClientInputError(`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`);
}
if (providerMapping.status === "staging") {
console.warn(`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`);
}
return { ...providerMapping, hfModelId: params.modelId };
}
return null;
}
async function resolveProvider(provider, modelId, endpointUrl) {
if (endpointUrl) {
if (provider) {
throw new errors_js_1.InferenceClientInputError("Specifying both endpointUrl and provider is not supported.");
}
/// Defaulting to hf-inference helpers / API
return "hf-inference";
}
if (!provider) {
console.log("Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.");
provider = "auto";
}
if (provider === "auto") {
if (!modelId) {
throw new errors_js_1.InferenceClientInputError("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0];
}
if (!provider) {
throw new errors_js_1.InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);
}
return provider;
}
|