rag-visualizer / api /chat.py
Ahmed Sadik
feat: implement rate limiting for chat and document API endpoints
13f26b4
import json
from fastapi.params import Depends
from services.prompt import FALLBACK_RESPONSE
from services.generation import craft_prompt, is_grounded_response, request_llm_response, parse_llm_json, relevant_chunks_to_json
from services.resource_manager import get_llm_client
from fastapi import APIRouter, Form, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from db.qdrant import get_qdrant, get_relevant_chunks
from services.embedding import embed_question
from db.postgres import get_pg_pool, update_document_activity
from services.rate_limiter import limiter
class ChatRequest(BaseModel):
question: str
uuid: str
route = APIRouter(prefix="/chat", tags=["Chat"])
@route.post("/ask")
@limiter.limit("5/minute")
async def ask_question(request: Request, body: ChatRequest = Form(...), pg_pool = Depends(get_pg_pool), qdrant = Depends(get_qdrant)):
question = embed_question(body.question)
try:
await update_document_activity(body.uuid, pg_pool)
chunks, distances = await get_relevant_chunks(question, body.uuid, qdrant)
if not chunks:
return JSONResponse(content={"ok": False, "error": "No chunks found.", "data": None, "relevant_chunks": []}, status_code=404)
if distances and max(distances) < 0.52: # Threshold for relevance, based on empirical testing
return JSONResponse(content={"ok": True, "error": None, "data": json.loads(FALLBACK_RESPONSE), "relevant_chunks": []}, status_code=200)
except Exception as e:
print(f"Error retrieving chunks: {e}")
return JSONResponse(content={"ok": False, "error": "Error retrieving relevant document chunks.", "data": None, "relevant_chunks": []}, status_code=500)
prompt = craft_prompt(body.question, chunks)
client = get_llm_client()
response = await request_llm_response(client, prompt)
if not response.get("ok"):
return JSONResponse(content={"ok": False, "error": response.get("error", "LLM generation failed."), "data": None, "relevant_chunks": []}, status_code=500)
parsed_data = parse_llm_json(response['data'])
if not is_grounded_response(parsed_data):
# If it's ungrounded, malicious, or failed parsing, trigger the fallback
parsed_data = json.loads(FALLBACK_RESPONSE)
retrieved_chunks = [] # Don't include chunks if the response isn't grounded
else:
# If it's grounded, include the relevant chunks in the response for transparency
retrieved_chunks = relevant_chunks_to_json(chunks, distances)
final_response = {
"ok": True,
"error": None,
"data": parsed_data,
"retrieved_chunks": retrieved_chunks
}
return JSONResponse(content=final_response)