Spaces:
Sleeping
Sleeping
File size: 51,131 Bytes
27cec6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 | """
train_torchrl.py β TorchRL PPO training for the curriculum car racer.
Replaces train_sb3.py with a torchrl-based PPO. All hyperparameters are
transferred exactly from train_sb3.py (SB3 PPO defaults + our overrides):
learning_rate 3e-4 (Adam, eps=1e-5) β SB3 default
rollout frames 2048 / update (n_steps Γ n_envs in SB3)
batch size 64 (minibatch size for PPO updates)
n_epochs 10 (passes over rollout data)
gamma 0.99
gae_lambda 0.95
clip_epsilon 0.2
vf_coef 0.5
ent_coef 0.0 (SB3 default)
max_grad_norm 0.5
normalize_adv True
target_kl None (no early stop)
log_std_init -1.0 (initial std β 0.37; SB3 DiagGaussian β no clamp)
actor mean bias [0]=0.3 (gentle forward accel)
ortho init gain=0.01 (actor) / 1.0 (critic)
features extractor RaceEncoder (ImpalaCNN + MLP β 288 dims)
net_arch empty (direct linear heads, no extra MLP)
share features True (encoder params shared across heads)
W&B metrics use identical keys to train_sb3.py.
Usage
βββββ
uv run python train_torchrl.py
uv run python train_torchrl.py --num-envs 8 --total-steps 10_000_000
uv run python train_torchrl.py --resume checkpoints/ppo_torchrl_step500000.pt
"""
import argparse
import math
import os
import random
import re
import statistics
import sys
import time
from collections import deque
# Ensure project root (parent of training/) is on sys.path so env/ and game/ are importable
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import torch
import torch.nn as nn
import wandb
# Headless pygame β must come before any game/env import
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
# UTF-8 stdout so box-drawing glyphs inside tensordict/torchrl banners don't
# explode on Windows cp1252 when wandb wraps stdout.
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.collectors import Collector
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
import multiprocessing as mp
from torchrl.envs import Compose, GymWrapper, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.gym_like import BaseInfoDictReader
from torchrl.envs.transforms import RewardSum
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.data.tensor_specs import Composite, Unbounded
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import IndependentNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from env import CurriculumBuilder, DriveAction
from env.encoder import RaceEncoder
from env.gym_env import RaceGymEnv
from game.rl_splits import TRAIN, difficulty_of
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Args (same flags/defaults as train_sb3.py)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def parse_args():
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
g = p.add_argument_group("W&B")
g.add_argument("--wandb-project", default="curriculum-car-racer")
g.add_argument("--wandb-run-name", default=None)
g.add_argument("--wandb-id", default=None)
g.add_argument("--wandb-offline", action="store_true")
g = p.add_argument_group("Training budget")
g.add_argument("--total-steps", type=int, default=5_000_000)
g.add_argument("--rollout-steps", type=int, default=2048,
help="Total frames per PPO update (across all envs)")
g.add_argument("--num-envs", type=int, default=4)
g.add_argument("--batch-size", type=int, default=64)
g.add_argument("--ppo-epochs", type=int, default=10)
g = p.add_argument_group("PPO (SB3 defaults)")
g.add_argument("--lr", type=float, default=3e-4)
g.add_argument("--gamma", type=float, default=0.99)
g.add_argument("--gae-lambda", type=float, default=0.95)
g.add_argument("--clip-eps", type=float, default=0.2)
g.add_argument("--vf-coef", type=float, default=0.5)
g.add_argument("--ent-coef", type=float, default=0.01)
g.add_argument("--max-grad-norm", type=float, default=0.5)
g.add_argument("--target-kl", type=float, default=0.1,
help="Stop PPO epochs early when approx_kl exceeds this")
g = p.add_argument_group("Curriculum")
g.add_argument("--threshold", type=float, default=30.0)
g.add_argument("--window", type=int, default=20)
g.add_argument("--replay-frac", type=float, default=0.3)
g.add_argument("--eval-episodes", type=int, default=1,
help="Greedy eval episodes run every --eval-interval-steps for curriculum gating")
g.add_argument("--eval-interval-steps", type=int, default=25_000,
help="Run greedy curriculum eval every N global steps")
g = p.add_argument_group("Checkpointing")
g.add_argument("--checkpoint-interval", type=int, default=500_000)
g.add_argument("--checkpoint-dir", default="checkpoints")
g.add_argument("--keep-checkpoints", type=int, default=5)
g.add_argument("--resume", default=None)
g = p.add_argument_group("Misc")
g.add_argument("--seed", type=int, default=42)
g.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
g.add_argument("--compile", action="store_true")
g.add_argument("--video-interval", type=int, default=25_000)
g.add_argument("--video-dir", default="inference_videos")
return p.parse_args()
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Shared encoder + actor/critic heads
#
# RaceEncoder (ImpalaCNN + scalar MLP β 288) is shared between the actor and
# critic, matching SB3's share_features_extractor=True. The actor and critic
# each run the encoder once per forward (torchrl's PPO evaluates them
# separately); parameter sharing gives identical gradients to SB3.
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _flatten_batch_dims(image: torch.Tensor, scalars: torch.Tensor):
"""
Collapse all leading batch dimensions into one so Conv2d gets a 4D tensor.
Returns (img_flat, sca_flat, lead_shape) where lead_shape is used to
restore the original batch structure on outputs.
RaceEncoder expects image (B, 3, 64, 64) and scalars (B, 9). During PPO
loss/GAE, torchrl hands us (N, T, 3, 64, 64) / (N, T, 9); during rollout
collection it's (N, 3, 64, 64) / (N, 9). Flatten uniformly.
"""
lead_shape = image.shape[:-3] # everything except C,H,W
img_flat = image.reshape(-1, *image.shape[-3:])
sca_flat = scalars.reshape(-1, scalars.shape[-1])
return img_flat, sca_flat, lead_shape
class _ActorNet(nn.Module):
"""Image + scalars β (loc, scale) for IndependentNormal."""
def __init__(self, encoder: RaceEncoder, log_std_init: float = -1.0):
super().__init__()
self.encoder = encoder
self.mean = nn.Linear(encoder.out_features, 2)
# log_std is a free parameter (not state-conditioned), matching SB3's
# DiagGaussianDistribution. Unbounded β SB3 does not clamp log_std.
self.log_std = nn.Parameter(torch.full((2,), float(log_std_init)))
nn.init.orthogonal_(self.mean.weight, gain=0.01)
nn.init.zeros_(self.mean.bias)
# Gentle forward accel so the car stays moving while exploring.
with torch.no_grad():
self.mean.bias[0] = 0.3
def forward(self, image: torch.Tensor, scalars: torch.Tensor):
img_f, sca_f, lead = _flatten_batch_dims(image, scalars)
feat = self.encoder(img_f, sca_f)
loc = self.mean(feat).reshape(*lead, 2)
scale = self.log_std.exp().expand_as(loc)
return loc, scale
class _CriticNet(nn.Module):
"""Image + scalars β value (1,)."""
def __init__(self, encoder: RaceEncoder):
super().__init__()
self.encoder = encoder
self.value = nn.Linear(encoder.out_features, 1)
nn.init.orthogonal_(self.value.weight, gain=1.0)
nn.init.zeros_(self.value.bias)
def forward(self, image: torch.Tensor, scalars: torch.Tensor):
img_f, sca_f, lead = _flatten_batch_dims(image, scalars)
v = self.value(self.encoder(img_f, sca_f))
return v.reshape(*lead, 1)
def _sb3_ortho_init(module: nn.Module, gain: float) -> None:
"""Mirror SB3's ActorCriticPolicy.init_weights(gain) via module.apply:
orthogonal-init every Conv2d/Linear weight with the given gain and zero
biases. SB3 applies gain=sqrt(2) to the features extractor."""
for m in module.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(m.weight, gain=gain)
if m.bias is not None:
nn.init.zeros_(m.bias)
def build_policy_and_value(device: torch.device):
"""Build actor + critic modules with a shared RaceEncoder (shared params)."""
encoder = RaceEncoder()
_sb3_ortho_init(encoder, gain=math.sqrt(2)) # SB3 ortho_init=True on features extractor
actor_net = _ActorNet(encoder) # head overrides encoder-wide init on its own layers
critic_net = _CriticNet(encoder)
actor_tdm = TensorDictModule(
actor_net,
in_keys=["image", "scalars"],
out_keys=["loc", "scale"],
)
policy_module = ProbabilisticActor(
module = actor_tdm,
in_keys = ["loc", "scale"],
out_keys = ["action"],
distribution_class = IndependentNormal,
return_log_prob = True,
).to(device)
value_module = ValueOperator(
module = critic_net,
in_keys = ["image", "scalars"],
out_keys = ["state_value"],
).to(device)
return policy_module, value_module, encoder
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Environment factory
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class _EpisodeStatsReader(BaseInfoDictReader):
"""
BaseInfoDictReader subclass so set_info_dict_reader() registers the keys in
GymWrapper.observation_spec β required for ParallelEnv to allocate shared
memory and transfer these values from subprocess to main process.
"""
info_spec = Composite(
episode_laps = Unbounded((), dtype=torch.float32),
episode_crashes = Unbounded((), dtype=torch.float32),
on_track_pct = Unbounded((), dtype=torch.float32),
track_level = Unbounded((), dtype=torch.float32),
)
def __call__(self, info, td):
td["episode_laps"] = torch.tensor(info.get("episode_laps", 0), dtype=torch.float32)
td["episode_crashes"] = torch.tensor(info.get("episode_crashes", 0), dtype=torch.float32)
td["on_track_pct"] = torch.tensor(info.get("on_track_pct", 0.0), dtype=torch.float32)
td["track_level"] = torch.tensor(info.get("track_level", 0), dtype=torch.float32)
def reset(self, tensordict_reset=None):
pass
def make_vec_env(num_envs: int, max_steps: int, laps_target: int,
replay_frac: float, device, shared_level: mp.Value,
shared_priority=None, shared_n_priority=None):
"""
ParallelEnv of GymWrapper(RaceGymEnv) β each env runs in its own subprocess
for parallel CPU stepping. Frontier level is shared via a multiprocessing.Value
so curriculum advances in the main process propagate instantly to all workers.
"""
def _factory():
gym_env = RaceGymEnv(
sampler = None,
frontier_level = 0,
replay_frac = replay_frac,
max_steps = max_steps,
laps_target = laps_target,
shared_level = shared_level,
shared_priority = shared_priority,
shared_n_priority = shared_n_priority,
)
wrapped = GymWrapper(gym_env, device="cpu")
wrapped.set_info_dict_reader(_EpisodeStatsReader())
return wrapped
base = ParallelEnv(num_envs, _factory, mp_start_method="fork")
return TransformedEnv(base, Compose(StepCounter(), RewardSum()))
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Inference video (frontier track only β same as train_sb3.py)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _game_frame(race_env) -> np.ndarray:
import pygame
from game.oval_racer import draw_car, draw_headlights
ce = race_env._env
surf = ce.track.surface.copy()
draw_headlights(surf, ce._x, ce._y, ce._angle)
draw_car(surf, ce._x, ce._y, ce._angle)
small = pygame.transform.scale(surf, (450, 300))
return pygame.surfarray.array3d(small).transpose(1, 0, 2).copy()
@torch.no_grad()
def log_inference_videos(
policy_module,
builder: CurriculumBuilder,
device: torch.device,
global_step: int,
video_dir: str = "inference_videos",
frame_skip: int = 2,
) -> None:
import imageio.v3 as iio
from env.environment import RaceEnvironment
from game.rl_splits import _ensure_pygame
_ensure_pygame()
os.makedirs(video_dir, exist_ok=True)
policy_module.eval()
frontier_track = TRAIN[builder.current_level]
track = frontier_track
track.build()
env = RaceEnvironment(track, max_steps=3000, laps_target=1, use_image=True)
raw_obs = env.reset()
frames = [_game_frame(env)]
step = 0
while not raw_obs.done:
img = (torch.from_numpy(raw_obs.image.copy())
.float().div(255.0).permute(2, 0, 1).unsqueeze(0).to(device))
scalars = torch.tensor(raw_obs.scalars, dtype=torch.float32,
device=device).unsqueeze(0)
td = TensorDict({"image": img, "scalars": scalars}, batch_size=[1])
with set_exploration_type(ExplorationType.MEAN):
td = policy_module(td)
action = td.get("action")[0].clamp(-1.0, 1.0).cpu().numpy()
raw_obs = env.step(DriveAction(
accel=float(action[0]), steer=float(action[1])
))
step += 1
if step % frame_skip == 0:
frames.append(_game_frame(env))
video = np.stack(frames, axis=0)
track_slug = track.name.replace(" ", "_")
filename = f"step{global_step:08d}_track{track.level:02d}_{track_slug}.mp4"
iio.imwrite(os.path.join(video_dir, filename),
video, fps=20, codec="libx264", plugin="pyav")
policy_module.train()
print(f" [VIDEO] Saved to {os.path.join(video_dir, filename)}")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Episode iteration over a collected rollout
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _iter_episodes(td):
"""
Yield (env_idx, step_idx, episode_reward, episode_length, info_dict) for
every terminal step in the collected rollout.
td has shape (N, T) with the standard torchrl layout:
td["next","done"] β (N, T, 1)
td["next","episode_reward"]β (N, T, 1) (from RewardSum)
td["next","step_count"] β (N, T, 1) (from StepCounter)
td["next","episode_laps"] β (N, T) (from RaceGymEnv info, if present)
RaceGymEnv writes episode_laps/episode_crashes/on_track_pct only on done.
"""
next_td = td.get("next")
dones = next_td.get("done").squeeze(-1) # (N, T) bool
ep_r = next_td.get("episode_reward").squeeze(-1) # (N, T) float
ep_l = next_td.get("step_count").squeeze(-1) # (N, T) int
# Info fields from RaceGymEnv (only populated on done steps). If they were
# never observed, these keys will not exist β tolerate that.
def _get(key, default):
if key in next_td.keys():
v = next_td.get(key)
return v.squeeze(-1) if v.dim() > dones.dim() else v
return torch.full_like(ep_r, float(default))
ep_crashes = _get("episode_crashes", 0).to(torch.float32)
ep_laps_info = _get("episode_laps", 0).to(torch.float32)
on_track = _get("on_track_pct", 0.0).to(torch.float32)
track_level = _get("track_level", 0).to(torch.float32)
track_name = next_td.get("track_name", None) # may be bytes/str tensor
N, T = dones.shape
for n in range(N):
for t in range(T):
if not bool(dones[n, t]):
continue
yield {
"env_idx": n,
"step_idx": t,
"ep_reward": float(ep_r[n, t]),
"ep_length": int(ep_l[n, t]),
"ep_crashes": int(ep_crashes[n, t]),
"ep_laps": int(ep_laps_info[n, t]),
"on_track_pct": float(on_track[n, t]),
"track_level": int(track_level[n, t]),
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Checkpoints
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def save_checkpoint(path, policy_module, value_module, optimizer,
global_step, builder, args, reward_window, frontier_reward_window,
episode_num, wandb_run_id):
torch.save({
"step": global_step,
"curriculum_level": builder.current_level,
"policy": policy_module.state_dict(),
"value": value_module.state_dict(),
"optimizer": optimizer.state_dict(),
"args": vars(args),
"reward_window": list(reward_window),
"frontier_reward_window": list(frontier_reward_window),
"episode_num": episode_num,
"sampler_idx": builder._sampler._idx,
"sampler_rewards": list(builder._sampler._rewards),
"sampler_crashes": list(builder._sampler._crashes),
"sampler_laps": list(builder._sampler._laps),
"sampler_is_frontier": list(builder._sampler._is_frontier),
"sampler_frontier_crashes": list(builder._sampler._frontier_crashes),
"sampler_frontier_laps": list(builder._sampler._frontier_laps),
"wandb_run_id": wandb_run_id,
}, path)
def prune_checkpoints(checkpoint_dir: str, keep: int):
if keep <= 0:
return
import glob as _glob
pts = sorted(_glob.glob(os.path.join(checkpoint_dir, "ppo_torchrl_step*.pt")))
for old in pts[:-keep]:
os.remove(old)
print(f" [PRUNE] Removed {os.path.basename(old)}")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Greedy curriculum evaluation
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@torch.no_grad()
def _greedy_eval(policy_module, track, device, n_episodes, max_steps=3000):
"""
Run n_episodes greedy (MEAN action) episodes on track.
Returns list of dicts: {laps, crashes}.
Uses RaceEnvironment directly β no vectorisation overhead.
"""
from env.environment import RaceEnvironment
policy_module.eval()
results = []
track.build()
for _ in range(n_episodes):
env = RaceEnvironment(track, max_steps=max_steps, laps_target=1, use_image=True)
raw_obs = env.reset()
while not raw_obs.done:
img = (torch.from_numpy(raw_obs.image.copy())
.float().div(255.0).permute(2, 0, 1).unsqueeze(0).to(device))
scalars = torch.tensor(raw_obs.scalars, dtype=torch.float32,
device=device).unsqueeze(0)
td = TensorDict({"image": img, "scalars": scalars}, batch_size=[1])
with set_exploration_type(ExplorationType.MEAN):
td = policy_module(td)
action = td["action"][0].clamp(-1.0, 1.0).cpu().numpy()
raw_obs = env.step(DriveAction(accel=float(action[0]), steer=float(action[1])))
ce = env._env
results.append({"laps": ce._laps, "crashes": ce._crash_count})
policy_module.train()
return results
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Main
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
args = parse_args()
device = torch.device(args.device)
# ββ Seed ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if device.type == "cuda":
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.set_float32_matmul_precision("high") # TF32 on A10 tensor cores
# ββ Auto-detect latest checkpoint when resuming a W&B run ββββββββββββββββ
if args.wandb_id and not args.resume:
import glob as _glob
ckpts = sorted(_glob.glob(os.path.join(args.checkpoint_dir, "ppo_torchrl_step*.pt")))
if ckpts:
args.resume = ckpts[-1]
print(f" [RESUME] Auto-detected checkpoint: {args.resume}")
ckpt = None
if args.resume:
print(f"\n [RESUME] Loading {args.resume}")
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
print(f" [RESUME] From step {ckpt['step']:,} lvl {ckpt['curriculum_level']}")
# ββ W&B βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
wandb_kwargs = dict(
project = args.wandb_project,
name = args.wandb_run_name,
config = vars(args),
mode = "offline" if args.wandb_offline else "online",
sync_tensorboard = False,
)
if args.wandb_id:
wandb_kwargs["id"] = args.wandb_id
wandb_kwargs["resume"] = "must"
elif ckpt and ckpt.get("wandb_run_id"):
wandb_kwargs["id"] = ckpt["wandb_run_id"]
wandb_kwargs["resume"] = "allow"
run = wandb.init(**wandb_kwargs)
wandb.define_metric("global_step")
for prefix in ("episode", "ppo", "curriculum", "val", "system"):
wandb.define_metric(f"{prefix}/*", step_metric="global_step")
# ββ Curriculum ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
builder = CurriculumBuilder(
threshold = args.threshold,
window = args.window,
replay_frac = args.replay_frac,
use_image = True,
)
if ckpt:
builder._sampler._idx = ckpt["sampler_idx"]
builder._sampler._rewards = deque(ckpt["sampler_rewards"], maxlen=args.window)
builder._sampler._crashes = deque(ckpt.get("sampler_crashes", []), maxlen=args.window)
builder._sampler._laps = deque(ckpt.get("sampler_laps", []), maxlen=args.window)
builder._sampler._is_frontier = deque(ckpt.get("sampler_is_frontier", []), maxlen=args.window)
builder._sampler._frontier_crashes = deque(ckpt.get("sampler_frontier_crashes", []), maxlen=args.window)
builder._sampler._frontier_laps = deque(ckpt.get("sampler_frontier_laps", []), maxlen=args.window)
sampler = builder._sampler
# ββ Environment (torchrl) βββββββββββββββββββββββββββββββββββββββββββββββββ
N = args.num_envs
shared_level = mp.Value("i", builder.current_level)
shared_priority = mp.Array("i", [-1] * 10) # TRAIN indices of failing tracks
shared_n_priority = mp.Value("i", 0) # how many entries are valid
vec_env = make_vec_env(
num_envs = N,
max_steps = 3000,
laps_target = 1,
replay_frac = args.replay_frac,
device = device,
shared_level = shared_level,
shared_priority = shared_priority,
shared_n_priority = shared_n_priority,
)
vec_env.set_seed(args.seed)
# ββ Policy + value ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
policy_module, value_module, encoder = build_policy_and_value(device)
if ckpt:
def _strip_orig_mod(sd):
if any(k.startswith("_orig_mod.") for k in sd):
return {k.replace("_orig_mod.", "", 1): v for k, v in sd.items()}
return sd
policy_module.load_state_dict(_strip_orig_mod(ckpt["policy"]))
value_module.load_state_dict(_strip_orig_mod(ckpt["value"]))
# Sanity: run once through reset so specs match
with torch.no_grad():
td0 = vec_env.reset().to(device)
policy_module(td0)
value_module(td0)
# ββ Loss + optimiser ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
advantage_module = GAE(
gamma = args.gamma,
lmbda = args.gae_lambda,
value_network = value_module,
average_gae = False,
)
loss_module = ClipPPOLoss(
actor_network = policy_module,
critic_network = value_module,
clip_epsilon = args.clip_eps,
entropy_bonus = True, # always compute entropy term
entropy_coeff = args.ent_coef, # 0.0 β SB3 default (no bonus)
critic_coeff = args.vf_coef,
loss_critic_type = "l2",
normalize_advantage = True,
)
optimizer = torch.optim.Adam(loss_module.parameters(), lr=args.lr, eps=1e-5)
if ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
if args.compile:
try:
policy_module = torch.compile(policy_module, mode="default")
print("torch.compile enabled on policy (default)")
except Exception as e:
print(f"torch.compile skipped: {e}")
total_params = (
sum(p.numel() for p in policy_module.parameters() if p.requires_grad)
+ sum(p.numel() for p in value_module.parameters() if p.requires_grad
# exclude shared encoder params (already counted in policy)
and not any(p is q for q in encoder.parameters()))
)
# ββ Collector βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
collector = Collector(
vec_env,
policy_module,
frames_per_batch = args.rollout_steps,
total_frames = args.total_steps,
device = device,
storing_device = device,
reset_at_each_iter = False,
)
# ββ Replay buffer for PPO minibatches βββββββββββββββββββββββββββββββββββββ
replay = ReplayBuffer(
storage = LazyTensorStorage(args.rollout_steps, device=device),
sampler = SamplerWithoutReplacement(),
batch_size = args.batch_size,
)
# ββ Counters ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
global_step = ckpt["step"] if ckpt else 0
episode_num = ckpt["episode_num"] if ckpt else 0
reward_window = deque(ckpt["reward_window"] if ckpt else [], maxlen=args.window)
frontier_reward_window = deque(ckpt.get("frontier_reward_window", []) if ckpt else [], maxlen=args.window)
update_num = 0
start_time = time.time()
greedy_clean = 0 # result of last greedy eval (for display)
next_eval = args.eval_interval_steps
while next_eval <= global_step:
next_eval += args.eval_interval_steps
LOG_INTERVAL = 25_000
next_log = LOG_INTERVAL
while next_log <= global_step:
next_log += LOG_INTERVAL
# Accumulators for the current log window
_log_ep_rewards: list[float] = []
_log_ep_lengths: list[int] = []
if args.checkpoint_interval > 0:
next_ckpt = args.checkpoint_interval
while next_ckpt <= global_step:
next_ckpt += args.checkpoint_interval
else:
next_ckpt = float("inf")
if args.video_interval > 0:
next_video = args.video_interval
while next_video <= global_step:
next_video += args.video_interval
else:
next_video = float("inf")
os.makedirs(args.checkpoint_dir, exist_ok=True)
print(f"\nModel: {total_params:,} parameters | Device: {device} | Envs: {N}")
print(f"Rollout: {args.rollout_steps} frames per update (batch={args.batch_size}, epochs={args.ppo_epochs})")
print(f"PPO: lr={args.lr} gamma={args.gamma} lambda={args.gae_lambda} clip={args.clip_eps}")
print(f" vf={args.vf_coef} ent={args.ent_coef} grad={args.max_grad_norm}")
print(f"Curriculum: threshold={args.threshold} window={args.window} replay={args.replay_frac}")
print(f"Frontier : track {sampler.frontier_track.level} '{sampler.frontier_track.name}'")
print(f"W&B : {run.url}\n")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Training loop
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("Starting training loop...", flush=True)
for td in collector:
rollout_frames = td.numel()
global_step += rollout_frames
update_num += 1
# ββ Episode bookkeeping + curriculum advance βββββββββββββββββββββββββ
for ep in _iter_episodes(td):
episode_num += 1
reward_window.append(ep["ep_reward"])
_log_ep_rewards.append(ep["ep_reward"])
_log_ep_lengths.append(ep["ep_length"])
frontier = sampler.frontier_track
threshold = args.threshold * frontier.complexity
is_replay = (ep["track_level"] != frontier.level)
if not is_replay:
frontier_reward_window.append(ep["ep_reward"])
rolling_mean = statistics.mean(frontier_reward_window) if frontier_reward_window else 0.0
builder.record(ep["ep_reward"], ep["ep_crashes"], ep["ep_laps"],
is_frontier=not is_replay)
wandb.log({
"global_step": global_step,
"episode/reward": ep["ep_reward"],
"episode/length": ep["ep_length"],
"episode/laps": ep["ep_laps"],
"episode/crashes": ep["ep_crashes"],
"episode/on_track_pct": ep["on_track_pct"],
"episode/number": episode_num,
"curriculum/level": builder.current_level,
"curriculum/track_level": ep["track_level"],
"curriculum/tier": difficulty_of(frontier),
"curriculum/rolling_mean": rolling_mean,
"curriculum/threshold": threshold,
"curriculum/is_replay": int(is_replay),
}, step=global_step)
# ββ Compute GAE advantages & targets, then flatten for PPO βββββββββββ
with torch.no_grad():
advantage_module(td)
data_flat = td.reshape(-1)
replay.extend(data_flat)
# ββ PPO update: n_epochs Γ minibatches βββββββββββββββββββββββββββββββ
pg_losses, v_losses, ent_losses = [], [], []
approx_kls, clip_fracs, grad_norms = [], [], []
for epoch in range(args.ppo_epochs):
for _ in range(args.rollout_steps // args.batch_size):
mb = replay.sample()
loss_vals = loss_module(mb)
loss = (
loss_vals["loss_objective"]
+ loss_vals.get("loss_critic", torch.tensor(0.0, device=device))
+ loss_vals.get("loss_entropy", torch.tensor(0.0, device=device))
)
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
loss_module.parameters(), args.max_grad_norm
)
optimizer.step()
pg_losses.append(loss_vals["loss_objective"].detach().item())
if "loss_critic" in loss_vals:
v_losses.append(loss_vals["loss_critic"].detach().item())
if "loss_entropy" in loss_vals:
ent_losses.append(loss_vals["loss_entropy"].detach().item())
if "kl_approx" in loss_vals:
approx_kls.append(loss_vals["kl_approx"].detach().item())
if "clip_fraction" in loss_vals:
clip_fracs.append(loss_vals["clip_fraction"].detach().item())
grad_norms.append(float(grad_norm))
# After each epoch, re-shuffle for next epoch
replay.empty()
replay.extend(data_flat)
# Early stop if policy has moved too far from rollout data
if args.target_kl is not None and approx_kls:
epoch_kl = float(np.mean(approx_kls))
if epoch_kl > args.target_kl:
break
replay.empty()
# ββ Explained variance ββββββββββββββββββββββββββββββββββββββββββββββββ
with torch.no_grad():
values = td.get("state_value").reshape(-1)
returns = td.get(("next", "value_target")).reshape(-1) \
if ("next", "value_target") in td.keys(include_nested=True) \
else td.get("value_target").reshape(-1)
var_y = torch.var(returns)
ev = float("nan") if var_y == 0 else float(
1.0 - torch.var(returns - values) / var_y
)
# ββ Log PPO + system metrics ββββββββββββββββββββββββββββββββββββββββββ
sps = global_step / max(time.time() - start_time, 1e-6)
_mean = lambda xs: float(np.mean(xs)) if xs else float("nan")
wandb.log({
"global_step": global_step,
"ppo/policy_loss": _mean(pg_losses),
"ppo/value_loss": _mean(v_losses),
"ppo/entropy": _mean(ent_losses),
"ppo/approx_kl": _mean(approx_kls),
"ppo/clip_fraction": _mean(clip_fracs),
"ppo/explained_variance": ev,
"ppo/learning_rate": args.lr,
"ppo/entropy_coef": args.ent_coef,
"ppo/grad_norm": _mean(grad_norms),
"ppo/update": update_num,
"system/steps_per_sec": sps,
"system/elapsed_hours": (time.time() - start_time) / 3600,
}, step=global_step)
# ββ Periodic SB3-style summary ββββββββββββββββββββββββββββββββββββββββ
if global_step >= next_log:
ep_rew_mean = float(np.mean(_log_ep_rewards)) if _log_ep_rewards else float("nan")
ep_len_mean = float(np.mean(_log_ep_lengths)) if _log_ep_lengths else float("nan")
sps_now = global_step / max(time.time() - start_time, 1e-6)
frontier = sampler.frontier_track
win_clean = sum(1 for l, c in zip(sampler._frontier_laps, sampler._frontier_crashes)
if l >= 1 and c == 0)
def _fmt(v, fmt=".3g"):
return ("-" if v != v else format(v, fmt)) # nan β "-"
rows = [
("rollout/", ""),
(" ep_len_mean", _fmt(ep_len_mean, ".1f")),
(" ep_rew_mean", _fmt(ep_rew_mean, ".3f")),
(" episodes", str(episode_num)),
("curriculum/", ""),
(" level", f"{builder.current_level}/{len(TRAIN)-1}"),
(" frontier_track", f"{frontier.level} '{frontier.name}'"),
(" rolling_mean", _fmt(rolling_mean, ".2f")),
(" clean_wins", f"{win_clean}/{args.window} (stochastic)"),
(" greedy_clean", f"{greedy_clean}/{args.eval_episodes} (eval)"),
("time/", ""),
(" fps", _fmt(sps_now, ".0f")),
(" iterations", str(update_num)),
(" total_timesteps", f"{global_step:,}"),
("train/", ""),
(" approx_kl", _fmt(_mean(approx_kls))),
(" clip_fraction", _fmt(_mean(clip_fracs))),
(" entropy_loss", _fmt(_mean(ent_losses))),
(" explained_variance", _fmt(ev)),
(" learning_rate", _fmt(args.lr)),
(" policy_grad_loss", _fmt(_mean(pg_losses))),
(" value_loss", _fmt(_mean(v_losses))),
(" grad_norm", _fmt(_mean(grad_norms))),
]
col_w = max(len(k) for k, _ in rows) + 2
val_w = max((len(v) for _, v in rows if v), default=6) + 2
sep = "-" * (col_w + val_w + 5)
print(sep)
for k, v in rows:
if v:
print(f"| {k:<{col_w}} | {v:>{val_w}} |")
else:
print(f"| {k:<{col_w+val_w+3}} |")
print(sep, flush=True)
_log_ep_rewards.clear()
_log_ep_lengths.clear()
next_log += LOG_INTERVAL
# ββ Greedy curriculum eval ββββββββββββββββββββββββββββββββββββββββββββ
if global_step >= next_eval:
tracks_to_eval = TRAIN
all_pass = True
eval_log = {}
n_passing = 0
eval_passed: dict = {} # tr β bool, used for priority replay update
print(f"\n [EVAL] greedy eval β {len(tracks_to_eval)} track(s):", flush=True)
for tr in tracks_to_eval:
res = _greedy_eval(policy_module, tr, device,
args.eval_episodes, max_steps=3000)
clean = sum(1 for r in res if r["laps"] >= 1 and r["crashes"] == 0)
ok = (clean == args.eval_episodes)
n_passing += int(ok)
eval_passed[tr] = ok
if not ok:
all_pass = False
print(f" track {tr.level:02d} '{tr.name}': "
f"{'PASS' if ok else 'FAIL'} ({clean}/{args.eval_episodes})",
flush=True)
eval_log[f"curriculum/greedy_track{tr.level:02d}"] = int(ok)
# Update priority replay: give failing tracks dedicated 30% of episodes.
failing_indices = [
i for i, tr in enumerate(tracks_to_eval) if not eval_passed[tr]
]
n_fail = min(len(failing_indices), 10)
shared_n_priority.value = n_fail
for i, idx in enumerate(failing_indices[:10]):
shared_priority[i] = idx
if failing_indices:
fail_names = ", ".join(
f"track {TRAIN[i].level}" for i in failing_indices[:10]
)
print(f" [PRIO] priority replay set for {n_fail} failing track(s): {fail_names}",
flush=True)
else:
print(f" [PRIO] all tracks passed β priority replay cleared", flush=True)
wandb.log({"global_step": global_step,
"curriculum/greedy_pass": int(all_pass),
"curriculum/greedy_n_pass": n_passing,
"curriculum/priority_n_tracks": n_fail,
**eval_log}, step=global_step)
if all_pass:
# Every track passed simultaneously β training complete
shared_n_priority.value = 0
print(
f"\n >>> ALL {len(TRAIN)} TRACKS PASSED β training complete "
f"at step {global_step:,}\n",
flush=True,
)
wandb.log({"global_step": global_step, "curriculum/complete": 1},
step=global_step)
adv_path = os.path.join(
args.checkpoint_dir,
f"ppo_torchrl_complete_step{global_step:08d}.pt",
)
save_checkpoint(
adv_path, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
print(f" [CKPT] Final checkpoint: {adv_path}", flush=True)
break
else:
# Advance only if every track up to and including the frontier passes
prior_ok = all(
eval_passed.get(tr, False)
for tr in TRAIN[: builder.current_level + 1]
)
if prior_ok:
advanced = builder._sampler.advance()
if advanced:
shared_level.value = builder.current_level
new_frontier = sampler.frontier_track
print(
f"\n >>> ADVANCE -> Track {new_frontier.level} "
f"'{new_frontier.name}' "
f"[lvl {builder.current_level}/{len(TRAIN)-1}]\n",
flush=True,
)
wandb.log({
"global_step": global_step,
"curriculum/level": builder.current_level,
"curriculum/advanced_to_level": new_frontier.level,
"curriculum/advanced_to_name": new_frontier.name,
}, step=global_step)
adv_path = os.path.join(
args.checkpoint_dir,
f"ppo_torchrl_advance_lvl{builder.current_level:02d}"
f"_step{global_step:08d}.pt",
)
save_checkpoint(
adv_path, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
print(f" [CKPT] Advance checkpoint: {adv_path}", flush=True)
next_eval += args.eval_interval_steps
# ββ Checkpoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if global_step >= next_ckpt:
ckpt_path = os.path.join(
args.checkpoint_dir,
f"ppo_torchrl_step{global_step:08d}_lvl{builder.current_level:02d}.pt",
)
save_checkpoint(
ckpt_path, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
wandb.save(ckpt_path)
print(f"\n [CKPT] {ckpt_path}")
prune_checkpoints(args.checkpoint_dir, args.keep_checkpoints)
next_ckpt += args.checkpoint_interval
# ββ Video βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if global_step >= next_video:
try:
log_inference_videos(
policy_module = policy_module,
builder = builder,
device = device,
global_step = global_step,
video_dir = args.video_dir,
)
except Exception as e:
print(f" [VIDEO] Warning: failed to render video: {e}")
next_video += args.video_interval
# ββ Update collector's policy weights (in case of compile) βββββββββββ
collector.update_policy_weights_()
# ββ Final checkpoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
final = os.path.join(args.checkpoint_dir, "ppo_torchrl_final.pt")
save_checkpoint(
final, policy_module, value_module, optimizer,
global_step, builder, args,
reward_window, frontier_reward_window, episode_num, run.id,
)
wandb.save(final)
collector.shutdown()
elapsed = time.time() - start_time
print(f"\n{'-'*80}")
print(f"Training complete | {global_step:,} steps | {elapsed/3600:.2f} h")
print(f"Final model: {final}")
print(f"W&B run: {run.url}")
run.finish()
if __name__ == "__main__":
import traceback
try:
main()
except Exception:
traceback.print_exc()
sys.exit(1)
|