| | import { ToolResultStatus, type ToolCall, type Tool, type ToolResult } from "$lib/types/Tool"; |
| | import { v4 as uuidV4 } from "uuid"; |
| | import { getCallMethod, toolFromConfigs, type BackendToolContext } from "../tools"; |
| | import { |
| | MessageToolUpdateType, |
| | MessageUpdateType, |
| | type MessageUpdate, |
| | } from "$lib/types/MessageUpdate"; |
| | import type { TextGenerationContext } from "./types"; |
| |
|
| | import directlyAnswer from "../tools/directlyAnswer"; |
| | import websearch from "../tools/web/search"; |
| | import { z } from "zod"; |
| | import { logger } from "../logger"; |
| | import { extractJson, toolHasName } from "../tools/utils"; |
| | import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators"; |
| | import { MetricsServer } from "../metrics"; |
| | import { stringifyError } from "$lib/utils/stringifyError"; |
| | import { collections } from "../database"; |
| | import { ObjectId } from "mongodb"; |
| | import type { Message } from "$lib/types/Message"; |
| | import type { Assistant } from "$lib/types/Assistant"; |
| | import { assistantHasWebSearch } from "./assistant"; |
| |
|
| | export async function getTools( |
| | toolsPreference: Array<string>, |
| | assistant: Pick<Assistant, "rag" | "tools"> | undefined |
| | ): Promise<Tool[]> { |
| | let preferences = toolsPreference; |
| |
|
| | if (assistant) { |
| | if (assistant?.tools?.length) { |
| | preferences = assistant.tools; |
| |
|
| | if (assistantHasWebSearch(assistant)) { |
| | preferences.push(websearch._id.toString()); |
| | } |
| | } else { |
| | if (assistantHasWebSearch(assistant)) { |
| | return [websearch, directlyAnswer]; |
| | } |
| | return [directlyAnswer]; |
| | } |
| | } |
| |
|
| | |
| | const activeConfigTools = toolFromConfigs.filter((el) => { |
| | if (el.isLocked && el.isOnByDefault && !assistant) return true; |
| | return preferences?.includes(el._id.toString()) ?? (el.isOnByDefault && !assistant); |
| | }); |
| |
|
| | |
| | const activeCommunityTools = await collections.tools |
| | .find({ |
| | _id: { $in: preferences.map((el) => new ObjectId(el)) }, |
| | }) |
| | .toArray() |
| | .then((el) => el.map((el) => ({ ...el, call: getCallMethod(el) }))); |
| |
|
| | return [...activeConfigTools, ...activeCommunityTools]; |
| | } |
| |
|
| | async function* callTool( |
| | ctx: BackendToolContext, |
| | tools: Tool[], |
| | call: ToolCall |
| | ): AsyncGenerator<MessageUpdate, ToolResult | undefined, undefined> { |
| | const uuid = uuidV4(); |
| |
|
| | const tool = tools.find((el) => toolHasName(call.name, el)); |
| | if (!tool) { |
| | return { call, status: ToolResultStatus.Error, message: `Could not find tool "${call.name}"` }; |
| | } |
| |
|
| | |
| | if (toolHasName(directlyAnswer.name, tool)) return; |
| |
|
| | const startTime = Date.now(); |
| | MetricsServer.getMetrics().tool.toolUseCount.inc({ tool: call.name }); |
| |
|
| | yield { |
| | type: MessageUpdateType.Tool, |
| | subtype: MessageToolUpdateType.Call, |
| | uuid, |
| | call, |
| | }; |
| |
|
| | try { |
| | const toolResult = yield* tool.call(call.parameters, ctx, uuid); |
| |
|
| | yield { |
| | type: MessageUpdateType.Tool, |
| | subtype: MessageToolUpdateType.Result, |
| | uuid, |
| | result: { ...toolResult, call, status: ToolResultStatus.Success }, |
| | }; |
| |
|
| | MetricsServer.getMetrics().tool.toolUseDuration.observe( |
| | { tool: call.name }, |
| | Date.now() - startTime |
| | ); |
| |
|
| | await collections.tools.findOneAndUpdate({ _id: tool._id }, { $inc: { useCount: 1 } }); |
| |
|
| | return { ...toolResult, call, status: ToolResultStatus.Success }; |
| | } catch (error) { |
| | MetricsServer.getMetrics().tool.toolUseCountError.inc({ tool: call.name }); |
| | logger.error(error, `Failed while running tool ${call.name}. ${stringifyError(error)}`); |
| |
|
| | yield { |
| | type: MessageUpdateType.Tool, |
| | subtype: MessageToolUpdateType.Error, |
| | uuid, |
| | message: |
| | "An error occurred while calling the tool " + call.name + ": " + stringifyError(error), |
| | }; |
| |
|
| | return { |
| | call, |
| | status: ToolResultStatus.Error, |
| | message: |
| | "An error occurred while calling the tool " + call.name + ": " + stringifyError(error), |
| | }; |
| | } |
| | } |
| |
|
| | export async function* runTools( |
| | ctx: TextGenerationContext, |
| | tools: Tool[], |
| | preprompt?: string |
| | ): AsyncGenerator<MessageUpdate, ToolResult[], undefined> { |
| | const { endpoint, conv, messages, assistant, ip, username } = ctx; |
| | const calls: ToolCall[] = []; |
| |
|
| | const pickToolStartTime = Date.now(); |
| | |
| |
|
| | const files = messages.reduce((acc, curr, idx) => { |
| | if (curr.files) { |
| | const prefix = (curr.from === "user" ? "input" : "ouput") + "_" + idx; |
| | acc.push( |
| | ...curr.files.map( |
| | (file, fileIdx) => `${prefix}_${fileIdx}.${file?.name?.split(".")?.pop()?.toLowerCase()}` |
| | ) |
| | ); |
| | } |
| | return acc; |
| | }, [] as string[]); |
| |
|
| | let formattedMessages = messages.map((message, msgIdx) => { |
| | let content = message.content; |
| |
|
| | if (message.files && message.files.length > 0) { |
| | content += |
| | "\n\nAdded files: \n - " + |
| | message.files |
| | .map((file, fileIdx) => { |
| | const prefix = message.from === "user" ? "input" : "output"; |
| | const fileName = file.name.split(".").pop()?.toLowerCase(); |
| |
|
| | return `${prefix}_${msgIdx}_${fileIdx}.${fileName}`; |
| | }) |
| | .join("\n - "); |
| | } |
| |
|
| | return { |
| | ...message, |
| | content, |
| | } satisfies Message; |
| | }); |
| |
|
| | const fileMsg = { |
| | id: crypto.randomUUID(), |
| | from: "system", |
| | content: |
| | "Here is the list of available filenames that can be used as input for tools. Use the filenames that are in this list. \n The filename structure is as follows : {input for user|output for tool}_{message index in the conversation}_{file index in the list of files}.{file extension} \n - " + |
| | files.join("\n - ") + |
| | "\n\n\n", |
| | } satisfies Message; |
| |
|
| | |
| | formattedMessages = files.length |
| | ? [...formattedMessages.slice(0, 1), fileMsg, ...formattedMessages.slice(1)] |
| | : messages; |
| |
|
| | const mappedTools = tools.map((tool) => ({ |
| | ...tool, |
| | inputs: tool.inputs.map((input) => ({ |
| | ...input, |
| | type: input.type === "file" ? "str" : input.type, |
| | })), |
| | })); |
| |
|
| | |
| | for await (const output of await endpoint({ |
| | messages: formattedMessages, |
| | preprompt, |
| | generateSettings: { temperature: 0.1, ...assistant?.generateSettings }, |
| | tools: mappedTools, |
| | conversationId: conv._id, |
| | })) { |
| | |
| | if (output.token.toolCalls) { |
| | calls.push(...output.token.toolCalls); |
| | continue; |
| | } |
| |
|
| | |
| | |
| | if (output.generated_text) { |
| | try { |
| | const rawCalls = await extractJson(output.generated_text); |
| | const newCalls = rawCalls |
| | .map((call) => externalToToolCall(call, tools)) |
| | .filter((call) => call !== undefined) as ToolCall[]; |
| |
|
| | calls.push(...newCalls); |
| | } catch (e) { |
| | logger.warn({ rawCall: output.generated_text, error: e }, "Error while parsing tool calls"); |
| | } |
| | } |
| | } |
| |
|
| | MetricsServer.getMetrics().tool.timeToChooseTools.observe( |
| | { model: conv.model }, |
| | Date.now() - pickToolStartTime |
| | ); |
| |
|
| | const toolContext: BackendToolContext = { conv, messages, preprompt, assistant, ip, username }; |
| | const toolResults: (ToolResult | undefined)[] = yield* mergeAsyncGenerators( |
| | calls.map((call) => callTool(toolContext, tools, call)) |
| | ); |
| | return toolResults.filter((result): result is ToolResult => result !== undefined); |
| | } |
| |
|
| | export function externalToToolCall(call: unknown, tools: Tool[]): ToolCall | undefined { |
| | |
| | if (!isValidCallObject(call)) { |
| | return undefined; |
| | } |
| |
|
| | const parsedCall = parseExternalCall(call); |
| | if (!parsedCall) return undefined; |
| |
|
| | const tool = tools.find((tool) => toolHasName(parsedCall.tool_name, tool)); |
| | if (!tool) { |
| | logger.debug( |
| | `Model requested tool that does not exist: "${parsedCall.tool_name}". Skipping tool...` |
| | ); |
| | return undefined; |
| | } |
| |
|
| | const parametersWithDefaults: Record<string, string> = {}; |
| |
|
| | for (const input of tool.inputs) { |
| | const value = parsedCall.parameters[input.name]; |
| |
|
| | |
| | if (input.paramType === "required") { |
| | if (value === undefined) { |
| | logger.debug( |
| | `Model requested tool "${parsedCall.tool_name}" but was missing required parameter "${input.name}". Skipping tool...` |
| | ); |
| | return; |
| | } |
| | parametersWithDefaults[input.name] = value; |
| | continue; |
| | } |
| |
|
| | |
| | parametersWithDefaults[input.name] = value; |
| |
|
| | if (input.paramType === "optional") { |
| | parametersWithDefaults[input.name] ??= input.default.toString(); |
| | } |
| | } |
| |
|
| | return { |
| | name: parsedCall.tool_name, |
| | parameters: parametersWithDefaults, |
| | }; |
| | } |
| |
|
| | |
| | function isValidCallObject(call: unknown): call is Record<string, unknown> { |
| | return typeof call === "object" && call !== null; |
| | } |
| |
|
| | function parseExternalCall(callObj: Record<string, unknown>) { |
| | let toolCall = callObj; |
| | if ( |
| | isValidCallObject(callObj) && |
| | "function" in callObj && |
| | isValidCallObject(callObj.function) && |
| | "_name" in callObj.function |
| | ) { |
| | toolCall = { |
| | tool_name: callObj["function"]["_name"], |
| | parameters: { |
| | ...callObj["function"], |
| | _name: undefined, |
| | }, |
| | }; |
| | } |
| |
|
| | const nameFields = ["tool_name", "name"] as const; |
| | const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const; |
| |
|
| | const groupedCall = { |
| | tool_name: "" as string, |
| | parameters: undefined as Record<string, string> | undefined, |
| | }; |
| |
|
| | for (const name of nameFields) { |
| | if (toolCall[name]) { |
| | groupedCall.tool_name = toolCall[name] as string; |
| | } |
| | } |
| |
|
| | for (const name of parametersFields) { |
| | if (toolCall[name]) { |
| | groupedCall.parameters = toolCall[name] as Record<string, string>; |
| | } |
| | } |
| |
|
| | return z |
| | .object({ |
| | tool_name: z.string(), |
| | parameters: z.record(z.any()), |
| | }) |
| | .parse(groupedCall); |
| | } |
| |
|