LifeGuide / rag_chain.py
Shouvik599
Added the multi turn conversation feature
56da115
"""
rag_chain.py β€” Core RAG chain using LangChain + NVIDIA.
KEY FEATURES:
- Per-book retrieval (guaranteed slots per scripture)
- Hybrid BM25 + vector search with NVIDIA reranking
- Semantic cache for repeated/similar questions
- Multi-turn conversation memory (session-based ConversationBufferMemory)
Public API:
query_sacred_texts(question, session_id) -> Generator[str, None, None]
clear_session(session_id)
"""
import os
import json
from dotenv import load_dotenv
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA, NVIDIARerank
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.retrievers import BM25Retriever
from langchain_classic.retrievers import EnsembleRetriever, ContextualCompressionRetriever
from langchain_core.documents import Document
load_dotenv()
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
CHROMA_DB_PATH = os.getenv("CHROMA_DB_PATH", "./chroma_db")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "sacred_texts")
CHUNKS_PER_BOOK = int(os.getenv("CHUNKS_PER_BOOK", "3"))
CACHE_COLLECTION = "semantic_cache"
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "6")) # last N human+AI pairs kept
KNOWN_BOOKS = [
"Bhagavad Gita",
"Quran",
"Bible",
"Guru Granth Sahib",
]
# ─── In-memory session store ──────────────────────────────────────────────────
# { session_id: [HumanMessage | AIMessage, ...] }
_session_store: dict[str, list] = {}
def get_history(session_id: str) -> list:
return _session_store.get(session_id, [])
def append_turn(session_id: str, human_msg: str, ai_msg: str):
history = _session_store.setdefault(session_id, [])
history.append(HumanMessage(content=human_msg))
history.append(AIMessage(content=ai_msg))
# Trim to last MAX_HISTORY_TURNS pairs (each pair = 2 messages)
if len(history) > MAX_HISTORY_TURNS * 2:
_session_store[session_id] = history[-(MAX_HISTORY_TURNS * 2):]
def clear_session(session_id: str):
"""Wipe the conversation history for a session."""
_session_store.pop(session_id, None)
def list_sessions() -> list[str]:
return list(_session_store.keys())
# ─── System Prompt ────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are a scholarly and compassionate guide to sacred scriptures.
You have deep knowledge of the Bhagavad Gita, the Quran, the Bible, and the Guru Granth Sahib.
STRICT RULES you must ALWAYS follow:
1. Answer ONLY using the provided context passages. Do NOT use any external knowledge.
2. If a specific book's passages are provided but not relevant to the question, skip that book.
3. If NONE of the context is relevant, say: "The provided texts do not directly address this question."
4. Always explicitly name and cite which book(s) your answer draws from in the text of your answer.
5. When the question asks to COMPARE books (e.g. "what do Quran and Gita say"), you MUST
address EACH of those books separately, then synthesise the common thread.
6. Be respectful and neutral toward all faiths β€” treat each text with equal reverence.
7. Do NOT speculate, invent verses, or add information beyond the context.
8. You have access to the conversation history. Use it to:
- Understand follow-up questions (e.g. "elaborate on the second point", "what about the Bible?")
- Maintain continuity across turns without repeating yourself unnecessarily
- Resolve pronouns and references ("it", "that teaching", "the verse you mentioned") from history
FORMAT your response as:
- A clear, thoughtful answer (2–4 paragraphs)
- A "πŸ“š Sources" section listing each book referenced with the key insight drawn from it
Context passages from the sacred texts (guaranteed passages from each book):
────────────────────────────────────────
{context}
────────────────────────────────────────
"""
# ─── Embeddings & Vector Store ────────────────────────────────────────────────
def get_embeddings():
return NVIDIAEmbeddings(
model="nvidia/llama-nemotron-embed-vl-1b-v2",
api_key=NVIDIA_API_KEY,
truncate="NONE",
)
def get_vector_store(embeddings):
return Chroma(
persist_directory=CHROMA_DB_PATH,
embedding_function=embeddings,
collection_name=COLLECTION_NAME,
)
# ─── Per-Book Hybrid Retrieval ────────────────────────────────────────────────
def retrieve_per_book(question: str, vector_store: Chroma) -> list:
"""
Retrieve CHUNKS_PER_BOOK chunks from EACH known book independently using
a hybrid BM25+vector ensemble, then rerank the pooled candidates.
"""
all_candidates = []
question_lower = question.lower()
target_books = []
if any(kw in question_lower for kw in ["gita", "bhagavad", "hindu", "hinduism"]):
target_books.append("Bhagavad Gita")
if any(kw in question_lower for kw in ["quran", "koran", "islam", "muslim", "muhammad"]):
target_books.append("Quran")
if any(kw in question_lower for kw in ["bible", "testament", "christian", "jesus", "christ"]):
target_books.append("Bible")
if any(kw in question_lower for kw in ["granth", "guru", "sikh", "sikhism", "nanak"]):
target_books.append("Guru Granth Sahib")
books_to_search = target_books if target_books else KNOWN_BOOKS
print(f"🎯 Routing query to: {books_to_search}")
CANDIDATE_COUNT = 10
for book in books_to_search:
try:
book_data = vector_store.get(where={"book": book})
book_docs = [
Document(page_content=d, metadata=m)
for d, m in zip(book_data["documents"], book_data["metadatas"])
]
if not book_docs:
continue
bm25_retriever = BM25Retriever.from_documents(book_docs)
bm25_retriever.k = CANDIDATE_COUNT
vector_retriever = vector_store.as_retriever(
search_kwargs={"k": CANDIDATE_COUNT, "filter": {"book": book}}
)
ensemble = EnsembleRetriever(
retrievers=[bm25_retriever, vector_retriever],
weights=[0.5, 0.5],
)
book_candidates = ensemble.invoke(question)
all_candidates.extend(book_candidates)
print(f" πŸ“¦ {book}: {len(book_candidates)} candidates")
except Exception as e:
print(f" ❌ {book}: retrieval error β€” {e}")
if not all_candidates:
return []
print(f"πŸš€ Reranking {len(all_candidates)} total candidates...")
reranker = NVIDIARerank(
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
api_key=NVIDIA_API_KEY,
top_n=5,
)
final_docs = reranker.compress_documents(all_candidates, question)
for i, doc in enumerate(final_docs):
score = doc.metadata.get("relevance_score", "N/A")
print(f"Rank {i+1} [{doc.metadata['book']}]: Score {score}")
return final_docs
# ─── Format Retrieved Docs ────────────────────────────────────────────────────
def format_docs(docs: list) -> str:
by_book: dict[str, list] = {}
for doc in docs:
book = doc.metadata.get("book", "Unknown")
by_book.setdefault(book, []).append(doc)
sections = []
for book, book_docs in by_book.items():
header = f"═══ {book} ═══"
chunks = []
for i, doc in enumerate(book_docs, 1):
ang = doc.metadata.get("ang")
ch = doc.metadata.get("chapter")
vs = doc.metadata.get("verse")
if ang:
citation = f"Ang {ang}"
elif ch and vs:
citation = f"{ch}:{vs}"
else:
citation = f"Page {doc.metadata.get('page', '?')}"
chunks.append(f" [{i}] ({citation}): {doc.page_content.strip()}")
sections.append(header + "\n" + "\n\n".join(chunks))
return "\n\n".join(sections)
# ─── Build the RAG Chain ──────────────────────────────────────────────────────
def build_chain():
embeddings = get_embeddings()
vector_store = get_vector_store(embeddings)
llm = ChatNVIDIA(
model="meta/llama-3.3-70b-instruct",
api_key=NVIDIA_API_KEY,
temperature=0.2,
top_p=0.7,
max_output_tokens=2048,
)
# Prompt now includes a chat-history placeholder so prior turns are visible
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"), # ← injected per-request
("human", "{question}"),
])
llm_chain = prompt | llm | StrOutputParser()
return llm_chain, vector_store
# ─── Singleton init ───────────────────────────────────────────────────────────
_llm_chain = None
_vector_store = None
# ─── Public API ───────────────────────────────────────────────────────────────
def query_sacred_texts(question: str, session_id: str = "default"):
"""
Stream an answer grounded in the sacred texts, maintaining per-session
conversation history for natural follow-up questions.
Yields JSON-lines of the form:
{"type": "token", "data": "<chunk>"}
{"type": "sources", "data": [...]}
{"type": "cache", "data": {"answer": "...", "sources": [...]}}
"""
global _llm_chain, _vector_store
if _llm_chain is None:
print("πŸ”§ Initialising RAG chain (first call)...")
_llm_chain, _vector_store = build_chain()
# ── Semantic cache check (skip for follow-ups that reference history) ──
history = get_history(session_id)
is_followup = len(history) > 0
if not is_followup:
cache_coll = _vector_store._client.get_or_create_collection(CACHE_COLLECTION)
cache_results = cache_coll.query(query_texts=[question], n_results=1)
THRESHOLD = 0.35
if cache_results["ids"] and cache_results["ids"][0]:
distance = cache_results["distances"][0][0]
if distance < THRESHOLD:
print(f"⚑️ Semantic Cache Hit! (Distance: {distance:.4f})")
cached = json.loads(cache_results["metadatas"][0][0]["response_json"])
# Store this cache hit in session memory too
append_turn(session_id, question, cached["answer"])
yield json.dumps({"type": "cache", "data": cached}) + "\n"
return
# ── Retrieval ──────────────────────────────────────────────────────────
# For follow-ups, augment the question with the last human turn for better
# semantic search (the follow-up itself may be too short/vague)
retrieval_query = question
if is_followup and len(question.split()) < 8:
last_human = next(
(m.content for m in reversed(history) if isinstance(m, HumanMessage)), ""
)
retrieval_query = f"{last_human} {question}".strip()
print(f"πŸ” Follow-up detected β€” augmented retrieval query: '{retrieval_query}'")
print(f"\nπŸ” Retrieving chunks for: '{retrieval_query}'")
source_docs = retrieve_per_book(retrieval_query, _vector_store)
if not source_docs:
yield json.dumps({"type": "token", "data": "No content found in the knowledge base."}) + "\n"
return
# ── Build sources list ─────────────────────────────────────────────────
seen_sources: set[str] = set()
sources = []
for doc in source_docs:
book = doc.metadata.get("book", "Unknown")
ang = doc.metadata.get("ang")
ch = doc.metadata.get("chapter")
vs = doc.metadata.get("verse")
if ang:
cite_val = f"Ang {ang}"
elif ch and vs:
cite_val = f"{ch}:{vs}"
else:
cite_val = f"p. {doc.metadata.get('page', '?')}"
display_name = f"{book} {cite_val}"
snippet = doc.page_content[:200].strip() + "..."
if display_name not in seen_sources:
seen_sources.add(display_name)
sources.append({"book": display_name, "page": cite_val, "snippet": snippet})
context = format_docs(source_docs)
full_answer = ""
# ── Stream LLM response (history injected here) ────────────────────────
for chunk in _llm_chain.stream({
"context": context,
"question": question,
"history": history, # ← the conversation so far
}):
full_answer += chunk
yield json.dumps({"type": "token", "data": chunk}) + "\n"
# ── Filter sources to those actually cited in the answer ───────────────
answer_lower = full_answer.lower()
final_sources = [s for s in sources if s["book"].lower() in answer_lower] or []
# ── Persist this turn into session memory ─────────────────────────────
append_turn(session_id, question, full_answer)
print(f"πŸ’Ύ Session '{session_id}': {len(get_history(session_id)) // 2} turn(s) stored")
# ── Cache first-turn answers only ─────────────────────────────────────
if not is_followup:
result_to_cache = {"answer": full_answer, "sources": final_sources}
try:
cache_coll = _vector_store._client.get_or_create_collection(CACHE_COLLECTION)
cache_coll.add(
documents=[question],
metadatas=[{"response_json": json.dumps(result_to_cache)}],
ids=[question],
)
except Exception as e:
print(f"⚠️ Cache write failed: {e}")
yield json.dumps({"type": "sources", "data": sources}) + "\n"
# ─── Quick CLI Test ───────────────────────────────────────────────────────────
if __name__ == "__main__":
test_q = "What do the scriptures say about forgiveness?"
print(f"\nπŸ” Test query: {test_q}\n")
for line in query_sacred_texts(test_q, session_id="cli-test"):
obj = json.loads(line)
if obj["type"] == "token":
print(obj["data"], end="", flush=True)
print("\n")