File size: 6,190 Bytes
92bfe31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Rate limiting middleware using slowapi.
"""
import os
import logging

from fastapi import Request
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded as SlowAPIRateLimitExceeded

logger = logging.getLogger("mathpulse.ratelimit")

# Environment-based configuration with defaults
RATE_LIMIT_AI_RPM = int(os.getenv("RATE_LIMIT_AI_RPM", "20"))
RATE_LIMIT_QUIZ_GENERATE_RPM = int(os.getenv("RATE_LIMIT_QUIZ_GENERATE_RPM", "10"))
RATE_LIMIT_QUIZ_SUBMIT_RPM = int(os.getenv("RATE_LIMIT_QUIZ_SUBMIT_RPM", "30"))
RATE_LIMIT_AUTH_RPM = int(os.getenv("RATE_LIMIT_AUTH_RPM", "5"))
RATE_LIMIT_LEADERBOARD_RPM = int(os.getenv("RATE_LIMIT_LEADERBOARD_RPM", "60"))
RATE_LIMIT_DEFAULT_RPM = int(os.getenv("RATE_LIMIT_DEFAULT_RPM", "100"))
RATE_LIMIT_ADMIN_MULTIPLIER = int(os.getenv("RATE_LIMIT_ADMIN_MULTIPLIER", "10"))
RATE_LIMIT_TEACHER_MULTIPLIER = int(os.getenv("RATE_LIMIT_TEACHER_MULTIPLIER", "3"))

# Role multipliers for rate limit adjustment
ROLE_MULTIPLIERS = {
    "admin": RATE_LIMIT_ADMIN_MULTIPLIER,
    "teacher": RATE_LIMIT_TEACHER_MULTIPLIER,
    "student": 1,
}


def _get_user_identifier(request: Request) -> str:
    """
    Extract user identifier for rate limiting.
    Uses Firebase UID from request.state.user if authenticated, otherwise falls back to IP.
    """
    user = getattr(request.state, "user", None)
    if user and hasattr(user, "uid") and user.uid:
        return f"uid:{user.uid}"

    if request.client:
        return f"ip:{request.client.host}"
    return "ip:unknown"


def _get_user_role(request: Request) -> str:
    """Get user role from request state for multiplier calculation."""
    user = getattr(request.state, "user", None)
    if user and hasattr(user, "role") and user.role:
        return user.role
    return "student"


def _get_role_multiplier(request: Request) -> int:
    """Get rate limit multiplier based on user role."""
    role = _get_user_role(request)
    return ROLE_MULTIPLIERS.get(role, 1)


class MathPulseLimiter:
    """
    Rate limiter with role-aware multipliers for MathPulse AI.
    """

    def __init__(self) -> None:
        self._limiter = Limiter(
            key_func=_get_user_identifier,
            storage_uri="memory://",
            default_limits=[f"{RATE_LIMIT_DEFAULT_RPM}/minute"],
        )

    @property
    def limiter(self) -> Limiter:
        return self._limiter

    def _get_adjusted_limit(self, base_rpm: int, request: Request) -> int:
        """Apply role multiplier to base rate limit."""
        multiplier = _get_role_multiplier(request)
        return base_rpm * multiplier

    def ai_limit(self, request: Request) -> str:
        """Rate limit for AI endpoints with role adjustment."""
        limit = self._get_adjusted_limit(RATE_LIMIT_AI_RPM, request)
        return f"{limit}/minute"

    def quiz_generate_limit(self, request: Request) -> str:
        """Rate limit for quiz generation with role adjustment."""
        limit = self._get_adjusted_limit(RATE_LIMIT_QUIZ_GENERATE_RPM, request)
        return f"{limit}/minute"

    def quiz_submit_limit(self, request: Request) -> str:
        """Rate limit for quiz submission with role adjustment."""
        limit = self._get_adjusted_limit(RATE_LIMIT_QUIZ_SUBMIT_RPM, request)
        return f"{limit}/minute"

    def auth_limit(self, request: Request) -> str:
        """Rate limit for auth endpoints with role adjustment."""
        limit = self._get_adjusted_limit(RATE_LIMIT_AUTH_RPM, request)
        return f"{limit}/minute"

    def leaderboard_limit(self, request: Request) -> str:
        """Rate limit for leaderboard with role adjustment."""
        limit = self._get_adjusted_limit(RATE_LIMIT_LEADERBOARD_RPM, request)
        return f"{limit}/minute"

    def default_limit(self, request: Request) -> str:
        """Default rate limit with role adjustment."""
        limit = self._get_adjusted_limit(RATE_LIMIT_DEFAULT_RPM, request)
        return f"{limit}/minute"


# Global rate limiter instance
rate_limiter = MathPulseLimiter()


def setup_rate_limiting(app) -> None:
    """
    Set up rate limiting for the FastAPI application.
    """

    # Add limiter to app state
    app.state.limiter = rate_limiter.limiter

    # Add slowapi exception handler
    app.add_exception_handler(
        SlowAPIRateLimitExceeded,
        lambda request, exc: _rate_limit_exceeded_handler(request, exc)
    )

    logger.info(
        f"Rate limiting configured: AI={RATE_LIMIT_AI_RPM}/min, "
        f"QuizGen={RATE_LIMIT_QUIZ_GENERATE_RPM}/min, "
        f"Auth={RATE_LIMIT_AUTH_RPM}/min, "
        f"Admin={RATE_LIMIT_ADMIN_MULTIPLIER}x, Teacher={RATE_LIMIT_TEACHER_MULTIPLIER}x"
    )


def _rate_limit_exceeded_handler(request: Request, exc: SlowAPIRateLimitExceeded):
    """Handle rate limit exceeded errors with proper JSON response."""
    from fastapi.responses import JSONResponse

    retry_after = getattr(exc, "retry_after", 60)
    return JSONResponse(
        status_code=429,
        content={
            "error": "rate_limit_exceeded",
            "message": "Too many requests. Please try again later.",
            "retry_after": retry_after,
        },
        headers={
            "Retry-After": str(retry_after),
            "Content-Type": "application/json",
        }
    )


# Decorator helpers
def ai_rate_limit():
    """Decorator for AI endpoint rate limiting."""
    return rate_limiter.limiter.limit(rate_limiter.ai_limit)


def quiz_generate_rate_limit():
    """Decorator for quiz generation rate limiting."""
    return rate_limiter.limiter.limit(rate_limiter.quiz_generate_limit)


def quiz_submit_rate_limit():
    """Decorator for quiz submit rate limiting."""
    return rate_limiter.limiter.limit(rate_limiter.quiz_submit_limit)


def auth_rate_limit():
    """Decorator for auth endpoint rate limiting."""
    return rate_limiter.limiter.limit(rate_limiter.auth_limit)


def leaderboard_rate_limit():
    """Decorator for leaderboard rate limiting."""
    return rate_limiter.limiter.limit(rate_limiter.leaderboard_limit)


def default_rate_limit():
    """Decorator for default rate limiting."""
    return rate_limiter.limiter.limit(rate_limiter.default_limit)