GBhaveshKumar commited on
Commit
e07de5b
·
verified ·
1 Parent(s): 6c57e45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -2,7 +2,6 @@ from transformers import GPT2LMHeadModel, AutoTokenizer, pipeline
2
  import gradio as gr
3
  import torch
4
 
5
- # Load model from Hugging Face Hub
6
  model = GPT2LMHeadModel.from_pretrained("GBhaveshKumar/ConvoAI")
7
  tokenizer = AutoTokenizer.from_pretrained("GBhaveshKumar/ConvoAI")
8
 
@@ -13,9 +12,12 @@ generator = pipeline(
13
  device=0 if torch.cuda.is_available() else -1
14
  )
15
 
16
- def chat_with_bot(history, user_input):
17
- context = "".join([f"A: {q}\nB: {a}\n" for q, a in history])
18
- context += f"A: {user_input}\nB:"
 
 
 
19
  output = generator(
20
  context,
21
  max_length=len(tokenizer.encode(context)) + 50,
@@ -25,8 +27,13 @@ def chat_with_bot(history, user_input):
25
  top_p=0.95,
26
  temperature=0.8
27
  )[0]['generated_text']
 
28
  reply = output[len(context):].split("\n")[0].strip()
29
- history.append((user_input, reply))
30
- return history, ""
31
 
32
- gr.ChatInterface(fn=chat_with_bot, title="💬 DailyDialog Chatbot").launch()
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
 
 
5
  model = GPT2LMHeadModel.from_pretrained("GBhaveshKumar/ConvoAI")
6
  tokenizer = AutoTokenizer.from_pretrained("GBhaveshKumar/ConvoAI")
7
 
 
12
  device=0 if torch.cuda.is_available() else -1
13
  )
14
 
15
+ def respond(message, chat_history):
16
+ context = ""
17
+ for user, bot in chat_history:
18
+ context += f"A: {user}\nB: {bot}\n"
19
+ context += f"A: {message}\nB:"
20
+
21
  output = generator(
22
  context,
23
  max_length=len(tokenizer.encode(context)) + 50,
 
27
  top_p=0.95,
28
  temperature=0.8
29
  )[0]['generated_text']
30
+
31
  reply = output[len(context):].split("\n")[0].strip()
32
+ chat_history.append((message, reply))
33
+ return "", chat_history
34
 
35
+ gr.ChatInterface(
36
+ fn=respond,
37
+ title="💬 DailyDialog Chatbot",
38
+ description="A fine-tuned GPT-2 chatbot trained on the DailyDialog dataset.",
39
+ ).launch()