File size: 5,118 Bytes
5000a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List

import numpy as np

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from scripts.run_random_baseline import run_random_baseline
from scripts.run_surrogate_baseline import run_surrogate_baseline


def _average_metric_dict(records: List[Dict[str, float]]) -> Dict[str, float]:
    if not records:
        return {}
    keys = sorted({key for record in records for key in record.keys()}, key=lambda value: int(value))
    return {
        key: float(np.mean(np.asarray([record[key] for record in records if key in record], dtype=np.float32)))
        for key in keys
    }


def _summarize_runs(runs: List[Dict[str, Any]]) -> Dict[str, Any]:
    mean_regret_records = [run["aggregate_metrics"].get("mean_regret_at", {}) for run in runs]
    median_regret_records = [run["aggregate_metrics"].get("median_regret_at", {}) for run in runs]
    auc_values = [run["aggregate_metrics"].get("mean_auc_regret") for run in runs]
    oracle_hit_values = [run["aggregate_metrics"].get("oracle_hit_rate_final") for run in runs]
    return {
        "mean_regret_at": _average_metric_dict(mean_regret_records),
        "median_regret_at": _average_metric_dict(median_regret_records),
        "mean_best_so_far_auc": float(np.mean(np.asarray(auc_values, dtype=np.float32))) if auc_values else None,
        "mean_oracle_hit_rate_final": float(np.mean(np.asarray(oracle_hit_values, dtype=np.float32))) if oracle_hit_values else None,
    }


def _evaluate_section(
    section_name: str,
    split: Dict[str, Any],
    measurement_path: str,
    episodes: int,
    budget: int,
    seed: int,
    acquisition: str,
    beta: float,
    xi: float,
) -> Dict[str, Any]:
    train_tasks = split["train_tasks"]
    test_tasks = split["test_tasks"]
    random_runs: List[Dict[str, Any]] = []
    surrogate_runs: List[Dict[str, Any]] = []

    for idx, task in enumerate(test_tasks):
        task_seed = seed + idx * 1000
        random_runs.append(
            run_random_baseline(
                task=task,
                episodes=episodes,
                budget=budget,
                seed=task_seed,
                measurement_path=measurement_path,
            )
        )
        surrogate_runs.append(
            run_surrogate_baseline(
                task=task,
                episodes=episodes,
                budget=budget,
                seed=task_seed,
                measurement_path=measurement_path,
                train_task_ids=train_tasks,
                acquisition=acquisition,
                beta=beta,
                xi=xi,
            )
        )

    return {
        "section": section_name,
        "train_tasks": train_tasks,
        "test_tasks": test_tasks,
        "random_summary": _summarize_runs(random_runs),
        "surrogate_summary": _summarize_runs(surrogate_runs),
        "task_runs": {
            "random": random_runs,
            "surrogate": surrogate_runs,
        },
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate random vs surrogate on shape and family holdout splits.")
    parser.add_argument("--measurement-path", type=str, default="data/autotune_measurements.csv")
    parser.add_argument("--splits", type=Path, default=Path("data/benchmark_splits.json"))
    parser.add_argument("--episodes", type=int, default=20)
    parser.add_argument("--budget", type=int, default=6)
    parser.add_argument("--seed", type=int, default=2)
    parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb")
    parser.add_argument("--beta", type=float, default=2.0)
    parser.add_argument("--xi", type=float, default=0.0)
    parser.add_argument("--output", type=Path, default=Path("outputs/generalization_eval.json"))
    args = parser.parse_args()

    splits = json.loads(args.splits.read_text(encoding="utf-8"))
    sections = {
        "shape_generalization": splits["shape_generalization"],
        "family_holdout": splits["family_holdout"],
    }
    results = {
        name: _evaluate_section(
            section_name=name,
            split=section,
            measurement_path=args.measurement_path,
            episodes=args.episodes,
            budget=args.budget,
            seed=args.seed,
            acquisition=args.acquisition,
            beta=args.beta,
            xi=args.xi,
        )
        for name, section in sections.items()
    }

    summary = {
        "measurement_path": args.measurement_path,
        "splits_path": str(args.splits),
        "episodes": args.episodes,
        "budget": args.budget,
        "acquisition": args.acquisition,
        "beta": args.beta,
        "xi": args.xi,
        "results": results,
    }
    args.output.parent.mkdir(parents=True, exist_ok=True)
    with args.output.open("w", encoding="utf-8") as handle:
        json.dump(summary, handle, indent=2)
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()