File size: 4,469 Bytes
151d11e
 
 
 
 
 
1284aff
 
 
 
 
 
 
 
 
 
151d11e
 
 
 
 
 
 
1284aff
 
 
151d11e
 
 
 
 
1284aff
151d11e
 
 
1284aff
151d11e
 
 
 
1284aff
 
151d11e
1284aff
 
151d11e
 
1284aff
 
151d11e
 
 
 
 
 
 
 
1284aff
 
151d11e
 
 
 
1284aff
 
151d11e
 
1284aff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Space Robotics Lab, SnT, University of Luxembourg, SpaceR
# RANS: arXiv:2310.07393 — OpenEnv-compatible implementation

"""
RANSEnv — OpenEnv client for the RANS spacecraft navigation environment.

Usage (synchronous)::

    from rans_env import RANSEnv, SpacecraftAction

    with RANSEnv(base_url="http://localhost:8000").sync() as env:
        result = env.reset()
        n = len(result.observation.thruster_masks)
        result = env.step(SpacecraftAction(thrusters=[1, 0, 0, 0, 0, 0, 0, 0]))
        print(result.reward, result.done)

Usage (async)::

    import asyncio
    from rans_env import RANSEnv, SpacecraftAction

    async def main():
        async with RANSEnv(base_url="http://localhost:8000") as env:
            result = await env.reset()
            result = await env.step(SpacecraftAction(thrusters=[0.0] * 8))
            print(result.reward, result.done)

    asyncio.run(main())

Docker::

    env = RANSEnv.from_docker_image("rans-env:latest", env={"RANS_TASK": "GoToPose"})

HuggingFace Spaces::

    env = RANSEnv.from_env("dpang/rans-env")
"""

from __future__ import annotations

from typing import Any, Dict

try:
    from openenv.core.env_client import EnvClient, StepResult
    _OPENENV_AVAILABLE = True
except ImportError:
    EnvClient = object  # type: ignore[assignment,misc]
    StepResult = None   # type: ignore[assignment,misc]
    _OPENENV_AVAILABLE = False

from rans_env.models import SpacecraftAction, SpacecraftObservation, SpacecraftState


class RANSEnv(EnvClient):
    """
    Client for the RANS spacecraft navigation OpenEnv environment.

    Implements the three ``EnvClient`` abstract methods that handle
    JSON serialisation of actions and deserialisation of observations.

    Parameters
    ----------
    base_url:
        HTTP/WebSocket URL of the running server,
        e.g. ``"http://localhost:8000"`` or ``"ws://localhost:8000"``.
    """

    # ------------------------------------------------------------------
    # EnvClient abstract method implementations
    # ------------------------------------------------------------------

    def _step_payload(self, action: SpacecraftAction) -> Dict[str, Any]:
        """Serialise SpacecraftAction → JSON dict for the WebSocket message."""
        return {"thrusters": action.thrusters}

    def _parse_result(self, payload: Dict[str, Any]) -> "StepResult[SpacecraftObservation]":
        """
        Deserialise the server response into a typed StepResult.

        The server sends::

            {
              "observation": { "state_obs": [...], "thruster_transforms": [...],
                               "thruster_masks": [...], "mass": 10.0, "inertia": 0.5,
                               "task": "GoToPosition", "reward": 0.42, "done": false,
                               "info": {...} },
              "reward": 0.42,
              "done": false
            }
        """
        obs_dict = payload.get("observation", payload)
        observation = SpacecraftObservation(
            state_obs=obs_dict.get("state_obs", []),
            thruster_transforms=obs_dict.get("thruster_transforms", []),
            thruster_masks=obs_dict.get("thruster_masks", []),
            mass=obs_dict.get("mass", 10.0),
            inertia=obs_dict.get("inertia", 0.5),
            task=obs_dict.get("task", "GoToPosition"),
            reward=float(obs_dict.get("reward") or 0.0),
            done=bool(obs_dict.get("done", False)),
            info=obs_dict.get("info", {}),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward") or observation.reward,
            done=payload.get("done", observation.done),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> SpacecraftState:
        """Deserialise the /state response into a SpacecraftState."""
        return SpacecraftState(
            episode_id=payload.get("episode_id", ""),
            step_count=payload.get("step_count", 0),
            task=payload.get("task", "GoToPosition"),
            x=payload.get("x", 0.0),
            y=payload.get("y", 0.0),
            heading_rad=payload.get("heading_rad", 0.0),
            vx=payload.get("vx", 0.0),
            vy=payload.get("vy", 0.0),
            angular_velocity_rads=payload.get("angular_velocity_rads", 0.0),
            total_reward=payload.get("total_reward", 0.0),
            goal_reached=payload.get("goal_reached", False),
        )