| |
| """Train a REINFORCE agent on LabEnv and compare against the naive baseline.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from lab_env.env import LabEnv, INITIAL_BUDGET |
| from agents.naive_agent import NaiveAgent |
| from agents.rl_agent import ReinforceAgent |
|
|
|
|
| |
| |
| |
|
|
| def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: |
| obs, info = env.reset(seed=seed) |
| agent.reset() |
|
|
| total_reward = 0.0 |
| steps = 0 |
|
|
| while True: |
| action = agent.select_action(obs) |
| obs, reward, terminated, truncated, info = env.step(action) |
| total_reward += reward |
| steps += 1 |
| if terminated or truncated: |
| break |
|
|
| return { |
| "reward": total_reward, |
| "success": info["best_result"] == "success", |
| "partial": info["best_result"] == "partial", |
| "minutes": info["elapsed_minutes"], |
| "cost": INITIAL_BUDGET - info["remaining_budget"], |
| "steps": steps, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def aggregate(results: list[dict]) -> dict: |
| n = len(results) |
| return { |
| "n": n, |
| "avg_reward": sum(r["reward"] for r in results) / n, |
| "success_rate": sum(r["success"] for r in results) / n, |
| "partial_rate": sum(r["partial"] for r in results) / n, |
| "avg_minutes": sum(r["minutes"] for r in results) / n, |
| "avg_cost": sum(r["cost"] for r in results) / n, |
| "avg_steps": sum(r["steps"] for r in results) / n, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Train & evaluate REINFORCE agent") |
| parser.add_argument("--train-episodes", type=int, default=2000) |
| parser.add_argument("--eval-episodes", type=int, default=100) |
| parser.add_argument("--log-interval", type=int, default=100) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--lr", type=float, default=3e-3) |
| parser.add_argument("--gamma", type=float, default=0.99) |
| parser.add_argument("--max-trials", type=int, default=4) |
| args = parser.parse_args() |
|
|
| env = LabEnv() |
| rl_agent = ReinforceAgent(lr=args.lr, gamma=args.gamma, max_trials=args.max_trials) |
|
|
| |
| print("=" * 60) |
| print(" Training REINFORCE agent") |
| print("=" * 60) |
|
|
| window: list[float] = [] |
| successes_window: list[bool] = [] |
| for ep in range(1, args.train_episodes + 1): |
| result = rl_agent.run_episode(env, seed=args.seed + ep, train=True) |
| window.append(result["reward"]) |
| successes_window.append(result["success"]) |
|
|
| if ep % args.log_interval == 0: |
| avg = sum(window) / len(window) |
| sr = sum(successes_window) / len(successes_window) |
| print( |
| f" Episode {ep:5d} | avg reward (last {args.log_interval}): " |
| f"{avg:7.1f} | success rate: {sr:.0%}" |
| ) |
| window.clear() |
| successes_window.clear() |
|
|
| |
| print() |
| print("=" * 60) |
| print(" Evaluating on fixed seed range") |
| print("=" * 60) |
|
|
| eval_seed_base = 999_999 |
|
|
| rl_results = [ |
| rl_agent.run_episode(env, seed=eval_seed_base + i, train=False) |
| for i in range(args.eval_episodes) |
| ] |
|
|
| naive_agent = NaiveAgent(num_trials=3, seed=0) |
| naive_results = [ |
| run_episode_naive(env, naive_agent, seed=eval_seed_base + i) |
| for i in range(args.eval_episodes) |
| ] |
|
|
| env.close() |
|
|
| rl_stats = aggregate(rl_results) |
| naive_stats = aggregate(naive_results) |
|
|
| header = f"{'Metric':<20s} {'REINFORCE':>12s} {'Naive':>12s}" |
| sep = "-" * len(header) |
| rows = [ |
| ("Avg reward", f"{rl_stats['avg_reward']:.1f}", f"{naive_stats['avg_reward']:.1f}"), |
| ("Success rate", f"{rl_stats['success_rate']:.1%}", f"{naive_stats['success_rate']:.1%}"), |
| ("Partial rate", f"{rl_stats['partial_rate']:.1%}", f"{naive_stats['partial_rate']:.1%}"), |
| ("Avg time", f"{rl_stats['avg_minutes']:.1f}m", f"{naive_stats['avg_minutes']:.1f}m"), |
| ("Avg cost", f"${rl_stats['avg_cost']:.1f}", f"${naive_stats['avg_cost']:.1f}"), |
| ("Avg steps", f"{rl_stats['avg_steps']:.1f}", f"{naive_stats['avg_steps']:.1f}"), |
| ] |
|
|
| print() |
| print(header) |
| print(sep) |
| for label, rl_val, naive_val in rows: |
| print(f"{label:<20s} {rl_val:>12s} {naive_val:>12s}") |
| print(sep) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|