| |
| """ |
| Split EB-ALFRED mixed tasks into aligned success/failure case files. |
| |
| A "mixed task" is defined by the tuple: |
| (eval_set, episode_id, instruction) |
| |
| and must contain at least one successful run and at least one failed run. |
| |
| The script writes two JSON files with identical case ordering: |
| - success cases: only successful runs for each mixed task |
| - failure cases: only failed runs for each mixed task |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| DEFAULT_INPUT = Path("/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_dataset_single_step.json") |
| DEFAULT_SUCCESS_OUTPUT = Path( |
| "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_success_cases.json" |
| ) |
| DEFAULT_FAILURE_OUTPUT = Path( |
| "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_failure_cases.json" |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Split EB-ALFRED mixed tasks into aligned success/failure case files." |
| ) |
| parser.add_argument( |
| "--input", |
| type=Path, |
| default=DEFAULT_INPUT, |
| help=f"Path to the source JSON file. Default: {DEFAULT_INPUT}", |
| ) |
| parser.add_argument( |
| "--success-output", |
| type=Path, |
| default=DEFAULT_SUCCESS_OUTPUT, |
| help=f"Path to the success-case JSON. Default: {DEFAULT_SUCCESS_OUTPUT}", |
| ) |
| parser.add_argument( |
| "--failure-output", |
| type=Path, |
| default=DEFAULT_FAILURE_OUTPUT, |
| help=f"Path to the failure-case JSON. Default: {DEFAULT_FAILURE_OUTPUT}", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def action_stats(trajectory_steps: list[dict[str, Any]]) -> dict[str, int]: |
| success_actions = 0 |
| failed_actions = 0 |
|
|
| for step in trajectory_steps: |
| for executable in step.get("executable_plan", []) or []: |
| result = executable.get("action_success") |
| if result in (1, 1.0, True): |
| success_actions += 1 |
| elif result in (0, 0.0, False): |
| failed_actions += 1 |
|
|
| return { |
| "num_steps": len(trajectory_steps), |
| "successful_actions": success_actions, |
| "failed_actions": failed_actions, |
| } |
|
|
|
|
| def build_run(record: dict[str, Any]) -> dict[str, Any]: |
| trajectory_steps = record.get("trajectory", []) or [] |
| run = { |
| "model_name": record.get("model_name"), |
| "success": record.get("success"), |
| "input": record.get("input"), |
| "trajectory": trajectory_steps, |
| } |
| run.update(action_stats(trajectory_steps)) |
| return run |
|
|
|
|
| def sort_group_key(key: tuple[str, str, str]) -> tuple[str, int | str, str]: |
| eval_set, episode_id, instruction = key |
| try: |
| episode_sort = int(episode_id) |
| except (TypeError, ValueError): |
| episode_sort = episode_id |
| return eval_set, episode_sort, instruction |
|
|
|
|
| def build_payloads( |
| records: list[dict[str, Any]], source_file: Path |
| ) -> tuple[dict[str, Any], dict[str, Any]]: |
| grouped: dict[tuple[str, str, str], list[dict[str, Any]]] = defaultdict(list) |
| for record in records: |
| key = ( |
| str(record["eval_set"]), |
| str(record["episode_id"]), |
| str(record["instruction"]), |
| ) |
| grouped[key].append(record) |
|
|
| success_cases: list[dict[str, Any]] = [] |
| failure_cases: list[dict[str, Any]] = [] |
| run_count_distribution = Counter() |
|
|
| mixed_items: list[tuple[tuple[str, str, str], list[dict[str, Any]]]] = [] |
| for key, group_records in grouped.items(): |
| has_success = any(record.get("success") in (1, 1.0, True) for record in group_records) |
| has_failure = any(record.get("success") in (0, 0.0, False) for record in group_records) |
| if has_success and has_failure: |
| mixed_items.append((key, group_records)) |
|
|
| mixed_items.sort(key=lambda item: sort_group_key(item[0])) |
|
|
| for idx, (key, group_records) in enumerate(mixed_items, start=1): |
| eval_set, episode_id, instruction = key |
| instruction_hash = hashlib.md5(instruction.encode("utf-8")).hexdigest()[:12] |
| trajectory_id = f"{eval_set}__episode_{episode_id}__{instruction_hash}" |
|
|
| success_runs = [] |
| failure_runs = [] |
| for record in sorted(group_records, key=lambda item: str(item.get("model_name", ""))): |
| run = build_run(record) |
| if record.get("success") in (1, 1.0, True): |
| success_runs.append(run) |
| else: |
| failure_runs.append(run) |
|
|
| run_count_distribution[(len(success_runs), len(failure_runs))] += 1 |
|
|
| shared_fields = { |
| "case_index": idx, |
| "trajectory_id": trajectory_id, |
| "eval_set": eval_set, |
| "episode_id": episode_id, |
| "instruction": instruction, |
| "paired_success_run_count": len(success_runs), |
| "paired_failure_run_count": len(failure_runs), |
| } |
|
|
| success_cases.append( |
| { |
| **shared_fields, |
| "num_runs": len(success_runs), |
| "models": [run["model_name"] for run in success_runs], |
| "runs": success_runs, |
| } |
| ) |
| failure_cases.append( |
| { |
| **shared_fields, |
| "num_runs": len(failure_runs), |
| "models": [run["model_name"] for run in failure_runs], |
| "runs": failure_runs, |
| } |
| ) |
|
|
| common_meta = { |
| "source_file": str(source_file), |
| "group_rule": ["eval_set", "episode_id", "instruction"], |
| "selection_rule": "mixed tasks only: at least one success run and at least one failure run", |
| "num_input_records": len(records), |
| "num_mixed_tasks": len(success_cases), |
| "case_order": "sorted by (eval_set, int(episode_id) when possible, instruction)", |
| "run_count_distribution": { |
| f"success_{succ}__failure_{fail}": count |
| for (succ, fail), count in sorted(run_count_distribution.items()) |
| }, |
| } |
|
|
| success_payload = { |
| **common_meta, |
| "case_type": "success_runs_only", |
| "cases": success_cases, |
| } |
| failure_payload = { |
| **common_meta, |
| "case_type": "failure_runs_only", |
| "cases": failure_cases, |
| } |
| return success_payload, failure_payload |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| with args.input.open("r", encoding="utf-8") as f: |
| records = json.load(f) |
|
|
| success_payload, failure_payload = build_payloads(records, args.input) |
|
|
| args.success_output.parent.mkdir(parents=True, exist_ok=True) |
| args.failure_output.parent.mkdir(parents=True, exist_ok=True) |
|
|
| with args.success_output.open("w", encoding="utf-8") as f: |
| json.dump(success_payload, f, ensure_ascii=False, indent=2) |
| with args.failure_output.open("w", encoding="utf-8") as f: |
| json.dump(failure_payload, f, ensure_ascii=False, indent=2) |
|
|
| print(f"Input records: {len(records)}") |
| print(f"Mixed tasks: {success_payload['num_mixed_tasks']}") |
| print(f"Saved success cases to: {args.success_output}") |
| print(f"Saved failure cases to: {args.failure_output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|