""" Security Module Security utilities for rate limiting, secrets management, and request validation. """ from __future__ import annotations import hashlib import hmac import logging import os import secrets import time from dataclasses import dataclass, field from typing import Optional from functools import wraps logger = logging.getLogger(__name__) # ============================================================================= # RATE LIMITING # ============================================================================= @dataclass class RateLimitConfig: """Configuration for rate limiting.""" requests_per_minute: int = 60 requests_per_hour: int = 1000 burst_limit: int = 10 @dataclass class RateLimitState: """Current state of rate limiter.""" tokens: float = 0.0 last_update: float = 0.0 hourly_count: int = 0 hourly_reset: float = 0.0 class TokenBucketRateLimiter: """ Token bucket rate limiter. Allows bursts up to bucket size while maintaining average rate over time. """ def __init__(self, config: Optional[RateLimitConfig] = None) -> None: self.config = config or RateLimitConfig() self._states: dict[str, RateLimitState] = {} def _get_state(self, key: str) -> RateLimitState: """Get or create state for key.""" if key not in self._states: now = time.time() self._states[key] = RateLimitState( tokens=float(self.config.burst_limit), last_update=now, hourly_reset=now + 3600, ) return self._states[key] def check(self, key: str = "default") -> bool: """ Check if request is allowed. Args: key: Rate limit key (e.g., IP address, user ID) Returns: True if allowed, False if rate limited """ now = time.time() state = self._get_state(key) # Refill tokens based on elapsed time elapsed = now - state.last_update rate = self.config.requests_per_minute / 60.0 state.tokens = min( float(self.config.burst_limit), state.tokens + elapsed * rate ) state.last_update = now # Reset hourly counter if needed if now >= state.hourly_reset: state.hourly_count = 0 state.hourly_reset = now + 3600 # Check hourly limit if state.hourly_count >= self.config.requests_per_hour: logger.warning(f"Hourly rate limit exceeded for {key}") return False # Check token bucket if state.tokens < 1: logger.warning(f"Rate limit exceeded for {key}") return False # Consume token state.tokens -= 1 state.hourly_count += 1 return True def get_retry_after(self, key: str = "default") -> int: """ Get seconds until next request is allowed. Args: key: Rate limit key Returns: Seconds to wait """ state = self._get_state(key) if state.tokens >= 1: return 0 rate = self.config.requests_per_minute / 60.0 return int((1 - state.tokens) / rate) + 1 def reset(self, key: str = "default") -> None: """Reset rate limiter for key.""" if key in self._states: del self._states[key] # Global rate limiter instance _rate_limiter: Optional[TokenBucketRateLimiter] = None def get_rate_limiter() -> TokenBucketRateLimiter: """Get or create global rate limiter.""" global _rate_limiter if _rate_limiter is None: _rate_limiter = TokenBucketRateLimiter() return _rate_limiter def rate_limit(key: str = "default"): """ Rate limiting decorator. Args: key: Rate limit key """ def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): from src.exceptions import RateLimitError limiter = get_rate_limiter() if not limiter.check(key): retry_after = limiter.get_retry_after(key) raise RateLimitError( "Rate limit exceeded", retry_after=retry_after ) return await func(*args, **kwargs) return wrapper return decorator # ============================================================================= # SECRETS MANAGEMENT # ============================================================================= def get_secret(name: str, default: Optional[str] = None) -> Optional[str]: """ Get a secret from environment variables. Args: name: Secret name (environment variable) default: Default value if not found Returns: Secret value or default """ value = os.environ.get(name) if not value: if default is not None: return default logger.warning(f"Secret {name} not found in environment") return None return value def require_secret(name: str) -> str: """ Get a required secret from environment. Args: name: Secret name Returns: Secret value Raises: ValueError: If secret not found """ value = get_secret(name) if not value: raise ValueError(f"Required secret {name} not configured") return value def mask_secret(value: str, visible_chars: int = 4) -> str: """ Mask a secret for logging. Args: value: Secret to mask visible_chars: Number of chars to show at end Returns: Masked string (e.g., "****xyz") """ if not value: return "" if len(value) <= visible_chars: return "*" * len(value) return "*" * (len(value) - visible_chars) + value[-visible_chars:] # ============================================================================= # TOKEN GENERATION # ============================================================================= def generate_token(length: int = 32) -> str: """ Generate a cryptographically secure random token. Args: length: Token length in bytes (will be hex-encoded) Returns: Hex-encoded random token """ return secrets.token_hex(length) def generate_api_key() -> str: """Generate an API key in format 'mcp_xxx...'.""" return f"mcp_{secrets.token_hex(24)}" # ============================================================================= # HASH FUNCTIONS # ============================================================================= def hash_value(value: str, salt: Optional[str] = None) -> str: """ Hash a value using SHA-256. Args: value: Value to hash salt: Optional salt Returns: Hex-encoded hash """ if salt: value = f"{salt}:{value}" return hashlib.sha256(value.encode()).hexdigest() def verify_signature( payload: str, signature: str, secret: str, ) -> bool: """ Verify HMAC-SHA256 signature. Args: payload: Signed payload signature: Expected signature (hex-encoded) secret: Signing secret Returns: True if signature is valid """ expected = hmac.new( secret.encode(), payload.encode(), hashlib.sha256 ).hexdigest() return hmac.compare_digest(expected, signature) # ============================================================================= # SECURITY HEADERS # ============================================================================= SECURITY_HEADERS = { "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block", "Referrer-Policy": "strict-origin-when-cross-origin", "Permissions-Policy": "geolocation=(), microphone=(), camera=()", } CSP_DIRECTIVES = { "default-src": "'self'", "script-src": "'self' 'unsafe-inline' https://cdn.jsdelivr.net", "style-src": "'self' 'unsafe-inline' https://fonts.googleapis.com", "font-src": "'self' https://fonts.gstatic.com", "img-src": "'self' data: https: blob:", "connect-src": "'self' https://api.elevenlabs.io https://api.groq.com", } def get_csp_header() -> str: """Get Content-Security-Policy header value.""" return "; ".join(f"{k} {v}" for k, v in CSP_DIRECTIVES.items()) def get_security_headers() -> dict[str, str]: """Get all security headers including CSP.""" headers = SECURITY_HEADERS.copy() headers["Content-Security-Policy"] = get_csp_header() return headers