File size: 6,179 Bytes
3c72c9d
d4cf06c
aec0e75
d4cf06c
 
ea15168
d4cf06c
 
3c72c9d
d4cf06c
3c72c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4cf06c
 
 
3c72c9d
a977e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4cf06c
 
 
3c72c9d
ea15168
d4cf06c
 
 
 
 
 
 
 
 
 
 
 
3c72c9d
 
 
 
 
 
d4cf06c
 
f097829
3c72c9d
 
ea15168
3c72c9d
 
d4cf06c
 
3c72c9d
 
 
a0b0392
 
 
 
 
 
 
 
 
3c72c9d
d4cf06c
 
3c72c9d
 
d4cf06c
 
 
 
 
 
 
 
 
aec0e75
d4cf06c
 
 
3c72c9d
d4cf06c
 
 
 
aec0e75
3c72c9d
d4cf06c
 
 
aec0e75
d4cf06c
3c72c9d
d4cf06c
3e4a530
3c72c9d
d4cf06c
 
 
3c72c9d
 
 
d4cf06c
 
 
 
d9066de
a977e38
 
 
 
 
 
3d17099
d9066de
 
3c72c9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os, shutil, sqlite3, json
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, AIMessage
from retriever import load_indexes, reload_indexes, hybrid_retrieve, indexes_loaded as _indexes_loaded
from agent import run_rag_agent
from ingestion import run_ingestion
from config import DOCS_DIR, TOP_K, MAX_HISTORY_TURNS, SQLITE_PATH

# ── SQLite session memory ─────────────────────────────

def _init_db():
    con = sqlite3.connect(SQLITE_PATH)
    con.execute("""
        CREATE TABLE IF NOT EXISTS sessions (
            session_id TEXT PRIMARY KEY,
            history    TEXT NOT NULL DEFAULT '[]'
        )
    """)
    con.commit()
    con.close()

def _load_history(session_id: str) -> list:
    con = sqlite3.connect(SQLITE_PATH)
    row = con.execute(
        "SELECT history FROM sessions WHERE session_id=?", (session_id,)
    ).fetchone()
    con.close()
    if not row:
        return []
    raw = json.loads(row[0])
    # Reconstruct LangChain message objects
    msgs = []
    for m in raw:
        if m["role"] == "human":
            msgs.append(HumanMessage(content=m["content"]))
        else:
            msgs.append(AIMessage(content=m["content"]))
    return msgs

def _save_history(session_id: str, history: list):
    raw = [
        {"role": "human" if isinstance(m, HumanMessage) else "ai",
         "content": m.content}
        for m in history
    ]
    con = sqlite3.connect(SQLITE_PATH)
    con.execute(
        "INSERT OR REPLACE INTO sessions (session_id, history) VALUES (?,?)",
        (session_id, json.dumps(raw))
    )
    con.commit()
    con.close()

# ── app lifecycle ─────────────────────────────────────

@asynccontextmanager
async def lifespan(app: FastAPI):
    _init_db()
    load_indexes()
    if not _indexes_loaded():
        from pathlib import Path
        docs_path = Path(DOCS_DIR)
        has_docs = any(docs_path.glob("*.txt")) or any(docs_path.glob("*.pdf"))
        if has_docs:
            print("Cold start: ChromaDB empty, re-indexing docs folder...")
            try:
                run_ingestion()
                reload_indexes()
                print("Cold start ingestion complete.")
            except Exception as e:
                print(f"Cold start ingestion failed: {e}")
        else:
            print("WARNING: No indexes and no docs found. Upload documents first.")
    yield
app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan)

# ── models ────────────────────────────────────────────

class QueryRequest(BaseModel):
    question:   str
    session_id: str = "default"
    top_k:      int = TOP_K

class QueryResponse(BaseModel):
    answer:       str
    sources:      list
    retries_used: int
    validation:   str
    session_id:   str

# ── routes ────────────────────────────────────────────

@app.get("/")
def home():
    return {"message": "RAG API running πŸš€"}

@app.post("/query", response_model=QueryResponse)
async def query(req: QueryRequest):
    if not _indexes_loaded():
        try: load_indexes()
        except: pass
    if not _indexes_loaded():
        raise HTTPException(503, detail="Indexes not ready. Upload documents first.")

    results = hybrid_retrieve(req.question, top_k=req.top_k)
    if not results:
        raise HTTPException(404, detail="No relevant chunks found.")

    history = _load_history(req.session_id)
    try:
        answer, retries, verdict = run_rag_agent(req.question, results, history)
    except Exception as e:
        if "429" in str(e) or "rate_limit" in str(e).lower() or "rate limit" in str(e).lower():
            raise HTTPException(
                status_code=429,
                detail="Rate limit reached. Please wait 30 seconds and try again."
            )
        raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}")

    history.append(HumanMessage(content=req.question))
    history.append(AIMessage(content=answer))
    _save_history(req.session_id, history[-(MAX_HISTORY_TURNS * 2):])

    return QueryResponse(
        answer=answer,
        sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results],
        retries_used=retries,
        validation=verdict,
        session_id=req.session_id,
    )

@app.post("/upload")
async def upload(file: UploadFile = File(...)):
    allowed = {".txt", ".pdf"}
    ext = os.path.splitext(file.filename or "")[1].lower()
    if ext not in allowed:
        raise HTTPException(400, detail="Only .txt and .pdf allowed.")
    os.makedirs(DOCS_DIR, exist_ok=True)
    dest = os.path.join(DOCS_DIR, file.filename)
    with open(dest, "wb") as f:
        shutil.copyfileobj(file.file, f)
    _reindex()
    return {"status": "uploaded", "filename": file.filename}

def _reindex():
    try:
        run_ingestion()
        reload_indexes()
        print(f"Re-indexing complete. Loaded: {_indexes_loaded()}")
    except Exception as e:
        import traceback
        print(f"Re-indexing failed: {e}"); traceback.print_exc()

@app.delete("/session/{session_id}")
def clear_session(session_id: str):
    con = sqlite3.connect(SQLITE_PATH)
    con.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
    con.commit(); con.close()
    return {"status": "cleared", "session_id": session_id}

@app.get("/health")
def health():
    return {"status": "ok", "indexes_loaded": _indexes_loaded()}
@app.get("/eval")
def get_eval():
    if not os.path.exists("eval_results.json"):
        raise HTTPException(status_code=404, detail="Run evaluate.py first to generate scores.")
    with open("eval_results.json", "r") as f:
        return json.load(f)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))