|
|
import streamlit as st
|
|
|
import pandas as pd
|
|
|
from langchain.document_loaders import DirectoryLoader
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
from langchain_community.llms import Ollama
|
|
|
from langchain.vectorstores import FAISS
|
|
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
|
from langchain.chains import create_retrieval_chain
|
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
from sentence_transformers import SentenceTransformer, util
|
|
|
from langchain.schema import Document
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory
|
|
|
from langchain.chains import create_history_aware_retriever
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
|
|
bot_template = '''
|
|
|
<div style="display: flex; align-items: center; margin-bottom: 10px; background-color: #B22222; padding: 10px; border-radius: 10px; border: 1px solid #7A0000;">
|
|
|
<div style="flex-shrink: 0; margin-right: 10px;">
|
|
|
<img src="https://raw.githubusercontent.com/AalaaAyman24/Test/main/chatbot.png"
|
|
|
style="max-height: 50px; max-width: 50px; object-fit: cover;">
|
|
|
</div>
|
|
|
<div style="background-color: #B22222; color: white; padding: 10px; border-radius: 10px; max-width: 75%; word-wrap: break-word; overflow-wrap: break-word;">
|
|
|
{msg}
|
|
|
</div>
|
|
|
</div>
|
|
|
'''
|
|
|
|
|
|
|
|
|
user_template = '''
|
|
|
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: flex-end;">
|
|
|
<div style="flex-shrink: 0; margin-left: 10px;">
|
|
|
<img src="https://raw.githubusercontent.com/AalaaAyman24/Test/main/question.png"
|
|
|
style="max-height: 50px; max-width: 50px; border-radius: 50%; object-fit: cover;">
|
|
|
</div>
|
|
|
<div style="background-color: #757882; color: white; padding: 10px; border-radius: 10px; max-width: 75%; word-wrap: break-word; overflow-wrap: break-word;">
|
|
|
{msg}
|
|
|
</div>
|
|
|
</div>
|
|
|
'''
|
|
|
|
|
|
button_style = """
|
|
|
<style>
|
|
|
.small-button {
|
|
|
display: inline-block;
|
|
|
padding: 5px 10px;
|
|
|
font-size: 12px;
|
|
|
color: white;
|
|
|
background-color: #007bff;
|
|
|
border: none;
|
|
|
border-radius: 5px;
|
|
|
cursor: pointer;
|
|
|
margin-right: 5px;
|
|
|
}
|
|
|
.small-button:hover {
|
|
|
background-color: #0056b3;
|
|
|
}
|
|
|
.chat-box {
|
|
|
position: fixed;
|
|
|
bottom: 20px;
|
|
|
width: 100%;
|
|
|
left: 0;
|
|
|
padding: 20px;
|
|
|
background-color: #f1f1f1;
|
|
|
border-radius: 10px;
|
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
|
|
}
|
|
|
</style>
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_and_split_docs(files):
|
|
|
split_docs = []
|
|
|
for file in files:
|
|
|
|
|
|
if file.name.endswith('.csv'):
|
|
|
df = pd.read_csv(file)
|
|
|
elif file.name.endswith('.xlsx'):
|
|
|
df = pd.read_excel(file)
|
|
|
|
|
|
|
|
|
|
|
|
text = df.to_string(index=False)
|
|
|
|
|
|
|
|
|
document = Document(page_content=text, metadata={"source": file.name})
|
|
|
|
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
|
|
chunk_size=512,
|
|
|
chunk_overlap=256,
|
|
|
disallowed_special=(),
|
|
|
separators=["\n\n", "\n", " "]
|
|
|
)
|
|
|
split_docs.extend(splitter.split_documents([document]))
|
|
|
return split_docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ingest_into_vectordb(split_docs):
|
|
|
embeddings = HuggingFaceEmbeddings(
|
|
|
model_name='sentence-transformers/all-MiniLM-L6-v2')
|
|
|
db = FAISS.from_documents(split_docs, embeddings)
|
|
|
DB_FAISS_PATH = 'vectorstore/db_faiss'
|
|
|
db.save_local(DB_FAISS_PATH)
|
|
|
return db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_conversation_chain(retriever):
|
|
|
llm = Ollama(model="llama3.2:1b")
|
|
|
contextualize_q_system_prompt = (
|
|
|
"Given the chat history and the latest user question, "
|
|
|
"provide a response that directly addresses the user's query based on the provided documents. "
|
|
|
"Do not rephrase the question or ask follow-up questions."
|
|
|
)
|
|
|
|
|
|
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
|
|
[
|
|
|
("system", contextualize_q_system_prompt),
|
|
|
MessagesPlaceholder("chat_history"),
|
|
|
("human", "{input}"),
|
|
|
]
|
|
|
)
|
|
|
history_aware_retriever = create_history_aware_retriever(
|
|
|
llm, retriever, contextualize_q_prompt
|
|
|
)
|
|
|
|
|
|
system_prompt = (
|
|
|
"As a personal chat assistant, provide accurate and relevant information based on the provided document in 2-3 sentences. "
|
|
|
"Answer should be limited to 50 words and 2-3 sentences. Do not prompt to select answers or formulate a stand-alone question."
|
|
|
"{context}"
|
|
|
)
|
|
|
|
|
|
qa_prompt = ChatPromptTemplate.from_messages(
|
|
|
[
|
|
|
("system", system_prompt),
|
|
|
MessagesPlaceholder("chat_history"),
|
|
|
("human", "{input}"),
|
|
|
]
|
|
|
)
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
|
|
|
|
rag_chain = create_retrieval_chain(
|
|
|
history_aware_retriever, question_answer_chain)
|
|
|
|
|
|
store = {}
|
|
|
|
|
|
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
|
|
if session_id not in store:
|
|
|
store[session_id] = ChatMessageHistory()
|
|
|
return store[session_id]
|
|
|
|
|
|
conversational_rag_chain = RunnableWithMessageHistory(
|
|
|
rag_chain,
|
|
|
get_session_history,
|
|
|
input_messages_key="input",
|
|
|
history_messages_key="chat_history",
|
|
|
output_messages_key="answer",
|
|
|
)
|
|
|
return conversational_rag_chain
|
|
|
|
|
|
|
|
|
def calculate_similarity_score(answer: str, context_docs: list) -> float:
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
context_docs = [doc.page_content for doc in context_docs]
|
|
|
answer_embedding = model.encode(answer, convert_to_tensor=True)
|
|
|
context_embeddings = model.encode(context_docs, convert_to_tensor=True)
|
|
|
similarities = util.pytorch_cos_sim(answer_embedding, context_embeddings)
|
|
|
max_score = similarities.max().item()
|
|
|
return max_score
|
|
|
|
|
|
|
|
|
st.title("What can I help with⁉️")
|
|
|
|
|
|
|
|
|
uploaded_files = st.sidebar.file_uploader(
|
|
|
"Upload CSV/Excel Documents", type=["csv", "xlsx"], accept_multiple_files=True)
|
|
|
|
|
|
if uploaded_files:
|
|
|
if st.sidebar.button("Process Documents"):
|
|
|
split_docs = prepare_and_split_docs(uploaded_files)
|
|
|
vector_db = ingest_into_vectordb(split_docs)
|
|
|
retriever = vector_db.as_retriever()
|
|
|
st.sidebar.success("Documents processed and vector database created!")
|
|
|
|
|
|
|
|
|
conversational_chain = get_conversation_chain(retriever)
|
|
|
st.session_state.conversational_chain = conversational_chain
|
|
|
|
|
|
if 'chat_history' not in st.session_state:
|
|
|
st.session_state.chat_history = []
|
|
|
|
|
|
|
|
|
st.markdown(button_style, unsafe_allow_html=True)
|
|
|
user_input = st.text_input("Ask a question about the dataset:",
|
|
|
key="user_input", placeholder="Type your question here...")
|
|
|
|
|
|
|
|
|
if st.button("Submit"):
|
|
|
st.markdown(button_style, unsafe_allow_html=True)
|
|
|
if user_input and 'conversational_chain' in st.session_state:
|
|
|
session_id = "abc123"
|
|
|
conversational_chain = st.session_state.conversational_chain
|
|
|
response = conversational_chain.invoke({"input": user_input}, config={
|
|
|
"configurable": {"session_id": session_id}})
|
|
|
context_docs = response.get('context', [])
|
|
|
st.session_state.chat_history.append(
|
|
|
{"user": user_input, "bot": response['answer'], "context_docs": context_docs})
|
|
|
|
|
|
|
|
|
if st.session_state.chat_history:
|
|
|
for message in st.session_state.chat_history:
|
|
|
st.markdown(user_template.format(
|
|
|
msg=message['user']), unsafe_allow_html=True)
|
|
|
st.markdown(bot_template.format(
|
|
|
msg=message['bot']), unsafe_allow_html=True)
|
|
|
|