DDDDEvvvvv commited on
Commit
dc0f20f
·
verified ·
1 Parent(s): b675c93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -5,22 +5,31 @@ model_name = "facebook/blenderbot-400M-distill"
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
 
8
- history = []
9
 
10
- def respond(message):
11
- global history
12
- history.append(message)
13
- input_text = " ".join(history[-5:]) # keep last 5 messages
 
 
 
 
 
 
14
  inputs = tokenizer(input_text, return_tensors="pt")
15
  outputs = model.generate(**inputs, max_new_tokens=100)
16
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
- history.append(response)
18
- return response
 
 
 
19
 
20
  with gr.Blocks() as demo:
21
  chatbot = gr.Chatbot()
22
  msg = gr.Textbox(placeholder="Say something...")
23
- msg.submit(respond, msg, chatbot)
24
 
25
  demo.launch(css="""
26
  body {background-color: #000 !important; color: #fff !important;}
 
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
 
8
+ history = [] # list of (user_message, bot_response) pairs
9
 
10
+ def respond(message, chat_history):
11
+ # chat_history is a list of tuples: (user, bot)
12
+ if chat_history is None:
13
+ chat_history = []
14
+ # add user message to history
15
+ chat_history.append((message, None))
16
+
17
+ # flatten last few messages for context, e.g. last 5 user+bot turns (10 messages)
18
+ # but keep it simple here and just join user messages
19
+ input_text = " ".join([msg for pair in chat_history[-5:] for msg in pair if msg])
20
  inputs = tokenizer(input_text, return_tensors="pt")
21
  outputs = model.generate(**inputs, max_new_tokens=100)
22
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+
24
+ # update last user message's bot response
25
+ chat_history[-1] = (message, response)
26
+
27
+ return chat_history, chat_history
28
 
29
  with gr.Blocks() as demo:
30
  chatbot = gr.Chatbot()
31
  msg = gr.Textbox(placeholder="Say something...")
32
+ msg.submit(respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot])
33
 
34
  demo.launch(css="""
35
  body {background-color: #000 !important; color: #fff !important;}