suusuu93 commited on
Commit
ba03f5d
·
verified ·
1 Parent(s): 2805f98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -6,17 +6,16 @@ import torch
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
8
 
9
- # Biến toàn cục lưu lịch sử hội thoại cho model và hiển thị
10
  chat_history_ids = None
11
- history = [] # lưu [("User: ...", "Bot: ..."), ...]
12
 
13
- def chatbot(msg):
14
- global chat_history_ids, history
15
 
16
  # Encode câu hỏi mới
17
  new_input_ids = tokenizer.encode(msg + tokenizer.eos_token, return_tensors='pt')
18
 
19
- # Nếu chưa có lịch sử, khởi tạo; nếu có thì nối thêm
20
  if chat_history_ids is None:
21
  bot_input_ids = new_input_ids
22
  else:
@@ -24,28 +23,30 @@ def chatbot(msg):
24
 
25
  # Sinh câu trả lời dựa trên toàn bộ lịch sử
26
  chat_history_ids = model.generate(
27
- bot_input_ids,
28
- max_length=1000,
29
  pad_token_id=tokenizer.eos_token_id
30
  )
31
 
32
  # Decode câu trả lời (chỉ lấy phần mới sinh ra)
33
- reply = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
34
-
35
- # Lưu vào lịch sử
36
- history.append(("Bạn: " + msg, "Dr. Mom AI: " + reply))
37
 
38
- # Hiển thị toàn bộ lịch sử hội thoại
39
- chat_log = "\n".join([f"{u}\n{b}" for u, b in history])
40
- return chat_log
41
 
42
  # Giao diện Gradio
43
- demo = gr.Interface(
44
- fn=chatbot,
45
- inputs=gr.Textbox(label="Bạn hỏi gì nè?"),
46
- outputs=gr.Textbox(label="Lịch sử hội thoại", lines=15),
47
- title="Dr. Mom AI",
48
- theme="default"
49
- )
 
 
50
 
51
  demo.launch()
 
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
8
 
9
+ # Biến toàn cục để lưu lịch sử cho hình
10
  chat_history_ids = None
 
11
 
12
+ def chatbot(msg, history):
13
+ global chat_history_ids
14
 
15
  # Encode câu hỏi mới
16
  new_input_ids = tokenizer.encode(msg + tokenizer.eos_token, return_tensors='pt')
17
 
18
+ # Nếu chưa có lịch sử thì khởi tạo; nếu có thì nối thêm
19
  if chat_history_ids is None:
20
  bot_input_ids = new_input_ids
21
  else:
 
23
 
24
  # Sinh câu trả lời dựa trên toàn bộ lịch sử
25
  chat_history_ids = model.generate(
26
+ bot_input_ids,
27
+ max_length=1000,
28
  pad_token_id=tokenizer.eos_token_id
29
  )
30
 
31
  # Decode câu trả lời (chỉ lấy phần mới sinh ra)
32
+ reply = tokenizer.decode(
33
+ chat_history_ids[:, bot_input_ids.shape[-1]:][0],
34
+ skip_special_tokens=True
35
+ )
36
 
37
+ # Lưu vào history (Gradio dùng list of tuples: (user, bot))
38
+ history.append((msg, reply))
39
+ return history, history
40
 
41
  # Giao diện Gradio
42
+ with gr.Blocks(title="Dr. Mom AI") as demo:
43
+ gr.Markdown("## 🤖 Dr. Mom AI Chatbot")
44
+
45
+ chatbot_ui = gr.Chatbot(label="Lịch sử hội thoại")
46
+ msg = gr.Textbox(label="Bạn hỏi gì nè?", placeholder="Nhập tin nhắn rồi nhấn Enter...")
47
+ clear = gr.Button("🔄 Reset hội thoại")
48
+
49
+ msg.submit(chatbot, [msg, chatbot_ui], [chatbot_ui, chatbot_ui])
50
+ clear.click(lambda: None, None, chatbot_ui, queue=False)
51
 
52
  demo.launch()