ICML-2027 / split_eb_alfred_mixed_tasks.py
Wendy-Fly's picture
Upload split_eb_alfred_mixed_tasks.py with huggingface_hub
d2b5b43 verified
#!/usr/bin/env python3
"""
Split EB-ALFRED mixed tasks into aligned success/failure case files.
A "mixed task" is defined by the tuple:
(eval_set, episode_id, instruction)
and must contain at least one successful run and at least one failed run.
The script writes two JSON files with identical case ordering:
- success cases: only successful runs for each mixed task
- failure cases: only failed runs for each mixed task
"""
from __future__ import annotations
import argparse
import hashlib
import json
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
DEFAULT_INPUT = Path("/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_dataset_single_step.json")
DEFAULT_SUCCESS_OUTPUT = Path(
"/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_success_cases.json"
)
DEFAULT_FAILURE_OUTPUT = Path(
"/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_failure_cases.json"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Split EB-ALFRED mixed tasks into aligned success/failure case files."
)
parser.add_argument(
"--input",
type=Path,
default=DEFAULT_INPUT,
help=f"Path to the source JSON file. Default: {DEFAULT_INPUT}",
)
parser.add_argument(
"--success-output",
type=Path,
default=DEFAULT_SUCCESS_OUTPUT,
help=f"Path to the success-case JSON. Default: {DEFAULT_SUCCESS_OUTPUT}",
)
parser.add_argument(
"--failure-output",
type=Path,
default=DEFAULT_FAILURE_OUTPUT,
help=f"Path to the failure-case JSON. Default: {DEFAULT_FAILURE_OUTPUT}",
)
return parser.parse_args()
def action_stats(trajectory_steps: list[dict[str, Any]]) -> dict[str, int]:
success_actions = 0
failed_actions = 0
for step in trajectory_steps:
for executable in step.get("executable_plan", []) or []:
result = executable.get("action_success")
if result in (1, 1.0, True):
success_actions += 1
elif result in (0, 0.0, False):
failed_actions += 1
return {
"num_steps": len(trajectory_steps),
"successful_actions": success_actions,
"failed_actions": failed_actions,
}
def build_run(record: dict[str, Any]) -> dict[str, Any]:
trajectory_steps = record.get("trajectory", []) or []
run = {
"model_name": record.get("model_name"),
"success": record.get("success"),
"input": record.get("input"),
"trajectory": trajectory_steps,
}
run.update(action_stats(trajectory_steps))
return run
def sort_group_key(key: tuple[str, str, str]) -> tuple[str, int | str, str]:
eval_set, episode_id, instruction = key
try:
episode_sort = int(episode_id)
except (TypeError, ValueError):
episode_sort = episode_id
return eval_set, episode_sort, instruction
def build_payloads(
records: list[dict[str, Any]], source_file: Path
) -> tuple[dict[str, Any], dict[str, Any]]:
grouped: dict[tuple[str, str, str], list[dict[str, Any]]] = defaultdict(list)
for record in records:
key = (
str(record["eval_set"]),
str(record["episode_id"]),
str(record["instruction"]),
)
grouped[key].append(record)
success_cases: list[dict[str, Any]] = []
failure_cases: list[dict[str, Any]] = []
run_count_distribution = Counter()
mixed_items: list[tuple[tuple[str, str, str], list[dict[str, Any]]]] = []
for key, group_records in grouped.items():
has_success = any(record.get("success") in (1, 1.0, True) for record in group_records)
has_failure = any(record.get("success") in (0, 0.0, False) for record in group_records)
if has_success and has_failure:
mixed_items.append((key, group_records))
mixed_items.sort(key=lambda item: sort_group_key(item[0]))
for idx, (key, group_records) in enumerate(mixed_items, start=1):
eval_set, episode_id, instruction = key
instruction_hash = hashlib.md5(instruction.encode("utf-8")).hexdigest()[:12]
trajectory_id = f"{eval_set}__episode_{episode_id}__{instruction_hash}"
success_runs = []
failure_runs = []
for record in sorted(group_records, key=lambda item: str(item.get("model_name", ""))):
run = build_run(record)
if record.get("success") in (1, 1.0, True):
success_runs.append(run)
else:
failure_runs.append(run)
run_count_distribution[(len(success_runs), len(failure_runs))] += 1
shared_fields = {
"case_index": idx,
"trajectory_id": trajectory_id,
"eval_set": eval_set,
"episode_id": episode_id,
"instruction": instruction,
"paired_success_run_count": len(success_runs),
"paired_failure_run_count": len(failure_runs),
}
success_cases.append(
{
**shared_fields,
"num_runs": len(success_runs),
"models": [run["model_name"] for run in success_runs],
"runs": success_runs,
}
)
failure_cases.append(
{
**shared_fields,
"num_runs": len(failure_runs),
"models": [run["model_name"] for run in failure_runs],
"runs": failure_runs,
}
)
common_meta = {
"source_file": str(source_file),
"group_rule": ["eval_set", "episode_id", "instruction"],
"selection_rule": "mixed tasks only: at least one success run and at least one failure run",
"num_input_records": len(records),
"num_mixed_tasks": len(success_cases),
"case_order": "sorted by (eval_set, int(episode_id) when possible, instruction)",
"run_count_distribution": {
f"success_{succ}__failure_{fail}": count
for (succ, fail), count in sorted(run_count_distribution.items())
},
}
success_payload = {
**common_meta,
"case_type": "success_runs_only",
"cases": success_cases,
}
failure_payload = {
**common_meta,
"case_type": "failure_runs_only",
"cases": failure_cases,
}
return success_payload, failure_payload
def main() -> None:
args = parse_args()
with args.input.open("r", encoding="utf-8") as f:
records = json.load(f)
success_payload, failure_payload = build_payloads(records, args.input)
args.success_output.parent.mkdir(parents=True, exist_ok=True)
args.failure_output.parent.mkdir(parents=True, exist_ok=True)
with args.success_output.open("w", encoding="utf-8") as f:
json.dump(success_payload, f, ensure_ascii=False, indent=2)
with args.failure_output.open("w", encoding="utf-8") as f:
json.dump(failure_payload, f, ensure_ascii=False, indent=2)
print(f"Input records: {len(records)}")
print(f"Mixed tasks: {success_payload['num_mixed_tasks']}")
print(f"Saved success cases to: {args.success_output}")
print(f"Saved failure cases to: {args.failure_output}")
if __name__ == "__main__":
main()