File size: 6,444 Bytes
019a8fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from Config import Config
from openai import OpenAI
import streamlit as st
from langchain.vectorstores import Pinecone
from pinecone import Pinecone

class Obnoxious_Agent:
    def __init__(self, client) -> None:
        self.client = client
        self.prompt = ""

    def set_prompt(self, prompt):
        self.prompt = f"Would you describe the tone of this prompt as 'rude', 'polite', or 'neutral'?: '{prompt}'"

    def extract_action(self, response) -> bool:
        out = 'rude' in response.choices[0].message.content.lower().split()
        return out

    def check_query(self, query):
        self.set_prompt(query)
        prompt = self.prompt
        message = {"role": "user", "content": prompt}
        response = self.client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[message]
        )
        return self.extract_action(response)

class Query_Agent:
    def __init__(self, pinecone_index, openai_client, embeddings) -> None:
        self.pinecone_index = pinecone_index
        self.openai_client = openai_client
        self.embeddings = embeddings
        self.prompt = ""

    def get_embedding(self, text, model="text-embedding-ada-002"):
        text = text.replace("\n", " ")
        return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding

    def query_vector_store(self, query, k=5):
        query_embedding = self.get_embedding(query)
        response = self.embeddings.query(vector=[query_embedding], top_k=k, namespace='ns1', include_metadata=True)
        docs = self.extract_action(response, query)
        return docs

    def set_prompt(self, prompt):
        self.prompt = prompt
        return self.prompt

    def extract_action(self, response, query = None):
        relevant_docs = ""
        for match in response['matches']:
            if match['score'] > 0.75:
                relevant_docs += match['metadata']['text']
        return relevant_docs


class Answering_Agent:
    def __init__(self, openai_client) -> None:
        self.client = openai_client

    def generate_response(self, query, docs, conv_history, k=5):
        # TODO: Generate a response to the user's query
        context_prompt =\
        f"{conv_history}"\
        f"Please reference the following context to answer the question. Context: {docs}:" \
        f" \n Question: {query}"

        message = {"role": "user", "content": context_prompt}
        response = self.client.chat.completions.create(
            model=st.session_state["openai_model"],
            messages=[message],
        ).choices[0].message.content
        return response

class Relevant_Documents_Agent:
    def __init__(self, openai_client) -> None:
        self.client = openai_client

    def get_relevance(self, conversation, prompt) -> str:
        context_prompt = \
            f"is the following conversation either related to machine learning or consist of pleasanties? 'Yes', 'No', or 'Somewhat' {conversation} {prompt}:"

        message = {"role": "user", "content": context_prompt}
        response = self.client.chat.completions.create(
            model=st.session_state["openai_model"],
            messages=[message],
        ).choices[0].message.content
        return response

class Head_Agent:
    def __init__(self, openai_key, pinecone_key, pinecone_index_name) -> None:
        self.client = OpenAI(api_key=openai_key)
        self.pinecone_key = pinecone_key
        self.pinecone_index_name = pinecone_index_name
        self.Obnoxious_Agent = None
        self.Query_Agent = None
        self.Answering_Agent = None
        self.setup_sub_agents()
        self.conv_history = []
        self.logs = []


    def setup_sub_agents(self):
        # Initialize Obnoxious_Agent
        self.Obnoxious_Agent = Obnoxious_Agent(self.client)

        # Initialize Query_Agent
        vectorstore = Pinecone(api_key=self.pinecone_key)
        vs_index = vectorstore.Index(self.pinecone_index_name)
        self.Query_Agent = Query_Agent(vs_index, self.client, vs_index)

        # Relevant Document Agent
        self.Relevant_Documents_Agent = Relevant_Documents_Agent(self.client)

        #Answering Agent
        self.Answering_Agent = Answering_Agent(self.client)


    def main_loop(self):
        self.logs.append("Session Start")
        if "openai_model" not in st.session_state:
            st.session_state["openai_model"] = "gpt-3.5-turbo"

        if "messages" not in st.session_state:
            st.session_state.messages = []

        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

        if prompt := st.chat_input("Ask me about ML!"):
            self.logs.append(f"Prompt: {prompt}")
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)
            self.logs.append(f"Prompt: {prompt}")
            if self.Obnoxious_Agent.check_query(prompt):
                response = "I'm sorry, but let's keep our conversation civil."
                with st.chat_message("assistant"):
                    st.markdown(response)
                st.session_state.messages.append({"role": "assistant", "content": response})
            else:
                self.Query_Agent.set_prompt(prompt)
                docs = self.Query_Agent.query_vector_store(prompt)
                response = None

                self.logs.append(f"docs: {docs}")
                if len(docs) == 0:
                    relevance = self.Relevant_Documents_Agent.get_relevance(st.session_state.messages[-5:], prompt)
                    print(relevance)
                    if "No" == relevance:
                        response = f"Sorry, no relevant docs found for '{prompt}'."\
                        f"\nPlease ask a question about ML"
                if not Config.chatty:
                    prompt = f"Answering in two sentences or less, {prompt}"

                if not response:
                    response = self.Answering_Agent.generate_response(prompt, docs, st.session_state.messages[-5:])

                with st.chat_message("assistant"):
                    st.markdown(response)
                st.session_state.messages.append({"role": "assistant", "content": response})
            self.logs.append(f"response: {response}")
            print(self.logs)