DDDDEvvvvv commited on
Commit
706b7d9
·
verified ·
1 Parent(s): dc0f20f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -1,45 +1,53 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
- model_name = "facebook/blenderbot-400M-distill"
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
 
8
- history = [] # list of (user_message, bot_response) pairs
9
 
10
  def respond(message, chat_history):
11
- # chat_history is a list of tuples: (user, bot)
12
  if chat_history is None:
13
  chat_history = []
14
- # add user message to history
 
15
  chat_history.append((message, None))
16
 
17
- # flatten last few messages for context, e.g. last 5 user+bot turns (10 messages)
18
- # but keep it simple here and just join user messages
19
- input_text = " ".join([msg for pair in chat_history[-5:] for msg in pair if msg])
20
- inputs = tokenizer(input_text, return_tensors="pt")
21
- outputs = model.generate(**inputs, max_new_tokens=100)
 
 
 
 
 
22
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
 
24
- # update last user message's bot response
25
  chat_history[-1] = (message, response)
26
 
27
  return chat_history, chat_history
28
 
29
  with gr.Blocks() as demo:
30
- chatbot = gr.Chatbot()
31
  msg = gr.Textbox(placeholder="Say something...")
32
  msg.submit(respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot])
33
 
34
  demo.launch(css="""
35
- body {background-color: #000 !important; color: #fff !important;}
36
- #chat-box {
37
- height: 500px;
38
- overflow-y: scroll;
39
- border: 1px solid #fff;
40
- padding: 10px;
41
- border-radius: 8px;
42
- background-color: #111;
 
 
 
43
  }
44
  .bubble {
45
  padding: 10px;
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ model_name = "microsoft/DialoGPT-small"
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForCausalLM.from_pretrained(model_name)
7
 
8
+ history = []
9
 
10
  def respond(message, chat_history):
 
11
  if chat_history is None:
12
  chat_history = []
13
+
14
+
15
  chat_history.append((message, None))
16
 
17
+
18
+ input_ids = None
19
+ for user_msg, bot_msg in chat_history[-5:]:
20
+ if input_ids is None:
21
+ input_ids = tokenizer.encode(user_msg + tokenizer.eos_token, return_tensors="pt")
22
+ else:
23
+ input_ids = tokenizer.encode(user_msg + tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(input_ids.device)
24
+
25
+
26
+ outputs = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
27
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
29
+
30
  chat_history[-1] = (message, response)
31
 
32
  return chat_history, chat_history
33
 
34
  with gr.Blocks() as demo:
35
+ chatbot = gr.Chatbot(elem_id="chatbox")
36
  msg = gr.Textbox(placeholder="Say something...")
37
  msg.submit(respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot])
38
 
39
  demo.launch(css="""
40
+ body, html, #chatbox, .gradio-container {
41
+ height: 100% !important;
42
+ margin: 0; padding: 0;
43
+ background-color: #000 !important;
44
+ color: #fff !important;
45
+ }
46
+ #chatbox {
47
+ border-radius: 8px;
48
+ border: 1px solid #fff;
49
+ background-color: #111 !important;
50
+ overflow-y: auto;
51
  }
52
  .bubble {
53
  padding: 10px;