import streamlit as st import pandas as pd from langchain.document_loaders import DirectoryLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.llms import Ollama from langchain.vectorstores import FAISS from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from sentence_transformers import SentenceTransformer, util from langchain.schema import Document from langchain_core.chat_history import BaseChatMessageHistory from langchain.chains import create_history_aware_retriever from langchain_huggingface import HuggingFaceEmbeddings bot_template = '''
{msg}
''' user_template = '''
{msg}
''' button_style = """ """ # Function to prepare and split documents from CSV or Excel def prepare_and_split_docs(files): split_docs = [] for file in files: # Read the file with pandas based on the extension if file.name.endswith('.csv'): df = pd.read_csv(file) elif file.name.endswith('.xlsx'): df = pd.read_excel(file) # Convert dataframe to text for document splitting (this could vary based on the structure of the data) # Convert the whole dataframe to string without index text = df.to_string(index=False) # Wrap the string into a Document object document = Document(page_content=text, metadata={"source": file.name}) # Create the splitter and split the document splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=512, chunk_overlap=256, disallowed_special=(), separators=["\n\n", "\n", " "] ) split_docs.extend(splitter.split_documents([document])) return split_docs # Function to ingest documents into the vector database def ingest_into_vectordb(split_docs): embeddings = HuggingFaceEmbeddings( model_name='sentence-transformers/all-MiniLM-L6-v2') db = FAISS.from_documents(split_docs, embeddings) DB_FAISS_PATH = 'vectorstore/db_faiss' db.save_local(DB_FAISS_PATH) return db # Function to get the conversation chain def get_conversation_chain(retriever): llm = Ollama(model="llama3.2:1b") contextualize_q_system_prompt = ( "Given the chat history and the latest user question, " "provide a response that directly addresses the user's query based on the provided documents. " "Do not rephrase the question or ask follow-up questions." ) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) system_prompt = ( "As a personal chat assistant, provide accurate and relevant information based on the provided document in 2-3 sentences. " "Answer should be limited to 50 words and 2-3 sentences. Do not prompt to select answers or formulate a stand-alone question." "{context}" ) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain( history_aware_retriever, question_answer_chain) store = {} def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) return conversational_rag_chain def calculate_similarity_score(answer: str, context_docs: list) -> float: model = SentenceTransformer('all-MiniLM-L6-v2') context_docs = [doc.page_content for doc in context_docs] answer_embedding = model.encode(answer, convert_to_tensor=True) context_embeddings = model.encode(context_docs, convert_to_tensor=True) similarities = util.pytorch_cos_sim(answer_embedding, context_embeddings) max_score = similarities.max().item() return max_score st.title("What can I help with⁉️") # Sidebar for file upload uploaded_files = st.sidebar.file_uploader( "Upload CSV/Excel Documents", type=["csv", "xlsx"], accept_multiple_files=True) if uploaded_files: if st.sidebar.button("Process Documents"): split_docs = prepare_and_split_docs(uploaded_files) vector_db = ingest_into_vectordb(split_docs) retriever = vector_db.as_retriever() st.sidebar.success("Documents processed and vector database created!") # Initialize the conversation chain conversational_chain = get_conversation_chain(retriever) st.session_state.conversational_chain = conversational_chain if 'chat_history' not in st.session_state: st.session_state.chat_history = [] # Chat input st.markdown(button_style, unsafe_allow_html=True) user_input = st.text_input("Ask a question about the dataset:", key="user_input", placeholder="Type your question here...") if st.button("Submit"): st.markdown(button_style, unsafe_allow_html=True) if user_input and 'conversational_chain' in st.session_state: session_id = "abc123" conversational_chain = st.session_state.conversational_chain response = conversational_chain.invoke({"input": user_input}, config={ "configurable": {"session_id": session_id}}) context_docs = response.get('context', []) st.session_state.chat_history.append( {"user": user_input, "bot": response['answer'], "context_docs": context_docs}) # Display chat history if st.session_state.chat_history: for message in st.session_state.chat_history: st.markdown(user_template.format( msg=message['user']), unsafe_allow_html=True) st.markdown(bot_template.format( msg=message['bot']), unsafe_allow_html=True)