File size: 5,407 Bytes
2c826a6
 
 
 
 
b4a831f
 
2c826a6
 
 
 
 
 
 
 
 
 
 
 
 
 
a9713d7
 
2c826a6
7b559e5
1322f40
 
 
 
2c826a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6fa095
2c826a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32c790e
ec8a6f6
 
 
 
b4edc2a
 
 
 
 
 
 
32c790e
2c826a6
 
 
 
 
0fdc900
2c826a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f64bb
09fcba2
05356c3
2c826a6
 
 
 
 
 
 
 
 
 
37bab95
2c826a6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import os
import pinecone
import openai

from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.vectorstores import Pinecone

from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.schema.messages import SystemMessage
from langchain.prompts import MessagesPlaceholder
from langchain.agents import AgentExecutor
from langchain.agents.agent_toolkits import create_retriever_tool

from langchain.callbacks.base import BaseCallbackHandler

from queue import Queue
from threading import Thread

# pinecone db index
index_name = "kellogg-markstrat"

pinecone.init(
    api_key=os.getenv("PINECONE_API_KEY"),  # find at app.pinecone.io
    environment=os.getenv("PINECONE_ENV"),  # next to api key in console
)

# set up OpenAI environment vars and embeddings
openai.api_key = os.environ.get("OPENAI_API_KEY")
embeddings = OpenAIEmbeddings()

print("CHECK - setting up conversational retrieval agent")

# callback handler for streaming
class QueueCallback(BaseCallbackHandler):
    """Callback handler for streaming LLM responses to a queue."""

    def __init__(self, q):
        self.q = q

    def on_llm_new_token(self, token: str, **kwargs: any) -> None:
        self.q.put(token)

    def on_llm_end(self, *args, **kwargs: any) -> None:
        return self.q.empty()

system_message = SystemMessage(
        content=(
			"You are an AI bot marketing professor at a business school helping students understand how to play the markstrat simulation. For every question or comment compose a well-structured response to the user's question, using context and conversation information. Your tone of voice will be conversational and engaging, while still being to the point and direct. "
			"Use the MarkStrat search tool to generate helpful answers for the user question. "
            "If referring to Kellogg or Northwestern use terms like 'we' instead of 'they', for example 'Here at Kellogg, we calculate profit as....'"
            "If its a simple question that asks for a quantitative answer, then provide a much more succinct response. "
			"Respond to questions that are at least 80% similar to the content within the specified context being sent to you, if the question has nothing to do with the additional information supplied to you, then reply with 'I can only answer questions related to MarkStrat."
        )
)

print("CHECK - setting up gradio chatbot UI")

# build Gradio selectable options in Chat UI
model_type=gr.Dropdown(choices=["gpt-4 + rag", 
								"gpt-3.5-turbo + rag"], 
								value="gpt-4 + rag",
								type="index",
								label="LLM Models"
)

# RAG agent function
def predict(message, model_type):
	# Create a Queue
	q = Queue()
	job_done = object()

    # load existing index
	vectorsearch = Pinecone.from_existing_index(index_name, embeddings)
	retriever = vectorsearch.as_retriever()
    
	# create retrieval tool
	tool = create_retriever_tool(
		retriever, 
		"search_markstrat",
		"Searches and returns information about the MarkStrat simulation program."
	)
	tools = [tool]
    
	# conversational retrieval agent component construction - memory, prompt template, agent, agent executor
	# specifying LLM to use
	if (model_type==1):
		llm =  ChatOpenAI(temperature = 0.1, model_name="gpt-3.5-turbo-16k", streaming=True, callbacks=[QueueCallback(q)])
	else:
		llm =  ChatOpenAI(temperature = 0.1, model_name="gpt-4-turbo-preview", streaming=True, callbacks=[QueueCallback(q)])

	# This is needed for both the memory and the prompt
	memory_key = "history"
	memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)

	prompt = OpenAIFunctionsAgent.create_prompt(
        	system_message=system_message,
        	extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
	)

	agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
	agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=False, return_intermediate_steps=True)	

	# Create a funciton to call - this will run in a thread
	def task():
		resp = agent_executor({"input":message})
		q.put(job_done)

	# Create a thread and start the function
	t = Thread(target=task)
	t.start()

	content = ""

	# Get each new token from the queue and yield for our generator
	while True:
		try:
			next_token = q.get(True, timeout=1)
			if next_token is job_done:
				break
			content += next_token
			yield next_token, content
		except:
			pass

def ask_llm(message, history, model_type):
    for next_token, content in predict(message, model_type):
        yield(content)

# set up and run chat interface
kellogg_agent = gr.ChatInterface(
	fn=ask_llm,
	chatbot=gr.Chatbot(height=500),
	textbox=gr.Textbox(placeholder="Ask me a question", container=False, scale=7),
	title="Kellogg MarkStrat AI Assistant",
	description="Please provide your questions about MarkStrat.",
	additional_inputs=[model_type],
	additional_inputs_accordion="AI Assistant Options:",
	examples=[["How do I play MarkStrat?"]],
#    cache_examples=True,
#    retry_btn=None,
#	undo_btn="Delete Previous",
	clear_btn="Clear",
)

user_cred = os.environ.get("USER_CRED")
pass_cred = os.environ.get("PASS_CRED")

def main():
	kellogg_agent.queue().launch()

# start UI
if __name__ == "__main__":
	main()