Udyan commited on
Commit
ea46755
·
verified ·
1 Parent(s): 122aed7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -3,36 +3,43 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
  model_name = "facebook/blenderbot-400M-distill"
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
  def chat_function(message, history):
10
- # Keep only the last 2-3 exchanges to stay under the 128 token limit
11
- # Blenderbot crashes if the input context is too long.
12
- recent_history = history[-2:] if len(history) > 2 else history
13
-
14
  history_text = ""
15
- for user, bot in recent_history:
16
- history_text += f"{user} {bot} "
17
-
18
- # Format the final input string
19
- input_text = f"{history_text}{message}"
20
-
21
- # Use truncation and specifically set max_length for the encoder
22
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=128)
23
-
24
- # Generate
 
 
 
 
 
25
  with torch.no_grad():
26
- outputs = model.generate(**inputs, max_new_tokens=60)
27
-
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
 
 
29
  return response
30
 
 
31
  demo = gr.ChatInterface(
32
  fn=chat_function,
33
  title="BlenderBot Chat",
34
  description="Ask me anything!"
35
  )
36
 
37
- if __name__ == "__main__":
38
- demo.launch()
 
3
  import torch
4
 
5
  model_name = "facebook/blenderbot-400M-distill"
6
+
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
  def chat_function(message, history):
11
+
 
 
 
12
  history_text = ""
13
+
14
+ # Keep only last 2 exchanges
15
+ for pair in history[-2:]:
16
+ if pair[0] and pair[1]:
17
+ history_text += pair[0] + " " + pair[1] + " "
18
+
19
+ input_text = history_text + message
20
+
21
+ inputs = tokenizer(
22
+ input_text,
23
+ return_tensors="pt",
24
+ truncation=True,
25
+ max_length=128
26
+ )
27
+
28
  with torch.no_grad():
29
+ outputs = model.generate(
30
+ **inputs,
31
+ max_new_tokens=60
32
+ )
33
+
34
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+
36
  return response
37
 
38
+
39
  demo = gr.ChatInterface(
40
  fn=chat_function,
41
  title="BlenderBot Chat",
42
  description="Ask me anything!"
43
  )
44
 
45
+ demo.launch()