| |
| """ |
| Extract divergent steps for aligned EB-ALFRED mixed-task cases. |
| |
| For each aligned success/failure case pair, the script finds the most similar |
| success run and failure run by maximizing their shared action prefix length. |
| It then saves only the steps from the first divergence onward for both sides. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| DEFAULT_SUCCESS_INPUT = Path( |
| "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_success_cases.json" |
| ) |
| DEFAULT_FAILURE_INPUT = Path( |
| "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_failure_cases.json" |
| ) |
| DEFAULT_OUTPUT = Path( |
| "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_diff_steps.json" |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Extract divergent steps from aligned EB-ALFRED mixed-task cases." |
| ) |
| parser.add_argument( |
| "--success-input", |
| type=Path, |
| default=DEFAULT_SUCCESS_INPUT, |
| help=f"Path to aligned success cases JSON. Default: {DEFAULT_SUCCESS_INPUT}", |
| ) |
| parser.add_argument( |
| "--failure-input", |
| type=Path, |
| default=DEFAULT_FAILURE_INPUT, |
| help=f"Path to aligned failure cases JSON. Default: {DEFAULT_FAILURE_INPUT}", |
| ) |
| parser.add_argument( |
| "--output", |
| type=Path, |
| default=DEFAULT_OUTPUT, |
| help=f"Path to output JSON. Default: {DEFAULT_OUTPUT}", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def action_sequence(run: dict[str, Any]) -> list[str]: |
| actions: list[str] = [] |
| for step in run.get("trajectory", []) or []: |
| executable_plan = step.get("executable_plan", []) or [] |
| if executable_plan: |
| actions.append(str(executable_plan[0]["action"][1])) |
| return actions |
|
|
|
|
| def shared_prefix_len(a: list[str], b: list[str]) -> int: |
| prefix = 0 |
| for left, right in zip(a, b): |
| if left != right: |
| break |
| prefix += 1 |
| return prefix |
|
|
|
|
| def choose_best_pair( |
| success_case: dict[str, Any], failure_case: dict[str, Any] |
| ) -> tuple[dict[str, Any], dict[str, Any], int]: |
| best_success_run: dict[str, Any] | None = None |
| best_failure_run: dict[str, Any] | None = None |
| best_prefix = -1 |
| best_success_len = -1 |
| best_failure_len = -1 |
|
|
| for success_run in success_case.get("runs", []): |
| success_actions = action_sequence(success_run) |
| for failure_run in failure_case.get("runs", []): |
| failure_actions = action_sequence(failure_run) |
| prefix = shared_prefix_len(success_actions, failure_actions) |
|
|
| |
| success_len = len(success_actions) |
| failure_len = len(failure_actions) |
| is_better = ( |
| prefix > best_prefix |
| or ( |
| prefix == best_prefix |
| and success_len + failure_len < best_success_len + best_failure_len |
| ) |
| ) |
| if is_better: |
| best_success_run = success_run |
| best_failure_run = failure_run |
| best_prefix = prefix |
| best_success_len = success_len |
| best_failure_len = failure_len |
|
|
| if best_success_run is None or best_failure_run is None: |
| raise ValueError( |
| f"Could not find a valid run pair for case {success_case.get('case_index')}" |
| ) |
|
|
| return best_success_run, best_failure_run, best_prefix |
|
|
|
|
| def annotate_diff_steps( |
| trajectory_steps: list[dict[str, Any]], |
| prefix: int, |
| trajectory_outcome: str, |
| selected_model: str | None, |
| ) -> list[dict[str, Any]]: |
| annotated_steps: list[dict[str, Any]] = [] |
| trajectory_success = 1.0 if trajectory_outcome == "success" else 0.0 |
| for original_step_index, step in enumerate(trajectory_steps[prefix:], start=prefix): |
| annotated_step = copy.deepcopy(step) |
| executable_plan = annotated_step.get("executable_plan", []) or [] |
| for executable in executable_plan: |
| executable["trajectory_outcome"] = trajectory_outcome |
| executable["trajectory_success"] = trajectory_success |
| executable["selected_model"] = selected_model |
| executable["original_step_index_0based"] = original_step_index |
| executable["original_step_index_1based"] = original_step_index + 1 |
| annotated_step["trajectory_outcome"] = trajectory_outcome |
| annotated_step["trajectory_success"] = trajectory_success |
| annotated_step["selected_model"] = selected_model |
| annotated_step["original_step_index_0based"] = original_step_index |
| annotated_step["original_step_index_1based"] = original_step_index + 1 |
| annotated_steps.append(annotated_step) |
| return annotated_steps |
|
|
|
|
| def build_diff_case( |
| success_case: dict[str, Any], failure_case: dict[str, Any] |
| ) -> dict[str, Any]: |
| keys_to_match = ["case_index", "trajectory_id", "eval_set", "episode_id", "instruction"] |
| for key in keys_to_match: |
| if success_case.get(key) != failure_case.get(key): |
| raise ValueError( |
| f"Mismatched aligned cases for key '{key}': " |
| f"{success_case.get(key)!r} != {failure_case.get(key)!r}" |
| ) |
|
|
| success_run, failure_run, prefix = choose_best_pair(success_case, failure_case) |
| success_actions = action_sequence(success_run) |
| failure_actions = action_sequence(failure_run) |
|
|
| success_diff_steps = annotate_diff_steps( |
| success_run.get("trajectory", []) or [], |
| prefix, |
| "success", |
| success_run.get("model_name"), |
| ) |
| failure_diff_steps = annotate_diff_steps( |
| failure_run.get("trajectory", []) or [], |
| prefix, |
| "failure", |
| failure_run.get("model_name"), |
| ) |
|
|
| success_first_diff_action = success_actions[prefix] if prefix < len(success_actions) else None |
| failure_first_diff_action = failure_actions[prefix] if prefix < len(failure_actions) else None |
|
|
| return { |
| "case_index": success_case["case_index"], |
| "trajectory_id": success_case["trajectory_id"], |
| "eval_set": success_case["eval_set"], |
| "episode_id": success_case["episode_id"], |
| "instruction": success_case["instruction"], |
| "paired_success_run_count": success_case["paired_success_run_count"], |
| "paired_failure_run_count": success_case["paired_failure_run_count"], |
| "selected_success_model": success_run.get("model_name"), |
| "selected_failure_model": failure_run.get("model_name"), |
| "selected_success_input": success_run.get("input"), |
| "selected_failure_input": failure_run.get("input"), |
| "selected_success_num_steps": success_run.get("num_steps"), |
| "selected_failure_num_steps": failure_run.get("num_steps"), |
| "shared_prefix_len": prefix, |
| "shared_prefix_actions": success_actions[:prefix], |
| "first_diff_step_index_0based": prefix, |
| "first_diff_step_index_1based": prefix + 1, |
| "first_diff_success_action": success_first_diff_action, |
| "first_diff_failure_action": failure_first_diff_action, |
| "success_diff_num_steps": len(success_diff_steps), |
| "failure_diff_num_steps": len(failure_diff_steps), |
| "success_diff_actions": success_actions[prefix:], |
| "failure_diff_actions": failure_actions[prefix:], |
| "success_diff_steps": success_diff_steps, |
| "failure_diff_steps": failure_diff_steps, |
| } |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| with args.success_input.open("r", encoding="utf-8") as f: |
| success_payload = json.load(f) |
| with args.failure_input.open("r", encoding="utf-8") as f: |
| failure_payload = json.load(f) |
|
|
| success_cases = success_payload["cases"] |
| failure_cases = failure_payload["cases"] |
| if len(success_cases) != len(failure_cases): |
| raise ValueError( |
| f"Case count mismatch: {len(success_cases)} success cases vs " |
| f"{len(failure_cases)} failure cases" |
| ) |
|
|
| diff_cases = [ |
| build_diff_case(success_case, failure_case) |
| for success_case, failure_case in zip(success_cases, failure_cases) |
| ] |
|
|
| output_payload = { |
| "success_input_file": str(args.success_input), |
| "failure_input_file": str(args.failure_input), |
| "selection_rule": ( |
| "For each aligned case, choose the success/failure run pair with the " |
| "longest shared action prefix; extract only the divergent tail steps." |
| ), |
| "num_cases": len(diff_cases), |
| "cases": diff_cases, |
| } |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| with args.output.open("w", encoding="utf-8") as f: |
| json.dump(output_payload, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Aligned cases processed: {len(diff_cases)}") |
| print(f"Saved divergent-step file to: {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|