RL_Surrogate_ENV / server /softmax_surrogate_environment.py
wlan0's picture
Surrogate Discovery vs. Pytorch.compile vs. Triton.autotune
5000a45 unverified
from __future__ import annotations
import csv
import json
import math
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS
DEFAULT_MEASUREMENT_PATH = "data/autotune_measurements.csv"
DEFAULT_BUDGET = 6
INITIAL_DATASET_SIZE = 2
DUPLICATE_PENALTY = -1e-4
FAMILY_INDEX = {
"softmax": 0,
"layernorm": 1,
"grouped_gemm": 2,
"rmsnorm": 3,
"gemm": 4,
}
@dataclass(frozen=True)
class Measurement:
family_group: str
family: str
task_id: str
m: int
n: int
k: int
config_id: int
block_size: int
num_warps: int
num_stages: int
shape_json: str
config_json: str
median_ms: float
effective_gbps: float
score: float
validation_error: float
def _normalize_discrete(values: Sequence[int], value: int) -> float:
idx = list(values).index(int(value))
if len(values) == 1:
return 0.0
return 2.0 * (idx / (len(values) - 1)) - 1.0
class SoftmaxSurrogateEnvironment:
"""
Generic discrete-action autotuning environment backed by measured GPU data.
The class name is kept for compatibility with the existing local server and
baseline scripts, but the task space is now multi-family.
"""
def __init__(
self,
measurement_path: str = DEFAULT_MEASUREMENT_PATH,
budget: int = DEFAULT_BUDGET,
seed: int = 0,
initial_samples: int = INITIAL_DATASET_SIZE,
train_task_ids: Optional[Sequence[str]] = None,
) -> None:
self.measurement_path = Path(measurement_path)
self.budget = int(budget)
self.seed = int(seed)
self.initial_samples = max(1, int(initial_samples))
self.train_task_ids = set(train_task_ids or [])
self._measurements = self._load_measurements()
self._task_ids = sorted(self._measurements.keys())
if not self._task_ids:
raise RuntimeError(
"No measurement data found. Run the measurement collectors first."
)
self._rng = random.Random(self.seed)
self._episode_counter = 0
self._task_id: Optional[str] = None
self._family: Optional[str] = None
self._episode_id: Optional[str] = None
self._task_rows: List[Measurement] = []
self._prior_rows: List[Measurement] = []
self._config_by_id: Dict[int, Measurement] = {}
self._observed_ids: List[int] = []
self._observed_id_set = set()
self._observed_rows: List[Measurement] = []
self._observed_latencies: List[float] = []
self._steps_taken = 0
self._steps_remaining = 0
self._best_latency_ms = float("inf")
self._best_config_id: Optional[int] = None
self._validation_mse = float("inf")
self._surrogate_version = 0
self._surrogate_fitted_version = -1
self._surrogate_x: Optional[np.ndarray] = None
self._surrogate_y: Optional[np.ndarray] = None
self._surrogate_alpha: Optional[np.ndarray] = None
self._surrogate_k: Optional[np.ndarray] = None
self._surrogate_length_scale: float = 0.5
def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> Dict[str, Any]:
if seed is not None:
self._rng = random.Random(int(seed))
if task is None:
task = self._rng.choice(self._task_ids)
if task not in self._measurements:
raise ValueError(f"Unknown task: {task}")
rows = self._measurements[task]
self._task_id = task
self._family = rows[0].family
self._task_rows = rows
self._config_by_id = {row.config_id: row for row in rows}
self._prior_rows = self._build_prior_rows(task)
self._observed_ids = []
self._observed_id_set = set()
self._observed_rows = []
self._observed_latencies = []
self._steps_taken = 0
self._steps_remaining = self.budget
self._best_latency_ms = float("inf")
self._best_config_id = None
self._episode_counter += 1
self._episode_id = f"{task}:{self.seed}:{self._episode_counter}"
sample_count = min(self.initial_samples, len(rows))
for config_id in self._rng.sample(list(self._config_by_id.keys()), k=sample_count):
self._observe_config(config_id)
self._invalidate_surrogate()
self._validation_mse = self._compute_validation_mse()
return self._format_step_output(
observation=self._observation_payload(kind="reset"),
reward=0.0,
done=False,
info=self.diagnostics(),
)
def step(self, action: Any) -> Dict[str, Any]:
if self._task_id is None:
raise RuntimeError("Call reset() before step().")
if self._steps_remaining <= 0:
return self._format_step_output(
observation=self._observation_payload(kind="done"),
reward=0.0,
done=True,
info=self.diagnostics(),
)
config_id = self._extract_config_id(action)
row = self._row_for_id(config_id)
prev_best = self._best_latency_ms
duplicate = config_id in self._observed_id_set
if not duplicate:
self._observe_config(config_id)
self._surrogate_version += 1
self._steps_taken += 1
self._steps_remaining -= 1
self._validation_mse = self._compute_validation_mse()
reward = DUPLICATE_PENALTY if duplicate else max(0.0, math.log(prev_best) - math.log(self._best_latency_ms))
observation = self._observation_payload(
kind="step",
last_trial={
"config_id": config_id,
"config": self.config_info(config_id),
"latency_ms": row.median_ms,
"score": row.score,
"duplicate": duplicate,
},
)
return self._format_step_output(
observation=observation,
reward=reward,
done=self._steps_remaining <= 0,
info=self.diagnostics(),
)
def state(self) -> Dict[str, Any]:
if self._task_id is None:
return {"status": "uninitialized"}
return {
"episode_id": self._episode_id,
"step_count": self._steps_taken,
"task_id": self._task_id,
"family": self._family,
"tried_config_ids": list(self._observed_ids),
}
def diagnostics(self) -> Dict[str, Any]:
if self._task_id is None:
return {"status": "uninitialized"}
oracle_best_ms = self.oracle_best()["median_ms"]
regret = self._best_latency_ms / oracle_best_ms - 1.0
return {
"validation_mse": self._validation_mse,
"best_so_far_ms": self._best_latency_ms,
"oracle_best_ms": oracle_best_ms,
"current_regret": regret,
"observed_count": len(self._observed_ids),
"prior_count": len(self._prior_rows),
}
def available_tasks(self) -> List[str]:
return list(self._task_ids)
def available_config_ids(self) -> List[int]:
if self._task_id is None:
raise RuntimeError("Call reset() before accessing config ids.")
return sorted(self._config_by_id.keys())
def available_configs(self) -> List[Dict[str, Any]]:
return [self.config_info(config_id) for config_id in self.available_config_ids()]
def config_info(self, config_id: int) -> Dict[str, Any]:
row = self._row_for_id(config_id)
return {
"config_id": int(config_id),
"family": row.family,
"task_id": row.task_id,
"block_size": row.block_size,
"num_warps": row.num_warps,
"num_stages": row.num_stages,
}
def measured_latency_ms(self, config_id: int) -> float:
return self._row_for_id(config_id).median_ms
def oracle_best(self) -> Dict[str, Any]:
if self._task_id is None:
raise RuntimeError("Call reset() before querying oracle_best().")
best = min(self._task_rows, key=lambda row: row.median_ms)
return {
"config_id": best.config_id,
"family": best.family,
"task_id": best.task_id,
"block_size": best.block_size,
"num_warps": best.num_warps,
"num_stages": best.num_stages,
"median_ms": best.median_ms,
"score": best.score,
}
def predict_score(self, config_id: int) -> float:
return float(self._predict_with_uncertainty(config_id)[0])
def acquisition_score(
self,
config_id: int,
strategy: str = "ucb",
beta: float = 1.0,
xi: float = 0.0,
) -> float:
mean, sigma = self._predict_with_uncertainty(config_id)
if strategy == "mean":
return float(mean)
if strategy == "ucb":
return float(mean + float(beta) * sigma)
if strategy == "ei":
best_observed = max(row.score for row in self._observed_rows) if self._observed_rows else mean
delta = mean - best_observed - float(xi)
if sigma <= 0.0:
return float(max(delta, 0.0))
z = delta / sigma
return float(max(delta * _normal_cdf(z) + sigma * _normal_pdf(z), 0.0))
raise ValueError(f"Unknown acquisition strategy: {strategy}")
def seen_config_ids(self) -> List[int]:
return list(self._observed_ids)
def _build_prior_rows(self, current_task: str) -> List[Measurement]:
if not self.train_task_ids:
return []
prior_rows: List[Measurement] = []
for task_id in sorted(self.train_task_ids):
if task_id == current_task or task_id not in self._measurements:
continue
prior_rows.extend(self._measurements[task_id])
return prior_rows
def _predict_with_uncertainty(self, config_id: int) -> Tuple[float, float]:
if not self._observed_rows and not self._prior_rows:
raise RuntimeError("No surrogate data available.")
self._fit_surrogate()
if self._surrogate_x is None or self._surrogate_y is None:
raise RuntimeError("Surrogate model unavailable.")
if self._surrogate_x.shape[0] == 1:
return float(self._surrogate_y[0]), 0.0
cfg = _config_to_vector(self._row_for_id(config_id)).reshape(1, -1)
if self._surrogate_k is None or self._surrogate_alpha is None:
raise RuntimeError("Surrogate model unavailable.")
k = _rbf_kernel(self._surrogate_x, cfg, self._surrogate_length_scale).reshape(-1)
pred = float(k @ self._surrogate_alpha)
solve = np.linalg.solve(self._surrogate_k, k)
var = max(0.0, float(1.0 - k @ solve))
return pred, float(math.sqrt(max(var, 1e-12)))
def _fit_surrogate(self) -> None:
if self._surrogate_fitted_version == self._surrogate_version:
return
rows = self._prior_rows + self._observed_rows
if not rows:
self._surrogate_x = None
self._surrogate_y = None
self._surrogate_alpha = None
self._surrogate_k = None
self._surrogate_fitted_version = self._surrogate_version
return
self._surrogate_x = np.array([_config_to_vector(row) for row in rows], dtype=np.float32)
self._surrogate_y = np.array([row.score for row in rows], dtype=np.float32)
if self._surrogate_x.shape[0] == 1:
self._surrogate_alpha = self._surrogate_y.copy()
self._surrogate_k = None
self._surrogate_fitted_version = self._surrogate_version
return
pairwise = _pairwise_sq_dists(self._surrogate_x)
triu = pairwise[np.triu_indices(self._surrogate_x.shape[0], k=1)]
med_dist = float(np.median(np.sqrt(triu))) if triu.size else 0.5
self._surrogate_length_scale = max(0.15, med_dist)
k = _rbf_kernel(self._surrogate_x, self._surrogate_x, self._surrogate_length_scale)
k[np.diag_indices_from(k)] += 1e-3
self._surrogate_k = k
self._surrogate_alpha = np.linalg.solve(k, self._surrogate_y)
self._surrogate_fitted_version = self._surrogate_version
def _compute_validation_mse(self) -> float:
if not self._task_rows:
return float("inf")
preds = np.array(
[self._predict_with_uncertainty(config_id)[0] for config_id in self.available_config_ids()],
dtype=np.float32,
)
target = np.array([self._row_for_id(config_id).score for config_id in self.available_config_ids()], dtype=np.float32)
return float(np.mean((preds - target) ** 2))
def _observe_config(self, config_id: int) -> None:
row = self._row_for_id(config_id)
self._observed_ids.append(config_id)
self._observed_id_set.add(config_id)
self._observed_rows.append(row)
self._observed_latencies.append(row.median_ms)
if row.median_ms < self._best_latency_ms:
self._best_latency_ms = row.median_ms
self._best_config_id = config_id
def _observation_payload(
self,
kind: str,
last_trial: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
payload = {
"type": kind,
"task_id": self._task_id,
"family": self._family,
"M": self._task_rows[0].m if self._task_rows else None,
"N": self._task_rows[0].n if self._task_rows else None,
"dtype": "fp16",
"tried_config_ids": list(self._observed_ids),
"tried_latencies_ms": list(self._observed_latencies),
"best_so_far_ms": self._best_latency_ms,
"steps_remaining": self._steps_remaining,
}
if last_trial is not None:
payload["last_trial"] = last_trial
return payload
def _extract_config_id(self, action: Any) -> int:
if isinstance(action, (str, bytes)):
action = json.loads(action)
if isinstance(action, dict):
if "config_id" in action:
return int(action["config_id"])
if "x" in action:
normalized = self._extract_legacy_action(action["x"])
config = self._map_legacy_action_to_config(normalized)
return config
if isinstance(action, (int, np.integer)):
return int(action)
raise TypeError("Action must be an int config_id or dict with config_id.")
def _extract_legacy_action(self, action: Any) -> List[float]:
arr = np.clip(np.asarray(action, dtype=float), -1.0, 1.0)
if arr.shape != (3,):
raise ValueError("Legacy action vector must have 3 values.")
return arr.tolist()
def _map_legacy_action_to_config(self, action: Sequence[float]) -> int:
base = (
_de_norm(float(action[0]), BLOCK_SIZES),
_de_norm(float(action[1]), NUM_WARPS),
_de_norm(float(action[2]), NUM_STAGES),
)
best_id = min(
self.available_config_ids(),
key=lambda config_id: (
self._row_for_id(config_id).block_size - base[0]
) ** 2
+ (self._row_for_id(config_id).num_warps - base[1]) ** 2
+ (self._row_for_id(config_id).num_stages - base[2]) ** 2,
)
return int(best_id)
def _row_for_id(self, config_id: int) -> Measurement:
if config_id not in self._config_by_id:
raise ValueError(f"Unknown config_id={config_id}")
return self._config_by_id[int(config_id)]
def _invalidate_surrogate(self) -> None:
self._surrogate_version += 1
self._surrogate_fitted_version = -1
self._surrogate_x = None
self._surrogate_y = None
self._surrogate_alpha = None
self._surrogate_k = None
def _format_step_output(
self,
observation: Dict[str, Any],
reward: float,
done: bool,
info: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
return {
"observation": observation,
"reward": float(reward),
"done": bool(done),
"state": self.state(),
"info": info or {},
}
def _load_measurements(self) -> Dict[str, List[Measurement]]:
if not self.measurement_path.exists():
raise FileNotFoundError(
f"Missing measurement file at {self.measurement_path}. "
"Run the measurement collectors first."
)
grouped: Dict[str, List[Measurement]] = {}
with self.measurement_path.open("r", newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
fallback_config_ids: Dict[str, int] = {}
for row in reader:
family = row.get("family", "softmax")
family_group = row.get("family_group", "A" if family in {"softmax", "layernorm"} else "B")
task_id = row["task_id"]
block_size = int(row["block_size"])
num_warps = int(row["num_warps"])
num_stages = int(row["num_stages"])
config_id_raw = row.get("config_id")
if config_id_raw in (None, ""):
key = f"{task_id}|{block_size}|{num_warps}|{num_stages}"
if key not in fallback_config_ids:
fallback_config_ids[key] = len([k for k in fallback_config_ids if k.startswith(f"{task_id}|")])
config_id = fallback_config_ids[key]
else:
config_id = int(config_id_raw)
measurement = Measurement(
family_group=family_group,
family=family,
task_id=task_id,
m=int(row["m"]),
n=int(row["n"]),
k=int(row.get("k", 0) or 0),
config_id=config_id,
block_size=block_size,
num_warps=num_warps,
num_stages=num_stages,
shape_json=row.get("shape_json", "{}"),
config_json=row.get("config_json", "{}"),
median_ms=float(row["median_ms"]),
effective_gbps=float(row["effective_gbps"]),
score=float(row["score"]),
validation_error=float(row["validation_error"]),
)
grouped.setdefault(task_id, []).append(measurement)
for task_id in grouped:
grouped[task_id].sort(key=lambda row: row.config_id)
return grouped
def _config_to_vector(row: Measurement) -> np.ndarray:
family_vec = np.zeros(len(FAMILY_INDEX), dtype=np.float32)
if row.family in FAMILY_INDEX:
family_vec[FAMILY_INDEX[row.family]] = 1.0
shape_fields = {}
try:
shape_fields = json.loads(row.shape_json) if row.shape_json else {}
except Exception:
shape_fields = {}
mode_val = _shape_scalar(shape_fields.get("mode_id"))
role_val = _shape_scalar(shape_fields.get("role_id"))
seq_val = _shape_scalar(shape_fields.get("seq_len"))
ctx_val = _shape_scalar(shape_fields.get("ctx_len"))
group_val = _shape_scalar(shape_fields.get("group_count"))
numeric = np.array(
[
math.log2(max(row.m, 1)) / 16.0,
math.log2(max(row.n, 1)) / 16.0,
math.log2(max(row.k, 1)) / 16.0 if row.k > 0 else 0.0,
math.log2(max(seq_val, 1.0)) / 16.0 if seq_val > 0 else 0.0,
math.log2(max(ctx_val, 1.0)) / 16.0 if ctx_val > 0 else 0.0,
math.log2(max(group_val, 1.0)) / 8.0 if group_val > 0 else 0.0,
mode_val / 8.0,
role_val / 16.0,
_normalize_discrete(BLOCK_SIZES, row.block_size),
_normalize_discrete(NUM_WARPS, row.num_warps),
_normalize_discrete(NUM_STAGES, row.num_stages),
],
dtype=np.float32,
)
return np.concatenate([family_vec, numeric], axis=0)
def _pairwise_sq_dists(X: np.ndarray) -> np.ndarray:
diff = X[:, None, :] - X[None, :, :]
return np.sum(diff * diff, axis=2)
def _rbf_kernel(X: np.ndarray, Y: np.ndarray, length_scale: float) -> np.ndarray:
sigma2 = float(length_scale * length_scale)
if sigma2 <= 0:
sigma2 = 1e-6
xy = X @ Y.T
x2 = np.sum(X * X, axis=1)[:, None]
y2 = np.sum(Y * Y, axis=1)[None, :]
d2 = np.maximum(x2 - 2.0 * xy + y2, 0.0)
return np.exp(-0.5 * d2 / sigma2).astype(np.float32)
def _normal_pdf(z: float) -> float:
inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi)
return float(inv_sqrt_2pi * math.exp(-0.5 * z * z))
def _normal_cdf(z: float) -> float:
return 0.5 * (1.0 + math.erf(z / math.sqrt(2.0)))
def _shape_scalar(value: Any) -> float:
if value is None:
return 0.0
try:
return float(value)
except (TypeError, ValueError):
return 0.0