File size: 4,965 Bytes
da63ca8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python3
"""Train a REINFORCE agent on LabEnv and compare against the naive baseline."""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from lab_env.env import LabEnv, INITIAL_BUDGET
from agents.naive_agent import NaiveAgent
from agents.rl_agent import ReinforceAgent
# ------------------------------------------------------------------
# Naive episode runner
# ------------------------------------------------------------------
def run_episode_naive(env: LabEnv, agent: NaiveAgent, seed: int) -> dict:
obs, info = env.reset(seed=seed)
agent.reset()
total_reward = 0.0
steps = 0
while True:
action = agent.select_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
if terminated or truncated:
break
return {
"reward": total_reward,
"success": info["best_result"] == "success",
"partial": info["best_result"] == "partial",
"minutes": info["elapsed_minutes"],
"cost": INITIAL_BUDGET - info["remaining_budget"],
"steps": steps,
}
# ------------------------------------------------------------------
# Aggregation
# ------------------------------------------------------------------
def aggregate(results: list[dict]) -> dict:
n = len(results)
return {
"n": n,
"avg_reward": sum(r["reward"] for r in results) / n,
"success_rate": sum(r["success"] for r in results) / n,
"partial_rate": sum(r["partial"] for r in results) / n,
"avg_minutes": sum(r["minutes"] for r in results) / n,
"avg_cost": sum(r["cost"] for r in results) / n,
"avg_steps": sum(r["steps"] for r in results) / n,
}
# ------------------------------------------------------------------
# Main
# ------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Train & evaluate REINFORCE agent")
parser.add_argument("--train-episodes", type=int, default=2000)
parser.add_argument("--eval-episodes", type=int, default=100)
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lr", type=float, default=3e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--max-trials", type=int, default=4)
args = parser.parse_args()
env = LabEnv()
rl_agent = ReinforceAgent(lr=args.lr, gamma=args.gamma, max_trials=args.max_trials)
# ---- Training ----
print("=" * 60)
print(" Training REINFORCE agent")
print("=" * 60)
window: list[float] = []
successes_window: list[bool] = []
for ep in range(1, args.train_episodes + 1):
result = rl_agent.run_episode(env, seed=args.seed + ep, train=True)
window.append(result["reward"])
successes_window.append(result["success"])
if ep % args.log_interval == 0:
avg = sum(window) / len(window)
sr = sum(successes_window) / len(successes_window)
print(
f" Episode {ep:5d} | avg reward (last {args.log_interval}): "
f"{avg:7.1f} | success rate: {sr:.0%}"
)
window.clear()
successes_window.clear()
# ---- Evaluation ----
print()
print("=" * 60)
print(" Evaluating on fixed seed range")
print("=" * 60)
eval_seed_base = 999_999
rl_results = [
rl_agent.run_episode(env, seed=eval_seed_base + i, train=False)
for i in range(args.eval_episodes)
]
naive_agent = NaiveAgent(num_trials=3, seed=0)
naive_results = [
run_episode_naive(env, naive_agent, seed=eval_seed_base + i)
for i in range(args.eval_episodes)
]
env.close()
rl_stats = aggregate(rl_results)
naive_stats = aggregate(naive_results)
header = f"{'Metric':<20s} {'REINFORCE':>12s} {'Naive':>12s}"
sep = "-" * len(header)
rows = [
("Avg reward", f"{rl_stats['avg_reward']:.1f}", f"{naive_stats['avg_reward']:.1f}"),
("Success rate", f"{rl_stats['success_rate']:.1%}", f"{naive_stats['success_rate']:.1%}"),
("Partial rate", f"{rl_stats['partial_rate']:.1%}", f"{naive_stats['partial_rate']:.1%}"),
("Avg time", f"{rl_stats['avg_minutes']:.1f}m", f"{naive_stats['avg_minutes']:.1f}m"),
("Avg cost", f"${rl_stats['avg_cost']:.1f}", f"${naive_stats['avg_cost']:.1f}"),
("Avg steps", f"{rl_stats['avg_steps']:.1f}", f"{naive_stats['avg_steps']:.1f}"),
]
print()
print(header)
print(sep)
for label, rl_val, naive_val in rows:
print(f"{label:<20s} {rl_val:>12s} {naive_val:>12s}")
print(sep)
print()
if __name__ == "__main__":
main()
|