emotion-detection-api / app /rate_limiter.py
HimAJ's picture
upload 32 files for the ml
1e4fc28 verified
"""
Simple in-memory rate limiter for API endpoints.
For production, consider using Redis-based rate limiting.
"""
import time
from collections import defaultdict
from typing import Dict, Tuple
from threading import Lock
class RateLimiter:
"""
Simple token bucket rate limiter.
Thread-safe for basic use cases.
"""
def __init__(self, max_requests: int = 100, window_seconds: int = 60):
"""
Args:
max_requests: Maximum requests allowed in the time window
window_seconds: Time window in seconds
"""
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests: Dict[str, list] = defaultdict(list)
self.lock = Lock()
def is_allowed(self, identifier: str) -> Tuple[bool, int]:
"""
Check if a request is allowed.
Args:
identifier: Unique identifier (e.g., IP address, user ID)
Returns:
Tuple of (is_allowed, remaining_requests)
"""
current_time = time.time()
with self.lock:
# Clean old requests outside the window
window_start = current_time - self.window_seconds
self.requests[identifier] = [
req_time for req_time in self.requests[identifier]
if req_time > window_start
]
# Check if limit exceeded
if len(self.requests[identifier]) >= self.max_requests:
remaining = 0
return False, remaining
# Add current request
self.requests[identifier].append(current_time)
remaining = self.max_requests - len(self.requests[identifier])
return True, remaining
def reset(self, identifier: str = None):
"""Reset rate limit for an identifier or all identifiers."""
with self.lock:
if identifier:
self.requests.pop(identifier, None)
else:
self.requests.clear()
# Global rate limiters for different endpoints
detect_limiter = RateLimiter(max_requests=30, window_seconds=60) # 30 requests per minute
logs_limiter = RateLimiter(max_requests=100, window_seconds=60) # 100 requests per minute
images_limiter = RateLimiter(max_requests=200, window_seconds=60) # 200 requests per minute
def get_client_identifier(request) -> str:
"""
Get a unique identifier for rate limiting.
Uses IP address by default.
"""
# Try to get real IP (behind proxy)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.remote_addr or "unknown"