""" Rate Limiting & DDoS Protection - Rate limiting enforcement - DDoS protection - API quota enforcement """ import time from typing import Optional, Tuple from functools import wraps from fastapi import Request, HTTPException from server.cache_manager import cache, RateLimiter # Rate limiters for different endpoints RATE_LIMITERS = { 'api_crawl': RateLimiter(max_requests=10, window_seconds=3600), # 10/hour 'api_analyze': RateLimiter(max_requests=20, window_seconds=3600), # 20/hour 'api_keywords': RateLimiter(max_requests=15, window_seconds=3600), # 15/hour 'api_content_generate': RateLimiter(max_requests=5, window_seconds=3600), # 5/hour 'api_search': RateLimiter(max_requests=30, window_seconds=3600), # 30/hour 'api_default': RateLimiter(max_requests=100, window_seconds=60), # 100/minute } # Endpoint-specific limits ENDPOINT_LIMITS = { '/api/crawl': 'api_crawl', '/api/analyze': 'api_analyze', '/api/keywords': 'api_keywords', '/api/content/generate': 'api_content_generate', '/api/search': 'api_search', } class RateLimitExceeded(HTTPException): """Rate limit exceeded exception""" def __init__(self, retry_after: int = 60): self.retry_after = retry_after super().__init__( status_code=429, detail=f'Rate limit exceeded. Retry after {retry_after} seconds.' ) def get_client_identifier(request: Request) -> str: """Get unique client identifier""" # Try to get user ID from token try: auth = request.headers.get('authorization', '') if auth.startswith('Bearer '): token = auth.split(' ', 1)[1].strip() from server import users uid = users.verify_token(token) if uid: return f"user:{uid}" except: pass # Fall back to IP address client_ip = request.client.host if request.client else 'unknown' return f"ip:{client_ip}" def rate_limit(limiter_key: str = 'api_default'): """Rate limiting decorator""" def decorator(func): @wraps(func) async def wrapper(request: Request, *args, **kwargs): limiter = RATE_LIMITERS.get(limiter_key, RATE_LIMITERS['api_default']) identifier = get_client_identifier(request) if not limiter.is_allowed(identifier): remaining = limiter.get_remaining(identifier) raise RateLimitExceeded(retry_after=60) # Add rate limit headers response = await func(request, *args, **kwargs) remaining = limiter.get_remaining(identifier) if hasattr(response, 'headers'): response.headers['X-RateLimit-Remaining'] = str(remaining) response.headers['X-RateLimit-Limit'] = str(limiter.max_requests) return response return wrapper return decorator def rate_limit_by_endpoint(request: Request) -> Tuple[bool, Optional[int]]: """Check rate limit for endpoint""" endpoint = request.url.path limiter_key = ENDPOINT_LIMITS.get(endpoint, 'api_default') limiter = RATE_LIMITERS[limiter_key] identifier = get_client_identifier(request) allowed = limiter.is_allowed(identifier) remaining = limiter.get_remaining(identifier) return allowed, remaining class DDoSProtection: """DDoS protection mechanisms""" # Suspicious activity thresholds REQUESTS_PER_SECOND = 100 UNIQUE_IPS_THRESHOLD = 50 FAILED_REQUESTS_THRESHOLD = 100 @staticmethod def check_request_rate(identifier: str) -> bool: """Check if request rate is suspicious""" key = f"ddos:rate:{identifier}" count = cache.increment(key) if count == 1: # Set 1-second window if cache.use_redis: from server.cache_manager import redis_client redis_client.expire(key, 1) return count <= DDoSProtection.REQUESTS_PER_SECOND @staticmethod def check_failed_requests(identifier: str) -> bool: """Check if too many failed requests""" key = f"ddos:failed:{identifier}" count = cache.get(key) or 0 return count < DDoSProtection.FAILED_REQUESTS_THRESHOLD @staticmethod def record_failed_request(identifier: str): """Record failed request""" key = f"ddos:failed:{identifier}" cache.increment(key) # Reset after 1 hour if cache.use_redis: from server.cache_manager import redis_client redis_client.expire(key, 3600) @staticmethod def check_unique_ips() -> bool: """Check if too many unique IPs""" key = "ddos:unique_ips" ips = cache.get(key) or set() return len(ips) < DDoSProtection.UNIQUE_IPS_THRESHOLD @staticmethod def record_ip(ip: str): """Record IP address""" key = "ddos:unique_ips" ips = cache.get(key) or set() ips.add(ip) cache.set(key, ips, 3600) @staticmethod def is_blocked(identifier: str) -> bool: """Check if identifier is blocked""" key = f"ddos:blocked:{identifier}" return cache.get(key) is not None @staticmethod def block(identifier: str, duration: int = 3600): """Block identifier""" key = f"ddos:blocked:{identifier}" cache.set(key, True, duration) @staticmethod def unblock(identifier: str): """Unblock identifier""" key = f"ddos:blocked:{identifier}" cache.delete(key) class QuotaManager: """API quota management""" # Default quotas per plan QUOTAS = { 'free': { 'crawls_per_month': 10, 'analyses_per_month': 20, 'content_generations_per_month': 5, 'api_calls_per_day': 1000, }, 'pro': { 'crawls_per_month': 100, 'analyses_per_month': 200, 'content_generations_per_month': 50, 'api_calls_per_day': 10000, }, 'enterprise': { 'crawls_per_month': 1000, 'analyses_per_month': 2000, 'content_generations_per_month': 500, 'api_calls_per_day': 100000, }, } @staticmethod def get_quota(user_id: int, plan: str = 'free') -> dict: """Get quota for user""" return QuotaManager.QUOTAS.get(plan, QuotaManager.QUOTAS['free']) @staticmethod def check_quota(user_id: int, resource: str, plan: str = 'free') -> Tuple[bool, dict]: """Check if user has quota available""" quota = QuotaManager.get_quota(user_id, plan) key = f"quota:{user_id}:{resource}" used = cache.get(key) or 0 limit = quota.get(f"{resource}_per_month", 0) if limit == 0: return True, {'used': 0, 'limit': 0, 'remaining': 0} remaining = max(0, limit - used) allowed = used < limit return allowed, { 'used': used, 'limit': limit, 'remaining': remaining, 'resource': resource } @staticmethod def increment_usage(user_id: int, resource: str, amount: int = 1): """Increment resource usage""" key = f"quota:{user_id}:{resource}" cache.increment(key, amount) # Reset monthly quota at start of month if cache.use_redis: from server.cache_manager import redis_client redis_client.expire(key, 30 * 24 * 3600) # 30 days @staticmethod def get_usage(user_id: int) -> dict: """Get current usage for user""" resources = ['crawls', 'analyses', 'content_generations', 'api_calls'] usage = {} for resource in resources: key = f"quota:{user_id}:{resource}" usage[resource] = cache.get(key) or 0 return usage def check_rate_limit_middleware(request: Request) -> Tuple[bool, Optional[str]]: """Middleware to check rate limits""" identifier = get_client_identifier(request) # Check if blocked if DDoSProtection.is_blocked(identifier): return False, 'Client is blocked due to suspicious activity' # Check request rate if not DDoSProtection.check_request_rate(identifier): DDoSProtection.block(identifier, duration=3600) return False, 'Rate limit exceeded - client blocked' # Check failed requests if not DDoSProtection.check_failed_requests(identifier): DDoSProtection.block(identifier, duration=3600) return False, 'Too many failed requests - client blocked' # Check endpoint-specific rate limit allowed, remaining = rate_limit_by_endpoint(request) if not allowed: DDoSProtection.record_failed_request(identifier) return False, f'Rate limit exceeded for this endpoint' return True, None def get_rate_limit_status(user_id: int, plan: str = 'free') -> dict: """Get rate limit status for user""" usage = QuotaManager.get_usage(user_id) quota = QuotaManager.get_quota(user_id, plan) status = {} for resource, used in usage.items(): limit = quota.get(f"{resource}_per_month", 0) status[resource] = { 'used': used, 'limit': limit, 'remaining': max(0, limit - used), 'percent_used': (used / limit * 100) if limit > 0 else 0 } return status def reset_rate_limits(user_id: int = None): """Reset rate limits""" if user_id: # Reset specific user resources = ['crawls', 'analyses', 'content_generations', 'api_calls'] for resource in resources: key = f"quota:{user_id}:{resource}" cache.delete(key) else: # Reset all cache.clear()