File size: 4,183 Bytes
4b62d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Rate limiting middleware for FastAPI.
"""

import logging
from typing import Callable, Optional
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response

from .redis_service import get_redis_service
from .mongodb_service import get_mongodb_service

logger = logging.getLogger(__name__)


class RateLimitMiddleware(BaseHTTPMiddleware):
    """
    Rate limiting middleware using Redis.
    
    Policy: 2 requests per user per minute
    Priority: user_id > device_id > IP
    """
    
    def __init__(self, app, max_requests: int = 2, window_seconds: int = 60):
        super().__init__(app)
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.redis_service = get_redis_service()
        self.mongodb_service = get_mongodb_service()
    
    def _get_identifier(self, request: Request) -> tuple[str, str]:
        """
        Get rate limit identifier with priority: user_id > device_id > IP.
        
        Returns:
            Tuple of (identifier, identifier_type)
        """
        # Try to get from headers/cookies
        user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id")
        device_id = request.headers.get("X-Device-Id") or request.cookies.get("device_id")
        
        # Priority: user_id > device_id > IP
        if user_id:
            return f"user:{user_id}", "user_id"
        elif device_id:
            return f"device:{device_id}", "device_id"
        else:
            # Fallback to IP
            client_ip = request.client.host if request.client else "unknown"
            return f"ip:{client_ip}", "ip"
    
    async def dispatch(
        self,
        request: Request,
        call_next: Callable
    ) -> Response:
        """Process request with rate limiting."""
        
        # Skip rate limiting for admin endpoints
        if request.url.path.startswith("/admin/"):
            return await call_next(request)
        
        # Skip for health checks
        if request.url.path in ["/health", "/", "/docs", "/openapi.json"]:
            return await call_next(request)
        
        # Get identifier
        identifier, identifier_type = self._get_identifier(request)
        
        # Check rate limit
        is_allowed, current_count = self.redis_service.check_rate_limit(
            identifier,
            self.max_requests,
            self.window_seconds
        )
        
        if not is_allowed:
            # Log rate limit hit to MongoDB
            device_id = (
                request.headers.get("X-Device-Id") 
                or request.cookies.get("device_id") 
                or identifier
            )
            user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id")
            
            self.mongodb_service.log_event(
                event_type="RATE_LIMIT_HIT",
                device_id=device_id,
                user_id=user_id,
                metadata={
                    "identifier": identifier,
                    "identifier_type": identifier_type,
                    "count": current_count,
                    "path": request.url.path
                }
            )
            
            logger.warning(
                f"Rate limit exceeded for {identifier} "
                f"(count: {current_count}, max: {self.max_requests})"
            )
            
            raise HTTPException(
                status_code=429,
                detail={
                    "error": "Rate limit exceeded",
                    "message": f"Too many requests. Maximum {self.max_requests} requests per {self.window_seconds} seconds.",
                    "retry_after": self.window_seconds
                }
            )
        
        # Add rate limit info to response headers
        response = await call_next(request)
        response.headers["X-RateLimit-Limit"] = str(self.max_requests)
        response.headers["X-RateLimit-Remaining"] = str(max(0, self.max_requests - current_count))
        response.headers["X-RateLimit-Reset"] = str(self.window_seconds)
        
        return response