Chetan3110 commited on
Commit
3037327
·
verified ·
1 Parent(s): 464dffa

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +149 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain_community.llms import HuggingFaceEndpoint
10
+
11
+ api_token = os.getenv("HF_TOKEN")
12
+
13
+ # Available LLMs
14
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
15
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
16
+
17
+ # Load and split PDF document
18
+ def load_doc(list_file_path):
19
+ loaders = [PyPDFLoader(file_path) for file_path in list_file_path]
20
+ pages = [page for loader in loaders for page in loader.load()]
21
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
22
+ return text_splitter.split_documents(pages)
23
+
24
+ # Create vector database
25
+ def create_db(splits):
26
+ embeddings = HuggingFaceEmbeddings()
27
+ return FAISS.from_documents(splits, embeddings)
28
+
29
+ # Initialize LLM chain
30
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
31
+ llm = HuggingFaceEndpoint(
32
+ repo_id=llm_model,
33
+ huggingfacehub_api_token=api_token,
34
+ temperature=temperature,
35
+ max_new_tokens=max_tokens,
36
+ top_k=top_k,
37
+ )
38
+
39
+ memory = ConversationBufferMemory(
40
+ memory_key="chat_history",
41
+ output_key="answer",
42
+ return_messages=True,
43
+ )
44
+
45
+ retriever = vector_db.as_retriever()
46
+ return ConversationalRetrievalChain.from_llm(
47
+ llm,
48
+ retriever=retriever,
49
+ chain_type="stuff",
50
+ memory=memory,
51
+ return_source_documents=True,
52
+ verbose=False,
53
+ )
54
+
55
+ # Initialize database
56
+ def initialize_database(list_file_obj, progress=gr.Progress()):
57
+ list_file_path = [file.name for file in list_file_obj if file is not None]
58
+ doc_splits = load_doc(list_file_path)
59
+ vector_db = create_db(doc_splits)
60
+ return vector_db, "✅ Vector database created successfully!"
61
+
62
+ # Initialize LLM
63
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
64
+ llm_name = list_llm[llm_option]
65
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
66
+ return qa_chain, "✅ Chatbot initialized. Ready to assist!"
67
+
68
+ # Format chat history for better readability
69
+ def format_chat_history(message, chat_history):
70
+ return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
71
+
72
+ # Handle conversation
73
+ def conversation(qa_chain, message, history):
74
+ formatted_chat_history = format_chat_history(message, history)
75
+ response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
76
+ response_answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"]
77
+ response_sources = response["source_documents"]
78
+
79
+ # Extract sources with their pages
80
+ sources = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
81
+ new_history = history + [(message, response_answer)]
82
+ return qa_chain, gr.update(value=""), new_history, *(item for sublist in sources for item in sublist)
83
+
84
+ # File upload handling
85
+ def upload_file(file_obj):
86
+ return [file.name for file in file_obj]
87
+
88
+ # Gradio UI
89
+ def demo():
90
+ with gr.Blocks() as demo:
91
+ vector_db = gr.State()
92
+ qa_chain = gr.State()
93
+ gr.HTML("""
94
+ <div style="background-color: #101010; padding: 15px; border-radius: 0px;">
95
+ <h1 style="text-align: center; color: white;">📄 DocuQuery AI</h1>
96
+ </div>
97
+ <div style="background-color: #101010; padding: 15px; border-radius: 0px; margin-bottom: 20px;">
98
+ <p style="color: white; font-size: 16px; text-align: center; font-weight: normal;">
99
+ This chatbot enables you to query your PDF documents using Retrieval-Augmented Generation (RAG).<br>
100
+ 🛑 Please refrain from uploading confidential documents! <br>
101
+ This is only for education purpose.
102
+ </p>
103
+ </div>
104
+ """)
105
+
106
+ with gr.Row():
107
+ with gr.Column(scale=86):
108
+ gr.Markdown("### Step 1: Upload PDF files and Initialize RAG Pipeline")
109
+ document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload PDF Files")
110
+ db_btn = gr.Button("Create Vector Database")
111
+ db_progress = gr.Textbox(value="⏳ Waiting for input...", show_label=False)
112
+
113
+ gr.Markdown("### Step 2: Configure Large Language Model (LLM)")
114
+ llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index")
115
+
116
+ with gr.Accordion("LLM Settings (Optional)", open=False):
117
+ slider_temperature = gr.Slider(0.01, 1.0, 0.5, 0.1, label="Temperature")
118
+ slider_maxtokens = gr.Slider(128, 4096, 2048, 128, label="Max Tokens")
119
+ slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k")
120
+ qachain_btn = gr.Button("Initialize Chatbot")
121
+ llm_progress = gr.Textbox(value="⏳ Waiting for LLM setup...", show_label=False)
122
+
123
+ with gr.Column(scale=200):
124
+ gr.Markdown("### Step 3: Chat with Your Document")
125
+ chatbot = gr.Chatbot(height=505)
126
+
127
+ with gr.Accordion("Context from Source Document", open=False):
128
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
129
+ source1_page = gr.Number(label="Page", scale=1)
130
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
131
+ source2_page = gr.Number(label="Page", scale=1)
132
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
133
+ source3_page = gr.Number(label="Page", scale=1)
134
+
135
+ msg = gr.Textbox(placeholder="Type your question here...", container=True)
136
+ submit_btn = gr.Button("Submit")
137
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear Chat")
138
+
139
+ # Event bindings
140
+ db_btn.click(initialize_database, [document], [vector_db, db_progress])
141
+ qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
142
+ msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
143
+ submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
144
+ clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], None, [chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
145
+
146
+ demo.queue().launch(debug=True)
147
+
148
+ if __name__ == "__main__":
149
+ demo()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ langchain-community
6
+ tqdm
7
+ accelerate
8
+ pypdf
9
+ faiss-cpu