#!/usr/bin/env python3 from dotenv import load_dotenv from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceEmbeddings from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma from langchain.llms import GPT4All, LlamaCpp import os import argparse from pathlib import Path import base64 import gradio as gr load_dotenv() embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") persist_directory = os.environ.get('PERSIST_DIRECTORY') model_type = os.environ.get('MODEL_TYPE') model_path = os.environ.get('MODEL_PATH') model_n_ctx = os.environ.get('MODEL_N_CTX') from constants import CHROMA_SETTINGS def main(): # Parse the command line arguments args = parse_arguments() embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) retriever = db.as_retriever() # activate/deactivate the streaming StdOut callback for LLMs callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] # Prepare the LLM '''match model_type: case "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) case "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) case _default: print(f"Model {model_type} not supported!") exit;''' if model_type == "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) elif model_type == "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) else: print(f"Model {model_type} not supported!") exit; qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) # Interactive questions and answers while True: query = input("\nEnter a query: ") if query == "exit": break # Get the answer from the chain res = qa(query) answer, docs = res['result'], [] if args.hide_source else res['source_documents'] # Print the result print("\n\n> Question:") print(query) print("\n> Answer:") print(answer) # Print the relevant sources used for the answer for document in docs: print("\n> " + document.metadata["source"] + ":") print(document.page_content) def parse_arguments(): parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' 'using the power of LLMs.') parser.add_argument("--hide-source", "-S", action='store_true', help='Use this flag to disable printing of source documents used for answers.') parser.add_argument("--mute-stream", "-M", action='store_true', help='Use this flag to disable the streaming StdOut callback for LLMs.') return parser.parse_args() def apply_html(text, color): if "" in text and "
" in text: # If the text contains table tags, modify the table structure for Gradio table_start = text.index("") table_end = text.index("
") + len("") table_content = text[table_start:table_end] # Modify the table structure for Gradio modified_table = table_content.replace("", "
") modified_table = modified_table.replace("
", "") modified_table = modified_table.replace("", "") # Replace the modified table back into the original text modified_text = text[:table_start] + modified_table + text[table_end:] return modified_text else: # Return the plain text as is return text def add_text(history, text): # Apply selected rules if history is not None: # If all rules pass, add message to chat history with bot's response set to None history.append([apply_html(text, "blue"), None]) return history, text def bot(query, history, fileListHistory, k=5): # Parse the command line arguments args = parse_arguments() print("QUERY : " + query) embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) retriever = db.as_retriever() # activate/deactivate the streaming StdOut callback for LLMs callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] # Prepare the LLM '''match model_type: case "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) case "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) case _default: print(f"Model {model_type} not supported!") exit;''' if model_type == "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) elif model_type == "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) else: print(f"Model {model_type} not supported!") exit; qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) # Get the answer from the chain res = qa(query) answer, docs = res['result'], [] if args.hide_source else res['source_documents'] # Print the result print("\n\n> Question:") print(query) print("\n> Answer:") print(answer) # Print the relevant sources used for the answer for document in docs: print("\n> " + document.metadata["source"] + ":") print(document.page_content) # If the call was not successful after 3 attempts, set the response to a timeout message if answer is None: print("Unfortunately, the connection to ChatGPT timed out. Please try after some time.") if history is not None and len(history) > 0: # Update the chat history with the bot's response history[-1][1] = apply_html(answer.text.strip(), "black") else: # Print the generated response print("\nGPT RESPONSE:\n") # print(answer['choices'][0]['message']['content'].strip()) if history is not None and len(history) > 0: # Update the chat history with the bot's response history[-1][1] = apply_html(answer.strip(), "black") return history, fileListHistory # Open the image and convert it to base64 with open(Path("bot.png"), "rb") as img_file: img_str = base64.b64encode(img_file.read()).decode() html_code = f'''
RyBOT image Happy Bot

[ "I am Happy Bot, get your answers here" ]

''' css = """ .feedback textarea {background-color: #e9f0f7} .gradio-container {background-color: #eeeeee} """ def clear_textbox(): print("Calling CLEAR") return None with gr.Blocks(theme=gr.themes.Soft(), css=css, title="RyBOT") as demo: gr.HTML(html_code) chatbot = gr.Chatbot([], elem_id="chatbot", label="Chat", color_map=["blue","grey"]).style(height=450) fileListBot = gr.Chatbot([], elem_id="fileListBot", label="References", color_map=["blue","grey"]).style(height=150) txt = gr.Textbox( label="Type your query here:", placeholder="What would you like to find today?" ).style(container=True) txt.submit( add_text, [chatbot, txt], [chatbot, txt] ).then( bot, [txt, chatbot, fileListBot], [chatbot, fileListBot] ).then( clear_textbox, inputs=None, outputs=[txt] ) btn = gr.Button(value="Send") btn.click( add_text, [chatbot, txt], [chatbot, txt], ).then( bot, [txt, chatbot, fileListBot], [chatbot, fileListBot] ).then( clear_textbox, inputs=None, outputs=[txt] ) demo.launch()