luanns commited on
Commit
c39c017
·
verified ·
1 Parent(s): 3873289

Upload src/evaluation/eval_gui.py

Browse files
Files changed (1) hide show
  1. src/evaluation/eval_gui.py +246 -0
src/evaluation/eval_gui.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GUI Task Evaluation Scripts for GUI-Shift.
4
+
5
+ Evaluates on:
6
+ - AndroidControl (Low / High)
7
+ - GUI Odyssey
8
+ - ScreenSpot-v2
9
+ - ScreenSpot-Pro
10
+
11
+ Metrics:
12
+ - TM: Type Match (correct action type)
13
+ - EM: Exact Match (correct type + all parameters)
14
+
15
+ From: GUI-Shift paper Section 4 (arXiv:2505.12493)
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ from pathlib import Path
22
+ from typing import Dict, Any, List, Tuple, Optional
23
+ from collections import defaultdict
24
+
25
+ import torch
26
+ from transformers import AutoModelForVision2Seq, AutoProcessor
27
+ from PIL import Image
28
+
29
+
30
+ def load_model(model_path: str, device: str = "cuda"):
31
+ """Load trained GUI-Shift model."""
32
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
33
+ model = AutoModelForVision2Seq.from_pretrained(
34
+ model_path,
35
+ torch_dtype=torch.bfloat16,
36
+ device_map=device,
37
+ trust_remote_code=True,
38
+ )
39
+ model.eval()
40
+ return model, processor
41
+
42
+
43
+ def parse_predicted_action(text: str) -> Optional[Dict[str, Any]]:
44
+ """Parse action from model output."""
45
+ import re
46
+ match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
47
+ if not match:
48
+ return None
49
+
50
+ content = match.group(1).strip()
51
+ try:
52
+ return json.loads(content)
53
+ except json.JSONDecodeError:
54
+ # Fallback regex
55
+ action_type_match = re.search(r'"action_type"\s*:\s*"([^"]+)"', content)
56
+ if action_type_match:
57
+ action = {"action_type": action_type_match.group(1)}
58
+ for key in ["x", "y", "direction", "app_name", "text"]:
59
+ match = re.search(rf'"{key}"\s*:\s*(?:"([^"]+)"|(\d+))', content)
60
+ if match:
61
+ val = match.group(1) or int(match.group(2))
62
+ if key in ["x", "y"]:
63
+ val = int(val)
64
+ action[key] = val
65
+ return action
66
+ return None
67
+
68
+
69
+ def type_match(pred: Optional[Dict], gt: Dict) -> bool:
70
+ """Check if predicted action type matches ground truth."""
71
+ if not pred:
72
+ return False
73
+ return pred.get("action_type") == gt.get("action_type")
74
+
75
+
76
+ def exact_match(pred: Optional[Dict], gt: Dict) -> bool:
77
+ """Check if predicted action exactly matches ground truth."""
78
+ if not type_match(pred, gt):
79
+ return False
80
+
81
+ action_type = gt.get("action_type", "")
82
+
83
+ if action_type in ["click", "long_press"]:
84
+ bbox = gt.get("bbox", [0, 0, 0, 0])
85
+ x = pred.get("x", 0)
86
+ y = pred.get("y", 0)
87
+ if bbox and len(bbox) >= 4:
88
+ return bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3]
89
+ if "x" in gt and "y" in gt:
90
+ tolerance = 20
91
+ return abs(x - gt["x"]) <= tolerance and abs(y - gt["y"]) <= tolerance
92
+ return False
93
+
94
+ elif action_type == "scroll":
95
+ return pred.get("direction") == gt.get("direction")
96
+
97
+ elif action_type == "open_app":
98
+ return pred.get("app_name") == gt.get("app_name")
99
+
100
+ elif action_type == "input_text":
101
+ return pred.get("text") == gt.get("text")
102
+
103
+ elif action_type in ["navigate_back", "navigate_home", "wait"]:
104
+ return True
105
+
106
+ return False
107
+
108
+
109
+ def evaluate_sample(
110
+ model,
111
+ processor,
112
+ sample: Dict[str, Any],
113
+ device: str = "cuda",
114
+ ) -> Tuple[bool, bool, str]:
115
+ """Evaluate a single sample. Returns (type_match, exact_match, prediction_text)."""
116
+ image_paths = sample.get("image_path", sample.get("image", []))
117
+ problem = sample.get("problem", sample.get("instruction", ""))
118
+ ground_truth = sample.get("ground_truth_action", sample.get("action", {}))
119
+
120
+ # Load images
121
+ images = []
122
+ for img_path in image_paths:
123
+ if isinstance(img_path, str) and os.path.exists(img_path):
124
+ images.append(Image.open(img_path).convert("RGB"))
125
+
126
+ # Build prompt
127
+ messages = [
128
+ {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": problem}]}
129
+ ]
130
+
131
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
132
+ inputs = processor(text=text, images=images, return_tensors="pt", padding=True).to(model.device)
133
+
134
+ with torch.no_grad():
135
+ outputs = model.generate(
136
+ **inputs,
137
+ max_new_tokens=256,
138
+ do_sample=False,
139
+ )
140
+
141
+ generated_text = processor.batch_decode(
142
+ outputs[:, inputs["input_ids"].shape[1]:],
143
+ skip_special_tokens=True,
144
+ )[0]
145
+
146
+ pred_action = parse_predicted_action(generated_text)
147
+
148
+ tm = type_match(pred_action, ground_truth)
149
+ em = exact_match(pred_action, ground_truth)
150
+
151
+ return tm, em, generated_text
152
+
153
+
154
+ def evaluate_dataset(
155
+ model,
156
+ processor,
157
+ dataset_path: str,
158
+ device: str = "cuda",
159
+ output_path: Optional[str] = None,
160
+ ) -> Dict[str, float]:
161
+ """Evaluate a full dataset and compute metrics."""
162
+ samples = []
163
+ with open(dataset_path, "r") as f:
164
+ for line in f:
165
+ if line.strip():
166
+ samples.append(json.loads(line))
167
+
168
+ results = {
169
+ "total": len(samples),
170
+ "type_match": 0,
171
+ "exact_match": 0,
172
+ "by_type": defaultdict(lambda: {"total": 0, "tm": 0, "em": 0}),
173
+ "details": [],
174
+ }
175
+
176
+ for i, sample in enumerate(samples):
177
+ print(f" Evaluating {i+1}/{len(samples)}...", end="\r")
178
+
179
+ tm, em, pred_text = evaluate_sample(model, processor, sample, device)
180
+
181
+ gt = sample.get("ground_truth_action", sample.get("action", {}))
182
+ action_type = gt.get("action_type", "unknown")
183
+
184
+ results["type_match"] += int(tm)
185
+ results["exact_match"] += int(em)
186
+ results["by_type"][action_type]["total"] += 1
187
+ results["by_type"][action_type]["tm"] += int(tm)
188
+ results["by_type"][action_type]["em"] += int(em)
189
+
190
+ results["details"].append({
191
+ "id": sample.get("id", i),
192
+ "type_match": tm,
193
+ "exact_match": em,
194
+ "predicted": pred_text,
195
+ "ground_truth": gt,
196
+ })
197
+
198
+ # Compute percentages
199
+ results["type_match_pct"] = 100.0 * results["type_match"] / len(samples) if samples else 0
200
+ results["exact_match_pct"] = 100.0 * results["exact_match"] / len(samples) if samples else 0
201
+
202
+ for action_type, counts in results["by_type"].items():
203
+ counts["tm_pct"] = 100.0 * counts["tm"] / counts["total"] if counts["total"] else 0
204
+ counts["em_pct"] = 100.0 * counts["em"] / counts["total"] if counts["total"] else 0
205
+
206
+ print(f"\n TM: {results['type_match_pct']:.1f}% | EM: {results['exact_match_pct']:.1f}%")
207
+
208
+ if output_path:
209
+ with open(output_path, "w") as f:
210
+ json.dump(results, f, indent=2)
211
+ print(f" Results saved to {output_path}")
212
+
213
+ return results
214
+
215
+
216
+ def main():
217
+ parser = argparse.ArgumentParser(description="Evaluate GUI-Shift model on GUI benchmarks")
218
+ parser.add_argument("--model_path", type=str, required=True, help="Path to trained model")
219
+ parser.add_argument("--dataset", type=str, required=True, help="Path to evaluation dataset (JSONL)")
220
+ parser.add_argument("--output", type=str, default="evaluation_results.json", help="Output results file")
221
+ parser.add_argument("--device", type=str, default="cuda", help="Device for inference")
222
+ parser.add_argument("--benchmark", type=str, default="androidcontrol",
223
+ choices=["androidcontrol_low", "androidcontrol_high", "gui_odyssey",
224
+ "screenspot_v2", "screenspot_pro"],
225
+ help="Benchmark name")
226
+ args = parser.parse_args()
227
+
228
+ print(f"Loading model from {args.model_path}...")
229
+ model, processor = load_model(args.model_path, args.device)
230
+
231
+ print(f"Evaluating on {args.benchmark}...")
232
+ results = evaluate_dataset(model, processor, args.dataset, args.device, args.output)
233
+
234
+ print("\n=== Final Results ===")
235
+ print(f"Benchmark: {args.benchmark}")
236
+ print(f"Total samples: {results['total']}")
237
+ print(f"Type Match (TM): {results['type_match']}/{results['total']} = {results['type_match_pct']:.2f}%")
238
+ print(f"Exact Match (EM): {results['exact_match']}/{results['total']} = {results['exact_match_pct']:.2f}%")
239
+
240
+ print("\nPer-action breakdown:")
241
+ for action_type, counts in sorted(results["by_type"].items()):
242
+ print(f" {action_type:20s}: TM={counts['tm_pct']:.1f}% EM={counts['em_pct']:.1f}% ({counts['tm']}/{counts['total']})")
243
+
244
+
245
+ if __name__ == "__main__":
246
+ main()