Spaces:
Running
Running
Merge pull request #299 from saurabhhhcodes/security/chat-rate-limit-278
Browse files- backend/app/rate_limit.py +3 -0
- backend/app/routes/chat.py +3 -3
- backend/tests/conftest.py +4 -1
- backend/tests/test_rate_limit.py +33 -0
backend/app/rate_limit.py
CHANGED
|
@@ -6,6 +6,9 @@ from slowapi import Limiter
|
|
| 6 |
from slowapi.util import get_remote_address
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
def rate_limit_key_func(request: Request) -> str:
|
| 10 |
"""Use authenticated user id when available, otherwise fall back to client IP."""
|
| 11 |
authorization = request.headers.get("authorization", "")
|
|
|
|
| 6 |
from slowapi.util import get_remote_address
|
| 7 |
|
| 8 |
|
| 9 |
+
CHAT_QUERY_RATE_LIMIT = "15/minute"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
def rate_limit_key_func(request: Request) -> str:
|
| 13 |
"""Use authenticated user id when available, otherwise fall back to client IP."""
|
| 14 |
authorization = request.headers.get("authorization", "")
|
backend/app/routes/chat.py
CHANGED
|
@@ -17,7 +17,7 @@ from app.auth import get_current_user
|
|
| 17 |
from app.database import get_db
|
| 18 |
from app.metrics import record_query_response_time
|
| 19 |
from app.models import User, ChatMessage, Document, SharedMessage, ChatSession
|
| 20 |
-
from app.rate_limit import limiter
|
| 21 |
from app.schemas import (
|
| 22 |
ChatRequest,
|
| 23 |
ChatResponse,
|
|
@@ -221,7 +221,7 @@ def generate_answer_stream(question: str, user_id: str, document_id: Optional[st
|
|
| 221 |
|
| 222 |
|
| 223 |
@router.post("/ask", response_model=ChatResponse)
|
| 224 |
-
@limiter.limit(
|
| 225 |
def ask_question(
|
| 226 |
request: Request,
|
| 227 |
payload: ChatRequest,
|
|
@@ -283,7 +283,7 @@ def ask_question(
|
|
| 283 |
|
| 284 |
|
| 285 |
@router.post("/ask/stream")
|
| 286 |
-
@limiter.limit(
|
| 287 |
def ask_question_stream(
|
| 288 |
request: Request,
|
| 289 |
payload: ChatRequest,
|
|
|
|
| 17 |
from app.database import get_db
|
| 18 |
from app.metrics import record_query_response_time
|
| 19 |
from app.models import User, ChatMessage, Document, SharedMessage, ChatSession
|
| 20 |
+
from app.rate_limit import CHAT_QUERY_RATE_LIMIT, limiter
|
| 21 |
from app.schemas import (
|
| 22 |
ChatRequest,
|
| 23 |
ChatResponse,
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
@router.post("/ask", response_model=ChatResponse)
|
| 224 |
+
@limiter.limit(CHAT_QUERY_RATE_LIMIT)
|
| 225 |
def ask_question(
|
| 226 |
request: Request,
|
| 227 |
payload: ChatRequest,
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
@router.post("/ask/stream")
|
| 286 |
+
@limiter.limit(CHAT_QUERY_RATE_LIMIT)
|
| 287 |
def ask_question_stream(
|
| 288 |
request: Request,
|
| 289 |
payload: ChatRequest,
|
backend/tests/conftest.py
CHANGED
|
@@ -65,8 +65,11 @@ class Limiter:
|
|
| 65 |
def __init__(self, key_func=None, *args, **kwargs):
|
| 66 |
self.key_func = key_func
|
| 67 |
|
| 68 |
-
def limit(self,
|
| 69 |
def decorator(fn):
|
|
|
|
|
|
|
|
|
|
| 70 |
return fn
|
| 71 |
return decorator
|
| 72 |
|
|
|
|
| 65 |
def __init__(self, key_func=None, *args, **kwargs):
|
| 66 |
self.key_func = key_func
|
| 67 |
|
| 68 |
+
def limit(self, value):
|
| 69 |
def decorator(fn):
|
| 70 |
+
limits = list(getattr(fn, "__rate_limits__", []))
|
| 71 |
+
limits.append(value)
|
| 72 |
+
setattr(fn, "__rate_limits__", limits)
|
| 73 |
return fn
|
| 74 |
return decorator
|
| 75 |
|
backend/tests/test_rate_limit.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
|
| 3 |
+
from app.auth import create_access_token
|
| 4 |
+
from app.rate_limit import CHAT_QUERY_RATE_LIMIT, rate_limit_key_func
|
| 5 |
+
from app.routes.chat import ask_question, ask_question_stream
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DummyRequest:
|
| 9 |
+
def __init__(self, headers=None):
|
| 10 |
+
self.headers = headers or {}
|
| 11 |
+
self.client = SimpleNamespace(host="203.0.113.10")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_rate_limit_key_prefers_authenticated_user_id():
|
| 15 |
+
token = create_access_token("user-123")
|
| 16 |
+
|
| 17 |
+
key = rate_limit_key_func(
|
| 18 |
+
DummyRequest(headers={"authorization": f"Bearer {token}"})
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
assert key == "user:user-123"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_rate_limit_key_falls_back_to_client_ip():
|
| 25 |
+
key = rate_limit_key_func(DummyRequest())
|
| 26 |
+
|
| 27 |
+
assert key.startswith("ip:")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_chat_endpoints_use_required_rate_limit():
|
| 31 |
+
assert CHAT_QUERY_RATE_LIMIT == "15/minute"
|
| 32 |
+
assert ask_question.__rate_limits__ == [CHAT_QUERY_RATE_LIMIT]
|
| 33 |
+
assert ask_question_stream.__rate_limits__ == [CHAT_QUERY_RATE_LIMIT]
|