Udyan commited on
Commit
469baf9
·
verified ·
1 Parent(s): bdbfde3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -1,29 +1,32 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- # Load the model using the pipeline (more stable for Gradio)
5
- # Blenderbot-400M can sometimes hit memory limits; pipeline is efficient.
6
- chatbot_pipeline = pipeline("text2text-generation", model="facebook/blenderbot-400M-distill")
 
7
 
8
  def chat_function(message, history):
9
- # Gradio history is a list of [user, bot] pairs.
10
- # We turn it into one string for Blenderbot.
11
- conversation = ""
12
- for user_msg, bot_msg in history:
13
- conversation += f"{user_msg} {bot_msg} "
14
-
15
- conversation += message
16
-
17
- # Generate the response
18
- # truncation=True prevents "input too long" errors
19
- result = chatbot_pipeline(conversation, max_new_tokens=60, truncation=True)
20
-
21
- return result[0]['generated_text']
22
-
23
- # Define the interface
 
 
24
  demo = gr.ChatInterface(
25
  fn=chat_function,
26
- title="Blenderbot Chat",
27
  description="Ask me anything!"
28
  )
29
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ model_name = "facebook/blenderbot-400M-distill"
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
  def chat_function(message, history):
10
+
11
+ history_text = ""
12
+
13
+ for user, bot in history:
14
+ history_text += user + " " + bot + " "
15
+
16
+ history_text += message
17
+
18
+ inputs = tokenizer(history_text, return_tensors="pt", truncation=True)
19
+
20
+ outputs = model.generate(**inputs, max_new_tokens=60)
21
+
22
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+
24
+ return response
25
+
26
+
27
  demo = gr.ChatInterface(
28
  fn=chat_function,
29
+ title="BlenderBot Chat",
30
  description="Ask me anything!"
31
  )
32