Spaces:
Running
Running
| # src/api/middleware/rate_limit.py | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.requests import Request | |
| from starlette.responses import JSONResponse | |
| import time | |
| from collections import defaultdict | |
| # we implement this class for the sake of securing the application from getting hacked like hacker can send multiple request to block my applciation or crash my application to control this we are using rate limit middleware | |
| class SimpleRateLimitMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Simple rate limiter: X requests per Y seconds. | |
| In production, you'd use Redis for this. | |
| This is a simple example for learning. | |
| """ | |
| def __init__(self, app, requests_per_minute: int = 60): | |
| super().__init__(app) | |
| self.requests_per_minute = requests_per_minute | |
| self.requests = defaultdict(list) # IP -> list of timestamps | |
| async def dispatch(self, request: Request, call_next): | |
| # Get client's IP address | |
| client_ip = request.client.host | |
| # Get current time | |
| now = time.time() | |
| minute_ago = now - 60 | |
| # Clean old requests (older than 1 minute) | |
| self.requests[client_ip] = [ | |
| req_time for req_time in self.requests[client_ip] | |
| if req_time > minute_ago | |
| ] | |
| # Check if rate limit exceeded | |
| if len(self.requests[client_ip]) >= self.requests_per_minute: | |
| return JSONResponse( | |
| status_code=429, | |
| content={"error": "Too many requests. Please slow down."} | |
| ) | |
| # Record this request | |
| self.requests[client_ip].append(now) | |
| # Continue to the route | |
| return await call_next(request) |