Spaces:
Sleeping
Sleeping
File size: 2,290 Bytes
6c2e024 786d074 6c2e024 786d074 6c2e024 786d074 6c2e024 786d074 6c2e024 786d074 6c2e024 786d074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
def chat_response(message, history):
try:
# Build conversation history string
chat_history_ids = None
for human_msg, bot_msg in history:
# Encode user message
user_input_ids = tokenizer.encode(
human_msg + tokenizer.eos_token,
return_tensors='pt'
)
# Encode bot response
bot_output_ids = tokenizer.encode(
bot_msg + tokenizer.eos_token,
return_tensors='pt'
)
# Build full conversation
if chat_history_ids is None:
chat_history_ids = torch.cat([user_input_ids, bot_output_ids], dim=-1)
else:
chat_history_ids = torch.cat([chat_history_ids, user_input_ids, bot_output_ids], dim=-1)
# Add new user message
new_user_input_ids = tokenizer.encode(
message + tokenizer.eos_token,
return_tensors='pt'
)
# Generate response
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
# Generate bot response
bot_output_ids = model.generate(
chat_history_ids,
max_length=1000,
pad_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=100,
top_p=0.7,
temperature=0.8
)
# Extract only the bot's response (remove history)
response = tokenizer.decode(
bot_output_ids[:, chat_history_ids.shape[-1]:][0],
skip_special_tokens=True
)
return response
except Exception as e:
return f"Error: {str(e)}"
# Create chat interface
demo = gr.ChatInterface(
chat_response,
title="DialoGPT Chatbot",
examples=["Hello!", "What's AI?", "Tell me a joke"],
type="messages"
)
if __name__ == "__main__":
demo.launch() |