Spaces:
Sleeping
Sleeping
| import os, shutil, sqlite3, json | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from pydantic import BaseModel | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from retriever import load_indexes, reload_indexes, hybrid_retrieve, indexes_loaded as _indexes_loaded | |
| from agent import run_rag_agent | |
| from ingestion import run_ingestion | |
| from config import DOCS_DIR, TOP_K, MAX_HISTORY_TURNS, SQLITE_PATH | |
| # ββ SQLite session memory βββββββββββββββββββββββββββββ | |
| def _init_db(): | |
| con = sqlite3.connect(SQLITE_PATH) | |
| con.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| session_id TEXT PRIMARY KEY, | |
| history TEXT NOT NULL DEFAULT '[]' | |
| ) | |
| """) | |
| con.commit() | |
| con.close() | |
| def _load_history(session_id: str) -> list: | |
| con = sqlite3.connect(SQLITE_PATH) | |
| row = con.execute( | |
| "SELECT history FROM sessions WHERE session_id=?", (session_id,) | |
| ).fetchone() | |
| con.close() | |
| if not row: | |
| return [] | |
| raw = json.loads(row[0]) | |
| # Reconstruct LangChain message objects | |
| msgs = [] | |
| for m in raw: | |
| if m["role"] == "human": | |
| msgs.append(HumanMessage(content=m["content"])) | |
| else: | |
| msgs.append(AIMessage(content=m["content"])) | |
| return msgs | |
| def _save_history(session_id: str, history: list): | |
| raw = [ | |
| {"role": "human" if isinstance(m, HumanMessage) else "ai", | |
| "content": m.content} | |
| for m in history | |
| ] | |
| con = sqlite3.connect(SQLITE_PATH) | |
| con.execute( | |
| "INSERT OR REPLACE INTO sessions (session_id, history) VALUES (?,?)", | |
| (session_id, json.dumps(raw)) | |
| ) | |
| con.commit() | |
| con.close() | |
| # ββ app lifecycle βββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| _init_db() | |
| load_indexes() | |
| if not _indexes_loaded(): | |
| from pathlib import Path | |
| docs_path = Path(DOCS_DIR) | |
| has_docs = any(docs_path.glob("*.txt")) or any(docs_path.glob("*.pdf")) | |
| if has_docs: | |
| print("Cold start: ChromaDB empty, re-indexing docs folder...") | |
| try: | |
| run_ingestion() | |
| reload_indexes() | |
| print("Cold start ingestion complete.") | |
| except Exception as e: | |
| print(f"Cold start ingestion failed: {e}") | |
| else: | |
| print("WARNING: No indexes and no docs found. Upload documents first.") | |
| yield | |
| app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan) | |
| # ββ models ββββββββββββββββββββββββββββββββββββββββββββ | |
| class QueryRequest(BaseModel): | |
| question: str | |
| session_id: str = "default" | |
| top_k: int = TOP_K | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: list | |
| retries_used: int | |
| validation: str | |
| session_id: str | |
| # ββ routes ββββββββββββββββββββββββββββββββββββββββββββ | |
| def home(): | |
| return {"message": "RAG API running π"} | |
| async def query(req: QueryRequest): | |
| if not _indexes_loaded(): | |
| try: load_indexes() | |
| except: pass | |
| if not _indexes_loaded(): | |
| raise HTTPException(503, detail="Indexes not ready. Upload documents first.") | |
| results = hybrid_retrieve(req.question, top_k=req.top_k) | |
| if not results: | |
| raise HTTPException(404, detail="No relevant chunks found.") | |
| history = _load_history(req.session_id) | |
| try: | |
| answer, retries, verdict = run_rag_agent(req.question, results, history) | |
| except Exception as e: | |
| if "429" in str(e) or "rate_limit" in str(e).lower() or "rate limit" in str(e).lower(): | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Rate limit reached. Please wait 30 seconds and try again." | |
| ) | |
| raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}") | |
| history.append(HumanMessage(content=req.question)) | |
| history.append(AIMessage(content=answer)) | |
| _save_history(req.session_id, history[-(MAX_HISTORY_TURNS * 2):]) | |
| return QueryResponse( | |
| answer=answer, | |
| sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results], | |
| retries_used=retries, | |
| validation=verdict, | |
| session_id=req.session_id, | |
| ) | |
| async def upload(file: UploadFile = File(...)): | |
| allowed = {".txt", ".pdf"} | |
| ext = os.path.splitext(file.filename or "")[1].lower() | |
| if ext not in allowed: | |
| raise HTTPException(400, detail="Only .txt and .pdf allowed.") | |
| os.makedirs(DOCS_DIR, exist_ok=True) | |
| dest = os.path.join(DOCS_DIR, file.filename) | |
| with open(dest, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| _reindex() | |
| return {"status": "uploaded", "filename": file.filename} | |
| def _reindex(): | |
| try: | |
| run_ingestion() | |
| reload_indexes() | |
| print(f"Re-indexing complete. Loaded: {_indexes_loaded()}") | |
| except Exception as e: | |
| import traceback | |
| print(f"Re-indexing failed: {e}"); traceback.print_exc() | |
| def clear_session(session_id: str): | |
| con = sqlite3.connect(SQLITE_PATH) | |
| con.execute("DELETE FROM sessions WHERE session_id=?", (session_id,)) | |
| con.commit(); con.close() | |
| return {"status": "cleared", "session_id": session_id} | |
| def health(): | |
| return {"status": "ok", "indexes_loaded": _indexes_loaded()} | |
| def get_eval(): | |
| if not os.path.exists("eval_results.json"): | |
| raise HTTPException(status_code=404, detail="Run evaluate.py first to generate scores.") | |
| with open("eval_results.json", "r") as f: | |
| return json.load(f) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) | |