MohitGupta41
FastAPI RAG backend (Docker)
f7c12a3
# app.py
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, Any, List
from rag import (
extract_text_from_pdf,
chunk_text,
create_session,
retrieve_top_k,
generate_answer,
SESSIONS,
)
app = FastAPI(title="Mini RAG Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AskRequest(BaseModel):
session_id: str
question: str
top_k: int = 3
@app.get("/")
def root():
return {"status": "ok", "service": "rag-backend", "docs": "/docs"}
@app.post("/ingest")
async def ingest(files: List[UploadFile] = File(...)) -> Dict[str, Any]:
"""
Accept multiple PDF/TXT files and build ONE combined RAG index across all chunks.
"""
all_chunks: List[str] = []
docs_meta: List[Dict[str, Any]] = []
if not files or len(files) == 0:
raise HTTPException(status_code=400, detail="No files uploaded")
for doc_id, file in enumerate(files):
filename = (file.filename or "").lower()
content = await file.read()
if filename.endswith(".pdf"):
text = extract_text_from_pdf(content)
elif filename.endswith(".txt"):
text = content.decode("utf-8", errors="ignore")
else:
raise HTTPException(status_code=400, detail=f"Only PDF or TXT allowed: {file.filename}")
text = text.strip()
if not text:
raise HTTPException(status_code=400, detail=f"No extractable text found in: {file.filename}")
chunks = chunk_text(text, chunk_size_words=350, overlap_words=60)
if len(chunks) == 0:
raise HTTPException(status_code=400, detail=f"Chunking produced 0 chunks for: {file.filename}")
# Prefix chunk with filename to help the LLM cite which file it came from (optional but helpful)
# Keeps retrieval the same, improves answer grounding readability.
chunks = [f"[SOURCE: {file.filename}]\n{c}" for c in chunks]
start_idx = len(all_chunks)
all_chunks.extend(chunks)
end_idx = len(all_chunks) - 1
docs_meta.append(
{
"doc_id": doc_id,
"filename": file.filename,
"num_chunks": len(chunks),
"chunk_range": [start_idx, end_idx],
}
)
session_id = create_session(all_chunks, docs_meta)
return {
"session_id": session_id,
"num_files": len(files),
"num_chunks": len(all_chunks),
"docs": docs_meta,
}
@app.post("/ask")
async def ask(req: AskRequest) -> Dict[str, Any]:
sess = SESSIONS.get(req.session_id)
if not sess:
raise HTTPException(status_code=404, detail="Invalid session_id")
chunks = sess["chunks"]
index = sess["index"]
hits = retrieve_top_k(req.question, chunks, index, k=req.top_k)
context = "\n\n---\n\n".join([h[2] for h in hits])
answer = generate_answer(req.question, context)
# Return full chunk text (no truncation)
sources = [
{"chunk_id": h[0], "score": h[1], "text": h[2]}
for h in hits
]
return {
"answer": answer,
"sources": sources,
"docs": sess.get("docs", []),
}
@app.delete("/session/{session_id}")
async def delete_session(session_id: str):
if session_id in SESSIONS:
del SESSIONS[session_id]
return {"status": "ok"}