# app.py from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Dict, Any, List from rag import ( extract_text_from_pdf, chunk_text, create_session, retrieve_top_k, generate_answer, SESSIONS, ) app = FastAPI(title="Mini RAG Backend") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AskRequest(BaseModel): session_id: str question: str top_k: int = 3 @app.get("/") def root(): return {"status": "ok", "service": "rag-backend", "docs": "/docs"} @app.post("/ingest") async def ingest(files: List[UploadFile] = File(...)) -> Dict[str, Any]: """ Accept multiple PDF/TXT files and build ONE combined RAG index across all chunks. """ all_chunks: List[str] = [] docs_meta: List[Dict[str, Any]] = [] if not files or len(files) == 0: raise HTTPException(status_code=400, detail="No files uploaded") for doc_id, file in enumerate(files): filename = (file.filename or "").lower() content = await file.read() if filename.endswith(".pdf"): text = extract_text_from_pdf(content) elif filename.endswith(".txt"): text = content.decode("utf-8", errors="ignore") else: raise HTTPException(status_code=400, detail=f"Only PDF or TXT allowed: {file.filename}") text = text.strip() if not text: raise HTTPException(status_code=400, detail=f"No extractable text found in: {file.filename}") chunks = chunk_text(text, chunk_size_words=350, overlap_words=60) if len(chunks) == 0: raise HTTPException(status_code=400, detail=f"Chunking produced 0 chunks for: {file.filename}") # Prefix chunk with filename to help the LLM cite which file it came from (optional but helpful) # Keeps retrieval the same, improves answer grounding readability. chunks = [f"[SOURCE: {file.filename}]\n{c}" for c in chunks] start_idx = len(all_chunks) all_chunks.extend(chunks) end_idx = len(all_chunks) - 1 docs_meta.append( { "doc_id": doc_id, "filename": file.filename, "num_chunks": len(chunks), "chunk_range": [start_idx, end_idx], } ) session_id = create_session(all_chunks, docs_meta) return { "session_id": session_id, "num_files": len(files), "num_chunks": len(all_chunks), "docs": docs_meta, } @app.post("/ask") async def ask(req: AskRequest) -> Dict[str, Any]: sess = SESSIONS.get(req.session_id) if not sess: raise HTTPException(status_code=404, detail="Invalid session_id") chunks = sess["chunks"] index = sess["index"] hits = retrieve_top_k(req.question, chunks, index, k=req.top_k) context = "\n\n---\n\n".join([h[2] for h in hits]) answer = generate_answer(req.question, context) # Return full chunk text (no truncation) sources = [ {"chunk_id": h[0], "score": h[1], "text": h[2]} for h in hits ] return { "answer": answer, "sources": sources, "docs": sess.get("docs", []), } @app.delete("/session/{session_id}") async def delete_session(session_id: str): if session_id in SESSIONS: del SESSIONS[session_id] return {"status": "ok"}