#!/usr/bin/env python3 from __future__ import annotations import argparse import json from dataclasses import dataclass import sys from pathlib import Path from typing import Any, Dict, List import numpy as np ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment @dataclass class RunRecord: task_id: str episode: int best_latency_ms: float best_config: Dict[str, int] final_validation_mse: float final_state: Dict[str, Any] final_regret: float history: List[Dict[str, Any]] def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]: ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget)) regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks} auc_regrets: List[float] = [] for episode in episode_records: regrets = [float(step["regret"]) for step in episode["history"]] if regrets: auc_regrets.append(float(sum(regrets) / len(regrets))) for k in ks: if len(regrets) >= k: regrets_by_k[k].append(regrets[k - 1]) return { "mean_regret_at": { str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals }, "median_regret_at": { str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals }, "mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None, "oracle_hit_rate_final": float( sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records) ) if episode_records else None, } def _pick_task_from_input(args: argparse.Namespace) -> str: if args.task: return args.task env = SoftmaxSurrogateEnvironment( measurement_path=args.measurement_path, budget=args.budget, seed=args.seed, ) return env.reset()["observation"]["task_id"] def run_random_baseline( task: str, episodes: int, budget: int, seed: int, measurement_path: str, ) -> Dict[str, Any]: rng = np.random.default_rng(seed) best_overall: Dict[str, Any] = {"latency_ms": float("inf"), "config": None, "task_id": task} episode_records: List[Dict[str, Any]] = [] env = SoftmaxSurrogateEnvironment( measurement_path=measurement_path, budget=budget, seed=seed, ) for episode in range(episodes): env.reset(task=task, seed=seed + episode) done = False episode_best = float("inf") episode_best_cfg: Dict[str, int] | None = None history: List[Dict[str, Any]] = [] while not done: unseen = [config_id for config_id in env.available_config_ids() if config_id not in env.seen_config_ids()] choice_pool = unseen if unseen else env.available_config_ids() config_id = int(rng.choice(choice_pool)) step_out = env.step({"config_id": config_id}) obs = step_out["observation"] trial = obs["last_trial"] history.append( { "config_id": config_id, "latency_ms": trial["latency_ms"], "config": trial["config"], "reward": step_out["reward"], "regret": step_out["info"]["current_regret"], "validation_mse": step_out["info"]["validation_mse"], } ) if obs["best_so_far_ms"] < episode_best: episode_best = obs["best_so_far_ms"] best_id = env.seen_config_ids()[int(np.argmin([env.measured_latency_ms(cid) for cid in env.seen_config_ids()]))] episode_best_cfg = env.config_info(best_id) done = bool(step_out["done"]) if episode_best < best_overall["latency_ms"]: best_overall = { "latency_ms": float(episode_best), "config": episode_best_cfg, "task_id": task, } diagnostics = env.diagnostics() episode_records.append( RunRecord( task_id=task, episode=episode, best_latency_ms=float(episode_best), best_config=episode_best_cfg or {}, final_validation_mse=float(diagnostics["validation_mse"]), final_state=env.state(), final_regret=float(diagnostics["current_regret"]), history=history, ).__dict__ ) return { "task": task, "method": "random", "episodes": episodes, "budget": budget, "seed": seed, "oracle_best_ms": env.oracle_best()["median_ms"], "best_overall": best_overall, "aggregate_metrics": _aggregate_metrics(episode_records, budget), "episodes_summary": episode_records, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Random baseline for surrogate environment.") parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)") parser.add_argument("--episodes", type=int, default=20) parser.add_argument("--budget", type=int, default=6) parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--measurement-path", type=str, default="data/autotune_measurements.csv", ) parser.add_argument( "--output", type=Path, default=Path("outputs/random_baseline.json"), ) return parser.parse_args() def main() -> None: args = parse_args() task = _pick_task_from_input(args) summary = run_random_baseline( task=task, episodes=args.episodes, budget=args.budget, seed=args.seed, measurement_path=args.measurement_path, ) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2) print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()