Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import csv | |
| import json | |
| import math | |
| import random | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| import numpy as np | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.append(str(ROOT)) | |
| from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS | |
| DEFAULT_MEASUREMENT_PATH = "data/autotune_measurements.csv" | |
| DEFAULT_BUDGET = 6 | |
| INITIAL_DATASET_SIZE = 2 | |
| DUPLICATE_PENALTY = -1e-4 | |
| FAMILY_INDEX = { | |
| "softmax": 0, | |
| "layernorm": 1, | |
| "grouped_gemm": 2, | |
| "rmsnorm": 3, | |
| "gemm": 4, | |
| } | |
| class Measurement: | |
| family_group: str | |
| family: str | |
| task_id: str | |
| m: int | |
| n: int | |
| k: int | |
| config_id: int | |
| block_size: int | |
| num_warps: int | |
| num_stages: int | |
| shape_json: str | |
| config_json: str | |
| median_ms: float | |
| effective_gbps: float | |
| score: float | |
| validation_error: float | |
| def _normalize_discrete(values: Sequence[int], value: int) -> float: | |
| idx = list(values).index(int(value)) | |
| if len(values) == 1: | |
| return 0.0 | |
| return 2.0 * (idx / (len(values) - 1)) - 1.0 | |
| class SoftmaxSurrogateEnvironment: | |
| """ | |
| Generic discrete-action autotuning environment backed by measured GPU data. | |
| The class name is kept for compatibility with the existing local server and | |
| baseline scripts, but the task space is now multi-family. | |
| """ | |
| def __init__( | |
| self, | |
| measurement_path: str = DEFAULT_MEASUREMENT_PATH, | |
| budget: int = DEFAULT_BUDGET, | |
| seed: int = 0, | |
| initial_samples: int = INITIAL_DATASET_SIZE, | |
| train_task_ids: Optional[Sequence[str]] = None, | |
| ) -> None: | |
| self.measurement_path = Path(measurement_path) | |
| self.budget = int(budget) | |
| self.seed = int(seed) | |
| self.initial_samples = max(1, int(initial_samples)) | |
| self.train_task_ids = set(train_task_ids or []) | |
| self._measurements = self._load_measurements() | |
| self._task_ids = sorted(self._measurements.keys()) | |
| if not self._task_ids: | |
| raise RuntimeError( | |
| "No measurement data found. Run the measurement collectors first." | |
| ) | |
| self._rng = random.Random(self.seed) | |
| self._episode_counter = 0 | |
| self._task_id: Optional[str] = None | |
| self._family: Optional[str] = None | |
| self._episode_id: Optional[str] = None | |
| self._task_rows: List[Measurement] = [] | |
| self._prior_rows: List[Measurement] = [] | |
| self._config_by_id: Dict[int, Measurement] = {} | |
| self._observed_ids: List[int] = [] | |
| self._observed_id_set = set() | |
| self._observed_rows: List[Measurement] = [] | |
| self._observed_latencies: List[float] = [] | |
| self._steps_taken = 0 | |
| self._steps_remaining = 0 | |
| self._best_latency_ms = float("inf") | |
| self._best_config_id: Optional[int] = None | |
| self._validation_mse = float("inf") | |
| self._surrogate_version = 0 | |
| self._surrogate_fitted_version = -1 | |
| self._surrogate_x: Optional[np.ndarray] = None | |
| self._surrogate_y: Optional[np.ndarray] = None | |
| self._surrogate_alpha: Optional[np.ndarray] = None | |
| self._surrogate_k: Optional[np.ndarray] = None | |
| self._surrogate_length_scale: float = 0.5 | |
| def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> Dict[str, Any]: | |
| if seed is not None: | |
| self._rng = random.Random(int(seed)) | |
| if task is None: | |
| task = self._rng.choice(self._task_ids) | |
| if task not in self._measurements: | |
| raise ValueError(f"Unknown task: {task}") | |
| rows = self._measurements[task] | |
| self._task_id = task | |
| self._family = rows[0].family | |
| self._task_rows = rows | |
| self._config_by_id = {row.config_id: row for row in rows} | |
| self._prior_rows = self._build_prior_rows(task) | |
| self._observed_ids = [] | |
| self._observed_id_set = set() | |
| self._observed_rows = [] | |
| self._observed_latencies = [] | |
| self._steps_taken = 0 | |
| self._steps_remaining = self.budget | |
| self._best_latency_ms = float("inf") | |
| self._best_config_id = None | |
| self._episode_counter += 1 | |
| self._episode_id = f"{task}:{self.seed}:{self._episode_counter}" | |
| sample_count = min(self.initial_samples, len(rows)) | |
| for config_id in self._rng.sample(list(self._config_by_id.keys()), k=sample_count): | |
| self._observe_config(config_id) | |
| self._invalidate_surrogate() | |
| self._validation_mse = self._compute_validation_mse() | |
| return self._format_step_output( | |
| observation=self._observation_payload(kind="reset"), | |
| reward=0.0, | |
| done=False, | |
| info=self.diagnostics(), | |
| ) | |
| def step(self, action: Any) -> Dict[str, Any]: | |
| if self._task_id is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._steps_remaining <= 0: | |
| return self._format_step_output( | |
| observation=self._observation_payload(kind="done"), | |
| reward=0.0, | |
| done=True, | |
| info=self.diagnostics(), | |
| ) | |
| config_id = self._extract_config_id(action) | |
| row = self._row_for_id(config_id) | |
| prev_best = self._best_latency_ms | |
| duplicate = config_id in self._observed_id_set | |
| if not duplicate: | |
| self._observe_config(config_id) | |
| self._surrogate_version += 1 | |
| self._steps_taken += 1 | |
| self._steps_remaining -= 1 | |
| self._validation_mse = self._compute_validation_mse() | |
| reward = DUPLICATE_PENALTY if duplicate else max(0.0, math.log(prev_best) - math.log(self._best_latency_ms)) | |
| observation = self._observation_payload( | |
| kind="step", | |
| last_trial={ | |
| "config_id": config_id, | |
| "config": self.config_info(config_id), | |
| "latency_ms": row.median_ms, | |
| "score": row.score, | |
| "duplicate": duplicate, | |
| }, | |
| ) | |
| return self._format_step_output( | |
| observation=observation, | |
| reward=reward, | |
| done=self._steps_remaining <= 0, | |
| info=self.diagnostics(), | |
| ) | |
| def state(self) -> Dict[str, Any]: | |
| if self._task_id is None: | |
| return {"status": "uninitialized"} | |
| return { | |
| "episode_id": self._episode_id, | |
| "step_count": self._steps_taken, | |
| "task_id": self._task_id, | |
| "family": self._family, | |
| "tried_config_ids": list(self._observed_ids), | |
| } | |
| def diagnostics(self) -> Dict[str, Any]: | |
| if self._task_id is None: | |
| return {"status": "uninitialized"} | |
| oracle_best_ms = self.oracle_best()["median_ms"] | |
| regret = self._best_latency_ms / oracle_best_ms - 1.0 | |
| return { | |
| "validation_mse": self._validation_mse, | |
| "best_so_far_ms": self._best_latency_ms, | |
| "oracle_best_ms": oracle_best_ms, | |
| "current_regret": regret, | |
| "observed_count": len(self._observed_ids), | |
| "prior_count": len(self._prior_rows), | |
| } | |
| def available_tasks(self) -> List[str]: | |
| return list(self._task_ids) | |
| def available_config_ids(self) -> List[int]: | |
| if self._task_id is None: | |
| raise RuntimeError("Call reset() before accessing config ids.") | |
| return sorted(self._config_by_id.keys()) | |
| def available_configs(self) -> List[Dict[str, Any]]: | |
| return [self.config_info(config_id) for config_id in self.available_config_ids()] | |
| def config_info(self, config_id: int) -> Dict[str, Any]: | |
| row = self._row_for_id(config_id) | |
| return { | |
| "config_id": int(config_id), | |
| "family": row.family, | |
| "task_id": row.task_id, | |
| "block_size": row.block_size, | |
| "num_warps": row.num_warps, | |
| "num_stages": row.num_stages, | |
| } | |
| def measured_latency_ms(self, config_id: int) -> float: | |
| return self._row_for_id(config_id).median_ms | |
| def oracle_best(self) -> Dict[str, Any]: | |
| if self._task_id is None: | |
| raise RuntimeError("Call reset() before querying oracle_best().") | |
| best = min(self._task_rows, key=lambda row: row.median_ms) | |
| return { | |
| "config_id": best.config_id, | |
| "family": best.family, | |
| "task_id": best.task_id, | |
| "block_size": best.block_size, | |
| "num_warps": best.num_warps, | |
| "num_stages": best.num_stages, | |
| "median_ms": best.median_ms, | |
| "score": best.score, | |
| } | |
| def predict_score(self, config_id: int) -> float: | |
| return float(self._predict_with_uncertainty(config_id)[0]) | |
| def acquisition_score( | |
| self, | |
| config_id: int, | |
| strategy: str = "ucb", | |
| beta: float = 1.0, | |
| xi: float = 0.0, | |
| ) -> float: | |
| mean, sigma = self._predict_with_uncertainty(config_id) | |
| if strategy == "mean": | |
| return float(mean) | |
| if strategy == "ucb": | |
| return float(mean + float(beta) * sigma) | |
| if strategy == "ei": | |
| best_observed = max(row.score for row in self._observed_rows) if self._observed_rows else mean | |
| delta = mean - best_observed - float(xi) | |
| if sigma <= 0.0: | |
| return float(max(delta, 0.0)) | |
| z = delta / sigma | |
| return float(max(delta * _normal_cdf(z) + sigma * _normal_pdf(z), 0.0)) | |
| raise ValueError(f"Unknown acquisition strategy: {strategy}") | |
| def seen_config_ids(self) -> List[int]: | |
| return list(self._observed_ids) | |
| def _build_prior_rows(self, current_task: str) -> List[Measurement]: | |
| if not self.train_task_ids: | |
| return [] | |
| prior_rows: List[Measurement] = [] | |
| for task_id in sorted(self.train_task_ids): | |
| if task_id == current_task or task_id not in self._measurements: | |
| continue | |
| prior_rows.extend(self._measurements[task_id]) | |
| return prior_rows | |
| def _predict_with_uncertainty(self, config_id: int) -> Tuple[float, float]: | |
| if not self._observed_rows and not self._prior_rows: | |
| raise RuntimeError("No surrogate data available.") | |
| self._fit_surrogate() | |
| if self._surrogate_x is None or self._surrogate_y is None: | |
| raise RuntimeError("Surrogate model unavailable.") | |
| if self._surrogate_x.shape[0] == 1: | |
| return float(self._surrogate_y[0]), 0.0 | |
| cfg = _config_to_vector(self._row_for_id(config_id)).reshape(1, -1) | |
| if self._surrogate_k is None or self._surrogate_alpha is None: | |
| raise RuntimeError("Surrogate model unavailable.") | |
| k = _rbf_kernel(self._surrogate_x, cfg, self._surrogate_length_scale).reshape(-1) | |
| pred = float(k @ self._surrogate_alpha) | |
| solve = np.linalg.solve(self._surrogate_k, k) | |
| var = max(0.0, float(1.0 - k @ solve)) | |
| return pred, float(math.sqrt(max(var, 1e-12))) | |
| def _fit_surrogate(self) -> None: | |
| if self._surrogate_fitted_version == self._surrogate_version: | |
| return | |
| rows = self._prior_rows + self._observed_rows | |
| if not rows: | |
| self._surrogate_x = None | |
| self._surrogate_y = None | |
| self._surrogate_alpha = None | |
| self._surrogate_k = None | |
| self._surrogate_fitted_version = self._surrogate_version | |
| return | |
| self._surrogate_x = np.array([_config_to_vector(row) for row in rows], dtype=np.float32) | |
| self._surrogate_y = np.array([row.score for row in rows], dtype=np.float32) | |
| if self._surrogate_x.shape[0] == 1: | |
| self._surrogate_alpha = self._surrogate_y.copy() | |
| self._surrogate_k = None | |
| self._surrogate_fitted_version = self._surrogate_version | |
| return | |
| pairwise = _pairwise_sq_dists(self._surrogate_x) | |
| triu = pairwise[np.triu_indices(self._surrogate_x.shape[0], k=1)] | |
| med_dist = float(np.median(np.sqrt(triu))) if triu.size else 0.5 | |
| self._surrogate_length_scale = max(0.15, med_dist) | |
| k = _rbf_kernel(self._surrogate_x, self._surrogate_x, self._surrogate_length_scale) | |
| k[np.diag_indices_from(k)] += 1e-3 | |
| self._surrogate_k = k | |
| self._surrogate_alpha = np.linalg.solve(k, self._surrogate_y) | |
| self._surrogate_fitted_version = self._surrogate_version | |
| def _compute_validation_mse(self) -> float: | |
| if not self._task_rows: | |
| return float("inf") | |
| preds = np.array( | |
| [self._predict_with_uncertainty(config_id)[0] for config_id in self.available_config_ids()], | |
| dtype=np.float32, | |
| ) | |
| target = np.array([self._row_for_id(config_id).score for config_id in self.available_config_ids()], dtype=np.float32) | |
| return float(np.mean((preds - target) ** 2)) | |
| def _observe_config(self, config_id: int) -> None: | |
| row = self._row_for_id(config_id) | |
| self._observed_ids.append(config_id) | |
| self._observed_id_set.add(config_id) | |
| self._observed_rows.append(row) | |
| self._observed_latencies.append(row.median_ms) | |
| if row.median_ms < self._best_latency_ms: | |
| self._best_latency_ms = row.median_ms | |
| self._best_config_id = config_id | |
| def _observation_payload( | |
| self, | |
| kind: str, | |
| last_trial: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| payload = { | |
| "type": kind, | |
| "task_id": self._task_id, | |
| "family": self._family, | |
| "M": self._task_rows[0].m if self._task_rows else None, | |
| "N": self._task_rows[0].n if self._task_rows else None, | |
| "dtype": "fp16", | |
| "tried_config_ids": list(self._observed_ids), | |
| "tried_latencies_ms": list(self._observed_latencies), | |
| "best_so_far_ms": self._best_latency_ms, | |
| "steps_remaining": self._steps_remaining, | |
| } | |
| if last_trial is not None: | |
| payload["last_trial"] = last_trial | |
| return payload | |
| def _extract_config_id(self, action: Any) -> int: | |
| if isinstance(action, (str, bytes)): | |
| action = json.loads(action) | |
| if isinstance(action, dict): | |
| if "config_id" in action: | |
| return int(action["config_id"]) | |
| if "x" in action: | |
| normalized = self._extract_legacy_action(action["x"]) | |
| config = self._map_legacy_action_to_config(normalized) | |
| return config | |
| if isinstance(action, (int, np.integer)): | |
| return int(action) | |
| raise TypeError("Action must be an int config_id or dict with config_id.") | |
| def _extract_legacy_action(self, action: Any) -> List[float]: | |
| arr = np.clip(np.asarray(action, dtype=float), -1.0, 1.0) | |
| if arr.shape != (3,): | |
| raise ValueError("Legacy action vector must have 3 values.") | |
| return arr.tolist() | |
| def _map_legacy_action_to_config(self, action: Sequence[float]) -> int: | |
| base = ( | |
| _de_norm(float(action[0]), BLOCK_SIZES), | |
| _de_norm(float(action[1]), NUM_WARPS), | |
| _de_norm(float(action[2]), NUM_STAGES), | |
| ) | |
| best_id = min( | |
| self.available_config_ids(), | |
| key=lambda config_id: ( | |
| self._row_for_id(config_id).block_size - base[0] | |
| ) ** 2 | |
| + (self._row_for_id(config_id).num_warps - base[1]) ** 2 | |
| + (self._row_for_id(config_id).num_stages - base[2]) ** 2, | |
| ) | |
| return int(best_id) | |
| def _row_for_id(self, config_id: int) -> Measurement: | |
| if config_id not in self._config_by_id: | |
| raise ValueError(f"Unknown config_id={config_id}") | |
| return self._config_by_id[int(config_id)] | |
| def _invalidate_surrogate(self) -> None: | |
| self._surrogate_version += 1 | |
| self._surrogate_fitted_version = -1 | |
| self._surrogate_x = None | |
| self._surrogate_y = None | |
| self._surrogate_alpha = None | |
| self._surrogate_k = None | |
| def _format_step_output( | |
| self, | |
| observation: Dict[str, Any], | |
| reward: float, | |
| done: bool, | |
| info: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| return { | |
| "observation": observation, | |
| "reward": float(reward), | |
| "done": bool(done), | |
| "state": self.state(), | |
| "info": info or {}, | |
| } | |
| def _load_measurements(self) -> Dict[str, List[Measurement]]: | |
| if not self.measurement_path.exists(): | |
| raise FileNotFoundError( | |
| f"Missing measurement file at {self.measurement_path}. " | |
| "Run the measurement collectors first." | |
| ) | |
| grouped: Dict[str, List[Measurement]] = {} | |
| with self.measurement_path.open("r", newline="", encoding="utf-8") as handle: | |
| reader = csv.DictReader(handle) | |
| fallback_config_ids: Dict[str, int] = {} | |
| for row in reader: | |
| family = row.get("family", "softmax") | |
| family_group = row.get("family_group", "A" if family in {"softmax", "layernorm"} else "B") | |
| task_id = row["task_id"] | |
| block_size = int(row["block_size"]) | |
| num_warps = int(row["num_warps"]) | |
| num_stages = int(row["num_stages"]) | |
| config_id_raw = row.get("config_id") | |
| if config_id_raw in (None, ""): | |
| key = f"{task_id}|{block_size}|{num_warps}|{num_stages}" | |
| if key not in fallback_config_ids: | |
| fallback_config_ids[key] = len([k for k in fallback_config_ids if k.startswith(f"{task_id}|")]) | |
| config_id = fallback_config_ids[key] | |
| else: | |
| config_id = int(config_id_raw) | |
| measurement = Measurement( | |
| family_group=family_group, | |
| family=family, | |
| task_id=task_id, | |
| m=int(row["m"]), | |
| n=int(row["n"]), | |
| k=int(row.get("k", 0) or 0), | |
| config_id=config_id, | |
| block_size=block_size, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| shape_json=row.get("shape_json", "{}"), | |
| config_json=row.get("config_json", "{}"), | |
| median_ms=float(row["median_ms"]), | |
| effective_gbps=float(row["effective_gbps"]), | |
| score=float(row["score"]), | |
| validation_error=float(row["validation_error"]), | |
| ) | |
| grouped.setdefault(task_id, []).append(measurement) | |
| for task_id in grouped: | |
| grouped[task_id].sort(key=lambda row: row.config_id) | |
| return grouped | |
| def _config_to_vector(row: Measurement) -> np.ndarray: | |
| family_vec = np.zeros(len(FAMILY_INDEX), dtype=np.float32) | |
| if row.family in FAMILY_INDEX: | |
| family_vec[FAMILY_INDEX[row.family]] = 1.0 | |
| shape_fields = {} | |
| try: | |
| shape_fields = json.loads(row.shape_json) if row.shape_json else {} | |
| except Exception: | |
| shape_fields = {} | |
| mode_val = _shape_scalar(shape_fields.get("mode_id")) | |
| role_val = _shape_scalar(shape_fields.get("role_id")) | |
| seq_val = _shape_scalar(shape_fields.get("seq_len")) | |
| ctx_val = _shape_scalar(shape_fields.get("ctx_len")) | |
| group_val = _shape_scalar(shape_fields.get("group_count")) | |
| numeric = np.array( | |
| [ | |
| math.log2(max(row.m, 1)) / 16.0, | |
| math.log2(max(row.n, 1)) / 16.0, | |
| math.log2(max(row.k, 1)) / 16.0 if row.k > 0 else 0.0, | |
| math.log2(max(seq_val, 1.0)) / 16.0 if seq_val > 0 else 0.0, | |
| math.log2(max(ctx_val, 1.0)) / 16.0 if ctx_val > 0 else 0.0, | |
| math.log2(max(group_val, 1.0)) / 8.0 if group_val > 0 else 0.0, | |
| mode_val / 8.0, | |
| role_val / 16.0, | |
| _normalize_discrete(BLOCK_SIZES, row.block_size), | |
| _normalize_discrete(NUM_WARPS, row.num_warps), | |
| _normalize_discrete(NUM_STAGES, row.num_stages), | |
| ], | |
| dtype=np.float32, | |
| ) | |
| return np.concatenate([family_vec, numeric], axis=0) | |
| def _pairwise_sq_dists(X: np.ndarray) -> np.ndarray: | |
| diff = X[:, None, :] - X[None, :, :] | |
| return np.sum(diff * diff, axis=2) | |
| def _rbf_kernel(X: np.ndarray, Y: np.ndarray, length_scale: float) -> np.ndarray: | |
| sigma2 = float(length_scale * length_scale) | |
| if sigma2 <= 0: | |
| sigma2 = 1e-6 | |
| xy = X @ Y.T | |
| x2 = np.sum(X * X, axis=1)[:, None] | |
| y2 = np.sum(Y * Y, axis=1)[None, :] | |
| d2 = np.maximum(x2 - 2.0 * xy + y2, 0.0) | |
| return np.exp(-0.5 * d2 / sigma2).astype(np.float32) | |
| def _normal_pdf(z: float) -> float: | |
| inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi) | |
| return float(inv_sqrt_2pi * math.exp(-0.5 * z * z)) | |
| def _normal_cdf(z: float) -> float: | |
| return 0.5 * (1.0 + math.erf(z / math.sqrt(2.0))) | |
| def _shape_scalar(value: Any) -> float: | |
| if value is None: | |
| return 0.0 | |
| try: | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return 0.0 | |