File size: 5,697 Bytes
bce29b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL } from "../config.js";
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts.js";
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference.js";
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types.js";
import { typedInclude } from "../utils/typedInclude.js";
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";

export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();

export type InferenceProviderMapping = Partial<
	Record<InferenceProvider, Omit<InferenceProviderModelMapping, "hfModelId">>
>;

export interface InferenceProviderModelMapping {
	adapter?: string;
	adapterWeightsPath?: string;
	hfModelId: ModelId;
	providerId: string;
	status: "live" | "staging";
	task: WidgetType;
}

export async function fetchInferenceProviderMappingForModel(
	modelId: ModelId,
	accessToken?: string,
	options?: {
		fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
	}
): Promise<InferenceProviderMapping> {
	let inferenceProviderMapping: InferenceProviderMapping | null;
	if (inferenceProviderMappingCache.has(modelId)) {
		// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
		inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
	} else {
		const url = `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
		const resp = await (options?.fetch ?? fetch)(url, {
			headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
		});
		if (!resp.ok) {
			if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
				const error = await resp.json();
				if ("error" in error && typeof error.error === "string") {
					throw new InferenceClientHubApiError(
						`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`,
						{ url, method: "GET" },
						{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error }
					);
				}
			} else {
				throw new InferenceClientHubApiError(
					`Failed to fetch inference provider mapping for model ${modelId}`,
					{ url, method: "GET" },
					{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
				);
			}
		}
		let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
		try {
			payload = await resp.json();
		} catch {
			throw new InferenceClientHubApiError(
				`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`,
				{ url, method: "GET" },
				{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
			);
		}
		if (!payload?.inferenceProviderMapping) {
			throw new InferenceClientHubApiError(
				`We have not been able to find inference provider information for model ${modelId}.`,
				{ url, method: "GET" },
				{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
			);
		}
		inferenceProviderMapping = payload.inferenceProviderMapping;
	}
	return inferenceProviderMapping;
}

export async function getInferenceProviderMapping(
	params: {
		accessToken?: string;
		modelId: ModelId;
		provider: InferenceProvider;
		task: WidgetType;
	},
	options: {
		fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
	}
): Promise<InferenceProviderModelMapping | null> {
	if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
		return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
	}
	const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
		params.modelId,
		params.accessToken,
		options
	);
	const providerMapping = inferenceProviderMapping[params.provider];
	if (providerMapping) {
		const equivalentTasks =
			params.provider === "hf-inference" && typedInclude(EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS, params.task)
				? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
				: [params.task];
		if (!typedInclude(equivalentTasks, providerMapping.task)) {
			throw new InferenceClientInputError(
				`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
			);
		}
		if (providerMapping.status === "staging") {
			console.warn(
				`Model ${params.modelId} is in staging mode for provider ${params.provider}. Meant for test purposes only.`
			);
		}
		return { ...providerMapping, hfModelId: params.modelId };
	}
	return null;
}

export async function resolveProvider(
	provider?: InferenceProviderOrPolicy,
	modelId?: string,
	endpointUrl?: string
): Promise<InferenceProvider> {
	if (endpointUrl) {
		if (provider) {
			throw new InferenceClientInputError("Specifying both endpointUrl and provider is not supported.");
		}
		/// Defaulting to hf-inference helpers / API
		return "hf-inference";
	}
	if (!provider) {
		console.log(
			"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
		);
		provider = "auto";
	}
	if (provider === "auto") {
		if (!modelId) {
			throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
		}
		const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
		provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
	}
	if (!provider) {
		throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);
	}
	return provider;
}