Wendy-Fly commited on
Commit
74ab802
·
verified ·
1 Parent(s): 6dc3323

Upload extract_eb_alfred_mixed_case_diff_steps.py with huggingface_hub

Browse files
extract_eb_alfred_mixed_case_diff_steps.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract divergent steps for aligned EB-ALFRED mixed-task cases.
4
+
5
+ For each aligned success/failure case pair, the script finds the most similar
6
+ success run and failure run by maximizing their shared action prefix length.
7
+ It then saves only the steps from the first divergence onward for both sides.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import copy
14
+ import json
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+
19
+ DEFAULT_SUCCESS_INPUT = Path(
20
+ "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_success_cases.json"
21
+ )
22
+ DEFAULT_FAILURE_INPUT = Path(
23
+ "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_failure_cases.json"
24
+ )
25
+ DEFAULT_OUTPUT = Path(
26
+ "/data/Top_Spcae/ICML_2026/Dataset/eb-alfred_mixed_tasks_diff_steps.json"
27
+ )
28
+
29
+
30
+ def parse_args() -> argparse.Namespace:
31
+ parser = argparse.ArgumentParser(
32
+ description="Extract divergent steps from aligned EB-ALFRED mixed-task cases."
33
+ )
34
+ parser.add_argument(
35
+ "--success-input",
36
+ type=Path,
37
+ default=DEFAULT_SUCCESS_INPUT,
38
+ help=f"Path to aligned success cases JSON. Default: {DEFAULT_SUCCESS_INPUT}",
39
+ )
40
+ parser.add_argument(
41
+ "--failure-input",
42
+ type=Path,
43
+ default=DEFAULT_FAILURE_INPUT,
44
+ help=f"Path to aligned failure cases JSON. Default: {DEFAULT_FAILURE_INPUT}",
45
+ )
46
+ parser.add_argument(
47
+ "--output",
48
+ type=Path,
49
+ default=DEFAULT_OUTPUT,
50
+ help=f"Path to output JSON. Default: {DEFAULT_OUTPUT}",
51
+ )
52
+ return parser.parse_args()
53
+
54
+
55
+ def action_sequence(run: dict[str, Any]) -> list[str]:
56
+ actions: list[str] = []
57
+ for step in run.get("trajectory", []) or []:
58
+ executable_plan = step.get("executable_plan", []) or []
59
+ if executable_plan:
60
+ actions.append(str(executable_plan[0]["action"][1]))
61
+ return actions
62
+
63
+
64
+ def shared_prefix_len(a: list[str], b: list[str]) -> int:
65
+ prefix = 0
66
+ for left, right in zip(a, b):
67
+ if left != right:
68
+ break
69
+ prefix += 1
70
+ return prefix
71
+
72
+
73
+ def choose_best_pair(
74
+ success_case: dict[str, Any], failure_case: dict[str, Any]
75
+ ) -> tuple[dict[str, Any], dict[str, Any], int]:
76
+ best_success_run: dict[str, Any] | None = None
77
+ best_failure_run: dict[str, Any] | None = None
78
+ best_prefix = -1
79
+ best_success_len = -1
80
+ best_failure_len = -1
81
+
82
+ for success_run in success_case.get("runs", []):
83
+ success_actions = action_sequence(success_run)
84
+ for failure_run in failure_case.get("runs", []):
85
+ failure_actions = action_sequence(failure_run)
86
+ prefix = shared_prefix_len(success_actions, failure_actions)
87
+
88
+ # Prefer longer shared prefixes. On ties, prefer shorter total tails.
89
+ success_len = len(success_actions)
90
+ failure_len = len(failure_actions)
91
+ is_better = (
92
+ prefix > best_prefix
93
+ or (
94
+ prefix == best_prefix
95
+ and success_len + failure_len < best_success_len + best_failure_len
96
+ )
97
+ )
98
+ if is_better:
99
+ best_success_run = success_run
100
+ best_failure_run = failure_run
101
+ best_prefix = prefix
102
+ best_success_len = success_len
103
+ best_failure_len = failure_len
104
+
105
+ if best_success_run is None or best_failure_run is None:
106
+ raise ValueError(
107
+ f"Could not find a valid run pair for case {success_case.get('case_index')}"
108
+ )
109
+
110
+ return best_success_run, best_failure_run, best_prefix
111
+
112
+
113
+ def annotate_diff_steps(
114
+ trajectory_steps: list[dict[str, Any]],
115
+ prefix: int,
116
+ trajectory_outcome: str,
117
+ selected_model: str | None,
118
+ ) -> list[dict[str, Any]]:
119
+ annotated_steps: list[dict[str, Any]] = []
120
+ trajectory_success = 1.0 if trajectory_outcome == "success" else 0.0
121
+ for original_step_index, step in enumerate(trajectory_steps[prefix:], start=prefix):
122
+ annotated_step = copy.deepcopy(step)
123
+ executable_plan = annotated_step.get("executable_plan", []) or []
124
+ for executable in executable_plan:
125
+ executable["trajectory_outcome"] = trajectory_outcome
126
+ executable["trajectory_success"] = trajectory_success
127
+ executable["selected_model"] = selected_model
128
+ executable["original_step_index_0based"] = original_step_index
129
+ executable["original_step_index_1based"] = original_step_index + 1
130
+ annotated_step["trajectory_outcome"] = trajectory_outcome
131
+ annotated_step["trajectory_success"] = trajectory_success
132
+ annotated_step["selected_model"] = selected_model
133
+ annotated_step["original_step_index_0based"] = original_step_index
134
+ annotated_step["original_step_index_1based"] = original_step_index + 1
135
+ annotated_steps.append(annotated_step)
136
+ return annotated_steps
137
+
138
+
139
+ def build_diff_case(
140
+ success_case: dict[str, Any], failure_case: dict[str, Any]
141
+ ) -> dict[str, Any]:
142
+ keys_to_match = ["case_index", "trajectory_id", "eval_set", "episode_id", "instruction"]
143
+ for key in keys_to_match:
144
+ if success_case.get(key) != failure_case.get(key):
145
+ raise ValueError(
146
+ f"Mismatched aligned cases for key '{key}': "
147
+ f"{success_case.get(key)!r} != {failure_case.get(key)!r}"
148
+ )
149
+
150
+ success_run, failure_run, prefix = choose_best_pair(success_case, failure_case)
151
+ success_actions = action_sequence(success_run)
152
+ failure_actions = action_sequence(failure_run)
153
+
154
+ success_diff_steps = annotate_diff_steps(
155
+ success_run.get("trajectory", []) or [],
156
+ prefix,
157
+ "success",
158
+ success_run.get("model_name"),
159
+ )
160
+ failure_diff_steps = annotate_diff_steps(
161
+ failure_run.get("trajectory", []) or [],
162
+ prefix,
163
+ "failure",
164
+ failure_run.get("model_name"),
165
+ )
166
+
167
+ success_first_diff_action = success_actions[prefix] if prefix < len(success_actions) else None
168
+ failure_first_diff_action = failure_actions[prefix] if prefix < len(failure_actions) else None
169
+
170
+ return {
171
+ "case_index": success_case["case_index"],
172
+ "trajectory_id": success_case["trajectory_id"],
173
+ "eval_set": success_case["eval_set"],
174
+ "episode_id": success_case["episode_id"],
175
+ "instruction": success_case["instruction"],
176
+ "paired_success_run_count": success_case["paired_success_run_count"],
177
+ "paired_failure_run_count": success_case["paired_failure_run_count"],
178
+ "selected_success_model": success_run.get("model_name"),
179
+ "selected_failure_model": failure_run.get("model_name"),
180
+ "selected_success_input": success_run.get("input"),
181
+ "selected_failure_input": failure_run.get("input"),
182
+ "selected_success_num_steps": success_run.get("num_steps"),
183
+ "selected_failure_num_steps": failure_run.get("num_steps"),
184
+ "shared_prefix_len": prefix,
185
+ "shared_prefix_actions": success_actions[:prefix],
186
+ "first_diff_step_index_0based": prefix,
187
+ "first_diff_step_index_1based": prefix + 1,
188
+ "first_diff_success_action": success_first_diff_action,
189
+ "first_diff_failure_action": failure_first_diff_action,
190
+ "success_diff_num_steps": len(success_diff_steps),
191
+ "failure_diff_num_steps": len(failure_diff_steps),
192
+ "success_diff_actions": success_actions[prefix:],
193
+ "failure_diff_actions": failure_actions[prefix:],
194
+ "success_diff_steps": success_diff_steps,
195
+ "failure_diff_steps": failure_diff_steps,
196
+ }
197
+
198
+
199
+ def main() -> None:
200
+ args = parse_args()
201
+
202
+ with args.success_input.open("r", encoding="utf-8") as f:
203
+ success_payload = json.load(f)
204
+ with args.failure_input.open("r", encoding="utf-8") as f:
205
+ failure_payload = json.load(f)
206
+
207
+ success_cases = success_payload["cases"]
208
+ failure_cases = failure_payload["cases"]
209
+ if len(success_cases) != len(failure_cases):
210
+ raise ValueError(
211
+ f"Case count mismatch: {len(success_cases)} success cases vs "
212
+ f"{len(failure_cases)} failure cases"
213
+ )
214
+
215
+ diff_cases = [
216
+ build_diff_case(success_case, failure_case)
217
+ for success_case, failure_case in zip(success_cases, failure_cases)
218
+ ]
219
+
220
+ output_payload = {
221
+ "success_input_file": str(args.success_input),
222
+ "failure_input_file": str(args.failure_input),
223
+ "selection_rule": (
224
+ "For each aligned case, choose the success/failure run pair with the "
225
+ "longest shared action prefix; extract only the divergent tail steps."
226
+ ),
227
+ "num_cases": len(diff_cases),
228
+ "cases": diff_cases,
229
+ }
230
+
231
+ args.output.parent.mkdir(parents=True, exist_ok=True)
232
+ with args.output.open("w", encoding="utf-8") as f:
233
+ json.dump(output_payload, f, ensure_ascii=False, indent=2)
234
+
235
+ print(f"Aligned cases processed: {len(diff_cases)}")
236
+ print(f"Saved divergent-step file to: {args.output}")
237
+
238
+
239
+ if __name__ == "__main__":
240
+ main()