|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b") |
|
|
print(f"正在加载模型: {MODEL_ID}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("模型加载成功!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_response(prompt: str, history: list[list[str]] = None): |
|
|
""" |
|
|
一个简单的函数,用于与模型进行单次对话。 |
|
|
|
|
|
Args: |
|
|
prompt (str): 用户当前输入的问题。 |
|
|
history (list[list[str]], optional): 对话历史,格式为 [[user_msg_1, bot_msg_1], ...]。默认为 None。 |
|
|
|
|
|
Returns: |
|
|
str: 模型生成的回复。 |
|
|
""" |
|
|
if history is None: |
|
|
history = [] |
|
|
|
|
|
|
|
|
messages = [] |
|
|
for user_message, bot_message in history: |
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
messages.append({"role": "assistant", "content": bot_message}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt" |
|
|
).to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=1024, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
response_ids = outputs[0][input_ids.shape[-1]:] |
|
|
response_text = tokenizer.decode(response_ids, skip_special_tokens=True) |
|
|
|
|
|
return response_text |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("\n--- 示例 1: 单轮对话 ---") |
|
|
question1 = "你好,你是谁?" |
|
|
print(f"用户: {question1}") |
|
|
answer1 = get_response(question1) |
|
|
print(f"模型: {answer1}") |
|
|
|
|
|
|
|
|
print("\n--- 示例 2: 多轮对话 ---") |
|
|
|
|
|
chat_history = [ |
|
|
["用Python写一个快速排序", "当然,这是快速排序的Python实现:\n```python\ndef quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)\n\nprint(quick_sort())\n```"] |
|
|
] |
|
|
question2 = "很好,你能解释一下它的工作原理吗?" |
|
|
print(f"历史: {chat_history}") |
|
|
print(f"用户: {question2}") |
|
|
|
|
|
answer2 = get_response(question2, history=chat_history) |
|
|
print(f"模型: {answer2}") |