wang4067's picture
Update app.py
9f4f4b7 verified
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import os
# --- 配置 ---
MODEL_ID = "Qwen/Qwen3-0.6B-Base"
# ❗ 1. 定义系统提示词
SYSTEM_PROMPT = """As a helper/no_think
"""
# --- 加载模型和分词器 ---
print("开始加载模型和分词器...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
print("模型和分词器加载成功!")
except Exception as e:
print(f"模型加载失败: {e}")
raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")
# --- 核心对话函数 ---
def predict(message, history):
messages = []
# ❗ 2. 在消息列表的开头添加系统提示词
messages.append({"role": "system", "content": SYSTEM_PROMPT})
# # 3. 添加历史对话记录
# for turn in history:
# user_msg, assistant_msg = turn
# messages.append({"role": "user", "content": user_msg})
# messages.append({"role": "assistant", "content": assistant_msg})
# 4. 添加当前的用户消息
messages.append({"role": "user", "content": message})
# 使用 tokenizer.apply_chat_template 将消息列表转换为模型输入
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs=model_inputs,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
temperature=0.4,
top_p=0.95,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_response = ""
for new_text in streamer:
full_response += new_text
yield full_response
# --- 创建并启动Gradio界面 ---
# 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题
demo = gr.ChatInterface(
fn=predict
)
if __name__ == "__main__":
# 使用 share=True 来允许跨域 WebSocket 连接
demo.queue().launch(share=True)