Spaces:
Runtime error
Runtime error
File size: 7,455 Bytes
3037327 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import gradio as gr
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
api_token = os.getenv("HF_TOKEN")
# Available LLMs
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Load and split PDF document
def load_doc(list_file_path):
loaders = [PyPDFLoader(file_path) for file_path in list_file_path]
pages = [page for loader in loaders for page in loader.load()]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
return text_splitter.split_documents(pages)
# Create vector database
def create_db(splits):
embeddings = HuggingFaceEmbeddings()
return FAISS.from_documents(splits, embeddings)
# Initialize LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
return_messages=True,
)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
# Initialize database
def initialize_database(list_file_obj, progress=gr.Progress()):
list_file_path = [file.name for file in list_file_obj if file is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "✅ Vector database created successfully!"
# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "✅ Chatbot initialized. Ready to assist!"
# Format chat history for better readability
def format_chat_history(message, chat_history):
return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
# Handle conversation
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"]
response_sources = response["source_documents"]
# Extract sources with their pages
sources = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history, *(item for sublist in sources for item in sublist)
# File upload handling
def upload_file(file_obj):
return [file.name for file in file_obj]
# Gradio UI
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.HTML("""
<div style="background-color: #101010; padding: 15px; border-radius: 0px;">
<h1 style="text-align: center; color: white;">📄 DocuQuery AI</h1>
</div>
<div style="background-color: #101010; padding: 15px; border-radius: 0px; margin-bottom: 20px;">
<p style="color: white; font-size: 16px; text-align: center; font-weight: normal;">
This chatbot enables you to query your PDF documents using Retrieval-Augmented Generation (RAG).<br>
🛑 Please refrain from uploading confidential documents! <br>
This is only for education purpose.
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=86):
gr.Markdown("### Step 1: Upload PDF files and Initialize RAG Pipeline")
document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload PDF Files")
db_btn = gr.Button("Create Vector Database")
db_progress = gr.Textbox(value="⏳ Waiting for input...", show_label=False)
gr.Markdown("### Step 2: Configure Large Language Model (LLM)")
llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index")
with gr.Accordion("LLM Settings (Optional)", open=False):
slider_temperature = gr.Slider(0.01, 1.0, 0.5, 0.1, label="Temperature")
slider_maxtokens = gr.Slider(128, 4096, 2048, 128, label="Max Tokens")
slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k")
qachain_btn = gr.Button("Initialize Chatbot")
llm_progress = gr.Textbox(value="⏳ Waiting for LLM setup...", show_label=False)
with gr.Column(scale=200):
gr.Markdown("### Step 3: Chat with Your Document")
chatbot = gr.Chatbot(height=505)
with gr.Accordion("Context from Source Document", open=False):
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Page", scale=1)
msg = gr.Textbox(placeholder="Type your question here...", container=True)
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton([msg, chatbot], value="Clear Chat")
# Event bindings
db_btn.click(initialize_database, [document], [vector_db, db_progress])
qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], None, [chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo()
|