ARAG / RAG.py
1MR's picture
Upload 8 files
27b7701 verified
import streamlit as st
import pandas as pd
from langchain.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langchain.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from sentence_transformers import SentenceTransformer, util
from langchain.schema import Document
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.chains import create_history_aware_retriever
from langchain_huggingface import HuggingFaceEmbeddings
bot_template = '''
<div style="display: flex; align-items: center; margin-bottom: 10px; background-color: #B22222; padding: 10px; border-radius: 10px; border: 1px solid #7A0000;">
<div style="flex-shrink: 0; margin-right: 10px;">
<img src="https://raw.githubusercontent.com/AalaaAyman24/Test/main/chatbot.png"
style="max-height: 50px; max-width: 50px; object-fit: cover;">
</div>
<div style="background-color: #B22222; color: white; padding: 10px; border-radius: 10px; max-width: 75%; word-wrap: break-word; overflow-wrap: break-word;">
{msg}
</div>
</div>
'''
user_template = '''
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: flex-end;">
<div style="flex-shrink: 0; margin-left: 10px;">
<img src="https://raw.githubusercontent.com/AalaaAyman24/Test/main/question.png"
style="max-height: 50px; max-width: 50px; border-radius: 50%; object-fit: cover;">
</div>
<div style="background-color: #757882; color: white; padding: 10px; border-radius: 10px; max-width: 75%; word-wrap: break-word; overflow-wrap: break-word;">
{msg}
</div>
</div>
'''
button_style = """
<style>
.small-button {
display: inline-block;
padding: 5px 10px;
font-size: 12px;
color: white;
background-color: #007bff;
border: none;
border-radius: 5px;
cursor: pointer;
margin-right: 5px;
}
.small-button:hover {
background-color: #0056b3;
}
.chat-box {
position: fixed;
bottom: 20px;
width: 100%;
left: 0;
padding: 20px;
background-color: #f1f1f1;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
</style>
"""
# Function to prepare and split documents from CSV or Excel
def prepare_and_split_docs(files):
split_docs = []
for file in files:
# Read the file with pandas based on the extension
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith('.xlsx'):
df = pd.read_excel(file)
# Convert dataframe to text for document splitting (this could vary based on the structure of the data)
# Convert the whole dataframe to string without index
text = df.to_string(index=False)
# Wrap the string into a Document object
document = Document(page_content=text, metadata={"source": file.name})
# Create the splitter and split the document
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=512,
chunk_overlap=256,
disallowed_special=(),
separators=["\n\n", "\n", " "]
)
split_docs.extend(splitter.split_documents([document]))
return split_docs
# Function to ingest documents into the vector database
def ingest_into_vectordb(split_docs):
embeddings = HuggingFaceEmbeddings(
model_name='sentence-transformers/all-MiniLM-L6-v2')
db = FAISS.from_documents(split_docs, embeddings)
DB_FAISS_PATH = 'vectorstore/db_faiss'
db.save_local(DB_FAISS_PATH)
return db
# Function to get the conversation chain
def get_conversation_chain(retriever):
llm = Ollama(model="llama3.2:1b")
contextualize_q_system_prompt = (
"Given the chat history and the latest user question, "
"provide a response that directly addresses the user's query based on the provided documents. "
"Do not rephrase the question or ask follow-up questions."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
system_prompt = (
"As a personal chat assistant, provide accurate and relevant information based on the provided document in 2-3 sentences. "
"Answer should be limited to 50 words and 2-3 sentences. Do not prompt to select answers or formulate a stand-alone question."
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain)
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
return conversational_rag_chain
def calculate_similarity_score(answer: str, context_docs: list) -> float:
model = SentenceTransformer('all-MiniLM-L6-v2')
context_docs = [doc.page_content for doc in context_docs]
answer_embedding = model.encode(answer, convert_to_tensor=True)
context_embeddings = model.encode(context_docs, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(answer_embedding, context_embeddings)
max_score = similarities.max().item()
return max_score
st.title("What can I help with⁉️")
# Sidebar for file upload
uploaded_files = st.sidebar.file_uploader(
"Upload CSV/Excel Documents", type=["csv", "xlsx"], accept_multiple_files=True)
if uploaded_files:
if st.sidebar.button("Process Documents"):
split_docs = prepare_and_split_docs(uploaded_files)
vector_db = ingest_into_vectordb(split_docs)
retriever = vector_db.as_retriever()
st.sidebar.success("Documents processed and vector database created!")
# Initialize the conversation chain
conversational_chain = get_conversation_chain(retriever)
st.session_state.conversational_chain = conversational_chain
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat input
st.markdown(button_style, unsafe_allow_html=True)
user_input = st.text_input("Ask a question about the dataset:",
key="user_input", placeholder="Type your question here...")
if st.button("Submit"):
st.markdown(button_style, unsafe_allow_html=True)
if user_input and 'conversational_chain' in st.session_state:
session_id = "abc123"
conversational_chain = st.session_state.conversational_chain
response = conversational_chain.invoke({"input": user_input}, config={
"configurable": {"session_id": session_id}})
context_docs = response.get('context', [])
st.session_state.chat_history.append(
{"user": user_input, "bot": response['answer'], "context_docs": context_docs})
# Display chat history
if st.session_state.chat_history:
for message in st.session_state.chat_history:
st.markdown(user_template.format(
msg=message['user']), unsafe_allow_html=True)
st.markdown(bot_template.format(
msg=message['bot']), unsafe_allow_html=True)