luanns commited on
Commit
3873289
·
verified ·
1 Parent(s): 1345eac

Upload src/data_construction/build_kstep_data.py

Browse files
src/data_construction/build_kstep_data.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ K-step GUI Transition Data Construction
4
+ Builds self-supervised training data from GUI trajectories.
5
+
6
+ From: GUI-Shift paper (arXiv:2505.12493)
7
+ - Given state pairs (S_t, S_{t+k}), predict the first action a_t
8
+ - No textual annotations needed — future state is the visual goal
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+ import random
15
+ from pathlib import Path
16
+ from typing import List, Dict, Any, Tuple
17
+ from collections import defaultdict
18
+
19
+
20
+ def parse_trajectory(trajectory: Dict[str, Any]) -> List[Dict[str, Any]]:
21
+ """Parse a GUI trajectory into a list of state-action pairs."""
22
+ steps = trajectory.get("steps", [])
23
+ parsed = []
24
+ for step in steps:
25
+ parsed.append({
26
+ "screenshot": step.get("screenshot", step.get("img_path")),
27
+ "action": step.get("action", {}),
28
+ "instruction": step.get("instruction", ""),
29
+ })
30
+ return parsed
31
+
32
+
33
+ def build_k_step_pairs(
34
+ trajectory: List[Dict[str, Any]],
35
+ k: int,
36
+ episode_id: str,
37
+ ) -> List[Dict[str, Any]]:
38
+ """Build (S_t, S_{t+k}) pairs with action a_t as target."""
39
+ samples = []
40
+ max_t = len(trajectory) - k
41
+
42
+ for t in range(max_t):
43
+ state_t = trajectory[t]
44
+ state_tk = trajectory[t + k]
45
+ action_t = state_t["action"]
46
+
47
+ # Skip if missing screenshots
48
+ if not state_t["screenshot"] or not state_tk["screenshot"]:
49
+ continue
50
+
51
+ # Skip if action is missing/empty
52
+ if not action_t or not isinstance(action_t, dict):
53
+ continue
54
+
55
+ sample = {
56
+ "id": f"{episode_id}_step_{t:04d}_k{k}",
57
+ "image": [state_t["screenshot"], state_tk["screenshot"]],
58
+ "conversations": [
59
+ {
60
+ "from": "human",
61
+ "value": "<image><image>What is the first action that transitions the first screen to the second screen? Output your answer in <answer></answer> tags."
62
+ },
63
+ {
64
+ "from": "gpt",
65
+ "value": action_to_answer(action_t),
66
+ }
67
+ ],
68
+ "ground_truth_action": action_t,
69
+ "k": k,
70
+ "episode_id": episode_id,
71
+ "step": t,
72
+ }
73
+ samples.append(sample)
74
+
75
+ return samples
76
+
77
+
78
+ def action_to_answer(action: Dict[str, Any]) -> str:
79
+ """Convert action dict to answer string in required format."""
80
+ action_type = action.get("action_type", action.get("type", ""))
81
+
82
+ if action_type in ["click", "long_press"]:
83
+ bbox = action.get("bbox", [0, 0, 0, 0])
84
+ x = (bbox[0] + bbox[2]) // 2 if len(bbox) >= 4 else action.get("x", 0)
85
+ y = (bbox[1] + bbox[3]) // 2 if len(bbox) >= 4 else action.get("y", 0)
86
+ return f'<answer>{{"action_type": "{action_type}", "x": {x}, "y": {y}}}</answer>'
87
+
88
+ elif action_type == "scroll":
89
+ direction = action.get("direction", action.get("scroll_direction", "down"))
90
+ return f'<answer>{{"action_type": "scroll", "direction": "{direction}"}}</answer>'
91
+
92
+ elif action_type == "open_app":
93
+ app_name = action.get("app_name", action.get("app", ""))
94
+ return f'<answer>{{"action_type": "open_app", "app_name": "{app_name}"}}</answer>'
95
+
96
+ elif action_type == "input_text":
97
+ text = action.get("text", action.get("input_text", ""))
98
+ # Escape quotes in text
99
+ text = text.replace('"', '\\"')
100
+ return f'<answer>{{"action_type": "input_text", "text": "{text}"}}</answer>'
101
+
102
+ elif action_type in ["navigate_back", "navigate_home", "wait"]:
103
+ return f'<answer>{{"action_type": "{action_type}"}}</answer>'
104
+
105
+ else:
106
+ return f'<answer>{{"action_type": "{action_type}"}}</answer>'
107
+
108
+
109
+ def load_androidcontrol_trajectories(data_dir: str) -> List[Dict[str, Any]]:
110
+ """Load AndroidControl dataset trajectories."""
111
+ trajectories = []
112
+ data_path = Path(data_dir)
113
+
114
+ # AndroidControl format: JSON or JSONL files with episodes
115
+ json_files = list(data_path.glob("**/*.json")) + list(data_path.glob("**/*.jsonl"))
116
+
117
+ for file_path in json_files:
118
+ if file_path.suffix == ".jsonl":
119
+ with open(file_path, "r") as f:
120
+ for line in f:
121
+ if line.strip():
122
+ trajectories.append(json.loads(line))
123
+ else:
124
+ with open(file_path, "r") as f:
125
+ data = json.load(f)
126
+ if isinstance(data, list):
127
+ trajectories.extend(data)
128
+ else:
129
+ trajectories.append(data)
130
+
131
+ return trajectories
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(description="Build K-step GUI Transition data")
136
+ parser.add_argument("--input_dir", type=str, required=True, help="Directory with GUI trajectories")
137
+ parser.add_argument("--output_dir", type=str, required=True, help="Output directory for K-step data")
138
+ parser.add_argument("--k_values", type=int, nargs="+", default=[1, 2, 3, 4], help="K values to use")
139
+ parser.add_argument("--samples_per_k", type=int, default=2000, help="Target samples per k value")
140
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
141
+ args = parser.parse_args()
142
+
143
+ random.seed(args.seed)
144
+ os.makedirs(args.output_dir, exist_ok=True)
145
+
146
+ print(f"Loading trajectories from {args.input_dir}...")
147
+ trajectories = load_androidcontrol_trajectories(args.input_dir)
148
+ print(f"Loaded {len(trajectories)} trajectories")
149
+
150
+ # Parse all trajectories
151
+ parsed_traj = {}
152
+ for i, traj in enumerate(trajectories):
153
+ episode_id = traj.get("episode_id", traj.get("id", f"ep_{i:05d}"))
154
+ steps = parse_trajectory(traj)
155
+ if len(steps) >= 2:
156
+ parsed_traj[episode_id] = steps
157
+
158
+ print(f"Parsed {len(parsed_traj)} valid trajectories")
159
+
160
+ # Build K-step pairs for each k
161
+ for k in args.k_values:
162
+ print(f"\nBuilding K={k} step transition data...")
163
+ all_samples = []
164
+
165
+ for episode_id, steps in parsed_traj.items():
166
+ if len(steps) <= k:
167
+ continue
168
+ samples = build_k_step_pairs(steps, k, episode_id)
169
+ all_samples.extend(samples)
170
+
171
+ print(f" Generated {len(all_samples)} raw samples for k={k}")
172
+
173
+ # Sample down to target count if needed
174
+ if len(all_samples) > args.samples_per_k:
175
+ selected = random.sample(all_samples, args.samples_per_k)
176
+ else:
177
+ selected = all_samples
178
+
179
+ # Write to file
180
+ output_file = os.path.join(args.output_dir, f"k{k}_transition.jsonl")
181
+ with open(output_file, "w") as f:
182
+ for sample in selected:
183
+ f.write(json.dumps(sample, ensure_ascii=False) + "\n")
184
+
185
+ print(f" Wrote {len(selected)} samples to {output_file}")
186
+
187
+ # Write metadata
188
+ metadata = {
189
+ "source": "AndroidControl",
190
+ "k_values": args.k_values,
191
+ "samples_per_k": args.samples_per_k,
192
+ "num_trajectories": len(parsed_traj),
193
+ "seed": args.seed,
194
+ }
195
+ with open(os.path.join(args.output_dir, "metadata.json"), "w") as f:
196
+ json.dump(metadata, f, indent=2)
197
+
198
+ print("\nData construction complete!")
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()