#!/usr/bin/env python3 from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict, List import numpy as np ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from scripts.run_random_baseline import run_random_baseline from scripts.run_surrogate_baseline import run_surrogate_baseline def _average_metric_dict(records: List[Dict[str, float]]) -> Dict[str, float]: if not records: return {} keys = sorted({key for record in records for key in record.keys()}, key=lambda value: int(value)) return { key: float(np.mean(np.asarray([record[key] for record in records if key in record], dtype=np.float32))) for key in keys } def _summarize_runs(runs: List[Dict[str, Any]]) -> Dict[str, Any]: mean_regret_records = [run["aggregate_metrics"].get("mean_regret_at", {}) for run in runs] median_regret_records = [run["aggregate_metrics"].get("median_regret_at", {}) for run in runs] auc_values = [run["aggregate_metrics"].get("mean_auc_regret") for run in runs] oracle_hit_values = [run["aggregate_metrics"].get("oracle_hit_rate_final") for run in runs] return { "mean_regret_at": _average_metric_dict(mean_regret_records), "median_regret_at": _average_metric_dict(median_regret_records), "mean_best_so_far_auc": float(np.mean(np.asarray(auc_values, dtype=np.float32))) if auc_values else None, "mean_oracle_hit_rate_final": float(np.mean(np.asarray(oracle_hit_values, dtype=np.float32))) if oracle_hit_values else None, } def _evaluate_section( section_name: str, split: Dict[str, Any], measurement_path: str, episodes: int, budget: int, seed: int, acquisition: str, beta: float, xi: float, ) -> Dict[str, Any]: train_tasks = split["train_tasks"] test_tasks = split["test_tasks"] random_runs: List[Dict[str, Any]] = [] surrogate_runs: List[Dict[str, Any]] = [] for idx, task in enumerate(test_tasks): task_seed = seed + idx * 1000 random_runs.append( run_random_baseline( task=task, episodes=episodes, budget=budget, seed=task_seed, measurement_path=measurement_path, ) ) surrogate_runs.append( run_surrogate_baseline( task=task, episodes=episodes, budget=budget, seed=task_seed, measurement_path=measurement_path, train_task_ids=train_tasks, acquisition=acquisition, beta=beta, xi=xi, ) ) return { "section": section_name, "train_tasks": train_tasks, "test_tasks": test_tasks, "random_summary": _summarize_runs(random_runs), "surrogate_summary": _summarize_runs(surrogate_runs), "task_runs": { "random": random_runs, "surrogate": surrogate_runs, }, } def main() -> None: parser = argparse.ArgumentParser(description="Evaluate random vs surrogate on shape and family holdout splits.") parser.add_argument("--measurement-path", type=str, default="data/autotune_measurements.csv") parser.add_argument("--splits", type=Path, default=Path("data/benchmark_splits.json")) parser.add_argument("--episodes", type=int, default=20) parser.add_argument("--budget", type=int, default=6) parser.add_argument("--seed", type=int, default=2) parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb") parser.add_argument("--beta", type=float, default=2.0) parser.add_argument("--xi", type=float, default=0.0) parser.add_argument("--output", type=Path, default=Path("outputs/generalization_eval.json")) args = parser.parse_args() splits = json.loads(args.splits.read_text(encoding="utf-8")) sections = { "shape_generalization": splits["shape_generalization"], "family_holdout": splits["family_holdout"], } results = { name: _evaluate_section( section_name=name, split=section, measurement_path=args.measurement_path, episodes=args.episodes, budget=args.budget, seed=args.seed, acquisition=args.acquisition, beta=args.beta, xi=args.xi, ) for name, section in sections.items() } summary = { "measurement_path": args.measurement_path, "splits_path": str(args.splits), "episodes": args.episodes, "budget": args.budget, "acquisition": args.acquisition, "beta": args.beta, "xi": args.xi, "results": results, } args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as handle: json.dump(summary, handle, indent=2) print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()