SyedSubahani's picture
Update app.py
d851838 verified
import streamlit as st
import replicate
import os
import PyPDF2
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain, RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
st.set_page_config(page_title="💬 Chatbot")
# Get Hugging Face API token from environment variables
replicate_api = os.getenv("REPLICATE_API_TOKEN")
hf_api_token = os.getenv("HF_AUTH_TOKEN")
def load_doc(list_file_path):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=64
)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def read_pdf(uploaded_file):
pdf_reader = PyPDF2.PdfReader(uploaded_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
def split_chunks(docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
)
return text_splitter.split_text(docs)
def create_db(splits):
embeddings_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
vectordb = FAISS.from_texts(splits, embeddings_model)
return vectordb
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=hf_api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='response',
return_messages=True
)
retriever = vector_db.as_retriever()
#qa_chain = ConversationalRetrievalChain.from_llm(
#llm,
#retriever=retriever,
#chain_type="stuff",
##memory=memory,
#return_source_documents=True,
#verbose=False,
# )
qa_chain=RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
#chain_type_kwargs=chain_type_kwargs
)
return qa_chain
def format_chat_history(chat_history):
formatted_chat_history = []
for message in chat_history:
formatted_chat_history.append((message['role'], message['content']))
#st.write(formatted_chat_history)
return formatted_chat_history
def generate_llama2_response(prompt_input):
string_dialogue = "You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'."
for dict_message in st.session_state.messages:
if dict_message["role"] == "user":
string_dialogue += "User: " + dict_message["content"] + "\n\n"
else:
string_dialogue += "Assistant: " + dict_message["content"] + "\n\n"
#st.write(llm_model)
qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, st.session_state.vector_db)
#st.write(llm_model, temperature, max_tokens, top_k)
#formatted_chat_history = format_chat_history(st.session_state.messages)
#response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
response=qa_chain({"query": prompt_input})
#st.write(response["result"])
#st.write(response["answer"])
return response["result"]
def clear_chat_history():
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
# Store LLM generated responses
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
if "vector_db" not in st.session_state:
st.session_state.vector_db = None
if "uploaded_file" not in st.session_state:
st.session_state.uploaded_file = None
# Main app layout
st.title("QA Chatbot with PDF Upload")
st.markdown("---")
# Sidebar for model selection and parameters
with st.sidebar:
uploaded_file = st.sidebar.file_uploader("Upload a PDF", type="pdf")
# Check if a new file has been uploaded or the file has been removed
if uploaded_file != st.session_state.uploaded_file:
st.session_state.uploaded_file = uploaded_file
clear_chat_history()
st.session_state.vector_db = None
if uploaded_file is not None and st.session_state.vector_db is None:
with st.spinner("Converting to Vectors..."):
text = read_pdf(uploaded_file)
chunks = split_chunks(docs=text)
st.session_state.vector_db = create_db(chunks)
#st.sidebar.write("PDF processed and vector database created!")
st.sidebar.markdown('<p style="color:green;">PDF processed and vector database created!</p>', unsafe_allow_html=True)
st.sidebar.title("Model Settings")
selected_model = st.sidebar.selectbox('Choose a LLM model', ['Llama-3-8B', 'Mistral-7B'], key='selected_model')
llm_model = 'meta-llama/Meta-Llama-3-8B-Instruct' if selected_model == 'Llama-3-8B' else 'mistralai/Mistral-7B-Instruct-v0.2'
temperature = st.sidebar.slider('Temperature', 0.1, 1.0, 0.1)
top_k = st.sidebar.slider('Top_k', 1, 10, 3)
max_tokens = st.sidebar.slider('Max Tokens', 1, 512, 256)
#st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
show_clear_button = len(st.session_state.messages) > 1
if st.session_state.vector_db is not None:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# User-provided prompt
if prompt := st.chat_input(disabled=not replicate_api):
#if button_clicked and prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = generate_llama2_response(prompt)
#st.write(response)
placeholder = st.empty()
full_response = ''
for item in response:
full_response += item
#placeholder.markdown(full_response)
placeholder.markdown(full_response)
message = {"role": "assistant", "content": full_response}
#st.write("-------")
#st.write(message)
st.session_state.messages.append(message)
show_clear_button = len(st.session_state.messages) > 1
else:
st.write("Please upload a PDF file to initialize the database.")
if show_clear_button and st.button('Clear Chat History'):
clear_chat_history()
st.experimental_rerun() # This line ensures the page reruns to reflect the changes