| | import { config } from "$lib/server/config"; |
| | import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate"; |
| | import { getMcpServers } from "$lib/server/mcp/registry"; |
| | import { isValidUrl } from "$lib/server/urlSafety"; |
| | import { resetMcpToolsCache } from "$lib/server/mcp/tools"; |
| | import { getOpenAiToolsForMcp } from "$lib/server/mcp/tools"; |
| | import type { |
| | ChatCompletionChunk, |
| | ChatCompletionCreateParamsStreaming, |
| | ChatCompletionMessageParam, |
| | ChatCompletionMessageToolCall, |
| | } from "openai/resources/chat/completions"; |
| | import type { Stream } from "openai/streaming"; |
| | import { buildToolPreprompt } from "../utils/toolPrompt"; |
| | import type { EndpointMessage } from "../../endpoints/endpoints"; |
| | import { resolveRouterTarget } from "./routerResolution"; |
| | import { executeToolCalls, type NormalizedToolCall } from "./toolInvocation"; |
| | import { drainPool } from "$lib/server/mcp/clientPool"; |
| | import type { TextGenerationContext } from "../types"; |
| | import { hasAuthHeader, isStrictHfMcpLogin, hasNonEmptyToken } from "$lib/server/mcp/hf"; |
| | import { buildImageRefResolver } from "./fileRefs"; |
| | import { prepareMessagesWithFiles } from "$lib/server/textGeneration/utils/prepareFiles"; |
| | import { makeImageProcessor } from "$lib/server/endpoints/images"; |
| | import { logger } from "$lib/server/logger"; |
| |
|
| | export type RunMcpFlowContext = Pick< |
| | TextGenerationContext, |
| | "model" | "conv" | "assistant" | "forceMultimodal" | "forceTools" | "locals" |
| | > & { messages: EndpointMessage[] }; |
| |
|
| | export async function* runMcpFlow({ |
| | model, |
| | conv, |
| | messages, |
| | assistant, |
| | forceMultimodal, |
| | forceTools, |
| | locals, |
| | preprompt, |
| | abortSignal, |
| | }: RunMcpFlowContext & { preprompt?: string; abortSignal?: AbortSignal }): AsyncGenerator< |
| | MessageUpdate, |
| | boolean, |
| | undefined |
| | > { |
| | |
| | let servers = getMcpServers(); |
| | try { |
| | logger.debug( |
| | { baseServers: servers.map((s) => ({ name: s.name, url: s.url })), count: servers.length }, |
| | "[mcp] base servers loaded" |
| | ); |
| | } catch {} |
| |
|
| | |
| | try { |
| | const reqMcp = ( |
| | locals as unknown as { |
| | mcp?: { |
| | selectedServers?: Array<{ name: string; url: string; headers?: Record<string, string> }>; |
| | selectedServerNames?: string[]; |
| | }; |
| | } |
| | )?.mcp; |
| | const custom = Array.isArray(reqMcp?.selectedServers) ? reqMcp?.selectedServers : []; |
| | if (custom.length > 0) { |
| | |
| | resetMcpToolsCache(); |
| | |
| | const byName = new Map< |
| | string, |
| | { name: string; url: string; headers?: Record<string, string> } |
| | >(); |
| | for (const s of servers) byName.set(s.name, s); |
| | for (const s of custom) byName.set(s.name, s); |
| | servers = [...byName.values()]; |
| | try { |
| | logger.debug( |
| | { |
| | customProvidedCount: custom.length, |
| | mergedServers: servers.map((s) => ({ |
| | name: s.name, |
| | url: s.url, |
| | hasAuth: !!s.headers?.Authorization, |
| | })), |
| | }, |
| | "[mcp] merged request-provided servers" |
| | ); |
| | } catch {} |
| | } |
| |
|
| | |
| | const names = Array.isArray(reqMcp?.selectedServerNames) |
| | ? reqMcp?.selectedServerNames |
| | : undefined; |
| | if (Array.isArray(names)) { |
| | const before = servers.map((s) => s.name); |
| | servers = servers.filter((s) => names.includes(s.name)); |
| | try { |
| | logger.debug( |
| | { selectedNames: names, before, after: servers.map((s) => s.name) }, |
| | "[mcp] applied name selection" |
| | ); |
| | } catch {} |
| | } |
| | } catch { |
| | |
| | } |
| |
|
| | |
| | if (servers.length === 0) { |
| | logger.warn({}, "[mcp] no MCP servers selected after merge/name filter"); |
| | return false; |
| | } |
| |
|
| | |
| | { |
| | const before = servers.slice(); |
| | servers = servers.filter((s) => { |
| | try { |
| | return isValidUrl(s.url); |
| | } catch { |
| | return false; |
| | } |
| | }); |
| | try { |
| | const rejected = before.filter((b) => !servers.includes(b)); |
| | if (rejected.length > 0) { |
| | logger.warn( |
| | { rejected: rejected.map((r) => ({ name: r.name, url: r.url })) }, |
| | "[mcp] rejected servers by URL safety" |
| | ); |
| | } |
| | } catch {} |
| | } |
| | if (servers.length === 0) { |
| | logger.warn({}, "[mcp] all selected MCP servers rejected by URL safety guard"); |
| | return false; |
| | } |
| |
|
| | |
| | |
| | try { |
| | const shouldForward = config.MCP_FORWARD_HF_USER_TOKEN === "true"; |
| | const userToken = |
| | (locals as unknown as { hfAccessToken?: string } | undefined)?.hfAccessToken ?? |
| | (locals as unknown as { token?: string } | undefined)?.token; |
| |
|
| | if (shouldForward && hasNonEmptyToken(userToken)) { |
| | const overlayApplied: string[] = []; |
| | servers = servers.map((s) => { |
| | try { |
| | if (isStrictHfMcpLogin(s.url) && !hasAuthHeader(s.headers)) { |
| | overlayApplied.push(s.name); |
| | return { |
| | ...s, |
| | headers: { ...(s.headers ?? {}), Authorization: `Bearer ${userToken}` }, |
| | }; |
| | } |
| | } catch { |
| | |
| | } |
| | return s; |
| | }); |
| | if (overlayApplied.length > 0) { |
| | try { |
| | logger.debug({ overlayApplied }, "[mcp] forwarded HF token to servers"); |
| | } catch {} |
| | } |
| | } |
| | } catch { |
| | |
| | } |
| | logger.debug( |
| | { count: servers.length, servers: servers.map((s) => s.name) }, |
| | "[mcp] servers configured" |
| | ); |
| | if (servers.length === 0) { |
| | return false; |
| | } |
| |
|
| | |
| | try { |
| | const supportsTools = Boolean((model as unknown as { supportsTools?: boolean }).supportsTools); |
| | const toolsEnabled = Boolean(forceTools) || supportsTools; |
| | logger.debug( |
| | { |
| | model: model.id ?? model.name, |
| | supportsTools, |
| | forceTools: Boolean(forceTools), |
| | toolsEnabled, |
| | }, |
| | "[mcp] tools gate evaluation" |
| | ); |
| | if (!toolsEnabled) { |
| | logger.info( |
| | { model: model.id ?? model.name }, |
| | "[mcp] tools disabled for model; skipping MCP flow" |
| | ); |
| | return false; |
| | } |
| | } catch { |
| | |
| | } |
| |
|
| | const resolveFileRef = buildImageRefResolver(messages); |
| | const imageProcessor = makeImageProcessor({ |
| | supportedMimeTypes: ["image/png", "image/jpeg"], |
| | preferredMimeType: "image/jpeg", |
| | maxSizeInMB: 1, |
| | maxWidth: 1024, |
| | maxHeight: 1024, |
| | }); |
| |
|
| | const hasImageInput = messages.some((msg) => |
| | (msg.files ?? []).some( |
| | (file) => typeof file?.mime === "string" && file.mime.startsWith("image/") |
| | ) |
| | ); |
| |
|
| | const { runMcp, targetModel, candidateModelId, resolvedRoute } = await resolveRouterTarget({ |
| | model, |
| | messages, |
| | conversationId: conv._id.toString(), |
| | hasImageInput, |
| | locals, |
| | }); |
| |
|
| | if (!runMcp) { |
| | logger.info( |
| | { model: targetModel.id ?? targetModel.name, resolvedRoute }, |
| | "[mcp] runMcp=false (routing chose non-tools candidate)" |
| | ); |
| | return false; |
| | } |
| |
|
| | const { tools: oaTools, mapping } = await getOpenAiToolsForMcp(servers, { signal: abortSignal }); |
| | try { |
| | logger.info( |
| | { toolCount: oaTools.length, toolNames: oaTools.map((t) => t.function.name) }, |
| | "[mcp] openai tool defs built" |
| | ); |
| | } catch {} |
| | if (oaTools.length === 0) { |
| | logger.warn({}, "[mcp] zero tools available after listing; skipping MCP flow"); |
| | return false; |
| | } |
| |
|
| | try { |
| | const { OpenAI } = await import("openai"); |
| |
|
| | |
| | let providerHeader: string | undefined; |
| | const captureProviderFetch = async ( |
| | input: RequestInfo | URL, |
| | init?: RequestInit |
| | ): Promise<Response> => { |
| | const res = await fetch(input, init); |
| | const p = res.headers.get("x-inference-provider"); |
| | if (p && !providerHeader) providerHeader = p; |
| | return res; |
| | }; |
| |
|
| | const openai = new OpenAI({ |
| | apiKey: config.OPENAI_API_KEY || config.HF_TOKEN || "sk-", |
| | baseURL: config.OPENAI_BASE_URL, |
| | fetch: captureProviderFetch, |
| | defaultHeaders: { |
| | |
| | ...(config.isHuggingChat && locals?.billingOrganization |
| | ? { "X-HF-Bill-To": locals.billingOrganization } |
| | : {}), |
| | }, |
| | }); |
| |
|
| | const mmEnabled = (forceMultimodal ?? false) || targetModel.multimodal; |
| | logger.info( |
| | { |
| | targetModel: targetModel.id ?? targetModel.name, |
| | mmEnabled, |
| | route: resolvedRoute, |
| | candidateModelId, |
| | toolCount: oaTools.length, |
| | hasUserToken: Boolean((locals as unknown as { token?: string })?.token), |
| | }, |
| | "[mcp] starting completion with tools" |
| | ); |
| | let messagesOpenAI: ChatCompletionMessageParam[] = await prepareMessagesWithFiles( |
| | messages, |
| | imageProcessor, |
| | mmEnabled |
| | ); |
| | const toolPreprompt = buildToolPreprompt(oaTools); |
| | const prepromptPieces: string[] = []; |
| | if (toolPreprompt.trim().length > 0) { |
| | prepromptPieces.push(toolPreprompt); |
| | } |
| | if (typeof preprompt === "string" && preprompt.trim().length > 0) { |
| | prepromptPieces.push(preprompt); |
| | } |
| | const mergedPreprompt = prepromptPieces.join("\n\n"); |
| | const hasSystemMessage = messagesOpenAI.length > 0 && messagesOpenAI[0]?.role === "system"; |
| | if (hasSystemMessage) { |
| | if (mergedPreprompt.length > 0) { |
| | const existing = messagesOpenAI[0].content ?? ""; |
| | const existingText = typeof existing === "string" ? existing : ""; |
| | messagesOpenAI[0].content = mergedPreprompt + (existingText ? "\n\n" + existingText : ""); |
| | } |
| | } else if (mergedPreprompt.length > 0) { |
| | messagesOpenAI = [{ role: "system", content: mergedPreprompt }, ...messagesOpenAI]; |
| | } |
| |
|
| | |
| | if ( |
| | typeof config.OPENAI_BASE_URL === "string" && |
| | config.OPENAI_BASE_URL.length > 0 && |
| | (config.OPENAI_BASE_URL.includes("hf.space") || |
| | config.OPENAI_BASE_URL.includes("gradio.app")) && |
| | messagesOpenAI[0]?.role === "system" |
| | ) { |
| | messagesOpenAI[0] = { ...messagesOpenAI[0], role: "user" }; |
| | } |
| |
|
| | const parameters = { ...targetModel.parameters, ...assistant?.generateSettings } as Record< |
| | string, |
| | unknown |
| | >; |
| | const maxTokens = |
| | (parameters?.max_tokens as number | undefined) ?? |
| | (parameters?.max_new_tokens as number | undefined) ?? |
| | (parameters?.max_completion_tokens as number | undefined); |
| |
|
| | const stopSequences = |
| | typeof parameters?.stop === "string" |
| | ? parameters.stop |
| | : Array.isArray(parameters?.stop) |
| | ? (parameters.stop as string[]) |
| | : undefined; |
| |
|
| | const completionBase: Omit<ChatCompletionCreateParamsStreaming, "messages"> = { |
| | model: targetModel.id ?? targetModel.name, |
| | stream: true, |
| | temperature: typeof parameters?.temperature === "number" ? parameters.temperature : undefined, |
| | top_p: typeof parameters?.top_p === "number" ? parameters.top_p : undefined, |
| | frequency_penalty: |
| | typeof parameters?.frequency_penalty === "number" |
| | ? parameters.frequency_penalty |
| | : typeof parameters?.repetition_penalty === "number" |
| | ? parameters.repetition_penalty |
| | : undefined, |
| | presence_penalty: |
| | typeof parameters?.presence_penalty === "number" ? parameters.presence_penalty : undefined, |
| | stop: stopSequences, |
| | max_tokens: typeof maxTokens === "number" ? maxTokens : undefined, |
| | tools: oaTools, |
| | tool_choice: "auto", |
| | }; |
| |
|
| | const toPrimitive = (value: unknown) => { |
| | if (typeof value === "string" || typeof value === "number" || typeof value === "boolean") { |
| | return value; |
| | } |
| | return undefined; |
| | }; |
| |
|
| | const parseArgs = (raw: unknown): Record<string, unknown> => { |
| | if (typeof raw !== "string" || raw.trim().length === 0) return {}; |
| | try { |
| | return JSON.parse(raw); |
| | } catch { |
| | return {}; |
| | } |
| | }; |
| |
|
| | const processToolOutput = ( |
| | text: string |
| | ): { |
| | annotated: string; |
| | sources: { index: number; link: string }[]; |
| | } => ({ annotated: text, sources: [] }); |
| |
|
| | let lastAssistantContent = ""; |
| | let streamedContent = false; |
| | |
| | |
| | let thinkOpen = false; |
| |
|
| | if (resolvedRoute && candidateModelId) { |
| | yield { |
| | type: MessageUpdateType.RouterMetadata, |
| | route: resolvedRoute, |
| | model: candidateModelId, |
| | }; |
| | logger.debug( |
| | { route: resolvedRoute, model: candidateModelId }, |
| | "[mcp] router metadata emitted" |
| | ); |
| | } |
| |
|
| | for (let loop = 0; loop < 10; loop += 1) { |
| | lastAssistantContent = ""; |
| | streamedContent = false; |
| |
|
| | const completionRequest: ChatCompletionCreateParamsStreaming = { |
| | ...completionBase, |
| | messages: messagesOpenAI, |
| | }; |
| |
|
| | const completionStream: Stream<ChatCompletionChunk> = await openai.chat.completions.create( |
| | completionRequest, |
| | { |
| | signal: abortSignal, |
| | headers: { |
| | "ChatUI-Conversation-ID": conv._id.toString(), |
| | "X-use-cache": "false", |
| | ...(locals?.token ? { Authorization: `Bearer ${locals.token}` } : {}), |
| | }, |
| | } |
| | ); |
| |
|
| | |
| | if (providerHeader) { |
| | yield { |
| | type: MessageUpdateType.RouterMetadata, |
| | route: "", |
| | model: "", |
| | provider: providerHeader as unknown as import("@huggingface/inference").InferenceProvider, |
| | }; |
| | logger.debug({ provider: providerHeader }, "[mcp] provider metadata emitted"); |
| | } |
| |
|
| | const toolCallState: Record<number, { id?: string; name?: string; arguments: string }> = {}; |
| | let firstToolDeltaLogged = false; |
| | let sawToolCall = false; |
| | let tokenCount = 0; |
| | for await (const chunk of completionStream) { |
| | const choice = chunk.choices?.[0]; |
| | const delta = choice?.delta; |
| | if (!delta) continue; |
| |
|
| | const chunkToolCalls = delta.tool_calls ?? []; |
| | if (chunkToolCalls.length > 0) { |
| | sawToolCall = true; |
| | for (const call of chunkToolCalls) { |
| | const toolCall = call as unknown as { |
| | index?: number; |
| | id?: string; |
| | function?: { name?: string; arguments?: string }; |
| | }; |
| | const index = toolCall.index ?? 0; |
| | const current = toolCallState[index] ?? { arguments: "" }; |
| | if (toolCall.id) current.id = toolCall.id; |
| | if (toolCall.function?.name) current.name = toolCall.function.name; |
| | if (toolCall.function?.arguments) current.arguments += toolCall.function.arguments; |
| | toolCallState[index] = current; |
| | } |
| | if (!firstToolDeltaLogged) { |
| | try { |
| | const first = |
| | toolCallState[ |
| | Object.keys(toolCallState) |
| | .map((k) => Number(k)) |
| | .sort((a, b) => a - b)[0] ?? 0 |
| | ]; |
| | logger.info( |
| | { firstCallName: first?.name, hasId: Boolean(first?.id) }, |
| | "[mcp] observed streamed tool_call delta" |
| | ); |
| | firstToolDeltaLogged = true; |
| | } catch {} |
| | } |
| | } |
| |
|
| | const deltaContent = (() => { |
| | if (typeof delta.content === "string") return delta.content; |
| | const maybeParts = delta.content as unknown; |
| | if (Array.isArray(maybeParts)) { |
| | return maybeParts |
| | .map((part) => |
| | typeof part === "object" && |
| | part !== null && |
| | "text" in part && |
| | typeof (part as Record<string, unknown>).text === "string" |
| | ? String((part as Record<string, unknown>).text) |
| | : "" |
| | ) |
| | .join(""); |
| | } |
| | return ""; |
| | })(); |
| |
|
| | |
| | const deltaReasoning: string = |
| | typeof (delta as unknown as Record<string, unknown>)?.reasoning === "string" |
| | ? ((delta as unknown as { reasoning?: string }).reasoning as string) |
| | : typeof (delta as unknown as Record<string, unknown>)?.reasoning_content === "string" |
| | ? ((delta as unknown as { reasoning_content?: string }).reasoning_content as string) |
| | : ""; |
| |
|
| | |
| | |
| | let combined = ""; |
| | if (deltaReasoning.trim().length > 0) { |
| | if (!thinkOpen) { |
| | combined += "<think>" + deltaReasoning; |
| | thinkOpen = true; |
| | } else { |
| | combined += deltaReasoning; |
| | } |
| | } |
| |
|
| | if (deltaContent && deltaContent.length > 0) { |
| | if (thinkOpen) { |
| | combined += "</think>" + deltaContent; |
| | thinkOpen = false; |
| | } else { |
| | combined += deltaContent; |
| | } |
| | } |
| |
|
| | if (combined.length > 0) { |
| | lastAssistantContent += combined; |
| | if (!sawToolCall) { |
| | streamedContent = true; |
| | yield { type: MessageUpdateType.Stream, token: combined }; |
| | tokenCount += combined.length; |
| | } |
| | } |
| | } |
| | logger.info( |
| | { sawToolCalls: Object.keys(toolCallState).length > 0, tokens: tokenCount, loop }, |
| | "[mcp] completion stream closed" |
| | ); |
| |
|
| | if (Object.keys(toolCallState).length > 0) { |
| | |
| | const missingId = Object.values(toolCallState).some((c) => c?.name && !c?.id); |
| | let calls: NormalizedToolCall[]; |
| | if (missingId) { |
| | logger.debug( |
| | { loop }, |
| | "[mcp] missing tool_call id in stream; retrying non-stream to recover ids" |
| | ); |
| | const nonStream = await openai.chat.completions.create( |
| | { ...completionBase, messages: messagesOpenAI, stream: false }, |
| | { |
| | signal: abortSignal, |
| | headers: { |
| | "ChatUI-Conversation-ID": conv._id.toString(), |
| | "X-use-cache": "false", |
| | ...(locals?.token ? { Authorization: `Bearer ${locals.token}` } : {}), |
| | }, |
| | } |
| | ); |
| | const tc = nonStream.choices?.[0]?.message?.tool_calls ?? []; |
| | calls = tc.map((t) => ({ |
| | id: t.id, |
| | name: t.function?.name ?? "", |
| | arguments: t.function?.arguments ?? "", |
| | })); |
| | } else { |
| | calls = Object.values(toolCallState) |
| | .map((c) => (c?.id && c?.name ? c : undefined)) |
| | .filter(Boolean) |
| | .map((c) => ({ |
| | id: c?.id ?? "", |
| | name: c?.name ?? "", |
| | arguments: c?.arguments ?? "", |
| | })) as NormalizedToolCall[]; |
| | } |
| |
|
| | |
| | |
| | const toolCalls: ChatCompletionMessageToolCall[] = calls.map((call) => ({ |
| | id: call.id, |
| | type: "function", |
| | function: { name: call.name, arguments: call.arguments }, |
| | })); |
| |
|
| | |
| | |
| | const assistantContentForToolMsg = lastAssistantContent.replace( |
| | /<think>[\s\S]*?(?:<\/think>|$)/g, |
| | "" |
| | ); |
| | const assistantToolMessage: ChatCompletionMessageParam = { |
| | role: "assistant", |
| | content: assistantContentForToolMsg, |
| | tool_calls: toolCalls, |
| | }; |
| |
|
| | const exec = executeToolCalls({ |
| | calls, |
| | mapping, |
| | servers, |
| | parseArgs, |
| | resolveFileRef, |
| | toPrimitive, |
| | processToolOutput, |
| | abortSignal, |
| | }); |
| | let toolMsgCount = 0; |
| | let toolRunCount = 0; |
| | for await (const event of exec) { |
| | if (event.type === "update") { |
| | yield event.update; |
| | } else { |
| | messagesOpenAI = [ |
| | ...messagesOpenAI, |
| | assistantToolMessage, |
| | ...(event.summary.toolMessages ?? []), |
| | ]; |
| | toolMsgCount = event.summary.toolMessages?.length ?? 0; |
| | toolRunCount = event.summary.toolRuns?.length ?? 0; |
| | logger.info( |
| | { toolMsgCount, toolRunCount }, |
| | "[mcp] tools executed; continuing loop for follow-up completion" |
| | ); |
| | } |
| | } |
| | |
| | continue; |
| | } |
| |
|
| | |
| | |
| | if (thinkOpen) { |
| | lastAssistantContent += "</think>"; |
| | thinkOpen = false; |
| | } |
| | if (!streamedContent && lastAssistantContent.trim().length > 0) { |
| | yield { type: MessageUpdateType.Stream, token: lastAssistantContent }; |
| | } |
| | yield { |
| | type: MessageUpdateType.FinalAnswer, |
| | text: lastAssistantContent, |
| | interrupted: false, |
| | }; |
| | logger.info( |
| | { length: lastAssistantContent.length, loop }, |
| | "[mcp] final answer emitted (no tool_calls)" |
| | ); |
| | return true; |
| | } |
| | logger.warn({}, "[mcp] exceeded tool-followup loops; falling back"); |
| | } catch (err) { |
| | const msg = String(err ?? ""); |
| | const isAbort = |
| | (abortSignal && abortSignal.aborted) || |
| | msg.includes("AbortError") || |
| | msg.includes("APIUserAbortError") || |
| | msg.includes("Request was aborted"); |
| | if (isAbort) { |
| | |
| | logger.debug({}, "[mcp] aborted by user"); |
| | return false; |
| | } |
| | logger.warn({ err: msg }, "[mcp] flow failed, falling back to default endpoint"); |
| | } finally { |
| | |
| | await drainPool(); |
| | } |
| |
|
| | return false; |
| | } |
| |
|