| | import type { Tool } from "$lib/types/Tool"; |
| | import { extractJson } from "./utils"; |
| | import { externalToToolCall } from "../textGeneration/tools"; |
| | import { logger } from "../logger"; |
| | import type { Endpoint, EndpointMessage } from "../endpoints/endpoints"; |
| |
|
| | interface GetToolOutputOptions { |
| | messages: EndpointMessage[]; |
| | tool: Tool; |
| | preprompt?: string; |
| | endpoint: Endpoint; |
| | generateSettings?: { |
| | max_new_tokens?: number; |
| | [key: string]: unknown; |
| | }; |
| | } |
| |
|
| | export async function getToolOutput<T = string>({ |
| | messages, |
| | preprompt, |
| | tool, |
| | endpoint, |
| | generateSettings = { max_new_tokens: 64 }, |
| | }: GetToolOutputOptions): Promise<T | undefined> { |
| | try { |
| | const stream = await endpoint({ |
| | messages, |
| | preprompt: preprompt + `\n\n Only use tool ${tool.name}.`, |
| | tools: [tool], |
| | generateSettings, |
| | }); |
| |
|
| | const calls = []; |
| |
|
| | for await (const output of stream) { |
| | if (output.token.toolCalls) { |
| | calls.push(...output.token.toolCalls); |
| | } |
| | if (output.generated_text) { |
| | const extractedCalls = await extractJson(output.generated_text).then((calls) => |
| | calls.map((call) => externalToToolCall(call, [tool])).filter((call) => call !== undefined) |
| | ); |
| | calls.push(...extractedCalls); |
| | } |
| |
|
| | if (calls.length > 0) { |
| | break; |
| | } |
| | } |
| |
|
| | if (calls.length > 0) { |
| | |
| | const toolCall = calls.find((call) => call.name === tool.name); |
| |
|
| | |
| | if (toolCall?.parameters) { |
| | |
| | const firstParamValue = Object.values(toolCall.parameters)[0]; |
| | if (typeof firstParamValue === "string") { |
| | return firstParamValue as T; |
| | } |
| | } |
| | } |
| |
|
| | return undefined; |
| | } catch (error) { |
| | logger.warn(error, "Error getting tool output"); |
| | return undefined; |
| | } |
| | } |
| |
|