#!/usr/bin/env python3 """GEAK task runner — Triton kernel tasks.""" import argparse import json import os import re import subprocess import sys TASK_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) os.chdir(TASK_DIR) BUILD_DIR = os.path.join(TASK_DIR, "build") KERNEL_DIR = os.path.join(TASK_DIR, "kernel") HARNESS = os.path.join(KERNEL_DIR, "test_harness.py") def _run_harness(argv): return subprocess.run( [sys.executable, HARNESS] + argv, cwd=KERNEL_DIR, capture_output=True, text=True, timeout=7200, ) def run_compile(): try: p = subprocess.run( [ sys.executable, "-c", "import sys; sys.path.insert(0, 'kernel'); import host, kernel_jit", ], cwd=TASK_DIR, capture_output=True, text=True, timeout=600, ) if p.returncode != 0: return False, (p.stderr or "") + "\n" + (p.stdout or "") return True, None except Exception as e: return False, str(e) def run_correctness(): p = _run_harness(["--correctness"]) out = (p.stdout or "") + (p.stderr or "") ok = p.returncode == 0 and "CORRECTNESS_OVERALL: PASS" in out if not ok: tail = out[-8000:] if len(out) > 8000 else out return False, tail or f"exit {p.returncode}" return True, None def run_performance(): p = _run_harness(["--benchmark"]) out = (p.stdout or "") + (p.stderr or "") rows = [] for m in re.finditer(r"CASE=(\S+)\s+GEAK_RESULT_LATENCY_MS=([0-9.eE+-]+)", out): rows.append( { "test_case_id": m.group(1), "execution_time_ms": float(m.group(2)), "params": {}, } ) if not rows: m = re.search(r"GEAK_RESULT_LATENCY_MS=([0-9.eE+-]+)", out) if m: rows.append( { "test_case_id": "aggregate", "execution_time_ms": float(m.group(1)), "params": {}, } ) ok = p.returncode == 0 and rows and all(r.get("execution_time_ms", -1) >= 0 for r in rows) if not ok: return [ { "test_case_id": "_error", "execution_time_ms": -1.0, "params": {"stderr": out[-4000:]}, } ] return rows def main(): ap = argparse.ArgumentParser() ap.add_argument("mode", choices=["compile", "correctness", "performance"]) args = ap.parse_args() os.makedirs(BUILD_DIR, exist_ok=True) if args.mode == "compile": ok, err = run_compile() with open(os.path.join(BUILD_DIR, "compile_report.json"), "w") as f: json.dump({"status": "ok" if ok else "fail", "error": err}, f) print(f"Compilation: {'PASS' if ok else 'FAIL'}") if err: print("Error:", err) sys.exit(0 if ok else 1) if args.mode == "correctness": ok, err = run_correctness() with open(os.path.join(BUILD_DIR, "correctness_report.json"), "w") as f: json.dump({"status": "ok" if ok else "fail", "error": err}, f) print(f"Correctness: {'PASS' if ok else 'FAIL'}") if err: print("Error:", err) sys.exit(0 if ok else 1) cases = run_performance() with open(os.path.join(BUILD_DIR, "performance_report.json"), "w") as f: json.dump({"test_cases": cases}, f, indent=2) for c in cases: ms = c.get("execution_time_ms", -1) if ms >= 0: print(f"Performance: {ms:.4f} ms ({c['test_case_id']})") bad = any(c.get("execution_time_ms", -1) < 0 for c in cases) sys.exit(1 if bad else 0) if __name__ == "__main__": main()