File size: 6,179 Bytes
3c72c9d d4cf06c aec0e75 d4cf06c ea15168 d4cf06c 3c72c9d d4cf06c 3c72c9d d4cf06c 3c72c9d a977e38 d4cf06c 3c72c9d ea15168 d4cf06c 3c72c9d d4cf06c f097829 3c72c9d ea15168 3c72c9d d4cf06c 3c72c9d a0b0392 3c72c9d d4cf06c 3c72c9d d4cf06c aec0e75 d4cf06c 3c72c9d d4cf06c aec0e75 3c72c9d d4cf06c aec0e75 d4cf06c 3c72c9d d4cf06c 3e4a530 3c72c9d d4cf06c 3c72c9d d4cf06c d9066de a977e38 3d17099 d9066de 3c72c9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import os, shutil, sqlite3, json
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, AIMessage
from retriever import load_indexes, reload_indexes, hybrid_retrieve, indexes_loaded as _indexes_loaded
from agent import run_rag_agent
from ingestion import run_ingestion
from config import DOCS_DIR, TOP_K, MAX_HISTORY_TURNS, SQLITE_PATH
# ββ SQLite session memory βββββββββββββββββββββββββββββ
def _init_db():
con = sqlite3.connect(SQLITE_PATH)
con.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
history TEXT NOT NULL DEFAULT '[]'
)
""")
con.commit()
con.close()
def _load_history(session_id: str) -> list:
con = sqlite3.connect(SQLITE_PATH)
row = con.execute(
"SELECT history FROM sessions WHERE session_id=?", (session_id,)
).fetchone()
con.close()
if not row:
return []
raw = json.loads(row[0])
# Reconstruct LangChain message objects
msgs = []
for m in raw:
if m["role"] == "human":
msgs.append(HumanMessage(content=m["content"]))
else:
msgs.append(AIMessage(content=m["content"]))
return msgs
def _save_history(session_id: str, history: list):
raw = [
{"role": "human" if isinstance(m, HumanMessage) else "ai",
"content": m.content}
for m in history
]
con = sqlite3.connect(SQLITE_PATH)
con.execute(
"INSERT OR REPLACE INTO sessions (session_id, history) VALUES (?,?)",
(session_id, json.dumps(raw))
)
con.commit()
con.close()
# ββ app lifecycle βββββββββββββββββββββββββββββββββββββ
@asynccontextmanager
async def lifespan(app: FastAPI):
_init_db()
load_indexes()
if not _indexes_loaded():
from pathlib import Path
docs_path = Path(DOCS_DIR)
has_docs = any(docs_path.glob("*.txt")) or any(docs_path.glob("*.pdf"))
if has_docs:
print("Cold start: ChromaDB empty, re-indexing docs folder...")
try:
run_ingestion()
reload_indexes()
print("Cold start ingestion complete.")
except Exception as e:
print(f"Cold start ingestion failed: {e}")
else:
print("WARNING: No indexes and no docs found. Upload documents first.")
yield
app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan)
# ββ models ββββββββββββββββββββββββββββββββββββββββββββ
class QueryRequest(BaseModel):
question: str
session_id: str = "default"
top_k: int = TOP_K
class QueryResponse(BaseModel):
answer: str
sources: list
retries_used: int
validation: str
session_id: str
# ββ routes ββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/")
def home():
return {"message": "RAG API running π"}
@app.post("/query", response_model=QueryResponse)
async def query(req: QueryRequest):
if not _indexes_loaded():
try: load_indexes()
except: pass
if not _indexes_loaded():
raise HTTPException(503, detail="Indexes not ready. Upload documents first.")
results = hybrid_retrieve(req.question, top_k=req.top_k)
if not results:
raise HTTPException(404, detail="No relevant chunks found.")
history = _load_history(req.session_id)
try:
answer, retries, verdict = run_rag_agent(req.question, results, history)
except Exception as e:
if "429" in str(e) or "rate_limit" in str(e).lower() or "rate limit" in str(e).lower():
raise HTTPException(
status_code=429,
detail="Rate limit reached. Please wait 30 seconds and try again."
)
raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}")
history.append(HumanMessage(content=req.question))
history.append(AIMessage(content=answer))
_save_history(req.session_id, history[-(MAX_HISTORY_TURNS * 2):])
return QueryResponse(
answer=answer,
sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results],
retries_used=retries,
validation=verdict,
session_id=req.session_id,
)
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
allowed = {".txt", ".pdf"}
ext = os.path.splitext(file.filename or "")[1].lower()
if ext not in allowed:
raise HTTPException(400, detail="Only .txt and .pdf allowed.")
os.makedirs(DOCS_DIR, exist_ok=True)
dest = os.path.join(DOCS_DIR, file.filename)
with open(dest, "wb") as f:
shutil.copyfileobj(file.file, f)
_reindex()
return {"status": "uploaded", "filename": file.filename}
def _reindex():
try:
run_ingestion()
reload_indexes()
print(f"Re-indexing complete. Loaded: {_indexes_loaded()}")
except Exception as e:
import traceback
print(f"Re-indexing failed: {e}"); traceback.print_exc()
@app.delete("/session/{session_id}")
def clear_session(session_id: str):
con = sqlite3.connect(SQLITE_PATH)
con.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
con.commit(); con.close()
return {"status": "cleared", "session_id": session_id}
@app.get("/health")
def health():
return {"status": "ok", "indexes_loaded": _indexes_loaded()}
@app.get("/eval")
def get_eval():
if not os.path.exists("eval_results.json"):
raise HTTPException(status_code=404, detail="Run evaluate.py first to generate scores.")
with open("eval_results.json", "r") as f:
return json.load(f)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|