| |
| """GEAK task runner — Triton kernel tasks.""" |
| import argparse |
| import json |
| import os |
| import re |
| import subprocess |
| import sys |
|
|
| TASK_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| os.chdir(TASK_DIR) |
| BUILD_DIR = os.path.join(TASK_DIR, "build") |
| KERNEL_DIR = os.path.join(TASK_DIR, "kernel") |
| HARNESS = os.path.join(KERNEL_DIR, "test_harness.py") |
|
|
|
|
| def _run_harness(argv): |
| return subprocess.run( |
| [sys.executable, HARNESS] + argv, |
| cwd=KERNEL_DIR, |
| capture_output=True, |
| text=True, |
| timeout=7200, |
| ) |
|
|
|
|
| def run_compile(): |
| try: |
| p = subprocess.run( |
| [ |
| sys.executable, |
| "-c", |
| "import sys; sys.path.insert(0, 'kernel'); import host, kernel_jit", |
| ], |
| cwd=TASK_DIR, |
| capture_output=True, |
| text=True, |
| timeout=600, |
| ) |
| if p.returncode != 0: |
| return False, (p.stderr or "") + "\n" + (p.stdout or "") |
| return True, None |
| except Exception as e: |
| return False, str(e) |
|
|
|
|
| def run_correctness(): |
| p = _run_harness(["--correctness"]) |
| out = (p.stdout or "") + (p.stderr or "") |
| ok = p.returncode == 0 and "CORRECTNESS_OVERALL: PASS" in out |
| if not ok: |
| tail = out[-8000:] if len(out) > 8000 else out |
| return False, tail or f"exit {p.returncode}" |
| return True, None |
|
|
|
|
| def run_performance(): |
| p = _run_harness(["--benchmark"]) |
| out = (p.stdout or "") + (p.stderr or "") |
| rows = [] |
| for m in re.finditer(r"CASE=(\S+)\s+GEAK_RESULT_LATENCY_MS=([0-9.eE+-]+)", out): |
| rows.append( |
| { |
| "test_case_id": m.group(1), |
| "execution_time_ms": float(m.group(2)), |
| "params": {}, |
| } |
| ) |
| if not rows: |
| m = re.search(r"GEAK_RESULT_LATENCY_MS=([0-9.eE+-]+)", out) |
| if m: |
| rows.append( |
| { |
| "test_case_id": "aggregate", |
| "execution_time_ms": float(m.group(1)), |
| "params": {}, |
| } |
| ) |
| ok = p.returncode == 0 and rows and all(r.get("execution_time_ms", -1) >= 0 for r in rows) |
| if not ok: |
| return [ |
| { |
| "test_case_id": "_error", |
| "execution_time_ms": -1.0, |
| "params": {"stderr": out[-4000:]}, |
| } |
| ] |
| return rows |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("mode", choices=["compile", "correctness", "performance"]) |
| args = ap.parse_args() |
| os.makedirs(BUILD_DIR, exist_ok=True) |
| if args.mode == "compile": |
| ok, err = run_compile() |
| with open(os.path.join(BUILD_DIR, "compile_report.json"), "w") as f: |
| json.dump({"status": "ok" if ok else "fail", "error": err}, f) |
| print(f"Compilation: {'PASS' if ok else 'FAIL'}") |
| if err: |
| print("Error:", err) |
| sys.exit(0 if ok else 1) |
| if args.mode == "correctness": |
| ok, err = run_correctness() |
| with open(os.path.join(BUILD_DIR, "correctness_report.json"), "w") as f: |
| json.dump({"status": "ok" if ok else "fail", "error": err}, f) |
| print(f"Correctness: {'PASS' if ok else 'FAIL'}") |
| if err: |
| print("Error:", err) |
| sys.exit(0 if ok else 1) |
| cases = run_performance() |
| with open(os.path.join(BUILD_DIR, "performance_report.json"), "w") as f: |
| json.dump({"test_cases": cases}, f, indent=2) |
| for c in cases: |
| ms = c.get("execution_time_ms", -1) |
| if ms >= 0: |
| print(f"Performance: {ms:.4f} ms ({c['test_case_id']})") |
| bad = any(c.get("execution_time_ms", -1) < 0 for c in cases) |
| sys.exit(1 if bad else 0) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|