Medium-MCP / tests /unit /test_security.py
Nikhil Pravin Pise
feat: implement comprehensive improvement plan (Phases 1-5)
e98cc10
"""
Unit Tests for Security Module
Tests for rate limiting, secrets management, and security headers.
"""
from __future__ import annotations
import os
import time
from unittest.mock import patch
import pytest
from src.security import (
TokenBucketRateLimiter,
RateLimitConfig,
get_rate_limiter,
get_secret,
require_secret,
mask_secret,
generate_token,
generate_api_key,
hash_value,
verify_signature,
get_csp_header,
get_security_headers,
)
class TestTokenBucketRateLimiter:
"""Tests for rate limiter."""
def test_allows_initial_requests(self) -> None:
"""Should allow requests under limit."""
config = RateLimitConfig(requests_per_minute=60, burst_limit=10)
limiter = TokenBucketRateLimiter(config)
for _ in range(5):
assert limiter.check("test")
def test_blocks_after_burst(self) -> None:
"""Should block after burst limit."""
config = RateLimitConfig(requests_per_minute=60, burst_limit=3)
limiter = TokenBucketRateLimiter(config)
# Use up burst
for _ in range(3):
limiter.check("test")
# Next should fail
assert not limiter.check("test")
def test_refills_over_time(self) -> None:
"""Should refill tokens over time."""
config = RateLimitConfig(requests_per_minute=6000, burst_limit=1)
limiter = TokenBucketRateLimiter(config)
limiter.check("test")
assert not limiter.check("test")
# Wait for refill (at 100/sec, should refill in 0.01s)
time.sleep(0.02)
assert limiter.check("test")
def test_retry_after(self) -> None:
"""Should return retry time."""
config = RateLimitConfig(requests_per_minute=60, burst_limit=1)
limiter = TokenBucketRateLimiter(config)
limiter.check("test")
retry = limiter.get_retry_after("test")
assert retry > 0
def test_reset_clears_state(self) -> None:
"""Reset should clear rate limit state."""
config = RateLimitConfig(burst_limit=1)
limiter = TokenBucketRateLimiter(config)
limiter.check("test")
assert not limiter.check("test")
limiter.reset("test")
assert limiter.check("test")
def test_separate_keys(self) -> None:
"""Different keys should have separate limits."""
config = RateLimitConfig(burst_limit=2)
limiter = TokenBucketRateLimiter(config)
limiter.check("user1")
limiter.check("user1")
assert not limiter.check("user1")
# user2 has fresh limit
assert limiter.check("user2")
class TestSecretsManagement:
"""Tests for secrets functions."""
def test_get_secret_from_env(self) -> None:
"""Should get secret from environment."""
with patch.dict(os.environ, {"TEST_SECRET": "my-secret"}):
assert get_secret("TEST_SECRET") == "my-secret"
def test_get_secret_default(self) -> None:
"""Should return default when not found."""
result = get_secret("NONEXISTENT_SECRET", "default")
assert result == "default"
def test_require_secret_raises(self) -> None:
"""Should raise when required secret missing."""
with pytest.raises(ValueError):
require_secret("DEFINITELY_NOT_SET")
def test_mask_secret(self) -> None:
"""Should mask secret correctly."""
masked = mask_secret("my-secret-key", 4)
assert masked == "*************-key"
assert "secret" not in masked
def test_mask_short_secret(self) -> None:
"""Should mask short secrets completely."""
masked = mask_secret("abc", 4)
assert masked == "***"
class TestTokenGeneration:
"""Tests for token generation."""
def test_generate_token_length(self) -> None:
"""Should generate correct length tokens."""
token = generate_token(16)
assert len(token) == 32 # hex = 2x bytes
def test_generate_token_unique(self) -> None:
"""Tokens should be unique."""
tokens = [generate_token() for _ in range(10)]
assert len(set(tokens)) == 10
def test_generate_api_key_format(self) -> None:
"""API key should have correct prefix."""
key = generate_api_key()
assert key.startswith("mcp_")
assert len(key) > 10
class TestHashing:
"""Tests for hashing functions."""
def test_hash_value_consistent(self) -> None:
"""Same input should produce same hash."""
h1 = hash_value("test")
h2 = hash_value("test")
assert h1 == h2
def test_hash_value_different(self) -> None:
"""Different inputs should produce different hashes."""
h1 = hash_value("test1")
h2 = hash_value("test2")
assert h1 != h2
def test_hash_with_salt(self) -> None:
"""Different salts should produce different hashes."""
h1 = hash_value("test", salt="salt1")
h2 = hash_value("test", salt="salt2")
assert h1 != h2
def test_verify_signature_valid(self) -> None:
"""Valid signature should verify."""
import hmac
import hashlib
payload = "test-payload"
secret = "my-secret"
signature = hmac.new(
secret.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()
assert verify_signature(payload, signature, secret)
def test_verify_signature_invalid(self) -> None:
"""Invalid signature should fail."""
assert not verify_signature("payload", "bad-signature", "secret")
class TestSecurityHeaders:
"""Tests for security headers."""
def test_get_csp_header(self) -> None:
"""CSP header should be formatted correctly."""
csp = get_csp_header()
assert "default-src" in csp
assert "script-src" in csp
def test_get_security_headers(self) -> None:
"""Should return all security headers."""
headers = get_security_headers()
assert "X-Content-Type-Options" in headers
assert "X-Frame-Options" in headers
assert "Content-Security-Policy" in headers
assert headers["X-Frame-Options"] == "DENY"