Spaces:
Running
Running
| import { z } from "zod"; | |
| import type { Endpoint, TextGenerationStreamOutputWithToolsAndWebSources } from "../endpoints"; | |
| import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; | |
| import { INFERENCE_PROVIDERS, InferenceClient } from "@huggingface/inference"; | |
| import { config } from "$lib/server/config"; | |
| import type { Tool, ToolCall } from "$lib/types/Tool"; | |
| import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; | |
| import type { FunctionDefinition } from "openai/resources/index.mjs"; | |
| import type { ChatCompletionTool, FunctionParameters } from "openai/resources/index.mjs"; | |
| import { logger } from "$lib/server/logger"; | |
| import type { MessageFile } from "$lib/types/Message"; | |
| import { v4 as uuidv4 } from "uuid"; | |
| import type { Conversation } from "$lib/types/Conversation"; | |
| import { downloadFile } from "$lib/server/files/downloadFile"; | |
| import { jsonrepair } from "jsonrepair"; | |
| type DeltaToolCall = NonNullable< | |
| ChatCompletionStreamOutput["choices"][number]["delta"]["tool_calls"] | |
| >[number]; | |
| function createChatCompletionToolsArray(tools: Tool[] | undefined): ChatCompletionTool[] { | |
| const toolChoices = [] as ChatCompletionTool[]; | |
| if (tools === undefined) { | |
| return toolChoices; | |
| } | |
| for (const t of tools) { | |
| const requiredProperties = [] as string[]; | |
| const properties = {} as Record<string, unknown>; | |
| for (const idx in t.inputs) { | |
| const parameterDefinition = t.inputs[idx]; | |
| const parameter = {} as Record<string, unknown>; | |
| switch (parameterDefinition.type) { | |
| case "str": | |
| parameter.type = "string"; | |
| break; | |
| case "float": | |
| case "int": | |
| parameter.type = "number"; | |
| break; | |
| case "bool": | |
| parameter.type = "boolean"; | |
| break; | |
| case "file": | |
| throw new Error("File type's currently not supported"); | |
| default: | |
| throw new Error(`Unknown tool IO type: ${t}`); | |
| } | |
| if ("description" in parameterDefinition) { | |
| parameter.description = parameterDefinition.description; | |
| } | |
| if (parameterDefinition.paramType == "required") { | |
| requiredProperties.push(t.inputs[idx].name); | |
| } | |
| properties[t.inputs[idx].name] = parameter; | |
| } | |
| const functionParameters: FunctionParameters = { | |
| type: "object", | |
| ...(requiredProperties.length > 0 ? { required: requiredProperties } : {}), | |
| properties, | |
| }; | |
| const functionDefinition: FunctionDefinition = { | |
| name: t.name, | |
| description: t.description, | |
| parameters: functionParameters, | |
| }; | |
| const toolDefinition: ChatCompletionTool = { | |
| type: "function", | |
| function: functionDefinition, | |
| }; | |
| toolChoices.push(toolDefinition); | |
| } | |
| return toolChoices; | |
| } | |
| export const endpointInferenceClientParametersSchema = z.object({ | |
| type: z.literal("inference-client"), | |
| weight: z.number().int().positive().default(1), | |
| model: z.any(), | |
| provider: z.enum(INFERENCE_PROVIDERS).optional(), | |
| modelName: z.string().optional(), | |
| baseURL: z.string().optional(), | |
| multimodal: z | |
| .object({ | |
| image: createImageProcessorOptionsValidator({ | |
| supportedMimeTypes: [ | |
| "image/png", | |
| "image/jpeg", | |
| "image/webp", | |
| "image/avif", | |
| "image/tiff", | |
| "image/gif", | |
| ], | |
| preferredMimeType: "image/webp", | |
| maxSizeInMB: Infinity, | |
| maxWidth: 4096, | |
| maxHeight: 4096, | |
| }), | |
| }) | |
| .default({}), | |
| }); | |
| export async function endpointInferenceClient( | |
| input: z.input<typeof endpointInferenceClientParametersSchema> | |
| ): Promise<Endpoint> { | |
| const { model, provider, modelName, baseURL, multimodal } = | |
| endpointInferenceClientParametersSchema.parse(input); | |
| if (!!provider && !!baseURL) { | |
| throw new Error("provider and baseURL cannot both be provided"); | |
| } | |
| const client = baseURL | |
| ? new InferenceClient(config.HF_TOKEN, { endpointUrl: baseURL }) | |
| : new InferenceClient(config.HF_TOKEN); | |
| const imageProcessor = multimodal.image ? makeImageProcessor(multimodal.image) : undefined; | |
| async function prepareFiles(files: MessageFile[], conversationId?: Conversation["_id"]) { | |
| if (!imageProcessor) { | |
| return []; | |
| } | |
| const processedFiles = await Promise.all( | |
| files | |
| .filter((file) => file.mime.startsWith("image/")) | |
| .map(async (file) => { | |
| if (file.type === "hash" && conversationId) { | |
| file = await downloadFile(file.value, conversationId); | |
| } | |
| return imageProcessor(file); | |
| }) | |
| ); | |
| return processedFiles.map((file) => ({ | |
| type: "image_url" as const, | |
| image_url: { | |
| url: `data:${file.mime};base64,${file.image.toString("base64")}`, | |
| }, | |
| })); | |
| } | |
| return async ({ messages, generateSettings, tools, toolResults, preprompt, conversationId }) => { | |
| /* eslint-disable @typescript-eslint/no-explicit-any */ | |
| let messagesArray = (await Promise.all( | |
| messages.map(async (message) => { | |
| return { | |
| role: message.from, | |
| content: [ | |
| ...(await prepareFiles(message.files ?? [], conversationId)), | |
| { type: "text" as const, text: message.content }, | |
| ], | |
| }; | |
| }) | |
| )) as any[]; | |
| if ( | |
| !model.systemRoleSupported && | |
| messagesArray.length > 0 && | |
| messagesArray[0]?.role === "system" | |
| ) { | |
| messagesArray[0].role = "user"; | |
| } else if (messagesArray[0].role !== "system") { | |
| messagesArray.unshift({ | |
| role: "system", | |
| content: preprompt ?? "", | |
| }); | |
| } | |
| if (toolResults && toolResults.length > 0) { | |
| messagesArray = [ | |
| ...messagesArray, | |
| { | |
| role: "assistant", | |
| content: [ | |
| { | |
| type: "text" as const, | |
| text: "", | |
| }, | |
| ], | |
| tool_calls: toolResults.map((toolResult) => ({ | |
| type: "function", | |
| function: { | |
| name: toolResult.call.name, | |
| arguments: JSON.stringify(toolResult.call.parameters), | |
| }, | |
| id: toolResult?.call?.toolId || uuidv4(), | |
| })), | |
| }, | |
| ...toolResults.map((toolResult) => ({ | |
| role: model.systemRoleSupported ? "tool" : "user", | |
| content: [ | |
| { | |
| type: "text" as const, | |
| text: JSON.stringify(toolResult), | |
| }, | |
| ], | |
| tool_call_id: toolResult?.call?.toolId || uuidv4(), | |
| })), | |
| ]; | |
| } | |
| messagesArray = messagesArray.reduce((acc: typeof messagesArray, current) => { | |
| if (acc.length === 0 || current.role !== acc[acc.length - 1].role) { | |
| acc.push(current); | |
| } else { | |
| const prevMessage = acc[acc.length - 1]; | |
| prevMessage.content = [ | |
| ...prevMessage.content.filter((item: any) => item.type !== "text"), | |
| ...current.content.filter((item: any) => item.type !== "text"), | |
| { | |
| type: "text" as const, | |
| text: [ | |
| ...prevMessage.content.filter((item: any) => item.type === "text"), | |
| ...current.content.filter((item: any) => item.type === "text"), | |
| ] | |
| .map((item: any) => item.text) | |
| .join("\n") | |
| .replace(/^\n/, ""), | |
| }, | |
| ]; | |
| prevMessage.files = [...(prevMessage?.files ?? []), ...(current?.files ?? [])]; | |
| prevMessage.tool_calls = [ | |
| ...(prevMessage?.tool_calls ?? []), | |
| ...(current?.tool_calls ?? []), | |
| ]; | |
| } | |
| return acc; | |
| }, []); | |
| const toolCallChoices = createChatCompletionToolsArray(tools); | |
| const stream = client.chatCompletionStream( | |
| { | |
| ...model.parameters, | |
| ...generateSettings, | |
| model: modelName ?? model.id ?? model.name, | |
| provider: baseURL ? undefined : provider || ("hf-inference" as const), | |
| messages: messagesArray, | |
| ...(toolCallChoices.length > 0 ? { tools: toolCallChoices, tool_choice: "auto" } : {}), | |
| toolResults, | |
| }, | |
| { | |
| fetch: async (url, options) => { | |
| return fetch(url, { | |
| ...options, | |
| headers: { | |
| ...options?.headers, | |
| "X-Use-Cache": "false", | |
| "ChatUI-Conversation-ID": conversationId?.toString() ?? "", | |
| }, | |
| }); | |
| }, | |
| } | |
| ); | |
| let tokenId = 0; | |
| let generated_text = ""; | |
| const finalToolCalls: DeltaToolCall[] = []; | |
| async function* convertStream(): AsyncGenerator< | |
| TextGenerationStreamOutputWithToolsAndWebSources, | |
| void, | |
| void | |
| > { | |
| for await (const chunk of stream) { | |
| const token = chunk.choices?.[0]?.delta?.content || ""; | |
| generated_text += token; | |
| const toolCalls = chunk.choices?.[0]?.delta?.tool_calls ?? []; | |
| for (const toolCall of toolCalls) { | |
| const index = toolCall.index ?? 0; | |
| if (!finalToolCalls[index]) { | |
| finalToolCalls[index] = toolCall; | |
| } else { | |
| if (finalToolCalls[index].function.arguments === undefined) { | |
| finalToolCalls[index].function.arguments = ""; | |
| } | |
| if (toolCall.function.arguments) { | |
| finalToolCalls[index].function.arguments += toolCall.function.arguments; | |
| } | |
| } | |
| } | |
| yield { | |
| token: { | |
| id: tokenId++, | |
| text: token, | |
| logprob: 0, | |
| special: false, | |
| }, | |
| details: null, | |
| generated_text: null, | |
| }; | |
| } | |
| let mappedToolCalls: ToolCall[] | undefined; | |
| try { | |
| if (finalToolCalls.length === 0) { | |
| mappedToolCalls = undefined; | |
| } else { | |
| // Ensure finalToolCalls is an array | |
| const toolCallsArray = Array.isArray(finalToolCalls) ? finalToolCalls : [finalToolCalls]; | |
| mappedToolCalls = toolCallsArray.map((tc) => ({ | |
| id: tc.id, | |
| name: tc.function.name ?? "", | |
| parameters: JSON.parse(jsonrepair(tc.function.arguments || "{}")), | |
| })); | |
| } | |
| } catch (e) { | |
| logger.error(e, "error mapping tool calls"); | |
| } | |
| yield { | |
| token: { | |
| id: tokenId++, | |
| text: "", | |
| logprob: 0, | |
| special: true, | |
| toolCalls: mappedToolCalls, | |
| }, | |
| generated_text, | |
| details: null, | |
| }; | |
| } | |
| return convertStream(); | |
| }; | |
| } | |