| |
| """ |
| K-step GUI Transition Data Construction |
| Builds self-supervised training data from GUI trajectories. |
| |
| From: GUI-Shift paper (arXiv:2505.12493) |
| - Given state pairs (S_t, S_{t+k}), predict the first action a_t |
| - No textual annotations needed — future state is the visual goal |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import random |
| from pathlib import Path |
| from typing import List, Dict, Any, Tuple |
| from collections import defaultdict |
|
|
|
|
| def parse_trajectory(trajectory: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """Parse a GUI trajectory into a list of state-action pairs.""" |
| steps = trajectory.get("steps", []) |
| parsed = [] |
| for step in steps: |
| parsed.append({ |
| "screenshot": step.get("screenshot", step.get("img_path")), |
| "action": step.get("action", {}), |
| "instruction": step.get("instruction", ""), |
| }) |
| return parsed |
|
|
|
|
| def build_k_step_pairs( |
| trajectory: List[Dict[str, Any]], |
| k: int, |
| episode_id: str, |
| ) -> List[Dict[str, Any]]: |
| """Build (S_t, S_{t+k}) pairs with action a_t as target.""" |
| samples = [] |
| max_t = len(trajectory) - k |
| |
| for t in range(max_t): |
| state_t = trajectory[t] |
| state_tk = trajectory[t + k] |
| action_t = state_t["action"] |
| |
| |
| if not state_t["screenshot"] or not state_tk["screenshot"]: |
| continue |
| |
| |
| if not action_t or not isinstance(action_t, dict): |
| continue |
| |
| sample = { |
| "id": f"{episode_id}_step_{t:04d}_k{k}", |
| "image": [state_t["screenshot"], state_tk["screenshot"]], |
| "conversations": [ |
| { |
| "from": "human", |
| "value": "<image><image>What is the first action that transitions the first screen to the second screen? Output your answer in <answer></answer> tags." |
| }, |
| { |
| "from": "gpt", |
| "value": action_to_answer(action_t), |
| } |
| ], |
| "ground_truth_action": action_t, |
| "k": k, |
| "episode_id": episode_id, |
| "step": t, |
| } |
| samples.append(sample) |
| |
| return samples |
|
|
|
|
| def action_to_answer(action: Dict[str, Any]) -> str: |
| """Convert action dict to answer string in required format.""" |
| action_type = action.get("action_type", action.get("type", "")) |
| |
| if action_type in ["click", "long_press"]: |
| bbox = action.get("bbox", [0, 0, 0, 0]) |
| x = (bbox[0] + bbox[2]) // 2 if len(bbox) >= 4 else action.get("x", 0) |
| y = (bbox[1] + bbox[3]) // 2 if len(bbox) >= 4 else action.get("y", 0) |
| return f'<answer>{{"action_type": "{action_type}", "x": {x}, "y": {y}}}</answer>' |
| |
| elif action_type == "scroll": |
| direction = action.get("direction", action.get("scroll_direction", "down")) |
| return f'<answer>{{"action_type": "scroll", "direction": "{direction}"}}</answer>' |
| |
| elif action_type == "open_app": |
| app_name = action.get("app_name", action.get("app", "")) |
| return f'<answer>{{"action_type": "open_app", "app_name": "{app_name}"}}</answer>' |
| |
| elif action_type == "input_text": |
| text = action.get("text", action.get("input_text", "")) |
| |
| text = text.replace('"', '\\"') |
| return f'<answer>{{"action_type": "input_text", "text": "{text}"}}</answer>' |
| |
| elif action_type in ["navigate_back", "navigate_home", "wait"]: |
| return f'<answer>{{"action_type": "{action_type}"}}</answer>' |
| |
| else: |
| return f'<answer>{{"action_type": "{action_type}"}}</answer>' |
|
|
|
|
| def load_androidcontrol_trajectories(data_dir: str) -> List[Dict[str, Any]]: |
| """Load AndroidControl dataset trajectories.""" |
| trajectories = [] |
| data_path = Path(data_dir) |
| |
| |
| json_files = list(data_path.glob("**/*.json")) + list(data_path.glob("**/*.jsonl")) |
| |
| for file_path in json_files: |
| if file_path.suffix == ".jsonl": |
| with open(file_path, "r") as f: |
| for line in f: |
| if line.strip(): |
| trajectories.append(json.loads(line)) |
| else: |
| with open(file_path, "r") as f: |
| data = json.load(f) |
| if isinstance(data, list): |
| trajectories.extend(data) |
| else: |
| trajectories.append(data) |
| |
| return trajectories |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Build K-step GUI Transition data") |
| parser.add_argument("--input_dir", type=str, required=True, help="Directory with GUI trajectories") |
| parser.add_argument("--output_dir", type=str, required=True, help="Output directory for K-step data") |
| parser.add_argument("--k_values", type=int, nargs="+", default=[1, 2, 3, 4], help="K values to use") |
| parser.add_argument("--samples_per_k", type=int, default=2000, help="Target samples per k value") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| args = parser.parse_args() |
| |
| random.seed(args.seed) |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| print(f"Loading trajectories from {args.input_dir}...") |
| trajectories = load_androidcontrol_trajectories(args.input_dir) |
| print(f"Loaded {len(trajectories)} trajectories") |
| |
| |
| parsed_traj = {} |
| for i, traj in enumerate(trajectories): |
| episode_id = traj.get("episode_id", traj.get("id", f"ep_{i:05d}")) |
| steps = parse_trajectory(traj) |
| if len(steps) >= 2: |
| parsed_traj[episode_id] = steps |
| |
| print(f"Parsed {len(parsed_traj)} valid trajectories") |
| |
| |
| for k in args.k_values: |
| print(f"\nBuilding K={k} step transition data...") |
| all_samples = [] |
| |
| for episode_id, steps in parsed_traj.items(): |
| if len(steps) <= k: |
| continue |
| samples = build_k_step_pairs(steps, k, episode_id) |
| all_samples.extend(samples) |
| |
| print(f" Generated {len(all_samples)} raw samples for k={k}") |
| |
| |
| if len(all_samples) > args.samples_per_k: |
| selected = random.sample(all_samples, args.samples_per_k) |
| else: |
| selected = all_samples |
| |
| |
| output_file = os.path.join(args.output_dir, f"k{k}_transition.jsonl") |
| with open(output_file, "w") as f: |
| for sample in selected: |
| f.write(json.dumps(sample, ensure_ascii=False) + "\n") |
| |
| print(f" Wrote {len(selected)} samples to {output_file}") |
| |
| |
| metadata = { |
| "source": "AndroidControl", |
| "k_values": args.k_values, |
| "samples_per_k": args.samples_per_k, |
| "num_trajectories": len(parsed_traj), |
| "seed": args.seed, |
| } |
| with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: |
| json.dump(metadata, f, indent=2) |
| |
| print("\nData construction complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|