File size: 5,652 Bytes
3b7f713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Rate Limiting Middleware dla GrantForge AI.
Chroni endpointy generatora i audytora przed nadu偶yciami.
U偶ywa prostego in-memory store (dla multi-worker贸w u偶yj Redis).
"""

import time
import logging
from collections import defaultdict
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

logger = logging.getLogger(__name__)

# Konfiguracja limit贸w per endpoint-gruppe
RATE_LIMITS = {
    "/api/generator/stream": {"requests": 5, "window_seconds": 300},  # 5 req / 5 min
    "/api/projects/": {
        "audit": {"requests": 10, "window_seconds": 3600},  # 10 audyt贸w / godz.
        "autofix": {"requests": 200, "window_seconds": 3600},  # 200 autofix贸w / godz.
    },
    "default": {"requests": 120, "window_seconds": 60},  # 120 req / min dla reszty
}

# In-memory store: {user_id: {endpoint_key: [(timestamp), ...]}}
_request_log: dict = defaultdict(lambda: defaultdict(list))


def _get_user_id(request: Request) -> str:
    """Wyci膮ga user_id z tokena JWT lub u偶ywa IP jako fallback."""
    import jwt

    auth = request.headers.get("Authorization", "")
    if auth.startswith("Bearer "):
        token = auth.split(" ", 1)[1]
        try:
            if token == "dev_test_token":
                return "dev_user"
            decoded = jwt.decode(token, options={"verify_signature": False})
            return decoded.get("sub", request.client.host)
        except Exception:
            pass
    # Fallback: token w query string (generator SSE)
    token = request.query_params.get("token", "")
    if token:
        try:
            decoded = jwt.decode(token, options={"verify_signature": False})
            return decoded.get("sub", request.client.host)
        except Exception:
            pass
    return getattr(request.client, "host", "unknown")


def _check_rate_limit(user_id: str, endpoint_key: str, limit: dict) -> tuple[bool, int]:
    """
    Sprawdza czy u偶ytkownik nie przekroczy艂 limitu.
    Zwraca (allowed, retry_after_seconds).
    """
    now = time.time()
    window = limit["window_seconds"]
    max_requests = limit["requests"]

    # Wyczy艣膰 stare wpisy
    timestamps = _request_log[user_id][endpoint_key]
    _request_log[user_id][endpoint_key] = [t for t in timestamps if now - t < window]

    current_count = len(_request_log[user_id][endpoint_key])

    if current_count >= max_requests:
        oldest = _request_log[user_id][endpoint_key][0]
        retry_after = int(window - (now - oldest)) + 1
        return False, retry_after

    _request_log[user_id][endpoint_key].append(now)
    return True, 0


class RateLimitMiddleware(BaseHTTPMiddleware):
    """
    Middleware aplikuj膮cy rate limiting do wybranych endpoint贸w.
    Styl: sliding window per user.
    """

    # Endpointy do kt贸rych stosujemy 艣cis艂e limity
    STRICT_PATHS = {
        "/api/generator/stream",
    }

    # Wzorce URL z kluczem (艣cie偶ka zawiera te fragmenty)
    PATTERN_LIMITS = {
        "/audit": {"requests": 10, "window_seconds": 3600},
        "/autofix": {"requests": 200, "window_seconds": 3600},
    }

    async def dispatch(self, request: Request, call_next):
        path = request.url.path

        # Pomijamy health check i statyczne zasoby
        if path in ("/health", "/api/health", "/", "/docs", "/openapi.json"):
            return await call_next(request)

        user_id = _get_user_id(request)

        # 1. 艢cis艂e limity dla generatora
        if path in self.STRICT_PATHS:
            limit = RATE_LIMITS["/api/generator/stream"]
            allowed, retry_after = _check_rate_limit(user_id, path, limit)
            if not allowed:
                logger.warning(
                    f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)"
                )
                return JSONResponse(
                    status_code=429,
                    content={
                        "detail": f"Przekroczono limit zapyta艅. Spr贸buj ponownie za {retry_after} sekund.",
                        "retry_after": retry_after,
                    },
                    headers={"Retry-After": str(retry_after)},
                )

        # 2. Limity dla wzorc贸w audit/autofix
        for pattern, limit in self.PATTERN_LIMITS.items():
            if pattern in path:
                endpoint_key = f"{path}:{request.method}"
                allowed, retry_after = _check_rate_limit(user_id, endpoint_key, limit)
                if not allowed:
                    logger.warning(
                        f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)"
                    )
                    return JSONResponse(
                        status_code=429,
                        content={
                            "detail": f"Przekroczono limit operacji AI. Spr贸buj ponownie za {retry_after} sekund.",
                            "retry_after": retry_after,
                        },
                        headers={"Retry-After": str(retry_after)},
                    )
                break

        response = await call_next(request)

        # Dodaj nag艂贸wki informacyjne o limitach (opcjonalnie)
        if path in self.STRICT_PATHS:
            limit = RATE_LIMITS["/api/generator/stream"]
            timestamps = _request_log[user_id][path]
            remaining = max(0, limit["requests"] - len(timestamps))
            response.headers["X-RateLimit-Limit"] = str(limit["requests"])
            response.headers["X-RateLimit-Remaining"] = str(remaining)
            response.headers["X-RateLimit-Window"] = str(limit["window_seconds"])

        return response