""" EB-Alfred K-step Trajectory Prediction Reward Function 评分规则: 1. format_reward: 输出是否符合 ...... 格式 2. accuracy_reward: 中的动作序列与 ground_truth 基于 LCS 的 F1 匹配度 - LCS precision = lcs_len / len(pred) 预测中有多少是有效的 - LCS recall = lcs_len / len(gt) GT 被覆盖了多少 - F1 = 2 * precision * recall / (precision + recall) - 解析失败: 0.0 输入格式: reward_input = { "response": "...[{...}]", "response_length": int, "ground_truth": '[{"action_id": 64, "action_name": "find a Ladle"}, ...]' } """ import json import re from typing import Any REWARD_NAME = "eb_alfred_k_step" REWARD_TYPE = "sequential" def format_reward(response: str) -> float: """检查输出是否符合 ...... 格式""" pattern = re.compile(r".*?\s*.*?", re.DOTALL) return 1.0 if re.fullmatch(pattern, response.strip()) else 0.0 def extract_actions(text: str) -> list[dict] | None: """从文本中提取动作列表, 返回 None 表示解析失败""" # 尝试直接解析 try: actions = json.loads(text.strip()) if isinstance(actions, list): return actions except (json.JSONDecodeError, ValueError): pass # 尝试从 ```json ... ``` 中提取 match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) if match: try: actions = json.loads(match.group(1)) if isinstance(actions, list): return actions except (json.JSONDecodeError, ValueError): pass # 尝试提取看起来像 JSON 数组的部分 match = re.search(r"\[.*\]", text, re.DOTALL) if match: try: actions = json.loads(match.group(0)) if isinstance(actions, list): return actions except (json.JSONDecodeError, ValueError): pass return None def _extract_action_ids(actions: list[dict]) -> list[int]: """从动作列表中提取所有合法的 action_id, 过滤掉缺失或无效的""" ids = [] for a in actions: aid = a.get("action_id") if aid is not None: ids.append(int(aid)) return ids def _lcs_length(seq1: list[int], seq2: list[int]) -> int: """ 计算两个 int 序列的 LCS 长度. 使用滚动数组优化空间到 O(min(n, m)). """ if len(seq1) < len(seq2): seq1, seq2 = seq2, seq1 m = len(seq2) prev = [0] * (m + 1) curr = [0] * (m + 1) for i in range(1, len(seq1) + 1): for j in range(1, m + 1): if seq1[i - 1] == seq2[j - 1]: curr[j] = prev[j - 1] + 1 else: curr[j] = max(prev[j], curr[j - 1]) prev, curr = curr, [0] * (m + 1) return prev[m] def _lcs_f1(pred_ids: list[int], gt_ids: list[int]) -> float: """ 基于 LCS 计算 F1. - precision = lcs_len / len(pred) 惩罚冗余预测 - recall = lcs_len / len(gt) 惩罚遗漏 - F1 = 2 * P * R / (P + R) """ if len(pred_ids) == 0 or len(gt_ids) == 0: return 0.0 lcs_len = _lcs_length(pred_ids, gt_ids) if lcs_len == 0: return 0.0 precision = lcs_len / len(pred_ids) recall = lcs_len / len(gt_ids) return 2 * precision * recall / (precision + recall) def compute_score(reward_input: dict[str, Any], format_weight: float = 0.2) -> dict[str, float]: """ 计算总分. Args: reward_input: {"response": str, "ground_truth": str} format_weight: 格式分权重 (默认 0.2) Returns: {"overall": float, "format": float, "accuracy": float} """ response = reward_input.get("response", "") ground_truth = reward_input.get("ground_truth", "") # --- 格式分 --- fmt_score = format_reward(response) # --- 解析并提取 action_id 列表 --- answer_match = re.search(r"(.*?)", response, re.DOTALL) answer_text = answer_match.group(1) if answer_match else response pred_actions = extract_actions(answer_text) gt_actions = extract_actions(ground_truth) if pred_actions is None or gt_actions is None: pred_ids, gt_ids = [], [] else: pred_ids = _extract_action_ids(pred_actions) gt_ids = _extract_action_ids(gt_actions) print('-------------------------------------------') print('pred_ids: ', pred_ids) print('gt_ids: ', gt_ids) # --- 准确度分 --- acc_score = _lcs_f1(pred_ids, gt_ids) print('acc_score: ', acc_score) overall = format_weight * fmt_score + (1 - format_weight) * acc_score return { "overall": overall, "format": fmt_score, "accuracy": acc_score, } if __name__ == "__main__": gt = '[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]' tests = [ # 1. 完美匹配 → F1=1.0 { "name": "完美匹配", "response": 'I need to find the ladle first.[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]', "ground_truth": gt, }, # 2. 部分匹配 (LCS=2, pred=3, gt=3) → P=2/3, R=2/3, F1=0.67 { "name": "部分匹配 (2/3)", "response": 'thinking[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 99, "action_name": "wrong action"}]', "ground_truth": gt, }, # 3. 格式错误但答案正确 → format=0, accuracy=1.0 { "name": "格式错误但答案正确", "response": '[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 18, "action_name": "find a DiningTable"}]', "ground_truth": gt, }, # 4. 完全错误 { "name": "完全错误", "response": "I have no ideaI don't know", "ground_truth": gt, }, # 5. 预测过长 (LCS=3, pred=6, gt=3) → P=3/6=0.5, R=3/3=1.0, F1=0.67 { "name": "预测过长 (冗余动作)", "response": 'do everything[{"action_id": 64, "action_name": "find a Ladle"}, {"action_id": 50, "action_name": "look around"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 77, "action_name": "turn left"}, {"action_id": 18, "action_name": "find a DiningTable"}, {"action_id": 33, "action_name": "open drawer"}]', "ground_truth": gt, }, # 6. 预测过短 (LCS=1, pred=1, gt=3) → P=1/1=1.0, R=1/3=0.33, F1=0.5 { "name": "预测过短 (遗漏动作)", "response": 'just one step[{"action_id": 64, "action_name": "find a Ladle"}]', "ground_truth": gt, }, # 7. 顺序错位但内容对 (LCS=2, pred=3, gt=3) → P=2/3, R=2/3, F1=0.67 { "name": "顺序错位", "response": 'wrong order[{"action_id": 18, "action_name": "find a DiningTable"}, {"action_id": 109, "action_name": "pick up the Ladle"}, {"action_id": 64, "action_name": "find a Ladle"}]', "ground_truth": gt, }, ] print("EB-Alfred K-step Reward Function Tests (LCS-F1)") print("=" * 70) for i, t in enumerate(tests, 1): score = compute_score(t) print(f"{i}. [{t['name']}]") print(f" format={score['format']:.1f} accuracy={score['accuracy']:.2f} overall={score['overall']:.2f}") print()