Spaces:
Running
Running
| """ | |
| Security regression tests | |
| """ | |
| import pytest | |
| from unittest.mock import patch, MagicMock | |
| from fastapi.testclient import TestClient | |
| class TestInputValidation: | |
| """Test input validation and sanitization""" | |
| def test_sql_injection_prevention(self, client: TestClient): | |
| """Test SQL injection prevention in merchant search""" | |
| malicious_input = "'; DROP TABLE merchants; --" | |
| response = client.get("/api/v1/merchants/search", params={ | |
| "category": malicious_input, | |
| "latitude": 40.7128, | |
| "longitude": -74.0060 | |
| }) | |
| # Should either sanitize input or return 400 | |
| assert response.status_code in [200, 400] | |
| if response.status_code == 200: | |
| # If processed, should not contain malicious SQL | |
| data = response.json() | |
| assert isinstance(data, list) | |
| def test_xss_prevention_in_nlp_query(self, client: TestClient): | |
| """Test XSS prevention in NLP query processing""" | |
| xss_payload = "<script>alert('xss')</script>find salon" | |
| with patch('app.services.advanced_nlp.advanced_nlp_pipeline') as mock_nlp: | |
| mock_nlp.process_query.return_value = { | |
| "query": "find salon", # Should be sanitized | |
| "primary_intent": {"intent": "SEARCH_SERVICE", "confidence": 0.8}, | |
| "entities": {}, | |
| "similar_services": [], | |
| "search_parameters": {}, | |
| "processing_time": 0.1 | |
| } | |
| response = client.post("/api/v1/nlp/analyze-query", params={ | |
| "query": xss_payload | |
| }) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Script tags should be removed/sanitized | |
| assert "<script>" not in str(data) | |
| def test_command_injection_prevention(self, client: TestClient): | |
| """Test command injection prevention""" | |
| command_injection = "; rm -rf /" | |
| response = client.post("/api/v1/helpers/process-text", json={ | |
| "text": f"find salon{command_injection}", | |
| "latitude": 40.7128, | |
| "longitude": -74.0060 | |
| }) | |
| # Should handle malicious input safely | |
| assert response.status_code in [200, 400] | |
| def test_path_traversal_prevention(self, client: TestClient): | |
| """Test path traversal prevention in merchant ID""" | |
| path_traversal = "../../../etc/passwd" | |
| response = client.get(f"/api/v1/merchants/{path_traversal}") | |
| # Should not allow path traversal | |
| assert response.status_code in [400, 404] | |
| def test_large_payload_handling(self, client: TestClient): | |
| """Test handling of excessively large payloads""" | |
| large_text = "A" * (10 * 1024 * 1024) # 10MB | |
| response = client.post("/api/v1/helpers/process-text", json={ | |
| "text": large_text, | |
| "latitude": 40.7128, | |
| "longitude": -74.0060 | |
| }) | |
| # Should reject or handle large payloads appropriately | |
| assert response.status_code in [400, 413, 422] | |
| def test_invalid_coordinates_handling(self, client: TestClient): | |
| """Test handling of invalid coordinates""" | |
| invalid_coords = [ | |
| (999, 999), # Out of range | |
| (-999, -999), # Out of range | |
| ("abc", "def"), # Non-numeric | |
| (None, None) # Null values | |
| ] | |
| for lat, lng in invalid_coords: | |
| response = client.get("/api/v1/merchants/search", params={ | |
| "latitude": lat, | |
| "longitude": lng, | |
| "radius": 5000 | |
| }) | |
| # Should handle invalid coordinates gracefully | |
| assert response.status_code in [200, 400, 422] | |
| class TestAuthentication: | |
| """Test authentication mechanisms (if implemented)""" | |
| def test_unauthenticated_access_to_public_endpoints(self, client: TestClient): | |
| """Test that public endpoints don't require authentication""" | |
| public_endpoints = [ | |
| "/health", | |
| "/api/v1/merchants/", | |
| "/api/v1/merchants/search" | |
| ] | |
| for endpoint in public_endpoints: | |
| response = client.get(endpoint) | |
| # Should not require authentication | |
| assert response.status_code != 401 | |
| def test_api_key_validation(self, client: TestClient): | |
| """Test API key validation if implemented""" | |
| # This test assumes API key authentication might be implemented | |
| invalid_api_key = "invalid_key_12345" | |
| response = client.get("/api/v1/merchants/", headers={ | |
| "X-API-Key": invalid_api_key | |
| }) | |
| # Should either ignore invalid key or reject it | |
| assert response.status_code in [200, 401, 403] | |
| class TestAuthorization: | |
| """Test authorization and access control""" | |
| def test_admin_endpoint_access(self, client: TestClient): | |
| """Test access to admin endpoints""" | |
| # Assuming there might be admin endpoints | |
| admin_endpoints = [ | |
| "/admin/users", | |
| "/admin/merchants", | |
| "/admin/system" | |
| ] | |
| for endpoint in admin_endpoints: | |
| response = client.get(endpoint) | |
| # Should require proper authorization or not exist | |
| assert response.status_code in [401, 403, 404] | |
| def test_user_data_isolation(self, client: TestClient): | |
| """Test that users can only access their own data""" | |
| # This would be relevant if user-specific data exists | |
| user_id = "user123" | |
| other_user_id = "user456" | |
| # Try to access another user's data | |
| response = client.get(f"/api/v1/users/{other_user_id}/data", headers={ | |
| "User-ID": user_id | |
| }) | |
| # Should not allow access to other user's data | |
| assert response.status_code in [401, 403, 404] | |
| class TestDataProtection: | |
| """Test data protection and privacy""" | |
| def test_sensitive_data_not_exposed(self, client: TestClient, sample_merchant_data): | |
| """Test that sensitive data is not exposed in API responses""" | |
| with patch('app.services.merchant.get_merchants') as mock_get: | |
| # Add sensitive data to mock | |
| merchant_with_sensitive = sample_merchant_data.copy() | |
| merchant_with_sensitive.update({ | |
| "internal_id": "INTERNAL_123", | |
| "api_key": "secret_api_key", | |
| "database_password": "secret_password", | |
| "private_notes": "Internal business notes" | |
| }) | |
| mock_get.return_value = [merchant_with_sensitive] | |
| response = client.get("/api/v1/merchants/") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| response_text = str(data) | |
| # Sensitive fields should not be in response | |
| sensitive_fields = ["api_key", "database_password", "internal_id"] | |
| for field in sensitive_fields: | |
| assert field not in response_text | |
| def test_error_message_information_disclosure(self, client: TestClient): | |
| """Test that error messages don't disclose sensitive information""" | |
| with patch('app.services.merchant.get_merchants') as mock_get: | |
| # Simulate database error with sensitive info | |
| mock_get.side_effect = Exception("Database connection failed: host=internal-db.company.com, user=admin, password=secret123") | |
| response = client.get("/api/v1/merchants/") | |
| assert response.status_code == 500 | |
| # Error response should not contain sensitive database info | |
| error_text = str(response.json()) | |
| sensitive_info = ["password=", "host=internal-db", "user=admin"] | |
| for info in sensitive_info: | |
| assert info not in error_text | |
| def test_log_sanitization(self, client: TestClient): | |
| """Test that logs don't contain sensitive information""" | |
| # This would require checking actual log output | |
| # For now, test that endpoints handle sensitive data properly | |
| sensitive_query = "my credit card is 4111-1111-1111-1111" | |
| response = client.post("/api/v1/nlp/analyze-query", params={ | |
| "query": sensitive_query | |
| }) | |
| # Should process without exposing sensitive data | |
| assert response.status_code in [200, 400] | |
| class TestCORSAndHeaders: | |
| """Test CORS and security headers""" | |
| def test_cors_configuration(self, client: TestClient): | |
| """Test CORS configuration""" | |
| # Test preflight request | |
| response = client.options("/api/v1/merchants/", headers={ | |
| "Origin": "http://localhost:3000", | |
| "Access-Control-Request-Method": "GET" | |
| }) | |
| # Should handle CORS properly | |
| assert response.status_code in [200, 204] | |
| def test_cors_origin_validation(self, client: TestClient): | |
| """Test CORS origin validation""" | |
| # Test with allowed origin | |
| response = client.get("/api/v1/merchants/", headers={ | |
| "Origin": "http://localhost:3000" | |
| }) | |
| assert response.status_code == 200 | |
| # Test with disallowed origin | |
| response = client.get("/api/v1/merchants/", headers={ | |
| "Origin": "http://malicious-site.com" | |
| }) | |
| # Should still work but without CORS headers for invalid origin | |
| assert response.status_code == 200 | |
| def test_security_headers(self, client: TestClient): | |
| """Test security headers in responses""" | |
| response = client.get("/health") | |
| # Check for common security headers | |
| headers = response.headers | |
| # These might not be implemented yet, but good to test for | |
| security_headers = [ | |
| "X-Content-Type-Options", | |
| "X-Frame-Options", | |
| "X-XSS-Protection", | |
| "Strict-Transport-Security" | |
| ] | |
| # At minimum, should not have dangerous headers | |
| dangerous_headers = [ | |
| "Server", # Should not expose server details | |
| "X-Powered-By" # Should not expose technology stack | |
| ] | |
| for header in dangerous_headers: | |
| if header in headers: | |
| # If present, should not contain sensitive info | |
| assert "internal" not in headers[header].lower() | |
| assert "secret" not in headers[header].lower() | |
| class TestRateLimiting: | |
| """Test rate limiting and abuse prevention""" | |
| async def test_rate_limiting_basic(self, async_client): | |
| """Test basic rate limiting functionality""" | |
| # Make many requests quickly | |
| responses = [] | |
| for _ in range(100): | |
| response = await async_client.get("/health") | |
| responses.append(response) | |
| # Should either all succeed or some be rate limited | |
| status_codes = [r.status_code for r in responses] | |
| # All should be either 200 (OK) or 429 (Too Many Requests) | |
| assert all(code in [200, 429] for code in status_codes) | |
| def test_rate_limiting_per_endpoint(self, client: TestClient): | |
| """Test rate limiting per endpoint""" | |
| endpoints = [ | |
| "/health", | |
| "/api/v1/merchants/", | |
| "/api/v1/nlp/supported-intents" | |
| ] | |
| for endpoint in endpoints: | |
| # Make multiple requests to each endpoint | |
| responses = [] | |
| for _ in range(20): | |
| response = client.get(endpoint) | |
| responses.append(response) | |
| # Should handle multiple requests appropriately | |
| status_codes = [r.status_code for r in responses] | |
| assert all(code in [200, 429, 500] for code in status_codes) | |
| class TestInputSanitization: | |
| """Test comprehensive input sanitization""" | |
| def test_html_sanitization(self, client: TestClient): | |
| """Test HTML tag sanitization""" | |
| html_inputs = [ | |
| "<b>bold text</b>", | |
| "<img src='x' onerror='alert(1)'>", | |
| "<iframe src='javascript:alert(1)'></iframe>", | |
| "<<script>alert('xss')</script>script>alert('xss')<</script>/script>" | |
| ] | |
| for html_input in html_inputs: | |
| response = client.post("/api/v1/helpers/process-text", json={ | |
| "text": html_input, | |
| "latitude": 40.7128, | |
| "longitude": -74.0060 | |
| }) | |
| # Should handle HTML input safely | |
| assert response.status_code in [200, 400] | |
| if response.status_code == 200: | |
| # Response should not contain dangerous HTML | |
| response_text = str(response.json()) | |
| assert "<script>" not in response_text | |
| assert "javascript:" not in response_text | |
| def test_unicode_handling(self, client: TestClient): | |
| """Test Unicode and special character handling""" | |
| unicode_inputs = [ | |
| "café résumé naïve", # Accented characters | |
| "🏪🔍💇♀️", # Emojis | |
| "测试中文字符", # Chinese characters | |
| "тест кириллица", # Cyrillic | |
| "\u0000\u0001\u0002", # Control characters | |
| ] | |
| for unicode_input in unicode_inputs: | |
| response = client.post("/api/v1/nlp/analyze-query", params={ | |
| "query": unicode_input | |
| }) | |
| # Should handle Unicode safely | |
| assert response.status_code in [200, 400] | |
| def test_numeric_input_validation(self, client: TestClient): | |
| """Test numeric input validation""" | |
| invalid_numeric_inputs = [ | |
| ("latitude", "not_a_number"), | |
| ("longitude", "infinity"), | |
| ("radius", "-1000"), | |
| ("limit", "999999999999999999999"), | |
| ("skip", "-1") | |
| ] | |
| for param, value in invalid_numeric_inputs: | |
| response = client.get("/api/v1/merchants/search", params={ | |
| param: value, | |
| "latitude": 40.7128 if param != "latitude" else value, | |
| "longitude": -74.0060 if param != "longitude" else value | |
| }) | |
| # Should validate numeric inputs | |
| assert response.status_code in [200, 400, 422] | |
| class TestDatabaseSecurity: | |
| """Test database security measures""" | |
| async def test_mongodb_injection_prevention(self): | |
| """Test MongoDB injection prevention""" | |
| from app.repositories.db_repository import search_merchants_in_db | |
| # MongoDB injection attempts | |
| injection_attempts = [ | |
| {"$where": "this.name == 'test'"}, | |
| {"$regex": ".*"}, | |
| {"$ne": None} | |
| ] | |
| with patch('app.nosql.get_mongodb_client') as mock_client: | |
| mock_collection = MagicMock() | |
| mock_client.return_value.__getitem__.return_value.__getitem__.return_value = mock_collection | |
| mock_collection.find.return_value.limit.return_value.to_list.return_value = [] | |
| for injection in injection_attempts: | |
| try: | |
| # Should sanitize or reject injection attempts | |
| await search_merchants_in_db(category=injection) | |
| # If it doesn't raise an exception, check that query was sanitized | |
| call_args = mock_collection.find.call_args | |
| if call_args: | |
| query = call_args[0][0] | |
| # Should not contain MongoDB operators in user input | |
| assert "$where" not in str(query.get("category", "")) | |
| except (ValueError, TypeError): | |
| # Expected for invalid input | |
| pass | |
| def test_connection_string_security(self): | |
| """Test that database connection strings don't expose credentials""" | |
| from app.nosql import get_mongodb_client | |
| # This test ensures connection strings are properly configured | |
| # In a real scenario, you'd check that credentials aren't hardcoded | |
| client = get_mongodb_client() | |
| # Should have a client instance | |
| assert client is not None | |
| # Connection string should not be exposed in error messages | |
| # This would require triggering a connection error and checking the message | |
| class TestAPISecurityBestPractices: | |
| """Test API security best practices""" | |
| def test_http_methods_restriction(self, client: TestClient): | |
| """Test that endpoints only accept appropriate HTTP methods""" | |
| # Test that GET endpoints don't accept POST | |
| response = client.post("/api/v1/merchants/") | |
| assert response.status_code == 405 # Method Not Allowed | |
| # Test that POST endpoints don't accept GET | |
| response = client.get("/api/v1/helpers/process-text") | |
| assert response.status_code == 405 # Method Not Allowed | |
| def test_content_type_validation(self, client: TestClient): | |
| """Test content type validation""" | |
| # Send JSON data with wrong content type | |
| response = client.post( | |
| "/api/v1/helpers/process-text", | |
| data='{"text": "test"}', | |
| headers={"Content-Type": "text/plain"} | |
| ) | |
| # Should reject or handle appropriately | |
| assert response.status_code in [400, 415, 422] | |
| def test_parameter_pollution(self, client: TestClient): | |
| """Test handling of parameter pollution""" | |
| # Send duplicate parameters | |
| response = client.get("/api/v1/merchants/search?category=salon&category=spa") | |
| # Should handle duplicate parameters appropriately | |
| assert response.status_code in [200, 400] | |
| def test_null_byte_injection(self, client: TestClient): | |
| """Test null byte injection prevention""" | |
| null_byte_input = "test\x00malicious" | |
| response = client.post("/api/v1/nlp/analyze-query", params={ | |
| "query": null_byte_input | |
| }) | |
| # Should handle null bytes safely | |
| assert response.status_code in [200, 400] | |
| if response.status_code == 200: | |
| # Response should not contain null bytes | |
| response_text = str(response.json()) | |
| assert "\x00" not in response_text |