Paramjit Singh commited on
Commit
d85fcca
Β·
unverified Β·
2 Parent(s): 87c191fbb497db

Merge pull request #150 from Jiya3177/feat/slowapi-rate-limit-124

Browse files
backend/app/main.py CHANGED
@@ -9,11 +9,15 @@ from contextlib import asynccontextmanager
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
12
- from fastapi.responses import FileResponse
13
  from sqlalchemy import select
14
  from sqlalchemy.exc import SQLAlchemyError
15
 
 
 
 
16
  from app.config import get_settings
 
17
  from app.database import init_db, get_db
18
  from app.rag.vectorstore import get_chroma_client
19
 
@@ -63,6 +67,16 @@ app = FastAPI(
63
  lifespan=lifespan,
64
  )
65
 
 
 
 
 
 
 
 
 
 
 
66
  # ── CORS (allow frontend dev server) ─────────────────
67
  app.add_middleware(
68
  CORSMiddleware,
 
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
12
+ from fastapi.responses import FileResponse, JSONResponse
13
  from sqlalchemy import select
14
  from sqlalchemy.exc import SQLAlchemyError
15
 
16
+ from slowapi.errors import RateLimitExceeded
17
+ from slowapi.middleware import SlowAPIMiddleware
18
+
19
  from app.config import get_settings
20
+ from app.rate_limit import limiter
21
  from app.database import init_db, get_db
22
  from app.rag.vectorstore import get_chroma_client
23
 
 
67
  lifespan=lifespan,
68
  )
69
 
70
+ app.state.limiter = limiter
71
+ app.add_exception_handler(
72
+ RateLimitExceeded,
73
+ lambda request, exc: JSONResponse(
74
+ status_code=429,
75
+ content={"detail": "Rate limit exceeded. Please try again later."},
76
+ ),
77
+ )
78
+ app.add_middleware(SlowAPIMiddleware)
79
+
80
  # ── CORS (allow frontend dev server) ─────────────────
81
  app.add_middleware(
82
  CORSMiddleware,
backend/app/rate_limit.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SlowAPI rate limiting configuration.
3
+ """
4
+ from fastapi import Request
5
+ from slowapi import Limiter
6
+ from slowapi.util import get_remote_address
7
+
8
+
9
+ def rate_limit_key_func(request: Request) -> str:
10
+ """Use authenticated user id when available, otherwise fall back to client IP."""
11
+ authorization = request.headers.get("authorization", "")
12
+ if authorization.lower().startswith("bearer "):
13
+ try:
14
+ from app.auth import decode_token
15
+
16
+ user_id = decode_token(authorization.split(" ", 1)[1])
17
+ if user_id:
18
+ return f"user:{user_id}"
19
+ except Exception:
20
+ pass
21
+ return f"ip:{get_remote_address(request)}"
22
+
23
+
24
+ limiter = Limiter(key_func=rate_limit_key_func)
backend/app/routes/chat.py CHANGED
@@ -8,7 +8,7 @@ from io import BytesIO
8
  import logging
9
  from typing import Optional
10
 
11
- from fastapi import APIRouter, Depends, HTTPException
12
  from fastapi.responses import Response, StreamingResponse
13
  from reportlab.lib.pagesizes import letter
14
  from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
@@ -21,6 +21,7 @@ from app.models import User, ChatMessage, Document
21
  from app.schemas import ChatRequest, ChatResponse, ChatMessageResponse, ChatHistoryResponse, SourceChunk
22
  from app.auth import get_current_user
23
  from app.rag.agent import generate_answer, generate_answer_stream
 
24
 
25
  logger = logging.getLogger(__name__)
26
 
@@ -28,7 +29,9 @@ router = APIRouter(prefix="/chat", tags=["Chat"])
28
 
29
 
30
  @router.post("/ask", response_model=ChatResponse)
 
31
  def ask_question(
 
32
  payload: ChatRequest,
33
  user: User = Depends(get_current_user),
34
  db: Session = Depends(get_db),
@@ -95,7 +98,9 @@ def ask_question(
95
 
96
 
97
  @router.post("/ask/stream")
 
98
  def ask_question_stream(
 
99
  payload: ChatRequest,
100
  user: User = Depends(get_current_user),
101
  db: Session = Depends(get_db),
 
8
  import logging
9
  from typing import Optional
10
 
11
+ from fastapi import APIRouter, Depends, HTTPException, Request
12
  from fastapi.responses import Response, StreamingResponse
13
  from reportlab.lib.pagesizes import letter
14
  from reportlab.lib.styles import ParagraphStyle, getSampleStyleSheet
 
21
  from app.schemas import ChatRequest, ChatResponse, ChatMessageResponse, ChatHistoryResponse, SourceChunk
22
  from app.auth import get_current_user
23
  from app.rag.agent import generate_answer, generate_answer_stream
24
+ from app.rate_limit import limiter
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
29
 
30
 
31
  @router.post("/ask", response_model=ChatResponse)
32
+ @limiter.limit("10/minute")
33
  def ask_question(
34
+ request: Request,
35
  payload: ChatRequest,
36
  user: User = Depends(get_current_user),
37
  db: Session = Depends(get_db),
 
98
 
99
 
100
  @router.post("/ask/stream")
101
+ @limiter.limit("10/minute")
102
  def ask_question_stream(
103
+ request: Request,
104
  payload: ChatRequest,
105
  user: User = Depends(get_current_user),
106
  db: Session = Depends(get_db),
backend/requirements.txt CHANGED
@@ -40,6 +40,7 @@ huggingface-hub
40
 
41
  # Production
42
  gunicorn
 
43
 
44
  # File Validation
45
  #sudo apt-get install libmagic1 // for Debian/Ubuntu
 
40
 
41
  # Production
42
  gunicorn
43
+ slowapi
44
 
45
  # File Validation
46
  #sudo apt-get install libmagic1 // for Debian/Ubuntu