| import streamlit as st |
| import os |
| from langchain.document_loaders import PyPDFLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain.vectorstores import Chroma |
| from langchain.chains import ConversationalRetrievalChain |
| from langchain.embeddings import HuggingFaceEmbeddings |
| from langchain.llms import HuggingFacePipeline |
| from langchain.chains import ConversationChain |
| from langchain.memory import ConversationBufferMemory |
| from langchain.llms import HuggingFaceHub |
|
|
| from transformers import AutoTokenizer |
| import transformers |
| import torch |
| import tqdm |
| import accelerate |
|
|
| default_persist_directory = './chroma_HF/' |
|
|
| llm_name1 = "mistralai/Mistral-7B-Instruct-v0.2" |
| llm_name2 = "mistralai/Mistral-7B-Instruct-v0.1" |
| llm_name3 = "meta-llama/Llama-2-7b-chat-hf" |
| llm_name4 = "microsoft/phi-2" |
| llm_name5 = "mosaicml/mpt-7b-instruct" |
| llm_name6 = "tiiuae/falcon-7b-instruct" |
| llm_name7 = "google/flan-t5-xxl" |
| list_llm = [llm_name1, llm_name2, llm_name3, llm_name4, llm_name5, llm_name6, llm_name7] |
| list_llm_simple = [os.path.basename(llm) for llm in list_llm] |
|
|
|
|
|
|
| Load PDF document and create doc splits |
| def load_doc(list_file_path, chunk_size, chunk_overlap): |
| |
| |
| |
| loaders = [PyPDFLoader(x) for x in list_file_path] |
| pages = [] |
| for loader in loaders: |
| pages.extend(loader.load()) |
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size = chunk_size, |
| chunk_overlap = chunk_overlap) |
| doc_splits = text_splitter.split_documents(pages) |
| return doc_splits |
|
|
|
|
| |
| def create_db(splits): |
| embedding = HuggingFaceEmbeddings() |
| vectordb = Chroma.from_documents( |
| documents=splits, |
| embedding=embedding, |
| persist_directory=default_persist_directory |
| ) |
| return vectordb |
|
|
|
|
| |
| def load_db(): |
| embedding = HuggingFaceEmbeddings() |
| vectordb = Chroma( |
| persist_directory=default_persist_directory, |
| embedding_function=embedding) |
| return vectordb |
|
|
|
|
| |
| def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): |
| progress(0.1, desc="Initializing HF tokenizer...") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| progress(0.5, desc="Initializing HF Hub...") |
| llm = HuggingFaceHub( |
| repo_id=llm_model, |
| model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k,\ |
| "trust_remote_code": True, "torch_dtype": "auto"} |
| ) |
| |
| progress(0.75, desc="Defining buffer memory...") |
| memory = ConversationBufferMemory( |
| memory_key="chat_history", |
| output_key='answer', |
| return_messages=True |
| ) |
| |
| retriever=vector_db.as_retriever() |
| progress(0.8, desc="Defining retrieval chain...") |
| qa_chain = ConversationalRetrievalChain.from_llm( |
| llm, |
| retriever=retriever, |
| chain_type="stuff", |
| memory=memory, |
| |
| return_source_documents=True, |
| |
| |
| ) |
| progress(0.9, desc="Done!") |
| return qa_chain |
|
|
|
|
| |
| def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()): |
| |
| |
| list_file_path = [x.name for x in list_file_obj if x is not None] |
| |
| progress(0.25, desc="Loading document...") |
| |
| doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) |
| |
| progress(0.5, desc="Generating vector database...") |
| |
| vector_db = create_db(doc_splits) |
| progress(0.9, desc="Done!") |
| return vector_db, "Complete!" |
|
|
|
|
| def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): |
| |
| llm_name = list_llm[llm_option] |
| |
| qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) |
| return qa_chain, "Complete!" |
|
|
|
|
| def format_chat_history(message, chat_history): |
| formatted_chat_history = [] |
| for user_message, bot_message in chat_history: |
| formatted_chat_history.append(f"User: {user_message}") |
| formatted_chat_history.append(f"Assistant: {bot_message}") |
| return formatted_chat_history |
| |
|
|
| def conversation(qa_chain, message, history): |
| formatted_chat_history = format_chat_history(message, history) |
| |
| |
| |
| response = qa_chain({"question": message, "chat_history": formatted_chat_history}) |
| response_answer = response["answer"] |
| response_sources = response["source_documents"] |
| response_source1 = response_sources[0].page_content.strip() |
| response_source2 = response_sources[1].page_content.strip() |
| |
| response_source1_page = response_sources[0].metadata["page"] + 1 |
| response_source2_page = response_sources[1].metadata["page"] + 1 |
| |
| |
| |
| |
| new_history = history + [(message, response_answer)] |
| |
| return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page |
| |
|
|
| def upload_file(file_obj): |
| list_file_path = [] |
| for idx, file in enumerate(file_obj): |
| file_path = file_obj.name |
| list_file_path.append(file_path) |
| |
| |
| return list_file_path |
|
|
|
|
|
|
|
|
|
|
| def main(): |
| st.title("PDF-based chatbot (powered by LangChain and open-source LLMs)") |
| st.markdown(""" |
| ## Ask any questions about your PDF documents, along with follow-ups |
| **Note:** This AI assistant performs retrieval-augmented generation from your PDF documents. |
| When generating answers, it takes past questions into account (via conversational memory), |
| and includes document references for clarity purposes. |
| \n**Warning:** This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output. |
| """) |
|
|
| |
| st.header("Step 1 - Document pre-processing") |
| uploaded_files = st.file_uploader("Upload your PDF documents (single or multiple)", type="pdf", accept_multiple_files=True) |
| db_btn = st.radio("Vector database type", ["ChromaDB"]) |
|
|
| st.slider("Chunk size", 100, 1000, 600, 20, key="chunk_size") |
| st.slider("Chunk overlap", 10, 200, 40, 10, key="chunk_overlap") |
| |
| if st.button("Generating vector database..."): |
| |
|
|
| |
| st.header("Step 2 - QA chain initialization") |
| llm_option = st.radio("LLM models", list_llm_simple) |
| st.slider("Temperature", 0.0, 1.0, 0.7, 0.1, key="llm_temperature") |
| st.slider("Max Tokens", 224, 4096, 1024, 32, key="max_tokens") |
| st.slider("Top-k samples", 1, 10, 3, 1, key="top_k") |
|
|
| if st.button("Initializing question-answering chain..."): |
| |
|
|
| |
| st.header("Step 3 - Conversation with chatbot") |
| msg = st.text_input("Type message", key="message") |
| if st.button("Submit"): |
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|