File size: 4,496 Bytes
67befa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Simple rate limiting middleware for demo/single user scenarios."""

import time
from collections import defaultdict, deque
from typing import Callable
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse

from app.utils.constants import (
    RATE_LIMIT_REQUESTS_PER_MINUTE,
    RATE_LIMIT_REQUESTS_PER_HOUR,
)


class SimpleRateLimiter:
    """Simple in-memory rate limiter for demo use (not for production with multiple servers)."""
    
    def __init__(self):
        # Track requests by IP address
        self._requests_by_ip: dict[str, deque] = defaultdict(lambda: deque())
        self._last_cleanup = time.time()
        self._cleanup_interval = 300  # Clean up old entries every 5 minutes
    
    def _cleanup_old_entries(self):
        """Remove old request timestamps to prevent memory growth."""
        current_time = time.time()
        if current_time - self._last_cleanup < self._cleanup_interval:
            return
        
        cutoff_minute = current_time - 60
        cutoff_hour = current_time - 3600
        
        for ip in list(self._requests_by_ip.keys()):
            requests = self._requests_by_ip[ip]
            # Keep only requests from last hour
            while requests and requests[0] < cutoff_hour:
                requests.popleft()
            
            # Remove IP if no recent requests
            if not requests:
                del self._requests_by_ip[ip]
        
        self._last_cleanup = current_time
    
    def check_rate_limit(self, ip: str) -> tuple[bool, str | None]:
        """
        Check if request should be allowed.
        
        Returns:
            (allowed, error_message)
        """
        self._cleanup_old_entries()
        
        current_time = time.time()
        requests = self._requests_by_ip[ip]
        
        # Remove requests older than 1 hour
        cutoff_hour = current_time - 3600
        while requests and requests[0] < cutoff_hour:
            requests.popleft()
        
        # Check hourly limit
        if len(requests) >= RATE_LIMIT_REQUESTS_PER_HOUR:
            return False, f"Rate limit exceeded: {RATE_LIMIT_REQUESTS_PER_HOUR} requests per hour"
        
        # Check per-minute limit (last 60 seconds)
        cutoff_minute = current_time - 60
        recent_requests = [r for r in requests if r >= cutoff_minute]
        if len(recent_requests) >= RATE_LIMIT_REQUESTS_PER_MINUTE:
            return False, f"Rate limit exceeded: {RATE_LIMIT_REQUESTS_PER_MINUTE} requests per minute"
        
        # Record this request
        requests.append(current_time)
        return True, None


# Global rate limiter instance
_rate_limiter = SimpleRateLimiter()


async def rate_limit_middleware(request: Request, call_next: Callable):
    """Rate limiting middleware."""
    # Skip rate limiting for public endpoints
    public_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json", "/v1/stats"]
    if request.url.path in public_paths:
        return await call_next(request)
    
    # Get client IP
    client_ip = request.client.host if request.client else "unknown"
    
    # Check rate limit
    allowed, error_msg = _rate_limiter.check_rate_limit(client_ip)
    
    if not allowed:
        return JSONResponse(
            status_code=429,
            content={
                "error": {
                    "message": error_msg,
                    "type": "rate_limit_error"
                }
            },
            headers={
                "Retry-After": "60",  # Suggest retrying after 60 seconds
                "X-RateLimit-Limit-Minute": str(RATE_LIMIT_REQUESTS_PER_MINUTE),
                "X-RateLimit-Limit-Hour": str(RATE_LIMIT_REQUESTS_PER_HOUR),
            }
        )
    
    response = await call_next(request)
    
    # Add rate limit headers
    requests = _rate_limiter._requests_by_ip[client_ip]
    current_time = time.time()
    recent_minute = [r for r in requests if r >= current_time - 60]
    recent_hour = [r for r in requests if r >= current_time - 3600]
    
    response.headers["X-RateLimit-Limit-Minute"] = str(RATE_LIMIT_REQUESTS_PER_MINUTE)
    response.headers["X-RateLimit-Limit-Hour"] = str(RATE_LIMIT_REQUESTS_PER_HOUR)
    response.headers["X-RateLimit-Remaining-Minute"] = str(max(0, RATE_LIMIT_REQUESTS_PER_MINUTE - len(recent_minute)))
    response.headers["X-RateLimit-Remaining-Hour"] = str(max(0, RATE_LIMIT_REQUESTS_PER_HOUR - len(recent_hour)))
    
    return response