File size: 3,997 Bytes
ecb7812 7f2c7f6 ed1d652 7aee45a d00fffc ed1d652 128f145 7aee45a 128f145 ed1d652 128f145 ecb7812 128f145 72967c5 128f145 72967c5 128f145 72967c5 128f145 ed1d652 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 128f145 ecb7812 0e98b1f 128f145 0e98b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# --- 1. 配置与模型加载 ---
# 假设运行环境的硬件资源是充足的。
MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
print(f"INFO: 正在加载模型: {MODEL_ID}")
# 使用 try-except 来捕获任何可能的加载错误 (例如网络问题、模型名称错误等)
try:
# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# device_map="auto" 会自动利用可用的硬件 (如 CPU 或 GPU)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto", # 自动选择最佳数据类型
device_map="auto",
trust_remote_code=True
)
print("INFO: 模型和分词器加载成功!")
# 将核心推理逻辑定义为一个函数
# 只有在模型成功加载后,这个函数才会被有效定义
def predict(prompt: str, history: list[list[str]]):
"""
接收用户输入和对话历史,返回更新后的完整对话历史。
Gradio 会自动为这个函数创建 API 端点。
"""
print(f"INFO: 收到API/UI请求: prompt='{prompt}'")
# 1. 构建符合模型要求的消息列表
messages = []
for user_message, bot_message in history:
messages.append({"role": "user", "content": user_message})
messages.append({"role": "assistant", "content": bot_message})
messages.append({"role": "user", "content": prompt})
# 2. 应用聊天模板并进行分词
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
# 3. 生成回复
# 使用简单的 .generate(),不加额外的采样参数以保持简洁
outputs = model.generate(input_ids, max_new_tokens=1024)
# 4. 解码生成的文本,跳过输入的token
response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(f"INFO: 生成回复: {response_text}")
# 5. 更新并返回对话历史
history.append([prompt, response_text])
return history
except Exception as e:
print(f"FATAL: 加载模型或分词器时发生致命错误: {e}")
# 如果模型加载失败,则定义一个专门用于报错的函数
# 这能确保Gradio界面依然可以启动,并向用户显示一个清晰的错误信息
def predict(*args, **kwargs):
raise gr.Error(f"模型未能加载,应用无法工作。请检查后台日志获取详细错误信息。错误: {e}")
# --- 2. 创建并启动 Gradio 应用 ---
# 使用 gr.Blocks 来自定义界面布局
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown(f"## 模型聊天机器人\n当前模型: `{MODEL_ID}`")
# 定义聊天机器人组件和输入框
chatbot = gr.Chatbot(label="对话历史", height=600)
msg_input = gr.Textbox(label="在这里输入你的问题...", placeholder="例如:你好,你是谁?")
clear_button = gr.Button("清除对话")
# 设定组件的交互逻辑
# 当用户在输入框中按回车时,调用 predict 函数
msg_input.submit(predict, [msg_input, chatbot], chatbot)
# 当用户点击“清除对话”按钮时,清空聊天机器人组件
clear_button.click(lambda: [], None, chatbot)
# --- 3. 启动应用并开放API ---
print("INFO: 准备启动Gradio应用...")
# .queue() 使应用能够处理多个排队的请求,并且在 4.29.0 版本中会自动开放API。
# share=True 是解决CORS问题的关键。它会生成一个公开的、已配置好CORS的 .gradio.live 网址。
# *** 已移除 'api_open=True' 参数以适配 gradio==4.29.0 ***
demo.queue().launch(share=True) |