suusuu93 commited on
Commit
2805f98
·
verified ·
1 Parent(s): 9d2f0df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -6,11 +6,12 @@ import torch
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
8
 
9
- # Biến toàn cục lưu lịch sử hội thoại
10
  chat_history_ids = None
 
11
 
12
  def chatbot(msg):
13
- global chat_history_ids
14
 
15
  # Encode câu hỏi mới
16
  new_input_ids = tokenizer.encode(msg + tokenizer.eos_token, return_tensors='pt')
@@ -24,19 +25,25 @@ def chatbot(msg):
24
  # Sinh câu trả lời dựa trên toàn bộ lịch sử
25
  chat_history_ids = model.generate(
26
  bot_input_ids,
27
- max_length=1000, # tăng để giữ nhiều context
28
  pad_token_id=tokenizer.eos_token_id
29
  )
30
 
31
  # Decode câu trả lời (chỉ lấy phần mới sinh ra)
32
  reply = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
33
- return f"Dr. Mom AI: {reply}"
 
 
 
 
 
 
34
 
35
  # Giao diện Gradio
36
  demo = gr.Interface(
37
  fn=chatbot,
38
  inputs=gr.Textbox(label="Bạn hỏi gì nè?"),
39
- outputs=gr.Textbox(label="Dr. Mom AI trả lời"),
40
  title="Dr. Mom AI",
41
  theme="default"
42
  )
 
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
8
 
9
+ # Biến toàn cục lưu lịch sử hội thoại cho model và hiển thị
10
  chat_history_ids = None
11
+ history = [] # lưu [("User: ...", "Bot: ..."), ...]
12
 
13
  def chatbot(msg):
14
+ global chat_history_ids, history
15
 
16
  # Encode câu hỏi mới
17
  new_input_ids = tokenizer.encode(msg + tokenizer.eos_token, return_tensors='pt')
 
25
  # Sinh câu trả lời dựa trên toàn bộ lịch sử
26
  chat_history_ids = model.generate(
27
  bot_input_ids,
28
+ max_length=1000,
29
  pad_token_id=tokenizer.eos_token_id
30
  )
31
 
32
  # Decode câu trả lời (chỉ lấy phần mới sinh ra)
33
  reply = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
34
+
35
+ # Lưu vào lịch sử
36
+ history.append(("Bạn: " + msg, "Dr. Mom AI: " + reply))
37
+
38
+ # Hiển thị toàn bộ lịch sử hội thoại
39
+ chat_log = "\n".join([f"{u}\n{b}" for u, b in history])
40
+ return chat_log
41
 
42
  # Giao diện Gradio
43
  demo = gr.Interface(
44
  fn=chatbot,
45
  inputs=gr.Textbox(label="Bạn hỏi gì nè?"),
46
+ outputs=gr.Textbox(label="Lịch sử hội thoại", lines=15),
47
  title="Dr. Mom AI",
48
  theme="default"
49
  )