JLW commited on
Commit
a3ea15f
·
1 Parent(s): bf21e84

Provide ability to reset the chat.

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -243,7 +243,7 @@ def load_chain(tools_list, llm):
243
 
244
  chain = initialize_agent(tools, llm, agent="conversational-react-description", verbose=True, memory=memory)
245
  express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
246
- return chain, express_chain
247
 
248
 
249
  def set_openai_api_key(api_key):
@@ -258,7 +258,7 @@ def set_openai_api_key(api_key):
258
  llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
259
  print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str(
260
  len(os.environ["OPENAI_API_KEY"])))
261
- chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm)
262
 
263
  # Pertains to question answering functionality
264
  embeddings = OpenAIEmbeddings()
@@ -267,8 +267,8 @@ def set_openai_api_key(api_key):
267
  print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str(
268
  len(os.environ["OPENAI_API_KEY"])))
269
  os.environ["OPENAI_API_KEY"] = ""
270
- return chain, express_chain, llm, embeddings, qa_chain
271
- return None, None, None, None, None
272
 
273
 
274
  def run_chain(chain, inp, capture_hidden_text):
@@ -335,6 +335,12 @@ def run_chain(chain, inp, capture_hidden_text):
335
  return output, hidden_text
336
 
337
 
 
 
 
 
 
 
338
  class ChatWrapper:
339
 
340
  def __init__(self):
@@ -558,6 +564,7 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
558
  speak_text_state = gr.State(False)
559
  talking_head_state = gr.State(True)
560
  monologue_state = gr.State(False) # Takes the input and repeats it back to the user, optionally transforming it.
 
561
 
562
  # Pertains to Express-inator functionality
563
  num_words_state = gr.State(NUM_WORDS_DEFAULT)
@@ -667,6 +674,9 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
667
  monologue_cb.change(update_foo, inputs=[monologue_cb, monologue_state],
668
  outputs=[monologue_state])
669
 
 
 
 
670
  with gr.Tab("Whisper STT"):
671
  whisper_lang_radio = gr.Radio(label="Whisper speech-to-text language:", choices=[
672
  WHISPER_DETECT_LANG, "Arabic", "Arabic (Gulf)", "Catalan", "Chinese (Cantonese)", "Chinese (Mandarin)",
@@ -849,6 +859,7 @@ with gr.Blocks(css=".gradio-container {background-color: lightgray}") as block:
849
 
850
  openai_api_key_textbox.change(set_openai_api_key,
851
  inputs=[openai_api_key_textbox],
852
- outputs=[chain_state, express_chain_state, llm_state, embeddings_state, qa_chain_state])
 
853
 
854
  block.launch(debug=True)
 
243
 
244
  chain = initialize_agent(tools, llm, agent="conversational-react-description", verbose=True, memory=memory)
245
  express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
246
+ return chain, express_chain, memory
247
 
248
 
249
  def set_openai_api_key(api_key):
 
258
  llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
259
  print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str(
260
  len(os.environ["OPENAI_API_KEY"])))
261
+ chain, express_chain, memory = load_chain(TOOLS_DEFAULT_LIST, llm)
262
 
263
  # Pertains to question answering functionality
264
  embeddings = OpenAIEmbeddings()
 
267
  print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str(
268
  len(os.environ["OPENAI_API_KEY"])))
269
  os.environ["OPENAI_API_KEY"] = ""
270
+ return chain, express_chain, llm, embeddings, qa_chain, memory
271
+ return None, None, None, None, None, None
272
 
273
 
274
  def run_chain(chain, inp, capture_hidden_text):
 
335
  return output, hidden_text
336
 
337
 
338
+ def reset_memory(history, memory):
339
+ memory.clear()
340
+ history = []
341
+ return history, history, memory
342
+
343
+
344
  class ChatWrapper:
345
 
346
  def __init__(self):
 
564
  speak_text_state = gr.State(False)
565
  talking_head_state = gr.State(True)
566
  monologue_state = gr.State(False) # Takes the input and repeats it back to the user, optionally transforming it.
567
+ memory_state = gr.State()
568
 
569
  # Pertains to Express-inator functionality
570
  num_words_state = gr.State(NUM_WORDS_DEFAULT)
 
674
  monologue_cb.change(update_foo, inputs=[monologue_cb, monologue_state],
675
  outputs=[monologue_state])
676
 
677
+ reset_btn = gr.Button(value="Reset chat", variant="secondary").style(full_width=False)
678
+ reset_btn.click(reset_memory, inputs=[history_state, memory_state], outputs=[chatbot, history_state, memory_state])
679
+
680
  with gr.Tab("Whisper STT"):
681
  whisper_lang_radio = gr.Radio(label="Whisper speech-to-text language:", choices=[
682
  WHISPER_DETECT_LANG, "Arabic", "Arabic (Gulf)", "Catalan", "Chinese (Cantonese)", "Chinese (Mandarin)",
 
859
 
860
  openai_api_key_textbox.change(set_openai_api_key,
861
  inputs=[openai_api_key_textbox],
862
+ outputs=[chain_state, express_chain_state, llm_state, embeddings_state,
863
+ qa_chain_state, memory_state])
864
 
865
  block.launch(debug=True)