| """ |
| Rate Limiting & DDoS Protection |
| - Rate limiting enforcement |
| - DDoS protection |
| - API quota enforcement |
| """ |
|
|
| import time |
| from typing import Optional, Tuple |
| from functools import wraps |
| from fastapi import Request, HTTPException |
| from server.cache_manager import cache, RateLimiter |
|
|
| |
| RATE_LIMITERS = { |
| 'api_crawl': RateLimiter(max_requests=10, window_seconds=3600), |
| 'api_analyze': RateLimiter(max_requests=20, window_seconds=3600), |
| 'api_keywords': RateLimiter(max_requests=15, window_seconds=3600), |
| 'api_content_generate': RateLimiter(max_requests=5, window_seconds=3600), |
| 'api_search': RateLimiter(max_requests=30, window_seconds=3600), |
| 'api_default': RateLimiter(max_requests=100, window_seconds=60), |
| } |
|
|
| |
| ENDPOINT_LIMITS = { |
| '/api/crawl': 'api_crawl', |
| '/api/analyze': 'api_analyze', |
| '/api/keywords': 'api_keywords', |
| '/api/content/generate': 'api_content_generate', |
| '/api/search': 'api_search', |
| } |
|
|
| class RateLimitExceeded(HTTPException): |
| """Rate limit exceeded exception""" |
| def __init__(self, retry_after: int = 60): |
| self.retry_after = retry_after |
| super().__init__( |
| status_code=429, |
| detail=f'Rate limit exceeded. Retry after {retry_after} seconds.' |
| ) |
|
|
| def get_client_identifier(request: Request) -> str: |
| """Get unique client identifier""" |
| |
| try: |
| auth = request.headers.get('authorization', '') |
| if auth.startswith('Bearer '): |
| token = auth.split(' ', 1)[1].strip() |
| from server import users |
| uid = users.verify_token(token) |
| if uid: |
| return f"user:{uid}" |
| except: |
| pass |
| |
| |
| client_ip = request.client.host if request.client else 'unknown' |
| return f"ip:{client_ip}" |
|
|
| def rate_limit(limiter_key: str = 'api_default'): |
| """Rate limiting decorator""" |
| def decorator(func): |
| @wraps(func) |
| async def wrapper(request: Request, *args, **kwargs): |
| limiter = RATE_LIMITERS.get(limiter_key, RATE_LIMITERS['api_default']) |
| identifier = get_client_identifier(request) |
| |
| if not limiter.is_allowed(identifier): |
| remaining = limiter.get_remaining(identifier) |
| raise RateLimitExceeded(retry_after=60) |
| |
| |
| response = await func(request, *args, **kwargs) |
| remaining = limiter.get_remaining(identifier) |
| |
| if hasattr(response, 'headers'): |
| response.headers['X-RateLimit-Remaining'] = str(remaining) |
| response.headers['X-RateLimit-Limit'] = str(limiter.max_requests) |
| |
| return response |
| |
| return wrapper |
| return decorator |
|
|
| def rate_limit_by_endpoint(request: Request) -> Tuple[bool, Optional[int]]: |
| """Check rate limit for endpoint""" |
| endpoint = request.url.path |
| limiter_key = ENDPOINT_LIMITS.get(endpoint, 'api_default') |
| limiter = RATE_LIMITERS[limiter_key] |
| |
| identifier = get_client_identifier(request) |
| allowed = limiter.is_allowed(identifier) |
| remaining = limiter.get_remaining(identifier) |
| |
| return allowed, remaining |
|
|
| class DDoSProtection: |
| """DDoS protection mechanisms""" |
| |
| |
| REQUESTS_PER_SECOND = 100 |
| UNIQUE_IPS_THRESHOLD = 50 |
| FAILED_REQUESTS_THRESHOLD = 100 |
| |
| @staticmethod |
| def check_request_rate(identifier: str) -> bool: |
| """Check if request rate is suspicious""" |
| key = f"ddos:rate:{identifier}" |
| count = cache.increment(key) |
| |
| if count == 1: |
| |
| if cache.use_redis: |
| from server.cache_manager import redis_client |
| redis_client.expire(key, 1) |
| |
| return count <= DDoSProtection.REQUESTS_PER_SECOND |
| |
| @staticmethod |
| def check_failed_requests(identifier: str) -> bool: |
| """Check if too many failed requests""" |
| key = f"ddos:failed:{identifier}" |
| count = cache.get(key) or 0 |
| |
| return count < DDoSProtection.FAILED_REQUESTS_THRESHOLD |
| |
| @staticmethod |
| def record_failed_request(identifier: str): |
| """Record failed request""" |
| key = f"ddos:failed:{identifier}" |
| cache.increment(key) |
| |
| |
| if cache.use_redis: |
| from server.cache_manager import redis_client |
| redis_client.expire(key, 3600) |
| |
| @staticmethod |
| def check_unique_ips() -> bool: |
| """Check if too many unique IPs""" |
| key = "ddos:unique_ips" |
| ips = cache.get(key) or set() |
| |
| return len(ips) < DDoSProtection.UNIQUE_IPS_THRESHOLD |
| |
| @staticmethod |
| def record_ip(ip: str): |
| """Record IP address""" |
| key = "ddos:unique_ips" |
| ips = cache.get(key) or set() |
| ips.add(ip) |
| cache.set(key, ips, 3600) |
| |
| @staticmethod |
| def is_blocked(identifier: str) -> bool: |
| """Check if identifier is blocked""" |
| key = f"ddos:blocked:{identifier}" |
| return cache.get(key) is not None |
| |
| @staticmethod |
| def block(identifier: str, duration: int = 3600): |
| """Block identifier""" |
| key = f"ddos:blocked:{identifier}" |
| cache.set(key, True, duration) |
| |
| @staticmethod |
| def unblock(identifier: str): |
| """Unblock identifier""" |
| key = f"ddos:blocked:{identifier}" |
| cache.delete(key) |
|
|
| class QuotaManager: |
| """API quota management""" |
| |
| |
| QUOTAS = { |
| 'free': { |
| 'crawls_per_month': 10, |
| 'analyses_per_month': 20, |
| 'content_generations_per_month': 5, |
| 'api_calls_per_day': 1000, |
| }, |
| 'pro': { |
| 'crawls_per_month': 100, |
| 'analyses_per_month': 200, |
| 'content_generations_per_month': 50, |
| 'api_calls_per_day': 10000, |
| }, |
| 'enterprise': { |
| 'crawls_per_month': 1000, |
| 'analyses_per_month': 2000, |
| 'content_generations_per_month': 500, |
| 'api_calls_per_day': 100000, |
| }, |
| } |
| |
| @staticmethod |
| def get_quota(user_id: int, plan: str = 'free') -> dict: |
| """Get quota for user""" |
| return QuotaManager.QUOTAS.get(plan, QuotaManager.QUOTAS['free']) |
| |
| @staticmethod |
| def check_quota(user_id: int, resource: str, plan: str = 'free') -> Tuple[bool, dict]: |
| """Check if user has quota available""" |
| quota = QuotaManager.get_quota(user_id, plan) |
| |
| key = f"quota:{user_id}:{resource}" |
| used = cache.get(key) or 0 |
| limit = quota.get(f"{resource}_per_month", 0) |
| |
| if limit == 0: |
| return True, {'used': 0, 'limit': 0, 'remaining': 0} |
| |
| remaining = max(0, limit - used) |
| allowed = used < limit |
| |
| return allowed, { |
| 'used': used, |
| 'limit': limit, |
| 'remaining': remaining, |
| 'resource': resource |
| } |
| |
| @staticmethod |
| def increment_usage(user_id: int, resource: str, amount: int = 1): |
| """Increment resource usage""" |
| key = f"quota:{user_id}:{resource}" |
| cache.increment(key, amount) |
| |
| |
| if cache.use_redis: |
| from server.cache_manager import redis_client |
| redis_client.expire(key, 30 * 24 * 3600) |
| |
| @staticmethod |
| def get_usage(user_id: int) -> dict: |
| """Get current usage for user""" |
| resources = ['crawls', 'analyses', 'content_generations', 'api_calls'] |
| usage = {} |
| |
| for resource in resources: |
| key = f"quota:{user_id}:{resource}" |
| usage[resource] = cache.get(key) or 0 |
| |
| return usage |
|
|
| def check_rate_limit_middleware(request: Request) -> Tuple[bool, Optional[str]]: |
| """Middleware to check rate limits""" |
| identifier = get_client_identifier(request) |
| |
| |
| if DDoSProtection.is_blocked(identifier): |
| return False, 'Client is blocked due to suspicious activity' |
| |
| |
| if not DDoSProtection.check_request_rate(identifier): |
| DDoSProtection.block(identifier, duration=3600) |
| return False, 'Rate limit exceeded - client blocked' |
| |
| |
| if not DDoSProtection.check_failed_requests(identifier): |
| DDoSProtection.block(identifier, duration=3600) |
| return False, 'Too many failed requests - client blocked' |
| |
| |
| allowed, remaining = rate_limit_by_endpoint(request) |
| if not allowed: |
| DDoSProtection.record_failed_request(identifier) |
| return False, f'Rate limit exceeded for this endpoint' |
| |
| return True, None |
|
|
| def get_rate_limit_status(user_id: int, plan: str = 'free') -> dict: |
| """Get rate limit status for user""" |
| usage = QuotaManager.get_usage(user_id) |
| quota = QuotaManager.get_quota(user_id, plan) |
| |
| status = {} |
| for resource, used in usage.items(): |
| limit = quota.get(f"{resource}_per_month", 0) |
| status[resource] = { |
| 'used': used, |
| 'limit': limit, |
| 'remaining': max(0, limit - used), |
| 'percent_used': (used / limit * 100) if limit > 0 else 0 |
| } |
| |
| return status |
|
|
| def reset_rate_limits(user_id: int = None): |
| """Reset rate limits""" |
| if user_id: |
| |
| resources = ['crawls', 'analyses', 'content_generations', 'api_calls'] |
| for resource in resources: |
| key = f"quota:{user_id}:{resource}" |
| cache.delete(key) |
| else: |
| |
| cache.clear() |
|
|