| """ |
| rag_chain.py β Core RAG chain using LangChain + NVIDIA. |
| |
| KEY FEATURES: |
| - Per-book retrieval (guaranteed slots per scripture) |
| - Hybrid BM25 + vector search with NVIDIA reranking |
| - Semantic cache for repeated/similar questions |
| - Multi-turn conversation memory (session-based ConversationBufferMemory) |
| |
| Public API: |
| query_sacred_texts(question, session_id) -> Generator[str, None, None] |
| clear_session(session_id) |
| """ |
|
|
| import os |
| import json |
| from dotenv import load_dotenv |
| from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA, NVIDIARerank |
| from langchain_chroma import Chroma |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.messages import HumanMessage, AIMessage |
| from langchain_community.retrievers import BM25Retriever |
| from langchain_classic.retrievers import EnsembleRetriever, ContextualCompressionRetriever |
| from langchain_core.documents import Document |
|
|
| load_dotenv() |
|
|
| NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY") |
| CHROMA_DB_PATH = os.getenv("CHROMA_DB_PATH", "./chroma_db") |
| COLLECTION_NAME = os.getenv("COLLECTION_NAME", "sacred_texts") |
| CHUNKS_PER_BOOK = int(os.getenv("CHUNKS_PER_BOOK", "3")) |
| CACHE_COLLECTION = "semantic_cache" |
| MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "6")) |
|
|
| KNOWN_BOOKS = [ |
| "Bhagavad Gita", |
| "Quran", |
| "Bible", |
| "Guru Granth Sahib", |
| ] |
|
|
| |
| |
| _session_store: dict[str, list] = {} |
|
|
|
|
| def get_history(session_id: str) -> list: |
| return _session_store.get(session_id, []) |
|
|
|
|
| def append_turn(session_id: str, human_msg: str, ai_msg: str): |
| history = _session_store.setdefault(session_id, []) |
| history.append(HumanMessage(content=human_msg)) |
| history.append(AIMessage(content=ai_msg)) |
| |
| if len(history) > MAX_HISTORY_TURNS * 2: |
| _session_store[session_id] = history[-(MAX_HISTORY_TURNS * 2):] |
|
|
|
|
| def clear_session(session_id: str): |
| """Wipe the conversation history for a session.""" |
| _session_store.pop(session_id, None) |
|
|
|
|
| def list_sessions() -> list[str]: |
| return list(_session_store.keys()) |
|
|
|
|
| |
|
|
| SYSTEM_PROMPT = """You are a scholarly and compassionate guide to sacred scriptures. |
| You have deep knowledge of the Bhagavad Gita, the Quran, the Bible, and the Guru Granth Sahib. |
| |
| STRICT RULES you must ALWAYS follow: |
| 1. Answer ONLY using the provided context passages. Do NOT use any external knowledge. |
| 2. If a specific book's passages are provided but not relevant to the question, skip that book. |
| 3. If NONE of the context is relevant, say: "The provided texts do not directly address this question." |
| 4. Always explicitly name and cite which book(s) your answer draws from in the text of your answer. |
| 5. When the question asks to COMPARE books (e.g. "what do Quran and Gita say"), you MUST |
| address EACH of those books separately, then synthesise the common thread. |
| 6. Be respectful and neutral toward all faiths β treat each text with equal reverence. |
| 7. Do NOT speculate, invent verses, or add information beyond the context. |
| 8. You have access to the conversation history. Use it to: |
| - Understand follow-up questions (e.g. "elaborate on the second point", "what about the Bible?") |
| - Maintain continuity across turns without repeating yourself unnecessarily |
| - Resolve pronouns and references ("it", "that teaching", "the verse you mentioned") from history |
| |
| FORMAT your response as: |
| - A clear, thoughtful answer (2β4 paragraphs) |
| - A "π Sources" section listing each book referenced with the key insight drawn from it |
| |
| Context passages from the sacred texts (guaranteed passages from each book): |
| ββββββββββββββββββββββββββββββββββββββββ |
| {context} |
| ββββββββββββββββββββββββββββββββββββββββ |
| """ |
|
|
|
|
| |
|
|
| def get_embeddings(): |
| return NVIDIAEmbeddings( |
| model="nvidia/llama-nemotron-embed-vl-1b-v2", |
| api_key=NVIDIA_API_KEY, |
| truncate="NONE", |
| ) |
|
|
|
|
| def get_vector_store(embeddings): |
| return Chroma( |
| persist_directory=CHROMA_DB_PATH, |
| embedding_function=embeddings, |
| collection_name=COLLECTION_NAME, |
| ) |
|
|
|
|
| |
|
|
| def retrieve_per_book(question: str, vector_store: Chroma) -> list: |
| """ |
| Retrieve CHUNKS_PER_BOOK chunks from EACH known book independently using |
| a hybrid BM25+vector ensemble, then rerank the pooled candidates. |
| """ |
| all_candidates = [] |
| question_lower = question.lower() |
|
|
| target_books = [] |
| if any(kw in question_lower for kw in ["gita", "bhagavad", "hindu", "hinduism"]): |
| target_books.append("Bhagavad Gita") |
| if any(kw in question_lower for kw in ["quran", "koran", "islam", "muslim", "muhammad"]): |
| target_books.append("Quran") |
| if any(kw in question_lower for kw in ["bible", "testament", "christian", "jesus", "christ"]): |
| target_books.append("Bible") |
| if any(kw in question_lower for kw in ["granth", "guru", "sikh", "sikhism", "nanak"]): |
| target_books.append("Guru Granth Sahib") |
|
|
| books_to_search = target_books if target_books else KNOWN_BOOKS |
| print(f"π― Routing query to: {books_to_search}") |
|
|
| CANDIDATE_COUNT = 10 |
|
|
| for book in books_to_search: |
| try: |
| book_data = vector_store.get(where={"book": book}) |
| book_docs = [ |
| Document(page_content=d, metadata=m) |
| for d, m in zip(book_data["documents"], book_data["metadatas"]) |
| ] |
| if not book_docs: |
| continue |
|
|
| bm25_retriever = BM25Retriever.from_documents(book_docs) |
| bm25_retriever.k = CANDIDATE_COUNT |
|
|
| vector_retriever = vector_store.as_retriever( |
| search_kwargs={"k": CANDIDATE_COUNT, "filter": {"book": book}} |
| ) |
|
|
| ensemble = EnsembleRetriever( |
| retrievers=[bm25_retriever, vector_retriever], |
| weights=[0.5, 0.5], |
| ) |
|
|
| book_candidates = ensemble.invoke(question) |
| all_candidates.extend(book_candidates) |
| print(f" π¦ {book}: {len(book_candidates)} candidates") |
|
|
| except Exception as e: |
| print(f" β {book}: retrieval error β {e}") |
|
|
| if not all_candidates: |
| return [] |
|
|
| print(f"π Reranking {len(all_candidates)} total candidates...") |
| reranker = NVIDIARerank( |
| model="nvidia/llama-3.2-nv-rerankqa-1b-v2", |
| api_key=NVIDIA_API_KEY, |
| top_n=5, |
| ) |
| final_docs = reranker.compress_documents(all_candidates, question) |
|
|
| for i, doc in enumerate(final_docs): |
| score = doc.metadata.get("relevance_score", "N/A") |
| print(f"Rank {i+1} [{doc.metadata['book']}]: Score {score}") |
|
|
| return final_docs |
|
|
|
|
| |
|
|
| def format_docs(docs: list) -> str: |
| by_book: dict[str, list] = {} |
| for doc in docs: |
| book = doc.metadata.get("book", "Unknown") |
| by_book.setdefault(book, []).append(doc) |
|
|
| sections = [] |
| for book, book_docs in by_book.items(): |
| header = f"βββ {book} βββ" |
| chunks = [] |
| for i, doc in enumerate(book_docs, 1): |
| ang = doc.metadata.get("ang") |
| ch = doc.metadata.get("chapter") |
| vs = doc.metadata.get("verse") |
| if ang: |
| citation = f"Ang {ang}" |
| elif ch and vs: |
| citation = f"{ch}:{vs}" |
| else: |
| citation = f"Page {doc.metadata.get('page', '?')}" |
| chunks.append(f" [{i}] ({citation}): {doc.page_content.strip()}") |
| sections.append(header + "\n" + "\n\n".join(chunks)) |
|
|
| return "\n\n".join(sections) |
|
|
|
|
| |
|
|
| def build_chain(): |
| embeddings = get_embeddings() |
| vector_store = get_vector_store(embeddings) |
|
|
| llm = ChatNVIDIA( |
| model="meta/llama-3.3-70b-instruct", |
| api_key=NVIDIA_API_KEY, |
| temperature=0.2, |
| top_p=0.7, |
| max_output_tokens=2048, |
| ) |
|
|
| |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", SYSTEM_PROMPT), |
| MessagesPlaceholder(variable_name="history"), |
| ("human", "{question}"), |
| ]) |
|
|
| llm_chain = prompt | llm | StrOutputParser() |
| return llm_chain, vector_store |
|
|
|
|
| |
|
|
| _llm_chain = None |
| _vector_store = None |
|
|
|
|
| |
|
|
| def query_sacred_texts(question: str, session_id: str = "default"): |
| """ |
| Stream an answer grounded in the sacred texts, maintaining per-session |
| conversation history for natural follow-up questions. |
| |
| Yields JSON-lines of the form: |
| {"type": "token", "data": "<chunk>"} |
| {"type": "sources", "data": [...]} |
| {"type": "cache", "data": {"answer": "...", "sources": [...]}} |
| """ |
| global _llm_chain, _vector_store |
|
|
| if _llm_chain is None: |
| print("π§ Initialising RAG chain (first call)...") |
| _llm_chain, _vector_store = build_chain() |
|
|
| |
| history = get_history(session_id) |
| is_followup = len(history) > 0 |
|
|
| if not is_followup: |
| cache_coll = _vector_store._client.get_or_create_collection(CACHE_COLLECTION) |
| cache_results = cache_coll.query(query_texts=[question], n_results=1) |
|
|
| THRESHOLD = 0.35 |
| if cache_results["ids"] and cache_results["ids"][0]: |
| distance = cache_results["distances"][0][0] |
| if distance < THRESHOLD: |
| print(f"β‘οΈ Semantic Cache Hit! (Distance: {distance:.4f})") |
| cached = json.loads(cache_results["metadatas"][0][0]["response_json"]) |
| |
| append_turn(session_id, question, cached["answer"]) |
| yield json.dumps({"type": "cache", "data": cached}) + "\n" |
| return |
|
|
| |
| |
| |
| retrieval_query = question |
| if is_followup and len(question.split()) < 8: |
| last_human = next( |
| (m.content for m in reversed(history) if isinstance(m, HumanMessage)), "" |
| ) |
| retrieval_query = f"{last_human} {question}".strip() |
| print(f"π Follow-up detected β augmented retrieval query: '{retrieval_query}'") |
|
|
| print(f"\nπ Retrieving chunks for: '{retrieval_query}'") |
| source_docs = retrieve_per_book(retrieval_query, _vector_store) |
|
|
| if not source_docs: |
| yield json.dumps({"type": "token", "data": "No content found in the knowledge base."}) + "\n" |
| return |
|
|
| |
| seen_sources: set[str] = set() |
| sources = [] |
| for doc in source_docs: |
| book = doc.metadata.get("book", "Unknown") |
| ang = doc.metadata.get("ang") |
| ch = doc.metadata.get("chapter") |
| vs = doc.metadata.get("verse") |
| if ang: |
| cite_val = f"Ang {ang}" |
| elif ch and vs: |
| cite_val = f"{ch}:{vs}" |
| else: |
| cite_val = f"p. {doc.metadata.get('page', '?')}" |
| display_name = f"{book} {cite_val}" |
| snippet = doc.page_content[:200].strip() + "..." |
| if display_name not in seen_sources: |
| seen_sources.add(display_name) |
| sources.append({"book": display_name, "page": cite_val, "snippet": snippet}) |
|
|
| context = format_docs(source_docs) |
| full_answer = "" |
|
|
| |
| for chunk in _llm_chain.stream({ |
| "context": context, |
| "question": question, |
| "history": history, |
| }): |
| full_answer += chunk |
| yield json.dumps({"type": "token", "data": chunk}) + "\n" |
|
|
| |
| answer_lower = full_answer.lower() |
| final_sources = [s for s in sources if s["book"].lower() in answer_lower] or [] |
|
|
| |
| append_turn(session_id, question, full_answer) |
| print(f"πΎ Session '{session_id}': {len(get_history(session_id)) // 2} turn(s) stored") |
|
|
| |
| if not is_followup: |
| result_to_cache = {"answer": full_answer, "sources": final_sources} |
| try: |
| cache_coll = _vector_store._client.get_or_create_collection(CACHE_COLLECTION) |
| cache_coll.add( |
| documents=[question], |
| metadatas=[{"response_json": json.dumps(result_to_cache)}], |
| ids=[question], |
| ) |
| except Exception as e: |
| print(f"β οΈ Cache write failed: {e}") |
|
|
| yield json.dumps({"type": "sources", "data": sources}) + "\n" |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| test_q = "What do the scriptures say about forgiveness?" |
| print(f"\nπ Test query: {test_q}\n") |
| for line in query_sacred_texts(test_q, session_id="cli-test"): |
| obj = json.loads(line) |
| if obj["type"] == "token": |
| print(obj["data"], end="", flush=True) |
| print("\n") |