File size: 8,019 Bytes
9a3b3da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f571a27
9a3b3da
a760e19
 
 
 
9a3b3da
 
 
 
 
 
 
 
 
 
 
 
 
f571a27
6117f75
9a3b3da
a760e19
9a3b3da
 
 
 
 
 
 
 
 
 
 
efc9cdf
9a3b3da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a760e19
 
9a3b3da
a760e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
    Secure version of RAG with Memory for customer support agent.    
"""

import os
import sys
from typing import Dict
from loguru import logger
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_classic.chains.history_aware_retriever import create_history_aware_retriever
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_classic.chains.retrieval import create_retrieval_chain
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_community.vectorstores import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
from langchain_community.document_loaders import TextLoader
load_dotenv()
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

# Setup production logging
logger.remove()
logger.add(sys.stdout, format="<green>{time:HH:mm:ss}</green> | <level>{level}</level> | {message}", level="INFO")

class MemoryRAG:
    def __init__(self, docs_path: str, model: str = "meta-llama/Llama-3.1-8B-Instruct"):
        self.docs_path = docs_path
        self.store: Dict[str, BaseChatMessageHistory] = {}
        
        try:
            logger.info(f"Initializing RAG with knowledge base: {docs_path}")
            
            # 1. Load and chunk documents
            loader = DirectoryLoader(docs_path, glob="**/*.md",
                                     loader_cls=TextLoader, silent_errors=False)
            docs = loader.load()
            logger.info(f"RAG DATABASE STATUS: Loaded {len(docs)} documents from {docs_path}")
            if not docs:
                logger.warning(f"No documents found in {docs_path}. RAG will be empty.")

            splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
            chunks = splitter.split_documents(docs)

            # 2. Vector DB - Persistent storage
            embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
            self.db = Chroma.from_documents(
                chunks, 
                embeddings, 
                persist_directory="./chroma_db"
            )

            # 3. LLM Setup
            hf_token = os.getenv("HF_API_TOKEN")
            if not hf_token:
                logger.critical("HF_API_TOKEN is missing from environment variables!")
                raise RuntimeError("HF_API_TOKEN not set")

            self.raw_llm = HuggingFaceEndpoint(
                repo_id=model,
                huggingfacehub_api_token=hf_token,
                temperature=0.1,
                max_new_tokens=200,
                return_full_text=False, 
                task="conversational"
            )
            self.llm = ChatHuggingFace(llm=self.raw_llm)

            # 4. Chains Setup
            self.retriever = self.db.as_retriever(search_kwargs={"k": 6})
            
            contextualize_q_system_prompt = (
            "Given a chat history and the latest user question "
            "which might reference context in the chat history, "
            "formulate a standalone question which can be understood "
            "without the chat history. Do NOT answer the question, "
            "just reformulate it if needed and otherwise return it as is."
        )
            context_prompt = ChatPromptTemplate.from_messages([
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder(variable_name="chat_history"),
            ("human", "{input}"),
        ])

            history_aware_retriever = create_history_aware_retriever(self.llm, self.retriever, context_prompt)

            qa_prompt = ChatPromptTemplate.from_messages([
                ("system", (
                    "You are the SmartCoffee Support AI. Use the provided context to answer the user's question. "
                    "\n\n"
                    "### FORMATTING RULES:\n"
                    "- Use **Markdown** for all responses.\n"
                    "- If the answer involves a process or multiple steps, use a **numbered list** (1, 2, 3).\n"
                    "- If the answer contains several facts, use **bullet points** (•).\n"
                    "- Use **bold text** for button names or important terms (e.g., 'Press the **Brew** button').\n"
                    "- Keep the response concise and avoid long paragraphs."
                    "- If the answer is not in the context, say: 'I'm sorry, I don't have that specific policy in my records.'\n"
                    "- DO NOT use your internal knowledge to invent support tiers, response times, or phone numbers.\n"
                    "\n\n"
                    "Context: {context}"
                )),
                MessagesPlaceholder(variable_name="chat_history"),
                ("human", "{input}"),
                    ])
            question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
            self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
            
            logger.success("MemoryRAG system initialized successfully.")

        except Exception as e:
            logger.exception("Failed to initialize MemoryRAG components")
            raise e

    def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        if session_id not in self.store:
            self.store[session_id] = ChatMessageHistory()
        return self.store[session_id]

    def query(self, question: str, session_id: str = "default_session") -> dict:
        # Create a logger tied to this session
        session_logger = logger.bind(session_id=session_id)
        
        conversational_rag_chain = RunnableWithMessageHistory(
            self.rag_chain,
            self.get_session_history,
            input_messages_key="input",
            history_messages_key="chat_history",
            output_messages_key="answer",
        )

        try:
            session_logger.info(f"RAG Query received: {question[:50]}...")
            
            result = conversational_rag_chain.invoke(
                {"input": question},
                config={"configurable": {"session_id": session_id}},
            )

            # Extract sources directly from the result
            sources = list(set([doc.metadata.get("source", "unknown") for doc in result.get("context", [])]))

            session_logger.success("RAG Query completed.")
            return {
                "answer": result["answer"].strip(),
                "sources": sources
            }

        except Exception as e:
            session_logger.error(f"RAG Query Error: {e}")
            return {
                "answer": "I'm sorry, I encountered an error accessing my knowledge base.",
                "sources": []
            }


if __name__ == "__main__":
    import os
    import glob

    # 1. Define the same safe paths you used in tools.py
    possible_paths = [
        "/app/data/knowledge_base",
        "./data/knowledge_base",
        "./backend/data/knowledge_base"
    ]

    KNOWLEDGE_BASE_PATH = None
    for p in possible_paths:
        # Check if the folder exists and actually has .md files inside
        if os.path.exists(p) and glob.glob(os.path.join(p, "*.md")):
            KNOWLEDGE_BASE_PATH = p
            break

    if not KNOWLEDGE_BASE_PATH:
        print("CRITICAL ERROR: No knowledge base found in any of the possible paths!")
        # Fallback to a default to prevent crash, but it will be empty
        KNOWLEDGE_BASE_PATH = "./data/knowledge_base" 

    # 2. Use the detected path
    rag = MemoryRAG(KNOWLEDGE_BASE_PATH, model="meta-llama/Llama-3.1-8B-Instruct")