runsdata's picture
Update app.py
5961f5c
raw
history blame
5.73 kB
import os
import time
import csv
import shutil
from datetime import datetime
import openai
import gradio as gr
import pandas as pd
# Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
# Chat Q&A
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
# This sets up OpenAI embeddings model
embeddings = OpenAIEmbeddings()
# Loads database from persisted directory
db_directory = "./docs/2023_12_21_chroma_db"
db = Chroma(persist_directory=db_directory, embedding_function=embeddings)
# This is code that retrieves relevant documents based on a similarity search (in this case, it grabs the top 2 relevant documents or chunks)
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k":2})
with open('system_prompt.txt', 'r') as file:
ORIG_SYSTEM_MESSAGE_PROMPT = file.read()
with open('user_info_simulated.txt', 'r') as file:
user_info_simulated = file.read()
openai.api_key = os.getenv("OPENAI_API_KEY")
#chat = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0) # Faster for experiments
chat = ChatOpenAI(model_name="gpt-4",temperature=0)
# Make sure we don't exceed estimation of token limit:
TOKEN_LIMIT = 4096 # GPT-3.5 Turbo token limit
BUFFER = 100 # Extra tokens to consider for incoming messages
PERSISTENT_LOG_PATH = "/data/downvoted_responses.csv" # File in which to log downvoted responses
LOCAL_LOG_PATH = "./data/downvoted_responses.csv"
def estimate_tokens(texts):
return sum([len(t.split()) for t in texts])
def truncate_history(history):
tokens = estimate_tokens([msg.content for msg in history])
while tokens + BUFFER > TOKEN_LIMIT and len(history) > 3:
history = history[0:1] + history[3:]
tokens = estimate_tokens([msg.content for msg in history])
return history
def get_full_context(input):
retrieved_documents = retriever.get_relevant_documents(input)
context = ""
file_path = "./docs/Troubleshooting_Table.csv"
data = pd.read_csv(file_path)
for doc in retrieved_documents:
index = doc.metadata['index']
row_string = data.iloc[[index]].to_string(index=False)
context += row_string + "\n\n"
return context
is_first_run = True # Flag to check if it's the first run
# Here is the langchain
def predict(history, input):
global is_first_run # Use the global flag
if is_first_run:
context = get_full_context(input)
print(context) # For debugging
is_first_run = False # Set the flag to False after the first run
else:
context = ""
history_langchain_format = []
history_langchain_format.append(SystemMessage(content=f"{ORIG_SYSTEM_MESSAGE_PROMPT}, here is the user information: {user_info_simulated}"))
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=input))
history_langchain_format.append(SystemMessage(content=f"Here is a table with some potentially useful information for troubleshooting: {context}"))
# Truncate if history is too long
history_langchain_format = truncate_history(history_langchain_format)
gpt_response = chat(history_langchain_format)
# Extract pairs of HumanMessage and AIMessage
pairs = []
for i in range(len(history_langchain_format)):
if isinstance(history_langchain_format[i], HumanMessage) and (i+1 < len(history_langchain_format)) and isinstance(history_langchain_format[i+1], AIMessage):
pairs.append((history_langchain_format[i].content, history_langchain_format[i+1].content))
# Add the new AI response to the pairs for subsequent interactions
pairs.append((input, gpt_response.content))
return pairs
# Function to handle user message (this clears the interface)
def user(user_message, chatbot_history):
return "", chatbot_history + [[user_message, ""]]
# Function to handle AI's response
def bot(chatbot_history):
user_message = chatbot_history[-1][0] #This line is because we cleared the user_message previously in the user function above
# Call the predict function to get the AI's response
pairs = predict(chatbot_history, user_message)
_, ai_response = pairs[-1] # Get the latest response
response_in_progress = ""
for character in ai_response:
response_in_progress += character
chatbot_history[-1][1] = response_in_progress
time.sleep(0.05)
yield chatbot_history
# This is a function to do something with the voted information (TODO: Save this info somewhere?)
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
with open("output.txt", "a") as text_file:
print(f"Disliked content: {data.value}", file=text_file)
def reset_flag():
global is_first_run
is_first_run = True
# The Gradio App interface
with gr.Blocks() as demo:
gr.Markdown("""<h1><center>TROUBLESHOOTING Bot by CIONIC</center></h1>""")
gr.Markdown("""<p><center>To open a new case, press the clear button.</center></p>""")
chatbot = gr.Chatbot()
textbox = gr.Textbox()
clear_button = gr.ClearButton(components=[chatbot])
clear_button.click(reset_flag, None, None)
# Chain user and bot functions with `.then()`
textbox.submit(user, [textbox, chatbot], [textbox, chatbot], queue=False).then(
bot, chatbot, chatbot,
)
chatbot.like(vote, None, None)
# Enable queuing
demo.queue()
demo.launch(debug=True)