import streamlit as st
import pandas as pd
from langchain.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import Ollama
from langchain.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from sentence_transformers import SentenceTransformer, util
from langchain.schema import Document
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.chains import create_history_aware_retriever
from langchain_huggingface import HuggingFaceEmbeddings
bot_template = '''
'''
user_template = '''
'''
button_style = """
"""
# Function to prepare and split documents from CSV or Excel
def prepare_and_split_docs(files):
split_docs = []
for file in files:
# Read the file with pandas based on the extension
if file.name.endswith('.csv'):
df = pd.read_csv(file)
elif file.name.endswith('.xlsx'):
df = pd.read_excel(file)
# Convert dataframe to text for document splitting (this could vary based on the structure of the data)
# Convert the whole dataframe to string without index
text = df.to_string(index=False)
# Wrap the string into a Document object
document = Document(page_content=text, metadata={"source": file.name})
# Create the splitter and split the document
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=512,
chunk_overlap=256,
disallowed_special=(),
separators=["\n\n", "\n", " "]
)
split_docs.extend(splitter.split_documents([document]))
return split_docs
# Function to ingest documents into the vector database
def ingest_into_vectordb(split_docs):
embeddings = HuggingFaceEmbeddings(
model_name='sentence-transformers/all-MiniLM-L6-v2')
db = FAISS.from_documents(split_docs, embeddings)
DB_FAISS_PATH = 'vectorstore/db_faiss'
db.save_local(DB_FAISS_PATH)
return db
# Function to get the conversation chain
def get_conversation_chain(retriever):
llm = Ollama(model="llama3.2:1b")
contextualize_q_system_prompt = (
"Given the chat history and the latest user question, "
"provide a response that directly addresses the user's query based on the provided documents. "
"Do not rephrase the question or ask follow-up questions."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
system_prompt = (
"As a personal chat assistant, provide accurate and relevant information based on the provided document in 2-3 sentences. "
"Answer should be limited to 50 words and 2-3 sentences. Do not prompt to select answers or formulate a stand-alone question."
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain)
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
return conversational_rag_chain
def calculate_similarity_score(answer: str, context_docs: list) -> float:
model = SentenceTransformer('all-MiniLM-L6-v2')
context_docs = [doc.page_content for doc in context_docs]
answer_embedding = model.encode(answer, convert_to_tensor=True)
context_embeddings = model.encode(context_docs, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(answer_embedding, context_embeddings)
max_score = similarities.max().item()
return max_score
st.title("What can I help with⁉️")
# Sidebar for file upload
uploaded_files = st.sidebar.file_uploader(
"Upload CSV/Excel Documents", type=["csv", "xlsx"], accept_multiple_files=True)
if uploaded_files:
if st.sidebar.button("Process Documents"):
split_docs = prepare_and_split_docs(uploaded_files)
vector_db = ingest_into_vectordb(split_docs)
retriever = vector_db.as_retriever()
st.sidebar.success("Documents processed and vector database created!")
# Initialize the conversation chain
conversational_chain = get_conversation_chain(retriever)
st.session_state.conversational_chain = conversational_chain
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat input
st.markdown(button_style, unsafe_allow_html=True)
user_input = st.text_input("Ask a question about the dataset:",
key="user_input", placeholder="Type your question here...")
if st.button("Submit"):
st.markdown(button_style, unsafe_allow_html=True)
if user_input and 'conversational_chain' in st.session_state:
session_id = "abc123"
conversational_chain = st.session_state.conversational_chain
response = conversational_chain.invoke({"input": user_input}, config={
"configurable": {"session_id": session_id}})
context_docs = response.get('context', [])
st.session_state.chat_history.append(
{"user": user_input, "bot": response['answer'], "context_docs": context_docs})
# Display chat history
if st.session_state.chat_history:
for message in st.session_state.chat_history:
st.markdown(user_template.format(
msg=message['user']), unsafe_allow_html=True)
st.markdown(bot_template.format(
msg=message['bot']), unsafe_allow_html=True)