|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b") |
|
|
print(f"INFO: 正在加载模型: {MODEL_ID}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("INFO: 模型和分词器加载成功!") |
|
|
|
|
|
|
|
|
|
|
|
def predict(prompt: str, history: list[list[str]]): |
|
|
""" |
|
|
接收用户输入和对话历史,返回更新后的完整对话历史。 |
|
|
Gradio 会自动为这个函数创建 API 端点。 |
|
|
""" |
|
|
print(f"INFO: 收到API/UI请求: prompt='{prompt}'") |
|
|
|
|
|
|
|
|
messages = [] |
|
|
for user_message, bot_message in history: |
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
messages.append({"role": "assistant", "content": bot_message}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt" |
|
|
).to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
outputs = model.generate(input_ids, max_new_tokens=1024) |
|
|
|
|
|
|
|
|
response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True) |
|
|
|
|
|
print(f"INFO: 生成回复: {response_text}") |
|
|
|
|
|
|
|
|
history.append([prompt, response_text]) |
|
|
return history |
|
|
|
|
|
except Exception as e: |
|
|
print(f"FATAL: 加载模型或分词器时发生致命错误: {e}") |
|
|
|
|
|
|
|
|
def predict(*args, **kwargs): |
|
|
raise gr.Error(f"模型未能加载,应用无法工作。请检查后台日志获取详细错误信息。错误: {e}") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: |
|
|
gr.Markdown(f"## 模型聊天机器人\n当前模型: `{MODEL_ID}`") |
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot(label="对话历史", height=600) |
|
|
msg_input = gr.Textbox(label="在这里输入你的问题...", placeholder="例如:你好,你是谁?") |
|
|
clear_button = gr.Button("清除对话") |
|
|
|
|
|
|
|
|
|
|
|
msg_input.submit(predict, [msg_input, chatbot], chatbot) |
|
|
|
|
|
clear_button.click(lambda: [], None, chatbot) |
|
|
|
|
|
|
|
|
print("INFO: 准备启动Gradio应用...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch(share=True) |