#!/usr/bin/env python3 """ K-step GUI Transition Data Construction Builds self-supervised training data from GUI trajectories. From: GUI-Shift paper (arXiv:2505.12493) - Given state pairs (S_t, S_{t+k}), predict the first action a_t - No textual annotations needed — future state is the visual goal """ import argparse import json import os import random from pathlib import Path from typing import List, Dict, Any, Tuple from collections import defaultdict def parse_trajectory(trajectory: Dict[str, Any]) -> List[Dict[str, Any]]: """Parse a GUI trajectory into a list of state-action pairs.""" steps = trajectory.get("steps", []) parsed = [] for step in steps: parsed.append({ "screenshot": step.get("screenshot", step.get("img_path")), "action": step.get("action", {}), "instruction": step.get("instruction", ""), }) return parsed def build_k_step_pairs( trajectory: List[Dict[str, Any]], k: int, episode_id: str, ) -> List[Dict[str, Any]]: """Build (S_t, S_{t+k}) pairs with action a_t as target.""" samples = [] max_t = len(trajectory) - k for t in range(max_t): state_t = trajectory[t] state_tk = trajectory[t + k] action_t = state_t["action"] # Skip if missing screenshots if not state_t["screenshot"] or not state_tk["screenshot"]: continue # Skip if action is missing/empty if not action_t or not isinstance(action_t, dict): continue sample = { "id": f"{episode_id}_step_{t:04d}_k{k}", "image": [state_t["screenshot"], state_tk["screenshot"]], "conversations": [ { "from": "human", "value": "What is the first action that transitions the first screen to the second screen? Output your answer in tags." }, { "from": "gpt", "value": action_to_answer(action_t), } ], "ground_truth_action": action_t, "k": k, "episode_id": episode_id, "step": t, } samples.append(sample) return samples def action_to_answer(action: Dict[str, Any]) -> str: """Convert action dict to answer string in required format.""" action_type = action.get("action_type", action.get("type", "")) if action_type in ["click", "long_press"]: bbox = action.get("bbox", [0, 0, 0, 0]) x = (bbox[0] + bbox[2]) // 2 if len(bbox) >= 4 else action.get("x", 0) y = (bbox[1] + bbox[3]) // 2 if len(bbox) >= 4 else action.get("y", 0) return f'{{"action_type": "{action_type}", "x": {x}, "y": {y}}}' elif action_type == "scroll": direction = action.get("direction", action.get("scroll_direction", "down")) return f'{{"action_type": "scroll", "direction": "{direction}"}}' elif action_type == "open_app": app_name = action.get("app_name", action.get("app", "")) return f'{{"action_type": "open_app", "app_name": "{app_name}"}}' elif action_type == "input_text": text = action.get("text", action.get("input_text", "")) # Escape quotes in text text = text.replace('"', '\\"') return f'{{"action_type": "input_text", "text": "{text}"}}' elif action_type in ["navigate_back", "navigate_home", "wait"]: return f'{{"action_type": "{action_type}"}}' else: return f'{{"action_type": "{action_type}"}}' def load_androidcontrol_trajectories(data_dir: str) -> List[Dict[str, Any]]: """Load AndroidControl dataset trajectories.""" trajectories = [] data_path = Path(data_dir) # AndroidControl format: JSON or JSONL files with episodes json_files = list(data_path.glob("**/*.json")) + list(data_path.glob("**/*.jsonl")) for file_path in json_files: if file_path.suffix == ".jsonl": with open(file_path, "r") as f: for line in f: if line.strip(): trajectories.append(json.loads(line)) else: with open(file_path, "r") as f: data = json.load(f) if isinstance(data, list): trajectories.extend(data) else: trajectories.append(data) return trajectories def main(): parser = argparse.ArgumentParser(description="Build K-step GUI Transition data") parser.add_argument("--input_dir", type=str, required=True, help="Directory with GUI trajectories") parser.add_argument("--output_dir", type=str, required=True, help="Output directory for K-step data") parser.add_argument("--k_values", type=int, nargs="+", default=[1, 2, 3, 4], help="K values to use") parser.add_argument("--samples_per_k", type=int, default=2000, help="Target samples per k value") parser.add_argument("--seed", type=int, default=42, help="Random seed") args = parser.parse_args() random.seed(args.seed) os.makedirs(args.output_dir, exist_ok=True) print(f"Loading trajectories from {args.input_dir}...") trajectories = load_androidcontrol_trajectories(args.input_dir) print(f"Loaded {len(trajectories)} trajectories") # Parse all trajectories parsed_traj = {} for i, traj in enumerate(trajectories): episode_id = traj.get("episode_id", traj.get("id", f"ep_{i:05d}")) steps = parse_trajectory(traj) if len(steps) >= 2: parsed_traj[episode_id] = steps print(f"Parsed {len(parsed_traj)} valid trajectories") # Build K-step pairs for each k for k in args.k_values: print(f"\nBuilding K={k} step transition data...") all_samples = [] for episode_id, steps in parsed_traj.items(): if len(steps) <= k: continue samples = build_k_step_pairs(steps, k, episode_id) all_samples.extend(samples) print(f" Generated {len(all_samples)} raw samples for k={k}") # Sample down to target count if needed if len(all_samples) > args.samples_per_k: selected = random.sample(all_samples, args.samples_per_k) else: selected = all_samples # Write to file output_file = os.path.join(args.output_dir, f"k{k}_transition.jsonl") with open(output_file, "w") as f: for sample in selected: f.write(json.dumps(sample, ensure_ascii=False) + "\n") print(f" Wrote {len(selected)} samples to {output_file}") # Write metadata metadata = { "source": "AndroidControl", "k_values": args.k_values, "samples_per_k": args.samples_per_k, "num_trajectories": len(parsed_traj), "seed": args.seed, } with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: json.dump(metadata, f, indent=2) print("\nData construction complete!") if __name__ == "__main__": main()