RL_Surrogate_ENV / scripts /run_surrogate_baseline.py
wlan0's picture
Surrogate Discovery vs. Pytorch.compile vs. Triton.autotune
5000a45 unverified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
def _choose_surrogate_action(
env: SoftmaxSurrogateEnvironment,
acquisition: str,
beta: float,
xi: float,
) -> int:
seen = set(env.seen_config_ids())
best_config_id = -1
best_score = float("-inf")
for config_id in env.available_config_ids():
if config_id in seen and len(seen) < len(env.available_config_ids()):
continue
score = env.acquisition_score(config_id, strategy=acquisition, beta=beta, xi=xi)
if score > best_score:
best_score = score
best_config_id = config_id
if best_config_id < 0:
raise RuntimeError("Failed to choose a surrogate action.")
return best_config_id
def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]:
ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget))
regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks}
auc_regrets: List[float] = []
for episode in episode_records:
regrets = [float(step["regret"]) for step in episode["history"]]
if regrets:
auc_regrets.append(float(sum(regrets) / len(regrets)))
for k in ks:
if len(regrets) >= k:
regrets_by_k[k].append(regrets[k - 1])
return {
"mean_regret_at": {
str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals
},
"median_regret_at": {
str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals
},
"mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None,
"oracle_hit_rate_final": float(
sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records)
) if episode_records else None,
}
def run_surrogate_baseline(
task: str,
episodes: int,
budget: int,
seed: int,
measurement_path: str,
train_task_ids: List[str] | None = None,
acquisition: str = "ucb",
beta: float = 1.5,
xi: float = 0.0,
) -> Dict[str, Any]:
env = SoftmaxSurrogateEnvironment(
measurement_path=measurement_path,
budget=budget,
seed=seed,
train_task_ids=train_task_ids,
)
best_overall = {"latency_ms": float("inf"), "config": None, "task_id": task}
episode_records: List[Dict[str, Any]] = []
for episode in range(episodes):
env.reset(task=task, seed=seed + episode)
done = False
episode_best = float("inf")
episode_best_cfg: Dict[str, int] | None = None
history: List[Dict[str, Any]] = []
while not done:
config_id = _choose_surrogate_action(env, acquisition=acquisition, beta=beta, xi=xi)
out = env.step({"config_id": config_id})
obs = out["observation"]
trial = obs["last_trial"]
history.append(
{
"config_id": config_id,
"latency_ms": trial["latency_ms"],
"config": trial["config"],
"reward": out["reward"],
"regret": out["info"]["current_regret"],
"validation_mse": out["info"]["validation_mse"],
}
)
if obs["best_so_far_ms"] < episode_best:
episode_best = obs["best_so_far_ms"]
best_seen = min(env.seen_config_ids(), key=env.measured_latency_ms)
episode_best_cfg = env.config_info(best_seen)
done = bool(out["done"])
if episode_best < best_overall["latency_ms"]:
best_overall = {
"latency_ms": float(episode_best),
"config": episode_best_cfg,
"task_id": task,
}
diagnostics = env.diagnostics()
episode_records.append(
{
"task_id": task,
"episode": episode,
"best_latency_ms": episode_best,
"best_config": episode_best_cfg or {},
"final_validation_mse": diagnostics["validation_mse"],
"final_regret": diagnostics["current_regret"],
"history": history,
}
)
return {
"task": task,
"method": "surrogate",
"episodes": episodes,
"budget": budget,
"seed": seed,
"train_task_ids": list(train_task_ids or []),
"acquisition": acquisition,
"beta": beta,
"xi": xi,
"oracle_best_ms": env.oracle_best()["median_ms"],
"best_overall": best_overall,
"aggregate_metrics": _aggregate_metrics(episode_records, budget),
"episodes_summary": episode_records,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Surrogate-guided baseline.")
parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)")
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--budget", type=int, default=6)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--acquisition",
type=str,
choices=("mean", "ucb", "ei"),
default="ucb",
help="Candidate selection mode: mean, ucb, or ei.",
)
parser.add_argument("--beta", type=float, default=1.5, help="UCB exploration strength.")
parser.add_argument("--xi", type=float, default=0.0, help="Expected-improvement margin.")
parser.add_argument(
"--train-tasks-file",
type=Path,
default=None,
help="Optional JSON file containing a list of train task ids.",
)
parser.add_argument(
"--measurement-path",
type=str,
default="data/autotune_measurements.csv",
)
parser.add_argument(
"--output",
type=Path,
default=Path("outputs/surrogate_baseline.json"),
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if not args.task:
env = SoftmaxSurrogateEnvironment(measurement_path=args.measurement_path, budget=args.budget, seed=args.seed)
args.task = env.reset()["observation"]["task_id"]
train_task_ids = None
if args.train_tasks_file is not None:
train_task_ids = json.loads(args.train_tasks_file.read_text(encoding="utf-8"))
summary = run_surrogate_baseline(
task=args.task,
episodes=args.episodes,
budget=args.budget,
seed=args.seed,
measurement_path=args.measurement_path,
train_task_ids=train_task_ids,
acquisition=args.acquisition,
beta=args.beta,
xi=args.xi,
)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()