Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.append(str(ROOT)) | |
| from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment | |
| def _choose_surrogate_action( | |
| env: SoftmaxSurrogateEnvironment, | |
| acquisition: str, | |
| beta: float, | |
| xi: float, | |
| ) -> int: | |
| seen = set(env.seen_config_ids()) | |
| best_config_id = -1 | |
| best_score = float("-inf") | |
| for config_id in env.available_config_ids(): | |
| if config_id in seen and len(seen) < len(env.available_config_ids()): | |
| continue | |
| score = env.acquisition_score(config_id, strategy=acquisition, beta=beta, xi=xi) | |
| if score > best_score: | |
| best_score = score | |
| best_config_id = config_id | |
| if best_config_id < 0: | |
| raise RuntimeError("Failed to choose a surrogate action.") | |
| return best_config_id | |
| def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]: | |
| ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget)) | |
| regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks} | |
| auc_regrets: List[float] = [] | |
| for episode in episode_records: | |
| regrets = [float(step["regret"]) for step in episode["history"]] | |
| if regrets: | |
| auc_regrets.append(float(sum(regrets) / len(regrets))) | |
| for k in ks: | |
| if len(regrets) >= k: | |
| regrets_by_k[k].append(regrets[k - 1]) | |
| return { | |
| "mean_regret_at": { | |
| str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals | |
| }, | |
| "median_regret_at": { | |
| str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals | |
| }, | |
| "mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None, | |
| "oracle_hit_rate_final": float( | |
| sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records) | |
| ) if episode_records else None, | |
| } | |
| def run_surrogate_baseline( | |
| task: str, | |
| episodes: int, | |
| budget: int, | |
| seed: int, | |
| measurement_path: str, | |
| train_task_ids: List[str] | None = None, | |
| acquisition: str = "ucb", | |
| beta: float = 1.5, | |
| xi: float = 0.0, | |
| ) -> Dict[str, Any]: | |
| env = SoftmaxSurrogateEnvironment( | |
| measurement_path=measurement_path, | |
| budget=budget, | |
| seed=seed, | |
| train_task_ids=train_task_ids, | |
| ) | |
| best_overall = {"latency_ms": float("inf"), "config": None, "task_id": task} | |
| episode_records: List[Dict[str, Any]] = [] | |
| for episode in range(episodes): | |
| env.reset(task=task, seed=seed + episode) | |
| done = False | |
| episode_best = float("inf") | |
| episode_best_cfg: Dict[str, int] | None = None | |
| history: List[Dict[str, Any]] = [] | |
| while not done: | |
| config_id = _choose_surrogate_action(env, acquisition=acquisition, beta=beta, xi=xi) | |
| out = env.step({"config_id": config_id}) | |
| obs = out["observation"] | |
| trial = obs["last_trial"] | |
| history.append( | |
| { | |
| "config_id": config_id, | |
| "latency_ms": trial["latency_ms"], | |
| "config": trial["config"], | |
| "reward": out["reward"], | |
| "regret": out["info"]["current_regret"], | |
| "validation_mse": out["info"]["validation_mse"], | |
| } | |
| ) | |
| if obs["best_so_far_ms"] < episode_best: | |
| episode_best = obs["best_so_far_ms"] | |
| best_seen = min(env.seen_config_ids(), key=env.measured_latency_ms) | |
| episode_best_cfg = env.config_info(best_seen) | |
| done = bool(out["done"]) | |
| if episode_best < best_overall["latency_ms"]: | |
| best_overall = { | |
| "latency_ms": float(episode_best), | |
| "config": episode_best_cfg, | |
| "task_id": task, | |
| } | |
| diagnostics = env.diagnostics() | |
| episode_records.append( | |
| { | |
| "task_id": task, | |
| "episode": episode, | |
| "best_latency_ms": episode_best, | |
| "best_config": episode_best_cfg or {}, | |
| "final_validation_mse": diagnostics["validation_mse"], | |
| "final_regret": diagnostics["current_regret"], | |
| "history": history, | |
| } | |
| ) | |
| return { | |
| "task": task, | |
| "method": "surrogate", | |
| "episodes": episodes, | |
| "budget": budget, | |
| "seed": seed, | |
| "train_task_ids": list(train_task_ids or []), | |
| "acquisition": acquisition, | |
| "beta": beta, | |
| "xi": xi, | |
| "oracle_best_ms": env.oracle_best()["median_ms"], | |
| "best_overall": best_overall, | |
| "aggregate_metrics": _aggregate_metrics(episode_records, budget), | |
| "episodes_summary": episode_records, | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Surrogate-guided baseline.") | |
| parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)") | |
| parser.add_argument("--episodes", type=int, default=20) | |
| parser.add_argument("--budget", type=int, default=6) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument( | |
| "--acquisition", | |
| type=str, | |
| choices=("mean", "ucb", "ei"), | |
| default="ucb", | |
| help="Candidate selection mode: mean, ucb, or ei.", | |
| ) | |
| parser.add_argument("--beta", type=float, default=1.5, help="UCB exploration strength.") | |
| parser.add_argument("--xi", type=float, default=0.0, help="Expected-improvement margin.") | |
| parser.add_argument( | |
| "--train-tasks-file", | |
| type=Path, | |
| default=None, | |
| help="Optional JSON file containing a list of train task ids.", | |
| ) | |
| parser.add_argument( | |
| "--measurement-path", | |
| type=str, | |
| default="data/autotune_measurements.csv", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=Path("outputs/surrogate_baseline.json"), | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| if not args.task: | |
| env = SoftmaxSurrogateEnvironment(measurement_path=args.measurement_path, budget=args.budget, seed=args.seed) | |
| args.task = env.reset()["observation"]["task_id"] | |
| train_task_ids = None | |
| if args.train_tasks_file is not None: | |
| train_task_ids = json.loads(args.train_tasks_file.read_text(encoding="utf-8")) | |
| summary = run_surrogate_baseline( | |
| task=args.task, | |
| episodes=args.episodes, | |
| budget=args.budget, | |
| seed=args.seed, | |
| measurement_path=args.measurement_path, | |
| train_task_ids=train_task_ids, | |
| acquisition=args.acquisition, | |
| beta=args.beta, | |
| xi=args.xi, | |
| ) | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| with args.output.open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2) | |
| print(json.dumps(summary, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |