File size: 6,110 Bytes
caa3582
 
d48b278
fcd27aa
d48b278
caa3582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fd9ed7
caa3582
 
 
 
 
 
 
 
 
 
 
ad7792a
caa3582
 
 
 
fcd27aa
 
caa3582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcd27aa
d48b278
 
fcd27aa
d48b278
fcd27aa
d48b278
 
 
 
fcd27aa
d48b278
 
 
fcd27aa
 
 
d48b278
 
 
 
 
 
 
 
 
fcd27aa
d48b278
 
caa3582
 
 
 
d48b278
 
fcd27aa
caa3582
 
 
 
 
0c42d61
 
caa3582
 
 
 
 
 
 
d48b278
caa3582
 
 
73a4a8c
caa3582
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import time
import csv
import shutil
from datetime import datetime
import openai
import gradio as gr

# Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma

# Chat Q&A
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage

# This sets up OpenAI embeddings model
embeddings = OpenAIEmbeddings()

# Loads database from persisted directory
db_directory = "2023_12_04_chroma_db"
db = Chroma(persist_directory=db_directory, embedding_function=embeddings)

# This is code that retrieves relevant documents based on a similarity search (in this case, it grabs the top 2 relevant documents or chunks)
retriever = db.as_retriever(search_type='similarity', search_kwargs={"k":2})

with open('system_prompt.txt', 'r') as file:
    ORIG_SYSTEM_MESSAGE_PROMPT = file.read()

openai.api_key = os.getenv("OPENAI_API_KEY")

#chat = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0) # Faster for experiments
chat = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0)

# Make sure we don't exceed estimation of token limit:
TOKEN_LIMIT = 4096  # GPT-3.5 Turbo token limit
BUFFER = 100  # Extra tokens to consider for incoming messages
PERSISTENT_LOG_PATH = "/data/downvoted_responses.csv" # File in which to log downvoted responses
LOCAL_LOG_PATH = "./data/downvoted_responses.csv"

def estimate_tokens(texts):
    return sum([len(t.split()) for t in texts])

def truncate_history(history):
    tokens = estimate_tokens([msg.content for msg in history])
    while tokens + BUFFER > TOKEN_LIMIT and len(history) > 3:
        history = history[0:1] + history[3:]
        tokens = estimate_tokens([msg.content for msg in history])
    return history

# Here is the langchain
def predict(history, input):
    context = retriever.get_relevant_documents(input)
    print(context) #For debugging
    history_langchain_format = []
    history_langchain_format.append(SystemMessage(content=f"{ORIG_SYSTEM_MESSAGE_PROMPT}"))
    for human, ai in history:
        history_langchain_format.append(HumanMessage(content=human))
        history_langchain_format.append(AIMessage(content=ai))
    history_langchain_format.append(HumanMessage(content=input))
    history_langchain_format.append(SystemMessage(content=f"If you need to answer a question based on the previous message, here is some info: {context}"))
    
    # Truncate if history is too long
    history_langchain_format = truncate_history(history_langchain_format)
    
    gpt_response = chat(history_langchain_format)
    
    # Extract pairs of HumanMessage and AIMessage
    pairs = []
    for i in range(len(history_langchain_format)):
        if isinstance(history_langchain_format[i], HumanMessage) and (i+1 < len(history_langchain_format)) and isinstance(history_langchain_format[i+1], AIMessage):
            pairs.append((history_langchain_format[i].content, history_langchain_format[i+1].content))
    
    # Add the new AI response to the pairs for subsequent interactions
    pairs.append((input, gpt_response.content))
    
    return pairs

# Function to handle user message (this clears the interface)
def user(user_message, chatbot_history):
    return "", chatbot_history + [[user_message, ""]]

# Function to handle AI's response
def bot(chatbot_history):
    user_message = chatbot_history[-1][0] #This line is because we cleared the user_message previously in the user function above
    # Call the predict function to get the AI's response
    pairs = predict(chatbot_history, user_message)
    _, ai_response = pairs[-1]  # Get the latest response

    response_in_progress = ""
    for character in ai_response:
        response_in_progress += character
        chatbot_history[-1][1] = response_in_progress
        time.sleep(0.05)
        yield chatbot_history

def log_to_csv(question, answer):
    """Append a line to a CSV. Create a new file if needed."""
    now = datetime.today().strftime("%Y%m%d_%H:%M:%S")
    if not os.path.isfile(PERSISTENT_LOG_PATH):
        # Add the column names to the CSV
        with open(PERSISTENT_LOG_PATH, "w+") as csv_file:
            writer = csv.writer(csv_file)
            writer.writerow(["datetime", "user_question", "bot_response"])

    # Write the disliked message to the CSV
    with open(PERSISTENT_LOG_PATH, "a") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow([now, question, answer])

    # Copy file from persistent storage to local repo
    shutil.copyfile(PERSISTENT_LOG_PATH, LOCAL_LOG_PATH)

def get_voted_qa_pair(history, voted_answer):
    """Return the question-answer pair from the chat history, given a 
    particular bot answer. Note: This is required because the 'vote'
    event handler only has access to the answer that was liked/disliked.
    """
    for question, answer in history:
        if answer == voted_answer:
            return question, answer

def vote(data: gr.LikeData, history):
    """This is a function to do something with the voted information"""
    print(history)
    if data.liked:
        print("You upvoted this response: " + data.value)
    else:
        print("You downvoted this response: " + data.value)
        # Find Q/A pair that was disliked
        question, answer = get_voted_qa_pair(history, data.value)
        log_to_csv(question, answer)

# The Gradio App interface
with gr.Blocks() as demo:
    gr.Markdown("""<h1><center>DylanAI by CIONIC</center></h1>""")
    gr.Markdown("""<p><center>For best results, please ask DylanAI one question at a time. Unlike Human Dylan, DylanAI cannot multitask.</center></p>""")
    chatbot = gr.Chatbot(label="DylanAI")
    textbox = gr.Textbox(label="Type your question here")
    clear = gr.Button("Clear")
    
    # Chain user and bot functions with `.then()`
    textbox.submit(user, [textbox, chatbot], [textbox, chatbot], queue=False).then(
        bot, chatbot, chatbot,
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    chatbot.like(vote, chatbot, None)

# Enable queuing
demo.queue()
demo.launch(debug=True, share=True)