Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| from dotenv import load_dotenv | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from langchain.vectorstores import Chroma | |
| from langchain.llms import GPT4All, LlamaCpp | |
| import os | |
| import argparse | |
| from pathlib import Path | |
| import base64 | |
| import gradio as gr | |
| load_dotenv() | |
| embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
| persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
| model_type = os.environ.get('MODEL_TYPE') | |
| model_path = os.environ.get('MODEL_PATH') | |
| model_n_ctx = os.environ.get('MODEL_N_CTX') | |
| from constants import CHROMA_SETTINGS | |
| def main(): | |
| # Parse the command line arguments | |
| args = parse_arguments() | |
| embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
| db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
| retriever = db.as_retriever() | |
| # activate/deactivate the streaming StdOut callback for LLMs | |
| callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] | |
| # Prepare the LLM | |
| '''match model_type: | |
| case "LlamaCpp": | |
| llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) | |
| case "GPT4All": | |
| llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
| case _default: | |
| print(f"Model {model_type} not supported!") | |
| exit;''' | |
| if model_type == "LlamaCpp": | |
| llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) | |
| elif model_type == "GPT4All": | |
| llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
| else: | |
| print(f"Model {model_type} not supported!") | |
| exit; | |
| qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) | |
| # Interactive questions and answers | |
| while True: | |
| query = input("\nEnter a query: ") | |
| if query == "exit": | |
| break | |
| # Get the answer from the chain | |
| res = qa(query) | |
| answer, docs = res['result'], [] if args.hide_source else res['source_documents'] | |
| # Print the result | |
| print("\n\n> Question:") | |
| print(query) | |
| print("\n> Answer:") | |
| print(answer) | |
| # Print the relevant sources used for the answer | |
| for document in docs: | |
| print("\n> " + document.metadata["source"] + ":") | |
| print(document.page_content) | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' | |
| 'using the power of LLMs.') | |
| parser.add_argument("--hide-source", "-S", action='store_true', | |
| help='Use this flag to disable printing of source documents used for answers.') | |
| parser.add_argument("--mute-stream", "-M", | |
| action='store_true', | |
| help='Use this flag to disable the streaming StdOut callback for LLMs.') | |
| return parser.parse_args() | |
| def apply_html(text, color): | |
| if "<table>" in text and "</table>" in text: | |
| # If the text contains table tags, modify the table structure for Gradio | |
| table_start = text.index("<table>") | |
| table_end = text.index("</table>") + len("</table>") | |
| table_content = text[table_start:table_end] | |
| # Modify the table structure for Gradio | |
| modified_table = table_content.replace("<table>", "<table style='border-collapse: collapse;'>") | |
| modified_table = modified_table.replace("<th>", "<th style='border: 1px solid #ddd; padding: 8px; background-color: #f2f2f2;'>") | |
| modified_table = modified_table.replace("<td>", "<td style='border: 1px solid #ddd; padding: 8px;'>") | |
| # Replace the modified table back into the original text | |
| modified_text = text[:table_start] + modified_table + text[table_end:] | |
| return modified_text | |
| else: | |
| # Return the plain text as is | |
| return text | |
| def add_text(history, text): | |
| # Apply selected rules | |
| if history is not None: | |
| # If all rules pass, add message to chat history with bot's response set to None | |
| history.append([apply_html(text, "blue"), None]) | |
| return history, text | |
| def bot(query, history, fileListHistory, k=5): | |
| # Parse the command line arguments | |
| args = parse_arguments() | |
| print("QUERY : " + query) | |
| embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
| db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) | |
| retriever = db.as_retriever() | |
| # activate/deactivate the streaming StdOut callback for LLMs | |
| callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] | |
| # Prepare the LLM | |
| '''match model_type: | |
| case "LlamaCpp": | |
| llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) | |
| case "GPT4All": | |
| llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
| case _default: | |
| print(f"Model {model_type} not supported!") | |
| exit;''' | |
| if model_type == "LlamaCpp": | |
| llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) | |
| elif model_type == "GPT4All": | |
| llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) | |
| else: | |
| print(f"Model {model_type} not supported!") | |
| exit; | |
| qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) | |
| # Get the answer from the chain | |
| res = qa(query) | |
| answer, docs = res['result'], [] if args.hide_source else res['source_documents'] | |
| # Print the result | |
| print("\n\n> Question:") | |
| print(query) | |
| print("\n> Answer:") | |
| print(answer) | |
| # Print the relevant sources used for the answer | |
| for document in docs: | |
| print("\n> " + document.metadata["source"] + ":") | |
| print(document.page_content) | |
| # If the call was not successful after 3 attempts, set the response to a timeout message | |
| if answer is None: | |
| print("Unfortunately, the connection to ChatGPT timed out. Please try after some time.") | |
| if history is not None and len(history) > 0: | |
| # Update the chat history with the bot's response | |
| history[-1][1] = apply_html(answer.text.strip(), "black") | |
| else: | |
| # Print the generated response | |
| print("\nGPT RESPONSE:\n") | |
| # print(answer['choices'][0]['message']['content'].strip()) | |
| if history is not None and len(history) > 0: | |
| # Update the chat history with the bot's response | |
| history[-1][1] = apply_html(answer.strip(), "black") | |
| return history, fileListHistory | |
| # Open the image and convert it to base64 | |
| with open(Path("bot.png"), "rb") as img_file: | |
| img_str = base64.b64encode(img_file.read()).decode() | |
| html_code = f''' | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <style> | |
| .center {{ | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| margin-top: -40px; /* adjust this value as per your requirement */ | |
| margin-bottom: 5px; | |
| }} | |
| .large-text {{ | |
| font-size: 40px; | |
| font-family: Arial, Helvetica, sans-serif; | |
| font-weight: 900 !important; | |
| margin-left: 5px; | |
| color: #5b5b5b !important; | |
| }} | |
| .image-container {{ | |
| display: inline-block; | |
| vertical-align: middle; | |
| height: 10px; /* Twice the font-size */ | |
| margin-bottom: 5px; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="center"> | |
| <img src="data:image/jpg;base64,{img_str}" alt="RyBOT image" class="image-container" /> | |
| <strong class="large-text">Happy Bot</strong> | |
| </div> | |
| <br> | |
| <div class="center"> | |
| <h3> [ "I am Happy Bot, get your answers here" ] </h3> | |
| </div> | |
| </body> | |
| </html> | |
| ''' | |
| css = """ | |
| .feedback textarea {background-color: #e9f0f7} | |
| .gradio-container {background-color: #eeeeee} | |
| """ | |
| def clear_textbox(): | |
| print("Calling CLEAR") | |
| return None | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css, title="RyBOT") as demo: | |
| gr.HTML(html_code) | |
| chatbot = gr.Chatbot([], elem_id="chatbot", label="Chat", color_map=["blue","grey"]).style(height=450) | |
| fileListBot = gr.Chatbot([], elem_id="fileListBot", label="References", color_map=["blue","grey"]).style(height=150) | |
| txt = gr.Textbox( | |
| label="Type your query here:", | |
| placeholder="What would you like to find today?" | |
| ).style(container=True) | |
| txt.submit( | |
| add_text, | |
| [chatbot, txt], | |
| [chatbot, txt] | |
| ).then( | |
| bot, | |
| [txt, chatbot, fileListBot], | |
| [chatbot, fileListBot] | |
| ).then( | |
| clear_textbox, | |
| inputs=None, | |
| outputs=[txt] | |
| ) | |
| btn = gr.Button(value="Send") | |
| btn.click( | |
| add_text, | |
| [chatbot, txt], | |
| [chatbot, txt], | |
| ).then( | |
| bot, | |
| [txt, chatbot, fileListBot], | |
| [chatbot, fileListBot] | |
| ).then( | |
| clear_textbox, | |
| inputs=None, | |
| outputs=[txt] | |
| ) | |
| demo.launch() |