import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Load model and tokenizer model_name = "microsoft/DialoGPT-small" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name) # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) def respond(message, chat_history, chat_history_ids): if not message.strip(): return "", chat_history or [], chat_history_ids, "Please enter a message." if chat_history is None: chat_history = [] new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device) input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids try: chat_history_ids = model.generate( input_ids, max_length=200, pad_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=3, do_sample=True, top_k=50, top_p=0.95, temperature=0.8 ) response = tokenizer.decode( chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True ) chat_history.append((message, response)) if len(chat_history) > 10: chat_history = chat_history[-10:] history_text = "".join([msg + resp + tokenizer.eos_token for msg, resp in chat_history]) chat_history_ids = tokenizer.encode(history_text, return_tensors="pt").to(device) return "", chat_history, chat_history_ids, None except Exception as e: return "", chat_history, chat_history_ids, f"Error: {str(e)}" def clear_history(): return [], None, None with gr.Blocks() as demo: state = gr.State() gr.Markdown("## DialoGPT Chatbot") chatbot = gr.Chatbot() msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") clear = gr.Button("Clear History") error = gr.Textbox(label="Error", interactive=False, visible=False) msg.submit( respond, inputs=[msg, chatbot, state], outputs=[msg, chatbot, state, error] ) clear.click( fn=clear_history, inputs=None, outputs=[chatbot, state, error], queue=False )