JLW commited on
Commit
768a16a
·
1 Parent(s): 96559bd

Make tools selectable

Browse files
Files changed (2) hide show
  1. app.py +39 -25
  2. videos/tempfile.mp4 +2 -2
app.py CHANGED
@@ -18,12 +18,17 @@ from langchain.llms import OpenAI
18
  news_api_key = os.environ["NEWS_API_KEY"]
19
  tmdb_bearer_token = os.environ["TMDB_BEARER_TOKEN"]
20
 
 
 
 
 
21
  # UNCOMMENT TO USE WHISPER
22
  # warnings.filterwarnings("ignore")
23
  # WHISPER_MODEL = whisper.load_model("tiny")
24
  # print("WHISPER_MODEL", WHISPER_MODEL)
25
 
26
 
 
27
  # def transcribe(aud_inp):
28
  # if aud_inp is None:
29
  # return ""
@@ -40,36 +45,30 @@ tmdb_bearer_token = os.environ["TMDB_BEARER_TOKEN"]
40
  # return result_text
41
 
42
 
43
- def load_chain():
44
- """Logic for loading the chain you want to use should go here."""
45
- llm = OpenAI(temperature=0)
46
-
47
- tool_names = ['serpapi', 'pal-math', 'pal-colored-objects']
48
- # tool_names = ['serpapi', 'pal-math', 'pal-colored-objects', 'news-api', 'tmdb-api', 'open-meteo-api']
49
 
50
  memory = ConversationBufferMemory(memory_key="chat_history")
51
-
52
- tools = load_tools(tool_names, llm=llm)
53
- # tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key, tmdb_bearer_token=tmdb_bearer_token)
54
-
55
  chain = initialize_agent(tools, llm, agent="conversational-react-description", verbose=True, memory=memory)
56
  return chain
57
 
58
 
59
- def set_openai_api_key(api_key, agent):
60
  """Set the api key and return chain.
61
-
62
  If no api_key, then None is returned.
63
  """
64
  if api_key:
65
  os.environ["OPENAI_API_KEY"] = api_key
66
- chain = load_chain()
 
67
  os.environ["OPENAI_API_KEY"] = ""
68
- return chain
69
 
70
 
71
  def chat(
72
- inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain]
73
  ):
74
  """Execute the chat functionality."""
75
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
@@ -108,15 +107,27 @@ def do_html_video_speak(words_to_speak):
108
  return html_video, "videos/tempfile.mp4"
109
 
110
 
 
 
 
 
 
 
 
111
  block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
112
 
113
  with block:
 
 
 
 
 
114
  with gr.Row():
115
  with gr.Column():
116
  gr.Markdown("<h4><center>Conversational Agent using GPT-3.5 & LangChain</center></h4>")
117
 
118
  openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
119
- show_label=False, lines=1, type='password')
120
 
121
  with gr.Row():
122
  with gr.Column(scale=0.25, min_width=240):
@@ -126,6 +137,13 @@ with block:
126
  htm_video = f'<video width="256" height="256" autoplay muted loop><source src={tmp_file_url} type="video/mp4" poster="Masahiro.png"></video>'
127
  video_html = gr.HTML(htm_video)
128
 
 
 
 
 
 
 
 
129
  with gr.Column(scale=0.75):
130
  chatbot = gr.Chatbot()
131
 
@@ -160,15 +178,11 @@ with block:
160
 
161
  gr.HTML("<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>")
162
 
163
- state = gr.State()
164
- chain_state = gr.State()
165
-
166
- message.submit(chat, inputs=[message, state, chain_state], outputs=[chatbot, state, video_html, my_file, message])
167
- submit.click(chat, inputs=[message, state, chain_state], outputs=[chatbot, state, video_html, my_file, message])
168
 
169
  openai_api_key_textbox.change(set_openai_api_key,
170
- inputs=[openai_api_key_textbox, chain_state],
171
- outputs=[chain_state])
172
-
173
- block.launch(debug = True)
174
 
 
 
18
  news_api_key = os.environ["NEWS_API_KEY"]
19
  tmdb_bearer_token = os.environ["TMDB_BEARER_TOKEN"]
20
 
21
+ TOOLS_LIST = ['serpapi', 'pal-math', 'pal-colored-objects', 'news-api', 'tmdb-api', 'open-meteo-api']
22
+ TOOLS_DEFAULT_LIST = ['serpapi', 'pal-math', 'pal-colored-objects']
23
+
24
+
25
  # UNCOMMENT TO USE WHISPER
26
  # warnings.filterwarnings("ignore")
27
  # WHISPER_MODEL = whisper.load_model("tiny")
28
  # print("WHISPER_MODEL", WHISPER_MODEL)
29
 
30
 
31
+ # UNCOMMENT TO USE WHISPER
32
  # def transcribe(aud_inp):
33
  # if aud_inp is None:
34
  # return ""
 
45
  # return result_text
46
 
47
 
48
+ def load_chain(tools_list, llm):
49
+ print("tools_list", tools_list)
50
+ tool_names = tools_list
51
+ tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key, tmdb_bearer_token=tmdb_bearer_token)
 
 
52
 
53
  memory = ConversationBufferMemory(memory_key="chat_history")
 
 
 
 
54
  chain = initialize_agent(tools, llm, agent="conversational-react-description", verbose=True, memory=memory)
55
  return chain
56
 
57
 
58
+ def set_openai_api_key(api_key):
59
  """Set the api key and return chain.
 
60
  If no api_key, then None is returned.
61
  """
62
  if api_key:
63
  os.environ["OPENAI_API_KEY"] = api_key
64
+ llm = OpenAI(temperature=0)
65
+ chain = load_chain(TOOLS_DEFAULT_LIST, llm)
66
  os.environ["OPENAI_API_KEY"] = ""
67
+ return chain, llm
68
 
69
 
70
  def chat(
71
+ inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain]
72
  ):
73
  """Execute the chat functionality."""
74
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
 
107
  return html_video, "videos/tempfile.mp4"
108
 
109
 
110
+ def update_selected_tools(widget, state, llm):
111
+ if widget:
112
+ state = widget
113
+ chain = load_chain(state, llm)
114
+ return state, llm, chain
115
+
116
+
117
  block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
118
 
119
  with block:
120
+ llm_state = gr.State()
121
+ history_state = gr.State()
122
+ chain_state = gr.State()
123
+ tools_list_state = gr.State(TOOLS_DEFAULT_LIST)
124
+
125
  with gr.Row():
126
  with gr.Column():
127
  gr.Markdown("<h4><center>Conversational Agent using GPT-3.5 & LangChain</center></h4>")
128
 
129
  openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
130
+ show_label=False, lines=1, type='password')
131
 
132
  with gr.Row():
133
  with gr.Column(scale=0.25, min_width=240):
 
137
  htm_video = f'<video width="256" height="256" autoplay muted loop><source src={tmp_file_url} type="video/mp4" poster="Masahiro.png"></video>'
138
  video_html = gr.HTML(htm_video)
139
 
140
+ tools_cb_group = gr.CheckboxGroup(label="Tools:", choices=TOOLS_LIST,
141
+ value=TOOLS_DEFAULT_LIST)
142
+
143
+ tools_cb_group.change(update_selected_tools,
144
+ inputs=[tools_cb_group, tools_list_state, llm_state],
145
+ outputs=[tools_list_state, llm_state, chain_state])
146
+
147
  with gr.Column(scale=0.75):
148
  chatbot = gr.Chatbot()
149
 
 
178
 
179
  gr.HTML("<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>")
180
 
181
+ message.submit(chat, inputs=[message, history_state, chain_state], outputs=[chatbot, history_state, video_html, my_file, message])
182
+ submit.click(chat, inputs=[message, history_state, chain_state], outputs=[chatbot, history_state, video_html, my_file, message])
 
 
 
183
 
184
  openai_api_key_textbox.change(set_openai_api_key,
185
+ inputs=[openai_api_key_textbox],
186
+ outputs=[chain_state, llm_state])
 
 
187
 
188
+ block.launch(debug=True)
videos/tempfile.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ad6ea94ca0de42304c461a30340e259f9943ef79c9aaa68d8eef2087ee398a6
3
- size 135190
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:355cfbe21252ee7bd7b3cc6ea13e68abc209330bd139abb0d24e301d42e74b57
3
+ size 75