| | |
| |
|
| | import pytest |
| | import asyncio |
| | import time |
| | from unittest.mock import AsyncMock, MagicMock, patch |
| |
|
| | from ankigen_core.agents.security import ( |
| | RateLimitConfig, |
| | SecurityConfig, |
| | RateLimiter, |
| | SecurityValidator, |
| | SecureAgentWrapper, |
| | SecurityError, |
| | get_rate_limiter, |
| | get_security_validator, |
| | create_secure_agent, |
| | strip_html_tags, |
| | validate_api_key_format, |
| | sanitize_for_logging |
| | ) |
| |
|
| |
|
| | |
| | def test_rate_limit_config_defaults(): |
| | """Test RateLimitConfig default values""" |
| | config = RateLimitConfig() |
| | |
| | assert config.requests_per_minute == 60 |
| | assert config.requests_per_hour == 1000 |
| | assert config.burst_limit == 10 |
| | assert config.cooldown_period == 300 |
| |
|
| |
|
| | def test_rate_limit_config_custom(): |
| | """Test RateLimitConfig with custom values""" |
| | config = RateLimitConfig( |
| | requests_per_minute=30, |
| | requests_per_hour=500, |
| | burst_limit=5, |
| | cooldown_period=600 |
| | ) |
| | |
| | assert config.requests_per_minute == 30 |
| | assert config.requests_per_hour == 500 |
| | assert config.burst_limit == 5 |
| | assert config.cooldown_period == 600 |
| |
|
| |
|
| | |
| | def test_security_config_defaults(): |
| | """Test SecurityConfig default values""" |
| | config = SecurityConfig() |
| | |
| | assert config.enable_input_validation is True |
| | assert config.enable_output_filtering is True |
| | assert config.enable_rate_limiting is True |
| | assert config.max_input_length == 10000 |
| | assert config.max_output_length == 50000 |
| | assert len(config.blocked_patterns) > 0 |
| | assert '.txt' in config.allowed_file_extensions |
| |
|
| |
|
| | def test_security_config_blocked_patterns(): |
| | """Test SecurityConfig blocked patterns""" |
| | config = SecurityConfig() |
| | |
| | |
| | patterns = config.blocked_patterns |
| | assert any('api' in pattern.lower() for pattern in patterns) |
| | assert any('secret' in pattern.lower() for pattern in patterns) |
| | assert any('password' in pattern.lower() for pattern in patterns) |
| |
|
| |
|
| | |
| | @pytest.fixture |
| | def rate_limiter(): |
| | """Rate limiter with test configuration""" |
| | config = RateLimitConfig( |
| | requests_per_minute=5, |
| | requests_per_hour=50, |
| | burst_limit=3 |
| | ) |
| | return RateLimiter(config) |
| |
|
| |
|
| | async def test_rate_limiter_allows_requests_under_limit(rate_limiter): |
| | """Test rate limiter allows requests under limits""" |
| | identifier = "test_user" |
| | |
| | |
| | assert await rate_limiter.check_rate_limit(identifier) is True |
| | assert await rate_limiter.check_rate_limit(identifier) is True |
| | assert await rate_limiter.check_rate_limit(identifier) is True |
| |
|
| |
|
| | async def test_rate_limiter_blocks_burst_limit(rate_limiter): |
| | """Test rate limiter blocks requests exceeding burst limit""" |
| | identifier = "test_user" |
| | |
| | |
| | for _ in range(3): |
| | assert await rate_limiter.check_rate_limit(identifier) is True |
| | |
| | |
| | assert await rate_limiter.check_rate_limit(identifier) is False |
| |
|
| |
|
| | async def test_rate_limiter_per_minute_limit(rate_limiter): |
| | """Test rate limiter per-minute limit""" |
| | identifier = "test_user" |
| | |
| | |
| | with patch('time.time') as mock_time: |
| | current_time = 1000.0 |
| | mock_time.return_value = current_time |
| | |
| | |
| | for _ in range(5): |
| | assert await rate_limiter.check_rate_limit(identifier) is True |
| | |
| | |
| | assert await rate_limiter.check_rate_limit(identifier) is False |
| |
|
| |
|
| | async def test_rate_limiter_different_identifiers(rate_limiter): |
| | """Test rate limiter handles different identifiers separately""" |
| | user1 = "user1" |
| | user2 = "user2" |
| | |
| | |
| | for _ in range(3): |
| | assert await rate_limiter.check_rate_limit(user1) is True |
| | |
| | assert await rate_limiter.check_rate_limit(user1) is False |
| | |
| | |
| | assert await rate_limiter.check_rate_limit(user2) is True |
| |
|
| |
|
| | async def test_rate_limiter_reset_time(rate_limiter): |
| | """Test rate limiter reset time calculation""" |
| | identifier = "test_user" |
| | |
| | |
| | for _ in range(3): |
| | await rate_limiter.check_rate_limit(identifier) |
| | |
| | |
| | reset_time = rate_limiter.get_reset_time(identifier) |
| | assert reset_time is not None |
| |
|
| |
|
| | |
| | @pytest.fixture |
| | def security_validator(): |
| | """Security validator with test configuration""" |
| | config = SecurityConfig( |
| | max_input_length=100, |
| | max_output_length=200 |
| | ) |
| | return SecurityValidator(config) |
| |
|
| |
|
| | def test_security_validator_valid_input(security_validator): |
| | """Test security validator allows valid input""" |
| | valid_input = "This is a normal, safe input text." |
| | assert security_validator.validate_input(valid_input, "test") is True |
| |
|
| |
|
| | def test_security_validator_input_too_long(security_validator): |
| | """Test security validator rejects input that's too long""" |
| | long_input = "x" * 1000 |
| | assert security_validator.validate_input(long_input, "test") is False |
| |
|
| |
|
| | def test_security_validator_blocked_patterns(security_validator): |
| | """Test security validator blocks dangerous patterns""" |
| | dangerous_inputs = [ |
| | "Here is my API key: sk-1234567890abcdef", |
| | "My password is secret123", |
| | "The access_token is abc123", |
| | "<script>alert('xss')</script>" |
| | ] |
| | |
| | for dangerous_input in dangerous_inputs: |
| | assert security_validator.validate_input(dangerous_input, "test") is False |
| |
|
| |
|
| | def test_security_validator_output_validation(security_validator): |
| | """Test security validator validates output""" |
| | safe_output = "This is a safe response with no sensitive information." |
| | assert security_validator.validate_output(safe_output, "test_agent") is True |
| | |
| | dangerous_output = "Here's your API key: sk-1234567890abcdef" |
| | assert security_validator.validate_output(dangerous_output, "test_agent") is False |
| |
|
| |
|
| | def test_security_validator_sanitize_input(security_validator): |
| | """Test input sanitization""" |
| | dirty_input = "<script>alert('xss')</script>Normal text" |
| | sanitized = security_validator.sanitize_input(dirty_input) |
| | |
| | assert "<script>" not in sanitized |
| | assert "Normal text" in sanitized |
| |
|
| |
|
| | def test_security_validator_sanitize_output(security_validator): |
| | """Test output sanitization""" |
| | output_with_secrets = "Response with API key sk-1234567890abcdef" |
| | sanitized = security_validator.sanitize_output(output_with_secrets) |
| | |
| | assert "sk-1234567890abcdef" not in sanitized |
| | assert "[REDACTED]" in sanitized |
| |
|
| |
|
| | def test_security_validator_disabled_validation(): |
| | """Test validator with validation disabled""" |
| | config = SecurityConfig( |
| | enable_input_validation=False, |
| | enable_output_filtering=False |
| | ) |
| | validator = SecurityValidator(config) |
| | |
| | |
| | assert validator.validate_input("api_key: sk-123", "test") is True |
| | assert validator.validate_output("secret: password", "test") is True |
| |
|
| |
|
| | |
| | @pytest.fixture |
| | def mock_base_agent(): |
| | """Mock base agent for testing""" |
| | agent = MagicMock() |
| | agent.config = {"name": "test_agent"} |
| | agent.execute = AsyncMock(return_value="test response") |
| | return agent |
| |
|
| |
|
| | @pytest.fixture |
| | def secure_agent_wrapper(mock_base_agent): |
| | """Secure agent wrapper for testing""" |
| | rate_limiter = RateLimiter(RateLimitConfig(burst_limit=2)) |
| | validator = SecurityValidator(SecurityConfig()) |
| | return SecureAgentWrapper(mock_base_agent, rate_limiter, validator) |
| |
|
| |
|
| | async def test_secure_agent_wrapper_successful_execution(secure_agent_wrapper, mock_base_agent): |
| | """Test successful secure execution""" |
| | result = await secure_agent_wrapper.secure_execute("Safe input") |
| | |
| | assert result == "test response" |
| | mock_base_agent.execute.assert_called_once() |
| |
|
| |
|
| | async def test_secure_agent_wrapper_rate_limit_exceeded(secure_agent_wrapper): |
| | """Test rate limit exceeded""" |
| | |
| | await secure_agent_wrapper.secure_execute("input1") |
| | await secure_agent_wrapper.secure_execute("input2") |
| | |
| | |
| | with pytest.raises(SecurityError, match="Rate limit exceeded"): |
| | await secure_agent_wrapper.secure_execute("input3") |
| |
|
| |
|
| | async def test_secure_agent_wrapper_input_validation_failed(): |
| | """Test input validation failure""" |
| | rate_limiter = RateLimiter(RateLimitConfig()) |
| | validator = SecurityValidator(SecurityConfig()) |
| | mock_agent = MagicMock() |
| | wrapper = SecureAgentWrapper(mock_agent, rate_limiter, validator) |
| | |
| | |
| | with pytest.raises(SecurityError, match="Input validation failed"): |
| | await wrapper.secure_execute("API key: sk-1234567890abcdef") |
| |
|
| |
|
| | async def test_secure_agent_wrapper_output_validation_failed(): |
| | """Test output validation failure""" |
| | rate_limiter = RateLimiter(RateLimitConfig()) |
| | validator = SecurityValidator(SecurityConfig()) |
| | |
| | mock_agent = MagicMock() |
| | mock_agent.execute = AsyncMock(return_value="Response with API key: sk-1234567890abcdef") |
| | |
| | wrapper = SecureAgentWrapper(mock_agent, rate_limiter, validator) |
| | |
| | with pytest.raises(SecurityError, match="Output validation failed"): |
| | await wrapper.secure_execute("Safe input") |
| |
|
| |
|
| | |
| | def test_strip_html_tags(): |
| | """Test HTML tag stripping""" |
| | html_text = "<p>Hello <b>World</b>!</p><script>alert('xss')</script>" |
| | clean_text = strip_html_tags(html_text) |
| | |
| | assert "<p>" not in clean_text |
| | assert "<b>" not in clean_text |
| | assert "<script>" not in clean_text |
| | assert "Hello World!" in clean_text |
| |
|
| |
|
| | def test_validate_api_key_format(): |
| | """Test API key format validation""" |
| | |
| | assert validate_api_key_format("sk-1234567890abcdef1234567890abcdef") is True |
| | |
| | |
| | assert validate_api_key_format("") is False |
| | assert validate_api_key_format("invalid") is False |
| | assert validate_api_key_format("sk-test") is False |
| | assert validate_api_key_format("sk-fake1234567890abcdef") is False |
| |
|
| |
|
| | def test_sanitize_for_logging(): |
| | """Test log sanitization""" |
| | sensitive_text = "User input with API key sk-1234567890abcdef" |
| | sanitized = sanitize_for_logging(sensitive_text, max_length=50) |
| | |
| | assert "sk-1234567890abcdef" not in sanitized |
| | assert len(sanitized) <= 50 + 20 |
| |
|
| |
|
| | |
| | def test_get_rate_limiter(): |
| | """Test global rate limiter getter""" |
| | limiter1 = get_rate_limiter() |
| | limiter2 = get_rate_limiter() |
| | |
| | |
| | assert limiter1 is limiter2 |
| |
|
| |
|
| | def test_get_security_validator(): |
| | """Test global security validator getter""" |
| | validator1 = get_security_validator() |
| | validator2 = get_security_validator() |
| | |
| | |
| | assert validator1 is validator2 |
| |
|
| |
|
| | def test_create_secure_agent(): |
| | """Test secure agent creation""" |
| | mock_agent = MagicMock() |
| | secure_agent = create_secure_agent(mock_agent) |
| | |
| | assert isinstance(secure_agent, SecureAgentWrapper) |
| | assert secure_agent.base_agent is mock_agent |
| |
|
| |
|
| | |
| | async def test_rate_limiter_cleanup(): |
| | """Test rate limiter cleans up old requests""" |
| | config = RateLimitConfig(requests_per_minute=10, requests_per_hour=100) |
| | limiter = RateLimiter(config) |
| | |
| | identifier = "test_user" |
| | |
| | |
| | with patch('time.time') as mock_time: |
| | |
| | mock_time.return_value = 1000.0 |
| | |
| | |
| | for _ in range(5): |
| | await limiter.check_rate_limit(identifier) |
| | |
| | |
| | mock_time.return_value = 5000.0 |
| | |
| | |
| | assert await limiter.check_rate_limit(identifier) is True |
| | |
| | |
| | assert len(limiter._requests[identifier]) == 1 |
| |
|
| |
|
| | def test_security_config_file_permissions(): |
| | """Test setting secure file permissions""" |
| | import tempfile |
| | import os |
| | |
| | with tempfile.NamedTemporaryFile(delete=False) as tmp_file: |
| | tmp_path = tmp_file.name |
| | |
| | try: |
| | from ankigen_core.agents.security import set_secure_file_permissions |
| | |
| | |
| | set_secure_file_permissions(tmp_path) |
| | |
| | |
| | if hasattr(os, 'chmod'): |
| | stat_info = os.stat(tmp_path) |
| | |
| | assert stat_info.st_mode & 0o077 == 0 |
| | |
| | finally: |
| | os.unlink(tmp_path) |
| |
|
| |
|
| | |
| | async def test_rate_limiter_concurrent_access(): |
| | """Test rate limiter with concurrent access""" |
| | limiter = RateLimiter(RateLimitConfig(burst_limit=5)) |
| | identifier = "concurrent_user" |
| | |
| | |
| | tasks = [limiter.check_rate_limit(identifier) for _ in range(10)] |
| | results = await asyncio.gather(*tasks) |
| | |
| | |
| | success_count = sum(1 for result in results if result) |
| | assert success_count <= 5 |
| |
|
| |
|
| | def test_security_validator_error_handling(): |
| | """Test security validator error handling""" |
| | validator = SecurityValidator(SecurityConfig()) |
| | |
| | |
| | assert validator.validate_input(None, "test") is False |
| | |
| | |
| | huge_input = "x" * 1000000 |
| | assert validator.validate_input(huge_input, "test") is False |
| |
|
| |
|
| | async def test_secure_agent_wrapper_base_agent_error(): |
| | """Test secure agent wrapper handles base agent errors""" |
| | rate_limiter = RateLimiter(RateLimitConfig()) |
| | validator = SecurityValidator(SecurityConfig()) |
| | |
| | mock_agent = MagicMock() |
| | mock_agent.config = {"name": "test_agent"} |
| | mock_agent.execute = AsyncMock(side_effect=Exception("Base agent failed")) |
| | |
| | wrapper = SecureAgentWrapper(mock_agent, rate_limiter, validator) |
| | |
| | with pytest.raises(Exception, match="Base agent failed"): |
| | await wrapper.secure_execute("Safe input") |