File size: 2,201 Bytes
ecb7812
105ed0f
 
 
7aee45a
d00fffc
9eb5e7a
105ed0f
 
43e2809
105ed0f
 
 
 
 
43e2809
 
105ed0f
 
 
 
 
 
 
c4614a3
 
9eb5e7a
c4614a3
 
9eb5e7a
 
c4614a3
105ed0f
 
 
 
 
43e2809
c4614a3
d3329dd
c4614a3
105ed0f
 
 
 
 
 
 
 
 
 
 
c4614a3
105ed0f
 
 
c4614a3
 
43e2809
c4614a3
 
43e2809
105ed0f
c4614a3
 
 
43e2809
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import os

# --- 配置 ---
MODEL_ID = "badanwang/teacher_basic_qwen3-0.6b"

# --- 加载模型和分词器 ---
print("开始加载模型和分词器...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True
    )
    print("模型和分词器加载成功!")
except Exception as e:
    print(f"模型加载失败: {e}")
    raise gr.Error(f"关键错误:无法加载模型 {MODEL_ID}。错误信息: {e}")

# --- 核心对话函数 ---
def predict(message, history):
    messages = []
    for turn in history:
        user_msg, assistant_msg = turn
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})
    
    model_inputs = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=300.0, skip_prompt=True, skip_special_tokens=True)

    generation_kwargs = dict(
        inputs=model_inputs,
        streamer=streamer,
        max_new_tokens=2048,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    full_response = ""
    for new_text in streamer:
        full_response += new_text
        yield full_response

# --- 创建并启动Gradio界面 ---
# 已移除 examples 和 cache_examples 参数来修复点击示例时报错的问题
demo = gr.ChatInterface(
    fn=predict,
    title="小Q老师 - 基础问答 (本地加载)",
    description=f"直接在Space中运行 {MODEL_ID} 模型进行流式对话。CPU推理可能较慢,请耐心等待。",
)

if __name__ == "__main__":
    # 使用 share=True 来允许跨域 WebSocket 连接
    demo.launch(share=True)