runsdata's picture
Update app.py
b91fbb2
import os
import time
import openai
import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
# Sets up OpenAI embeddings model
embeddings = OpenAIEmbeddings()
# Loads database from persisted directory
db_directory = "chroma_db"
db = Chroma(persist_directory=db_directory, embedding_function=embeddings)
# Retrieves relevant documents based on a similarity search
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k":3})
with open('system_prompt.txt', 'r') as file:
ORIG_SYSTEM_MESSAGE_PROMPT = file.read()
openai.api_key = os.getenv("OPENAI_API_KEY")
chat = ChatOpenAI(model_name="gpt-4",temperature=0)
# Here is the langchain
def predict(history, input):
context = retriever.get_relevant_documents(input)
print(context) #For debugging
history_langchain_format = []
history_langchain_format.append(SystemMessage(content=f"{ORIG_SYSTEM_MESSAGE_PROMPT}"))
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=input))
history_langchain_format.append(SystemMessage(content=f"Here are some stories the user may like: {context}"))
gpt_response = chat(history_langchain_format)
# Extract pairs of HumanMessage and AIMessage
pairs = []
for i in range(len(history_langchain_format)):
if isinstance(history_langchain_format[i], HumanMessage) and (i+1 < len(history_langchain_format)) and isinstance(history_langchain_format[i+1], AIMessage):
pairs.append((history_langchain_format[i].content, history_langchain_format[i+1].content))
# Add new AI response to the pairs for subsequent interactions
pairs.append((input, gpt_response.content))
return pairs
# Function to handle user message
def user(user_message, chatbot_history):
return "", chatbot_history + [[user_message, ""]]
# Function to handle AI's response
def bot(chatbot_history):
user_message = chatbot_history[-1][0] #This line is because we cleared the user_message previously in the user function above
# Call the predict function to get the AI's response
pairs = predict(chatbot_history, user_message)
_, ai_response = pairs[-1] # Get the latest response
response_in_progress = ""
for character in ai_response:
response_in_progress += character
chatbot_history[-1][1] = response_in_progress
time.sleep(0.05)
yield chatbot_history
# This is a function to do something with the voted information
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
with open("logs.txt", "a") as text_file:
print(f"Disliked content: {data.value}", file=text_file)
# The Gradio App interface
with gr.Blocks() as demo:
gr.Markdown("""<h1><center>Technocomplex Bot</center></h1>""")
gr.Markdown("""<h3><center>This is a demo for Our Complex Relationships with Technology course, Duke, 2023</center></h3>""")
chatbot = gr.Chatbot(label="Technocomplex Bot")
textbox = gr.Textbox(label="Start chatting here and click 'Enter' to submit")
clear = gr.Button("Clear")
# Chain user and bot functions with `.then()`
textbox.submit(user, [textbox, chatbot], [textbox, chatbot], queue=False).then(
bot, chatbot, chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
chatbot.like(vote, None, None)
# Enable queuing
demo.queue()
demo.launch(debug=True, share=True)