File size: 2,394 Bytes
ecb7812
105ed0f
 
 
7aee45a
d00fffc
9eb5e7a
7a8fd4e
50baffa
9f4f4b7
abe0e50
43e2809
105ed0f
 
 
 
 
43e2809
 
105ed0f
 
 
 
 
 
 
c4614a3
 
9eb5e7a
50baffa
 
 
 
551277d
 
 
 
 
50baffa
 
c4614a3
105ed0f
50baffa
105ed0f
 
 
 
43e2809
c4614a3
d3329dd
c4614a3
105ed0f
 
 
 
 
551277d
105ed0f
 
 
 
 
c4614a3
105ed0f
 
 
c4614a3
 
43e2809
c4614a3
74a5a3e
c4614a3
 
 
43e2809
8c5305d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import os

# --- 配置 ---
MODEL_ID = "Qwen/Qwen3-0.6B-Base"
# ❗ 1. 定义系统提示词
SYSTEM_PROMPT = """As a helper/no_think
"""
# --- 加载模型和分词器 ---
print("开始加载模型和分词器...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True
    )
    print("模型和分词器加载成功!")
except Exception as e:
    print(f"模型加载失败: {e}")
    raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")

# --- 核心对话函数 ---
def predict(message, history):
    messages = []
    
    # ❗ 2. 在消息列表的开头添加系统提示词
    messages.append({"role": "system", "content": SYSTEM_PROMPT})
    
    # # 3. 添加历史对话记录
    # for turn in history:
    #     user_msg, assistant_msg = turn
    #     messages.append({"role": "user", "content": user_msg})
    #     messages.append({"role": "assistant", "content": assistant_msg})
        
    # 4. 添加当前的用户消息
    messages.append({"role": "user", "content": message})
    
    # 使用 tokenizer.apply_chat_template 将消息列表转换为模型输入
    model_inputs = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        inputs=model_inputs,
        streamer=streamer,
        max_new_tokens=2048,
        do_sample=True,
        temperature=0.4,
        top_p=0.95,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    full_response = ""
    for new_text in streamer:
        full_response += new_text
        yield full_response

# --- 创建并启动Gradio界面 ---
# 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题
demo = gr.ChatInterface(
    fn=predict
)

if __name__ == "__main__":
    # 使用 share=True 来允许跨域 WebSocket 连接
    demo.queue().launch(share=True)