bookmyservice-mhs / app /tests /test_security.py
MukeshKapoor25's picture
test(performance): Add comprehensive test suite for performance optimization
7611990
"""
Security regression tests
"""
import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
class TestInputValidation:
"""Test input validation and sanitization"""
def test_sql_injection_prevention(self, client: TestClient):
"""Test SQL injection prevention in merchant search"""
malicious_input = "'; DROP TABLE merchants; --"
response = client.get("/api/v1/merchants/search", params={
"category": malicious_input,
"latitude": 40.7128,
"longitude": -74.0060
})
# Should either sanitize input or return 400
assert response.status_code in [200, 400]
if response.status_code == 200:
# If processed, should not contain malicious SQL
data = response.json()
assert isinstance(data, list)
def test_xss_prevention_in_nlp_query(self, client: TestClient):
"""Test XSS prevention in NLP query processing"""
xss_payload = "<script>alert('xss')</script>find salon"
with patch('app.services.advanced_nlp.advanced_nlp_pipeline') as mock_nlp:
mock_nlp.process_query.return_value = {
"query": "find salon", # Should be sanitized
"primary_intent": {"intent": "SEARCH_SERVICE", "confidence": 0.8},
"entities": {},
"similar_services": [],
"search_parameters": {},
"processing_time": 0.1
}
response = client.post("/api/v1/nlp/analyze-query", params={
"query": xss_payload
})
assert response.status_code == 200
data = response.json()
# Script tags should be removed/sanitized
assert "<script>" not in str(data)
def test_command_injection_prevention(self, client: TestClient):
"""Test command injection prevention"""
command_injection = "; rm -rf /"
response = client.post("/api/v1/helpers/process-text", json={
"text": f"find salon{command_injection}",
"latitude": 40.7128,
"longitude": -74.0060
})
# Should handle malicious input safely
assert response.status_code in [200, 400]
def test_path_traversal_prevention(self, client: TestClient):
"""Test path traversal prevention in merchant ID"""
path_traversal = "../../../etc/passwd"
response = client.get(f"/api/v1/merchants/{path_traversal}")
# Should not allow path traversal
assert response.status_code in [400, 404]
def test_large_payload_handling(self, client: TestClient):
"""Test handling of excessively large payloads"""
large_text = "A" * (10 * 1024 * 1024) # 10MB
response = client.post("/api/v1/helpers/process-text", json={
"text": large_text,
"latitude": 40.7128,
"longitude": -74.0060
})
# Should reject or handle large payloads appropriately
assert response.status_code in [400, 413, 422]
def test_invalid_coordinates_handling(self, client: TestClient):
"""Test handling of invalid coordinates"""
invalid_coords = [
(999, 999), # Out of range
(-999, -999), # Out of range
("abc", "def"), # Non-numeric
(None, None) # Null values
]
for lat, lng in invalid_coords:
response = client.get("/api/v1/merchants/search", params={
"latitude": lat,
"longitude": lng,
"radius": 5000
})
# Should handle invalid coordinates gracefully
assert response.status_code in [200, 400, 422]
class TestAuthentication:
"""Test authentication mechanisms (if implemented)"""
def test_unauthenticated_access_to_public_endpoints(self, client: TestClient):
"""Test that public endpoints don't require authentication"""
public_endpoints = [
"/health",
"/api/v1/merchants/",
"/api/v1/merchants/search"
]
for endpoint in public_endpoints:
response = client.get(endpoint)
# Should not require authentication
assert response.status_code != 401
def test_api_key_validation(self, client: TestClient):
"""Test API key validation if implemented"""
# This test assumes API key authentication might be implemented
invalid_api_key = "invalid_key_12345"
response = client.get("/api/v1/merchants/", headers={
"X-API-Key": invalid_api_key
})
# Should either ignore invalid key or reject it
assert response.status_code in [200, 401, 403]
class TestAuthorization:
"""Test authorization and access control"""
def test_admin_endpoint_access(self, client: TestClient):
"""Test access to admin endpoints"""
# Assuming there might be admin endpoints
admin_endpoints = [
"/admin/users",
"/admin/merchants",
"/admin/system"
]
for endpoint in admin_endpoints:
response = client.get(endpoint)
# Should require proper authorization or not exist
assert response.status_code in [401, 403, 404]
def test_user_data_isolation(self, client: TestClient):
"""Test that users can only access their own data"""
# This would be relevant if user-specific data exists
user_id = "user123"
other_user_id = "user456"
# Try to access another user's data
response = client.get(f"/api/v1/users/{other_user_id}/data", headers={
"User-ID": user_id
})
# Should not allow access to other user's data
assert response.status_code in [401, 403, 404]
class TestDataProtection:
"""Test data protection and privacy"""
def test_sensitive_data_not_exposed(self, client: TestClient, sample_merchant_data):
"""Test that sensitive data is not exposed in API responses"""
with patch('app.services.merchant.get_merchants') as mock_get:
# Add sensitive data to mock
merchant_with_sensitive = sample_merchant_data.copy()
merchant_with_sensitive.update({
"internal_id": "INTERNAL_123",
"api_key": "secret_api_key",
"database_password": "secret_password",
"private_notes": "Internal business notes"
})
mock_get.return_value = [merchant_with_sensitive]
response = client.get("/api/v1/merchants/")
assert response.status_code == 200
data = response.json()
response_text = str(data)
# Sensitive fields should not be in response
sensitive_fields = ["api_key", "database_password", "internal_id"]
for field in sensitive_fields:
assert field not in response_text
def test_error_message_information_disclosure(self, client: TestClient):
"""Test that error messages don't disclose sensitive information"""
with patch('app.services.merchant.get_merchants') as mock_get:
# Simulate database error with sensitive info
mock_get.side_effect = Exception("Database connection failed: host=internal-db.company.com, user=admin, password=secret123")
response = client.get("/api/v1/merchants/")
assert response.status_code == 500
# Error response should not contain sensitive database info
error_text = str(response.json())
sensitive_info = ["password=", "host=internal-db", "user=admin"]
for info in sensitive_info:
assert info not in error_text
def test_log_sanitization(self, client: TestClient):
"""Test that logs don't contain sensitive information"""
# This would require checking actual log output
# For now, test that endpoints handle sensitive data properly
sensitive_query = "my credit card is 4111-1111-1111-1111"
response = client.post("/api/v1/nlp/analyze-query", params={
"query": sensitive_query
})
# Should process without exposing sensitive data
assert response.status_code in [200, 400]
class TestCORSAndHeaders:
"""Test CORS and security headers"""
def test_cors_configuration(self, client: TestClient):
"""Test CORS configuration"""
# Test preflight request
response = client.options("/api/v1/merchants/", headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "GET"
})
# Should handle CORS properly
assert response.status_code in [200, 204]
def test_cors_origin_validation(self, client: TestClient):
"""Test CORS origin validation"""
# Test with allowed origin
response = client.get("/api/v1/merchants/", headers={
"Origin": "http://localhost:3000"
})
assert response.status_code == 200
# Test with disallowed origin
response = client.get("/api/v1/merchants/", headers={
"Origin": "http://malicious-site.com"
})
# Should still work but without CORS headers for invalid origin
assert response.status_code == 200
def test_security_headers(self, client: TestClient):
"""Test security headers in responses"""
response = client.get("/health")
# Check for common security headers
headers = response.headers
# These might not be implemented yet, but good to test for
security_headers = [
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Strict-Transport-Security"
]
# At minimum, should not have dangerous headers
dangerous_headers = [
"Server", # Should not expose server details
"X-Powered-By" # Should not expose technology stack
]
for header in dangerous_headers:
if header in headers:
# If present, should not contain sensitive info
assert "internal" not in headers[header].lower()
assert "secret" not in headers[header].lower()
class TestRateLimiting:
"""Test rate limiting and abuse prevention"""
@pytest.mark.asyncio
async def test_rate_limiting_basic(self, async_client):
"""Test basic rate limiting functionality"""
# Make many requests quickly
responses = []
for _ in range(100):
response = await async_client.get("/health")
responses.append(response)
# Should either all succeed or some be rate limited
status_codes = [r.status_code for r in responses]
# All should be either 200 (OK) or 429 (Too Many Requests)
assert all(code in [200, 429] for code in status_codes)
def test_rate_limiting_per_endpoint(self, client: TestClient):
"""Test rate limiting per endpoint"""
endpoints = [
"/health",
"/api/v1/merchants/",
"/api/v1/nlp/supported-intents"
]
for endpoint in endpoints:
# Make multiple requests to each endpoint
responses = []
for _ in range(20):
response = client.get(endpoint)
responses.append(response)
# Should handle multiple requests appropriately
status_codes = [r.status_code for r in responses]
assert all(code in [200, 429, 500] for code in status_codes)
class TestInputSanitization:
"""Test comprehensive input sanitization"""
def test_html_sanitization(self, client: TestClient):
"""Test HTML tag sanitization"""
html_inputs = [
"<b>bold text</b>",
"<img src='x' onerror='alert(1)'>",
"<iframe src='javascript:alert(1)'></iframe>",
"<<script>alert('xss')</script>script>alert('xss')<</script>/script>"
]
for html_input in html_inputs:
response = client.post("/api/v1/helpers/process-text", json={
"text": html_input,
"latitude": 40.7128,
"longitude": -74.0060
})
# Should handle HTML input safely
assert response.status_code in [200, 400]
if response.status_code == 200:
# Response should not contain dangerous HTML
response_text = str(response.json())
assert "<script>" not in response_text
assert "javascript:" not in response_text
def test_unicode_handling(self, client: TestClient):
"""Test Unicode and special character handling"""
unicode_inputs = [
"café résumé naïve", # Accented characters
"🏪🔍💇‍♀️", # Emojis
"测试中文字符", # Chinese characters
"тест кириллица", # Cyrillic
"\u0000\u0001\u0002", # Control characters
]
for unicode_input in unicode_inputs:
response = client.post("/api/v1/nlp/analyze-query", params={
"query": unicode_input
})
# Should handle Unicode safely
assert response.status_code in [200, 400]
def test_numeric_input_validation(self, client: TestClient):
"""Test numeric input validation"""
invalid_numeric_inputs = [
("latitude", "not_a_number"),
("longitude", "infinity"),
("radius", "-1000"),
("limit", "999999999999999999999"),
("skip", "-1")
]
for param, value in invalid_numeric_inputs:
response = client.get("/api/v1/merchants/search", params={
param: value,
"latitude": 40.7128 if param != "latitude" else value,
"longitude": -74.0060 if param != "longitude" else value
})
# Should validate numeric inputs
assert response.status_code in [200, 400, 422]
class TestDatabaseSecurity:
"""Test database security measures"""
@pytest.mark.asyncio
async def test_mongodb_injection_prevention(self):
"""Test MongoDB injection prevention"""
from app.repositories.db_repository import search_merchants_in_db
# MongoDB injection attempts
injection_attempts = [
{"$where": "this.name == 'test'"},
{"$regex": ".*"},
{"$ne": None}
]
with patch('app.nosql.get_mongodb_client') as mock_client:
mock_collection = MagicMock()
mock_client.return_value.__getitem__.return_value.__getitem__.return_value = mock_collection
mock_collection.find.return_value.limit.return_value.to_list.return_value = []
for injection in injection_attempts:
try:
# Should sanitize or reject injection attempts
await search_merchants_in_db(category=injection)
# If it doesn't raise an exception, check that query was sanitized
call_args = mock_collection.find.call_args
if call_args:
query = call_args[0][0]
# Should not contain MongoDB operators in user input
assert "$where" not in str(query.get("category", ""))
except (ValueError, TypeError):
# Expected for invalid input
pass
def test_connection_string_security(self):
"""Test that database connection strings don't expose credentials"""
from app.nosql import get_mongodb_client
# This test ensures connection strings are properly configured
# In a real scenario, you'd check that credentials aren't hardcoded
client = get_mongodb_client()
# Should have a client instance
assert client is not None
# Connection string should not be exposed in error messages
# This would require triggering a connection error and checking the message
class TestAPISecurityBestPractices:
"""Test API security best practices"""
def test_http_methods_restriction(self, client: TestClient):
"""Test that endpoints only accept appropriate HTTP methods"""
# Test that GET endpoints don't accept POST
response = client.post("/api/v1/merchants/")
assert response.status_code == 405 # Method Not Allowed
# Test that POST endpoints don't accept GET
response = client.get("/api/v1/helpers/process-text")
assert response.status_code == 405 # Method Not Allowed
def test_content_type_validation(self, client: TestClient):
"""Test content type validation"""
# Send JSON data with wrong content type
response = client.post(
"/api/v1/helpers/process-text",
data='{"text": "test"}',
headers={"Content-Type": "text/plain"}
)
# Should reject or handle appropriately
assert response.status_code in [400, 415, 422]
def test_parameter_pollution(self, client: TestClient):
"""Test handling of parameter pollution"""
# Send duplicate parameters
response = client.get("/api/v1/merchants/search?category=salon&category=spa")
# Should handle duplicate parameters appropriately
assert response.status_code in [200, 400]
def test_null_byte_injection(self, client: TestClient):
"""Test null byte injection prevention"""
null_byte_input = "test\x00malicious"
response = client.post("/api/v1/nlp/analyze-query", params={
"query": null_byte_input
})
# Should handle null bytes safely
assert response.status_code in [200, 400]
if response.status_code == 200:
# Response should not contain null bytes
response_text = str(response.json())
assert "\x00" not in response_text