File size: 9,990 Bytes
a74b879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""
Rate Limiting & DDoS Protection
- Rate limiting enforcement
- DDoS protection
- API quota enforcement
"""

import time
from typing import Optional, Tuple
from functools import wraps
from fastapi import Request, HTTPException
from server.cache_manager import cache, RateLimiter

# Rate limiters for different endpoints
RATE_LIMITERS = {
    'api_crawl': RateLimiter(max_requests=10, window_seconds=3600),  # 10/hour
    'api_analyze': RateLimiter(max_requests=20, window_seconds=3600),  # 20/hour
    'api_keywords': RateLimiter(max_requests=15, window_seconds=3600),  # 15/hour
    'api_content_generate': RateLimiter(max_requests=5, window_seconds=3600),  # 5/hour
    'api_search': RateLimiter(max_requests=30, window_seconds=3600),  # 30/hour
    'api_default': RateLimiter(max_requests=100, window_seconds=60),  # 100/minute
}

# Endpoint-specific limits
ENDPOINT_LIMITS = {
    '/api/crawl': 'api_crawl',
    '/api/analyze': 'api_analyze',
    '/api/keywords': 'api_keywords',
    '/api/content/generate': 'api_content_generate',
    '/api/search': 'api_search',
}

class RateLimitExceeded(HTTPException):
    """Rate limit exceeded exception"""
    def __init__(self, retry_after: int = 60):
        self.retry_after = retry_after
        super().__init__(
            status_code=429,
            detail=f'Rate limit exceeded. Retry after {retry_after} seconds.'
        )

def get_client_identifier(request: Request) -> str:
    """Get unique client identifier"""
    # Try to get user ID from token
    try:
        auth = request.headers.get('authorization', '')
        if auth.startswith('Bearer '):
            token = auth.split(' ', 1)[1].strip()
            from server import users
            uid = users.verify_token(token)
            if uid:
                return f"user:{uid}"
    except:
        pass
    
    # Fall back to IP address
    client_ip = request.client.host if request.client else 'unknown'
    return f"ip:{client_ip}"

def rate_limit(limiter_key: str = 'api_default'):
    """Rate limiting decorator"""
    def decorator(func):
        @wraps(func)
        async def wrapper(request: Request, *args, **kwargs):
            limiter = RATE_LIMITERS.get(limiter_key, RATE_LIMITERS['api_default'])
            identifier = get_client_identifier(request)
            
            if not limiter.is_allowed(identifier):
                remaining = limiter.get_remaining(identifier)
                raise RateLimitExceeded(retry_after=60)
            
            # Add rate limit headers
            response = await func(request, *args, **kwargs)
            remaining = limiter.get_remaining(identifier)
            
            if hasattr(response, 'headers'):
                response.headers['X-RateLimit-Remaining'] = str(remaining)
                response.headers['X-RateLimit-Limit'] = str(limiter.max_requests)
            
            return response
        
        return wrapper
    return decorator

def rate_limit_by_endpoint(request: Request) -> Tuple[bool, Optional[int]]:
    """Check rate limit for endpoint"""
    endpoint = request.url.path
    limiter_key = ENDPOINT_LIMITS.get(endpoint, 'api_default')
    limiter = RATE_LIMITERS[limiter_key]
    
    identifier = get_client_identifier(request)
    allowed = limiter.is_allowed(identifier)
    remaining = limiter.get_remaining(identifier)
    
    return allowed, remaining

class DDoSProtection:
    """DDoS protection mechanisms"""
    
    # Suspicious activity thresholds
    REQUESTS_PER_SECOND = 100
    UNIQUE_IPS_THRESHOLD = 50
    FAILED_REQUESTS_THRESHOLD = 100
    
    @staticmethod
    def check_request_rate(identifier: str) -> bool:
        """Check if request rate is suspicious"""
        key = f"ddos:rate:{identifier}"
        count = cache.increment(key)
        
        if count == 1:
            # Set 1-second window
            if cache.use_redis:
                from server.cache_manager import redis_client
                redis_client.expire(key, 1)
        
        return count <= DDoSProtection.REQUESTS_PER_SECOND
    
    @staticmethod
    def check_failed_requests(identifier: str) -> bool:
        """Check if too many failed requests"""
        key = f"ddos:failed:{identifier}"
        count = cache.get(key) or 0
        
        return count < DDoSProtection.FAILED_REQUESTS_THRESHOLD
    
    @staticmethod
    def record_failed_request(identifier: str):
        """Record failed request"""
        key = f"ddos:failed:{identifier}"
        cache.increment(key)
        
        # Reset after 1 hour
        if cache.use_redis:
            from server.cache_manager import redis_client
            redis_client.expire(key, 3600)
    
    @staticmethod
    def check_unique_ips() -> bool:
        """Check if too many unique IPs"""
        key = "ddos:unique_ips"
        ips = cache.get(key) or set()
        
        return len(ips) < DDoSProtection.UNIQUE_IPS_THRESHOLD
    
    @staticmethod
    def record_ip(ip: str):
        """Record IP address"""
        key = "ddos:unique_ips"
        ips = cache.get(key) or set()
        ips.add(ip)
        cache.set(key, ips, 3600)
    
    @staticmethod
    def is_blocked(identifier: str) -> bool:
        """Check if identifier is blocked"""
        key = f"ddos:blocked:{identifier}"
        return cache.get(key) is not None
    
    @staticmethod
    def block(identifier: str, duration: int = 3600):
        """Block identifier"""
        key = f"ddos:blocked:{identifier}"
        cache.set(key, True, duration)
    
    @staticmethod
    def unblock(identifier: str):
        """Unblock identifier"""
        key = f"ddos:blocked:{identifier}"
        cache.delete(key)

class QuotaManager:
    """API quota management"""
    
    # Default quotas per plan
    QUOTAS = {
        'free': {
            'crawls_per_month': 10,
            'analyses_per_month': 20,
            'content_generations_per_month': 5,
            'api_calls_per_day': 1000,
        },
        'pro': {
            'crawls_per_month': 100,
            'analyses_per_month': 200,
            'content_generations_per_month': 50,
            'api_calls_per_day': 10000,
        },
        'enterprise': {
            'crawls_per_month': 1000,
            'analyses_per_month': 2000,
            'content_generations_per_month': 500,
            'api_calls_per_day': 100000,
        },
    }
    
    @staticmethod
    def get_quota(user_id: int, plan: str = 'free') -> dict:
        """Get quota for user"""
        return QuotaManager.QUOTAS.get(plan, QuotaManager.QUOTAS['free'])
    
    @staticmethod
    def check_quota(user_id: int, resource: str, plan: str = 'free') -> Tuple[bool, dict]:
        """Check if user has quota available"""
        quota = QuotaManager.get_quota(user_id, plan)
        
        key = f"quota:{user_id}:{resource}"
        used = cache.get(key) or 0
        limit = quota.get(f"{resource}_per_month", 0)
        
        if limit == 0:
            return True, {'used': 0, 'limit': 0, 'remaining': 0}
        
        remaining = max(0, limit - used)
        allowed = used < limit
        
        return allowed, {
            'used': used,
            'limit': limit,
            'remaining': remaining,
            'resource': resource
        }
    
    @staticmethod
    def increment_usage(user_id: int, resource: str, amount: int = 1):
        """Increment resource usage"""
        key = f"quota:{user_id}:{resource}"
        cache.increment(key, amount)
        
        # Reset monthly quota at start of month
        if cache.use_redis:
            from server.cache_manager import redis_client
            redis_client.expire(key, 30 * 24 * 3600)  # 30 days
    
    @staticmethod
    def get_usage(user_id: int) -> dict:
        """Get current usage for user"""
        resources = ['crawls', 'analyses', 'content_generations', 'api_calls']
        usage = {}
        
        for resource in resources:
            key = f"quota:{user_id}:{resource}"
            usage[resource] = cache.get(key) or 0
        
        return usage

def check_rate_limit_middleware(request: Request) -> Tuple[bool, Optional[str]]:
    """Middleware to check rate limits"""
    identifier = get_client_identifier(request)
    
    # Check if blocked
    if DDoSProtection.is_blocked(identifier):
        return False, 'Client is blocked due to suspicious activity'
    
    # Check request rate
    if not DDoSProtection.check_request_rate(identifier):
        DDoSProtection.block(identifier, duration=3600)
        return False, 'Rate limit exceeded - client blocked'
    
    # Check failed requests
    if not DDoSProtection.check_failed_requests(identifier):
        DDoSProtection.block(identifier, duration=3600)
        return False, 'Too many failed requests - client blocked'
    
    # Check endpoint-specific rate limit
    allowed, remaining = rate_limit_by_endpoint(request)
    if not allowed:
        DDoSProtection.record_failed_request(identifier)
        return False, f'Rate limit exceeded for this endpoint'
    
    return True, None

def get_rate_limit_status(user_id: int, plan: str = 'free') -> dict:
    """Get rate limit status for user"""
    usage = QuotaManager.get_usage(user_id)
    quota = QuotaManager.get_quota(user_id, plan)
    
    status = {}
    for resource, used in usage.items():
        limit = quota.get(f"{resource}_per_month", 0)
        status[resource] = {
            'used': used,
            'limit': limit,
            'remaining': max(0, limit - used),
            'percent_used': (used / limit * 100) if limit > 0 else 0
        }
    
    return status

def reset_rate_limits(user_id: int = None):
    """Reset rate limits"""
    if user_id:
        # Reset specific user
        resources = ['crawls', 'analyses', 'content_generations', 'api_calls']
        for resource in resources:
            key = f"quota:{user_id}:{resource}"
            cache.delete(key)
    else:
        # Reset all
        cache.clear()