| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
| model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
|
| chat_history_ids = None |
|
|
| def chatbot(user_input, history): |
| global chat_history_ids |
| |
| |
| new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') |
| |
| |
| if chat_history_ids is not None: |
| bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) |
| else: |
| bot_input_ids = new_input_ids |
| |
| |
| chat_history_ids = model.generate( |
| bot_input_ids, |
| max_length=1000, |
| pad_token_id=tokenizer.eos_token_id, |
| do_sample=True, |
| top_k=100, |
| top_p=0.7, |
| temperature=0.8 |
| ) |
| |
| |
| response = tokenizer.decode( |
| chat_history_ids[:, bot_input_ids.shape[-1]:][0], |
| skip_special_tokens=True |
| ) |
| |
| |
| history = history or [] |
| history.append((user_input, response)) |
| |
| return response, history |
|
|
| |
| iface = gr.Interface( |
| fn=chatbot, |
| inputs=[ |
| gr.Textbox(lines=2, placeholder="Type your message here..."), |
| gr.State() |
| ], |
| outputs=[ |
| "text", |
| gr.State() |
| ], |
| title="Mini Chatbot 🤖", |
| description="This chatbot uses Microsoft DialoGPT-medium" |
| ) |
|
|
| iface.launch() |