| """ |
| EB-Alfred K-step Trajectory Prediction Reward Function |
| |
| 评分规则: |
| 1. format_reward: 输出是否符合 <think>...</think><answer>...</answer> 格式 |
| 2. accuracy_reward: <answer> 中的动作序列与 ground_truth 基于 LCS 的 F1 匹配度 |
| - LCS precision = lcs_len / len(pred) 预测中有多少是有效的 |
| - LCS recall = lcs_len / len(gt) GT 被覆盖了多少 |
| - F1 = 2 * precision * recall / (precision + recall) |
| - 解析失败: 0.0 |
| |
| 输入格式: |
| reward_input = { |
| "response": "<think>...</think><answer>[{...}]</answer>", |
| "response_length": int, |
| "ground_truth": '[{"action_id": 64, "action_name": "find a Ladle"}, ...]' |
| } |
| """ |
|
|
| import json |
| import re |
| from typing import Any |
|
|
|
|
| REWARD_NAME = "eb_alfred_k_step" |
| REWARD_TYPE = "sequential" |
|
|
|
|
| def format_reward(response: str) -> float: |
| """检查输出是否符合 <think>...</think><answer>...</answer> 格式""" |
| pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL) |
| return 1.0 if re.fullmatch(pattern, response.strip()) else 0.0 |
|
|
|
|
| def extract_actions(text: str) -> list[dict] | None: |
| """从文本中提取动作列表, 返回 None 表示解析失败""" |
| |
| try: |
| actions = json.loads(text.strip()) |
| if isinstance(actions, list): |
| return actions |
| except (json.JSONDecodeError, ValueError): |
| pass |
|
|
| |
| match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) |
| if match: |
| try: |
| actions = json.loads(match.group(1)) |
| if isinstance(actions, list): |
| return actions |
| except (json.JSONDecodeError, ValueError): |
| pass |
|
|
| |
| match = re.search(r"\[.*\]", text, re.DOTALL) |
| if match: |
| try: |
| actions = json.loads(match.group(0)) |
| if isinstance(actions, list): |
| return actions |
| except (json.JSONDecodeError, ValueError): |
| pass |
|
|
| return None |
|
|
|
|
| def _extract_action_ids(actions: list[dict]) -> list[int]: |
| """从动作列表中提取所有合法的 action_id, 过滤掉缺失或无效的""" |
| ids = [] |
| for a in actions: |
| aid = a.get("action_id") |
| if aid is not None: |
| ids.append(int(aid)) |
| return ids |
|
|
|
|
| def _lcs_length(seq1: list[int], seq2: list[int]) -> int: |
| """ |
| 计算两个 int 序列的 LCS 长度. |
| 使用滚动数组优化空间到 O(min(n, m)). |
| """ |
| if len(seq1) < len(seq2): |
| seq1, seq2 = seq2, seq1 |
|
|
| m = len(seq2) |
| prev = [0] * (m + 1) |
| curr = [0] * (m + 1) |
|
|
| for i in range(1, len(seq1) + 1): |
| for j in range(1, m + 1): |
| if seq1[i - 1] == seq2[j - 1]: |
| curr[j] = prev[j - 1] + 1 |
| else: |
| curr[j] = max(prev[j], curr[j - 1]) |
| prev, curr = curr, [0] * (m + 1) |
|
|
| return prev[m] |
|
|
|
|
| def _lcs_f1(pred_ids: list[int], gt_ids: list[int]) -> float: |
| """ |
| 基于 LCS 计算 F1. |
| - precision = lcs_len / len(pred) 惩罚冗余预测 |
| - recall = lcs_len / len(gt) 惩罚遗漏 |
| - F1 = 2 * P * R / (P + R) |
| """ |
| if len(pred_ids) == 0 or len(gt_ids) == 0: |
| return 0.0 |
|
|
| lcs_len = _lcs_length(pred_ids, gt_ids) |
| if lcs_len == 0: |
| return 0.0 |
|
|
| precision = lcs_len / len(pred_ids) |
| recall = lcs_len / len(gt_ids) |
| return 2 * precision * recall / (precision + recall) |
|
|
|
|
| def compute_score(reward_input: dict[str, Any], format_weight: float = 0.2) -> dict[str, float]: |
| """ |
| 计算总分. |
| |
| Args: |
| reward_input: {"response": str, "ground_truth": str} |
| format_weight: 格式分权重 (默认 0.2) |
| |
| Returns: |
| {"overall": float, "format": float, "accuracy": float} |
| """ |
| response = reward_input.get("response", "") |
| ground_truth = reward_input.get("ground_truth", "") |
|
|
| |
| fmt_score = format_reward(response) |
|
|
| |
| answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL) |
| answer_text = answer_match.group(1) if answer_match else response |
|
|
| pred_actions = extract_actions(answer_text) |
| gt_actions = extract_actions(ground_truth) |
|
|
| if pred_actions is None or gt_actions is None: |
| pred_ids, gt_ids = [], [] |
| else: |
| pred_ids = _extract_action_ids(pred_actions) |
| gt_ids = _extract_action_ids(gt_actions) |
|
|
| print('-------------------------------------------') |
| print('pred_ids: ', pred_ids) |
| print('gt_ids: ', gt_ids) |
|
|
| |
| acc_score = _lcs_f1(pred_ids, gt_ids) |
| print('acc_score: ', acc_score) |
|
|
| overall = format_weight * fmt_score + (1 - format_weight) * acc_score |
|
|
| return { |
| "overall": overall, |
| "format": fmt_score, |
| "accuracy": acc_score, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| gt = '[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]' |
|
|
| tests = [ |
| |
| { |
| "name": "完美匹配", |
| "response": '<think>I need to find the ladle first.</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]</answer>', |
| "ground_truth": gt, |
| }, |
| |
| { |
| "name": "部分匹配 (2/3)", |
| "response": '<think>thinking</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 99, "action_name": "wrong action"}]</answer>', |
| "ground_truth": gt, |
| }, |
| |
| { |
| "name": "格式错误但答案正确", |
| "response": '[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]', |
| "ground_truth": gt, |
| }, |
| |
| { |
| "name": "完全错误", |
| "response": "<think>I have no idea</think><answer>I don't know</answer>", |
| "ground_truth": gt, |
| }, |
| |
| { |
| "name": "预测过长 (冗余动作)", |
| "response": '<think>do everything</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 50, "action_name": "look around"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 77, "action_name": "turn left"}, {"action_id": 18, "action_name": "find a DiningTable"}, {"action_id": 33, "action_name": "open drawer"}]</answer>', |
| "ground_truth": gt, |
| }, |
| |
| { |
| "name": "预测过短 (遗漏动作)", |
| "response": '<think>just one step</think><answer>[{"action_id": 64, "action_name": "find a Ladle"}]</answer>', |
| "ground_truth": gt, |
| }, |
| |
| { |
| "name": "顺序错位", |
| "response": '<think>wrong order</think><answer>[{"action_id": 18, "action_name": "find a DiningTable"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 64, "action_name": "find a Ladle"}]</answer>', |
| "ground_truth": gt, |
| }, |
| ] |
|
|
| print("EB-Alfred K-step Reward Function Tests (LCS-F1)") |
| print("=" * 70) |
| for i, t in enumerate(tests, 1): |
| score = compute_score(t) |
| print(f"{i}. [{t['name']}]") |
| print(f" format={score['format']:.1f} accuracy={score['accuracy']:.2f} overall={score['overall']:.2f}") |
| print() |
|
|