#!/usr/bin/env python3 """ Split EB-ALFRED mixed tasks into aligned success/failure case files. A "mixed task" is defined by the tuple: (eval_set, episode_id, instruction) and must contain at least one successful run and at least one failed run. The script writes two JSON files with identical case ordering: - success cases: only successful runs for each mixed task - failure cases: only failed runs for each mixed task """ from __future__ import annotations import argparse import hashlib import json from collections import Counter, defaultdict from pathlib import Path from typing import Any DEFAULT_INPUT = Path("/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_dataset_single_step.json") DEFAULT_SUCCESS_OUTPUT = Path( "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_success_cases.json" ) DEFAULT_FAILURE_OUTPUT = Path( "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_failure_cases.json" ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Split EB-ALFRED mixed tasks into aligned success/failure case files." ) parser.add_argument( "--input", type=Path, default=DEFAULT_INPUT, help=f"Path to the source JSON file. Default: {DEFAULT_INPUT}", ) parser.add_argument( "--success-output", type=Path, default=DEFAULT_SUCCESS_OUTPUT, help=f"Path to the success-case JSON. Default: {DEFAULT_SUCCESS_OUTPUT}", ) parser.add_argument( "--failure-output", type=Path, default=DEFAULT_FAILURE_OUTPUT, help=f"Path to the failure-case JSON. Default: {DEFAULT_FAILURE_OUTPUT}", ) return parser.parse_args() def action_stats(trajectory_steps: list[dict[str, Any]]) -> dict[str, int]: success_actions = 0 failed_actions = 0 for step in trajectory_steps: for executable in step.get("executable_plan", []) or []: result = executable.get("action_success") if result in (1, 1.0, True): success_actions += 1 elif result in (0, 0.0, False): failed_actions += 1 return { "num_steps": len(trajectory_steps), "successful_actions": success_actions, "failed_actions": failed_actions, } def build_run(record: dict[str, Any]) -> dict[str, Any]: trajectory_steps = record.get("trajectory", []) or [] run = { "model_name": record.get("model_name"), "success": record.get("success"), "input": record.get("input"), "trajectory": trajectory_steps, } run.update(action_stats(trajectory_steps)) return run def sort_group_key(key: tuple[str, str, str]) -> tuple[str, int | str, str]: eval_set, episode_id, instruction = key try: episode_sort = int(episode_id) except (TypeError, ValueError): episode_sort = episode_id return eval_set, episode_sort, instruction def build_payloads( records: list[dict[str, Any]], source_file: Path ) -> tuple[dict[str, Any], dict[str, Any]]: grouped: dict[tuple[str, str, str], list[dict[str, Any]]] = defaultdict(list) for record in records: key = ( str(record["eval_set"]), str(record["episode_id"]), str(record["instruction"]), ) grouped[key].append(record) success_cases: list[dict[str, Any]] = [] failure_cases: list[dict[str, Any]] = [] run_count_distribution = Counter() mixed_items: list[tuple[tuple[str, str, str], list[dict[str, Any]]]] = [] for key, group_records in grouped.items(): has_success = any(record.get("success") in (1, 1.0, True) for record in group_records) has_failure = any(record.get("success") in (0, 0.0, False) for record in group_records) if has_success and has_failure: mixed_items.append((key, group_records)) mixed_items.sort(key=lambda item: sort_group_key(item[0])) for idx, (key, group_records) in enumerate(mixed_items, start=1): eval_set, episode_id, instruction = key instruction_hash = hashlib.md5(instruction.encode("utf-8")).hexdigest()[:12] trajectory_id = f"{eval_set}__episode_{episode_id}__{instruction_hash}" success_runs = [] failure_runs = [] for record in sorted(group_records, key=lambda item: str(item.get("model_name", ""))): run = build_run(record) if record.get("success") in (1, 1.0, True): success_runs.append(run) else: failure_runs.append(run) run_count_distribution[(len(success_runs), len(failure_runs))] += 1 shared_fields = { "case_index": idx, "trajectory_id": trajectory_id, "eval_set": eval_set, "episode_id": episode_id, "instruction": instruction, "paired_success_run_count": len(success_runs), "paired_failure_run_count": len(failure_runs), } success_cases.append( { **shared_fields, "num_runs": len(success_runs), "models": [run["model_name"] for run in success_runs], "runs": success_runs, } ) failure_cases.append( { **shared_fields, "num_runs": len(failure_runs), "models": [run["model_name"] for run in failure_runs], "runs": failure_runs, } ) common_meta = { "source_file": str(source_file), "group_rule": ["eval_set", "episode_id", "instruction"], "selection_rule": "mixed tasks only: at least one success run and at least one failure run", "num_input_records": len(records), "num_mixed_tasks": len(success_cases), "case_order": "sorted by (eval_set, int(episode_id) when possible, instruction)", "run_count_distribution": { f"success_{succ}__failure_{fail}": count for (succ, fail), count in sorted(run_count_distribution.items()) }, } success_payload = { **common_meta, "case_type": "success_runs_only", "cases": success_cases, } failure_payload = { **common_meta, "case_type": "failure_runs_only", "cases": failure_cases, } return success_payload, failure_payload def main() -> None: args = parse_args() with args.input.open("r", encoding="utf-8") as f: records = json.load(f) success_payload, failure_payload = build_payloads(records, args.input) args.success_output.parent.mkdir(parents=True, exist_ok=True) args.failure_output.parent.mkdir(parents=True, exist_ok=True) with args.success_output.open("w", encoding="utf-8") as f: json.dump(success_payload, f, ensure_ascii=False, indent=2) with args.failure_output.open("w", encoding="utf-8") as f: json.dump(failure_payload, f, ensure_ascii=False, indent=2) print(f"Input records: {len(records)}") print(f"Mixed tasks: {success_payload['num_mixed_tasks']}") print(f"Saved success cases to: {args.success_output}") print(f"Saved failure cases to: {args.failure_output}") if __name__ == "__main__": main()