| |
| """Benchmark Naive, RL, and Research LLM agents on the same eval seeds.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from lab_env.env import LabEnv, INITIAL_BUDGET |
| from agents.naive_agent import NaiveAgent |
| from agents.rl_agent import ReinforceAgent |
| from agents.research_llm_agent import ResearchLLMAgent |
|
|
|
|
| def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: |
| obs, info = env.reset(seed=seed) |
| agent.reset() |
| total_reward = 0.0 |
| steps = 0 |
| while True: |
| action = agent.select_action(obs) |
| obs, reward, terminated, truncated, info = env.step(action) |
| total_reward += reward |
| steps += 1 |
| if terminated or truncated: |
| break |
| return { |
| "reward": total_reward, |
| "success": info["best_result"] == "success", |
| "partial": info["best_result"] == "partial", |
| "minutes": info["elapsed_minutes"], |
| "cost": INITIAL_BUDGET - info["remaining_budget"], |
| "steps": steps, |
| } |
|
|
|
|
| def aggregate(results: list[dict]) -> dict: |
| n = len(results) |
| successes = [r["success"] for r in results] |
| steps_to_success = [r["steps"] for r in results if r["success"]] or [0] |
| return { |
| "n": n, |
| "avg_reward": sum(r["reward"] for r in results) / n, |
| "success_rate": sum(successes) / n, |
| "partial_rate": sum(r["partial"] for r in results) / n, |
| "avg_minutes": sum(r["minutes"] for r in results) / n, |
| "avg_cost": sum(r["cost"] for r in results) / n, |
| "avg_steps": sum(r["steps"] for r in results) / n, |
| "experiments_to_success": sum(steps_to_success) / len(steps_to_success) if steps_to_success else 0, |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Compare Naive, RL, and Research LLM agents") |
| parser.add_argument("--eval-episodes", type=int, default=50, help="Episodes per agent (eval)") |
| parser.add_argument("--train-episodes", type=int, default=500, help="RL training episodes before eval") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--max-trials", type=int, default=5, help="Max trials per episode (RL and LLM)") |
| parser.add_argument("--no-llm", action="store_true", help="Skip LLM agent (no API key)") |
| args = parser.parse_args() |
|
|
| eval_seed_base = 100_000 + args.seed |
| env = LabEnv() |
|
|
| |
| print("Running Naive agent...") |
| naive_agent = NaiveAgent(num_trials=3, seed=args.seed) |
| naive_results = [ |
| run_episode_naive(env, naive_agent, eval_seed_base + i) |
| for i in range(args.eval_episodes) |
| ] |
| naive_stats = aggregate(naive_results) |
|
|
| |
| print("Training REINFORCE agent...") |
| rl_agent = ReinforceAgent(max_trials=args.max_trials) |
| for ep in range(1, args.train_episodes + 1): |
| rl_agent.run_episode(env, seed=args.seed + ep, train=True) |
| if ep % 100 == 0: |
| print(f" RL train episode {ep}/{args.train_episodes}") |
| print("Evaluating REINFORCE agent...") |
| rl_results = [ |
| rl_agent.run_episode(env, seed=eval_seed_base + i, train=False) |
| for i in range(args.eval_episodes) |
| ] |
| rl_stats = aggregate(rl_results) |
|
|
| |
| llm_stats = None |
| if not args.no_llm: |
| print("Running Research LLM agent...") |
| try: |
| llm_agent = ResearchLLMAgent(max_trials=args.max_trials) |
| llm_results = [ |
| llm_agent.run_episode(env, seed=eval_seed_base + i) |
| for i in range(args.eval_episodes) |
| ] |
| llm_stats = aggregate(llm_results) |
| except Exception as e: |
| print(f" Skipping LLM agent: {e}") |
|
|
| env.close() |
|
|
| |
| header = f"{'Metric':<22} {'Naive':>12} {'RL (MLP)':>12}" |
| if llm_stats is not None: |
| header += f" {'LLM Researcher':>14}" |
| sep = "-" * len(header) |
| print() |
| print(sep) |
| print(" Agent comparison (same eval seeds)") |
| print(sep) |
| print(header) |
| print(sep) |
|
|
| def row(label: str, n_val: str, r_val: str, l_val: str | None = None) -> None: |
| line = f"{label:<22} {n_val:>12} {r_val:>12}" |
| if l_val is not None: |
| line += f" {l_val:>14}" |
| print(line) |
|
|
| row("Success rate", f"{naive_stats['success_rate']:.1%}", f"{rl_stats['success_rate']:.1%}", |
| f"{llm_stats['success_rate']:.1%}" if llm_stats else None) |
| row("Experiments to success", f"{naive_stats['experiments_to_success']:.1f}", f"{rl_stats['experiments_to_success']:.1f}", |
| f"{llm_stats['experiments_to_success']:.1f}" if llm_stats else None) |
| row("Cost/episode", f"${naive_stats['avg_cost']:.1f}", f"${rl_stats['avg_cost']:.1f}", |
| f"${llm_stats['avg_cost']:.1f}" if llm_stats else None) |
| row("Avg reward", f"{naive_stats['avg_reward']:.1f}", f"{rl_stats['avg_reward']:.1f}", |
| f"{llm_stats['avg_reward']:.1f}" if llm_stats else None) |
| row("Avg steps", f"{naive_stats['avg_steps']:.1f}", f"{rl_stats['avg_steps']:.1f}", |
| f"{llm_stats['avg_steps']:.1f}" if llm_stats else None) |
| print(sep) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|