| | import { z } from "zod"; |
| | import type { Endpoint, TextGenerationStreamOutputWithToolsAndWebSources } from "../endpoints"; |
| | import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; |
| | import { INFERENCE_PROVIDERS, InferenceClient } from "@huggingface/inference"; |
| | import { config } from "$lib/server/config"; |
| | import type { Tool, ToolCall } from "$lib/types/Tool"; |
| | import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; |
| | import type { FunctionDefinition } from "openai/resources/index.mjs"; |
| | import type { ChatCompletionTool, FunctionParameters } from "openai/resources/index.mjs"; |
| | import { logger } from "$lib/server/logger"; |
| | import type { MessageFile } from "$lib/types/Message"; |
| | import { v4 as uuidv4 } from "uuid"; |
| | import type { Conversation } from "$lib/types/Conversation"; |
| | import { downloadFile } from "$lib/server/files/downloadFile"; |
| | import { jsonrepair } from "jsonrepair"; |
| | type DeltaToolCall = NonNullable< |
| | ChatCompletionStreamOutput["choices"][number]["delta"]["tool_calls"] |
| | >[number]; |
| |
|
| | function createChatCompletionToolsArray(tools: Tool[] | undefined): ChatCompletionTool[] { |
| | const toolChoices = [] as ChatCompletionTool[]; |
| | if (tools === undefined) { |
| | return toolChoices; |
| | } |
| |
|
| | for (const t of tools) { |
| | const requiredProperties = [] as string[]; |
| |
|
| | const properties = {} as Record<string, unknown>; |
| | for (const idx in t.inputs) { |
| | const parameterDefinition = t.inputs[idx]; |
| |
|
| | const parameter = {} as Record<string, unknown>; |
| | switch (parameterDefinition.type) { |
| | case "str": |
| | parameter.type = "string"; |
| | break; |
| | case "float": |
| | case "int": |
| | parameter.type = "number"; |
| | break; |
| | case "bool": |
| | parameter.type = "boolean"; |
| | break; |
| | case "file": |
| | throw new Error("File type's currently not supported"); |
| | default: |
| | throw new Error(`Unknown tool IO type: ${t}`); |
| | } |
| |
|
| | if ("description" in parameterDefinition) { |
| | parameter.description = parameterDefinition.description; |
| | } |
| |
|
| | if (parameterDefinition.paramType == "required") { |
| | requiredProperties.push(t.inputs[idx].name); |
| | } |
| |
|
| | properties[t.inputs[idx].name] = parameter; |
| | } |
| |
|
| | const functionParameters: FunctionParameters = { |
| | type: "object", |
| | ...(requiredProperties.length > 0 ? { required: requiredProperties } : {}), |
| | properties, |
| | }; |
| |
|
| | const functionDefinition: FunctionDefinition = { |
| | name: t.name, |
| | description: t.description, |
| | parameters: functionParameters, |
| | }; |
| |
|
| | const toolDefinition: ChatCompletionTool = { |
| | type: "function", |
| | function: functionDefinition, |
| | }; |
| |
|
| | toolChoices.push(toolDefinition); |
| | } |
| |
|
| | return toolChoices; |
| | } |
| |
|
| | export const endpointInferenceClientParametersSchema = z.object({ |
| | type: z.literal("inference-client"), |
| | weight: z.number().int().positive().default(1), |
| | model: z.any(), |
| | provider: z.enum(INFERENCE_PROVIDERS).optional(), |
| | modelName: z.string().optional(), |
| | baseURL: z.string().optional(), |
| | multimodal: z |
| | .object({ |
| | image: createImageProcessorOptionsValidator({ |
| | supportedMimeTypes: [ |
| | "image/png", |
| | "image/jpeg", |
| | "image/webp", |
| | "image/avif", |
| | "image/tiff", |
| | "image/gif", |
| | ], |
| | preferredMimeType: "image/webp", |
| | maxSizeInMB: Infinity, |
| | maxWidth: 4096, |
| | maxHeight: 4096, |
| | }), |
| | }) |
| | .default({}), |
| | }); |
| |
|
| | export async function endpointInferenceClient( |
| | input: z.input<typeof endpointInferenceClientParametersSchema> |
| | ): Promise<Endpoint> { |
| | const { model, provider, modelName, baseURL, multimodal } = |
| | endpointInferenceClientParametersSchema.parse(input); |
| |
|
| | if (!!provider && !!baseURL) { |
| | throw new Error("provider and baseURL cannot both be provided"); |
| | } |
| | const client = baseURL |
| | ? new InferenceClient(config.HF_TOKEN, { endpointUrl: baseURL }) |
| | : new InferenceClient(config.HF_TOKEN); |
| |
|
| | const imageProcessor = multimodal.image ? makeImageProcessor(multimodal.image) : undefined; |
| |
|
| | async function prepareFiles(files: MessageFile[], conversationId?: Conversation["_id"]) { |
| | if (!imageProcessor) { |
| | return []; |
| | } |
| | const processedFiles = await Promise.all( |
| | files |
| | .filter((file) => file.mime.startsWith("image/")) |
| | .map(async (file) => { |
| | if (file.type === "hash" && conversationId) { |
| | file = await downloadFile(file.value, conversationId); |
| | } |
| | return imageProcessor(file); |
| | }) |
| | ); |
| | return processedFiles.map((file) => ({ |
| | type: "image_url" as const, |
| | image_url: { |
| | url: `data:${file.mime};base64,${file.image.toString("base64")}`, |
| | }, |
| | })); |
| | } |
| |
|
| | return async ({ messages, generateSettings, tools, toolResults, preprompt, conversationId }) => { |
| | |
| | let messagesArray = (await Promise.all( |
| | messages.map(async (message) => { |
| | return { |
| | role: message.from, |
| | content: [ |
| | ...(await prepareFiles(message.files ?? [], conversationId)), |
| | { type: "text" as const, text: message.content }, |
| | ], |
| | }; |
| | }) |
| | )) as any[]; |
| |
|
| | if ( |
| | !model.systemRoleSupported && |
| | messagesArray.length > 0 && |
| | messagesArray[0]?.role === "system" |
| | ) { |
| | messagesArray[0].role = "user"; |
| | } else if (messagesArray[0].role !== "system") { |
| | messagesArray.unshift({ |
| | role: "system", |
| | content: preprompt ?? "", |
| | }); |
| | } |
| |
|
| | if (toolResults && toolResults.length > 0) { |
| | messagesArray = [ |
| | ...messagesArray, |
| | { |
| | role: "assistant", |
| | content: [ |
| | { |
| | type: "text" as const, |
| | text: "", |
| | }, |
| | ], |
| | tool_calls: toolResults.map((toolResult) => ({ |
| | type: "function", |
| | function: { |
| | name: toolResult.call.name, |
| | arguments: JSON.stringify(toolResult.call.parameters), |
| | }, |
| | id: toolResult?.call?.toolId || uuidv4(), |
| | })), |
| | }, |
| | ...toolResults.map((toolResult) => ({ |
| | role: model.systemRoleSupported ? "tool" : "user", |
| | content: [ |
| | { |
| | type: "text" as const, |
| | text: JSON.stringify(toolResult), |
| | }, |
| | ], |
| | tool_call_id: toolResult?.call?.toolId || uuidv4(), |
| | })), |
| | ]; |
| | } |
| |
|
| | messagesArray = messagesArray.reduce((acc: typeof messagesArray, current) => { |
| | if (acc.length === 0 || current.role !== acc[acc.length - 1].role) { |
| | acc.push(current); |
| | } else { |
| | const prevMessage = acc[acc.length - 1]; |
| |
|
| | prevMessage.content = [ |
| | ...prevMessage.content.filter((item: any) => item.type !== "text"), |
| | ...current.content.filter((item: any) => item.type !== "text"), |
| | { |
| | type: "text" as const, |
| | text: [ |
| | ...prevMessage.content.filter((item: any) => item.type === "text"), |
| | ...current.content.filter((item: any) => item.type === "text"), |
| | ] |
| | .map((item: any) => item.text) |
| | .join("\n") |
| | .replace(/^\n/, ""), |
| | }, |
| | ]; |
| |
|
| | prevMessage.files = [...(prevMessage?.files ?? []), ...(current?.files ?? [])]; |
| |
|
| | prevMessage.tool_calls = [ |
| | ...(prevMessage?.tool_calls ?? []), |
| | ...(current?.tool_calls ?? []), |
| | ]; |
| | } |
| | return acc; |
| | }, []); |
| | const toolCallChoices = createChatCompletionToolsArray(tools); |
| | const stream = client.chatCompletionStream( |
| | { |
| | ...model.parameters, |
| | ...generateSettings, |
| | model: modelName ?? model.id ?? model.name, |
| | provider: baseURL ? undefined : provider || ("hf-inference" as const), |
| | messages: messagesArray, |
| | ...(toolCallChoices.length > 0 ? { tools: toolCallChoices, tool_choice: "auto" } : {}), |
| | toolResults, |
| | }, |
| | { |
| | fetch: async (url, options) => { |
| | return fetch(url, { |
| | ...options, |
| | headers: { |
| | ...options?.headers, |
| | "X-Use-Cache": "false", |
| | "ChatUI-Conversation-ID": conversationId?.toString() ?? "", |
| | }, |
| | }); |
| | }, |
| | } |
| | ); |
| |
|
| | let tokenId = 0; |
| | let generated_text = ""; |
| | const finalToolCalls: DeltaToolCall[] = []; |
| |
|
| | async function* convertStream(): AsyncGenerator< |
| | TextGenerationStreamOutputWithToolsAndWebSources, |
| | void, |
| | void |
| | > { |
| | for await (const chunk of stream) { |
| | const token = chunk.choices?.[0]?.delta?.content || ""; |
| |
|
| | generated_text += token; |
| |
|
| | const toolCalls = chunk.choices?.[0]?.delta?.tool_calls ?? []; |
| |
|
| | for (const toolCall of toolCalls) { |
| | const index = toolCall.index ?? 0; |
| |
|
| | if (!finalToolCalls[index]) { |
| | finalToolCalls[index] = toolCall; |
| | } else { |
| | if (finalToolCalls[index].function.arguments === undefined) { |
| | finalToolCalls[index].function.arguments = ""; |
| | } |
| | if (toolCall.function.arguments) { |
| | finalToolCalls[index].function.arguments += toolCall.function.arguments; |
| | } |
| | } |
| | } |
| |
|
| | yield { |
| | token: { |
| | id: tokenId++, |
| | text: token, |
| | logprob: 0, |
| | special: false, |
| | }, |
| | details: null, |
| | generated_text: null, |
| | }; |
| | } |
| |
|
| | let mappedToolCalls: ToolCall[] | undefined; |
| | try { |
| | if (finalToolCalls.length === 0) { |
| | mappedToolCalls = undefined; |
| | } else { |
| | |
| | const toolCallsArray = Array.isArray(finalToolCalls) ? finalToolCalls : [finalToolCalls]; |
| |
|
| | mappedToolCalls = toolCallsArray.map((tc) => ({ |
| | id: tc.id, |
| | name: tc.function.name ?? "", |
| | parameters: JSON.parse(jsonrepair(tc.function.arguments || "{}")), |
| | })); |
| | } |
| | } catch (e) { |
| | logger.error(e, "error mapping tool calls"); |
| | } |
| |
|
| | yield { |
| | token: { |
| | id: tokenId++, |
| | text: "", |
| | logprob: 0, |
| | special: true, |
| | toolCalls: mappedToolCalls, |
| | }, |
| | generated_text, |
| | details: null, |
| | }; |
| | } |
| |
|
| | return convertStream(); |
| | }; |
| | } |
| |
|