luanns commited on
Commit
1345eac
·
verified ·
1 Parent(s): 37433c4

Upload src/filtering/filter_data.py

Browse files
Files changed (1) hide show
  1. src/filtering/filter_data.py +294 -0
src/filtering/filter_data.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Data Filtering Pipeline for GUI-Shift.
4
+
5
+ Filters K-step GUI Transition samples based on model-generated responses.
6
+ - Discards samples where all N responses are entirely correct or incorrect
7
+ - Keeps samples with mixed correctness (informative for learning)
8
+
9
+ From: GUI-Shift paper Section 3.3 (arXiv:2505.12493)
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import random
16
+ from pathlib import Path
17
+ from typing import List, Dict, Any, Tuple, Optional
18
+
19
+ import torch
20
+ from transformers import AutoModelForVision2Seq, AutoProcessor
21
+ from PIL import Image
22
+
23
+
24
+ def load_model_and_processor(model_path: str, device: str = "cuda"):
25
+ """Load base VLM model and processor for filtering."""
26
+ processor = AutoProcessor.from_pretrained(
27
+ model_path,
28
+ trust_remote_code=True,
29
+ )
30
+ model = AutoModelForVision2Seq.from_pretrained(
31
+ model_path,
32
+ torch_dtype=torch.bfloat16,
33
+ device_map=device,
34
+ trust_remote_code=True,
35
+ )
36
+ model.eval()
37
+ return model, processor
38
+
39
+
40
+ def generate_responses(
41
+ model,
42
+ processor,
43
+ sample: Dict[str, Any],
44
+ num_generations: int = 8,
45
+ temperature: float = 0.9,
46
+ max_new_tokens: int = 256,
47
+ ) -> List[str]:
48
+ """Generate N candidate responses for a single sample."""
49
+ image_paths = sample.get("image_path", sample.get("image", []))
50
+ problem = sample["problem"]
51
+
52
+ # Load images
53
+ images = []
54
+ for img_path in image_paths:
55
+ if isinstance(img_path, str) and os.path.exists(img_path):
56
+ images.append(Image.open(img_path).convert("RGB"))
57
+
58
+ # Build prompt
59
+ messages = [
60
+ {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": problem}]}
61
+ ]
62
+
63
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
64
+
65
+ inputs = processor(
66
+ text=text,
67
+ images=images,
68
+ return_tensors="pt",
69
+ padding=True,
70
+ ).to(model.device)
71
+
72
+ responses = []
73
+ with torch.no_grad():
74
+ for _ in range(num_generations):
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_new_tokens=max_new_tokens,
78
+ temperature=temperature,
79
+ do_sample=True,
80
+ num_return_sequences=1,
81
+ )
82
+
83
+ # Decode response
84
+ generated_text = processor.batch_decode(
85
+ outputs[:, inputs["input_ids"].shape[1]:],
86
+ skip_special_tokens=True,
87
+ )[0]
88
+ responses.append(generated_text)
89
+
90
+ return responses
91
+
92
+
93
+ def evaluate_action_correctness(
94
+ response: str,
95
+ ground_truth: Dict[str, Any],
96
+ ) -> float:
97
+ """
98
+ Evaluate if a response action matches the ground truth.
99
+ Returns 1.0 if correct, 0.0 if incorrect.
100
+ """
101
+ import re
102
+
103
+ # Extract action from response
104
+ match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
105
+ if not match:
106
+ return 0.0
107
+
108
+ content = match.group(1).strip()
109
+
110
+ try:
111
+ pred_action = json.loads(content)
112
+ except json.JSONDecodeError:
113
+ return 0.0
114
+
115
+ gt_action = ground_truth if isinstance(ground_truth, dict) else json.loads(ground_truth)
116
+
117
+ pred_type = pred_action.get("action_type", "")
118
+ gt_type = gt_action.get("action_type", "")
119
+
120
+ if pred_type != gt_type:
121
+ return 0.0
122
+
123
+ # Check parameters based on action type
124
+ if pred_type in ["click", "long_press"]:
125
+ bbox = gt_action.get("bbox", [0, 0, 0, 0])
126
+ x = pred_action.get("x", 0)
127
+ y = pred_action.get("y", 0)
128
+ if bbox and len(bbox) >= 4:
129
+ if bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]:
130
+ return 1.0
131
+ # Fallback: check exact coordinates
132
+ if "x" in gt_action and "y" in gt_action:
133
+ tolerance = 20
134
+ if abs(x - gt_action["x"]) <= tolerance and abs(y - gt_action["y"]) <= tolerance:
135
+ return 1.0
136
+ return 0.0
137
+
138
+ elif pred_type == "scroll":
139
+ return 1.0 if pred_action.get("direction") == gt_action.get("direction") else 0.0
140
+
141
+ elif pred_type == "open_app":
142
+ return 1.0 if pred_action.get("app_name") == gt_action.get("app_name") else 0.0
143
+
144
+ elif pred_type == "input_text":
145
+ return 1.0 if pred_action.get("text") == gt_action.get("text") else 0.0
146
+
147
+ elif pred_type in ["navigate_back", "navigate_home", "wait"]:
148
+ return 1.0
149
+
150
+ return 0.0
151
+
152
+
153
+ def filter_sample(
154
+ responses: List[str],
155
+ ground_truth: Dict[str, Any],
156
+ threshold_all_correct: float = 1.0,
157
+ threshold_all_incorrect: float = 0.0,
158
+ ) -> bool:
159
+ """
160
+ Decide whether to keep a sample based on response correctness diversity.
161
+
162
+ Returns True if sample should be KEPT (has mixed correctness),
163
+ False if sample should be DISCARDED (all correct or all incorrect).
164
+ """
165
+ scores = [evaluate_action_correctness(resp, ground_truth) for resp in responses]
166
+
167
+ # Check if all responses are entirely correct
168
+ if all(score >= threshold_all_correct for score in scores):
169
+ return False # Too easy, discard
170
+
171
+ # Check if all responses are entirely incorrect
172
+ if all(score <= threshold_all_incorrect for score in scores):
173
+ return False # Too hard, discard
174
+
175
+ # Mixed correctness — informative for learning
176
+ return True
177
+
178
+
179
+ def parse_ground_truth(sample: Dict[str, Any]) -> Dict[str, Any]:
180
+ """Extract ground truth action from sample."""
181
+ if "ground_truth_action" in sample:
182
+ return sample["ground_truth_action"]
183
+
184
+ # Extract from solution in conversations
185
+ solution = sample.get("solution", "")
186
+ if isinstance(solution, str):
187
+ import re
188
+ match = re.search(r'<answer>(.*?)</answer>', solution, re.DOTALL)
189
+ if match:
190
+ try:
191
+ return json.loads(match.group(1).strip())
192
+ except json.JSONDecodeError:
193
+ pass
194
+
195
+ return {}
196
+
197
+
198
+ def main():
199
+ parser = argparse.ArgumentParser(description="Filter K-step GUI Transition data")
200
+ parser.add_argument("--input_file", type=str, required=True, help="Input JSONL file with K-step data")
201
+ parser.add_argument("--output_file", type=str, required=True, help="Output filtered JSONL file")
202
+ parser.add_argument("--model_path", type=str, required=True, help="Base VLM model for filtering")
203
+ parser.add_argument("--num_generations", type=int, default=8, help="Number of generations per sample")
204
+ parser.add_argument("--temperature", type=float, default=0.9, help="Sampling temperature")
205
+ parser.add_argument("--max_new_tokens", type=int, default=256, help="Max tokens per generation")
206
+ parser.add_argument("--device", type=str, default="cuda", help="Device for model inference")
207
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
208
+ args = parser.parse_args()
209
+
210
+ random.seed(args.seed)
211
+ torch.manual_seed(args.seed)
212
+
213
+ print(f"Loading model from {args.model_path}...")
214
+ model, processor = load_model_and_processor(args.model_path, args.device)
215
+
216
+ print(f"Loading samples from {args.input_file}...")
217
+ samples = []
218
+ with open(args.input_file, "r") as f:
219
+ for line in f:
220
+ if line.strip():
221
+ samples.append(json.loads(line))
222
+
223
+ print(f"Loaded {len(samples)} samples. Starting filtering...")
224
+
225
+ kept_samples = []
226
+ discarded_easy = 0
227
+ discarded_hard = 0
228
+
229
+ for i, sample in enumerate(samples):
230
+ print(f" Processing sample {i+1}/{len(samples)}...", end="\r")
231
+
232
+ # Generate responses
233
+ responses = generate_responses(
234
+ model, processor, sample,
235
+ num_generations=args.num_generations,
236
+ temperature=args.temperature,
237
+ max_new_tokens=args.max_new_tokens,
238
+ )
239
+
240
+ # Get ground truth
241
+ gt = parse_ground_truth(sample)
242
+ if not gt:
243
+ print(f"\n Warning: Could not parse ground truth for sample {sample.get('id', i)}")
244
+ continue
245
+
246
+ # Evaluate and filter
247
+ scores = [evaluate_action_correctness(resp, gt) for resp in responses]
248
+
249
+ if all(score >= 1.0 for score in scores):
250
+ discarded_easy += 1
251
+ continue
252
+ elif all(score <= 0.0 for score in scores):
253
+ discarded_hard += 1
254
+ continue
255
+
256
+ # Add correctness scores to sample metadata
257
+ sample["filter_scores"] = scores
258
+ sample["filter_mean_score"] = sum(scores) / len(scores)
259
+ kept_samples.append(sample)
260
+
261
+ print(f"\nFiltering complete!")
262
+ print(f" Kept: {len(kept_samples)} samples")
263
+ print(f" Discarded (too easy): {discarded_easy} samples")
264
+ print(f" Discarded (too hard): {discarded_hard} samples")
265
+
266
+ # Write filtered data
267
+ os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
268
+ with open(args.output_file, "w") as f:
269
+ for sample in kept_samples:
270
+ f.write(json.dumps(sample, ensure_ascii=False) + "\n")
271
+
272
+ print(f"Wrote filtered data to {args.output_file}")
273
+
274
+ # Write statistics
275
+ stats = {
276
+ "input_file": args.input_file,
277
+ "output_file": args.output_file,
278
+ "model_path": args.model_path,
279
+ "num_generations": args.num_generations,
280
+ "total_samples": len(samples),
281
+ "kept_samples": len(kept_samples),
282
+ "discarded_easy": discarded_easy,
283
+ "discarded_hard": discarded_hard,
284
+ "keep_ratio": len(kept_samples) / len(samples) if samples else 0,
285
+ }
286
+ stats_file = args.output_file.replace(".jsonl", "_stats.json")
287
+ with open(stats_file, "w") as f:
288
+ json.dump(stats, f, indent=2)
289
+
290
+ print(f"Wrote statistics to {stats_file}")
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()