customer-support-agent-space / rag_with_memory.py
Mohammad Wasil
Fix frontend connection: use relative API path
efc9cdf
"""
Secure version of RAG with Memory for customer support agent.
"""
import os
import sys
from typing import Dict
from loguru import logger
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_classic.chains.history_aware_retriever import create_history_aware_retriever
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_classic.chains.retrieval import create_retrieval_chain
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_community.vectorstores import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
from langchain_community.document_loaders import TextLoader
load_dotenv()
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
# Setup production logging
logger.remove()
logger.add(sys.stdout, format="<green>{time:HH:mm:ss}</green> | <level>{level}</level> | {message}", level="INFO")
class MemoryRAG:
def __init__(self, docs_path: str, model: str = "meta-llama/Llama-3.1-8B-Instruct"):
self.docs_path = docs_path
self.store: Dict[str, BaseChatMessageHistory] = {}
try:
logger.info(f"Initializing RAG with knowledge base: {docs_path}")
# 1. Load and chunk documents
loader = DirectoryLoader(docs_path, glob="**/*.md",
loader_cls=TextLoader, silent_errors=False)
docs = loader.load()
logger.info(f"RAG DATABASE STATUS: Loaded {len(docs)} documents from {docs_path}")
if not docs:
logger.warning(f"No documents found in {docs_path}. RAG will be empty.")
splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
chunks = splitter.split_documents(docs)
# 2. Vector DB - Persistent storage
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
self.db = Chroma.from_documents(
chunks,
embeddings,
persist_directory="./chroma_db"
)
# 3. LLM Setup
hf_token = os.getenv("HF_API_TOKEN")
if not hf_token:
logger.critical("HF_API_TOKEN is missing from environment variables!")
raise RuntimeError("HF_API_TOKEN not set")
self.raw_llm = HuggingFaceEndpoint(
repo_id=model,
huggingfacehub_api_token=hf_token,
temperature=0.1,
max_new_tokens=200,
return_full_text=False,
task="conversational"
)
self.llm = ChatHuggingFace(llm=self.raw_llm)
# 4. Chains Setup
self.retriever = self.db.as_retriever(search_kwargs={"k": 6})
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
context_prompt = ChatPromptTemplate.from_messages([
("system", contextualize_q_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(self.llm, self.retriever, context_prompt)
qa_prompt = ChatPromptTemplate.from_messages([
("system", (
"You are the SmartCoffee Support AI. Use the provided context to answer the user's question. "
"\n\n"
"### FORMATTING RULES:\n"
"- Use **Markdown** for all responses.\n"
"- If the answer involves a process or multiple steps, use a **numbered list** (1, 2, 3).\n"
"- If the answer contains several facts, use **bullet points** (•).\n"
"- Use **bold text** for button names or important terms (e.g., 'Press the **Brew** button').\n"
"- Keep the response concise and avoid long paragraphs."
"- If the answer is not in the context, say: 'I'm sorry, I don't have that specific policy in my records.'\n"
"- DO NOT use your internal knowledge to invent support tiers, response times, or phone numbers.\n"
"\n\n"
"Context: {context}"
)),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
logger.success("MemoryRAG system initialized successfully.")
except Exception as e:
logger.exception("Failed to initialize MemoryRAG components")
raise e
def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
def query(self, question: str, session_id: str = "default_session") -> dict:
# Create a logger tied to this session
session_logger = logger.bind(session_id=session_id)
conversational_rag_chain = RunnableWithMessageHistory(
self.rag_chain,
self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
try:
session_logger.info(f"RAG Query received: {question[:50]}...")
result = conversational_rag_chain.invoke(
{"input": question},
config={"configurable": {"session_id": session_id}},
)
# Extract sources directly from the result
sources = list(set([doc.metadata.get("source", "unknown") for doc in result.get("context", [])]))
session_logger.success("RAG Query completed.")
return {
"answer": result["answer"].strip(),
"sources": sources
}
except Exception as e:
session_logger.error(f"RAG Query Error: {e}")
return {
"answer": "I'm sorry, I encountered an error accessing my knowledge base.",
"sources": []
}
if __name__ == "__main__":
import os
import glob
# 1. Define the same safe paths you used in tools.py
possible_paths = [
"/app/data/knowledge_base",
"./data/knowledge_base",
"./backend/data/knowledge_base"
]
KNOWLEDGE_BASE_PATH = None
for p in possible_paths:
# Check if the folder exists and actually has .md files inside
if os.path.exists(p) and glob.glob(os.path.join(p, "*.md")):
KNOWLEDGE_BASE_PATH = p
break
if not KNOWLEDGE_BASE_PATH:
print("CRITICAL ERROR: No knowledge base found in any of the possible paths!")
# Fallback to a default to prevent crash, but it will be empty
KNOWLEDGE_BASE_PATH = "./data/knowledge_base"
# 2. Use the detected path
rag = MemoryRAG(KNOWLEDGE_BASE_PATH, model="meta-llama/Llama-3.1-8B-Instruct")