RL_Surrogate_ENV / scripts /run_random_baseline.py
wlan0's picture
Surrogate Discovery vs. Pytorch.compile vs. Triton.autotune
5000a45 unverified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
import sys
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
@dataclass
class RunRecord:
task_id: str
episode: int
best_latency_ms: float
best_config: Dict[str, int]
final_validation_mse: float
final_state: Dict[str, Any]
final_regret: float
history: List[Dict[str, Any]]
def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]:
ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget))
regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks}
auc_regrets: List[float] = []
for episode in episode_records:
regrets = [float(step["regret"]) for step in episode["history"]]
if regrets:
auc_regrets.append(float(sum(regrets) / len(regrets)))
for k in ks:
if len(regrets) >= k:
regrets_by_k[k].append(regrets[k - 1])
return {
"mean_regret_at": {
str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals
},
"median_regret_at": {
str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals
},
"mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None,
"oracle_hit_rate_final": float(
sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records)
) if episode_records else None,
}
def _pick_task_from_input(args: argparse.Namespace) -> str:
if args.task:
return args.task
env = SoftmaxSurrogateEnvironment(
measurement_path=args.measurement_path,
budget=args.budget,
seed=args.seed,
)
return env.reset()["observation"]["task_id"]
def run_random_baseline(
task: str,
episodes: int,
budget: int,
seed: int,
measurement_path: str,
) -> Dict[str, Any]:
rng = np.random.default_rng(seed)
best_overall: Dict[str, Any] = {"latency_ms": float("inf"), "config": None, "task_id": task}
episode_records: List[Dict[str, Any]] = []
env = SoftmaxSurrogateEnvironment(
measurement_path=measurement_path,
budget=budget,
seed=seed,
)
for episode in range(episodes):
env.reset(task=task, seed=seed + episode)
done = False
episode_best = float("inf")
episode_best_cfg: Dict[str, int] | None = None
history: List[Dict[str, Any]] = []
while not done:
unseen = [config_id for config_id in env.available_config_ids() if config_id not in env.seen_config_ids()]
choice_pool = unseen if unseen else env.available_config_ids()
config_id = int(rng.choice(choice_pool))
step_out = env.step({"config_id": config_id})
obs = step_out["observation"]
trial = obs["last_trial"]
history.append(
{
"config_id": config_id,
"latency_ms": trial["latency_ms"],
"config": trial["config"],
"reward": step_out["reward"],
"regret": step_out["info"]["current_regret"],
"validation_mse": step_out["info"]["validation_mse"],
}
)
if obs["best_so_far_ms"] < episode_best:
episode_best = obs["best_so_far_ms"]
best_id = env.seen_config_ids()[int(np.argmin([env.measured_latency_ms(cid) for cid in env.seen_config_ids()]))]
episode_best_cfg = env.config_info(best_id)
done = bool(step_out["done"])
if episode_best < best_overall["latency_ms"]:
best_overall = {
"latency_ms": float(episode_best),
"config": episode_best_cfg,
"task_id": task,
}
diagnostics = env.diagnostics()
episode_records.append(
RunRecord(
task_id=task,
episode=episode,
best_latency_ms=float(episode_best),
best_config=episode_best_cfg or {},
final_validation_mse=float(diagnostics["validation_mse"]),
final_state=env.state(),
final_regret=float(diagnostics["current_regret"]),
history=history,
).__dict__
)
return {
"task": task,
"method": "random",
"episodes": episodes,
"budget": budget,
"seed": seed,
"oracle_best_ms": env.oracle_best()["median_ms"],
"best_overall": best_overall,
"aggregate_metrics": _aggregate_metrics(episode_records, budget),
"episodes_summary": episode_records,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Random baseline for surrogate environment.")
parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)")
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--budget", type=int, default=6)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--measurement-path",
type=str,
default="data/autotune_measurements.csv",
)
parser.add_argument(
"--output",
type=Path,
default=Path("outputs/random_baseline.json"),
)
return parser.parse_args()
def main() -> None:
args = parse_args()
task = _pick_task_from_input(args)
summary = run_random_baseline(
task=task,
episodes=args.episodes,
budget=args.budget,
seed=args.seed,
measurement_path=args.measurement_path,
)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()