"""Vendored PerfSkills-style harness helpers for GEAK Triton tasks.""" from __future__ import annotations import argparse import math import sys import time from typing import Any, Callable, Dict, List import torch def make_argparser(title: str) -> argparse.ArgumentParser: p = argparse.ArgumentParser(description=title) p.add_argument("--correctness", action="store_true") p.add_argument("--benchmark", action="store_true") p.add_argument("--full-benchmark", dest="full_benchmark", action="store_true") return p def _bench_once(fn: Callable[[], Any], *, warmup: int, repeat: int) -> float: for _ in range(warmup): fn() torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(repeat): fn() torch.cuda.synchronize() return (time.perf_counter() - t0) / repeat * 1000.0 def _geomean(xs: List[float]) -> float: xs = [x for x in xs if x > 0 and not math.isnan(x)] if not xs: return float("nan") return math.exp(sum(math.log(x) for x in xs) / len(xs)) def run_modes(args: argparse.Namespace, cases: List[Dict[str, Any]]) -> None: if args.correctness: all_ok = True for c in cases: name = c["name"] print(f"CASE={name} checking...") try: ok = bool(c["check"]()) except Exception as e: print(f"CASE={name} exception: {e}") ok = False st = "PASS" if ok else "FAIL" print(f"CASE={name} {st}") all_ok = all_ok and ok if all_ok: print("CORRECTNESS_OVERALL: PASS") sys.exit(0) print("CORRECTNESS_OVERALL: FAIL") sys.exit(1) if args.benchmark or args.full_benchmark: warmup, repeat = (5, 10) if args.benchmark else (10, 100) lats: List[float] = [] for c in cases: name = c["name"] ms = _bench_once(c["run"], warmup=warmup, repeat=repeat) lats.append(ms) print(f"CASE={name} GEAK_RESULT_LATENCY_MS={ms:.6f}") gm = _geomean(lats) print(f"GEAK_SHAPES_USED={list(range(len(cases)))}") print(f"GEAK_RESULT_LATENCY_MS={gm:.6f}") if args.full_benchmark: print("GEAK_RESULT_SPEEDUP=1.000000") sys.exit(0) print("No mode selected (--correctness | --benchmark | --full-benchmark).", file=sys.stderr) sys.exit(2)