avo_test_cases / scripts /task_runner.py
jiliu1's picture
Upload folder using huggingface_hub
15c2580 verified
Raw
History Blame Contribute Delete
3.82 kB
#!/usr/bin/env python3
"""GEAK task runner — Triton kernel tasks."""
import argparse
import json
import os
import re
import subprocess
import sys
TASK_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.chdir(TASK_DIR)
BUILD_DIR = os.path.join(TASK_DIR, "build")
KERNEL_DIR = os.path.join(TASK_DIR, "kernel")
HARNESS = os.path.join(KERNEL_DIR, "test_harness.py")
def _run_harness(argv):
return subprocess.run(
[sys.executable, HARNESS] + argv,
cwd=KERNEL_DIR,
capture_output=True,
text=True,
timeout=7200,
)
def run_compile():
try:
p = subprocess.run(
[
sys.executable,
"-c",
"import sys; sys.path.insert(0, 'kernel'); import host, kernel_jit",
],
cwd=TASK_DIR,
capture_output=True,
text=True,
timeout=600,
)
if p.returncode != 0:
return False, (p.stderr or "") + "\n" + (p.stdout or "")
return True, None
except Exception as e:
return False, str(e)
def run_correctness():
p = _run_harness(["--correctness"])
out = (p.stdout or "") + (p.stderr or "")
ok = p.returncode == 0 and "CORRECTNESS_OVERALL: PASS" in out
if not ok:
tail = out[-8000:] if len(out) > 8000 else out
return False, tail or f"exit {p.returncode}"
return True, None
def run_performance():
p = _run_harness(["--benchmark"])
out = (p.stdout or "") + (p.stderr or "")
rows = []
for m in re.finditer(r"CASE=(\S+)\s+GEAK_RESULT_LATENCY_MS=([0-9.eE+-]+)", out):
rows.append(
{
"test_case_id": m.group(1),
"execution_time_ms": float(m.group(2)),
"params": {},
}
)
if not rows:
m = re.search(r"GEAK_RESULT_LATENCY_MS=([0-9.eE+-]+)", out)
if m:
rows.append(
{
"test_case_id": "aggregate",
"execution_time_ms": float(m.group(1)),
"params": {},
}
)
ok = p.returncode == 0 and rows and all(r.get("execution_time_ms", -1) >= 0 for r in rows)
if not ok:
return [
{
"test_case_id": "_error",
"execution_time_ms": -1.0,
"params": {"stderr": out[-4000:]},
}
]
return rows
def main():
ap = argparse.ArgumentParser()
ap.add_argument("mode", choices=["compile", "correctness", "performance"])
args = ap.parse_args()
os.makedirs(BUILD_DIR, exist_ok=True)
if args.mode == "compile":
ok, err = run_compile()
with open(os.path.join(BUILD_DIR, "compile_report.json"), "w") as f:
json.dump({"status": "ok" if ok else "fail", "error": err}, f)
print(f"Compilation: {'PASS' if ok else 'FAIL'}")
if err:
print("Error:", err)
sys.exit(0 if ok else 1)
if args.mode == "correctness":
ok, err = run_correctness()
with open(os.path.join(BUILD_DIR, "correctness_report.json"), "w") as f:
json.dump({"status": "ok" if ok else "fail", "error": err}, f)
print(f"Correctness: {'PASS' if ok else 'FAIL'}")
if err:
print("Error:", err)
sys.exit(0 if ok else 1)
cases = run_performance()
with open(os.path.join(BUILD_DIR, "performance_report.json"), "w") as f:
json.dump({"test_cases": cases}, f, indent=2)
for c in cases:
ms = c.get("execution_time_ms", -1)
if ms >= 0:
print(f"Performance: {ms:.4f} ms ({c['test_case_id']})")
bad = any(c.get("execution_time_ms", -1) < 0 for c in cases)
sys.exit(1 if bad else 0)
if __name__ == "__main__":
main()