import { randomUUID } from "crypto"; import { logger } from "../../logger"; import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { MessageToolUpdateType, MessageUpdateType } from "$lib/types/MessageUpdate"; import { ToolResultStatus } from "$lib/types/Tool"; import type { ChatCompletionMessageParam } from "openai/resources/chat/completions"; import type { McpToolMapping } from "$lib/server/mcp/tools"; import type { McpServerConfig } from "$lib/server/mcp/httpClient"; import { callMcpTool, type McpToolTextResponse } from "$lib/server/mcp/httpClient"; import { getClient } from "$lib/server/mcp/clientPool"; import { attachFileRefsToArgs, type FileRefResolver } from "./fileRefs"; import type { Client } from "@modelcontextprotocol/sdk/client"; export type Primitive = string | number | boolean; export type ToolRun = { name: string; parameters: Record; output: string; }; export interface NormalizedToolCall { id: string; name: string; arguments: string; } export interface ExecuteToolCallsParams { calls: NormalizedToolCall[]; mapping: Record; servers: McpServerConfig[]; parseArgs: (raw: unknown) => Record; resolveFileRef?: FileRefResolver; toPrimitive: (value: unknown) => Primitive | undefined; processToolOutput: (text: string) => { annotated: string; sources: { index: number; link: string }[]; }; abortSignal?: AbortSignal; toolTimeoutMs?: number; } export interface ToolCallExecutionResult { toolMessages: ChatCompletionMessageParam[]; toolRuns: ToolRun[]; finalAnswer?: { text: string; interrupted: boolean }; } export type ToolExecutionEvent = | { type: "update"; update: MessageUpdate } | { type: "complete"; summary: ToolCallExecutionResult }; const serverMap = (servers: McpServerConfig[]): Map => { const map = new Map(); for (const server of servers) { if (server?.name) { map.set(server.name, server); } } return map; }; export async function* executeToolCalls({ calls, mapping, servers, parseArgs, resolveFileRef, toPrimitive, processToolOutput, abortSignal, toolTimeoutMs = 30_000, }: ExecuteToolCallsParams): AsyncGenerator { const toolMessages: ChatCompletionMessageParam[] = []; const toolRuns: ToolRun[] = []; const serverLookup = serverMap(servers); // Pre-emit call + ETA updates and prepare tasks type TaskResult = { index: number; output?: string; structured?: unknown; blocks?: unknown[]; error?: string; uuid: string; paramsClean: Record; }; const prepared = calls.map((call) => { const argsObj = parseArgs(call.arguments); const paramsClean: Record = {}; for (const [k, v] of Object.entries(argsObj ?? {})) { const prim = toPrimitive(v); if (prim !== undefined) paramsClean[k] = prim; } // Attach any resolved image payloads _after_ computing paramsClean so that // logging / status updates continue to show only the lightweight primitive // arguments (e.g. "image_1") while the full data: URLs or image blobs are // only sent to the MCP tool server. attachFileRefsToArgs(argsObj, resolveFileRef); return { call, argsObj, paramsClean, uuid: randomUUID() }; }); for (const p of prepared) { yield { type: "update", update: { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.Call, uuid: p.uuid, call: { name: p.call.name, parameters: p.paramsClean }, }, }; yield { type: "update", update: { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.ETA, uuid: p.uuid, eta: 10, }, }; } // Preload clients per distinct server used in this batch const distinctServerNames = Array.from( new Set(prepared.map((p) => mapping[p.call.name]?.server).filter(Boolean) as string[]) ); const clientMap = new Map(); await Promise.all( distinctServerNames.map(async (name) => { const cfg = serverLookup.get(name); if (!cfg) return; try { const client = await getClient(cfg, abortSignal); clientMap.set(name, client); } catch (e) { logger.warn({ server: name, err: String(e) }, "[mcp] failed to connect client"); } }) ); // Async queue to stream results in finish order function createQueue() { const items: T[] = []; const waiters: Array<(v: IteratorResult) => void> = []; let closed = false; return { push(item: T) { const waiter = waiters.shift(); if (waiter) waiter({ value: item, done: false }); else items.push(item); }, close() { closed = true; let waiter: ((v: IteratorResult) => void) | undefined; while ((waiter = waiters.shift())) { waiter({ value: undefined as unknown as T, done: true }); } }, async *iterator() { for (;;) { if (items.length) { const first = items.shift(); if (first !== undefined) yield first as T; continue; } if (closed) return; const value: IteratorResult = await new Promise((res) => waiters.push(res)); if (value.done) return; yield value.value as T; } }, }; } const q = createQueue(); const tasks = prepared.map(async (p, index) => { const mappingEntry = mapping[p.call.name]; if (!mappingEntry) { q.push({ index, error: `Unknown MCP function: ${p.call.name}`, uuid: p.uuid, paramsClean: p.paramsClean, }); return; } const serverCfg = serverLookup.get(mappingEntry.server); if (!serverCfg) { q.push({ index, error: `Unknown MCP server: ${mappingEntry.server}`, uuid: p.uuid, paramsClean: p.paramsClean, }); return; } const client = clientMap.get(mappingEntry.server); try { logger.debug( { server: mappingEntry.server, tool: mappingEntry.tool, parameters: p.paramsClean }, "[mcp] invoking tool" ); const toolResponse: McpToolTextResponse = await callMcpTool( serverCfg, mappingEntry.tool, p.argsObj, { client, signal: abortSignal, timeoutMs: toolTimeoutMs, } ); const { annotated } = processToolOutput(toolResponse.text ?? ""); logger.debug( { server: mappingEntry.server, tool: mappingEntry.tool }, "[mcp] tool call completed" ); q.push({ index, output: annotated, structured: toolResponse.structured, blocks: toolResponse.content, uuid: p.uuid, paramsClean: p.paramsClean, }); } catch (err) { const message = err instanceof Error ? err.message : String(err); logger.warn( { server: mappingEntry.server, tool: mappingEntry.tool, err: message }, "[mcp] tool call failed" ); q.push({ index, error: message, uuid: p.uuid, paramsClean: p.paramsClean }); } }); // kick off and stream as they finish Promise.allSettled(tasks).then(() => q.close()); const results: TaskResult[] = []; for await (const r of q.iterator()) { results.push(r); if (r.error) { yield { type: "update", update: { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.Error, uuid: r.uuid, message: r.error, }, }; } else { yield { type: "update", update: { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.Result, uuid: r.uuid, result: { status: ToolResultStatus.Success, call: { name: prepared[r.index].call.name, parameters: r.paramsClean }, outputs: [ { text: r.output ?? "", structured: r.structured, content: r.blocks, } as unknown as Record, ], display: true, }, }, }; } } // Collate outputs in original call order results.sort((a, b) => a.index - b.index); for (const r of results) { const name = prepared[r.index].call.name; const id = prepared[r.index].call.id; if (!r.error) { const output = r.output ?? ""; toolRuns.push({ name, parameters: r.paramsClean, output }); // For the LLM follow-up call, we keep only the textual output toolMessages.push({ role: "tool", tool_call_id: id, content: output }); } } yield { type: "complete", summary: { toolMessages, toolRuns } }; }