File size: 5,697 Bytes
bce29b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL } from "../config.js";
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts.js";
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference.js";
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types.js";
import { typedInclude } from "../utils/typedInclude.js";
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
export type InferenceProviderMapping = Partial<
Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId">>
>;
export interface InferenceProviderModelMapping {
adapter?: string;
adapterWeightsPath?: string;
hfModelId: ModelId;
providerId: string;
status: "live" | "staging";
task: WidgetType;
}
export async function fetchInferenceProviderMappingForModel(
modelId: ModelId,
accessToken?: string,
options?: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderMapping> {
let inferenceProviderMapping: InferenceProviderMapping | null;
if (inferenceProviderMappingCache.has(modelId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
} else {
const url = `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
const resp = await (options?.fetch ?? fetch)(url, {
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
});
if (!resp.ok) {
if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
const error = await resp.json();
if ("error" in error && typeof error.error === "string") {
throw new InferenceClientHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error }
);
}
} else {
throw new InferenceClientHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
}
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
try {
payload = await resp.json();
} catch {
throw new InferenceClientHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
if (!payload?.inferenceProviderMapping) {
throw new InferenceClientHubApiError(
`We have not been able to find inference provider information for model ${modelId}.`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
inferenceProviderMapping = payload.inferenceProviderMapping;
}
return inferenceProviderMapping;
}
export async function getInferenceProviderMapping(
params: {
accessToken?: string;
modelId: ModelId;
provider: InferenceProvider;
task: WidgetType;
},
options: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderModelMapping | null> {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
params.modelId,
params.accessToken,
options
);
const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
const equivalentTasks =
params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
: [params.task];
if (!typedInclude(equivalentTasks, providerMapping.task)) {
throw new InferenceClientInputError(
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
);
}
if (providerMapping.status === "staging") {
console.warn(
`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
);
}
return { ...providerMapping, hfModelId: params.modelId };
}
return null;
}
export async function resolveProvider(
provider?: InferenceProviderOrPolicy,
modelId?: string,
endpointUrl?: string
): Promise<InferenceProvider> {
if (endpointUrl) {
if (provider) {
throw new InferenceClientInputError("Specifying both endpointUrl and provider is not supported.");
}
/// Defaulting to hf-inference helpers / API
return "hf-inference";
}
if (!provider) {
console.log(
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
);
provider = "auto";
}
if (provider === "auto") {
if (!modelId) {
throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
}
if (!provider) {
throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);
}
return provider;
}
|