| |
| """ |
| Data Filtering Pipeline for GUI-Shift. |
| |
| Filters K-step GUI Transition samples based on model-generated responses. |
| - Discards samples where all N responses are entirely correct or incorrect |
| - Keeps samples with mixed correctness (informative for learning) |
| |
| From: GUI-Shift paper Section 3.3 (arXiv:2505.12493) |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import random |
| from pathlib import Path |
| from typing import List, Dict, Any, Tuple, Optional |
|
|
| import torch |
| from transformers import AutoModelForVision2Seq, AutoProcessor |
| from PIL import Image |
|
|
|
|
| def load_model_and_processor(model_path: str, device: str = "cuda"): |
| """Load base VLM model and processor for filtering.""" |
| processor = AutoProcessor.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| ) |
| model = AutoModelForVision2Seq.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| return model, processor |
|
|
|
|
| def generate_responses( |
| model, |
| processor, |
| sample: Dict[str, Any], |
| num_generations: int = 8, |
| temperature: float = 0.9, |
| max_new_tokens: int = 256, |
| ) -> List[str]: |
| """Generate N candidate responses for a single sample.""" |
| image_paths = sample.get("image_path", sample.get("image", [])) |
| problem = sample["problem"] |
| |
| |
| images = [] |
| for img_path in image_paths: |
| if isinstance(img_path, str) and os.path.exists(img_path): |
| images.append(Image.open(img_path).convert("RGB")) |
| |
| |
| messages = [ |
| {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": problem}]} |
| ] |
| |
| text = processor.apply_chat_template(messages, add_generation_prompt=True) |
| |
| inputs = processor( |
| text=text, |
| images=images, |
| return_tensors="pt", |
| padding=True, |
| ).to(model.device) |
| |
| responses = [] |
| with torch.no_grad(): |
| for _ in range(num_generations): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| do_sample=True, |
| num_return_sequences=1, |
| ) |
| |
| |
| generated_text = processor.batch_decode( |
| outputs[:, inputs["input_ids"].shape[1]:], |
| skip_special_tokens=True, |
| )[0] |
| responses.append(generated_text) |
| |
| return responses |
|
|
|
|
| def evaluate_action_correctness( |
| response: str, |
| ground_truth: Dict[str, Any], |
| ) -> float: |
| """ |
| Evaluate if a response action matches the ground truth. |
| Returns 1.0 if correct, 0.0 if incorrect. |
| """ |
| import re |
| |
| |
| match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL) |
| if not match: |
| return 0.0 |
| |
| content = match.group(1).strip() |
| |
| try: |
| pred_action = json.loads(content) |
| except json.JSONDecodeError: |
| return 0.0 |
| |
| gt_action = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth) |
| |
| pred_type = pred_action.get("action_type", "") |
| gt_type = gt_action.get("action_type", "") |
| |
| if pred_type != gt_type: |
| return 0.0 |
| |
| |
| if pred_type in ["click", "long_press"]: |
| bbox = gt_action.get("bbox", [0, 0, 0, 0]) |
| x = pred_action.get("x", 0) |
| y = pred_action.get("y", 0) |
| if bbox and len(bbox) >= 4: |
| if bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]: |
| return 1.0 |
| |
| if "x" in gt_action and "y" in gt_action: |
| tolerance = 20 |
| if abs(x - gt_action["x"]) <= tolerance and abs(y - gt_action["y"]) <= tolerance: |
| return 1.0 |
| return 0.0 |
| |
| elif pred_type == "scroll": |
| return 1.0 if pred_action.get("direction") == gt_action.get("direction") else 0.0 |
| |
| elif pred_type == "open_app": |
| return 1.0 if pred_action.get("app_name") == gt_action.get("app_name") else 0.0 |
| |
| elif pred_type == "input_text": |
| return 1.0 if pred_action.get("text") == gt_action.get("text") else 0.0 |
| |
| elif pred_type in ["navigate_back", "navigate_home", "wait"]: |
| return 1.0 |
| |
| return 0.0 |
|
|
|
|
| def filter_sample( |
| responses: List[str], |
| ground_truth: Dict[str, Any], |
| threshold_all_correct: float = 1.0, |
| threshold_all_incorrect: float = 0.0, |
| ) -> bool: |
| """ |
| Decide whether to keep a sample based on response correctness diversity. |
| |
| Returns True if sample should be KEPT (has mixed correctness), |
| False if sample should be DISCARDED (all correct or all incorrect). |
| """ |
| scores = [evaluate_action_correctness(resp, ground_truth) for resp in responses] |
| |
| |
| if all(score >= threshold_all_correct for score in scores): |
| return False |
| |
| |
| if all(score <= threshold_all_incorrect for score in scores): |
| return False |
| |
| |
| return True |
|
|
|
|
| def parse_ground_truth(sample: Dict[str, Any]) -> Dict[str, Any]: |
| """Extract ground truth action from sample.""" |
| if "ground_truth_action" in sample: |
| return sample["ground_truth_action"] |
| |
| |
| solution = sample.get("solution", "") |
| if isinstance(solution, str): |
| import re |
| match = re.search(r'<answer>(.*?)</answer>', solution, re.DOTALL) |
| if match: |
| try: |
| return json.loads(match.group(1).strip()) |
| except json.JSONDecodeError: |
| pass |
| |
| return {} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Filter K-step GUI Transition data") |
| parser.add_argument("--input_file", type=str, required=True, help="Input JSONL file with K-step data") |
| parser.add_argument("--output_file", type=str, required=True, help="Output filtered JSONL file") |
| parser.add_argument("--model_path", type=str, required=True, help="Base VLM model for filtering") |
| parser.add_argument("--num_generations", type=int, default=8, help="Number of generations per sample") |
| parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature") |
| parser.add_argument("--max_new_tokens", type=int, default=256, help="Max tokens per generation") |
| parser.add_argument("--device", type=str, default="cuda", help="Device for model inference") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| args = parser.parse_args() |
| |
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| |
| print(f"Loading model from {args.model_path}...") |
| model, processor = load_model_and_processor(args.model_path, args.device) |
| |
| print(f"Loading samples from {args.input_file}...") |
| samples = [] |
| with open(args.input_file, "r") as f: |
| for line in f: |
| if line.strip(): |
| samples.append(json.loads(line)) |
| |
| print(f"Loaded {len(samples)} samples. Starting filtering...") |
| |
| kept_samples = [] |
| discarded_easy = 0 |
| discarded_hard = 0 |
| |
| for i, sample in enumerate(samples): |
| print(f" Processing sample {i+1}/{len(samples)}...", end="\r") |
| |
| |
| responses = generate_responses( |
| model, processor, sample, |
| num_generations=args.num_generations, |
| temperature=args.temperature, |
| max_new_tokens=args.max_new_tokens, |
| ) |
| |
| |
| gt = parse_ground_truth(sample) |
| if not gt: |
| print(f"\n Warning: Could not parse ground truth for sample {sample.get('id', i)}") |
| continue |
| |
| |
| scores = [evaluate_action_correctness(resp, gt) for resp in responses] |
| |
| if all(score >= 1.0 for score in scores): |
| discarded_easy += 1 |
| continue |
| elif all(score <= 0.0 for score in scores): |
| discarded_hard += 1 |
| continue |
| |
| |
| sample["filter_scores"] = scores |
| sample["filter_mean_score"] = sum(scores) / len(scores) |
| kept_samples.append(sample) |
| |
| print(f"\nFiltering complete!") |
| print(f" Kept: {len(kept_samples)} samples") |
| print(f" Discarded (too easy): {discarded_easy} samples") |
| print(f" Discarded (too hard): {discarded_hard} samples") |
| |
| |
| os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True) |
| with open(args.output_file, "w") as f: |
| for sample in kept_samples: |
| f.write(json.dumps(sample, ensure_ascii=False) + "\n") |
| |
| print(f"Wrote filtered data to {args.output_file}") |
| |
| |
| stats = { |
| "input_file": args.input_file, |
| "output_file": args.output_file, |
| "model_path": args.model_path, |
| "num_generations": args.num_generations, |
| "total_samples": len(samples), |
| "kept_samples": len(kept_samples), |
| "discarded_easy": discarded_easy, |
| "discarded_hard": discarded_hard, |
| "keep_ratio": len(kept_samples) / len(samples) if samples else 0, |
| } |
| stats_file = args.output_file.replace(".jsonl", "_stats.json") |
| with open(stats_file, "w") as f: |
| json.dump(stats, f, indent=2) |
| |
| print(f"Wrote statistics to {stats_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|