Car-Racing-Agent / training /train_torchrl.py
nirmalpratheep's picture
Upload 2 files
27cec6c verified
"""
train_torchrl.py β€” TorchRL PPO training for the curriculum car racer.
Replaces train_sb3.py with a torchrl-based PPO. All hyperparameters are
transferred exactly from train_sb3.py (SB3 PPO defaults + our overrides):
learning_rate 3e-4 (Adam, eps=1e-5) β€” SB3 default
rollout frames 2048 / update (n_steps Γ— n_envs in SB3)
batch size 64 (minibatch size for PPO updates)
n_epochs 10 (passes over rollout data)
gamma 0.99
gae_lambda 0.95
clip_epsilon 0.2
vf_coef 0.5
ent_coef 0.0 (SB3 default)
max_grad_norm 0.5
normalize_adv True
target_kl None (no early stop)
log_std_init -1.0 (initial std β‰ˆ 0.37; SB3 DiagGaussian β€” no clamp)
actor mean bias [0]=0.3 (gentle forward accel)
ortho init gain=0.01 (actor) / 1.0 (critic)
features extractor RaceEncoder (ImpalaCNN + MLP β†’ 288 dims)
net_arch empty (direct linear heads, no extra MLP)
share features True (encoder params shared across heads)
W&B metrics use identical keys to train_sb3.py.
Usage
─────
uv run python train_torchrl.py
uv run python train_torchrl.py --num-envs 8 --total-steps 10_000_000
uv run python train_torchrl.py --resume checkpoints/ppo_torchrl_step500000.pt
"""
import argparse
import math
import os
import random
import re
import statistics
import sys
import time
from collections import deque
# Ensure project root (parent of training/) is on sys.path so env/ and game/ are importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import torch
import torch.nn as nn
import wandb
# Headless pygame β€” must come before any game/env import
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
# UTF-8 stdout so box-drawing glyphs inside tensordict/torchrl banners don't
# explode on Windows cp1252 when wandb wraps stdout.
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.collectors import Collector
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
import multiprocessing as mp
from torchrl.envs import Compose, GymWrapper, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.gym_like import BaseInfoDictReader
from torchrl.envs.transforms import RewardSum
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.data.tensor_specs import Composite, Unbounded
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import IndependentNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from env import CurriculumBuilder, DriveAction
from env.encoder import RaceEncoder
from env.gym_env import RaceGymEnv
from game.rl_splits import TRAIN, difficulty_of
# ─────────────────────────────────────────────────────────────────────────────
# Args (same flags/defaults as train_sb3.py)
# ─────────────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
g = p.add_argument_group("W&B")
g.add_argument("--wandb-project", default="curriculum-car-racer")
g.add_argument("--wandb-run-name", default=None)
g.add_argument("--wandb-id", default=None)
g.add_argument("--wandb-offline", action="store_true")
g = p.add_argument_group("Training budget")
g.add_argument("--total-steps", type=int, default=5_000_000)
g.add_argument("--rollout-steps", type=int, default=2048,
help="Total frames per PPO update (across all envs)")
g.add_argument("--num-envs", type=int, default=4)
g.add_argument("--batch-size", type=int, default=64)
g.add_argument("--ppo-epochs", type=int, default=10)
g = p.add_argument_group("PPO (SB3 defaults)")
g.add_argument("--lr", type=float, default=3e-4)
g.add_argument("--gamma", type=float, default=0.99)
g.add_argument("--gae-lambda", type=float, default=0.95)
g.add_argument("--clip-eps", type=float, default=0.2)
g.add_argument("--vf-coef", type=float, default=0.5)
g.add_argument("--ent-coef", type=float, default=0.01)
g.add_argument("--max-grad-norm", type=float, default=0.5)
g.add_argument("--target-kl", type=float, default=0.1,
help="Stop PPO epochs early when approx_kl exceeds this")
g = p.add_argument_group("Curriculum")
g.add_argument("--threshold", type=float, default=30.0)
g.add_argument("--window", type=int, default=20)
g.add_argument("--replay-frac", type=float, default=0.3)
g.add_argument("--eval-episodes", type=int, default=1,
help="Greedy eval episodes run every --eval-interval-steps for curriculum gating")
g.add_argument("--eval-interval-steps", type=int, default=25_000,
help="Run greedy curriculum eval every N global steps")
g = p.add_argument_group("Checkpointing")
g.add_argument("--checkpoint-interval", type=int, default=500_000)
g.add_argument("--checkpoint-dir", default="checkpoints")
g.add_argument("--keep-checkpoints", type=int, default=5)
g.add_argument("--resume", default=None)
g = p.add_argument_group("Misc")
g.add_argument("--seed", type=int, default=42)
g.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
g.add_argument("--compile", action="store_true")
g.add_argument("--video-interval", type=int, default=25_000)
g.add_argument("--video-dir", default="inference_videos")
return p.parse_args()
# ─────────────────────────────────────────────────────────────────────────────
# Shared encoder + actor/critic heads
#
# RaceEncoder (ImpalaCNN + scalar MLP β†’ 288) is shared between the actor and
# critic, matching SB3's share_features_extractor=True. The actor and critic
# each run the encoder once per forward (torchrl's PPO evaluates them
# separately); parameter sharing gives identical gradients to SB3.
# ─────────────────────────────────────────────────────────────────────────────
def _flatten_batch_dims(image: torch.Tensor, scalars: torch.Tensor):
"""
Collapse all leading batch dimensions into one so Conv2d gets a 4D tensor.
Returns (img_flat, sca_flat, lead_shape) where lead_shape is used to
restore the original batch structure on outputs.
RaceEncoder expects image (B, 3, 64, 64) and scalars (B, 9). During PPO
loss/GAE, torchrl hands us (N, T, 3, 64, 64) / (N, T, 9); during rollout
collection it's (N, 3, 64, 64) / (N, 9). Flatten uniformly.
"""
lead_shape = image.shape[:-3] # everything except C,H,W
img_flat = image.reshape(-1, *image.shape[-3:])
sca_flat = scalars.reshape(-1, scalars.shape[-1])
return img_flat, sca_flat, lead_shape
class _ActorNet(nn.Module):
"""Image + scalars β†’ (loc, scale) for IndependentNormal."""
def __init__(self, encoder: RaceEncoder, log_std_init: float = -1.0):
super().__init__()
self.encoder = encoder
self.mean = nn.Linear(encoder.out_features, 2)
# log_std is a free parameter (not state-conditioned), matching SB3's
# DiagGaussianDistribution. Unbounded β€” SB3 does not clamp log_std.
self.log_std = nn.Parameter(torch.full((2,), float(log_std_init)))
nn.init.orthogonal_(self.mean.weight, gain=0.01)
nn.init.zeros_(self.mean.bias)
# Gentle forward accel so the car stays moving while exploring.
with torch.no_grad():
self.mean.bias[0] = 0.3
def forward(self, image: torch.Tensor, scalars: torch.Tensor):
img_f, sca_f, lead = _flatten_batch_dims(image, scalars)
feat = self.encoder(img_f, sca_f)
loc = self.mean(feat).reshape(*lead, 2)
scale = self.log_std.exp().expand_as(loc)
return loc, scale
class _CriticNet(nn.Module):
"""Image + scalars β†’ value (1,)."""
def __init__(self, encoder: RaceEncoder):
super().__init__()
self.encoder = encoder
self.value = nn.Linear(encoder.out_features, 1)
nn.init.orthogonal_(self.value.weight, gain=1.0)
nn.init.zeros_(self.value.bias)
def forward(self, image: torch.Tensor, scalars: torch.Tensor):
img_f, sca_f, lead = _flatten_batch_dims(image, scalars)
v = self.value(self.encoder(img_f, sca_f))
return v.reshape(*lead, 1)
def _sb3_ortho_init(module: nn.Module, gain: float) -> None:
"""Mirror SB3's ActorCriticPolicy.init_weights(gain) via module.apply:
orthogonal-init every Conv2d/Linear weight with the given gain and zero
biases. SB3 applies gain=sqrt(2) to the features extractor."""
for m in module.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(m.weight, gain=gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def build_policy_and_value(device: torch.device):
"""Build actor + critic modules with a shared RaceEncoder (shared params)."""
encoder = RaceEncoder()
_sb3_ortho_init(encoder, gain=math.sqrt(2)) # SB3 ortho_init=True on features extractor
actor_net = _ActorNet(encoder) # head overrides encoder-wide init on its own layers
critic_net = _CriticNet(encoder)
actor_tdm = TensorDictModule(
actor_net,
in_keys=["image", "scalars"],
out_keys=["loc", "scale"],
)
policy_module = ProbabilisticActor(
module = actor_tdm,
in_keys = ["loc", "scale"],
out_keys = ["action"],
distribution_class = IndependentNormal,
return_log_prob = True,
).to(device)
value_module = ValueOperator(
module = critic_net,
in_keys = ["image", "scalars"],
out_keys = ["state_value"],
).to(device)
return policy_module, value_module, encoder
# ─────────────────────────────────────────────────────────────────────────────
# Environment factory
# ─────────────────────────────────────────────────────────────────────────────
class _EpisodeStatsReader(BaseInfoDictReader):
"""
BaseInfoDictReader subclass so set_info_dict_reader() registers the keys in
GymWrapper.observation_spec β€” required for ParallelEnv to allocate shared
memory and transfer these values from subprocess to main process.
"""
info_spec = Composite(
episode_laps = Unbounded((), dtype=torch.float32),
episode_crashes = Unbounded((), dtype=torch.float32),
on_track_pct = Unbounded((), dtype=torch.float32),
track_level = Unbounded((), dtype=torch.float32),
)
def __call__(self, info, td):
td["episode_laps"] = torch.tensor(info.get("episode_laps", 0), dtype=torch.float32)
td["episode_crashes"] = torch.tensor(info.get("episode_crashes", 0), dtype=torch.float32)
td["on_track_pct"] = torch.tensor(info.get("on_track_pct", 0.0), dtype=torch.float32)
td["track_level"] = torch.tensor(info.get("track_level", 0), dtype=torch.float32)
def reset(self, tensordict_reset=None):
pass
def make_vec_env(num_envs: int, max_steps: int, laps_target: int,
replay_frac: float, device, shared_level: mp.Value,
shared_priority=None, shared_n_priority=None):
"""
ParallelEnv of GymWrapper(RaceGymEnv) β€” each env runs in its own subprocess
for parallel CPU stepping. Frontier level is shared via a multiprocessing.Value
so curriculum advances in the main process propagate instantly to all workers.
"""
def _factory():
gym_env = RaceGymEnv(
sampler = None,
frontier_level = 0,
replay_frac = replay_frac,
max_steps = max_steps,
laps_target = laps_target,
shared_level = shared_level,
shared_priority = shared_priority,
shared_n_priority = shared_n_priority,
)
wrapped = GymWrapper(gym_env, device="cpu")
wrapped.set_info_dict_reader(_EpisodeStatsReader())
return wrapped
base = ParallelEnv(num_envs, _factory, mp_start_method="fork")
return TransformedEnv(base, Compose(StepCounter(), RewardSum()))
# ─────────────────────────────────────────────────────────────────────────────
# Inference video (frontier track only β€” same as train_sb3.py)
# ─────────────────────────────────────────────────────────────────────────────
def _game_frame(race_env) -> np.ndarray:
import pygame
from game.oval_racer import draw_car, draw_headlights
ce = race_env._env
surf = ce.track.surface.copy()
draw_headlights(surf, ce._x, ce._y, ce._angle)
draw_car(surf, ce._x, ce._y, ce._angle)
small = pygame.transform.scale(surf, (450, 300))
return pygame.surfarray.array3d(small).transpose(1, 0, 2).copy()
@torch.no_grad()
def log_inference_videos(
policy_module,
builder: CurriculumBuilder,
device: torch.device,
global_step: int,
video_dir: str = "inference_videos",
frame_skip: int = 2,
) -> None:
import imageio.v3 as iio
from env.environment import RaceEnvironment
from game.rl_splits import _ensure_pygame
_ensure_pygame()
os.makedirs(video_dir, exist_ok=True)
policy_module.eval()
frontier_track = TRAIN[builder.current_level]
track = frontier_track
track.build()
env = RaceEnvironment(track, max_steps=3000, laps_target=1, use_image=True)
raw_obs = env.reset()
frames = [_game_frame(env)]
step = 0
while not raw_obs.done:
img = (torch.from_numpy(raw_obs.image.copy())
.float().div(255.0).permute(2, 0, 1).unsqueeze(0).to(device))
scalars = torch.tensor(raw_obs.scalars, dtype=torch.float32,
device=device).unsqueeze(0)
td = TensorDict({"image": img, "scalars": scalars}, batch_size=[1])
with set_exploration_type(ExplorationType.MEAN):
td = policy_module(td)
action = td.get("action")[0].clamp(-1.0, 1.0).cpu().numpy()
raw_obs = env.step(DriveAction(
accel=float(action[0]), steer=float(action[1])
))
step += 1
if step % frame_skip == 0:
frames.append(_game_frame(env))
video = np.stack(frames, axis=0)
track_slug = track.name.replace(" ", "_")
filename = f"step{global_step:08d}_track{track.level:02d}_{track_slug}.mp4"
iio.imwrite(os.path.join(video_dir, filename),
video, fps=20, codec="libx264", plugin="pyav")
policy_module.train()
print(f" [VIDEO] Saved to {os.path.join(video_dir, filename)}")
# ─────────────────────────────────────────────────────────────────────────────
# Episode iteration over a collected rollout
# ─────────────────────────────────────────────────────────────────────────────
def _iter_episodes(td):
"""
Yield (env_idx, step_idx, episode_reward, episode_length, info_dict) for
every terminal step in the collected rollout.
td has shape (N, T) with the standard torchrl layout:
td["next","done"] β†’ (N, T, 1)
td["next","episode_reward"]β†’ (N, T, 1) (from RewardSum)
td["next","step_count"] β†’ (N, T, 1) (from StepCounter)
td["next","episode_laps"] β†’ (N, T) (from RaceGymEnv info, if present)
RaceGymEnv writes episode_laps/episode_crashes/on_track_pct only on done.
"""
next_td = td.get("next")
dones = next_td.get("done").squeeze(-1) # (N, T) bool
ep_r = next_td.get("episode_reward").squeeze(-1) # (N, T) float
ep_l = next_td.get("step_count").squeeze(-1) # (N, T) int
# Info fields from RaceGymEnv (only populated on done steps). If they were
# never observed, these keys will not exist β€” tolerate that.
def _get(key, default):
if key in next_td.keys():
v = next_td.get(key)
return v.squeeze(-1) if v.dim() > dones.dim() else v
return torch.full_like(ep_r, float(default))
ep_crashes = _get("episode_crashes", 0).to(torch.float32)
ep_laps_info = _get("episode_laps", 0).to(torch.float32)
on_track = _get("on_track_pct", 0.0).to(torch.float32)
track_level = _get("track_level", 0).to(torch.float32)
track_name = next_td.get("track_name", None) # may be bytes/str tensor
N, T = dones.shape
for n in range(N):
for t in range(T):
if not bool(dones[n, t]):
continue
yield {
"env_idx": n,
"step_idx": t,
"ep_reward": float(ep_r[n, t]),
"ep_length": int(ep_l[n, t]),
"ep_crashes": int(ep_crashes[n, t]),
"ep_laps": int(ep_laps_info[n, t]),
"on_track_pct": float(on_track[n, t]),
"track_level": int(track_level[n, t]),
}
# ─────────────────────────────────────────────────────────────────────────────
# Checkpoints
# ─────────────────────────────────────────────────────────────────────────────
def save_checkpoint(path, policy_module, value_module, optimizer,
global_step, builder, args, reward_window, frontier_reward_window,
episode_num, wandb_run_id):
torch.save({
"step": global_step,
"curriculum_level": builder.current_level,
"policy": policy_module.state_dict(),
"value": value_module.state_dict(),
"optimizer": optimizer.state_dict(),
"args": vars(args),
"reward_window": list(reward_window),
"frontier_reward_window": list(frontier_reward_window),
"episode_num": episode_num,
"sampler_idx": builder._sampler._idx,
"sampler_rewards": list(builder._sampler._rewards),
"sampler_crashes": list(builder._sampler._crashes),
"sampler_laps": list(builder._sampler._laps),
"sampler_is_frontier": list(builder._sampler._is_frontier),
"sampler_frontier_crashes": list(builder._sampler._frontier_crashes),
"sampler_frontier_laps": list(builder._sampler._frontier_laps),
"wandb_run_id": wandb_run_id,
}, path)
def prune_checkpoints(checkpoint_dir: str, keep: int):
if keep <= 0:
return
import glob as _glob
pts = sorted(_glob.glob(os.path.join(checkpoint_dir, "ppo_torchrl_step*.pt")))
for old in pts[:-keep]:
os.remove(old)
print(f" [PRUNE] Removed {os.path.basename(old)}")
# ─────────────────────────────────────────────────────────────────────────────
# Greedy curriculum evaluation
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def _greedy_eval(policy_module, track, device, n_episodes, max_steps=3000):
"""
Run n_episodes greedy (MEAN action) episodes on track.
Returns list of dicts: {laps, crashes}.
Uses RaceEnvironment directly β€” no vectorisation overhead.
"""
from env.environment import RaceEnvironment
policy_module.eval()
results = []
track.build()
for _ in range(n_episodes):
env = RaceEnvironment(track, max_steps=max_steps, laps_target=1, use_image=True)
raw_obs = env.reset()
while not raw_obs.done:
img = (torch.from_numpy(raw_obs.image.copy())
.float().div(255.0).permute(2, 0, 1).unsqueeze(0).to(device))
scalars = torch.tensor(raw_obs.scalars, dtype=torch.float32,
device=device).unsqueeze(0)
td = TensorDict({"image": img, "scalars": scalars}, batch_size=[1])
with set_exploration_type(ExplorationType.MEAN):
td = policy_module(td)
action = td["action"][0].clamp(-1.0, 1.0).cpu().numpy()
raw_obs = env.step(DriveAction(accel=float(action[0]), steer=float(action[1])))
ce = env._env
results.append({"laps": ce._laps, "crashes": ce._crash_count})
policy_module.train()
return results
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main():
args = parse_args()
device = torch.device(args.device)
# ── Seed ──────────────────────────────────────────────────────────────────
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if device.type == "cuda":
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.set_float32_matmul_precision("high") # TF32 on A10 tensor cores
# ── Auto-detect latest checkpoint when resuming a W&B run ────────────────
if args.wandb_id and not args.resume:
import glob as _glob
ckpts = sorted(_glob.glob(os.path.join(args.checkpoint_dir, "ppo_torchrl_step*.pt")))
if ckpts:
args.resume = ckpts[-1]
print(f" [RESUME] Auto-detected checkpoint: {args.resume}")
ckpt = None
if args.resume:
print(f"\n [RESUME] Loading {args.resume}")
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
print(f" [RESUME] From step {ckpt['step']:,} lvl {ckpt['curriculum_level']}")
# ── W&B ───────────────────────────────────────────────────────────────────
wandb_kwargs = dict(
project = args.wandb_project,
name = args.wandb_run_name,
config = vars(args),
mode = "offline" if args.wandb_offline else "online",
sync_tensorboard = False,
)
if args.wandb_id:
wandb_kwargs["id"] = args.wandb_id
wandb_kwargs["resume"] = "must"
elif ckpt and ckpt.get("wandb_run_id"):
wandb_kwargs["id"] = ckpt["wandb_run_id"]
wandb_kwargs["resume"] = "allow"
run = wandb.init(**wandb_kwargs)
wandb.define_metric("global_step")
for prefix in ("episode", "ppo", "curriculum", "val", "system"):
wandb.define_metric(f"{prefix}/*", step_metric="global_step")
# ── Curriculum ────────────────────────────────────────────────────────────
builder = CurriculumBuilder(
threshold = args.threshold,
window = args.window,
replay_frac = args.replay_frac,
use_image = True,
)
if ckpt:
builder._sampler._idx = ckpt["sampler_idx"]
builder._sampler._rewards = deque(ckpt["sampler_rewards"], maxlen=args.window)
builder._sampler._crashes = deque(ckpt.get("sampler_crashes", []), maxlen=args.window)
builder._sampler._laps = deque(ckpt.get("sampler_laps", []), maxlen=args.window)
builder._sampler._is_frontier = deque(ckpt.get("sampler_is_frontier", []), maxlen=args.window)
builder._sampler._frontier_crashes = deque(ckpt.get("sampler_frontier_crashes", []), maxlen=args.window)
builder._sampler._frontier_laps = deque(ckpt.get("sampler_frontier_laps", []), maxlen=args.window)
sampler = builder._sampler
# ── Environment (torchrl) ─────────────────────────────────────────────────
N = args.num_envs
shared_level = mp.Value("i", builder.current_level)
shared_priority = mp.Array("i", [-1] * 10) # TRAIN indices of failing tracks
shared_n_priority = mp.Value("i", 0) # how many entries are valid
vec_env = make_vec_env(
num_envs = N,
max_steps = 3000,
laps_target = 1,
replay_frac = args.replay_frac,
device = device,
shared_level = shared_level,
shared_priority = shared_priority,
shared_n_priority = shared_n_priority,
)
vec_env.set_seed(args.seed)
# ── Policy + value ────────────────────────────────────────────────────────
policy_module, value_module, encoder = build_policy_and_value(device)
if ckpt:
def _strip_orig_mod(sd):
if any(k.startswith("_orig_mod.") for k in sd):
return {k.replace("_orig_mod.", "", 1): v for k, v in sd.items()}
return sd
policy_module.load_state_dict(_strip_orig_mod(ckpt["policy"]))
value_module.load_state_dict(_strip_orig_mod(ckpt["value"]))
# Sanity: run once through reset so specs match
with torch.no_grad():
td0 = vec_env.reset().to(device)
policy_module(td0)
value_module(td0)
# ── Loss + optimiser ──────────────────────────────────────────────────────
advantage_module = GAE(
gamma = args.gamma,
lmbda = args.gae_lambda,
value_network = value_module,
average_gae = False,
)
loss_module = ClipPPOLoss(
actor_network = policy_module,
critic_network = value_module,
clip_epsilon = args.clip_eps,
entropy_bonus = True, # always compute entropy term
entropy_coeff = args.ent_coef, # 0.0 β‡’ SB3 default (no bonus)
critic_coeff = args.vf_coef,
loss_critic_type = "l2",
normalize_advantage = True,
)
optimizer = torch.optim.Adam(loss_module.parameters(), lr=args.lr, eps=1e-5)
if ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
if args.compile:
try:
policy_module = torch.compile(policy_module, mode="default")
print("torch.compile enabled on policy (default)")
except Exception as e:
print(f"torch.compile skipped: {e}")
total_params = (
sum(p.numel() for p in policy_module.parameters() if p.requires_grad)
+ sum(p.numel() for p in value_module.parameters() if p.requires_grad
# exclude shared encoder params (already counted in policy)
and not any(p is q for q in encoder.parameters()))
)
# ── Collector ─────────────────────────────────────────────────────────────
collector = Collector(
vec_env,
policy_module,
frames_per_batch = args.rollout_steps,
total_frames = args.total_steps,
device = device,
storing_device = device,
reset_at_each_iter = False,
)
# ── Replay buffer for PPO minibatches ─────────────────────────────────────
replay = ReplayBuffer(
storage = LazyTensorStorage(args.rollout_steps, device=device),
sampler = SamplerWithoutReplacement(),
batch_size = args.batch_size,
)
# ── Counters ──────────────────────────────────────────────────────────────
global_step = ckpt["step"] if ckpt else 0
episode_num = ckpt["episode_num"] if ckpt else 0
reward_window = deque(ckpt["reward_window"] if ckpt else [], maxlen=args.window)
frontier_reward_window = deque(ckpt.get("frontier_reward_window", []) if ckpt else [], maxlen=args.window)
update_num = 0
start_time = time.time()
greedy_clean = 0 # result of last greedy eval (for display)
next_eval = args.eval_interval_steps
while next_eval <= global_step:
next_eval += args.eval_interval_steps
LOG_INTERVAL = 25_000
next_log = LOG_INTERVAL
while next_log <= global_step:
next_log += LOG_INTERVAL
# Accumulators for the current log window
_log_ep_rewards: list[float] = []
_log_ep_lengths: list[int] = []
if args.checkpoint_interval > 0:
next_ckpt = args.checkpoint_interval
while next_ckpt <= global_step:
next_ckpt += args.checkpoint_interval
else:
next_ckpt = float("inf")
if args.video_interval > 0:
next_video = args.video_interval
while next_video <= global_step:
next_video += args.video_interval
else:
next_video = float("inf")
os.makedirs(args.checkpoint_dir, exist_ok=True)
print(f"\nModel: {total_params:,} parameters | Device: {device} | Envs: {N}")
print(f"Rollout: {args.rollout_steps} frames per update (batch={args.batch_size}, epochs={args.ppo_epochs})")
print(f"PPO: lr={args.lr} gamma={args.gamma} lambda={args.gae_lambda} clip={args.clip_eps}")
print(f" vf={args.vf_coef} ent={args.ent_coef} grad={args.max_grad_norm}")
print(f"Curriculum: threshold={args.threshold} window={args.window} replay={args.replay_frac}")
print(f"Frontier : track {sampler.frontier_track.level} '{sampler.frontier_track.name}'")
print(f"W&B : {run.url}\n")
# ─────────────────────────────────────────────────────────────────────────
# Training loop
# ─────────────────────────────────────────────────────────────────────────
print("Starting training loop...", flush=True)
for td in collector:
rollout_frames = td.numel()
global_step += rollout_frames
update_num += 1
# ── Episode bookkeeping + curriculum advance ─────────────────────────
for ep in _iter_episodes(td):
episode_num += 1
reward_window.append(ep["ep_reward"])
_log_ep_rewards.append(ep["ep_reward"])
_log_ep_lengths.append(ep["ep_length"])
frontier = sampler.frontier_track
threshold = args.threshold * frontier.complexity
is_replay = (ep["track_level"] != frontier.level)
if not is_replay:
frontier_reward_window.append(ep["ep_reward"])
rolling_mean = statistics.mean(frontier_reward_window) if frontier_reward_window else 0.0
builder.record(ep["ep_reward"], ep["ep_crashes"], ep["ep_laps"],
is_frontier=not is_replay)
wandb.log({
"global_step": global_step,
"episode/reward": ep["ep_reward"],
"episode/length": ep["ep_length"],
"episode/laps": ep["ep_laps"],
"episode/crashes": ep["ep_crashes"],
"episode/on_track_pct": ep["on_track_pct"],
"episode/number": episode_num,
"curriculum/level": builder.current_level,
"curriculum/track_level": ep["track_level"],
"curriculum/tier": difficulty_of(frontier),
"curriculum/rolling_mean": rolling_mean,
"curriculum/threshold": threshold,
"curriculum/is_replay": int(is_replay),
}, step=global_step)
# ── Compute GAE advantages & targets, then flatten for PPO ───────────
with torch.no_grad():
advantage_module(td)
data_flat = td.reshape(-1)
replay.extend(data_flat)
# ── PPO update: n_epochs Γ— minibatches ───────────────────────────────
pg_losses, v_losses, ent_losses = [], [], []
approx_kls, clip_fracs, grad_norms = [], [], []
for epoch in range(args.ppo_epochs):
for _ in range(args.rollout_steps // args.batch_size):
mb = replay.sample()
loss_vals = loss_module(mb)
loss = (
loss_vals["loss_objective"]
+ loss_vals.get("loss_critic", torch.tensor(0.0, device=device))
+ loss_vals.get("loss_entropy", torch.tensor(0.0, device=device))
)
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
loss_module.parameters(), args.max_grad_norm
)
optimizer.step()
pg_losses.append(loss_vals["loss_objective"].detach().item())
if "loss_critic" in loss_vals:
v_losses.append(loss_vals["loss_critic"].detach().item())
if "loss_entropy" in loss_vals:
ent_losses.append(loss_vals["loss_entropy"].detach().item())
if "kl_approx" in loss_vals:
approx_kls.append(loss_vals["kl_approx"].detach().item())
if "clip_fraction" in loss_vals:
clip_fracs.append(loss_vals["clip_fraction"].detach().item())
grad_norms.append(float(grad_norm))
# After each epoch, re-shuffle for next epoch
replay.empty()
replay.extend(data_flat)
# Early stop if policy has moved too far from rollout data
if args.target_kl is not None and approx_kls:
epoch_kl = float(np.mean(approx_kls))
if epoch_kl > args.target_kl:
break
replay.empty()
# ── Explained variance ────────────────────────────────────────────────
with torch.no_grad():
values = td.get("state_value").reshape(-1)
returns = td.get(("next", "value_target")).reshape(-1) \
if ("next", "value_target") in td.keys(include_nested=True) \
else td.get("value_target").reshape(-1)
var_y = torch.var(returns)
ev = float("nan") if var_y == 0 else float(
1.0 - torch.var(returns - values) / var_y
)
# ── Log PPO + system metrics ──────────────────────────────────────────
sps = global_step / max(time.time() - start_time, 1e-6)
_mean = lambda xs: float(np.mean(xs)) if xs else float("nan")
wandb.log({
"global_step": global_step,
"ppo/policy_loss": _mean(pg_losses),
"ppo/value_loss": _mean(v_losses),
"ppo/entropy": _mean(ent_losses),
"ppo/approx_kl": _mean(approx_kls),
"ppo/clip_fraction": _mean(clip_fracs),
"ppo/explained_variance": ev,
"ppo/learning_rate": args.lr,
"ppo/entropy_coef": args.ent_coef,
"ppo/grad_norm": _mean(grad_norms),
"ppo/update": update_num,
"system/steps_per_sec": sps,
"system/elapsed_hours": (time.time() - start_time) / 3600,
}, step=global_step)
# ── Periodic SB3-style summary ────────────────────────────────────────
if global_step >= next_log:
ep_rew_mean = float(np.mean(_log_ep_rewards)) if _log_ep_rewards else float("nan")
ep_len_mean = float(np.mean(_log_ep_lengths)) if _log_ep_lengths else float("nan")
sps_now = global_step / max(time.time() - start_time, 1e-6)
frontier = sampler.frontier_track
win_clean = sum(1 for l, c in zip(sampler._frontier_laps, sampler._frontier_crashes)
if l >= 1 and c == 0)
def _fmt(v, fmt=".3g"):
return ("-" if v != v else format(v, fmt)) # nan β†’ "-"
rows = [
("rollout/", ""),
(" ep_len_mean", _fmt(ep_len_mean, ".1f")),
(" ep_rew_mean", _fmt(ep_rew_mean, ".3f")),
(" episodes", str(episode_num)),
("curriculum/", ""),
(" level", f"{builder.current_level}/{len(TRAIN)-1}"),
(" frontier_track", f"{frontier.level} '{frontier.name}'"),
(" rolling_mean", _fmt(rolling_mean, ".2f")),
(" clean_wins", f"{win_clean}/{args.window} (stochastic)"),
(" greedy_clean", f"{greedy_clean}/{args.eval_episodes} (eval)"),
("time/", ""),
(" fps", _fmt(sps_now, ".0f")),
(" iterations", str(update_num)),
(" total_timesteps", f"{global_step:,}"),
("train/", ""),
(" approx_kl", _fmt(_mean(approx_kls))),
(" clip_fraction", _fmt(_mean(clip_fracs))),
(" entropy_loss", _fmt(_mean(ent_losses))),
(" explained_variance", _fmt(ev)),
(" learning_rate", _fmt(args.lr)),
(" policy_grad_loss", _fmt(_mean(pg_losses))),
(" value_loss", _fmt(_mean(v_losses))),
(" grad_norm", _fmt(_mean(grad_norms))),
]
col_w = max(len(k) for k, _ in rows) + 2
val_w = max((len(v) for _, v in rows if v), default=6) + 2
sep = "-" * (col_w + val_w + 5)
print(sep)
for k, v in rows:
if v:
print(f"| {k:<{col_w}} | {v:>{val_w}} |")
else:
print(f"| {k:<{col_w+val_w+3}} |")
print(sep, flush=True)
_log_ep_rewards.clear()
_log_ep_lengths.clear()
next_log += LOG_INTERVAL
# ── Greedy curriculum eval ────────────────────────────────────────────
if global_step >= next_eval:
tracks_to_eval = TRAIN
all_pass = True
eval_log = {}
n_passing = 0
eval_passed: dict = {} # tr β†’ bool, used for priority replay update
print(f"\n [EVAL] greedy eval β€” {len(tracks_to_eval)} track(s):", flush=True)
for tr in tracks_to_eval:
res = _greedy_eval(policy_module, tr, device,
args.eval_episodes, max_steps=3000)
clean = sum(1 for r in res if r["laps"] >= 1 and r["crashes"] == 0)
ok = (clean == args.eval_episodes)
n_passing += int(ok)
eval_passed[tr] = ok
if not ok:
all_pass = False
print(f" track {tr.level:02d} '{tr.name}': "
f"{'PASS' if ok else 'FAIL'} ({clean}/{args.eval_episodes})",
flush=True)
eval_log[f"curriculum/greedy_track{tr.level:02d}"] = int(ok)
# Update priority replay: give failing tracks dedicated 30% of episodes.
failing_indices = [
i for i, tr in enumerate(tracks_to_eval) if not eval_passed[tr]
]
n_fail = min(len(failing_indices), 10)
shared_n_priority.value = n_fail
for i, idx in enumerate(failing_indices[:10]):
shared_priority[i] = idx
if failing_indices:
fail_names = ", ".join(
f"track {TRAIN[i].level}" for i in failing_indices[:10]
)
print(f" [PRIO] priority replay set for {n_fail} failing track(s): {fail_names}",
flush=True)
else:
print(f" [PRIO] all tracks passed β€” priority replay cleared", flush=True)
wandb.log({"global_step": global_step,
"curriculum/greedy_pass": int(all_pass),
"curriculum/greedy_n_pass": n_passing,
"curriculum/priority_n_tracks": n_fail,
**eval_log}, step=global_step)
if all_pass:
# Every track passed simultaneously β€” training complete
shared_n_priority.value = 0
print(
f"\n >>> ALL {len(TRAIN)} TRACKS PASSED β€” training complete "
f"at step {global_step:,}\n",
flush=True,
)
wandb.log({"global_step": global_step, "curriculum/complete": 1},
step=global_step)
adv_path = os.path.join(
args.checkpoint_dir,
f"ppo_torchrl_complete_step{global_step:08d}.pt",
)
save_checkpoint(
adv_path, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
print(f" [CKPT] Final checkpoint: {adv_path}", flush=True)
break
else:
# Advance only if every track up to and including the frontier passes
prior_ok = all(
eval_passed.get(tr, False)
for tr in TRAIN[: builder.current_level + 1]
)
if prior_ok:
advanced = builder._sampler.advance()
if advanced:
shared_level.value = builder.current_level
new_frontier = sampler.frontier_track
print(
f"\n >>> ADVANCE -> Track {new_frontier.level} "
f"'{new_frontier.name}' "
f"[lvl {builder.current_level}/{len(TRAIN)-1}]\n",
flush=True,
)
wandb.log({
"global_step": global_step,
"curriculum/level": builder.current_level,
"curriculum/advanced_to_level": new_frontier.level,
"curriculum/advanced_to_name": new_frontier.name,
}, step=global_step)
adv_path = os.path.join(
args.checkpoint_dir,
f"ppo_torchrl_advance_lvl{builder.current_level:02d}"
f"_step{global_step:08d}.pt",
)
save_checkpoint(
adv_path, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
print(f" [CKPT] Advance checkpoint: {adv_path}", flush=True)
next_eval += args.eval_interval_steps
# ── Checkpoint ────────────────────────────────────────────────────────
if global_step >= next_ckpt:
ckpt_path = os.path.join(
args.checkpoint_dir,
f"ppo_torchrl_step{global_step:08d}_lvl{builder.current_level:02d}.pt",
)
save_checkpoint(
ckpt_path, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
wandb.save(ckpt_path)
print(f"\n [CKPT] {ckpt_path}")
prune_checkpoints(args.checkpoint_dir, args.keep_checkpoints)
next_ckpt += args.checkpoint_interval
# ── Video ─────────────────────────────────────────────────────────────
if global_step >= next_video:
try:
log_inference_videos(
policy_module = policy_module,
builder = builder,
device = device,
global_step = global_step,
video_dir = args.video_dir,
)
except Exception as e:
print(f" [VIDEO] Warning: failed to render video: {e}")
next_video += args.video_interval
# ── Update collector's policy weights (in case of compile) ───────────
collector.update_policy_weights_()
# ── Final checkpoint ──────────────────────────────────────────────────────
final = os.path.join(args.checkpoint_dir, "ppo_torchrl_final.pt")
save_checkpoint(
final, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
wandb.save(final)
collector.shutdown()
elapsed = time.time() - start_time
print(f"\n{'-'*80}")
print(f"Training complete | {global_step:,} steps | {elapsed/3600:.2f} h")
print(f"Final model: {final}")
print(f"W&B run: {run.url}")
run.finish()
if __name__ == "__main__":
import traceback
try:
main()
except Exception:
traceback.print_exc()
sys.exit(1)