Spaces:
Sleeping
Sleeping
File size: 2,022 Bytes
551e9e2 a074dc6 551e9e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.messages import AIMessage
from typing import TypedDict, Annotated, List
import operator
# 定义此组件操作的图状态的子集
class GraphState(TypedDict):
messages: Annotated[List[AIMessage], operator.add]
# --- 模型加载 ---
# 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度
MODEL_NAME = "inclusionAI/Ring-mini-2.0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
def completion_node(state: GraphState) -> dict:
"""
一个调用语言模型以获取响应的节点。
Args:
state (GraphState): 图的当前状态,包含消息历史。
Returns:
dict: 一个包含新 AI 消息的字典,该消息将被添加到状态中。
"""
messages = state["messages"]
# --- 提示工程 ---
# 从消息历史中组装提示。
prompt = ""
for msg in messages:
if msg.type == "system":
prompt += f"{msg.content}\n"
elif msg.type == "human":
prompt += f"User: {msg.content}\n"
elif msg.type == "ai":
prompt += f"Assistant: {msg.content}\n"
prompt += "Assistant:"
# --- 模型调用 ---
# 虽然模型设备是自动映射的,但输入张量仍需显式移动到模型所在的设备
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=512, # 暂时硬编码
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
# 以 AIMessage 的形式返回响应,以添加到图的状态中。
return {"messages": [AIMessage(content=output)]} |