#!/usr/bin/env python3 """Train a REINFORCE agent on LabEnv and compare against the naive baseline.""" from __future__ import annotations import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from lab_env.env import LabEnv, INITIAL_BUDGET from agents.naive_agent import NaiveAgent from agents.rl_agent import ReinforceAgent # ------------------------------------------------------------------ # Naive episode runner # ------------------------------------------------------------------ def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict: obs, info = env.reset(seed=seed) agent.reset() total_reward = 0.0 steps = 0 while True: action = agent.select_action(obs) obs, reward, terminated, truncated, info = env.step(action) total_reward += reward steps += 1 if terminated or truncated: break return { "reward": total_reward, "success": info["best_result"] == "success", "partial": info["best_result"] == "partial", "minutes": info["elapsed_minutes"], "cost": INITIAL_BUDGET - info["remaining_budget"], "steps": steps, } # ------------------------------------------------------------------ # Aggregation # ------------------------------------------------------------------ def aggregate(results: list[dict]) -> dict: n = len(results) return { "n": n, "avg_reward": sum(r["reward"] for r in results) / n, "success_rate": sum(r["success"] for r in results) / n, "partial_rate": sum(r["partial"] for r in results) / n, "avg_minutes": sum(r["minutes"] for r in results) / n, "avg_cost": sum(r["cost"] for r in results) / n, "avg_steps": sum(r["steps"] for r in results) / n, } # ------------------------------------------------------------------ # Main # ------------------------------------------------------------------ def main() -> None: parser = argparse.ArgumentParser(description="Train & evaluate REINFORCE agent") parser.add_argument("--train-episodes", type=int, default=2000) parser.add_argument("--eval-episodes", type=int, default=100) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--max-trials", type=int, default=4) args = parser.parse_args() env = LabEnv() rl_agent = ReinforceAgent(lr=args.lr, gamma=args.gamma, max_trials=args.max_trials) # ---- Training ---- print("=" * 60) print(" Training REINFORCE agent") print("=" * 60) window: list[float] = [] successes_window: list[bool] = [] for ep in range(1, args.train_episodes + 1): result = rl_agent.run_episode(env, seed=args.seed + ep, train=True) window.append(result["reward"]) successes_window.append(result["success"]) if ep % args.log_interval == 0: avg = sum(window) / len(window) sr = sum(successes_window) / len(successes_window) print( f" Episode {ep:5d} | avg reward (last {args.log_interval}): " f"{avg:7.1f} | success rate: {sr:.0%}" ) window.clear() successes_window.clear() # ---- Evaluation ---- print() print("=" * 60) print(" Evaluating on fixed seed range") print("=" * 60) eval_seed_base = 999_999 rl_results = [ rl_agent.run_episode(env, seed=eval_seed_base + i, train=False) for i in range(args.eval_episodes) ] naive_agent = NaiveAgent(num_trials=3, seed=0) naive_results = [ run_episode_naive(env, naive_agent, seed=eval_seed_base + i) for i in range(args.eval_episodes) ] env.close() rl_stats = aggregate(rl_results) naive_stats = aggregate(naive_results) header = f"{'Metric':<20s} {'REINFORCE':>12s} {'Naive':>12s}" sep = "-" * len(header) rows = [ ("Avg reward", f"{rl_stats['avg_reward']:.1f}", f"{naive_stats['avg_reward']:.1f}"), ("Success rate", f"{rl_stats['success_rate']:.1%}", f"{naive_stats['success_rate']:.1%}"), ("Partial rate", f"{rl_stats['partial_rate']:.1%}", f"{naive_stats['partial_rate']:.1%}"), ("Avg time", f"{rl_stats['avg_minutes']:.1f}m", f"{naive_stats['avg_minutes']:.1f}m"), ("Avg cost", f"${rl_stats['avg_cost']:.1f}", f"${naive_stats['avg_cost']:.1f}"), ("Avg steps", f"{rl_stats['avg_steps']:.1f}", f"{naive_stats['avg_steps']:.1f}"), ] print() print(header) print(sep) for label, rl_val, naive_val in rows: print(f"{label:<20s} {rl_val:>12s} {naive_val:>12s}") print(sep) print() if __name__ == "__main__": main()