ICML-2027 / examples /reward_function /eb_alfred_k_step.py
Codex Restore
Restore previous repo contents and add deep-dive document
48c208c
"""
EB-Alfred K-step Trajectory Prediction Reward Function
评分规则:
1. format_reward: 输出是否符合 <think>...</think><answer>...</answer> 格式
2. accuracy_reward: <answer> 中的动作序列与 ground_truth 基于 LCS 的 F1 匹配度
- LCS precision = lcs_len / len(pred) 预测中有多少是有效的
- LCS recall = lcs_len / len(gt) GT 被覆盖了多少
- F1 = 2 * precision * recall / (precision + recall)
- 解析失败: 0.0
输入格式:
reward_input = {
"response": "<think>...</think><answer>[{...}]</answer>",
"response_length": int,
"ground_truth": '[{"action_id": 64, "action_name": "find a Ladle"}, ...]'
}
"""
import json
import re
from typing import Any
REWARD_NAME = "eb_alfred_k_step"
REWARD_TYPE = "sequential"
def format_reward(response: str) -> float:
"""检查输出是否符合 <think>...</think><answer>...</answer> 格式"""
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
return 1.0 if re.fullmatch(pattern, response.strip()) else 0.0
def extract_actions(text: str) -> list[dict] | None:
"""从文本中提取动作列表, 返回 None 表示解析失败"""
# 尝试直接解析
try:
actions = json.loads(text.strip())
if isinstance(actions, list):
return actions
except (json.JSONDecodeError, ValueError):
pass
# 尝试从 ```json ... ``` 中提取
match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
if match:
try:
actions = json.loads(match.group(1))
if isinstance(actions, list):
return actions
except (json.JSONDecodeError, ValueError):
pass
# 尝试提取看起来像 JSON 数组的部分
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
try:
actions = json.loads(match.group(0))
if isinstance(actions, list):
return actions
except (json.JSONDecodeError, ValueError):
pass
return None
def _extract_action_ids(actions: list[dict]) -> list[int]:
"""从动作列表中提取所有合法的 action_id, 过滤掉缺失或无效的"""
ids = []
for a in actions:
aid = a.get("action_id")
if aid is not None:
ids.append(int(aid))
return ids
def _lcs_length(seq1: list[int], seq2: list[int]) -> int:
"""
计算两个 int 序列的 LCS 长度.
使用滚动数组优化空间到 O(min(n, m)).
"""
if len(seq1) < len(seq2):
seq1, seq2 = seq2, seq1
m = len(seq2)
prev = [0] * (m + 1)
curr = [0] * (m + 1)
for i in range(1, len(seq1) + 1):
for j in range(1, m + 1):
if seq1[i - 1] == seq2[j - 1]:
curr[j] = prev[j - 1] + 1
else:
curr[j] = max(prev[j], curr[j - 1])
prev, curr = curr, [0] * (m + 1)
return prev[m]
def _lcs_f1(pred_ids: list[int], gt_ids: list[int]) -> float:
"""
基于 LCS 计算 F1.
- precision = lcs_len / len(pred) 惩罚冗余预测
- recall = lcs_len / len(gt) 惩罚遗漏
- F1 = 2 * P * R / (P + R)
"""
if len(pred_ids) == 0 or len(gt_ids) == 0:
return 0.0
lcs_len = _lcs_length(pred_ids, gt_ids)
if lcs_len == 0:
return 0.0
precision = lcs_len / len(pred_ids)
recall = lcs_len / len(gt_ids)
return 2 * precision * recall / (precision + recall)
def compute_score(reward_input: dict[str, Any], format_weight: float = 0.2) -> dict[str, float]:
"""
计算总分.
Args:
reward_input: {"response": str, "ground_truth": str}
format_weight: 格式分权重 (默认 0.2)
Returns:
{"overall": float, "format": float, "accuracy": float}
"""
response = reward_input.get("response", "")
ground_truth = reward_input.get("ground_truth", "")
# --- 格式分 ---
fmt_score = format_reward(response)
# --- 解析并提取 action_id 列表 ---
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
answer_text = answer_match.group(1) if answer_match else response
pred_actions = extract_actions(answer_text)
gt_actions = extract_actions(ground_truth)
if pred_actions is None or gt_actions is None:
pred_ids, gt_ids = [], []
else:
pred_ids = _extract_action_ids(pred_actions)
gt_ids = _extract_action_ids(gt_actions)
print('-------------------------------------------')
print('pred_ids: ', pred_ids)
print('gt_ids: ', gt_ids)
# --- 准确度分 ---
acc_score = _lcs_f1(pred_ids, gt_ids)
print('acc_score: ', acc_score)
overall = format_weight * fmt_score + (1 - format_weight) * acc_score
return {
"overall": overall,
"format": fmt_score,
"accuracy": acc_score,
}
if __name__ == "__main__":
gt = '[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]'
tests = [
# 1. 完美匹配 → F1=1.0
{
"name": "完美匹配",
"response": '<think>I need to find the ladle first.</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]</answer>',
"ground_truth": gt,
},
# 2. 部分匹配 (LCS=2, pred=3, gt=3) → P=2/3, R=2/3, F1=0.67
{
"name": "部分匹配 (2/3)",
"response": '<think>thinking</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 99, "action_name": "wrong action"}]</answer>',
"ground_truth": gt,
},
# 3. 格式错误但答案正确 → format=0, accuracy=1.0
{
"name": "格式错误但答案正确",
"response": '[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]',
"ground_truth": gt,
},
# 4. 完全错误
{
"name": "完全错误",
"response": "<think>I have no idea</think><answer>I don't know</answer>",
"ground_truth": gt,
},
# 5. 预测过长 (LCS=3, pred=6, gt=3) → P=3/6=0.5, R=3/3=1.0, F1=0.67
{
"name": "预测过长 (冗余动作)",
"response": '<think>do everything</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 50, "action_name": "look around"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 77, "action_name": "turn left"}, {"action_id": 18, "action_name": "find a DiningTable"}, {"action_id": 33, "action_name": "open drawer"}]</answer>',
"ground_truth": gt,
},
# 6. 预测过短 (LCS=1, pred=1, gt=3) → P=1/1=1.0, R=1/3=0.33, F1=0.5
{
"name": "预测过短 (遗漏动作)",
"response": '<think>just one step</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}]</answer>',
"ground_truth": gt,
},
# 7. 顺序错位但内容对 (LCS=2, pred=3, gt=3) → P=2/3, R=2/3, F1=0.67
{
"name": "顺序错位",
"response": '<think>wrong order</think><answer>[{"action_id": 18, "action_name": "find a DiningTable"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 64, "action_name": "find a Ladle"}]</answer>',
"ground_truth": gt,
},
]
print("EB-Alfred K-step Reward Function Tests (LCS-F1)")
print("=" * 70)
for i, t in enumerate(tests, 1):
score = compute_score(t)
print(f"{i}. [{t['name']}]")
print(f" format={score['format']:.1f} accuracy={score['accuracy']:.2f} overall={score['overall']:.2f}")
print()