| | |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import gradio as gr |
| |
|
| | |
| | checkpoint = "microsoft/DialoGPT-medium" |
| | tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| | model = AutoModelForCausalLM.from_pretrained(checkpoint) |
| |
|
| | |
| | chat_history_ids = None |
| |
|
| | def respond(user_input, history=[]): |
| | global chat_history_ids |
| | |
| | |
| | new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') |
| | |
| | if chat_history_ids is not None: |
| | bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) |
| | else: |
| | bot_input_ids = new_input_ids |
| |
|
| | chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) |
| |
|
| | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
| |
|
| | history.append((user_input, output)) |
| | return history, history |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | chatbot = gr.Chatbot() |
| | msg = gr.Textbox(label="Type your message here") |
| | clear = gr.Button("Clear Chat") |
| |
|
| | state = gr.State([]) |
| |
|
| | msg.submit(respond, [msg, state], [chatbot, state]) |
| | clear.click(lambda: ([], []), None, [chatbot, state]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |