| """ |
| Shared evaluator for GPU Mode Triton kernel optimization. |
| |
| No @triton.jit requirement — pure PyTorch submissions are allowed. |
| Supports local GPU and Modal cloud GPU evaluation. |
| Set GPUMODE_USE_MODAL=true and GPUMODE_MODAL_GPU=H100 for Modal. |
| |
| Scoring: combined_score = SCORE_SCALE / geom_mean_us (higher is better). |
| The geom_mean_us metric is also reported for absolute runtime tracking. |
| |
| Each problem provides a reference.py module with: |
| - ref_kernel(data) |
| - generate_input(**kwargs) |
| - check_implementation(data, output) -> (bool, str) |
| - TEST_CASES: list of dicts |
| - BENCHMARK_CASES: list of dicts |
| - SCORE_SCALE: float |
| |
| Optional benchmark configuration in reference.py: |
| - BENCH_USE_CUDA_EVENTS: bool (default True) |
| - BENCH_REL_ERROR: float (default 0.001) |
| - BENCH_WALL_TIMEOUT_NS: float or None (default 120e9) |
| - BENCH_NO_GRAD: bool (default False) |
| - BENCH_MAX_REPEATS: int (default 100) |
| - BENCH_MAX_TIME_NS: float (default 10e9) |
| - BENCH_WARMUP_STYLE: str ('tiny_benchmark' or 'timed_calls', default 'tiny_benchmark') |
| """ |
|
|
| import os |
| import sys |
| import copy |
| import time |
| import math |
| import contextlib |
| import dataclasses |
| import traceback |
| import importlib.util |
|
|
| import torch |
| import torch.cuda |
|
|
| from skydiscover.evaluation.evaluation_result import EvaluationResult |
|
|
| |
| |
| import reference |
|
|
| |
| |
| |
|
|
| USE_MODAL = os.environ.get("GPUMODE_USE_MODAL", "false").lower() == "true" |
| MODAL_GPU = os.environ.get("GPUMODE_MODAL_GPU", "H100") |
|
|
| |
| SCORE_SCALE = getattr(reference, 'SCORE_SCALE', 3000.0) |
| BENCH_USE_CUDA_EVENTS = getattr(reference, 'BENCH_USE_CUDA_EVENTS', True) |
| BENCH_REL_ERROR = getattr(reference, 'BENCH_REL_ERROR', 0.001) |
| BENCH_WALL_TIMEOUT_NS = getattr(reference, 'BENCH_WALL_TIMEOUT_NS', 120e9) |
| BENCH_NO_GRAD = getattr(reference, 'BENCH_NO_GRAD', False) |
| BENCH_MAX_REPEATS = getattr(reference, 'BENCH_MAX_REPEATS', 100) |
| BENCH_MAX_TIME_NS = getattr(reference, 'BENCH_MAX_TIME_NS', 10e9) |
| BENCH_WARMUP_STYLE = getattr(reference, 'BENCH_WARMUP_STYLE', 'tiny_benchmark') |
|
|
| |
| |
| |
|
|
|
|
| def _clone(data): |
| """Recursively clone data, handling tensors, dataclasses, and nn.Modules.""" |
| if isinstance(data, tuple): |
| return tuple(_clone(x) for x in data) |
| if isinstance(data, list): |
| return [_clone(x) for x in data] |
| if isinstance(data, dict): |
| return {k: _clone(v) for k, v in data.items()} |
| if isinstance(data, torch.Tensor): |
| return data.clone() |
| if dataclasses.is_dataclass(data) and not isinstance(data, type): |
| fields = {f.name: _clone(getattr(data, f.name)) for f in dataclasses.fields(data)} |
| return type(data)(**fields) |
| if isinstance(data, torch.nn.Module): |
| cloned = copy.deepcopy(data) |
| if hasattr(data, 'seq_len'): |
| cloned.seq_len = data.seq_len |
| return cloned |
| return data |
|
|
|
|
| def _stats(durations): |
| """Compute statistics from a list of durations (in nanoseconds).""" |
| n = len(durations) |
| avg = sum(durations) / n |
| if n > 1: |
| var = sum((x - avg) ** 2 for x in durations) / (n - 1) |
| std = math.sqrt(var) |
| err = std / math.sqrt(n) |
| else: |
| std, err = 0.0, 0.0 |
| return {"runs": n, "mean": avg, "std": std, "err": err} |
|
|
|
|
| def _warmup(kernel_fn, bench_args): |
| """Warmup the kernel to trigger Triton compilation.""" |
| if BENCH_WARMUP_STYLE == 'timed_calls': |
| |
| data = reference.generate_input(**bench_args) |
| start = time.perf_counter() |
| while time.perf_counter() - start < 0.2: |
| kernel_fn(data) |
| torch.cuda.synchronize() |
| else: |
| |
| _bench_single(kernel_fn, bench_args, max_time_ns=10e7) |
|
|
|
|
| def _bench_single(kernel_fn, bench_args, max_time_ns=None): |
| """Benchmark a kernel on a single case. |
| |
| Returns (stats_dict_or_None, error_str_or_None). |
| Stats dict has durations in nanoseconds. |
| """ |
| if max_time_ns is None: |
| max_time_ns = BENCH_MAX_TIME_NS |
|
|
| data = reference.generate_input(**bench_args) |
| data_copy = _clone(data) |
|
|
| |
| ctx = torch.no_grad() if BENCH_NO_GRAD else contextlib.nullcontext() |
| with ctx: |
| output = kernel_fn(data) |
| torch.cuda.synchronize() |
| passed, msg = reference.check_implementation(data_copy, output) |
| if not passed: |
| return None, f"Benchmark correctness: {msg}" |
| del output |
|
|
| |
| durations_ns = [] |
| bm_start = time.perf_counter_ns() |
|
|
| with ctx: |
| for i in range(BENCH_MAX_REPEATS): |
| torch.cuda.synchronize() |
|
|
| if BENCH_USE_CUDA_EVENTS: |
| s = torch.cuda.Event(enable_timing=True) |
| e = torch.cuda.Event(enable_timing=True) |
| s.record() |
| output = kernel_fn(data) |
| e.record() |
| torch.cuda.synchronize() |
| duration_ns = s.elapsed_time(e) * 1e6 |
| else: |
| start_ns = time.perf_counter_ns() |
| output = kernel_fn(data) |
| torch.cuda.synchronize() |
| duration_ns = time.perf_counter_ns() - start_ns |
|
|
| del output |
| durations_ns.append(duration_ns) |
|
|
| if i > 1: |
| st = _stats(durations_ns) |
| if st["mean"] > 0 and st["err"] / st["mean"] < BENCH_REL_ERROR: |
| break |
| if st["mean"] * st["runs"] > max_time_ns: |
| break |
| if BENCH_WALL_TIMEOUT_NS is not None and \ |
| (time.perf_counter_ns() - bm_start) > BENCH_WALL_TIMEOUT_NS: |
| break |
|
|
| return _stats(durations_ns), None |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _evaluate_modal(submission_code): |
| parent_dir = os.path.dirname(os.path.abspath(__file__)) |
| if parent_dir not in sys.path: |
| sys.path.insert(0, parent_dir) |
| from modal_eval import ( |
| eval_triton_h100, eval_triton_a100, eval_triton_l40s, eval_triton_t4, |
| eval_triton_h200, app as modal_app, |
| ) |
|
|
| gpu_fns = { |
| "H100": eval_triton_h100, |
| "A100": eval_triton_a100, |
| "L40S": eval_triton_l40s, |
| "T4": eval_triton_t4, |
| "H200": eval_triton_h200, |
| } |
| eval_fn = gpu_fns.get(MODAL_GPU, eval_triton_h100) |
|
|
| ref_code = getattr(reference, 'MODAL_REFERENCE_CODE', None) |
| if ref_code is None: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={"error": "MODAL_REFERENCE_CODE not defined in reference.py", |
| "failure_stage": "modal_setup"}, |
| ) |
|
|
| with modal_app.run(): |
| result = eval_fn.remote( |
| submission_code=submission_code, |
| reference_code=ref_code, |
| test_cases=reference.TEST_CASES, |
| benchmark_cases=reference.BENCHMARK_CASES, |
| score_scale=SCORE_SCALE, |
| bench_use_cuda_events=BENCH_USE_CUDA_EVENTS, |
| bench_rel_error=BENCH_REL_ERROR, |
| bench_wall_timeout_ns=BENCH_WALL_TIMEOUT_NS, |
| bench_no_grad=BENCH_NO_GRAD, |
| bench_max_repeats=BENCH_MAX_REPEATS, |
| bench_max_time_ns=BENCH_MAX_TIME_NS, |
| bench_warmup_style=BENCH_WARMUP_STYLE, |
| ) |
|
|
| if isinstance(result, dict): |
| error = result.get("error") |
| score = float(result.get("combined_score", 0.0)) |
| metrics = {"combined_score": score, "correctness": float(result.get("correctness", 0.0))} |
| if "geom_mean_us" in result: |
| metrics["geom_mean_us"] = float(result["geom_mean_us"]) |
| artifacts = {} |
| if error: |
| artifacts["error"] = str(error) |
| artifacts["failure_stage"] = "modal_eval" |
| if "bench_means_us" in result: |
| for i, us in enumerate(result["bench_means_us"]): |
| artifacts[f"bench_{i}_mean_us"] = f"{us:.2f}" |
| artifacts["hardware"] = MODAL_GPU |
| return EvaluationResult(metrics=metrics, artifacts=artifacts) |
|
|
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={"error": "Modal returned unexpected type", "failure_stage": "modal_eval"}, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _evaluate_local(program_path): |
| try: |
| spec = importlib.util.spec_from_file_location("submission", program_path) |
| mod = importlib.util.module_from_spec(spec) |
| sys.modules["submission"] = mod |
| spec.loader.exec_module(mod) |
| custom_kernel = mod.custom_kernel |
| except Exception as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={ |
| "error": f"Failed to load submission: {exc}", |
| "traceback": traceback.format_exc(), |
| "failure_stage": "import", |
| }, |
| ) |
|
|
| |
| for i, tc in enumerate(reference.TEST_CASES): |
| try: |
| data = reference.generate_input(**tc) |
| data_copy = _clone(data) |
| torch.cuda.synchronize() |
| output = custom_kernel(data) |
| torch.cuda.synchronize() |
| passed, msg = reference.check_implementation(data_copy, output) |
| if not passed: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={ |
| "error": f"Test {i} failed: {msg}", |
| "failure_stage": "correctness", |
| "test_index": str(i), |
| }, |
| ) |
| except Exception as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={ |
| "error": f"Test {i} error: {exc}", |
| "traceback": traceback.format_exc(), |
| "failure_stage": "correctness", |
| "test_index": str(i), |
| }, |
| ) |
|
|
| |
| _warmup(custom_kernel, reference.BENCHMARK_CASES[0]) |
|
|
| |
| bench_means_ns = [] |
| for bench_args in reference.BENCHMARK_CASES: |
| st, err = _bench_single(custom_kernel, bench_args) |
| if err: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 1.0}, |
| artifacts={"error": err, "failure_stage": "benchmark"}, |
| ) |
| bench_means_ns.append(st["mean"]) |
|
|
| |
| means_seconds = [ns / 1e9 for ns in bench_means_ns] |
| geom_mean_s = math.pow(math.prod(means_seconds), 1.0 / len(means_seconds)) |
| geom_mean_us = geom_mean_s * 1e6 |
| score = SCORE_SCALE / geom_mean_us |
|
|
| metrics = { |
| "combined_score": score, |
| "correctness": 1.0, |
| "geom_mean_us": geom_mean_us, |
| } |
| artifacts = { |
| "hardware": "local", |
| } |
| for i, ns in enumerate(bench_means_ns): |
| artifacts[f"bench_{i}_mean_us"] = f"{ns / 1e3:.2f}" |
|
|
| return EvaluationResult( |
| metrics=metrics, |
| artifacts=artifacts, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def evaluate(program_path): |
| try: |
| with open(program_path, "r") as f: |
| code = f.read() |
| except Exception as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={"error": f"Failed to read file: {exc}", "failure_stage": "file_read"}, |
| ) |
|
|
| if USE_MODAL: |
| try: |
| return _evaluate_modal(code) |
| except Exception as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "correctness": 0.0}, |
| artifacts={ |
| "error": f"Modal evaluation failed: {exc}", |
| "traceback": traceback.format_exc(), |
| "failure_stage": "modal_eval", |
| }, |
| ) |
|
|
| return _evaluate_local(program_path) |
|
|
|
|
| def evaluate_stage1(program_path): |
| try: |
| with open(program_path, "r") as f: |
| code = f.read() |
| except Exception as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "stage1_passed": 0.0}, |
| artifacts={"error": f"Failed to read file: {exc}", "failure_stage": "file_read"}, |
| ) |
|
|
| if "custom_kernel" not in code: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "stage1_passed": 0.0}, |
| artifacts={"error": "Missing custom_kernel function", "failure_stage": "validation"}, |
| ) |
|
|
| try: |
| compile(code, program_path, "exec") |
| except SyntaxError as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "stage1_passed": 0.0}, |
| artifacts={ |
| "error": f"Syntax error at line {exc.lineno}: {exc.msg}", |
| "failure_stage": "syntax_check", |
| }, |
| ) |
|
|
| |
| if not USE_MODAL: |
| try: |
| spec = importlib.util.spec_from_file_location("submission_check", program_path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| if not hasattr(mod, "custom_kernel"): |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "stage1_passed": 0.0}, |
| artifacts={"error": "custom_kernel not found after import", "failure_stage": "import"}, |
| ) |
| except Exception as exc: |
| return EvaluationResult( |
| metrics={"combined_score": 0.0, "stage1_passed": 0.0}, |
| artifacts={ |
| "error": f"Import failed: {exc}", |
| "traceback": traceback.format_exc(), |
| "failure_stage": "import", |
| }, |
| ) |
|
|
| return EvaluationResult( |
| metrics={"combined_score": 0.5, "stage1_passed": 1.0}, |
| artifacts={}, |
| ) |
|
|
|
|
| def evaluate_stage2(program_path): |
| return evaluate(program_path) |
|
|