biosim / scripts /compare_all_agents.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
#!/usr/bin/env python3
"""Benchmark Naive, RL, and Research LLM agents on the same eval seeds."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv, INITIAL_BUDGET
from agents.naive_agent import NaiveAgent
from agents.rl_agent import ReinforceAgent
from agents.research_llm_agent import ResearchLLMAgent
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
obs, info = env.reset(seed=seed)
agent.reset()
total_reward = 0.0
steps = 0
while True:
action = agent.select_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
if terminated or truncated:
break
return {
"reward": total_reward,
"success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": INITIAL_BUDGET - info["remaining_budget"],
"steps": steps,
}
def aggregate(results: list[dict]) -> dict:
n = len(results)
successes = [r["success"] for r in results]
steps_to_success = [r["steps"] for r in results if r["success"]] or [0]
return {
"n": n,
"avg_reward": sum(r["reward"] for r in results) / n,
"success_rate": sum(successes) / n,
"partial_rate": sum(r["partial"] for r in results) / n,
"avg_minutes": sum(r["minutes"] for r in results) / n,
"avg_cost": sum(r["cost"] for r in results) / n,
"avg_steps": sum(r["steps"] for r in results) / n,
"experiments_to_success": sum(steps_to_success) / len(steps_to_success) if steps_to_success else 0,
}
def main() -> None:
parser = argparse.ArgumentParser(description="Compare Naive, RL, and Research LLM agents")
parser.add_argument("--eval-episodes", type=int, default=50, help="Episodes per agent (eval)")
parser.add_argument("--train-episodes", type=int, default=500, help="RL training episodes before eval")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max-trials", type=int, default=5, help="Max trials per episode (RL and LLM)")
parser.add_argument("--no-llm", action="store_true", help="Skip LLM agent (no API key)")
args = parser.parse_args()
eval_seed_base = 100_000 + args.seed
env = LabEnv()
# ---- Naive ----
print("Running Naive agent...")
naive_agent = NaiveAgent(num_trials=3, seed=args.seed)
naive_results = [
run_episode_naive(env, naive_agent, eval_seed_base + i)
for i in range(args.eval_episodes)
]
naive_stats = aggregate(naive_results)
# ---- RL (train then eval) ----
print("Training REINFORCE agent...")
rl_agent = ReinforceAgent(max_trials=args.max_trials)
for ep in range(1, args.train_episodes + 1):
rl_agent.run_episode(env, seed=args.seed + ep, train=True)
if ep % 100 == 0:
print(f" RL train episode {ep}/{args.train_episodes}")
print("Evaluating REINFORCE agent...")
rl_results = [
rl_agent.run_episode(env, seed=eval_seed_base + i, train=False)
for i in range(args.eval_episodes)
]
rl_stats = aggregate(rl_results)
# ---- Research LLM ----
llm_stats = None
if not args.no_llm:
print("Running Research LLM agent...")
try:
llm_agent = ResearchLLMAgent(max_trials=args.max_trials)
llm_results = [
llm_agent.run_episode(env, seed=eval_seed_base + i)
for i in range(args.eval_episodes)
]
llm_stats = aggregate(llm_results)
except Exception as e:
print(f" Skipping LLM agent: {e}")
env.close()
# ---- Table ----
header = f"{'Metric':<22} {'Naive':>12} {'RL (MLP)':>12}"
if llm_stats is not None:
header += f" {'LLM Researcher':>14}"
sep = "-" * len(header)
print()
print(sep)
print(" Agent comparison (same eval seeds)")
print(sep)
print(header)
print(sep)
def row(label: str, n_val: str, r_val: str, l_val: str | None = None) -> None:
line = f"{label:<22} {n_val:>12} {r_val:>12}"
if l_val is not None:
line += f" {l_val:>14}"
print(line)
row("Success rate", f"{naive_stats['success_rate']:.1%}", f"{rl_stats['success_rate']:.1%}",
f"{llm_stats['success_rate']:.1%}" if llm_stats else None)
row("Experiments to success", f"{naive_stats['experiments_to_success']:.1f}", f"{rl_stats['experiments_to_success']:.1f}",
f"{llm_stats['experiments_to_success']:.1f}" if llm_stats else None)
row("Cost/episode", f"${naive_stats['avg_cost']:.1f}", f"${rl_stats['avg_cost']:.1f}",
f"${llm_stats['avg_cost']:.1f}" if llm_stats else None)
row("Avg reward", f"{naive_stats['avg_reward']:.1f}", f"{rl_stats['avg_reward']:.1f}",
f"{llm_stats['avg_reward']:.1f}" if llm_stats else None)
row("Avg steps", f"{naive_stats['avg_steps']:.1f}", f"{rl_stats['avg_steps']:.1f}",
f"{llm_stats['avg_steps']:.1f}" if llm_stats else None)
print(sep)
if __name__ == "__main__":
main()