File size: 7,455 Bytes
3037327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings 
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint

api_token = os.getenv("HF_TOKEN")

# Available LLMs
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]  
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# Load and split PDF document
def load_doc(list_file_path):
    loaders = [PyPDFLoader(file_path) for file_path in list_file_path]
    pages = [page for loader in loaders for page in loader.load()]
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
    return text_splitter.split_documents(pages)

# Create vector database
def create_db(splits):
    embeddings = HuggingFaceEmbeddings()
    return FAISS.from_documents(splits, embeddings)

# Initialize LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        huggingfacehub_api_token=api_token,
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k,
    )
    
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        return_messages=True,
    )

    retriever = vector_db.as_retriever()
    return ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )

# Initialize database
def initialize_database(list_file_obj, progress=gr.Progress()):
    list_file_path = [file.name for file in list_file_obj if file is not None]
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "✅ Vector database created successfully!"

# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
    return qa_chain, "✅ Chatbot initialized. Ready to assist!"

# Format chat history for better readability
def format_chat_history(message, chat_history):
    return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]

# Handle conversation
def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"]
    response_sources = response["source_documents"]
    
    # Extract sources with their pages
    sources = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
    new_history = history + [(message, response_answer)]
    return qa_chain, gr.update(value=""), new_history, *(item for sublist in sources for item in sublist)

# File upload handling
def upload_file(file_obj):
    return [file.name for file in file_obj]

# Gradio UI
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.HTML("""

        <div style="background-color: #101010; padding: 15px; border-radius: 0px;">

            <h1 style="text-align: center; color: white;">📄 DocuQuery AI</h1>

        </div>

        <div style="background-color: #101010; padding: 15px; border-radius: 0px; margin-bottom: 20px;">

            <p style="color: white; font-size: 16px; text-align: center; font-weight: normal;">

                This chatbot enables you to query your PDF documents using Retrieval-Augmented Generation (RAG).<br>  

                🛑 Please refrain from uploading confidential documents! <br>

                This is only for education purpose.

            </p>

        </div>

        """)
        
        with gr.Row():
            with gr.Column(scale=86):
                gr.Markdown("### Step 1: Upload PDF files and Initialize RAG Pipeline")
                document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload PDF Files")
                db_btn = gr.Button("Create Vector Database")
                db_progress = gr.Textbox(value="⏳ Waiting for input...", show_label=False)
                
                gr.Markdown("### Step 2: Configure Large Language Model (LLM)")
                llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index")
                
                with gr.Accordion("LLM Settings (Optional)", open=False):
                    slider_temperature = gr.Slider(0.01, 1.0, 0.5, 0.1, label="Temperature")
                    slider_maxtokens = gr.Slider(128, 4096, 2048, 128, label="Max Tokens")
                    slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k")
                qachain_btn = gr.Button("Initialize Chatbot")
                llm_progress = gr.Textbox(value="⏳ Waiting for LLM setup...", show_label=False)

            with gr.Column(scale=200):
                gr.Markdown("### Step 3: Chat with Your Document")
                chatbot = gr.Chatbot(height=505)
                
                with gr.Accordion("Context from Source Document", open=False):
                    doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                    source1_page = gr.Number(label="Page", scale=1)
                    doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
                    source2_page = gr.Number(label="Page", scale=1)
                    doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
                    source3_page = gr.Number(label="Page", scale=1)
                
                msg = gr.Textbox(placeholder="Type your question here...", container=True)
                submit_btn = gr.Button("Submit")
                clear_btn = gr.ClearButton([msg, chatbot], value="Clear Chat")

        # Event bindings
        db_btn.click(initialize_database, [document], [vector_db, db_progress])
        qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
        msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
        submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
        clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], None, [chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])

    demo.queue().launch(debug=True)

if __name__ == "__main__":
    demo()