| import { TokenMethods } from '@librechat/data-schemas'; |
| import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..'; |
| import { MCPManager } from '../MCPManager'; |
| import { mcpServersRegistry } from '../../mcp/registry/MCPServersRegistry'; |
| import { OAuthReconnectionManager } from './OAuthReconnectionManager'; |
| import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; |
|
|
| jest.mock('@librechat/data-schemas', () => ({ |
| logger: { |
| info: jest.fn(), |
| warn: jest.fn(), |
| error: jest.fn(), |
| debug: jest.fn(), |
| }, |
| })); |
|
|
| jest.mock('../MCPManager'); |
| jest.mock('../../mcp/registry/MCPServersRegistry', () => ({ |
| mcpServersRegistry: { |
| getServerConfig: jest.fn(), |
| getOAuthServers: jest.fn(), |
| }, |
| })); |
|
|
| describe('OAuthReconnectionManager', () => { |
| let flowManager: jest.Mocked<FlowStateManager<null>>; |
| let tokenMethods: jest.Mocked<TokenMethods>; |
| let mockMCPManager: jest.Mocked<MCPManager>; |
| let reconnectionManager: OAuthReconnectionManager; |
|
|
| beforeEach(() => { |
| jest.clearAllMocks(); |
|
|
| |
| |
| (OAuthReconnectionManager as any).instance = null; |
|
|
| |
| flowManager = { |
| createFlow: jest.fn(), |
| completeFlow: jest.fn(), |
| failFlow: jest.fn(), |
| deleteFlow: jest.fn(), |
| getFlow: jest.fn(), |
| } as unknown as jest.Mocked<FlowStateManager<null>>; |
|
|
| |
| tokenMethods = { |
| findToken: jest.fn(), |
| createToken: jest.fn(), |
| updateToken: jest.fn(), |
| deleteToken: jest.fn(), |
| } as unknown as jest.Mocked<TokenMethods>; |
|
|
| |
| mockMCPManager = { |
| getOAuthServers: jest.fn(), |
| getUserConnection: jest.fn(), |
| getUserConnections: jest.fn(), |
| disconnectUserConnection: jest.fn(), |
| } as unknown as jest.Mocked<MCPManager>; |
|
|
| (MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager); |
| (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({}); |
| }); |
|
|
| afterEach(() => { |
| jest.clearAllMocks(); |
| }); |
|
|
| describe('Singleton Pattern', () => { |
| it('should create instance successfully', async () => { |
| const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods); |
| expect(instance).toBeInstanceOf(OAuthReconnectionManager); |
| }); |
|
|
| it('should throw error when creating instance twice', async () => { |
| await OAuthReconnectionManager.createInstance(flowManager, tokenMethods); |
| await expect( |
| OAuthReconnectionManager.createInstance(flowManager, tokenMethods), |
| ).rejects.toThrow('OAuthReconnectionManager already initialized'); |
| }); |
|
|
| it('should throw error when getting instance before creation', () => { |
| expect(() => OAuthReconnectionManager.getInstance()).toThrow( |
| 'OAuthReconnectionManager not initialized', |
| ); |
| }); |
| }); |
|
|
| describe('isReconnecting', () => { |
| let reconnectionTracker: OAuthReconnectionTracker; |
| beforeEach(async () => { |
| reconnectionTracker = new OAuthReconnectionTracker(); |
| reconnectionManager = await OAuthReconnectionManager.createInstance( |
| flowManager, |
| tokenMethods, |
| reconnectionTracker, |
| ); |
| }); |
|
|
| it('should return true when server is actively reconnecting', () => { |
| const userId = 'user-123'; |
| const serverName = 'test-server'; |
|
|
| expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false); |
|
|
| reconnectionTracker.setActive(userId, serverName); |
| const result = reconnectionManager.isReconnecting(userId, serverName); |
| expect(result).toBe(true); |
| }); |
|
|
| it('should return false when server is not reconnecting', () => { |
| const userId = 'user-123'; |
| const serverName = 'test-server'; |
|
|
| const result = reconnectionManager.isReconnecting(userId, serverName); |
| expect(result).toBe(false); |
| }); |
| }); |
|
|
| describe('clearReconnection', () => { |
| let reconnectionTracker: OAuthReconnectionTracker; |
| beforeEach(async () => { |
| reconnectionTracker = new OAuthReconnectionTracker(); |
| reconnectionManager = await OAuthReconnectionManager.createInstance( |
| flowManager, |
| tokenMethods, |
| reconnectionTracker, |
| ); |
| }); |
|
|
| it('should clear both failed and active reconnection states', () => { |
| const userId = 'user-123'; |
| const serverName = 'test-server'; |
|
|
| reconnectionTracker.setFailed(userId, serverName); |
| reconnectionTracker.setActive(userId, serverName); |
|
|
| reconnectionManager.clearReconnection(userId, serverName); |
|
|
| expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false); |
| expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false); |
| expect(reconnectionTracker.isActive(userId, serverName)).toBe(false); |
| }); |
| }); |
|
|
| describe('reconnectServers', () => { |
| let reconnectionTracker: OAuthReconnectionTracker; |
| beforeEach(async () => { |
| reconnectionTracker = new OAuthReconnectionTracker(); |
| reconnectionManager = await OAuthReconnectionManager.createInstance( |
| flowManager, |
| tokenMethods, |
| reconnectionTracker, |
| ); |
| }); |
|
|
| it('should reconnect eligible servers', async () => { |
| const userId = 'user-123'; |
| const oauthServers = new Set(['server1', 'server2', 'server3']); |
| (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); |
|
|
| |
| reconnectionTracker.setFailed(userId, 'server1'); |
|
|
| |
| const mockConnection = { |
| isConnected: jest.fn().mockResolvedValue(true), |
| }; |
| const userConnections = new Map([['server2', mockConnection]]); |
| mockMCPManager.getUserConnections.mockReturnValue( |
| userConnections as unknown as Map<string, MCPConnection>, |
| ); |
|
|
| |
| tokenMethods.findToken.mockImplementation(async ({ identifier }) => { |
| if (identifier === 'mcp:server3') { |
| return { |
| userId, |
| identifier, |
| expiresAt: new Date(Date.now() + 3600000), |
| } as unknown as MCPOAuthTokens; |
| } |
| return null; |
| }); |
|
|
| |
| const mockNewConnection = { |
| isConnected: jest.fn().mockResolvedValue(true), |
| disconnect: jest.fn(), |
| }; |
| mockMCPManager.getUserConnection.mockResolvedValue( |
| mockNewConnection as unknown as MCPConnection, |
| ); |
| (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({ |
| initTimeout: 5000, |
| } as unknown as MCPOptions); |
|
|
| await reconnectionManager.reconnectServers(userId); |
|
|
| |
| expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true); |
|
|
| |
| await new Promise((resolve) => setTimeout(resolve, 100)); |
|
|
| |
| expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({ |
| serverName: 'server3', |
| user: { id: userId }, |
| flowManager, |
| tokenMethods, |
| forceNew: false, |
| connectionTimeout: 5000, |
| returnOnOAuth: true, |
| }); |
|
|
| |
| expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false); |
| expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false); |
| }); |
|
|
| it('should handle failed reconnection attempts', async () => { |
| const userId = 'user-123'; |
| const oauthServers = new Set(['server1']); |
| (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); |
|
|
| |
| tokenMethods.findToken.mockResolvedValue({ |
| userId, |
| identifier: 'mcp:server1', |
| expiresAt: new Date(Date.now() + 3600000), |
| } as unknown as MCPOAuthTokens); |
|
|
| |
| mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed')); |
| (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( |
| {} as unknown as MCPOptions, |
| ); |
|
|
| await reconnectionManager.reconnectServers(userId); |
|
|
| |
| await new Promise((resolve) => setTimeout(resolve, 100)); |
|
|
| |
| expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true); |
| expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); |
| expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); |
| }); |
|
|
| it('should not reconnect servers with expired tokens', async () => { |
| const userId = 'user-123'; |
| const oauthServers = new Set(['server1']); |
| (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); |
|
|
| |
| tokenMethods.findToken.mockResolvedValue({ |
| userId, |
| identifier: 'mcp:server1', |
| expiresAt: new Date(Date.now() - 3600000), |
| } as unknown as MCPOAuthTokens); |
|
|
| await reconnectionManager.reconnectServers(userId); |
|
|
| |
| expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); |
| expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); |
| }); |
|
|
| it('should handle connection that returns but is not connected', async () => { |
| const userId = 'user-123'; |
| const oauthServers = new Set(['server1']); |
| (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); |
|
|
| tokenMethods.findToken.mockResolvedValue({ |
| userId, |
| identifier: 'mcp:server1', |
| expiresAt: new Date(Date.now() + 3600000), |
| } as unknown as MCPOAuthTokens); |
|
|
| |
| const mockConnection = { |
| isConnected: jest.fn().mockResolvedValue(false), |
| disconnect: jest.fn(), |
| }; |
| mockMCPManager.getUserConnection.mockResolvedValue( |
| mockConnection as unknown as MCPConnection, |
| ); |
| (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( |
| {} as unknown as MCPOptions, |
| ); |
|
|
| await reconnectionManager.reconnectServers(userId); |
|
|
| |
| await new Promise((resolve) => setTimeout(resolve, 100)); |
|
|
| |
| expect(mockConnection.disconnect).toHaveBeenCalled(); |
| expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true); |
| expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false); |
| expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1'); |
| }); |
|
|
| it('should handle MCPManager not available gracefully', async () => { |
| const userId = 'user-123'; |
|
|
| |
| (OAuthReconnectionManager as unknown as { instance: null }).instance = null; |
|
|
| |
| (MCPManager.getInstance as jest.Mock).mockImplementation(() => { |
| throw new Error('MCPManager has not been initialized.'); |
| }); |
|
|
| |
| const reconnectionTracker = new OAuthReconnectionTracker(); |
| const reconnectionManagerWithoutMCP = await OAuthReconnectionManager.createInstance( |
| flowManager, |
| tokenMethods, |
| reconnectionTracker, |
| ); |
|
|
| |
| await expect(reconnectionManagerWithoutMCP.reconnectServers(userId)).resolves.toBeUndefined(); |
|
|
| |
| expect(tokenMethods.findToken).not.toHaveBeenCalled(); |
| expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled(); |
| expect(mockMCPManager.disconnectUserConnection).not.toHaveBeenCalled(); |
| }); |
| }); |
|
|
| describe('reconnection timeout behavior', () => { |
| let reconnectionTracker: OAuthReconnectionTracker; |
|
|
| beforeEach(async () => { |
| jest.useFakeTimers(); |
| reconnectionTracker = new OAuthReconnectionTracker(); |
| reconnectionManager = await OAuthReconnectionManager.createInstance( |
| flowManager, |
| tokenMethods, |
| reconnectionTracker, |
| ); |
| }); |
|
|
| afterEach(() => { |
| jest.useRealTimers(); |
| }); |
|
|
| it('should handle timed out reconnections via isReconnecting check', () => { |
| const userId = 'user-123'; |
| const serverName = 'test-server'; |
| const now = Date.now(); |
| jest.setSystemTime(now); |
|
|
| |
| reconnectionTracker.setActive(userId, serverName); |
| expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(true); |
|
|
| |
| jest.advanceTimersByTime(2 * 60 * 1000 + 59 * 1000); |
| expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(true); |
|
|
| |
| jest.advanceTimersByTime(2000); |
| expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false); |
| }); |
|
|
| it('should not attempt to reconnect servers that have timed out during reconnection', async () => { |
| const userId = 'user-123'; |
| const oauthServers = new Set(['server1', 'server2']); |
| (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); |
|
|
| const now = Date.now(); |
| jest.setSystemTime(now); |
|
|
| |
| reconnectionTracker.setActive(userId, 'server1'); |
| jest.advanceTimersByTime(6 * 60 * 1000); |
|
|
| |
| tokenMethods.findToken.mockImplementation(async ({ identifier }) => { |
| if (identifier === 'mcp:server2') { |
| return { |
| userId, |
| identifier, |
| expiresAt: new Date(Date.now() + 3600000), |
| } as unknown as MCPOAuthTokens; |
| } |
| return null; |
| }); |
|
|
| |
| const mockNewConnection = { |
| isConnected: jest.fn().mockResolvedValue(true), |
| disconnect: jest.fn(), |
| }; |
| mockMCPManager.getUserConnection.mockResolvedValue( |
| mockNewConnection as unknown as MCPConnection, |
| ); |
|
|
| await reconnectionManager.reconnectServers(userId); |
|
|
| |
| expect(reconnectionTracker.isActive(userId, 'server1')).toBe(true); |
| expect(reconnectionTracker.isStillReconnecting(userId, 'server1')).toBe(false); |
|
|
| |
| expect(reconnectionTracker.isActive(userId, 'server2')).toBe(true); |
|
|
| |
| await jest.runAllTimersAsync(); |
|
|
| |
| expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(1); |
| expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith( |
| expect.objectContaining({ |
| serverName: 'server2', |
| }), |
| ); |
| }); |
|
|
| it('should properly track reconnection time for multiple sequential reconnect attempts', async () => { |
| const userId = 'user-123'; |
| const serverName = 'server1'; |
| const oauthServers = new Set([serverName]); |
| (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); |
|
|
| const now = Date.now(); |
| jest.setSystemTime(now); |
|
|
| |
| tokenMethods.findToken.mockResolvedValue({ |
| userId, |
| identifier: `mcp:${serverName}`, |
| expiresAt: new Date(Date.now() + 3600000), |
| } as unknown as MCPOAuthTokens); |
|
|
| |
| mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed')); |
| (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( |
| {} as unknown as MCPOptions, |
| ); |
|
|
| await reconnectionManager.reconnectServers(userId); |
| await jest.runAllTimersAsync(); |
|
|
| |
| expect(reconnectionTracker.isFailed(userId, serverName)).toBe(true); |
| expect(reconnectionTracker.isActive(userId, serverName)).toBe(false); |
|
|
| |
| reconnectionManager.clearReconnection(userId, serverName); |
|
|
| |
| jest.advanceTimersByTime(3 * 60 * 1000); |
|
|
| |
| const mockConnection = { |
| isConnected: jest.fn().mockResolvedValue(true), |
| }; |
| mockMCPManager.getUserConnection.mockResolvedValue( |
| mockConnection as unknown as MCPConnection, |
| ); |
|
|
| await reconnectionManager.reconnectServers(userId); |
|
|
| |
| expect(reconnectionTracker.isActive(userId, serverName)).toBe(true); |
|
|
| await jest.runAllTimersAsync(); |
|
|
| |
| expect(reconnectionTracker.isActive(userId, serverName)).toBe(false); |
| expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false); |
| }); |
| }); |
| }); |
|
|