File size: 5,407 Bytes
bce29b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.inferenceProviderMappingCache = void 0;
exports.fetchInferenceProviderMappingForModel = fetchInferenceProviderMappingForModel;
exports.getInferenceProviderMapping = getInferenceProviderMapping;
exports.resolveProvider = resolveProvider;
const config_js_1 = require("../config.js");
const consts_js_1 = require("../providers/consts.js");
const hf_inference_js_1 = require("../providers/hf-inference.js");
const typedInclude_js_1 = require("../utils/typedInclude.js");
const errors_js_1 = require("../errors.js");
exports.inferenceProviderMappingCache = new Map();
async function fetchInferenceProviderMappingForModel(modelId, accessToken, options) {
    let inferenceProviderMapping;
    if (exports.inferenceProviderMappingCache.has(modelId)) {
        // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
        inferenceProviderMapping = exports.inferenceProviderMappingCache.get(modelId);
    }
    else {
        const url = `${config_js_1.HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
        const resp = await (options?.fetch ?? fetch)(url, {
            headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
        });
        if (!resp.ok) {
            if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
                const error = await resp.json();
                if ("error" in error && typeof error.error === "string") {
                    throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error });
                }
            }
            else {
                throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() });
            }
        }
        let payload = null;
        try {
            payload = await resp.json();
        }
        catch {
            throw new errors_js_1.InferenceClientHubApiError(`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() });
        }
        if (!payload?.inferenceProviderMapping) {
            throw new errors_js_1.InferenceClientHubApiError(`We have not been able to find inference provider information for model ${modelId}.`, { url, method: "GET" }, { requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() });
        }
        inferenceProviderMapping = payload.inferenceProviderMapping;
    }
    return inferenceProviderMapping;
}
async function getInferenceProviderMapping(params, options) {
    if (consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
        return consts_js_1.HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
    }
    const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(params.modelId, params.accessToken, options);
    const providerMapping = inferenceProviderMapping[params.provider];
    if (providerMapping) {
        const equivalentTasks = params.provider === "hf-inference" && (0, typedInclude_js_1.typedInclude)(hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
            ? hf_inference_js_1.EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
            : [params.task];
        if (!(0, typedInclude_js_1.typedInclude)(equivalentTasks, providerMapping.task)) {
            throw new errors_js_1.InferenceClientInputError(`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`);
        }
        if (providerMapping.status === "staging") {
            console.warn(`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`);
        }
        return { ...providerMapping, hfModelId: params.modelId };
    }
    return null;
}
async function resolveProvider(provider, modelId, endpointUrl) {
    if (endpointUrl) {
        if (provider) {
            throw new errors_js_1.InferenceClientInputError("Specifying both endpointUrl and provider is not supported.");
        }
        /// Defaulting to hf-inference helpers / API
        return "hf-inference";
    }
    if (!provider) {
        console.log("Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.");
        provider = "auto";
    }
    if (provider === "auto") {
        if (!modelId) {
            throw new errors_js_1.InferenceClientInputError("Specifying a model is required when provider is 'auto'");
        }
        const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
        provider = Object.keys(inferenceProviderMapping)[0];
    }
    if (!provider) {
        throw new errors_js_1.InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);
    }
    return provider;
}