Spaces:
Sleeping
Sleeping
| # app.py | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Dict, Any, List | |
| from rag import ( | |
| extract_text_from_pdf, | |
| chunk_text, | |
| create_session, | |
| retrieve_top_k, | |
| generate_answer, | |
| SESSIONS, | |
| ) | |
| app = FastAPI(title="Mini RAG Backend") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class AskRequest(BaseModel): | |
| session_id: str | |
| question: str | |
| top_k: int = 3 | |
| def root(): | |
| return {"status": "ok", "service": "rag-backend", "docs": "/docs"} | |
| async def ingest(files: List[UploadFile] = File(...)) -> Dict[str, Any]: | |
| """ | |
| Accept multiple PDF/TXT files and build ONE combined RAG index across all chunks. | |
| """ | |
| all_chunks: List[str] = [] | |
| docs_meta: List[Dict[str, Any]] = [] | |
| if not files or len(files) == 0: | |
| raise HTTPException(status_code=400, detail="No files uploaded") | |
| for doc_id, file in enumerate(files): | |
| filename = (file.filename or "").lower() | |
| content = await file.read() | |
| if filename.endswith(".pdf"): | |
| text = extract_text_from_pdf(content) | |
| elif filename.endswith(".txt"): | |
| text = content.decode("utf-8", errors="ignore") | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Only PDF or TXT allowed: {file.filename}") | |
| text = text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail=f"No extractable text found in: {file.filename}") | |
| chunks = chunk_text(text, chunk_size_words=350, overlap_words=60) | |
| if len(chunks) == 0: | |
| raise HTTPException(status_code=400, detail=f"Chunking produced 0 chunks for: {file.filename}") | |
| # Prefix chunk with filename to help the LLM cite which file it came from (optional but helpful) | |
| # Keeps retrieval the same, improves answer grounding readability. | |
| chunks = [f"[SOURCE: {file.filename}]\n{c}" for c in chunks] | |
| start_idx = len(all_chunks) | |
| all_chunks.extend(chunks) | |
| end_idx = len(all_chunks) - 1 | |
| docs_meta.append( | |
| { | |
| "doc_id": doc_id, | |
| "filename": file.filename, | |
| "num_chunks": len(chunks), | |
| "chunk_range": [start_idx, end_idx], | |
| } | |
| ) | |
| session_id = create_session(all_chunks, docs_meta) | |
| return { | |
| "session_id": session_id, | |
| "num_files": len(files), | |
| "num_chunks": len(all_chunks), | |
| "docs": docs_meta, | |
| } | |
| async def ask(req: AskRequest) -> Dict[str, Any]: | |
| sess = SESSIONS.get(req.session_id) | |
| if not sess: | |
| raise HTTPException(status_code=404, detail="Invalid session_id") | |
| chunks = sess["chunks"] | |
| index = sess["index"] | |
| hits = retrieve_top_k(req.question, chunks, index, k=req.top_k) | |
| context = "\n\n---\n\n".join([h[2] for h in hits]) | |
| answer = generate_answer(req.question, context) | |
| # Return full chunk text (no truncation) | |
| sources = [ | |
| {"chunk_id": h[0], "score": h[1], "text": h[2]} | |
| for h in hits | |
| ] | |
| return { | |
| "answer": answer, | |
| "sources": sources, | |
| "docs": sess.get("docs", []), | |
| } | |
| async def delete_session(session_id: str): | |
| if session_id in SESSIONS: | |
| del SESSIONS[session_id] | |
| return {"status": "ok"} | |