cafe3310 commited on
Commit
491422c
·
1 Parent(s): ce5f3e0

fix: 显式传递 attention_mask 以修复生成警告

Browse files
Files changed (1) hide show
  1. comp.py +6 -3
comp.py CHANGED
@@ -46,10 +46,13 @@ def completion_node(state: GraphState) -> dict:
46
  prompt += "Assistant:"
47
 
48
  # --- 模型调用 ---
49
- # 虽然模型设备是自动映射的,但输入张量仍需显式移动到模型所在的设备
50
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
 
 
51
  output_ids = model.generate(
52
- input_ids,
 
53
  max_new_tokens=512, # 暂时硬编码
54
  do_sample=True,
55
  pad_token_id=tokenizer.eos_token_id,
 
46
  prompt += "Assistant:"
47
 
48
  # --- 模型调用 ---
49
+ # 调用 tokenizer 时获取 input_ids 和 attention_mask
50
+ inputs = tokenizer(prompt, return_tensors="pt")
51
+
52
+ # 将 attention_mask 和 input_ids 一起传递给 model.generate
53
  output_ids = model.generate(
54
+ inputs.input_ids.to(model.device),
55
+ attention_mask=inputs.attention_mask.to(model.device),
56
  max_new_tokens=512, # 暂时硬编码
57
  do_sample=True,
58
  pad_token_id=tokenizer.eos_token_id,