|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
checkpoint = "microsoft/DialoGPT-medium" |
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
|
|
|
|
|
|
|
chat_history_ids = None |
|
|
|
|
|
def respond(user_input, history=[]): |
|
|
global chat_history_ids |
|
|
|
|
|
|
|
|
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') |
|
|
|
|
|
if chat_history_ids is not None: |
|
|
bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) |
|
|
else: |
|
|
bot_input_ids = new_input_ids |
|
|
|
|
|
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) |
|
|
|
|
|
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
|
|
history.append((user_input, output)) |
|
|
return history, history |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
chatbot = gr.Chatbot() |
|
|
msg = gr.Textbox(label="Type your message here") |
|
|
clear = gr.Button("Clear Chat") |
|
|
|
|
|
state = gr.State([]) |
|
|
|
|
|
msg.submit(respond, [msg, state], [chatbot, state]) |
|
|
clear.click(lambda: ([], []), None, [chatbot, state]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |