#!/usr/bin/env python3 from __future__ import annotations import argparse import json from dataclasses import dataclass import sys from pathlib import Path from typing import Any, Dict, List ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment from scripts.collect_measurements import benchmark_single_config from scripts.run_random_baseline import run_random_baseline from scripts.run_surrogate_baseline import run_surrogate_baseline @dataclass class BaselineResult: method: str latency_ms: float config: Dict[str, int] regret: float def _search_metric_key(summary: Dict[str, Any], oracle_best_ms: float) -> tuple[float, float, float]: metrics = summary.get("aggregate_metrics", {}) mean_auc_regret = float(metrics.get("mean_auc_regret", float("inf"))) oracle_hit_rate_final = float(metrics.get("oracle_hit_rate_final", 0.0)) best_latency_ms = float(summary["best_overall"]["latency_ms"]) latency_regret = best_latency_ms / oracle_best_ms - 1.0 return (mean_auc_regret, -oracle_hit_rate_final, latency_regret) def _heuristic_for_task(task_id: str, task_rows: List[Dict[str, Any]], env: SoftmaxSurrogateEnvironment) -> BaselineResult: n = int(task_id.split("_n")[-1]) block = min(row["block_size"] for row in task_rows if row["block_size"] >= n) warp = 4 if 4 in {row["num_warps"] for row in task_rows} else 2 stage = 2 if 2 in {row["num_stages"] for row in task_rows} else 1 candidate = None for row in task_rows: if row["block_size"] == block and row["num_warps"] == warp and row["num_stages"] == stage: candidate = row break if candidate is None: candidate = min( task_rows, key=lambda row: abs(row["block_size"] - block) + 10 * abs(row["num_warps"] - warp), ) latency_ms = env.measured_latency_ms(candidate["config_id"]) oracle_best_ms = env.oracle_best()["median_ms"] return BaselineResult( method="heuristic", latency_ms=float(latency_ms), config=candidate, regret=float(latency_ms / oracle_best_ms - 1.0), ) def _pick_task(task_arg: str | None, measurement_path: str, budget: int) -> str: env = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=budget, seed=0) if task_arg: env.reset(task=task_arg) else: env.reset() return env.state()["task_id"] def _run_all( task: str, budget: int, episodes: int, seed: int, measurement_path: str, acquisition: str, beta: float, xi: float, ) -> Dict[str, Any]: env = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=budget, seed=seed) env.reset(task=task) task_rows = env.available_configs() oracle_best = env.oracle_best() heuristic = _heuristic_for_task(task, task_rows, env) random_summary = run_random_baseline(task, episodes=episodes, budget=budget, seed=seed, measurement_path=measurement_path) surrogate_summary = run_surrogate_baseline( task, episodes=episodes, budget=budget, seed=seed, measurement_path=measurement_path, acquisition=acquisition, beta=beta, xi=xi, ) search_summaries = { "random": random_summary, "surrogate": surrogate_summary, } winner_method, winner_summary = min( search_summaries.items(), key=lambda item: _search_metric_key(item[1], oracle_best["median_ms"]), ) winner_cfg = winner_summary["best_overall"]["config"] winner_regret = float(winner_summary["best_overall"]["latency_ms"] / oracle_best["median_ms"] - 1.0) n = int(task.split("_n")[-1]) live = benchmark_single_config( n=n, block_size=winner_cfg["block_size"], num_warps=winner_cfg["num_warps"], num_stages=winner_cfg["num_stages"], repeats=max(200, budget * 20), warmup=25, seed=seed + 999, ) return { "task": task, "seed": seed, "budget": budget, "episodes": episodes, "acquisition": acquisition, "beta": beta, "xi": xi, "oracle_best": oracle_best, "heuristic": heuristic.__dict__, "random": random_summary["best_overall"], "random_aggregate_metrics": random_summary.get("aggregate_metrics", {}), "surrogate": surrogate_summary["best_overall"], "surrogate_aggregate_metrics": surrogate_summary.get("aggregate_metrics", {}), "winner": { "method": winner_method, "selection_metric": "min(mean_auc_regret), tie-break max(oracle_hit_rate_final), then best latency", "latency_ms": winner_summary["best_overall"]["latency_ms"], "config": winner_cfg, "regret": winner_regret, "live_rerun": live.__dict__, }, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Compare heuristic/random/surrogate baselines.") parser.add_argument( "--task", default="softmax_m4096_n2048", help="Task ID (e.g., softmax_m4096_n2048)", ) parser.add_argument("--budget", type=int, default=6) parser.add_argument("--episodes", type=int, default=20) parser.add_argument("--seed", type=int, default=2) parser.add_argument( "--acquisition", type=str, choices=("mean", "ucb", "ei"), default="ucb", ) parser.add_argument("--beta", type=float, default=2.0) parser.add_argument("--xi", type=float, default=0.0) parser.add_argument( "--measurement-path", type=str, default="data/autotune_measurements.csv", ) parser.add_argument( "--output", type=Path, default=Path("outputs/demo_compare.json"), ) return parser.parse_args() def main() -> None: args = parse_args() task = _pick_task(args.task, args.measurement_path, args.budget) summary = _run_all( task=task, budget=args.budget, episodes=args.episodes, seed=args.seed, measurement_path=args.measurement_path, acquisition=args.acquisition, beta=args.beta, xi=args.xi, ) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2) print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()