import torch from transformers import AutoModelForCausalLM, AutoTokenizer from langchain_core.messages import AIMessage from typing import TypedDict, Annotated, List import operator # 定义此组件操作的图状态的子集 class GraphState(TypedDict): messages: Annotated[List[AIMessage], operator.add] # --- 模型加载 --- # 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度 MODEL_NAME = "inclusionAI/Ring-mini-2.0" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map="auto", trust_remote_code=True ) def completion_node(state: GraphState) -> dict: """ 一个调用语言模型以获取响应的节点。 Args: state (GraphState): 图的当前状态,包含消息历史。 Returns: dict: 一个包含新 AI 消息的字典,该消息将被添加到状态中。 """ messages = state["messages"] # --- 提示工程 --- # 从消息历史中组装提示。 prompt = "" for msg in messages: if msg.type == "system": prompt += f"{msg.content}\n" elif msg.type == "human": prompt += f"User: {msg.content}\n" elif msg.type == "ai": prompt += f"Assistant: {msg.content}\n" prompt += "Assistant:" # --- 模型调用 --- # 虽然模型设备是自动映射的,但输入张量仍需显式移动到模型所在的设备 input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) output_ids = model.generate( input_ids, max_new_tokens=512, # 暂时硬编码 do_sample=True, pad_token_id=tokenizer.eos_token_id, ) output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) # 以 AIMessage 的形式返回响应,以添加到图的状态中。 return {"messages": [AIMessage(content=output)]}