File size: 5,269 Bytes
da63ca8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | #!/usr/bin/env python3
"""Benchmark Naive, RL, and Research LLM agents on the same eval seeds."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv, INITIAL_BUDGET
from agents.naive_agent import NaiveAgent
from agents.rl_agent import ReinforceAgent
from agents.research_llm_agent import ResearchLLMAgent
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
obs, info = env.reset(seed=seed)
agent.reset()
total_reward = 0.0
steps = 0
while True:
action = agent.select_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
if terminated or truncated:
break
return {
"reward": total_reward,
"success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": INITIAL_BUDGET - info["remaining_budget"],
"steps": steps,
}
def aggregate(results: list[dict]) -> dict:
n = len(results)
successes = [r["success"] for r in results]
steps_to_success = [r["steps"] for r in results if r["success"]] or [0]
return {
"n": n,
"avg_reward": sum(r["reward"] for r in results) / n,
"success_rate": sum(successes) / n,
"partial_rate": sum(r["partial"] for r in results) / n,
"avg_minutes": sum(r["minutes"] for r in results) / n,
"avg_cost": sum(r["cost"] for r in results) / n,
"avg_steps": sum(r["steps"] for r in results) / n,
"experiments_to_success": sum(steps_to_success) / len(steps_to_success) if steps_to_success else 0,
}
def main() -> None:
parser = argparse.ArgumentParser(description="Compare Naive, RL, and Research LLM agents")
parser.add_argument("--eval-episodes", type=int, default=50, help="Episodes per agent (eval)")
parser.add_argument("--train-episodes", type=int, default=500, help="RL training episodes before eval")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max-trials", type=int, default=5, help="Max trials per episode (RL and LLM)")
parser.add_argument("--no-llm", action="store_true", help="Skip LLM agent (no API key)")
args = parser.parse_args()
eval_seed_base = 100_000 + args.seed
env = LabEnv()
# ---- Naive ----
print("Running Naive agent...")
naive_agent = NaiveAgent(num_trials=3, seed=args.seed)
naive_results = [
run_episode_naive(env, naive_agent, eval_seed_base + i)
for i in range(args.eval_episodes)
]
naive_stats = aggregate(naive_results)
# ---- RL (train then eval) ----
print("Training REINFORCE agent...")
rl_agent = ReinforceAgent(max_trials=args.max_trials)
for ep in range(1, args.train_episodes + 1):
rl_agent.run_episode(env, seed=args.seed + ep, train=True)
if ep % 100 == 0:
print(f" RL train episode {ep}/{args.train_episodes}")
print("Evaluating REINFORCE agent...")
rl_results = [
rl_agent.run_episode(env, seed=eval_seed_base + i, train=False)
for i in range(args.eval_episodes)
]
rl_stats = aggregate(rl_results)
# ---- Research LLM ----
llm_stats = None
if not args.no_llm:
print("Running Research LLM agent...")
try:
llm_agent = ResearchLLMAgent(max_trials=args.max_trials)
llm_results = [
llm_agent.run_episode(env, seed=eval_seed_base + i)
for i in range(args.eval_episodes)
]
llm_stats = aggregate(llm_results)
except Exception as e:
print(f" Skipping LLM agent: {e}")
env.close()
# ---- Table ----
header = f"{'Metric':<22} {'Naive':>12} {'RL (MLP)':>12}"
if llm_stats is not None:
header += f" {'LLM Researcher':>14}"
sep = "-" * len(header)
print()
print(sep)
print(" Agent comparison (same eval seeds)")
print(sep)
print(header)
print(sep)
def row(label: str, n_val: str, r_val: str, l_val: str | None = None) -> None:
line = f"{label:<22} {n_val:>12} {r_val:>12}"
if l_val is not None:
line += f" {l_val:>14}"
print(line)
row("Success rate", f"{naive_stats['success_rate']:.1%}", f"{rl_stats['success_rate']:.1%}",
f"{llm_stats['success_rate']:.1%}" if llm_stats else None)
row("Experiments to success", f"{naive_stats['experiments_to_success']:.1f}", f"{rl_stats['experiments_to_success']:.1f}",
f"{llm_stats['experiments_to_success']:.1f}" if llm_stats else None)
row("Cost/episode", f"${naive_stats['avg_cost']:.1f}", f"${rl_stats['avg_cost']:.1f}",
f"${llm_stats['avg_cost']:.1f}" if llm_stats else None)
row("Avg reward", f"{naive_stats['avg_reward']:.1f}", f"{rl_stats['avg_reward']:.1f}",
f"{llm_stats['avg_reward']:.1f}" if llm_stats else None)
row("Avg steps", f"{naive_stats['avg_steps']:.1f}", f"{rl_stats['avg_steps']:.1f}",
f"{llm_stats['avg_steps']:.1f}" if llm_stats else None)
print(sep)
if __name__ == "__main__":
main()
|