private-gpt / app.py
praneeth dodedu
changed icons
630760f
#!/usr/bin/env python3
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp
import os
import argparse
from pathlib import Path
import base64
import gradio as gr
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
from constants import CHROMA_SETTINGS
def main():
# Parse the command line arguments
args = parse_arguments()
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
'''match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
case _default:
print(f"Model {model_type} not supported!")
exit;'''
if model_type == "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
elif model_type == "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
else:
print(f"Model {model_type} not supported!")
exit;
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
def parse_arguments():
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
'using the power of LLMs.')
parser.add_argument("--hide-source", "-S", action='store_true',
help='Use this flag to disable printing of source documents used for answers.')
parser.add_argument("--mute-stream", "-M",
action='store_true',
help='Use this flag to disable the streaming StdOut callback for LLMs.')
return parser.parse_args()
def apply_html(text, color):
if "<table>" in text and "</table>" in text:
# If the text contains table tags, modify the table structure for Gradio
table_start = text.index("<table>")
table_end = text.index("</table>") + len("</table>")
table_content = text[table_start:table_end]
# Modify the table structure for Gradio
modified_table = table_content.replace("<table>", "<table style='border-collapse: collapse;'>")
modified_table = modified_table.replace("<th>", "<th style='border: 1px solid #ddd; padding: 8px; background-color: #f2f2f2;'>")
modified_table = modified_table.replace("<td>", "<td style='border: 1px solid #ddd; padding: 8px;'>")
# Replace the modified table back into the original text
modified_text = text[:table_start] + modified_table + text[table_end:]
return modified_text
else:
# Return the plain text as is
return text
def add_text(history, text):
# Apply selected rules
if history is not None:
# If all rules pass, add message to chat history with bot's response set to None
history.append([apply_html(text, "blue"), None])
return history, text
def bot(query, history, fileListHistory, k=5):
# Parse the command line arguments
args = parse_arguments()
print("QUERY : " + query)
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
'''match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
case _default:
print(f"Model {model_type} not supported!")
exit;'''
if model_type == "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
elif model_type == "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
else:
print(f"Model {model_type} not supported!")
exit;
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
# If the call was not successful after 3 attempts, set the response to a timeout message
if answer is None:
print("Unfortunately, the connection to ChatGPT timed out. Please try after some time.")
if history is not None and len(history) > 0:
# Update the chat history with the bot's response
history[-1][1] = apply_html(answer.text.strip(), "black")
else:
# Print the generated response
print("\nGPT RESPONSE:\n")
# print(answer['choices'][0]['message']['content'].strip())
if history is not None and len(history) > 0:
# Update the chat history with the bot's response
history[-1][1] = apply_html(answer.strip(), "black")
return history, fileListHistory
# Open the image and convert it to base64
with open(Path("bot.png"), "rb") as img_file:
img_str = base64.b64encode(img_file.read()).decode()
html_code = f'''
<!DOCTYPE html>
<html>
<head>
<style>
.center {{
display: flex;
justify-content: center;
align-items: center;
margin-top: -40px; /* adjust this value as per your requirement */
margin-bottom: 5px;
}}
.large-text {{
font-size: 40px;
font-family: Arial, Helvetica, sans-serif;
font-weight: 900 !important;
margin-left: 5px;
color: #5b5b5b !important;
}}
.image-container {{
display: inline-block;
vertical-align: middle;
height: 10px; /* Twice the font-size */
margin-bottom: 5px;
}}
</style>
</head>
<body>
<div class="center">
<img src="data:image/jpg;base64,{img_str}" alt="RyBOT image" class="image-container" />
<strong class="large-text">Happy Bot</strong>
</div>
<br>
<div class="center">
<h3> [ "I am Happy Bot, get your answers here" ] </h3>
</div>
</body>
</html>
'''
css = """
.feedback textarea {background-color: #e9f0f7}
.gradio-container {background-color: #eeeeee}
"""
def clear_textbox():
print("Calling CLEAR")
return None
with gr.Blocks(theme=gr.themes.Soft(), css=css, title="RyBOT") as demo:
gr.HTML(html_code)
chatbot = gr.Chatbot([], elem_id="chatbot", label="Chat", color_map=["blue","grey"]).style(height=450)
fileListBot = gr.Chatbot([], elem_id="fileListBot", label="References", color_map=["blue","grey"]).style(height=150)
txt = gr.Textbox(
label="Type your query here:",
placeholder="What would you like to find today?"
).style(container=True)
txt.submit(
add_text,
[chatbot, txt],
[chatbot, txt]
).then(
bot,
[txt, chatbot, fileListBot],
[chatbot, fileListBot]
).then(
clear_textbox,
inputs=None,
outputs=[txt]
)
btn = gr.Button(value="Send")
btn.click(
add_text,
[chatbot, txt],
[chatbot, txt],
).then(
bot,
[txt, chatbot, fileListBot],
[chatbot, fileListBot]
).then(
clear_textbox,
inputs=None,
outputs=[txt]
)
demo.launch()