#!/usr/bin/env python3 """ GUI Task Evaluation Scripts for GUI-Shift. Evaluates on: - AndroidControl (Low / High) - GUI Odyssey - ScreenSpot-v2 - ScreenSpot-Pro Metrics: - TM: Type Match (correct action type) - EM: Exact Match (correct type + all parameters) From: GUI-Shift paper Section 4 (arXiv:2505.12493) """ import argparse import json import os from pathlib import Path from typing import Dict, Any, List, Tuple, Optional from collections import defaultdict import torch from transformers import AutoModelForVision2Seq, AutoProcessor from PIL import Image def load_model(model_path: str, device: str = "cuda"): """Load trained GUI-Shift model.""" processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True, ) model.eval() return model, processor def parse_predicted_action(text: str) -> Optional[Dict[str, Any]]: """Parse action from model output.""" import re match = re.search(r'(.*?)', text, re.DOTALL) if not match: return None content = match.group(1).strip() try: return json.loads(content) except json.JSONDecodeError: # Fallback regex action_type_match = re.search(r'"action_type"\s*:\s*"([^"]+)"', content) if action_type_match: action = {"action_type": action_type_match.group(1)} for key in ["x", "y", "direction", "app_name", "text"]: match = re.search(rf'"{key}"\s*:\s*(?:"([^"]+)"|(\d+))', content) if match: val = match.group(1) or int(match.group(2)) if key in ["x", "y"]: val = int(val) action[key] = val return action return None def type_match(pred: Optional[Dict], gt: Dict) -> bool: """Check if predicted action type matches ground truth.""" if not pred: return False return pred.get("action_type") == gt.get("action_type") def exact_match(pred: Optional[Dict], gt: Dict) -> bool: """Check if predicted action exactly matches ground truth.""" if not type_match(pred, gt): return False action_type = gt.get("action_type", "") if action_type in ["click", "long_press"]: bbox = gt.get("bbox", [0, 0, 0, 0]) x = pred.get("x", 0) y = pred.get("y", 0) if bbox and len(bbox) >= 4: return bbox[0] <= x <= bbox[2] and bbox[1] <= y <= bbox[3] if "x" in gt and "y" in gt: tolerance = 20 return abs(x - gt["x"]) <= tolerance and abs(y - gt["y"]) <= tolerance return False elif action_type == "scroll": return pred.get("direction") == gt.get("direction") elif action_type == "open_app": return pred.get("app_name") == gt.get("app_name") elif action_type == "input_text": return pred.get("text") == gt.get("text") elif action_type in ["navigate_back", "navigate_home", "wait"]: return True return False def evaluate_sample( model, processor, sample: Dict[str, Any], device: str = "cuda", ) -> Tuple[bool, bool, str]: """Evaluate a single sample. Returns (type_match, exact_match, prediction_text).""" image_paths = sample.get("image_path", sample.get("image", [])) problem = sample.get("problem", sample.get("instruction", "")) ground_truth = sample.get("ground_truth_action", sample.get("action", {})) # Load images images = [] for img_path in image_paths: if isinstance(img_path, str) and os.path.exists(img_path): images.append(Image.open(img_path).convert("RGB")) # Build prompt messages = [ {"role": "user", "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": problem}]} ] text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=text, images=images, return_tensors="pt", padding=True).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=False, ) generated_text = processor.batch_decode( outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True, )[0] pred_action = parse_predicted_action(generated_text) tm = type_match(pred_action, ground_truth) em = exact_match(pred_action, ground_truth) return tm, em, generated_text def evaluate_dataset( model, processor, dataset_path: str, device: str = "cuda", output_path: Optional[str] = None, ) -> Dict[str, float]: """Evaluate a full dataset and compute metrics.""" samples = [] with open(dataset_path, "r") as f: for line in f: if line.strip(): samples.append(json.loads(line)) results = { "total": len(samples), "type_match": 0, "exact_match": 0, "by_type": defaultdict(lambda: {"total": 0, "tm": 0, "em": 0}), "details": [], } for i, sample in enumerate(samples): print(f" Evaluating {i+1}/{len(samples)}...", end="\r") tm, em, pred_text = evaluate_sample(model, processor, sample, device) gt = sample.get("ground_truth_action", sample.get("action", {})) action_type = gt.get("action_type", "unknown") results["type_match"] += int(tm) results["exact_match"] += int(em) results["by_type"][action_type]["total"] += 1 results["by_type"][action_type]["tm"] += int(tm) results["by_type"][action_type]["em"] += int(em) results["details"].append({ "id": sample.get("id", i), "type_match": tm, "exact_match": em, "predicted": pred_text, "ground_truth": gt, }) # Compute percentages results["type_match_pct"] = 100.0 * results["type_match"] / len(samples) if samples else 0 results["exact_match_pct"] = 100.0 * results["exact_match"] / len(samples) if samples else 0 for action_type, counts in results["by_type"].items(): counts["tm_pct"] = 100.0 * counts["tm"] / counts["total"] if counts["total"] else 0 counts["em_pct"] = 100.0 * counts["em"] / counts["total"] if counts["total"] else 0 print(f"\n TM: {results['type_match_pct']:.1f}% | EM: {results['exact_match_pct']:.1f}%") if output_path: with open(output_path, "w") as f: json.dump(results, f, indent=2) print(f" Results saved to {output_path}") return results def main(): parser = argparse.ArgumentParser(description="Evaluate GUI-Shift model on GUI benchmarks") parser.add_argument("--model_path", type=str, required=True, help="Path to trained model") parser.add_argument("--dataset", type=str, required=True, help="Path to evaluation dataset (JSONL)") parser.add_argument("--output", type=str, default="evaluation_results.json", help="Output results file") parser.add_argument("--device", type=str, default="cuda", help="Device for inference") parser.add_argument("--benchmark", type=str, default="androidcontrol", choices=["androidcontrol_low", "androidcontrol_high", "gui_odyssey", "screenspot_v2", "screenspot_pro"], help="Benchmark name") args = parser.parse_args() print(f"Loading model from {args.model_path}...") model, processor = load_model(args.model_path, args.device) print(f"Evaluating on {args.benchmark}...") results = evaluate_dataset(model, processor, args.dataset, args.device, args.output) print("\n=== Final Results ===") print(f"Benchmark: {args.benchmark}") print(f"Total samples: {results['total']}") print(f"Type Match (TM): {results['type_match']}/{results['total']} = {results['type_match_pct']:.2f}%") print(f"Exact Match (EM): {results['exact_match']}/{results['total']} = {results['exact_match_pct']:.2f}%") print("\nPer-action breakdown:") for action_type, counts in sorted(results["by_type"].items()): print(f" {action_type:20s}: TM={counts['tm_pct']:.1f}% EM={counts['em_pct']:.1f}% ({counts['tm']}/{counts['total']})") if __name__ == "__main__": main()