Spaces:
Running
Running
| """ | |
| Simple in-memory rate limiter for API endpoints. | |
| For production, consider using Redis-based rate limiting. | |
| """ | |
| import time | |
| from collections import defaultdict | |
| from typing import Dict, Tuple | |
| from threading import Lock | |
| class RateLimiter: | |
| """ | |
| Simple token bucket rate limiter. | |
| Thread-safe for basic use cases. | |
| """ | |
| def __init__(self, max_requests: int = 100, window_seconds: int = 60): | |
| """ | |
| Args: | |
| max_requests: Maximum requests allowed in the time window | |
| window_seconds: Time window in seconds | |
| """ | |
| self.max_requests = max_requests | |
| self.window_seconds = window_seconds | |
| self.requests: Dict[str, list] = defaultdict(list) | |
| self.lock = Lock() | |
| def is_allowed(self, identifier: str) -> Tuple[bool, int]: | |
| """ | |
| Check if a request is allowed. | |
| Args: | |
| identifier: Unique identifier (e.g., IP address, user ID) | |
| Returns: | |
| Tuple of (is_allowed, remaining_requests) | |
| """ | |
| current_time = time.time() | |
| with self.lock: | |
| # Clean old requests outside the window | |
| window_start = current_time - self.window_seconds | |
| self.requests[identifier] = [ | |
| req_time for req_time in self.requests[identifier] | |
| if req_time > window_start | |
| ] | |
| # Check if limit exceeded | |
| if len(self.requests[identifier]) >= self.max_requests: | |
| remaining = 0 | |
| return False, remaining | |
| # Add current request | |
| self.requests[identifier].append(current_time) | |
| remaining = self.max_requests - len(self.requests[identifier]) | |
| return True, remaining | |
| def reset(self, identifier: str = None): | |
| """Reset rate limit for an identifier or all identifiers.""" | |
| with self.lock: | |
| if identifier: | |
| self.requests.pop(identifier, None) | |
| else: | |
| self.requests.clear() | |
| # Global rate limiters for different endpoints | |
| detect_limiter = RateLimiter(max_requests=30, window_seconds=60) # 30 requests per minute | |
| logs_limiter = RateLimiter(max_requests=100, window_seconds=60) # 100 requests per minute | |
| images_limiter = RateLimiter(max_requests=200, window_seconds=60) # 200 requests per minute | |
| def get_client_identifier(request) -> str: | |
| """ | |
| Get a unique identifier for rate limiting. | |
| Uses IP address by default. | |
| """ | |
| # Try to get real IP (behind proxy) | |
| forwarded_for = request.headers.get("X-Forwarded-For") | |
| if forwarded_for: | |
| # Take the first IP in the chain | |
| return forwarded_for.split(",")[0].strip() | |
| real_ip = request.headers.get("X-Real-IP") | |
| if real_ip: | |
| return real_ip | |
| return request.remote_addr or "unknown" | |