import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os # --- 1. 配置与模型加载 --- # 从环境变量或默认值加载模型ID MODEL_ID = os.getenv("MODEL_ID", "badanwang/teacher_basic_qwen3-0.6b") print(f"正在加载模型: {MODEL_ID}") # 加载分词器和模型 # trust_remote_code=True 是加载Qwen等模型所必需的 # device_map="auto" 会自动将模型分配到可用的硬件上(如GPU) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto", trust_remote_code=True ) print("模型加载成功!") # --- 2. 核心推理函数 (API) --- def get_response(prompt: str, history: list[list[str]] = None): """ 一个简单的函数,用于与模型进行单次对话。 Args: prompt (str): 用户当前输入的问题。 history (list[list[str]], optional): 对话历史,格式为 [[user_msg_1, bot_msg_1], ...]。默认为 None。 Returns: str: 模型生成的回复。 """ if history is None: history = [] # 1. 构建消息列表 messages = [] for user_message, bot_message in history: messages.append({"role": "user", "content": user_message}) messages.append({"role": "assistant", "content": bot_message}) messages.append({"role": "user", "content": prompt}) # 2. 应用聊天模板并进行分词 # 这是与聊天模型正确交互的关键步骤 input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" ).to(model.device) # 3. 生成回复 # 这是一个阻塞式调用,会等待模型生成完毕 outputs = model.generate( input_ids, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.9 ) # 4. 解码生成的文本 # `outputs[0]` 包含了输入的token和新生成的token,我们需要切片只获取新生成的部分 response_ids = outputs[0][input_ids.shape[-1]:] response_text = tokenizer.decode(response_ids, skip_special_tokens=True) return response_text # --- 3. 使用示例 --- if __name__ == "__main__": # 示例1: 单轮对话 print("\n--- 示例 1: 单轮对话 ---") question1 = "你好,你是谁?" print(f"用户: {question1}") answer1 = get_response(question1) print(f"模型: {answer1}") # 示例2: 多轮对话 print("\n--- 示例 2: 多轮对话 ---") # 首先,定义一个对话历史 chat_history = [ ["用Python写一个快速排序", "当然,这是快速排序的Python实现:\n```python\ndef quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)\n\nprint(quick_sort())\n```"] ] question2 = "很好,你能解释一下它的工作原理吗?" print(f"历史: {chat_history}") print(f"用户: {question2}") # 调用时传入历史记录 answer2 = get_response(question2, history=chat_history) print(f"模型: {answer2}")