ChatBotAPI / app.py
hamza2923's picture
Update app.py
6fe40f1 verified
raw
history blame
2.47 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def respond(message, chat_history, chat_history_ids):
if not message.strip():
return "", chat_history or [], chat_history_ids, "Please enter a message."
if chat_history is None:
chat_history = []
new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
try:
chat_history_ids = model.generate(
input_ids,
max_length=200,
pad_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8
)
response = tokenizer.decode(
chat_history_ids[:, input_ids.shape[-1]:][0],
skip_special_tokens=True
)
chat_history.append((message, response))
if len(chat_history) > 10:
chat_history = chat_history[-10:]
history_text = "".join([msg + resp + tokenizer.eos_token for msg, resp in chat_history])
chat_history_ids = tokenizer.encode(history_text, return_tensors="pt").to(device)
return "", chat_history, chat_history_ids, None
except Exception as e:
return "", chat_history, chat_history_ids, f"Error: {str(e)}"
def clear_history():
return [], None, None
with gr.Blocks() as demo:
state = gr.State()
gr.Markdown("## DialoGPT Chatbot")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
clear = gr.Button("Clear History")
error = gr.Textbox(label="Error", interactive=False, visible=False)
msg.submit(
respond,
inputs=[msg, chatbot, state],
outputs=[msg, chatbot, state, error]
)
clear.click(
fn=clear_history,
inputs=None,
outputs=[chatbot, state, error],
queue=False
)