import { config } from "$lib/server/config"; import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate"; import { getMcpServers } from "$lib/server/mcp/registry"; import { isValidUrl } from "$lib/server/urlSafety"; import { resetMcpToolsCache } from "$lib/server/mcp/tools"; import { getOpenAiToolsForMcp } from "$lib/server/mcp/tools"; import type { ChatCompletionChunk, ChatCompletionCreateParamsStreaming, ChatCompletionMessageParam, ChatCompletionMessageToolCall, } from "openai/resources/chat/completions"; import type { Stream } from "openai/streaming"; import { buildToolPreprompt } from "../utils/toolPrompt"; import type { EndpointMessage } from "../../endpoints/endpoints"; import { resolveRouterTarget } from "./routerResolution"; import { executeToolCalls, type NormalizedToolCall } from "./toolInvocation"; import { drainPool } from "$lib/server/mcp/clientPool"; import type { TextGenerationContext } from "../types"; import { hasAuthHeader, isStrictHfMcpLogin, hasNonEmptyToken } from "$lib/server/mcp/hf"; import { buildImageRefResolver } from "./fileRefs"; import { prepareMessagesWithFiles } from "$lib/server/textGeneration/utils/prepareFiles"; import { makeImageProcessor } from "$lib/server/endpoints/images"; export type RunMcpFlowContext = Pick< TextGenerationContext, "model" | "conv" | "assistant" | "forceMultimodal" | "forceTools" | "locals" > & { messages: EndpointMessage[] }; export async function* runMcpFlow({ model, conv, messages, assistant, forceMultimodal, forceTools, locals, preprompt, abortSignal, }: RunMcpFlowContext & { preprompt?: string; abortSignal?: AbortSignal }): AsyncGenerator< MessageUpdate, boolean, undefined > { // Start from env-configured servers let servers = getMcpServers(); try { console.debug( { baseServers: servers.map((s) => ({ name: s.name, url: s.url })), count: servers.length }, "[mcp] base servers loaded" ); } catch {} // Merge in request-provided custom servers (if any) try { const reqMcp = ( locals as unknown as { mcp?: { selectedServers?: Array<{ name: string; url: string; headers?: Record }>; selectedServerNames?: string[]; }; } )?.mcp; const custom = Array.isArray(reqMcp?.selectedServers) ? reqMcp?.selectedServers : []; if (custom.length > 0) { // Invalidate cached tool list when the set of servers changes at request-time resetMcpToolsCache(); // Deduplicate by server name (request takes precedence) const byName = new Map< string, { name: string; url: string; headers?: Record } >(); for (const s of servers) byName.set(s.name, s); for (const s of custom) byName.set(s.name, s); servers = [...byName.values()]; try { console.debug( { customProvidedCount: custom.length, mergedServers: servers.map((s) => ({ name: s.name, url: s.url, hasAuth: !!s.headers?.Authorization, })), }, "[mcp] merged request-provided servers" ); } catch {} } // If the client specified a selection by name, filter to those const names = Array.isArray(reqMcp?.selectedServerNames) ? reqMcp?.selectedServerNames : undefined; if (Array.isArray(names)) { const before = servers.map((s) => s.name); servers = servers.filter((s) => names.includes(s.name)); try { console.debug( { selectedNames: names, before, after: servers.map((s) => s.name) }, "[mcp] applied name selection" ); } catch {} } } catch { // ignore selection merge errors and proceed with env servers } // If selection/merge yielded no servers, bail early with clearer log if (servers.length === 0) { console.warn("[mcp] no MCP servers selected after merge/name filter"); return false; } // Enforce server-side safety (public HTTPS only, no private ranges) { const before = servers.slice(); servers = servers.filter((s) => { try { return isValidUrl(s.url); } catch { return false; } }); try { const rejected = before.filter((b) => !servers.includes(b)); if (rejected.length > 0) { console.warn( { rejected: rejected.map((r) => ({ name: r.name, url: r.url })) }, "[mcp] rejected servers by URL safety" ); } } catch {} } if (servers.length === 0) { console.warn("[mcp] all selected MCP servers rejected by URL safety guard"); return false; } // Optionally attach the logged-in user's HF token to the official HF MCP server only. // Never override an explicit Authorization header, and require token to look like an HF token. try { const shouldForward = config.MCP_FORWARD_HF_USER_TOKEN === "true"; const userToken = (locals as unknown as { hfAccessToken?: string } | undefined)?.hfAccessToken ?? (locals as unknown as { token?: string } | undefined)?.token; if (shouldForward && hasNonEmptyToken(userToken)) { const overlayApplied: string[] = []; servers = servers.map((s) => { try { if (isStrictHfMcpLogin(s.url) && !hasAuthHeader(s.headers)) { overlayApplied.push(s.name); return { ...s, headers: { ...(s.headers ?? {}), Authorization: `Bearer ${userToken}` }, }; } } catch { // ignore URL parse errors and leave server unchanged } return s; }); if (overlayApplied.length > 0) { try { console.debug({ overlayApplied }, "[mcp] forwarded HF token to servers"); } catch {} } } } catch { // best-effort overlay; continue if anything goes wrong } console.debug( { count: servers.length, servers: servers.map((s) => s.name) }, "[mcp] servers configured" ); if (servers.length === 0) { return false; } // Gate MCP flow based on model tool support (aggregated) with user override try { const supportsTools = Boolean((model as unknown as { supportsTools?: boolean }).supportsTools); const toolsEnabled = Boolean(forceTools) || supportsTools; console.debug( { model: model.id ?? model.name, supportsTools, forceTools: Boolean(forceTools), toolsEnabled, }, "[mcp] tools gate evaluation" ); if (!toolsEnabled) { console.info( { model: model.id ?? model.name }, "[mcp] tools disabled for model; skipping MCP flow" ); return false; } } catch { // If anything goes wrong reading the flag, proceed (previous behavior) } const resolveFileRef = buildImageRefResolver(messages); const imageProcessor = makeImageProcessor({ supportedMimeTypes: ["image/png", "image/jpeg"], preferredMimeType: "image/jpeg", maxSizeInMB: 1, maxWidth: 1024, maxHeight: 1024, }); const hasImageInput = messages.some((msg) => (msg.files ?? []).some( (file) => typeof file?.mime === "string" && file.mime.startsWith("image/") ) ); const { runMcp, targetModel, candidateModelId, resolvedRoute } = await resolveRouterTarget({ model, messages, conversationId: conv._id.toString(), hasImageInput, locals, }); if (!runMcp) { console.info( { model: targetModel.id ?? targetModel.name, resolvedRoute }, "[mcp] runMcp=false (routing chose non-tools candidate)" ); return false; } const { tools: oaTools, mapping } = await getOpenAiToolsForMcp(servers, { signal: abortSignal }); try { console.info( { toolCount: oaTools.length, toolNames: oaTools.map((t) => t.function.name) }, "[mcp] openai tool defs built" ); } catch {} if (oaTools.length === 0) { console.warn("[mcp] zero tools available after listing; skipping MCP flow"); return false; } try { const { OpenAI } = await import("openai"); // Capture provider header (x-inference-provider) from the upstream OpenAI-compatible server. let providerHeader: string | undefined; const captureProviderFetch = async ( input: RequestInfo | URL, init?: RequestInit ): Promise => { const res = await fetch(input, init); const p = res.headers.get("x-inference-provider"); if (p && !providerHeader) providerHeader = p; return res; }; const openai = new OpenAI({ apiKey: config.OPENAI_API_KEY || config.HF_TOKEN || "sk-", baseURL: config.OPENAI_BASE_URL, fetch: captureProviderFetch, defaultHeaders: { // Bill to organization if configured (HuggingChat only) ...(config.isHuggingChat && locals?.billingOrganization ? { "X-HF-Bill-To": locals.billingOrganization } : {}), }, }); const mmEnabled = (forceMultimodal ?? false) || targetModel.multimodal; console.info( { targetModel: targetModel.id ?? targetModel.name, mmEnabled, route: resolvedRoute, candidateModelId, toolCount: oaTools.length, hasUserToken: Boolean((locals as unknown as { token?: string })?.token), }, "[mcp] starting completion with tools" ); let messagesOpenAI: ChatCompletionMessageParam[] = await prepareMessagesWithFiles( messages, imageProcessor, mmEnabled ); const toolPreprompt = buildToolPreprompt(oaTools); const prepromptPieces: string[] = []; if (toolPreprompt.trim().length > 0) { prepromptPieces.push(toolPreprompt); } if (typeof preprompt === "string" && preprompt.trim().length > 0) { prepromptPieces.push(preprompt); } const mergedPreprompt = prepromptPieces.join("\n\n"); const hasSystemMessage = messagesOpenAI.length > 0 && messagesOpenAI[0]?.role === "system"; if (hasSystemMessage) { if (mergedPreprompt.length > 0) { const existing = messagesOpenAI[0].content ?? ""; const existingText = typeof existing === "string" ? existing : ""; messagesOpenAI[0].content = mergedPreprompt + (existingText ? "\n\n" + existingText : ""); } } else if (mergedPreprompt.length > 0) { messagesOpenAI = [{ role: "system", content: mergedPreprompt }, ...messagesOpenAI]; } // Work around servers that reject `system` role if ( typeof config.OPENAI_BASE_URL === "string" && config.OPENAI_BASE_URL.length > 0 && (config.OPENAI_BASE_URL.includes("hf.space") || config.OPENAI_BASE_URL.includes("gradio.app")) && messagesOpenAI[0]?.role === "system" ) { messagesOpenAI[0] = { ...messagesOpenAI[0], role: "user" }; } const parameters = { ...targetModel.parameters, ...assistant?.generateSettings } as Record< string, unknown >; const maxTokens = (parameters?.max_tokens as number | undefined) ?? (parameters?.max_new_tokens as number | undefined) ?? (parameters?.max_completion_tokens as number | undefined); const stopSequences = typeof parameters?.stop === "string" ? parameters.stop : Array.isArray(parameters?.stop) ? (parameters.stop as string[]) : undefined; const completionBase: Omit = { model: targetModel.id ?? targetModel.name, stream: true, temperature: typeof parameters?.temperature === "number" ? parameters.temperature : undefined, top_p: typeof parameters?.top_p === "number" ? parameters.top_p : undefined, frequency_penalty: typeof parameters?.frequency_penalty === "number" ? parameters.frequency_penalty : typeof parameters?.repetition_penalty === "number" ? parameters.repetition_penalty : undefined, presence_penalty: typeof parameters?.presence_penalty === "number" ? parameters.presence_penalty : undefined, stop: stopSequences, max_tokens: typeof maxTokens === "number" ? maxTokens : undefined, tools: oaTools, tool_choice: "auto", }; const toPrimitive = (value: unknown) => { if (typeof value === "string" || typeof value === "number" || typeof value === "boolean") { return value; } return undefined; }; const parseArgs = (raw: unknown): Record => { if (typeof raw !== "string" || raw.trim().length === 0) return {}; try { return JSON.parse(raw); } catch { return {}; } }; const processToolOutput = ( text: string ): { annotated: string; sources: { index: number; link: string }[]; } => ({ annotated: text, sources: [] }); let lastAssistantContent = ""; let streamedContent = false; // Track whether we're inside a block when the upstream streams // provider-specific reasoning tokens (e.g. `reasoning` or `reasoning_content`). let thinkOpen = false; if (resolvedRoute && candidateModelId) { yield { type: MessageUpdateType.RouterMetadata, route: resolvedRoute, model: candidateModelId, }; console.debug( { route: resolvedRoute, model: candidateModelId }, "[mcp] router metadata emitted" ); } for (let loop = 0; loop < 10; loop += 1) { lastAssistantContent = ""; streamedContent = false; const completionRequest: ChatCompletionCreateParamsStreaming = { ...completionBase, messages: messagesOpenAI, }; const completionStream: Stream = await openai.chat.completions.create( completionRequest, { signal: abortSignal, headers: { "ChatUI-Conversation-ID": conv._id.toString(), "X-use-cache": "false", ...(locals?.token ? { Authorization: `Bearer ${locals.token}` } : {}), }, } ); // If provider header was exposed, notify UI so it can render "via {provider}". if (providerHeader) { yield { type: MessageUpdateType.RouterMetadata, route: "", model: "", provider: providerHeader as unknown as import("@huggingface/inference").InferenceProvider, }; console.debug({ provider: providerHeader }, "[mcp] provider metadata emitted"); } const toolCallState: Record = {}; let firstToolDeltaLogged = false; let sawToolCall = false; let tokenCount = 0; for await (const chunk of completionStream) { const choice = chunk.choices?.[0]; const delta = choice?.delta; if (!delta) continue; const chunkToolCalls = delta.tool_calls ?? []; if (chunkToolCalls.length > 0) { sawToolCall = true; for (const call of chunkToolCalls) { const toolCall = call as unknown as { index?: number; id?: string; function?: { name?: string; arguments?: string }; }; const index = toolCall.index ?? 0; const current = toolCallState[index] ?? { arguments: "" }; if (toolCall.id) current.id = toolCall.id; if (toolCall.function?.name) current.name = toolCall.function.name; if (toolCall.function?.arguments) current.arguments += toolCall.function.arguments; toolCallState[index] = current; } if (!firstToolDeltaLogged) { try { const first = toolCallState[ Object.keys(toolCallState) .map((k) => Number(k)) .sort((a, b) => a - b)[0] ?? 0 ]; console.info( { firstCallName: first?.name, hasId: Boolean(first?.id) }, "[mcp] observed streamed tool_call delta" ); firstToolDeltaLogged = true; } catch {} } } const deltaContent = (() => { if (typeof delta.content === "string") return delta.content; const maybeParts = delta.content as unknown; if (Array.isArray(maybeParts)) { return maybeParts .map((part) => typeof part === "object" && part !== null && "text" in part && typeof (part as Record).text === "string" ? String((part as Record).text) : "" ) .join(""); } return ""; })(); // Provider-dependent reasoning fields (e.g., `reasoning` or `reasoning_content`). const deltaReasoning: string = typeof (delta as unknown as Record)?.reasoning === "string" ? ((delta as unknown as { reasoning?: string }).reasoning as string) : typeof (delta as unknown as Record)?.reasoning_content === "string" ? ((delta as unknown as { reasoning_content?: string }).reasoning_content as string) : ""; // Merge reasoning + content into a single combined token stream, mirroring // the OpenAI adapter so the UI can auto-detect blocks. let combined = ""; if (deltaReasoning.trim().length > 0) { if (!thinkOpen) { combined += "" + deltaReasoning; thinkOpen = true; } else { combined += deltaReasoning; } } if (deltaContent && deltaContent.length > 0) { if (thinkOpen) { combined += "" + deltaContent; thinkOpen = false; } else { combined += deltaContent; } } if (combined.length > 0) { lastAssistantContent += combined; if (!sawToolCall) { streamedContent = true; yield { type: MessageUpdateType.Stream, token: combined }; tokenCount += combined.length; } } } console.info( { sawToolCalls: Object.keys(toolCallState).length > 0, tokens: tokenCount, loop }, "[mcp] completion stream closed" ); if (Object.keys(toolCallState).length > 0) { // If any streamed call is missing id, perform a quick non-stream retry to recover full tool_calls with ids const missingId = Object.values(toolCallState).some((c) => c?.name && !c?.id); let calls: NormalizedToolCall[]; if (missingId) { console.debug( { loop }, "[mcp] missing tool_call id in stream; retrying non-stream to recover ids" ); const nonStream = await openai.chat.completions.create( { ...completionBase, messages: messagesOpenAI, stream: false }, { signal: abortSignal, headers: { "ChatUI-Conversation-ID": conv._id.toString(), "X-use-cache": "false", ...(locals?.token ? { Authorization: `Bearer ${locals.token}` } : {}), }, } ); const tc = nonStream.choices?.[0]?.message?.tool_calls ?? []; calls = tc.map((t) => ({ id: t.id, name: t.function?.name ?? "", arguments: t.function?.arguments ?? "", })); } else { calls = Object.values(toolCallState) .map((c) => (c?.id && c?.name ? c : undefined)) .filter(Boolean) .map((c) => ({ id: c?.id ?? "", name: c?.name ?? "", arguments: c?.arguments ?? "", })) as NormalizedToolCall[]; } // Include the assistant message with tool_calls so the next round // sees both the calls and their outputs, matching MCP branch behavior. const toolCalls: ChatCompletionMessageToolCall[] = calls.map((call) => ({ id: call.id, type: "function", function: { name: call.name, arguments: call.arguments }, })); // Avoid sending content back to the model alongside tool_calls // to prevent confusing follow-up reasoning. Strip any think blocks. const assistantContentForToolMsg = lastAssistantContent.replace( /[\s\S]*?(?:<\/think>|$)/g, "" ); const assistantToolMessage: ChatCompletionMessageParam = { role: "assistant", content: assistantContentForToolMsg, tool_calls: toolCalls, }; const exec = executeToolCalls({ calls, mapping, servers, parseArgs, resolveFileRef, toPrimitive, processToolOutput, abortSignal, }); let toolMsgCount = 0; let toolRunCount = 0; for await (const event of exec) { if (event.type === "update") { yield event.update; } else { messagesOpenAI = [ ...messagesOpenAI, assistantToolMessage, ...(event.summary.toolMessages ?? []), ]; toolMsgCount = event.summary.toolMessages?.length ?? 0; toolRunCount = event.summary.toolRuns?.length ?? 0; console.info( { toolMsgCount, toolRunCount }, "[mcp] tools executed; continuing loop for follow-up completion" ); } } // Continue loop: next iteration will use tool messages to get the final content continue; } // No tool calls: finalize and return // If a block is still open, close it for the final output if (thinkOpen) { lastAssistantContent += ""; thinkOpen = false; } if (!streamedContent && lastAssistantContent.trim().length > 0) { yield { type: MessageUpdateType.Stream, token: lastAssistantContent }; } yield { type: MessageUpdateType.FinalAnswer, text: lastAssistantContent, interrupted: false, }; console.info( { length: lastAssistantContent.length, loop }, "[mcp] final answer emitted (no tool_calls)" ); return true; } console.warn("[mcp] exceeded tool-followup loops; falling back"); } catch (err) { const msg = String(err ?? ""); const isAbort = (abortSignal && abortSignal.aborted) || msg.includes("AbortError") || msg.includes("APIUserAbortError") || msg.includes("Request was aborted"); if (isAbort) { // Expected on user stop; keep logs quiet and do not treat as error console.debug("[mcp] aborted by user"); return false; } console.warn({ err: msg }, "[mcp] flow failed, falling back to default endpoint"); } finally { // ensure MCP clients are closed after the turn await drainPool(); } return false; }