File size: 15,177 Bytes
16dd578 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 | """
Shared evaluator for GPU Mode Triton kernel optimization.
No @triton.jit requirement — pure PyTorch submissions are allowed.
Supports local GPU and Modal cloud GPU evaluation.
Set GPUMODE_USE_MODAL=true and GPUMODE_MODAL_GPU=H100 for Modal.
Scoring: combined_score = SCORE_SCALE / geom_mean_us (higher is better).
The geom_mean_us metric is also reported for absolute runtime tracking.
Each problem provides a reference.py module with:
- ref_kernel(data)
- generate_input(**kwargs)
- check_implementation(data, output) -> (bool, str)
- TEST_CASES: list of dicts
- BENCHMARK_CASES: list of dicts
- SCORE_SCALE: float
Optional benchmark configuration in reference.py:
- BENCH_USE_CUDA_EVENTS: bool (default True)
- BENCH_REL_ERROR: float (default 0.001)
- BENCH_WALL_TIMEOUT_NS: float or None (default 120e9)
- BENCH_NO_GRAD: bool (default False)
- BENCH_MAX_REPEATS: int (default 100)
- BENCH_MAX_TIME_NS: float (default 10e9)
- BENCH_WARMUP_STYLE: str ('tiny_benchmark' or 'timed_calls', default 'tiny_benchmark')
"""
import os
import sys
import copy
import time
import math
import contextlib
import dataclasses
import traceback
import importlib.util
import torch
import torch.cuda
from skydiscover.evaluation.evaluation_result import EvaluationResult
# Import problem-specific reference (the problem dir is already on sys.path
# because SkyDiscover adds the evaluator file's directory before loading it).
import reference
# ---------------------------------------------------------------------------
# Environment configuration
# ---------------------------------------------------------------------------
USE_MODAL = os.environ.get("GPUMODE_USE_MODAL", "false").lower() == "true"
MODAL_GPU = os.environ.get("GPUMODE_MODAL_GPU", "H100")
# Read benchmark configuration from reference module with defaults
SCORE_SCALE = getattr(reference, 'SCORE_SCALE', 3000.0)
BENCH_USE_CUDA_EVENTS = getattr(reference, 'BENCH_USE_CUDA_EVENTS', True)
BENCH_REL_ERROR = getattr(reference, 'BENCH_REL_ERROR', 0.001)
BENCH_WALL_TIMEOUT_NS = getattr(reference, 'BENCH_WALL_TIMEOUT_NS', 120e9)
BENCH_NO_GRAD = getattr(reference, 'BENCH_NO_GRAD', False)
BENCH_MAX_REPEATS = getattr(reference, 'BENCH_MAX_REPEATS', 100)
BENCH_MAX_TIME_NS = getattr(reference, 'BENCH_MAX_TIME_NS', 10e9)
BENCH_WARMUP_STYLE = getattr(reference, 'BENCH_WARMUP_STYLE', 'tiny_benchmark')
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _clone(data):
"""Recursively clone data, handling tensors, dataclasses, and nn.Modules."""
if isinstance(data, tuple):
return tuple(_clone(x) for x in data)
if isinstance(data, list):
return [_clone(x) for x in data]
if isinstance(data, dict):
return {k: _clone(v) for k, v in data.items()}
if isinstance(data, torch.Tensor):
return data.clone()
if dataclasses.is_dataclass(data) and not isinstance(data, type):
fields = {f.name: _clone(getattr(data, f.name)) for f in dataclasses.fields(data)}
return type(data)(**fields)
if isinstance(data, torch.nn.Module):
cloned = copy.deepcopy(data)
if hasattr(data, 'seq_len'):
cloned.seq_len = data.seq_len
return cloned
return data
def _stats(durations):
"""Compute statistics from a list of durations (in nanoseconds)."""
n = len(durations)
avg = sum(durations) / n
if n > 1:
var = sum((x - avg) ** 2 for x in durations) / (n - 1)
std = math.sqrt(var)
err = std / math.sqrt(n)
else:
std, err = 0.0, 0.0
return {"runs": n, "mean": avg, "std": std, "err": err}
def _warmup(kernel_fn, bench_args):
"""Warmup the kernel to trigger Triton compilation."""
if BENCH_WARMUP_STYLE == 'timed_calls':
# MLA-style: run repeatedly for 200ms
data = reference.generate_input(**bench_args)
start = time.perf_counter()
while time.perf_counter() - start < 0.2:
kernel_fn(data)
torch.cuda.synchronize()
else:
# trimul-style: run first benchmark with tiny time budget (10ms)
_bench_single(kernel_fn, bench_args, max_time_ns=10e7)
def _bench_single(kernel_fn, bench_args, max_time_ns=None):
"""Benchmark a kernel on a single case.
Returns (stats_dict_or_None, error_str_or_None).
Stats dict has durations in nanoseconds.
"""
if max_time_ns is None:
max_time_ns = BENCH_MAX_TIME_NS
data = reference.generate_input(**bench_args)
data_copy = _clone(data)
# Correctness check first
ctx = torch.no_grad() if BENCH_NO_GRAD else contextlib.nullcontext()
with ctx:
output = kernel_fn(data)
torch.cuda.synchronize()
passed, msg = reference.check_implementation(data_copy, output)
if not passed:
return None, f"Benchmark correctness: {msg}"
del output
# Timed runs — durations in nanoseconds
durations_ns = []
bm_start = time.perf_counter_ns()
with ctx:
for i in range(BENCH_MAX_REPEATS):
torch.cuda.synchronize()
if BENCH_USE_CUDA_EVENTS:
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
output = kernel_fn(data)
e.record()
torch.cuda.synchronize()
duration_ns = s.elapsed_time(e) * 1e6 # ms -> ns
else:
start_ns = time.perf_counter_ns()
output = kernel_fn(data)
torch.cuda.synchronize()
duration_ns = time.perf_counter_ns() - start_ns
del output
durations_ns.append(duration_ns)
if i > 1:
st = _stats(durations_ns)
if st["mean"] > 0 and st["err"] / st["mean"] < BENCH_REL_ERROR:
break
if st["mean"] * st["runs"] > max_time_ns:
break
if BENCH_WALL_TIMEOUT_NS is not None and \
(time.perf_counter_ns() - bm_start) > BENCH_WALL_TIMEOUT_NS:
break
return _stats(durations_ns), None
# ---------------------------------------------------------------------------
# Modal path
# ---------------------------------------------------------------------------
def _evaluate_modal(submission_code):
parent_dir = os.path.dirname(os.path.abspath(__file__))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from modal_eval import (
eval_triton_h100, eval_triton_a100, eval_triton_l40s, eval_triton_t4,
eval_triton_h200, app as modal_app,
)
gpu_fns = {
"H100": eval_triton_h100,
"A100": eval_triton_a100,
"L40S": eval_triton_l40s,
"T4": eval_triton_t4,
"H200": eval_triton_h200,
}
eval_fn = gpu_fns.get(MODAL_GPU, eval_triton_h100)
ref_code = getattr(reference, 'MODAL_REFERENCE_CODE', None)
if ref_code is None:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={"error": "MODAL_REFERENCE_CODE not defined in reference.py",
"failure_stage": "modal_setup"},
)
with modal_app.run():
result = eval_fn.remote(
submission_code=submission_code,
reference_code=ref_code,
test_cases=reference.TEST_CASES,
benchmark_cases=reference.BENCHMARK_CASES,
score_scale=SCORE_SCALE,
bench_use_cuda_events=BENCH_USE_CUDA_EVENTS,
bench_rel_error=BENCH_REL_ERROR,
bench_wall_timeout_ns=BENCH_WALL_TIMEOUT_NS,
bench_no_grad=BENCH_NO_GRAD,
bench_max_repeats=BENCH_MAX_REPEATS,
bench_max_time_ns=BENCH_MAX_TIME_NS,
bench_warmup_style=BENCH_WARMUP_STYLE,
)
if isinstance(result, dict):
error = result.get("error")
score = float(result.get("combined_score", 0.0))
metrics = {"combined_score": score, "correctness": float(result.get("correctness", 0.0))}
if "geom_mean_us" in result:
metrics["geom_mean_us"] = float(result["geom_mean_us"])
artifacts = {}
if error:
artifacts["error"] = str(error)
artifacts["failure_stage"] = "modal_eval"
if "bench_means_us" in result:
for i, us in enumerate(result["bench_means_us"]):
artifacts[f"bench_{i}_mean_us"] = f"{us:.2f}"
artifacts["hardware"] = MODAL_GPU
return EvaluationResult(metrics=metrics, artifacts=artifacts)
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={"error": "Modal returned unexpected type", "failure_stage": "modal_eval"},
)
# ---------------------------------------------------------------------------
# Local path
# ---------------------------------------------------------------------------
def _evaluate_local(program_path):
try:
spec = importlib.util.spec_from_file_location("submission", program_path)
mod = importlib.util.module_from_spec(spec)
sys.modules["submission"] = mod
spec.loader.exec_module(mod)
custom_kernel = mod.custom_kernel
except Exception as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={
"error": f"Failed to load submission: {exc}",
"traceback": traceback.format_exc(),
"failure_stage": "import",
},
)
# Correctness
for i, tc in enumerate(reference.TEST_CASES):
try:
data = reference.generate_input(**tc)
data_copy = _clone(data)
torch.cuda.synchronize()
output = custom_kernel(data)
torch.cuda.synchronize()
passed, msg = reference.check_implementation(data_copy, output)
if not passed:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={
"error": f"Test {i} failed: {msg}",
"failure_stage": "correctness",
"test_index": str(i),
},
)
except Exception as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={
"error": f"Test {i} error: {exc}",
"traceback": traceback.format_exc(),
"failure_stage": "correctness",
"test_index": str(i),
},
)
# Warmup
_warmup(custom_kernel, reference.BENCHMARK_CASES[0])
# Benchmarks — collect mean runtimes in nanoseconds
bench_means_ns = []
for bench_args in reference.BENCHMARK_CASES:
st, err = _bench_single(custom_kernel, bench_args)
if err:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 1.0},
artifacts={"error": err, "failure_stage": "benchmark"},
)
bench_means_ns.append(st["mean"])
# Scoring: geometric mean of benchmark means → microseconds → score
means_seconds = [ns / 1e9 for ns in bench_means_ns]
geom_mean_s = math.pow(math.prod(means_seconds), 1.0 / len(means_seconds))
geom_mean_us = geom_mean_s * 1e6
score = SCORE_SCALE / geom_mean_us
metrics = {
"combined_score": score,
"correctness": 1.0,
"geom_mean_us": geom_mean_us,
}
artifacts = {
"hardware": "local",
}
for i, ns in enumerate(bench_means_ns):
artifacts[f"bench_{i}_mean_us"] = f"{ns / 1e3:.2f}"
return EvaluationResult(
metrics=metrics,
artifacts=artifacts,
)
# ---------------------------------------------------------------------------
# Public API (used by SkyDiscover)
# ---------------------------------------------------------------------------
def evaluate(program_path):
try:
with open(program_path, "r") as f:
code = f.read()
except Exception as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={"error": f"Failed to read file: {exc}", "failure_stage": "file_read"},
)
if USE_MODAL:
try:
return _evaluate_modal(code)
except Exception as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "correctness": 0.0},
artifacts={
"error": f"Modal evaluation failed: {exc}",
"traceback": traceback.format_exc(),
"failure_stage": "modal_eval",
},
)
return _evaluate_local(program_path)
def evaluate_stage1(program_path):
try:
with open(program_path, "r") as f:
code = f.read()
except Exception as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "stage1_passed": 0.0},
artifacts={"error": f"Failed to read file: {exc}", "failure_stage": "file_read"},
)
if "custom_kernel" not in code:
return EvaluationResult(
metrics={"combined_score": 0.0, "stage1_passed": 0.0},
artifacts={"error": "Missing custom_kernel function", "failure_stage": "validation"},
)
try:
compile(code, program_path, "exec")
except SyntaxError as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "stage1_passed": 0.0},
artifacts={
"error": f"Syntax error at line {exc.lineno}: {exc.msg}",
"failure_stage": "syntax_check",
},
)
# When using Modal, skip local import check (triton may not be installed locally).
if not USE_MODAL:
try:
spec = importlib.util.spec_from_file_location("submission_check", program_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
if not hasattr(mod, "custom_kernel"):
return EvaluationResult(
metrics={"combined_score": 0.0, "stage1_passed": 0.0},
artifacts={"error": "custom_kernel not found after import", "failure_stage": "import"},
)
except Exception as exc:
return EvaluationResult(
metrics={"combined_score": 0.0, "stage1_passed": 0.0},
artifacts={
"error": f"Import failed: {exc}",
"traceback": traceback.format_exc(),
"failure_stage": "import",
},
)
return EvaluationResult(
metrics={"combined_score": 0.5, "stage1_passed": 1.0},
artifacts={},
)
def evaluate_stage2(program_path):
return evaluate(program_path)
|