Spaces:
Runtime error
Runtime error
| """ | |
| PPO trainer β ported directly from openenv/training/ppo_trainer.py. | |
| Same algorithm, same hyperparameters, same GAE implementation. | |
| Only change: uses OverflowGymEnv instead of CarEnv3D. | |
| Usage: | |
| from overflow_env.training.ppo_trainer import run_training | |
| run_training(policy_type="attention", total_steps=2_000_000) | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from collections import deque | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from .overflow_gym_env import OverflowGymEnv | |
| from .curriculum import CurriculumManager | |
| from .reward import compute_episode_bonus | |
| from ..policies.base_policy import BasePolicy | |
| from ..policies.policy_spec import OBS_DIM | |
| # ββ Rollout buffer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Identical to openenv/training/ppo_trainer.py | |
| class RolloutBuffer: | |
| def __init__(self, n_steps: int, obs_dim: int, device: torch.device): | |
| self.n = n_steps | |
| self.obs = torch.zeros(n_steps, obs_dim, device=device) | |
| self.acts = torch.zeros(n_steps, 3, device=device) | |
| self.rew = torch.zeros(n_steps, device=device) | |
| self.val = torch.zeros(n_steps, device=device) | |
| self.logp = torch.zeros(n_steps, device=device) | |
| self.done = torch.zeros(n_steps, device=device) | |
| self.ptr = 0 | |
| def add(self, obs, act, rew, val, logp, done): | |
| i = self.ptr | |
| self.obs[i] = torch.as_tensor(obs, dtype=torch.float32) | |
| self.acts[i] = torch.as_tensor(act, dtype=torch.float32) | |
| self.rew[i] = float(rew) | |
| self.val[i] = float(val) | |
| self.logp[i] = float(logp) | |
| self.done[i] = float(done) | |
| self.ptr += 1 | |
| def full(self) -> bool: | |
| return self.ptr >= self.n | |
| def reset(self): | |
| self.ptr = 0 | |
| def compute_returns(self, last_val: float, gamma: float, gae_lambda: float): | |
| """Generalized Advantage Estimation β identical to openenv.""" | |
| adv = torch.zeros_like(self.rew) | |
| gae = 0.0 | |
| for t in reversed(range(self.n)): | |
| next_val = last_val if t == self.n - 1 else float(self.val[t + 1]) | |
| delta = self.rew[t] + gamma * next_val * (1 - self.done[t]) - self.val[t] | |
| gae = delta + gamma * gae_lambda * (1 - self.done[t]) * gae | |
| adv[t] = gae | |
| self.ret = adv + self.val | |
| # ββ PPO Trainer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PPOTrainer: | |
| """ | |
| Identical to openenv PPOTrainer β same hyperparameters, same PPO update. | |
| Environment is OverflowGymEnv instead of CarEnv3D. | |
| """ | |
| def __init__( | |
| self, | |
| policy: BasePolicy, | |
| env: OverflowGymEnv, | |
| curriculum: Optional[CurriculumManager] = None, | |
| # PPO hyperparameters β same defaults as openenv | |
| lr: float = 3e-4, | |
| gamma: float = 0.99, | |
| gae_lambda: float = 0.95, | |
| clip_range: float = 0.2, | |
| clip_range_vf: float = 0.2, | |
| ent_coef: float = 0.02, | |
| vf_coef: float = 0.5, | |
| max_grad_norm: float = 0.5, | |
| n_steps: int = 2048, | |
| batch_size: int = 256, | |
| n_epochs: int = 10, | |
| save_dir: str = "checkpoints", | |
| log_interval: int = 10, | |
| device: str = "auto", | |
| ): | |
| self.policy = policy | |
| self.env = env | |
| self.curriculum = curriculum or CurriculumManager() | |
| self.gamma = gamma | |
| self.gae_lambda = gae_lambda | |
| self.clip = clip_range | |
| self.clip_vf = clip_range_vf | |
| self.ent_coef = ent_coef | |
| self.vf_coef = vf_coef | |
| self.max_grad = max_grad_norm | |
| self.n_steps = n_steps | |
| self.batch_size = batch_size | |
| self.n_epochs = n_epochs | |
| self.log_every = log_interval | |
| self.save_dir = Path(save_dir) | |
| self.save_dir.mkdir(parents=True, exist_ok=True) | |
| if device == "auto": | |
| device = "cuda" if torch.cuda.is_available() else \ | |
| "mps" if torch.backends.mps.is_available() else "cpu" | |
| self.device = torch.device(device) | |
| self.policy.to(self.device) | |
| self.optimizer = optim.Adam(policy.parameters(), lr=lr, eps=1e-5) | |
| self.scheduler = optim.lr_scheduler.LinearLR( | |
| self.optimizer, start_factor=1.0, end_factor=0.1, total_iters=500, | |
| ) | |
| self.buffer = RolloutBuffer(n_steps, OBS_DIM, self.device) | |
| self.ep_rewards = deque(maxlen=100) | |
| self.ep_lengths = deque(maxlen=100) | |
| self.total_steps = 0 | |
| self.n_updates = 0 | |
| # ββ Main training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(self, total_steps: int = 2_000_000) -> None: | |
| print(f"\n{'='*70}", flush=True) | |
| print(f" OpenENV PPO Training β policy={self.policy.__class__.__name__}", flush=True) | |
| print(f" total_steps={total_steps} n_steps={self.n_steps} lr={self.optimizer.param_groups[0]['lr']:.0e}", flush=True) | |
| print(f" gamma={self.gamma} gae_lambda={self.gae_lambda} clip={self.clip} ent_coef={self.ent_coef}", flush=True) | |
| print(f"{'='*70}\n", flush=True) | |
| obs, _ = self.env.reset() | |
| ep_reward = 0.0 | |
| ep_steps = 0 | |
| t0 = time.time() | |
| while self.total_steps < total_steps: | |
| self.buffer.reset() | |
| self.policy.eval() | |
| # ββ Collect rollout ββββββββββββββββββββββββββββββββββββββββββββββ | |
| for _ in range(self.n_steps): | |
| # Curriculum step (returns [] for OverflowEnv β kept for API compat) | |
| self.curriculum.step(self.env._sim_time) | |
| obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device) | |
| with torch.no_grad(): | |
| act_mean, val = self.policy(obs_t.unsqueeze(0)) | |
| act_mean = act_mean.squeeze(0) | |
| val = val.squeeze(0) | |
| dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3) | |
| action = dist.sample().clamp(-1, 1) | |
| logp = dist.log_prob(action).sum() | |
| next_obs, reward, term, trunc, info = self.env.step(action.cpu().numpy()) | |
| self.buffer.add( | |
| obs, action.cpu().numpy(), reward, | |
| float(val), float(logp), float(term or trunc), | |
| ) | |
| obs = next_obs | |
| ep_reward += reward | |
| ep_steps += 1 | |
| self.total_steps += 1 | |
| if term or trunc: | |
| bonus = compute_episode_bonus( | |
| total_steps=ep_steps, | |
| survived=not info.get("collision", False), | |
| ) | |
| ep_reward += bonus | |
| self.ep_rewards.append(ep_reward) | |
| self.ep_lengths.append(ep_steps) | |
| advanced = self.curriculum.record_episode_reward(ep_reward) | |
| outcome = "CRASH" if info.get("collision") else ("GOAL" if info.get("goal_reached") else "timeout") | |
| print( | |
| f" ep#{len(self.ep_rewards):>4d} | " | |
| f"steps={ep_steps:>3d} | " | |
| f"reward={ep_reward:>8.2f} | " | |
| f"outcome={outcome:<8} | " | |
| f"stage={self.curriculum.current_stage} | " | |
| f"total_steps={self.total_steps}", | |
| flush=True, | |
| ) | |
| obs, _ = self.env.reset() | |
| ep_reward = 0.0 | |
| ep_steps = 0 | |
| # ββ PPO update βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device) | |
| _, last_val = self.policy(obs_t.unsqueeze(0)) | |
| self.buffer.compute_returns(float(last_val), self.gamma, self.gae_lambda) | |
| self.policy.train() | |
| self._ppo_update() | |
| self.n_updates += 1 | |
| self.scheduler.step() | |
| elapsed = time.time() - t0 | |
| sps = self.total_steps / max(elapsed, 1) | |
| mean_r = np.mean(self.ep_rewards) if self.ep_rewards else 0.0 | |
| mean_l = np.mean(self.ep_lengths) if self.ep_lengths else 0.0 | |
| print( | |
| f"\n[PPO update #{self.n_updates}] " | |
| f"step={self.total_steps} " | |
| f"mean_reward={mean_r:.2f} " | |
| f"mean_ep_len={mean_l:.0f} " | |
| f"stage={self.curriculum.current_stage} " | |
| f"sps={sps:.0f}\n", | |
| flush=True, | |
| ) | |
| # ββ Checkpoint βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if self.n_updates % 50 == 0: | |
| ckpt = self.save_dir / f"policy_step{self.total_steps}_stage{self.curriculum.current_stage}.pt" | |
| torch.save({ | |
| "step": self.total_steps, | |
| "stage": self.curriculum.current_stage, | |
| "policy": self.policy.state_dict(), | |
| "optim": self.optimizer.state_dict(), | |
| }, ckpt) | |
| print(f"[PPO] Saved checkpoint β {ckpt}") | |
| # ββ PPO update pass β identical to openenv βββββββββββββββββββββββββββββββββ | |
| def _ppo_update(self): | |
| obs = self.buffer.obs | |
| acts = self.buffer.acts | |
| old_logp = self.buffer.logp | |
| adv = self.buffer.ret - self.buffer.val | |
| adv = (adv - adv.mean()) / (adv.std() + 1e-8) | |
| ret = self.buffer.ret | |
| old_val = self.buffer.val | |
| indices = torch.randperm(self.n_steps, device=self.device) | |
| for _ in range(self.n_epochs): | |
| for start in range(0, self.n_steps, self.batch_size): | |
| idx = indices[start: start + self.batch_size] | |
| act_mean, val = self.policy(obs[idx]) | |
| val = val.squeeze(-1) | |
| dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3) | |
| logp = dist.log_prob(acts[idx]).sum(dim=-1) | |
| entropy = dist.entropy().sum(dim=-1).mean() | |
| ratio = torch.exp(logp - old_logp[idx]) | |
| pg_loss1 = -adv[idx] * ratio | |
| pg_loss2 = -adv[idx] * ratio.clamp(1 - self.clip, 1 + self.clip) | |
| pg_loss = torch.max(pg_loss1, pg_loss2).mean() | |
| val_unclipped = (val - ret[idx]) ** 2 | |
| val_clipped = ( | |
| old_val[idx] | |
| + (val - old_val[idx]).clamp(-self.clip_vf, self.clip_vf) | |
| - ret[idx] | |
| ) ** 2 | |
| vf_loss = 0.5 * torch.max(val_unclipped, val_clipped).mean() | |
| loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad) | |
| self.optimizer.step() | |
| # ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_training( | |
| policy_type: str = "attention", | |
| total_steps: int = 2_000_000, | |
| start_stage: int = 1, | |
| checkpoint: Optional[str] = None, | |
| device: str = "auto", | |
| ) -> None: | |
| from ..policies.ticket_attention_policy import TicketAttentionPolicy | |
| from ..policies.flat_mlp_policy import FlatMLPPolicy | |
| policy_map = { | |
| "attention": lambda: TicketAttentionPolicy(obs_dim=OBS_DIM), | |
| "mlp": lambda: FlatMLPPolicy(obs_dim=OBS_DIM), | |
| } | |
| policy = policy_map[policy_type]() | |
| if checkpoint: | |
| ckpt = torch.load(checkpoint, map_location="cpu") | |
| policy.load_state_dict(ckpt["policy"]) | |
| print(f"[PPO] Loaded checkpoint from {checkpoint}") | |
| env = OverflowGymEnv() | |
| cm = CurriculumManager() | |
| if start_stage > 1: | |
| cm.force_stage(start_stage) | |
| trainer = PPOTrainer(policy=policy, env=env, curriculum=cm, device=device, n_steps=512) | |
| trainer.train(total_steps=total_steps) | |
| if __name__ == "__main__": | |
| import argparse | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--policy", default="attention", choices=["attention", "mlp"]) | |
| p.add_argument("--steps", default=2_000_000, type=int) | |
| p.add_argument("--stage", default=1, type=int) | |
| p.add_argument("--checkpoint", default=None) | |
| p.add_argument("--device", default="auto") | |
| args = p.parse_args() | |
| run_training(args.policy, args.steps, args.stage, args.checkpoint, args.device) | |