Spaces:
Sleeping
Sleeping
File size: 6,110 Bytes
caa3582 d48b278 fcd27aa d48b278 caa3582 1fd9ed7 caa3582 ad7792a caa3582 fcd27aa caa3582 fcd27aa d48b278 fcd27aa d48b278 fcd27aa d48b278 fcd27aa d48b278 fcd27aa d48b278 fcd27aa d48b278 caa3582 d48b278 fcd27aa caa3582 0c42d61 caa3582 d48b278 caa3582 73a4a8c caa3582 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import time
import csv
import shutil
from datetime import datetime
import openai
import gradio as gr
# Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
# Chat Q&A
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
# This sets up OpenAI embeddings model
embeddings = OpenAIEmbeddings()
# Loads database from persisted directory
db_directory = "2023_12_04_chroma_db"
db = Chroma(persist_directory=db_directory, embedding_function=embeddings)
# This is code that retrieves relevant documents based on a similarity search (in this case, it grabs the top 2 relevant documents or chunks)
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k":2})
with open('system_prompt.txt', 'r') as file:
ORIG_SYSTEM_MESSAGE_PROMPT = file.read()
openai.api_key = os.getenv("OPENAI_API_KEY")
#chat = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0) # Faster for experiments
chat = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0)
# Make sure we don't exceed estimation of token limit:
TOKEN_LIMIT = 4096 # GPT-3.5 Turbo token limit
BUFFER = 100 # Extra tokens to consider for incoming messages
PERSISTENT_LOG_PATH = "/data/downvoted_responses.csv" # File in which to log downvoted responses
LOCAL_LOG_PATH = "./data/downvoted_responses.csv"
def estimate_tokens(texts):
return sum([len(t.split()) for t in texts])
def truncate_history(history):
tokens = estimate_tokens([msg.content for msg in history])
while tokens + BUFFER > TOKEN_LIMIT and len(history) > 3:
history = history[0:1] + history[3:]
tokens = estimate_tokens([msg.content for msg in history])
return history
# Here is the langchain
def predict(history, input):
context = retriever.get_relevant_documents(input)
print(context) #For debugging
history_langchain_format = []
history_langchain_format.append(SystemMessage(content=f"{ORIG_SYSTEM_MESSAGE_PROMPT}"))
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=input))
history_langchain_format.append(SystemMessage(content=f"If you need to answer a question based on the previous message, here is some info: {context}"))
# Truncate if history is too long
history_langchain_format = truncate_history(history_langchain_format)
gpt_response = chat(history_langchain_format)
# Extract pairs of HumanMessage and AIMessage
pairs = []
for i in range(len(history_langchain_format)):
if isinstance(history_langchain_format[i], HumanMessage) and (i+1 < len(history_langchain_format)) and isinstance(history_langchain_format[i+1], AIMessage):
pairs.append((history_langchain_format[i].content, history_langchain_format[i+1].content))
# Add the new AI response to the pairs for subsequent interactions
pairs.append((input, gpt_response.content))
return pairs
# Function to handle user message (this clears the interface)
def user(user_message, chatbot_history):
return "", chatbot_history + [[user_message, ""]]
# Function to handle AI's response
def bot(chatbot_history):
user_message = chatbot_history[-1][0] #This line is because we cleared the user_message previously in the user function above
# Call the predict function to get the AI's response
pairs = predict(chatbot_history, user_message)
_, ai_response = pairs[-1] # Get the latest response
response_in_progress = ""
for character in ai_response:
response_in_progress += character
chatbot_history[-1][1] = response_in_progress
time.sleep(0.05)
yield chatbot_history
def log_to_csv(question, answer):
"""Append a line to a CSV. Create a new file if needed."""
now = datetime.today().strftime("%Y%m%d_%H:%M:%S")
if not os.path.isfile(PERSISTENT_LOG_PATH):
# Add the column names to the CSV
with open(PERSISTENT_LOG_PATH, "w+") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["datetime", "user_question", "bot_response"])
# Write the disliked message to the CSV
with open(PERSISTENT_LOG_PATH, "a") as csv_file:
writer = csv.writer(csv_file)
writer.writerow([now, question, answer])
# Copy file from persistent storage to local repo
shutil.copyfile(PERSISTENT_LOG_PATH, LOCAL_LOG_PATH)
def get_voted_qa_pair(history, voted_answer):
"""Return the question-answer pair from the chat history, given a
particular bot answer. Note: This is required because the 'vote'
event handler only has access to the answer that was liked/disliked.
"""
for question, answer in history:
if answer == voted_answer:
return question, answer
def vote(data: gr.LikeData, history):
"""This is a function to do something with the voted information"""
print(history)
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
# Find Q/A pair that was disliked
question, answer = get_voted_qa_pair(history, data.value)
log_to_csv(question, answer)
# The Gradio App interface
with gr.Blocks() as demo:
gr.Markdown("""<h1><center>DylanAI by CIONIC</center></h1>""")
gr.Markdown("""<p><center>For best results, please ask DylanAI one question at a time. Unlike Human Dylan, DylanAI cannot multitask.</center></p>""")
chatbot = gr.Chatbot(label="DylanAI")
textbox = gr.Textbox(label="Type your question here")
clear = gr.Button("Clear")
# Chain user and bot functions with `.then()`
textbox.submit(user, [textbox, chatbot], [textbox, chatbot], queue=False).then(
bot, chatbot, chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
chatbot.like(vote, chatbot, None)
# Enable queuing
demo.queue()
demo.launch(debug=True, share=True)
|