dpang commited on
Commit
1284aff
·
verified ·
1 Parent(s): 6cb4f35

Update client.py

Browse files
Files changed (1) hide show
  1. client.py +81 -33
client.py CHANGED
@@ -4,6 +4,16 @@
4
  """
5
  RANSEnv — OpenEnv client for the RANS spacecraft navigation environment.
6
 
 
 
 
 
 
 
 
 
 
 
7
  Usage (async)::
8
 
9
  import asyncio
@@ -11,44 +21,32 @@ Usage (async)::
11
 
12
  async def main():
13
  async with RANSEnv(base_url="http://localhost:8000") as env:
14
- obs = await env.reset()
15
- print("Task:", obs.task)
16
- print("Observation:", obs.state_obs)
17
-
18
- # Zero-thrust step
19
- n = len(obs.thruster_masks)
20
- result = await env.step(SpacecraftAction(thrusters=[0.0] * n))
21
- print("Reward:", result.reward)
22
- print("Done:", result.done)
23
 
24
  asyncio.run(main())
25
 
26
- Usage (synchronous)::
27
-
28
- from rans_env import RANSEnv, SpacecraftAction
29
-
30
- with RANSEnv(base_url="http://localhost:8000").sync() as env:
31
- obs = env.reset()
32
- result = env.step(SpacecraftAction(thrusters=[1, 0, 0, 0, 0, 0, 0, 0]))
33
-
34
  Docker::
35
 
36
- env = RANSEnv.from_docker_image(
37
- "rans-env:latest",
38
- env={"RANS_TASK": "GoToPose"},
39
- )
40
 
41
  HuggingFace Spaces::
42
 
43
- env = RANSEnv.from_env("openenv/rans-env")
44
  """
45
 
46
  from __future__ import annotations
47
 
 
 
48
  try:
49
- from openenv.core.env_client import EnvClient
 
50
  except ImportError:
51
  EnvClient = object # type: ignore[assignment,misc]
 
 
52
 
53
  from rans_env.models import SpacecraftAction, SpacecraftObservation, SpacecraftState
54
 
@@ -57,19 +55,69 @@ class RANSEnv(EnvClient):
57
  """
58
  Client for the RANS spacecraft navigation OpenEnv environment.
59
 
60
- All functionality (``reset``, ``step``, ``state``, ``sync``,
61
- ``from_docker_image``, ``from_env``) is provided by the ``EnvClient``
62
- base class from openenv-core.
63
-
64
- The client is typed: it sends ``SpacecraftAction`` objects and receives
65
- ``SpacecraftObservation`` objects.
66
 
67
  Parameters
68
  ----------
69
  base_url:
70
- Base URL of the running RANS server, e.g. ``"http://localhost:8000"``.
 
71
  """
72
 
73
- action_type = SpacecraftAction
74
- observation_type = SpacecraftObservation
75
- state_type = SpacecraftState
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
  RANSEnv — OpenEnv client for the RANS spacecraft navigation environment.
6
 
7
+ Usage (synchronous)::
8
+
9
+ from rans_env import RANSEnv, SpacecraftAction
10
+
11
+ with RANSEnv(base_url="http://localhost:8000").sync() as env:
12
+ result = env.reset()
13
+ n = len(result.observation.thruster_masks)
14
+ result = env.step(SpacecraftAction(thrusters=[1, 0, 0, 0, 0, 0, 0, 0]))
15
+ print(result.reward, result.done)
16
+
17
  Usage (async)::
18
 
19
  import asyncio
 
21
 
22
  async def main():
23
  async with RANSEnv(base_url="http://localhost:8000") as env:
24
+ result = await env.reset()
25
+ result = await env.step(SpacecraftAction(thrusters=[0.0] * 8))
26
+ print(result.reward, result.done)
 
 
 
 
 
 
27
 
28
  asyncio.run(main())
29
 
 
 
 
 
 
 
 
 
30
  Docker::
31
 
32
+ env = RANSEnv.from_docker_image("rans-env:latest", env={"RANS_TASK": "GoToPose"})
 
 
 
33
 
34
  HuggingFace Spaces::
35
 
36
+ env = RANSEnv.from_env("dpang/rans-env")
37
  """
38
 
39
  from __future__ import annotations
40
 
41
+ from typing import Any, Dict
42
+
43
  try:
44
+ from openenv.core.env_client import EnvClient, StepResult
45
+ _OPENENV_AVAILABLE = True
46
  except ImportError:
47
  EnvClient = object # type: ignore[assignment,misc]
48
+ StepResult = None # type: ignore[assignment,misc]
49
+ _OPENENV_AVAILABLE = False
50
 
51
  from rans_env.models import SpacecraftAction, SpacecraftObservation, SpacecraftState
52
 
 
55
  """
56
  Client for the RANS spacecraft navigation OpenEnv environment.
57
 
58
+ Implements the three ``EnvClient`` abstract methods that handle
59
+ JSON serialisation of actions and deserialisation of observations.
 
 
 
 
60
 
61
  Parameters
62
  ----------
63
  base_url:
64
+ HTTP/WebSocket URL of the running server,
65
+ e.g. ``"http://localhost:8000"`` or ``"ws://localhost:8000"``.
66
  """
67
 
68
+ # ------------------------------------------------------------------
69
+ # EnvClient abstract method implementations
70
+ # ------------------------------------------------------------------
71
+
72
+ def _step_payload(self, action: SpacecraftAction) -> Dict[str, Any]:
73
+ """Serialise SpacecraftAction → JSON dict for the WebSocket message."""
74
+ return {"thrusters": action.thrusters}
75
+
76
+ def _parse_result(self, payload: Dict[str, Any]) -> "StepResult[SpacecraftObservation]":
77
+ """
78
+ Deserialise the server response into a typed StepResult.
79
+
80
+ The server sends::
81
+
82
+ {
83
+ "observation": { "state_obs": [...], "thruster_transforms": [...],
84
+ "thruster_masks": [...], "mass": 10.0, "inertia": 0.5,
85
+ "task": "GoToPosition", "reward": 0.42, "done": false,
86
+ "info": {...} },
87
+ "reward": 0.42,
88
+ "done": false
89
+ }
90
+ """
91
+ obs_dict = payload.get("observation", payload)
92
+ observation = SpacecraftObservation(
93
+ state_obs=obs_dict.get("state_obs", []),
94
+ thruster_transforms=obs_dict.get("thruster_transforms", []),
95
+ thruster_masks=obs_dict.get("thruster_masks", []),
96
+ mass=obs_dict.get("mass", 10.0),
97
+ inertia=obs_dict.get("inertia", 0.5),
98
+ task=obs_dict.get("task", "GoToPosition"),
99
+ reward=float(obs_dict.get("reward") or 0.0),
100
+ done=bool(obs_dict.get("done", False)),
101
+ info=obs_dict.get("info", {}),
102
+ )
103
+ return StepResult(
104
+ observation=observation,
105
+ reward=payload.get("reward") or observation.reward,
106
+ done=payload.get("done", observation.done),
107
+ )
108
+
109
+ def _parse_state(self, payload: Dict[str, Any]) -> SpacecraftState:
110
+ """Deserialise the /state response into a SpacecraftState."""
111
+ return SpacecraftState(
112
+ episode_id=payload.get("episode_id", ""),
113
+ step_count=payload.get("step_count", 0),
114
+ task=payload.get("task", "GoToPosition"),
115
+ x=payload.get("x", 0.0),
116
+ y=payload.get("y", 0.0),
117
+ heading_rad=payload.get("heading_rad", 0.0),
118
+ vx=payload.get("vx", 0.0),
119
+ vy=payload.get("vy", 0.0),
120
+ angular_velocity_rads=payload.get("angular_velocity_rads", 0.0),
121
+ total_reward=payload.get("total_reward", 0.0),
122
+ goal_reached=payload.get("goal_reached", False),
123
+ )