File size: 9,784 Bytes
1345eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
#!/usr/bin/env python3
"""
Data Filtering Pipeline for GUI-Shift.

Filters K-step GUI Transition samples based on model-generated responses.
- Discards samples where all N responses are entirely correct or incorrect
- Keeps samples with mixed correctness (informative for learning)

From: GUI-Shift paper Section 3.3 (arXiv:2505.12493)
"""

import argparse
import json
import os
import random
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional

import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image


def load_model_and_processor(model_path: str, device: str = "cuda"):
    """Load base VLM model and processor for filtering."""
    processor = AutoProcessor.from_pretrained(
        model_path,
        trust_remote_code=True,
    )
    model = AutoModelForVision2Seq.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map=device,
        trust_remote_code=True,
    )
    model.eval()
    return model, processor


def generate_responses(
    model,
    processor,
    sample: Dict[str, Any],
    num_generations: int = 8,
    temperature: float = 0.9,
    max_new_tokens: int = 256,
) -> List[str]:
    """Generate N candidate responses for a single sample."""
    image_paths = sample.get("image_path", sample.get("image", []))
    problem = sample["problem"]
    
    # Load images
    images = []
    for img_path in image_paths:
        if isinstance(img_path, str) and os.path.exists(img_path):
            images.append(Image.open(img_path).convert("RGB"))
    
    # Build prompt
    messages = [
        {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": problem}]}
    ]
    
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    
    inputs = processor(
        text=text,
        images=images,
        return_tensors="pt",
        padding=True,
    ).to(model.device)
    
    responses = []
    with torch.no_grad():
        for _ in range(num_generations):
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                num_return_sequences=1,
            )
            
            # Decode response
            generated_text = processor.batch_decode(
                outputs[:, inputs["input_ids"].shape[1]:],
                skip_special_tokens=True,
            )[0]
            responses.append(generated_text)
    
    return responses


def evaluate_action_correctness(
    response: str,
    ground_truth: Dict[str, Any],
) -> float:
    """
    Evaluate if a response action matches the ground truth.
    Returns 1.0 if correct, 0.0 if incorrect.
    """
    import re
    
    # Extract action from response
    match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
    if not match:
        return 0.0
    
    content = match.group(1).strip()
    
    try:
        pred_action = json.loads(content)
    except json.JSONDecodeError:
        return 0.0
    
    gt_action = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
    
    pred_type = pred_action.get("action_type", "")
    gt_type = gt_action.get("action_type", "")
    
    if pred_type != gt_type:
        return 0.0
    
    # Check parameters based on action type
    if pred_type in ["click", "long_press"]:
        bbox = gt_action.get("bbox", [0, 0, 0, 0])
        x = pred_action.get("x", 0)
        y = pred_action.get("y", 0)
        if bbox and len(bbox) >= 4:
            if bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]:
                return 1.0
        # Fallback: check exact coordinates
        if "x" in gt_action and "y" in gt_action:
            tolerance = 20
            if abs(x - gt_action["x"]) <= tolerance and abs(y - gt_action["y"]) <= tolerance:
                return 1.0
        return 0.0
    
    elif pred_type == "scroll":
        return 1.0 if pred_action.get("direction") == gt_action.get("direction") else 0.0
    
    elif pred_type == "open_app":
        return 1.0 if pred_action.get("app_name") == gt_action.get("app_name") else 0.0
    
    elif pred_type == "input_text":
        return 1.0 if pred_action.get("text") == gt_action.get("text") else 0.0
    
    elif pred_type in ["navigate_back", "navigate_home", "wait"]:
        return 1.0
    
    return 0.0


def filter_sample(
    responses: List[str],
    ground_truth: Dict[str, Any],
    threshold_all_correct: float = 1.0,
    threshold_all_incorrect: float = 0.0,
) -> bool:
    """
    Decide whether to keep a sample based on response correctness diversity.
    
    Returns True if sample should be KEPT (has mixed correctness),
    False if sample should be DISCARDED (all correct or all incorrect).
    """
    scores = [evaluate_action_correctness(resp, ground_truth) for resp in responses]
    
    # Check if all responses are entirely correct
    if all(score >= threshold_all_correct for score in scores):
        return False  # Too easy, discard
    
    # Check if all responses are entirely incorrect
    if all(score <= threshold_all_incorrect for score in scores):
        return False  # Too hard, discard
    
    # Mixed correctness — informative for learning
    return True


def parse_ground_truth(sample: Dict[str, Any]) -> Dict[str, Any]:
    """Extract ground truth action from sample."""
    if "ground_truth_action" in sample:
        return sample["ground_truth_action"]
    
    # Extract from solution in conversations
    solution = sample.get("solution", "")
    if isinstance(solution, str):
        import re
        match = re.search(r'<answer>(.*?)</answer>', solution, re.DOTALL)
        if match:
            try:
                return json.loads(match.group(1).strip())
            except json.JSONDecodeError:
                pass
    
    return {}


def main():
    parser = argparse.ArgumentParser(description="Filter K-step GUI Transition data")
    parser.add_argument("--input_file", type=str, required=True, help="Input JSONL file with K-step data")
    parser.add_argument("--output_file", type=str, required=True, help="Output filtered JSONL file")
    parser.add_argument("--model_path", type=str, required=True, help="Base VLM model for filtering")
    parser.add_argument("--num_generations", type=int, default=8, help="Number of generations per sample")
    parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature")
    parser.add_argument("--max_new_tokens", type=int, default=256, help="Max tokens per generation")
    parser.add_argument("--device", type=str, default="cuda", help="Device for model inference")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()
    
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    print(f"Loading model from {args.model_path}...")
    model, processor = load_model_and_processor(args.model_path, args.device)
    
    print(f"Loading samples from {args.input_file}...")
    samples = []
    with open(args.input_file, "r") as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line))
    
    print(f"Loaded {len(samples)} samples. Starting filtering...")
    
    kept_samples = []
    discarded_easy = 0
    discarded_hard = 0
    
    for i, sample in enumerate(samples):
        print(f"  Processing sample {i+1}/{len(samples)}...", end="\r")
        
        # Generate responses
        responses = generate_responses(
            model, processor, sample,
            num_generations=args.num_generations,
            temperature=args.temperature,
            max_new_tokens=args.max_new_tokens,
        )
        
        # Get ground truth
        gt = parse_ground_truth(sample)
        if not gt:
            print(f"\n  Warning: Could not parse ground truth for sample {sample.get('id', i)}")
            continue
        
        # Evaluate and filter
        scores = [evaluate_action_correctness(resp, gt) for resp in responses]
        
        if all(score >= 1.0 for score in scores):
            discarded_easy += 1
            continue
        elif all(score <= 0.0 for score in scores):
            discarded_hard += 1
            continue
        
        # Add correctness scores to sample metadata
        sample["filter_scores"] = scores
        sample["filter_mean_score"] = sum(scores) / len(scores)
        kept_samples.append(sample)
    
    print(f"\nFiltering complete!")
    print(f"  Kept: {len(kept_samples)} samples")
    print(f"  Discarded (too easy): {discarded_easy} samples")
    print(f"  Discarded (too hard): {discarded_hard} samples")
    
    # Write filtered data
    os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
    with open(args.output_file, "w") as f:
        for sample in kept_samples:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    
    print(f"Wrote filtered data to {args.output_file}")
    
    # Write statistics
    stats = {
        "input_file": args.input_file,
        "output_file": args.output_file,
        "model_path": args.model_path,
        "num_generations": args.num_generations,
        "total_samples": len(samples),
        "kept_samples": len(kept_samples),
        "discarded_easy": discarded_easy,
        "discarded_hard": discarded_hard,
        "keep_ratio": len(kept_samples) / len(samples) if samples else 0,
    }
    stats_file = args.output_file.replace(".jsonl", "_stats.json")
    with open(stats_file, "w") as f:
        json.dump(stats, f, indent=2)
    
    print(f"Wrote statistics to {stats_file}")


if __name__ == "__main__":
    main()