dpang commited on
Commit
7078c69
·
verified ·
1 Parent(s): 7243a06

Add server/rans_environment.py

Browse files
Files changed (1) hide show
  1. server/rans_environment.py +228 -0
server/rans_environment.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Space Robotics Lab, SnT, University of Luxembourg, SpaceR
2
+ # RANS: arXiv:2310.07393 — OpenEnv-compatible implementation
3
+
4
+ """
5
+ RANSEnvironment
6
+ ===============
7
+ OpenEnv ``Environment`` subclass that wraps the 2-D spacecraft simulator and
8
+ the RANS task suite.
9
+
10
+ Supported tasks (set via RANS_TASK env-var or constructor argument):
11
+ • GoToPosition — reach a target (x, y)
12
+ • GoToPose — reach a target (x, y, θ)
13
+ • TrackLinearVelocity — maintain (vx_t, vy_t)
14
+ • TrackLinearAngularVelocity — maintain (vx_t, vy_t, ω_t)
15
+
16
+ The environment follows the RANS paper (arXiv:2310.07393) physics and reward
17
+ formulations, adapted to run in CPU-only Docker containers without Isaac Gym.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import math
23
+ import os
24
+ import uuid
25
+ from typing import Any, Dict, Optional
26
+
27
+ import numpy as np
28
+
29
+ try:
30
+ from openenv.core.env_server.interfaces import Action, Environment, Observation
31
+ except ImportError:
32
+ from pydantic import BaseModel as Action # type: ignore[assignment]
33
+ from pydantic import BaseModel as Environment # type: ignore[assignment]
34
+ from pydantic import BaseModel as Observation # type: ignore[assignment]
35
+
36
+ try:
37
+ # Installed package import
38
+ from rans_env.models import SpacecraftAction, SpacecraftObservation, SpacecraftState
39
+ from rans_env.server.spacecraft_physics import Spacecraft2D, SpacecraftConfig
40
+ from rans_env.server.tasks import TASK_REGISTRY
41
+ except ImportError:
42
+ # Development / test import (package not yet installed, RANS dir on sys.path)
43
+ import sys, os as _os
44
+ sys.path.insert(0, _os.path.dirname(_os.path.dirname(__file__)))
45
+ from models import SpacecraftAction, SpacecraftObservation, SpacecraftState # type: ignore[no-redef]
46
+ from server.spacecraft_physics import Spacecraft2D, SpacecraftConfig # type: ignore[no-redef]
47
+ from server.tasks import TASK_REGISTRY # type: ignore[no-redef]
48
+
49
+
50
+ class RANSEnvironment(Environment):
51
+ """
52
+ RANS spacecraft navigation environment for OpenEnv.
53
+
54
+ References
55
+ ----------
56
+ El-Hariry, Richard, Olivares-Mendez (2023).
57
+ "RANS: Highly-Parallelised Simulator for Reinforcement Learning based
58
+ Autonomous Navigating Spacecrafts." arXiv:2310.07393.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ task: str = "GoToPosition",
64
+ spacecraft_config: Optional[SpacecraftConfig] = None,
65
+ task_config: Optional[Dict[str, Any]] = None,
66
+ max_episode_steps: int = 500,
67
+ initial_pos_range: float = 2.0,
68
+ initial_vel_range: float = 0.1,
69
+ ) -> None:
70
+ """
71
+ Parameters
72
+ ----------
73
+ task:
74
+ One of TASK_REGISTRY keys. Overridden by RANS_TASK env-var.
75
+ spacecraft_config:
76
+ Physical platform configuration. Uses 8-thruster MFP2D default.
77
+ task_config:
78
+ Dict of task hyper-parameters forwarded to the task constructor.
79
+ max_episode_steps:
80
+ Hard step limit per episode (overrides RANS_MAX_STEPS env-var).
81
+ initial_pos_range:
82
+ Half-width of the uniform distribution for random initial position.
83
+ initial_vel_range:
84
+ Half-width for random initial velocities.
85
+ """
86
+ # Allow env-var overrides (useful for Docker deployments)
87
+ task = os.environ.get("RANS_TASK", task)
88
+ max_episode_steps = int(
89
+ os.environ.get("RANS_MAX_STEPS", str(max_episode_steps))
90
+ )
91
+
92
+ if task not in TASK_REGISTRY:
93
+ raise ValueError(
94
+ f"Unknown task '{task}'. "
95
+ f"Available: {sorted(TASK_REGISTRY.keys())}"
96
+ )
97
+
98
+ self._task_name = task
99
+ self._max_episode_steps = max_episode_steps
100
+ self._initial_pos_range = initial_pos_range
101
+ self._initial_vel_range = initial_vel_range
102
+
103
+ # Physics simulator
104
+ self._spacecraft = Spacecraft2D(
105
+ spacecraft_config or SpacecraftConfig.default_8_thruster()
106
+ )
107
+
108
+ # Task
109
+ self._task = TASK_REGISTRY[task](task_config or {})
110
+
111
+ # Episode bookkeeping
112
+ self._step_count: int = 0
113
+ self._total_reward: float = 0.0
114
+ self._ep_state = SpacecraftState(task=self._task_name)
115
+
116
+ # ------------------------------------------------------------------
117
+ # OpenEnv Environment interface
118
+ # ------------------------------------------------------------------
119
+
120
+ def reset(self) -> Observation:
121
+ """Start a new episode with a randomised initial spacecraft state."""
122
+ init_state = self._sample_initial_state()
123
+ self._spacecraft.reset(init_state)
124
+
125
+ task_info = self._task.reset(self._spacecraft.state)
126
+
127
+ self._step_count = 0
128
+ self._total_reward = 0.0
129
+ self._ep_state = SpacecraftState(
130
+ episode_id=str(uuid.uuid4()),
131
+ step_count=0,
132
+ task=self._task_name,
133
+ **self._physical_state_dict(),
134
+ )
135
+
136
+ return self._make_observation(reward=0.0, done=False, info=task_info)
137
+
138
+ def step(self, action: Action) -> Observation:
139
+ """Apply thruster activations and advance the simulation by one step."""
140
+ if not hasattr(action, "thrusters"):
141
+ raise ValueError(
142
+ f"Expected SpacecraftAction (with 'thrusters' field), "
143
+ f"received {type(action).__name__}."
144
+ )
145
+
146
+ # Validate / reshape activation vector
147
+ activations = np.array(action.thrusters, dtype=np.float64)
148
+ n = self._spacecraft.n_thrusters
149
+ if len(activations) != n:
150
+ padded = np.zeros(n, dtype=np.float64)
151
+ padded[: min(len(activations), n)] = activations[:n]
152
+ activations = padded
153
+
154
+ # Advance physics
155
+ self._spacecraft.step(activations)
156
+ self._step_count += 1
157
+
158
+ # Compute task reward
159
+ reward, goal_reached, info = self._task.compute_reward(
160
+ self._spacecraft.state
161
+ )
162
+ self._total_reward += reward
163
+
164
+ # Determine episode termination
165
+ done = goal_reached or (self._step_count >= self._max_episode_steps)
166
+
167
+ # Rebuild persistent state (Pydantic models are immutable by default)
168
+ self._ep_state = SpacecraftState(
169
+ episode_id=self._ep_state.episode_id,
170
+ step_count=self._step_count,
171
+ task=self._task_name,
172
+ total_reward=self._total_reward,
173
+ goal_reached=goal_reached,
174
+ **self._physical_state_dict(),
175
+ )
176
+
177
+ return self._make_observation(reward=reward, done=done, info=info)
178
+
179
+ @property
180
+ def state(self) -> SpacecraftState:
181
+ return self._ep_state
182
+
183
+ # ------------------------------------------------------------------
184
+ # Helpers
185
+ # ------------------------------------------------------------------
186
+
187
+ def _sample_initial_state(self) -> np.ndarray:
188
+ """Uniform random initial state (small velocities, random pose)."""
189
+ r = self._initial_pos_range
190
+ v = self._initial_vel_range
191
+ return np.array(
192
+ [
193
+ np.random.uniform(-r, r), # x
194
+ np.random.uniform(-r, r), # y
195
+ np.random.uniform(-math.pi, math.pi), # θ
196
+ np.random.uniform(-v, v), # vx
197
+ np.random.uniform(-v, v), # vy
198
+ np.random.uniform(-v, v), # ω
199
+ ],
200
+ dtype=np.float64,
201
+ )
202
+
203
+ def _physical_state_dict(self) -> Dict[str, float]:
204
+ s = self._spacecraft.state
205
+ return {
206
+ "x": float(s[0]),
207
+ "y": float(s[1]),
208
+ "heading_rad": float(s[2]),
209
+ "vx": float(s[3]),
210
+ "vy": float(s[4]),
211
+ "angular_velocity_rads": float(s[5]),
212
+ }
213
+
214
+ def _make_observation(
215
+ self, reward: float, done: bool, info: Dict[str, Any]
216
+ ) -> SpacecraftObservation:
217
+ task_obs = self._task.get_observation(self._spacecraft.state)
218
+ return SpacecraftObservation(
219
+ state_obs=task_obs.tolist(),
220
+ thruster_transforms=self._spacecraft.get_thruster_transforms().tolist(),
221
+ thruster_masks=self._spacecraft.get_thruster_masks().tolist(),
222
+ mass=self._spacecraft.config.mass,
223
+ inertia=self._spacecraft.config.inertia,
224
+ task=self._task_name,
225
+ reward=float(reward),
226
+ done=bool(done),
227
+ info={**info, "step": self._step_count},
228
+ )