Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from llama_index.core import SimpleDirectoryReader | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core import VectorStoreIndex | |
| from custom_llm import CustomLLM | |
| import gradio as gr | |
| # import shutil | |
| import tempfile | |
| repo_id = "mistralai/Mistral-7B-Instruct-v0.1" | |
| model_type = 'text-generation' | |
| API_TOKEN = os.getenv('HF_INFER_API') | |
| temp_dir = tempfile.TemporaryDirectory() | |
| embedding_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
| llm = CustomLLM(repo_id=repo_id, model_type=model_type, api_token=API_TOKEN) | |
| def add_text(history, text): | |
| history = history + [(text, None)] | |
| return history, gr.Textbox(value="", interactive=False) | |
| def hasFile(history): | |
| pdf_files = 0 | |
| for user_prompt, bot_response in history: | |
| if '.pdf' in user_prompt.lower(): | |
| pdf_files += 1 | |
| return pdf_files | |
| def modelChanged(history, drop): | |
| history = history + [(f'===> {drop}', None)] | |
| return history, drop | |
| def getEngine(llm): | |
| loader = SimpleDirectoryReader( | |
| input_dir=temp_dir.name, | |
| recursive=True, | |
| required_exts=[".pdf", ".PDF"], | |
| ) | |
| # Load files as documents | |
| documents = loader.load_data() | |
| # create an index in the memory | |
| index = VectorStoreIndex.from_documents( | |
| documents, | |
| embed_model=embedding_model, | |
| ) | |
| #create query_engine | |
| query_engine = index.as_query_engine(llm=llm) | |
| return query_engine | |
| def copy_pdf(source_path, destination_path): | |
| # Open the source PDF file in binary read mode | |
| with open(source_path, "rb") as source_file: | |
| # Read the entire content of the source file | |
| data = source_file.read() | |
| # Open the destination file in binary write mode | |
| with open(destination_path, "wb") as destination_file: | |
| # Write the copied data to the destination file | |
| destination_file.write(data) | |
| # Print a success message | |
| print(f"PDF copied successfully from {source_path} to {destination_path}") | |
| def add_file(history, file): | |
| pdf_files = hasFile(history) | |
| if pdf_files + 1 >= 4: | |
| history = history + [("%s!!!"%os.path.basename(file), None)] | |
| return history | |
| file_path = os.path.join(temp_dir.name, os.path.basename(file)) | |
| # shutil.copyfile(file.name, file_path) # <---Asynchronous | |
| copy_pdf(file.name, file_path) | |
| history = history + [(os.path.basename(file), None)] | |
| return history | |
| def clearClick(): | |
| print("clear temp files...") | |
| temp_dir.cleanup() | |
| def format_prompt(message, history, model): | |
| if model is None or 'mistral' in model.lower(): | |
| prompt = "<s>" | |
| for user_prompt, bot_response in history: | |
| prompt += f"[INST] {user_prompt} [/INST]" | |
| prompt += f" {bot_response}</s> " | |
| prompt += f"[INST] {message} [/INST]" | |
| elif 'google' in model.lower(): | |
| prompt = "<bos>" | |
| for user_prompt, bot_response in history: | |
| prompt += f"<start_of_turn>user {user_prompt} <end_of_turn><start_of_turn>model {bot_response}" | |
| prompt += f"<start_of_turn>user {message} <end_of_turn><start_of_turn>model" | |
| else: | |
| prompt = "" | |
| return prompt | |
| def bot(history, model=None): | |
| print("===> model: ", model) | |
| local_llm = llm | |
| if model: | |
| local_llm = CustomLLM(repo_id=model, model_type=model_type, api_token=API_TOKEN) | |
| if len(history) > 0 and len(history[-1]) > 0 and '===>' in history[-1][0]: | |
| new_model = history[-1][0].replace("===>", "") | |
| response = f"You have changed the model to {new_model}" | |
| elif len(history) > 0 and len(history[-1]) > 0 and '.pdf!!!' in history[-1][0].lower(): | |
| response = f"Unable to add file. Maximum 3 files allowed." | |
| elif len(history) > 0 and len(history[-1]) > 0 and '.pdf' in history[-1][0]: | |
| response = "You uploaded a PDF file. You can ask questions from the file." | |
| else: | |
| prompt = history[-1][0] | |
| if hasFile(history): | |
| query_engine = getEngine(local_llm) | |
| response = query_engine.query(prompt) | |
| print("Response from file") | |
| else: | |
| response = local_llm.predict(format_prompt(prompt, history, model)) | |
| print("Response from Model") | |
| # print(response) | |
| # response = "Thats cool!" | |
| history[-1][1] = "" | |
| for character in str(response): | |
| history[-1][1] += character | |
| # time.sleep(0.05) | |
| yield history | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| <div style="display: grid; justify-content: center;"> | |
| <h1>Basic RAG with Huggingface Inference API</h1> | |
| <h4>For best performance start with small PDF files (less than 20 pages). </h4> | |
| </div> | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| # avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))), | |
| ) | |
| with gr.Row(): | |
| drop = gr.Dropdown( | |
| [ | |
| ("Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"), | |
| ("Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.2"), | |
| ("gemma-7b-it", "google/gemma-7b-it"), | |
| ("gemma-2b-it", "google/gemma-2b-it") | |
| ], | |
| value="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| label="Model", | |
| info="" | |
| ) | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| scale=4, | |
| show_label=False, | |
| placeholder="Type your question and press enter", | |
| container=False, | |
| ) | |
| btn = gr.UploadButton("📁", file_types=[".pdf"]) | |
| clear_btn = gr.ClearButton([chatbot, txt]) | |
| drop.change(modelChanged, [chatbot, drop], [chatbot, drop], queue=False).then( | |
| bot, [chatbot, drop], chatbot | |
| ) | |
| txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
| bot, [chatbot, drop], chatbot, api_name="bot_response" | |
| ) | |
| txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
| file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
| bot, [chatbot, drop], chatbot | |
| ) | |
| clear_btn.click(clearClick) | |
| demo.queue() | |
| demo.launch() | |