File size: 2,547 Bytes
fba30db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f3f24
 
fba30db
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Phase 15.2 — Rate limiting via slowapi.

Exposes a shared `limiter` plus a request key function that differentiates
authenticated users (keyed by user id) from anonymous clients (keyed by IP).

slowapi 0.1.9 calls `exempt_when()` with no arguments, so we stash the current
Request in a ContextVar via a middleware (`RateLimitContextMiddleware`) and
read from it inside the exempt predicates.
"""
from __future__ import annotations

from contextvars import ContextVar

from fastapi import Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware

from services.auth_service import decode_token

_REQUEST_STASH: ContextVar[Request | None] = ContextVar("_REQUEST_STASH", default=None)


def _bearer(request: Request) -> str | None:
    auth = request.headers.get("authorization")
    if not auth:
        return None
    parts = auth.split()
    if len(parts) != 2 or parts[0].lower() != "bearer":
        return None
    return parts[1]


def _authed_user_id(request: Request) -> str | None:
    token = _bearer(request)
    if not token:
        return None
    payload = decode_token(token)
    if not payload or "sub" not in payload:
        return None
    return str(payload["sub"])


def request_key(request: Request) -> str:
    """Keyed on user id when authed, IP address otherwise."""
    uid = _authed_user_id(request)
    if uid:
        return f"user:{uid}"
    return f"ip:{get_remote_address(request)}"


def is_authed() -> bool:
    request = _REQUEST_STASH.get()
    if request is None:
        return False
    return _authed_user_id(request) is not None


def is_anon() -> bool:
    request = _REQUEST_STASH.get()
    if request is None:
        return True
    return _authed_user_id(request) is None


class RateLimitContextMiddleware(BaseHTTPMiddleware):
    """Stashes the incoming Request in a ContextVar so slowapi's no-arg
    `exempt_when` predicates can read it."""

    async def dispatch(self, request: Request, call_next):
        token = _REQUEST_STASH.set(request)
        try:
            return await call_next(request)
        finally:
            _REQUEST_STASH.reset(token)


# Per-route rate limits — anon gets strict caps, authed gets generous quotas.
ANON_ANALYZE = "5/hour"
AUTH_ANALYZE = "50/hour"
ANON_REPORT = "2/hour"
AUTH_REPORT = "20/hour"
ANON_AUTH_REGISTER = "5/hour"
ANON_AUTH_LOGIN = "10/minute"

limiter = Limiter(
    key_func=request_key,
    default_limits=[],
    headers_enabled=True,
    enabled=True,
)