Meta-Hackathon-main / test_env.py
Parth3841's picture
Upload folder using huggingface_hub
7c2f148 verified
"""
Quick sanity test for the Compiler Pass Ordering environment.
Run with: python test_env.py
Server must be running: uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
"""
from compiler_opt_env import CompilerOptAction, CompilerOptEnv
from compiler_opt_env.models import PASS_NAMES, TASK_EASY, TASK_MEDIUM, TASK_HARD
from compiler_opt_env.server.compiler_opt_env_environment import BASE_PASS_EFFECTS
def run_greedy(env, task_id: int, label: str):
"""Greedy agent: always pick the available pass with highest base effect."""
print(f"\n{'─'*60}")
print(f" {label} (greedy agent)")
print(f"{'─'*60}")
obs = env.reset().observation if hasattr(env.reset(), 'observation') else env.reset()
# reset() returns StepResult in sync mode
result = env.reset()
obs = result.observation
print(f" Program type: {obs.program_type}")
print(f" Baseline cost: {obs.baseline_cost:.1f}")
print()
while not obs.done:
available = obs.passes_available
best = max(available, key=lambda p: BASE_PASS_EFFECTS[p])
step_result = env.step(CompilerOptAction(pass_id=best, task_id=task_id))
obs = step_result.observation
print(f" Step {obs.step_count:2d}: {PASS_NAMES[best]:35s} "
f"cost={obs.estimated_cost:7.1f} "
f"improvement={obs.improvement_pct:5.1f}% "
f"reward={step_result.reward:+.4f}")
print(f"\n Greedy improvement: {obs.improvement_pct:.1f}%")
print(f" Grader score: {obs.grader_score:.3f}")
return obs.improvement_pct, obs.grader_score
def run_optimal_task1(env):
"""Hand-crafted optimal sequence for Task 1: alias → DCE → vectorization chain."""
print(f"\n{'─'*60}")
print(f" Task 1 — Optimal sequence (alias → DCE → vectorization)")
print(f"{'─'*60}")
result = env.reset()
obs = result.observation
print(f" Program type: {obs.program_type}")
print(f" Baseline cost: {obs.baseline_cost:.1f}\n")
# FIX: Padded the sequence to 10 steps to ensure the episode finishes (done=True)
optimal_sequence = [13, 0, 4, 5, 2, 7, 1, 10, 8, 9]
for pass_id in optimal_sequence:
if obs.done:
break
step_result = env.step(CompilerOptAction(pass_id=pass_id, task_id=TASK_EASY))
obs = step_result.observation
print(f" Step {obs.step_count:2d}: {PASS_NAMES[pass_id]:35s} "
f"cost={obs.estimated_cost:7.1f} "
f"improvement={obs.improvement_pct:5.1f}% "
f"reward={step_result.reward:+.4f}")
print(f"\n Optimal improvement: {obs.improvement_pct:.1f}%")
print(f" Grader score: {obs.grader_score:.3f}")
return obs.improvement_pct, obs.grader_score
if __name__ == "__main__":
print("Compiler Pass Ordering — Environment Sanity Test")
print("=" * 60)
with CompilerOptEnv(base_url="http://localhost:8000").sync() as env:
# Task 1: greedy vs optimal
greedy_improv_1, greedy_score_1 = run_greedy(env, TASK_EASY, "Task 1 (Easy)")
opt_improv_1, opt_score_1 = run_optimal_task1(env)
# Task 2: greedy
greedy_improv_2, greedy_score_2 = run_greedy(env, TASK_MEDIUM, "Task 2 (Medium)")
# Task 3: greedy
greedy_improv_3, greedy_score_3 = run_greedy(env, TASK_HARD, "Task 3 (Hard)")
print(f"\n{'='*60}")
print(" SUMMARY")
print(f"{'='*60}")
print(f" Task 1 greedy: {greedy_improv_1:.1f}% improvement score={greedy_score_1:.3f}")
print(f" Task 1 optimal: {opt_improv_1:.1f}% improvement score={opt_score_1:.3f}")
print(f" Task 2 greedy: {greedy_improv_2:.1f}% improvement score={greedy_score_2:.3f}")
print(f" Task 3 greedy: {greedy_improv_3:.1f}% improvement score={greedy_score_3:.3f}")
print()
print(" Expected: greedy ~19-24% | optimal Task 1 ~40-50%")
print(" If greedy << optimal: ✓ environment requires RL")
print(f"{'='*60}")