Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import string | |
| import random | |
| import requests | |
| from bs4 import BeautifulSoup | |
| from datetime import datetime | |
| import wget | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| from langchain_community.document_loaders import UnstructuredURLLoader | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.chains import ConversationChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.llms import HuggingFaceHub | |
| from pathlib import Path | |
| import chromadb | |
| from transformers import AutoTokenizer | |
| import transformers | |
| import torch | |
| import tqdm | |
| import accelerate | |
| # default_persist_directory = './chroma_HF/' | |
| list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"] | |
| list_llm_simple = [os.path.basename(llm) for llm in list_llm] | |
| # Load PDF document and create doc splits | |
| def load_doc(list_file_path, chunk_size, chunk_overlap): | |
| loaders = [PyPDFLoader(x) for x in list_file_path] | |
| pages = [] | |
| for loader in loaders: | |
| pages.extend(loader.load()) | |
| print(pages) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size = chunk_size, | |
| chunk_overlap = chunk_overlap) | |
| doc_splits = text_splitter.split_documents(pages) | |
| return doc_splits | |
| def convert_github_url_to_raw(url): | |
| try: | |
| response = requests.get(url) | |
| html_content = response.text | |
| # Step 2: Find the GitHub Icon and Extract the Link | |
| soup = BeautifulSoup(html_content, "html.parser") | |
| github_icon_link = None | |
| for a in soup.find_all('a', href=True): | |
| if "github.com" in a['href']: # Assuming the GitHub link contains "github.com" | |
| github_icon_link = a['href'] | |
| print(github_icon_link) | |
| break | |
| raw_url = github_icon_link.replace("github.com", "raw.githubusercontent.com").replace("/blob", "") | |
| # final_response = requests.get(raw_url) | |
| # content = final_response.text | |
| return raw_url | |
| except Exception as e: | |
| print(e) | |
| return '' | |
| def load_url(list_url_path, chunk_size, chunk_overlap): | |
| urls = [convert_github_url_to_raw(x) for x in list_url_path] | |
| files = [wget.download(x) for x in urls] | |
| loaders = [UnstructuredMarkdownLoader(f'./{x}') for x in files] | |
| pages = [] | |
| for loader in loaders: | |
| pages.extend(loader.load()) | |
| print(pages) | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size = chunk_size, | |
| chunk_overlap = chunk_overlap) | |
| doc_splits = text_splitter.split_documents(pages) | |
| _ = [os.remove(f'./{x}') for x in files] | |
| return doc_splits | |
| # def load_url(list_url_path, chunk_size, chunk_overlap): | |
| # texts = [convert_github_url_to_raw(x) for x in list_url_path] | |
| # pages = [] | |
| # for text in texts: | |
| # pages.append(text) | |
| # print(f'length of pages is {len(pages)}') | |
| # text_splitter = RecursiveCharacterTextSplitter( | |
| # chunk_size = chunk_size, | |
| # chunk_overlap = chunk_overlap) | |
| # total_doc_splits = [] | |
| # docs_ = text_splitter.create_documents(pages) | |
| # print(f"lenth of docs is {len(docs_)}") | |
| # return docs_ | |
| # Create vector database | |
| def create_db(splits, collection_name): | |
| embedding = HuggingFaceEmbeddings() | |
| new_client = chromadb.EphemeralClient() | |
| vectordb = Chroma.from_documents( | |
| documents=splits, | |
| embedding=embedding, | |
| client=new_client, | |
| collection_name=collection_name, | |
| # persist_directory=default_persist_directory | |
| ) | |
| return vectordb | |
| # Load vector database | |
| def load_db(): | |
| embedding = HuggingFaceEmbeddings() | |
| vectordb = Chroma( | |
| # persist_directory=default_persist_directory, | |
| embedding_function=embedding) | |
| return vectordb | |
| # Initialize langchain LLM chain | |
| def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): | |
| progress(0.1, desc="Initializing HF tokenizer...") | |
| # HuggingFaceHub uses HF inference endpoints | |
| progress(0.5, desc="Initializing HF Hub...") | |
| # Use of trust_remote_code as model_kwargs | |
| # URL: https://github.com/langchain-ai/langchain/issues/6080 | |
| llm = HuggingFaceHub( | |
| repo_id=llm_model, | |
| # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"} | |
| model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} | |
| ) | |
| progress(0.75, desc="Defining buffer memory...") | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| output_key='answer', | |
| return_messages=True | |
| ) | |
| # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3}) | |
| retriever=vector_db.as_retriever() | |
| progress(0.8, desc="Defining retrieval chain...") | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm, | |
| retriever=retriever, | |
| chain_type="stuff", | |
| memory=memory, | |
| # combine_docs_chain_kwargs={"prompt": your_prompt}) | |
| return_source_documents=True, | |
| #return_generated_question=False, | |
| verbose=False, | |
| ) | |
| progress(0.9, desc="Done!") | |
| return qa_chain | |
| # Initialize database | |
| def initialize_database(list_file_obj, input_urls, chunk_size, chunk_overlap, progress=gr.Progress()): | |
| # Create list of documents (when valid) | |
| try: | |
| list_file_path = [x.name for x in list_file_obj if x is not None] | |
| # print(f'file paths are {list_file_path}') | |
| except: | |
| list_file_path = None | |
| try: | |
| list_url = [url.strip() for url in input_urls.split(',') if url.strip()] | |
| except: | |
| list_url = None | |
| # Create collection_name for vector database | |
| progress(0.1, desc="Creating collection...") | |
| res = ''.join(random.choices(string.ascii_letters, k=10)) | |
| collection_name = f"HuggingFace101_{res}" | |
| print('Collection name: ', collection_name) | |
| progress(0.25, desc="Loading document...") | |
| # Load document and create splits | |
| if list_file_path is not None: | |
| doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) | |
| else: | |
| doc_splits = [] | |
| if list_url is not None: | |
| url_splits = load_url(list_url, chunk_size, chunk_overlap) | |
| else: | |
| url_splits = [] | |
| # pdf_data_type = type(doc_splits) | |
| # url_data_type = type(url_splits) | |
| # print(pdf_data_type) | |
| # print(url_data_type) | |
| total_splits = [] | |
| total_splits.extend(doc_splits) | |
| total_splits.extend(url_splits) | |
| print(total_splits[0].metadata.keys()) | |
| # Create or load vector database | |
| progress(0.5, desc="Generating vector database...") | |
| # global vector_db | |
| vector_db = create_db(total_splits, collection_name) | |
| progress(0.9, desc="Done!") | |
| return vector_db, collection_name, "Complete!" | |
| def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): | |
| # print("llm_option",llm_option) | |
| llm_name = list_llm[llm_option] | |
| print("llm_name: ",llm_name) | |
| qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) | |
| return qa_chain, "Complete!" | |
| def format_chat_history(message, chat_history): | |
| formatted_chat_history = [] | |
| for user_message, bot_message in chat_history: | |
| formatted_chat_history.append(f"User: {user_message}") | |
| formatted_chat_history.append(f"Assistant: {bot_message}") | |
| return formatted_chat_history | |
| def conversation(qa_chain, message, history): | |
| formatted_chat_history = format_chat_history(message, history) | |
| #print("formatted_chat_history",formatted_chat_history) | |
| # Generate response using QA chain | |
| response = qa_chain({"question": message, "chat_history": formatted_chat_history}) | |
| response_answer = response["answer"] | |
| response_sources = response["source_documents"] | |
| response_source1 = response_sources[0].page_content.strip() | |
| response_source2 = response_sources[1].page_content.strip() | |
| # Langchain sources are zero-based | |
| try: | |
| response_source1_page = response_sources[0].metadata["page"] + 1 | |
| response_source2_page = response_sources[1].metadata["page"] + 1 | |
| except: | |
| response_source1_page = response_sources[0].metadata['source'] | |
| response_source2_page = response_sources[1].metadata['source'] | |
| # print ('chat response: ', response_answer) | |
| # print('DB source', response_sources) | |
| # Append user message and response to chat history | |
| new_history = history + [(message, response_answer)] | |
| # return gr.update(value=""), new_history, response_sources[0], response_sources[1] | |
| return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page | |
| def upload_file(file_obj): | |
| list_file_path = [] | |
| for idx, file in enumerate(file_obj): | |
| file_path = file_obj.name | |
| list_file_path.append(file_path) | |
| # print(file_path) | |
| # initialize_database(file_path, progress) | |
| return list_file_path | |
| def demo(): | |
| with gr.Blocks(theme="base") as demo: | |
| vector_db = gr.State() | |
| qa_chain = gr.State() | |
| collection_name = gr.State() | |
| gr.Markdown( | |
| """<center><h2>HugginFace Articles URL-based chatbot (powered by LangChain and open-source LLMs)</center></h2> | |
| <h3>Ask any questions about your Huggingface Articles, along with follow-ups</h3> | |
| <b>Note:</b> This AI assistant performs retrieval-augmented generation from Huggingface Articles. \ | |
| When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i> | |
| <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br> | |
| """) | |
| with gr.Tab("Step 1 - Document pre-processing"): | |
| with gr.Row(): | |
| document = gr.Files(height=100, | |
| file_count="multiple", | |
| file_types=["pdf"], | |
| interactive=True, | |
| label="Upload your PDF documents (single or multiple)") | |
| input_url = gr.Textbox(label="Or Enter a URL", | |
| value="https://huggingface.co/blog/segmoe", | |
| placeholder="Enter URLs separated by commas" | |
| ) | |
| with gr.Row(): | |
| db_btn = gr.Radio(["ChromaDB"], | |
| label="Vector database type", | |
| value = "ChromaDB", | |
| type="index", | |
| info="Choose your vector database") | |
| with gr.Accordion("Advanced options - Document text splitter", open=False): | |
| with gr.Row(): | |
| slider_chunk_size = gr.Slider(minimum = 100, | |
| maximum = 1000, | |
| value=600, | |
| step=20, | |
| label="Chunk size", | |
| info="Chunk size", | |
| interactive=True) | |
| with gr.Row(): | |
| slider_chunk_overlap = gr.Slider(minimum = 10, | |
| maximum = 200, | |
| value=40, | |
| step=10, | |
| label="Chunk overlap", | |
| info="Chunk overlap", | |
| interactive=True) | |
| with gr.Row(): | |
| db_progress = gr.Textbox(label="Vector database initialization", value="None") | |
| with gr.Row(): | |
| db_btn = gr.Button("Generating vector database...") | |
| with gr.Tab("Step 2 - QA chain initialization"): | |
| with gr.Row(): | |
| llm_btn = gr.Radio(list_llm_simple, | |
| label="LLM models", | |
| value = list_llm_simple[0], | |
| type="index", | |
| info="Choose your LLM model") | |
| with gr.Accordion("Advanced options - LLM model", open=False): | |
| with gr.Row(): | |
| slider_temperature = gr.Slider(minimum = 0.0, | |
| maximum = 1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Model temperature", | |
| interactive=True) | |
| with gr.Row(): | |
| slider_maxtokens = gr.Slider(minimum = 224, | |
| maximum = 4096, | |
| value=1024, | |
| step=32, | |
| label="Max Tokens", | |
| info="Model max tokens", | |
| interactive=True) | |
| with gr.Row(): | |
| slider_topk = gr.Slider(minimum = 1, | |
| maximum = 10, | |
| value=3, | |
| step=1, | |
| label="top-k samples", | |
| info="Model top-k samples", | |
| interactive=True) | |
| with gr.Row(): | |
| llm_progress = gr.Textbox(value="None",label="QA chain initialization") | |
| with gr.Row(): | |
| qachain_btn = gr.Button("Initialize question-answering chain...") | |
| with gr.Tab("Step 3 - Conversation with chatbot"): | |
| chatbot = gr.Chatbot(height=300) | |
| with gr.Accordion("Advanced - Document references", open=False): | |
| with gr.Row(): | |
| doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20) | |
| source1_page = gr.Number(label="Page", scale=1) | |
| with gr.Row(): | |
| doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20) | |
| source2_page = gr.Number(label="Page", scale=1) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Type message", container=True) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.ClearButton([msg, chatbot]) | |
| # Preprocessing events | |
| #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document]) | |
| db_btn.click(initialize_database, \ | |
| inputs=[document, input_url, slider_chunk_size, slider_chunk_overlap], \ | |
| outputs=[vector_db, collection_name, db_progress]) | |
| qachain_btn.click(initialize_LLM, \ | |
| inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \ | |
| outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \ | |
| inputs=None, \ | |
| outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \ | |
| queue=False) | |
| # Chatbot events | |
| msg.submit(conversation, \ | |
| inputs=[qa_chain, msg, chatbot], \ | |
| outputs=[qa_chain, msg, chatbot], \ | |
| queue=False) | |
| submit_btn.click(conversation, \ | |
| inputs=[qa_chain, msg, chatbot], \ | |
| outputs=[qa_chain, msg, chatbot], \ | |
| queue=False) | |
| clear_btn.click(lambda:[None,"",0,"",0], \ | |
| inputs=None, \ | |
| outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \ | |
| queue=False) | |
| demo.queue().launch(debug=True) | |
| if __name__ == "__main__": | |
| demo() | |