|
|
import streamlit as st |
|
|
import replicate |
|
|
import os |
|
|
import PyPDF2 |
|
|
|
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.chains import ConversationalRetrievalChain, RetrievalQA |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
st.set_page_config(page_title="💬 Chatbot") |
|
|
|
|
|
|
|
|
replicate_api = os.getenv("REPLICATE_API_TOKEN") |
|
|
hf_api_token = os.getenv("HF_AUTH_TOKEN") |
|
|
|
|
|
def load_doc(list_file_path): |
|
|
loaders = [PyPDFLoader(x) for x in list_file_path] |
|
|
pages = [] |
|
|
for loader in loaders: |
|
|
pages.extend(loader.load()) |
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1024, |
|
|
chunk_overlap=64 |
|
|
) |
|
|
doc_splits = text_splitter.split_documents(pages) |
|
|
return doc_splits |
|
|
|
|
|
def read_pdf(uploaded_file): |
|
|
pdf_reader = PyPDF2.PdfReader(uploaded_file) |
|
|
text = "" |
|
|
for page in pdf_reader.pages: |
|
|
text += page.extract_text() |
|
|
return text |
|
|
|
|
|
def split_chunks(docs): |
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, |
|
|
chunk_overlap=100, |
|
|
) |
|
|
return text_splitter.split_text(docs) |
|
|
|
|
|
def create_db(splits): |
|
|
embeddings_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') |
|
|
vectordb = FAISS.from_texts(splits, embeddings_model) |
|
|
return vectordb |
|
|
|
|
|
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): |
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id=llm_model, |
|
|
huggingfacehub_api_token=hf_api_token, |
|
|
temperature=temperature, |
|
|
max_new_tokens=max_tokens, |
|
|
top_k=top_k, |
|
|
) |
|
|
|
|
|
memory = ConversationBufferMemory( |
|
|
memory_key="chat_history", |
|
|
output_key='response', |
|
|
return_messages=True |
|
|
) |
|
|
|
|
|
retriever = vector_db.as_retriever() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qa_chain=RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
chain_type="stuff", |
|
|
retriever=retriever, |
|
|
return_source_documents=True, |
|
|
|
|
|
) |
|
|
|
|
|
return qa_chain |
|
|
|
|
|
def format_chat_history(chat_history): |
|
|
formatted_chat_history = [] |
|
|
for message in chat_history: |
|
|
formatted_chat_history.append((message['role'], message['content'])) |
|
|
|
|
|
return formatted_chat_history |
|
|
|
|
|
def generate_llama2_response(prompt_input): |
|
|
string_dialogue = "You are a helpful assistant. You do not respond as 'User' or pretend to be 'User'. You only respond once as 'Assistant'." |
|
|
for dict_message in st.session_state.messages: |
|
|
if dict_message["role"] == "user": |
|
|
string_dialogue += "User: " + dict_message["content"] + "\n\n" |
|
|
else: |
|
|
string_dialogue += "Assistant: " + dict_message["content"] + "\n\n" |
|
|
|
|
|
|
|
|
qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, st.session_state.vector_db) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response=qa_chain({"query": prompt_input}) |
|
|
|
|
|
|
|
|
|
|
|
return response["result"] |
|
|
|
|
|
def clear_chat_history(): |
|
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state.keys(): |
|
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] |
|
|
|
|
|
if "vector_db" not in st.session_state: |
|
|
st.session_state.vector_db = None |
|
|
|
|
|
if "uploaded_file" not in st.session_state: |
|
|
st.session_state.uploaded_file = None |
|
|
|
|
|
|
|
|
st.title("QA Chatbot with PDF Upload") |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
uploaded_file = st.sidebar.file_uploader("Upload a PDF", type="pdf") |
|
|
|
|
|
|
|
|
if uploaded_file != st.session_state.uploaded_file: |
|
|
st.session_state.uploaded_file = uploaded_file |
|
|
clear_chat_history() |
|
|
st.session_state.vector_db = None |
|
|
|
|
|
if uploaded_file is not None and st.session_state.vector_db is None: |
|
|
with st.spinner("Converting to Vectors..."): |
|
|
text = read_pdf(uploaded_file) |
|
|
chunks = split_chunks(docs=text) |
|
|
st.session_state.vector_db = create_db(chunks) |
|
|
|
|
|
|
|
|
st.sidebar.markdown('<p style="color:green;">PDF processed and vector database created!</p>', unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title("Model Settings") |
|
|
selected_model = st.sidebar.selectbox('Choose a LLM model', ['Llama-3-8B', 'Mistral-7B'], key='selected_model') |
|
|
llm_model = 'meta-llama/Meta-Llama-3-8B-Instruct' if selected_model == 'Llama-3-8B' else 'mistralai/Mistral-7B-Instruct-v0.2' |
|
|
temperature = st.sidebar.slider('Temperature', 0.1, 1.0, 0.1) |
|
|
top_k = st.sidebar.slider('Top_k', 1, 10, 3) |
|
|
max_tokens = st.sidebar.slider('Max Tokens', 1, 512, 256) |
|
|
|
|
|
|
|
|
|
|
|
show_clear_button = len(st.session_state.messages) > 1 |
|
|
|
|
|
if st.session_state.vector_db is not None: |
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.write(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input(disabled=not replicate_api): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.write(prompt) |
|
|
|
|
|
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
response = generate_llama2_response(prompt) |
|
|
|
|
|
placeholder = st.empty() |
|
|
full_response = '' |
|
|
for item in response: |
|
|
full_response += item |
|
|
|
|
|
placeholder.markdown(full_response) |
|
|
message = {"role": "assistant", "content": full_response} |
|
|
|
|
|
|
|
|
st.session_state.messages.append(message) |
|
|
show_clear_button = len(st.session_state.messages) > 1 |
|
|
|
|
|
else: |
|
|
st.write("Please upload a PDF file to initialize the database.") |
|
|
|
|
|
if show_clear_button and st.button('Clear Chat History'): |
|
|
clear_chat_history() |
|
|
st.experimental_rerun() |
|
|
|