|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
|
|
|
|
|
|
|
chat_history_ids = None |
|
|
|
|
|
def chatbot(msg, history): |
|
|
global chat_history_ids |
|
|
|
|
|
|
|
|
new_input_ids = tokenizer.encode(msg + tokenizer.eos_token, return_tensors='pt') |
|
|
|
|
|
|
|
|
if chat_history_ids is None: |
|
|
bot_input_ids = new_input_ids |
|
|
else: |
|
|
bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) |
|
|
|
|
|
|
|
|
chat_history_ids = model.generate( |
|
|
bot_input_ids, |
|
|
max_length=1000, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
reply = tokenizer.decode( |
|
|
chat_history_ids[:, bot_input_ids.shape[-1]:][0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
history.append((msg, reply)) |
|
|
return history, history |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Dr. Mom AI") as demo: |
|
|
gr.Markdown("## 🤖 Dr. Mom AI Chatbot") |
|
|
|
|
|
chatbot_ui = gr.Chatbot(label="Lịch sử hội thoại") |
|
|
msg = gr.Textbox(label="Bạn hỏi gì nè?", placeholder="Nhập tin nhắn rồi nhấn Enter...") |
|
|
clear = gr.Button("🔄 Reset hội thoại") |
|
|
|
|
|
msg.submit(chatbot, [msg, chatbot_ui], [chatbot_ui, chatbot_ui]) |
|
|
clear.click(lambda: None, None, chatbot_ui, queue=False) |
|
|
|
|
|
demo.launch() |
|
|
|