File size: 6,631 Bytes
71e61b7
3842010
 
 
 
71e61b7
b5d7000
71e61b7
b5d7000
3842010
 
 
1dbbef2
3842010
71e61b7
 
 
 
 
 
b5d7000
 
 
71e61b7
 
 
 
 
 
3842010
b5d7000
71e61b7
b5d7000
71e61b7
 
b5d7000
71e61b7
 
 
 
b5d7000
71e61b7
 
3842010
71e61b7
 
 
b5d7000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71e61b7
 
 
 
b5d7000
 
 
71e61b7
 
b5d7000
71e61b7
b5d7000
 
71e61b7
b5d7000
71e61b7
b5d7000
71e61b7
b5d7000
 
 
71e61b7
b5d7000
 
 
 
 
 
 
 
71e61b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc5c6b1
3842010
71e61b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
import os

# List of available LLMs
list_llm = [
    "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
    "google/gemma-7b-it", "google/gemma-2b-it", "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
    "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct",
    "tiiuae/falcon-7b-instruct", "google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# Load and split PDF document
def load_doc(file_paths, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(fp) for fp in file_paths]
    pages = [page for loader in loaders for page in loader.load()]
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return text_splitter.split_documents(pages)

# Create vector database
def create_db(docs, collection_name):
    embedding = HuggingFaceEmbeddings()
    client = chromadb.EphemeralClient()
    return Chroma.from_documents(documents=docs, embedding=embedding, client=client, collection_name=collection_name)

# Initialize LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    if llm_model in ["mistralai/Mixtral-8x7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
        raise ValueError("LLM model is too large to be loaded automatically on free inference endpoint")
    
    model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
    llm = HuggingFaceEndpoint(repo_id=llm_model, **model_kwargs)
    
    memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
    retriever = vector_db.as_retriever()
    return ConversationalRetrievalChain.from_llm(llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False)

# Generate collection name for vector database
def create_collection_name(filepath):
    collection_name = Path(filepath).stem
    collection_name = unidecode(collection_name)
    collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
    collection_name = collection_name[:50]
    if len(collection_name) < 3:
        collection_name = collection_name + 'xyz'
    if not collection_name[0].isalnum():
        collection_name = 'A' + collection_name[1:]
    if not collection_name[-1].isalnum():
        collection_name = collection_name[:-1] + 'Z'
    return collection_name

# Initialize database
def initialize_database(files, chunk_size, chunk_overlap):
    file_paths = [file.name for file in files]
    collection_name = create_collection_name(file_paths[0])
    doc_splits = load_doc(file_paths, chunk_size, chunk_overlap)
    vector_db = create_db(doc_splits, collection_name)
    return vector_db, collection_name, "Complete!"

# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
    return qa_chain, "Complete!"

# Format chat history
def format_chat_history(message, chat_history):
    return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]

# Handle conversation
def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"].split("Helpful Answer:")[-1] if "Helpful Answer:" in response["answer"] else response["answer"]
    response_sources = response["source_documents"]
    response_source1 = response_sources[0].page_content.strip()
    response_source2 = response_sources[1].page_content.strip()
    response_source3 = response_sources[2].page_content.strip()
    response_source1_page = response_sources[0].metadata["page"] + 1
    response_source2_page = response_sources[1].metadata["page"] + 1
    response_source3_page = response_sources[2].metadata["page"] + 1
    new_history = history + [(message, response_answer)]
    return qa_chain, "", new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page

# Streamlit app
def main():
    st.title("PDF-based Chatbot")
    st.write("Upload your PDF documents and interact with the chatbot to get insights from your PDFs.")

    uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
    if uploaded_files:
        chunk_size = st.slider("Chunk Size", 100, 1000, 600)
        chunk_overlap = st.slider("Chunk Overlap", 10, 200, 40)
        vector_db, collection_name, db_status = initialize_database(uploaded_files, chunk_size, chunk_overlap)
        st.write(f"Vector Database Initialized: {db_status}")

        llm_option = st.selectbox("Select LLM Model", options=list_llm_simple)
        llm_temperature = st.slider("Temperature", 0.01, 1.0, 0.7)
        max_tokens = st.slider("Max Tokens", 224, 4096, 1024)
        top_k = st.slider("Top-K Samples", 1, 10, 3)
        qa_chain, llm_status = initialize_LLM(list_llm_simple.index(llm_option), llm_temperature, max_tokens, top_k, vector_db)
        st.write(f"QA Chain Initialized: {llm_status}")

        st.write("Chat with the bot:")
        chat_history = []
        user_message = st.text_input("Your Message:")
        if st.button("Submit"):
            if user_message:
                qa_chain, _, chat_history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page = conversation(qa_chain, user_message, chat_history)
                st.write(f"**Bot's Response:** {chat_history[-1][1]}")
                st.write(f"**Reference 1:** {doc_source1} (Page {source1_page})")
                st.write(f"**Reference 2:** {doc_source2} (Page {source2_page})")
                st.write(f"**Reference 3:** {doc_source3} (Page {source3_page})")

if __name__ == "__main__":
    main()