File size: 2,394 Bytes
ecb7812 105ed0f 7aee45a d00fffc 9eb5e7a 7a8fd4e 50baffa 9f4f4b7 abe0e50 43e2809 105ed0f 43e2809 105ed0f c4614a3 9eb5e7a 50baffa 551277d 50baffa c4614a3 105ed0f 50baffa 105ed0f 43e2809 c4614a3 d3329dd c4614a3 105ed0f 551277d 105ed0f c4614a3 105ed0f c4614a3 43e2809 c4614a3 74a5a3e c4614a3 43e2809 8c5305d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import os
# --- 配置 ---
MODEL_ID = "Qwen/Qwen3-0.6B-Base"
# ❗ 1. 定义系统提示词
SYSTEM_PROMPT = """As a helper/no_think
"""
# --- 加载模型和分词器 ---
print("开始加载模型和分词器...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
print("模型和分词器加载成功!")
except Exception as e:
print(f"模型加载失败: {e}")
raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")
# --- 核心对话函数 ---
def predict(message, history):
messages = []
# ❗ 2. 在消息列表的开头添加系统提示词
messages.append({"role": "system", "content": SYSTEM_PROMPT})
# # 3. 添加历史对话记录
# for turn in history:
# user_msg, assistant_msg = turn
# messages.append({"role": "user", "content": user_msg})
# messages.append({"role": "assistant", "content": assistant_msg})
# 4. 添加当前的用户消息
messages.append({"role": "user", "content": message})
# 使用 tokenizer.apply_chat_template 将消息列表转换为模型输入
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs=model_inputs,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
temperature=0.4,
top_p=0.95,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_response = ""
for new_text in streamer:
full_response += new_text
yield full_response
# --- 创建并启动Gradio界面 ---
# 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题
demo = gr.ChatInterface(
fn=predict
)
if __name__ == "__main__":
# 使用 share=True 来允许跨域 WebSocket 连接
demo.queue().launch(share=True) |