dpang commited on
Commit
50c6a61
·
verified ·
1 Parent(s): 6de1b43

Update examples/ppo_train.py

Browse files
Files changed (1) hide show
  1. examples/ppo_train.py +406 -0
examples/ppo_train.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Space Robotics Lab, SnT, University of Luxembourg, SpaceR
3
+ # RANS: arXiv:2310.07393 — OpenEnv training examples
4
+
5
+ """
6
+ PPO Training for RANS
7
+ ======================
8
+ Trains a spacecraft navigation policy using Proximal Policy Optimization (PPO),
9
+ the same algorithm used in the original RANS paper (via rl-games).
10
+
11
+ This implementation runs the environment locally (no HTTP server) and uses
12
+ pure PyTorch — no extra RL library required.
13
+
14
+ Architecture
15
+ ------------
16
+ Policy network: MLP obs → [64, 64] → action_mean, log_std
17
+ Value network: MLP obs → [64, 64] → value
18
+ Algorithm: PPO with GAE advantage estimation
19
+
20
+ Usage
21
+ -----
22
+ # GoToPosition (default)
23
+ python examples/ppo_train.py
24
+
25
+ # GoToPose, more steps
26
+ python examples/ppo_train.py --task GoToPose --timesteps 500000
27
+
28
+ # Continue from checkpoint
29
+ python examples/ppo_train.py --checkpoint rans_ppo_GoToPosition.pt
30
+
31
+ # Use trained policy
32
+ python examples/ppo_train.py --eval --checkpoint rans_ppo_GoToPosition.pt
33
+
34
+ Requirements
35
+ ------------
36
+ pip install torch numpy
37
+ """
38
+
39
+ from __future__ import annotations
40
+
41
+ import argparse
42
+ import os
43
+ import sys
44
+ import time
45
+ from typing import List
46
+
47
+ import numpy as np
48
+ import torch
49
+ import torch.nn as nn
50
+ import torch.optim as optim
51
+ from torch.distributions import Normal
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Local imports (no server needed)
55
+ # ---------------------------------------------------------------------------
56
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
57
+ from examples.gymnasium_wrapper import make_rans_env
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # Neural network policy
62
+ # ---------------------------------------------------------------------------
63
+
64
+ def _mlp(in_dim: int, hidden: List[int], out_dim: int) -> nn.Sequential:
65
+ layers: List[nn.Module] = []
66
+ prev = in_dim
67
+ for h in hidden:
68
+ layers += [nn.Linear(prev, h), nn.Tanh()]
69
+ prev = h
70
+ layers.append(nn.Linear(prev, out_dim))
71
+ return nn.Sequential(*layers)
72
+
73
+
74
+ class ActorCritic(nn.Module):
75
+ """
76
+ Shared-trunk actor-critic network.
77
+
78
+ The actor outputs a Gaussian distribution over continuous thruster
79
+ activations in [0, 1]. A Sigmoid is applied to the mean so it stays
80
+ in a valid range; log_std is a learnable parameter.
81
+ """
82
+
83
+ def __init__(self, obs_dim: int, act_dim: int, hidden: List[int] = None) -> None:
84
+ super().__init__()
85
+ if hidden is None:
86
+ hidden = [64, 64]
87
+ self.actor_mean = _mlp(obs_dim, hidden, act_dim)
88
+ self.log_std = nn.Parameter(torch.zeros(act_dim))
89
+ self.critic = _mlp(obs_dim, hidden, 1)
90
+
91
+ def forward(self, obs: torch.Tensor):
92
+ mean = torch.sigmoid(self.actor_mean(obs)) # ∈ (0, 1)
93
+ std = self.log_std.exp().expand_as(mean)
94
+ dist = Normal(mean, std)
95
+ value = self.critic(obs).squeeze(-1)
96
+ return dist, value
97
+
98
+ @torch.no_grad()
99
+ def act(self, obs: torch.Tensor):
100
+ dist, value = self(obs)
101
+ action = dist.sample().clamp(0.0, 1.0)
102
+ log_prob = dist.log_prob(action).sum(-1)
103
+ return action, log_prob, value
104
+
105
+ @torch.no_grad()
106
+ def act_deterministic(self, obs: torch.Tensor) -> torch.Tensor:
107
+ mean = torch.sigmoid(self.actor_mean(obs))
108
+ return mean.clamp(0.0, 1.0)
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Rollout buffer
113
+ # ---------------------------------------------------------------------------
114
+
115
+ class RolloutBuffer:
116
+ def __init__(self, n_steps: int, obs_dim: int, act_dim: int, device: str) -> None:
117
+ self.n = n_steps
118
+ self.device = device
119
+ self.obs = torch.zeros(n_steps, obs_dim, device=device)
120
+ self.actions = torch.zeros(n_steps, act_dim, device=device)
121
+ self.log_probs = torch.zeros(n_steps, device=device)
122
+ self.rewards = torch.zeros(n_steps, device=device)
123
+ self.values = torch.zeros(n_steps, device=device)
124
+ self.dones = torch.zeros(n_steps, device=device)
125
+ self.ptr = 0
126
+
127
+ def add(self, obs, action, log_prob, reward, value, done) -> None:
128
+ i = self.ptr
129
+ self.obs[i] = obs
130
+ self.actions[i] = action
131
+ self.log_probs[i] = log_prob
132
+ self.rewards[i] = reward
133
+ self.values[i] = value
134
+ self.dones[i] = done
135
+ self.ptr += 1
136
+
137
+ def reset(self) -> None:
138
+ self.ptr = 0
139
+
140
+ def compute_returns_and_advantages(
141
+ self, last_value: torch.Tensor, gamma: float = 0.99, lam: float = 0.95
142
+ ) -> tuple:
143
+ """GAE-λ advantage estimation."""
144
+ advantages = torch.zeros_like(self.rewards)
145
+ last_gae = 0.0
146
+ for t in reversed(range(self.n)):
147
+ next_val = last_value if t == self.n - 1 else self.values[t + 1]
148
+ next_done = 0.0 if t == self.n - 1 else self.dones[t + 1]
149
+ delta = (self.rewards[t]
150
+ + gamma * next_val * (1 - next_done)
151
+ - self.values[t])
152
+ last_gae = delta + gamma * lam * (1 - self.dones[t]) * last_gae
153
+ advantages[t] = last_gae
154
+ returns = advantages + self.values
155
+ return advantages, returns
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # PPO update
160
+ # ---------------------------------------------------------------------------
161
+
162
+ def ppo_update(
163
+ policy: ActorCritic,
164
+ optimizer: optim.Optimizer,
165
+ buffer: RolloutBuffer,
166
+ advantages: torch.Tensor,
167
+ returns: torch.Tensor,
168
+ clip_eps: float = 0.2,
169
+ entropy_coef: float = 0.01,
170
+ value_coef: float = 0.5,
171
+ n_epochs: int = 10,
172
+ batch_size: int = 64,
173
+ ) -> dict:
174
+ """Single PPO update over the collected rollout."""
175
+ n = buffer.n
176
+ idx = torch.randperm(n, device=buffer.device)
177
+
178
+ stats = {"policy_loss": 0.0, "value_loss": 0.0, "entropy": 0.0}
179
+ n_updates = 0
180
+
181
+ for _ in range(n_epochs):
182
+ for start in range(0, n, batch_size):
183
+ mb = idx[start: start + batch_size]
184
+ obs_b = buffer.obs[mb]
185
+ act_b = buffer.actions[mb]
186
+ old_lp_b = buffer.log_probs[mb]
187
+ adv_b = advantages[mb]
188
+ ret_b = returns[mb]
189
+
190
+ # Normalise advantages
191
+ adv_b = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
192
+
193
+ dist, value = policy(obs_b)
194
+ log_prob = dist.log_prob(act_b).sum(-1)
195
+ entropy = dist.entropy().sum(-1).mean()
196
+
197
+ ratio = (log_prob - old_lp_b).exp()
198
+ surr1 = ratio * adv_b
199
+ surr2 = ratio.clamp(1 - clip_eps, 1 + clip_eps) * adv_b
200
+ policy_loss = -torch.min(surr1, surr2).mean()
201
+ value_loss = (value - ret_b).pow(2).mean()
202
+ loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
203
+
204
+ optimizer.zero_grad()
205
+ loss.backward()
206
+ nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
207
+ optimizer.step()
208
+
209
+ stats["policy_loss"] += policy_loss.item()
210
+ stats["value_loss"] += value_loss.item()
211
+ stats["entropy"] += entropy.item()
212
+ n_updates += 1
213
+
214
+ return {k: v / n_updates for k, v in stats.items()}
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # Training loop
219
+ # ---------------------------------------------------------------------------
220
+
221
+ def train(args: argparse.Namespace) -> None:
222
+ device = "cuda" if torch.cuda.is_available() else "cpu"
223
+ print(f"\nRANS PPO Training")
224
+ print(f" task={args.task} device={device} steps={args.timesteps}")
225
+ print("=" * 60)
226
+
227
+ # Environment
228
+ env = make_rans_env(task=args.task, max_episode_steps=args.episode_steps)
229
+ obs_dim = env.observation_space.shape[0]
230
+ act_dim = env.action_space.shape[0]
231
+ print(f" obs_dim={obs_dim} act_dim={act_dim}")
232
+
233
+ # Policy
234
+ policy = ActorCritic(obs_dim, act_dim).to(device)
235
+ optimizer = optim.Adam(policy.parameters(), lr=args.lr)
236
+
237
+ if args.checkpoint and os.path.exists(args.checkpoint):
238
+ ckpt = torch.load(args.checkpoint, map_location=device)
239
+ policy.load_state_dict(ckpt["policy"])
240
+ optimizer.load_state_dict(ckpt["optimizer"])
241
+ print(f" Loaded checkpoint: {args.checkpoint}")
242
+
243
+ buffer = RolloutBuffer(args.n_steps, obs_dim, act_dim, device)
244
+
245
+ # Tracking
246
+ ep_rewards: List[float] = []
247
+ ep_lengths: List[int] = []
248
+ ep_reward = 0.0
249
+ ep_length = 0
250
+ best_mean_reward = -float("inf")
251
+
252
+ obs_np, _ = env.reset()
253
+ obs = torch.from_numpy(obs_np).float().to(device)
254
+ total_steps = 0
255
+ update_num = 0
256
+ t0 = time.perf_counter()
257
+
258
+ while total_steps < args.timesteps:
259
+ # --- Collect rollout ---
260
+ buffer.reset()
261
+ for _ in range(args.n_steps):
262
+ action, log_prob, value = policy.act(obs)
263
+ action_np = action.cpu().numpy()
264
+
265
+ next_obs_np, reward, terminated, truncated, info = env.step(action_np)
266
+ done = terminated or truncated
267
+
268
+ buffer.add(obs, action, log_prob,
269
+ torch.tensor(reward, device=device),
270
+ value,
271
+ torch.tensor(float(done), device=device))
272
+
273
+ ep_reward += reward
274
+ ep_length += 1
275
+ total_steps += 1
276
+
277
+ if done:
278
+ ep_rewards.append(ep_reward)
279
+ ep_lengths.append(ep_length)
280
+ ep_reward = 0.0
281
+ ep_length = 0
282
+ next_obs_np, _ = env.reset()
283
+
284
+ obs = torch.from_numpy(next_obs_np).float().to(device)
285
+
286
+ # Bootstrap value for last observation
287
+ with torch.no_grad():
288
+ _, last_value = policy(obs)
289
+
290
+ advantages, returns = buffer.compute_returns_and_advantages(
291
+ last_value, gamma=args.gamma, lam=args.lam
292
+ )
293
+
294
+ # --- PPO update ---
295
+ stats = ppo_update(
296
+ policy, optimizer, buffer, advantages, returns,
297
+ clip_eps=args.clip_eps, entropy_coef=args.entropy_coef,
298
+ n_epochs=args.n_epochs, batch_size=args.batch_size,
299
+ )
300
+ update_num += 1
301
+
302
+ # --- Logging ---
303
+ if update_num % args.log_interval == 0:
304
+ mean_rew = np.mean(ep_rewards[-100:]) if ep_rewards else float("nan")
305
+ mean_len = np.mean(ep_lengths[-100:]) if ep_lengths else float("nan")
306
+ elapsed = time.perf_counter() - t0
307
+ fps = total_steps / elapsed
308
+ print(f" Update {update_num:5d} | steps={total_steps:7d} "
309
+ f"| mean_reward={mean_rew:6.3f} mean_len={mean_len:5.0f} "
310
+ f"| fps={fps:.0f} "
311
+ f"| pi_loss={stats['policy_loss']:.4f} "
312
+ f"| v_loss={stats['value_loss']:.4f}")
313
+
314
+ # --- Checkpoint ---
315
+ if ep_rewards:
316
+ mean_rew = np.mean(ep_rewards[-100:])
317
+ if mean_rew > best_mean_reward:
318
+ best_mean_reward = mean_rew
319
+ ckpt_path = f"rans_ppo_{args.task}.pt"
320
+ torch.save({"policy": policy.state_dict(),
321
+ "optimizer": optimizer.state_dict(),
322
+ "total_steps": total_steps,
323
+ "best_mean_reward": best_mean_reward}, ckpt_path)
324
+
325
+ env.close()
326
+ print(f"\nTraining complete. Best mean reward: {best_mean_reward:.3f}")
327
+ print(f"Checkpoint saved to: rans_ppo_{args.task}.pt")
328
+
329
+
330
+ # ---------------------------------------------------------------------------
331
+ # Evaluation loop
332
+ # ---------------------------------------------------------------------------
333
+
334
+ def evaluate(args: argparse.Namespace) -> None:
335
+ device = "cpu"
336
+ env = make_rans_env(task=args.task, max_episode_steps=args.episode_steps)
337
+ obs_dim = env.observation_space.shape[0]
338
+ act_dim = env.action_space.shape[0]
339
+
340
+ policy = ActorCritic(obs_dim, act_dim).to(device)
341
+ ckpt = torch.load(args.checkpoint, map_location=device)
342
+ policy.load_state_dict(ckpt["policy"])
343
+ policy.eval()
344
+ print(f"\nEvaluating {args.checkpoint} task={args.task}")
345
+ print(f" Best training reward: {ckpt.get('best_mean_reward', '?'):.3f}")
346
+ print("=" * 60)
347
+
348
+ for ep in range(args.eval_episodes):
349
+ obs_np, _ = env.reset()
350
+ total_reward = 0.0
351
+ steps = 0
352
+ while True:
353
+ obs = torch.from_numpy(obs_np).float().to(device)
354
+ action = policy.act_deterministic(obs).numpy()
355
+ obs_np, reward, terminated, truncated, info = env.step(action)
356
+ total_reward += reward
357
+ steps += 1
358
+ if terminated or truncated:
359
+ break
360
+ print(f" Episode {ep + 1:2d} | steps={steps:4d} "
361
+ f"| reward={total_reward:.3f} "
362
+ f"| goal={info.get('goal_reached', '?')}")
363
+
364
+ env.close()
365
+
366
+
367
+ # ---------------------------------------------------------------------------
368
+ # Entry point
369
+ # ---------------------------------------------------------------------------
370
+
371
+ def main() -> None:
372
+ parser = argparse.ArgumentParser(description="RANS PPO training")
373
+ parser.add_argument("--task", default="GoToPosition",
374
+ choices=["GoToPosition", "GoToPose",
375
+ "TrackLinearVelocity", "TrackLinearAngularVelocity"])
376
+ parser.add_argument("--timesteps", type=int, default=300_000)
377
+ parser.add_argument("--episode-steps", type=int, default=500)
378
+ parser.add_argument("--n-steps", type=int, default=2048,
379
+ help="Rollout length before each PPO update")
380
+ parser.add_argument("--n-epochs", type=int, default=10)
381
+ parser.add_argument("--batch-size", type=int, default=64)
382
+ parser.add_argument("--lr", type=float, default=3e-4)
383
+ parser.add_argument("--gamma", type=float, default=0.99)
384
+ parser.add_argument("--lam", type=float, default=0.95)
385
+ parser.add_argument("--clip-eps", type=float, default=0.2)
386
+ parser.add_argument("--entropy-coef", type=float, default=0.01)
387
+ parser.add_argument("--log-interval", type=int, default=10,
388
+ help="Log every N PPO updates")
389
+ parser.add_argument("--checkpoint", default=None,
390
+ help="Path to a .pt checkpoint to load or save")
391
+ parser.add_argument("--eval", action="store_true",
392
+ help="Run evaluation only (requires --checkpoint)")
393
+ parser.add_argument("--eval-episodes", type=int, default=10)
394
+ args = parser.parse_args()
395
+
396
+ if args.eval:
397
+ if not args.checkpoint:
398
+ print("--eval requires --checkpoint PATH")
399
+ sys.exit(1)
400
+ evaluate(args)
401
+ else:
402
+ train(args)
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()