Hitan2004 commited on
Commit
d4cf06c
·
0 Parent(s):

initial RAG system

Browse files
Files changed (11) hide show
  1. .gitignore +6 -0
  2. agent.py +140 -0
  3. app.py +138 -0
  4. config.py +25 -0
  5. ingestion.py +168 -0
  6. main.py +109 -0
  7. requirements.txt +18 -0
  8. retriever.py +62 -0
  9. start.sh +4 -0
  10. test_sources.py +4 -0
  11. verify.py +70 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ *.pkl
4
+ faiss.index
5
+ embedder_model/
6
+ .env
agent.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+ from langgraph.graph import StateGraph, END
3
+ from langchain_groq import ChatGroq
4
+ from langchain_core.messages import HumanMessage, AIMessage
5
+ from config import GROQ_API_KEY, GROQ_MODEL, MAX_RETRIES
6
+
7
+ llm = ChatGroq(
8
+ model=GROQ_MODEL,
9
+ temperature=0,
10
+ api_key=GROQ_API_KEY,
11
+ )
12
+
13
+
14
+ class RAGState(TypedDict):
15
+ question: str
16
+ context_chunks: list
17
+ answer: str
18
+ validation_result: str
19
+ fail_reason: str
20
+ retry_count: int
21
+ chat_history: list
22
+
23
+
24
+ def generate_node(state: RAGState) -> dict:
25
+ context_text = "\n\n---\n\n".join(
26
+ f"[Source: {r['source']}]\n{r['chunk']}"
27
+ for r in state["context_chunks"]
28
+ )
29
+
30
+ history_lines = []
31
+ for msg in state.get("chat_history", [])[-6:]:
32
+ role = "User" if isinstance(msg, HumanMessage) else "Assistant"
33
+ history_lines.append(f"{role}: {msg.content}")
34
+ history_text = "\n".join(history_lines) or "None"
35
+
36
+ correction = ""
37
+ if state.get("retry_count", 0) > 0:
38
+ correction = (
39
+ f"\n\nIMPORTANT CORRECTION REQUIRED: Your previous answer was "
40
+ f"rejected because: {state.get('fail_reason', 'unverifiable claims')}. "
41
+ f"Re-answer using ONLY the context provided."
42
+ )
43
+
44
+ prompt = (
45
+ "You are an AI assistant that answers questions AND generates content based on provided documents.\n"
46
+ "Answer ONLY using information from the CONTEXT below.\n"
47
+ "If the answer cannot be found, say exactly: "
48
+ '"I don\'t have enough information in the provided documents."\n'
49
+ "Do NOT invent facts or use outside knowledge."
50
+ + correction
51
+ + f"\n\nPREVIOUS CONVERSATION:\n{history_text}"
52
+ + f"\n\nCONTEXT:\n{context_text}"
53
+ + f"\n\nQUESTION: {state['question']}\n\nAnswer:"
54
+ )
55
+
56
+ response = llm.invoke([HumanMessage(content=prompt)])
57
+ return {"answer": response.content}
58
+
59
+
60
+ def validate_node(state: RAGState) -> dict:
61
+ context_text = "\n\n".join(r["chunk"] for r in state["context_chunks"])
62
+
63
+ prompt = (
64
+ "You are a strict hallucination checker for a RAG system.\n\n"
65
+ "Given the CONTEXT and the ANSWER below, check:\n"
66
+ "1. Is every factual claim directly supported by the context?\n"
67
+ "2. Does the answer address the question?\n"
68
+ "3. Are there any invented facts not in the context?\n\n"
69
+ f"Context:\n{context_text}\n\n"
70
+ f"Question: {state['question']}\n"
71
+ f"Answer: {state['answer']}\n\n"
72
+ "Respond in EXACTLY this format:\n"
73
+ "VERDICT: PASS\n"
74
+ "REASON: <one sentence>\n\n"
75
+ "or\n\n"
76
+ "VERDICT: FAIL\n"
77
+ "REASON: <one sentence explaining what is wrong>"
78
+ )
79
+
80
+ result = llm.invoke([HumanMessage(content=prompt)])
81
+ text = result.content.strip()
82
+
83
+ verdict = "PASS" if "VERDICT: PASS" in text.upper() else "FAIL"
84
+ reason = ""
85
+ for line in text.splitlines():
86
+ if line.upper().startswith("REASON:"):
87
+ reason = line.split(":", 1)[1].strip()
88
+ break
89
+
90
+ return {"validation_result": verdict, "fail_reason": reason}
91
+
92
+
93
+ def increment_retry_node(state: RAGState) -> dict:
94
+ return {"retry_count": state.get("retry_count", 0) + 1}
95
+
96
+
97
+ def route_after_validation(state: RAGState) -> str:
98
+ if (
99
+ state["validation_result"] == "FAIL"
100
+ and state.get("retry_count", 0) < MAX_RETRIES
101
+ ):
102
+ return "retry"
103
+ return "done"
104
+
105
+
106
+ def _build_graph():
107
+ g = StateGraph(RAGState)
108
+ g.add_node("generate", generate_node)
109
+ g.add_node("validate", validate_node)
110
+ g.add_node("increment_retry", increment_retry_node)
111
+ g.set_entry_point("generate")
112
+ g.add_edge("generate", "validate")
113
+ g.add_conditional_edges(
114
+ "validate",
115
+ route_after_validation,
116
+ {"retry": "increment_retry", "done": END},
117
+ )
118
+ g.add_edge("increment_retry", "generate")
119
+ return g.compile()
120
+
121
+
122
+ _rag_graph = _build_graph()
123
+
124
+
125
+ def run_rag_agent(
126
+ question: str,
127
+ context_chunks: list,
128
+ chat_history: list = [],
129
+ ) -> tuple:
130
+ init_state: RAGState = {
131
+ "question": question,
132
+ "context_chunks": context_chunks,
133
+ "answer": "",
134
+ "validation_result": "",
135
+ "fail_reason": "",
136
+ "retry_count": 0,
137
+ "chat_history": chat_history,
138
+ }
139
+ final = _rag_graph.invoke(init_state)
140
+ return final["answer"], final["retry_count"], final["validation_result"]
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import uuid
3
+ import streamlit as st
4
+ import requests
5
+
6
+ API = "http://localhost:8000"
7
+
8
+ st.set_page_config(
9
+ page_title="Corrective RAG",
10
+ page_icon="📄",
11
+ layout="wide",
12
+ )
13
+ st.title("📄 Corrective RAG — Document Q&A")
14
+ st.caption("Groq LLaMA 3 · FAISS · BM25 · LangGraph self-correction")
15
+
16
+ # ── Session state init ────────────────────────────────────────
17
+ if "session_id" not in st.session_state:
18
+ st.session_state.session_id = str(uuid.uuid4())
19
+ if "messages" not in st.session_state:
20
+ st.session_state.messages = []
21
+
22
+ # ── Sidebar ───────────────────────────────────────────────────
23
+ with st.sidebar:
24
+ st.header("Upload documents")
25
+ uploaded_files = st.file_uploader(
26
+ "Choose .txt or .pdf files",
27
+ type=["txt", "pdf"],
28
+ accept_multiple_files=True,
29
+ )
30
+ if st.button("Index documents", type="primary") and uploaded_files:
31
+ for f in uploaded_files:
32
+ try:
33
+ r = requests.post(
34
+ f"{API}/upload",
35
+ files={"file": (f.name, f.getvalue())},
36
+ timeout=30,
37
+ )
38
+ if r.status_code == 200:
39
+ st.success(f"{f.name} — uploaded, indexing started")
40
+ else:
41
+ st.error(f"{f.name} — {r.json().get('detail', r.text)}")
42
+ except requests.ConnectionError:
43
+ st.error("Cannot reach backend. Is `uvicorn main:app` running?")
44
+
45
+ st.divider()
46
+
47
+ # Health check
48
+ try:
49
+ h = requests.get(f"{API}/health", timeout=3).json()
50
+ idx_status = "ready" if h.get("indexes_loaded") else "not loaded"
51
+ st.caption(f"Backend: connected | Indexes: {idx_status}")
52
+ except Exception:
53
+ st.caption("Backend: not connected")
54
+
55
+ st.divider()
56
+ if st.button("Clear conversation"):
57
+ try:
58
+ requests.delete(f"{API}/session/{st.session_state.session_id}", timeout=5)
59
+ except Exception:
60
+ pass
61
+ st.session_state.messages = []
62
+ st.rerun()
63
+
64
+ st.caption(f"Session ID: `{st.session_state.session_id[:8]}...`")
65
+
66
+ # ── Render chat history ───────────────────────────────────────
67
+ for msg in st.session_state.messages:
68
+ with st.chat_message(msg["role"]):
69
+ st.markdown(msg["content"])
70
+ if msg["role"] == "assistant" and msg.get("meta"):
71
+ m = msg["meta"]
72
+ c1, c2, c3 = st.columns(3)
73
+ c1.metric("Retries used", m["retries"])
74
+ c2.metric("Validation", m["validation"])
75
+ c3.metric("Sources found", m["num_sources"])
76
+ if m.get("sources"):
77
+ with st.expander("View source chunks"):
78
+ for s in m["sources"]:
79
+ st.markdown(f"**{s['source']}**")
80
+ st.text(s["chunk"])
81
+ st.divider()
82
+
83
+ # ── Chat input ────────────────────────────────────────────────
84
+ if question := st.chat_input("Ask a question about your documents..."):
85
+ st.session_state.messages.append({"role": "user", "content": question})
86
+ with st.chat_message("user"):
87
+ st.markdown(question)
88
+
89
+ with st.chat_message("assistant"):
90
+ with st.spinner("Retrieving and generating (with self-correction)..."):
91
+ answer = ""
92
+ meta = {"retries": 0, "validation": "N/A",
93
+ "num_sources": 0, "sources": []}
94
+ try:
95
+ r = requests.post(
96
+ f"{API}/query",
97
+ json={
98
+ "question": question,
99
+ "session_id": st.session_state.session_id,
100
+ },
101
+ timeout=60,
102
+ )
103
+ if r.status_code == 200:
104
+ data = r.json()
105
+ answer = data["answer"]
106
+ meta = {
107
+ "retries": data["retries_used"],
108
+ "validation": data["validation"],
109
+ "num_sources": len(data["sources"]),
110
+ "sources": data["sources"],
111
+ }
112
+ else:
113
+ answer = f"Error {r.status_code}: {r.json().get('detail', r.text)}"
114
+
115
+ except requests.ConnectionError:
116
+ answer = "Cannot reach backend. Make sure `uvicorn main:app` is running."
117
+ except requests.Timeout:
118
+ answer = "Request timed out. The model may be slow — try again."
119
+ except Exception as e:
120
+ answer = f"Unexpected error: {e}"
121
+
122
+ st.markdown(answer)
123
+ c1, c2, c3 = st.columns(3)
124
+ c1.metric("Retries used", meta["retries"])
125
+ c2.metric("Validation", meta["validation"])
126
+ c3.metric("Sources found", meta["num_sources"])
127
+ if meta["sources"]:
128
+ with st.expander("View source chunks"):
129
+ for s in meta["sources"]:
130
+ st.markdown(f"**{s['source']}**")
131
+ st.text(s["chunk"])
132
+ st.divider()
133
+
134
+ st.session_state.messages.append({
135
+ "role": "assistant",
136
+ "content": answer,
137
+ "meta": meta,
138
+ })
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
8
+ GROQ_MODEL = "llama-3.3-70b-versatile"
9
+
10
+ DOCS_DIR = "./docs"
11
+ FAISS_INDEX_PATH = "./faiss.index"
12
+ BM25_PATH = "./bm25.pkl"
13
+ CHUNKS_PATH = "./chunks.pkl"
14
+ SOURCES_PATH = "./sources.pkl"
15
+ EMBEDDER_PATH = "./embedder_model"
16
+ EMBEDDER_NAME = "all-MiniLM-L6-v2"
17
+
18
+ CHUNK_SIZE = 500
19
+ CHUNK_OVERLAP = 50
20
+ TOP_K = 5
21
+ MAX_RETRIES = 3
22
+ MAX_HISTORY_TURNS = 5
23
+
24
+ if not GROQ_API_KEY:
25
+ raise ValueError("GROQ_API_KEY not set in .env file")
ingestion.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ingestion.py
2
+ import os, pickle
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import faiss
6
+ from sentence_transformers import SentenceTransformer
7
+ from rank_bm25 import BM25Okapi
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from config import (
10
+ DOCS_DIR, FAISS_INDEX_PATH, BM25_PATH,
11
+ CHUNKS_PATH, SOURCES_PATH, EMBEDDER_PATH,
12
+ EMBEDDER_NAME, CHUNK_SIZE, CHUNK_OVERLAP
13
+ )
14
+
15
+ # ─────────────────────────────────────────────────────────────
16
+ # Better PDF extraction (IMPORTANT)
17
+ # ─────────────────────────────────────────────────────────────
18
+ def read_pdf_text(fpath):
19
+ import fitz # PyMuPDF
20
+ doc = fitz.open(fpath)
21
+ text = []
22
+ for page in doc:
23
+ text.append(page.get_text())
24
+ return "\n".join(text).strip()
25
+
26
+
27
+ # ─────────────────────────────────────────────────────────────
28
+ # Clean text (removes weird spacing)
29
+ # ─────────────────────────────────────────────────────────────
30
+ def clean_text(text):
31
+ return " ".join(text.split())
32
+
33
+
34
+ # ─────────────────────────────────────────────────────────────
35
+ # Load documents
36
+ # ─────────────────────────────────────────────────────────────
37
+ def load_documents():
38
+ docs, filenames = [], []
39
+ path = Path(DOCS_DIR)
40
+ path.mkdir(exist_ok=True)
41
+
42
+ # Load TXT files
43
+ for fpath in path.glob("*.txt"):
44
+ try:
45
+ text = fpath.read_text(encoding="utf-8")
46
+ text = clean_text(text)
47
+ docs.append(text)
48
+ filenames.append(fpath.name)
49
+ print(f" Loaded text: {fpath.name}")
50
+ except Exception as e:
51
+ print(f" Skipped {fpath.name}: {e}")
52
+
53
+ # Load PDF files (using PyMuPDF)
54
+ for fpath in path.glob("*.pdf"):
55
+ try:
56
+ text = read_pdf_text(fpath)
57
+ text = clean_text(text)
58
+
59
+ if text:
60
+ docs.append(text)
61
+ filenames.append(fpath.name)
62
+ print(f" Loaded PDF: {fpath.name}")
63
+ else:
64
+ print(f" WARNING: {fpath.name} extracted empty text")
65
+ except Exception as e:
66
+ print(f" Skipped {fpath.name}: {e}")
67
+
68
+ if not docs:
69
+ raise FileNotFoundError(
70
+ f"No .txt or .pdf files found in '{DOCS_DIR}'. "
71
+ "Add at least one document and re-run."
72
+ )
73
+
74
+ print(f"\nLoaded {len(docs)} document(s)")
75
+ return docs, filenames
76
+
77
+
78
+ # ─────────────────────────────────────────────────────────────
79
+ # Chunking (optimized for resumes)
80
+ # ─────────────────────────────────────────────────────────────
81
+ def semantic_chunk(docs, filenames):
82
+ splitter = RecursiveCharacterTextSplitter(
83
+ chunk_size=300, # smaller chunks → better retrieval
84
+ chunk_overlap=80,
85
+ separators=["\n\n", "\n", ". ", " "],
86
+ )
87
+
88
+ all_chunks, all_sources = [], []
89
+
90
+ for doc, fname in zip(docs, filenames):
91
+ chunks = splitter.split_text(doc)
92
+ all_chunks.extend(chunks)
93
+ all_sources.extend([fname] * len(chunks))
94
+
95
+ print(f"Created {len(all_chunks)} chunks "
96
+ f"(avg {sum(len(c) for c in all_chunks)//len(all_chunks)} chars each)")
97
+
98
+ # Debug: show sample chunk
99
+ print("\n--- SAMPLE CHUNK ---")
100
+ print(all_chunks[0][:500])
101
+ print("--------------------\n")
102
+
103
+ return all_chunks, all_sources
104
+
105
+
106
+ # ─────────────────────────────────────────────────────────────
107
+ # Build indexes
108
+ # ─────────────────────────────────────────────────────────────
109
+ def build_indexes(chunks):
110
+ print("\nBuilding dense embeddings...")
111
+
112
+ model = SentenceTransformer(EMBEDDER_NAME)
113
+ embeddings = model.encode(chunks, show_progress_bar=True, batch_size=32)
114
+
115
+ embeddings = np.array(embeddings, dtype="float32")
116
+ faiss.normalize_L2(embeddings)
117
+
118
+ dim = embeddings.shape[1]
119
+ faiss_index = faiss.IndexFlatIP(dim)
120
+ faiss_index.add(embeddings)
121
+
122
+ print(f"FAISS index: {faiss_index.ntotal} vectors, dim={dim}")
123
+
124
+ tokenized = [c.lower().split() for c in chunks]
125
+ bm25_index = BM25Okapi(tokenized)
126
+
127
+ print("BM25 index: built")
128
+
129
+ return faiss_index, bm25_index, model
130
+
131
+
132
+ # ─────────────────────────────────────────────────────────────
133
+ # Save everything
134
+ # ─────────────────────────────────────────────────────────────
135
+ def save_indexes(faiss_index, bm25_index, chunks, sources, model):
136
+ faiss.write_index(faiss_index, FAISS_INDEX_PATH)
137
+
138
+ with open(BM25_PATH, "wb") as f:
139
+ pickle.dump(bm25_index, f)
140
+
141
+ with open(CHUNKS_PATH, "wb") as f:
142
+ pickle.dump(chunks, f)
143
+
144
+ with open(SOURCES_PATH, "wb") as f:
145
+ pickle.dump(sources, f)
146
+
147
+ model.save(EMBEDDER_PATH)
148
+
149
+ print("\nSaved indexes to disk.")
150
+
151
+
152
+ # ─────────────────────────────────────────────────────────────
153
+ # Main runner
154
+ # ─────────────────────────────────────────────────────────────
155
+ def run_ingestion():
156
+ print("=== Starting ingestion ===\n")
157
+
158
+ docs, filenames = load_documents()
159
+ chunks, sources = semantic_chunk(docs, filenames)
160
+
161
+ fi, bm25, model = build_indexes(chunks)
162
+ save_indexes(fi, bm25, chunks, sources, model)
163
+
164
+ print("\n=== Ingestion complete ===")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ run_ingestion()
main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
5
+ from pydantic import BaseModel
6
+ from langchain_core.messages import HumanMessage, AIMessage
7
+
8
+ from retriever import load_indexes, reload_indexes, hybrid_retrieve
9
+ from agent import run_rag_agent
10
+ from ingestion import run_ingestion
11
+ from config import DOCS_DIR, TOP_K, MAX_HISTORY_TURNS
12
+
13
+ sessions: dict = {}
14
+
15
+
16
+ @asynccontextmanager
17
+ async def lifespan(app: FastAPI):
18
+ try:
19
+ load_indexes()
20
+ except FileNotFoundError:
21
+ print("WARNING: No indexes found. Upload documents first.")
22
+ yield
23
+
24
+
25
+ app = FastAPI(title="Corrective RAG API", version="1.0", lifespan=lifespan)
26
+
27
+
28
+ class QueryRequest(BaseModel):
29
+ question: str
30
+ session_id: str = "default"
31
+ top_k: int = TOP_K
32
+
33
+
34
+ class QueryResponse(BaseModel):
35
+ answer: str
36
+ sources: list
37
+ retries_used: int
38
+ validation: str
39
+ session_id: str
40
+
41
+
42
+ @app.post("/query", response_model=QueryResponse)
43
+ async def query(req: QueryRequest):
44
+ if not indexes_loaded():
45
+ raise HTTPException(
46
+ status_code=503,
47
+ detail="Indexes not ready. Upload and index documents first."
48
+ )
49
+
50
+ results = hybrid_retrieve(req.question, top_k=req.top_k)
51
+ if not results:
52
+ raise HTTPException(status_code=404, detail="No relevant chunks found.")
53
+
54
+ history = sessions.get(req.session_id, [])
55
+ answer, retries, verdict = run_rag_agent(req.question, results, history)
56
+
57
+ history.append(HumanMessage(content=req.question))
58
+ history.append(AIMessage(content=answer))
59
+ sessions[req.session_id] = history[-(MAX_HISTORY_TURNS * 2):]
60
+
61
+ return QueryResponse(
62
+ answer=answer,
63
+ sources=[{"chunk": r["chunk"][:300], "source": r["source"]} for r in results],
64
+ retries_used=retries,
65
+ validation=verdict,
66
+ session_id=req.session_id,
67
+ )
68
+
69
+
70
+ @app.post("/upload")
71
+ async def upload(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
72
+ allowed = {".txt", ".pdf"}
73
+ ext = os.path.splitext(file.filename or "")[1].lower()
74
+ if ext not in allowed:
75
+ raise HTTPException(status_code=400, detail="Only .txt and .pdf files allowed.")
76
+
77
+ os.makedirs(DOCS_DIR, exist_ok=True)
78
+ dest = os.path.join(DOCS_DIR, file.filename)
79
+ with open(dest, "wb") as f:
80
+ shutil.copyfileobj(file.file, f)
81
+
82
+ background_tasks.add_task(_reindex)
83
+ return {"status": "uploaded", "filename": file.filename,
84
+ "message": "Indexing started in background."}
85
+
86
+
87
+ def _reindex():
88
+ try:
89
+ run_ingestion()
90
+ reload_indexes()
91
+ print("Re-indexing complete.")
92
+ except Exception as e:
93
+ print(f"Re-indexing failed: {e}")
94
+
95
+
96
+ def indexes_loaded():
97
+ from retriever import indexes_loaded as _il
98
+ return _il()
99
+
100
+
101
+ @app.delete("/session/{session_id}")
102
+ def clear_session(session_id: str):
103
+ sessions.pop(session_id, None)
104
+ return {"status": "cleared", "session_id": session_id}
105
+
106
+
107
+ @app.get("/health")
108
+ def health():
109
+ return {"status": "ok", "indexes_loaded": indexes_loaded()}
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## requirements.txt
3
+
4
+ langchain==0.3.25
5
+ langchain-groq==0.3.2
6
+ langgraph==0.3.29
7
+ sentence-transformers==3.4.1
8
+ faiss-cpu==1.9.0
9
+ rank-bm25==0.2.2
10
+ fastapi==0.115.12
11
+ uvicorn==0.34.0
12
+ streamlit==1.44.1
13
+ pdfplumber==0.11.6
14
+ python-dotenv==1.1.0
15
+ numpy==1.26.4
16
+ requests==2.32.3
17
+ pydantic==2.11.1
18
+ pip install python-multipart
retriever.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import faiss
4
+ from sentence_transformers import SentenceTransformer
5
+ from config import (
6
+ FAISS_INDEX_PATH, BM25_PATH, CHUNKS_PATH,
7
+ SOURCES_PATH, EMBEDDER_PATH
8
+ )
9
+
10
+ _faiss_index = None
11
+ _bm25_index = None
12
+ _chunks = None
13
+ _sources = None
14
+ _model = None
15
+
16
+
17
+ def indexes_loaded() -> bool:
18
+ return _faiss_index is not None
19
+
20
+
21
+ def load_indexes():
22
+ global _faiss_index, _bm25_index, _chunks, _sources, _model
23
+ _faiss_index = faiss.read_index(FAISS_INDEX_PATH)
24
+ with open(BM25_PATH, "rb") as f: _bm25_index = pickle.load(f)
25
+ with open(CHUNKS_PATH, "rb") as f: _chunks = pickle.load(f)
26
+ with open(SOURCES_PATH,"rb") as f: _sources = pickle.load(f)
27
+ _model = SentenceTransformer(EMBEDDER_PATH)
28
+ print(f"Indexes loaded: {_faiss_index.ntotal} vectors, {len(_chunks)} chunks")
29
+
30
+
31
+ def reload_indexes():
32
+ global _faiss_index, _bm25_index, _chunks, _sources, _model
33
+ _faiss_index = _bm25_index = _chunks = _sources = _model = None
34
+ load_indexes()
35
+
36
+
37
+ def _reciprocal_rank_fusion(lists: list, k: int = 60) -> list:
38
+ scores: dict = {}
39
+ for ranked_list in lists:
40
+ for rank, doc_id in enumerate(ranked_list):
41
+ scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
42
+ return sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
43
+
44
+
45
+ def hybrid_retrieve(query: str, top_k: int = 5) -> list:
46
+ if not indexes_loaded():
47
+ raise RuntimeError("Indexes not loaded. Call load_indexes() first.")
48
+
49
+ q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
50
+ faiss.normalize_L2(q_emb)
51
+ _, dense_ids = _faiss_index.search(q_emb, top_k * 3)
52
+ dense_ranking = [int(i) for i in dense_ids[0] if i >= 0]
53
+
54
+ bm25_scores = _bm25_index.get_scores(query.lower().split())
55
+ sparse_ranking = np.argsort(bm25_scores)[::-1][:top_k * 3].tolist()
56
+
57
+ merged = _reciprocal_rank_fusion([dense_ranking, sparse_ranking])[:top_k]
58
+
59
+ return [
60
+ {"chunk": _chunks[i], "source": _sources[i], "chunk_id": i}
61
+ for i in merged
62
+ ]
start.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python ingestion.py
4
+ uvicorn main:app --host 0.0.0.0 --port 10000
test_sources.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from retriever import load_indexes, _sources
2
+
3
+ load_indexes()
4
+ print(set(_sources))
verify.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # verify.py — tests each component individually
2
+ import sys
3
+
4
+ def check(label, fn):
5
+ try:
6
+ fn()
7
+ print(f" PASS {label}")
8
+ except Exception as e:
9
+ print(f" FAIL {label}: {e}")
10
+ sys.exit(1)
11
+
12
+ print("\n=== Corrective RAG — environment check ===\n")
13
+
14
+ # 1. Config / API key
15
+ def test_config():
16
+ from config import GROQ_API_KEY
17
+ assert len(GROQ_API_KEY) > 10, "GROQ_API_KEY looks invalid"
18
+ check("Config + GROQ key loaded", test_config)
19
+
20
+ # 2. Groq connection
21
+ def test_groq():
22
+ from langchain_groq import ChatGroq
23
+ from langchain_core.messages import HumanMessage
24
+ from config import GROQ_API_KEY, GROQ_MODEL
25
+ llm = ChatGroq(model=GROQ_MODEL, temperature=0, api_key=GROQ_API_KEY)
26
+ r = llm.invoke([HumanMessage(content="Say OK")])
27
+ assert "ok" in r.content.lower() or len(r.content) > 0
28
+ check("Groq API connection", test_groq)
29
+
30
+ # 3. Ingestion
31
+ def test_ingestion():
32
+ import os
33
+ from pathlib import Path
34
+ Path("./docs").mkdir(exist_ok=True)
35
+ test_file = "./docs/_verify_test.txt"
36
+ Path(test_file).write_text(
37
+ "The Eiffel Tower is in Paris, France. "
38
+ "It was built in 1889 for the World's Fair. "
39
+ "It is 330 metres tall."
40
+ )
41
+ from ingestion import run_ingestion
42
+ run_ingestion()
43
+ os.remove(test_file)
44
+ check("Ingestion pipeline", test_ingestion)
45
+
46
+ # 4. Retriever
47
+ def test_retriever():
48
+ from retriever import load_indexes, hybrid_retrieve
49
+ load_indexes()
50
+ results = hybrid_retrieve("Where is the Eiffel Tower?", top_k=3)
51
+ assert len(results) > 0
52
+ assert "chunk" in results[0]
53
+ assert "source" in results[0]
54
+ check("Hybrid retrieval (BM25 + FAISS)", test_retriever)
55
+
56
+ # 5. Agent
57
+ def test_agent():
58
+ from retriever import hybrid_retrieve
59
+ from agent import run_rag_agent
60
+ results = hybrid_retrieve("How tall is the Eiffel Tower?", top_k=3)
61
+ answer, retries, verdict = run_rag_agent(
62
+ "How tall is the Eiffel Tower?", results
63
+ )
64
+ assert len(answer) > 10, f"Answer too short: {answer}"
65
+ print(f"\n Answer: {answer[:120]}")
66
+ print(f" Retries: {retries}")
67
+ print(f" Verdict: {verdict}")
68
+ check("LangGraph agent (generate + validate)", test_agent)
69
+
70
+ print("\n=== All checks passed — ready to run ===\n")