gemini / tests /test_key_redaction.py
JXJBing's picture
Upload 102 files
07a2f32 verified
"""
Unit tests for API key redaction functionality
"""
import unittest
import logging
from unittest.mock import patch, MagicMock
from app.utils.helpers import redact_key_for_logging
from app.log.logger import AccessLogFormatter
class TestKeyRedaction(unittest.TestCase):
"""Test cases for the redact_key_for_logging function"""
def test_valid_long_key_redaction(self):
"""Test redaction of valid long API keys"""
# Test Google/Gemini API key
# This value is a random generated string for testing
gemini_key = "AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI"
result = redact_key_for_logging(gemini_key)
expected = "AIzaSy...xDfGhI"
self.assertEqual(result, expected)
# Test OpenAI API key
# This value is a random generated string for testing
openai_key = "sk-1234567890abcdef1234567890abcdef1234567890abcdef"
result = redact_key_for_logging(openai_key)
expected = "sk-123...abcdef"
self.assertEqual(result, expected)
def test_short_key_handling(self):
"""Test handling of short keys"""
short_key = "short"
result = redact_key_for_logging(short_key)
self.assertEqual(result, "[SHORT_KEY]")
# Test exactly 12 characters (boundary case)
boundary_key = "123456789012"
result = redact_key_for_logging(boundary_key)
self.assertEqual(result, "[SHORT_KEY]")
def test_empty_and_none_keys(self):
"""Test handling of empty and None keys"""
# Test empty string
result = redact_key_for_logging("")
self.assertEqual(result, "[INVALID_KEY]")
# Test None
result = redact_key_for_logging(None)
self.assertEqual(result, "[INVALID_KEY]")
def test_invalid_input_types(self):
"""Test handling of invalid input types"""
# Test integer
result = redact_key_for_logging(123)
self.assertEqual(result, "[INVALID_KEY]")
# Test list
result = redact_key_for_logging(["key"])
self.assertEqual(result, "[INVALID_KEY]")
# Test dict
result = redact_key_for_logging({"key": "value"})
self.assertEqual(result, "[INVALID_KEY]")
def test_boundary_cases(self):
"""Test boundary cases for key length"""
# Test 13 characters (just above the threshold)
key_13 = "1234567890123"
result = redact_key_for_logging(key_13)
expected = "123456...890123"
self.assertEqual(result, expected)
# Test very long key
long_key = "a" * 100
result = redact_key_for_logging(long_key)
expected = "aaaaaa...aaaaaa"
self.assertEqual(result, expected)
class TestAccessLogFormatter(unittest.TestCase):
"""Test cases for the AccessLogFormatter class"""
def setUp(self):
"""Set up test fixtures"""
self.formatter = AccessLogFormatter()
def test_gemini_key_redaction_in_url(self):
"""Test redaction of Gemini API keys in URLs"""
log_message = (
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
)
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertIn("AIzaSy...xDfGhI", result)
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
def test_openai_key_redaction_in_url(self):
"""Test redaction of OpenAI API keys in URLs"""
log_message = 'GET /api/models?key=sk-1234567890abcdef1234567890abcdef1234567890abcdef HTTP/1.1" 200'
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertIn("sk-123...abcdef", result)
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
def test_multiple_keys_in_message(self):
"""Test redaction of multiple API keys in a single message"""
log_message = "Request with keys: AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI and sk-1234567890abcdef1234567890abcdef1234567890abcdef"
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertIn("AIzaSy...xDfGhI", result)
self.assertIn("sk-123...abcdef", result)
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
def test_no_keys_in_message(self):
"""Test that messages without API keys are unchanged"""
log_message = 'GET /api/health HTTP/1.1" 200'
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertEqual(result, log_message)
def test_partial_key_patterns_not_redacted(self):
"""Test that partial key patterns are not redacted"""
log_message = "Message with partial patterns: AIza sk- incomplete"
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertEqual(result, log_message)
def test_error_handling_in_redaction(self):
"""Test error handling in the redaction process"""
# Test by directly calling _redact_api_keys_in_message with a broken pattern
original_patterns = self.formatter.compiled_patterns
# Create a mock pattern that will raise an exception
mock_pattern = MagicMock()
mock_pattern.sub.side_effect = Exception("Regex error")
self.formatter.compiled_patterns = [mock_pattern]
try:
log_message = (
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
)
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertEqual(result, "[LOG_REDACTION_ERROR]")
finally:
# Restore original patterns
self.formatter.compiled_patterns = original_patterns
def test_format_method(self):
"""Test the format method of AccessLogFormatter"""
# Create a mock log record
record = MagicMock()
record.getMessage.return_value = (
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
)
# Mock the parent format method
with patch(
"logging.Formatter.format",
return_value='2025-01-01 12:00:00 | INFO | POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200',
):
result = self.formatter.format(record)
self.assertIn("AIzaSy...xDfGhI", result)
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
def test_regex_patterns_compilation(self):
"""Test that regex patterns are properly compiled"""
formatter = AccessLogFormatter()
self.assertEqual(len(formatter.compiled_patterns), 2)
self.assertTrue(
all(hasattr(pattern, "sub") for pattern in formatter.compiled_patterns)
)
def test_flexible_openai_pattern(self):
"""Test the flexible OpenAI pattern matches various formats"""
test_cases = [
"sk-1234567890abcdef1234567890abcdef1234567890abcdef", # Standard 48 chars
"sk-proj-1234567890abcdef1234567890abcdef1234567890abcdef", # Project key
"sk-1234567890abcdef_1234567890abcdef-1234567890abcdef", # With underscores/hyphens
"sk-12345678901234567890", # Shorter key (20 chars)
]
for test_key in test_cases:
log_message = f"Request with key: {test_key}"
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertNotIn(test_key, result)
self.assertIn("sk-", result) # Should still contain the prefix
if __name__ == "__main__":
unittest.main()