#!/usr/bin/env python3 """Benchmark Naive, RL, and Research LLM agents on the same eval seeds.""" from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from lab_env.env import LabEnv, INITIAL_BUDGET from agents.naive_agent import NaiveAgent from agents.rl_agent import ReinforceAgent from agents.research_llm_agent import ResearchLLMAgent def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: obs, info = env.reset(seed=seed) agent.reset() total_reward = 0.0 steps = 0 while True: action = agent.select_action(obs) obs, reward, terminated, truncated, info = env.step(action) total_reward += reward steps += 1 if terminated or truncated: break return { "reward": total_reward, "success": info["best_result"] == "success", "partial": info["best_result"] == "partial", "minutes": info["elapsed_minutes"], "cost": INITIAL_BUDGET - info["remaining_budget"], "steps": steps, } def aggregate(results: list[dict]) -> dict: n = len(results) successes = [r["success"] for r in results] steps_to_success = [r["steps"] for r in results if r["success"]] or [0] return { "n": n, "avg_reward": sum(r["reward"] for r in results) / n, "success_rate": sum(successes) / n, "partial_rate": sum(r["partial"] for r in results) / n, "avg_minutes": sum(r["minutes"] for r in results) / n, "avg_cost": sum(r["cost"] for r in results) / n, "avg_steps": sum(r["steps"] for r in results) / n, "experiments_to_success": sum(steps_to_success) / len(steps_to_success) if steps_to_success else 0, } def main() -> None: parser = argparse.ArgumentParser(description="Compare Naive, RL, and Research LLM agents") parser.add_argument("--eval-episodes", type=int, default=50, help="Episodes per agent (eval)") parser.add_argument("--train-episodes", type=int, default=500, help="RL training episodes before eval") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--max-trials", type=int, default=5, help="Max trials per episode (RL and LLM)") parser.add_argument("--no-llm", action="store_true", help="Skip LLM agent (no API key)") args = parser.parse_args() eval_seed_base = 100_000 + args.seed env = LabEnv() # ---- Naive ---- print("Running Naive agent...") naive_agent = NaiveAgent(num_trials=3, seed=args.seed) naive_results = [ run_episode_naive(env, naive_agent, eval_seed_base + i) for i in range(args.eval_episodes) ] naive_stats = aggregate(naive_results) # ---- RL (train then eval) ---- print("Training REINFORCE agent...") rl_agent = ReinforceAgent(max_trials=args.max_trials) for ep in range(1, args.train_episodes + 1): rl_agent.run_episode(env, seed=args.seed + ep, train=True) if ep % 100 == 0: print(f" RL train episode {ep}/{args.train_episodes}") print("Evaluating REINFORCE agent...") rl_results = [ rl_agent.run_episode(env, seed=eval_seed_base + i, train=False) for i in range(args.eval_episodes) ] rl_stats = aggregate(rl_results) # ---- Research LLM ---- llm_stats = None if not args.no_llm: print("Running Research LLM agent...") try: llm_agent = ResearchLLMAgent(max_trials=args.max_trials) llm_results = [ llm_agent.run_episode(env, seed=eval_seed_base + i) for i in range(args.eval_episodes) ] llm_stats = aggregate(llm_results) except Exception as e: print(f" Skipping LLM agent: {e}") env.close() # ---- Table ---- header = f"{'Metric':<22} {'Naive':>12} {'RL (MLP)':>12}" if llm_stats is not None: header += f" {'LLM Researcher':>14}" sep = "-" * len(header) print() print(sep) print(" Agent comparison (same eval seeds)") print(sep) print(header) print(sep) def row(label: str, n_val: str, r_val: str, l_val: str | None = None) -> None: line = f"{label:<22} {n_val:>12} {r_val:>12}" if l_val is not None: line += f" {l_val:>14}" print(line) row("Success rate", f"{naive_stats['success_rate']:.1%}", f"{rl_stats['success_rate']:.1%}", f"{llm_stats['success_rate']:.1%}" if llm_stats else None) row("Experiments to success", f"{naive_stats['experiments_to_success']:.1f}", f"{rl_stats['experiments_to_success']:.1f}", f"{llm_stats['experiments_to_success']:.1f}" if llm_stats else None) row("Cost/episode", f"${naive_stats['avg_cost']:.1f}", f"${rl_stats['avg_cost']:.1f}", f"${llm_stats['avg_cost']:.1f}" if llm_stats else None) row("Avg reward", f"{naive_stats['avg_reward']:.1f}", f"{rl_stats['avg_reward']:.1f}", f"{llm_stats['avg_reward']:.1f}" if llm_stats else None) row("Avg steps", f"{naive_stats['avg_steps']:.1f}", f"{rl_stats['avg_steps']:.1f}", f"{llm_stats['avg_steps']:.1f}" if llm_stats else None) print(sep) if __name__ == "__main__": main()