| import subprocess, time, os, sys | |
| import numpy as np | |
| sys.path.insert(0, os.path.expanduser('~/projects/machineai')) | |
| os.chdir(os.path.expanduser('~/projects/machineai')) | |
| from stable_baselines3 import PPO | |
| from ir_feature_extractor import extract_features | |
| import glob | |
| SOURCE_FILES = glob.glob("benchmarks/*.c") | |
| model = PPO.load("models/x86v2_base") | |
| passes = {0:"",1:"loop-vectorize",2:"inline,loop-vectorize", | |
| 3:"loop-unroll,loop-vectorize",4:"inline,loop-unroll,loop-vectorize",5:"loop-unroll"} | |
| def measure(exe, n=9): | |
| times = [] | |
| for _ in range(n): | |
| t0 = time.perf_counter() | |
| r = subprocess.run([exe], capture_output=True) | |
| t1 = time.perf_counter() | |
| if r.returncode == 0: | |
| times.append(t1 - t0) | |
| return float(np.median(times)) if times else 999.0 | |
| print("=" * 60) | |
| print(" BENCHMARK RESULTS — AI vs -O1 baseline") | |
| print("=" * 60) | |
| total_pct = [] | |
| for src in SOURCE_FILES: | |
| bc = src.replace(".c", ".bc") | |
| subprocess.run(["clang", "-O1", "-emit-llvm", "-c", src, "-o", bc], capture_output=True) | |
| obs = extract_features(bc) | |
| action, _ = model.predict(obs, deterministic=True) | |
| out_bc = bc.replace(".bc", "_ai.bc") | |
| p = passes[int(action)] | |
| if p: | |
| subprocess.run(["opt", f"--passes={p}", bc, "-o", out_bc], capture_output=True) | |
| else: | |
| out_bc = bc | |
| base_exe = src.replace(".c", "_o1_exe") | |
| ai_exe = src.replace(".c", "_ai_exe") | |
| subprocess.run(["clang", "-O1", bc, "-o", base_exe, "-lm"], capture_output=True) | |
| subprocess.run(["clang", "-O1", out_bc, "-o", ai_exe, "-lm"], capture_output=True) | |
| t_base = measure(base_exe) | |
| t_ai = measure(ai_exe) | |
| pct = (t_base - t_ai) / (t_base + 1e-9) * 100 | |
| total_pct.append(pct) | |
| name = os.path.basename(src) | |
| print(f" {name:<20} action={action} base={t_base*1000:.1f}ms ai={t_ai*1000:.1f}ms {pct:+.1f}%") | |
| print("=" * 60) | |
| print(f" 평균 speedup: {np.mean(total_pct):+.1f}%") | |
| print("=" * 60) | |