Spaces:
Running
Running
| import time | |
| from collections import defaultdict | |
| from functools import wraps | |
| from typing import Callable, Optional | |
| import threading | |
| from fastapi import Request, HTTPException, status | |
| from fastapi.responses import JSONResponse | |
| class RateLimiter: | |
| """ | |
| Simple in-memory rate limiter for API endpoints. | |
| Uses a sliding window algorithm. | |
| """ | |
| def __init__(self): | |
| self._requests = defaultdict(list) | |
| self._lock = threading.RLock() | |
| def is_allowed( | |
| self, | |
| key: str, | |
| max_requests: int = 60, | |
| window_seconds: int = 60 | |
| ) -> tuple[bool, dict]: | |
| """ | |
| Check if a request is allowed under rate limits. | |
| Returns: (is_allowed, info_dict) | |
| """ | |
| with self._lock: | |
| now = time.time() | |
| window_start = now - window_seconds | |
| # Clean old requests | |
| self._requests[key] = [ | |
| t for t in self._requests[key] if t > window_start | |
| ] | |
| current_count = len(self._requests[key]) | |
| if current_count >= max_requests: | |
| retry_after = self._requests[key][0] - window_start | |
| return False, { | |
| 'limit': max_requests, | |
| 'remaining': 0, | |
| 'reset': int(self._requests[key][0] + window_seconds), | |
| 'retry_after': int(retry_after) + 1 | |
| } | |
| # Add current request | |
| self._requests[key].append(now) | |
| return True, { | |
| 'limit': max_requests, | |
| 'remaining': max_requests - current_count - 1, | |
| 'reset': int(now + window_seconds) | |
| } | |
| def reset(self, key: str): | |
| """Reset rate limit for a key.""" | |
| with self._lock: | |
| if key in self._requests: | |
| del self._requests[key] | |
| # Singleton instance | |
| rate_limiter = RateLimiter() | |
| # Rate limit configurations per endpoint type | |
| RATE_LIMITS = { | |
| 'auth': {'max_requests': 10, 'window': 60}, # 10 per minute | |
| 'chat': {'max_requests': 30, 'window': 60}, # 30 per minute | |
| 'compile': {'max_requests': 5, 'window': 300}, # 5 per 5 minutes | |
| 'agents': {'max_requests': 60, 'window': 60}, # 60 per minute | |
| 'default': {'max_requests': 100, 'window': 60} # 100 per minute | |
| } | |
| async def rate_limit_middleware(request: Request, call_next): | |
| """ | |
| FastAPI middleware for rate limiting. | |
| """ | |
| # Get client identifier (IP or user ID if authenticated) | |
| client_ip = request.client.host if request.client else "unknown" | |
| # Determine endpoint type | |
| path = request.url.path | |
| if '/auth/' in path: | |
| limit_type = 'auth' | |
| elif '/chat/' in path: | |
| limit_type = 'chat' | |
| elif '/compile' in path: | |
| limit_type = 'compile' | |
| elif '/agents' in path: | |
| limit_type = 'agents' | |
| else: | |
| limit_type = 'default' | |
| # Check rate limit | |
| limits = RATE_LIMITS[limit_type] | |
| key = f"{client_ip}:{limit_type}" | |
| allowed, info = rate_limiter.is_allowed( | |
| key, | |
| max_requests=limits['max_requests'], | |
| window_seconds=limits['window'] | |
| ) | |
| if not allowed: | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| 'detail': 'Too many requests', | |
| 'retry_after': info['retry_after'] | |
| }, | |
| headers={ | |
| 'X-RateLimit-Limit': str(info['limit']), | |
| 'X-RateLimit-Remaining': str(info['remaining']), | |
| 'X-RateLimit-Reset': str(info['reset']), | |
| 'Retry-After': str(info['retry_after']) | |
| } | |
| ) | |
| # Process request | |
| response = await call_next(request) | |
| # Add rate limit headers | |
| response.headers['X-RateLimit-Limit'] = str(info['limit']) | |
| response.headers['X-RateLimit-Remaining'] = str(info['remaining']) | |
| response.headers['X-RateLimit-Reset'] = str(info['reset']) | |
| return response | |
| # File validation constants | |
| MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB | |
| ALLOWED_EXTENSIONS = {'.csv', '.pdf', '.docx', '.txt', '.json', '.xlsx'} | |
| def validate_file_upload(filename: str, file_size: int) -> Optional[str]: | |
| """ | |
| Validate an uploaded file. | |
| Returns error message if invalid, None if valid. | |
| """ | |
| import os | |
| # Check extension | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext not in ALLOWED_EXTENSIONS: | |
| return f"File type '{ext}' not allowed. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}" | |
| # Check size | |
| if file_size > MAX_FILE_SIZE: | |
| max_mb = MAX_FILE_SIZE / (1024 * 1024) | |
| return f"File too large. Maximum size is {max_mb}MB" | |
| return None | |
| # Security headers middleware | |
| async def security_headers_middleware(request: Request, call_next): | |
| """Add security headers to all responses.""" | |
| response = await call_next(request) | |
| response.headers['X-Content-Type-Options'] = 'nosniff' | |
| response.headers['X-Frame-Options'] = 'DENY' | |
| response.headers['X-XSS-Protection'] = '1; mode=block' | |
| response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin' | |
| return response | |