File size: 3,353 Bytes
7f2c7f6
ed1d652
7aee45a
d00fffc
ed1d652
b965102
ed1d652
7aee45a
ed1d652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aee45a
ed1d652
7aee45a
 
ed1d652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b69a39
ed1d652
548ffa6
ed1d652
b965102
7aee45a
ed1d652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# --- 1. 配置与模型加载 ---

# 从环境变量或默认值加载模型ID
MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b")
print(f"正在加载模型: {MODEL_ID}")

# 加载分词器和模型
# trust_remote_code=True 是加载Qwen等模型所必需的
# device_map="auto" 会自动将模型分配到可用的硬件上(如GPU)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
print("模型加载成功!")


# --- 2. 核心推理函数 (API) ---

def get_response(prompt: str, history: list[list[str]] = None):
    """
    一个简单的函数,用于与模型进行单次对话。

    Args:
        prompt (str): 用户当前输入的问题。
        history (list[list[str]], optional): 对话历史,格式为 [[user_msg_1, bot_msg_1], ...]。默认为 None。

    Returns:
        str: 模型生成的回复。
    """
    if history is None:
        history = []

    # 1. 构建消息列表
    messages = []
    for user_message, bot_message in history:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "assistant", "content": bot_message})
    messages.append({"role": "user", "content": prompt})

    # 2. 应用聊天模板并进行分词
    # 这是与聊天模型正确交互的关键步骤
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt"
    ).to(model.device)

    # 3. 生成回复
    # 这是一个阻塞式调用,会等待模型生成完毕
    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    
    # 4. 解码生成的文本
    # `outputs[0]` 包含了输入的token和新生成的token,我们需要切片只获取新生成的部分
    response_ids = outputs[0][input_ids.shape[-1]:]
    response_text = tokenizer.decode(response_ids, skip_special_tokens=True)

    return response_text

# --- 3. 使用示例 ---

if __name__ == "__main__":
    # 示例1: 单轮对话
    print("\n--- 示例 1: 单轮对话 ---")
    question1 = "你好,你是谁?"
    print(f"用户: {question1}")
    answer1 = get_response(question1)
    print(f"模型: {answer1}")

    # 示例2: 多轮对话
    print("\n--- 示例 2: 多轮对话 ---")
    # 首先,定义一个对话历史
    chat_history = [
        ["用Python写一个快速排序", "当然,这是快速排序的Python实现:\n```python\ndef quick_sort(arr):\n    if len(arr) <= 1:\n        return arr\n    pivot = arr[len(arr) // 2]\n    left = [x for x in arr if x < pivot]\n    middle = [x for x in arr if x == pivot]\n    right = [x for x in arr if x > pivot]\n    return quick_sort(left) + middle + quick_sort(right)\n\nprint(quick_sort())\n```"]
    ]
    question2 = "很好,你能解释一下它的工作原理吗?"
    print(f"历史: {chat_history}")
    print(f"用户: {question2}")
    # 调用时传入历史记录
    answer2 = get_response(question2, history=chat_history)
    print(f"模型: {answer2}")