| | |
| | |
| | |
| |
|
| | """ |
| | PPO Training Against a Running RANS Server |
| | ========================================== |
| | Trains a spacecraft navigation policy via the OpenEnv HTTP/WebSocket client, |
| | connecting to a RANS server running locally (uvicorn) or in Docker. |
| | |
| | This is the canonical OpenEnv training pattern: |
| | |
| | βββββββββββββββββββββββββββββββ HTTP/WS ββββββββββββββββββββββββ |
| | β ppo_train (this script) β βββββββββββΊ β uvicorn / Docker β |
| | β RemoteRANSGymnasiumEnv β β RANSEnvironment β |
| | β ActorCritic + PPO β βββββββββββ β spacecraft physics β |
| | βββββββββββββββββββββββββββββββ ββββββββββββββββββββββββ |
| | |
| | Start the server first: |
| | uvicorn rans_env.server.app:app --host 0.0.0.0 --port 8000 |
| | |
| | Then run this script: |
| | python examples/train_against_server.py --task GoToPosition |
| | python examples/train_against_server.py --task GoToPose --url http://localhost:8000 |
| | python examples/train_against_server.py --eval --checkpoint rans_ppo_remote_GoToPosition.pt |
| | |
| | Requirements: |
| | pip install torch gymnasium openenv-core |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import os |
| | import sys |
| | import time |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | import numpy as np |
| |
|
| | try: |
| | import gymnasium as gym |
| | from gymnasium import spaces |
| | except ImportError: |
| | print("gymnasium is required: pip install gymnasium") |
| | sys.exit(1) |
| |
|
| | try: |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.distributions import Normal |
| | except ImportError: |
| | print("torch is required: pip install torch") |
| | sys.exit(1) |
| |
|
| | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class RemoteRANSGymnasiumEnv(gym.Env): |
| | """ |
| | Gymnasium-compatible environment that connects to a running RANS server |
| | via the OpenEnv ``RANSEnv`` WebSocket/HTTP client. |
| | |
| | Identical observation and action spaces to ``RANSGymnasiumEnv``, but all |
| | physics runs inside the server process (or Docker container). |
| | """ |
| |
|
| | metadata = {"render_modes": []} |
| |
|
| | def __init__(self, base_url: str = "http://localhost:8000") -> None: |
| | super().__init__() |
| | from rans_env import RANSEnv, SpacecraftAction |
| |
|
| | self._SpacecraftAction = SpacecraftAction |
| | |
| | self._client = RANSEnv(base_url=base_url) |
| | self._client.connect() |
| |
|
| | |
| | result = self._client.reset() |
| | obs = result.observation |
| | flat = self._flatten(obs) |
| |
|
| | n = len(obs.thruster_masks) |
| | self.action_space = spaces.Box(low=0.0, high=1.0, shape=(n,), dtype=np.float32) |
| | self.observation_space = spaces.Box( |
| | low=-np.inf, high=np.inf, shape=(flat.shape[0],), dtype=np.float32 |
| | ) |
| | self._last_flat = flat |
| | self._task = obs.task |
| |
|
| | def reset(self, *, seed=None, options=None) -> Tuple[np.ndarray, Dict]: |
| | super().reset(seed=seed) |
| | result = self._client.reset() |
| | self._last_flat = self._flatten(result.observation) |
| | return self._last_flat, {"task": result.observation.task} |
| |
|
| | def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]: |
| | result = self._client.step( |
| | self._SpacecraftAction(thrusters=action.tolist()) |
| | ) |
| | flat = self._flatten(result.observation) |
| | reward = float(result.reward or 0.0) |
| | done = bool(result.done) |
| | self._last_flat = flat |
| | return flat, reward, done, False, result.observation.info or {} |
| |
|
| | def close(self) -> None: |
| | self._client.disconnect() |
| |
|
| | @staticmethod |
| | def _flatten(obs) -> np.ndarray: |
| | return np.concatenate([ |
| | np.array(obs.state_obs, dtype=np.float32), |
| | np.array(obs.thruster_transforms, dtype=np.float32).flatten(), |
| | np.array(obs.thruster_masks, dtype=np.float32), |
| | np.array([obs.mass, obs.inertia], dtype=np.float32), |
| | ]) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _mlp(in_dim: int, hidden: List[int], out_dim: int) -> nn.Sequential: |
| | layers: List[nn.Module] = [] |
| | prev = in_dim |
| | for h in hidden: |
| | layers += [nn.Linear(prev, h), nn.Tanh()] |
| | prev = h |
| | layers.append(nn.Linear(prev, out_dim)) |
| | return nn.Sequential(*layers) |
| |
|
| |
|
| | class ActorCritic(nn.Module): |
| | def __init__(self, obs_dim: int, act_dim: int, hidden: List[int] = None): |
| | super().__init__() |
| | hidden = hidden or [64, 64] |
| | self.actor_mean = _mlp(obs_dim, hidden, act_dim) |
| | self.log_std = nn.Parameter(torch.zeros(act_dim)) |
| | self.critic = _mlp(obs_dim, hidden, 1) |
| |
|
| | def forward(self, obs): |
| | mean = torch.sigmoid(self.actor_mean(obs)) |
| | std = self.log_std.exp().expand_as(mean) |
| | return Normal(mean, std), self.critic(obs).squeeze(-1) |
| |
|
| | @torch.no_grad() |
| | def act(self, obs): |
| | dist, value = self(obs) |
| | action = dist.sample().clamp(0.0, 1.0) |
| | return action, dist.log_prob(action).sum(-1), value |
| |
|
| | @torch.no_grad() |
| | def act_deterministic(self, obs): |
| | return torch.sigmoid(self.actor_mean(obs)).clamp(0.0, 1.0) |
| |
|
| |
|
| | class RolloutBuffer: |
| | def __init__(self, n: int, obs_dim: int, act_dim: int, device: str): |
| | self.n, self.device = n, device |
| | self.obs = torch.zeros(n, obs_dim, device=device) |
| | self.actions = torch.zeros(n, act_dim, device=device) |
| | self.log_probs = torch.zeros(n, device=device) |
| | self.rewards = torch.zeros(n, device=device) |
| | self.values = torch.zeros(n, device=device) |
| | self.dones = torch.zeros(n, device=device) |
| | self.ptr = 0 |
| |
|
| | def add(self, obs, action, log_prob, reward, value, done): |
| | i = self.ptr |
| | self.obs[i], self.actions[i] = obs, action |
| | self.log_probs[i], self.rewards[i] = log_prob, reward |
| | self.values[i], self.dones[i] = value, done |
| | self.ptr += 1 |
| |
|
| | def reset(self): self.ptr = 0 |
| |
|
| | def compute_gae(self, last_value, gamma=0.99, lam=0.95): |
| | adv = torch.zeros_like(self.rewards) |
| | last_gae = 0.0 |
| | for t in reversed(range(self.n)): |
| | nv = last_value if t == self.n - 1 else self.values[t + 1] |
| | nd = 0.0 if t == self.n - 1 else self.dones[t + 1] |
| | delta = self.rewards[t] + gamma * nv * (1 - nd) - self.values[t] |
| | last_gae = delta + gamma * lam * (1 - self.dones[t]) * last_gae |
| | adv[t] = last_gae |
| | return adv, adv + self.values |
| |
|
| |
|
| | def ppo_update(policy, optimizer, buf, adv, returns, |
| | clip=0.2, ent=0.01, vf=0.5, epochs=10, bs=64): |
| | n = buf.n |
| | stats = {"pi": 0.0, "vf": 0.0, "ent": 0.0} |
| | n_updates = 0 |
| | for _ in range(epochs): |
| | for s in range(0, n, bs): |
| | mb = torch.randperm(n, device=buf.device)[s:s+bs] |
| | a_b = (adv[mb] - adv[mb].mean()) / (adv[mb].std() + 1e-8) |
| | dist, val = policy(buf.obs[mb]) |
| | lp = dist.log_prob(buf.actions[mb]).sum(-1) |
| | r = (lp - buf.log_probs[mb]).exp() |
| | pi_loss = -torch.min(r * a_b, r.clamp(1-clip, 1+clip) * a_b).mean() |
| | vf_loss = (val - returns[mb]).pow(2).mean() |
| | loss = pi_loss + vf * vf_loss - ent * dist.entropy().sum(-1).mean() |
| | optimizer.zero_grad(); loss.backward() |
| | nn.utils.clip_grad_norm_(policy.parameters(), 0.5) |
| | optimizer.step() |
| | stats["pi"] += pi_loss.item() |
| | stats["vf"] += vf_loss.item() |
| | n_updates += 1 |
| | return {key: val / max(n_updates, 1) for key, val in stats.items()} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train(args: argparse.Namespace) -> None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"\nRANS PPO β Remote Training via OpenEnv Client") |
| | print(f" server : {args.url}") |
| | print(f" task : connecting⦠(task set by RANS_TASK on server)") |
| | print(f" device : {device}") |
| | print(f" steps : {args.timesteps:,}") |
| | print("=" * 60) |
| |
|
| | env = RemoteRANSGymnasiumEnv(base_url=args.url) |
| | obs_dim = env.observation_space.shape[0] |
| | act_dim = env.action_space.shape[0] |
| | print(f" task : {env._task}") |
| | print(f" obs_dim : {obs_dim}") |
| | print(f" act_dim : {act_dim} (thrusters)") |
| | print() |
| |
|
| | policy = ActorCritic(obs_dim, act_dim).to(device) |
| | optimizer = optim.Adam(policy.parameters(), lr=args.lr) |
| |
|
| | if args.checkpoint and os.path.exists(args.checkpoint): |
| | ck = torch.load(args.checkpoint, map_location=device) |
| | policy.load_state_dict(ck["policy"]) |
| | optimizer.load_state_dict(ck["optimizer"]) |
| | print(f" Loaded checkpoint: {args.checkpoint}") |
| |
|
| | buf = RolloutBuffer(args.n_steps, obs_dim, act_dim, device) |
| |
|
| | ep_rewards: List[float] = [] |
| | ep_lengths: List[int] = [] |
| | ep_reward = ep_len = 0.0 |
| | best_mean = -float("inf") |
| |
|
| | obs_np, _ = env.reset() |
| | obs = torch.from_numpy(obs_np).float().to(device) |
| | total_steps = update_num = 0 |
| | t0 = time.perf_counter() |
| |
|
| | while total_steps < args.timesteps: |
| | buf.reset() |
| | for _ in range(args.n_steps): |
| | action, log_prob, value = policy.act(obs) |
| | next_obs_np, reward, terminated, truncated, info = env.step( |
| | action.cpu().numpy() |
| | ) |
| | done = terminated or truncated |
| | buf.add(obs, action, log_prob, |
| | torch.tensor(reward, device=device), |
| | value, |
| | torch.tensor(float(done), device=device)) |
| | ep_reward += reward |
| | ep_len += 1 |
| | total_steps += 1 |
| | if done: |
| | ep_rewards.append(ep_reward) |
| | ep_lengths.append(ep_len) |
| | ep_reward = ep_len = 0.0 |
| | next_obs_np, _ = env.reset() |
| | obs = torch.from_numpy(next_obs_np).float().to(device) |
| |
|
| | with torch.no_grad(): |
| | _, last_val = policy(obs) |
| | adv, returns = buf.compute_gae(last_val, args.gamma, args.lam) |
| |
|
| | stats = ppo_update(policy, optimizer, buf, adv, returns, |
| | clip=args.clip_eps, ent=args.entropy_coef, |
| | epochs=args.n_epochs, bs=args.batch_size) |
| | update_num += 1 |
| |
|
| | if update_num % args.log_interval == 0: |
| | mean_rew = np.mean(ep_rewards[-100:]) if ep_rewards else float("nan") |
| | fps = total_steps / (time.perf_counter() - t0) |
| | print(f" update {update_num:4d} | steps {total_steps:7,} | " |
| | f"mean_rew {mean_rew:6.3f} | fps {fps:4.0f} | " |
| | f"pi {stats['pi']:+.4f} vf {stats['vf']:.4f}") |
| |
|
| | if ep_rewards: |
| | mean_rew = np.mean(ep_rewards[-100:]) |
| | if mean_rew > best_mean: |
| | best_mean = mean_rew |
| | ck_path = args.checkpoint or f"rans_ppo_remote_{env._task}.pt" |
| | torch.save({"policy": policy.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "best_mean_reward": best_mean, |
| | "task": env._task}, ck_path) |
| |
|
| | env.close() |
| | print(f"\nTraining complete. Best mean reward: {best_mean:.3f}") |
| | print(f"Checkpoint: {args.checkpoint or f'rans_ppo_remote_{env._task}.pt'}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def evaluate(args: argparse.Namespace) -> None: |
| | env = RemoteRANSGymnasiumEnv(base_url=args.url) |
| | obs_dim = env.observation_space.shape[0] |
| | act_dim = env.action_space.shape[0] |
| | ck = torch.load(args.checkpoint, map_location="cpu") |
| | policy = ActorCritic(obs_dim, act_dim) |
| | policy.load_state_dict(ck["policy"]) |
| | policy.eval() |
| |
|
| | print(f"\nEvaluating {args.checkpoint} against {args.url}") |
| | print(f" task: {env._task} | best training reward: {ck.get('best_mean_reward', '?'):.3f}") |
| | print("=" * 60) |
| |
|
| | for ep in range(args.eval_episodes): |
| | obs_np, _ = env.reset() |
| | total_r, steps = 0.0, 0 |
| | while True: |
| | action = policy.act_deterministic( |
| | torch.from_numpy(obs_np).float() |
| | ).numpy() |
| | obs_np, r, term, trunc, info = env.step(action) |
| | total_r += r; steps += 1 |
| | if term or trunc: break |
| | print(f" ep {ep+1:2d} | steps {steps:4d} | reward {total_r:.3f} | " |
| | f"goal {info.get('goal_reached', '?')}") |
| | env.close() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | p = argparse.ArgumentParser(description="RANS PPO training via OpenEnv client") |
| | p.add_argument("--url", default="http://localhost:8000") |
| | p.add_argument("--timesteps", type=int, default=200_000) |
| | p.add_argument("--n-steps", type=int, default=1024) |
| | p.add_argument("--n-epochs", type=int, default=10) |
| | p.add_argument("--batch-size", type=int, default=64) |
| | p.add_argument("--lr", type=float, default=3e-4) |
| | p.add_argument("--gamma", type=float, default=0.99) |
| | p.add_argument("--lam", type=float, default=0.95) |
| | p.add_argument("--clip-eps", type=float, default=0.2) |
| | p.add_argument("--entropy-coef", type=float, default=0.01) |
| | p.add_argument("--log-interval", type=int, default=5) |
| | p.add_argument("--checkpoint", default=None) |
| | p.add_argument("--eval", action="store_true") |
| | p.add_argument("--eval-episodes", type=int, default=10) |
| | args = p.parse_args() |
| |
|
| | if args.eval: |
| | if not args.checkpoint: |
| | print("--eval requires --checkpoint"); sys.exit(1) |
| | evaluate(args) |
| | else: |
| | train(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|