biosim / scripts /train_and_eval_agent.py
arminfg's picture
SimLab: lab automation RL env, OpenEnv adapter, Training UI, agents
da63ca8
#!/usr/bin/env python3
"""Train a REINFORCE agent on LabEnv and compare against the naive baseline."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv, INITIAL_BUDGET
from agents.naive_agent import NaiveAgent
from agents.rl_agent import ReinforceAgent
# ------------------------------------------------------------------
# Naive episode runner
# ------------------------------------------------------------------
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
obs, info = env.reset(seed=seed)
agent.reset()
total_reward = 0.0
steps = 0
while True:
action = agent.select_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
if terminated or truncated:
break
return {
"reward": total_reward,
"success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": INITIAL_BUDGET - info["remaining_budget"],
"steps": steps,
}
# ------------------------------------------------------------------
# Aggregation
# ------------------------------------------------------------------
def aggregate(results: list[dict]) -> dict:
n = len(results)
return {
"n": n,
"avg_reward": sum(r["reward"] for r in results) / n,
"success_rate": sum(r["success"] for r in results) / n,
"partial_rate": sum(r["partial"] for r in results) / n,
"avg_minutes": sum(r["minutes"] for r in results) / n,
"avg_cost": sum(r["cost"] for r in results) / n,
"avg_steps": sum(r["steps"] for r in results) / n,
}
# ------------------------------------------------------------------
# Main
# ------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Train & evaluate REINFORCE agent")
parser.add_argument("--train-episodes", type=int, default=2000)
parser.add_argument("--eval-episodes", type=int, default=100)
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lr", type=float, default=3e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--max-trials", type=int, default=4)
args = parser.parse_args()
env = LabEnv()
rl_agent = ReinforceAgent(lr=args.lr, gamma=args.gamma, max_trials=args.max_trials)
# ---- Training ----
print("=" * 60)
print(" Training REINFORCE agent")
print("=" * 60)
window: list[float] = []
successes_window: list[bool] = []
for ep in range(1, args.train_episodes + 1):
result = rl_agent.run_episode(env, seed=args.seed + ep, train=True)
window.append(result["reward"])
successes_window.append(result["success"])
if ep % args.log_interval == 0:
avg = sum(window) / len(window)
sr = sum(successes_window) / len(successes_window)
print(
f" Episode {ep:5d} | avg reward (last {args.log_interval}): "
f"{avg:7.1f} | success rate: {sr:.0%}"
)
window.clear()
successes_window.clear()
# ---- Evaluation ----
print()
print("=" * 60)
print(" Evaluating on fixed seed range")
print("=" * 60)
eval_seed_base = 999_999
rl_results = [
rl_agent.run_episode(env, seed=eval_seed_base + i, train=False)
for i in range(args.eval_episodes)
]
naive_agent = NaiveAgent(num_trials=3, seed=0)
naive_results = [
run_episode_naive(env, naive_agent, seed=eval_seed_base + i)
for i in range(args.eval_episodes)
]
env.close()
rl_stats = aggregate(rl_results)
naive_stats = aggregate(naive_results)
header = f"{'Metric':<20s} {'REINFORCE':>12s} {'Naive':>12s}"
sep = "-" * len(header)
rows = [
("Avg reward", f"{rl_stats['avg_reward']:.1f}", f"{naive_stats['avg_reward']:.1f}"),
("Success rate", f"{rl_stats['success_rate']:.1%}", f"{naive_stats['success_rate']:.1%}"),
("Partial rate", f"{rl_stats['partial_rate']:.1%}", f"{naive_stats['partial_rate']:.1%}"),
("Avg time", f"{rl_stats['avg_minutes']:.1f}m", f"{naive_stats['avg_minutes']:.1f}m"),
("Avg cost", f"${rl_stats['avg_cost']:.1f}", f"${naive_stats['avg_cost']:.1f}"),
("Avg steps", f"{rl_stats['avg_steps']:.1f}", f"{naive_stats['avg_steps']:.1f}"),
]
print()
print(header)
print(sep)
for label, rl_val, naive_val in rows:
print(f"{label:<20s} {rl_val:>12s} {naive_val:>12s}")
print(sep)
print()
if __name__ == "__main__":
main()