Spaces:
Sleeping
Sleeping
File size: 8,019 Bytes
9a3b3da f571a27 9a3b3da a760e19 9a3b3da f571a27 6117f75 9a3b3da a760e19 9a3b3da efc9cdf 9a3b3da a760e19 9a3b3da a760e19 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """
Secure version of RAG with Memory for customer support agent.
"""
import os
import sys
from typing import Dict
from loguru import logger
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_classic.chains.history_aware_retriever import create_history_aware_retriever
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_classic.chains.retrieval import create_retrieval_chain
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_community.vectorstores import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
from langchain_community.document_loaders import TextLoader
load_dotenv()
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
# Setup production logging
logger.remove()
logger.add(sys.stdout, format="<green>{time:HH:mm:ss}</green> | <level>{level}</level> | {message}", level="INFO")
class MemoryRAG:
def __init__(self, docs_path: str, model: str = "meta-llama/Llama-3.1-8B-Instruct"):
self.docs_path = docs_path
self.store: Dict[str, BaseChatMessageHistory] = {}
try:
logger.info(f"Initializing RAG with knowledge base: {docs_path}")
# 1. Load and chunk documents
loader = DirectoryLoader(docs_path, glob="**/*.md",
loader_cls=TextLoader, silent_errors=False)
docs = loader.load()
logger.info(f"RAG DATABASE STATUS: Loaded {len(docs)} documents from {docs_path}")
if not docs:
logger.warning(f"No documents found in {docs_path}. RAG will be empty.")
splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
chunks = splitter.split_documents(docs)
# 2. Vector DB - Persistent storage
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
self.db = Chroma.from_documents(
chunks,
embeddings,
persist_directory="./chroma_db"
)
# 3. LLM Setup
hf_token = os.getenv("HF_API_TOKEN")
if not hf_token:
logger.critical("HF_API_TOKEN is missing from environment variables!")
raise RuntimeError("HF_API_TOKEN not set")
self.raw_llm = HuggingFaceEndpoint(
repo_id=model,
huggingfacehub_api_token=hf_token,
temperature=0.1,
max_new_tokens=200,
return_full_text=False,
task="conversational"
)
self.llm = ChatHuggingFace(llm=self.raw_llm)
# 4. Chains Setup
self.retriever = self.db.as_retriever(search_kwargs={"k": 6})
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
context_prompt = ChatPromptTemplate.from_messages([
("system", contextualize_q_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(self.llm, self.retriever, context_prompt)
qa_prompt = ChatPromptTemplate.from_messages([
("system", (
"You are the SmartCoffee Support AI. Use the provided context to answer the user's question. "
"\n\n"
"### FORMATTING RULES:\n"
"- Use **Markdown** for all responses.\n"
"- If the answer involves a process or multiple steps, use a **numbered list** (1, 2, 3).\n"
"- If the answer contains several facts, use **bullet points** (•).\n"
"- Use **bold text** for button names or important terms (e.g., 'Press the **Brew** button').\n"
"- Keep the response concise and avoid long paragraphs."
"- If the answer is not in the context, say: 'I'm sorry, I don't have that specific policy in my records.'\n"
"- DO NOT use your internal knowledge to invent support tiers, response times, or phone numbers.\n"
"\n\n"
"Context: {context}"
)),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
logger.success("MemoryRAG system initialized successfully.")
except Exception as e:
logger.exception("Failed to initialize MemoryRAG components")
raise e
def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
def query(self, question: str, session_id: str = "default_session") -> dict:
# Create a logger tied to this session
session_logger = logger.bind(session_id=session_id)
conversational_rag_chain = RunnableWithMessageHistory(
self.rag_chain,
self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
try:
session_logger.info(f"RAG Query received: {question[:50]}...")
result = conversational_rag_chain.invoke(
{"input": question},
config={"configurable": {"session_id": session_id}},
)
# Extract sources directly from the result
sources = list(set([doc.metadata.get("source", "unknown") for doc in result.get("context", [])]))
session_logger.success("RAG Query completed.")
return {
"answer": result["answer"].strip(),
"sources": sources
}
except Exception as e:
session_logger.error(f"RAG Query Error: {e}")
return {
"answer": "I'm sorry, I encountered an error accessing my knowledge base.",
"sources": []
}
if __name__ == "__main__":
import os
import glob
# 1. Define the same safe paths you used in tools.py
possible_paths = [
"/app/data/knowledge_base",
"./data/knowledge_base",
"./backend/data/knowledge_base"
]
KNOWLEDGE_BASE_PATH = None
for p in possible_paths:
# Check if the folder exists and actually has .md files inside
if os.path.exists(p) and glob.glob(os.path.join(p, "*.md")):
KNOWLEDGE_BASE_PATH = p
break
if not KNOWLEDGE_BASE_PATH:
print("CRITICAL ERROR: No knowledge base found in any of the possible paths!")
# Fallback to a default to prevent crash, but it will be empty
KNOWLEDGE_BASE_PATH = "./data/knowledge_base"
# 2. Use the detected path
rag = MemoryRAG(KNOWLEDGE_BASE_PATH, model="meta-llama/Llama-3.1-8B-Instruct")
|