dpang commited on
Commit
87bf301
·
verified ·
1 Parent(s): 50c6a61

Update examples/gymnasium_wrapper.py

Browse files
Files changed (1) hide show
  1. examples/gymnasium_wrapper.py +217 -0
examples/gymnasium_wrapper.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Space Robotics Lab, SnT, University of Luxembourg, SpaceR
3
+ # RANS: arXiv:2310.07393 — OpenEnv training examples
4
+
5
+ """
6
+ Gymnasium Wrapper for RANS
7
+ ===========================
8
+ Wraps ``RANSEnvironment`` in a standard ``gymnasium.Env`` interface so any
9
+ Gymnasium-compatible RL library can be used for training:
10
+
11
+ • Stable-Baselines3 (PPO, SAC, TD3, …)
12
+ • CleanRL
13
+ • RLlib
14
+ • TorchRL
15
+
16
+ The wrapper runs the environment **locally** (in-process) — no HTTP server
17
+ needed. For server-based training, replace ``RANSEnvironment()`` with the
18
+ ``RANSEnv`` WebSocket client (see remote_train_sb3.py).
19
+
20
+ Usage
21
+ -----
22
+ # Standalone check
23
+ python examples/gymnasium_wrapper.py
24
+
25
+ # Stable-Baselines3 PPO (requires: pip install stable-baselines3)
26
+ from examples.gymnasium_wrapper import make_rans_env
27
+ from stable_baselines3 import PPO
28
+
29
+ env = make_rans_env(task="GoToPosition")
30
+ model = PPO("MlpPolicy", env, verbose=1)
31
+ model.learn(total_timesteps=200_000)
32
+ model.save("rans_ppo_go_to_position")
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import sys
38
+ from typing import Any, Dict, Optional, Tuple
39
+
40
+ import numpy as np
41
+
42
+ try:
43
+ import gymnasium as gym
44
+ from gymnasium import spaces
45
+ except ImportError:
46
+ print("gymnasium is required: pip install gymnasium")
47
+ sys.exit(1)
48
+
49
+ # Local import (no server needed)
50
+ sys.path.insert(0, __file__.replace("examples/gymnasium_wrapper.py", ""))
51
+ from server.rans_environment import RANSEnvironment
52
+ from server.spacecraft_physics import SpacecraftConfig
53
+ from rans_env.models import SpacecraftAction
54
+
55
+
56
+ class RANSGymnasiumEnv(gym.Env):
57
+ """
58
+ Gymnasium-compatible wrapper around ``RANSEnvironment``.
59
+
60
+ Observation space:
61
+ Flat Box containing [state_obs, thruster_transforms (flattened),
62
+ thruster_masks, mass, inertia].
63
+
64
+ Action space:
65
+ Box([0, 1]^n_thrusters) — continuous thruster activations.
66
+
67
+ Parameters
68
+ ----------
69
+ task:
70
+ RANS task name.
71
+ spacecraft_config:
72
+ Physical platform configuration.
73
+ task_config:
74
+ Dict of task hyper-parameters.
75
+ max_episode_steps:
76
+ Hard step limit per episode.
77
+ """
78
+
79
+ metadata = {"render_modes": []}
80
+
81
+ def __init__(
82
+ self,
83
+ task: str = "GoToPosition",
84
+ spacecraft_config: Optional[SpacecraftConfig] = None,
85
+ task_config: Optional[Dict[str, Any]] = None,
86
+ max_episode_steps: int = 500,
87
+ ) -> None:
88
+ super().__init__()
89
+ self._env = RANSEnvironment(
90
+ task=task,
91
+ spacecraft_config=spacecraft_config,
92
+ task_config=task_config,
93
+ max_episode_steps=max_episode_steps,
94
+ )
95
+ sc = self._env._spacecraft
96
+
97
+ # --- action space ---
98
+ n = sc.n_thrusters
99
+ self.action_space = spaces.Box(
100
+ low=0.0, high=1.0, shape=(n,), dtype=np.float32
101
+ )
102
+
103
+ # --- observation space ---
104
+ # state_obs (task-dependent) + transforms [n×5] + masks [n] + mass + inertia
105
+ obs0 = self._env.reset()
106
+ flat_obs = self._flatten(obs0)
107
+ dim = flat_obs.shape[0]
108
+ self.observation_space = spaces.Box(
109
+ low=-np.inf, high=np.inf, shape=(dim,), dtype=np.float32
110
+ )
111
+
112
+ self._last_obs = flat_obs
113
+
114
+ # ------------------------------------------------------------------
115
+ # Gymnasium interface
116
+ # ------------------------------------------------------------------
117
+
118
+ def reset(
119
+ self,
120
+ *,
121
+ seed: Optional[int] = None,
122
+ options: Optional[Dict] = None,
123
+ ) -> Tuple[np.ndarray, Dict]:
124
+ super().reset(seed=seed)
125
+ obs = self._env.reset()
126
+ self._last_obs = self._flatten(obs)
127
+ return self._last_obs, {"task": obs.task}
128
+
129
+ def step(
130
+ self, action: np.ndarray
131
+ ) -> Tuple[np.ndarray, float, bool, bool, Dict]:
132
+ result = self._env.step(
133
+ SpacecraftAction(thrusters=action.tolist())
134
+ )
135
+ flat_obs = self._flatten(result)
136
+ reward = float(result.reward or 0.0)
137
+ terminated = bool(result.done)
138
+ truncated = False # RANSEnvironment merges step-limit into done
139
+ self._last_obs = flat_obs
140
+ return flat_obs, reward, terminated, truncated, result.info or {}
141
+
142
+ def render(self) -> None:
143
+ pass # headless — use result.info for diagnostics
144
+
145
+ def close(self) -> None:
146
+ pass
147
+
148
+ # ------------------------------------------------------------------
149
+ # Helpers
150
+ # ------------------------------------------------------------------
151
+
152
+ @staticmethod
153
+ def _flatten(obs) -> np.ndarray:
154
+ """Flatten the SpacecraftObservation into a 1-D float32 array."""
155
+ parts = [
156
+ np.array(obs.state_obs, dtype=np.float32),
157
+ np.array(obs.thruster_transforms, dtype=np.float32).flatten(),
158
+ np.array(obs.thruster_masks, dtype=np.float32),
159
+ np.array([obs.mass, obs.inertia], dtype=np.float32),
160
+ ]
161
+ return np.concatenate(parts)
162
+
163
+
164
+ def make_rans_env(
165
+ task: str = "GoToPosition",
166
+ task_config: Optional[Dict[str, Any]] = None,
167
+ max_episode_steps: int = 500,
168
+ ) -> RANSGymnasiumEnv:
169
+ """
170
+ Factory that returns a ``gymnasium.Env``-compatible RANS environment.
171
+
172
+ Example::
173
+
174
+ from examples.gymnasium_wrapper import make_rans_env
175
+ from stable_baselines3 import PPO
176
+
177
+ env = make_rans_env(task="GoToPose")
178
+ model = PPO("MlpPolicy", env, verbose=1, n_steps=2048)
179
+ model.learn(total_timesteps=500_000)
180
+ """
181
+ return RANSGymnasiumEnv(task=task, task_config=task_config,
182
+ max_episode_steps=max_episode_steps)
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Standalone smoke test
187
+ # ---------------------------------------------------------------------------
188
+
189
+ def _smoke_test() -> None:
190
+ print("RANS Gymnasium Wrapper — smoke test")
191
+ print("=" * 50)
192
+
193
+ for task in ["GoToPosition", "GoToPose",
194
+ "TrackLinearVelocity", "TrackLinearAngularVelocity"]:
195
+ env = make_rans_env(task=task, max_episode_steps=100)
196
+ obs, info = env.reset()
197
+ print(f"\nTask: {task}")
198
+ print(f" obs shape: {obs.shape}")
199
+ print(f" action shape: {env.action_space.shape}")
200
+
201
+ total_reward = 0.0
202
+ for _ in range(100):
203
+ action = env.action_space.sample()
204
+ obs, reward, terminated, truncated, info = env.step(action)
205
+ total_reward += reward
206
+ if terminated or truncated:
207
+ break
208
+
209
+ print(f" total_reward: {total_reward:.3f}")
210
+ print(f" goal_reached: {info.get('goal_reached', False)}")
211
+ env.close()
212
+
213
+ print("\nAll tasks OK.")
214
+
215
+
216
+ if __name__ == "__main__":
217
+ _smoke_test()