| |
| """ |
| GUI Task Evaluation Scripts for GUI-Shift. |
| |
| Evaluates on: |
| - AndroidControl (Low / High) |
| - GUI Odyssey |
| - ScreenSpot-v2 |
| - ScreenSpot-Pro |
| |
| Metrics: |
| - TM: Type Match (correct action type) |
| - EM: Exact Match (correct type + all parameters) |
| |
| From: GUI-Shift paper Section 4 (arXiv:2505.12493) |
| """ |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
| from typing import Dict, Any, List, Tuple, Optional |
| from collections import defaultdict |
|
|
| import torch |
| from transformers import AutoModelForVision2Seq, AutoProcessor |
| from PIL import Image |
|
|
|
|
| def load_model(model_path: str, device: str = "cuda"): |
| """Load trained GUI-Shift model.""" |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
| model = AutoModelForVision2Seq.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map=device, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| return model, processor |
|
|
|
|
| def parse_predicted_action(text: str) -> Optional[Dict[str, Any]]: |
| """Parse action from model output.""" |
| import re |
| match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL) |
| if not match: |
| return None |
| |
| content = match.group(1).strip() |
| try: |
| return json.loads(content) |
| except json.JSONDecodeError: |
| |
| action_type_match = re.search(r'"action_type"\s*:\s*"([^"]+)"', content) |
| if action_type_match: |
| action = {"action_type": action_type_match.group(1)} |
| for key in ["x", "y", "direction", "app_name", "text"]: |
| match = re.search(rf'"{key}"\s*:\s*(?:"([^"]+)"|(\d+))', content) |
| if match: |
| val = match.group(1) or int(match.group(2)) |
| if key in ["x", "y"]: |
| val = int(val) |
| action[key] = val |
| return action |
| return None |
|
|
|
|
| def type_match(pred: Optional[Dict], gt: Dict) -> bool: |
| """Check if predicted action type matches ground truth.""" |
| if not pred: |
| return False |
| return pred.get("action_type") == gt.get("action_type") |
|
|
|
|
| def exact_match(pred: Optional[Dict], gt: Dict) -> bool: |
| """Check if predicted action exactly matches ground truth.""" |
| if not type_match(pred, gt): |
| return False |
| |
| action_type = gt.get("action_type", "") |
| |
| if action_type in ["click", "long_press"]: |
| bbox = gt.get("bbox", [0, 0, 0, 0]) |
| x = pred.get("x", 0) |
| y = pred.get("y", 0) |
| if bbox and len(bbox) >= 4: |
| return bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3] |
| if "x" in gt and "y" in gt: |
| tolerance = 20 |
| return abs(x - gt["x"]) <= tolerance and abs(y - gt["y"]) <= tolerance |
| return False |
| |
| elif action_type == "scroll": |
| return pred.get("direction") == gt.get("direction") |
| |
| elif action_type == "open_app": |
| return pred.get("app_name") == gt.get("app_name") |
| |
| elif action_type == "input_text": |
| return pred.get("text") == gt.get("text") |
| |
| elif action_type in ["navigate_back", "navigate_home", "wait"]: |
| return True |
| |
| return False |
|
|
|
|
| def evaluate_sample( |
| model, |
| processor, |
| sample: Dict[str, Any], |
| device: str = "cuda", |
| ) -> Tuple[bool, bool, str]: |
| """Evaluate a single sample. Returns (type_match, exact_match, prediction_text).""" |
| image_paths = sample.get("image_path", sample.get("image", [])) |
| problem = sample.get("problem", sample.get("instruction", "")) |
| ground_truth = sample.get("ground_truth_action", sample.get("action", {})) |
| |
| |
| images = [] |
| for img_path in image_paths: |
| if isinstance(img_path, str) and os.path.exists(img_path): |
| images.append(Image.open(img_path).convert("RGB")) |
| |
| |
| messages = [ |
| {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": problem}]} |
| ] |
| |
| text = processor.apply_chat_template(messages, add_generation_prompt=True) |
| inputs = processor(text=text, images=images, return_tensors="pt", padding=True).to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=256, |
| do_sample=False, |
| ) |
| |
| generated_text = processor.batch_decode( |
| outputs[:, inputs["input_ids"].shape[1]:], |
| skip_special_tokens=True, |
| )[0] |
| |
| pred_action = parse_predicted_action(generated_text) |
| |
| tm = type_match(pred_action, ground_truth) |
| em = exact_match(pred_action, ground_truth) |
| |
| return tm, em, generated_text |
|
|
|
|
| def evaluate_dataset( |
| model, |
| processor, |
| dataset_path: str, |
| device: str = "cuda", |
| output_path: Optional[str] = None, |
| ) -> Dict[str, float]: |
| """Evaluate a full dataset and compute metrics.""" |
| samples = [] |
| with open(dataset_path, "r") as f: |
| for line in f: |
| if line.strip(): |
| samples.append(json.loads(line)) |
| |
| results = { |
| "total": len(samples), |
| "type_match": 0, |
| "exact_match": 0, |
| "by_type": defaultdict(lambda: {"total": 0, "tm": 0, "em": 0}), |
| "details": [], |
| } |
| |
| for i, sample in enumerate(samples): |
| print(f" Evaluating {i+1}/{len(samples)}...", end="\r") |
| |
| tm, em, pred_text = evaluate_sample(model, processor, sample, device) |
| |
| gt = sample.get("ground_truth_action", sample.get("action", {})) |
| action_type = gt.get("action_type", "unknown") |
| |
| results["type_match"] += int(tm) |
| results["exact_match"] += int(em) |
| results["by_type"][action_type]["total"] += 1 |
| results["by_type"][action_type]["tm"] += int(tm) |
| results["by_type"][action_type]["em"] += int(em) |
| |
| results["details"].append({ |
| "id": sample.get("id", i), |
| "type_match": tm, |
| "exact_match": em, |
| "predicted": pred_text, |
| "ground_truth": gt, |
| }) |
| |
| |
| results["type_match_pct"] = 100.0 * results["type_match"] / len(samples) if samples else 0 |
| results["exact_match_pct"] = 100.0 * results["exact_match"] / len(samples) if samples else 0 |
| |
| for action_type, counts in results["by_type"].items(): |
| counts["tm_pct"] = 100.0 * counts["tm"] / counts["total"] if counts["total"] else 0 |
| counts["em_pct"] = 100.0 * counts["em"] / counts["total"] if counts["total"] else 0 |
| |
| print(f"\n TM: {results['type_match_pct']:.1f}% | EM: {results['exact_match_pct']:.1f}%") |
| |
| if output_path: |
| with open(output_path, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f" Results saved to {output_path}") |
| |
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate GUI-Shift model on GUI benchmarks") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to trained model") |
| parser.add_argument("--dataset", type=str, required=True, help="Path to evaluation dataset (JSONL)") |
| parser.add_argument("--output", type=str, default="evaluation_results.json", help="Output results file") |
| parser.add_argument("--device", type=str, default="cuda", help="Device for inference") |
| parser.add_argument("--benchmark", type=str, default="androidcontrol", |
| choices=["androidcontrol_low", "androidcontrol_high", "gui_odyssey", |
| "screenspot_v2", "screenspot_pro"], |
| help="Benchmark name") |
| args = parser.parse_args() |
| |
| print(f"Loading model from {args.model_path}...") |
| model, processor = load_model(args.model_path, args.device) |
| |
| print(f"Evaluating on {args.benchmark}...") |
| results = evaluate_dataset(model, processor, args.dataset, args.device, args.output) |
| |
| print("\n=== Final Results ===") |
| print(f"Benchmark: {args.benchmark}") |
| print(f"Total samples: {results['total']}") |
| print(f"Type Match (TM): {results['type_match']}/{results['total']} = {results['type_match_pct']:.2f}%") |
| print(f"Exact Match (EM): {results['exact_match']}/{results['total']} = {results['exact_match_pct']:.2f}%") |
| |
| print("\nPer-action breakdown:") |
| for action_type, counts in sorted(results["by_type"].items()): |
| print(f" {action_type:20s}: TM={counts['tm_pct']:.1f}% EM={counts['em_pct']:.1f}% ({counts['tm']}/{counts['total']})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|