|
|
import { Template } from "@huggingface/jinja"; |
|
|
import { |
|
|
type InferenceSnippet, |
|
|
type InferenceSnippetLanguage, |
|
|
type ModelDataMinimal, |
|
|
getModelInputSnippet, |
|
|
inferenceSnippetLanguages, |
|
|
} from "@huggingface/tasks"; |
|
|
import type { PipelineType, WidgetType } from "@huggingface/tasks"; |
|
|
import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingface/tasks"; |
|
|
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping.js"; |
|
|
import { getProviderHelper } from "../lib/getProviderHelper.js"; |
|
|
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.js"; |
|
|
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js"; |
|
|
import { templates } from "./templates.exported.js"; |
|
|
|
|
|
export type InferenceSnippetOptions = { |
|
|
streaming?: boolean; |
|
|
billTo?: string; |
|
|
accessToken?: string; |
|
|
directRequest?: boolean; |
|
|
endpointUrl?: string; |
|
|
} & Record<string, unknown>; |
|
|
|
|
|
const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const; |
|
|
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const; |
|
|
const SH_CLIENTS = ["curl"] as const; |
|
|
|
|
|
type Client = (typeof SH_CLIENTS)[number] | (typeof PYTHON_CLIENTS)[number] | (typeof JS_CLIENTS)[number]; |
|
|
|
|
|
const CLIENTS: Record<InferenceSnippetLanguage, Client[]> = { |
|
|
js: [...JS_CLIENTS], |
|
|
python: [...PYTHON_CLIENTS], |
|
|
sh: [...SH_CLIENTS], |
|
|
}; |
|
|
|
|
|
const CLIENTS_AUTO_POLICY: Partial<Record<InferenceSnippetLanguage, Client[]>> = { |
|
|
js: ["huggingface.js"], |
|
|
python: ["huggingface_hub"], |
|
|
}; |
|
|
|
|
|
type InputPreparationFn = (model: ModelDataMinimal, opts?: Record<string, unknown>) => object; |
|
|
interface TemplateParams { |
|
|
accessToken?: string; |
|
|
authorizationHeader?: string; |
|
|
baseUrl?: string; |
|
|
fullUrl?: string; |
|
|
inputs?: object; |
|
|
providerInputs?: object; |
|
|
model?: ModelDataMinimal; |
|
|
provider?: InferenceProviderOrPolicy; |
|
|
providerModelId?: string; |
|
|
billTo?: string; |
|
|
methodName?: string; |
|
|
importBase64?: boolean; |
|
|
importJson?: boolean; |
|
|
endpointUrl?: string; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
const hasTemplate = (language: InferenceSnippetLanguage, client: Client, templateName: string): boolean => |
|
|
templates[language]?.[client]?.[templateName] !== undefined; |
|
|
|
|
|
const loadTemplate = ( |
|
|
language: InferenceSnippetLanguage, |
|
|
client: Client, |
|
|
templateName: string |
|
|
): ((data: TemplateParams) => string) => { |
|
|
const template = templates[language]?.[client]?.[templateName]; |
|
|
if (!template) { |
|
|
throw new Error(`Template not found: ${language}/${client}/${templateName}`); |
|
|
} |
|
|
return (data: TemplateParams) => new Template(template).render({ ...data }); |
|
|
}; |
|
|
|
|
|
const snippetImportPythonInferenceClient = loadTemplate("python", "huggingface_hub", "importInferenceClient"); |
|
|
const snippetImportRequests = loadTemplate("python", "requests", "importRequests"); |
|
|
|
|
|
|
|
|
|
|
|
const HF_PYTHON_METHODS: Partial<Record<WidgetType, string>> = { |
|
|
"audio-classification": "audio_classification", |
|
|
"audio-to-audio": "audio_to_audio", |
|
|
"automatic-speech-recognition": "automatic_speech_recognition", |
|
|
"document-question-answering": "document_question_answering", |
|
|
"feature-extraction": "feature_extraction", |
|
|
"fill-mask": "fill_mask", |
|
|
"image-classification": "image_classification", |
|
|
"image-segmentation": "image_segmentation", |
|
|
"image-to-image": "image_to_image", |
|
|
"image-to-text": "image_to_text", |
|
|
"object-detection": "object_detection", |
|
|
"question-answering": "question_answering", |
|
|
"sentence-similarity": "sentence_similarity", |
|
|
summarization: "summarization", |
|
|
"table-question-answering": "table_question_answering", |
|
|
"tabular-classification": "tabular_classification", |
|
|
"tabular-regression": "tabular_regression", |
|
|
"text-classification": "text_classification", |
|
|
"text-generation": "text_generation", |
|
|
"text-to-image": "text_to_image", |
|
|
"text-to-speech": "text_to_speech", |
|
|
"text-to-video": "text_to_video", |
|
|
"token-classification": "token_classification", |
|
|
translation: "translation", |
|
|
"visual-question-answering": "visual_question_answering", |
|
|
"zero-shot-classification": "zero_shot_classification", |
|
|
"zero-shot-image-classification": "zero_shot_image_classification", |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
const HF_JS_METHODS: Partial<Record<WidgetType, string>> = { |
|
|
"automatic-speech-recognition": "automaticSpeechRecognition", |
|
|
"feature-extraction": "featureExtraction", |
|
|
"fill-mask": "fillMask", |
|
|
"image-classification": "imageClassification", |
|
|
"question-answering": "questionAnswering", |
|
|
"sentence-similarity": "sentenceSimilarity", |
|
|
summarization: "summarization", |
|
|
"table-question-answering": "tableQuestionAnswering", |
|
|
"text-classification": "textClassification", |
|
|
"text-generation": "textGeneration", |
|
|
"text2text-generation": "textGeneration", |
|
|
"token-classification": "tokenClassification", |
|
|
"text-to-speech": "textToSpeech", |
|
|
translation: "translation", |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
const ACCESS_TOKEN_ROUTING_PLACEHOLDER = "hf_token_placeholder"; |
|
|
const ACCESS_TOKEN_DIRECT_REQUEST_PLACEHOLDER = "not_hf_token_placeholder"; |
|
|
|
|
|
|
|
|
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => { |
|
|
return ( |
|
|
model: ModelDataMinimal, |
|
|
provider: InferenceProviderOrPolicy, |
|
|
inferenceProviderMapping?: InferenceProviderModelMapping, |
|
|
opts?: InferenceSnippetOptions |
|
|
): InferenceSnippet[] => { |
|
|
const providerModelId = inferenceProviderMapping?.providerId ?? model.id; |
|
|
|
|
|
let task = model.pipeline_tag as InferenceTask; |
|
|
if ( |
|
|
model.pipeline_tag && |
|
|
["text-generation", "image-text-to-text"].includes(model.pipeline_tag) && |
|
|
model.tags.includes("conversational") |
|
|
) { |
|
|
templateName = opts?.streaming ? "conversationalStream" : "conversational"; |
|
|
inputPreparationFn = prepareConversationalInput; |
|
|
task = "conversational"; |
|
|
} |
|
|
let providerHelper: ReturnType<typeof getProviderHelper>; |
|
|
try { |
|
|
providerHelper = getProviderHelper(provider, task); |
|
|
} catch (e) { |
|
|
console.error(`Failed to get provider helper for ${provider} (${task})`, e); |
|
|
return []; |
|
|
} |
|
|
|
|
|
const placeholder = opts?.directRequest |
|
|
? ACCESS_TOKEN_DIRECT_REQUEST_PLACEHOLDER |
|
|
: ACCESS_TOKEN_ROUTING_PLACEHOLDER; |
|
|
const accessTokenOrPlaceholder = opts?.accessToken ?? placeholder; |
|
|
|
|
|
|
|
|
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) }; |
|
|
const request = makeRequestOptionsFromResolvedModel( |
|
|
providerModelId, |
|
|
providerHelper, |
|
|
{ |
|
|
accessToken: accessTokenOrPlaceholder, |
|
|
provider, |
|
|
endpointUrl: opts?.endpointUrl, |
|
|
...inputs, |
|
|
} as RequestArgs, |
|
|
inferenceProviderMapping, |
|
|
{ |
|
|
task, |
|
|
billTo: opts?.billTo, |
|
|
} |
|
|
); |
|
|
|
|
|
|
|
|
|
|
|
let providerInputs = inputs; |
|
|
const bodyAsObj = request.info.body; |
|
|
if (typeof bodyAsObj === "string") { |
|
|
try { |
|
|
providerInputs = JSON.parse(bodyAsObj); |
|
|
} catch (e) { |
|
|
console.error("Failed to parse body as JSON", e); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const params: TemplateParams = { |
|
|
accessToken: accessTokenOrPlaceholder, |
|
|
authorizationHeader: (request.info.headers as Record<string, string>)?.Authorization, |
|
|
baseUrl: removeSuffix(request.url, "/chat/completions"), |
|
|
fullUrl: request.url, |
|
|
inputs: { |
|
|
asObj: inputs, |
|
|
asCurlString: formatBody(inputs, "curl"), |
|
|
asJsonString: formatBody(inputs, "json"), |
|
|
asPythonString: formatBody(inputs, "python"), |
|
|
asTsString: formatBody(inputs, "ts"), |
|
|
}, |
|
|
providerInputs: { |
|
|
asObj: providerInputs, |
|
|
asCurlString: formatBody(providerInputs, "curl"), |
|
|
asJsonString: formatBody(providerInputs, "json"), |
|
|
asPythonString: formatBody(providerInputs, "python"), |
|
|
asTsString: formatBody(providerInputs, "ts"), |
|
|
}, |
|
|
model, |
|
|
provider, |
|
|
providerModelId: providerModelId ?? model.id, |
|
|
billTo: opts?.billTo, |
|
|
endpointUrl: opts?.endpointUrl, |
|
|
}; |
|
|
|
|
|
|
|
|
const clients = provider === "auto" ? CLIENTS_AUTO_POLICY : CLIENTS; |
|
|
return inferenceSnippetLanguages |
|
|
.map((language) => { |
|
|
const langClients = clients[language] ?? []; |
|
|
return langClients |
|
|
.map((client) => { |
|
|
if (!hasTemplate(language, client, templateName)) { |
|
|
return; |
|
|
} |
|
|
const template = loadTemplate(language, client, templateName); |
|
|
if (client === "huggingface_hub" && templateName.includes("basic")) { |
|
|
if (!(model.pipeline_tag && model.pipeline_tag in HF_PYTHON_METHODS)) { |
|
|
return; |
|
|
} |
|
|
params["methodName"] = HF_PYTHON_METHODS[model.pipeline_tag]; |
|
|
} |
|
|
|
|
|
if (client === "huggingface.js" && templateName.includes("basic")) { |
|
|
if (!(model.pipeline_tag && model.pipeline_tag in HF_JS_METHODS)) { |
|
|
return; |
|
|
} |
|
|
params["methodName"] = HF_JS_METHODS[model.pipeline_tag]; |
|
|
} |
|
|
|
|
|
|
|
|
let snippet = template(params).trim(); |
|
|
if (!snippet) { |
|
|
return; |
|
|
} |
|
|
|
|
|
|
|
|
if (client === "huggingface_hub") { |
|
|
const importSection = snippetImportPythonInferenceClient({ ...params }); |
|
|
snippet = `${importSection}\n\n${snippet}`; |
|
|
} else if (client === "requests") { |
|
|
const importSection = snippetImportRequests({ |
|
|
...params, |
|
|
importBase64: snippet.includes("base64"), |
|
|
importJson: snippet.includes("json."), |
|
|
}); |
|
|
snippet = `${importSection}\n\n${snippet}`; |
|
|
} |
|
|
|
|
|
|
|
|
if (snippet.includes(placeholder)) { |
|
|
snippet = replaceAccessTokenPlaceholder( |
|
|
opts?.directRequest, |
|
|
placeholder, |
|
|
snippet, |
|
|
language, |
|
|
provider, |
|
|
opts?.endpointUrl |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
return { language, client: client as string, content: snippet }; |
|
|
}) |
|
|
.filter((snippet): snippet is InferenceSnippet => snippet !== undefined); |
|
|
}) |
|
|
.flat(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
const prepareDocumentQuestionAnsweringInput = (model: ModelDataMinimal): object => { |
|
|
return JSON.parse(getModelInputSnippet(model) as string); |
|
|
}; |
|
|
|
|
|
const prepareImageToImageInput = (model: ModelDataMinimal): object => { |
|
|
const data = JSON.parse(getModelInputSnippet(model) as string); |
|
|
return { inputs: data.image, parameters: { prompt: data.prompt } }; |
|
|
}; |
|
|
|
|
|
const prepareConversationalInput = ( |
|
|
model: ModelDataMinimal, |
|
|
opts?: { |
|
|
streaming?: boolean; |
|
|
messages?: ChatCompletionInputMessage[]; |
|
|
temperature?: GenerationParameters["temperature"]; |
|
|
max_tokens?: GenerationParameters["max_new_tokens"]; |
|
|
top_p?: GenerationParameters["top_p"]; |
|
|
} |
|
|
): object => { |
|
|
return { |
|
|
messages: opts?.messages ?? getModelInputSnippet(model), |
|
|
...(opts?.temperature ? { temperature: opts?.temperature } : undefined), |
|
|
...(opts?.max_tokens ? { max_tokens: opts?.max_tokens } : undefined), |
|
|
...(opts?.top_p ? { top_p: opts?.top_p } : undefined), |
|
|
}; |
|
|
}; |
|
|
|
|
|
const prepareQuestionAnsweringInput = (model: ModelDataMinimal): object => { |
|
|
const data = JSON.parse(getModelInputSnippet(model) as string); |
|
|
return { question: data.question, context: data.context }; |
|
|
}; |
|
|
|
|
|
const prepareTableQuestionAnsweringInput = (model: ModelDataMinimal): object => { |
|
|
const data = JSON.parse(getModelInputSnippet(model) as string); |
|
|
return { query: data.query, table: JSON.stringify(data.table) }; |
|
|
}; |
|
|
|
|
|
const snippets: Partial< |
|
|
Record< |
|
|
PipelineType, |
|
|
( |
|
|
model: ModelDataMinimal, |
|
|
provider: InferenceProviderOrPolicy, |
|
|
inferenceProviderMapping?: InferenceProviderModelMapping, |
|
|
opts?: InferenceSnippetOptions |
|
|
) => InferenceSnippet[] |
|
|
> |
|
|
> = { |
|
|
"audio-classification": snippetGenerator("basicAudio"), |
|
|
"audio-to-audio": snippetGenerator("basicAudio"), |
|
|
"automatic-speech-recognition": snippetGenerator("basicAudio"), |
|
|
"document-question-answering": snippetGenerator("documentQuestionAnswering", prepareDocumentQuestionAnsweringInput), |
|
|
"feature-extraction": snippetGenerator("basic"), |
|
|
"fill-mask": snippetGenerator("basic"), |
|
|
"image-classification": snippetGenerator("basicImage"), |
|
|
"image-segmentation": snippetGenerator("basicImage"), |
|
|
"image-text-to-text": snippetGenerator("conversational"), |
|
|
"image-to-image": snippetGenerator("imageToImage", prepareImageToImageInput), |
|
|
"image-to-text": snippetGenerator("basicImage"), |
|
|
"object-detection": snippetGenerator("basicImage"), |
|
|
"question-answering": snippetGenerator("questionAnswering", prepareQuestionAnsweringInput), |
|
|
"sentence-similarity": snippetGenerator("basic"), |
|
|
summarization: snippetGenerator("basic"), |
|
|
"tabular-classification": snippetGenerator("tabular"), |
|
|
"tabular-regression": snippetGenerator("tabular"), |
|
|
"table-question-answering": snippetGenerator("tableQuestionAnswering", prepareTableQuestionAnsweringInput), |
|
|
"text-classification": snippetGenerator("basic"), |
|
|
"text-generation": snippetGenerator("basic"), |
|
|
"text-to-audio": snippetGenerator("textToAudio"), |
|
|
"text-to-image": snippetGenerator("textToImage"), |
|
|
"text-to-speech": snippetGenerator("textToSpeech"), |
|
|
"text-to-video": snippetGenerator("textToVideo"), |
|
|
"text2text-generation": snippetGenerator("basic"), |
|
|
"token-classification": snippetGenerator("basic"), |
|
|
translation: snippetGenerator("basic"), |
|
|
"zero-shot-classification": snippetGenerator("zeroShotClassification"), |
|
|
"zero-shot-image-classification": snippetGenerator("zeroShotImageClassification"), |
|
|
}; |
|
|
|
|
|
export function getInferenceSnippets( |
|
|
model: ModelDataMinimal, |
|
|
provider: InferenceProviderOrPolicy, |
|
|
inferenceProviderMapping?: InferenceProviderModelMapping, |
|
|
opts?: Record<string, unknown> |
|
|
): InferenceSnippet[] { |
|
|
return model.pipeline_tag && model.pipeline_tag in snippets |
|
|
? snippets[model.pipeline_tag]?.(model, provider, inferenceProviderMapping, opts) ?? [] |
|
|
: []; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
function formatBody(obj: object, format: "curl" | "json" | "python" | "ts"): string { |
|
|
switch (format) { |
|
|
case "curl": |
|
|
return indentString(formatBody(obj, "json")); |
|
|
|
|
|
case "json": |
|
|
|
|
|
return JSON.stringify(obj, null, 4).split("\n").slice(1, -1).join("\n"); |
|
|
|
|
|
case "python": |
|
|
return indentString( |
|
|
Object.entries(obj) |
|
|
.map(([key, value]) => { |
|
|
const formattedValue = JSON.stringify(value, null, 4).replace(/"/g, '"'); |
|
|
return `${key}=${formattedValue},`; |
|
|
}) |
|
|
.join("\n") |
|
|
); |
|
|
|
|
|
case "ts": |
|
|
|
|
|
return formatTsObject(obj).split("\n").slice(1, -1).join("\n"); |
|
|
|
|
|
default: |
|
|
throw new Error(`Unsupported format: ${format}`); |
|
|
} |
|
|
} |
|
|
|
|
|
function formatTsObject(obj: unknown, depth?: number): string { |
|
|
depth = depth ?? 0; |
|
|
|
|
|
|
|
|
if (typeof obj !== "object" || obj === null) { |
|
|
return JSON.stringify(obj); |
|
|
} |
|
|
|
|
|
|
|
|
if (Array.isArray(obj)) { |
|
|
const items = obj |
|
|
.map((item) => { |
|
|
const formatted = formatTsObject(item, depth + 1); |
|
|
return `${" ".repeat(4 * (depth + 1))}${formatted},`; |
|
|
}) |
|
|
.join("\n"); |
|
|
return `[\n${items}\n${" ".repeat(4 * depth)}]`; |
|
|
} |
|
|
|
|
|
|
|
|
const entries = Object.entries(obj); |
|
|
const lines = entries |
|
|
.map(([key, value]) => { |
|
|
const formattedValue = formatTsObject(value, depth + 1); |
|
|
const keyStr = /^[a-zA-Z_$][a-zA-Z0-9_$]*$/.test(key) ? key : `"${key}"`; |
|
|
return `${" ".repeat(4 * (depth + 1))}${keyStr}: ${formattedValue},`; |
|
|
}) |
|
|
.join("\n"); |
|
|
return `{\n${lines}\n${" ".repeat(4 * depth)}}`; |
|
|
} |
|
|
|
|
|
function indentString(str: string): string { |
|
|
return str |
|
|
.split("\n") |
|
|
.map((line) => " ".repeat(4) + line) |
|
|
.join("\n"); |
|
|
} |
|
|
|
|
|
function removeSuffix(str: string, suffix: string) { |
|
|
return str.endsWith(suffix) ? str.slice(0, -suffix.length) : str; |
|
|
} |
|
|
|
|
|
function replaceAccessTokenPlaceholder( |
|
|
directRequest: boolean | undefined, |
|
|
placeholder: string, |
|
|
snippet: string, |
|
|
language: InferenceSnippetLanguage, |
|
|
provider: InferenceProviderOrPolicy, |
|
|
endpointUrl?: string |
|
|
): string { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const useHfToken = |
|
|
!endpointUrl && |
|
|
(provider == "hf-inference" || |
|
|
(!directRequest && |
|
|
(!snippet.includes("https://") || |
|
|
snippet.includes("https://router.huggingface.co")))); |
|
|
const accessTokenEnvVar = useHfToken |
|
|
? "HF_TOKEN" |
|
|
: endpointUrl |
|
|
? "API_TOKEN" |
|
|
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; |
|
|
|
|
|
|
|
|
if (language === "sh") { |
|
|
snippet = snippet.replace( |
|
|
`'Authorization: Bearer ${placeholder}'`, |
|
|
`"Authorization: Bearer $${accessTokenEnvVar}"` |
|
|
); |
|
|
} else if (language === "python") { |
|
|
snippet = "import os\n" + snippet; |
|
|
snippet = snippet.replace( |
|
|
`"${placeholder}"`, |
|
|
`os.environ["${accessTokenEnvVar}"]` |
|
|
); |
|
|
snippet = snippet.replace( |
|
|
`"Bearer ${placeholder}"`, |
|
|
`f"Bearer {os.environ['${accessTokenEnvVar}']}"` |
|
|
); |
|
|
snippet = snippet.replace( |
|
|
`"Key ${placeholder}"`, |
|
|
`f"Key {os.environ['${accessTokenEnvVar}']}"` |
|
|
); |
|
|
snippet = snippet.replace( |
|
|
`"X-Key ${placeholder}"`, |
|
|
`f"X-Key {os.environ['${accessTokenEnvVar}']}"` |
|
|
); |
|
|
} else if (language === "js") { |
|
|
snippet = snippet.replace( |
|
|
`"${placeholder}"`, |
|
|
`process.env.${accessTokenEnvVar}` |
|
|
); |
|
|
snippet = snippet.replace( |
|
|
`Authorization: "Bearer ${placeholder}",`, |
|
|
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}\`,` |
|
|
); |
|
|
snippet = snippet.replace( |
|
|
`Authorization: "Key ${placeholder}",`, |
|
|
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}\`,` |
|
|
); |
|
|
snippet = snippet.replace( |
|
|
`Authorization: "X-Key ${placeholder}",`, |
|
|
`Authorization: \`X-Key $\{process.env.${accessTokenEnvVar}}\`,` |
|
|
); |
|
|
} |
|
|
return snippet; |
|
|
} |
|
|
|