Medium-MCP / src /security.py
Nikhil Pravin Pise
feat: implement comprehensive improvement plan (Phases 1-5)
e98cc10
"""
Security Module
Security utilities for rate limiting, secrets management,
and request validation.
"""
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import secrets
import time
from dataclasses import dataclass, field
from typing import Optional
from functools import wraps
logger = logging.getLogger(__name__)
# =============================================================================
# RATE LIMITING
# =============================================================================
@dataclass
class RateLimitConfig:
"""Configuration for rate limiting."""
requests_per_minute: int = 60
requests_per_hour: int = 1000
burst_limit: int = 10
@dataclass
class RateLimitState:
"""Current state of rate limiter."""
tokens: float = 0.0
last_update: float = 0.0
hourly_count: int = 0
hourly_reset: float = 0.0
class TokenBucketRateLimiter:
"""
Token bucket rate limiter.
Allows bursts up to bucket size while maintaining
average rate over time.
"""
def __init__(self, config: Optional[RateLimitConfig] = None) -> None:
self.config = config or RateLimitConfig()
self._states: dict[str, RateLimitState] = {}
def _get_state(self, key: str) -> RateLimitState:
"""Get or create state for key."""
if key not in self._states:
now = time.time()
self._states[key] = RateLimitState(
tokens=float(self.config.burst_limit),
last_update=now,
hourly_reset=now + 3600,
)
return self._states[key]
def check(self, key: str = "default") -> bool:
"""
Check if request is allowed.
Args:
key: Rate limit key (e.g., IP address, user ID)
Returns:
True if allowed, False if rate limited
"""
now = time.time()
state = self._get_state(key)
# Refill tokens based on elapsed time
elapsed = now - state.last_update
rate = self.config.requests_per_minute / 60.0
state.tokens = min(
float(self.config.burst_limit),
state.tokens + elapsed * rate
)
state.last_update = now
# Reset hourly counter if needed
if now >= state.hourly_reset:
state.hourly_count = 0
state.hourly_reset = now + 3600
# Check hourly limit
if state.hourly_count >= self.config.requests_per_hour:
logger.warning(f"Hourly rate limit exceeded for {key}")
return False
# Check token bucket
if state.tokens < 1:
logger.warning(f"Rate limit exceeded for {key}")
return False
# Consume token
state.tokens -= 1
state.hourly_count += 1
return True
def get_retry_after(self, key: str = "default") -> int:
"""
Get seconds until next request is allowed.
Args:
key: Rate limit key
Returns:
Seconds to wait
"""
state = self._get_state(key)
if state.tokens >= 1:
return 0
rate = self.config.requests_per_minute / 60.0
return int((1 - state.tokens) / rate) + 1
def reset(self, key: str = "default") -> None:
"""Reset rate limiter for key."""
if key in self._states:
del self._states[key]
# Global rate limiter instance
_rate_limiter: Optional[TokenBucketRateLimiter] = None
def get_rate_limiter() -> TokenBucketRateLimiter:
"""Get or create global rate limiter."""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = TokenBucketRateLimiter()
return _rate_limiter
def rate_limit(key: str = "default"):
"""
Rate limiting decorator.
Args:
key: Rate limit key
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
from src.exceptions import RateLimitError
limiter = get_rate_limiter()
if not limiter.check(key):
retry_after = limiter.get_retry_after(key)
raise RateLimitError(
"Rate limit exceeded",
retry_after=retry_after
)
return await func(*args, **kwargs)
return wrapper
return decorator
# =============================================================================
# SECRETS MANAGEMENT
# =============================================================================
def get_secret(name: str, default: Optional[str] = None) -> Optional[str]:
"""
Get a secret from environment variables.
Args:
name: Secret name (environment variable)
default: Default value if not found
Returns:
Secret value or default
"""
value = os.environ.get(name)
if not value:
if default is not None:
return default
logger.warning(f"Secret {name} not found in environment")
return None
return value
def require_secret(name: str) -> str:
"""
Get a required secret from environment.
Args:
name: Secret name
Returns:
Secret value
Raises:
ValueError: If secret not found
"""
value = get_secret(name)
if not value:
raise ValueError(f"Required secret {name} not configured")
return value
def mask_secret(value: str, visible_chars: int = 4) -> str:
"""
Mask a secret for logging.
Args:
value: Secret to mask
visible_chars: Number of chars to show at end
Returns:
Masked string (e.g., "****xyz")
"""
if not value:
return ""
if len(value) <= visible_chars:
return "*" * len(value)
return "*" * (len(value) - visible_chars) + value[-visible_chars:]
# =============================================================================
# TOKEN GENERATION
# =============================================================================
def generate_token(length: int = 32) -> str:
"""
Generate a cryptographically secure random token.
Args:
length: Token length in bytes (will be hex-encoded)
Returns:
Hex-encoded random token
"""
return secrets.token_hex(length)
def generate_api_key() -> str:
"""Generate an API key in format 'mcp_xxx...'."""
return f"mcp_{secrets.token_hex(24)}"
# =============================================================================
# HASH FUNCTIONS
# =============================================================================
def hash_value(value: str, salt: Optional[str] = None) -> str:
"""
Hash a value using SHA-256.
Args:
value: Value to hash
salt: Optional salt
Returns:
Hex-encoded hash
"""
if salt:
value = f"{salt}:{value}"
return hashlib.sha256(value.encode()).hexdigest()
def verify_signature(
payload: str,
signature: str,
secret: str,
) -> bool:
"""
Verify HMAC-SHA256 signature.
Args:
payload: Signed payload
signature: Expected signature (hex-encoded)
secret: Signing secret
Returns:
True if signature is valid
"""
expected = hmac.new(
secret.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(expected, signature)
# =============================================================================
# SECURITY HEADERS
# =============================================================================
SECURITY_HEADERS = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
}
CSP_DIRECTIVES = {
"default-src": "'self'",
"script-src": "'self' 'unsafe-inline' https://cdn.jsdelivr.net",
"style-src": "'self' 'unsafe-inline' https://fonts.googleapis.com",
"font-src": "'self' https://fonts.gstatic.com",
"img-src": "'self' data: https: blob:",
"connect-src": "'self' https://api.elevenlabs.io https://api.groq.com",
}
def get_csp_header() -> str:
"""Get Content-Security-Policy header value."""
return "; ".join(f"{k} {v}" for k, v in CSP_DIRECTIVES.items())
def get_security_headers() -> dict[str, str]:
"""Get all security headers including CSP."""
headers = SECURITY_HEADERS.copy()
headers["Content-Security-Policy"] = get_csp_header()
return headers