avo_test_cases / scripts /_bench_common.py
jiliu1's picture
Upload folder using huggingface_hub
15c2580 verified
Raw
History Blame Contribute Delete
2.43 kB
"""Vendored PerfSkills-style harness helpers for GEAK Triton tasks."""
from __future__ import annotations
import argparse
import math
import sys
import time
from typing import Any, Callable, Dict, List
import torch
def make_argparser(title: str) -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description=title)
p.add_argument("--correctness", action="store_true")
p.add_argument("--benchmark", action="store_true")
p.add_argument("--full-benchmark", dest="full_benchmark", action="store_true")
return p
def _bench_once(fn: Callable[[], Any], *, warmup: int, repeat: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - t0) / repeat * 1000.0
def _geomean(xs: List[float]) -> float:
xs = [x for x in xs if x > 0 and not math.isnan(x)]
if not xs:
return float("nan")
return math.exp(sum(math.log(x) for x in xs) / len(xs))
def run_modes(args: argparse.Namespace, cases: List[Dict[str, Any]]) -> None:
if args.correctness:
all_ok = True
for c in cases:
name = c["name"]
print(f"CASE={name} checking...")
try:
ok = bool(c["check"]())
except Exception as e:
print(f"CASE={name} exception: {e}")
ok = False
st = "PASS" if ok else "FAIL"
print(f"CASE={name} {st}")
all_ok = all_ok and ok
if all_ok:
print("CORRECTNESS_OVERALL: PASS")
sys.exit(0)
print("CORRECTNESS_OVERALL: FAIL")
sys.exit(1)
if args.benchmark or args.full_benchmark:
warmup, repeat = (5, 10) if args.benchmark else (10, 100)
lats: List[float] = []
for c in cases:
name = c["name"]
ms = _bench_once(c["run"], warmup=warmup, repeat=repeat)
lats.append(ms)
print(f"CASE={name} GEAK_RESULT_LATENCY_MS={ms:.6f}")
gm = _geomean(lats)
print(f"GEAK_SHAPES_USED={list(range(len(cases)))}")
print(f"GEAK_RESULT_LATENCY_MS={gm:.6f}")
if args.full_benchmark:
print("GEAK_RESULT_SPEEDUP=1.000000")
sys.exit(0)
print("No mode selected (--correctness | --benchmark | --full-benchmark).", file=sys.stderr)
sys.exit(2)