ecarr-bend commited on
Commit
2c826a6
·
1 Parent(s): dd5cfa4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import pinecone
4
+ import openai
5
+
6
+ from langchain.embeddings.openai import OpenAIEmbeddings
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.vectorstores import Pinecone
9
+
10
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
11
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
12
+ from langchain.schema.messages import SystemMessage
13
+ from langchain.prompts import MessagesPlaceholder
14
+ from langchain.agents import AgentExecutor
15
+ from langchain.agents.agent_toolkits import create_retriever_tool
16
+
17
+ from langchain.callbacks.base import BaseCallbackHandler
18
+
19
+ from queue import Queue
20
+ from threading import Thread
21
+
22
+ print("CHECK - Pinecone vector db setup")
23
+
24
+ # set up OpenAI environment vars and embeddings
25
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
26
+ embeddings = OpenAIEmbeddings()
27
+
28
+ # initialize pinecone db
29
+ index_name = "kellogg-markstrat"
30
+
31
+ pinecone.init(
32
+ api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
33
+ environment=os.getenv("PINECONE_ENV"), # next to api key in console
34
+ )
35
+
36
+ # load existing index
37
+ vectorsearch = Pinecone.from_existing_index(index_name, embeddings)
38
+ retriever = vectorsearch.as_retriever()
39
+
40
+ print("CHECK - setting up conversational retrieval agent")
41
+
42
+ # callback handler for streaming
43
+ class QueueCallback(BaseCallbackHandler):
44
+ """Callback handler for streaming LLM responses to a queue."""
45
+
46
+ def __init__(self, q):
47
+ self.q = q
48
+
49
+ def on_llm_new_token(self, token: str, **kwargs: any) -> None:
50
+ self.q.put(token)
51
+
52
+ def on_llm_end(self, *args, **kwargs: any) -> None:
53
+ return self.q.empty()
54
+
55
+ # create retrieval tool
56
+ tool = create_retriever_tool(
57
+ retriever,
58
+ "search_markstrat",
59
+ "Searches and returns information about the MarkStrat simulation program."
60
+ )
61
+ tools = [tool]
62
+
63
+ system_message = SystemMessage(
64
+ content=(
65
+ "You are an AI bot marketing professor at a business school helping students understand how to play the markstrat simulation. For every question or comment compose a well-structured response to the user's question, using context and conversation information. Your tone of voice will be conversational and engaging, while still being to the point and direct. "
66
+ "Use the MarkStrat search tool to generate helpful answers for the user question. "
67
+ "If its a simple question that asks for a quantitative answer, then provide a much more succinct response. "
68
+ "Respond to questions that are at least 80% similar to the content within the specified context being sent to you, if the question has nothing to do with the additional information supplied to you, then reply with 'I can only answer questions related to MarkStrat."
69
+ )
70
+ )
71
+
72
+ print("CHECK - setting up gradio chatbot UI")
73
+
74
+ # build Gradio selectable options in Chat UI
75
+ model_type=gr.Dropdown(choices=["gpt-4 + rag",
76
+ "gpt-3.5-turbo + rag"],
77
+ value="gpt-4 + rag",
78
+ type="index",
79
+ label="LLM Models"
80
+ )
81
+
82
+ # RAG agent function
83
+ def predict(message, model_type):
84
+ # Create a Queue
85
+ q = Queue()
86
+ job_done = object()
87
+
88
+ # conversational retrieval agent component construction - memory, prompt template, agent, agent executor
89
+ # specifying LLM to use
90
+ if (model_type==1):
91
+ llm = ChatOpenAI(temperature = 0.1, model_name="gpt-3.5-turbo-16k", streaming=True, callbacks=[QueueCallback(q)])
92
+ else:
93
+ llm = ChatOpenAI(temperature = 0.1, model_name="gpt-4-1106-preview", streaming=True, callbacks=[QueueCallback(q)])
94
+
95
+ # This is needed for both the memory and the prompt
96
+ memory_key = "history"
97
+ memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
98
+
99
+ prompt = OpenAIFunctionsAgent.create_prompt(
100
+ system_message=system_message,
101
+ extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
102
+ )
103
+
104
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
105
+ agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=False, return_intermediate_steps=True)
106
+
107
+ # Create a funciton to call - this will run in a thread
108
+ def task():
109
+ resp = agent_executor({"input":message})
110
+ q.put(job_done)
111
+
112
+ # Create a thread and start the function
113
+ t = Thread(target=task)
114
+ t.start()
115
+
116
+ content = ""
117
+
118
+ # Get each new token from the queue and yield for our generator
119
+ while True:
120
+ try:
121
+ next_token = q.get(True, timeout=1)
122
+ if next_token is job_done:
123
+ break
124
+ content += next_token
125
+ yield next_token, content
126
+ except:
127
+ pass
128
+
129
+ def ask_llm(message, history, model_type):
130
+ for next_token, content in predict(message, model_type):
131
+ yield(content)
132
+
133
+ # set up and run chat interface
134
+ kellogg_agent = gr.ChatInterface(
135
+ fn=ask_llm,
136
+ chatbot=gr.Chatbot(height=500),
137
+ textbox=gr.Textbox(placeholder="Ask me a question", container=False, scale=7),
138
+ title="Kellogg MarkStrat AI Assistant",
139
+ description="Please provide your questions about MarkStrat.",
140
+ # additional_inputs=[model_type],
141
+ # additional_inputs_accordion_name="AI Assistant Options:",
142
+ examples=[["What is MarkStrat?"]],
143
+ # cache_examples=True,
144
+ # retry_btn=None,
145
+ # undo_btn="Delete Previous",
146
+ clear_btn="Clear",
147
+ )
148
+
149
+ user_cred = os.environ.get("USER_CRED")
150
+ pass_cred = os.environ.get("PASS_CRED")
151
+
152
+ def main():
153
+ kellogg_agent.queue().launch(auth=(user_cred, pass_cred))
154
+
155
+ # start UI
156
+ if __name__ == "__main__":
157
+ main()