File size: 15,721 Bytes
22fd41f
56da115
 
 
 
 
 
 
 
 
 
 
22fd41f
 
 
56da115
22fd41f
7ae27cd
22fd41f
56da115
22fd41f
56da115
7ae27cd
 
56da115
22fd41f
56da115
22fd41f
56da115
 
 
 
 
 
22fd41f
 
 
 
 
 
 
 
56da115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22fd41f
 
 
 
 
 
 
 
 
 
7ae27cd
22fd41f
 
 
 
56da115
 
 
 
22fd41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56da115
7ae27cd
22fd41f
 
56da115
 
22fd41f
7ae27cd
 
56da115
 
7ae27cd
 
 
 
 
 
 
 
56da115
7ae27cd
 
56da115
 
 
7ae27cd
22fd41f
7ae27cd
56da115
 
 
 
7ae27cd
 
56da115
7ae27cd
 
56da115
 
 
 
 
 
 
 
 
 
 
7ae27cd
56da115
 
22fd41f
 
56da115
7ae27cd
 
56da115
7ae27cd
 
56da115
7ae27cd
56da115
7ae27cd
56da115
 
7ae27cd
 
 
22fd41f
7ae27cd
22fd41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ae27cd
56da115
 
7ae27cd
 
 
 
 
 
56da115
22fd41f
 
 
 
 
 
 
 
56da115
22fd41f
 
da91d6e
 
 
22fd41f
da91d6e
 
22fd41f
 
56da115
22fd41f
 
56da115
 
22fd41f
 
 
 
 
 
56da115
22fd41f
56da115
22fd41f
 
 
56da115
22fd41f
56da115
 
 
 
22fd41f
56da115
 
 
 
22fd41f
 
 
 
 
 
 
56da115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22fd41f
 
7ae27cd
 
22fd41f
56da115
 
22fd41f
 
 
56da115
 
 
7ae27cd
 
 
 
 
 
 
22fd41f
7ae27cd
 
 
 
56da115
 
 
 
 
 
 
 
 
7ae27cd
56da115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ae27cd
22fd41f
 
 
 
 
56da115
22fd41f
56da115
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
"""
rag_chain.py β€” Core RAG chain using LangChain + NVIDIA.

KEY FEATURES:
- Per-book retrieval (guaranteed slots per scripture)
- Hybrid BM25 + vector search with NVIDIA reranking
- Semantic cache for repeated/similar questions
- Multi-turn conversation memory (session-based ConversationBufferMemory)

Public API:
    query_sacred_texts(question, session_id) -> Generator[str, None, None]
    clear_session(session_id)
"""

import os
import json
from dotenv import load_dotenv
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA, NVIDIARerank
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.retrievers import BM25Retriever
from langchain_classic.retrievers import EnsembleRetriever, ContextualCompressionRetriever
from langchain_core.documents import Document

load_dotenv()

NVIDIA_API_KEY    = os.getenv("NVIDIA_API_KEY")
CHROMA_DB_PATH    = os.getenv("CHROMA_DB_PATH", "./chroma_db")
COLLECTION_NAME   = os.getenv("COLLECTION_NAME", "sacred_texts")
CHUNKS_PER_BOOK   = int(os.getenv("CHUNKS_PER_BOOK", "3"))
CACHE_COLLECTION  = "semantic_cache"
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "6"))   # last N human+AI pairs kept

KNOWN_BOOKS = [
    "Bhagavad Gita",
    "Quran",
    "Bible",
    "Guru Granth Sahib",
]

# ─── In-memory session store ──────────────────────────────────────────────────
# { session_id: [HumanMessage | AIMessage, ...] }
_session_store: dict[str, list] = {}


def get_history(session_id: str) -> list:
    return _session_store.get(session_id, [])


def append_turn(session_id: str, human_msg: str, ai_msg: str):
    history = _session_store.setdefault(session_id, [])
    history.append(HumanMessage(content=human_msg))
    history.append(AIMessage(content=ai_msg))
    # Trim to last MAX_HISTORY_TURNS pairs (each pair = 2 messages)
    if len(history) > MAX_HISTORY_TURNS * 2:
        _session_store[session_id] = history[-(MAX_HISTORY_TURNS * 2):]


def clear_session(session_id: str):
    """Wipe the conversation history for a session."""
    _session_store.pop(session_id, None)


def list_sessions() -> list[str]:
    return list(_session_store.keys())


# ─── System Prompt ────────────────────────────────────────────────────────────

SYSTEM_PROMPT = """You are a scholarly and compassionate guide to sacred scriptures.
You have deep knowledge of the Bhagavad Gita, the Quran, the Bible, and the Guru Granth Sahib.

STRICT RULES you must ALWAYS follow:
1. Answer ONLY using the provided context passages. Do NOT use any external knowledge.
2. If a specific book's passages are provided but not relevant to the question, skip that book.
3. If NONE of the context is relevant, say: "The provided texts do not directly address this question."
4. Always explicitly name and cite which book(s) your answer draws from in the text of your answer.
5. When the question asks to COMPARE books (e.g. "what do Quran and Gita say"), you MUST
   address EACH of those books separately, then synthesise the common thread.
6. Be respectful and neutral toward all faiths β€” treat each text with equal reverence.
7. Do NOT speculate, invent verses, or add information beyond the context.
8. You have access to the conversation history. Use it to:
   - Understand follow-up questions (e.g. "elaborate on the second point", "what about the Bible?")
   - Maintain continuity across turns without repeating yourself unnecessarily
   - Resolve pronouns and references ("it", "that teaching", "the verse you mentioned") from history

FORMAT your response as:
- A clear, thoughtful answer (2–4 paragraphs)
- A "πŸ“š Sources" section listing each book referenced with the key insight drawn from it

Context passages from the sacred texts (guaranteed passages from each book):
────────────────────────────────────────
{context}
────────────────────────────────────────
"""


# ─── Embeddings & Vector Store ────────────────────────────────────────────────

def get_embeddings():
    return NVIDIAEmbeddings(
        model="nvidia/llama-nemotron-embed-vl-1b-v2",
        api_key=NVIDIA_API_KEY,
        truncate="NONE",
    )


def get_vector_store(embeddings):
    return Chroma(
        persist_directory=CHROMA_DB_PATH,
        embedding_function=embeddings,
        collection_name=COLLECTION_NAME,
    )


# ─── Per-Book Hybrid Retrieval ────────────────────────────────────────────────

def retrieve_per_book(question: str, vector_store: Chroma) -> list:
    """
    Retrieve CHUNKS_PER_BOOK chunks from EACH known book independently using
    a hybrid BM25+vector ensemble, then rerank the pooled candidates.
    """
    all_candidates = []
    question_lower = question.lower()

    target_books = []
    if any(kw in question_lower for kw in ["gita", "bhagavad", "hindu", "hinduism"]):
        target_books.append("Bhagavad Gita")
    if any(kw in question_lower for kw in ["quran", "koran", "islam", "muslim", "muhammad"]):
        target_books.append("Quran")
    if any(kw in question_lower for kw in ["bible", "testament", "christian", "jesus", "christ"]):
        target_books.append("Bible")
    if any(kw in question_lower for kw in ["granth", "guru", "sikh", "sikhism", "nanak"]):
        target_books.append("Guru Granth Sahib")

    books_to_search = target_books if target_books else KNOWN_BOOKS
    print(f"🎯 Routing query to: {books_to_search}")

    CANDIDATE_COUNT = 10

    for book in books_to_search:
        try:
            book_data = vector_store.get(where={"book": book})
            book_docs = [
                Document(page_content=d, metadata=m)
                for d, m in zip(book_data["documents"], book_data["metadatas"])
            ]
            if not book_docs:
                continue

            bm25_retriever = BM25Retriever.from_documents(book_docs)
            bm25_retriever.k = CANDIDATE_COUNT

            vector_retriever = vector_store.as_retriever(
                search_kwargs={"k": CANDIDATE_COUNT, "filter": {"book": book}}
            )

            ensemble = EnsembleRetriever(
                retrievers=[bm25_retriever, vector_retriever],
                weights=[0.5, 0.5],
            )

            book_candidates = ensemble.invoke(question)
            all_candidates.extend(book_candidates)
            print(f"  πŸ“¦ {book}: {len(book_candidates)} candidates")

        except Exception as e:
            print(f"  ❌  {book}: retrieval error β€” {e}")

    if not all_candidates:
        return []

    print(f"πŸš€ Reranking {len(all_candidates)} total candidates...")
    reranker = NVIDIARerank(
        model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
        api_key=NVIDIA_API_KEY,
        top_n=5,
    )
    final_docs = reranker.compress_documents(all_candidates, question)

    for i, doc in enumerate(final_docs):
        score = doc.metadata.get("relevance_score", "N/A")
        print(f"Rank {i+1} [{doc.metadata['book']}]: Score {score}")

    return final_docs


# ─── Format Retrieved Docs ────────────────────────────────────────────────────

def format_docs(docs: list) -> str:
    by_book: dict[str, list] = {}
    for doc in docs:
        book = doc.metadata.get("book", "Unknown")
        by_book.setdefault(book, []).append(doc)

    sections = []
    for book, book_docs in by_book.items():
        header = f"═══ {book} ═══"
        chunks = []
        for i, doc in enumerate(book_docs, 1):
            ang = doc.metadata.get("ang")
            ch  = doc.metadata.get("chapter")
            vs  = doc.metadata.get("verse")
            if ang:
                citation = f"Ang {ang}"
            elif ch and vs:
                citation = f"{ch}:{vs}"
            else:
                citation = f"Page {doc.metadata.get('page', '?')}"
            chunks.append(f"  [{i}] ({citation}): {doc.page_content.strip()}")
        sections.append(header + "\n" + "\n\n".join(chunks))

    return "\n\n".join(sections)


# ─── Build the RAG Chain ──────────────────────────────────────────────────────

def build_chain():
    embeddings   = get_embeddings()
    vector_store = get_vector_store(embeddings)

    llm = ChatNVIDIA(
        model="meta/llama-3.3-70b-instruct",
        api_key=NVIDIA_API_KEY,
        temperature=0.2,
        top_p=0.7,
        max_output_tokens=2048,
    )

    # Prompt now includes a chat-history placeholder so prior turns are visible
    prompt = ChatPromptTemplate.from_messages([
        ("system", SYSTEM_PROMPT),
        MessagesPlaceholder(variable_name="history"),   # ← injected per-request
        ("human", "{question}"),
    ])

    llm_chain = prompt | llm | StrOutputParser()
    return llm_chain, vector_store


# ─── Singleton init ───────────────────────────────────────────────────────────

_llm_chain    = None
_vector_store = None


# ─── Public API ───────────────────────────────────────────────────────────────

def query_sacred_texts(question: str, session_id: str = "default"):
    """
    Stream an answer grounded in the sacred texts, maintaining per-session
    conversation history for natural follow-up questions.

    Yields JSON-lines of the form:
        {"type": "token",   "data": "<chunk>"}
        {"type": "sources", "data": [...]}
        {"type": "cache",   "data": {"answer": "...", "sources": [...]}}
    """
    global _llm_chain, _vector_store

    if _llm_chain is None:
        print("πŸ”§  Initialising RAG chain (first call)...")
        _llm_chain, _vector_store = build_chain()

    # ── Semantic cache check (skip for follow-ups that reference history) ──
    history = get_history(session_id)
    is_followup = len(history) > 0

    if not is_followup:
        cache_coll = _vector_store._client.get_or_create_collection(CACHE_COLLECTION)
        cache_results = cache_coll.query(query_texts=[question], n_results=1)

        THRESHOLD = 0.35
        if cache_results["ids"] and cache_results["ids"][0]:
            distance = cache_results["distances"][0][0]
            if distance < THRESHOLD:
                print(f"⚑️ Semantic Cache Hit! (Distance: {distance:.4f})")
                cached = json.loads(cache_results["metadatas"][0][0]["response_json"])
                # Store this cache hit in session memory too
                append_turn(session_id, question, cached["answer"])
                yield json.dumps({"type": "cache", "data": cached}) + "\n"
                return

    # ── Retrieval ──────────────────────────────────────────────────────────
    # For follow-ups, augment the question with the last human turn for better
    # semantic search (the follow-up itself may be too short/vague)
    retrieval_query = question
    if is_followup and len(question.split()) < 8:
        last_human = next(
            (m.content for m in reversed(history) if isinstance(m, HumanMessage)), ""
        )
        retrieval_query = f"{last_human} {question}".strip()
        print(f"πŸ” Follow-up detected β€” augmented retrieval query: '{retrieval_query}'")

    print(f"\nπŸ”  Retrieving chunks for: '{retrieval_query}'")
    source_docs = retrieve_per_book(retrieval_query, _vector_store)

    if not source_docs:
        yield json.dumps({"type": "token", "data": "No content found in the knowledge base."}) + "\n"
        return

    # ── Build sources list ─────────────────────────────────────────────────
    seen_sources: set[str] = set()
    sources = []
    for doc in source_docs:
        book = doc.metadata.get("book", "Unknown")
        ang  = doc.metadata.get("ang")
        ch   = doc.metadata.get("chapter")
        vs   = doc.metadata.get("verse")
        if ang:
            cite_val = f"Ang {ang}"
        elif ch and vs:
            cite_val = f"{ch}:{vs}"
        else:
            cite_val = f"p. {doc.metadata.get('page', '?')}"
        display_name = f"{book} {cite_val}"
        snippet = doc.page_content[:200].strip() + "..."
        if display_name not in seen_sources:
            seen_sources.add(display_name)
            sources.append({"book": display_name, "page": cite_val, "snippet": snippet})

    context   = format_docs(source_docs)
    full_answer = ""

    # ── Stream LLM response (history injected here) ────────────────────────
    for chunk in _llm_chain.stream({
        "context":  context,
        "question": question,
        "history":  history,          # ← the conversation so far
    }):
        full_answer += chunk
        yield json.dumps({"type": "token", "data": chunk}) + "\n"

    # ── Filter sources to those actually cited in the answer ───────────────
    answer_lower = full_answer.lower()
    final_sources = [s for s in sources if s["book"].lower() in answer_lower] or []

    # ── Persist this turn into session memory ─────────────────────────────
    append_turn(session_id, question, full_answer)
    print(f"πŸ’Ύ Session '{session_id}': {len(get_history(session_id)) // 2} turn(s) stored")

    # ── Cache first-turn answers only ─────────────────────────────────────
    if not is_followup:
        result_to_cache = {"answer": full_answer, "sources": final_sources}
        try:
            cache_coll = _vector_store._client.get_or_create_collection(CACHE_COLLECTION)
            cache_coll.add(
                documents=[question],
                metadatas=[{"response_json": json.dumps(result_to_cache)}],
                ids=[question],
            )
        except Exception as e:
            print(f"⚠️  Cache write failed: {e}")

    yield json.dumps({"type": "sources", "data": sources}) + "\n"


# ─── Quick CLI Test ───────────────────────────────────────────────────────────

if __name__ == "__main__":
    test_q = "What do the scriptures say about forgiveness?"
    print(f"\nπŸ”  Test query: {test_q}\n")
    for line in query_sacred_texts(test_q, session_id="cli-test"):
        obj = json.loads(line)
        if obj["type"] == "token":
            print(obj["data"], end="", flush=True)
    print("\n")