| """Vendored PerfSkills-style harness helpers for GEAK Triton tasks.""" |
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import sys |
| import time |
| from typing import Any, Callable, Dict, List |
|
|
| import torch |
|
|
|
|
| def make_argparser(title: str) -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser(description=title) |
| p.add_argument("--correctness", action="store_true") |
| p.add_argument("--benchmark", action="store_true") |
| p.add_argument("--full-benchmark", dest="full_benchmark", action="store_true") |
| return p |
|
|
|
|
| def _bench_once(fn: Callable[[], Any], *, warmup: int, repeat: int) -> float: |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| for _ in range(repeat): |
| fn() |
| torch.cuda.synchronize() |
| return (time.perf_counter() - t0) / repeat * 1000.0 |
|
|
|
|
| def _geomean(xs: List[float]) -> float: |
| xs = [x for x in xs if x > 0 and not math.isnan(x)] |
| if not xs: |
| return float("nan") |
| return math.exp(sum(math.log(x) for x in xs) / len(xs)) |
|
|
|
|
| def run_modes(args: argparse.Namespace, cases: List[Dict[str, Any]]) -> None: |
| if args.correctness: |
| all_ok = True |
| for c in cases: |
| name = c["name"] |
| print(f"CASE={name} checking...") |
| try: |
| ok = bool(c["check"]()) |
| except Exception as e: |
| print(f"CASE={name} exception: {e}") |
| ok = False |
| st = "PASS" if ok else "FAIL" |
| print(f"CASE={name} {st}") |
| all_ok = all_ok and ok |
| if all_ok: |
| print("CORRECTNESS_OVERALL: PASS") |
| sys.exit(0) |
| print("CORRECTNESS_OVERALL: FAIL") |
| sys.exit(1) |
|
|
| if args.benchmark or args.full_benchmark: |
| warmup, repeat = (5, 10) if args.benchmark else (10, 100) |
| lats: List[float] = [] |
| for c in cases: |
| name = c["name"] |
| ms = _bench_once(c["run"], warmup=warmup, repeat=repeat) |
| lats.append(ms) |
| print(f"CASE={name} GEAK_RESULT_LATENCY_MS={ms:.6f}") |
| gm = _geomean(lats) |
| print(f"GEAK_SHAPES_USED={list(range(len(cases)))}") |
| print(f"GEAK_RESULT_LATENCY_MS={gm:.6f}") |
| if args.full_benchmark: |
| print("GEAK_RESULT_SPEEDUP=1.000000") |
| sys.exit(0) |
|
|
| print("No mode selected (--correctness | --benchmark | --full-benchmark).", file=sys.stderr) |
| sys.exit(2) |
|
|