#!/usr/bin/env python3 from dotenv import load_dotenv from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceEmbeddings from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma from langchain.llms import GPT4All, LlamaCpp import os import argparse from pathlib import Path import base64 import gradio as gr load_dotenv() embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") persist_directory = os.environ.get('PERSIST_DIRECTORY') model_type = os.environ.get('MODEL_TYPE') model_path = os.environ.get('MODEL_PATH') model_n_ctx = os.environ.get('MODEL_N_CTX') from constants import CHROMA_SETTINGS def main(): # Parse the command line arguments args = parse_arguments() embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) retriever = db.as_retriever() # activate/deactivate the streaming StdOut callback for LLMs callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] # Prepare the LLM '''match model_type: case "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) case "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) case _default: print(f"Model {model_type} not supported!") exit;''' if model_type == "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False) elif model_type == "GPT4All": llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False) else: print(f"Model {model_type} not supported!") exit; qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) # Interactive questions and answers while True: query = input("\nEnter a query: ") if query == "exit": break # Get the answer from the chain res = qa(query) answer, docs = res['result'], [] if args.hide_source else res['source_documents'] # Print the result print("\n\n> Question:") print(query) print("\n> Answer:") print(answer) # Print the relevant sources used for the answer for document in docs: print("\n> " + document.metadata["source"] + ":") print(document.page_content) def parse_arguments(): parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' 'using the power of LLMs.') parser.add_argument("--hide-source", "-S", action='store_true', help='Use this flag to disable printing of source documents used for answers.') parser.add_argument("--mute-stream", "-M", action='store_true', help='Use this flag to disable the streaming StdOut callback for LLMs.') return parser.parse_args() def apply_html(text, color): if "
| ", " | ") modified_table = modified_table.replace(" | ", " | ")
# Replace the modified table back into the original text
modified_text = text[:table_start] + modified_table + text[table_end:]
return modified_text
else:
# Return the plain text as is
return text
def add_text(history, text):
# Apply selected rules
if history is not None:
# If all rules pass, add message to chat history with bot's response set to None
history.append([apply_html(text, "blue"), None])
return history, text
def bot(query, history, fileListHistory, k=5):
# Parse the command line arguments
args = parse_arguments()
print("QUERY : " + query)
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
'''match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
case _default:
print(f"Model {model_type} not supported!")
exit;'''
if model_type == "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
elif model_type == "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
else:
print(f"Model {model_type} not supported!")
exit;
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
# If the call was not successful after 3 attempts, set the response to a timeout message
if answer is None:
print("Unfortunately, the connection to ChatGPT timed out. Please try after some time.")
if history is not None and len(history) > 0:
# Update the chat history with the bot's response
history[-1][1] = apply_html(answer.text.strip(), "black")
else:
# Print the generated response
print("\nGPT RESPONSE:\n")
# print(answer['choices'][0]['message']['content'].strip())
if history is not None and len(history) > 0:
# Update the chat history with the bot's response
history[-1][1] = apply_html(answer.strip(), "black")
return history, fileListHistory
# Open the image and convert it to base64
with open(Path("bot.png"), "rb") as img_file:
img_str = base64.b64encode(img_file.read()).decode()
html_code = f'''
[ "I am Happy Bot, get your answers here" ] |
|---|