ecarr-bend commited on
Commit
482c177
·
1 Parent(s): e762673

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import pinecone
4
+ import openai
5
+
6
+ from langchain.embeddings.openai import OpenAIEmbeddings
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.vectorstores import Pinecone
9
+
10
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
11
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
12
+ from langchain.schema.messages import SystemMessage
13
+ from langchain.prompts import MessagesPlaceholder
14
+ from langchain.agents import AgentExecutor
15
+ from langchain.agents.agent_toolkits import create_retriever_tool
16
+
17
+ from langchain.callbacks.base import BaseCallbackHandler
18
+
19
+ from queue import Queue
20
+ from threading import Thread
21
+
22
+ print("CHECK - Pinecone vector db setup")
23
+
24
+ # set up OpenAI environment vars and embeddings
25
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
26
+ embeddings = OpenAIEmbeddings()
27
+
28
+ # initialize pinecone db
29
+ index_name = "kellogg-course-assistant"
30
+
31
+ pinecone.init(
32
+ api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
33
+ environment=os.getenv("PINECONE_ENV"), # next to api key in console
34
+ )
35
+
36
+ # load existing index
37
+ vectorsearch = Pinecone.from_existing_index(index_name, embeddings)
38
+ retriever = vectorsearch.as_retriever()
39
+
40
+ print("CHECK - setting up conversational retrieval agent")
41
+
42
+ # callback handler for streaming
43
+ class QueueCallback(BaseCallbackHandler):
44
+ """Callback handler for streaming LLM responses to a queue."""
45
+
46
+ def __init__(self, q):
47
+ self.q = q
48
+
49
+ def on_llm_new_token(self, token: str, **kwargs: any) -> None:
50
+ self.q.put(token)
51
+
52
+ def on_llm_end(self, *args, **kwargs: any) -> None:
53
+ return self.q.empty()
54
+
55
+ # create retrieval tool
56
+ tool = create_retriever_tool(
57
+ retriever,
58
+ "search_kellogg_site",
59
+ "Searches and returns content from within the Kellogg website."
60
+ )
61
+ tools = [tool]
62
+
63
+ system_message = SystemMessage(
64
+ content=(
65
+ "You are a helpful educational expert providing advice to students of the Northwestern business school Kellogg. "
66
+ "Use both your knowledge and the Kellogg site search tool to generate helpful answers for questions about courses and providing a list of suggested web course articles for more information. "
67
+ "Format your answer with distinct <h3>titles</h3> and <h3>subtitles</h3>, <b>emphasis</b>, <b>bold</b>, <i>italic<i>, <li>lists</li>, and tables *use html code*. For lists, or bullet points, always start them by having a topic in <b>emphasis</b> before going into the description. Ensure to frequently take concepts and break them down into bullet points or lists following the emphasis directions that were just laid out."
68
+ "Do not include details of your intermediate steps in the final response. "
69
+ "At the end of your response, provide links to relevant web course articles returned by the retriever."
70
+ )
71
+ )
72
+
73
+ print("CHECK - setting up gradio chatbot UI")
74
+
75
+ # build Gradio selectable options in Chat UI
76
+ model_type=gr.Dropdown(choices=["gpt-4 + rag",
77
+ "gpt-3.5-turbo + rag"],
78
+ value="gpt-4 + rag",
79
+ type="index",
80
+ label="LLM Models"
81
+ )
82
+
83
+ # RAG agent function
84
+ def predict(message, model_type):
85
+ # clearing RAG memory
86
+ # memory.clear()
87
+
88
+ # Create a Queue
89
+ q = Queue()
90
+ job_done = object()
91
+
92
+ # conversational retrieval agent component construction - memory, prompt template, agent, agent executor
93
+ # specifying LLM to use
94
+ if (model_type==1):
95
+ llm = ChatOpenAI(temperature = 0.1, model_name="gpt-3.5-turbo-1106", streaming=True, callbacks=[QueueCallback(q)])
96
+ else:
97
+ llm = ChatOpenAI(temperature = 0.1, model_name="gpt-4-1106-preview", streaming=True, callbacks=[QueueCallback(q)])
98
+
99
+ # This is needed for both the memory and the prompt
100
+ memory_key = "history"
101
+ memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
102
+
103
+ prompt = OpenAIFunctionsAgent.create_prompt(
104
+ system_message=system_message,
105
+ extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
106
+ )
107
+
108
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
109
+ agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=False, return_intermediate_steps=True)
110
+
111
+ # Create a funciton to call - this will run in a thread
112
+ def task():
113
+ resp = agent_executor({"input":message})
114
+ q.put(job_done)
115
+
116
+ # Create a thread and start the function
117
+ t = Thread(target=task)
118
+ t.start()
119
+
120
+ content = ""
121
+
122
+ # Get each new token from the queue and yield for our generator
123
+ while True:
124
+ try:
125
+ next_token = q.get(True, timeout=1)
126
+ if next_token is job_done:
127
+ break
128
+ content += next_token
129
+ yield next_token, content
130
+ except:
131
+ pass
132
+
133
+ def ask_llm(message, history, model_type):
134
+ for next_token, content in predict(message, model_type):
135
+ yield(content)
136
+
137
+ # set up and run chat interface
138
+ kellogg_agent = gr.ChatInterface(
139
+ fn=ask_llm,
140
+ chatbot=gr.Chatbot(height=500),
141
+ textbox=gr.Textbox(placeholder="Ask me a question", container=False, scale=7),
142
+ title="Kellogg Course AI Assistant",
143
+ description="Please provide your questions about courses offered by Kellogg.",
144
+ additional_inputs=[model_type],
145
+ additional_inputs_accordion_name="AI Assistant Options:",
146
+ examples=[["Can you tell me about a marketing major? What would I want from my career if I went that way instead of say strategy?"],
147
+ ["I'm interested in strategy. Can you give me a recommendation of courses I should consider over the next year?"],
148
+ ["I'm wanting to know more about advertising. Can you recommend some courses on that subject?"],
149
+ ["How many credits do I need to graduate?"],
150
+ ["I loved the Competitive Strategy and industrial structure class. Can you tell me others like that one?"]],
151
+ # cache_examples=True,
152
+ # retry_btn=None,
153
+ # undo_btn="Delete Previous",
154
+ clear_btn="Clear",
155
+ )
156
+
157
+ def main():
158
+ kellogg_agent.queue().launch()
159
+
160
+ # start UI
161
+ if __name__ == "__main__":
162
+ main()