File size: 7,704 Bytes
07a2f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """
Unit tests for API key redaction functionality
"""
import unittest
import logging
from unittest.mock import patch, MagicMock
from app.utils.helpers import redact_key_for_logging
from app.log.logger import AccessLogFormatter
class TestKeyRedaction(unittest.TestCase):
"""Test cases for the redact_key_for_logging function"""
def test_valid_long_key_redaction(self):
"""Test redaction of valid long API keys"""
# Test Google/Gemini API key
# This value is a random generated string for testing
gemini_key = "AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI"
result = redact_key_for_logging(gemini_key)
expected = "AIzaSy...xDfGhI"
self.assertEqual(result, expected)
# Test OpenAI API key
# This value is a random generated string for testing
openai_key = "sk-1234567890abcdef1234567890abcdef1234567890abcdef"
result = redact_key_for_logging(openai_key)
expected = "sk-123...abcdef"
self.assertEqual(result, expected)
def test_short_key_handling(self):
"""Test handling of short keys"""
short_key = "short"
result = redact_key_for_logging(short_key)
self.assertEqual(result, "[SHORT_KEY]")
# Test exactly 12 characters (boundary case)
boundary_key = "123456789012"
result = redact_key_for_logging(boundary_key)
self.assertEqual(result, "[SHORT_KEY]")
def test_empty_and_none_keys(self):
"""Test handling of empty and None keys"""
# Test empty string
result = redact_key_for_logging("")
self.assertEqual(result, "[INVALID_KEY]")
# Test None
result = redact_key_for_logging(None)
self.assertEqual(result, "[INVALID_KEY]")
def test_invalid_input_types(self):
"""Test handling of invalid input types"""
# Test integer
result = redact_key_for_logging(123)
self.assertEqual(result, "[INVALID_KEY]")
# Test list
result = redact_key_for_logging(["key"])
self.assertEqual(result, "[INVALID_KEY]")
# Test dict
result = redact_key_for_logging({"key": "value"})
self.assertEqual(result, "[INVALID_KEY]")
def test_boundary_cases(self):
"""Test boundary cases for key length"""
# Test 13 characters (just above the threshold)
key_13 = "1234567890123"
result = redact_key_for_logging(key_13)
expected = "123456...890123"
self.assertEqual(result, expected)
# Test very long key
long_key = "a" * 100
result = redact_key_for_logging(long_key)
expected = "aaaaaa...aaaaaa"
self.assertEqual(result, expected)
class TestAccessLogFormatter(unittest.TestCase):
"""Test cases for the AccessLogFormatter class"""
def setUp(self):
"""Set up test fixtures"""
self.formatter = AccessLogFormatter()
def test_gemini_key_redaction_in_url(self):
"""Test redaction of Gemini API keys in URLs"""
log_message = (
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
)
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertIn("AIzaSy...xDfGhI", result)
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
def test_openai_key_redaction_in_url(self):
"""Test redaction of OpenAI API keys in URLs"""
log_message = 'GET /api/models?key=sk-1234567890abcdef1234567890abcdef1234567890abcdef HTTP/1.1" 200'
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertIn("sk-123...abcdef", result)
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
def test_multiple_keys_in_message(self):
"""Test redaction of multiple API keys in a single message"""
log_message = "Request with keys: AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI and sk-1234567890abcdef1234567890abcdef1234567890abcdef"
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertIn("AIzaSy...xDfGhI", result)
self.assertIn("sk-123...abcdef", result)
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
self.assertNotIn("sk-1234567890abcdef1234567890abcdef1234567890abcdef", result)
def test_no_keys_in_message(self):
"""Test that messages without API keys are unchanged"""
log_message = 'GET /api/health HTTP/1.1" 200'
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertEqual(result, log_message)
def test_partial_key_patterns_not_redacted(self):
"""Test that partial key patterns are not redacted"""
log_message = "Message with partial patterns: AIza sk- incomplete"
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertEqual(result, log_message)
def test_error_handling_in_redaction(self):
"""Test error handling in the redaction process"""
# Test by directly calling _redact_api_keys_in_message with a broken pattern
original_patterns = self.formatter.compiled_patterns
# Create a mock pattern that will raise an exception
mock_pattern = MagicMock()
mock_pattern.sub.side_effect = Exception("Regex error")
self.formatter.compiled_patterns = [mock_pattern]
try:
log_message = (
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
)
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertEqual(result, "[LOG_REDACTION_ERROR]")
finally:
# Restore original patterns
self.formatter.compiled_patterns = original_patterns
def test_format_method(self):
"""Test the format method of AccessLogFormatter"""
# Create a mock log record
record = MagicMock()
record.getMessage.return_value = (
'POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200'
)
# Mock the parent format method
with patch(
"logging.Formatter.format",
return_value='2025-01-01 12:00:00 | INFO | POST /verify-key/AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI HTTP/1.1" 200',
):
result = self.formatter.format(record)
self.assertIn("AIzaSy...xDfGhI", result)
self.assertNotIn("AIzaSyDhKGfJ8xYzQwErTyUiOpLkMnBvCxDfGhI", result)
def test_regex_patterns_compilation(self):
"""Test that regex patterns are properly compiled"""
formatter = AccessLogFormatter()
self.assertEqual(len(formatter.compiled_patterns), 2)
self.assertTrue(
all(hasattr(pattern, "sub") for pattern in formatter.compiled_patterns)
)
def test_flexible_openai_pattern(self):
"""Test the flexible OpenAI pattern matches various formats"""
test_cases = [
"sk-1234567890abcdef1234567890abcdef1234567890abcdef", # Standard 48 chars
"sk-proj-1234567890abcdef1234567890abcdef1234567890abcdef", # Project key
"sk-1234567890abcdef_1234567890abcdef-1234567890abcdef", # With underscores/hyphens
"sk-12345678901234567890", # Shorter key (20 chars)
]
for test_key in test_cases:
log_message = f"Request with key: {test_key}"
result = self.formatter._redact_api_keys_in_message(log_message)
self.assertNotIn(test_key, result)
self.assertIn("sk-", result) # Should still contain the prefix
if __name__ == "__main__":
unittest.main()
|