Udyan commited on
Commit
122aed7
·
verified ·
1 Parent(s): 469baf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -1,29 +1,33 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
  model_name = "facebook/blenderbot-400M-distill"
5
-
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
  def chat_function(message, history):
10
-
 
 
 
11
  history_text = ""
12
-
13
- for user, bot in history:
14
- history_text += user + " " + bot + " "
15
-
16
- history_text += message
17
-
18
- inputs = tokenizer(history_text, return_tensors="pt", truncation=True)
19
-
20
- outputs = model.generate(**inputs, max_new_tokens=60)
21
-
22
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
-
 
 
24
  return response
25
 
26
-
27
  demo = gr.ChatInterface(
28
  fn=chat_function,
29
  title="BlenderBot Chat",
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
  model_name = "facebook/blenderbot-400M-distill"
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
  def chat_function(message, history):
10
+ # Keep only the last 2-3 exchanges to stay under the 128 token limit
11
+ # Blenderbot crashes if the input context is too long.
12
+ recent_history = history[-2:] if len(history) > 2 else history
13
+
14
  history_text = ""
15
+ for user, bot in recent_history:
16
+ history_text += f"{user} {bot} "
17
+
18
+ # Format the final input string
19
+ input_text = f"{history_text}{message}"
20
+
21
+ # Use truncation and specifically set max_length for the encoder
22
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=128)
23
+
24
+ # Generate
25
+ with torch.no_grad():
26
+ outputs = model.generate(**inputs, max_new_tokens=60)
27
+
28
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
29
  return response
30
 
 
31
  demo = gr.ChatInterface(
32
  fn=chat_function,
33
  title="BlenderBot Chat",