Spaces:
Sleeping
Sleeping
| """ | |
| Security Module | |
| Security utilities for rate limiting, secrets management, | |
| and request validation. | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import hmac | |
| import logging | |
| import os | |
| import secrets | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from functools import wraps | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # RATE LIMITING | |
| # ============================================================================= | |
| class RateLimitConfig: | |
| """Configuration for rate limiting.""" | |
| requests_per_minute: int = 60 | |
| requests_per_hour: int = 1000 | |
| burst_limit: int = 10 | |
| class RateLimitState: | |
| """Current state of rate limiter.""" | |
| tokens: float = 0.0 | |
| last_update: float = 0.0 | |
| hourly_count: int = 0 | |
| hourly_reset: float = 0.0 | |
| class TokenBucketRateLimiter: | |
| """ | |
| Token bucket rate limiter. | |
| Allows bursts up to bucket size while maintaining | |
| average rate over time. | |
| """ | |
| def __init__(self, config: Optional[RateLimitConfig] = None) -> None: | |
| self.config = config or RateLimitConfig() | |
| self._states: dict[str, RateLimitState] = {} | |
| def _get_state(self, key: str) -> RateLimitState: | |
| """Get or create state for key.""" | |
| if key not in self._states: | |
| now = time.time() | |
| self._states[key] = RateLimitState( | |
| tokens=float(self.config.burst_limit), | |
| last_update=now, | |
| hourly_reset=now + 3600, | |
| ) | |
| return self._states[key] | |
| def check(self, key: str = "default") -> bool: | |
| """ | |
| Check if request is allowed. | |
| Args: | |
| key: Rate limit key (e.g., IP address, user ID) | |
| Returns: | |
| True if allowed, False if rate limited | |
| """ | |
| now = time.time() | |
| state = self._get_state(key) | |
| # Refill tokens based on elapsed time | |
| elapsed = now - state.last_update | |
| rate = self.config.requests_per_minute / 60.0 | |
| state.tokens = min( | |
| float(self.config.burst_limit), | |
| state.tokens + elapsed * rate | |
| ) | |
| state.last_update = now | |
| # Reset hourly counter if needed | |
| if now >= state.hourly_reset: | |
| state.hourly_count = 0 | |
| state.hourly_reset = now + 3600 | |
| # Check hourly limit | |
| if state.hourly_count >= self.config.requests_per_hour: | |
| logger.warning(f"Hourly rate limit exceeded for {key}") | |
| return False | |
| # Check token bucket | |
| if state.tokens < 1: | |
| logger.warning(f"Rate limit exceeded for {key}") | |
| return False | |
| # Consume token | |
| state.tokens -= 1 | |
| state.hourly_count += 1 | |
| return True | |
| def get_retry_after(self, key: str = "default") -> int: | |
| """ | |
| Get seconds until next request is allowed. | |
| Args: | |
| key: Rate limit key | |
| Returns: | |
| Seconds to wait | |
| """ | |
| state = self._get_state(key) | |
| if state.tokens >= 1: | |
| return 0 | |
| rate = self.config.requests_per_minute / 60.0 | |
| return int((1 - state.tokens) / rate) + 1 | |
| def reset(self, key: str = "default") -> None: | |
| """Reset rate limiter for key.""" | |
| if key in self._states: | |
| del self._states[key] | |
| # Global rate limiter instance | |
| _rate_limiter: Optional[TokenBucketRateLimiter] = None | |
| def get_rate_limiter() -> TokenBucketRateLimiter: | |
| """Get or create global rate limiter.""" | |
| global _rate_limiter | |
| if _rate_limiter is None: | |
| _rate_limiter = TokenBucketRateLimiter() | |
| return _rate_limiter | |
| def rate_limit(key: str = "default"): | |
| """ | |
| Rate limiting decorator. | |
| Args: | |
| key: Rate limit key | |
| """ | |
| def decorator(func): | |
| async def wrapper(*args, **kwargs): | |
| from src.exceptions import RateLimitError | |
| limiter = get_rate_limiter() | |
| if not limiter.check(key): | |
| retry_after = limiter.get_retry_after(key) | |
| raise RateLimitError( | |
| "Rate limit exceeded", | |
| retry_after=retry_after | |
| ) | |
| return await func(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| # ============================================================================= | |
| # SECRETS MANAGEMENT | |
| # ============================================================================= | |
| def get_secret(name: str, default: Optional[str] = None) -> Optional[str]: | |
| """ | |
| Get a secret from environment variables. | |
| Args: | |
| name: Secret name (environment variable) | |
| default: Default value if not found | |
| Returns: | |
| Secret value or default | |
| """ | |
| value = os.environ.get(name) | |
| if not value: | |
| if default is not None: | |
| return default | |
| logger.warning(f"Secret {name} not found in environment") | |
| return None | |
| return value | |
| def require_secret(name: str) -> str: | |
| """ | |
| Get a required secret from environment. | |
| Args: | |
| name: Secret name | |
| Returns: | |
| Secret value | |
| Raises: | |
| ValueError: If secret not found | |
| """ | |
| value = get_secret(name) | |
| if not value: | |
| raise ValueError(f"Required secret {name} not configured") | |
| return value | |
| def mask_secret(value: str, visible_chars: int = 4) -> str: | |
| """ | |
| Mask a secret for logging. | |
| Args: | |
| value: Secret to mask | |
| visible_chars: Number of chars to show at end | |
| Returns: | |
| Masked string (e.g., "****xyz") | |
| """ | |
| if not value: | |
| return "" | |
| if len(value) <= visible_chars: | |
| return "*" * len(value) | |
| return "*" * (len(value) - visible_chars) + value[-visible_chars:] | |
| # ============================================================================= | |
| # TOKEN GENERATION | |
| # ============================================================================= | |
| def generate_token(length: int = 32) -> str: | |
| """ | |
| Generate a cryptographically secure random token. | |
| Args: | |
| length: Token length in bytes (will be hex-encoded) | |
| Returns: | |
| Hex-encoded random token | |
| """ | |
| return secrets.token_hex(length) | |
| def generate_api_key() -> str: | |
| """Generate an API key in format 'mcp_xxx...'.""" | |
| return f"mcp_{secrets.token_hex(24)}" | |
| # ============================================================================= | |
| # HASH FUNCTIONS | |
| # ============================================================================= | |
| def hash_value(value: str, salt: Optional[str] = None) -> str: | |
| """ | |
| Hash a value using SHA-256. | |
| Args: | |
| value: Value to hash | |
| salt: Optional salt | |
| Returns: | |
| Hex-encoded hash | |
| """ | |
| if salt: | |
| value = f"{salt}:{value}" | |
| return hashlib.sha256(value.encode()).hexdigest() | |
| def verify_signature( | |
| payload: str, | |
| signature: str, | |
| secret: str, | |
| ) -> bool: | |
| """ | |
| Verify HMAC-SHA256 signature. | |
| Args: | |
| payload: Signed payload | |
| signature: Expected signature (hex-encoded) | |
| secret: Signing secret | |
| Returns: | |
| True if signature is valid | |
| """ | |
| expected = hmac.new( | |
| secret.encode(), | |
| payload.encode(), | |
| hashlib.sha256 | |
| ).hexdigest() | |
| return hmac.compare_digest(expected, signature) | |
| # ============================================================================= | |
| # SECURITY HEADERS | |
| # ============================================================================= | |
| SECURITY_HEADERS = { | |
| "X-Content-Type-Options": "nosniff", | |
| "X-Frame-Options": "DENY", | |
| "X-XSS-Protection": "1; mode=block", | |
| "Referrer-Policy": "strict-origin-when-cross-origin", | |
| "Permissions-Policy": "geolocation=(), microphone=(), camera=()", | |
| } | |
| CSP_DIRECTIVES = { | |
| "default-src": "'self'", | |
| "script-src": "'self' 'unsafe-inline' https://cdn.jsdelivr.net", | |
| "style-src": "'self' 'unsafe-inline' https://fonts.googleapis.com", | |
| "font-src": "'self' https://fonts.gstatic.com", | |
| "img-src": "'self' data: https: blob:", | |
| "connect-src": "'self' https://api.elevenlabs.io https://api.groq.com", | |
| } | |
| def get_csp_header() -> str: | |
| """Get Content-Security-Policy header value.""" | |
| return "; ".join(f"{k} {v}" for k, v in CSP_DIRECTIVES.items()) | |
| def get_security_headers() -> dict[str, str]: | |
| """Get all security headers including CSP.""" | |
| headers = SECURITY_HEADERS.copy() | |
| headers["Content-Security-Policy"] = get_csp_header() | |
| return headers | |