Update main.py
Browse files
main.py
CHANGED
|
@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
|
| 4 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 7 |
-
|
| 8 |
from retriever import load_indexes, reload_indexes, hybrid_retrieve, indexes_loaded as _indexes_loaded
|
| 9 |
from agent import run_rag_agent
|
| 10 |
from ingestion import run_ingestion
|
|
@@ -12,7 +11,6 @@ from config import DOCS_DIR, TOP_K, MAX_HISTORY_TURNS
|
|
| 12 |
|
| 13 |
sessions: dict = {}
|
| 14 |
|
| 15 |
-
|
| 16 |
@asynccontextmanager
|
| 17 |
async def lifespan(app: FastAPI):
|
| 18 |
try:
|
|
@@ -21,21 +19,17 @@ async def lifespan(app: FastAPI):
|
|
| 21 |
print("WARNING: No indexes found. Upload documents first.")
|
| 22 |
yield
|
| 23 |
|
| 24 |
-
|
| 25 |
app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan)
|
| 26 |
|
| 27 |
-
|
| 28 |
@app.get("/")
|
| 29 |
def home():
|
| 30 |
return {"message": "RAG API running 🚀"}
|
| 31 |
|
| 32 |
-
|
| 33 |
class QueryRequest(BaseModel):
|
| 34 |
question: str
|
| 35 |
session_id: str = "default"
|
| 36 |
top_k: int = TOP_K
|
| 37 |
|
| 38 |
-
|
| 39 |
class QueryResponse(BaseModel):
|
| 40 |
answer: str
|
| 41 |
sources: list
|
|
@@ -43,7 +37,6 @@ class QueryResponse(BaseModel):
|
|
| 43 |
validation: str
|
| 44 |
session_id: str
|
| 45 |
|
| 46 |
-
|
| 47 |
@app.post("/query", response_model=QueryResponse)
|
| 48 |
async def query(req: QueryRequest):
|
| 49 |
if not _indexes_loaded():
|
|
@@ -56,20 +49,14 @@ async def query(req: QueryRequest):
|
|
| 56 |
status_code=503,
|
| 57 |
detail="Indexes not ready. Upload and index documents first."
|
| 58 |
)
|
| 59 |
-
results = hybrid_retrieve(req.question, top_k=req.top_k)
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
results = hybrid_retrieve(req.question, top_k=req.top_k)
|
| 63 |
if not results:
|
| 64 |
raise HTTPException(status_code=404, detail="No relevant chunks found.")
|
| 65 |
-
|
| 66 |
history = sessions.get(req.session_id, [])
|
| 67 |
answer, retries, verdict = run_rag_agent(req.question, results, history)
|
| 68 |
-
|
| 69 |
history.append(HumanMessage(content=req.question))
|
| 70 |
history.append(AIMessage(content=answer))
|
| 71 |
sessions[req.session_id] = history[-(MAX_HISTORY_TURNS * 2):]
|
| 72 |
-
|
| 73 |
return QueryResponse(
|
| 74 |
answer=answer,
|
| 75 |
sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results],
|
|
@@ -78,25 +65,20 @@ async def query(req: QueryRequest):
|
|
| 78 |
session_id=req.session_id,
|
| 79 |
)
|
| 80 |
|
| 81 |
-
|
| 82 |
@app.post("/upload")
|
| 83 |
-
|
| 84 |
async def upload(file: UploadFile = File(...)):
|
| 85 |
allowed = {".txt", ".pdf"}
|
| 86 |
ext = os.path.splitext(file.filename or "")[1].lower()
|
| 87 |
if ext not in allowed:
|
| 88 |
raise HTTPException(status_code=400, detail="Only .txt and .pdf files allowed.")
|
| 89 |
-
|
| 90 |
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 91 |
dest = os.path.join(DOCS_DIR, file.filename)
|
| 92 |
with open(dest, "wb") as f:
|
| 93 |
shutil.copyfileobj(file.file, f)
|
| 94 |
-
|
| 95 |
_reindex()
|
| 96 |
return {"status": "uploaded", "filename": file.filename,
|
| 97 |
"message": "Indexing complete."}
|
| 98 |
|
| 99 |
-
|
| 100 |
def _reindex():
|
| 101 |
try:
|
| 102 |
run_ingestion()
|
|
@@ -105,16 +87,15 @@ def _reindex():
|
|
| 105 |
except Exception as e:
|
| 106 |
print(f"Re-indexing failed: {e}")
|
| 107 |
|
| 108 |
-
|
| 109 |
@app.delete("/session/{session_id}")
|
| 110 |
def clear_session(session_id: str):
|
| 111 |
sessions.pop(session_id, None)
|
| 112 |
return {"status": "cleared", "session_id": session_id}
|
| 113 |
|
| 114 |
-
|
| 115 |
@app.get("/health")
|
| 116 |
def health():
|
| 117 |
return {"status": "ok", "indexes_loaded": _indexes_loaded()}
|
|
|
|
| 118 |
if __name__ == "__main__":
|
| 119 |
import uvicorn
|
| 120 |
-
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|
|
|
|
| 4 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from langchain_core.messages import HumanMessage, AIMessage
|
|
|
|
| 7 |
from retriever import load_indexes, reload_indexes, hybrid_retrieve, indexes_loaded as _indexes_loaded
|
| 8 |
from agent import run_rag_agent
|
| 9 |
from ingestion import run_ingestion
|
|
|
|
| 11 |
|
| 12 |
sessions: dict = {}
|
| 13 |
|
|
|
|
| 14 |
@asynccontextmanager
|
| 15 |
async def lifespan(app: FastAPI):
|
| 16 |
try:
|
|
|
|
| 19 |
print("WARNING: No indexes found. Upload documents first.")
|
| 20 |
yield
|
| 21 |
|
|
|
|
| 22 |
app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan)
|
| 23 |
|
|
|
|
| 24 |
@app.get("/")
|
| 25 |
def home():
|
| 26 |
return {"message": "RAG API running 🚀"}
|
| 27 |
|
|
|
|
| 28 |
class QueryRequest(BaseModel):
|
| 29 |
question: str
|
| 30 |
session_id: str = "default"
|
| 31 |
top_k: int = TOP_K
|
| 32 |
|
|
|
|
| 33 |
class QueryResponse(BaseModel):
|
| 34 |
answer: str
|
| 35 |
sources: list
|
|
|
|
| 37 |
validation: str
|
| 38 |
session_id: str
|
| 39 |
|
|
|
|
| 40 |
@app.post("/query", response_model=QueryResponse)
|
| 41 |
async def query(req: QueryRequest):
|
| 42 |
if not _indexes_loaded():
|
|
|
|
| 49 |
status_code=503,
|
| 50 |
detail="Indexes not ready. Upload and index documents first."
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
| 52 |
results = hybrid_retrieve(req.question, top_k=req.top_k)
|
| 53 |
if not results:
|
| 54 |
raise HTTPException(status_code=404, detail="No relevant chunks found.")
|
|
|
|
| 55 |
history = sessions.get(req.session_id, [])
|
| 56 |
answer, retries, verdict = run_rag_agent(req.question, results, history)
|
|
|
|
| 57 |
history.append(HumanMessage(content=req.question))
|
| 58 |
history.append(AIMessage(content=answer))
|
| 59 |
sessions[req.session_id] = history[-(MAX_HISTORY_TURNS * 2):]
|
|
|
|
| 60 |
return QueryResponse(
|
| 61 |
answer=answer,
|
| 62 |
sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results],
|
|
|
|
| 65 |
session_id=req.session_id,
|
| 66 |
)
|
| 67 |
|
|
|
|
| 68 |
@app.post("/upload")
|
|
|
|
| 69 |
async def upload(file: UploadFile = File(...)):
|
| 70 |
allowed = {".txt", ".pdf"}
|
| 71 |
ext = os.path.splitext(file.filename or "")[1].lower()
|
| 72 |
if ext not in allowed:
|
| 73 |
raise HTTPException(status_code=400, detail="Only .txt and .pdf files allowed.")
|
|
|
|
| 74 |
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 75 |
dest = os.path.join(DOCS_DIR, file.filename)
|
| 76 |
with open(dest, "wb") as f:
|
| 77 |
shutil.copyfileobj(file.file, f)
|
|
|
|
| 78 |
_reindex()
|
| 79 |
return {"status": "uploaded", "filename": file.filename,
|
| 80 |
"message": "Indexing complete."}
|
| 81 |
|
|
|
|
| 82 |
def _reindex():
|
| 83 |
try:
|
| 84 |
run_ingestion()
|
|
|
|
| 87 |
except Exception as e:
|
| 88 |
print(f"Re-indexing failed: {e}")
|
| 89 |
|
|
|
|
| 90 |
@app.delete("/session/{session_id}")
|
| 91 |
def clear_session(session_id: str):
|
| 92 |
sessions.pop(session_id, None)
|
| 93 |
return {"status": "cleared", "session_id": session_id}
|
| 94 |
|
|
|
|
| 95 |
@app.get("/health")
|
| 96 |
def health():
|
| 97 |
return {"status": "ok", "indexes_loaded": _indexes_loaded()}
|
| 98 |
+
|
| 99 |
if __name__ == "__main__":
|
| 100 |
import uvicorn
|
| 101 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|