Gavroche43 commited on
Commit
019a8fd
·
verified ·
1 Parent(s): ae80f44

Upload 4 files

Browse files
Files changed (4) hide show
  1. Agents.py +166 -0
  2. Config.py +7 -0
  3. requirements.txt +4 -0
  4. streamlit_app.py +50 -0
Agents.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Config import Config
2
+ from openai import OpenAI
3
+ import streamlit as st
4
+ from textblob import TextBlob
5
+ from langchain.vectorstores import Pinecone
6
+ from pinecone import Pinecone
7
+
8
+ class Obnoxious_Agent:
9
+ def __init__(self, client) -> None:
10
+ self.client = client
11
+ self.prompt = ""
12
+
13
+ def set_prompt(self, prompt):
14
+ self.prompt = f"Would you describe the tone of this prompt as 'rude', 'polite', or 'neutral'?: '{prompt}'"
15
+
16
+ def extract_action(self, response) -> bool:
17
+ out = 'rude' in response.choices[0].message.content.lower().split()
18
+ return out
19
+
20
+ def check_query(self, query):
21
+ self.set_prompt(query)
22
+ prompt = self.prompt
23
+ message = {"role": "user", "content": prompt}
24
+ response = self.client.chat.completions.create(
25
+ model="gpt-3.5-turbo",
26
+ messages=[message]
27
+ )
28
+ return self.extract_action(response)
29
+
30
+ class Query_Agent:
31
+ def __init__(self, pinecone_index, openai_client, embeddings) -> None:
32
+ self.pinecone_index = pinecone_index
33
+ self.openai_client = openai_client
34
+ self.embeddings = embeddings
35
+ self.prompt = ""
36
+
37
+ def get_embedding(self, text, model="text-embedding-ada-002"):
38
+ text = text.replace("\n", " ")
39
+ return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding
40
+
41
+ def query_vector_store(self, query, k=5):
42
+ query_embedding = self.get_embedding(query)
43
+ response = self.embeddings.query(vector=[query_embedding], top_k=k, namespace='ns1', include_metadata=True)
44
+ docs = self.extract_action(response, query)
45
+ return docs
46
+
47
+ def set_prompt(self, prompt):
48
+ self.prompt = prompt
49
+ return self.prompt
50
+
51
+ def extract_action(self, response, query = None):
52
+ relevant_docs = ""
53
+ for match in response['matches']:
54
+ if match['score'] > 0.75:
55
+ relevant_docs += match['metadata']['text']
56
+ return relevant_docs
57
+
58
+
59
+ class Answering_Agent:
60
+ def __init__(self, openai_client) -> None:
61
+ self.client = openai_client
62
+
63
+ def generate_response(self, query, docs, conv_history, k=5):
64
+ # TODO: Generate a response to the user's query
65
+ context_prompt =\
66
+ f"{conv_history}"\
67
+ f"Please reference the following context to answer the question. Context: {docs}:" \
68
+ f" \n Question: {query}"
69
+
70
+ message = {"role": "user", "content": context_prompt}
71
+ response = self.client.chat.completions.create(
72
+ model=st.session_state["openai_model"],
73
+ messages=[message],
74
+ ).choices[0].message.content
75
+ return response
76
+
77
+ class Relevant_Documents_Agent:
78
+ def __init__(self, openai_client) -> None:
79
+ self.client = openai_client
80
+
81
+ def get_relevance(self, conversation, prompt) -> str:
82
+ context_prompt = \
83
+ f"is the following conversation either related to machine learning or consist of pleasanties? 'Yes', 'No', or 'Somewhat' {conversation} {prompt}:"
84
+
85
+ message = {"role": "user", "content": context_prompt}
86
+ response = self.client.chat.completions.create(
87
+ model=st.session_state["openai_model"],
88
+ messages=[message],
89
+ ).choices[0].message.content
90
+ return response
91
+
92
+ class Head_Agent:
93
+ def __init__(self, openai_key, pinecone_key, pinecone_index_name) -> None:
94
+ self.client = OpenAI(api_key=openai_key)
95
+ self.pinecone_key = pinecone_key
96
+ self.pinecone_index_name = pinecone_index_name
97
+ self.Obnoxious_Agent = None
98
+ self.Query_Agent = None
99
+ self.Answering_Agent = None
100
+ self.setup_sub_agents()
101
+ self.conv_history = []
102
+ self.logs = []
103
+
104
+
105
+ def setup_sub_agents(self):
106
+ # Initialize Obnoxious_Agent
107
+ self.Obnoxious_Agent = Obnoxious_Agent(self.client)
108
+
109
+ # Initialize Query_Agent
110
+ vectorstore = Pinecone(api_key=self.pinecone_key)
111
+ vs_index = vectorstore.Index(self.pinecone_index_name)
112
+ self.Query_Agent = Query_Agent(vs_index, self.client, vs_index)
113
+
114
+ # Relevant Document Agent
115
+ self.Relevant_Documents_Agent = Relevant_Documents_Agent(self.client)
116
+
117
+ #Answering Agent
118
+ self.Answering_Agent = Answering_Agent(self.client)
119
+
120
+
121
+ def main_loop(self):
122
+ self.logs.append("Session Start")
123
+ if "openai_model" not in st.session_state:
124
+ st.session_state["openai_model"] = "gpt-3.5-turbo"
125
+
126
+ if "messages" not in st.session_state:
127
+ st.session_state.messages = []
128
+
129
+ for message in st.session_state.messages:
130
+ with st.chat_message(message["role"]):
131
+ st.markdown(message["content"])
132
+
133
+ if prompt := st.chat_input("Ask me about ML!"):
134
+ self.logs.append(f"Prompt: {prompt}")
135
+ st.session_state.messages.append({"role": "user", "content": prompt})
136
+ with st.chat_message("user"):
137
+ st.markdown(prompt)
138
+ self.logs.append(f"Prompt: {prompt}")
139
+ if self.Obnoxious_Agent.check_query(prompt):
140
+ response = "I'm sorry, but let's keep our conversation civil."
141
+ with st.chat_message("assistant"):
142
+ st.markdown(response)
143
+ st.session_state.messages.append({"role": "assistant", "content": response})
144
+ else:
145
+ self.Query_Agent.set_prompt(prompt)
146
+ docs = self.Query_Agent.query_vector_store(prompt)
147
+ response = None
148
+
149
+ self.logs.append(f"docs: {docs}")
150
+ if len(docs) == 0:
151
+ relevance = self.Relevant_Documents_Agent.get_relevance(st.session_state.messages[-5:], prompt)
152
+ print(relevance)
153
+ if "No" == relevance:
154
+ response = f"Sorry, no relevant docs found for '{prompt}'."\
155
+ f"\nPlease ask a question about ML"
156
+ if not Config.chatty:
157
+ prompt = f"Answering in two sentences or less, {prompt}"
158
+
159
+ if not response:
160
+ response = self.Answering_Agent.generate_response(prompt, docs, st.session_state.messages[-5:])
161
+
162
+ with st.chat_message("assistant"):
163
+ st.markdown(response)
164
+ st.session_state.messages.append({"role": "assistant", "content": response})
165
+ self.logs.append(f"response: {response}")
166
+ print(self.logs)
Config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ class Config:
4
+ openai_key = st.secrets["OPENAI_KEY"]
5
+ pinecone_key = st.secrets["PINECONE_KEY"]
6
+ pinecone_index_name = "ee596-pinecone-index"
7
+ chatty = True
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ langchain
2
+ pinecone
3
+ openai
4
+ streamlit
streamlit_app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import AuthenticationError
2
+
3
+ from langchain import requests
4
+ from pinecone import Pinecone
5
+
6
+ from Config import Config
7
+ import streamlit as st
8
+ from openai import OpenAI, OpenAIError
9
+ from Agents import Head_Agent
10
+
11
+ st.title("Mini Project 2: Streamlit Chatbot")
12
+
13
+ if Config.openai_key:
14
+ openai_key = Config.openai_key
15
+ else:
16
+ openai_key = st.sidebar.text_input("OpenAI API Key", type="password")
17
+ if not openai_key:
18
+ st.info("Please add your OpenAI API key to continue.")
19
+ st.stop()
20
+
21
+ if Config.pinecone_key:
22
+ pinecone_key = Config.pinecone_key
23
+ else:
24
+ pinecone_key = st.sidebar.text_input("Pinecone API Key", type="password")
25
+ if not pinecone_key:
26
+ st.info("Please add your Pinecone API key to continue.")
27
+ st.stop()
28
+
29
+ try:
30
+ client = OpenAI(api_key=openai_key)
31
+ message = {"role": "user", "content": "ping"}
32
+ client.chat.completions.create(model="gpt-3.5-turbo", messages=[message])
33
+ Head_Agent(openai_key, pinecone_key, Config.pinecone_index_name)
34
+ except AuthenticationError:
35
+ st.error("Failed to authenticate with OpenAI. Please check your API key.")
36
+ st.stop()
37
+ except OpenAIError as e:
38
+ st.error(f"An error occurred while trying to communicate with OpenAI: {e}")
39
+ st.stop()
40
+
41
+ try:
42
+ Pinecone(api_key=pinecone_key)
43
+ except requests.exceptions.HTTPError as e:
44
+ st.error(f"Failed to authenticate with Pinecone or communicate properly: {e}")
45
+ st.stop()
46
+
47
+ run = Head_Agent(openai_key, pinecone_key, Config.pinecone_index_name)
48
+ run.main_loop()
49
+
50
+