File size: 1,894 Bytes
1c68161 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载 DialoGPT 模型和 Tokenizer
model_name = "microsoft/DialoGPT-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 存储对话历史
conversation_history = []
# 对话生成函数
def respond_to_input(user_input):
global conversation_history
# 编码用户输入并将其附加到对话历史
conversation_history.append(f"User: {user_input}")
# 将历史对话作为模型输入
input_text = " ".join(conversation_history[-5:]) # 只传递最近的5条对话,避免过长
input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
# 生成对话的响应
response_ids = model.generate(input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
# 解码模型生成的响应
bot_response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# 将机器人响应添加到对话历史
conversation_history.append(f"Bot: {bot_response}")
# 返回更新后的对话历史
chat_history = "\n".join(conversation_history[-10:]) # 显示最近的10条对话
return chat_history, "" # 更新对话历史并清空输入框
# 创建 Gradio 界面
iface = gr.Interface(
fn=respond_to_input,
inputs=gr.Textbox(label="", placeholder="Type here...", lines=1, scale=2),
outputs=[gr.Textbox(label="Conversation History", lines=15, interactive=False), gr.Textbox()],
title="ChatGPT-like Chatbot",
description="Chat with a bot powered by DialoGPT. Type your question below!",
theme="default", # 使用默认的主题
live=True,
allow_flagging="never", # 禁用标记按钮
css=".output-textbox { height: 400px; }" # 自定义输出框高度
)
# 启动应用
iface.launch() |