| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import gradio as gr |
| import torch |
|
|
| |
| title = "🤖 AI ChatBot" |
| description = "Building open-domain chatbots is a challenging area for machine learning research." |
| examples = [["How are you?"]] |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") |
| model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") |
|
|
| def predict(input_text, history=None): |
| if history is None: |
| history = [] |
|
|
| |
| new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt") |
|
|
| |
| if history: |
| past_ids = torch.LongTensor(history) |
| bot_input_ids = torch.cat([past_ids, new_user_input_ids], dim=-1) |
| else: |
| bot_input_ids = new_user_input_ids |
|
|
| |
| output_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) |
| history = output_ids.tolist() |
|
|
| |
| decoded_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| user_reply = input_text |
| bot_reply = decoded_text.split(input_text)[-1].strip() |
|
|
| |
| chatbot_messages = [] |
| if len(history) > 0: |
| chatbot_messages = [(user_reply, bot_reply)] |
|
|
| return chatbot_messages, history |
|
|
| |
| gr.Interface( |
| fn=predict, |
| title=title, |
| description=description, |
| examples=examples, |
| inputs=["text", "state"], |
| outputs=["chatbot", "state"], |
| theme="finlaymacklon/boxy_violet" |
| ).launch() |
|
|