| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | import { |
| | describe, |
| | it, |
| | expect, |
| | vi, |
| | beforeEach, |
| | afterEach, |
| | Mocked, |
| | } from 'vitest'; |
| | import { ToolRegistry, DiscoveredTool } from './tool-registry.js'; |
| | import { DiscoveredMCPTool } from './mcp-tool.js'; |
| | import { |
| | Config, |
| | ConfigParameters, |
| | MCPServerConfig, |
| | ApprovalMode, |
| | } from '../config/config.js'; |
| | import { BaseTool, ToolResult } from './tools.js'; |
| | import { |
| | FunctionDeclaration, |
| | CallableTool, |
| | mcpToTool, |
| | Type, |
| | } from '@google/genai'; |
| | import { execSync } from 'node:child_process'; |
| |
|
| | |
| | const mockDiscoverMcpTools = vi.hoisted(() => vi.fn()); |
| |
|
| | |
| | vi.mock('./mcp-client.js', () => ({ |
| | discoverMcpTools: mockDiscoverMcpTools, |
| | })); |
| |
|
| | |
| | vi.mock('node:child_process', async () => { |
| | const actual = await vi.importActual('node:child_process'); |
| | return { |
| | ...actual, |
| | execSync: vi.fn(), |
| | spawn: vi.fn(), |
| | }; |
| | }); |
| |
|
| | |
| | const mockMcpClientConnect = vi.fn(); |
| | const mockMcpClientOnError = vi.fn(); |
| | const mockStdioTransportClose = vi.fn(); |
| | const mockSseTransportClose = vi.fn(); |
| |
|
| | vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { |
| | const MockClient = vi.fn().mockImplementation(() => ({ |
| | connect: mockMcpClientConnect, |
| | set onerror(handler: any) { |
| | mockMcpClientOnError(handler); |
| | }, |
| | |
| | })); |
| | return { Client: MockClient }; |
| | }); |
| |
|
| | vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => { |
| | const MockStdioClientTransport = vi.fn().mockImplementation(() => ({ |
| | stderr: { |
| | on: vi.fn(), |
| | }, |
| | close: mockStdioTransportClose, |
| | })); |
| | return { StdioClientTransport: MockStdioClientTransport }; |
| | }); |
| |
|
| | vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { |
| | const MockSSEClientTransport = vi.fn().mockImplementation(() => ({ |
| | close: mockSseTransportClose, |
| | })); |
| | return { SSEClientTransport: MockSSEClientTransport }; |
| | }); |
| |
|
| | |
| | vi.mock('@google/genai', async () => { |
| | const actualGenai = |
| | await vi.importActual<typeof import('@google/genai')>('@google/genai'); |
| | return { |
| | ...actualGenai, |
| | mcpToTool: vi.fn().mockImplementation(() => ({ |
| | |
| | tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }), |
| | callTool: vi.fn(), |
| | })), |
| | }; |
| | }); |
| |
|
| | |
| | const createMockCallableTool = ( |
| | toolDeclarations: FunctionDeclaration[], |
| | ): Mocked<CallableTool> => ({ |
| | tool: vi.fn().mockResolvedValue({ functionDeclarations: toolDeclarations }), |
| | callTool: vi.fn(), |
| | }); |
| |
|
| | class MockTool extends BaseTool<{ param: string }, ToolResult> { |
| | constructor(name = 'mock-tool', description = 'A mock tool') { |
| | super(name, name, description, { |
| | type: 'object', |
| | properties: { |
| | param: { type: 'string' }, |
| | }, |
| | required: ['param'], |
| | }); |
| | } |
| | async execute(params: { param: string }): Promise<ToolResult> { |
| | return { |
| | llmContent: `Executed with ${params.param}`, |
| | returnDisplay: `Executed with ${params.param}`, |
| | }; |
| | } |
| | } |
| |
|
| | const baseConfigParams: ConfigParameters = { |
| | cwd: '/tmp', |
| | model: 'test-model', |
| | embeddingModel: 'test-embedding-model', |
| | sandbox: undefined, |
| | targetDir: '/test/dir', |
| | debugMode: false, |
| | userMemory: '', |
| | geminiMdFileCount: 0, |
| | approvalMode: ApprovalMode.DEFAULT, |
| | sessionId: 'test-session-id', |
| | }; |
| |
|
| | describe('ToolRegistry', () => { |
| | let config: Config; |
| | let toolRegistry: ToolRegistry; |
| |
|
| | beforeEach(() => { |
| | config = new Config(baseConfigParams); |
| | toolRegistry = new ToolRegistry(config); |
| | vi.spyOn(console, 'warn').mockImplementation(() => {}); |
| | vi.spyOn(console, 'error').mockImplementation(() => {}); |
| | vi.spyOn(console, 'debug').mockImplementation(() => {}); |
| | vi.spyOn(console, 'log').mockImplementation(() => {}); |
| |
|
| | |
| | mockMcpClientConnect.mockReset().mockResolvedValue(undefined); |
| | mockStdioTransportClose.mockReset(); |
| | mockSseTransportClose.mockReset(); |
| | vi.mocked(mcpToTool).mockClear(); |
| | |
| | vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([])); |
| | }); |
| |
|
| | afterEach(() => { |
| | vi.restoreAllMocks(); |
| | }); |
| |
|
| | describe('registerTool', () => { |
| | it('should register a new tool', () => { |
| | const tool = new MockTool(); |
| | toolRegistry.registerTool(tool); |
| | expect(toolRegistry.getTool('mock-tool')).toBe(tool); |
| | }); |
| | |
| | }); |
| |
|
| | describe('getToolsByServer', () => { |
| | it('should return an empty array if no tools match the server name', () => { |
| | toolRegistry.registerTool(new MockTool()); |
| | expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]); |
| | }); |
| |
|
| | it('should return only tools matching the server name', async () => { |
| | const server1Name = 'mcp-server-uno'; |
| | const server2Name = 'mcp-server-dos'; |
| |
|
| | |
| | const mockCallable = {} as CallableTool; |
| | const mcpTool1 = new DiscoveredMCPTool( |
| | mockCallable, |
| | server1Name, |
| | 'server1Name__tool-on-server1', |
| | 'd1', |
| | {}, |
| | 'tool-on-server1', |
| | ); |
| | const mcpTool2 = new DiscoveredMCPTool( |
| | mockCallable, |
| | server2Name, |
| | 'server2Name__tool-on-server2', |
| | 'd2', |
| | {}, |
| | 'tool-on-server2', |
| | ); |
| | const nonMcpTool = new MockTool('regular-tool'); |
| |
|
| | toolRegistry.registerTool(mcpTool1); |
| | toolRegistry.registerTool(mcpTool2); |
| | toolRegistry.registerTool(nonMcpTool); |
| |
|
| | const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name); |
| | expect(toolsFromServer1).toHaveLength(1); |
| | expect(toolsFromServer1[0].name).toBe(mcpTool1.name); |
| | expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe( |
| | server1Name, |
| | ); |
| |
|
| | const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name); |
| | expect(toolsFromServer2).toHaveLength(1); |
| | expect(toolsFromServer2[0].name).toBe(mcpTool2.name); |
| | expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe( |
| | server2Name, |
| | ); |
| |
|
| | expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]); |
| | }); |
| | }); |
| |
|
| | describe('discoverTools', () => { |
| | let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>; |
| | let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>; |
| | let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>; |
| | let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>; |
| |
|
| | beforeEach(() => { |
| | mockConfigGetToolDiscoveryCommand = vi.spyOn( |
| | config, |
| | 'getToolDiscoveryCommand', |
| | ); |
| | mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers'); |
| | mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand'); |
| | mockExecSync = vi.mocked(execSync); |
| | toolRegistry = new ToolRegistry(config); |
| | |
| | mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined); |
| | }); |
| |
|
| | it('should discover tools using discovery command', async () => { |
| | |
| | const discoveryCommand = 'my-discovery-command'; |
| | mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); |
| | const mockToolDeclarations: FunctionDeclaration[] = [ |
| | { |
| | name: 'discovered-tool-1', |
| | description: 'A discovered tool', |
| | parameters: { type: Type.OBJECT, properties: {} }, |
| | }, |
| | ]; |
| | mockExecSync.mockReturnValue( |
| | Buffer.from( |
| | JSON.stringify([{ function_declarations: mockToolDeclarations }]), |
| | ), |
| | ); |
| | await toolRegistry.discoverTools(); |
| | expect(execSync).toHaveBeenCalledWith(discoveryCommand); |
| | const discoveredTool = toolRegistry.getTool('discovered-tool-1'); |
| | expect(discoveredTool).toBeInstanceOf(DiscoveredTool); |
| | }); |
| |
|
| | it('should discover tools using MCP servers defined in getMcpServers', async () => { |
| | mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); |
| | mockConfigGetMcpServerCommand.mockReturnValue(undefined); |
| | const mcpServerConfigVal = { |
| | 'my-mcp-server': { |
| | command: 'mcp-server-cmd', |
| | args: ['--port', '1234'], |
| | trust: true, |
| | } as MCPServerConfig, |
| | }; |
| | mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal); |
| |
|
| | await toolRegistry.discoverTools(); |
| |
|
| | expect(mockDiscoverMcpTools).toHaveBeenCalledWith( |
| | mcpServerConfigVal, |
| | undefined, |
| | toolRegistry, |
| | ); |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | }); |
| |
|
| | it('should discover tools using MCP server command from getMcpServerCommand', async () => { |
| | mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); |
| | mockConfigGetMcpServers.mockReturnValue({}); |
| | mockConfigGetMcpServerCommand.mockReturnValue( |
| | 'mcp-server-start-command --param', |
| | ); |
| |
|
| | await toolRegistry.discoverTools(); |
| | expect(mockDiscoverMcpTools).toHaveBeenCalledWith( |
| | {}, |
| | 'mcp-server-start-command --param', |
| | toolRegistry, |
| | ); |
| | }); |
| |
|
| | it('should handle errors during MCP client connection gracefully and close transport', async () => { |
| | mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); |
| | mockConfigGetMcpServers.mockReturnValue({ |
| | 'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig, |
| | }); |
| |
|
| | mockMcpClientConnect.mockRejectedValue(new Error('Connection failed')); |
| |
|
| | await toolRegistry.discoverTools(); |
| | expect(mockDiscoverMcpTools).toHaveBeenCalledWith( |
| | { |
| | 'failing-mcp': { command: 'fail-cmd' }, |
| | }, |
| | undefined, |
| | toolRegistry, |
| | ); |
| | expect(toolRegistry.getAllTools()).toHaveLength(0); |
| | }); |
| | }); |
| | |
| | |
| | }); |
| |
|