gui-shift / src /data_construction /build_kstep_data.py
luanns's picture
Upload src/data_construction/build_kstep_data.py
3873289 verified
#!/usr/bin/env python3
"""
K-step GUI Transition Data Construction
Builds self-supervised training data from GUI trajectories.
From: GUI-Shift paper (arXiv:2505.12493)
- Given state pairs (S_t, S_{t+k}), predict the first action a_t
- No textual annotations needed — future state is the visual goal
"""
import argparse
import json
import os
import random
from pathlib import Path
from typing import List, Dict, Any, Tuple
from collections import defaultdict
def parse_trajectory(trajectory: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Parse a GUI trajectory into a list of state-action pairs."""
steps = trajectory.get("steps", [])
parsed = []
for step in steps:
parsed.append({
"screenshot": step.get("screenshot", step.get("img_path")),
"action": step.get("action", {}),
"instruction": step.get("instruction", ""),
})
return parsed
def build_k_step_pairs(
trajectory: List[Dict[str, Any]],
k: int,
episode_id: str,
) -> List[Dict[str, Any]]:
"""Build (S_t, S_{t+k}) pairs with action a_t as target."""
samples = []
max_t = len(trajectory) - k
for t in range(max_t):
state_t = trajectory[t]
state_tk = trajectory[t + k]
action_t = state_t["action"]
# Skip if missing screenshots
if not state_t["screenshot"] or not state_tk["screenshot"]:
continue
# Skip if action is missing/empty
if not action_t or not isinstance(action_t, dict):
continue
sample = {
"id": f"{episode_id}_step_{t:04d}_k{k}",
"image": [state_t["screenshot"], state_tk["screenshot"]],
"conversations": [
{
"from": "human",
"value": "<image><image>What is the first action that transitions the first screen to the second screen? Output your answer in <answer></answer> tags."
},
{
"from": "gpt",
"value": action_to_answer(action_t),
}
],
"ground_truth_action": action_t,
"k": k,
"episode_id": episode_id,
"step": t,
}
samples.append(sample)
return samples
def action_to_answer(action: Dict[str, Any]) -> str:
"""Convert action dict to answer string in required format."""
action_type = action.get("action_type", action.get("type", ""))
if action_type in ["click", "long_press"]:
bbox = action.get("bbox", [0, 0, 0, 0])
x = (bbox[0] + bbox[2]) // 2 if len(bbox) >= 4 else action.get("x", 0)
y = (bbox[1] + bbox[3]) // 2 if len(bbox) >= 4 else action.get("y", 0)
return f'<answer>{{"action_type": "{action_type}", "x": {x}, "y": {y}}}</answer>'
elif action_type == "scroll":
direction = action.get("direction", action.get("scroll_direction", "down"))
return f'<answer>{{"action_type": "scroll", "direction": "{direction}"}}</answer>'
elif action_type == "open_app":
app_name = action.get("app_name", action.get("app", ""))
return f'<answer>{{"action_type": "open_app", "app_name": "{app_name}"}}</answer>'
elif action_type == "input_text":
text = action.get("text", action.get("input_text", ""))
# Escape quotes in text
text = text.replace('"', '\\"')
return f'<answer>{{"action_type": "input_text", "text": "{text}"}}</answer>'
elif action_type in ["navigate_back", "navigate_home", "wait"]:
return f'<answer>{{"action_type": "{action_type}"}}</answer>'
else:
return f'<answer>{{"action_type": "{action_type}"}}</answer>'
def load_androidcontrol_trajectories(data_dir: str) -> List[Dict[str, Any]]:
"""Load AndroidControl dataset trajectories."""
trajectories = []
data_path = Path(data_dir)
# AndroidControl format: JSON or JSONL files with episodes
json_files = list(data_path.glob("**/*.json")) + list(data_path.glob("**/*.jsonl"))
for file_path in json_files:
if file_path.suffix == ".jsonl":
with open(file_path, "r") as f:
for line in f:
if line.strip():
trajectories.append(json.loads(line))
else:
with open(file_path, "r") as f:
data = json.load(f)
if isinstance(data, list):
trajectories.extend(data)
else:
trajectories.append(data)
return trajectories
def main():
parser = argparse.ArgumentParser(description="Build K-step GUI Transition data")
parser.add_argument("--input_dir", type=str, required=True, help="Directory with GUI trajectories")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for K-step data")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 2, 3, 4], help="K values to use")
parser.add_argument("--samples_per_k", type=int, default=2000, help="Target samples per k value")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
random.seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
print(f"Loading trajectories from {args.input_dir}...")
trajectories = load_androidcontrol_trajectories(args.input_dir)
print(f"Loaded {len(trajectories)} trajectories")
# Parse all trajectories
parsed_traj = {}
for i, traj in enumerate(trajectories):
episode_id = traj.get("episode_id", traj.get("id", f"ep_{i:05d}"))
steps = parse_trajectory(traj)
if len(steps) >= 2:
parsed_traj[episode_id] = steps
print(f"Parsed {len(parsed_traj)} valid trajectories")
# Build K-step pairs for each k
for k in args.k_values:
print(f"\nBuilding K={k} step transition data...")
all_samples = []
for episode_id, steps in parsed_traj.items():
if len(steps) <= k:
continue
samples = build_k_step_pairs(steps, k, episode_id)
all_samples.extend(samples)
print(f" Generated {len(all_samples)} raw samples for k={k}")
# Sample down to target count if needed
if len(all_samples) > args.samples_per_k:
selected = random.sample(all_samples, args.samples_per_k)
else:
selected = all_samples
# Write to file
output_file = os.path.join(args.output_dir, f"k{k}_transition.jsonl")
with open(output_file, "w") as f:
for sample in selected:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f" Wrote {len(selected)} samples to {output_file}")
# Write metadata
metadata = {
"source": "AndroidControl",
"k_values": args.k_values,
"samples_per_k": args.samples_per_k,
"num_trajectories": len(parsed_traj),
"seed": args.seed,
}
with open(os.path.join(args.output_dir, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
print("\nData construction complete!")
if __name__ == "__main__":
main()