| | |
| | |
| | |
| | |
| | |
| |
|
| | import { Client } from '@modelcontextprotocol/sdk/client/index.js'; |
| | import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; |
| | import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; |
| | import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; |
| | import { parse } from 'shell-quote'; |
| | import { MCPServerConfig } from '../config/config.js'; |
| | import { DiscoveredMCPTool } from './mcp-tool.js'; |
| | import { |
| | CallableTool, |
| | FunctionDeclaration, |
| | mcpToTool, |
| | Schema, |
| | } from '@google/genai'; |
| | import { ToolRegistry } from './tool-registry.js'; |
| |
|
| | export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; |
| |
|
| | |
| | |
| | |
| | export enum MCPServerStatus { |
| | |
| | DISCONNECTED = 'disconnected', |
| | |
| | CONNECTING = 'connecting', |
| | |
| | CONNECTED = 'connected', |
| | } |
| |
|
| | |
| | |
| | |
| | export enum MCPDiscoveryState { |
| | |
| | NOT_STARTED = 'not_started', |
| | |
| | IN_PROGRESS = 'in_progress', |
| | |
| | COMPLETED = 'completed', |
| | } |
| |
|
| | |
| | |
| | |
| | const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map(); |
| |
|
| | |
| | |
| | |
| | let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; |
| |
|
| | |
| | |
| | |
| | type StatusChangeListener = ( |
| | serverName: string, |
| | status: MCPServerStatus, |
| | ) => void; |
| | const statusChangeListeners: StatusChangeListener[] = []; |
| |
|
| | |
| | |
| | |
| | export function addMCPStatusChangeListener( |
| | listener: StatusChangeListener, |
| | ): void { |
| | statusChangeListeners.push(listener); |
| | } |
| |
|
| | |
| | |
| | |
| | export function removeMCPStatusChangeListener( |
| | listener: StatusChangeListener, |
| | ): void { |
| | const index = statusChangeListeners.indexOf(listener); |
| | if (index !== -1) { |
| | statusChangeListeners.splice(index, 1); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | function updateMCPServerStatus( |
| | serverName: string, |
| | status: MCPServerStatus, |
| | ): void { |
| | mcpServerStatusesInternal.set(serverName, status); |
| | |
| | for (const listener of statusChangeListeners) { |
| | listener(serverName, status); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | export function getMCPServerStatus(serverName: string): MCPServerStatus { |
| | return ( |
| | mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED |
| | ); |
| | } |
| |
|
| | |
| | |
| | |
| | export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> { |
| | return new Map(mcpServerStatusesInternal); |
| | } |
| |
|
| | |
| | |
| | |
| | export function getMCPDiscoveryState(): MCPDiscoveryState { |
| | return mcpDiscoveryState; |
| | } |
| |
|
| | export async function discoverMcpTools( |
| | mcpServers: Record<string, MCPServerConfig>, |
| | mcpServerCommand: string | undefined, |
| | toolRegistry: ToolRegistry, |
| | ): Promise<void> { |
| | |
| | mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; |
| |
|
| | try { |
| | if (mcpServerCommand) { |
| | const cmd = mcpServerCommand; |
| | const args = parse(cmd, process.env) as string[]; |
| | if (args.some((arg) => typeof arg !== 'string')) { |
| | throw new Error('failed to parse mcpServerCommand: ' + cmd); |
| | } |
| | |
| | mcpServers['mcp'] = { |
| | command: args[0], |
| | args: args.slice(1), |
| | }; |
| | } |
| |
|
| | const discoveryPromises = Object.entries(mcpServers).map( |
| | ([mcpServerName, mcpServerConfig]) => |
| | connectAndDiscover(mcpServerName, mcpServerConfig, toolRegistry), |
| | ); |
| | await Promise.all(discoveryPromises); |
| |
|
| | |
| | mcpDiscoveryState = MCPDiscoveryState.COMPLETED; |
| | } catch (error) { |
| | |
| | mcpDiscoveryState = MCPDiscoveryState.COMPLETED; |
| | throw error; |
| | } |
| | } |
| |
|
| | async function connectAndDiscover( |
| | mcpServerName: string, |
| | mcpServerConfig: MCPServerConfig, |
| | toolRegistry: ToolRegistry, |
| | ): Promise<void> { |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); |
| |
|
| | let transport; |
| | if (mcpServerConfig.httpUrl) { |
| | transport = new StreamableHTTPClientTransport( |
| | new URL(mcpServerConfig.httpUrl), |
| | ); |
| | } else if (mcpServerConfig.url) { |
| | transport = new SSEClientTransport(new URL(mcpServerConfig.url)); |
| | } else if (mcpServerConfig.command) { |
| | transport = new StdioClientTransport({ |
| | command: mcpServerConfig.command, |
| | args: mcpServerConfig.args || [], |
| | env: { |
| | ...process.env, |
| | ...(mcpServerConfig.env || {}), |
| | } as Record<string, string>, |
| | cwd: mcpServerConfig.cwd, |
| | stderr: 'pipe', |
| | }); |
| | } else { |
| | console.error( |
| | `MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`, |
| | ); |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); |
| | return; |
| | } |
| |
|
| | const mcpClient = new Client({ |
| | name: 'gemini-cli-mcp-client', |
| | version: '0.0.1', |
| | }); |
| |
|
| | |
| | |
| | if ('callTool' in mcpClient) { |
| | const origCallTool = mcpClient.callTool.bind(mcpClient); |
| | mcpClient.callTool = function (params, resultSchema, options) { |
| | return origCallTool(params, resultSchema, { |
| | ...options, |
| | timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, |
| | }); |
| | }; |
| | } |
| |
|
| | try { |
| | await mcpClient.connect(transport, { |
| | timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, |
| | }); |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); |
| | } catch (error) { |
| | |
| | const safeConfig = { |
| | command: mcpServerConfig.command, |
| | url: mcpServerConfig.url, |
| | cwd: mcpServerConfig.cwd, |
| | timeout: mcpServerConfig.timeout, |
| | trust: mcpServerConfig.trust, |
| | |
| | }; |
| |
|
| | let errorString = |
| | `failed to start or connect to MCP server '${mcpServerName}' ` + |
| | `${JSON.stringify(safeConfig)}; \n${error}`; |
| | if (process.env.SANDBOX) { |
| | errorString += `\nMake sure it is available in the sandbox`; |
| | } |
| | console.error(errorString); |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); |
| | return; |
| | } |
| |
|
| | mcpClient.onerror = (error) => { |
| | console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); |
| | }; |
| |
|
| | if (transport instanceof StdioClientTransport && transport.stderr) { |
| | transport.stderr.on('data', (data) => { |
| | const stderrStr = data.toString(); |
| | |
| | if (!stderrStr.includes('] INFO')) { |
| | console.debug(`MCP STDERR (${mcpServerName}):`, stderrStr); |
| | } |
| | }); |
| | } |
| |
|
| | try { |
| | const mcpCallableTool: CallableTool = mcpToTool(mcpClient); |
| | const discoveredToolFunctions = await mcpCallableTool.tool(); |
| |
|
| | if ( |
| | !discoveredToolFunctions || |
| | !Array.isArray(discoveredToolFunctions.functionDeclarations) |
| | ) { |
| | console.error( |
| | `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`, |
| | ); |
| | if ( |
| | transport instanceof StdioClientTransport || |
| | transport instanceof SSEClientTransport || |
| | transport instanceof StreamableHTTPClientTransport |
| | ) { |
| | await transport.close(); |
| | } |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); |
| | return; |
| | } |
| |
|
| | for (const funcDecl of discoveredToolFunctions.functionDeclarations) { |
| | if (!funcDecl.name) { |
| | console.warn( |
| | `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, |
| | ); |
| | continue; |
| | } |
| |
|
| | let toolNameForModel = funcDecl.name; |
| |
|
| | |
| | toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_'); |
| |
|
| | const existingTool = toolRegistry.getTool(toolNameForModel); |
| | if (existingTool) { |
| | toolNameForModel = mcpServerName + '__' + toolNameForModel; |
| | } |
| |
|
| | |
| | |
| | if (toolNameForModel.length > 63) { |
| | toolNameForModel = |
| | toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32); |
| | } |
| |
|
| | sanatizeParameters(funcDecl.parameters); |
| |
|
| | |
| | const parameterSchema: Record<string, unknown> = |
| | funcDecl.parameters && typeof funcDecl.parameters === 'object' |
| | ? { ...(funcDecl.parameters as FunctionDeclaration) } |
| | : { type: 'object', properties: {} }; |
| |
|
| | toolRegistry.registerTool( |
| | new DiscoveredMCPTool( |
| | mcpCallableTool, |
| | mcpServerName, |
| | toolNameForModel, |
| | funcDecl.description ?? '', |
| | parameterSchema, |
| | funcDecl.name, |
| | mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, |
| | mcpServerConfig.trust, |
| | ), |
| | ); |
| | } |
| | } catch (error) { |
| | console.error( |
| | `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`, |
| | ); |
| | |
| | if ( |
| | transport instanceof StdioClientTransport || |
| | transport instanceof SSEClientTransport || |
| | transport instanceof StreamableHTTPClientTransport |
| | ) { |
| | await transport.close(); |
| | } |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if (toolRegistry.getToolsByServer(mcpServerName).length === 0) { |
| | console.log( |
| | `No tools registered from MCP server '${mcpServerName}'. Closing connection.`, |
| | ); |
| | if ( |
| | transport instanceof StdioClientTransport || |
| | transport instanceof SSEClientTransport || |
| | transport instanceof StreamableHTTPClientTransport |
| | ) { |
| | await transport.close(); |
| | |
| | updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); |
| | } |
| | } |
| | } |
| |
|
| | export function sanatizeParameters(schema?: Schema) { |
| | if (!schema) { |
| | return; |
| | } |
| | if (schema.anyOf) { |
| | |
| | schema.default = undefined; |
| | for (const item of schema.anyOf) { |
| | sanatizeParameters(item); |
| | } |
| | } |
| | if (schema.items) { |
| | sanatizeParameters(schema.items); |
| | } |
| | if (schema.properties) { |
| | for (const item of Object.values(schema.properties)) { |
| | sanatizeParameters(item); |
| | } |
| | } |
| | } |
| |
|