File size: 7,331 Bytes
3873289 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | #!/usr/bin/env python3
"""
K-step GUI Transition Data Construction
Builds self-supervised training data from GUI trajectories.
From: GUI-Shift paper (arXiv:2505.12493)
- Given state pairs (S_t, S_{t+k}), predict the first action a_t
- No textual annotations needed — future state is the visual goal
"""
import argparse
import json
import os
import random
from pathlib import Path
from typing import List, Dict, Any, Tuple
from collections import defaultdict
def parse_trajectory(trajectory: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Parse a GUI trajectory into a list of state-action pairs."""
steps = trajectory.get("steps", [])
parsed = []
for step in steps:
parsed.append({
"screenshot": step.get("screenshot", step.get("img_path")),
"action": step.get("action", {}),
"instruction": step.get("instruction", ""),
})
return parsed
def build_k_step_pairs(
trajectory: List[Dict[str, Any]],
k: int,
episode_id: str,
) -> List[Dict[str, Any]]:
"""Build (S_t, S_{t+k}) pairs with action a_t as target."""
samples = []
max_t = len(trajectory) - k
for t in range(max_t):
state_t = trajectory[t]
state_tk = trajectory[t + k]
action_t = state_t["action"]
# Skip if missing screenshots
if not state_t["screenshot"] or not state_tk["screenshot"]:
continue
# Skip if action is missing/empty
if not action_t or not isinstance(action_t, dict):
continue
sample = {
"id": f"{episode_id}_step_{t:04d}_k{k}",
"image": [state_t["screenshot"], state_tk["screenshot"]],
"conversations": [
{
"from": "human",
"value": "<image><image>What is the first action that transitions the first screen to the second screen? Output your answer in <answer></answer> tags."
},
{
"from": "gpt",
"value": action_to_answer(action_t),
}
],
"ground_truth_action": action_t,
"k": k,
"episode_id": episode_id,
"step": t,
}
samples.append(sample)
return samples
def action_to_answer(action: Dict[str, Any]) -> str:
"""Convert action dict to answer string in required format."""
action_type = action.get("action_type", action.get("type", ""))
if action_type in ["click", "long_press"]:
bbox = action.get("bbox", [0, 0, 0, 0])
x = (bbox[0] + bbox[2]) // 2 if len(bbox) >= 4 else action.get("x", 0)
y = (bbox[1] + bbox[3]) // 2 if len(bbox) >= 4 else action.get("y", 0)
return f'<answer>{{"action_type": "{action_type}", "x": {x}, "y": {y}}}</answer>'
elif action_type == "scroll":
direction = action.get("direction", action.get("scroll_direction", "down"))
return f'<answer>{{"action_type": "scroll", "direction": "{direction}"}}</answer>'
elif action_type == "open_app":
app_name = action.get("app_name", action.get("app", ""))
return f'<answer>{{"action_type": "open_app", "app_name": "{app_name}"}}</answer>'
elif action_type == "input_text":
text = action.get("text", action.get("input_text", ""))
# Escape quotes in text
text = text.replace('"', '\\"')
return f'<answer>{{"action_type": "input_text", "text": "{text}"}}</answer>'
elif action_type in ["navigate_back", "navigate_home", "wait"]:
return f'<answer>{{"action_type": "{action_type}"}}</answer>'
else:
return f'<answer>{{"action_type": "{action_type}"}}</answer>'
def load_androidcontrol_trajectories(data_dir: str) -> List[Dict[str, Any]]:
"""Load AndroidControl dataset trajectories."""
trajectories = []
data_path = Path(data_dir)
# AndroidControl format: JSON or JSONL files with episodes
json_files = list(data_path.glob("**/*.json")) + list(data_path.glob("**/*.jsonl"))
for file_path in json_files:
if file_path.suffix == ".jsonl":
with open(file_path, "r") as f:
for line in f:
if line.strip():
trajectories.append(json.loads(line))
else:
with open(file_path, "r") as f:
data = json.load(f)
if isinstance(data, list):
trajectories.extend(data)
else:
trajectories.append(data)
return trajectories
def main():
parser = argparse.ArgumentParser(description="Build K-step GUI Transition data")
parser.add_argument("--input_dir", type=str, required=True, help="Directory with GUI trajectories")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for K-step data")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 2, 3, 4], help="K values to use")
parser.add_argument("--samples_per_k", type=int, default=2000, help="Target samples per k value")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
random.seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
print(f"Loading trajectories from {args.input_dir}...")
trajectories = load_androidcontrol_trajectories(args.input_dir)
print(f"Loaded {len(trajectories)} trajectories")
# Parse all trajectories
parsed_traj = {}
for i, traj in enumerate(trajectories):
episode_id = traj.get("episode_id", traj.get("id", f"ep_{i:05d}"))
steps = parse_trajectory(traj)
if len(steps) >= 2:
parsed_traj[episode_id] = steps
print(f"Parsed {len(parsed_traj)} valid trajectories")
# Build K-step pairs for each k
for k in args.k_values:
print(f"\nBuilding K={k} step transition data...")
all_samples = []
for episode_id, steps in parsed_traj.items():
if len(steps) <= k:
continue
samples = build_k_step_pairs(steps, k, episode_id)
all_samples.extend(samples)
print(f" Generated {len(all_samples)} raw samples for k={k}")
# Sample down to target count if needed
if len(all_samples) > args.samples_per_k:
selected = random.sample(all_samples, args.samples_per_k)
else:
selected = all_samples
# Write to file
output_file = os.path.join(args.output_dir, f"k{k}_transition.jsonl")
with open(output_file, "w") as f:
for sample in selected:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f" Wrote {len(selected)} samples to {output_file}")
# Write metadata
metadata = {
"source": "AndroidControl",
"k_values": args.k_values,
"samples_per_k": args.samples_per_k,
"num_trajectories": len(parsed_traj),
"seed": args.seed,
}
with open(os.path.join(args.output_dir, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
print("\nData construction complete!")
if __name__ == "__main__":
main()
|