last_edit / server /rate_limiter.py
Moharek
Deploy Moharek GEO Platform
a74b879
"""
Rate Limiting & DDoS Protection
- Rate limiting enforcement
- DDoS protection
- API quota enforcement
"""
import time
from typing import Optional, Tuple
from functools import wraps
from fastapi import Request, HTTPException
from server.cache_manager import cache, RateLimiter
# Rate limiters for different endpoints
RATE_LIMITERS = {
'api_crawl': RateLimiter(max_requests=10, window_seconds=3600), # 10/hour
'api_analyze': RateLimiter(max_requests=20, window_seconds=3600), # 20/hour
'api_keywords': RateLimiter(max_requests=15, window_seconds=3600), # 15/hour
'api_content_generate': RateLimiter(max_requests=5, window_seconds=3600), # 5/hour
'api_search': RateLimiter(max_requests=30, window_seconds=3600), # 30/hour
'api_default': RateLimiter(max_requests=100, window_seconds=60), # 100/minute
}
# Endpoint-specific limits
ENDPOINT_LIMITS = {
'/api/crawl': 'api_crawl',
'/api/analyze': 'api_analyze',
'/api/keywords': 'api_keywords',
'/api/content/generate': 'api_content_generate',
'/api/search': 'api_search',
}
class RateLimitExceeded(HTTPException):
"""Rate limit exceeded exception"""
def __init__(self, retry_after: int = 60):
self.retry_after = retry_after
super().__init__(
status_code=429,
detail=f'Rate limit exceeded. Retry after {retry_after} seconds.'
)
def get_client_identifier(request: Request) -> str:
"""Get unique client identifier"""
# Try to get user ID from token
try:
auth = request.headers.get('authorization', '')
if auth.startswith('Bearer '):
token = auth.split(' ', 1)[1].strip()
from server import users
uid = users.verify_token(token)
if uid:
return f"user:{uid}"
except:
pass
# Fall back to IP address
client_ip = request.client.host if request.client else 'unknown'
return f"ip:{client_ip}"
def rate_limit(limiter_key: str = 'api_default'):
"""Rate limiting decorator"""
def decorator(func):
@wraps(func)
async def wrapper(request: Request, *args, **kwargs):
limiter = RATE_LIMITERS.get(limiter_key, RATE_LIMITERS['api_default'])
identifier = get_client_identifier(request)
if not limiter.is_allowed(identifier):
remaining = limiter.get_remaining(identifier)
raise RateLimitExceeded(retry_after=60)
# Add rate limit headers
response = await func(request, *args, **kwargs)
remaining = limiter.get_remaining(identifier)
if hasattr(response, 'headers'):
response.headers['X-RateLimit-Remaining'] = str(remaining)
response.headers['X-RateLimit-Limit'] = str(limiter.max_requests)
return response
return wrapper
return decorator
def rate_limit_by_endpoint(request: Request) -> Tuple[bool, Optional[int]]:
"""Check rate limit for endpoint"""
endpoint = request.url.path
limiter_key = ENDPOINT_LIMITS.get(endpoint, 'api_default')
limiter = RATE_LIMITERS[limiter_key]
identifier = get_client_identifier(request)
allowed = limiter.is_allowed(identifier)
remaining = limiter.get_remaining(identifier)
return allowed, remaining
class DDoSProtection:
"""DDoS protection mechanisms"""
# Suspicious activity thresholds
REQUESTS_PER_SECOND = 100
UNIQUE_IPS_THRESHOLD = 50
FAILED_REQUESTS_THRESHOLD = 100
@staticmethod
def check_request_rate(identifier: str) -> bool:
"""Check if request rate is suspicious"""
key = f"ddos:rate:{identifier}"
count = cache.increment(key)
if count == 1:
# Set 1-second window
if cache.use_redis:
from server.cache_manager import redis_client
redis_client.expire(key, 1)
return count <= DDoSProtection.REQUESTS_PER_SECOND
@staticmethod
def check_failed_requests(identifier: str) -> bool:
"""Check if too many failed requests"""
key = f"ddos:failed:{identifier}"
count = cache.get(key) or 0
return count < DDoSProtection.FAILED_REQUESTS_THRESHOLD
@staticmethod
def record_failed_request(identifier: str):
"""Record failed request"""
key = f"ddos:failed:{identifier}"
cache.increment(key)
# Reset after 1 hour
if cache.use_redis:
from server.cache_manager import redis_client
redis_client.expire(key, 3600)
@staticmethod
def check_unique_ips() -> bool:
"""Check if too many unique IPs"""
key = "ddos:unique_ips"
ips = cache.get(key) or set()
return len(ips) < DDoSProtection.UNIQUE_IPS_THRESHOLD
@staticmethod
def record_ip(ip: str):
"""Record IP address"""
key = "ddos:unique_ips"
ips = cache.get(key) or set()
ips.add(ip)
cache.set(key, ips, 3600)
@staticmethod
def is_blocked(identifier: str) -> bool:
"""Check if identifier is blocked"""
key = f"ddos:blocked:{identifier}"
return cache.get(key) is not None
@staticmethod
def block(identifier: str, duration: int = 3600):
"""Block identifier"""
key = f"ddos:blocked:{identifier}"
cache.set(key, True, duration)
@staticmethod
def unblock(identifier: str):
"""Unblock identifier"""
key = f"ddos:blocked:{identifier}"
cache.delete(key)
class QuotaManager:
"""API quota management"""
# Default quotas per plan
QUOTAS = {
'free': {
'crawls_per_month': 10,
'analyses_per_month': 20,
'content_generations_per_month': 5,
'api_calls_per_day': 1000,
},
'pro': {
'crawls_per_month': 100,
'analyses_per_month': 200,
'content_generations_per_month': 50,
'api_calls_per_day': 10000,
},
'enterprise': {
'crawls_per_month': 1000,
'analyses_per_month': 2000,
'content_generations_per_month': 500,
'api_calls_per_day': 100000,
},
}
@staticmethod
def get_quota(user_id: int, plan: str = 'free') -> dict:
"""Get quota for user"""
return QuotaManager.QUOTAS.get(plan, QuotaManager.QUOTAS['free'])
@staticmethod
def check_quota(user_id: int, resource: str, plan: str = 'free') -> Tuple[bool, dict]:
"""Check if user has quota available"""
quota = QuotaManager.get_quota(user_id, plan)
key = f"quota:{user_id}:{resource}"
used = cache.get(key) or 0
limit = quota.get(f"{resource}_per_month", 0)
if limit == 0:
return True, {'used': 0, 'limit': 0, 'remaining': 0}
remaining = max(0, limit - used)
allowed = used < limit
return allowed, {
'used': used,
'limit': limit,
'remaining': remaining,
'resource': resource
}
@staticmethod
def increment_usage(user_id: int, resource: str, amount: int = 1):
"""Increment resource usage"""
key = f"quota:{user_id}:{resource}"
cache.increment(key, amount)
# Reset monthly quota at start of month
if cache.use_redis:
from server.cache_manager import redis_client
redis_client.expire(key, 30 * 24 * 3600) # 30 days
@staticmethod
def get_usage(user_id: int) -> dict:
"""Get current usage for user"""
resources = ['crawls', 'analyses', 'content_generations', 'api_calls']
usage = {}
for resource in resources:
key = f"quota:{user_id}:{resource}"
usage[resource] = cache.get(key) or 0
return usage
def check_rate_limit_middleware(request: Request) -> Tuple[bool, Optional[str]]:
"""Middleware to check rate limits"""
identifier = get_client_identifier(request)
# Check if blocked
if DDoSProtection.is_blocked(identifier):
return False, 'Client is blocked due to suspicious activity'
# Check request rate
if not DDoSProtection.check_request_rate(identifier):
DDoSProtection.block(identifier, duration=3600)
return False, 'Rate limit exceeded - client blocked'
# Check failed requests
if not DDoSProtection.check_failed_requests(identifier):
DDoSProtection.block(identifier, duration=3600)
return False, 'Too many failed requests - client blocked'
# Check endpoint-specific rate limit
allowed, remaining = rate_limit_by_endpoint(request)
if not allowed:
DDoSProtection.record_failed_request(identifier)
return False, f'Rate limit exceeded for this endpoint'
return True, None
def get_rate_limit_status(user_id: int, plan: str = 'free') -> dict:
"""Get rate limit status for user"""
usage = QuotaManager.get_usage(user_id)
quota = QuotaManager.get_quota(user_id, plan)
status = {}
for resource, used in usage.items():
limit = quota.get(f"{resource}_per_month", 0)
status[resource] = {
'used': used,
'limit': limit,
'remaining': max(0, limit - used),
'percent_used': (used / limit * 100) if limit > 0 else 0
}
return status
def reset_rate_limits(user_id: int = None):
"""Reset rate limits"""
if user_id:
# Reset specific user
resources = ['crawls', 'analyses', 'content_generations', 'api_calls']
for resource in resources:
key = f"quota:{user_id}:{resource}"
cache.delete(key)
else:
# Reset all
cache.clear()