koichi12 commited on
Commit
f710598
·
verified ·
1 Parent(s): c84597e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py +35 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py +37 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py +428 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py +128 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py +187 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py +1262 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py +481 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py +161 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py +799 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py +1107 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py +0 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py +403 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py +341 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py +462 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py +853 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py +1862 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py +589 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py +544 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__init__.py +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py +400 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py +220 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py +98 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py +157 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py +73 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py +130 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py +214 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py +381 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.core.columns import Columns
2
+
3
+
4
+ DEFAULT_AGENT_ID = "default_agent"
5
+ DEFAULT_POLICY_ID = "default_policy"
6
+ # TODO (sven): Change this to "default_module"
7
+ DEFAULT_MODULE_ID = DEFAULT_POLICY_ID
8
+ ALL_MODULES = "__all_modules__"
9
+
10
+ COMPONENT_ENV_RUNNER = "env_runner"
11
+ COMPONENT_ENV_TO_MODULE_CONNECTOR = "env_to_module_connector"
12
+ COMPONENT_EVAL_ENV_RUNNER = "eval_env_runner"
13
+ COMPONENT_LEARNER = "learner"
14
+ COMPONENT_LEARNER_GROUP = "learner_group"
15
+ COMPONENT_METRICS_LOGGER = "metrics_logger"
16
+ COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector"
17
+ COMPONENT_OPTIMIZER = "optimizer"
18
+ COMPONENT_RL_MODULE = "rl_module"
19
+
20
+
21
+ __all__ = [
22
+ "Columns",
23
+ "COMPONENT_ENV_RUNNER",
24
+ "COMPONENT_ENV_TO_MODULE_CONNECTOR",
25
+ "COMPONENT_EVAL_ENV_RUNNER",
26
+ "COMPONENT_LEARNER",
27
+ "COMPONENT_LEARNER_GROUP",
28
+ "COMPONENT_METRICS_LOGGER",
29
+ "COMPONENT_MODULE_TO_ENV_CONNECTOR",
30
+ "COMPONENT_OPTIMIZER",
31
+ "COMPONENT_RL_MODULE",
32
+ "DEFAULT_AGENT_ID",
33
+ "DEFAULT_MODULE_ID",
34
+ "DEFAULT_POLICY_ID",
35
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.env.base_env import BaseEnv
2
+ from ray.rllib.env.env_context import EnvContext
3
+ from ray.rllib.env.external_env import ExternalEnv
4
+ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
5
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
6
+ from ray.rllib.env.policy_client import PolicyClient
7
+ from ray.rllib.env.policy_server_input import PolicyServerInput
8
+ from ray.rllib.env.remote_base_env import RemoteBaseEnv
9
+ from ray.rllib.env.vector_env import VectorEnv
10
+
11
+ from ray.rllib.env.wrappers.dm_env_wrapper import DMEnv
12
+ from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv
13
+ from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
14
+ from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
15
+ from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
16
+ from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
17
+
18
+ INPUT_ENV_SPACES = "__env__"
19
+
20
+ __all__ = [
21
+ "BaseEnv",
22
+ "DMEnv",
23
+ "DMCEnv",
24
+ "EnvContext",
25
+ "ExternalEnv",
26
+ "ExternalMultiAgentEnv",
27
+ "GroupAgentsWrapper",
28
+ "MultiAgentEnv",
29
+ "PettingZooEnv",
30
+ "ParallelPettingZooEnv",
31
+ "PolicyClient",
32
+ "PolicyServerInput",
33
+ "RemoteBaseEnv",
34
+ "Unity3DEnv",
35
+ "VectorEnv",
36
+ "INPUT_ENV_SPACES",
37
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.58 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc ADDED
Binary file (6.18 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc ADDED
Binary file (9.09 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc ADDED
Binary file (54.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc ADDED
Binary file (7.48 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc ADDED
Binary file (37.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc ADDED
Binary file (40 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc ADDED
Binary file (32.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc ADDED
Binary file (92 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc ADDED
Binary file (28.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc ADDED
Binary file (27.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING, Union, Set
3
+
4
+ import gymnasium as gym
5
+ import ray
6
+ from ray.rllib.utils.annotations import OldAPIStack
7
+ from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict
8
+
9
+ if TYPE_CHECKING:
10
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
11
+
12
+ ASYNC_RESET_RETURN = "async_reset_return"
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @OldAPIStack
18
+ class BaseEnv:
19
+ """The lowest-level env interface used by RLlib for sampling.
20
+
21
+ BaseEnv models multiple agents executing asynchronously in multiple
22
+ vectorized sub-environments. A call to `poll()` returns observations from
23
+ ready agents keyed by their sub-environment ID and agent IDs, and
24
+ actions for those agents can be sent back via `send_actions()`.
25
+
26
+ All other RLlib supported env types can be converted to BaseEnv.
27
+ RLlib handles these conversions internally in RolloutWorker, for example:
28
+
29
+ gym.Env => rllib.VectorEnv => rllib.BaseEnv
30
+ rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
31
+ rllib.ExternalEnv => rllib.BaseEnv
32
+
33
+ .. testcode::
34
+ :skipif: True
35
+
36
+ MyBaseEnv = ...
37
+ env = MyBaseEnv()
38
+ obs, rewards, terminateds, truncateds, infos, off_policy_actions = (
39
+ env.poll()
40
+ )
41
+ print(obs)
42
+
43
+ env.send_actions({
44
+ "env_0": {
45
+ "car_0": 0,
46
+ "car_1": 1,
47
+ }, ...
48
+ })
49
+ obs, rewards, terminateds, truncateds, infos, off_policy_actions = (
50
+ env.poll()
51
+ )
52
+ print(obs)
53
+
54
+ print(terminateds)
55
+
56
+ .. testoutput::
57
+
58
+ {
59
+ "env_0": {
60
+ "car_0": [2.4, 1.6],
61
+ "car_1": [3.4, -3.2],
62
+ },
63
+ "env_1": {
64
+ "car_0": [8.0, 4.1],
65
+ },
66
+ "env_2": {
67
+ "car_0": [2.3, 3.3],
68
+ "car_1": [1.4, -0.2],
69
+ "car_3": [1.2, 0.1],
70
+ },
71
+ }
72
+ {
73
+ "env_0": {
74
+ "car_0": [4.1, 1.7],
75
+ "car_1": [3.2, -4.2],
76
+ }, ...
77
+ }
78
+ {
79
+ "env_0": {
80
+ "__all__": False,
81
+ "car_0": False,
82
+ "car_1": True,
83
+ }, ...
84
+ }
85
+
86
+ """
87
+
88
+ def to_base_env(
89
+ self,
90
+ make_env: Optional[Callable[[int], EnvType]] = None,
91
+ num_envs: int = 1,
92
+ remote_envs: bool = False,
93
+ remote_env_batch_wait_ms: int = 0,
94
+ restart_failed_sub_environments: bool = False,
95
+ ) -> "BaseEnv":
96
+ """Converts an RLlib-supported env into a BaseEnv object.
97
+
98
+ Supported types for the `env` arg are gym.Env, BaseEnv,
99
+ VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
100
+
101
+ The resulting BaseEnv is always vectorized (contains n
102
+ sub-environments) to support batched forward passes, where n may also
103
+ be 1. BaseEnv also supports async execution via the `poll` and
104
+ `send_actions` methods and thus supports external simulators.
105
+
106
+ TODO: Support gym3 environments, which are already vectorized.
107
+
108
+ Args:
109
+ env: An already existing environment of any supported env type
110
+ to convert/wrap into a BaseEnv. Supported types are gym.Env,
111
+ BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
112
+ ExternalMultiAgentEnv.
113
+ make_env: A callable taking an int as input (which indicates the
114
+ number of individual sub-environments within the final
115
+ vectorized BaseEnv) and returning one individual
116
+ sub-environment.
117
+ num_envs: The number of sub-environments to create in the
118
+ resulting (vectorized) BaseEnv. The already existing `env`
119
+ will be one of the `num_envs`.
120
+ remote_envs: Whether each sub-env should be a @ray.remote actor.
121
+ You can set this behavior in your config via the
122
+ `remote_worker_envs=True` option.
123
+ remote_env_batch_wait_ms: The wait time (in ms) to poll remote
124
+ sub-environments for, if applicable. Only used if
125
+ `remote_envs` is True.
126
+ policy_config: Optional policy config dict.
127
+
128
+ Returns:
129
+ The resulting BaseEnv object.
130
+ """
131
+ return self
132
+
133
+ def poll(
134
+ self,
135
+ ) -> Tuple[
136
+ MultiEnvDict,
137
+ MultiEnvDict,
138
+ MultiEnvDict,
139
+ MultiEnvDict,
140
+ MultiEnvDict,
141
+ MultiEnvDict,
142
+ ]:
143
+ """Returns observations from ready agents.
144
+
145
+ All return values are two-level dicts mapping from EnvID to dicts
146
+ mapping from AgentIDs to (observation/reward/etc..) values.
147
+ The number of agents and sub-environments may vary over time.
148
+
149
+ Returns:
150
+ Tuple consisting of:
151
+ New observations for each ready agent.
152
+ Reward values for each ready agent. If the episode is just started,
153
+ the value will be None.
154
+ Terminated values for each ready agent. The special key "__all__" is used to
155
+ indicate episode termination.
156
+ Truncated values for each ready agent. The special key "__all__"
157
+ is used to indicate episode truncation.
158
+ Info values for each ready agent.
159
+ Agents may take off-policy actions, in which case, there will be an entry
160
+ in this dict that contains the taken action. There is no need to
161
+ `send_actions()` for agents that have already chosen off-policy actions.
162
+ """
163
+ raise NotImplementedError
164
+
165
+ def send_actions(self, action_dict: MultiEnvDict) -> None:
166
+ """Called to send actions back to running agents in this env.
167
+
168
+ Actions should be sent for each ready agent that returned observations
169
+ in the previous poll() call.
170
+
171
+ Args:
172
+ action_dict: Actions values keyed by env_id and agent_id.
173
+ """
174
+ raise NotImplementedError
175
+
176
+ def try_reset(
177
+ self,
178
+ env_id: Optional[EnvID] = None,
179
+ *,
180
+ seed: Optional[int] = None,
181
+ options: Optional[dict] = None,
182
+ ) -> Tuple[Optional[MultiEnvDict], Optional[MultiEnvDict]]:
183
+ """Attempt to reset the sub-env with the given id or all sub-envs.
184
+
185
+ If the environment does not support synchronous reset, a tuple of
186
+ (ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) can be returned here.
187
+
188
+ Note: A MultiAgentDict is returned when using the deprecated wrapper
189
+ classes such as `ray.rllib.env.base_env._MultiAgentEnvToBaseEnv`,
190
+ however for consistency with the poll() method, a `MultiEnvDict` is
191
+ returned from the new wrapper classes, such as
192
+ `ray.rllib.env.multi_agent_env.MultiAgentEnvWrapper`.
193
+
194
+ Args:
195
+ env_id: The sub-environment's ID if applicable. If None, reset
196
+ the entire Env (i.e. all sub-environments).
197
+ seed: The seed to be passed to the sub-environment(s) when
198
+ resetting it. If None, will not reset any existing PRNG. If you pass an
199
+ integer, the PRNG will be reset even if it already exists.
200
+ options: An options dict to be passed to the sub-environment(s) when
201
+ resetting it.
202
+
203
+ Returns:
204
+ A tuple consisting of a) the reset (multi-env/multi-agent) observation
205
+ dict and b) the reset (multi-env/multi-agent) infos dict. Returns the
206
+ (ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) tuple, if not supported.
207
+ """
208
+ return None, None
209
+
210
+ def try_restart(self, env_id: Optional[EnvID] = None) -> None:
211
+ """Attempt to restart the sub-env with the given id or all sub-envs.
212
+
213
+ This could result in the sub-env being completely removed (gc'd) and recreated.
214
+
215
+ Args:
216
+ env_id: The sub-environment's ID, if applicable. If None, restart
217
+ the entire Env (i.e. all sub-environments).
218
+ """
219
+ return None
220
+
221
+ def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
222
+ """Return a reference to the underlying sub environments, if any.
223
+
224
+ Args:
225
+ as_dict: If True, return a dict mapping from env_id to env.
226
+
227
+ Returns:
228
+ List or dictionary of the underlying sub environments or [] / {}.
229
+ """
230
+ if as_dict:
231
+ return {}
232
+ return []
233
+
234
+ def get_agent_ids(self) -> Set[AgentID]:
235
+ """Return the agent ids for the sub_environment.
236
+
237
+ Returns:
238
+ All agent ids for each the environment.
239
+ """
240
+ return {}
241
+
242
+ def try_render(self, env_id: Optional[EnvID] = None) -> None:
243
+ """Tries to render the sub-environment with the given id or all.
244
+
245
+ Args:
246
+ env_id: The sub-environment's ID, if applicable.
247
+ If None, renders the entire Env (i.e. all sub-environments).
248
+ """
249
+
250
+ # By default, do nothing.
251
+ pass
252
+
253
+ def stop(self) -> None:
254
+ """Releases all resources used."""
255
+
256
+ # Try calling `close` on all sub-environments.
257
+ for env in self.get_sub_environments():
258
+ if hasattr(env, "close"):
259
+ env.close()
260
+
261
+ @property
262
+ def observation_space(self) -> gym.Space:
263
+ """Returns the observation space for each agent.
264
+
265
+ Note: samples from the observation space need to be preprocessed into a
266
+ `MultiEnvDict` before being used by a policy.
267
+
268
+ Returns:
269
+ The observation space for each environment.
270
+ """
271
+ raise NotImplementedError
272
+
273
+ @property
274
+ def action_space(self) -> gym.Space:
275
+ """Returns the action space for each agent.
276
+
277
+ Note: samples from the action space need to be preprocessed into a
278
+ `MultiEnvDict` before being passed to `send_actions`.
279
+
280
+ Returns:
281
+ The observation space for each environment.
282
+ """
283
+ raise NotImplementedError
284
+
285
+ def last(
286
+ self,
287
+ ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
288
+ """Returns the last observations, rewards, done- truncated flags and infos ...
289
+
290
+ that were returned by the environment.
291
+
292
+ Returns:
293
+ The last observations, rewards, done- and truncated flags, and infos
294
+ for each sub-environment.
295
+ """
296
+ logger.warning("last has not been implemented for this environment.")
297
+ return {}, {}, {}, {}, {}
298
+
299
+
300
+ # Fixed agent identifier when there is only the single agent in the env
301
+ _DUMMY_AGENT_ID = "agent0"
302
+
303
+
304
+ @OldAPIStack
305
+ def with_dummy_agent_id(
306
+ env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID
307
+ ) -> MultiEnvDict:
308
+ ret = {}
309
+ for (env_id, value) in env_id_to_values.items():
310
+ # If the value (e.g. the observation) is an Exception, publish this error
311
+ # under the env ID so the caller of `poll()` knows that the entire episode
312
+ # (sub-environment) has crashed.
313
+ ret[env_id] = value if isinstance(value, Exception) else {dummy_id: value}
314
+ return ret
315
+
316
+
317
+ @OldAPIStack
318
+ def convert_to_base_env(
319
+ env: EnvType,
320
+ make_env: Callable[[int], EnvType] = None,
321
+ num_envs: int = 1,
322
+ remote_envs: bool = False,
323
+ remote_env_batch_wait_ms: int = 0,
324
+ worker: Optional["RolloutWorker"] = None,
325
+ restart_failed_sub_environments: bool = False,
326
+ ) -> "BaseEnv":
327
+ """Converts an RLlib-supported env into a BaseEnv object.
328
+
329
+ Supported types for the `env` arg are gym.Env, BaseEnv,
330
+ VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
331
+
332
+ The resulting BaseEnv is always vectorized (contains n
333
+ sub-environments) to support batched forward passes, where n may also
334
+ be 1. BaseEnv also supports async execution via the `poll` and
335
+ `send_actions` methods and thus supports external simulators.
336
+
337
+ TODO: Support gym3 environments, which are already vectorized.
338
+
339
+ Args:
340
+ env: An already existing environment of any supported env type
341
+ to convert/wrap into a BaseEnv. Supported types are gym.Env,
342
+ BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
343
+ ExternalMultiAgentEnv.
344
+ make_env: A callable taking an int as input (which indicates the
345
+ number of individual sub-environments within the final
346
+ vectorized BaseEnv) and returning one individual
347
+ sub-environment.
348
+ num_envs: The number of sub-environments to create in the
349
+ resulting (vectorized) BaseEnv. The already existing `env`
350
+ will be one of the `num_envs`.
351
+ remote_envs: Whether each sub-env should be a @ray.remote actor.
352
+ You can set this behavior in your config via the
353
+ `remote_worker_envs=True` option.
354
+ remote_env_batch_wait_ms: The wait time (in ms) to poll remote
355
+ sub-environments for, if applicable. Only used if
356
+ `remote_envs` is True.
357
+ worker: An optional RolloutWorker that owns the env. This is only
358
+ used if `remote_worker_envs` is True in your config and the
359
+ `on_sub_environment_created` custom callback needs to be called
360
+ on each created actor.
361
+ restart_failed_sub_environments: If True and any sub-environment (within
362
+ a vectorized env) throws any error during env stepping, the
363
+ Sampler will try to restart the faulty sub-environment. This is done
364
+ without disturbing the other (still intact) sub-environment and without
365
+ the RolloutWorker crashing.
366
+
367
+ Returns:
368
+ The resulting BaseEnv object.
369
+ """
370
+
371
+ from ray.rllib.env.remote_base_env import RemoteBaseEnv
372
+ from ray.rllib.env.external_env import ExternalEnv
373
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
374
+ from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper
375
+
376
+ if remote_envs and num_envs == 1:
377
+ raise ValueError(
378
+ "Remote envs only make sense to use if num_envs > 1 "
379
+ "(i.e. environment vectorization is enabled)."
380
+ )
381
+
382
+ # Given `env` has a `to_base_env` method -> Call that to convert to a BaseEnv type.
383
+ if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
384
+ return env.to_base_env(
385
+ make_env=make_env,
386
+ num_envs=num_envs,
387
+ remote_envs=remote_envs,
388
+ remote_env_batch_wait_ms=remote_env_batch_wait_ms,
389
+ restart_failed_sub_environments=restart_failed_sub_environments,
390
+ )
391
+ # `env` is not a BaseEnv yet -> Need to convert/vectorize.
392
+ else:
393
+ # Sub-environments are ray.remote actors:
394
+ if remote_envs:
395
+ # Determine, whether the already existing sub-env (could
396
+ # be a ray.actor) is multi-agent or not.
397
+ multiagent = (
398
+ ray.get(env._is_multi_agent.remote())
399
+ if hasattr(env, "_is_multi_agent")
400
+ else False
401
+ )
402
+ env = RemoteBaseEnv(
403
+ make_env,
404
+ num_envs,
405
+ multiagent=multiagent,
406
+ remote_env_batch_wait_ms=remote_env_batch_wait_ms,
407
+ existing_envs=[env],
408
+ worker=worker,
409
+ restart_failed_sub_environments=restart_failed_sub_environments,
410
+ )
411
+ # Sub-environments are not ray.remote actors.
412
+ else:
413
+ # Convert gym.Env to VectorEnv ...
414
+ env = VectorEnv.vectorize_gym_envs(
415
+ make_env=make_env,
416
+ existing_envs=[env],
417
+ num_envs=num_envs,
418
+ action_space=env.action_space,
419
+ observation_space=env.observation_space,
420
+ restart_failed_sub_environments=restart_failed_sub_environments,
421
+ )
422
+ # ... then the resulting VectorEnv to a BaseEnv.
423
+ env = VectorEnvWrapper(env)
424
+
425
+ # Make sure conversion went well.
426
+ assert isinstance(env, BaseEnv), env
427
+
428
+ return env
.venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional
3
+
4
+ from ray.rllib.utils.annotations import OldAPIStack
5
+ from ray.rllib.utils.typing import EnvConfigDict
6
+
7
+
8
+ @OldAPIStack
9
+ class EnvContext(dict):
10
+ """Wraps env configurations to include extra rllib metadata.
11
+
12
+ These attributes can be used to parameterize environments per process.
13
+ For example, one might use `worker_index` to control which data file an
14
+ environment reads in on initialization.
15
+
16
+ RLlib auto-sets these attributes when constructing registered envs.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ env_config: EnvConfigDict,
22
+ worker_index: int,
23
+ vector_index: int = 0,
24
+ remote: bool = False,
25
+ num_workers: Optional[int] = None,
26
+ recreated_worker: bool = False,
27
+ ):
28
+ """Initializes an EnvContext instance.
29
+
30
+ Args:
31
+ env_config: The env's configuration defined under the
32
+ "env_config" key in the Algorithm's config.
33
+ worker_index: When there are multiple workers created, this
34
+ uniquely identifies the worker the env is created in.
35
+ 0 for local worker, >0 for remote workers.
36
+ vector_index: When there are multiple envs per worker, this
37
+ uniquely identifies the env index within the worker.
38
+ Starts from 0.
39
+ remote: Whether individual sub-environments (in a vectorized
40
+ env) should be @ray.remote actors or not.
41
+ num_workers: The total number of (remote) workers in the set.
42
+ 0 if only a local worker exists.
43
+ recreated_worker: Whether the worker that holds this env is a recreated one.
44
+ This means that it replaced a previous (failed) worker when
45
+ `restart_failed_env_runners=True` in the Algorithm's config.
46
+ """
47
+ # Store the env_config in the (super) dict.
48
+ dict.__init__(self, env_config)
49
+
50
+ # Set some metadata attributes.
51
+ self.worker_index = worker_index
52
+ self.vector_index = vector_index
53
+ self.remote = remote
54
+ self.num_workers = num_workers
55
+ self.recreated_worker = recreated_worker
56
+
57
+ def copy_with_overrides(
58
+ self,
59
+ env_config: Optional[EnvConfigDict] = None,
60
+ worker_index: Optional[int] = None,
61
+ vector_index: Optional[int] = None,
62
+ remote: Optional[bool] = None,
63
+ num_workers: Optional[int] = None,
64
+ recreated_worker: Optional[bool] = None,
65
+ ) -> "EnvContext":
66
+ """Returns a copy of this EnvContext with some attributes overridden.
67
+
68
+ Args:
69
+ env_config: Optional env config to use. None for not overriding
70
+ the one from the source (self).
71
+ worker_index: Optional worker index to use. None for not
72
+ overriding the one from the source (self).
73
+ vector_index: Optional vector index to use. None for not
74
+ overriding the one from the source (self).
75
+ remote: Optional remote setting to use. None for not overriding
76
+ the one from the source (self).
77
+ num_workers: Optional num_workers to use. None for not overriding
78
+ the one from the source (self).
79
+ recreated_worker: Optional flag, indicating, whether the worker that holds
80
+ the env is a recreated one. This means that it replaced a previous
81
+ (failed) worker when `restart_failed_env_runners=True` in the
82
+ Algorithm's config.
83
+
84
+ Returns:
85
+ A new EnvContext object as a copy of self plus the provided
86
+ overrides.
87
+ """
88
+ return EnvContext(
89
+ copy.deepcopy(env_config) if env_config is not None else self,
90
+ worker_index if worker_index is not None else self.worker_index,
91
+ vector_index if vector_index is not None else self.vector_index,
92
+ remote if remote is not None else self.remote,
93
+ num_workers if num_workers is not None else self.num_workers,
94
+ recreated_worker if recreated_worker is not None else self.recreated_worker,
95
+ )
96
+
97
+ def set_defaults(self, defaults: dict) -> None:
98
+ """Sets missing keys of self to the values given in `defaults`.
99
+
100
+ If `defaults` contains keys that already exist in self, don't override
101
+ the values with these defaults.
102
+
103
+ Args:
104
+ defaults: The key/value pairs to add to self, but only for those
105
+ keys in `defaults` that don't exist yet in self.
106
+
107
+ .. testcode::
108
+ :skipif: True
109
+
110
+ from ray.rllib.env.env_context import EnvContext
111
+ env_ctx = EnvContext({"a": 1, "b": 2}, worker_index=0)
112
+ env_ctx.set_defaults({"a": -42, "c": 3})
113
+ print(env_ctx)
114
+
115
+ .. testoutput::
116
+
117
+ {"a": 1, "b": 2, "c": 3}
118
+ """
119
+ for key, value in defaults.items():
120
+ if key not in self:
121
+ self[key] = value
122
+
123
+ def __str__(self):
124
+ return (
125
+ super().__str__()[:-1]
126
+ + f", worker={self.worker_index}/{self.num_workers}, "
127
+ f"vector_idx={self.vector_index}, remote={self.remote}" + "}"
128
+ )
.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ from typing import Any, Dict, Tuple, TYPE_CHECKING
4
+
5
+ import gymnasium as gym
6
+ import tree # pip install dm_tree
7
+
8
+ from ray.rllib.utils.actor_manager import FaultAwareApply
9
+ from ray.rllib.utils.framework import try_import_tf
10
+ from ray.rllib.utils.torch_utils import convert_to_torch_tensor
11
+ from ray.rllib.utils.typing import TensorType
12
+ from ray.util.annotations import PublicAPI
13
+
14
+ if TYPE_CHECKING:
15
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
16
+
17
+ logger = logging.getLogger("ray.rllib")
18
+
19
+ tf1, tf, _ = try_import_tf()
20
+
21
+ ENV_RESET_FAILURE = "env_reset_failure"
22
+ ENV_STEP_FAILURE = "env_step_failure"
23
+
24
+
25
+ # TODO (sven): As soon as RolloutWorker is no longer supported, make this base class
26
+ # a Checkpointable. Currently, only some of its subclasses are Checkpointables.
27
+ @PublicAPI(stability="alpha")
28
+ class EnvRunner(FaultAwareApply, metaclass=abc.ABCMeta):
29
+ """Base class for distributed RL-style data collection from an environment.
30
+
31
+ The EnvRunner API's core functionalities can be summarized as:
32
+ - Gets configured via passing a AlgorithmConfig object to the constructor.
33
+ Normally, subclasses of EnvRunner then construct their own environment (possibly
34
+ vectorized) copies and RLModules/Policies and use the latter to step through the
35
+ environment in order to collect training data.
36
+ - Clients of EnvRunner can use the `sample()` method to collect data for training
37
+ from the environment(s).
38
+ - EnvRunner offers parallelism via creating n remote Ray Actors based on this class.
39
+ Use `ray.remote([resources])(EnvRunner)` method to create the corresponding Ray
40
+ remote class. Then instantiate n Actors using the Ray `[ctor].remote(...)` syntax.
41
+ - EnvRunner clients can get information about the server/node on which the
42
+ individual Actors are running.
43
+ """
44
+
45
+ def __init__(self, *, config: "AlgorithmConfig", **kwargs):
46
+ """Initializes an EnvRunner instance.
47
+
48
+ Args:
49
+ config: The AlgorithmConfig to use to setup this EnvRunner.
50
+ **kwargs: Forward compatibility kwargs.
51
+ """
52
+ self.config = config.copy(copy_frozen=False)
53
+ self.env = None
54
+
55
+ super().__init__(**kwargs)
56
+
57
+ # This eager check is necessary for certain all-framework tests
58
+ # that use tf's eager_mode() context generator.
59
+ if (
60
+ tf1
61
+ and (self.config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
62
+ and not tf1.executing_eagerly()
63
+ ):
64
+ tf1.enable_eager_execution()
65
+
66
+ @abc.abstractmethod
67
+ def assert_healthy(self):
68
+ """Checks that self.__init__() has been completed properly.
69
+
70
+ Useful in case an `EnvRunner` is run as @ray.remote (Actor) and the owner
71
+ would like to make sure the Ray Actor has been properly initialized.
72
+
73
+ Raises:
74
+ AssertionError: If the EnvRunner Actor has NOT been properly initialized.
75
+ """
76
+
77
+ # TODO: Make this an abstract method that must be implemented.
78
+ def make_env(self):
79
+ """Creates the RL environment for this EnvRunner and assigns it to `self.env`.
80
+
81
+ Note that users should be able to change the EnvRunner's config (e.g. change
82
+ `self.config.env_config`) and then call this method to create new environments
83
+ with the updated configuration.
84
+ It should also be called after a failure of an earlier env in order to clean up
85
+ the existing env (for example `close()` it), re-create a new one, and then
86
+ continue sampling with that new env.
87
+ """
88
+ pass
89
+
90
+ # TODO: Make this an abstract method that must be implemented.
91
+ def make_module(self):
92
+ """Creates the RLModule for this EnvRunner and assigns it to `self.module`.
93
+
94
+ Note that users should be able to change the EnvRunner's config (e.g. change
95
+ `self.config.rl_module_spec`) and then call this method to create a new RLModule
96
+ with the updated configuration.
97
+ """
98
+ pass
99
+
100
+ @abc.abstractmethod
101
+ def sample(self, **kwargs) -> Any:
102
+ """Returns experiences (of any form) sampled from this EnvRunner.
103
+
104
+ The exact nature and size of collected data are defined via the EnvRunner's
105
+ config and may be overridden by the given arguments.
106
+
107
+ Args:
108
+ **kwargs: Forward compatibility kwargs.
109
+
110
+ Returns:
111
+ The collected experience in any form.
112
+ """
113
+
114
+ # TODO (sven): Make this an abstract method that must be overridden.
115
+ def get_metrics(self) -> Any:
116
+ """Returns metrics (in any form) of the thus far collected, completed episodes.
117
+
118
+ Returns:
119
+ Metrics of any form.
120
+ """
121
+ pass
122
+
123
+ @abc.abstractmethod
124
+ def get_spaces(self) -> Dict[str, Tuple[gym.Space, gym.Space]]:
125
+ """Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space."""
126
+
127
+ def stop(self) -> None:
128
+ """Releases all resources used by this EnvRunner.
129
+
130
+ For example, when using a gym.Env in this EnvRunner, you should make sure
131
+ that its `close()` method is called.
132
+ """
133
+ pass
134
+
135
+ def __del__(self) -> None:
136
+ """If this Actor is deleted, clears all resources used by it."""
137
+ pass
138
+
139
+ def _try_env_reset(self):
140
+ """Tries resetting the env and - if an error orrurs - handles it gracefully."""
141
+ # Try to reset.
142
+ try:
143
+ obs, infos = self.env.reset()
144
+ # Everything ok -> return.
145
+ return obs, infos
146
+ # Error.
147
+ except Exception as e:
148
+ # If user wants to simply restart the env -> recreate env and try again
149
+ # (calling this method recursively until success).
150
+ if self.config.restart_failed_sub_environments:
151
+ logger.exception(
152
+ "Resetting the env resulted in an error! The original error "
153
+ f"is: {e.args[0]}"
154
+ )
155
+ # Recreate the env and simply try again.
156
+ self.make_env()
157
+ return self._try_env_reset()
158
+ else:
159
+ raise e
160
+
161
+ def _try_env_step(self, actions):
162
+ """Tries stepping the env and - if an error orrurs - handles it gracefully."""
163
+ try:
164
+ results = self.env.step(actions)
165
+ return results
166
+ except Exception as e:
167
+ if self.config.restart_failed_sub_environments:
168
+ logger.exception(
169
+ "Stepping the env resulted in an error! The original error "
170
+ f"is: {e.args[0]}"
171
+ )
172
+ # Recreate the env.
173
+ self.make_env()
174
+ # And return that the stepping failed. The caller will then handle
175
+ # specific cleanup operations (for example discarding thus-far collected
176
+ # data and repeating the step attempt).
177
+ return ENV_STEP_FAILURE
178
+ else:
179
+ raise e
180
+
181
+ def _convert_to_tensor(self, struct) -> TensorType:
182
+ """Converts structs to a framework-specific tensor."""
183
+
184
+ if self.config.framework_str == "torch":
185
+ return convert_to_torch_tensor(struct)
186
+ else:
187
+ return tree.map_structure(tf.convert_to_tensor, struct)
.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gymnasium as gym
3
+ import logging
4
+ import importlib.util
5
+ import os
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Collection,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ TYPE_CHECKING,
16
+ TypeVar,
17
+ Union,
18
+ )
19
+
20
+ import ray
21
+ from ray.actor import ActorHandle
22
+ from ray.exceptions import RayActorError
23
+ from ray.rllib.core import (
24
+ COMPONENT_ENV_TO_MODULE_CONNECTOR,
25
+ COMPONENT_LEARNER,
26
+ COMPONENT_MODULE_TO_ENV_CONNECTOR,
27
+ COMPONENT_RL_MODULE,
28
+ )
29
+ from ray.rllib.core.learner import LearnerGroup
30
+ from ray.rllib.core.rl_module import validate_module_id
31
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
32
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
33
+ from ray.rllib.env.base_env import BaseEnv
34
+ from ray.rllib.env.env_context import EnvContext
35
+ from ray.rllib.env.env_runner import EnvRunner
36
+ from ray.rllib.offline import get_dataset_and_shards
37
+ from ray.rllib.policy.policy import Policy, PolicyState
38
+ from ray.rllib.utils.actor_manager import FaultTolerantActorManager
39
+ from ray.rllib.utils.annotations import OldAPIStack
40
+ from ray.rllib.utils.deprecation import (
41
+ Deprecated,
42
+ deprecation_warning,
43
+ DEPRECATED_VALUE,
44
+ )
45
+ from ray.rllib.utils.framework import try_import_tf
46
+ from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO
47
+ from ray.rllib.utils.typing import (
48
+ AgentID,
49
+ EnvCreator,
50
+ EnvType,
51
+ EpisodeID,
52
+ PartialAlgorithmConfigDict,
53
+ PolicyID,
54
+ SampleBatchType,
55
+ TensorType,
56
+ )
57
+ from ray.util.annotations import DeveloperAPI
58
+
59
+ if TYPE_CHECKING:
60
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
61
+
62
+ tf1, tf, tfv = try_import_tf()
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+ # Generic type var for foreach_* methods.
67
+ T = TypeVar("T")
68
+
69
+
70
+ @DeveloperAPI
71
+ class EnvRunnerGroup:
72
+ """Set of EnvRunners with n @ray.remote workers and zero or one local worker.
73
+
74
+ Where: n >= 0.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ *,
80
+ env_creator: Optional[EnvCreator] = None,
81
+ validate_env: Optional[Callable[[EnvType], None]] = None,
82
+ default_policy_class: Optional[Type[Policy]] = None,
83
+ config: Optional["AlgorithmConfig"] = None,
84
+ local_env_runner: bool = True,
85
+ logdir: Optional[str] = None,
86
+ _setup: bool = True,
87
+ tune_trial_id: Optional[str] = None,
88
+ # Deprecated args.
89
+ num_env_runners: Optional[int] = None,
90
+ num_workers=DEPRECATED_VALUE,
91
+ local_worker=DEPRECATED_VALUE,
92
+ ):
93
+ """Initializes a EnvRunnerGroup instance.
94
+
95
+ Args:
96
+ env_creator: Function that returns env given env config.
97
+ validate_env: Optional callable to validate the generated
98
+ environment (only on worker=0). This callable should raise
99
+ an exception if the environment is invalid.
100
+ default_policy_class: An optional default Policy class to use inside
101
+ the (multi-agent) `policies` dict. In case the PolicySpecs in there
102
+ have no class defined, use this `default_policy_class`.
103
+ If None, PolicySpecs will be using the Algorithm's default Policy
104
+ class.
105
+ config: Optional AlgorithmConfig (or config dict).
106
+ local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
107
+ in the returned set as well (default: True). If `num_env_runners`
108
+ is 0, always create a local EnvRunner.
109
+ logdir: Optional logging directory for workers.
110
+ _setup: Whether to actually set up workers. This is only for testing.
111
+ tune_trial_id: The Ray Tune trial ID, if this EnvRunnerGroup is part of
112
+ an Algorithm run as a Tune trial. None, otherwise.
113
+ """
114
+ if num_workers != DEPRECATED_VALUE or local_worker != DEPRECATED_VALUE:
115
+ deprecation_warning(
116
+ old="WorkerSet(num_workers=..., local_worker=...)",
117
+ new="EnvRunnerGroup(num_env_runners=..., local_env_runner=...)",
118
+ error=True,
119
+ )
120
+
121
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
122
+
123
+ # Make sure `config` is an AlgorithmConfig object.
124
+ if not config:
125
+ config = AlgorithmConfig()
126
+ elif isinstance(config, dict):
127
+ config = AlgorithmConfig.from_dict(config)
128
+
129
+ self._env_creator = env_creator
130
+ self._policy_class = default_policy_class
131
+ self._remote_config = config
132
+ self._remote_args = {
133
+ "num_cpus": self._remote_config.num_cpus_per_env_runner,
134
+ "num_gpus": self._remote_config.num_gpus_per_env_runner,
135
+ "resources": self._remote_config.custom_resources_per_env_runner,
136
+ "max_restarts": (
137
+ config.max_num_env_runner_restarts
138
+ if config.restart_failed_env_runners
139
+ else 0
140
+ ),
141
+ }
142
+ self._tune_trial_id = tune_trial_id
143
+
144
+ # Set the EnvRunner subclass to be used as "workers". Default: RolloutWorker.
145
+ self.env_runner_cls = config.env_runner_cls
146
+ if self.env_runner_cls is None:
147
+ if config.enable_env_runner_and_connector_v2:
148
+ # If experiences should be recorded, use the `
149
+ # OfflineSingleAgentEnvRunner`.
150
+ if config.output:
151
+ # No multi-agent support.
152
+ if config.is_multi_agent:
153
+ raise ValueError("Multi-agent recording is not supported, yet.")
154
+ # Otherwise, load the single-agent env runner for
155
+ # recording.
156
+ else:
157
+ from ray.rllib.offline.offline_env_runner import (
158
+ OfflineSingleAgentEnvRunner,
159
+ )
160
+
161
+ self.env_runner_cls = OfflineSingleAgentEnvRunner
162
+ else:
163
+ if config.is_multi_agent:
164
+ from ray.rllib.env.multi_agent_env_runner import (
165
+ MultiAgentEnvRunner,
166
+ )
167
+
168
+ self.env_runner_cls = MultiAgentEnvRunner
169
+ else:
170
+ from ray.rllib.env.single_agent_env_runner import (
171
+ SingleAgentEnvRunner,
172
+ )
173
+
174
+ self.env_runner_cls = SingleAgentEnvRunner
175
+ else:
176
+ self.env_runner_cls = RolloutWorker
177
+ self._cls = ray.remote(**self._remote_args)(self.env_runner_cls).remote
178
+
179
+ self._logdir = logdir
180
+ self._ignore_ray_errors_on_env_runners = (
181
+ config.ignore_env_runner_failures or config.restart_failed_env_runners
182
+ )
183
+
184
+ # Create remote worker manager.
185
+ # ID=0 is used by the local worker.
186
+ # Starting remote workers from ID=1 to avoid conflicts.
187
+ self._worker_manager = FaultTolerantActorManager(
188
+ max_remote_requests_in_flight_per_actor=(
189
+ config.max_requests_in_flight_per_env_runner
190
+ ),
191
+ init_id=1,
192
+ )
193
+
194
+ if _setup:
195
+ try:
196
+ self._setup(
197
+ validate_env=validate_env,
198
+ config=config,
199
+ num_env_runners=(
200
+ num_env_runners
201
+ if num_env_runners is not None
202
+ else config.num_env_runners
203
+ ),
204
+ local_env_runner=local_env_runner,
205
+ )
206
+ # EnvRunnerGroup creation possibly fails, if some (remote) workers cannot
207
+ # be initialized properly (due to some errors in the EnvRunners's
208
+ # constructor).
209
+ except RayActorError as e:
210
+ # In case of an actor (remote worker) init failure, the remote worker
211
+ # may still exist and will be accessible, however, e.g. calling
212
+ # its `sample.remote()` would result in strange "property not found"
213
+ # errors.
214
+ if e.actor_init_failed:
215
+ # Raise the original error here that the EnvRunners raised
216
+ # during its construction process. This is to enforce transparency
217
+ # for the user (better to understand the real reason behind the
218
+ # failure).
219
+ # - e.args[0]: The RayTaskError (inside the caught RayActorError).
220
+ # - e.args[0].args[2]: The original Exception (e.g. a ValueError due
221
+ # to a config mismatch) thrown inside the actor.
222
+ raise e.args[0].args[2]
223
+ # In any other case, raise the RayActorError as-is.
224
+ else:
225
+ raise e
226
+
227
+ def _setup(
228
+ self,
229
+ *,
230
+ validate_env: Optional[Callable[[EnvType], None]] = None,
231
+ config: Optional["AlgorithmConfig"] = None,
232
+ num_env_runners: int = 0,
233
+ local_env_runner: bool = True,
234
+ ):
235
+ """Sets up an EnvRunnerGroup instance.
236
+ Args:
237
+ validate_env: Optional callable to validate the generated
238
+ environment (only on worker=0).
239
+ config: Optional dict that extends the common config of
240
+ the Algorithm class.
241
+ num_env_runners: Number of remote EnvRunner workers to create.
242
+ local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
243
+ in the returned set as well (default: True). If `num_env_runners`
244
+ is 0, always create a local EnvRunner.
245
+ """
246
+ # Force a local worker if num_env_runners == 0 (no remote workers).
247
+ # Otherwise, this EnvRunnerGroup would be empty.
248
+ self._local_env_runner = None
249
+ if num_env_runners == 0:
250
+ local_env_runner = True
251
+ # Create a local (learner) version of the config for the local worker.
252
+ # The only difference is the tf_session_args, which - for the local worker -
253
+ # will be `config.tf_session_args` updated/overridden with
254
+ # `config.local_tf_session_args`.
255
+ local_tf_session_args = config.tf_session_args.copy()
256
+ local_tf_session_args.update(config.local_tf_session_args)
257
+ self._local_config = config.copy(copy_frozen=False).framework(
258
+ tf_session_args=local_tf_session_args
259
+ )
260
+
261
+ if config.input_ == "dataset":
262
+ # Create the set of dataset readers to be shared by all the
263
+ # rollout workers.
264
+ self._ds, self._ds_shards = get_dataset_and_shards(config, num_env_runners)
265
+ else:
266
+ self._ds = None
267
+ self._ds_shards = None
268
+
269
+ # Create a number of @ray.remote workers.
270
+ self.add_workers(
271
+ num_env_runners,
272
+ validate=config.validate_env_runners_after_construction,
273
+ )
274
+
275
+ # If num_env_runners > 0 and we don't have an env on the local worker,
276
+ # get the observation- and action spaces for each policy from
277
+ # the first remote worker (which does have an env).
278
+ if (
279
+ local_env_runner
280
+ and self._worker_manager.num_actors() > 0
281
+ and not config.enable_env_runner_and_connector_v2
282
+ and not config.create_env_on_local_worker
283
+ and (not config.observation_space or not config.action_space)
284
+ ):
285
+ spaces = self.get_spaces()
286
+ else:
287
+ spaces = None
288
+
289
+ # Create a local worker, if needed.
290
+ if local_env_runner:
291
+ self._local_env_runner = self._make_worker(
292
+ cls=self.env_runner_cls,
293
+ env_creator=self._env_creator,
294
+ validate_env=validate_env,
295
+ worker_index=0,
296
+ num_workers=num_env_runners,
297
+ config=self._local_config,
298
+ spaces=spaces,
299
+ )
300
+
301
+ def get_spaces(self):
302
+ """Infer observation and action spaces from one (local or remote) EnvRunner.
303
+
304
+ Returns:
305
+ A dict mapping from ModuleID to a 2-tuple containing obs- and action-space.
306
+ """
307
+ # Get ID of the first remote worker.
308
+ remote_worker_ids = (
309
+ [self._worker_manager.actor_ids()[0]]
310
+ if self._worker_manager.actor_ids()
311
+ else []
312
+ )
313
+
314
+ spaces = self.foreach_env_runner(
315
+ lambda env_runner: env_runner.get_spaces(),
316
+ remote_worker_ids=remote_worker_ids,
317
+ local_env_runner=not remote_worker_ids,
318
+ )[0]
319
+
320
+ logger.info(
321
+ "Inferred observation/action spaces from remote "
322
+ f"worker (local worker has no env): {spaces}"
323
+ )
324
+
325
+ return spaces
326
+
327
+ @property
328
+ def local_env_runner(self) -> EnvRunner:
329
+ """Returns the local EnvRunner."""
330
+ return self._local_env_runner
331
+
332
+ def healthy_env_runner_ids(self) -> List[int]:
333
+ """Returns the list of remote worker IDs."""
334
+ return self._worker_manager.healthy_actor_ids()
335
+
336
+ def healthy_worker_ids(self) -> List[int]:
337
+ """Returns the list of remote worker IDs."""
338
+ return self.healthy_env_runner_ids()
339
+
340
+ def num_remote_env_runners(self) -> int:
341
+ """Returns the number of remote EnvRunners."""
342
+ return self._worker_manager.num_actors()
343
+
344
+ def num_remote_workers(self) -> int:
345
+ """Returns the number of remote EnvRunners."""
346
+ return self.num_remote_env_runners()
347
+
348
+ def num_healthy_remote_env_runners(self) -> int:
349
+ """Returns the number of healthy remote workers."""
350
+ return self._worker_manager.num_healthy_actors()
351
+
352
+ def num_healthy_remote_workers(self) -> int:
353
+ """Returns the number of healthy remote workers."""
354
+ return self.num_healthy_remote_env_runners()
355
+
356
+ def num_healthy_env_runners(self) -> int:
357
+ """Returns the number of all healthy workers, including the local worker."""
358
+ return int(bool(self._local_env_runner)) + self.num_healthy_remote_workers()
359
+
360
+ def num_healthy_workers(self) -> int:
361
+ """Returns the number of all healthy workers, including the local worker."""
362
+ return self.num_healthy_env_runners()
363
+
364
+ def num_in_flight_async_reqs(self) -> int:
365
+ """Returns the number of in-flight async requests."""
366
+ return self._worker_manager.num_outstanding_async_reqs()
367
+
368
+ def num_remote_worker_restarts(self) -> int:
369
+ """Total number of times managed remote workers have been restarted."""
370
+ return self._worker_manager.total_num_restarts()
371
+
372
+ def sync_env_runner_states(
373
+ self,
374
+ *,
375
+ config: "AlgorithmConfig",
376
+ from_worker: Optional[EnvRunner] = None,
377
+ env_steps_sampled: Optional[int] = None,
378
+ connector_states: Optional[List[Dict[str, Any]]] = None,
379
+ rl_module_state: Optional[Dict[str, Any]] = None,
380
+ env_runner_indices_to_update: Optional[List[int]] = None,
381
+ ) -> None:
382
+ """Synchronizes the connectors of this EnvRunnerGroup's EnvRunners.
383
+
384
+ The exact procedure works as follows:
385
+ - If `from_worker` is None, set `from_worker=self.local_env_runner`.
386
+ - If `config.use_worker_filter_stats` is True, gather all remote EnvRunners'
387
+ ConnectorV2 states. Otherwise, only use the ConnectorV2 states of `from_worker`.
388
+ - Merge all gathered states into one resulting state.
389
+ - Broadcast the resulting state back to all remote EnvRunners AND the local
390
+ EnvRunner.
391
+
392
+ Args:
393
+ config: The AlgorithmConfig object to use to determine, in which
394
+ direction(s) we need to synch and what the timeouts are.
395
+ from_worker: The EnvRunner from which to synch. If None, will use the local
396
+ worker of this EnvRunnerGroup.
397
+ env_steps_sampled: The total number of env steps taken thus far by all
398
+ workers combined. Used to broadcast this number to all remote workers
399
+ if `update_worker_filter_stats` is True in `config`.
400
+ env_runner_indices_to_update: The indices of those EnvRunners to update
401
+ with the merged state. Use None (default) to update all remote
402
+ EnvRunners.
403
+ """
404
+ from_worker = from_worker or self.local_env_runner
405
+
406
+ # Early out if the number of (healthy) remote workers is 0. In this case, the
407
+ # local worker is the only operating worker and thus of course always holds
408
+ # the reference connector state.
409
+ if self.num_healthy_remote_workers() == 0:
410
+ self.local_env_runner.set_state(
411
+ {
412
+ **(
413
+ {NUM_ENV_STEPS_SAMPLED_LIFETIME: env_steps_sampled}
414
+ if env_steps_sampled is not None
415
+ else {}
416
+ ),
417
+ **(rl_module_state if rl_module_state is not None else {}),
418
+ }
419
+ )
420
+ return
421
+
422
+ # Also early out, if we a) don't use the remote states AND b) don't want to
423
+ # broadcast back from `from_worker` to all remote workers.
424
+ # TODO (sven): Rename these to proper "..env_runner_states.." containing names.
425
+ if not config.update_worker_filter_stats and not config.use_worker_filter_stats:
426
+ return
427
+
428
+ # Use states from all remote EnvRunners.
429
+ if config.use_worker_filter_stats:
430
+ if connector_states == []:
431
+ env_runner_states = {}
432
+ else:
433
+ if connector_states is None:
434
+ connector_states = self.foreach_env_runner(
435
+ lambda w: w.get_state(
436
+ components=[
437
+ COMPONENT_ENV_TO_MODULE_CONNECTOR,
438
+ COMPONENT_MODULE_TO_ENV_CONNECTOR,
439
+ ]
440
+ ),
441
+ local_env_runner=False,
442
+ timeout_seconds=(
443
+ config.sync_filters_on_rollout_workers_timeout_s
444
+ ),
445
+ )
446
+ env_to_module_states = [
447
+ s[COMPONENT_ENV_TO_MODULE_CONNECTOR]
448
+ for s in connector_states
449
+ if COMPONENT_ENV_TO_MODULE_CONNECTOR in s
450
+ ]
451
+ module_to_env_states = [
452
+ s[COMPONENT_MODULE_TO_ENV_CONNECTOR]
453
+ for s in connector_states
454
+ if COMPONENT_MODULE_TO_ENV_CONNECTOR in s
455
+ ]
456
+
457
+ env_runner_states = {}
458
+ if env_to_module_states:
459
+ env_runner_states.update(
460
+ {
461
+ COMPONENT_ENV_TO_MODULE_CONNECTOR: (
462
+ self.local_env_runner._env_to_module.merge_states(
463
+ env_to_module_states
464
+ )
465
+ ),
466
+ }
467
+ )
468
+ if module_to_env_states:
469
+ env_runner_states.update(
470
+ {
471
+ COMPONENT_MODULE_TO_ENV_CONNECTOR: (
472
+ self.local_env_runner._module_to_env.merge_states(
473
+ module_to_env_states
474
+ )
475
+ ),
476
+ }
477
+ )
478
+ # Ignore states from remote EnvRunners (use the current `from_worker` states
479
+ # only).
480
+ else:
481
+ env_runner_states = from_worker.get_state(
482
+ components=[
483
+ COMPONENT_ENV_TO_MODULE_CONNECTOR,
484
+ COMPONENT_MODULE_TO_ENV_CONNECTOR,
485
+ ]
486
+ )
487
+
488
+ # Update the global number of environment steps, if necessary.
489
+ # Make sure to divide by the number of env runners (such that each EnvRunner
490
+ # knows (roughly) its own(!) lifetime count and can infer the global lifetime
491
+ # count from it).
492
+ if env_steps_sampled is not None:
493
+ env_runner_states[NUM_ENV_STEPS_SAMPLED_LIFETIME] = env_steps_sampled // (
494
+ config.num_env_runners or 1
495
+ )
496
+
497
+ # Update the rl_module component of the EnvRunner states, if necessary:
498
+ if rl_module_state:
499
+ env_runner_states.update(rl_module_state)
500
+
501
+ # If we do NOT want remote EnvRunners to get their Connector states updated,
502
+ # only update the local worker here (with all state components) and then remove
503
+ # the connector components.
504
+ if not config.update_worker_filter_stats:
505
+ self.local_env_runner.set_state(env_runner_states)
506
+ env_runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None)
507
+ env_runner_states.pop(COMPONENT_MODULE_TO_ENV_CONNECTOR, None)
508
+
509
+ # If there are components in the state left -> Update remote workers with these
510
+ # state components (and maybe the local worker, if it hasn't been updated yet).
511
+ if env_runner_states:
512
+ # Put the state dictionary into Ray's object store to avoid having to make n
513
+ # pickled copies of the state dict.
514
+ ref_env_runner_states = ray.put(env_runner_states)
515
+
516
+ def _update(_env_runner: EnvRunner) -> None:
517
+ _env_runner.set_state(ray.get(ref_env_runner_states))
518
+
519
+ # Broadcast updated states back to all workers.
520
+ self.foreach_env_runner(
521
+ _update,
522
+ remote_worker_ids=env_runner_indices_to_update,
523
+ local_env_runner=config.update_worker_filter_stats,
524
+ timeout_seconds=0.0, # This is a state update -> Fire-and-forget.
525
+ )
526
+
527
+ def sync_weights(
528
+ self,
529
+ policies: Optional[List[PolicyID]] = None,
530
+ from_worker_or_learner_group: Optional[Union[EnvRunner, "LearnerGroup"]] = None,
531
+ to_worker_indices: Optional[List[int]] = None,
532
+ global_vars: Optional[Dict[str, TensorType]] = None,
533
+ timeout_seconds: Optional[float] = 0.0,
534
+ inference_only: Optional[bool] = False,
535
+ ) -> None:
536
+ """Syncs model weights from the given weight source to all remote workers.
537
+
538
+ Weight source can be either a (local) rollout worker or a learner_group. It
539
+ should just implement a `get_weights` method.
540
+
541
+ Args:
542
+ policies: Optional list of PolicyIDs to sync weights for.
543
+ If None (default), sync weights to/from all policies.
544
+ from_worker_or_learner_group: Optional (local) EnvRunner instance or
545
+ LearnerGroup instance to sync from. If None (default),
546
+ sync from this EnvRunnerGroup's local worker.
547
+ to_worker_indices: Optional list of worker indices to sync the
548
+ weights to. If None (default), sync to all remote workers.
549
+ global_vars: An optional global vars dict to set this
550
+ worker to. If None, do not update the global_vars.
551
+ timeout_seconds: Timeout in seconds to wait for the sync weights
552
+ calls to complete. Default is 0.0 (fire-and-forget, do not wait
553
+ for any sync calls to finish). Setting this to 0.0 might significantly
554
+ improve algorithm performance, depending on the algo's `training_step`
555
+ logic.
556
+ inference_only: Sync weights with workers that keep inference-only
557
+ modules. This is needed for algorithms in the new stack that
558
+ use inference-only modules. In this case only a part of the
559
+ parameters are synced to the workers. Default is False.
560
+ """
561
+ if self.local_env_runner is None and from_worker_or_learner_group is None:
562
+ raise TypeError(
563
+ "No `local_env_runner` in EnvRunnerGroup! Must provide "
564
+ "`from_worker_or_learner_group` arg in `sync_weights()`!"
565
+ )
566
+
567
+ # Only sync if we have remote workers or `from_worker_or_trainer` is provided.
568
+ rl_module_state = None
569
+ if self.num_remote_workers() or from_worker_or_learner_group is not None:
570
+ weights_src = from_worker_or_learner_group or self.local_env_runner
571
+
572
+ if weights_src is None:
573
+ raise ValueError(
574
+ "`from_worker_or_trainer` is None. In this case, EnvRunnerGroup "
575
+ "should have local_env_runner. But local_env_runner is also None."
576
+ )
577
+
578
+ modules = (
579
+ [COMPONENT_RL_MODULE + "/" + p for p in policies]
580
+ if policies is not None
581
+ else [COMPONENT_RL_MODULE]
582
+ )
583
+ # LearnerGroup has-a Learner has-a RLModule.
584
+ if isinstance(weights_src, LearnerGroup):
585
+ rl_module_state = weights_src.get_state(
586
+ components=[COMPONENT_LEARNER + "/" + m for m in modules],
587
+ inference_only=inference_only,
588
+ )[COMPONENT_LEARNER]
589
+ # EnvRunner has-a RLModule.
590
+ elif self._remote_config.enable_env_runner_and_connector_v2:
591
+ rl_module_state = weights_src.get_state(
592
+ components=modules,
593
+ inference_only=inference_only,
594
+ )
595
+ else:
596
+ rl_module_state = weights_src.get_weights(
597
+ policies=policies,
598
+ inference_only=inference_only,
599
+ )
600
+
601
+ if self._remote_config.enable_env_runner_and_connector_v2:
602
+
603
+ # Make sure `rl_module_state` only contains the weights and the
604
+ # weight seq no, nothing else.
605
+ rl_module_state = {
606
+ k: v
607
+ for k, v in rl_module_state.items()
608
+ if k in [COMPONENT_RL_MODULE, WEIGHTS_SEQ_NO]
609
+ }
610
+
611
+ # Move weights to the object store to avoid having to make n pickled
612
+ # copies of the weights dict for each worker.
613
+ rl_module_state_ref = ray.put(rl_module_state)
614
+
615
+ def _set_weights(env_runner):
616
+ env_runner.set_state(ray.get(rl_module_state_ref))
617
+
618
+ else:
619
+ rl_module_state_ref = ray.put(rl_module_state)
620
+
621
+ def _set_weights(env_runner):
622
+ env_runner.set_weights(ray.get(rl_module_state_ref), global_vars)
623
+
624
+ # Sync to specified remote workers in this EnvRunnerGroup.
625
+ self.foreach_env_runner(
626
+ func=_set_weights,
627
+ local_env_runner=False, # Do not sync back to local worker.
628
+ remote_worker_ids=to_worker_indices,
629
+ timeout_seconds=timeout_seconds,
630
+ )
631
+
632
+ # If `from_worker_or_learner_group` is provided, also sync to this
633
+ # EnvRunnerGroup's local worker.
634
+ if self.local_env_runner is not None:
635
+ if from_worker_or_learner_group is not None:
636
+ if self._remote_config.enable_env_runner_and_connector_v2:
637
+ self.local_env_runner.set_state(rl_module_state)
638
+ else:
639
+ self.local_env_runner.set_weights(rl_module_state)
640
+ # If `global_vars` is provided and local worker exists -> Update its
641
+ # global_vars.
642
+ if global_vars is not None:
643
+ self.local_env_runner.set_global_vars(global_vars)
644
+
645
+ @OldAPIStack
646
+ def add_policy(
647
+ self,
648
+ policy_id: PolicyID,
649
+ policy_cls: Optional[Type[Policy]] = None,
650
+ policy: Optional[Policy] = None,
651
+ *,
652
+ observation_space: Optional[gym.spaces.Space] = None,
653
+ action_space: Optional[gym.spaces.Space] = None,
654
+ config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None,
655
+ policy_state: Optional[PolicyState] = None,
656
+ policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
657
+ policies_to_train: Optional[
658
+ Union[
659
+ Collection[PolicyID],
660
+ Callable[[PolicyID, Optional[SampleBatchType]], bool],
661
+ ]
662
+ ] = None,
663
+ module_spec: Optional[RLModuleSpec] = None,
664
+ # Deprecated.
665
+ workers: Optional[List[Union[EnvRunner, ActorHandle]]] = DEPRECATED_VALUE,
666
+ ) -> None:
667
+ """Adds a policy to this EnvRunnerGroup's workers or a specific list of workers.
668
+
669
+ Args:
670
+ policy_id: ID of the policy to add.
671
+ policy_cls: The Policy class to use for constructing the new Policy.
672
+ Note: Only one of `policy_cls` or `policy` must be provided.
673
+ policy: The Policy instance to add to this EnvRunnerGroup. If not None, the
674
+ given Policy object will be directly inserted into the
675
+ local worker and clones of that Policy will be created on all remote
676
+ workers.
677
+ Note: Only one of `policy_cls` or `policy` must be provided.
678
+ observation_space: The observation space of the policy to add.
679
+ If None, try to infer this space from the environment.
680
+ action_space: The action space of the policy to add.
681
+ If None, try to infer this space from the environment.
682
+ config: The config object or overrides for the policy to add.
683
+ policy_state: Optional state dict to apply to the new
684
+ policy instance, right after its construction.
685
+ policy_mapping_fn: An optional (updated) policy mapping function
686
+ to use from here on. Note that already ongoing episodes will
687
+ not change their mapping but will use the old mapping till
688
+ the end of the episode.
689
+ policies_to_train: An optional list of policy IDs to be trained
690
+ or a callable taking PolicyID and SampleBatchType and
691
+ returning a bool (trainable or not?).
692
+ If None, will keep the existing setup in place. Policies,
693
+ whose IDs are not in the list (or for which the callable
694
+ returns False) will not be updated.
695
+ module_spec: In the new RLModule API we need to pass in the module_spec for
696
+ the new module that is supposed to be added. Knowing the policy spec is
697
+ not sufficient.
698
+ workers: A list of EnvRunner/ActorHandles (remote
699
+ EnvRunners) to add this policy to. If defined, will only
700
+ add the given policy to these workers.
701
+
702
+ Raises:
703
+ KeyError: If the given `policy_id` already exists in this EnvRunnerGroup.
704
+ """
705
+ if self.local_env_runner and policy_id in self.local_env_runner.policy_map:
706
+ raise KeyError(
707
+ f"Policy ID '{policy_id}' already exists in policy map! "
708
+ "Make sure you use a Policy ID that has not been taken yet."
709
+ " Policy IDs that are already in your policy map: "
710
+ f"{list(self.local_env_runner.policy_map.keys())}"
711
+ )
712
+
713
+ if workers is not DEPRECATED_VALUE:
714
+ deprecation_warning(
715
+ old="EnvRunnerGroup.add_policy(.., workers=..)",
716
+ help=(
717
+ "The `workers` argument to `EnvRunnerGroup.add_policy()` is "
718
+ "deprecated! Please do not use it anymore."
719
+ ),
720
+ error=True,
721
+ )
722
+
723
+ if (policy_cls is None) == (policy is None):
724
+ raise ValueError(
725
+ "Only one of `policy_cls` or `policy` must be provided to "
726
+ "staticmethod: `EnvRunnerGroup.add_policy()`!"
727
+ )
728
+ validate_module_id(policy_id, error=False)
729
+
730
+ # Policy instance not provided: Use the information given here.
731
+ if policy_cls is not None:
732
+ new_policy_instance_kwargs = dict(
733
+ policy_id=policy_id,
734
+ policy_cls=policy_cls,
735
+ observation_space=observation_space,
736
+ action_space=action_space,
737
+ config=config,
738
+ policy_state=policy_state,
739
+ policy_mapping_fn=policy_mapping_fn,
740
+ policies_to_train=list(policies_to_train)
741
+ if policies_to_train
742
+ else None,
743
+ module_spec=module_spec,
744
+ )
745
+ # Policy instance provided: Create clones of this very policy on the different
746
+ # workers (copy all its properties here for the calls to add_policy on the
747
+ # remote workers).
748
+ else:
749
+ new_policy_instance_kwargs = dict(
750
+ policy_id=policy_id,
751
+ policy_cls=type(policy),
752
+ observation_space=policy.observation_space,
753
+ action_space=policy.action_space,
754
+ config=policy.config,
755
+ policy_state=policy.get_state(),
756
+ policy_mapping_fn=policy_mapping_fn,
757
+ policies_to_train=list(policies_to_train)
758
+ if policies_to_train
759
+ else None,
760
+ module_spec=module_spec,
761
+ )
762
+
763
+ def _create_new_policy_fn(worker):
764
+ # `foreach_env_runner` function: Adds the policy the the worker (and
765
+ # maybe changes its policy_mapping_fn - if provided here).
766
+ worker.add_policy(**new_policy_instance_kwargs)
767
+
768
+ if self.local_env_runner is not None:
769
+ # Add policy directly by (already instantiated) object.
770
+ if policy is not None:
771
+ self.local_env_runner.add_policy(
772
+ policy_id=policy_id,
773
+ policy=policy,
774
+ policy_mapping_fn=policy_mapping_fn,
775
+ policies_to_train=policies_to_train,
776
+ module_spec=module_spec,
777
+ )
778
+ # Add policy by constructor kwargs.
779
+ else:
780
+ self.local_env_runner.add_policy(**new_policy_instance_kwargs)
781
+
782
+ # Add the policy to all remote workers.
783
+ self.foreach_env_runner(_create_new_policy_fn, local_env_runner=False)
784
+
785
+ def add_workers(self, num_workers: int, validate: bool = False) -> None:
786
+ """Creates and adds a number of remote workers to this worker set.
787
+
788
+ Can be called several times on the same EnvRunnerGroup to add more
789
+ EnvRunners to the set.
790
+
791
+ Args:
792
+ num_workers: The number of remote Workers to add to this
793
+ EnvRunnerGroup.
794
+ validate: Whether to validate remote workers after their construction
795
+ process.
796
+
797
+ Raises:
798
+ RayError: If any of the constructed remote workers is not up and running
799
+ properly.
800
+ """
801
+ old_num_workers = self._worker_manager.num_actors()
802
+ new_workers = [
803
+ self._make_worker(
804
+ cls=self._cls,
805
+ env_creator=self._env_creator,
806
+ validate_env=None,
807
+ worker_index=old_num_workers + i + 1,
808
+ num_workers=old_num_workers + num_workers,
809
+ config=self._remote_config,
810
+ )
811
+ for i in range(num_workers)
812
+ ]
813
+ self._worker_manager.add_actors(new_workers)
814
+
815
+ # Validate here, whether all remote workers have been constructed properly
816
+ # and are "up and running". Establish initial states.
817
+ if validate:
818
+ for result in self._worker_manager.foreach_actor(
819
+ lambda w: w.assert_healthy()
820
+ ):
821
+ # Simiply raise the error, which will get handled by the try-except
822
+ # clause around the _setup().
823
+ if not result.ok:
824
+ e = result.get()
825
+ if self._ignore_ray_errors_on_env_runners:
826
+ logger.error(f"Validation of EnvRunner failed! Error={str(e)}")
827
+ else:
828
+ raise e
829
+
830
+ def reset(self, new_remote_workers: List[ActorHandle]) -> None:
831
+ """Hard overrides the remote EnvRunners in this set with the provided ones.
832
+
833
+ Args:
834
+ new_remote_workers: A list of new EnvRunners (as `ActorHandles`) to use as
835
+ new remote workers.
836
+ """
837
+ self._worker_manager.clear()
838
+ self._worker_manager.add_actors(new_remote_workers)
839
+
840
+ def stop(self) -> None:
841
+ """Calls `stop` on all EnvRunners (including the local one)."""
842
+ try:
843
+ # Make sure we stop all EnvRunners, include the ones that were just
844
+ # restarted / recovered or that are tagged unhealthy (at least, we should
845
+ # try).
846
+ self.foreach_env_runner(
847
+ lambda w: w.stop(), healthy_only=False, local_env_runner=True
848
+ )
849
+ except Exception:
850
+ logger.exception("Failed to stop workers!")
851
+ finally:
852
+ self._worker_manager.clear()
853
+
854
+ def is_policy_to_train(
855
+ self, policy_id: PolicyID, batch: Optional[SampleBatchType] = None
856
+ ) -> bool:
857
+ """Whether given PolicyID (optionally inside some batch) is trainable."""
858
+ if self.local_env_runner:
859
+ if self.local_env_runner.is_policy_to_train is None:
860
+ return True
861
+ return self.local_env_runner.is_policy_to_train(policy_id, batch)
862
+ else:
863
+ raise NotImplementedError
864
+
865
+ def foreach_env_runner(
866
+ self,
867
+ func: Callable[[EnvRunner], T],
868
+ *,
869
+ local_env_runner: bool = True,
870
+ healthy_only: bool = True,
871
+ remote_worker_ids: List[int] = None,
872
+ timeout_seconds: Optional[float] = None,
873
+ return_obj_refs: bool = False,
874
+ mark_healthy: bool = False,
875
+ ) -> List[T]:
876
+ """Calls the given function with each EnvRunner as its argument.
877
+
878
+ Args:
879
+ func: The function to call for each EnvRunners. The only call argument is
880
+ the respective EnvRunner instance.
881
+ local_env_runner: Whether to apply `func` to local EnvRunner, too.
882
+ Default is True.
883
+ healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
884
+ remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
885
+ Use None (default) for all remote EnvRunners.
886
+ timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
887
+ fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
888
+ synchronous execution).
889
+ return_obj_refs: Whether to return ObjectRef instead of actual results.
890
+ Note, for fault tolerance reasons, these returned ObjectRefs should
891
+ never be resolved with ray.get() outside of this EnvRunnerGroup.
892
+ mark_healthy: Whether to mark all those EnvRunners healthy again that are
893
+ currently marked unhealthy AND that returned results from the remote
894
+ call (within the given `timeout_seconds`).
895
+ Note that EnvRunners are NOT set unhealthy, if they simply time out
896
+ (only if they return a RayActorError).
897
+ Also note that this setting is ignored if `healthy_only=True` (b/c
898
+ `mark_healthy` only affects EnvRunners that are currently tagged as
899
+ unhealthy).
900
+
901
+ Returns:
902
+ The list of return values of all calls to `func([worker])`.
903
+ """
904
+ assert (
905
+ not return_obj_refs or not local_env_runner
906
+ ), "Can not return ObjectRef from local worker."
907
+
908
+ local_result = []
909
+ if local_env_runner and self.local_env_runner is not None:
910
+ local_result = [func(self.local_env_runner)]
911
+
912
+ if not self._worker_manager.actor_ids():
913
+ return local_result
914
+
915
+ remote_results = self._worker_manager.foreach_actor(
916
+ func,
917
+ healthy_only=healthy_only,
918
+ remote_actor_ids=remote_worker_ids,
919
+ timeout_seconds=timeout_seconds,
920
+ return_obj_refs=return_obj_refs,
921
+ mark_healthy=mark_healthy,
922
+ )
923
+
924
+ FaultTolerantActorManager.handle_remote_call_result_errors(
925
+ remote_results, ignore_ray_errors=self._ignore_ray_errors_on_env_runners
926
+ )
927
+
928
+ # With application errors handled, return good results.
929
+ remote_results = [r.get() for r in remote_results.ignore_errors()]
930
+
931
+ return local_result + remote_results
932
+
933
+ def foreach_env_runner_with_id(
934
+ self,
935
+ func: Callable[[int, EnvRunner], T],
936
+ *,
937
+ local_env_runner: bool = True,
938
+ healthy_only: bool = True,
939
+ remote_worker_ids: List[int] = None,
940
+ timeout_seconds: Optional[float] = None,
941
+ return_obj_refs: bool = False,
942
+ mark_healthy: bool = False,
943
+ # Deprecated args.
944
+ local_worker=DEPRECATED_VALUE,
945
+ ) -> List[T]:
946
+ """Calls the given function with each EnvRunner and its ID as its arguments.
947
+
948
+ Args:
949
+ func: The function to call for each EnvRunners. The call arguments are
950
+ the EnvRunner's index (int) and the respective EnvRunner instance
951
+ itself.
952
+ local_env_runner: Whether to apply `func` to the local EnvRunner, too.
953
+ Default is True.
954
+ healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
955
+ remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
956
+ timeout_seconds: Time to wait for results. Default is None.
957
+ return_obj_refs: Whether to return ObjectRef instead of actual results.
958
+ Note, for fault tolerance reasons, these returned ObjectRefs should
959
+ never be resolved with ray.get() outside of this EnvRunnerGroup.
960
+ mark_healthy: Whether to mark all those EnvRunners healthy again that are
961
+ currently marked unhealthy AND that returned results from the remote
962
+ call (within the given `timeout_seconds`).
963
+ Note that workers are NOT set unhealthy, if they simply time out
964
+ (only if they return a RayActorError).
965
+ Also note that this setting is ignored if `healthy_only=True` (b/c
966
+ `mark_healthy` only affects EnvRunners that are currently tagged as
967
+ unhealthy).
968
+
969
+ Returns:
970
+ The list of return values of all calls to `func([worker, id])`.
971
+ """
972
+ local_result = []
973
+ if local_env_runner and self.local_env_runner is not None:
974
+ local_result = [func(0, self.local_env_runner)]
975
+
976
+ if not remote_worker_ids:
977
+ remote_worker_ids = self._worker_manager.actor_ids()
978
+
979
+ funcs = [functools.partial(func, i) for i in remote_worker_ids]
980
+
981
+ remote_results = self._worker_manager.foreach_actor(
982
+ funcs,
983
+ healthy_only=healthy_only,
984
+ remote_actor_ids=remote_worker_ids,
985
+ timeout_seconds=timeout_seconds,
986
+ return_obj_refs=return_obj_refs,
987
+ mark_healthy=mark_healthy,
988
+ )
989
+
990
+ FaultTolerantActorManager.handle_remote_call_result_errors(
991
+ remote_results,
992
+ ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
993
+ )
994
+
995
+ remote_results = [r.get() for r in remote_results.ignore_errors()]
996
+
997
+ return local_result + remote_results
998
+
999
+ def foreach_env_runner_async(
1000
+ self,
1001
+ func: Callable[[EnvRunner], T],
1002
+ *,
1003
+ healthy_only: bool = True,
1004
+ remote_worker_ids: List[int] = None,
1005
+ ) -> int:
1006
+ """Calls the given function asynchronously with each EnvRunner as the argument.
1007
+
1008
+ Does not return results directly. Instead, `fetch_ready_async_reqs()` can be
1009
+ used to pull results in an async manner whenever they are available.
1010
+
1011
+ Args:
1012
+ func: The function to call for each EnvRunners. The only call argument is
1013
+ the respective EnvRunner instance.
1014
+ healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
1015
+ remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
1016
+
1017
+ Returns:
1018
+ The number of async requests that have actually been made. This is the
1019
+ length of `remote_worker_ids` (or self.num_remote_workers()` if
1020
+ `remote_worker_ids` is None) minus the number of requests that were NOT
1021
+ made b/c a remote EnvRunner already had its
1022
+ `max_remote_requests_in_flight_per_actor` counter reached.
1023
+ """
1024
+ return self._worker_manager.foreach_actor_async(
1025
+ func,
1026
+ healthy_only=healthy_only,
1027
+ remote_actor_ids=remote_worker_ids,
1028
+ )
1029
+
1030
+ def fetch_ready_async_reqs(
1031
+ self,
1032
+ *,
1033
+ timeout_seconds: Optional[float] = 0.0,
1034
+ return_obj_refs: bool = False,
1035
+ mark_healthy: bool = False,
1036
+ ) -> List[Tuple[int, T]]:
1037
+ """Get esults from outstanding asynchronous requests that are ready.
1038
+
1039
+ Args:
1040
+ timeout_seconds: Time to wait for results. Default is 0, meaning
1041
+ those requests that are already ready.
1042
+ return_obj_refs: Whether to return ObjectRef instead of actual results.
1043
+ mark_healthy: Whether to mark all those workers healthy again that are
1044
+ currently marked unhealthy AND that returned results from the remote
1045
+ call (within the given `timeout_seconds`).
1046
+ Note that workers are NOT set unhealthy, if they simply time out
1047
+ (only if they return a RayActorError).
1048
+ Also note that this setting is ignored if `healthy_only=True` (b/c
1049
+ `mark_healthy` only affects workers that are currently tagged as
1050
+ unhealthy).
1051
+
1052
+ Returns:
1053
+ A list of results successfully returned from outstanding remote calls,
1054
+ paired with the indices of the callee workers.
1055
+ """
1056
+ remote_results = self._worker_manager.fetch_ready_async_reqs(
1057
+ timeout_seconds=timeout_seconds,
1058
+ return_obj_refs=return_obj_refs,
1059
+ mark_healthy=mark_healthy,
1060
+ )
1061
+
1062
+ FaultTolerantActorManager.handle_remote_call_result_errors(
1063
+ remote_results,
1064
+ ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
1065
+ )
1066
+
1067
+ return [(r.actor_id, r.get()) for r in remote_results.ignore_errors()]
1068
+
1069
+ def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
1070
+ """Calls `func` with each worker's (policy, PolicyID) tuple.
1071
+
1072
+ Note that in the multi-agent case, each worker may have more than one
1073
+ policy.
1074
+
1075
+ Args:
1076
+ func: A function - taking a Policy and its ID - that is
1077
+ called on all workers' Policies.
1078
+
1079
+ Returns:
1080
+ The list of return values of func over all workers' policies. The
1081
+ length of this list is:
1082
+ (num_workers + 1 (local-worker)) *
1083
+ [num policies in the multi-agent config dict].
1084
+ The local workers' results are first, followed by all remote
1085
+ workers' results
1086
+ """
1087
+ results = []
1088
+ for r in self.foreach_env_runner(
1089
+ lambda w: w.foreach_policy(func), local_env_runner=True
1090
+ ):
1091
+ results.extend(r)
1092
+ return results
1093
+
1094
+ def foreach_policy_to_train(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
1095
+ """Apply `func` to all workers' Policies iff in `policies_to_train`.
1096
+
1097
+ Args:
1098
+ func: A function - taking a Policy and its ID - that is
1099
+ called on all workers' Policies, for which
1100
+ `worker.is_policy_to_train()` returns True.
1101
+
1102
+ Returns:
1103
+ List[any]: The list of n return values of all
1104
+ `func([trainable policy], [ID])`-calls.
1105
+ """
1106
+ results = []
1107
+ for r in self.foreach_env_runner(
1108
+ lambda w: w.foreach_policy_to_train(func), local_env_runner=True
1109
+ ):
1110
+ results.extend(r)
1111
+ return results
1112
+
1113
+ def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
1114
+ """Calls `func` with all workers' sub-environments as args.
1115
+
1116
+ An "underlying sub environment" is a single clone of an env within
1117
+ a vectorized environment.
1118
+ `func` takes a single underlying sub environment as arg, e.g. a
1119
+ gym.Env object.
1120
+
1121
+ Args:
1122
+ func: A function - taking an EnvType (normally a gym.Env object)
1123
+ as arg and returning a list of lists of return values, one
1124
+ value per underlying sub-environment per each worker.
1125
+
1126
+ Returns:
1127
+ The list (workers) of lists (sub environments) of results.
1128
+ """
1129
+ return list(
1130
+ self.foreach_env_runner(
1131
+ lambda w: w.foreach_env(func),
1132
+ local_env_runner=True,
1133
+ )
1134
+ )
1135
+
1136
+ def foreach_env_with_context(
1137
+ self, func: Callable[[BaseEnv, EnvContext], List[T]]
1138
+ ) -> List[List[T]]:
1139
+ """Calls `func` with all workers' sub-environments and env_ctx as args.
1140
+
1141
+ An "underlying sub environment" is a single clone of an env within
1142
+ a vectorized environment.
1143
+ `func` takes a single underlying sub environment and the env_context
1144
+ as args.
1145
+
1146
+ Args:
1147
+ func: A function - taking a BaseEnv object and an EnvContext as
1148
+ arg - and returning a list of lists of return values over envs
1149
+ of the worker.
1150
+
1151
+ Returns:
1152
+ The list (1 item per workers) of lists (1 item per sub-environment)
1153
+ of results.
1154
+ """
1155
+ return list(
1156
+ self.foreach_env_runner(
1157
+ lambda w: w.foreach_env_with_context(func),
1158
+ local_env_runner=True,
1159
+ )
1160
+ )
1161
+
1162
+ def probe_unhealthy_env_runners(self) -> List[int]:
1163
+ """Checks for unhealthy workers and tries restoring their states.
1164
+
1165
+ Returns:
1166
+ List of IDs of the workers that were restored.
1167
+ """
1168
+ return self._worker_manager.probe_unhealthy_actors(
1169
+ timeout_seconds=self._remote_config.env_runner_health_probe_timeout_s,
1170
+ mark_healthy=True,
1171
+ )
1172
+
1173
+ def _make_worker(
1174
+ self,
1175
+ *,
1176
+ cls: Callable,
1177
+ env_creator: EnvCreator,
1178
+ validate_env: Optional[Callable[[EnvType], None]],
1179
+ worker_index: int,
1180
+ num_workers: int,
1181
+ recreated_worker: bool = False,
1182
+ config: "AlgorithmConfig",
1183
+ spaces: Optional[
1184
+ Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]
1185
+ ] = None,
1186
+ ) -> Union[EnvRunner, ActorHandle]:
1187
+ worker = cls(
1188
+ env_creator=env_creator,
1189
+ validate_env=validate_env,
1190
+ default_policy_class=self._policy_class,
1191
+ config=config,
1192
+ worker_index=worker_index,
1193
+ num_workers=num_workers,
1194
+ recreated_worker=recreated_worker,
1195
+ log_dir=self._logdir,
1196
+ spaces=spaces,
1197
+ dataset_shards=self._ds_shards,
1198
+ tune_trial_id=self._tune_trial_id,
1199
+ )
1200
+
1201
+ return worker
1202
+
1203
+ @classmethod
1204
+ def _valid_module(cls, class_path):
1205
+ del cls
1206
+ if (
1207
+ isinstance(class_path, str)
1208
+ and not os.path.isfile(class_path)
1209
+ and "." in class_path
1210
+ ):
1211
+ module_path, class_name = class_path.rsplit(".", 1)
1212
+ try:
1213
+ spec = importlib.util.find_spec(module_path)
1214
+ if spec is not None:
1215
+ return True
1216
+ except (ModuleNotFoundError, ValueError):
1217
+ print(
1218
+ f"module {module_path} not found while trying to get "
1219
+ f"input {class_path}"
1220
+ )
1221
+ return False
1222
+
1223
+ @Deprecated(new="EnvRunnerGroup.probe_unhealthy_env_runners", error=False)
1224
+ def probe_unhealthy_workers(self, *args, **kwargs):
1225
+ return self.probe_unhealthy_env_runners(*args, **kwargs)
1226
+
1227
+ @Deprecated(new="EnvRunnerGroup.foreach_env_runner", error=False)
1228
+ def foreach_worker(self, *args, **kwargs):
1229
+ return self.foreach_env_runner(*args, **kwargs)
1230
+
1231
+ @Deprecated(new="EnvRunnerGroup.foreach_env_runner_with_id", error=False)
1232
+ def foreach_worker_with_id(self, *args, **kwargs):
1233
+ return self.foreach_env_runner_with_id(*args, **kwargs)
1234
+
1235
+ @Deprecated(new="EnvRunnerGroup.foreach_env_runner_async", error=False)
1236
+ def foreach_worker_async(self, *args, **kwargs):
1237
+ return self.foreach_env_runner_async(*args, **kwargs)
1238
+
1239
+ @Deprecated(new="EnvRunnerGroup.local_env_runner", error=True)
1240
+ def local_worker(self) -> EnvRunner:
1241
+ pass
1242
+
1243
+ @property
1244
+ @Deprecated(
1245
+ old="_remote_workers",
1246
+ new="Use either the `foreach_env_runner()`, `foreach_env_runner_with_id()`, or "
1247
+ "`foreach_env_runner_async()` APIs of `EnvRunnerGroup`, which all handle fault "
1248
+ "tolerance.",
1249
+ error=True,
1250
+ )
1251
+ def _remote_workers(self):
1252
+ pass
1253
+
1254
+ @Deprecated(
1255
+ old="remote_workers()",
1256
+ new="Use either the `foreach_env_runner()`, `foreach_env_runner_with_id()`, or "
1257
+ "`foreach_env_runner_async()` APIs of `EnvRunnerGroup`, which all handle fault "
1258
+ "tolerance.",
1259
+ error=True,
1260
+ )
1261
+ def remote_workers(self):
1262
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import queue
3
+ import threading
4
+ import uuid
5
+ from typing import Callable, Tuple, Optional, TYPE_CHECKING
6
+
7
+ from ray.rllib.env.base_env import BaseEnv
8
+ from ray.rllib.utils.annotations import override, OldAPIStack
9
+ from ray.rllib.utils.typing import (
10
+ EnvActionType,
11
+ EnvInfoDict,
12
+ EnvObsType,
13
+ EnvType,
14
+ MultiEnvDict,
15
+ )
16
+ from ray.rllib.utils.deprecation import deprecation_warning
17
+
18
+ if TYPE_CHECKING:
19
+ from ray.rllib.models.preprocessors import Preprocessor
20
+
21
+
22
+ @OldAPIStack
23
+ class ExternalEnv(threading.Thread):
24
+ """An environment that interfaces with external agents.
25
+
26
+ Unlike simulator envs, control is inverted: The environment queries the
27
+ policy to obtain actions and in return logs observations and rewards for
28
+ training. This is in contrast to gym.Env, where the algorithm drives the
29
+ simulation through env.step() calls.
30
+
31
+ You can use ExternalEnv as the backend for policy serving (by serving HTTP
32
+ requests in the run loop), for ingesting offline logs data (by reading
33
+ offline transitions in the run loop), or other custom use cases not easily
34
+ expressed through gym.Env.
35
+
36
+ ExternalEnv supports both on-policy actions (through self.get_action()),
37
+ and off-policy actions (through self.log_action()).
38
+
39
+ This env is thread-safe, but individual episodes must be executed serially.
40
+
41
+ .. testcode::
42
+ :skipif: True
43
+
44
+ from ray.tune import register_env
45
+ from ray.rllib.algorithms.dqn import DQN
46
+ YourExternalEnv = ...
47
+ register_env("my_env", lambda config: YourExternalEnv(config))
48
+ algo = DQN(env="my_env")
49
+ while True:
50
+ print(algo.train())
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ action_space: gym.Space,
56
+ observation_space: gym.Space,
57
+ max_concurrent: int = None,
58
+ ):
59
+ """Initializes an ExternalEnv instance.
60
+
61
+ Args:
62
+ action_space: Action space of the env.
63
+ observation_space: Observation space of the env.
64
+ """
65
+
66
+ threading.Thread.__init__(self)
67
+
68
+ self.daemon = True
69
+ self.action_space = action_space
70
+ self.observation_space = observation_space
71
+ self._episodes = {}
72
+ self._finished = set()
73
+ self._results_avail_condition = threading.Condition()
74
+ if max_concurrent is not None:
75
+ deprecation_warning(
76
+ "The `max_concurrent` argument has been deprecated. Please configure"
77
+ "the number of episodes using the `rollout_fragment_length` and"
78
+ "`batch_mode` arguments. Please raise an issue on the Ray Github if "
79
+ "these arguments do not support your expected use case for ExternalEnv",
80
+ error=True,
81
+ )
82
+
83
+ def run(self):
84
+ """Override this to implement the run loop.
85
+
86
+ Your loop should continuously:
87
+ 1. Call self.start_episode(episode_id)
88
+ 2. Call self.[get|log]_action(episode_id, obs, [action]?)
89
+ 3. Call self.log_returns(episode_id, reward)
90
+ 4. Call self.end_episode(episode_id, obs)
91
+ 5. Wait if nothing to do.
92
+
93
+ Multiple episodes may be started at the same time.
94
+ """
95
+ raise NotImplementedError
96
+
97
+ def start_episode(
98
+ self, episode_id: Optional[str] = None, training_enabled: bool = True
99
+ ) -> str:
100
+ """Record the start of an episode.
101
+
102
+ Args:
103
+ episode_id: Unique string id for the episode or
104
+ None for it to be auto-assigned and returned.
105
+ training_enabled: Whether to use experiences for this
106
+ episode to improve the policy.
107
+
108
+ Returns:
109
+ Unique string id for the episode.
110
+ """
111
+
112
+ if episode_id is None:
113
+ episode_id = uuid.uuid4().hex
114
+
115
+ if episode_id in self._finished:
116
+ raise ValueError("Episode {} has already completed.".format(episode_id))
117
+
118
+ if episode_id in self._episodes:
119
+ raise ValueError("Episode {} is already started".format(episode_id))
120
+
121
+ self._episodes[episode_id] = _ExternalEnvEpisode(
122
+ episode_id, self._results_avail_condition, training_enabled
123
+ )
124
+
125
+ return episode_id
126
+
127
+ def get_action(self, episode_id: str, observation: EnvObsType) -> EnvActionType:
128
+ """Record an observation and get the on-policy action.
129
+
130
+ Args:
131
+ episode_id: Episode id returned from start_episode().
132
+ observation: Current environment observation.
133
+
134
+ Returns:
135
+ Action from the env action space.
136
+ """
137
+
138
+ episode = self._get(episode_id)
139
+ return episode.wait_for_action(observation)
140
+
141
+ def log_action(
142
+ self, episode_id: str, observation: EnvObsType, action: EnvActionType
143
+ ) -> None:
144
+ """Record an observation and (off-policy) action taken.
145
+
146
+ Args:
147
+ episode_id: Episode id returned from start_episode().
148
+ observation: Current environment observation.
149
+ action: Action for the observation.
150
+ """
151
+
152
+ episode = self._get(episode_id)
153
+ episode.log_action(observation, action)
154
+
155
+ def log_returns(
156
+ self, episode_id: str, reward: float, info: Optional[EnvInfoDict] = None
157
+ ) -> None:
158
+ """Records returns (rewards and infos) from the environment.
159
+
160
+ The reward will be attributed to the previous action taken by the
161
+ episode. Rewards accumulate until the next action. If no reward is
162
+ logged before the next action, a reward of 0.0 is assumed.
163
+
164
+ Args:
165
+ episode_id: Episode id returned from start_episode().
166
+ reward: Reward from the environment.
167
+ info: Optional info dict.
168
+ """
169
+
170
+ episode = self._get(episode_id)
171
+ episode.cur_reward += reward
172
+
173
+ if info:
174
+ episode.cur_info = info or {}
175
+
176
+ def end_episode(self, episode_id: str, observation: EnvObsType) -> None:
177
+ """Records the end of an episode.
178
+
179
+ Args:
180
+ episode_id: Episode id returned from start_episode().
181
+ observation: Current environment observation.
182
+ """
183
+
184
+ episode = self._get(episode_id)
185
+ self._finished.add(episode.episode_id)
186
+ episode.done(observation)
187
+
188
+ def _get(self, episode_id: str) -> "_ExternalEnvEpisode":
189
+ """Get a started episode by its ID or raise an error."""
190
+
191
+ if episode_id in self._finished:
192
+ raise ValueError("Episode {} has already completed.".format(episode_id))
193
+
194
+ if episode_id not in self._episodes:
195
+ raise ValueError("Episode {} not found.".format(episode_id))
196
+
197
+ return self._episodes[episode_id]
198
+
199
+ def to_base_env(
200
+ self,
201
+ make_env: Optional[Callable[[int], EnvType]] = None,
202
+ num_envs: int = 1,
203
+ remote_envs: bool = False,
204
+ remote_env_batch_wait_ms: int = 0,
205
+ restart_failed_sub_environments: bool = False,
206
+ ) -> "BaseEnv":
207
+ """Converts an RLlib MultiAgentEnv into a BaseEnv object.
208
+
209
+ The resulting BaseEnv is always vectorized (contains n
210
+ sub-environments) to support batched forward passes, where n may
211
+ also be 1. BaseEnv also supports async execution via the `poll` and
212
+ `send_actions` methods and thus supports external simulators.
213
+
214
+ Args:
215
+ make_env: A callable taking an int as input (which indicates
216
+ the number of individual sub-environments within the final
217
+ vectorized BaseEnv) and returning one individual
218
+ sub-environment.
219
+ num_envs: The number of sub-environments to create in the
220
+ resulting (vectorized) BaseEnv. The already existing `env`
221
+ will be one of the `num_envs`.
222
+ remote_envs: Whether each sub-env should be a @ray.remote
223
+ actor. You can set this behavior in your config via the
224
+ `remote_worker_envs=True` option.
225
+ remote_env_batch_wait_ms: The wait time (in ms) to poll remote
226
+ sub-environments for, if applicable. Only used if
227
+ `remote_envs` is True.
228
+
229
+ Returns:
230
+ The resulting BaseEnv object.
231
+ """
232
+ if num_envs != 1:
233
+ raise ValueError(
234
+ "External(MultiAgent)Env does not currently support "
235
+ "num_envs > 1. One way of solving this would be to "
236
+ "treat your Env as a MultiAgentEnv hosting only one "
237
+ "type of agent but with several copies."
238
+ )
239
+ env = ExternalEnvWrapper(self)
240
+
241
+ return env
242
+
243
+
244
+ @OldAPIStack
245
+ class _ExternalEnvEpisode:
246
+ """Tracked state for each active episode."""
247
+
248
+ def __init__(
249
+ self,
250
+ episode_id: str,
251
+ results_avail_condition: threading.Condition,
252
+ training_enabled: bool,
253
+ multiagent: bool = False,
254
+ ):
255
+ self.episode_id = episode_id
256
+ self.results_avail_condition = results_avail_condition
257
+ self.training_enabled = training_enabled
258
+ self.multiagent = multiagent
259
+ self.data_queue = queue.Queue()
260
+ self.action_queue = queue.Queue()
261
+ if multiagent:
262
+ self.new_observation_dict = None
263
+ self.new_action_dict = None
264
+ self.cur_reward_dict = {}
265
+ self.cur_terminated_dict = {"__all__": False}
266
+ self.cur_truncated_dict = {"__all__": False}
267
+ self.cur_info_dict = {}
268
+ else:
269
+ self.new_observation = None
270
+ self.new_action = None
271
+ self.cur_reward = 0.0
272
+ self.cur_terminated = False
273
+ self.cur_truncated = False
274
+ self.cur_info = {}
275
+
276
+ def get_data(self):
277
+ if self.data_queue.empty():
278
+ return None
279
+ return self.data_queue.get_nowait()
280
+
281
+ def log_action(self, observation, action):
282
+ if self.multiagent:
283
+ self.new_observation_dict = observation
284
+ self.new_action_dict = action
285
+ else:
286
+ self.new_observation = observation
287
+ self.new_action = action
288
+ self._send()
289
+ self.action_queue.get(True, timeout=60.0)
290
+
291
+ def wait_for_action(self, observation):
292
+ if self.multiagent:
293
+ self.new_observation_dict = observation
294
+ else:
295
+ self.new_observation = observation
296
+ self._send()
297
+ return self.action_queue.get(True, timeout=300.0)
298
+
299
+ def done(self, observation):
300
+ if self.multiagent:
301
+ self.new_observation_dict = observation
302
+ self.cur_terminated_dict = {"__all__": True}
303
+ # TODO(sven): External env API does not currently support truncated,
304
+ # but we should deprecate external Env anyways in favor of a client-only
305
+ # approach.
306
+ self.cur_truncated_dict = {"__all__": False}
307
+ else:
308
+ self.new_observation = observation
309
+ self.cur_terminated = True
310
+ self.cur_truncated = False
311
+ self._send()
312
+
313
+ def _send(self):
314
+ if self.multiagent:
315
+ if not self.training_enabled:
316
+ for agent_id in self.cur_info_dict:
317
+ self.cur_info_dict[agent_id]["training_enabled"] = False
318
+ item = {
319
+ "obs": self.new_observation_dict,
320
+ "reward": self.cur_reward_dict,
321
+ "terminated": self.cur_terminated_dict,
322
+ "truncated": self.cur_truncated_dict,
323
+ "info": self.cur_info_dict,
324
+ }
325
+ if self.new_action_dict is not None:
326
+ item["off_policy_action"] = self.new_action_dict
327
+ self.new_observation_dict = None
328
+ self.new_action_dict = None
329
+ self.cur_reward_dict = {}
330
+ else:
331
+ item = {
332
+ "obs": self.new_observation,
333
+ "reward": self.cur_reward,
334
+ "terminated": self.cur_terminated,
335
+ "truncated": self.cur_truncated,
336
+ "info": self.cur_info,
337
+ }
338
+ if self.new_action is not None:
339
+ item["off_policy_action"] = self.new_action
340
+ self.new_observation = None
341
+ self.new_action = None
342
+ self.cur_reward = 0.0
343
+ if not self.training_enabled:
344
+ item["info"]["training_enabled"] = False
345
+
346
+ with self.results_avail_condition:
347
+ self.data_queue.put_nowait(item)
348
+ self.results_avail_condition.notify()
349
+
350
+
351
+ @OldAPIStack
352
+ class ExternalEnvWrapper(BaseEnv):
353
+ """Internal adapter of ExternalEnv to BaseEnv."""
354
+
355
+ def __init__(
356
+ self, external_env: "ExternalEnv", preprocessor: "Preprocessor" = None
357
+ ):
358
+ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
359
+
360
+ self.external_env = external_env
361
+ self.prep = preprocessor
362
+ self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
363
+ self._action_space = external_env.action_space
364
+ if preprocessor:
365
+ self._observation_space = preprocessor.observation_space
366
+ else:
367
+ self._observation_space = external_env.observation_space
368
+ external_env.start()
369
+
370
+ @override(BaseEnv)
371
+ def poll(
372
+ self,
373
+ ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
374
+ with self.external_env._results_avail_condition:
375
+ results = self._poll()
376
+ while len(results[0]) == 0:
377
+ self.external_env._results_avail_condition.wait()
378
+ results = self._poll()
379
+ if not self.external_env.is_alive():
380
+ raise Exception("Serving thread has stopped.")
381
+ return results
382
+
383
+ @override(BaseEnv)
384
+ def send_actions(self, action_dict: MultiEnvDict) -> None:
385
+ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
386
+
387
+ if self.multiagent:
388
+ for env_id, actions in action_dict.items():
389
+ self.external_env._episodes[env_id].action_queue.put(actions)
390
+ else:
391
+ for env_id, action in action_dict.items():
392
+ self.external_env._episodes[env_id].action_queue.put(
393
+ action[_DUMMY_AGENT_ID]
394
+ )
395
+
396
+ def _poll(
397
+ self,
398
+ ) -> Tuple[
399
+ MultiEnvDict,
400
+ MultiEnvDict,
401
+ MultiEnvDict,
402
+ MultiEnvDict,
403
+ MultiEnvDict,
404
+ MultiEnvDict,
405
+ ]:
406
+ from ray.rllib.env.base_env import with_dummy_agent_id
407
+
408
+ all_obs, all_rewards, all_terminateds, all_truncateds, all_infos = (
409
+ {},
410
+ {},
411
+ {},
412
+ {},
413
+ {},
414
+ )
415
+ off_policy_actions = {}
416
+ for eid, episode in self.external_env._episodes.copy().items():
417
+ data = episode.get_data()
418
+ cur_terminated = (
419
+ episode.cur_terminated_dict["__all__"]
420
+ if self.multiagent
421
+ else episode.cur_terminated
422
+ )
423
+ cur_truncated = (
424
+ episode.cur_truncated_dict["__all__"]
425
+ if self.multiagent
426
+ else episode.cur_truncated
427
+ )
428
+ if cur_terminated or cur_truncated:
429
+ del self.external_env._episodes[eid]
430
+ if data:
431
+ if self.prep:
432
+ all_obs[eid] = self.prep.transform(data["obs"])
433
+ else:
434
+ all_obs[eid] = data["obs"]
435
+ all_rewards[eid] = data["reward"]
436
+ all_terminateds[eid] = data["terminated"]
437
+ all_truncateds[eid] = data["truncated"]
438
+ all_infos[eid] = data["info"]
439
+ if "off_policy_action" in data:
440
+ off_policy_actions[eid] = data["off_policy_action"]
441
+ if self.multiagent:
442
+ # Ensure a consistent set of keys
443
+ # rely on all_obs having all possible keys for now.
444
+ for eid, eid_dict in all_obs.items():
445
+ for agent_id in eid_dict.keys():
446
+
447
+ def fix(d, zero_val):
448
+ if agent_id not in d[eid]:
449
+ d[eid][agent_id] = zero_val
450
+
451
+ fix(all_rewards, 0.0)
452
+ fix(all_terminateds, False)
453
+ fix(all_truncateds, False)
454
+ fix(all_infos, {})
455
+ return (
456
+ all_obs,
457
+ all_rewards,
458
+ all_terminateds,
459
+ all_truncateds,
460
+ all_infos,
461
+ off_policy_actions,
462
+ )
463
+ else:
464
+ return (
465
+ with_dummy_agent_id(all_obs),
466
+ with_dummy_agent_id(all_rewards),
467
+ with_dummy_agent_id(all_terminateds, "__all__"),
468
+ with_dummy_agent_id(all_truncateds, "__all__"),
469
+ with_dummy_agent_id(all_infos),
470
+ with_dummy_agent_id(off_policy_actions),
471
+ )
472
+
473
+ @property
474
+ @override(BaseEnv)
475
+ def observation_space(self) -> gym.spaces.Dict:
476
+ return self._observation_space
477
+
478
+ @property
479
+ @override(BaseEnv)
480
+ def action_space(self) -> gym.Space:
481
+ return self._action_space
.venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import gymnasium as gym
3
+ from typing import Optional
4
+
5
+ from ray.rllib.utils.annotations import override, OldAPIStack
6
+ from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode
7
+ from ray.rllib.utils.typing import MultiAgentDict
8
+
9
+
10
+ @OldAPIStack
11
+ class ExternalMultiAgentEnv(ExternalEnv):
12
+ """This is the multi-agent version of ExternalEnv."""
13
+
14
+ def __init__(
15
+ self,
16
+ action_space: gym.Space,
17
+ observation_space: gym.Space,
18
+ ):
19
+ """Initializes an ExternalMultiAgentEnv instance.
20
+
21
+ Args:
22
+ action_space: Action space of the env.
23
+ observation_space: Observation space of the env.
24
+ """
25
+ ExternalEnv.__init__(self, action_space, observation_space)
26
+
27
+ # We require to know all agents' spaces.
28
+ if isinstance(self.action_space, dict) or isinstance(
29
+ self.observation_space, dict
30
+ ):
31
+ if not (self.action_space.keys() == self.observation_space.keys()):
32
+ raise ValueError(
33
+ "Agent ids disagree for action space and obs "
34
+ "space dict: {} {}".format(
35
+ self.action_space.keys(), self.observation_space.keys()
36
+ )
37
+ )
38
+
39
+ def run(self):
40
+ """Override this to implement the multi-agent run loop.
41
+
42
+ Your loop should continuously:
43
+ 1. Call self.start_episode(episode_id)
44
+ 2. Call self.get_action(episode_id, obs_dict)
45
+ -or-
46
+ self.log_action(episode_id, obs_dict, action_dict)
47
+ 3. Call self.log_returns(episode_id, reward_dict)
48
+ 4. Call self.end_episode(episode_id, obs_dict)
49
+ 5. Wait if nothing to do.
50
+
51
+ Multiple episodes may be started at the same time.
52
+ """
53
+ raise NotImplementedError
54
+
55
+ @override(ExternalEnv)
56
+ def start_episode(
57
+ self, episode_id: Optional[str] = None, training_enabled: bool = True
58
+ ) -> str:
59
+ if episode_id is None:
60
+ episode_id = uuid.uuid4().hex
61
+
62
+ if episode_id in self._finished:
63
+ raise ValueError("Episode {} has already completed.".format(episode_id))
64
+
65
+ if episode_id in self._episodes:
66
+ raise ValueError("Episode {} is already started".format(episode_id))
67
+
68
+ self._episodes[episode_id] = _ExternalEnvEpisode(
69
+ episode_id, self._results_avail_condition, training_enabled, multiagent=True
70
+ )
71
+
72
+ return episode_id
73
+
74
+ @override(ExternalEnv)
75
+ def get_action(
76
+ self, episode_id: str, observation_dict: MultiAgentDict
77
+ ) -> MultiAgentDict:
78
+ """Record an observation and get the on-policy action.
79
+
80
+ Thereby, observation_dict is expected to contain the observation
81
+ of all agents acting in this episode step.
82
+
83
+ Args:
84
+ episode_id: Episode id returned from start_episode().
85
+ observation_dict: Current environment observation.
86
+
87
+ Returns:
88
+ action: Action from the env action space.
89
+ """
90
+
91
+ episode = self._get(episode_id)
92
+ return episode.wait_for_action(observation_dict)
93
+
94
+ @override(ExternalEnv)
95
+ def log_action(
96
+ self,
97
+ episode_id: str,
98
+ observation_dict: MultiAgentDict,
99
+ action_dict: MultiAgentDict,
100
+ ) -> None:
101
+ """Record an observation and (off-policy) action taken.
102
+
103
+ Args:
104
+ episode_id: Episode id returned from start_episode().
105
+ observation_dict: Current environment observation.
106
+ action_dict: Action for the observation.
107
+ """
108
+
109
+ episode = self._get(episode_id)
110
+ episode.log_action(observation_dict, action_dict)
111
+
112
+ @override(ExternalEnv)
113
+ def log_returns(
114
+ self,
115
+ episode_id: str,
116
+ reward_dict: MultiAgentDict,
117
+ info_dict: MultiAgentDict = None,
118
+ multiagent_done_dict: MultiAgentDict = None,
119
+ ) -> None:
120
+ """Record returns from the environment.
121
+
122
+ The reward will be attributed to the previous action taken by the
123
+ episode. Rewards accumulate until the next action. If no reward is
124
+ logged before the next action, a reward of 0.0 is assumed.
125
+
126
+ Args:
127
+ episode_id: Episode id returned from start_episode().
128
+ reward_dict: Reward from the environment agents.
129
+ info_dict: Optional info dict.
130
+ multiagent_done_dict: Optional done dict for agents.
131
+ """
132
+
133
+ episode = self._get(episode_id)
134
+
135
+ # Accumulate reward by agent.
136
+ # For existing agents, we want to add the reward up.
137
+ for agent, rew in reward_dict.items():
138
+ if agent in episode.cur_reward_dict:
139
+ episode.cur_reward_dict[agent] += rew
140
+ else:
141
+ episode.cur_reward_dict[agent] = rew
142
+
143
+ if multiagent_done_dict:
144
+ for agent, done in multiagent_done_dict.items():
145
+ episode.cur_done_dict[agent] = done
146
+
147
+ if info_dict:
148
+ episode.cur_info_dict = info_dict or {}
149
+
150
+ @override(ExternalEnv)
151
+ def end_episode(self, episode_id: str, observation_dict: MultiAgentDict) -> None:
152
+ """Record the end of an episode.
153
+
154
+ Args:
155
+ episode_id: Episode id returned from start_episode().
156
+ observation_dict: Current environment observation.
157
+ """
158
+
159
+ episode = self._get(episode_id)
160
+ self._finished.add(episode.episode_id)
161
+ episode.done(observation_dict)
.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import logging
3
+ from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type
4
+
5
+ import numpy as np
6
+
7
+ from ray.rllib.env.base_env import BaseEnv
8
+ from ray.rllib.env.env_context import EnvContext
9
+ from ray.rllib.utils.annotations import OldAPIStack, override
10
+ from ray.rllib.utils.deprecation import Deprecated
11
+ from ray.rllib.utils.typing import (
12
+ AgentID,
13
+ EnvCreator,
14
+ EnvID,
15
+ EnvType,
16
+ MultiAgentDict,
17
+ MultiEnvDict,
18
+ )
19
+ from ray.util import log_once
20
+ from ray.util.annotations import DeveloperAPI, PublicAPI
21
+
22
+ # If the obs space is Dict type, look for the global state under this key.
23
+ ENV_STATE = "state"
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @PublicAPI(stability="beta")
29
+ class MultiAgentEnv(gym.Env):
30
+ """An environment that hosts multiple independent agents.
31
+
32
+ Agents are identified by AgentIDs (string).
33
+ """
34
+
35
+ # Optional mappings from AgentID to individual agents' spaces.
36
+ # Set this to an "exhaustive" dictionary, mapping all possible AgentIDs to
37
+ # individual agents' spaces. Alternatively, override
38
+ # `get_observation_space(agent_id=...)` and `get_action_space(agent_id=...)`, which
39
+ # is the API that RLlib uses to get individual spaces and whose default
40
+ # implementation is to simply look up `agent_id` in these dicts.
41
+ observation_spaces: Optional[Dict[AgentID, gym.Space]] = None
42
+ action_spaces: Optional[Dict[AgentID, gym.Space]] = None
43
+
44
+ # All agents currently active in the environment. This attribute may change during
45
+ # the lifetime of the env or even during an individual episode.
46
+ agents: List[AgentID] = []
47
+ # All agents that may appear in the environment, ever.
48
+ # This attribute should not be changed during the lifetime of this env.
49
+ possible_agents: List[AgentID] = []
50
+
51
+ # @OldAPIStack, use `observation_spaces` and `action_spaces`, instead.
52
+ observation_space: Optional[gym.Space] = None
53
+ action_space: Optional[gym.Space] = None
54
+
55
+ def __init__(self):
56
+ super().__init__()
57
+
58
+ # @OldAPIStack
59
+ if not hasattr(self, "_agent_ids"):
60
+ self._agent_ids = set()
61
+
62
+ # If these important attributes are not set, try to infer them.
63
+ if not self.agents:
64
+ self.agents = list(self._agent_ids)
65
+ if not self.possible_agents:
66
+ self.possible_agents = self.agents.copy()
67
+
68
+ def reset(
69
+ self,
70
+ *,
71
+ seed: Optional[int] = None,
72
+ options: Optional[dict] = None,
73
+ ) -> Tuple[MultiAgentDict, MultiAgentDict]: # type: ignore
74
+ """Resets the env and returns observations from ready agents.
75
+
76
+ Args:
77
+ seed: An optional seed to use for the new episode.
78
+
79
+ Returns:
80
+ New observations for each ready agent.
81
+
82
+ .. testcode::
83
+ :skipif: True
84
+
85
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
86
+ class MyMultiAgentEnv(MultiAgentEnv):
87
+ # Define your env here.
88
+ env = MyMultiAgentEnv()
89
+ obs, infos = env.reset(seed=42, options={})
90
+ print(obs)
91
+
92
+ .. testoutput::
93
+
94
+ {
95
+ "car_0": [2.4, 1.6],
96
+ "car_1": [3.4, -3.2],
97
+ "traffic_light_1": [0, 3, 5, 1],
98
+ }
99
+ """
100
+ # Call super's `reset()` method to (maybe) set the given `seed`.
101
+ super().reset(seed=seed, options=options)
102
+
103
+ def step(
104
+ self, action_dict: MultiAgentDict
105
+ ) -> Tuple[
106
+ MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
107
+ ]:
108
+ """Returns observations from ready agents.
109
+
110
+ The returns are dicts mapping from agent_id strings to values. The
111
+ number of agents in the env can vary over time.
112
+
113
+ Returns:
114
+ Tuple containing 1) new observations for
115
+ each ready agent, 2) reward values for each ready agent. If
116
+ the episode is just started, the value will be None.
117
+ 3) Terminated values for each ready agent. The special key
118
+ "__all__" (required) is used to indicate env termination.
119
+ 4) Truncated values for each ready agent.
120
+ 5) Info values for each agent id (may be empty dicts).
121
+
122
+ .. testcode::
123
+ :skipif: True
124
+
125
+ env = ...
126
+ obs, rewards, terminateds, truncateds, infos = env.step(action_dict={
127
+ "car_0": 1, "car_1": 0, "traffic_light_1": 2,
128
+ })
129
+ print(rewards)
130
+
131
+ print(terminateds)
132
+
133
+ print(infos)
134
+
135
+ .. testoutput::
136
+
137
+ {
138
+ "car_0": 3,
139
+ "car_1": -1,
140
+ "traffic_light_1": 0,
141
+ }
142
+ {
143
+ "car_0": False, # car_0 is still running
144
+ "car_1": True, # car_1 is terminated
145
+ "__all__": False, # the env is not terminated
146
+ }
147
+ {
148
+ "car_0": {}, # info for car_0
149
+ "car_1": {}, # info for car_1
150
+ }
151
+
152
+ """
153
+ raise NotImplementedError
154
+
155
+ def render(self) -> None:
156
+ """Tries to render the environment."""
157
+
158
+ # By default, do nothing.
159
+ pass
160
+
161
+ def get_observation_space(self, agent_id: AgentID) -> gym.Space:
162
+ if self.observation_spaces is not None:
163
+ return self.observation_spaces[agent_id]
164
+
165
+ # @OldAPIStack behavior.
166
+ # `self.observation_space` is a `gym.spaces.Dict` AND contains `agent_id`.
167
+ if (
168
+ isinstance(self.observation_space, gym.spaces.Dict)
169
+ and agent_id in self.observation_space.spaces
170
+ ):
171
+ return self.observation_space[agent_id]
172
+ # `self.observation_space` is not a `gym.spaces.Dict` OR doesn't contain
173
+ # `agent_id` -> The defined space is most likely meant to be the space
174
+ # for all agents.
175
+ else:
176
+ return self.observation_space
177
+
178
+ def get_action_space(self, agent_id: AgentID) -> gym.Space:
179
+ if self.action_spaces is not None:
180
+ return self.action_spaces[agent_id]
181
+
182
+ # @OldAPIStack behavior.
183
+ # `self.action_space` is a `gym.spaces.Dict` AND contains `agent_id`.
184
+ if (
185
+ isinstance(self.action_space, gym.spaces.Dict)
186
+ and agent_id in self.action_space.spaces
187
+ ):
188
+ return self.action_space[agent_id]
189
+ # `self.action_space` is not a `gym.spaces.Dict` OR doesn't contain
190
+ # `agent_id` -> The defined space is most likely meant to be the space
191
+ # for all agents.
192
+ else:
193
+ return self.action_space
194
+
195
+ @property
196
+ def num_agents(self) -> int:
197
+ return len(self.agents)
198
+
199
+ @property
200
+ def max_num_agents(self) -> int:
201
+ return len(self.possible_agents)
202
+
203
+ # fmt: off
204
+ # __grouping_doc_begin__
205
+ def with_agent_groups(
206
+ self,
207
+ groups: Dict[str, List[AgentID]],
208
+ obs_space: gym.Space = None,
209
+ act_space: gym.Space = None,
210
+ ) -> "MultiAgentEnv":
211
+ """Convenience method for grouping together agents in this env.
212
+
213
+ An agent group is a list of agent IDs that are mapped to a single
214
+ logical agent. All agents of the group must act at the same time in the
215
+ environment. The grouped agent exposes Tuple action and observation
216
+ spaces that are the concatenated action and obs spaces of the
217
+ individual agents.
218
+
219
+ The rewards of all the agents in a group are summed. The individual
220
+ agent rewards are available under the "individual_rewards" key of the
221
+ group info return.
222
+
223
+ Agent grouping is required to leverage algorithms such as Q-Mix.
224
+
225
+ Args:
226
+ groups: Mapping from group id to a list of the agent ids
227
+ of group members. If an agent id is not present in any group
228
+ value, it will be left ungrouped. The group id becomes a new agent ID
229
+ in the final environment.
230
+ obs_space: Optional observation space for the grouped
231
+ env. Must be a tuple space. If not provided, will infer this to be a
232
+ Tuple of n individual agents spaces (n=num agents in a group).
233
+ act_space: Optional action space for the grouped env.
234
+ Must be a tuple space. If not provided, will infer this to be a Tuple
235
+ of n individual agents spaces (n=num agents in a group).
236
+
237
+ .. testcode::
238
+ :skipif: True
239
+
240
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
241
+ class MyMultiAgentEnv(MultiAgentEnv):
242
+ # define your env here
243
+ ...
244
+ env = MyMultiAgentEnv(...)
245
+ grouped_env = env.with_agent_groups(env, {
246
+ "group1": ["agent1", "agent2", "agent3"],
247
+ "group2": ["agent4", "agent5"],
248
+ })
249
+
250
+ """
251
+
252
+ from ray.rllib.env.wrappers.group_agents_wrapper import \
253
+ GroupAgentsWrapper
254
+ return GroupAgentsWrapper(self, groups, obs_space, act_space)
255
+
256
+ # __grouping_doc_end__
257
+ # fmt: on
258
+
259
+ @OldAPIStack
260
+ @Deprecated(new="MultiAgentEnv.possible_agents", error=False)
261
+ def get_agent_ids(self) -> Set[AgentID]:
262
+ if not hasattr(self, "_agent_ids"):
263
+ self._agent_ids = set()
264
+ if not isinstance(self._agent_ids, set):
265
+ self._agent_ids = set(self._agent_ids)
266
+ # Make this backward compatible as much as possible.
267
+ return self._agent_ids if self._agent_ids else set(self.agents)
268
+
269
+ @OldAPIStack
270
+ def to_base_env(
271
+ self,
272
+ make_env: Optional[Callable[[int], EnvType]] = None,
273
+ num_envs: int = 1,
274
+ remote_envs: bool = False,
275
+ remote_env_batch_wait_ms: int = 0,
276
+ restart_failed_sub_environments: bool = False,
277
+ ) -> "BaseEnv":
278
+ """Converts an RLlib MultiAgentEnv into a BaseEnv object.
279
+
280
+ The resulting BaseEnv is always vectorized (contains n
281
+ sub-environments) to support batched forward passes, where n may
282
+ also be 1. BaseEnv also supports async execution via the `poll` and
283
+ `send_actions` methods and thus supports external simulators.
284
+
285
+ Args:
286
+ make_env: A callable taking an int as input (which indicates
287
+ the number of individual sub-environments within the final
288
+ vectorized BaseEnv) and returning one individual
289
+ sub-environment.
290
+ num_envs: The number of sub-environments to create in the
291
+ resulting (vectorized) BaseEnv. The already existing `env`
292
+ will be one of the `num_envs`.
293
+ remote_envs: Whether each sub-env should be a @ray.remote
294
+ actor. You can set this behavior in your config via the
295
+ `remote_worker_envs=True` option.
296
+ remote_env_batch_wait_ms: The wait time (in ms) to poll remote
297
+ sub-environments for, if applicable. Only used if
298
+ `remote_envs` is True.
299
+ restart_failed_sub_environments: If True and any sub-environment (within
300
+ a vectorized env) throws any error during env stepping, we will try to
301
+ restart the faulty sub-environment. This is done
302
+ without disturbing the other (still intact) sub-environments.
303
+
304
+ Returns:
305
+ The resulting BaseEnv object.
306
+ """
307
+ from ray.rllib.env.remote_base_env import RemoteBaseEnv
308
+
309
+ if remote_envs:
310
+ env = RemoteBaseEnv(
311
+ make_env,
312
+ num_envs,
313
+ multiagent=True,
314
+ remote_env_batch_wait_ms=remote_env_batch_wait_ms,
315
+ restart_failed_sub_environments=restart_failed_sub_environments,
316
+ )
317
+ # Sub-environments are not ray.remote actors.
318
+ else:
319
+ env = MultiAgentEnvWrapper(
320
+ make_env=make_env,
321
+ existing_envs=[self],
322
+ num_envs=num_envs,
323
+ restart_failed_sub_environments=restart_failed_sub_environments,
324
+ )
325
+
326
+ return env
327
+
328
+
329
+ @DeveloperAPI
330
+ def make_multi_agent(
331
+ env_name_or_creator: Union[str, EnvCreator],
332
+ ) -> Type["MultiAgentEnv"]:
333
+ """Convenience wrapper for any single-agent env to be converted into MA.
334
+
335
+ Allows you to convert a simple (single-agent) `gym.Env` class
336
+ into a `MultiAgentEnv` class. This function simply stacks n instances
337
+ of the given ```gym.Env``` class into one unified ``MultiAgentEnv`` class
338
+ and returns this class, thus pretending the agents act together in the
339
+ same environment, whereas - under the hood - they live separately from
340
+ each other in n parallel single-agent envs.
341
+
342
+ Agent IDs in the resulting and are int numbers starting from 0
343
+ (first agent).
344
+
345
+ Args:
346
+ env_name_or_creator: String specifier or env_maker function taking
347
+ an EnvContext object as only arg and returning a gym.Env.
348
+
349
+ Returns:
350
+ New MultiAgentEnv class to be used as env.
351
+ The constructor takes a config dict with `num_agents` key
352
+ (default=1). The rest of the config dict will be passed on to the
353
+ underlying single-agent env's constructor.
354
+
355
+ .. testcode::
356
+ :skipif: True
357
+
358
+ from ray.rllib.env.multi_agent_env import make_multi_agent
359
+ # By gym string:
360
+ ma_cartpole_cls = make_multi_agent("CartPole-v1")
361
+ # Create a 2 agent multi-agent cartpole.
362
+ ma_cartpole = ma_cartpole_cls({"num_agents": 2})
363
+ obs = ma_cartpole.reset()
364
+ print(obs)
365
+
366
+ # By env-maker callable:
367
+ from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole
368
+ ma_stateless_cartpole_cls = make_multi_agent(
369
+ lambda config: StatelessCartPole(config))
370
+ # Create a 3 agent multi-agent stateless cartpole.
371
+ ma_stateless_cartpole = ma_stateless_cartpole_cls(
372
+ {"num_agents": 3})
373
+ print(obs)
374
+
375
+ .. testoutput::
376
+
377
+ {0: [...], 1: [...]}
378
+ {0: [...], 1: [...], 2: [...]}
379
+ """
380
+
381
+ class MultiEnv(MultiAgentEnv):
382
+ def __init__(self, config: EnvContext = None):
383
+ super().__init__()
384
+
385
+ # Note: Explicitly check for None here, because config
386
+ # can have an empty dict but meaningful data fields (worker_index,
387
+ # vector_index) etc.
388
+ # TODO (sven): Clean this up, so we are not mixing up dict fields
389
+ # with data fields.
390
+ if config is None:
391
+ config = {}
392
+ num = config.pop("num_agents", 1)
393
+ if isinstance(env_name_or_creator, str):
394
+ self.envs = [gym.make(env_name_or_creator) for _ in range(num)]
395
+ else:
396
+ self.envs = [env_name_or_creator(config) for _ in range(num)]
397
+ self.terminateds = set()
398
+ self.truncateds = set()
399
+ self.observation_spaces = {
400
+ i: self.envs[i].observation_space for i in range(num)
401
+ }
402
+ self.action_spaces = {i: self.envs[i].action_space for i in range(num)}
403
+ self.agents = list(range(num))
404
+ self.possible_agents = self.agents.copy()
405
+
406
+ @override(MultiAgentEnv)
407
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
408
+ self.terminateds = set()
409
+ self.truncateds = set()
410
+ obs, infos = {}, {}
411
+ for i, env in enumerate(self.envs):
412
+ obs[i], infos[i] = env.reset(seed=seed, options=options)
413
+
414
+ return obs, infos
415
+
416
+ @override(MultiAgentEnv)
417
+ def step(self, action_dict):
418
+ obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}
419
+
420
+ # The environment is expecting an action for at least one agent.
421
+ if len(action_dict) == 0:
422
+ raise ValueError(
423
+ "The environment is expecting an action for at least one agent."
424
+ )
425
+
426
+ for i, action in action_dict.items():
427
+ obs[i], rew[i], terminated[i], truncated[i], info[i] = self.envs[
428
+ i
429
+ ].step(action)
430
+ if terminated[i]:
431
+ self.terminateds.add(i)
432
+ if truncated[i]:
433
+ self.truncateds.add(i)
434
+ # TODO: Flaw in our MultiAgentEnv API wrt. new gymnasium: Need to return
435
+ # an additional episode_done bool that covers cases where all agents are
436
+ # either terminated or truncated, but not all are truncated and not all are
437
+ # terminated. We can then get rid of the aweful `__all__` special keys!
438
+ terminated["__all__"] = len(self.terminateds) + len(self.truncateds) == len(
439
+ self.envs
440
+ )
441
+ truncated["__all__"] = len(self.truncateds) == len(self.envs)
442
+ return obs, rew, terminated, truncated, info
443
+
444
+ @override(MultiAgentEnv)
445
+ def render(self):
446
+ # This render method simply renders all n underlying individual single-agent
447
+ # envs and concatenates their images (on top of each other if the returned
448
+ # images have dims where [width] > [height], otherwise next to each other).
449
+ render_images = [e.render() for e in self.envs]
450
+ if render_images[0].shape[1] > render_images[0].shape[0]:
451
+ concat_dim = 0
452
+ else:
453
+ concat_dim = 1
454
+ return np.concatenate(render_images, axis=concat_dim)
455
+
456
+ return MultiEnv
457
+
458
+
459
+ @OldAPIStack
460
+ class MultiAgentEnvWrapper(BaseEnv):
461
+ """Internal adapter of MultiAgentEnv to BaseEnv.
462
+
463
+ This also supports vectorization if num_envs > 1.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ make_env: Callable[[int], EnvType],
469
+ existing_envs: List["MultiAgentEnv"],
470
+ num_envs: int,
471
+ restart_failed_sub_environments: bool = False,
472
+ ):
473
+ """Wraps MultiAgentEnv(s) into the BaseEnv API.
474
+
475
+ Args:
476
+ make_env: Factory that produces a new MultiAgentEnv instance taking the
477
+ vector index as only call argument.
478
+ Must be defined, if the number of existing envs is less than num_envs.
479
+ existing_envs: List of already existing multi-agent envs.
480
+ num_envs: Desired num multiagent envs to have at the end in
481
+ total. This will include the given (already created)
482
+ `existing_envs`.
483
+ restart_failed_sub_environments: If True and any sub-environment (within
484
+ this vectorized env) throws any error during env stepping, we will try
485
+ to restart the faulty sub-environment. This is done
486
+ without disturbing the other (still intact) sub-environments.
487
+ """
488
+ self.make_env = make_env
489
+ self.envs = existing_envs
490
+ self.num_envs = num_envs
491
+ self.restart_failed_sub_environments = restart_failed_sub_environments
492
+
493
+ self.terminateds = set()
494
+ self.truncateds = set()
495
+ while len(self.envs) < self.num_envs:
496
+ self.envs.append(self.make_env(len(self.envs)))
497
+ for env in self.envs:
498
+ assert isinstance(env, MultiAgentEnv)
499
+ self._init_env_state(idx=None)
500
+ self._unwrapped_env = self.envs[0].unwrapped
501
+
502
+ @override(BaseEnv)
503
+ def poll(
504
+ self,
505
+ ) -> Tuple[
506
+ MultiEnvDict,
507
+ MultiEnvDict,
508
+ MultiEnvDict,
509
+ MultiEnvDict,
510
+ MultiEnvDict,
511
+ MultiEnvDict,
512
+ ]:
513
+ obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {}
514
+ for i, env_state in enumerate(self.env_states):
515
+ (
516
+ obs[i],
517
+ rewards[i],
518
+ terminateds[i],
519
+ truncateds[i],
520
+ infos[i],
521
+ ) = env_state.poll()
522
+ return obs, rewards, terminateds, truncateds, infos, {}
523
+
524
+ @override(BaseEnv)
525
+ def send_actions(self, action_dict: MultiEnvDict) -> None:
526
+ for env_id, agent_dict in action_dict.items():
527
+ if env_id in self.terminateds or env_id in self.truncateds:
528
+ raise ValueError(
529
+ f"Env {env_id} is already done and cannot accept new actions"
530
+ )
531
+ env = self.envs[env_id]
532
+ try:
533
+ obs, rewards, terminateds, truncateds, infos = env.step(agent_dict)
534
+ except Exception as e:
535
+ if self.restart_failed_sub_environments:
536
+ logger.exception(e.args[0])
537
+ self.try_restart(env_id=env_id)
538
+ obs = e
539
+ rewards = {}
540
+ terminateds = {"__all__": True}
541
+ truncateds = {"__all__": False}
542
+ infos = {}
543
+ else:
544
+ raise e
545
+
546
+ assert isinstance(
547
+ obs, (dict, Exception)
548
+ ), "Not a multi-agent obs dict or an Exception!"
549
+ assert isinstance(rewards, dict), "Not a multi-agent reward dict!"
550
+ assert isinstance(terminateds, dict), "Not a multi-agent terminateds dict!"
551
+ assert isinstance(truncateds, dict), "Not a multi-agent truncateds dict!"
552
+ assert isinstance(infos, dict), "Not a multi-agent info dict!"
553
+ if isinstance(obs, dict):
554
+ info_diff = set(infos).difference(set(obs))
555
+ if info_diff and info_diff != {"__common__"}:
556
+ raise ValueError(
557
+ "Key set for infos must be a subset of obs (plus optionally "
558
+ "the '__common__' key for infos concerning all/no agents): "
559
+ "{} vs {}".format(infos.keys(), obs.keys())
560
+ )
561
+ if "__all__" not in terminateds:
562
+ raise ValueError(
563
+ "In multi-agent environments, '__all__': True|False must "
564
+ "be included in the 'terminateds' dict: got {}.".format(terminateds)
565
+ )
566
+ elif "__all__" not in truncateds:
567
+ raise ValueError(
568
+ "In multi-agent environments, '__all__': True|False must "
569
+ "be included in the 'truncateds' dict: got {}.".format(truncateds)
570
+ )
571
+
572
+ if terminateds["__all__"]:
573
+ self.terminateds.add(env_id)
574
+ if truncateds["__all__"]:
575
+ self.truncateds.add(env_id)
576
+ self.env_states[env_id].observe(
577
+ obs, rewards, terminateds, truncateds, infos
578
+ )
579
+
580
+ @override(BaseEnv)
581
+ def try_reset(
582
+ self,
583
+ env_id: Optional[EnvID] = None,
584
+ *,
585
+ seed: Optional[int] = None,
586
+ options: Optional[dict] = None,
587
+ ) -> Optional[Tuple[MultiEnvDict, MultiEnvDict]]:
588
+ ret_obs = {}
589
+ ret_infos = {}
590
+ if isinstance(env_id, int):
591
+ env_id = [env_id]
592
+ if env_id is None:
593
+ env_id = list(range(len(self.envs)))
594
+ for idx in env_id:
595
+ obs, infos = self.env_states[idx].reset(seed=seed, options=options)
596
+
597
+ if isinstance(obs, Exception):
598
+ if self.restart_failed_sub_environments:
599
+ self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
600
+ else:
601
+ raise obs
602
+ else:
603
+ assert isinstance(obs, dict), "Not a multi-agent obs dict!"
604
+ if obs is not None:
605
+ if idx in self.terminateds:
606
+ self.terminateds.remove(idx)
607
+ if idx in self.truncateds:
608
+ self.truncateds.remove(idx)
609
+ ret_obs[idx] = obs
610
+ ret_infos[idx] = infos
611
+ return ret_obs, ret_infos
612
+
613
+ @override(BaseEnv)
614
+ def try_restart(self, env_id: Optional[EnvID] = None) -> None:
615
+ if isinstance(env_id, int):
616
+ env_id = [env_id]
617
+ if env_id is None:
618
+ env_id = list(range(len(self.envs)))
619
+ for idx in env_id:
620
+ # Try closing down the old (possibly faulty) sub-env, but ignore errors.
621
+ try:
622
+ self.envs[idx].close()
623
+ except Exception as e:
624
+ if log_once("close_sub_env"):
625
+ logger.warning(
626
+ "Trying to close old and replaced sub-environment (at vector "
627
+ f"index={idx}), but closing resulted in error:\n{e}"
628
+ )
629
+ # Try recreating the sub-env.
630
+ logger.warning(f"Trying to restart sub-environment at index {idx}.")
631
+ self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
632
+ logger.warning(f"Sub-environment at index {idx} restarted successfully.")
633
+
634
+ @override(BaseEnv)
635
+ def get_sub_environments(
636
+ self, as_dict: bool = False
637
+ ) -> Union[Dict[str, EnvType], List[EnvType]]:
638
+ if as_dict:
639
+ return {_id: env_state.env for _id, env_state in enumerate(self.env_states)}
640
+ return [state.env for state in self.env_states]
641
+
642
+ @override(BaseEnv)
643
+ def try_render(self, env_id: Optional[EnvID] = None) -> None:
644
+ if env_id is None:
645
+ env_id = 0
646
+ assert isinstance(env_id, int)
647
+ return self.envs[env_id].render()
648
+
649
+ @property
650
+ @override(BaseEnv)
651
+ def observation_space(self) -> gym.spaces.Dict:
652
+ return self.envs[0].observation_space
653
+
654
+ @property
655
+ @override(BaseEnv)
656
+ def action_space(self) -> gym.Space:
657
+ return self.envs[0].action_space
658
+
659
+ @override(BaseEnv)
660
+ def get_agent_ids(self) -> Set[AgentID]:
661
+ return self.envs[0].get_agent_ids()
662
+
663
+ def _init_env_state(self, idx: Optional[int] = None) -> None:
664
+ """Resets all or one particular sub-environment's state (by index).
665
+
666
+ Args:
667
+ idx: The index to reset at. If None, reset all the sub-environments' states.
668
+ """
669
+ # If index is None, reset all sub-envs' states:
670
+ if idx is None:
671
+ self.env_states = [
672
+ _MultiAgentEnvState(env, self.restart_failed_sub_environments)
673
+ for env in self.envs
674
+ ]
675
+ # Index provided, reset only the sub-env's state at the given index.
676
+ else:
677
+ assert isinstance(idx, int)
678
+ self.env_states[idx] = _MultiAgentEnvState(
679
+ self.envs[idx], self.restart_failed_sub_environments
680
+ )
681
+
682
+
683
+ @OldAPIStack
684
+ class _MultiAgentEnvState:
685
+ def __init__(self, env: MultiAgentEnv, return_error_as_obs: bool = False):
686
+ assert isinstance(env, MultiAgentEnv)
687
+ self.env = env
688
+ self.return_error_as_obs = return_error_as_obs
689
+
690
+ self.initialized = False
691
+ self.last_obs = {}
692
+ self.last_rewards = {}
693
+ self.last_terminateds = {"__all__": False}
694
+ self.last_truncateds = {"__all__": False}
695
+ self.last_infos = {}
696
+
697
+ def poll(
698
+ self,
699
+ ) -> Tuple[
700
+ MultiAgentDict,
701
+ MultiAgentDict,
702
+ MultiAgentDict,
703
+ MultiAgentDict,
704
+ MultiAgentDict,
705
+ ]:
706
+ if not self.initialized:
707
+ # TODO(sven): Should we make it possible to pass in a seed here?
708
+ self.reset()
709
+ self.initialized = True
710
+
711
+ observations = self.last_obs
712
+ rewards = {}
713
+ terminateds = {"__all__": self.last_terminateds["__all__"]}
714
+ truncateds = {"__all__": self.last_truncateds["__all__"]}
715
+ infos = self.last_infos
716
+
717
+ # If episode is done or we have an error, release everything we have.
718
+ if (
719
+ terminateds["__all__"]
720
+ or truncateds["__all__"]
721
+ or isinstance(observations, Exception)
722
+ ):
723
+ rewards = self.last_rewards
724
+ self.last_rewards = {}
725
+ terminateds = self.last_terminateds
726
+ if isinstance(observations, Exception):
727
+ terminateds["__all__"] = True
728
+ truncateds["__all__"] = False
729
+ self.last_terminateds = {}
730
+ truncateds = self.last_truncateds
731
+ self.last_truncateds = {}
732
+ self.last_obs = {}
733
+ infos = self.last_infos
734
+ self.last_infos = {}
735
+ # Only release those agents' rewards/terminateds/truncateds/infos, whose
736
+ # observations we have.
737
+ else:
738
+ for ag in observations.keys():
739
+ if ag in self.last_rewards:
740
+ rewards[ag] = self.last_rewards[ag]
741
+ del self.last_rewards[ag]
742
+ if ag in self.last_terminateds:
743
+ terminateds[ag] = self.last_terminateds[ag]
744
+ del self.last_terminateds[ag]
745
+ if ag in self.last_truncateds:
746
+ truncateds[ag] = self.last_truncateds[ag]
747
+ del self.last_truncateds[ag]
748
+
749
+ self.last_terminateds["__all__"] = False
750
+ self.last_truncateds["__all__"] = False
751
+ return observations, rewards, terminateds, truncateds, infos
752
+
753
+ def observe(
754
+ self,
755
+ obs: MultiAgentDict,
756
+ rewards: MultiAgentDict,
757
+ terminateds: MultiAgentDict,
758
+ truncateds: MultiAgentDict,
759
+ infos: MultiAgentDict,
760
+ ):
761
+ self.last_obs = obs
762
+ for ag, r in rewards.items():
763
+ if ag in self.last_rewards:
764
+ self.last_rewards[ag] += r
765
+ else:
766
+ self.last_rewards[ag] = r
767
+ for ag, d in terminateds.items():
768
+ if ag in self.last_terminateds:
769
+ self.last_terminateds[ag] = self.last_terminateds[ag] or d
770
+ else:
771
+ self.last_terminateds[ag] = d
772
+ for ag, t in truncateds.items():
773
+ if ag in self.last_truncateds:
774
+ self.last_truncateds[ag] = self.last_truncateds[ag] or t
775
+ else:
776
+ self.last_truncateds[ag] = t
777
+ self.last_infos = infos
778
+
779
+ def reset(
780
+ self,
781
+ *,
782
+ seed: Optional[int] = None,
783
+ options: Optional[dict] = None,
784
+ ) -> Tuple[MultiAgentDict, MultiAgentDict]:
785
+ try:
786
+ obs_and_infos = self.env.reset(seed=seed, options=options)
787
+ except Exception as e:
788
+ if self.return_error_as_obs:
789
+ logger.exception(e.args[0])
790
+ obs_and_infos = e, e
791
+ else:
792
+ raise e
793
+
794
+ self.last_obs, self.last_infos = obs_and_infos
795
+ self.last_rewards = {}
796
+ self.last_terminateds = {"__all__": False}
797
+ self.last_truncateds = {"__all__": False}
798
+
799
+ return self.last_obs, self.last_infos
.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from functools import partial
3
+ import logging
4
+ import time
5
+ from typing import Collection, DefaultDict, Dict, List, Optional, Union
6
+
7
+ import gymnasium as gym
8
+
9
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
10
+ from ray.rllib.callbacks.utils import make_callback
11
+ from ray.rllib.core import (
12
+ COMPONENT_ENV_TO_MODULE_CONNECTOR,
13
+ COMPONENT_MODULE_TO_ENV_CONNECTOR,
14
+ COMPONENT_RL_MODULE,
15
+ )
16
+ from ray.rllib.core.columns import Columns
17
+ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec
18
+ from ray.rllib.env.env_context import EnvContext
19
+ from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
20
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
21
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
22
+ from ray.rllib.env.utils import _gym_env_creator
23
+ from ray.rllib.utils import force_list
24
+ from ray.rllib.utils.annotations import override
25
+ from ray.rllib.utils.checkpoints import Checkpointable
26
+ from ray.rllib.utils.deprecation import Deprecated
27
+ from ray.rllib.utils.framework import get_device, try_import_torch
28
+ from ray.rllib.utils.metrics import (
29
+ EPISODE_DURATION_SEC_MEAN,
30
+ EPISODE_LEN_MAX,
31
+ EPISODE_LEN_MEAN,
32
+ EPISODE_LEN_MIN,
33
+ EPISODE_RETURN_MAX,
34
+ EPISODE_RETURN_MEAN,
35
+ EPISODE_RETURN_MIN,
36
+ NUM_AGENT_STEPS_SAMPLED,
37
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME,
38
+ NUM_ENV_STEPS_SAMPLED,
39
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
40
+ NUM_EPISODES,
41
+ NUM_EPISODES_LIFETIME,
42
+ NUM_MODULE_STEPS_SAMPLED,
43
+ NUM_MODULE_STEPS_SAMPLED_LIFETIME,
44
+ TIME_BETWEEN_SAMPLING,
45
+ WEIGHTS_SEQ_NO,
46
+ )
47
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
48
+ from ray.rllib.utils.pre_checks.env import check_multiagent_environments
49
+ from ray.rllib.utils.typing import EpisodeID, ModelWeights, ResultDict, StateDict
50
+ from ray.tune.registry import ENV_CREATOR, _global_registry
51
+ from ray.util.annotations import PublicAPI
52
+
53
+ torch, _ = try_import_torch()
54
+ logger = logging.getLogger("ray.rllib")
55
+
56
+
57
+ # TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself
58
+ # a Checkpointable. Currently, only some of its subclasses are Checkpointables.
59
+ @PublicAPI(stability="alpha")
60
+ class MultiAgentEnvRunner(EnvRunner, Checkpointable):
61
+ """The genetic environment runner for the multi-agent case."""
62
+
63
+ @override(EnvRunner)
64
+ def __init__(self, config: AlgorithmConfig, **kwargs):
65
+ """Initializes a MultiAgentEnvRunner instance.
66
+
67
+ Args:
68
+ config: An `AlgorithmConfig` object containing all settings needed to
69
+ build this `EnvRunner` class.
70
+ """
71
+ super().__init__(config=config)
72
+
73
+ # Raise an Error, if the provided config is not a multi-agent one.
74
+ if not self.config.is_multi_agent:
75
+ raise ValueError(
76
+ f"Cannot use this EnvRunner class ({type(self).__name__}), if your "
77
+ "setup is not multi-agent! Try adding multi-agent information to your "
78
+ "AlgorithmConfig via calling the `config.multi_agent(policies=..., "
79
+ "policy_mapping_fn=...)`."
80
+ )
81
+
82
+ # Get the worker index on which this instance is running.
83
+ self.worker_index: int = kwargs.get("worker_index")
84
+ self.tune_trial_id: str = kwargs.get("tune_trial_id")
85
+
86
+ # Set up all metrics-related structures and counters.
87
+ self.metrics: Optional[MetricsLogger] = None
88
+ self._setup_metrics()
89
+
90
+ # Create our callbacks object.
91
+ self._callbacks = [cls() for cls in force_list(self.config.callbacks_class)]
92
+
93
+ # Set device.
94
+ self._device = get_device(
95
+ self.config,
96
+ 0 if not self.worker_index else self.config.num_gpus_per_env_runner,
97
+ )
98
+
99
+ # Create the vectorized gymnasium env.
100
+ self.env: Optional[gym.Wrapper] = None
101
+ self.num_envs: int = 0
102
+ self.make_env()
103
+
104
+ # Create the env-to-module connector pipeline.
105
+ self._env_to_module = self.config.build_env_to_module_connector(
106
+ self.env.unwrapped, device=self._device
107
+ )
108
+ # Cached env-to-module results taken at the end of a `_sample_timesteps()`
109
+ # call to make sure the final observation (before an episode cut) gets properly
110
+ # processed (and maybe postprocessed and re-stored into the episode).
111
+ # For example, if we had a connector that normalizes observations and directly
112
+ # re-inserts these new obs back into the episode, the last observation in each
113
+ # sample call would NOT be processed, which could be very harmful in cases,
114
+ # in which value function bootstrapping of those (truncation) observations is
115
+ # required in the learning step.
116
+ self._cached_to_module = None
117
+
118
+ # Construct the MultiRLModule.
119
+ self.module: Optional[MultiRLModule] = None
120
+ self.make_module()
121
+
122
+ # Create the module-to-env connector pipeline.
123
+ self._module_to_env = self.config.build_module_to_env_connector(
124
+ self.env.unwrapped
125
+ )
126
+
127
+ self._needs_initial_reset: bool = True
128
+ self._episode: Optional[MultiAgentEpisode] = None
129
+ self._shared_data = None
130
+
131
+ self._weights_seq_no: int = 0
132
+
133
+ # Measures the time passed between returning from `sample()`
134
+ # and receiving the next `sample()` request from the user.
135
+ self._time_after_sampling = None
136
+
137
+ @override(EnvRunner)
138
+ def sample(
139
+ self,
140
+ *,
141
+ num_timesteps: int = None,
142
+ num_episodes: int = None,
143
+ explore: bool = None,
144
+ random_actions: bool = False,
145
+ force_reset: bool = False,
146
+ ) -> List[MultiAgentEpisode]:
147
+ """Runs and returns a sample (n timesteps or m episodes) on the env(s).
148
+
149
+ Args:
150
+ num_timesteps: The number of timesteps to sample during this call.
151
+ Note that only one of `num_timetseps` or `num_episodes` may be provided.
152
+ num_episodes: The number of episodes to sample during this call.
153
+ Note that only one of `num_timetseps` or `num_episodes` may be provided.
154
+ explore: If True, will use the RLModule's `forward_exploration()`
155
+ method to compute actions. If False, will use the RLModule's
156
+ `forward_inference()` method. If None (default), will use the `explore`
157
+ boolean setting from `self.config` passed into this EnvRunner's
158
+ constructor. You can change this setting in your config via
159
+ `config.env_runners(explore=True|False)`.
160
+ random_actions: If True, actions will be sampled randomly (from the action
161
+ space of the environment). If False (default), actions or action
162
+ distribution parameters are computed by the RLModule.
163
+ force_reset: Whether to force-reset all (vector) environments before
164
+ sampling. Useful if you would like to collect a clean slate of new
165
+ episodes via this call. Note that when sampling n episodes
166
+ (`num_episodes != None`), this is fixed to True.
167
+
168
+ Returns:
169
+ A list of `MultiAgentEpisode` instances, carrying the sampled data.
170
+ """
171
+ assert not (num_timesteps is not None and num_episodes is not None)
172
+
173
+ # Log time between `sample()` requests.
174
+ if self._time_after_sampling is not None:
175
+ self.metrics.log_value(
176
+ key=TIME_BETWEEN_SAMPLING,
177
+ value=time.perf_counter() - self._time_after_sampling,
178
+ )
179
+
180
+ # If no execution details are provided, use the config to try to infer the
181
+ # desired timesteps/episodes to sample and the exploration behavior.
182
+ if explore is None:
183
+ explore = self.config.explore
184
+ if num_timesteps is None and num_episodes is None:
185
+ if self.config.batch_mode == "truncate_episodes":
186
+ num_timesteps = self.config.get_rollout_fragment_length(
187
+ worker_index=self.worker_index,
188
+ )
189
+ else:
190
+ num_episodes = 1
191
+
192
+ # Sample n timesteps.
193
+ if num_timesteps is not None:
194
+ samples = self._sample_timesteps(
195
+ num_timesteps=num_timesteps,
196
+ explore=explore,
197
+ random_actions=random_actions,
198
+ force_reset=force_reset,
199
+ )
200
+ # Sample m episodes.
201
+ else:
202
+ samples = self._sample_episodes(
203
+ num_episodes=num_episodes,
204
+ explore=explore,
205
+ random_actions=random_actions,
206
+ )
207
+
208
+ # Make the `on_sample_end` callback.
209
+ make_callback(
210
+ "on_sample_end",
211
+ callbacks_objects=self._callbacks,
212
+ callbacks_functions=self.config.callbacks_on_sample_end,
213
+ kwargs=dict(
214
+ env_runner=self,
215
+ metrics_logger=self.metrics,
216
+ samples=samples,
217
+ ),
218
+ )
219
+
220
+ self._time_after_sampling = time.perf_counter()
221
+
222
+ return samples
223
+
224
+ def _sample_timesteps(
225
+ self,
226
+ num_timesteps: int,
227
+ explore: bool,
228
+ random_actions: bool = False,
229
+ force_reset: bool = False,
230
+ ) -> List[MultiAgentEpisode]:
231
+ """Helper method to sample n timesteps.
232
+
233
+ Args:
234
+ num_timesteps: int. Number of timesteps to sample during rollout.
235
+ explore: boolean. If in exploration or inference mode. Exploration
236
+ mode might for some algorithms provide extza model outputs that
237
+ are redundant in inference mode.
238
+ random_actions: boolean. If actions should be sampled from the action
239
+ space. In default mode (i.e. `False`) we sample actions frokm the
240
+ policy.
241
+
242
+ Returns:
243
+ `Lists of `MultiAgentEpisode` instances, carrying the collected sample data.
244
+ """
245
+ done_episodes_to_return: List[MultiAgentEpisode] = []
246
+
247
+ # Have to reset the env.
248
+ if force_reset or self._needs_initial_reset:
249
+ # Create n new episodes and make the `on_episode_created` callbacks.
250
+ self._episode = self._new_episode()
251
+ self._make_on_episode_callback("on_episode_created")
252
+
253
+ # Erase all cached ongoing episodes (these will never be completed and
254
+ # would thus never be returned/cleaned by `get_metrics` and cause a memory
255
+ # leak).
256
+ self._ongoing_episodes_for_metrics.clear()
257
+
258
+ # Try resetting the environment.
259
+ # TODO (simon): Check, if we need here the seed from the config.
260
+ obs, infos = self._try_env_reset()
261
+
262
+ self._cached_to_module = None
263
+
264
+ # Call `on_episode_start()` callbacks.
265
+ self._make_on_episode_callback("on_episode_start")
266
+
267
+ # We just reset the env. Don't have to force this again in the next
268
+ # call to `self._sample_timesteps()`.
269
+ self._needs_initial_reset = False
270
+
271
+ # Set the initial observations in the episodes.
272
+ self._episode.add_env_reset(observations=obs, infos=infos)
273
+
274
+ self._shared_data = {
275
+ "agent_to_module_mapping_fn": self.config.policy_mapping_fn,
276
+ }
277
+
278
+ # Loop through timesteps.
279
+ ts = 0
280
+
281
+ while ts < num_timesteps:
282
+ # Act randomly.
283
+ if random_actions:
284
+ # Only act (randomly) for those agents that had an observation.
285
+ to_env = {
286
+ Columns.ACTIONS: [
287
+ {
288
+ aid: self.env.unwrapped.get_action_space(aid).sample()
289
+ for aid in self._episode.get_agents_to_act()
290
+ }
291
+ ]
292
+ }
293
+ # Compute an action using the RLModule.
294
+ else:
295
+ # Env-to-module connector.
296
+ to_module = self._cached_to_module or self._env_to_module(
297
+ rl_module=self.module,
298
+ episodes=[self._episode],
299
+ explore=explore,
300
+ shared_data=self._shared_data,
301
+ metrics=self.metrics,
302
+ )
303
+ self._cached_to_module = None
304
+
305
+ # MultiRLModule forward pass: Explore or not.
306
+ if explore:
307
+ env_steps_lifetime = (
308
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
309
+ + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
310
+ ) * (self.config.num_env_runners or 1)
311
+ to_env = self.module.forward_exploration(
312
+ to_module, t=env_steps_lifetime
313
+ )
314
+ else:
315
+ to_env = self.module.forward_inference(to_module)
316
+
317
+ # Module-to-env connector.
318
+ to_env = self._module_to_env(
319
+ rl_module=self.module,
320
+ batch=to_env,
321
+ episodes=[self._episode],
322
+ explore=explore,
323
+ shared_data=self._shared_data,
324
+ metrics=self.metrics,
325
+ )
326
+
327
+ # Extract the (vectorized) actions (to be sent to the env) from the
328
+ # module/connector output. Note that these actions are fully ready (e.g.
329
+ # already unsquashed/clipped) to be sent to the environment) and might not
330
+ # be identical to the actions produced by the RLModule/distribution, which
331
+ # are the ones stored permanently in the episode objects.
332
+ actions = to_env.pop(Columns.ACTIONS)
333
+ actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
334
+
335
+ # Try stepping the environment.
336
+ # TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env.
337
+ # Support vectorized multi-agent envs.
338
+ results = self._try_env_step(actions_for_env[0])
339
+ # If any failure occurs during stepping -> Throw away all data collected
340
+ # thus far and restart sampling procedure.
341
+ if results == ENV_STEP_FAILURE:
342
+ return self._sample_timesteps(
343
+ num_timesteps=num_timesteps,
344
+ explore=explore,
345
+ random_actions=random_actions,
346
+ force_reset=True,
347
+ )
348
+ obs, rewards, terminateds, truncateds, infos = results
349
+
350
+ # TODO (sven): This simple approach to re-map `to_env` from a
351
+ # dict[col, List[MADict]] to a dict[agentID, MADict] would not work for
352
+ # a vectorized env.
353
+ extra_model_outputs = defaultdict(dict)
354
+ for col, ma_dict_list in to_env.items():
355
+ # TODO (sven): Support vectorized MA env.
356
+ ma_dict = ma_dict_list[0]
357
+ for agent_id, val in ma_dict.items():
358
+ extra_model_outputs[agent_id][col] = val
359
+ extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no
360
+ extra_model_outputs = dict(extra_model_outputs)
361
+
362
+ # Record the timestep in the episode instance.
363
+ self._episode.add_env_step(
364
+ obs,
365
+ actions[0],
366
+ rewards,
367
+ infos=infos,
368
+ terminateds=terminateds,
369
+ truncateds=truncateds,
370
+ extra_model_outputs=extra_model_outputs,
371
+ )
372
+
373
+ ts += self._increase_sampled_metrics(self.num_envs, obs, self._episode)
374
+
375
+ # Make the `on_episode_step` callback (before finalizing the episode
376
+ # object).
377
+ self._make_on_episode_callback("on_episode_step")
378
+
379
+ # Episode is done for all agents. Wrap up the old one and create a new
380
+ # one (and reset it) to continue.
381
+ if self._episode.is_done:
382
+ # We have to perform an extra env-to-module pass here, just in case
383
+ # the user's connector pipeline performs (permanent) transforms
384
+ # on each observation (including this final one here). Without such
385
+ # a call and in case the structure of the observations change
386
+ # sufficiently, the following `to_numpy()` call on the episode will
387
+ # fail.
388
+ if self.module is not None:
389
+ self._env_to_module(
390
+ episodes=[self._episode],
391
+ explore=explore,
392
+ rl_module=self.module,
393
+ shared_data=self._shared_data,
394
+ metrics=self.metrics,
395
+ )
396
+
397
+ # Make the `on_episode_end` callback (before finalizing the episode,
398
+ # but after(!) the last env-to-module connector call has been made.
399
+ # -> All obs (even the terminal one) should have been processed now (by
400
+ # the connector, if applicable).
401
+ self._make_on_episode_callback("on_episode_end")
402
+
403
+ self._prune_zero_len_sa_episodes(self._episode)
404
+
405
+ # Numpy'ize the episode.
406
+ if self.config.episodes_to_numpy:
407
+ done_episodes_to_return.append(self._episode.to_numpy())
408
+ # Leave episode as lists of individual (obs, action, etc..) items.
409
+ else:
410
+ done_episodes_to_return.append(self._episode)
411
+
412
+ # Create a new episode instance.
413
+ self._episode = self._new_episode()
414
+ self._make_on_episode_callback("on_episode_created")
415
+
416
+ # Reset the environment.
417
+ obs, infos = self._try_env_reset()
418
+ # Add initial observations and infos.
419
+ self._episode.add_env_reset(observations=obs, infos=infos)
420
+
421
+ # Make the `on_episode_start` callback.
422
+ self._make_on_episode_callback("on_episode_start")
423
+
424
+ # Already perform env-to-module connector call for next call to
425
+ # `_sample_timesteps()`. See comment in c'tor for `self._cached_to_module`.
426
+ if self.module is not None:
427
+ self._cached_to_module = self._env_to_module(
428
+ rl_module=self.module,
429
+ episodes=[self._episode],
430
+ explore=explore,
431
+ shared_data=self._shared_data,
432
+ metrics=self.metrics,
433
+ )
434
+
435
+ # Store done episodes for metrics.
436
+ self._done_episodes_for_metrics.extend(done_episodes_to_return)
437
+
438
+ # Also, make sure we start new episode chunks (continuing the ongoing episodes
439
+ # from the to-be-returned chunks).
440
+ ongoing_episode_continuation = self._episode.cut(
441
+ len_lookback_buffer=self.config.episode_lookback_horizon
442
+ )
443
+
444
+ ongoing_episodes_to_return = []
445
+ # Just started Episodes do not have to be returned. There is no data
446
+ # in them anyway.
447
+ if self._episode.env_t > 0:
448
+ self._episode.validate()
449
+ self._ongoing_episodes_for_metrics[self._episode.id_].append(self._episode)
450
+
451
+ self._prune_zero_len_sa_episodes(self._episode)
452
+
453
+ # Numpy'ize the episode.
454
+ if self.config.episodes_to_numpy:
455
+ ongoing_episodes_to_return.append(self._episode.to_numpy())
456
+ # Leave episode as lists of individual (obs, action, etc..) items.
457
+ else:
458
+ ongoing_episodes_to_return.append(self._episode)
459
+
460
+ # Continue collecting into the cut Episode chunk.
461
+ self._episode = ongoing_episode_continuation
462
+
463
+ # Return collected episode data.
464
+ return done_episodes_to_return + ongoing_episodes_to_return
465
+
466
+ def _sample_episodes(
467
+ self,
468
+ num_episodes: int,
469
+ explore: bool,
470
+ random_actions: bool = False,
471
+ ) -> List[MultiAgentEpisode]:
472
+ """Helper method to run n episodes.
473
+
474
+ See docstring of `self.sample()` for more details.
475
+ """
476
+ # If user calls sample(num_timesteps=..) after this, we must reset again
477
+ # at the beginning.
478
+ self._needs_initial_reset = True
479
+
480
+ done_episodes_to_return: List[MultiAgentEpisode] = []
481
+
482
+ # Create a new multi-agent episode.
483
+ _episode = self._new_episode()
484
+ self._make_on_episode_callback("on_episode_created", _episode)
485
+ _shared_data = {
486
+ "agent_to_module_mapping_fn": self.config.policy_mapping_fn,
487
+ }
488
+
489
+ # Try resetting the environment.
490
+ # TODO (simon): Check, if we need here the seed from the config.
491
+ obs, infos = self._try_env_reset()
492
+ # Set initial obs and infos in the episodes.
493
+ _episode.add_env_reset(observations=obs, infos=infos)
494
+ self._make_on_episode_callback("on_episode_start", _episode)
495
+
496
+ # Loop over episodes.
497
+ eps = 0
498
+ ts = 0
499
+ while eps < num_episodes:
500
+ # Act randomly.
501
+ if random_actions:
502
+ # Only act (randomly) for those agents that had an observation.
503
+ to_env = {
504
+ Columns.ACTIONS: [
505
+ {
506
+ aid: self.env.unwrapped.get_action_space(aid).sample()
507
+ for aid in self._episode.get_agents_to_act()
508
+ }
509
+ ]
510
+ }
511
+ # Compute an action using the RLModule.
512
+ else:
513
+ # Env-to-module connector.
514
+ to_module = self._env_to_module(
515
+ rl_module=self.module,
516
+ episodes=[_episode],
517
+ explore=explore,
518
+ shared_data=_shared_data,
519
+ metrics=self.metrics,
520
+ )
521
+
522
+ # MultiRLModule forward pass: Explore or not.
523
+ if explore:
524
+ env_steps_lifetime = (
525
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
526
+ + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
527
+ ) * (self.config.num_env_runners or 1)
528
+ to_env = self.module.forward_exploration(
529
+ to_module, t=env_steps_lifetime
530
+ )
531
+ else:
532
+ to_env = self.module.forward_inference(to_module)
533
+
534
+ # Module-to-env connector.
535
+ to_env = self._module_to_env(
536
+ rl_module=self.module,
537
+ batch=to_env,
538
+ episodes=[_episode],
539
+ explore=explore,
540
+ shared_data=_shared_data,
541
+ metrics=self.metrics,
542
+ )
543
+
544
+ # Extract the (vectorized) actions (to be sent to the env) from the
545
+ # module/connector output. Note that these actions are fully ready (e.g.
546
+ # already unsquashed/clipped) to be sent to the environment) and might not
547
+ # be identical to the actions produced by the RLModule/distribution, which
548
+ # are the ones stored permanently in the episode objects.
549
+ actions = to_env.pop(Columns.ACTIONS)
550
+ actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
551
+
552
+ # Try stepping the environment.
553
+ # TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env.
554
+ # Support vectorized multi-agent envs.
555
+ results = self._try_env_step(actions_for_env[0])
556
+ # If any failure occurs during stepping -> Throw away all data collected
557
+ # thus far and restart sampling procedure.
558
+ if results == ENV_STEP_FAILURE:
559
+ return self._sample_episodes(
560
+ num_episodes=num_episodes,
561
+ explore=explore,
562
+ random_actions=random_actions,
563
+ )
564
+ obs, rewards, terminateds, truncateds, infos = results
565
+
566
+ # TODO (sven): This simple approach to re-map `to_env` from a
567
+ # dict[col, List[MADict]] to a dict[agentID, MADict] would not work for
568
+ # a vectorized env.
569
+ extra_model_outputs = defaultdict(dict)
570
+ for col, ma_dict_list in to_env.items():
571
+ # TODO (sven): Support vectorized MA env.
572
+ ma_dict = ma_dict_list[0]
573
+ for agent_id, val in ma_dict.items():
574
+ extra_model_outputs[agent_id][col] = val
575
+ extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no
576
+ extra_model_outputs = dict(extra_model_outputs)
577
+
578
+ # Record the timestep in the episode instance.
579
+ _episode.add_env_step(
580
+ obs,
581
+ actions[0],
582
+ rewards,
583
+ infos=infos,
584
+ terminateds=terminateds,
585
+ truncateds=truncateds,
586
+ extra_model_outputs=extra_model_outputs,
587
+ )
588
+
589
+ ts += self._increase_sampled_metrics(self.num_envs, obs, _episode)
590
+
591
+ # Make `on_episode_step` callback before finalizing the episode.
592
+ self._make_on_episode_callback("on_episode_step", _episode)
593
+
594
+ # TODO (sven, simon): We have to check, if we need this elaborate
595
+ # function here or if the `MultiAgentEnv` defines the cases that
596
+ # can happen.
597
+ # Right now we have:
598
+ # 1. Most times only agents that step get `terminated`, `truncated`
599
+ # i.e. the rest we have to check in the episode.
600
+ # 2. There are edge cases like, some agents terminated, all others
601
+ # truncated and vice versa.
602
+ # See also `MultiAgentEpisode` for handling the `__all__`.
603
+ if _episode.is_done:
604
+ # Increase episode count.
605
+ eps += 1
606
+
607
+ # We have to perform an extra env-to-module pass here, just in case
608
+ # the user's connector pipeline performs (permanent) transforms
609
+ # on each observation (including this final one here). Without such
610
+ # a call and in case the structure of the observations change
611
+ # sufficiently, the following `to_numpy()` call on the episode will
612
+ # fail.
613
+ if self.module is not None:
614
+ self._env_to_module(
615
+ episodes=[_episode],
616
+ explore=explore,
617
+ rl_module=self.module,
618
+ shared_data=_shared_data,
619
+ metrics=self.metrics,
620
+ )
621
+
622
+ # Make the `on_episode_end` callback (before finalizing the episode,
623
+ # but after(!) the last env-to-module connector call has been made.
624
+ # -> All obs (even the terminal one) should have been processed now (by
625
+ # the connector, if applicable).
626
+ self._make_on_episode_callback("on_episode_end", _episode)
627
+
628
+ self._prune_zero_len_sa_episodes(_episode)
629
+
630
+ # Numpy'ize the episode.
631
+ if self.config.episodes_to_numpy:
632
+ done_episodes_to_return.append(_episode.to_numpy())
633
+ # Leave episode as lists of individual (obs, action, etc..) items.
634
+ else:
635
+ done_episodes_to_return.append(_episode)
636
+
637
+ # Also early-out if we reach the number of episodes within this
638
+ # for-loop.
639
+ if eps == num_episodes:
640
+ break
641
+
642
+ # Create a new episode instance.
643
+ _episode = self._new_episode()
644
+ self._make_on_episode_callback("on_episode_created", _episode)
645
+
646
+ # Try resetting the environment.
647
+ obs, infos = self._try_env_reset()
648
+ # Add initial observations and infos.
649
+ _episode.add_env_reset(observations=obs, infos=infos)
650
+
651
+ # Make `on_episode_start` callback.
652
+ self._make_on_episode_callback("on_episode_start", _episode)
653
+
654
+ self._done_episodes_for_metrics.extend(done_episodes_to_return)
655
+
656
+ return done_episodes_to_return
657
+
658
+ @override(EnvRunner)
659
+ def get_spaces(self):
660
+ # Return the already agent-to-module translated spaces from our connector
661
+ # pipeline.
662
+ return {
663
+ **{
664
+ mid: (o, self._env_to_module.action_space[mid])
665
+ for mid, o in self._env_to_module.observation_space.spaces.items()
666
+ },
667
+ }
668
+
669
+ @override(EnvRunner)
670
+ def get_metrics(self) -> ResultDict:
671
+ # Compute per-episode metrics (only on already completed episodes).
672
+ for eps in self._done_episodes_for_metrics:
673
+ assert eps.is_done
674
+ episode_length = len(eps)
675
+ agent_steps = defaultdict(
676
+ int,
677
+ {str(aid): len(sa_eps) for aid, sa_eps in eps.agent_episodes.items()},
678
+ )
679
+ episode_return = eps.get_return()
680
+ episode_duration_s = eps.get_duration_s()
681
+
682
+ agent_episode_returns = defaultdict(
683
+ float,
684
+ {
685
+ str(sa_eps.agent_id): sa_eps.get_return()
686
+ for sa_eps in eps.agent_episodes.values()
687
+ },
688
+ )
689
+ module_episode_returns = defaultdict(
690
+ float,
691
+ {
692
+ sa_eps.module_id: sa_eps.get_return()
693
+ for sa_eps in eps.agent_episodes.values()
694
+ },
695
+ )
696
+
697
+ # Don't forget about the already returned chunks of this episode.
698
+ if eps.id_ in self._ongoing_episodes_for_metrics:
699
+ for eps2 in self._ongoing_episodes_for_metrics[eps.id_]:
700
+ return_eps2 = eps2.get_return()
701
+ episode_length += len(eps2)
702
+ episode_return += return_eps2
703
+ episode_duration_s += eps2.get_duration_s()
704
+
705
+ for sa_eps in eps2.agent_episodes.values():
706
+ return_sa = sa_eps.get_return()
707
+ agent_steps[str(sa_eps.agent_id)] += len(sa_eps)
708
+ agent_episode_returns[str(sa_eps.agent_id)] += return_sa
709
+ module_episode_returns[sa_eps.module_id] += return_sa
710
+
711
+ del self._ongoing_episodes_for_metrics[eps.id_]
712
+
713
+ self._log_episode_metrics(
714
+ episode_length,
715
+ episode_return,
716
+ episode_duration_s,
717
+ agent_episode_returns,
718
+ module_episode_returns,
719
+ dict(agent_steps),
720
+ )
721
+
722
+ # Now that we have logged everything, clear cache of done episodes.
723
+ self._done_episodes_for_metrics.clear()
724
+
725
+ # Return reduced metrics.
726
+ return self.metrics.reduce()
727
+
728
+ @override(Checkpointable)
729
+ def get_state(
730
+ self,
731
+ components: Optional[Union[str, Collection[str]]] = None,
732
+ *,
733
+ not_components: Optional[Union[str, Collection[str]]] = None,
734
+ **kwargs,
735
+ ) -> StateDict:
736
+ # Basic state dict.
737
+ state = {
738
+ NUM_ENV_STEPS_SAMPLED_LIFETIME: (
739
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
740
+ ),
741
+ }
742
+
743
+ # RLModule (MultiRLModule) component.
744
+ if self._check_component(COMPONENT_RL_MODULE, components, not_components):
745
+ state[COMPONENT_RL_MODULE] = self.module.get_state(
746
+ components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
747
+ not_components=self._get_subcomponents(
748
+ COMPONENT_RL_MODULE, not_components
749
+ ),
750
+ **kwargs,
751
+ )
752
+ state[WEIGHTS_SEQ_NO] = self._weights_seq_no
753
+
754
+ # Env-to-module connector.
755
+ if self._check_component(
756
+ COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components
757
+ ):
758
+ state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state()
759
+ # Module-to-env connector.
760
+ if self._check_component(
761
+ COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components
762
+ ):
763
+ state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state()
764
+
765
+ return state
766
+
767
+ @override(Checkpointable)
768
+ def set_state(self, state: StateDict) -> None:
769
+ if COMPONENT_ENV_TO_MODULE_CONNECTOR in state:
770
+ self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR])
771
+ if COMPONENT_MODULE_TO_ENV_CONNECTOR in state:
772
+ self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR])
773
+
774
+ # Update RLModule state.
775
+ if COMPONENT_RL_MODULE in state:
776
+ # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
777
+ # update.
778
+ weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
779
+
780
+ # Only update the weigths, if this is the first synchronization or
781
+ # if the weights of this `EnvRunner` lacks behind the actual ones.
782
+ if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
783
+ self.module.set_state(state[COMPONENT_RL_MODULE])
784
+
785
+ # Update weights_seq_no, if the new one is > 0.
786
+ if weights_seq_no > 0:
787
+ self._weights_seq_no = weights_seq_no
788
+
789
+ # Update lifetime counters.
790
+ if NUM_ENV_STEPS_SAMPLED_LIFETIME in state:
791
+ self.metrics.set_value(
792
+ key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
793
+ value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME],
794
+ reduce="sum",
795
+ with_throughput=True,
796
+ )
797
+
798
+ @override(Checkpointable)
799
+ def get_ctor_args_and_kwargs(self):
800
+ return (
801
+ (), # *args
802
+ {"config": self.config}, # **kwargs
803
+ )
804
+
805
+ @override(Checkpointable)
806
+ def get_metadata(self):
807
+ metadata = Checkpointable.get_metadata(self)
808
+ metadata.update(
809
+ {
810
+ # TODO (sven): Maybe add serialized (JSON-writable) config here?
811
+ }
812
+ )
813
+ return metadata
814
+
815
+ @override(Checkpointable)
816
+ def get_checkpointable_components(self):
817
+ return [
818
+ (COMPONENT_RL_MODULE, self.module),
819
+ (COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module),
820
+ (COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env),
821
+ ]
822
+
823
+ @override(EnvRunner)
824
+ def assert_healthy(self):
825
+ """Checks that self.__init__() has been completed properly.
826
+
827
+ Ensures that the instances has a `MultiRLModule` and an
828
+ environment defined.
829
+
830
+ Raises:
831
+ AssertionError: If the EnvRunner Actor has NOT been properly initialized.
832
+ """
833
+ # Make sure, we have built our gym.vector.Env and RLModule properly.
834
+ assert self.env and self.module
835
+
836
+ @override(EnvRunner)
837
+ def make_env(self):
838
+ # If an env already exists, try closing it first (to allow it to properly
839
+ # cleanup).
840
+ if self.env is not None:
841
+ try:
842
+ self.env.close()
843
+ except Exception as e:
844
+ logger.warning(
845
+ "Tried closing the existing env (multi-agent), but failed with "
846
+ f"error: {e.args[0]}"
847
+ )
848
+ del self.env
849
+
850
+ env_ctx = self.config.env_config
851
+ if not isinstance(env_ctx, EnvContext):
852
+ env_ctx = EnvContext(
853
+ env_ctx,
854
+ worker_index=self.worker_index,
855
+ num_workers=self.config.num_env_runners,
856
+ remote=self.config.remote_worker_envs,
857
+ )
858
+
859
+ # No env provided -> Error.
860
+ if not self.config.env:
861
+ raise ValueError(
862
+ "`config.env` is not provided! You should provide a valid environment "
863
+ "to your config through `config.environment([env descriptor e.g. "
864
+ "'CartPole-v1'])`."
865
+ )
866
+ # Register env for the local context.
867
+ # Note, `gym.register` has to be called on each worker.
868
+ elif isinstance(self.config.env, str) and _global_registry.contains(
869
+ ENV_CREATOR, self.config.env
870
+ ):
871
+ entry_point = partial(
872
+ _global_registry.get(ENV_CREATOR, self.config.env),
873
+ env_ctx,
874
+ )
875
+ else:
876
+ entry_point = partial(
877
+ _gym_env_creator,
878
+ env_descriptor=self.config.env,
879
+ env_context=env_ctx,
880
+ )
881
+ gym.register(
882
+ "rllib-multi-agent-env-v0",
883
+ entry_point=entry_point,
884
+ disable_env_checker=True,
885
+ )
886
+
887
+ # Perform actual gym.make call.
888
+ self.env: MultiAgentEnv = gym.make("rllib-multi-agent-env-v0")
889
+ self.num_envs = 1
890
+ # If required, check the created MultiAgentEnv.
891
+ if not self.config.disable_env_checking:
892
+ try:
893
+ check_multiagent_environments(self.env.unwrapped)
894
+ except Exception as e:
895
+ logger.exception(e.args[0])
896
+ # If not required, still check the type (must be MultiAgentEnv).
897
+ else:
898
+ assert isinstance(self.env.unwrapped, MultiAgentEnv), (
899
+ "ERROR: When using the `MultiAgentEnvRunner` the environment needs "
900
+ "to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`."
901
+ )
902
+
903
+ # Set the flag to reset all envs upon the next `sample()` call.
904
+ self._needs_initial_reset = True
905
+
906
+ # Call the `on_environment_created` callback.
907
+ make_callback(
908
+ "on_environment_created",
909
+ callbacks_objects=self._callbacks,
910
+ callbacks_functions=self.config.callbacks_on_environment_created,
911
+ kwargs=dict(
912
+ env_runner=self,
913
+ metrics_logger=self.metrics,
914
+ env=self.env.unwrapped,
915
+ env_context=env_ctx,
916
+ ),
917
+ )
918
+
919
+ @override(EnvRunner)
920
+ def make_module(self):
921
+ try:
922
+ module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec(
923
+ env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True
924
+ )
925
+ # Build the module from its spec.
926
+ self.module = module_spec.build()
927
+ # Move the RLModule to our device.
928
+ # TODO (sven): In order to make this framework-agnostic, we should maybe
929
+ # make the MultiRLModule.build() method accept a device OR create an
930
+ # additional `(Multi)RLModule.to()` override.
931
+ if torch:
932
+ self.module.foreach_module(
933
+ lambda mid, mod: (
934
+ mod.to(self._device)
935
+ if isinstance(mod, torch.nn.Module)
936
+ else mod
937
+ )
938
+ )
939
+
940
+ # If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner
941
+ # will not have an RLModule, but might still be usable with random actions.
942
+ except NotImplementedError:
943
+ self.module = None
944
+
945
+ @override(EnvRunner)
946
+ def stop(self):
947
+ # Note, `MultiAgentEnv` inherits `close()`-method from `gym.Env`.
948
+ self.env.close()
949
+
950
+ def _setup_metrics(self):
951
+ self.metrics = MetricsLogger()
952
+
953
+ self._done_episodes_for_metrics: List[MultiAgentEpisode] = []
954
+ self._ongoing_episodes_for_metrics: DefaultDict[
955
+ EpisodeID, List[MultiAgentEpisode]
956
+ ] = defaultdict(list)
957
+
958
+ def _new_episode(self):
959
+ return MultiAgentEpisode(
960
+ observation_space={
961
+ aid: self.env.unwrapped.get_observation_space(aid)
962
+ for aid in self.env.unwrapped.possible_agents
963
+ },
964
+ action_space={
965
+ aid: self.env.unwrapped.get_action_space(aid)
966
+ for aid in self.env.unwrapped.possible_agents
967
+ },
968
+ agent_to_module_mapping_fn=self.config.policy_mapping_fn,
969
+ )
970
+
971
+ def _make_on_episode_callback(self, which: str, episode=None):
972
+ episode = episode if episode is not None else self._episode
973
+ make_callback(
974
+ which,
975
+ callbacks_objects=self._callbacks,
976
+ callbacks_functions=getattr(self.config, f"callbacks_{which}"),
977
+ kwargs=dict(
978
+ episode=episode,
979
+ env_runner=self,
980
+ metrics_logger=self.metrics,
981
+ env=self.env.unwrapped,
982
+ rl_module=self.module,
983
+ env_index=0,
984
+ ),
985
+ )
986
+
987
+ def _increase_sampled_metrics(self, num_steps, next_obs, episode):
988
+ # Env steps.
989
+ self.metrics.log_value(
990
+ NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True
991
+ )
992
+ self.metrics.log_value(
993
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
994
+ num_steps,
995
+ reduce="sum",
996
+ with_throughput=True,
997
+ )
998
+ # Completed episodes.
999
+ if episode.is_done:
1000
+ self.metrics.log_value(NUM_EPISODES, 1, reduce="sum", clear_on_reduce=True)
1001
+ self.metrics.log_value(NUM_EPISODES_LIFETIME, 1, reduce="sum")
1002
+
1003
+ # TODO (sven): obs is not-vectorized. Support vectorized MA envs.
1004
+ for aid in next_obs:
1005
+ self.metrics.log_value(
1006
+ (NUM_AGENT_STEPS_SAMPLED, str(aid)),
1007
+ 1,
1008
+ reduce="sum",
1009
+ clear_on_reduce=True,
1010
+ )
1011
+ self.metrics.log_value(
1012
+ (NUM_AGENT_STEPS_SAMPLED_LIFETIME, str(aid)),
1013
+ 1,
1014
+ reduce="sum",
1015
+ )
1016
+ self.metrics.log_value(
1017
+ (NUM_MODULE_STEPS_SAMPLED, episode.module_for(aid)),
1018
+ 1,
1019
+ reduce="sum",
1020
+ clear_on_reduce=True,
1021
+ )
1022
+ self.metrics.log_value(
1023
+ (NUM_MODULE_STEPS_SAMPLED_LIFETIME, episode.module_for(aid)),
1024
+ 1,
1025
+ reduce="sum",
1026
+ )
1027
+ return num_steps
1028
+
1029
+ def _log_episode_metrics(
1030
+ self,
1031
+ length,
1032
+ ret,
1033
+ sec,
1034
+ agents=None,
1035
+ modules=None,
1036
+ agent_steps=None,
1037
+ ):
1038
+ # Log general episode metrics.
1039
+ self.metrics.log_dict(
1040
+ {
1041
+ EPISODE_LEN_MEAN: length,
1042
+ EPISODE_RETURN_MEAN: ret,
1043
+ EPISODE_DURATION_SEC_MEAN: sec,
1044
+ **(
1045
+ {
1046
+ # Per-agent returns.
1047
+ "agent_episode_returns_mean": agents,
1048
+ # Per-RLModule returns.
1049
+ "module_episode_returns_mean": modules,
1050
+ "agent_steps": agent_steps,
1051
+ }
1052
+ if agents is not None
1053
+ else {}
1054
+ ),
1055
+ },
1056
+ # To mimick the old API stack behavior, we'll use `window` here for
1057
+ # these particular stats (instead of the default EMA).
1058
+ window=self.config.metrics_num_episodes_for_smoothing,
1059
+ )
1060
+ # For some metrics, log min/max as well.
1061
+ self.metrics.log_dict(
1062
+ {
1063
+ EPISODE_LEN_MIN: length,
1064
+ EPISODE_RETURN_MIN: ret,
1065
+ },
1066
+ reduce="min",
1067
+ window=self.config.metrics_num_episodes_for_smoothing,
1068
+ )
1069
+ self.metrics.log_dict(
1070
+ {
1071
+ EPISODE_LEN_MAX: length,
1072
+ EPISODE_RETURN_MAX: ret,
1073
+ },
1074
+ reduce="max",
1075
+ window=self.config.metrics_num_episodes_for_smoothing,
1076
+ )
1077
+
1078
+ @staticmethod
1079
+ def _prune_zero_len_sa_episodes(episode: MultiAgentEpisode):
1080
+ for agent_id, agent_eps in episode.agent_episodes.copy().items():
1081
+ if len(agent_eps) == 0:
1082
+ del episode.agent_episodes[agent_id]
1083
+
1084
+ @Deprecated(
1085
+ new="MultiAgentEnvRunner.get_state(components='rl_module')",
1086
+ error=False,
1087
+ )
1088
+ def get_weights(self, modules=None):
1089
+ rl_module_state = self.get_state(components=COMPONENT_RL_MODULE)[
1090
+ COMPONENT_RL_MODULE
1091
+ ]
1092
+ return rl_module_state
1093
+
1094
+ @Deprecated(new="MultiAgentEnvRunner.set_state()", error=False)
1095
+ def set_weights(
1096
+ self,
1097
+ weights: ModelWeights,
1098
+ global_vars: Optional[Dict] = None,
1099
+ weights_seq_no: int = 0,
1100
+ ) -> None:
1101
+ assert global_vars is None
1102
+ return self.set_state(
1103
+ {
1104
+ COMPONENT_RL_MODULE: weights,
1105
+ WEIGHTS_SEQ_NO: weights_seq_no,
1106
+ }
1107
+ )
.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """REST client to interact with a policy server.
2
+
3
+ This client supports both local and remote policy inference modes. Local
4
+ inference is faster but causes more compute to be done on the client.
5
+ """
6
+
7
+ import logging
8
+ import threading
9
+ import time
10
+ from typing import Union, Optional
11
+
12
+ import ray.cloudpickle as pickle
13
+ from ray.rllib.env.external_env import ExternalEnv
14
+ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
15
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
16
+ from ray.rllib.policy.sample_batch import MultiAgentBatch
17
+ from ray.rllib.utils.annotations import OldAPIStack
18
+ from ray.rllib.utils.typing import (
19
+ MultiAgentDict,
20
+ EnvInfoDict,
21
+ EnvObsType,
22
+ EnvActionType,
23
+ )
24
+
25
+ # Backward compatibility.
26
+ from ray.rllib.env.utils.external_env_protocol import RLlink as Commands
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ try:
31
+ import requests # `requests` is not part of stdlib.
32
+ except ImportError:
33
+ requests = None
34
+ logger.warning(
35
+ "Couldn't import `requests` library. Be sure to install it on"
36
+ " the client side."
37
+ )
38
+
39
+
40
+ @OldAPIStack
41
+ class PolicyClient:
42
+ """REST client to interact with an RLlib policy server."""
43
+
44
+ def __init__(
45
+ self,
46
+ address: str,
47
+ inference_mode: str = "local",
48
+ update_interval: float = 10.0,
49
+ session: Optional[requests.Session] = None,
50
+ ):
51
+ """Create a PolicyClient instance.
52
+
53
+ Args:
54
+ address: Server to connect to (e.g., "localhost:9090").
55
+ inference_mode: Whether to use 'local' or 'remote' policy
56
+ inference for computing actions.
57
+ update_interval (float or None): If using 'local' inference mode,
58
+ the policy is refreshed after this many seconds have passed,
59
+ or None for manual control via client.
60
+ session (requests.Session or None): If available the session object
61
+ is used to communicate with the policy server. Using a session
62
+ can lead to speedups as connections are reused. It is the
63
+ responsibility of the creator of the session to close it.
64
+ """
65
+ self.address = address
66
+ self.session = session
67
+ self.env: ExternalEnv = None
68
+ if inference_mode == "local":
69
+ self.local = True
70
+ self._setup_local_rollout_worker(update_interval)
71
+ elif inference_mode == "remote":
72
+ self.local = False
73
+ else:
74
+ raise ValueError("inference_mode must be either 'local' or 'remote'")
75
+
76
+ def start_episode(
77
+ self, episode_id: Optional[str] = None, training_enabled: bool = True
78
+ ) -> str:
79
+ """Record the start of one or more episode(s).
80
+
81
+ Args:
82
+ episode_id (Optional[str]): Unique string id for the episode or
83
+ None for it to be auto-assigned.
84
+ training_enabled: Whether to use experiences for this
85
+ episode to improve the policy.
86
+
87
+ Returns:
88
+ episode_id: Unique string id for the episode.
89
+ """
90
+
91
+ if self.local:
92
+ self._update_local_policy()
93
+ return self.env.start_episode(episode_id, training_enabled)
94
+
95
+ return self._send(
96
+ {
97
+ "episode_id": episode_id,
98
+ "command": Commands.START_EPISODE,
99
+ "training_enabled": training_enabled,
100
+ }
101
+ )["episode_id"]
102
+
103
+ def get_action(
104
+ self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
105
+ ) -> Union[EnvActionType, MultiAgentDict]:
106
+ """Record an observation and get the on-policy action.
107
+
108
+ Args:
109
+ episode_id: Episode id returned from start_episode().
110
+ observation: Current environment observation.
111
+
112
+ Returns:
113
+ action: Action from the env action space.
114
+ """
115
+
116
+ if self.local:
117
+ self._update_local_policy()
118
+ if isinstance(episode_id, (list, tuple)):
119
+ actions = {
120
+ eid: self.env.get_action(eid, observation[eid])
121
+ for eid in episode_id
122
+ }
123
+ return actions
124
+ else:
125
+ return self.env.get_action(episode_id, observation)
126
+ else:
127
+ return self._send(
128
+ {
129
+ "command": Commands.GET_ACTION,
130
+ "observation": observation,
131
+ "episode_id": episode_id,
132
+ }
133
+ )["action"]
134
+
135
+ def log_action(
136
+ self,
137
+ episode_id: str,
138
+ observation: Union[EnvObsType, MultiAgentDict],
139
+ action: Union[EnvActionType, MultiAgentDict],
140
+ ) -> None:
141
+ """Record an observation and (off-policy) action taken.
142
+
143
+ Args:
144
+ episode_id: Episode id returned from start_episode().
145
+ observation: Current environment observation.
146
+ action: Action for the observation.
147
+ """
148
+
149
+ if self.local:
150
+ self._update_local_policy()
151
+ return self.env.log_action(episode_id, observation, action)
152
+
153
+ self._send(
154
+ {
155
+ "command": Commands.LOG_ACTION,
156
+ "observation": observation,
157
+ "action": action,
158
+ "episode_id": episode_id,
159
+ }
160
+ )
161
+
162
+ def log_returns(
163
+ self,
164
+ episode_id: str,
165
+ reward: float,
166
+ info: Union[EnvInfoDict, MultiAgentDict] = None,
167
+ multiagent_done_dict: Optional[MultiAgentDict] = None,
168
+ ) -> None:
169
+ """Record returns from the environment.
170
+
171
+ The reward will be attributed to the previous action taken by the
172
+ episode. Rewards accumulate until the next action. If no reward is
173
+ logged before the next action, a reward of 0.0 is assumed.
174
+
175
+ Args:
176
+ episode_id: Episode id returned from start_episode().
177
+ reward: Reward from the environment.
178
+ info: Extra info dict.
179
+ multiagent_done_dict: Multi-agent done information.
180
+ """
181
+
182
+ if self.local:
183
+ self._update_local_policy()
184
+ if multiagent_done_dict is not None:
185
+ assert isinstance(reward, dict)
186
+ return self.env.log_returns(
187
+ episode_id, reward, info, multiagent_done_dict
188
+ )
189
+ return self.env.log_returns(episode_id, reward, info)
190
+
191
+ self._send(
192
+ {
193
+ "command": Commands.LOG_RETURNS,
194
+ "reward": reward,
195
+ "info": info,
196
+ "episode_id": episode_id,
197
+ "done": multiagent_done_dict,
198
+ }
199
+ )
200
+
201
+ def end_episode(
202
+ self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
203
+ ) -> None:
204
+ """Record the end of an episode.
205
+
206
+ Args:
207
+ episode_id: Episode id returned from start_episode().
208
+ observation: Current environment observation.
209
+ """
210
+
211
+ if self.local:
212
+ self._update_local_policy()
213
+ return self.env.end_episode(episode_id, observation)
214
+
215
+ self._send(
216
+ {
217
+ "command": Commands.END_EPISODE,
218
+ "observation": observation,
219
+ "episode_id": episode_id,
220
+ }
221
+ )
222
+
223
+ def update_policy_weights(self) -> None:
224
+ """Query the server for new policy weights, if local inference is enabled."""
225
+ self._update_local_policy(force=True)
226
+
227
+ def _send(self, data):
228
+ payload = pickle.dumps(data)
229
+
230
+ if self.session is None:
231
+ response = requests.post(self.address, data=payload)
232
+ else:
233
+ response = self.session.post(self.address, data=payload)
234
+
235
+ if response.status_code != 200:
236
+ logger.error("Request failed {}: {}".format(response.text, data))
237
+ response.raise_for_status()
238
+ parsed = pickle.loads(response.content)
239
+ return parsed
240
+
241
+ def _setup_local_rollout_worker(self, update_interval):
242
+ self.update_interval = update_interval
243
+ self.last_updated = 0
244
+
245
+ logger.info("Querying server for rollout worker settings.")
246
+ kwargs = self._send(
247
+ {
248
+ "command": Commands.GET_WORKER_ARGS,
249
+ }
250
+ )["worker_args"]
251
+ (self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
252
+ kwargs, self._send
253
+ )
254
+ self.env = self.rollout_worker.env
255
+
256
+ def _update_local_policy(self, force=False):
257
+ assert self.inference_thread.is_alive()
258
+ if (
259
+ self.update_interval
260
+ and time.time() - self.last_updated > self.update_interval
261
+ ) or force:
262
+ logger.info("Querying server for new policy weights.")
263
+ resp = self._send(
264
+ {
265
+ "command": Commands.GET_WEIGHTS,
266
+ }
267
+ )
268
+ weights = resp["weights"]
269
+ global_vars = resp["global_vars"]
270
+ logger.info(
271
+ "Updating rollout worker weights and global vars {}.".format(
272
+ global_vars
273
+ )
274
+ )
275
+ self.rollout_worker.set_weights(weights, global_vars)
276
+ self.last_updated = time.time()
277
+
278
+
279
+ class _LocalInferenceThread(threading.Thread):
280
+ """Thread that handles experience generation (worker.sample() loop)."""
281
+
282
+ def __init__(self, rollout_worker, send_fn):
283
+ super().__init__()
284
+ self.daemon = True
285
+ self.rollout_worker = rollout_worker
286
+ self.send_fn = send_fn
287
+
288
+ def run(self):
289
+ try:
290
+ while True:
291
+ logger.info("Generating new batch of experiences.")
292
+ samples = self.rollout_worker.sample()
293
+ metrics = self.rollout_worker.get_metrics()
294
+ if isinstance(samples, MultiAgentBatch):
295
+ logger.info(
296
+ "Sending batch of {} env steps ({} agent steps) to "
297
+ "server.".format(samples.env_steps(), samples.agent_steps())
298
+ )
299
+ else:
300
+ logger.info(
301
+ "Sending batch of {} steps back to server.".format(
302
+ samples.count
303
+ )
304
+ )
305
+ self.send_fn(
306
+ {
307
+ "command": Commands.REPORT_SAMPLES,
308
+ "samples": samples,
309
+ "metrics": metrics,
310
+ }
311
+ )
312
+ except Exception as e:
313
+ logger.error("Error: inference worker thread died!", e)
314
+
315
+
316
+ def _auto_wrap_external(real_env_creator):
317
+ """Wrap an environment in the ExternalEnv interface if needed.
318
+
319
+ Args:
320
+ real_env_creator: Create an env given the env_config.
321
+ """
322
+
323
+ def wrapped_creator(env_config):
324
+ real_env = real_env_creator(env_config)
325
+ if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
326
+ logger.info(
327
+ "The env you specified is not a supported (sub-)type of "
328
+ "ExternalEnv. Attempting to convert it automatically to "
329
+ "ExternalEnv."
330
+ )
331
+
332
+ if isinstance(real_env, MultiAgentEnv):
333
+ external_cls = ExternalMultiAgentEnv
334
+ else:
335
+ external_cls = ExternalEnv
336
+
337
+ class _ExternalEnvWrapper(external_cls):
338
+ def __init__(self, real_env):
339
+ super().__init__(
340
+ observation_space=real_env.observation_space,
341
+ action_space=real_env.action_space,
342
+ )
343
+
344
+ def run(self):
345
+ # Since we are calling methods on this class in the
346
+ # client, run doesn't need to do anything.
347
+ time.sleep(999999)
348
+
349
+ return _ExternalEnvWrapper(real_env)
350
+ return real_env
351
+
352
+ return wrapped_creator
353
+
354
+
355
+ def _create_embedded_rollout_worker(kwargs, send_fn):
356
+ """Create a local rollout worker and a thread that samples from it.
357
+
358
+ Args:
359
+ kwargs: Args for the RolloutWorker constructor.
360
+ send_fn: Function to send a JSON request to the server.
361
+ """
362
+
363
+ # Since the server acts as an input datasource, we have to reset the
364
+ # input config to the default, which runs env rollouts.
365
+ kwargs = kwargs.copy()
366
+ kwargs["config"] = kwargs["config"].copy(copy_frozen=False)
367
+ config = kwargs["config"]
368
+ config.output = None
369
+ config.input_ = "sampler"
370
+ config.input_config = {}
371
+
372
+ # If server has no env (which is the expected case):
373
+ # Generate a dummy ExternalEnv here using RandomEnv and the
374
+ # given observation/action spaces.
375
+ if config.env is None:
376
+ from ray.rllib.examples.envs.classes.random_env import (
377
+ RandomEnv,
378
+ RandomMultiAgentEnv,
379
+ )
380
+
381
+ env_config = {
382
+ "action_space": config.action_space,
383
+ "observation_space": config.observation_space,
384
+ }
385
+ is_ma = config.is_multi_agent
386
+ kwargs["env_creator"] = _auto_wrap_external(
387
+ lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(env_config)
388
+ )
389
+ # kwargs["config"].env = True
390
+ # Otherwise, use the env specified by the server args.
391
+ else:
392
+ real_env_creator = kwargs["env_creator"]
393
+ kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
394
+
395
+ logger.info("Creating rollout worker with kwargs={}".format(kwargs))
396
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
397
+
398
+ rollout_worker = RolloutWorker(**kwargs)
399
+
400
+ inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
401
+ inference_thread.start()
402
+
403
+ return rollout_worker, inference_thread
.venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
3
+ import logging
4
+ import queue
5
+ from socketserver import ThreadingMixIn
6
+ import threading
7
+ import time
8
+ import traceback
9
+
10
+ from typing import List
11
+ import ray.cloudpickle as pickle
12
+ from ray.rllib.env.policy_client import (
13
+ _create_embedded_rollout_worker,
14
+ Commands,
15
+ )
16
+ from ray.rllib.offline.input_reader import InputReader
17
+ from ray.rllib.offline.io_context import IOContext
18
+ from ray.rllib.policy.sample_batch import SampleBatch
19
+ from ray.rllib.utils.annotations import override, PublicAPI
20
+ from ray.rllib.evaluation.metrics import RolloutMetrics
21
+ from ray.rllib.evaluation.sampler import SamplerInput
22
+ from ray.rllib.utils.typing import SampleBatchType
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @PublicAPI
28
+ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
29
+ """REST policy server that acts as an offline data source.
30
+
31
+ This launches a multi-threaded server that listens on the specified host
32
+ and port to serve policy requests and forward experiences to RLlib. For
33
+ high performance experience collection, it implements InputReader.
34
+
35
+ For an example, run `examples/envs/external_envs/cartpole_server.py` along
36
+ with `examples/envs/external_envs/cartpole_client.py --inference-mode=local|remote`.
37
+
38
+ WARNING: This class is not meant to be publicly exposed. Anyone that can
39
+ communicate with this server can execute arbitary code on the machine. Use
40
+ this with caution, in isolated environments, and at your own risk.
41
+
42
+ .. testcode::
43
+ :skipif: True
44
+
45
+ import gymnasium as gym
46
+ from ray.rllib.algorithms.ppo import PPOConfig
47
+ from ray.rllib.env.policy_client import PolicyClient
48
+ from ray.rllib.env.policy_server_input import PolicyServerInput
49
+ addr, port = ...
50
+ config = (
51
+ PPOConfig()
52
+ .api_stack(
53
+ enable_rl_module_and_learner=False,
54
+ enable_env_runner_and_connector_v2=False,
55
+ )
56
+ .environment("CartPole-v1")
57
+ .offline_data(
58
+ input_=lambda ioctx: PolicyServerInput(ioctx, addr, port)
59
+ )
60
+ # Run just 1 server (in the Algorithm's EnvRunnerGroup).
61
+ .env_runners(num_env_runners=0)
62
+ )
63
+ algo = config.build()
64
+ while True:
65
+ algo.train()
66
+ client = PolicyClient(
67
+ "localhost:9900", inference_mode="local")
68
+ eps_id = client.start_episode()
69
+ env = gym.make("CartPole-v1")
70
+ obs, info = env.reset()
71
+ action = client.get_action(eps_id, obs)
72
+ _, reward, _, _, _ = env.step(action)
73
+ client.log_returns(eps_id, reward)
74
+ client.log_returns(eps_id, reward)
75
+ algo.stop()
76
+ """
77
+
78
+ @PublicAPI
79
+ def __init__(
80
+ self,
81
+ ioctx: IOContext,
82
+ address: str,
83
+ port: int,
84
+ idle_timeout: float = 3.0,
85
+ max_sample_queue_size: int = 20,
86
+ ):
87
+ """Create a PolicyServerInput.
88
+
89
+ This class implements rllib.offline.InputReader, and can be used with
90
+ any Algorithm by configuring
91
+
92
+ [AlgorithmConfig object]
93
+ .env_runners(num_env_runners=0)
94
+ .offline_data(input_=lambda ioctx: PolicyServerInput(ioctx, addr, port))
95
+
96
+ Note that by setting num_env_runners: 0, the algorithm will only create one
97
+ rollout worker / PolicyServerInput. Clients can connect to the launched
98
+ server using rllib.env.PolicyClient. You can increase the number of available
99
+ connections (ports) by setting num_env_runners to a larger number. The ports
100
+ used will then be `port` + the worker's index.
101
+
102
+ Args:
103
+ ioctx: IOContext provided by RLlib.
104
+ address: Server addr (e.g., "localhost").
105
+ port: Server port (e.g., 9900).
106
+ max_queue_size: The maximum size for the sample queue. Once full, will
107
+ purge (throw away) 50% of all samples, oldest first, and continue.
108
+ """
109
+
110
+ self.rollout_worker = ioctx.worker
111
+ # Protect ourselves from having a bottleneck on the server (learning) side.
112
+ # Once the queue (deque) is full, we throw away 50% (oldest
113
+ # samples first) of the samples, warn, and continue.
114
+ self.samples_queue = deque(maxlen=max_sample_queue_size)
115
+ self.metrics_queue = queue.Queue()
116
+ self.idle_timeout = idle_timeout
117
+
118
+ # Forwards client-reported metrics directly into the local rollout
119
+ # worker.
120
+ if self.rollout_worker.sampler is not None:
121
+ # This is a bit of a hack since it is patching the get_metrics
122
+ # function of the sampler.
123
+
124
+ def get_metrics():
125
+ completed = []
126
+ while True:
127
+ try:
128
+ completed.append(self.metrics_queue.get_nowait())
129
+ except queue.Empty:
130
+ break
131
+
132
+ return completed
133
+
134
+ self.rollout_worker.sampler.get_metrics = get_metrics
135
+ else:
136
+ # If there is no sampler, act like if there would be one to collect
137
+ # metrics from
138
+ class MetricsDummySampler(SamplerInput):
139
+ """This sampler only maintains a queue to get metrics from."""
140
+
141
+ def __init__(self, metrics_queue):
142
+ """Initializes a MetricsDummySampler instance.
143
+
144
+ Args:
145
+ metrics_queue: A queue of metrics
146
+ """
147
+ self.metrics_queue = metrics_queue
148
+
149
+ def get_data(self) -> SampleBatchType:
150
+ raise NotImplementedError
151
+
152
+ def get_extra_batches(self) -> List[SampleBatchType]:
153
+ raise NotImplementedError
154
+
155
+ def get_metrics(self) -> List[RolloutMetrics]:
156
+ """Returns metrics computed on a policy client rollout worker."""
157
+ completed = []
158
+ while True:
159
+ try:
160
+ completed.append(self.metrics_queue.get_nowait())
161
+ except queue.Empty:
162
+ break
163
+ return completed
164
+
165
+ self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue)
166
+
167
+ # Create a request handler that receives commands from the clients
168
+ # and sends data and metrics into the queues.
169
+ handler = _make_handler(
170
+ self.rollout_worker, self.samples_queue, self.metrics_queue
171
+ )
172
+ try:
173
+ import time
174
+
175
+ time.sleep(1)
176
+ HTTPServer.__init__(self, (address, port), handler)
177
+ except OSError:
178
+ print(f"Creating a PolicyServer on {address}:{port} failed!")
179
+ import time
180
+
181
+ time.sleep(1)
182
+ raise
183
+
184
+ logger.info(
185
+ "Starting connector server at " f"{self.server_name}:{self.server_port}"
186
+ )
187
+
188
+ # Start the serving thread, listening on socket and handling commands.
189
+ serving_thread = threading.Thread(name="server", target=self.serve_forever)
190
+ serving_thread.daemon = True
191
+ serving_thread.start()
192
+
193
+ # Start a dummy thread that puts empty SampleBatches on the queue, just
194
+ # in case we don't receive anything from clients (or there aren't
195
+ # any). The latter would block sample collection entirely otherwise,
196
+ # even if other workers' PolicyServerInput receive incoming data from
197
+ # actual clients.
198
+ heart_beat_thread = threading.Thread(
199
+ name="heart-beat", target=self._put_empty_sample_batch_every_n_sec
200
+ )
201
+ heart_beat_thread.daemon = True
202
+ heart_beat_thread.start()
203
+
204
+ @override(InputReader)
205
+ def next(self):
206
+ # Blocking wait until there is something in the deque.
207
+ while len(self.samples_queue) == 0:
208
+ time.sleep(0.1)
209
+ # Utilize last items first in order to remain as closely as possible
210
+ # to operating on-policy.
211
+ return self.samples_queue.pop()
212
+
213
+ def _put_empty_sample_batch_every_n_sec(self):
214
+ # Places an empty SampleBatch every `idle_timeout` seconds onto the
215
+ # `samples_queue`. This avoids hanging of all RolloutWorkers parallel
216
+ # to this one in case this PolicyServerInput does not have incoming
217
+ # data (e.g. no client connected) and the driver algorithm uses parallel
218
+ # synchronous sampling (e.g. PPO).
219
+ while True:
220
+ time.sleep(self.idle_timeout)
221
+ self.samples_queue.append(SampleBatch())
222
+
223
+
224
+ def _make_handler(rollout_worker, samples_queue, metrics_queue):
225
+ # Only used in remote inference mode. We must create a new rollout worker
226
+ # then since the original worker doesn't have the env properly wrapped in
227
+ # an ExternalEnv interface.
228
+ child_rollout_worker = None
229
+ inference_thread = None
230
+ lock = threading.Lock()
231
+
232
+ def setup_child_rollout_worker():
233
+ nonlocal lock
234
+
235
+ with lock:
236
+ nonlocal child_rollout_worker
237
+ nonlocal inference_thread
238
+
239
+ if child_rollout_worker is None:
240
+ (
241
+ child_rollout_worker,
242
+ inference_thread,
243
+ ) = _create_embedded_rollout_worker(
244
+ rollout_worker.creation_args(), report_data
245
+ )
246
+ child_rollout_worker.set_weights(rollout_worker.get_weights())
247
+
248
+ def report_data(data):
249
+ nonlocal child_rollout_worker
250
+
251
+ batch = data["samples"]
252
+ batch.decompress_if_needed()
253
+ samples_queue.append(batch)
254
+ # Deque is full -> purge 50% (oldest samples)
255
+ if len(samples_queue) == samples_queue.maxlen:
256
+ logger.warning(
257
+ "PolicyServerInput queue is full! Purging half of the samples (oldest)."
258
+ )
259
+ for _ in range(samples_queue.maxlen // 2):
260
+ samples_queue.popleft()
261
+ for rollout_metric in data["metrics"]:
262
+ metrics_queue.put(rollout_metric)
263
+
264
+ if child_rollout_worker is not None:
265
+ child_rollout_worker.set_weights(
266
+ rollout_worker.get_weights(), rollout_worker.get_global_vars()
267
+ )
268
+
269
+ class Handler(SimpleHTTPRequestHandler):
270
+ def __init__(self, *a, **kw):
271
+ super().__init__(*a, **kw)
272
+
273
+ def do_POST(self):
274
+ content_len = int(self.headers.get("Content-Length"), 0)
275
+ raw_body = self.rfile.read(content_len)
276
+ parsed_input = pickle.loads(raw_body)
277
+ try:
278
+ response = self.execute_command(parsed_input)
279
+ self.send_response(200)
280
+ self.end_headers()
281
+ self.wfile.write(pickle.dumps(response))
282
+ except Exception:
283
+ self.send_error(500, traceback.format_exc())
284
+
285
+ def execute_command(self, args):
286
+ command = args["command"]
287
+ response = {}
288
+
289
+ # Local inference commands:
290
+ if command == Commands.GET_WORKER_ARGS:
291
+ logger.info("Sending worker creation args to client.")
292
+ response["worker_args"] = rollout_worker.creation_args()
293
+ elif command == Commands.GET_WEIGHTS:
294
+ logger.info("Sending worker weights to client.")
295
+ response["weights"] = rollout_worker.get_weights()
296
+ response["global_vars"] = rollout_worker.get_global_vars()
297
+ elif command == Commands.REPORT_SAMPLES:
298
+ logger.info(
299
+ "Got sample batch of size {} from client.".format(
300
+ args["samples"].count
301
+ )
302
+ )
303
+ report_data(args)
304
+
305
+ # Remote inference commands:
306
+ elif command == Commands.START_EPISODE:
307
+ setup_child_rollout_worker()
308
+ assert inference_thread.is_alive()
309
+ response["episode_id"] = child_rollout_worker.env.start_episode(
310
+ args["episode_id"], args["training_enabled"]
311
+ )
312
+ elif command == Commands.GET_ACTION:
313
+ assert inference_thread.is_alive()
314
+ response["action"] = child_rollout_worker.env.get_action(
315
+ args["episode_id"], args["observation"]
316
+ )
317
+ elif command == Commands.LOG_ACTION:
318
+ assert inference_thread.is_alive()
319
+ child_rollout_worker.env.log_action(
320
+ args["episode_id"], args["observation"], args["action"]
321
+ )
322
+ elif command == Commands.LOG_RETURNS:
323
+ assert inference_thread.is_alive()
324
+ if args["done"]:
325
+ child_rollout_worker.env.log_returns(
326
+ args["episode_id"], args["reward"], args["info"], args["done"]
327
+ )
328
+ else:
329
+ child_rollout_worker.env.log_returns(
330
+ args["episode_id"], args["reward"], args["info"]
331
+ )
332
+ elif command == Commands.END_EPISODE:
333
+ assert inference_thread.is_alive()
334
+ child_rollout_worker.env.end_episode(
335
+ args["episode_id"], args["observation"]
336
+ )
337
+ else:
338
+ raise ValueError("Unknown command: {}".format(command))
339
+ return response
340
+
341
+ return Handler
.venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import logging
3
+ from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
4
+
5
+ import ray
6
+ from ray.util import log_once
7
+ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
8
+ from ray.rllib.utils.annotations import override, OldAPIStack
9
+ from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict
10
+
11
+ if TYPE_CHECKING:
12
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @OldAPIStack
18
+ class RemoteBaseEnv(BaseEnv):
19
+ """BaseEnv that executes its sub environments as @ray.remote actors.
20
+
21
+ This provides dynamic batching of inference as observations are returned
22
+ from the remote simulator actors. Both single and multi-agent child envs
23
+ are supported, and envs can be stepped synchronously or asynchronously.
24
+
25
+ NOTE: This class implicitly assumes that the remote envs are gym.Env's
26
+
27
+ You shouldn't need to instantiate this class directly. It's automatically
28
+ inserted when you use the `remote_worker_envs=True` option in your
29
+ Algorithm's config.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ make_env: Callable[[int], EnvType],
35
+ num_envs: int,
36
+ multiagent: bool,
37
+ remote_env_batch_wait_ms: int,
38
+ existing_envs: Optional[List[ray.actor.ActorHandle]] = None,
39
+ worker: Optional["RolloutWorker"] = None,
40
+ restart_failed_sub_environments: bool = False,
41
+ ):
42
+ """Initializes a RemoteVectorEnv instance.
43
+
44
+ Args:
45
+ make_env: Callable that produces a single (non-vectorized) env,
46
+ given the vector env index as only arg.
47
+ num_envs: The number of sub-environments to create for the
48
+ vectorization.
49
+ multiagent: Whether this is a multiagent env or not.
50
+ remote_env_batch_wait_ms: Time to wait for (ray.remote)
51
+ sub-environments to have new observations available when
52
+ polled. Only when none of the sub-environments is ready,
53
+ repeat the `ray.wait()` call until at least one sub-env
54
+ is ready. Then return only the observations of the ready
55
+ sub-environment(s).
56
+ existing_envs: Optional list of already created sub-environments.
57
+ These will be used as-is and only as many new sub-envs as
58
+ necessary (`num_envs - len(existing_envs)`) will be created.
59
+ worker: An optional RolloutWorker that owns the env. This is only
60
+ used if `remote_worker_envs` is True in your config and the
61
+ `on_sub_environment_created` custom callback needs to be
62
+ called on each created actor.
63
+ restart_failed_sub_environments: If True and any sub-environment (within
64
+ a vectorized env) throws any error during env stepping, the
65
+ Sampler will try to restart the faulty sub-environment. This is done
66
+ without disturbing the other (still intact) sub-environment and without
67
+ the RolloutWorker crashing.
68
+ """
69
+
70
+ # Could be creating local or remote envs.
71
+ self.make_env = make_env
72
+ self.num_envs = num_envs
73
+ self.multiagent = multiagent
74
+ self.poll_timeout = remote_env_batch_wait_ms / 1000
75
+ self.worker = worker
76
+ self.restart_failed_sub_environments = restart_failed_sub_environments
77
+
78
+ # Already existing env objects (generated by the RolloutWorker).
79
+ existing_envs = existing_envs or []
80
+
81
+ # Whether the given `make_env` callable already returns ActorHandles
82
+ # (@ray.remote class instances) or not.
83
+ self.make_env_creates_actors = False
84
+
85
+ self._observation_space = None
86
+ self._action_space = None
87
+
88
+ # List of ray actor handles (each handle points to one @ray.remote
89
+ # sub-environment).
90
+ self.actors: Optional[List[ray.actor.ActorHandle]] = None
91
+
92
+ # `self.make_env` already produces Actors: Use it directly.
93
+ if len(existing_envs) > 0 and isinstance(
94
+ existing_envs[0], ray.actor.ActorHandle
95
+ ):
96
+ self.make_env_creates_actors = True
97
+ self.actors = existing_envs
98
+ while len(self.actors) < self.num_envs:
99
+ self.actors.append(self._make_sub_env(len(self.actors)))
100
+
101
+ # `self.make_env` produces gym.Envs (or children thereof, such
102
+ # as MultiAgentEnv): Need to auto-wrap it here. The problem with
103
+ # this is that custom methods wil get lost. If you would like to
104
+ # keep your custom methods in your envs, you should provide the
105
+ # env class directly in your config (w/o tune.register_env()),
106
+ # such that your class can directly be made a @ray.remote
107
+ # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
108
+ # Also, if `len(existing_envs) > 0`, we have to throw those away
109
+ # as we need to create ray actors here.
110
+ else:
111
+ self.actors = [self._make_sub_env(i) for i in range(self.num_envs)]
112
+ # Utilize existing envs for inferring observation/action spaces.
113
+ if len(existing_envs) > 0:
114
+ self._observation_space = existing_envs[0].observation_space
115
+ self._action_space = existing_envs[0].action_space
116
+ # Have to call actors' remote methods to get observation/action spaces.
117
+ else:
118
+ self._observation_space, self._action_space = ray.get(
119
+ [
120
+ self.actors[0].observation_space.remote(),
121
+ self.actors[0].action_space.remote(),
122
+ ]
123
+ )
124
+
125
+ # Dict mapping object refs (return values of @ray.remote calls),
126
+ # whose actual values we are waiting for (via ray.wait in
127
+ # `self.poll()`) to their corresponding actor handles (the actors
128
+ # that created these return values).
129
+ # Call `reset()` on all @ray.remote sub-environment actors.
130
+ self.pending: Dict[ray.actor.ActorHandle] = {
131
+ a.reset.remote(): a for a in self.actors
132
+ }
133
+
134
+ @override(BaseEnv)
135
+ def poll(
136
+ self,
137
+ ) -> Tuple[
138
+ MultiEnvDict,
139
+ MultiEnvDict,
140
+ MultiEnvDict,
141
+ MultiEnvDict,
142
+ MultiEnvDict,
143
+ MultiEnvDict,
144
+ ]:
145
+
146
+ # each keyed by env_id in [0, num_remote_envs)
147
+ obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {}
148
+ ready = []
149
+
150
+ # Wait for at least 1 env to be ready here.
151
+ while not ready:
152
+ ready, _ = ray.wait(
153
+ list(self.pending),
154
+ num_returns=len(self.pending),
155
+ timeout=self.poll_timeout,
156
+ )
157
+
158
+ # Get and return observations for each of the ready envs
159
+ env_ids = set()
160
+ for obj_ref in ready:
161
+ # Get the corresponding actor handle from our dict and remove the
162
+ # object ref (we will call `ray.get()` on it and it will no longer
163
+ # be "pending").
164
+ actor = self.pending.pop(obj_ref)
165
+ env_id = self.actors.index(actor)
166
+ env_ids.add(env_id)
167
+ # Get the ready object ref (this may be return value(s) of
168
+ # `reset()` or `step()`).
169
+ try:
170
+ ret = ray.get(obj_ref)
171
+ except Exception as e:
172
+ # Something happened on the actor during stepping/resetting.
173
+ # Restart sub-environment (create new actor; close old one).
174
+ if self.restart_failed_sub_environments:
175
+ logger.exception(e.args[0])
176
+ self.try_restart(env_id)
177
+ # Always return multi-agent data.
178
+ # Set the observation to the exception, no rewards,
179
+ # terminated[__all__]=True (episode will be discarded anyways),
180
+ # no infos.
181
+ ret = (
182
+ e,
183
+ {},
184
+ {"__all__": True},
185
+ {"__all__": False},
186
+ {},
187
+ )
188
+ # Do not try to restart. Just raise the error.
189
+ else:
190
+ raise e
191
+
192
+ # Our sub-envs are simple Actor-turned gym.Envs or MultiAgentEnvs.
193
+ if self.make_env_creates_actors:
194
+ rew, terminated, truncated, info = None, None, None, None
195
+ if self.multiagent:
196
+ if isinstance(ret, tuple):
197
+ # Gym >= 0.26: `step()` result: Obs, reward, terminated,
198
+ # truncated, info.
199
+ if len(ret) == 5:
200
+ ob, rew, terminated, truncated, info = ret
201
+ # Gym >= 0.26: `reset()` result: Obs and infos.
202
+ elif len(ret) == 2:
203
+ ob = ret[0]
204
+ info = ret[1]
205
+ # Gym < 0.26? Something went wrong.
206
+ else:
207
+ raise AssertionError(
208
+ "Your gymnasium.Env seems to NOT return the correct "
209
+ "number of return values for `step()` (needs to return"
210
+ " 5 values: obs, reward, terminated, truncated and "
211
+ "info) or `reset()` (needs to return 2 values: obs and "
212
+ "info)!"
213
+ )
214
+ # Gym < 0.26: `reset()` result: Only obs.
215
+ else:
216
+ raise AssertionError(
217
+ "Your gymnasium.Env seems to only return a single value "
218
+ "upon `reset()`! Must return 2 (obs AND infos)."
219
+ )
220
+ else:
221
+ if isinstance(ret, tuple):
222
+ # `step()` result: Obs, reward, terminated, truncated, info.
223
+ if len(ret) == 5:
224
+ ob = {_DUMMY_AGENT_ID: ret[0]}
225
+ rew = {_DUMMY_AGENT_ID: ret[1]}
226
+ terminated = {_DUMMY_AGENT_ID: ret[2], "__all__": ret[2]}
227
+ truncated = {_DUMMY_AGENT_ID: ret[3], "__all__": ret[3]}
228
+ info = {_DUMMY_AGENT_ID: ret[4]}
229
+ # `reset()` result: Obs and infos.
230
+ elif len(ret) == 2:
231
+ ob = {_DUMMY_AGENT_ID: ret[0]}
232
+ info = {_DUMMY_AGENT_ID: ret[1]}
233
+ # Gym < 0.26? Something went wrong.
234
+ else:
235
+ raise AssertionError(
236
+ "Your gymnasium.Env seems to NOT return the correct "
237
+ "number of return values for `step()` (needs to return"
238
+ " 5 values: obs, reward, terminated, truncated and "
239
+ "info) or `reset()` (needs to return 2 values: obs and "
240
+ "info)!"
241
+ )
242
+ # Gym < 0.26?
243
+ else:
244
+ raise AssertionError(
245
+ "Your gymnasium.Env seems to only return a single value "
246
+ "upon `reset()`! Must return 2 (obs and infos)."
247
+ )
248
+
249
+ # If this is a `reset()` return value, we only have the initial
250
+ # observations and infos: Set rewards, terminateds, and truncateds to
251
+ # dummy values.
252
+ if rew is None:
253
+ rew = {agent_id: 0 for agent_id in ob.keys()}
254
+ terminated = {"__all__": False}
255
+ truncated = {"__all__": False}
256
+
257
+ # Our sub-envs are auto-wrapped (by `_RemoteSingleAgentEnv` or
258
+ # `_RemoteMultiAgentEnv`) and already behave like multi-agent
259
+ # envs.
260
+ else:
261
+ ob, rew, terminated, truncated, info = ret
262
+ obs[env_id] = ob
263
+ rewards[env_id] = rew
264
+ terminateds[env_id] = terminated
265
+ truncateds[env_id] = truncated
266
+ infos[env_id] = info
267
+
268
+ logger.debug(f"Got obs batch for actors {env_ids}")
269
+ return obs, rewards, terminateds, truncateds, infos, {}
270
+
271
+ @override(BaseEnv)
272
+ def send_actions(self, action_dict: MultiEnvDict) -> None:
273
+ for env_id, actions in action_dict.items():
274
+ actor = self.actors[env_id]
275
+ # `actor` is a simple single-agent (remote) env, e.g. a gym.Env
276
+ # that was made a @ray.remote.
277
+ if not self.multiagent and self.make_env_creates_actors:
278
+ obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID])
279
+ # `actor` is already a _RemoteSingleAgentEnv or
280
+ # _RemoteMultiAgentEnv wrapper
281
+ # (handles the multi-agent action_dict automatically).
282
+ else:
283
+ obj_ref = actor.step.remote(actions)
284
+ self.pending[obj_ref] = actor
285
+
286
+ @override(BaseEnv)
287
+ def try_reset(
288
+ self,
289
+ env_id: Optional[EnvID] = None,
290
+ *,
291
+ seed: Optional[int] = None,
292
+ options: Optional[dict] = None,
293
+ ) -> Tuple[MultiEnvDict, MultiEnvDict]:
294
+ actor = self.actors[env_id]
295
+ obj_ref = actor.reset.remote(seed=seed, options=options)
296
+
297
+ self.pending[obj_ref] = actor
298
+ # Because this env type does not support synchronous reset requests (with
299
+ # immediate return value), we return ASYNC_RESET_RETURN here to indicate
300
+ # that the reset results will be available via the next `poll()` call.
301
+ return ASYNC_RESET_RETURN, ASYNC_RESET_RETURN
302
+
303
+ @override(BaseEnv)
304
+ def try_restart(self, env_id: Optional[EnvID] = None) -> None:
305
+ # Try closing down the old (possibly faulty) sub-env, but ignore errors.
306
+ try:
307
+ # Close the env on the remote side.
308
+ self.actors[env_id].close.remote()
309
+ except Exception as e:
310
+ if log_once("close_sub_env"):
311
+ logger.warning(
312
+ "Trying to close old and replaced sub-environment (at vector "
313
+ f"index={env_id}), but closing resulted in error:\n{e}"
314
+ )
315
+
316
+ # Terminate the actor itself to free up its resources.
317
+ self.actors[env_id].__ray_terminate__.remote()
318
+
319
+ # Re-create a new sub-environment.
320
+ self.actors[env_id] = self._make_sub_env(env_id)
321
+
322
+ @override(BaseEnv)
323
+ def stop(self) -> None:
324
+ if self.actors is not None:
325
+ for actor in self.actors:
326
+ actor.__ray_terminate__.remote()
327
+
328
+ @override(BaseEnv)
329
+ def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
330
+ if as_dict:
331
+ return {env_id: actor for env_id, actor in enumerate(self.actors)}
332
+ return self.actors
333
+
334
+ @property
335
+ @override(BaseEnv)
336
+ def observation_space(self) -> gym.spaces.Dict:
337
+ return self._observation_space
338
+
339
+ @property
340
+ @override(BaseEnv)
341
+ def action_space(self) -> gym.Space:
342
+ return self._action_space
343
+
344
+ def _make_sub_env(self, idx: Optional[int] = None):
345
+ """Re-creates a sub-environment at the new index."""
346
+
347
+ # Our `make_env` creates ray actors directly.
348
+ if self.make_env_creates_actors:
349
+ sub_env = self.make_env(idx)
350
+ if self.worker is not None:
351
+ self.worker.callbacks.on_sub_environment_created(
352
+ worker=self.worker,
353
+ sub_environment=self.actors[idx],
354
+ env_context=self.worker.env_context.copy_with_overrides(
355
+ vector_index=idx
356
+ ),
357
+ )
358
+
359
+ # Our `make_env` returns actual envs -> Have to convert them into actors
360
+ # using our utility wrapper classes.
361
+ else:
362
+
363
+ def make_remote_env(i):
364
+ logger.info("Launching env {} in remote actor".format(i))
365
+ if self.multiagent:
366
+ sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i)
367
+ else:
368
+ sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i)
369
+
370
+ if self.worker is not None:
371
+ self.worker.callbacks.on_sub_environment_created(
372
+ worker=self.worker,
373
+ sub_environment=sub_env,
374
+ env_context=self.worker.env_context.copy_with_overrides(
375
+ vector_index=i
376
+ ),
377
+ )
378
+
379
+ return sub_env
380
+
381
+ sub_env = make_remote_env(idx)
382
+
383
+ return sub_env
384
+
385
+ @override(BaseEnv)
386
+ def get_agent_ids(self) -> Set[AgentID]:
387
+ if self.multiagent:
388
+ return ray.get(self.actors[0].get_agent_ids.remote())
389
+ else:
390
+ return {_DUMMY_AGENT_ID}
391
+
392
+
393
+ @ray.remote(num_cpus=0)
394
+ class _RemoteMultiAgentEnv:
395
+ """Wrapper class for making a multi-agent env a remote actor."""
396
+
397
+ def __init__(self, make_env, i):
398
+ self.env = make_env(i)
399
+ self.agent_ids = set()
400
+
401
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
402
+ obs, info = self.env.reset(seed=seed, options=options)
403
+
404
+ # each keyed by agent_id in the env
405
+ rew = {}
406
+ for agent_id in obs.keys():
407
+ self.agent_ids.add(agent_id)
408
+ rew[agent_id] = 0.0
409
+ terminated = {"__all__": False}
410
+ truncated = {"__all__": False}
411
+ return obs, rew, terminated, truncated, info
412
+
413
+ def step(self, action_dict):
414
+ return self.env.step(action_dict)
415
+
416
+ # Defining these 2 functions that way this information can be queried
417
+ # with a call to ray.get().
418
+ def observation_space(self):
419
+ return self.env.observation_space
420
+
421
+ def action_space(self):
422
+ return self.env.action_space
423
+
424
+ def get_agent_ids(self) -> Set[AgentID]:
425
+ return self.agent_ids
426
+
427
+
428
+ @ray.remote(num_cpus=0)
429
+ class _RemoteSingleAgentEnv:
430
+ """Wrapper class for making a gym env a remote actor."""
431
+
432
+ def __init__(self, make_env, i):
433
+ self.env = make_env(i)
434
+
435
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
436
+ obs_and_info = self.env.reset(seed=seed, options=options)
437
+
438
+ obs = {_DUMMY_AGENT_ID: obs_and_info[0]}
439
+ info = {_DUMMY_AGENT_ID: obs_and_info[1]}
440
+
441
+ rew = {_DUMMY_AGENT_ID: 0.0}
442
+ terminated = {"__all__": False}
443
+ truncated = {"__all__": False}
444
+ return obs, rew, terminated, truncated, info
445
+
446
+ def step(self, action):
447
+ results = self.env.step(action[_DUMMY_AGENT_ID])
448
+
449
+ obs, rew, terminated, truncated, info = [{_DUMMY_AGENT_ID: x} for x in results]
450
+
451
+ terminated["__all__"] = terminated[_DUMMY_AGENT_ID]
452
+ truncated["__all__"] = truncated[_DUMMY_AGENT_ID]
453
+
454
+ return obs, rew, terminated, truncated, info
455
+
456
+ # Defining these 2 functions that way this information can be queried
457
+ # with a call to ray.get().
458
+ def observation_space(self):
459
+ return self.env.observation_space
460
+
461
+ def action_space(self):
462
+ return self.env.action_space
.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from functools import partial
3
+ import logging
4
+ import time
5
+ from typing import Collection, DefaultDict, List, Optional, Union
6
+
7
+ import gymnasium as gym
8
+ from gymnasium.wrappers.vector import DictInfoToList
9
+
10
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
11
+ from ray.rllib.callbacks.callbacks import RLlibCallback
12
+ from ray.rllib.callbacks.utils import make_callback
13
+ from ray.rllib.core import (
14
+ COMPONENT_ENV_TO_MODULE_CONNECTOR,
15
+ COMPONENT_MODULE_TO_ENV_CONNECTOR,
16
+ COMPONENT_RL_MODULE,
17
+ DEFAULT_AGENT_ID,
18
+ DEFAULT_MODULE_ID,
19
+ )
20
+ from ray.rllib.core.columns import Columns
21
+ from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
22
+ from ray.rllib.env import INPUT_ENV_SPACES
23
+ from ray.rllib.env.env_context import EnvContext
24
+ from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
25
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
26
+ from ray.rllib.env.utils import _gym_env_creator
27
+ from ray.rllib.utils import force_list
28
+ from ray.rllib.utils.annotations import override
29
+ from ray.rllib.utils.checkpoints import Checkpointable
30
+ from ray.rllib.utils.deprecation import Deprecated
31
+ from ray.rllib.utils.framework import get_device
32
+ from ray.rllib.utils.metrics import (
33
+ EPISODE_DURATION_SEC_MEAN,
34
+ EPISODE_LEN_MAX,
35
+ EPISODE_LEN_MEAN,
36
+ EPISODE_LEN_MIN,
37
+ EPISODE_RETURN_MAX,
38
+ EPISODE_RETURN_MEAN,
39
+ EPISODE_RETURN_MIN,
40
+ NUM_AGENT_STEPS_SAMPLED,
41
+ NUM_AGENT_STEPS_SAMPLED_LIFETIME,
42
+ NUM_ENV_STEPS_SAMPLED,
43
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
44
+ NUM_EPISODES,
45
+ NUM_EPISODES_LIFETIME,
46
+ NUM_MODULE_STEPS_SAMPLED,
47
+ NUM_MODULE_STEPS_SAMPLED_LIFETIME,
48
+ SAMPLE_TIMER,
49
+ TIME_BETWEEN_SAMPLING,
50
+ WEIGHTS_SEQ_NO,
51
+ )
52
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
53
+ from ray.rllib.utils.spaces.space_utils import unbatch
54
+ from ray.rllib.utils.typing import EpisodeID, ResultDict, StateDict
55
+ from ray.tune.registry import ENV_CREATOR, _global_registry
56
+ from ray.util.annotations import PublicAPI
57
+
58
+ logger = logging.getLogger("ray.rllib")
59
+
60
+
61
+ # TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself
62
+ # a Checkpointable. Currently, only some of its subclasses are Checkpointables.
63
+ @PublicAPI(stability="alpha")
64
+ class SingleAgentEnvRunner(EnvRunner, Checkpointable):
65
+ """The generic environment runner for the single agent case."""
66
+
67
+ @override(EnvRunner)
68
+ def __init__(self, *, config: AlgorithmConfig, **kwargs):
69
+ """Initializes a SingleAgentEnvRunner instance.
70
+
71
+ Args:
72
+ config: An `AlgorithmConfig` object containing all settings needed to
73
+ build this `EnvRunner` class.
74
+ """
75
+ super().__init__(config=config)
76
+
77
+ self.worker_index: int = kwargs.get("worker_index")
78
+ self.num_workers: int = kwargs.get("num_workers", self.config.num_env_runners)
79
+ self.tune_trial_id: str = kwargs.get("tune_trial_id")
80
+
81
+ # Create a MetricsLogger object for logging custom stats.
82
+ self.metrics = MetricsLogger()
83
+
84
+ # Create our callbacks object.
85
+ self._callbacks: List[RLlibCallback] = [
86
+ cls() for cls in force_list(self.config.callbacks_class)
87
+ ]
88
+
89
+ # Set device.
90
+ self._device = get_device(
91
+ self.config,
92
+ 0 if not self.worker_index else self.config.num_gpus_per_env_runner,
93
+ )
94
+
95
+ # Create the vectorized gymnasium env.
96
+ self.env: Optional[gym.vector.VectorEnvWrapper] = None
97
+ self.num_envs: int = 0
98
+ self.make_env()
99
+
100
+ # Create the env-to-module connector pipeline.
101
+ self._env_to_module = self.config.build_env_to_module_connector(
102
+ self.env, device=self._device
103
+ )
104
+ # Cached env-to-module results taken at the end of a `_sample_timesteps()`
105
+ # call to make sure the final observation (before an episode cut) gets properly
106
+ # processed (and maybe postprocessed and re-stored into the episode).
107
+ # For example, if we had a connector that normalizes observations and directly
108
+ # re-inserts these new obs back into the episode, the last observation in each
109
+ # sample call would NOT be processed, which could be very harmful in cases,
110
+ # in which value function bootstrapping of those (truncation) observations is
111
+ # required in the learning step.
112
+ self._cached_to_module = None
113
+
114
+ # Create the RLModule.
115
+ self.module: Optional[RLModule] = None
116
+ self.make_module()
117
+
118
+ # Create the module-to-env connector pipeline.
119
+ self._module_to_env = self.config.build_module_to_env_connector(self.env)
120
+
121
+ # This should be the default.
122
+ self._needs_initial_reset: bool = True
123
+ self._episodes: List[Optional[SingleAgentEpisode]] = [
124
+ None for _ in range(self.num_envs)
125
+ ]
126
+ self._shared_data = None
127
+
128
+ self._done_episodes_for_metrics: List[SingleAgentEpisode] = []
129
+ self._ongoing_episodes_for_metrics: DefaultDict[
130
+ EpisodeID, List[SingleAgentEpisode]
131
+ ] = defaultdict(list)
132
+ self._weights_seq_no: int = 0
133
+
134
+ # Measures the time passed between returning from `sample()`
135
+ # and receiving the next `sample()` request from the user.
136
+ self._time_after_sampling = None
137
+
138
+ @override(EnvRunner)
139
+ def sample(
140
+ self,
141
+ *,
142
+ num_timesteps: int = None,
143
+ num_episodes: int = None,
144
+ explore: bool = None,
145
+ random_actions: bool = False,
146
+ force_reset: bool = False,
147
+ ) -> List[SingleAgentEpisode]:
148
+ """Runs and returns a sample (n timesteps or m episodes) on the env(s).
149
+
150
+ Args:
151
+ num_timesteps: The number of timesteps to sample during this call.
152
+ Note that only one of `num_timetseps` or `num_episodes` may be provided.
153
+ num_episodes: The number of episodes to sample during this call.
154
+ Note that only one of `num_timetseps` or `num_episodes` may be provided.
155
+ explore: If True, will use the RLModule's `forward_exploration()`
156
+ method to compute actions. If False, will use the RLModule's
157
+ `forward_inference()` method. If None (default), will use the `explore`
158
+ boolean setting from `self.config` passed into this EnvRunner's
159
+ constructor. You can change this setting in your config via
160
+ `config.env_runners(explore=True|False)`.
161
+ random_actions: If True, actions will be sampled randomly (from the action
162
+ space of the environment). If False (default), actions or action
163
+ distribution parameters are computed by the RLModule.
164
+ force_reset: Whether to force-reset all (vector) environments before
165
+ sampling. Useful if you would like to collect a clean slate of new
166
+ episodes via this call. Note that when sampling n episodes
167
+ (`num_episodes != None`), this is fixed to True.
168
+
169
+ Returns:
170
+ A list of `SingleAgentEpisode` instances, carrying the sampled data.
171
+ """
172
+ assert not (num_timesteps is not None and num_episodes is not None)
173
+
174
+ # Log time between `sample()` requests.
175
+ if self._time_after_sampling is not None:
176
+ self.metrics.log_value(
177
+ key=TIME_BETWEEN_SAMPLING,
178
+ value=time.perf_counter() - self._time_after_sampling,
179
+ )
180
+
181
+ # Log current weight seq no.
182
+ self.metrics.log_value(
183
+ key=WEIGHTS_SEQ_NO,
184
+ value=self._weights_seq_no,
185
+ window=1,
186
+ )
187
+
188
+ with self.metrics.log_time(SAMPLE_TIMER):
189
+ # If no execution details are provided, use the config to try to infer the
190
+ # desired timesteps/episodes to sample and exploration behavior.
191
+ if explore is None:
192
+ explore = self.config.explore
193
+ if (
194
+ num_timesteps is None
195
+ and num_episodes is None
196
+ and self.config.batch_mode == "truncate_episodes"
197
+ ):
198
+ num_timesteps = (
199
+ self.config.get_rollout_fragment_length(self.worker_index)
200
+ * self.num_envs
201
+ )
202
+
203
+ # Sample n timesteps.
204
+ if num_timesteps is not None:
205
+ samples = self._sample(
206
+ num_timesteps=num_timesteps,
207
+ explore=explore,
208
+ random_actions=random_actions,
209
+ force_reset=force_reset,
210
+ )
211
+ # Sample m episodes.
212
+ elif num_episodes is not None:
213
+ samples = self._sample(
214
+ num_episodes=num_episodes,
215
+ explore=explore,
216
+ random_actions=random_actions,
217
+ )
218
+ # For complete episodes mode, sample as long as the number of timesteps
219
+ # done is smaller than the `train_batch_size`.
220
+ else:
221
+ samples = self._sample(
222
+ num_episodes=self.num_envs,
223
+ explore=explore,
224
+ random_actions=random_actions,
225
+ )
226
+
227
+ # Make the `on_sample_end` callback.
228
+ make_callback(
229
+ "on_sample_end",
230
+ callbacks_objects=self._callbacks,
231
+ callbacks_functions=self.config.callbacks_on_sample_end,
232
+ kwargs=dict(
233
+ env_runner=self,
234
+ metrics_logger=self.metrics,
235
+ samples=samples,
236
+ ),
237
+ )
238
+
239
+ self._time_after_sampling = time.perf_counter()
240
+
241
+ return samples
242
+
243
+ def _sample(
244
+ self,
245
+ *,
246
+ num_timesteps: Optional[int] = None,
247
+ num_episodes: Optional[int] = None,
248
+ explore: bool,
249
+ random_actions: bool = False,
250
+ force_reset: bool = False,
251
+ ) -> List[SingleAgentEpisode]:
252
+ """Helper method to sample n timesteps or m episodes."""
253
+
254
+ done_episodes_to_return: List[SingleAgentEpisode] = []
255
+
256
+ # Have to reset the env (on all vector sub_envs).
257
+ if force_reset or num_episodes is not None or self._needs_initial_reset:
258
+ episodes = self._episodes = [None for _ in range(self.num_envs)]
259
+ shared_data = self._shared_data = {}
260
+ self._reset_envs(episodes, shared_data, explore)
261
+ # We just reset the env. Don't have to force this again in the next
262
+ # call to `self._sample_timesteps()`.
263
+ self._needs_initial_reset = False
264
+ else:
265
+ episodes = self._episodes
266
+ shared_data = self._shared_data
267
+
268
+ if num_episodes is not None:
269
+ self._needs_initial_reset = True
270
+
271
+ # Loop through `num_timesteps` timesteps or `num_episodes` episodes.
272
+ ts = 0
273
+ eps = 0
274
+ while (
275
+ (ts < num_timesteps) if num_timesteps is not None else (eps < num_episodes)
276
+ ):
277
+ # Act randomly.
278
+ if random_actions:
279
+ to_env = {
280
+ Columns.ACTIONS: self.env.action_space.sample(),
281
+ }
282
+ # Compute an action using the RLModule.
283
+ else:
284
+ # Env-to-module connector (already cached).
285
+ to_module = self._cached_to_module
286
+ assert to_module is not None
287
+ self._cached_to_module = None
288
+
289
+ # RLModule forward pass: Explore or not.
290
+ if explore:
291
+ # Global env steps sampled are (roughly) this EnvRunner's lifetime
292
+ # count times the number of env runners in the algo.
293
+ global_env_steps_lifetime = (
294
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
295
+ + ts
296
+ ) * (self.config.num_env_runners or 1)
297
+ to_env = self.module.forward_exploration(
298
+ to_module, t=global_env_steps_lifetime
299
+ )
300
+ else:
301
+ to_env = self.module.forward_inference(to_module)
302
+
303
+ # Module-to-env connector.
304
+ to_env = self._module_to_env(
305
+ rl_module=self.module,
306
+ batch=to_env,
307
+ episodes=episodes,
308
+ explore=explore,
309
+ shared_data=shared_data,
310
+ metrics=self.metrics,
311
+ )
312
+
313
+ # Extract the (vectorized) actions (to be sent to the env) from the
314
+ # module/connector output. Note that these actions are fully ready (e.g.
315
+ # already unsquashed/clipped) to be sent to the environment) and might not
316
+ # be identical to the actions produced by the RLModule/distribution, which
317
+ # are the ones stored permanently in the episode objects.
318
+ actions = to_env.pop(Columns.ACTIONS)
319
+ actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
320
+ # Try stepping the environment.
321
+ results = self._try_env_step(actions_for_env)
322
+ if results == ENV_STEP_FAILURE:
323
+ return self._sample(
324
+ num_timesteps=num_timesteps,
325
+ num_episodes=num_episodes,
326
+ explore=explore,
327
+ random_actions=random_actions,
328
+ force_reset=True,
329
+ )
330
+ observations, rewards, terminateds, truncateds, infos = results
331
+ observations, actions = unbatch(observations), unbatch(actions)
332
+
333
+ call_on_episode_start = set()
334
+ for env_index in range(self.num_envs):
335
+ extra_model_output = {k: v[env_index] for k, v in to_env.items()}
336
+ extra_model_output[WEIGHTS_SEQ_NO] = self._weights_seq_no
337
+
338
+ # Episode has no data in it yet -> Was just reset and needs to be called
339
+ # with its `add_env_reset()` method.
340
+ if not self._episodes[env_index].is_reset:
341
+ episodes[env_index].add_env_reset(
342
+ observation=observations[env_index],
343
+ infos=infos[env_index],
344
+ )
345
+ call_on_episode_start.add(env_index)
346
+
347
+ # Call `add_env_step()` method on episode.
348
+ else:
349
+ # Only increase ts when we actually stepped (not reset'd as a reset
350
+ # does not count as a timestep).
351
+ ts += 1
352
+ episodes[env_index].add_env_step(
353
+ observation=observations[env_index],
354
+ action=actions[env_index],
355
+ reward=rewards[env_index],
356
+ infos=infos[env_index],
357
+ terminated=terminateds[env_index],
358
+ truncated=truncateds[env_index],
359
+ extra_model_outputs=extra_model_output,
360
+ )
361
+
362
+ # Env-to-module connector pass (cache results as we will do the RLModule
363
+ # forward pass only in the next `while`-iteration.
364
+ if self.module is not None:
365
+ self._cached_to_module = self._env_to_module(
366
+ episodes=episodes,
367
+ explore=explore,
368
+ rl_module=self.module,
369
+ shared_data=shared_data,
370
+ metrics=self.metrics,
371
+ )
372
+
373
+ for env_index in range(self.num_envs):
374
+ # Call `on_episode_start()` callback (always after reset).
375
+ if env_index in call_on_episode_start:
376
+ self._make_on_episode_callback(
377
+ "on_episode_start", env_index, episodes
378
+ )
379
+ # Make the `on_episode_step` callbacks.
380
+ else:
381
+ self._make_on_episode_callback(
382
+ "on_episode_step", env_index, episodes
383
+ )
384
+
385
+ # Episode is done.
386
+ if episodes[env_index].is_done:
387
+ eps += 1
388
+
389
+ # Make the `on_episode_end` callbacks (before finalizing the episode
390
+ # object).
391
+ self._make_on_episode_callback(
392
+ "on_episode_end", env_index, episodes
393
+ )
394
+
395
+ # Numpy'ize the episode.
396
+ if self.config.episodes_to_numpy:
397
+ # Any possibly compress observations.
398
+ done_episodes_to_return.append(episodes[env_index].to_numpy())
399
+ # Leave episode as lists of individual (obs, action, etc..) items.
400
+ else:
401
+ done_episodes_to_return.append(episodes[env_index])
402
+
403
+ # Also early-out if we reach the number of episodes within this
404
+ # for-loop.
405
+ if eps == num_episodes:
406
+ break
407
+
408
+ # Create a new episode object with no data in it and execute
409
+ # `on_episode_created` callback (before the `env.reset()` call).
410
+ episodes[env_index] = SingleAgentEpisode(
411
+ observation_space=self.env.single_observation_space,
412
+ action_space=self.env.single_action_space,
413
+ )
414
+ self._make_on_episode_callback(
415
+ "on_episode_created",
416
+ env_index,
417
+ episodes,
418
+ )
419
+
420
+ # Return done episodes ...
421
+ self._done_episodes_for_metrics.extend(done_episodes_to_return)
422
+ # ... and all ongoing episode chunks.
423
+
424
+ # Also, make sure we start new episode chunks (continuing the ongoing episodes
425
+ # from the to-be-returned chunks).
426
+ ongoing_episodes_to_return = []
427
+ # Only if we are doing individual timesteps: We have to maybe cut an ongoing
428
+ # episode and continue building it on the next call to `sample()`.
429
+ if num_timesteps is not None:
430
+ ongoing_episodes_continuations = [
431
+ eps.cut(len_lookback_buffer=self.config.episode_lookback_horizon)
432
+ for eps in self._episodes
433
+ ]
434
+
435
+ for eps in self._episodes:
436
+ # Just started Episodes do not have to be returned. There is no data
437
+ # in them anyway.
438
+ if eps.t == 0:
439
+ continue
440
+ eps.validate()
441
+ self._ongoing_episodes_for_metrics[eps.id_].append(eps)
442
+
443
+ # Numpy'ize the episode.
444
+ if self.config.episodes_to_numpy:
445
+ # Any possibly compress observations.
446
+ ongoing_episodes_to_return.append(eps.to_numpy())
447
+ # Leave episode as lists of individual (obs, action, etc..) items.
448
+ else:
449
+ ongoing_episodes_to_return.append(eps)
450
+
451
+ # Continue collecting into the cut Episode chunks.
452
+ self._episodes = ongoing_episodes_continuations
453
+
454
+ self._increase_sampled_metrics(ts, len(done_episodes_to_return))
455
+
456
+ # Return collected episode data.
457
+ return done_episodes_to_return + ongoing_episodes_to_return
458
+
459
+ @override(EnvRunner)
460
+ def get_spaces(self):
461
+ return {
462
+ INPUT_ENV_SPACES: (self.env.observation_space, self.env.action_space),
463
+ DEFAULT_MODULE_ID: (
464
+ self._env_to_module.observation_space,
465
+ self.env.single_action_space,
466
+ ),
467
+ }
468
+
469
+ @override(EnvRunner)
470
+ def get_metrics(self) -> ResultDict:
471
+ # Compute per-episode metrics (only on already completed episodes).
472
+ for eps in self._done_episodes_for_metrics:
473
+ assert eps.is_done
474
+ episode_length = len(eps)
475
+ episode_return = eps.get_return()
476
+ episode_duration_s = eps.get_duration_s()
477
+ # Don't forget about the already returned chunks of this episode.
478
+ if eps.id_ in self._ongoing_episodes_for_metrics:
479
+ for eps2 in self._ongoing_episodes_for_metrics[eps.id_]:
480
+ episode_length += len(eps2)
481
+ episode_return += eps2.get_return()
482
+ episode_duration_s += eps2.get_duration_s()
483
+ del self._ongoing_episodes_for_metrics[eps.id_]
484
+
485
+ self._log_episode_metrics(
486
+ episode_length, episode_return, episode_duration_s
487
+ )
488
+
489
+ # Now that we have logged everything, clear cache of done episodes.
490
+ self._done_episodes_for_metrics.clear()
491
+
492
+ # Return reduced metrics.
493
+ return self.metrics.reduce()
494
+
495
+ @override(Checkpointable)
496
+ def get_state(
497
+ self,
498
+ components: Optional[Union[str, Collection[str]]] = None,
499
+ *,
500
+ not_components: Optional[Union[str, Collection[str]]] = None,
501
+ **kwargs,
502
+ ) -> StateDict:
503
+ state = {
504
+ NUM_ENV_STEPS_SAMPLED_LIFETIME: (
505
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
506
+ ),
507
+ }
508
+
509
+ if self._check_component(COMPONENT_RL_MODULE, components, not_components):
510
+ state[COMPONENT_RL_MODULE] = self.module.get_state(
511
+ components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
512
+ not_components=self._get_subcomponents(
513
+ COMPONENT_RL_MODULE, not_components
514
+ ),
515
+ **kwargs,
516
+ )
517
+ state[WEIGHTS_SEQ_NO] = self._weights_seq_no
518
+ if self._check_component(
519
+ COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components
520
+ ):
521
+ state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state()
522
+ if self._check_component(
523
+ COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components
524
+ ):
525
+ state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state()
526
+
527
+ return state
528
+
529
+ @override(Checkpointable)
530
+ def set_state(self, state: StateDict) -> None:
531
+ if COMPONENT_ENV_TO_MODULE_CONNECTOR in state:
532
+ self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR])
533
+ if COMPONENT_MODULE_TO_ENV_CONNECTOR in state:
534
+ self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR])
535
+
536
+ # Update the RLModule state.
537
+ if COMPONENT_RL_MODULE in state:
538
+ # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
539
+ # update.
540
+ weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
541
+
542
+ # Only update the weigths, if this is the first synchronization or
543
+ # if the weights of this `EnvRunner` lacks behind the actual ones.
544
+ if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
545
+ rl_module_state = state[COMPONENT_RL_MODULE]
546
+ if (
547
+ isinstance(rl_module_state, dict)
548
+ and DEFAULT_MODULE_ID in rl_module_state
549
+ ):
550
+ rl_module_state = rl_module_state[DEFAULT_MODULE_ID]
551
+ self.module.set_state(rl_module_state)
552
+
553
+ # Update our weights_seq_no, if the new one is > 0.
554
+ if weights_seq_no > 0:
555
+ self._weights_seq_no = weights_seq_no
556
+
557
+ # Update our lifetime counters.
558
+ if NUM_ENV_STEPS_SAMPLED_LIFETIME in state:
559
+ self.metrics.set_value(
560
+ key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
561
+ value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME],
562
+ reduce="sum",
563
+ with_throughput=True,
564
+ )
565
+
566
+ @override(Checkpointable)
567
+ def get_ctor_args_and_kwargs(self):
568
+ return (
569
+ (), # *args
570
+ {"config": self.config}, # **kwargs
571
+ )
572
+
573
+ @override(Checkpointable)
574
+ def get_metadata(self):
575
+ metadata = Checkpointable.get_metadata(self)
576
+ metadata.update(
577
+ {
578
+ # TODO (sven): Maybe add serialized (JSON-writable) config here?
579
+ }
580
+ )
581
+ return metadata
582
+
583
+ @override(Checkpointable)
584
+ def get_checkpointable_components(self):
585
+ return [
586
+ (COMPONENT_RL_MODULE, self.module),
587
+ (COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module),
588
+ (COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env),
589
+ ]
590
+
591
+ @override(EnvRunner)
592
+ def assert_healthy(self):
593
+ """Checks that self.__init__() has been completed properly.
594
+
595
+ Ensures that the instances has a `MultiRLModule` and an
596
+ environment defined.
597
+
598
+ Raises:
599
+ AssertionError: If the EnvRunner Actor has NOT been properly initialized.
600
+ """
601
+ # Make sure, we have built our gym.vector.Env and RLModule properly.
602
+ assert self.env and hasattr(self, "module")
603
+
604
+ @override(EnvRunner)
605
+ def make_env(self) -> None:
606
+ """Creates a vectorized gymnasium env and stores it in `self.env`.
607
+
608
+ Note that users can change the EnvRunner's config (e.g. change
609
+ `self.config.env_config`) and then call this method to create new environments
610
+ with the updated configuration.
611
+ """
612
+ # If an env already exists, try closing it first (to allow it to properly
613
+ # cleanup).
614
+ if self.env is not None:
615
+ try:
616
+ self.env.close()
617
+ except Exception as e:
618
+ logger.warning(
619
+ "Tried closing the existing env, but failed with error: "
620
+ f"{e.args[0]}"
621
+ )
622
+
623
+ env_ctx = self.config.env_config
624
+ if not isinstance(env_ctx, EnvContext):
625
+ env_ctx = EnvContext(
626
+ env_ctx,
627
+ worker_index=self.worker_index,
628
+ num_workers=self.num_workers,
629
+ remote=self.config.remote_worker_envs,
630
+ )
631
+
632
+ # No env provided -> Error.
633
+ if not self.config.env:
634
+ raise ValueError(
635
+ "`config.env` is not provided! You should provide a valid environment "
636
+ "to your config through `config.environment([env descriptor e.g. "
637
+ "'CartPole-v1'])`."
638
+ )
639
+ # Register env for the local context.
640
+ # Note, `gym.register` has to be called on each worker.
641
+ elif isinstance(self.config.env, str) and _global_registry.contains(
642
+ ENV_CREATOR, self.config.env
643
+ ):
644
+ entry_point = partial(
645
+ _global_registry.get(ENV_CREATOR, self.config.env),
646
+ env_ctx,
647
+ )
648
+ else:
649
+ entry_point = partial(
650
+ _gym_env_creator,
651
+ env_descriptor=self.config.env,
652
+ env_context=env_ctx,
653
+ )
654
+ gym.register("rllib-single-agent-env-v0", entry_point=entry_point)
655
+ vectorize_mode = self.config.gym_env_vectorize_mode
656
+
657
+ self.env = DictInfoToList(
658
+ gym.make_vec(
659
+ "rllib-single-agent-env-v0",
660
+ num_envs=self.config.num_envs_per_env_runner,
661
+ vectorization_mode=(
662
+ vectorize_mode
663
+ if isinstance(vectorize_mode, gym.envs.registration.VectorizeMode)
664
+ else gym.envs.registration.VectorizeMode(vectorize_mode.lower())
665
+ ),
666
+ )
667
+ )
668
+
669
+ self.num_envs: int = self.env.num_envs
670
+ assert self.num_envs == self.config.num_envs_per_env_runner
671
+
672
+ # Set the flag to reset all envs upon the next `sample()` call.
673
+ self._needs_initial_reset = True
674
+
675
+ # Call the `on_environment_created` callback.
676
+ make_callback(
677
+ "on_environment_created",
678
+ callbacks_objects=self._callbacks,
679
+ callbacks_functions=self.config.callbacks_on_environment_created,
680
+ kwargs=dict(
681
+ env_runner=self,
682
+ metrics_logger=self.metrics,
683
+ env=self.env.unwrapped,
684
+ env_context=env_ctx,
685
+ ),
686
+ )
687
+
688
+ @override(EnvRunner)
689
+ def make_module(self):
690
+ try:
691
+ module_spec: RLModuleSpec = self.config.get_rl_module_spec(
692
+ env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True
693
+ )
694
+ # Build the module from its spec.
695
+ self.module = module_spec.build()
696
+
697
+ # Move the RLModule to our device.
698
+ # TODO (sven): In order to make this framework-agnostic, we should maybe
699
+ # make the RLModule.build() method accept a device OR create an additional
700
+ # `RLModule.to()` override.
701
+ self.module.to(self._device)
702
+
703
+ # If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner
704
+ # will not have an RLModule, but might still be usable with random actions.
705
+ except NotImplementedError:
706
+ self.module = None
707
+
708
+ @override(EnvRunner)
709
+ def stop(self):
710
+ # Close our env object via gymnasium's API.
711
+ self.env.close()
712
+
713
+ def _reset_envs(self, episodes, shared_data, explore):
714
+ # Create n new episodes and make the `on_episode_created` callbacks.
715
+ for env_index in range(self.num_envs):
716
+ self._new_episode(env_index, episodes)
717
+
718
+ # Erase all cached ongoing episodes (these will never be completed and
719
+ # would thus never be returned/cleaned by `get_metrics` and cause a memory
720
+ # leak).
721
+ self._ongoing_episodes_for_metrics.clear()
722
+
723
+ # Try resetting the environment.
724
+ # TODO (simon): Check, if we need here the seed from the config.
725
+ observations, infos = self._try_env_reset()
726
+ observations = unbatch(observations)
727
+
728
+ # Set initial obs and infos in the episodes.
729
+ for env_index in range(self.num_envs):
730
+ episodes[env_index].add_env_reset(
731
+ observation=observations[env_index],
732
+ infos=infos[env_index],
733
+ )
734
+
735
+ # Run the env-to-module connector to make sure the reset-obs/infos have
736
+ # properly been processed (if applicable).
737
+ self._cached_to_module = None
738
+ if self.module:
739
+ self._cached_to_module = self._env_to_module(
740
+ rl_module=self.module,
741
+ episodes=episodes,
742
+ explore=explore,
743
+ shared_data=shared_data,
744
+ metrics=self.metrics,
745
+ )
746
+
747
+ # Call `on_episode_start()` callbacks (always after reset).
748
+ for env_index in range(self.num_envs):
749
+ self._make_on_episode_callback("on_episode_start", env_index, episodes)
750
+
751
+ def _new_episode(self, env_index, episodes=None):
752
+ episodes = episodes if episodes is not None else self._episodes
753
+ episodes[env_index] = SingleAgentEpisode(
754
+ observation_space=self.env.single_observation_space,
755
+ action_space=self.env.single_action_space,
756
+ )
757
+ self._make_on_episode_callback("on_episode_created", env_index, episodes)
758
+
759
+ def _make_on_episode_callback(self, which: str, idx: int, episodes):
760
+ make_callback(
761
+ which,
762
+ callbacks_objects=self._callbacks,
763
+ callbacks_functions=getattr(self.config, f"callbacks_{which}"),
764
+ kwargs=dict(
765
+ episode=episodes[idx],
766
+ env_runner=self,
767
+ metrics_logger=self.metrics,
768
+ env=self.env.unwrapped,
769
+ rl_module=self.module,
770
+ env_index=idx,
771
+ ),
772
+ )
773
+
774
+ def _increase_sampled_metrics(self, num_steps, num_episodes_completed):
775
+ # Per sample cycle stats.
776
+ self.metrics.log_value(
777
+ NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True
778
+ )
779
+ self.metrics.log_value(
780
+ (NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID),
781
+ num_steps,
782
+ reduce="sum",
783
+ clear_on_reduce=True,
784
+ )
785
+ self.metrics.log_value(
786
+ (NUM_MODULE_STEPS_SAMPLED, DEFAULT_MODULE_ID),
787
+ num_steps,
788
+ reduce="sum",
789
+ clear_on_reduce=True,
790
+ )
791
+ self.metrics.log_value(
792
+ NUM_EPISODES,
793
+ num_episodes_completed,
794
+ reduce="sum",
795
+ clear_on_reduce=True,
796
+ )
797
+ # Lifetime stats.
798
+ self.metrics.log_value(
799
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
800
+ num_steps,
801
+ reduce="sum",
802
+ with_throughput=True,
803
+ )
804
+ self.metrics.log_value(
805
+ (NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID),
806
+ num_steps,
807
+ reduce="sum",
808
+ )
809
+ self.metrics.log_value(
810
+ (NUM_MODULE_STEPS_SAMPLED_LIFETIME, DEFAULT_MODULE_ID),
811
+ num_steps,
812
+ reduce="sum",
813
+ )
814
+ self.metrics.log_value(
815
+ NUM_EPISODES_LIFETIME,
816
+ num_episodes_completed,
817
+ reduce="sum",
818
+ )
819
+ return num_steps
820
+
821
+ def _log_episode_metrics(self, length, ret, sec):
822
+ # Log general episode metrics.
823
+ # To mimic the old API stack behavior, we'll use `window` here for
824
+ # these particular stats (instead of the default EMA).
825
+ win = self.config.metrics_num_episodes_for_smoothing
826
+ self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win)
827
+ self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win)
828
+ self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win)
829
+ # Per-agent returns.
830
+ self.metrics.log_value(
831
+ ("agent_episode_returns_mean", DEFAULT_AGENT_ID), ret, window=win
832
+ )
833
+ # Per-RLModule returns.
834
+ self.metrics.log_value(
835
+ ("module_episode_returns_mean", DEFAULT_MODULE_ID), ret, window=win
836
+ )
837
+
838
+ # For some metrics, log min/max as well.
839
+ self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min", window=win)
840
+ self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min", window=win)
841
+ self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max", window=win)
842
+ self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max", window=win)
843
+
844
+ @Deprecated(
845
+ new="SingleAgentEnvRunner.get_state(components='rl_module')",
846
+ error=True,
847
+ )
848
+ def get_weights(self, *args, **kwargs):
849
+ pass
850
+
851
+ @Deprecated(new="SingleAgentEnvRunner.set_state()", error=True)
852
+ def set_weights(self, *args, **kwargs):
853
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py ADDED
@@ -0,0 +1,1862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from collections import defaultdict
3
+ import numpy as np
4
+ import time
5
+ import uuid
6
+
7
+ import gymnasium as gym
8
+ from gymnasium.core import ActType, ObsType
9
+ from typing import Any, Dict, List, Optional, SupportsFloat, Union
10
+
11
+ from ray.rllib.core.columns import Columns
12
+ from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
13
+ from ray.rllib.policy.sample_batch import SampleBatch
14
+ from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict
15
+ from ray.rllib.utils.deprecation import Deprecated
16
+ from ray.rllib.utils.typing import AgentID, ModuleID
17
+ from ray.util.annotations import PublicAPI
18
+
19
+
20
+ @PublicAPI(stability="alpha")
21
+ class SingleAgentEpisode:
22
+ """A class representing RL environment episodes for individual agents.
23
+
24
+ SingleAgentEpisode stores observations, info dicts, actions, rewards, and all
25
+ module outputs (e.g. state outs, action logp, etc..) for an individual agent within
26
+ some single-agent or multi-agent environment.
27
+ The two main APIs to add data to an ongoing episode are the `add_env_reset()`
28
+ and `add_env_step()` methods, which should be called passing the outputs of the
29
+ respective gym.Env API calls: `env.reset()` and `env.step()`.
30
+
31
+ A SingleAgentEpisode might also only represent a chunk of an episode, which is
32
+ useful for cases, in which partial (non-complete episode) sampling is performed
33
+ and collected episode data has to be returned before the actual gym.Env episode has
34
+ finished (see `SingleAgentEpisode.cut()`). In order to still maintain visibility
35
+ onto past experiences within such a "cut" episode, SingleAgentEpisode instances
36
+ can have a "lookback buffer" of n timesteps at their beginning (left side), which
37
+ solely exists for the purpose of compiling extra data (e.g. "prev. reward"), but
38
+ is not considered part of the finished/packaged episode (b/c the data in the
39
+ lookback buffer is already part of a previous episode chunk).
40
+
41
+ Powerful getter methods, such as `get_observations()` help collect different types
42
+ of data from the episode at individual time indices or time ranges, including the
43
+ "lookback buffer" range described above. For example, to extract the last 4 rewards
44
+ of an ongoing episode, one can call `self.get_rewards(slice(-4, None))` or
45
+ `self.rewards[-4:]`. This would work, even if the ongoing SingleAgentEpisode is
46
+ a continuation chunk from a much earlier started episode, as long as it has a
47
+ lookback buffer size of sufficient size.
48
+
49
+ Examples:
50
+
51
+ .. testcode::
52
+
53
+ import gymnasium as gym
54
+ import numpy as np
55
+
56
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
57
+
58
+ # Construct a new episode (without any data in it yet).
59
+ episode = SingleAgentEpisode()
60
+ assert len(episode) == 0
61
+
62
+ # Fill the episode with some data (10 timesteps).
63
+ env = gym.make("CartPole-v1")
64
+ obs, infos = env.reset()
65
+ episode.add_env_reset(obs, infos)
66
+
67
+ # Even with the initial obs/infos, the episode is still considered len=0.
68
+ assert len(episode) == 0
69
+ for _ in range(5):
70
+ action = env.action_space.sample()
71
+ obs, reward, term, trunc, infos = env.step(action)
72
+ episode.add_env_step(
73
+ observation=obs,
74
+ action=action,
75
+ reward=reward,
76
+ terminated=term,
77
+ truncated=trunc,
78
+ infos=infos,
79
+ )
80
+ assert len(episode) == 5
81
+
82
+ # We can now access information from the episode via the getter APIs.
83
+
84
+ # Get the last 3 rewards (in a batch of size 3).
85
+ episode.get_rewards(slice(-3, None)) # same as `episode.rewards[-3:]`
86
+
87
+ # Get the most recent action (single item, not batched).
88
+ # This works regardless of the action space or whether the episode has
89
+ # been numpy'ized or not (see below).
90
+ episode.get_actions(-1) # same as episode.actions[-1]
91
+
92
+ # Looking back from ts=1, get the previous 4 rewards AND fill with 0.0
93
+ # in case we go over the beginning (ts=0). So we would expect
94
+ # [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received
95
+ # reward in the episode:
96
+ episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0)
97
+
98
+ # Note the use of fill=0.0 here (fill everything that's out of range with this
99
+ # value) AND the argument `neg_index_as_lookback=True`, which interprets
100
+ # negative indices as being left of ts=0 (e.g. -1 being the timestep before
101
+ # ts=0).
102
+
103
+ # Assuming we had a complex action space (nested gym.spaces.Dict) with one or
104
+ # more elements being Discrete or MultiDiscrete spaces:
105
+ # 1) The `fill=...` argument would still work, filling all spaces (Boxes,
106
+ # Discrete) with that provided value.
107
+ # 2) Setting the flag `one_hot_discrete=True` would convert those discrete
108
+ # sub-components automatically into one-hot (or multi-one-hot) tensors.
109
+ # This simplifies the task of having to provide the previous 4 (nested and
110
+ # partially discrete/multi-discrete) actions for each timestep within a training
111
+ # batch, thereby filling timesteps before the episode started with 0.0s and
112
+ # one-hot'ing the discrete/multi-discrete components in these actions:
113
+ episode = SingleAgentEpisode(action_space=gym.spaces.Dict({
114
+ "a": gym.spaces.Discrete(3),
115
+ "b": gym.spaces.MultiDiscrete([2, 3]),
116
+ "c": gym.spaces.Box(-1.0, 1.0, (2,)),
117
+ }))
118
+
119
+ # ... fill episode with data ...
120
+ episode.add_env_reset(observation=0)
121
+ # ... from a few steps.
122
+ episode.add_env_step(
123
+ observation=1,
124
+ action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)},
125
+ reward=1.0,
126
+ )
127
+
128
+ # In your connector
129
+ prev_4_a = []
130
+ # Note here that len(episode) does NOT include the lookback buffer.
131
+ for ts in range(len(episode)):
132
+ prev_4_a.append(
133
+ episode.get_actions(
134
+ indices=slice(ts - 4, ts),
135
+ # Make sure negative indices are interpreted as
136
+ # "into lookback buffer"
137
+ neg_index_as_lookback=True,
138
+ # Zero-out everything even further before the lookback buffer.
139
+ fill=0.0,
140
+ # Take care of discrete components (get ready as NN input).
141
+ one_hot_discrete=True,
142
+ )
143
+ )
144
+
145
+ # Finally, convert from list of batch items to a struct (same as action space)
146
+ # of batched (numpy) arrays, in which all leafs have B==len(prev_4_a).
147
+ from ray.rllib.utils.spaces.space_utils import batch
148
+
149
+ prev_4_actions_col = batch(prev_4_a)
150
+ """
151
+
152
+ __slots__ = (
153
+ "actions",
154
+ "agent_id",
155
+ "extra_model_outputs",
156
+ "id_",
157
+ "infos",
158
+ "is_terminated",
159
+ "is_truncated",
160
+ "module_id",
161
+ "multi_agent_episode_id",
162
+ "observations",
163
+ "rewards",
164
+ "t",
165
+ "t_started",
166
+ "_action_space",
167
+ "_last_added_observation",
168
+ "_last_added_infos",
169
+ "_last_step_time",
170
+ "_observation_space",
171
+ "_start_time",
172
+ "_temporary_timestep_data",
173
+ )
174
+
175
+ def __init__(
176
+ self,
177
+ id_: Optional[str] = None,
178
+ *,
179
+ observations: Optional[Union[List[ObsType], InfiniteLookbackBuffer]] = None,
180
+ observation_space: Optional[gym.Space] = None,
181
+ infos: Optional[Union[List[Dict], InfiniteLookbackBuffer]] = None,
182
+ actions: Optional[Union[List[ActType], InfiniteLookbackBuffer]] = None,
183
+ action_space: Optional[gym.Space] = None,
184
+ rewards: Optional[Union[List[SupportsFloat], InfiniteLookbackBuffer]] = None,
185
+ terminated: bool = False,
186
+ truncated: bool = False,
187
+ extra_model_outputs: Optional[Dict[str, Any]] = None,
188
+ t_started: Optional[int] = None,
189
+ len_lookback_buffer: Union[int, str] = "auto",
190
+ agent_id: Optional[AgentID] = None,
191
+ module_id: Optional[ModuleID] = None,
192
+ multi_agent_episode_id: Optional[int] = None,
193
+ ):
194
+ """Initializes a SingleAgentEpisode instance.
195
+
196
+ This constructor can be called with or without already sampled data, part of
197
+ which might then go into the lookback buffer.
198
+
199
+ Args:
200
+ id_: Unique identifier for this episode. If no ID is provided the
201
+ constructor generates a unique hexadecimal code for the id.
202
+ observations: Either a list of individual observations from a sampling or
203
+ an already instantiated `InfiniteLookbackBuffer` object (possibly
204
+ with observation data in it). If a list, will construct the buffer
205
+ automatically (given the data and the `len_lookback_buffer` argument).
206
+ observation_space: An optional gym.Space, which all individual observations
207
+ should abide to. If not None and this SingleAgentEpisode is numpy'ized
208
+ (via the `self.to_numpy()` method), and data is appended or set, the new
209
+ data will be checked for correctness.
210
+ infos: Either a list of individual info dicts from a sampling or
211
+ an already instantiated `InfiniteLookbackBuffer` object (possibly
212
+ with info dicts in it). If a list, will construct the buffer
213
+ automatically (given the data and the `len_lookback_buffer` argument).
214
+ actions: Either a list of individual info dicts from a sampling or
215
+ an already instantiated `InfiniteLookbackBuffer` object (possibly
216
+ with info dict] data in it). If a list, will construct the buffer
217
+ automatically (given the data and the `len_lookback_buffer` argument).
218
+ action_space: An optional gym.Space, which all individual actions
219
+ should abide to. If not None and this SingleAgentEpisode is numpy'ized
220
+ (via the `self.to_numpy()` method), and data is appended or set, the new
221
+ data will be checked for correctness.
222
+ rewards: Either a list of individual rewards from a sampling or
223
+ an already instantiated `InfiniteLookbackBuffer` object (possibly
224
+ with reward data in it). If a list, will construct the buffer
225
+ automatically (given the data and the `len_lookback_buffer` argument).
226
+ extra_model_outputs: A dict mapping string keys to either lists of
227
+ individual extra model output tensors (e.g. `action_logp` or
228
+ `state_outs`) from a sampling or to already instantiated
229
+ `InfiniteLookbackBuffer` object (possibly with extra model output data
230
+ in it). If mapping is to lists, will construct the buffers automatically
231
+ (given the data and the `len_lookback_buffer` argument).
232
+ terminated: A boolean indicating, if the episode is already terminated.
233
+ truncated: A boolean indicating, if the episode has been truncated.
234
+ t_started: Optional. The starting timestep of the episode. The default
235
+ is zero. If data is provided, the starting point is from the last
236
+ observation onwards (i.e. `t_started = len(observations) - 1`). If
237
+ this parameter is provided the episode starts at the provided value.
238
+ len_lookback_buffer: The size of the (optional) lookback buffers to keep in
239
+ front of this Episode for each type of data (observations, actions,
240
+ etc..). If larger 0, will interpret the first `len_lookback_buffer`
241
+ items in each type of data as NOT part of this actual
242
+ episode chunk, but instead serve as "historical" record that may be
243
+ viewed and used to derive new data from. For example, it might be
244
+ necessary to have a lookback buffer of four if you would like to do
245
+ observation frame stacking and your episode has been cut and you are now
246
+ operating on a new chunk (continuing from the cut one). Then, for the
247
+ first 3 items, you would have to be able to look back into the old
248
+ chunk's data.
249
+ If `len_lookback_buffer` is "auto" (default), will interpret all
250
+ provided data in the constructor as part of the lookback buffers.
251
+ agent_id: An optional AgentID indicating which agent this episode belongs
252
+ to. This information is stored under `self.agent_id` and only serves
253
+ reference purposes.
254
+ module_id: An optional ModuleID indicating which RLModule this episode
255
+ belongs to. Normally, this information is obtained by querying an
256
+ `agent_to_module_mapping_fn` with a given agent ID. This information
257
+ is stored under `self.module_id` and only serves reference purposes.
258
+ multi_agent_episode_id: An optional EpisodeID of the encapsulating
259
+ `MultiAgentEpisode` that this `SingleAgentEpisode` belongs to.
260
+ """
261
+ self.id_ = id_ or uuid.uuid4().hex
262
+
263
+ self.agent_id = agent_id
264
+ self.module_id = module_id
265
+ self.multi_agent_episode_id = multi_agent_episode_id
266
+
267
+ # Lookback buffer length is not provided. Interpret already given data as
268
+ # lookback buffer lengths for all data types.
269
+ len_rewards = len(rewards) if rewards is not None else 0
270
+ if len_lookback_buffer == "auto" or len_lookback_buffer > len_rewards:
271
+ len_lookback_buffer = len_rewards
272
+
273
+ infos = infos or [{} for _ in range(len(observations or []))]
274
+
275
+ # Observations: t0 (initial obs) to T.
276
+ self._observation_space = None
277
+ if isinstance(observations, InfiniteLookbackBuffer):
278
+ self.observations = observations
279
+ else:
280
+ self.observations = InfiniteLookbackBuffer(
281
+ data=observations,
282
+ lookback=len_lookback_buffer,
283
+ )
284
+ self.observation_space = observation_space
285
+ # Infos: t0 (initial info) to T.
286
+ if isinstance(infos, InfiniteLookbackBuffer):
287
+ self.infos = infos
288
+ else:
289
+ self.infos = InfiniteLookbackBuffer(
290
+ data=infos,
291
+ lookback=len_lookback_buffer,
292
+ )
293
+ # Actions: t1 to T.
294
+ self._action_space = None
295
+ if isinstance(actions, InfiniteLookbackBuffer):
296
+ self.actions = actions
297
+ else:
298
+ self.actions = InfiniteLookbackBuffer(
299
+ data=actions,
300
+ lookback=len_lookback_buffer,
301
+ )
302
+ self.action_space = action_space
303
+ # Rewards: t1 to T.
304
+ if isinstance(rewards, InfiniteLookbackBuffer):
305
+ self.rewards = rewards
306
+ else:
307
+ self.rewards = InfiniteLookbackBuffer(
308
+ data=rewards,
309
+ lookback=len_lookback_buffer,
310
+ space=gym.spaces.Box(float("-inf"), float("inf"), (), np.float32),
311
+ )
312
+
313
+ # obs[-1] is the final observation in the episode.
314
+ self.is_terminated = terminated
315
+ # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more
316
+ # observations in following chunks for this episode).
317
+ self.is_truncated = truncated
318
+
319
+ # Extra model outputs, e.g. `action_dist_input` needed in the batch.
320
+ self.extra_model_outputs = {}
321
+ for k, v in (extra_model_outputs or {}).items():
322
+ if isinstance(v, InfiniteLookbackBuffer):
323
+ self.extra_model_outputs[k] = v
324
+ else:
325
+ # We cannot use the defaultdict's own constructor here as this would
326
+ # auto-set the lookback buffer to 0 (there is no data passed to that
327
+ # constructor). Then, when we manually have to set the data property,
328
+ # the lookback buffer would still be (incorrectly) 0.
329
+ self.extra_model_outputs[k] = InfiniteLookbackBuffer(
330
+ data=v, lookback=len_lookback_buffer
331
+ )
332
+
333
+ # The (global) timestep when this episode (possibly an episode chunk) started,
334
+ # excluding a possible lookback buffer.
335
+ self.t_started = t_started or 0
336
+ # The current (global) timestep in the episode (possibly an episode chunk).
337
+ self.t = len(self.rewards) + self.t_started
338
+
339
+ # Caches for temporary per-timestep data. May be used to store custom metrics
340
+ # from within a callback for the ongoing episode (e.g. render images).
341
+ self._temporary_timestep_data = defaultdict(list)
342
+
343
+ # Keep timer stats on deltas between steps.
344
+ self._start_time = None
345
+ self._last_step_time = None
346
+
347
+ self._last_added_observation = None
348
+ self._last_added_infos = None
349
+
350
+ # Validate the episode data thus far.
351
+ self.validate()
352
+
353
+ def add_env_reset(
354
+ self,
355
+ observation: ObsType,
356
+ infos: Optional[Dict] = None,
357
+ ) -> None:
358
+ """Adds the initial data (after an `env.reset()`) to the episode.
359
+
360
+ This data consists of initial observations and initial infos.
361
+
362
+ Args:
363
+ observation: The initial observation returned by `env.reset()`.
364
+ infos: An (optional) info dict returned by `env.reset()`.
365
+ """
366
+ assert not self.is_reset
367
+ assert not self.is_done
368
+ assert len(self.observations) == 0
369
+ # Assume that this episode is completely empty and has not stepped yet.
370
+ # Leave self.t (and self.t_started) at 0.
371
+ assert self.t == self.t_started == 0
372
+
373
+ infos = infos or {}
374
+
375
+ if self.observation_space is not None:
376
+ assert self.observation_space.contains(observation), (
377
+ f"`observation` {observation} does NOT fit SingleAgentEpisode's "
378
+ f"observation_space: {self.observation_space}!"
379
+ )
380
+
381
+ self.observations.append(observation)
382
+ self.infos.append(infos)
383
+
384
+ self._last_added_observation = observation
385
+ self._last_added_infos = infos
386
+
387
+ # Validate our data.
388
+ self.validate()
389
+
390
+ # Start the timer for this episode.
391
+ self._start_time = time.perf_counter()
392
+
393
+ def add_env_step(
394
+ self,
395
+ observation: ObsType,
396
+ action: ActType,
397
+ reward: SupportsFloat,
398
+ infos: Optional[Dict[str, Any]] = None,
399
+ *,
400
+ terminated: bool = False,
401
+ truncated: bool = False,
402
+ extra_model_outputs: Optional[Dict[str, Any]] = None,
403
+ ) -> None:
404
+ """Adds results of an `env.step()` call (including the action) to this episode.
405
+
406
+ This data consists of an observation and info dict, an action, a reward,
407
+ terminated/truncated flags, and extra model outputs (e.g. action probabilities
408
+ or RNN internal state outputs).
409
+
410
+ Args:
411
+ observation: The next observation received from the environment after(!)
412
+ taking `action`.
413
+ action: The last action used by the agent during the call to `env.step()`.
414
+ reward: The last reward received by the agent after taking `action`.
415
+ infos: The last info received from the environment after taking `action`.
416
+ terminated: A boolean indicating, if the environment has been
417
+ terminated (after taking `action`).
418
+ truncated: A boolean indicating, if the environment has been
419
+ truncated (after taking `action`).
420
+ extra_model_outputs: The last timestep's specific model outputs.
421
+ These are normally outputs of an RLModule that were computed along with
422
+ `action`, e.g. `action_logp` or `action_dist_inputs`.
423
+ """
424
+ # Cannot add data to an already done episode.
425
+ assert (
426
+ not self.is_done
427
+ ), "The agent is already done: no data can be added to its episode."
428
+
429
+ self.observations.append(observation)
430
+ self.actions.append(action)
431
+ self.rewards.append(reward)
432
+ infos = infos or {}
433
+ self.infos.append(infos)
434
+ self.t += 1
435
+ if extra_model_outputs is not None:
436
+ for k, v in extra_model_outputs.items():
437
+ if k not in self.extra_model_outputs:
438
+ self.extra_model_outputs[k] = InfiniteLookbackBuffer([v])
439
+ else:
440
+ self.extra_model_outputs[k].append(v)
441
+ self.is_terminated = terminated
442
+ self.is_truncated = truncated
443
+
444
+ self._last_added_observation = observation
445
+ self._last_added_infos = infos
446
+
447
+ # Only check spaces if numpy'ized AND every n timesteps.
448
+ if self.is_numpy and self.t % 100:
449
+ if self.observation_space is not None:
450
+ assert self.observation_space.contains(observation), (
451
+ f"`observation` {observation} does NOT fit SingleAgentEpisode's "
452
+ f"observation_space: {self.observation_space}!"
453
+ )
454
+ if self.action_space is not None:
455
+ assert self.action_space.contains(action), (
456
+ f"`action` {action} does NOT fit SingleAgentEpisode's "
457
+ f"action_space: {self.action_space}!"
458
+ )
459
+
460
+ # Validate our data.
461
+ self.validate()
462
+
463
+ # Step time stats.
464
+ self._last_step_time = time.perf_counter()
465
+ if self._start_time is None:
466
+ self._start_time = self._last_step_time
467
+
468
+ def validate(self) -> None:
469
+ """Validates the episode's data.
470
+
471
+ This function ensures that the data stored to a `SingleAgentEpisode` is
472
+ in order (e.g. that the correct number of observations, actions, rewards
473
+ are there).
474
+ """
475
+ assert len(self.observations) == len(self.infos)
476
+ if len(self.observations) == 0:
477
+ assert len(self.infos) == len(self.rewards) == len(self.actions) == 0
478
+ for k, v in self.extra_model_outputs.items():
479
+ assert len(v) == 0, (k, v, v.data, len(v))
480
+ # Make sure we always have one more obs stored than rewards (and actions)
481
+ # due to the reset/last-obs logic of an MDP.
482
+ else:
483
+ assert (
484
+ len(self.observations)
485
+ == len(self.infos)
486
+ == len(self.rewards) + 1
487
+ == len(self.actions) + 1
488
+ ), (
489
+ len(self.observations),
490
+ len(self.infos),
491
+ len(self.rewards),
492
+ len(self.actions),
493
+ )
494
+ for k, v in self.extra_model_outputs.items():
495
+ assert len(v) == len(self.observations) - 1
496
+
497
+ @property
498
+ def is_reset(self) -> bool:
499
+ """Returns True if `self.add_env_reset()` has already been called."""
500
+ return len(self.observations) > 0
501
+
502
+ @property
503
+ def is_numpy(self) -> bool:
504
+ """True, if the data in this episode is already stored as numpy arrays."""
505
+ # If rewards are still a list, return False.
506
+ # Otherwise, rewards should already be a (1D) numpy array.
507
+ return self.rewards.finalized
508
+
509
+ @property
510
+ def is_done(self) -> bool:
511
+ """Whether the episode is actually done (terminated or truncated).
512
+
513
+ A done episode cannot be continued via `self.add_timestep()` or being
514
+ concatenated on its right-side with another episode chunk or being
515
+ succeeded via `self.create_successor()`.
516
+ """
517
+ return self.is_terminated or self.is_truncated
518
+
519
+ def to_numpy(self) -> "SingleAgentEpisode":
520
+ """Converts this Episode's list attributes to numpy arrays.
521
+
522
+ This means in particular that this episodes' lists of (possibly complex)
523
+ data (e.g. if we have a dict obs space) will be converted to (possibly complex)
524
+ structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will
525
+ have the same length (batch dimension) as the length of the original lists.
526
+
527
+ Note that the data under the Columns.INFOS are NEVER numpy'ized and will remain
528
+ a list (normally, a list of the original, env-returned dicts). This is due to
529
+ the herterogenous nature of INFOS returned by envs, which would make it unwieldy
530
+ to convert this information to numpy arrays.
531
+
532
+ After calling this method, no further data may be added to this episode via
533
+ the `self.add_env_step()` method.
534
+
535
+ Examples:
536
+
537
+ .. testcode::
538
+
539
+ import numpy as np
540
+
541
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
542
+
543
+ episode = SingleAgentEpisode(
544
+ observations=[0, 1, 2, 3],
545
+ actions=[1, 2, 3],
546
+ rewards=[1, 2, 3],
547
+ # Note: terminated/truncated have nothing to do with an episode
548
+ # being numpy'ized or not (via the `self.to_numpy()` method)!
549
+ terminated=False,
550
+ len_lookback_buffer=0, # no lookback; all data is actually "in" episode
551
+ )
552
+ # Episode has not been numpy'ized yet.
553
+ assert not episode.is_numpy
554
+ # We are still operating on lists.
555
+ assert episode.get_observations([1]) == [1]
556
+ assert episode.get_observations(slice(None, 2)) == [0, 1]
557
+ # We can still add data (and even add the terminated=True flag).
558
+ episode.add_env_step(
559
+ observation=4,
560
+ action=4,
561
+ reward=4,
562
+ terminated=True,
563
+ )
564
+ # Still NOT numpy'ized.
565
+ assert not episode.is_numpy
566
+
567
+ # Numpy'ized the episode.
568
+ episode.to_numpy()
569
+ assert episode.is_numpy
570
+
571
+ # We cannot add data anymore. The following would crash.
572
+ # episode.add_env_step(observation=5, action=5, reward=5)
573
+
574
+ # Everything is now numpy arrays (with 0-axis of size
575
+ # B=[len of requested slice]).
576
+ assert isinstance(episode.get_observations([1]), np.ndarray) # B=1
577
+ assert isinstance(episode.actions[0:2], np.ndarray) # B=2
578
+ assert isinstance(episode.rewards[1:4], np.ndarray) # B=3
579
+
580
+ Returns:
581
+ This `SingleAgentEpisode` object with the converted numpy data.
582
+ """
583
+
584
+ self.observations.finalize()
585
+ if len(self) > 0:
586
+ self.actions.finalize()
587
+ self.rewards.finalize()
588
+ for k, v in self.extra_model_outputs.items():
589
+ self.extra_model_outputs[k].finalize()
590
+
591
+ return self
592
+
593
+ def concat_episode(self, other: "SingleAgentEpisode") -> None:
594
+ """Adds the given `other` SingleAgentEpisode to the right side of self.
595
+
596
+ In order for this to work, both chunks (`self` and `other`) must fit
597
+ together. This is checked by the IDs (must be identical), the time step counters
598
+ (`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the
599
+ observations/infos at the concatenation boundaries. Also, `self.is_done` must
600
+ not be True, meaning `self.is_terminated` and `self.is_truncated` are both
601
+ False.
602
+
603
+ Args:
604
+ other: The other `SingleAgentEpisode` to be concatenated to this one.
605
+
606
+ Returns: A `SingleAgentEpisode` instance containing the concatenated data
607
+ from both episodes (`self` and `other`).
608
+ """
609
+ assert other.id_ == self.id_
610
+ # NOTE (sven): This is what we agreed on. As the replay buffers must be
611
+ # able to concatenate.
612
+ assert not self.is_done
613
+ # Make sure the timesteps match.
614
+ assert self.t == other.t_started
615
+ # Validate `other`.
616
+ other.validate()
617
+
618
+ # Make sure, end matches other episode chunk's beginning.
619
+ assert np.all(other.observations[0] == self.observations[-1])
620
+ # Pop out our last observations and infos (as these are identical
621
+ # to the first obs and infos in the next episode).
622
+ self.observations.pop()
623
+ self.infos.pop()
624
+
625
+ # Extend ourselves. In case, episode_chunk is already terminated and numpy'ized
626
+ # we need to convert to lists (as we are ourselves still filling up lists).
627
+ self.observations.extend(other.get_observations())
628
+ self.actions.extend(other.get_actions())
629
+ self.rewards.extend(other.get_rewards())
630
+ self.infos.extend(other.get_infos())
631
+ self.t = other.t
632
+
633
+ if other.is_terminated:
634
+ self.is_terminated = True
635
+ elif other.is_truncated:
636
+ self.is_truncated = True
637
+
638
+ for key in other.extra_model_outputs.keys():
639
+ assert key in self.extra_model_outputs
640
+ self.extra_model_outputs[key].extend(other.get_extra_model_outputs(key))
641
+
642
+ # Validate.
643
+ self.validate()
644
+
645
+ def cut(self, len_lookback_buffer: int = 0) -> "SingleAgentEpisode":
646
+ """Returns a successor episode chunk (of len=0) continuing from this Episode.
647
+
648
+ The successor will have the same ID as `self`.
649
+ If no lookback buffer is requested (len_lookback_buffer=0), the successor's
650
+ observations will be the last observation(s) of `self` and its length will
651
+ therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0,
652
+ the returned successor will have `len_lookback_buffer` observations (and
653
+ actions, rewards, etc..) taken from the right side (end) of `self`. For example
654
+ if `len_lookback_buffer=2`, the returned successor's lookback buffer actions
655
+ will be identical to `self.actions[-2:]`.
656
+
657
+ This method is useful if you would like to discontinue building an episode
658
+ chunk (b/c you have to return it from somewhere), but would like to have a new
659
+ episode instance to continue building the actual gym.Env episode at a later
660
+ time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor)
661
+ will still be able to "look back" into this predecessor episode's data (at
662
+ least to some extend, depending on the value of `len_lookback_buffer`).
663
+
664
+ Args:
665
+ len_lookback_buffer: The number of timesteps to take along into the new
666
+ chunk as "lookback buffer". A lookback buffer is additional data on
667
+ the left side of the actual episode data for visibility purposes
668
+ (but without actually being part of the new chunk). For example, if
669
+ `self` ends in actions 5, 6, 7, and 8, and we call
670
+ `self.cut(len_lookback_buffer=2)`, the returned chunk will have
671
+ actions 7 and 8 already in it, but still `t_started`==t==8 (not 7!) and
672
+ a length of 0. If there is not enough data in `self` yet to fulfil
673
+ the `len_lookback_buffer` request, the value of `len_lookback_buffer`
674
+ is automatically adjusted (lowered).
675
+
676
+ Returns:
677
+ The successor Episode chunk of this one with the same ID and state and the
678
+ only observation being the last observation in self.
679
+ """
680
+ assert not self.is_done and len_lookback_buffer >= 0
681
+
682
+ # Initialize this chunk with the most recent obs and infos (even if lookback is
683
+ # 0). Similar to an initial `env.reset()`.
684
+ indices_obs_and_infos = slice(-len_lookback_buffer - 1, None)
685
+ indices_rest = (
686
+ slice(-len_lookback_buffer, None)
687
+ if len_lookback_buffer > 0
688
+ else slice(None, 0)
689
+ )
690
+
691
+ # Erase all temporary timestep data caches in `self`.
692
+ self._temporary_timestep_data.clear()
693
+
694
+ return SingleAgentEpisode(
695
+ # Same ID.
696
+ id_=self.id_,
697
+ observations=self.get_observations(indices=indices_obs_and_infos),
698
+ observation_space=self.observation_space,
699
+ infos=self.get_infos(indices=indices_obs_and_infos),
700
+ actions=self.get_actions(indices=indices_rest),
701
+ action_space=self.action_space,
702
+ rewards=self.get_rewards(indices=indices_rest),
703
+ extra_model_outputs={
704
+ k: self.get_extra_model_outputs(k, indices_rest)
705
+ for k in self.extra_model_outputs.keys()
706
+ },
707
+ # Continue with self's current timestep.
708
+ t_started=self.t,
709
+ # Use the length of the provided data as lookback buffer.
710
+ len_lookback_buffer="auto",
711
+ )
712
+
713
+ # TODO (sven): Distinguish between:
714
+ # - global index: This is the absolute, global timestep whose values always
715
+ # start from 0 (at the env reset). So doing get_observations(0, global_ts=True)
716
+ # should always return the exact 1st observation (reset obs), no matter what. In
717
+ # case we are in an episode chunk and `fill` or a sufficient lookback buffer is
718
+ # provided, this should yield a result. Otherwise, error.
719
+ # - global index=False -> indices are relative to the chunk start. If a chunk has
720
+ # t_started=6 and we ask for index=0, then return observation at timestep 6
721
+ # (t_started).
722
+ def get_observations(
723
+ self,
724
+ indices: Optional[Union[int, List[int], slice]] = None,
725
+ *,
726
+ neg_index_as_lookback: bool = False,
727
+ fill: Optional[Any] = None,
728
+ one_hot_discrete: bool = False,
729
+ ) -> Any:
730
+ """Returns individual observations or batched ranges thereof from this episode.
731
+
732
+ Args:
733
+ indices: A single int is interpreted as an index, from which to return the
734
+ individual observation stored at this index.
735
+ A list of ints is interpreted as a list of indices from which to gather
736
+ individual observations in a batch of size len(indices).
737
+ A slice object is interpreted as a range of observations to be returned.
738
+ Thereby, negative indices by default are interpreted as "before the end"
739
+ unless the `neg_index_as_lookback=True` option is used, in which case
740
+ negative indices are interpreted as "before ts=0", meaning going back
741
+ into the lookback buffer.
742
+ If None, will return all observations (from ts=0 to the end).
743
+ neg_index_as_lookback: If True, negative values in `indices` are
744
+ interpreted as "before ts=0", meaning going back into the lookback
745
+ buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9],
746
+ where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will
747
+ respond to `get_observations(-1, neg_index_as_lookback=True)`
748
+ with `6` and to
749
+ `get_observations(slice(-2, 1), neg_index_as_lookback=True)` with
750
+ `[5, 6, 7]`.
751
+ fill: An optional value to use for filling up the returned results at
752
+ the boundaries. This filling only happens if the requested index range's
753
+ start/stop boundaries exceed the episode's boundaries (including the
754
+ lookback buffer on the left side). This comes in very handy, if users
755
+ don't want to worry about reaching such boundaries and want to zero-pad.
756
+ For example, an episode with observations [10, 11, 12, 13, 14] and
757
+ lookback buffer size of 2 (meaning observations `10` and `11` are part
758
+ of the lookback buffer) will respond to
759
+ `get_observations(slice(-7, -2), fill=0.0)` with
760
+ `[0.0, 0.0, 10, 11, 12]`.
761
+ one_hot_discrete: If True, will return one-hot vectors (instead of
762
+ int-values) for those sub-components of a (possibly complex) observation
763
+ space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
764
+ requested `indices` are out of the range of our data, the returned
765
+ one-hot vectors will actually be zero-hot (all slots zero).
766
+
767
+ Examples:
768
+
769
+ .. testcode::
770
+
771
+ import gymnasium as gym
772
+
773
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
774
+ from ray.rllib.utils.test_utils import check
775
+
776
+ episode = SingleAgentEpisode(
777
+ # Discrete(4) observations (ints between 0 and 4 (excl.))
778
+ observation_space=gym.spaces.Discrete(4),
779
+ observations=[0, 1, 2, 3],
780
+ actions=[1, 2, 3], rewards=[1, 2, 3], # <- not relevant for this demo
781
+ len_lookback_buffer=0, # no lookback; all data is actually "in" episode
782
+ )
783
+ # Plain usage (`indices` arg only).
784
+ check(episode.get_observations(-1), 3)
785
+ check(episode.get_observations(0), 0)
786
+ check(episode.get_observations([0, 2]), [0, 2])
787
+ check(episode.get_observations([-1, 0]), [3, 0])
788
+ check(episode.get_observations(slice(None, 2)), [0, 1])
789
+ check(episode.get_observations(slice(-2, None)), [2, 3])
790
+ # Using `fill=...` (requesting slices beyond the boundaries).
791
+ check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1])
792
+ check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7])
793
+ # Using `one_hot_discrete=True`.
794
+ check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0])
795
+ check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1])
796
+ check(episode.get_observations(
797
+ slice(0, 3),
798
+ one_hot_discrete=True,
799
+ ), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]])
800
+ # Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
801
+ check(episode.get_observations(
802
+ -1,
803
+ neg_index_as_lookback=True, # -1 means one left of ts=0
804
+ fill=0.0,
805
+ one_hot_discrete=True,
806
+ ), [0, 0, 0, 0]) # <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
807
+
808
+ Returns:
809
+ The collected observations.
810
+ As a 0-axis batch, if there are several `indices` or a list of exactly one
811
+ index provided OR `indices` is a slice object.
812
+ As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
813
+ """
814
+ return self.observations.get(
815
+ indices=indices,
816
+ neg_index_as_lookback=neg_index_as_lookback,
817
+ fill=fill,
818
+ one_hot_discrete=one_hot_discrete,
819
+ )
820
+
821
+ def get_infos(
822
+ self,
823
+ indices: Optional[Union[int, List[int], slice]] = None,
824
+ *,
825
+ neg_index_as_lookback: bool = False,
826
+ fill: Optional[Any] = None,
827
+ ) -> Any:
828
+ """Returns individual info dicts or list (ranges) thereof from this episode.
829
+
830
+ Args:
831
+ indices: A single int is interpreted as an index, from which to return the
832
+ individual info dict stored at this index.
833
+ A list of ints is interpreted as a list of indices from which to gather
834
+ individual info dicts in a list of size len(indices).
835
+ A slice object is interpreted as a range of info dicts to be returned.
836
+ Thereby, negative indices by default are interpreted as "before the end"
837
+ unless the `neg_index_as_lookback=True` option is used, in which case
838
+ negative indices are interpreted as "before ts=0", meaning going back
839
+ into the lookback buffer.
840
+ If None, will return all infos (from ts=0 to the end).
841
+ neg_index_as_lookback: If True, negative values in `indices` are
842
+ interpreted as "before ts=0", meaning going back into the lookback
843
+ buffer. For example, an episode with infos
844
+ [{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the
845
+ first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will
846
+ respond to `get_infos(-1, neg_index_as_lookback=True)` with
847
+ `{"l":6}` and to
848
+ `get_infos(slice(-2, 1), neg_index_as_lookback=True)` with
849
+ `[{"l":5}, {"l":6}, {"a":7}]`.
850
+ fill: An optional value to use for filling up the returned results at
851
+ the boundaries. This filling only happens if the requested index range's
852
+ start/stop boundaries exceed the episode's boundaries (including the
853
+ lookback buffer on the left side). This comes in very handy, if users
854
+ don't want to worry about reaching such boundaries and want to
855
+ auto-fill. For example, an episode with infos
856
+ [{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer
857
+ size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback
858
+ buffer) will respond to `get_infos(slice(-7, -2), fill={"o": 0.0})`
859
+ with `[{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]`.
860
+
861
+ Examples:
862
+
863
+ .. testcode::
864
+
865
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
866
+
867
+ episode = SingleAgentEpisode(
868
+ infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}],
869
+ # The following is needed, but not relevant for this demo.
870
+ observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
871
+ len_lookback_buffer=0, # no lookback; all data is actually "in" episode
872
+ )
873
+ # Plain usage (`indices` arg only).
874
+ episode.get_infos(-1) # {"d":3}
875
+ episode.get_infos(0) # {"a":0}
876
+ episode.get_infos([0, 2]) # [{"a":0},{"c":2}]
877
+ episode.get_infos([-1, 0]) # [{"d":3},{"a":0}]
878
+ episode.get_infos(slice(None, 2)) # [{"a":0},{"b":1}]
879
+ episode.get_infos(slice(-2, None)) # [{"c":2},{"d":3}]
880
+ # Using `fill=...` (requesting slices beyond the boundaries).
881
+ # TODO (sven): This would require a space being provided. Maybe we can
882
+ # skip this check for infos, which don't have a space anyways.
883
+ # episode.get_infos(slice(-5, -3), fill={"o":-1}) # [{"o":-1},{"a":0}]
884
+ # episode.get_infos(slice(3, 5), fill={"o":-2}) # [{"d":3},{"o":-2}]
885
+
886
+ Returns:
887
+ The collected info dicts.
888
+ As a 0-axis batch, if there are several `indices` or a list of exactly one
889
+ index provided OR `indices` is a slice object.
890
+ As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
891
+ """
892
+ return self.infos.get(
893
+ indices=indices,
894
+ neg_index_as_lookback=neg_index_as_lookback,
895
+ fill=fill,
896
+ )
897
+
898
+ def get_actions(
899
+ self,
900
+ indices: Optional[Union[int, List[int], slice]] = None,
901
+ *,
902
+ neg_index_as_lookback: bool = False,
903
+ fill: Optional[Any] = None,
904
+ one_hot_discrete: bool = False,
905
+ ) -> Any:
906
+ """Returns individual actions or batched ranges thereof from this episode.
907
+
908
+ Args:
909
+ indices: A single int is interpreted as an index, from which to return the
910
+ individual action stored at this index.
911
+ A list of ints is interpreted as a list of indices from which to gather
912
+ individual actions in a batch of size len(indices).
913
+ A slice object is interpreted as a range of actions to be returned.
914
+ Thereby, negative indices by default are interpreted as "before the end"
915
+ unless the `neg_index_as_lookback=True` option is used, in which case
916
+ negative indices are interpreted as "before ts=0", meaning going back
917
+ into the lookback buffer.
918
+ If None, will return all actions (from ts=0 to the end).
919
+ neg_index_as_lookback: If True, negative values in `indices` are
920
+ interpreted as "before ts=0", meaning going back into the lookback
921
+ buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where
922
+ [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond
923
+ to `get_actions(-1, neg_index_as_lookback=True)` with `6` and
924
+ to `get_actions(slice(-2, 1), neg_index_as_lookback=True)` with
925
+ `[5, 6, 7]`.
926
+ fill: An optional value to use for filling up the returned results at
927
+ the boundaries. This filling only happens if the requested index range's
928
+ start/stop boundaries exceed the episode's boundaries (including the
929
+ lookback buffer on the left side). This comes in very handy, if users
930
+ don't want to worry about reaching such boundaries and want to zero-pad.
931
+ For example, an episode with actions [10, 11, 12, 13, 14] and
932
+ lookback buffer size of 2 (meaning actions `10` and `11` are part
933
+ of the lookback buffer) will respond to
934
+ `get_actions(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`.
935
+ one_hot_discrete: If True, will return one-hot vectors (instead of
936
+ int-values) for those sub-components of a (possibly complex) action
937
+ space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
938
+ requested `indices` are out of the range of our data, the returned
939
+ one-hot vectors will actually be zero-hot (all slots zero).
940
+
941
+ Examples:
942
+
943
+ .. testcode::
944
+
945
+ import gymnasium as gym
946
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
947
+
948
+ episode = SingleAgentEpisode(
949
+ # Discrete(4) actions (ints between 0 and 4 (excl.))
950
+ action_space=gym.spaces.Discrete(4),
951
+ actions=[1, 2, 3],
952
+ observations=[0, 1, 2, 3], rewards=[1, 2, 3], # <- not relevant here
953
+ len_lookback_buffer=0, # no lookback; all data is actually "in" episode
954
+ )
955
+ # Plain usage (`indices` arg only).
956
+ episode.get_actions(-1) # 3
957
+ episode.get_actions(0) # 1
958
+ episode.get_actions([0, 2]) # [1, 3]
959
+ episode.get_actions([-1, 0]) # [3, 1]
960
+ episode.get_actions(slice(None, 2)) # [1, 2]
961
+ episode.get_actions(slice(-2, None)) # [2, 3]
962
+ # Using `fill=...` (requesting slices beyond the boundaries).
963
+ episode.get_actions(slice(-5, -2), fill=-9) # [-9, -9, 1, 2]
964
+ episode.get_actions(slice(1, 5), fill=-7) # [2, 3, -7, -7]
965
+ # Using `one_hot_discrete=True`.
966
+ episode.get_actions(1, one_hot_discrete=True) # [0 0 1 0] (action=2)
967
+ episode.get_actions(2, one_hot_discrete=True) # [0 0 0 1] (action=3)
968
+ episode.get_actions(
969
+ slice(0, 2),
970
+ one_hot_discrete=True,
971
+ ) # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3)
972
+ # Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
973
+ episode.get_actions(
974
+ -1,
975
+ neg_index_as_lookback=True, # -1 means one left of ts=0
976
+ fill=0.0,
977
+ one_hot_discrete=True,
978
+ ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
979
+
980
+ Returns:
981
+ The collected actions.
982
+ As a 0-axis batch, if there are several `indices` or a list of exactly one
983
+ index provided OR `indices` is a slice object.
984
+ As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
985
+ """
986
+ return self.actions.get(
987
+ indices=indices,
988
+ neg_index_as_lookback=neg_index_as_lookback,
989
+ fill=fill,
990
+ one_hot_discrete=one_hot_discrete,
991
+ )
992
+
993
+ def get_rewards(
994
+ self,
995
+ indices: Optional[Union[int, List[int], slice]] = None,
996
+ *,
997
+ neg_index_as_lookback: bool = False,
998
+ fill: Optional[float] = None,
999
+ ) -> Any:
1000
+ """Returns individual rewards or batched ranges thereof from this episode.
1001
+
1002
+ Args:
1003
+ indices: A single int is interpreted as an index, from which to return the
1004
+ individual reward stored at this index.
1005
+ A list of ints is interpreted as a list of indices from which to gather
1006
+ individual rewards in a batch of size len(indices).
1007
+ A slice object is interpreted as a range of rewards to be returned.
1008
+ Thereby, negative indices by default are interpreted as "before the end"
1009
+ unless the `neg_index_as_lookback=True` option is used, in which case
1010
+ negative indices are interpreted as "before ts=0", meaning going back
1011
+ into the lookback buffer.
1012
+ If None, will return all rewards (from ts=0 to the end).
1013
+ neg_index_as_lookback: Negative values in `indices` are interpreted as
1014
+ as "before ts=0", meaning going back into the lookback buffer.
1015
+ For example, an episode with rewards [4, 5, 6, 7, 8, 9], where
1016
+ [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond
1017
+ to `get_rewards(-1, neg_index_as_lookback=True)` with `6` and
1018
+ to `get_rewards(slice(-2, 1), neg_index_as_lookback=True)` with
1019
+ `[5, 6, 7]`.
1020
+ fill: An optional float value to use for filling up the returned results at
1021
+ the boundaries. This filling only happens if the requested index range's
1022
+ start/stop boundaries exceed the episode's boundaries (including the
1023
+ lookback buffer on the left side). This comes in very handy, if users
1024
+ don't want to worry about reaching such boundaries and want to zero-pad.
1025
+ For example, an episode with rewards [10, 11, 12, 13, 14] and
1026
+ lookback buffer size of 2 (meaning rewards `10` and `11` are part
1027
+ of the lookback buffer) will respond to
1028
+ `get_rewards(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`.
1029
+
1030
+ Examples:
1031
+
1032
+ .. testcode::
1033
+
1034
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
1035
+
1036
+ episode = SingleAgentEpisode(
1037
+ rewards=[1.0, 2.0, 3.0],
1038
+ observations=[0, 1, 2, 3], actions=[1, 2, 3], # <- not relevant here
1039
+ len_lookback_buffer=0, # no lookback; all data is actually "in" episode
1040
+ )
1041
+ # Plain usage (`indices` arg only).
1042
+ episode.get_rewards(-1) # 3.0
1043
+ episode.get_rewards(0) # 1.0
1044
+ episode.get_rewards([0, 2]) # [1.0, 3.0]
1045
+ episode.get_rewards([-1, 0]) # [3.0, 1.0]
1046
+ episode.get_rewards(slice(None, 2)) # [1.0, 2.0]
1047
+ episode.get_rewards(slice(-2, None)) # [2.0, 3.0]
1048
+ # Using `fill=...` (requesting slices beyond the boundaries).
1049
+ episode.get_rewards(slice(-5, -2), fill=0.0) # [0.0, 0.0, 1.0, 2.0]
1050
+ episode.get_rewards(slice(1, 5), fill=0.0) # [2.0, 3.0, 0.0, 0.0]
1051
+
1052
+ Returns:
1053
+ The collected rewards.
1054
+ As a 0-axis batch, if there are several `indices` or a list of exactly one
1055
+ index provided OR `indices` is a slice object.
1056
+ As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
1057
+ """
1058
+ return self.rewards.get(
1059
+ indices=indices,
1060
+ neg_index_as_lookback=neg_index_as_lookback,
1061
+ fill=fill,
1062
+ )
1063
+
1064
+ def get_extra_model_outputs(
1065
+ self,
1066
+ key: str,
1067
+ indices: Optional[Union[int, List[int], slice]] = None,
1068
+ *,
1069
+ neg_index_as_lookback: bool = False,
1070
+ fill: Optional[Any] = None,
1071
+ ) -> Any:
1072
+ """Returns extra model outputs (under given key) from this episode.
1073
+
1074
+ Args:
1075
+ key: The `key` within `self.extra_model_outputs` to extract data for.
1076
+ indices: A single int is interpreted as an index, from which to return an
1077
+ individual extra model output stored under `key` at index.
1078
+ A list of ints is interpreted as a list of indices from which to gather
1079
+ individual actions in a batch of size len(indices).
1080
+ A slice object is interpreted as a range of extra model outputs to be
1081
+ returned. Thereby, negative indices by default are interpreted as
1082
+ "before the end" unless the `neg_index_as_lookback=True` option is
1083
+ used, in which case negative indices are interpreted as "before ts=0",
1084
+ meaning going back into the lookback buffer.
1085
+ If None, will return all extra model outputs (from ts=0 to the end).
1086
+ neg_index_as_lookback: If True, negative values in `indices` are
1087
+ interpreted as "before ts=0", meaning going back into the lookback
1088
+ buffer. For example, an episode with
1089
+ extra_model_outputs['a'] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
1090
+ lookback buffer range (ts=0 item is 7), will respond to
1091
+ `get_extra_model_outputs("a", -1, neg_index_as_lookback=True)` with
1092
+ `6` and to `get_extra_model_outputs("a", slice(-2, 1),
1093
+ neg_index_as_lookback=True)` with `[5, 6, 7]`.
1094
+ fill: An optional value to use for filling up the returned results at
1095
+ the boundaries. This filling only happens if the requested index range's
1096
+ start/stop boundaries exceed the episode's boundaries (including the
1097
+ lookback buffer on the left side). This comes in very handy, if users
1098
+ don't want to worry about reaching such boundaries and want to zero-pad.
1099
+ For example, an episode with
1100
+ extra_model_outputs["b"] = [10, 11, 12, 13, 14] and lookback buffer
1101
+ size of 2 (meaning `10` and `11` are part of the lookback buffer) will
1102
+ respond to
1103
+ `get_extra_model_outputs("b", slice(-7, -2), fill=0.0)` with
1104
+ `[0.0, 0.0, 10, 11, 12]`.
1105
+ TODO (sven): This would require a space being provided. Maybe we can
1106
+ automatically infer the space from existing data?
1107
+
1108
+ Examples:
1109
+
1110
+ .. testcode::
1111
+
1112
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
1113
+
1114
+ episode = SingleAgentEpisode(
1115
+ extra_model_outputs={"mo": [1, 2, 3]},
1116
+ len_lookback_buffer=0, # no lookback; all data is actually "in" episode
1117
+ # The following is needed, but not relevant for this demo.
1118
+ observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
1119
+ )
1120
+
1121
+ # Plain usage (`indices` arg only).
1122
+ episode.get_extra_model_outputs("mo", -1) # 3
1123
+ episode.get_extra_model_outputs("mo", 1) # 0
1124
+ episode.get_extra_model_outputs("mo", [0, 2]) # [1, 3]
1125
+ episode.get_extra_model_outputs("mo", [-1, 0]) # [3, 1]
1126
+ episode.get_extra_model_outputs("mo", slice(None, 2)) # [1, 2]
1127
+ episode.get_extra_model_outputs("mo", slice(-2, None)) # [2, 3]
1128
+ # Using `fill=...` (requesting slices beyond the boundaries).
1129
+ # TODO (sven): This would require a space being provided. Maybe we can
1130
+ # automatically infer the space from existing data?
1131
+ # episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0) # [0, 0, 1]
1132
+ # episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1) # [3, -1, -1]
1133
+
1134
+ Returns:
1135
+ The collected extra_model_outputs[`key`].
1136
+ As a 0-axis batch, if there are several `indices` or a list of exactly one
1137
+ index provided OR `indices` is a slice object.
1138
+ As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
1139
+ """
1140
+ value = self.extra_model_outputs[key]
1141
+ # The expected case is: `value` is a `InfiniteLookbackBuffer`.
1142
+ if isinstance(value, InfiniteLookbackBuffer):
1143
+ return value.get(
1144
+ indices=indices,
1145
+ neg_index_as_lookback=neg_index_as_lookback,
1146
+ fill=fill,
1147
+ )
1148
+ # TODO (sven): This does not seem to be solid yet. Users should NOT be able
1149
+ # to just write directly into our buffers. Instead, use:
1150
+ # `self.set_extra_model_outputs(key, new_data, at_indices=...)` and if key
1151
+ # is not known, add a new buffer to the `extra_model_outputs` dict.
1152
+ assert False
1153
+ # It might be that the user has added new key/value pairs in their custom
1154
+ # postprocessing/connector logic. The values are then most likely numpy
1155
+ # arrays. We convert them automatically to buffers and get the requested
1156
+ # indices (with the given options) from there.
1157
+ return InfiniteLookbackBuffer(value).get(
1158
+ indices, fill=fill, neg_index_as_lookback=neg_index_as_lookback
1159
+ )
1160
+
1161
+ def set_observations(
1162
+ self,
1163
+ *,
1164
+ new_data,
1165
+ at_indices: Optional[Union[int, List[int], slice]] = None,
1166
+ neg_index_as_lookback: bool = False,
1167
+ ) -> None:
1168
+ """Overwrites all or some of this Episode's observations with the provided data.
1169
+
1170
+ Note that an episode's observation data cannot be written to directly as it is
1171
+ managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
1172
+ observations are added to the episode either by calling `self.add_env_step` or
1173
+ more directly (and manually) via `self.observations.append|extend()`.
1174
+ However, for certain postprocessing steps, the entirety (or a slice) of an
1175
+ episode's observations might have to be rewritten, which is when
1176
+ `self.set_observations()` should be used.
1177
+
1178
+ Args:
1179
+ new_data: The new observation data to overwrite existing data with.
1180
+ This may be a list of individual observation(s) in case this episode
1181
+ is still not numpy'ized yet. In case this episode has already been
1182
+ numpy'ized, this should be (possibly complex) struct matching the
1183
+ observation space and with a batch size of its leafs exactly the size
1184
+ of the to-be-overwritten slice or segment (provided by `at_indices`).
1185
+ at_indices: A single int is interpreted as one index, which to overwrite
1186
+ with `new_data` (which is expected to be a single observation).
1187
+ A list of ints is interpreted as a list of indices, all of which to
1188
+ overwrite with `new_data` (which is expected to be of the same size
1189
+ as `len(at_indices)`).
1190
+ A slice object is interpreted as a range of indices to be overwritten
1191
+ with `new_data` (which is expected to be of the same size as the
1192
+ provided slice).
1193
+ Thereby, negative indices by default are interpreted as "before the end"
1194
+ unless the `neg_index_as_lookback=True` option is used, in which case
1195
+ negative indices are interpreted as "before ts=0", meaning going back
1196
+ into the lookback buffer.
1197
+ neg_index_as_lookback: If True, negative values in `at_indices` are
1198
+ interpreted as "before ts=0", meaning going back into the lookback
1199
+ buffer. For example, an episode with
1200
+ observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
1201
+ lookback buffer range (ts=0 item is 7), will handle a call to
1202
+ `set_observations(individual_observation, -1,
1203
+ neg_index_as_lookback=True)` by overwriting the value of 6 in our
1204
+ observations buffer with the provided "individual_observation".
1205
+
1206
+ Raises:
1207
+ IndexError: If the provided `at_indices` do not match the size of
1208
+ `new_data`.
1209
+ """
1210
+ self.observations.set(
1211
+ new_data=new_data,
1212
+ at_indices=at_indices,
1213
+ neg_index_as_lookback=neg_index_as_lookback,
1214
+ )
1215
+
1216
+ def set_actions(
1217
+ self,
1218
+ *,
1219
+ new_data,
1220
+ at_indices: Optional[Union[int, List[int], slice]] = None,
1221
+ neg_index_as_lookback: bool = False,
1222
+ ) -> None:
1223
+ """Overwrites all or some of this Episode's actions with the provided data.
1224
+
1225
+ Note that an episode's action data cannot be written to directly as it is
1226
+ managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
1227
+ actions are added to the episode either by calling `self.add_env_step` or
1228
+ more directly (and manually) via `self.actions.append|extend()`.
1229
+ However, for certain postprocessing steps, the entirety (or a slice) of an
1230
+ episode's actions might have to be rewritten, which is when
1231
+ `self.set_actions()` should be used.
1232
+
1233
+ Args:
1234
+ new_data: The new action data to overwrite existing data with.
1235
+ This may be a list of individual action(s) in case this episode
1236
+ is still not numpy'ized yet. In case this episode has already been
1237
+ numpy'ized, this should be (possibly complex) struct matching the
1238
+ action space and with a batch size of its leafs exactly the size
1239
+ of the to-be-overwritten slice or segment (provided by `at_indices`).
1240
+ at_indices: A single int is interpreted as one index, which to overwrite
1241
+ with `new_data` (which is expected to be a single action).
1242
+ A list of ints is interpreted as a list of indices, all of which to
1243
+ overwrite with `new_data` (which is expected to be of the same size
1244
+ as `len(at_indices)`).
1245
+ A slice object is interpreted as a range of indices to be overwritten
1246
+ with `new_data` (which is expected to be of the same size as the
1247
+ provided slice).
1248
+ Thereby, negative indices by default are interpreted as "before the end"
1249
+ unless the `neg_index_as_lookback=True` option is used, in which case
1250
+ negative indices are interpreted as "before ts=0", meaning going back
1251
+ into the lookback buffer.
1252
+ neg_index_as_lookback: If True, negative values in `at_indices` are
1253
+ interpreted as "before ts=0", meaning going back into the lookback
1254
+ buffer. For example, an episode with
1255
+ actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
1256
+ lookback buffer range (ts=0 item is 7), will handle a call to
1257
+ `set_actions(individual_action, -1,
1258
+ neg_index_as_lookback=True)` by overwriting the value of 6 in our
1259
+ actions buffer with the provided "individual_action".
1260
+
1261
+ Raises:
1262
+ IndexError: If the provided `at_indices` do not match the size of
1263
+ `new_data`.
1264
+ """
1265
+ self.actions.set(
1266
+ new_data=new_data,
1267
+ at_indices=at_indices,
1268
+ neg_index_as_lookback=neg_index_as_lookback,
1269
+ )
1270
+
1271
+ def set_rewards(
1272
+ self,
1273
+ *,
1274
+ new_data,
1275
+ at_indices: Optional[Union[int, List[int], slice]] = None,
1276
+ neg_index_as_lookback: bool = False,
1277
+ ) -> None:
1278
+ """Overwrites all or some of this Episode's rewards with the provided data.
1279
+
1280
+ Note that an episode's reward data cannot be written to directly as it is
1281
+ managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
1282
+ rewards are added to the episode either by calling `self.add_env_step` or
1283
+ more directly (and manually) via `self.rewards.append|extend()`.
1284
+ However, for certain postprocessing steps, the entirety (or a slice) of an
1285
+ episode's rewards might have to be rewritten, which is when
1286
+ `self.set_rewards()` should be used.
1287
+
1288
+ Args:
1289
+ new_data: The new reward data to overwrite existing data with.
1290
+ This may be a list of individual reward(s) in case this episode
1291
+ is still not numpy'ized yet. In case this episode has already been
1292
+ numpy'ized, this should be a np.ndarray with a length exactly
1293
+ the size of the to-be-overwritten slice or segment (provided by
1294
+ `at_indices`).
1295
+ at_indices: A single int is interpreted as one index, which to overwrite
1296
+ with `new_data` (which is expected to be a single reward).
1297
+ A list of ints is interpreted as a list of indices, all of which to
1298
+ overwrite with `new_data` (which is expected to be of the same size
1299
+ as `len(at_indices)`).
1300
+ A slice object is interpreted as a range of indices to be overwritten
1301
+ with `new_data` (which is expected to be of the same size as the
1302
+ provided slice).
1303
+ Thereby, negative indices by default are interpreted as "before the end"
1304
+ unless the `neg_index_as_lookback=True` option is used, in which case
1305
+ negative indices are interpreted as "before ts=0", meaning going back
1306
+ into the lookback buffer.
1307
+ neg_index_as_lookback: If True, negative values in `at_indices` are
1308
+ interpreted as "before ts=0", meaning going back into the lookback
1309
+ buffer. For example, an episode with
1310
+ rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
1311
+ lookback buffer range (ts=0 item is 7), will handle a call to
1312
+ `set_rewards(individual_reward, -1,
1313
+ neg_index_as_lookback=True)` by overwriting the value of 6 in our
1314
+ rewards buffer with the provided "individual_reward".
1315
+
1316
+ Raises:
1317
+ IndexError: If the provided `at_indices` do not match the size of
1318
+ `new_data`.
1319
+ """
1320
+ self.rewards.set(
1321
+ new_data=new_data,
1322
+ at_indices=at_indices,
1323
+ neg_index_as_lookback=neg_index_as_lookback,
1324
+ )
1325
+
1326
+ def set_extra_model_outputs(
1327
+ self,
1328
+ *,
1329
+ key,
1330
+ new_data,
1331
+ at_indices: Optional[Union[int, List[int], slice]] = None,
1332
+ neg_index_as_lookback: bool = False,
1333
+ ) -> None:
1334
+ """Overwrites all or some of this Episode's extra model outputs with `new_data`.
1335
+
1336
+ Note that an episode's `extra_model_outputs` data cannot be written to directly
1337
+ as it is managed by a `InfiniteLookbackBuffer` object. Normally, individual,
1338
+ current `extra_model_output` values are added to the episode either by calling
1339
+ `self.add_env_step` or more directly (and manually) via
1340
+ `self.extra_model_outputs[key].append|extend()`. However, for certain
1341
+ postprocessing steps, the entirety (or a slice) of an episode's
1342
+ `extra_model_outputs` might have to be rewritten or a new key (a new type of
1343
+ `extra_model_outputs`) must be inserted, which is when
1344
+ `self.set_extra_model_outputs()` should be used.
1345
+
1346
+ Args:
1347
+ key: The `key` within `self.extra_model_outputs` to override data on or
1348
+ to insert as a new key into `self.extra_model_outputs`.
1349
+ new_data: The new data to overwrite existing data with.
1350
+ This may be a list of individual reward(s) in case this episode
1351
+ is still not numpy'ized yet. In case this episode has already been
1352
+ numpy'ized, this should be a np.ndarray with a length exactly
1353
+ the size of the to-be-overwritten slice or segment (provided by
1354
+ `at_indices`).
1355
+ at_indices: A single int is interpreted as one index, which to overwrite
1356
+ with `new_data` (which is expected to be a single reward).
1357
+ A list of ints is interpreted as a list of indices, all of which to
1358
+ overwrite with `new_data` (which is expected to be of the same size
1359
+ as `len(at_indices)`).
1360
+ A slice object is interpreted as a range of indices to be overwritten
1361
+ with `new_data` (which is expected to be of the same size as the
1362
+ provided slice).
1363
+ Thereby, negative indices by default are interpreted as "before the end"
1364
+ unless the `neg_index_as_lookback=True` option is used, in which case
1365
+ negative indices are interpreted as "before ts=0", meaning going back
1366
+ into the lookback buffer.
1367
+ neg_index_as_lookback: If True, negative values in `at_indices` are
1368
+ interpreted as "before ts=0", meaning going back into the lookback
1369
+ buffer. For example, an episode with
1370
+ rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
1371
+ lookback buffer range (ts=0 item is 7), will handle a call to
1372
+ `set_rewards(individual_reward, -1,
1373
+ neg_index_as_lookback=True)` by overwriting the value of 6 in our
1374
+ rewards buffer with the provided "individual_reward".
1375
+
1376
+ Raises:
1377
+ IndexError: If the provided `at_indices` do not match the size of
1378
+ `new_data`.
1379
+ """
1380
+ # Record already exists -> Set existing record's data to new values.
1381
+ assert key in self.extra_model_outputs
1382
+ self.extra_model_outputs[key].set(
1383
+ new_data=new_data,
1384
+ at_indices=at_indices,
1385
+ neg_index_as_lookback=neg_index_as_lookback,
1386
+ )
1387
+
1388
+ def add_temporary_timestep_data(self, key: str, data: Any) -> None:
1389
+ """Temporarily adds (until `to_numpy()` called) per-timestep data to self.
1390
+
1391
+ The given `data` is appended to a list (`self._temporary_timestep_data`), which
1392
+ is cleared upon calling `self.to_numpy()`. To get the thus-far accumulated
1393
+ temporary timestep data for a certain key, use the `get_temporary_timestep_data`
1394
+ API.
1395
+ Note that the size of the per timestep list is NOT checked or validated against
1396
+ the other, non-temporary data in this episode (like observations).
1397
+
1398
+ Args:
1399
+ key: The key under which to find the list to append `data` to. If `data` is
1400
+ the first data to be added for this key, start a new list.
1401
+ data: The data item (representing a single timestep) to be stored.
1402
+ """
1403
+ if self.is_numpy:
1404
+ raise ValueError(
1405
+ "Cannot use the `add_temporary_timestep_data` API on an already "
1406
+ f"numpy'ized {type(self).__name__}!"
1407
+ )
1408
+ self._temporary_timestep_data[key].append(data)
1409
+
1410
+ def get_temporary_timestep_data(self, key: str) -> List[Any]:
1411
+ """Returns all temporarily stored data items (list) under the given key.
1412
+
1413
+ Note that all temporary timestep data is erased/cleared when calling
1414
+ `self.to_numpy()`.
1415
+
1416
+ Returns:
1417
+ The current list storing temporary timestep data under `key`.
1418
+ """
1419
+ if self.is_numpy:
1420
+ raise ValueError(
1421
+ "Cannot use the `get_temporary_timestep_data` API on an already "
1422
+ f"numpy'ized {type(self).__name__}! All temporary data has been erased "
1423
+ f"upon `{type(self).__name__}.to_numpy()`."
1424
+ )
1425
+ try:
1426
+ return self._temporary_timestep_data[key]
1427
+ except KeyError:
1428
+ raise KeyError(f"Key {key} not found in temporary timestep data!")
1429
+
1430
+ def slice(
1431
+ self,
1432
+ slice_: slice,
1433
+ *,
1434
+ len_lookback_buffer: Optional[int] = None,
1435
+ ) -> "SingleAgentEpisode":
1436
+ """Returns a slice of this episode with the given slice object.
1437
+
1438
+ For example, if `self` contains o0 (the reset observation), o1, o2, o3, and o4
1439
+ and the actions a1, a2, a3, and a4 (len of `self` is 4), then a call to
1440
+ `self.slice(slice(1, 3))` would return a new SingleAgentEpisode with
1441
+ observations o1, o2, and o3, and actions a2 and a3. Note here that there is
1442
+ always one observation more in an episode than there are actions (and rewards
1443
+ and extra model outputs) due to the initial observation received after an env
1444
+ reset.
1445
+
1446
+ .. testcode::
1447
+
1448
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
1449
+ from ray.rllib.utils.test_utils import check
1450
+
1451
+ # Generate a simple multi-agent episode.
1452
+ observations = [0, 1, 2, 3, 4, 5]
1453
+ actions = [1, 2, 3, 4, 5]
1454
+ rewards = [0.1, 0.2, 0.3, 0.4, 0.5]
1455
+ episode = SingleAgentEpisode(
1456
+ observations=observations,
1457
+ actions=actions,
1458
+ rewards=rewards,
1459
+ len_lookback_buffer=0, # all given data is part of the episode
1460
+ )
1461
+ slice_1 = episode[:1]
1462
+ check(slice_1.observations, [0, 1])
1463
+ check(slice_1.actions, [1])
1464
+ check(slice_1.rewards, [0.1])
1465
+
1466
+ slice_2 = episode[-2:]
1467
+ check(slice_2.observations, [3, 4, 5])
1468
+ check(slice_2.actions, [4, 5])
1469
+ check(slice_2.rewards, [0.4, 0.5])
1470
+
1471
+ Args:
1472
+ slice_: The slice object to use for slicing. This should exclude the
1473
+ lookback buffer, which will be prepended automatically to the returned
1474
+ slice.
1475
+ len_lookback_buffer: If not None, forces the returned slice to try to have
1476
+ this number of timesteps in its lookback buffer (if available). If None
1477
+ (default), tries to make the returned slice's lookback as large as the
1478
+ current lookback buffer of this episode (`self`).
1479
+
1480
+ Returns:
1481
+ The new SingleAgentEpisode representing the requested slice.
1482
+ """
1483
+ # Translate `slice_` into one that only contains 0-or-positive ints and will
1484
+ # NOT contain any None.
1485
+ start = slice_.start
1486
+ stop = slice_.stop
1487
+
1488
+ # Start is None -> 0.
1489
+ if start is None:
1490
+ start = 0
1491
+ # Start is negative -> Interpret index as counting "from end".
1492
+ elif start < 0:
1493
+ start = len(self) + start
1494
+
1495
+ # Stop is None -> Set stop to our len (one ts past last valid index).
1496
+ if stop is None:
1497
+ stop = len(self)
1498
+ # Stop is negative -> Interpret index as counting "from end".
1499
+ elif stop < 0:
1500
+ stop = len(self) + stop
1501
+
1502
+ step = slice_.step if slice_.step is not None else 1
1503
+
1504
+ # Figure out, whether slicing stops at the very end of this episode to know
1505
+ # whether `self.is_terminated/is_truncated` should be kept as-is.
1506
+ keep_done = stop == len(self)
1507
+ # Provide correct timestep- and pre-buffer information.
1508
+ t_started = self.t_started + start
1509
+
1510
+ _lb = (
1511
+ len_lookback_buffer
1512
+ if len_lookback_buffer is not None
1513
+ else self.observations.lookback
1514
+ )
1515
+ if (
1516
+ start >= 0
1517
+ and start - _lb < 0
1518
+ and self.observations.lookback < (_lb - start)
1519
+ ):
1520
+ _lb = self.observations.lookback + start
1521
+ observations = InfiniteLookbackBuffer(
1522
+ data=self.get_observations(
1523
+ slice(start - _lb, stop + 1, step),
1524
+ neg_index_as_lookback=True,
1525
+ ),
1526
+ lookback=_lb,
1527
+ space=self.observation_space,
1528
+ )
1529
+
1530
+ _lb = (
1531
+ len_lookback_buffer
1532
+ if len_lookback_buffer is not None
1533
+ else self.infos.lookback
1534
+ )
1535
+ if start >= 0 and start - _lb < 0 and self.infos.lookback < (_lb - start):
1536
+ _lb = self.infos.lookback + start
1537
+ infos = InfiniteLookbackBuffer(
1538
+ data=self.get_infos(
1539
+ slice(start - _lb, stop + 1, step),
1540
+ neg_index_as_lookback=True,
1541
+ ),
1542
+ lookback=_lb,
1543
+ )
1544
+
1545
+ _lb = (
1546
+ len_lookback_buffer
1547
+ if len_lookback_buffer is not None
1548
+ else self.actions.lookback
1549
+ )
1550
+ if start >= 0 and start - _lb < 0 and self.actions.lookback < (_lb - start):
1551
+ _lb = self.actions.lookback + start
1552
+ actions = InfiniteLookbackBuffer(
1553
+ data=self.get_actions(
1554
+ slice(start - _lb, stop, step),
1555
+ neg_index_as_lookback=True,
1556
+ ),
1557
+ lookback=_lb,
1558
+ space=self.action_space,
1559
+ )
1560
+
1561
+ _lb = (
1562
+ len_lookback_buffer
1563
+ if len_lookback_buffer is not None
1564
+ else self.rewards.lookback
1565
+ )
1566
+ if start >= 0 and start - _lb < 0 and self.rewards.lookback < (_lb - start):
1567
+ _lb = self.rewards.lookback + start
1568
+ rewards = InfiniteLookbackBuffer(
1569
+ data=self.get_rewards(
1570
+ slice(start - _lb, stop, step),
1571
+ neg_index_as_lookback=True,
1572
+ ),
1573
+ lookback=_lb,
1574
+ )
1575
+
1576
+ extra_model_outputs = {}
1577
+ for k, v in self.extra_model_outputs.items():
1578
+ _lb = len_lookback_buffer if len_lookback_buffer is not None else v.lookback
1579
+ if start >= 0 and start - _lb < 0 and v.lookback < (_lb - start):
1580
+ _lb = v.lookback + start
1581
+ extra_model_outputs[k] = InfiniteLookbackBuffer(
1582
+ data=self.get_extra_model_outputs(
1583
+ key=k,
1584
+ indices=slice(start - _lb, stop, step),
1585
+ neg_index_as_lookback=True,
1586
+ ),
1587
+ lookback=_lb,
1588
+ )
1589
+
1590
+ return SingleAgentEpisode(
1591
+ id_=self.id_,
1592
+ # In the following, offset `start`s automatically by lookbacks.
1593
+ observations=observations,
1594
+ observation_space=self.observation_space,
1595
+ infos=infos,
1596
+ actions=actions,
1597
+ action_space=self.action_space,
1598
+ rewards=rewards,
1599
+ extra_model_outputs=extra_model_outputs,
1600
+ terminated=(self.is_terminated if keep_done else False),
1601
+ truncated=(self.is_truncated if keep_done else False),
1602
+ t_started=t_started,
1603
+ )
1604
+
1605
+ def get_data_dict(self):
1606
+ """Converts a SingleAgentEpisode into a data dict mapping str keys to data.
1607
+
1608
+ The keys used are:
1609
+ Columns.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS,
1610
+ and those in `self.extra_model_outputs`.
1611
+
1612
+ Returns:
1613
+ A data dict mapping str keys to data records.
1614
+ """
1615
+ t = list(range(self.t_started, self.t))
1616
+ terminateds = [False] * (len(self) - 1) + [self.is_terminated]
1617
+ truncateds = [False] * (len(self) - 1) + [self.is_truncated]
1618
+ eps_id = [self.id_] * len(self)
1619
+
1620
+ if self.is_numpy:
1621
+ t = np.array(t)
1622
+ terminateds = np.array(terminateds)
1623
+ truncateds = np.array(truncateds)
1624
+ eps_id = np.array(eps_id)
1625
+
1626
+ return dict(
1627
+ {
1628
+ # Trivial 1D data (compiled above).
1629
+ Columns.TERMINATEDS: terminateds,
1630
+ Columns.TRUNCATEDS: truncateds,
1631
+ Columns.T: t,
1632
+ Columns.EPS_ID: eps_id,
1633
+ # Retrieve obs, infos, actions, rewards using our get_... APIs,
1634
+ # which return all relevant timesteps (excluding the lookback
1635
+ # buffer!). Slice off last obs and infos to have the same number
1636
+ # of them as we have actions and rewards.
1637
+ Columns.OBS: self.get_observations(slice(None, -1)),
1638
+ Columns.INFOS: self.get_infos(slice(None, -1)),
1639
+ Columns.ACTIONS: self.get_actions(),
1640
+ Columns.REWARDS: self.get_rewards(),
1641
+ },
1642
+ # All `extra_model_outs`: Same as obs: Use get_... API.
1643
+ **{
1644
+ k: self.get_extra_model_outputs(k)
1645
+ for k in self.extra_model_outputs.keys()
1646
+ },
1647
+ )
1648
+
1649
+ def get_sample_batch(self) -> SampleBatch:
1650
+ """Converts this `SingleAgentEpisode` into a `SampleBatch`.
1651
+
1652
+ Returns:
1653
+ A SampleBatch containing all of this episode's data.
1654
+ """
1655
+ return SampleBatch(self.get_data_dict())
1656
+
1657
+ def get_return(self) -> float:
1658
+ """Calculates an episode's return, excluding the lookback buffer's rewards.
1659
+
1660
+ The return is computed by a simple sum, neglecting the discount factor.
1661
+ Note that if `self` is a continuation chunk (resulting from a call to
1662
+ `self.cut()`), the previous chunk's rewards are NOT counted and thus NOT
1663
+ part of the returned reward sum.
1664
+
1665
+ Returns:
1666
+ The sum of rewards collected during this episode, excluding possible data
1667
+ inside the lookback buffer and excluding possible data in a predecessor
1668
+ chunk.
1669
+ """
1670
+ return sum(self.get_rewards())
1671
+
1672
+ def get_duration_s(self) -> float:
1673
+ """Returns the duration of this Episode (chunk) in seconds."""
1674
+ if self._last_step_time is None:
1675
+ return 0.0
1676
+ return self._last_step_time - self._start_time
1677
+
1678
+ def env_steps(self) -> int:
1679
+ """Returns the number of environment steps.
1680
+
1681
+ Note, this episode instance could be a chunk of an actual episode.
1682
+
1683
+ Returns:
1684
+ An integer that counts the number of environment steps this episode instance
1685
+ has seen.
1686
+ """
1687
+ return len(self)
1688
+
1689
+ def agent_steps(self) -> int:
1690
+ """Returns the number of agent steps.
1691
+
1692
+ Note, these are identical to the environment steps for a single-agent episode.
1693
+
1694
+ Returns:
1695
+ An integer counting the number of agent steps executed during the time this
1696
+ episode instance records.
1697
+ """
1698
+ return self.env_steps()
1699
+
1700
+ def get_state(self) -> Dict[str, Any]:
1701
+ """Returns the pickable state of an episode.
1702
+
1703
+ The data in the episode is stored into a dictionary. Note that episodes
1704
+ can also be generated from states (see `SingleAgentEpisode.from_state()`).
1705
+
1706
+ Returns:
1707
+ A dict containing all the data from the episode.
1708
+ """
1709
+ infos = self.infos.get_state()
1710
+ infos["data"] = np.array([info if info else None for info in infos["data"]])
1711
+ return {
1712
+ "id_": self.id_,
1713
+ "agent_id": self.agent_id,
1714
+ "module_id": self.module_id,
1715
+ "multi_agent_episode_id": self.multi_agent_episode_id,
1716
+ # Note, all data is stored in `InfiniteLookbackBuffer`s.
1717
+ "observations": self.observations.get_state(),
1718
+ "actions": self.actions.get_state(),
1719
+ "rewards": self.rewards.get_state(),
1720
+ "infos": self.infos.get_state(),
1721
+ "extra_model_outputs": {
1722
+ k: v.get_state() if v else v
1723
+ for k, v in self.extra_model_outputs.items()
1724
+ }
1725
+ if len(self.extra_model_outputs) > 0
1726
+ else None,
1727
+ "is_terminated": self.is_terminated,
1728
+ "is_truncated": self.is_truncated,
1729
+ "t_started": self.t_started,
1730
+ "t": self.t,
1731
+ "_observation_space": gym_space_to_dict(self._observation_space)
1732
+ if self._observation_space
1733
+ else None,
1734
+ "_action_space": gym_space_to_dict(self._action_space)
1735
+ if self._action_space
1736
+ else None,
1737
+ "_start_time": self._start_time,
1738
+ "_last_step_time": self._last_step_time,
1739
+ "_temporary_timestep_data": dict(self._temporary_timestep_data)
1740
+ if len(self._temporary_timestep_data) > 0
1741
+ else None,
1742
+ }
1743
+
1744
+ @staticmethod
1745
+ def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode":
1746
+ """Creates a new `SingleAgentEpisode` instance from a state dict.
1747
+
1748
+ Args:
1749
+ state: The state dict, as returned by `self.get_state()`.
1750
+
1751
+ Returns:
1752
+ A new `SingleAgentEpisode` instance with the data from the state dict.
1753
+ """
1754
+ # Create an empy episode instance.
1755
+ episode = SingleAgentEpisode(id_=state["id_"])
1756
+ # Load all the data from the state dict into the episode.
1757
+ episode.agent_id = state["agent_id"]
1758
+ episode.module_id = state["module_id"]
1759
+ episode.multi_agent_episode_id = state["multi_agent_episode_id"]
1760
+ # Convert data back to `InfiniteLookbackBuffer`s.
1761
+ episode.observations = InfiniteLookbackBuffer.from_state(state["observations"])
1762
+ episode.actions = InfiniteLookbackBuffer.from_state(state["actions"])
1763
+ episode.rewards = InfiniteLookbackBuffer.from_state(state["rewards"])
1764
+ episode.infos = InfiniteLookbackBuffer.from_state(state["infos"])
1765
+ episode.extra_model_outputs = (
1766
+ defaultdict(
1767
+ functools.partial(
1768
+ InfiniteLookbackBuffer, lookback=episode.observations.lookback
1769
+ ),
1770
+ {
1771
+ k: InfiniteLookbackBuffer.from_state(v)
1772
+ for k, v in state["extra_model_outputs"].items()
1773
+ },
1774
+ )
1775
+ if state["extra_model_outputs"]
1776
+ else defaultdict(
1777
+ functools.partial(
1778
+ InfiniteLookbackBuffer, lookback=episode.observations.lookback
1779
+ ),
1780
+ )
1781
+ )
1782
+ episode.is_terminated = state["is_terminated"]
1783
+ episode.is_truncated = state["is_truncated"]
1784
+ episode.t_started = state["t_started"]
1785
+ episode.t = state["t"]
1786
+ # We need to convert the spaces to dictionaries for serialization.
1787
+ episode._observation_space = (
1788
+ gym_space_from_dict(state["_observation_space"])
1789
+ if state["_observation_space"]
1790
+ else None
1791
+ )
1792
+ episode._action_space = (
1793
+ gym_space_from_dict(state["_action_space"])
1794
+ if state["_action_space"]
1795
+ else None
1796
+ )
1797
+ episode._start_time = state["_start_time"]
1798
+ episode._last_step_time = state["_last_step_time"]
1799
+ episode._temporary_timestep_data = defaultdict(
1800
+ list, state["_temporary_timestep_data"] or {}
1801
+ )
1802
+ # Validate the episode.
1803
+ episode.validate()
1804
+
1805
+ return episode
1806
+
1807
+ @property
1808
+ def observation_space(self):
1809
+ return self._observation_space
1810
+
1811
+ @observation_space.setter
1812
+ def observation_space(self, value):
1813
+ self._observation_space = self.observations.space = value
1814
+
1815
+ @property
1816
+ def action_space(self):
1817
+ return self._action_space
1818
+
1819
+ @action_space.setter
1820
+ def action_space(self, value):
1821
+ self._action_space = self.actions.space = value
1822
+
1823
+ def __len__(self) -> int:
1824
+ """Returning the length of an episode.
1825
+
1826
+ The length of an episode is defined by the length of its data, excluding
1827
+ the lookback buffer data. The length is the number of timesteps an agent has
1828
+ stepped through an environment thus far.
1829
+
1830
+ The length is 0 in case of an episode whose env has NOT been reset yet, but
1831
+ also 0 right after the `env.reset()` data has been added via
1832
+ `self.add_env_reset()`. Only after the first call to `env.step()` (and
1833
+ `self.add_env_step()`, the length will be 1.
1834
+
1835
+ Returns:
1836
+ An integer, defining the length of an episode.
1837
+ """
1838
+ return self.t - self.t_started
1839
+
1840
+ def __repr__(self):
1841
+ return (
1842
+ f"SAEps(len={len(self)} done={self.is_done} "
1843
+ f"R={self.get_return()} id_={self.id_})"
1844
+ )
1845
+
1846
+ def __getitem__(self, item: slice) -> "SingleAgentEpisode":
1847
+ """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:]."""
1848
+ if isinstance(item, slice):
1849
+ return self.slice(slice_=item)
1850
+ else:
1851
+ raise NotImplementedError(
1852
+ f"SingleAgentEpisode does not support getting item '{item}'! "
1853
+ "Only slice objects allowed with the syntax: `episode[a:b]`."
1854
+ )
1855
+
1856
+ @Deprecated(new="SingleAgentEpisode.is_numpy()", error=True)
1857
+ def is_finalized(self):
1858
+ pass
1859
+
1860
+ @Deprecated(new="SingleAgentEpisode.to_numpy()", error=True)
1861
+ def finalize(self):
1862
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from collections import defaultdict
3
+ import gzip
4
+ import json
5
+ import pathlib
6
+ import socket
7
+ import tempfile
8
+ import threading
9
+ import time
10
+ from typing import Collection, DefaultDict, List, Optional, Union
11
+
12
+ import gymnasium as gym
13
+ import numpy as np
14
+ import onnxruntime
15
+
16
+ from ray.rllib.core import (
17
+ Columns,
18
+ COMPONENT_RL_MODULE,
19
+ DEFAULT_AGENT_ID,
20
+ DEFAULT_MODULE_ID,
21
+ )
22
+ from ray.rllib.env import INPUT_ENV_SPACES
23
+ from ray.rllib.env.env_runner import EnvRunner
24
+ from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
25
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
26
+ from ray.rllib.env.utils.external_env_protocol import RLlink as rllink
27
+ from ray.rllib.utils.annotations import ExperimentalAPI, override
28
+ from ray.rllib.utils.checkpoints import Checkpointable
29
+ from ray.rllib.utils.framework import try_import_torch
30
+ from ray.rllib.utils.metrics import (
31
+ EPISODE_DURATION_SEC_MEAN,
32
+ EPISODE_LEN_MAX,
33
+ EPISODE_LEN_MEAN,
34
+ EPISODE_LEN_MIN,
35
+ EPISODE_RETURN_MAX,
36
+ EPISODE_RETURN_MEAN,
37
+ EPISODE_RETURN_MIN,
38
+ WEIGHTS_SEQ_NO,
39
+ )
40
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
41
+ from ray.rllib.utils.numpy import softmax
42
+ from ray.rllib.utils.typing import EpisodeID, StateDict
43
+
44
+ torch, _ = try_import_torch()
45
+
46
+
47
+ @ExperimentalAPI
48
+ class TcpClientInferenceEnvRunner(EnvRunner, Checkpointable):
49
+ """An EnvRunner communicating with an external env through a TCP socket.
50
+
51
+ This implementation assumes:
52
+ - Only one external client ever connects to this env runner.
53
+ - The external client performs inference locally through an ONNX model. Thus,
54
+ samples are sent in bulk once a certain number of timesteps has been executed on the
55
+ client's side (no individual action requests).
56
+ - A copy of the RLModule is kept at all times on the env runner, but never used
57
+ for inference, only as a data (weights) container.
58
+ TODO (sven): The above might be inefficient as we have to store basically two
59
+ models, one in this EnvRunner, one in the env (as ONNX).
60
+ - There is no environment and no connectors on this env runner. The external env
61
+ is responsible for generating all the data to create episodes.
62
+ """
63
+
64
+ @override(EnvRunner)
65
+ def __init__(self, *, config, **kwargs):
66
+ """
67
+ Initializes a TcpClientInferenceEnvRunner instance.
68
+
69
+ Args:
70
+ config: The AlgorithmConfig to use for setup.
71
+
72
+ Keyword Args:
73
+ port: The base port number. The server socket is then actually bound to
74
+ `port` + self.worker_index.
75
+ """
76
+ super().__init__(config=config)
77
+
78
+ self.worker_index: int = kwargs.get("worker_index", 0)
79
+
80
+ self._weights_seq_no = 0
81
+
82
+ # Build the module from its spec.
83
+ module_spec = self.config.get_rl_module_spec(
84
+ spaces=self.get_spaces(), inference_only=True
85
+ )
86
+ self.module = module_spec.build()
87
+
88
+ self.host = "localhost"
89
+ self.port = int(self.config.env_config.get("port", 5555)) + self.worker_index
90
+ self.server_socket = None
91
+ self.client_socket = None
92
+ self.address = None
93
+
94
+ self.metrics = MetricsLogger()
95
+
96
+ self._episode_chunks_to_return: Optional[List[SingleAgentEpisode]] = None
97
+ self._done_episodes_for_metrics: List[SingleAgentEpisode] = []
98
+ self._ongoing_episodes_for_metrics: DefaultDict[
99
+ EpisodeID, List[SingleAgentEpisode]
100
+ ] = defaultdict(list)
101
+
102
+ self._sample_lock = threading.Lock()
103
+ self._on_policy_lock = threading.Lock()
104
+ self._blocked_on_state = False
105
+
106
+ # Start a background thread for client communication.
107
+ self.thread = threading.Thread(
108
+ target=self._client_message_listener, daemon=True
109
+ )
110
+ self.thread.start()
111
+
112
+ @override(EnvRunner)
113
+ def assert_healthy(self):
114
+ """Checks that the server socket is open and listening."""
115
+ assert (
116
+ self.server_socket is not None
117
+ ), "Server socket is None (not connected, not listening)."
118
+
119
+ @override(EnvRunner)
120
+ def sample(self, **kwargs):
121
+ """Waits for the client to send episodes."""
122
+ while True:
123
+ with self._sample_lock:
124
+ if self._episode_chunks_to_return is not None:
125
+ num_env_steps = 0
126
+ num_episodes_completed = 0
127
+ for eps in self._episode_chunks_to_return:
128
+ if eps.is_done:
129
+ self._done_episodes_for_metrics.append(eps)
130
+ num_episodes_completed += 1
131
+ else:
132
+ self._ongoing_episodes_for_metrics[eps.id_].append(eps)
133
+ num_env_steps += len(eps)
134
+
135
+ ret = self._episode_chunks_to_return
136
+ self._episode_chunks_to_return = None
137
+
138
+ SingleAgentEnvRunner._increase_sampled_metrics(
139
+ self, num_env_steps, num_episodes_completed
140
+ )
141
+
142
+ return ret
143
+ time.sleep(0.01)
144
+
145
+ @override(EnvRunner)
146
+ def get_metrics(self):
147
+ # TODO (sven): We should probably make this a utility function to be called
148
+ # from within Single/MultiAgentEnvRunner and other EnvRunner subclasses, as
149
+ # needed.
150
+ # Compute per-episode metrics (only on already completed episodes).
151
+ for eps in self._done_episodes_for_metrics:
152
+ assert eps.is_done
153
+ episode_length = len(eps)
154
+ episode_return = eps.get_return()
155
+ episode_duration_s = eps.get_duration_s()
156
+ # Don't forget about the already returned chunks of this episode.
157
+ if eps.id_ in self._ongoing_episodes_for_metrics:
158
+ for eps2 in self._ongoing_episodes_for_metrics[eps.id_]:
159
+ episode_length += len(eps2)
160
+ episode_return += eps2.get_return()
161
+ episode_duration_s += eps2.get_duration_s()
162
+ del self._ongoing_episodes_for_metrics[eps.id_]
163
+
164
+ self._log_episode_metrics(
165
+ episode_length, episode_return, episode_duration_s
166
+ )
167
+
168
+ # Now that we have logged everything, clear cache of done episodes.
169
+ self._done_episodes_for_metrics.clear()
170
+
171
+ # Return reduced metrics.
172
+ return self.metrics.reduce()
173
+
174
+ def get_spaces(self):
175
+ return {
176
+ INPUT_ENV_SPACES: (self.config.observation_space, self.config.action_space),
177
+ DEFAULT_MODULE_ID: (
178
+ self.config.observation_space,
179
+ self.config.action_space,
180
+ ),
181
+ }
182
+
183
+ @override(EnvRunner)
184
+ def stop(self):
185
+ """Closes the client and server sockets."""
186
+ self._close_sockets_if_necessary()
187
+
188
+ @override(Checkpointable)
189
+ def get_ctor_args_and_kwargs(self):
190
+ return (
191
+ (), # *args
192
+ {"config": self.config}, # **kwargs
193
+ )
194
+
195
+ @override(Checkpointable)
196
+ def get_checkpointable_components(self):
197
+ return [
198
+ (COMPONENT_RL_MODULE, self.module),
199
+ ]
200
+
201
+ @override(Checkpointable)
202
+ def get_state(
203
+ self,
204
+ components: Optional[Union[str, Collection[str]]] = None,
205
+ *,
206
+ not_components: Optional[Union[str, Collection[str]]] = None,
207
+ **kwargs,
208
+ ) -> StateDict:
209
+ return {}
210
+
211
+ @override(Checkpointable)
212
+ def set_state(self, state: StateDict) -> None:
213
+ # Update the RLModule state.
214
+ if COMPONENT_RL_MODULE in state:
215
+ # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
216
+ # update.
217
+ weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
218
+
219
+ # Only update the weigths, if this is the first synchronization or
220
+ # if the weights of this `EnvRunner` lacks behind the actual ones.
221
+ if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
222
+ rl_module_state = state[COMPONENT_RL_MODULE]
223
+ if (
224
+ isinstance(rl_module_state, dict)
225
+ and DEFAULT_MODULE_ID in rl_module_state
226
+ ):
227
+ rl_module_state = rl_module_state[DEFAULT_MODULE_ID]
228
+ self.module.set_state(rl_module_state)
229
+
230
+ # Update our weights_seq_no, if the new one is > 0.
231
+ if weights_seq_no > 0:
232
+ self._weights_seq_no = weights_seq_no
233
+
234
+ if self._blocked_on_state is True:
235
+ self._send_set_state_message()
236
+ self._blocked_on_state = False
237
+
238
+ def _client_message_listener(self):
239
+ """Entry point for the listener thread."""
240
+
241
+ # Set up the server socket and bind to the specified host and port.
242
+ self._recycle_sockets()
243
+
244
+ # Enter an endless message receival- and processing loop.
245
+ while True:
246
+ # As long as we are blocked on a new state, sleep a bit and continue.
247
+ # Do NOT process any incoming messages (until we send out the new state
248
+ # back to the client).
249
+ if self._blocked_on_state is True:
250
+ time.sleep(0.01)
251
+ continue
252
+
253
+ try:
254
+ # Blocking call to get next message.
255
+ msg_type, msg_body = _get_message(self.client_socket)
256
+
257
+ # Process the message received based on its type.
258
+ # Initial handshake.
259
+ if msg_type == rllink.PING:
260
+ self._send_pong_message()
261
+
262
+ # Episode data from the client.
263
+ elif msg_type in [
264
+ rllink.EPISODES,
265
+ rllink.EPISODES_AND_GET_STATE,
266
+ ]:
267
+ self._process_episodes_message(msg_type, msg_body)
268
+
269
+ # Client requests the state (model weights).
270
+ elif msg_type == rllink.GET_STATE:
271
+ self._send_set_state_message()
272
+
273
+ # Clients requests some (relevant) config information.
274
+ elif msg_type == rllink.GET_CONFIG:
275
+ self._send_set_config_message()
276
+
277
+ except ConnectionError as e:
278
+ print(f"Messaging/connection error {e}! Recycling sockets ...")
279
+ self._recycle_sockets(5.0)
280
+ continue
281
+
282
+ def _recycle_sockets(self, sleep: float = 0.0):
283
+ # Close all old sockets, if they exist.
284
+ self._close_sockets_if_necessary()
285
+
286
+ time.sleep(sleep)
287
+
288
+ # Start listening on the configured port.
289
+ self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
290
+ # Allow reuse of the address.
291
+ self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
292
+ self.server_socket.bind((self.host, self.port))
293
+ # Listen for a single connection.
294
+ self.server_socket.listen(1)
295
+ print(f"Waiting for client to connect to port {self.port}...")
296
+
297
+ self.client_socket, self.address = self.server_socket.accept()
298
+ print(f"Connected to client at {self.address}")
299
+
300
+ def _close_sockets_if_necessary(self):
301
+ if self.client_socket:
302
+ self.client_socket.close()
303
+ if self.server_socket:
304
+ self.server_socket.close()
305
+
306
+ def _send_pong_message(self):
307
+ _send_message(self.client_socket, {"type": rllink.PONG.name})
308
+
309
+ def _process_episodes_message(self, msg_type, msg_body):
310
+ # On-policy training -> we have to block until we get a new `set_state` call
311
+ # (b/c the learning step is done and we can sent new weights back to all
312
+ # clients).
313
+ if msg_type == rllink.EPISODES_AND_GET_STATE:
314
+ self._blocked_on_state = True
315
+
316
+ episodes = []
317
+ for episode_data in msg_body["episodes"]:
318
+ episode = SingleAgentEpisode(
319
+ observation_space=self.config.observation_space,
320
+ observations=[np.array(o) for o in episode_data[Columns.OBS]],
321
+ action_space=self.config.action_space,
322
+ actions=episode_data[Columns.ACTIONS],
323
+ rewards=episode_data[Columns.REWARDS],
324
+ extra_model_outputs={
325
+ Columns.ACTION_DIST_INPUTS: [
326
+ np.array(a) for a in episode_data[Columns.ACTION_DIST_INPUTS]
327
+ ],
328
+ Columns.ACTION_LOGP: episode_data[Columns.ACTION_LOGP],
329
+ },
330
+ terminated=episode_data["is_terminated"],
331
+ truncated=episode_data["is_truncated"],
332
+ len_lookback_buffer=0,
333
+ )
334
+ episodes.append(episode.to_numpy())
335
+
336
+ # Push episodes into the to-be-returned list (for `sample()` requests).
337
+ with self._sample_lock:
338
+ if isinstance(self._episode_chunks_to_return, list):
339
+ self._episode_chunks_to_return.extend(episodes)
340
+ else:
341
+ self._episode_chunks_to_return = episodes
342
+
343
+ def _send_set_state_message(self):
344
+ with tempfile.TemporaryDirectory() as dir:
345
+ onnx_file = pathlib.Path(dir) / "_temp_model.onnx"
346
+ torch.onnx.export(
347
+ self.module,
348
+ {
349
+ "batch": {
350
+ "obs": torch.randn(1, *self.config.observation_space.shape)
351
+ }
352
+ },
353
+ onnx_file,
354
+ export_params=True,
355
+ )
356
+ with open(onnx_file, "rb") as f:
357
+ compressed = gzip.compress(f.read())
358
+ onnx_binary = base64.b64encode(compressed).decode("utf-8")
359
+ _send_message(
360
+ self.client_socket,
361
+ {
362
+ "type": rllink.SET_STATE.name,
363
+ "onnx_file": onnx_binary,
364
+ WEIGHTS_SEQ_NO: self._weights_seq_no,
365
+ },
366
+ )
367
+
368
+ def _send_set_config_message(self):
369
+ _send_message(
370
+ self.client_socket,
371
+ {
372
+ "type": rllink.SET_CONFIG.name,
373
+ "env_steps_per_sample": self.config.get_rollout_fragment_length(
374
+ worker_index=self.worker_index
375
+ ),
376
+ "force_on_policy": True,
377
+ },
378
+ )
379
+
380
+ def _log_episode_metrics(self, length, ret, sec):
381
+ # Log general episode metrics.
382
+ # To mimic the old API stack behavior, we'll use `window` here for
383
+ # these particular stats (instead of the default EMA).
384
+ win = self.config.metrics_num_episodes_for_smoothing
385
+ self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win)
386
+ self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win)
387
+ self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win)
388
+ # Per-agent returns.
389
+ self.metrics.log_value(
390
+ ("agent_episode_returns_mean", DEFAULT_AGENT_ID), ret, window=win
391
+ )
392
+ # Per-RLModule returns.
393
+ self.metrics.log_value(
394
+ ("module_episode_returns_mean", DEFAULT_MODULE_ID), ret, window=win
395
+ )
396
+
397
+ # For some metrics, log min/max as well.
398
+ self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min", window=win)
399
+ self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min", window=win)
400
+ self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max", window=win)
401
+ self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max", window=win)
402
+
403
+
404
+ def _send_message(sock_, message: dict):
405
+ """Sends a message to the client with a length header."""
406
+ body = json.dumps(message).encode("utf-8")
407
+ header = str(len(body)).zfill(8).encode("utf-8")
408
+ try:
409
+ sock_.sendall(header + body)
410
+ except Exception as e:
411
+ raise ConnectionError(
412
+ f"Error sending message {message} to server on socket {sock_}! "
413
+ f"Original error was: {e}"
414
+ )
415
+
416
+
417
+ def _get_message(sock_):
418
+ """Receives a message from the client following the length-header protocol."""
419
+ try:
420
+ # Read the length header (8 bytes)
421
+ header = _get_num_bytes(sock_, 8)
422
+ msg_length = int(header.decode("utf-8"))
423
+ # Read the message body
424
+ body = _get_num_bytes(sock_, msg_length)
425
+ # Decode JSON.
426
+ message = json.loads(body.decode("utf-8"))
427
+ # Check for proper protocol.
428
+ if "type" not in message:
429
+ raise ConnectionError(
430
+ "Protocol Error! Message from peer does not contain `type` " "field."
431
+ )
432
+ return rllink(message.pop("type")), message
433
+ except Exception as e:
434
+ raise ConnectionError(
435
+ f"Error receiving message from peer on socket {sock_}! "
436
+ f"Original error was: {e}"
437
+ )
438
+
439
+
440
+ def _get_num_bytes(sock_, num_bytes):
441
+ """Helper function to receive a specific number of bytes."""
442
+ data = b""
443
+ while len(data) < num_bytes:
444
+ packet = sock_.recv(num_bytes - len(data))
445
+ if not packet:
446
+ raise ConnectionError(f"No data received from socket {sock_}!")
447
+ data += packet
448
+ return data
449
+
450
+
451
+ def _dummy_client(port: int = 5556):
452
+ """A dummy client that runs CartPole and acts as a testing external env."""
453
+
454
+ def _set_state(msg_body):
455
+ with tempfile.TemporaryDirectory():
456
+ with open("_temp_onnx", "wb") as f:
457
+ f.write(
458
+ gzip.decompress(
459
+ base64.b64decode(msg_body["onnx_file"].encode("utf-8"))
460
+ )
461
+ )
462
+ onnx_session = onnxruntime.InferenceSession("_temp_onnx")
463
+ output_names = [o.name for o in onnx_session.get_outputs()]
464
+ return onnx_session, output_names
465
+
466
+ # Connect to server.
467
+ while True:
468
+ try:
469
+ print(f"Trying to connect to localhost:{port} ...")
470
+ sock_ = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
471
+ sock_.connect(("localhost", port))
472
+ break
473
+ except ConnectionRefusedError:
474
+ time.sleep(5)
475
+
476
+ # Send ping-pong.
477
+ _send_message(sock_, {"type": rllink.PING.name})
478
+ msg_type, msg_body = _get_message(sock_)
479
+ assert msg_type == rllink.PONG
480
+
481
+ # Request config.
482
+ _send_message(sock_, {"type": rllink.GET_CONFIG.name})
483
+ msg_type, msg_body = _get_message(sock_)
484
+ assert msg_type == rllink.SET_CONFIG
485
+ env_steps_per_sample = msg_body["env_steps_per_sample"]
486
+ force_on_policy = msg_body["force_on_policy"]
487
+
488
+ # Request ONNX weights.
489
+ _send_message(sock_, {"type": rllink.GET_STATE.name})
490
+ msg_type, msg_body = _get_message(sock_)
491
+ assert msg_type == rllink.SET_STATE
492
+ onnx_session, output_names = _set_state(msg_body)
493
+
494
+ # Episode collection buckets.
495
+ episodes = []
496
+ observations = []
497
+ actions = []
498
+ action_dist_inputs = []
499
+ action_logps = []
500
+ rewards = []
501
+
502
+ timesteps = 0
503
+ episode_return = 0.0
504
+
505
+ # Start actual env loop.
506
+ env = gym.make("CartPole-v1")
507
+ obs, info = env.reset()
508
+ observations.append(obs.tolist())
509
+
510
+ while True:
511
+ timesteps += 1
512
+ # Perform action inference using the ONNX model.
513
+ logits = onnx_session.run(
514
+ output_names,
515
+ {"onnx::Gemm_0": np.array([obs], np.float32)},
516
+ )[0][
517
+ 0
518
+ ] # [0]=first return item, [0]=batch size 1
519
+
520
+ # Stochastic sample.
521
+ action_probs = softmax(logits)
522
+ action = int(np.random.choice(list(range(env.action_space.n)), p=action_probs))
523
+ logp = float(np.log(action_probs[action]))
524
+
525
+ # Perform the env step.
526
+ obs, reward, terminated, truncated, info = env.step(action)
527
+
528
+ # Collect step data.
529
+ observations.append(obs.tolist())
530
+ actions.append(action)
531
+ action_dist_inputs.append(logits.tolist())
532
+ action_logps.append(logp)
533
+ rewards.append(reward)
534
+ episode_return += reward
535
+
536
+ # We have to create a new episode record.
537
+ if timesteps == env_steps_per_sample or terminated or truncated:
538
+ episodes.append(
539
+ {
540
+ Columns.OBS: observations,
541
+ Columns.ACTIONS: actions,
542
+ Columns.ACTION_DIST_INPUTS: action_dist_inputs,
543
+ Columns.ACTION_LOGP: action_logps,
544
+ Columns.REWARDS: rewards,
545
+ "is_terminated": terminated,
546
+ "is_truncated": truncated,
547
+ }
548
+ )
549
+ # We collected enough samples -> Send them to server.
550
+ if timesteps == env_steps_per_sample:
551
+ # Make sure the amount of data we collected is correct.
552
+ assert sum(len(e["actions"]) for e in episodes) == env_steps_per_sample
553
+
554
+ # Send the data to the server.
555
+ if force_on_policy:
556
+ _send_message(
557
+ sock_,
558
+ {
559
+ "type": rllink.EPISODES_AND_GET_STATE.name,
560
+ "episodes": episodes,
561
+ "timesteps": timesteps,
562
+ },
563
+ )
564
+ # We are forced to sample on-policy. Have to wait for a response
565
+ # with the state (weights) in it.
566
+ msg_type, msg_body = _get_message(sock_)
567
+ assert msg_type == rllink.SET_STATE
568
+ onnx_session, output_names = _set_state(msg_body)
569
+
570
+ # Sampling doesn't have to be on-policy -> continue collecting
571
+ # samples.
572
+ else:
573
+ raise NotImplementedError
574
+
575
+ episodes = []
576
+ timesteps = 0
577
+
578
+ # Set new buckets to empty lists (for next episode).
579
+ observations = [observations[-1]]
580
+ actions = []
581
+ action_dist_inputs = []
582
+ action_logps = []
583
+ rewards = []
584
+
585
+ # The episode is done -> Reset.
586
+ if terminated or truncated:
587
+ obs, _ = env.reset()
588
+ observations = [obs.tolist()]
589
+ episode_return = 0.0
.venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc ADDED
Binary file (1.28 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gymnasium as gym
3
+ import numpy as np
4
+ from typing import Callable, List, Optional, Tuple, Union, Set
5
+
6
+ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
7
+ from ray.rllib.utils.annotations import Deprecated, OldAPIStack, override
8
+ from ray.rllib.utils.typing import (
9
+ EnvActionType,
10
+ EnvID,
11
+ EnvInfoDict,
12
+ EnvObsType,
13
+ EnvType,
14
+ MultiEnvDict,
15
+ AgentID,
16
+ )
17
+ from ray.util import log_once
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @OldAPIStack
23
+ class VectorEnv:
24
+ """An environment that supports batch evaluation using clones of sub-envs."""
25
+
26
+ def __init__(
27
+ self, observation_space: gym.Space, action_space: gym.Space, num_envs: int
28
+ ):
29
+ """Initializes a VectorEnv instance.
30
+
31
+ Args:
32
+ observation_space: The observation Space of a single
33
+ sub-env.
34
+ action_space: The action Space of a single sub-env.
35
+ num_envs: The number of clones to make of the given sub-env.
36
+ """
37
+ self.observation_space = observation_space
38
+ self.action_space = action_space
39
+ self.num_envs = num_envs
40
+
41
+ @staticmethod
42
+ def vectorize_gym_envs(
43
+ make_env: Optional[Callable[[int], EnvType]] = None,
44
+ existing_envs: Optional[List[gym.Env]] = None,
45
+ num_envs: int = 1,
46
+ action_space: Optional[gym.Space] = None,
47
+ observation_space: Optional[gym.Space] = None,
48
+ restart_failed_sub_environments: bool = False,
49
+ # Deprecated. These seem to have never been used.
50
+ env_config=None,
51
+ policy_config=None,
52
+ ) -> "_VectorizedGymEnv":
53
+ """Translates any given gym.Env(s) into a VectorizedEnv object.
54
+
55
+ Args:
56
+ make_env: Factory that produces a new gym.Env taking the sub-env's
57
+ vector index as only arg. Must be defined if the
58
+ number of `existing_envs` is less than `num_envs`.
59
+ existing_envs: Optional list of already instantiated sub
60
+ environments.
61
+ num_envs: Total number of sub environments in this VectorEnv.
62
+ action_space: The action space. If None, use existing_envs[0]'s
63
+ action space.
64
+ observation_space: The observation space. If None, use
65
+ existing_envs[0]'s observation space.
66
+ restart_failed_sub_environments: If True and any sub-environment (within
67
+ a vectorized env) throws any error during env stepping, the
68
+ Sampler will try to restart the faulty sub-environment. This is done
69
+ without disturbing the other (still intact) sub-environment and without
70
+ the RolloutWorker crashing.
71
+
72
+ Returns:
73
+ The resulting _VectorizedGymEnv object (subclass of VectorEnv).
74
+ """
75
+ return _VectorizedGymEnv(
76
+ make_env=make_env,
77
+ existing_envs=existing_envs or [],
78
+ num_envs=num_envs,
79
+ observation_space=observation_space,
80
+ action_space=action_space,
81
+ restart_failed_sub_environments=restart_failed_sub_environments,
82
+ )
83
+
84
+ def vector_reset(
85
+ self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
86
+ ) -> Tuple[List[EnvObsType], List[EnvInfoDict]]:
87
+ """Resets all sub-environments.
88
+
89
+ Args:
90
+ seed: The list of seeds to be passed to the sub-environments' when resetting
91
+ them. If None, will not reset any existing PRNGs. If you pass
92
+ integers, the PRNGs will be reset even if they already exists.
93
+ options: The list of options dicts to be passed to the sub-environments'
94
+ when resetting them.
95
+
96
+ Returns:
97
+ Tuple consitsing of a list of observations from each environment and
98
+ a list of info dicts from each environment.
99
+ """
100
+ raise NotImplementedError
101
+
102
+ def reset_at(
103
+ self,
104
+ index: Optional[int] = None,
105
+ *,
106
+ seed: Optional[int] = None,
107
+ options: Optional[dict] = None,
108
+ ) -> Union[Tuple[EnvObsType, EnvInfoDict], Exception]:
109
+ """Resets a single sub-environment.
110
+
111
+ Args:
112
+ index: An optional sub-env index to reset.
113
+ seed: The seed to be passed to the sub-environment at index `index` when
114
+ resetting it. If None, will not reset any existing PRNG. If you pass an
115
+ integer, the PRNG will be reset even if it already exists.
116
+ options: An options dict to be passed to the sub-environment at index
117
+ `index` when resetting it.
118
+
119
+ Returns:
120
+ Tuple consisting of observations from the reset sub environment and
121
+ an info dict of the reset sub environment. Alternatively an Exception
122
+ can be returned, indicating that the reset operation on the sub environment
123
+ has failed (and why it failed).
124
+ """
125
+ raise NotImplementedError
126
+
127
+ def restart_at(self, index: Optional[int] = None) -> None:
128
+ """Restarts a single sub-environment.
129
+
130
+ Args:
131
+ index: An optional sub-env index to restart.
132
+ """
133
+ raise NotImplementedError
134
+
135
+ def vector_step(
136
+ self, actions: List[EnvActionType]
137
+ ) -> Tuple[
138
+ List[EnvObsType], List[float], List[bool], List[bool], List[EnvInfoDict]
139
+ ]:
140
+ """Performs a vectorized step on all sub environments using `actions`.
141
+
142
+ Args:
143
+ actions: List of actions (one for each sub-env).
144
+
145
+ Returns:
146
+ A tuple consisting of
147
+ 1) New observations for each sub-env.
148
+ 2) Reward values for each sub-env.
149
+ 3) Terminated values for each sub-env.
150
+ 4) Truncated values for each sub-env.
151
+ 5) Info values for each sub-env.
152
+ """
153
+ raise NotImplementedError
154
+
155
+ def get_sub_environments(self) -> List[EnvType]:
156
+ """Returns the underlying sub environments.
157
+
158
+ Returns:
159
+ List of all underlying sub environments.
160
+ """
161
+ return []
162
+
163
+ # TODO: (sven) Experimental method. Make @PublicAPI at some point.
164
+ def try_render_at(self, index: Optional[int] = None) -> Optional[np.ndarray]:
165
+ """Renders a single environment.
166
+
167
+ Args:
168
+ index: An optional sub-env index to render.
169
+
170
+ Returns:
171
+ Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or
172
+ None in case rendering is handled directly by this method.
173
+ """
174
+ pass
175
+
176
+ def to_base_env(
177
+ self,
178
+ make_env: Optional[Callable[[int], EnvType]] = None,
179
+ num_envs: int = 1,
180
+ remote_envs: bool = False,
181
+ remote_env_batch_wait_ms: int = 0,
182
+ restart_failed_sub_environments: bool = False,
183
+ ) -> "BaseEnv":
184
+ """Converts an RLlib MultiAgentEnv into a BaseEnv object.
185
+
186
+ The resulting BaseEnv is always vectorized (contains n
187
+ sub-environments) to support batched forward passes, where n may
188
+ also be 1. BaseEnv also supports async execution via the `poll` and
189
+ `send_actions` methods and thus supports external simulators.
190
+
191
+ Args:
192
+ make_env: A callable taking an int as input (which indicates
193
+ the number of individual sub-environments within the final
194
+ vectorized BaseEnv) and returning one individual
195
+ sub-environment.
196
+ num_envs: The number of sub-environments to create in the
197
+ resulting (vectorized) BaseEnv. The already existing `env`
198
+ will be one of the `num_envs`.
199
+ remote_envs: Whether each sub-env should be a @ray.remote
200
+ actor. You can set this behavior in your config via the
201
+ `remote_worker_envs=True` option.
202
+ remote_env_batch_wait_ms: The wait time (in ms) to poll remote
203
+ sub-environments for, if applicable. Only used if
204
+ `remote_envs` is True.
205
+
206
+ Returns:
207
+ The resulting BaseEnv object.
208
+ """
209
+ env = VectorEnvWrapper(self)
210
+ return env
211
+
212
+ @Deprecated(new="vectorize_gym_envs", error=True)
213
+ def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv":
214
+ pass
215
+
216
+ @Deprecated(new="get_sub_environments", error=True)
217
+ def get_unwrapped(self) -> List[EnvType]:
218
+ pass
219
+
220
+
221
+ @OldAPIStack
222
+ class _VectorizedGymEnv(VectorEnv):
223
+ """Internal wrapper to translate any gym.Envs into a VectorEnv object."""
224
+
225
+ def __init__(
226
+ self,
227
+ make_env: Optional[Callable[[int], EnvType]] = None,
228
+ existing_envs: Optional[List[gym.Env]] = None,
229
+ num_envs: int = 1,
230
+ *,
231
+ observation_space: Optional[gym.Space] = None,
232
+ action_space: Optional[gym.Space] = None,
233
+ restart_failed_sub_environments: bool = False,
234
+ # Deprecated. These seem to have never been used.
235
+ env_config=None,
236
+ policy_config=None,
237
+ ):
238
+ """Initializes a _VectorizedGymEnv object.
239
+
240
+ Args:
241
+ make_env: Factory that produces a new gym.Env taking the sub-env's
242
+ vector index as only arg. Must be defined if the
243
+ number of `existing_envs` is less than `num_envs`.
244
+ existing_envs: Optional list of already instantiated sub
245
+ environments.
246
+ num_envs: Total number of sub environments in this VectorEnv.
247
+ action_space: The action space. If None, use existing_envs[0]'s
248
+ action space.
249
+ observation_space: The observation space. If None, use
250
+ existing_envs[0]'s observation space.
251
+ restart_failed_sub_environments: If True and any sub-environment (within
252
+ a vectorized env) throws any error during env stepping, we will try to
253
+ restart the faulty sub-environment. This is done
254
+ without disturbing the other (still intact) sub-environments.
255
+ """
256
+ self.envs = existing_envs
257
+ self.make_env = make_env
258
+ self.restart_failed_sub_environments = restart_failed_sub_environments
259
+
260
+ # Fill up missing envs (so we have exactly num_envs sub-envs in this
261
+ # VectorEnv.
262
+ while len(self.envs) < num_envs:
263
+ self.envs.append(make_env(len(self.envs)))
264
+
265
+ super().__init__(
266
+ observation_space=observation_space or self.envs[0].observation_space,
267
+ action_space=action_space or self.envs[0].action_space,
268
+ num_envs=num_envs,
269
+ )
270
+
271
+ @override(VectorEnv)
272
+ def vector_reset(
273
+ self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
274
+ ) -> Tuple[List[EnvObsType], List[EnvInfoDict]]:
275
+ seeds = seeds or [None] * self.num_envs
276
+ options = options or [None] * self.num_envs
277
+ # Use reset_at(index) to restart and retry until
278
+ # we successfully create a new env.
279
+ resetted_obs = []
280
+ resetted_infos = []
281
+ for i in range(len(self.envs)):
282
+ while True:
283
+ obs, infos = self.reset_at(i, seed=seeds[i], options=options[i])
284
+ if not isinstance(obs, Exception):
285
+ break
286
+ resetted_obs.append(obs)
287
+ resetted_infos.append(infos)
288
+ return resetted_obs, resetted_infos
289
+
290
+ @override(VectorEnv)
291
+ def reset_at(
292
+ self,
293
+ index: Optional[int] = None,
294
+ *,
295
+ seed: Optional[int] = None,
296
+ options: Optional[dict] = None,
297
+ ) -> Tuple[Union[EnvObsType, Exception], Union[EnvInfoDict, Exception]]:
298
+ if index is None:
299
+ index = 0
300
+ try:
301
+ obs_and_infos = self.envs[index].reset(seed=seed, options=options)
302
+
303
+ except Exception as e:
304
+ if self.restart_failed_sub_environments:
305
+ logger.exception(e.args[0])
306
+ self.restart_at(index)
307
+ obs_and_infos = e, {}
308
+ else:
309
+ raise e
310
+
311
+ return obs_and_infos
312
+
313
+ @override(VectorEnv)
314
+ def restart_at(self, index: Optional[int] = None) -> None:
315
+ if index is None:
316
+ index = 0
317
+
318
+ # Try closing down the old (possibly faulty) sub-env, but ignore errors.
319
+ try:
320
+ self.envs[index].close()
321
+ except Exception as e:
322
+ if log_once("close_sub_env"):
323
+ logger.warning(
324
+ "Trying to close old and replaced sub-environment (at vector "
325
+ f"index={index}), but closing resulted in error:\n{e}"
326
+ )
327
+ env_to_del = self.envs[index]
328
+ self.envs[index] = None
329
+ del env_to_del
330
+
331
+ # Re-create the sub-env at the new index.
332
+ logger.warning(f"Trying to restart sub-environment at index {index}.")
333
+ self.envs[index] = self.make_env(index)
334
+ logger.warning(f"Sub-environment at index {index} restarted successfully.")
335
+
336
+ @override(VectorEnv)
337
+ def vector_step(
338
+ self, actions: List[EnvActionType]
339
+ ) -> Tuple[
340
+ List[EnvObsType], List[float], List[bool], List[bool], List[EnvInfoDict]
341
+ ]:
342
+ obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = (
343
+ [],
344
+ [],
345
+ [],
346
+ [],
347
+ [],
348
+ )
349
+ for i in range(self.num_envs):
350
+ try:
351
+ results = self.envs[i].step(actions[i])
352
+ except Exception as e:
353
+ if self.restart_failed_sub_environments:
354
+ logger.exception(e.args[0])
355
+ self.restart_at(i)
356
+ results = e, 0.0, True, True, {}
357
+ else:
358
+ raise e
359
+
360
+ obs, reward, terminated, truncated, info = results
361
+
362
+ if not isinstance(info, dict):
363
+ raise ValueError(
364
+ "Info should be a dict, got {} ({})".format(info, type(info))
365
+ )
366
+ obs_batch.append(obs)
367
+ reward_batch.append(reward)
368
+ terminated_batch.append(terminated)
369
+ truncated_batch.append(truncated)
370
+ info_batch.append(info)
371
+ return obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch
372
+
373
+ @override(VectorEnv)
374
+ def get_sub_environments(self) -> List[EnvType]:
375
+ return self.envs
376
+
377
+ @override(VectorEnv)
378
+ def try_render_at(self, index: Optional[int] = None):
379
+ if index is None:
380
+ index = 0
381
+ return self.envs[index].render()
382
+
383
+
384
+ @OldAPIStack
385
+ class VectorEnvWrapper(BaseEnv):
386
+ """Internal adapter of VectorEnv to BaseEnv.
387
+
388
+ We assume the caller will always send the full vector of actions in each
389
+ call to send_actions(), and that they call reset_at() on all completed
390
+ environments before calling send_actions().
391
+ """
392
+
393
+ def __init__(self, vector_env: VectorEnv):
394
+ self.vector_env = vector_env
395
+ self.num_envs = vector_env.num_envs
396
+ self._observation_space = vector_env.observation_space
397
+ self._action_space = vector_env.action_space
398
+
399
+ # Sub-environments' states.
400
+ self.new_obs = None
401
+ self.cur_rewards = None
402
+ self.cur_terminateds = None
403
+ self.cur_truncateds = None
404
+ self.cur_infos = None
405
+ # At first `poll()`, reset everything (all sub-environments).
406
+ self.first_reset_done = False
407
+ # Initialize sub-environments' state.
408
+ self._init_env_state(idx=None)
409
+
410
+ @override(BaseEnv)
411
+ def poll(
412
+ self,
413
+ ) -> Tuple[
414
+ MultiEnvDict,
415
+ MultiEnvDict,
416
+ MultiEnvDict,
417
+ MultiEnvDict,
418
+ MultiEnvDict,
419
+ MultiEnvDict,
420
+ ]:
421
+ from ray.rllib.env.base_env import with_dummy_agent_id
422
+
423
+ if not self.first_reset_done:
424
+ self.first_reset_done = True
425
+ # TODO(sven): We probably would like to seed this call here as well.
426
+ self.new_obs, self.cur_infos = self.vector_env.vector_reset()
427
+ new_obs = dict(enumerate(self.new_obs))
428
+ rewards = dict(enumerate(self.cur_rewards))
429
+ terminateds = dict(enumerate(self.cur_terminateds))
430
+ truncateds = dict(enumerate(self.cur_truncateds))
431
+ infos = dict(enumerate(self.cur_infos))
432
+
433
+ # Empty all states (in case `poll()` gets called again).
434
+ self.new_obs = []
435
+ self.cur_rewards = []
436
+ self.cur_terminateds = []
437
+ self.cur_truncateds = []
438
+ self.cur_infos = []
439
+
440
+ return (
441
+ with_dummy_agent_id(new_obs),
442
+ with_dummy_agent_id(rewards),
443
+ with_dummy_agent_id(terminateds, "__all__"),
444
+ with_dummy_agent_id(truncateds, "__all__"),
445
+ with_dummy_agent_id(infos),
446
+ {},
447
+ )
448
+
449
+ @override(BaseEnv)
450
+ def send_actions(self, action_dict: MultiEnvDict) -> None:
451
+ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
452
+
453
+ action_vector = [None] * self.num_envs
454
+ for i in range(self.num_envs):
455
+ action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
456
+ (
457
+ self.new_obs,
458
+ self.cur_rewards,
459
+ self.cur_terminateds,
460
+ self.cur_truncateds,
461
+ self.cur_infos,
462
+ ) = self.vector_env.vector_step(action_vector)
463
+
464
+ @override(BaseEnv)
465
+ def try_reset(
466
+ self,
467
+ env_id: Optional[EnvID] = None,
468
+ *,
469
+ seed: Optional[int] = None,
470
+ options: Optional[dict] = None,
471
+ ) -> Tuple[MultiEnvDict, MultiEnvDict]:
472
+ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
473
+
474
+ if env_id is None:
475
+ env_id = 0
476
+ assert isinstance(env_id, int)
477
+ obs, infos = self.vector_env.reset_at(env_id, seed=seed, options=options)
478
+
479
+ # If exceptions were returned, return MultiEnvDict mapping env indices to
480
+ # these exceptions (for obs and infos).
481
+ if isinstance(obs, Exception):
482
+ return {env_id: obs}, {env_id: infos}
483
+ # Otherwise, return a MultiEnvDict (with single agent ID) and the actual
484
+ # obs and info dicts.
485
+ else:
486
+ return {env_id: {_DUMMY_AGENT_ID: obs}}, {env_id: {_DUMMY_AGENT_ID: infos}}
487
+
488
+ @override(BaseEnv)
489
+ def try_restart(self, env_id: Optional[EnvID] = None) -> None:
490
+ assert env_id is None or isinstance(env_id, int)
491
+ # Restart the sub-env at the index.
492
+ self.vector_env.restart_at(env_id)
493
+ # Auto-reset (get ready for next `poll()`).
494
+ self._init_env_state(env_id)
495
+
496
+ @override(BaseEnv)
497
+ def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
498
+ if not as_dict:
499
+ return self.vector_env.get_sub_environments()
500
+ else:
501
+ return {
502
+ _id: env
503
+ for _id, env in enumerate(self.vector_env.get_sub_environments())
504
+ }
505
+
506
+ @override(BaseEnv)
507
+ def try_render(self, env_id: Optional[EnvID] = None) -> None:
508
+ assert env_id is None or isinstance(env_id, int)
509
+ return self.vector_env.try_render_at(env_id)
510
+
511
+ @property
512
+ @override(BaseEnv)
513
+ def observation_space(self) -> gym.Space:
514
+ return self._observation_space
515
+
516
+ @property
517
+ @override(BaseEnv)
518
+ def action_space(self) -> gym.Space:
519
+ return self._action_space
520
+
521
+ @override(BaseEnv)
522
+ def get_agent_ids(self) -> Set[AgentID]:
523
+ return {_DUMMY_AGENT_ID}
524
+
525
+ def _init_env_state(self, idx: Optional[int] = None) -> None:
526
+ """Resets all or one particular sub-environment's state (by index).
527
+
528
+ Args:
529
+ idx: The index to reset at. If None, reset all the sub-environments' states.
530
+ """
531
+ # If index is None, reset all sub-envs' states:
532
+ if idx is None:
533
+ self.new_obs = [None for _ in range(self.num_envs)]
534
+ self.cur_rewards = [0.0 for _ in range(self.num_envs)]
535
+ self.cur_terminateds = [False for _ in range(self.num_envs)]
536
+ self.cur_truncateds = [False for _ in range(self.num_envs)]
537
+ self.cur_infos = [{} for _ in range(self.num_envs)]
538
+ # Index provided, reset only the sub-env's state at the given index.
539
+ else:
540
+ self.new_obs[idx], self.cur_infos[idx] = self.vector_env.reset_at(idx)
541
+ # Reset all other states to null values.
542
+ self.cur_rewards[idx] = 0.0
543
+ self.cur_terminateds[idx] = False
544
+ self.cur_truncateds[idx] = False
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc ADDED
Binary file (22.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc ADDED
Binary file (7.97 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc ADDED
Binary file (4.45 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc ADDED
Binary file (8.82 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc ADDED
Binary file (16.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import gymnasium as gym
3
+ from gymnasium import spaces
4
+ import numpy as np
5
+ from typing import Optional, Union
6
+
7
+ from ray.rllib.utils.annotations import PublicAPI
8
+ from ray.rllib.utils.images import rgb2gray, resize
9
+
10
+
11
+ @PublicAPI
12
+ def is_atari(env: Union[gym.Env, str]) -> bool:
13
+ """Returns, whether a given env object or env descriptor (str) is an Atari env.
14
+
15
+ Args:
16
+ env: The gym.Env object or a string descriptor of the env (for example,
17
+ "ale_py:ALE/Pong-v5").
18
+
19
+ Returns:
20
+ Whether `env` is an Atari environment.
21
+ """
22
+ # If a gym.Env, check proper spaces as well as occurrence of the "Atari<ALE" string
23
+ # in the class name.
24
+ if not isinstance(env, str):
25
+ if (
26
+ hasattr(env.observation_space, "shape")
27
+ and env.observation_space.shape is not None
28
+ and len(env.observation_space.shape) <= 2
29
+ ):
30
+ return False
31
+ return "AtariEnv<ALE" in str(env)
32
+ # If string, check for "ale_py:ALE/" prefix.
33
+ else:
34
+ return env.startswith("ALE/") or env.startswith("ale_py:")
35
+
36
+
37
+ @PublicAPI
38
+ def get_wrapper_by_cls(env, cls):
39
+ """Returns the gym env wrapper of the given class, or None."""
40
+ currentenv = env
41
+ while True:
42
+ if isinstance(currentenv, cls):
43
+ return currentenv
44
+ elif isinstance(currentenv, gym.Wrapper):
45
+ currentenv = currentenv.env
46
+ else:
47
+ return None
48
+
49
+
50
+ @PublicAPI
51
+ class ClipRewardEnv(gym.RewardWrapper):
52
+ def __init__(self, env):
53
+ gym.RewardWrapper.__init__(self, env)
54
+
55
+ def reward(self, reward):
56
+ """Bin reward to {+1, 0, -1} by its sign."""
57
+ return np.sign(reward)
58
+
59
+
60
+ @PublicAPI
61
+ class EpisodicLifeEnv(gym.Wrapper):
62
+ def __init__(self, env):
63
+ """Make end-of-life == end-of-episode, but only reset on true game over.
64
+ Done by DeepMind for the DQN and co. since it helps value estimation.
65
+ """
66
+ gym.Wrapper.__init__(self, env)
67
+ self.lives = 0
68
+ self.was_real_terminated = True
69
+
70
+ def step(self, action):
71
+ obs, reward, terminated, truncated, info = self.env.step(action)
72
+ self.was_real_terminated = terminated
73
+ # check current lives, make loss of life terminal,
74
+ # then update lives to handle bonus lives
75
+ lives = self.env.unwrapped.ale.lives()
76
+ if lives < self.lives and lives > 0:
77
+ # for Qbert sometimes we stay in lives == 0 condtion for a few fr
78
+ # so its important to keep lives > 0, so that we only reset once
79
+ # the environment advertises `terminated`.
80
+ terminated = True
81
+ self.lives = lives
82
+ return obs, reward, terminated, truncated, info
83
+
84
+ def reset(self, **kwargs):
85
+ """Reset only when lives are exhausted.
86
+ This way all states are still reachable even though lives are episodic,
87
+ and the learner need not know about any of this behind-the-scenes.
88
+ """
89
+ if self.was_real_terminated:
90
+ obs, info = self.env.reset(**kwargs)
91
+ else:
92
+ # no-op step to advance from terminal/lost life state
93
+ obs, _, _, _, info = self.env.step(0)
94
+ self.lives = self.env.unwrapped.ale.lives()
95
+ return obs, info
96
+
97
+
98
+ @PublicAPI
99
+ class FireResetEnv(gym.Wrapper):
100
+ def __init__(self, env):
101
+ """Take action on reset.
102
+
103
+ For environments that are fixed until firing."""
104
+ gym.Wrapper.__init__(self, env)
105
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE"
106
+ assert len(env.unwrapped.get_action_meanings()) >= 3
107
+
108
+ def reset(self, **kwargs):
109
+ self.env.reset(**kwargs)
110
+ obs, _, terminated, truncated, _ = self.env.step(1)
111
+ if terminated or truncated:
112
+ self.env.reset(**kwargs)
113
+ obs, _, terminated, truncated, info = self.env.step(2)
114
+ if terminated or truncated:
115
+ self.env.reset(**kwargs)
116
+ return obs, info
117
+
118
+ def step(self, ac):
119
+ return self.env.step(ac)
120
+
121
+
122
+ @PublicAPI
123
+ class FrameStack(gym.Wrapper):
124
+ def __init__(self, env, k):
125
+ """Stack k last frames."""
126
+ gym.Wrapper.__init__(self, env)
127
+ self.k = k
128
+ self.frames = deque([], maxlen=k)
129
+ shp = env.observation_space.shape
130
+ self.observation_space = spaces.Box(
131
+ low=np.repeat(env.observation_space.low, repeats=k, axis=-1),
132
+ high=np.repeat(env.observation_space.high, repeats=k, axis=-1),
133
+ shape=(shp[0], shp[1], shp[2] * k),
134
+ dtype=env.observation_space.dtype,
135
+ )
136
+
137
+ def reset(self, *, seed=None, options=None):
138
+ ob, infos = self.env.reset(seed=seed, options=options)
139
+ for _ in range(self.k):
140
+ self.frames.append(ob)
141
+ return self._get_ob(), infos
142
+
143
+ def step(self, action):
144
+ ob, reward, terminated, truncated, info = self.env.step(action)
145
+ self.frames.append(ob)
146
+ return self._get_ob(), reward, terminated, truncated, info
147
+
148
+ def _get_ob(self):
149
+ assert len(self.frames) == self.k
150
+ return np.concatenate(self.frames, axis=2)
151
+
152
+
153
+ @PublicAPI
154
+ class FrameStackTrajectoryView(gym.ObservationWrapper):
155
+ def __init__(self, env):
156
+ """No stacking. Trajectory View API takes care of this."""
157
+ gym.Wrapper.__init__(self, env)
158
+ shp = env.observation_space.shape
159
+ assert shp[2] == 1
160
+ self.observation_space = spaces.Box(
161
+ low=0, high=255, shape=(shp[0], shp[1]), dtype=env.observation_space.dtype
162
+ )
163
+
164
+ def observation(self, observation):
165
+ return np.squeeze(observation, axis=-1)
166
+
167
+
168
+ @PublicAPI
169
+ class MaxAndSkipEnv(gym.Wrapper):
170
+ def __init__(self, env, skip=4):
171
+ """Return only every `skip`-th frame"""
172
+ gym.Wrapper.__init__(self, env)
173
+ # most recent raw observations (for max pooling across time steps)
174
+ self._obs_buffer = np.zeros(
175
+ (2,) + env.observation_space.shape, dtype=env.observation_space.dtype
176
+ )
177
+ self._skip = skip
178
+
179
+ def step(self, action):
180
+ """Repeat action, sum reward, and max over last observations."""
181
+ total_reward = 0.0
182
+ terminated = truncated = info = None
183
+ for i in range(self._skip):
184
+ obs, reward, terminated, truncated, info = self.env.step(action)
185
+ if i == self._skip - 2:
186
+ self._obs_buffer[0] = obs
187
+ if i == self._skip - 1:
188
+ self._obs_buffer[1] = obs
189
+ total_reward += reward
190
+ if terminated or truncated:
191
+ break
192
+ # Note that the observation on the terminated|truncated=True frame
193
+ # doesn't matter
194
+ max_frame = self._obs_buffer.max(axis=0)
195
+
196
+ return max_frame, total_reward, terminated, truncated, info
197
+
198
+ def reset(self, **kwargs):
199
+ return self.env.reset(**kwargs)
200
+
201
+
202
+ @PublicAPI
203
+ class MonitorEnv(gym.Wrapper):
204
+ def __init__(self, env=None):
205
+ """Record episodes stats prior to EpisodicLifeEnv, etc."""
206
+ gym.Wrapper.__init__(self, env)
207
+ self._current_reward = None
208
+ self._num_steps = None
209
+ self._total_steps = None
210
+ self._episode_rewards = []
211
+ self._episode_lengths = []
212
+ self._num_episodes = 0
213
+ self._num_returned = 0
214
+
215
+ def reset(self, **kwargs):
216
+ obs, info = self.env.reset(**kwargs)
217
+
218
+ if self._total_steps is None:
219
+ self._total_steps = sum(self._episode_lengths)
220
+
221
+ if self._current_reward is not None:
222
+ self._episode_rewards.append(self._current_reward)
223
+ self._episode_lengths.append(self._num_steps)
224
+ self._num_episodes += 1
225
+
226
+ self._current_reward = 0
227
+ self._num_steps = 0
228
+
229
+ return obs, info
230
+
231
+ def step(self, action):
232
+ obs, rew, terminated, truncated, info = self.env.step(action)
233
+ self._current_reward += rew
234
+ self._num_steps += 1
235
+ self._total_steps += 1
236
+ return obs, rew, terminated, truncated, info
237
+
238
+ def get_episode_rewards(self):
239
+ return self._episode_rewards
240
+
241
+ def get_episode_lengths(self):
242
+ return self._episode_lengths
243
+
244
+ def get_total_steps(self):
245
+ return self._total_steps
246
+
247
+ def next_episode_results(self):
248
+ for i in range(self._num_returned, len(self._episode_rewards)):
249
+ yield (self._episode_rewards[i], self._episode_lengths[i])
250
+ self._num_returned = len(self._episode_rewards)
251
+
252
+
253
+ @PublicAPI
254
+ class NoopResetEnv(gym.Wrapper):
255
+ def __init__(self, env, noop_max=30):
256
+ """Sample initial states by taking random number of no-ops on reset.
257
+ No-op is assumed to be action 0.
258
+ """
259
+ gym.Wrapper.__init__(self, env)
260
+ self.noop_max = noop_max
261
+ self.override_num_noops = None
262
+ self.noop_action = 0
263
+ assert env.unwrapped.get_action_meanings()[0] == "NOOP"
264
+
265
+ def reset(self, **kwargs):
266
+ """Do no-op action for a number of steps in [1, noop_max]."""
267
+ self.env.reset(**kwargs)
268
+ if self.override_num_noops is not None:
269
+ noops = self.override_num_noops
270
+ else:
271
+ # This environment now uses the pcg64 random number generator which
272
+ # does not have `randint` as an attribute only has `integers`.
273
+ try:
274
+ noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
275
+ # Also still support older versions.
276
+ except AttributeError:
277
+ noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
278
+ assert noops > 0
279
+ obs = None
280
+ for _ in range(noops):
281
+ obs, _, terminated, truncated, info = self.env.step(self.noop_action)
282
+ if terminated or truncated:
283
+ obs, info = self.env.reset(**kwargs)
284
+ return obs, info
285
+
286
+ def step(self, ac):
287
+ return self.env.step(ac)
288
+
289
+
290
+ @PublicAPI
291
+ class NormalizedImageEnv(gym.ObservationWrapper):
292
+ def __init__(self, *args, **kwargs):
293
+ super().__init__(*args, **kwargs)
294
+ self.observation_space = gym.spaces.Box(
295
+ -1.0,
296
+ 1.0,
297
+ shape=self.observation_space.shape,
298
+ dtype=np.float32,
299
+ )
300
+
301
+ # Divide by scale and center around 0.0, such that observations are in the range
302
+ # of -1.0 and 1.0.
303
+ def observation(self, observation):
304
+ return (observation.astype(np.float32) / 128.0) - 1.0
305
+
306
+
307
+ @PublicAPI
308
+ class WarpFrame(gym.ObservationWrapper):
309
+ def __init__(self, env, dim):
310
+ """Warp frames to the specified size (dim x dim)."""
311
+ gym.ObservationWrapper.__init__(self, env)
312
+ self.width = dim
313
+ self.height = dim
314
+ self.observation_space = spaces.Box(
315
+ low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8
316
+ )
317
+
318
+ def observation(self, frame):
319
+ frame = rgb2gray(frame)
320
+ frame = resize(frame, height=self.height, width=self.width)
321
+ return frame[:, :, None]
322
+
323
+
324
+ @PublicAPI
325
+ def wrap_atari_for_new_api_stack(
326
+ env: gym.Env,
327
+ dim: int = 64,
328
+ frameskip: int = 4,
329
+ framestack: Optional[int] = None,
330
+ # TODO (sven): Add option to NOT grayscale, in which case framestack must be None
331
+ # (b/c we are using the 3 color channels already as stacking frames).
332
+ ) -> gym.Env:
333
+ """Wraps `env` for new-API-stack-friendly RLlib Atari experiments.
334
+
335
+ Note that we assume reward clipping is done outside the wrapper.
336
+
337
+ Args:
338
+ env: The env object to wrap.
339
+ dim: Dimension to resize observations to (dim x dim).
340
+ frameskip: Whether to skip n frames and max over them (keep brightest pixels).
341
+ framestack: Whether to stack the last n (grayscaled) frames. Note that this
342
+ step happens after(!) a possible frameskip step, meaning that if
343
+ frameskip=4 and framestack=2, we would perform the following over this
344
+ trajectory:
345
+ actual env timesteps: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> ...
346
+ frameskip: ( max ) ( max ) ( max ) ( max )
347
+ framestack: ( stack ) (stack )
348
+
349
+ Returns:
350
+ The wrapped gym.Env.
351
+ """
352
+ # Time limit.
353
+ env = gym.wrappers.TimeLimit(env, max_episode_steps=108000)
354
+ # Grayscale + resize.
355
+ env = WarpFrame(env, dim=dim)
356
+ # Normalize the image.
357
+ env = NormalizedImageEnv(env)
358
+ # Frameskip: Take max over these n frames.
359
+ if frameskip > 1:
360
+ assert env.spec is not None
361
+ env = MaxAndSkipEnv(env, skip=frameskip)
362
+ # Send n noop actions into env after reset to increase variance in the
363
+ # "start states" of the trajectories. These dummy steps are NOT included in the
364
+ # sampled data used for learning.
365
+ env = NoopResetEnv(env, noop_max=30)
366
+ # Each life is one episode.
367
+ env = EpisodicLifeEnv(env)
368
+ # Some envs only start playing after pressing fire. Unblock those.
369
+ if "FIRE" in env.unwrapped.get_action_meanings():
370
+ env = FireResetEnv(env)
371
+ # Framestack.
372
+ if framestack:
373
+ env = FrameStack(env, k=framestack)
374
+ return env
375
+
376
+
377
+ @PublicAPI
378
+ def wrap_deepmind(env, dim=84, framestack=True, noframeskip=False):
379
+ """Configure environment for DeepMind-style Atari.
380
+
381
+ Note that we assume reward clipping is done outside the wrapper.
382
+
383
+ Args:
384
+ env: The env object to wrap.
385
+ dim: Dimension to resize observations to (dim x dim).
386
+ framestack: Whether to framestack observations.
387
+ """
388
+ env = MonitorEnv(env)
389
+ env = NoopResetEnv(env, noop_max=30)
390
+ if env.spec is not None and noframeskip is True:
391
+ env = MaxAndSkipEnv(env, skip=4)
392
+ env = EpisodicLifeEnv(env)
393
+ if "FIRE" in env.unwrapped.get_action_meanings():
394
+ env = FireResetEnv(env)
395
+ env = WarpFrame(env, dim)
396
+ # env = ClipRewardEnv(env) # reward clipping is handled by policy eval
397
+ # 4x image framestacking.
398
+ if framestack is True:
399
+ env = FrameStack(env, 4)
400
+ return env
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepMind Control Suite Wrapper directly sourced from:
3
+ https://github.com/denisyarats/dmc2gym
4
+
5
+ MIT License
6
+
7
+ Copyright (c) 2020 Denis Yarats
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy
10
+ of this software and associated documentation files (the "Software"), to deal
11
+ in the Software without restriction, including without limitation the rights
12
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13
+ copies of the Software, and to permit persons to whom the Software is
14
+ furnished to do so, subject to the following conditions:
15
+
16
+ The above copyright notice and this permission notice shall be included in all
17
+ copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25
+ SOFTWARE.
26
+ """
27
+ from gymnasium import core, spaces
28
+
29
+ try:
30
+ from dm_env import specs
31
+ except ImportError:
32
+ specs = None
33
+ try:
34
+ # Suppress MuJoCo warning (dm_control uses absl logging).
35
+ import absl.logging
36
+
37
+ absl.logging.set_verbosity("error")
38
+ from dm_control import suite
39
+ except (ImportError, OSError):
40
+ suite = None
41
+ import numpy as np
42
+
43
+ from ray.rllib.utils.annotations import PublicAPI
44
+
45
+
46
+ def _spec_to_box(spec):
47
+ def extract_min_max(s):
48
+ assert s.dtype == np.float64 or s.dtype == np.float32
49
+ dim = np.int_(np.prod(s.shape))
50
+ if type(s) is specs.Array:
51
+ bound = np.inf * np.ones(dim, dtype=np.float32)
52
+ return -bound, bound
53
+ elif type(s) is specs.BoundedArray:
54
+ zeros = np.zeros(dim, dtype=np.float32)
55
+ return s.minimum + zeros, s.maximum + zeros
56
+
57
+ mins, maxs = [], []
58
+ for s in spec:
59
+ mn, mx = extract_min_max(s)
60
+ mins.append(mn)
61
+ maxs.append(mx)
62
+ low = np.concatenate(mins, axis=0)
63
+ high = np.concatenate(maxs, axis=0)
64
+ assert low.shape == high.shape
65
+ return spaces.Box(low, high, dtype=np.float32)
66
+
67
+
68
+ def _flatten_obs(obs):
69
+ obs_pieces = []
70
+ for v in obs.values():
71
+ flat = np.array([v]) if np.isscalar(v) else v.ravel()
72
+ obs_pieces.append(flat)
73
+ return np.concatenate(obs_pieces, axis=0)
74
+
75
+
76
+ @PublicAPI
77
+ class DMCEnv(core.Env):
78
+ def __init__(
79
+ self,
80
+ domain_name,
81
+ task_name,
82
+ task_kwargs=None,
83
+ visualize_reward=False,
84
+ from_pixels=False,
85
+ height=64,
86
+ width=64,
87
+ camera_id=0,
88
+ frame_skip=2,
89
+ environment_kwargs=None,
90
+ channels_first=True,
91
+ preprocess=True,
92
+ ):
93
+ self._from_pixels = from_pixels
94
+ self._height = height
95
+ self._width = width
96
+ self._camera_id = camera_id
97
+ self._frame_skip = frame_skip
98
+ self._channels_first = channels_first
99
+ self.preprocess = preprocess
100
+
101
+ if specs is None:
102
+ raise RuntimeError(
103
+ (
104
+ "The `specs` module from `dm_env` was not imported. Make sure "
105
+ "`dm_env` is installed and visible in the current python "
106
+ "environment."
107
+ )
108
+ )
109
+ if suite is None:
110
+ raise RuntimeError(
111
+ (
112
+ "The `suite` module from `dm_control` was not imported. Make "
113
+ "sure `dm_control` is installed and visible in the current "
114
+ "python enviornment."
115
+ )
116
+ )
117
+
118
+ # create task
119
+ self._env = suite.load(
120
+ domain_name=domain_name,
121
+ task_name=task_name,
122
+ task_kwargs=task_kwargs,
123
+ visualize_reward=visualize_reward,
124
+ environment_kwargs=environment_kwargs,
125
+ )
126
+
127
+ # true and normalized action spaces
128
+ self._true_action_space = _spec_to_box([self._env.action_spec()])
129
+ self._norm_action_space = spaces.Box(
130
+ low=-1.0, high=1.0, shape=self._true_action_space.shape, dtype=np.float32
131
+ )
132
+
133
+ # create observation space
134
+ if from_pixels:
135
+ shape = [3, height, width] if channels_first else [height, width, 3]
136
+ self._observation_space = spaces.Box(
137
+ low=0, high=255, shape=shape, dtype=np.uint8
138
+ )
139
+ if preprocess:
140
+ self._observation_space = spaces.Box(
141
+ low=-0.5, high=0.5, shape=shape, dtype=np.float32
142
+ )
143
+ else:
144
+ self._observation_space = _spec_to_box(
145
+ self._env.observation_spec().values()
146
+ )
147
+
148
+ self._state_space = _spec_to_box(self._env.observation_spec().values())
149
+
150
+ self.current_state = None
151
+
152
+ def __getattr__(self, name):
153
+ return getattr(self._env, name)
154
+
155
+ def _get_obs(self, time_step):
156
+ if self._from_pixels:
157
+ obs = self.render(
158
+ height=self._height, width=self._width, camera_id=self._camera_id
159
+ )
160
+ if self._channels_first:
161
+ obs = obs.transpose(2, 0, 1).copy()
162
+ if self.preprocess:
163
+ obs = obs / 255.0 - 0.5
164
+ else:
165
+ obs = _flatten_obs(time_step.observation)
166
+ return obs.astype(np.float32)
167
+
168
+ def _convert_action(self, action):
169
+ action = action.astype(np.float64)
170
+ true_delta = self._true_action_space.high - self._true_action_space.low
171
+ norm_delta = self._norm_action_space.high - self._norm_action_space.low
172
+ action = (action - self._norm_action_space.low) / norm_delta
173
+ action = action * true_delta + self._true_action_space.low
174
+ action = action.astype(np.float32)
175
+ return action
176
+
177
+ @property
178
+ def observation_space(self):
179
+ return self._observation_space
180
+
181
+ @property
182
+ def state_space(self):
183
+ return self._state_space
184
+
185
+ @property
186
+ def action_space(self):
187
+ return self._norm_action_space
188
+
189
+ def step(self, action):
190
+ assert self._norm_action_space.contains(action)
191
+ action = self._convert_action(action)
192
+ assert self._true_action_space.contains(action)
193
+ reward = 0.0
194
+ extra = {"internal_state": self._env.physics.get_state().copy()}
195
+
196
+ terminated = truncated = False
197
+ for _ in range(self._frame_skip):
198
+ time_step = self._env.step(action)
199
+ reward += time_step.reward or 0.0
200
+ terminated = False
201
+ truncated = time_step.last()
202
+ if terminated or truncated:
203
+ break
204
+ obs = self._get_obs(time_step)
205
+ self.current_state = _flatten_obs(time_step.observation)
206
+ extra["discount"] = time_step.discount
207
+ return obs, reward, terminated, truncated, extra
208
+
209
+ def reset(self, *, seed=None, options=None):
210
+ time_step = self._env.reset()
211
+ self.current_state = _flatten_obs(time_step.observation)
212
+ obs = self._get_obs(time_step)
213
+ return obs, {}
214
+
215
+ def render(self, mode="rgb_array", height=None, width=None, camera_id=0):
216
+ assert mode == "rgb_array", "only support for rgb_array mode"
217
+ height = height or self._height
218
+ width = width or self._width
219
+ camera_id = camera_id or self._camera_id
220
+ return self._env.physics.render(height=height, width=width, camera_id=camera_id)
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+
4
+ import numpy as np
5
+
6
+ try:
7
+ from dm_env import specs
8
+ except ImportError:
9
+ specs = None
10
+
11
+ from ray.rllib.utils.annotations import PublicAPI
12
+
13
+
14
+ def _convert_spec_to_space(spec):
15
+ if isinstance(spec, dict):
16
+ return spaces.Dict({k: _convert_spec_to_space(v) for k, v in spec.items()})
17
+ if isinstance(spec, specs.DiscreteArray):
18
+ return spaces.Discrete(spec.num_values)
19
+ elif isinstance(spec, specs.BoundedArray):
20
+ return spaces.Box(
21
+ low=np.asscalar(spec.minimum),
22
+ high=np.asscalar(spec.maximum),
23
+ shape=spec.shape,
24
+ dtype=spec.dtype,
25
+ )
26
+ elif isinstance(spec, specs.Array):
27
+ return spaces.Box(
28
+ low=-float("inf"), high=float("inf"), shape=spec.shape, dtype=spec.dtype
29
+ )
30
+
31
+ raise NotImplementedError(
32
+ (
33
+ "Could not convert `Array` spec of type {} to Gym space. "
34
+ "Attempted to convert: {}"
35
+ ).format(type(spec), spec)
36
+ )
37
+
38
+
39
+ @PublicAPI
40
+ class DMEnv(gym.Env):
41
+ """A `gym.Env` wrapper for the `dm_env` API."""
42
+
43
+ metadata = {"render.modes": ["rgb_array"]}
44
+
45
+ def __init__(self, dm_env):
46
+ super(DMEnv, self).__init__()
47
+ self._env = dm_env
48
+ self._prev_obs = None
49
+
50
+ if specs is None:
51
+ raise RuntimeError(
52
+ (
53
+ "The `specs` module from `dm_env` was not imported. Make sure "
54
+ "`dm_env` is installed and visible in the current python "
55
+ "environment."
56
+ )
57
+ )
58
+
59
+ def step(self, action):
60
+ ts = self._env.step(action)
61
+
62
+ reward = ts.reward
63
+ if reward is None:
64
+ reward = 0.0
65
+
66
+ return ts.observation, reward, ts.last(), False, {"discount": ts.discount}
67
+
68
+ def reset(self, *, seed=None, options=None):
69
+ ts = self._env.reset()
70
+ return ts.observation, {}
71
+
72
+ def render(self, mode="rgb_array"):
73
+ if self._prev_obs is None:
74
+ raise ValueError(
75
+ "Environment not started. Make sure to reset before rendering."
76
+ )
77
+
78
+ if mode == "rgb_array":
79
+ return self._prev_obs
80
+ else:
81
+ raise NotImplementedError("Render mode '{}' is not supported.".format(mode))
82
+
83
+ @property
84
+ def action_space(self):
85
+ spec = self._env.action_spec()
86
+ return _convert_spec_to_space(spec)
87
+
88
+ @property
89
+ def observation_space(self):
90
+ spec = self._env.observation_spec()
91
+ return _convert_spec_to_space(spec)
92
+
93
+ @property
94
+ def reward_range(self):
95
+ spec = self._env.reward_spec()
96
+ if isinstance(spec, specs.BoundedArray):
97
+ return spec.minimum, spec.maximum
98
+ return -float("inf"), float("inf")
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import gymnasium as gym
3
+ from typing import Dict, List, Optional
4
+
5
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
6
+ from ray.rllib.utils.annotations import DeveloperAPI
7
+ from ray.rllib.utils.typing import AgentID
8
+
9
+ # info key for the individual rewards of an agent, for example:
10
+ # info: {
11
+ # group_1: {
12
+ # _group_rewards: [5, -1, 1], # 3 agents in this group
13
+ # }
14
+ # }
15
+ GROUP_REWARDS = "_group_rewards"
16
+
17
+ # info key for the individual infos of an agent, for example:
18
+ # info: {
19
+ # group_1: {
20
+ # _group_infos: [{"foo": ...}, {}], # 2 agents in this group
21
+ # }
22
+ # }
23
+ GROUP_INFO = "_group_info"
24
+
25
+
26
+ @DeveloperAPI
27
+ class GroupAgentsWrapper(MultiAgentEnv):
28
+ """Wraps a MultiAgentEnv environment with agents grouped as specified.
29
+
30
+ See multi_agent_env.py for the specification of groups.
31
+
32
+ This API is experimental.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ env: MultiAgentEnv,
38
+ groups: Dict[str, List[AgentID]],
39
+ obs_space: Optional[gym.Space] = None,
40
+ act_space: Optional[gym.Space] = None,
41
+ ):
42
+ """Wrap an existing MultiAgentEnv to group agent ID together.
43
+
44
+ See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
45
+
46
+ Args:
47
+ env: The env to wrap and whose agent IDs to group into new agents.
48
+ groups: Mapping from group id to a list of the agent ids
49
+ of group members. If an agent id is not present in any group
50
+ value, it will be left ungrouped. The group id becomes a new agent ID
51
+ in the final environment.
52
+ obs_space: Optional observation space for the grouped
53
+ env. Must be a tuple space. If not provided, will infer this to be a
54
+ Tuple of n individual agents spaces (n=num agents in a group).
55
+ act_space: Optional action space for the grouped env.
56
+ Must be a tuple space. If not provided, will infer this to be a Tuple
57
+ of n individual agents spaces (n=num agents in a group).
58
+ """
59
+ super().__init__()
60
+ self.env = env
61
+ self.groups = groups
62
+ self.agent_id_to_group = {}
63
+ for group_id, agent_ids in groups.items():
64
+ for agent_id in agent_ids:
65
+ if agent_id in self.agent_id_to_group:
66
+ raise ValueError(
67
+ "Agent id {} is in multiple groups".format(agent_id)
68
+ )
69
+ self.agent_id_to_group[agent_id] = group_id
70
+ if obs_space is not None:
71
+ self.observation_space = obs_space
72
+ if act_space is not None:
73
+ self.action_space = act_space
74
+ for group_id in groups.keys():
75
+ self._agent_ids.add(group_id)
76
+
77
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
78
+ obs, info = self.env.reset(seed=seed, options=options)
79
+
80
+ return (
81
+ self._group_items(obs),
82
+ self._group_items(
83
+ info,
84
+ agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())},
85
+ ),
86
+ )
87
+
88
+ def step(self, action_dict):
89
+ # Ungroup and send actions.
90
+ action_dict = self._ungroup_items(action_dict)
91
+ obs, rewards, terminateds, truncateds, infos = self.env.step(action_dict)
92
+
93
+ # Apply grouping transforms to the env outputs
94
+ obs = self._group_items(obs)
95
+ rewards = self._group_items(rewards, agg_fn=lambda gvals: list(gvals.values()))
96
+ # Only if all of the agents are terminated, the group is terminated as well.
97
+ terminateds = self._group_items(
98
+ terminateds, agg_fn=lambda gvals: all(gvals.values())
99
+ )
100
+ # If all of the agents are truncated, the group is truncated as well.
101
+ truncateds = self._group_items(
102
+ truncateds,
103
+ agg_fn=lambda gvals: all(gvals.values()),
104
+ )
105
+ infos = self._group_items(
106
+ infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())}
107
+ )
108
+
109
+ # Aggregate rewards, but preserve the original values in infos.
110
+ for agent_id, rew in rewards.items():
111
+ if isinstance(rew, list):
112
+ rewards[agent_id] = sum(rew)
113
+ if agent_id not in infos:
114
+ infos[agent_id] = {}
115
+ infos[agent_id][GROUP_REWARDS] = rew
116
+
117
+ return obs, rewards, terminateds, truncateds, infos
118
+
119
+ def _ungroup_items(self, items):
120
+ out = {}
121
+ for agent_id, value in items.items():
122
+ if agent_id in self.groups:
123
+ assert len(value) == len(self.groups[agent_id]), (
124
+ agent_id,
125
+ value,
126
+ self.groups,
127
+ )
128
+ for a, v in zip(self.groups[agent_id], value):
129
+ out[a] = v
130
+ else:
131
+ out[agent_id] = value
132
+ return out
133
+
134
+ def _group_items(self, items, agg_fn=None):
135
+ if agg_fn is None:
136
+ agg_fn = lambda gvals: list(gvals.values()) # noqa: E731
137
+
138
+ grouped_items = {}
139
+ for agent_id, item in items.items():
140
+ if agent_id in self.agent_id_to_group:
141
+ group_id = self.agent_id_to_group[agent_id]
142
+ if group_id in grouped_items:
143
+ continue # already added
144
+ group_out = OrderedDict()
145
+ for a in self.groups[group_id]:
146
+ if a in items:
147
+ group_out[a] = items[a]
148
+ else:
149
+ raise ValueError(
150
+ "Missing member of group {}: {}: {}".format(
151
+ group_id, a, items
152
+ )
153
+ )
154
+ grouped_items[group_id] = agg_fn(group_out)
155
+ else:
156
+ grouped_items[agent_id] = item
157
+ return grouped_items
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
4
+ from ray.rllib.utils.typing import MultiAgentDict
5
+
6
+
7
+ class MultiAgentEnvCompatibility(MultiAgentEnv):
8
+ """A wrapper converting MultiAgentEnv from old gym API to the new one.
9
+
10
+ "Old API" refers to step() method returning (observation, reward, done, info),
11
+ and reset() only retuning the observation.
12
+ "New API" refers to step() method returning (observation, reward, terminated,
13
+ truncated, info) and reset() returning (observation, info).
14
+
15
+ Known limitations:
16
+ - Environments that use `self.np_random` might not work as expected.
17
+ """
18
+
19
+ def __init__(self, old_env, render_mode: Optional[str] = None):
20
+ """A wrapper which converts old-style envs to valid modern envs.
21
+
22
+ Some information may be lost in the conversion, so we recommend updating your
23
+ environment.
24
+
25
+ Args:
26
+ old_env: The old MultiAgentEnv to wrap. Implemented with the old API.
27
+ render_mode: The render mode to use when rendering the environment,
28
+ passed automatically to `env.render()`.
29
+ """
30
+ super().__init__()
31
+
32
+ self.metadata = getattr(old_env, "metadata", {"render_modes": []})
33
+ self.render_mode = render_mode
34
+ self.reward_range = getattr(old_env, "reward_range", None)
35
+ self.spec = getattr(old_env, "spec", None)
36
+ self.env = old_env
37
+
38
+ self.observation_space = old_env.observation_space
39
+ self.action_space = old_env.action_space
40
+
41
+ def reset(
42
+ self, *, seed: Optional[int] = None, options: Optional[dict] = None
43
+ ) -> Tuple[MultiAgentDict, MultiAgentDict]:
44
+ # Use old `seed()` method.
45
+ if seed is not None:
46
+ self.env.seed(seed)
47
+ # Options are ignored
48
+
49
+ if self.render_mode == "human":
50
+ self.render()
51
+
52
+ obs = self.env.reset()
53
+ infos = {k: {} for k in obs.keys()}
54
+ return obs, infos
55
+
56
+ def step(
57
+ self, action
58
+ ) -> Tuple[
59
+ MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
60
+ ]:
61
+ obs, rewards, terminateds, infos = self.env.step(action)
62
+
63
+ # Truncated should always be False by default.
64
+ truncateds = {k: False for k in terminateds.keys()}
65
+
66
+ return obs, rewards, terminateds, truncateds, infos
67
+
68
+ def render(self):
69
+ # Use the old `render()` API, where we have to pass in the mode to each call.
70
+ return self.env.render(mode=self.render_mode)
71
+
72
+ def close(self):
73
+ self.env.close()
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import gymnasium as gym
5
+
6
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
7
+ from ray.rllib.env.utils import try_import_pyspiel
8
+
9
+ pyspiel = try_import_pyspiel(error=True)
10
+
11
+
12
+ class OpenSpielEnv(MultiAgentEnv):
13
+ def __init__(self, env):
14
+ super().__init__()
15
+ self.env = env
16
+ self.agents = self.possible_agents = list(range(self.env.num_players()))
17
+ # Store the open-spiel game type.
18
+ self.type = self.env.get_type()
19
+ # Stores the current open-spiel game state.
20
+ self.state = None
21
+
22
+ self.observation_space = gym.spaces.Dict(
23
+ {
24
+ aid: gym.spaces.Box(
25
+ float("-inf"),
26
+ float("inf"),
27
+ (self.env.observation_tensor_size(),),
28
+ dtype=np.float32,
29
+ )
30
+ for aid in self.possible_agents
31
+ }
32
+ )
33
+ self.action_space = gym.spaces.Dict(
34
+ {
35
+ aid: gym.spaces.Discrete(self.env.num_distinct_actions())
36
+ for aid in self.possible_agents
37
+ }
38
+ )
39
+
40
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
41
+ self.state = self.env.new_initial_state()
42
+ return self._get_obs(), {}
43
+
44
+ def step(self, action):
45
+ # Before applying action(s), there could be chance nodes.
46
+ # E.g. if env has to figure out, which agent's action should get
47
+ # resolved first in a simultaneous node.
48
+ self._solve_chance_nodes()
49
+ penalties = {}
50
+
51
+ # Sequential game:
52
+ if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
53
+ curr_player = self.state.current_player()
54
+ assert curr_player in action
55
+ try:
56
+ self.state.apply_action(action[curr_player])
57
+ # TODO: (sven) resolve this hack by publishing legal actions
58
+ # with each step.
59
+ except pyspiel.SpielError:
60
+ self.state.apply_action(np.random.choice(self.state.legal_actions()))
61
+ penalties[curr_player] = -0.1
62
+
63
+ # Compile rewards dict.
64
+ rewards = {ag: r for ag, r in enumerate(self.state.returns())}
65
+ # Simultaneous game.
66
+ else:
67
+ assert self.state.current_player() == -2
68
+ # Apparently, this works, even if one or more actions are invalid.
69
+ self.state.apply_actions([action[ag] for ag in range(self.num_agents)])
70
+
71
+ # Now that we have applied all actions, get the next obs.
72
+ obs = self._get_obs()
73
+
74
+ # Compile rewards dict and add the accumulated penalties
75
+ # (for taking invalid actions).
76
+ rewards = {ag: r for ag, r in enumerate(self.state.returns())}
77
+ for ag, penalty in penalties.items():
78
+ rewards[ag] += penalty
79
+
80
+ # Are we done?
81
+ is_terminated = self.state.is_terminal()
82
+ terminateds = dict(
83
+ {ag: is_terminated for ag in range(self.num_agents)},
84
+ **{"__all__": is_terminated}
85
+ )
86
+ truncateds = dict(
87
+ {ag: False for ag in range(self.num_agents)}, **{"__all__": False}
88
+ )
89
+
90
+ return obs, rewards, terminateds, truncateds, {}
91
+
92
+ def render(self, mode=None) -> None:
93
+ if mode == "human":
94
+ print(self.state)
95
+
96
+ def _get_obs(self):
97
+ # Before calculating an observation, there could be chance nodes
98
+ # (that may have an effect on the actual observations).
99
+ # E.g. After reset, figure out initial (random) positions of the
100
+ # agents.
101
+ self._solve_chance_nodes()
102
+
103
+ if self.state.is_terminal():
104
+ return {}
105
+
106
+ # Sequential game:
107
+ if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
108
+ curr_player = self.state.current_player()
109
+ return {
110
+ curr_player: np.reshape(self.state.observation_tensor(), [-1]).astype(
111
+ np.float32
112
+ )
113
+ }
114
+ # Simultaneous game.
115
+ else:
116
+ assert self.state.current_player() == -2
117
+ return {
118
+ ag: np.reshape(self.state.observation_tensor(ag), [-1]).astype(
119
+ np.float32
120
+ )
121
+ for ag in range(self.num_agents)
122
+ }
123
+
124
+ def _solve_chance_nodes(self):
125
+ # Chance node(s): Sample a (non-player) action and apply.
126
+ while self.state.is_chance_node():
127
+ assert self.state.current_player() == -1
128
+ actions, probs = zip(*self.state.chance_outcomes())
129
+ action = np.random.choice(actions, p=probs)
130
+ self.state.apply_action(action)
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import gymnasium as gym
4
+
5
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
6
+ from ray.rllib.utils.annotations import PublicAPI
7
+
8
+
9
+ @PublicAPI
10
+ class PettingZooEnv(MultiAgentEnv):
11
+ """An interface to the PettingZoo MARL environment library.
12
+
13
+ See: https://github.com/Farama-Foundation/PettingZoo
14
+
15
+ Inherits from MultiAgentEnv and exposes a given AEC
16
+ (actor-environment-cycle) game from the PettingZoo project via the
17
+ MultiAgentEnv public API.
18
+
19
+ Note that the wrapper has the following important limitation:
20
+
21
+ Environments are positive sum games (-> Agents are expected to cooperate
22
+ to maximize reward). This isn't a hard restriction, it just that
23
+ standard algorithms aren't expected to work well in highly competitive
24
+ games.
25
+
26
+ Also note that the earlier existing restriction of all agents having the same
27
+ observation- and action spaces has been lifted. Different agents can now have
28
+ different spaces and the entire environment's e.g. `self.action_space` is a Dict
29
+ mapping agent IDs to individual agents' spaces. Same for `self.observation_space`.
30
+
31
+ .. testcode::
32
+ :skipif: True
33
+
34
+ from pettingzoo.butterfly import prison_v3
35
+ from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
36
+ env = PettingZooEnv(prison_v3.env())
37
+ obs, infos = env.reset()
38
+ # only returns the observation for the agent which should be stepping
39
+ print(obs)
40
+
41
+ .. testoutput::
42
+
43
+ {
44
+ 'prisoner_0': array([[[0, 0, 0],
45
+ [0, 0, 0],
46
+ [0, 0, 0],
47
+ ...,
48
+ [0, 0, 0],
49
+ [0, 0, 0],
50
+ [0, 0, 0]]], dtype=uint8)
51
+ }
52
+
53
+ .. testcode::
54
+ :skipif: True
55
+
56
+ obs, rewards, terminateds, truncateds, infos = env.step({
57
+ "prisoner_0": 1
58
+ })
59
+ # only returns the observation, reward, info, etc, for
60
+ # the agent who's turn is next.
61
+ print(obs)
62
+
63
+ .. testoutput::
64
+
65
+ {
66
+ 'prisoner_1': array([[[0, 0, 0],
67
+ [0, 0, 0],
68
+ [0, 0, 0],
69
+ ...,
70
+ [0, 0, 0],
71
+ [0, 0, 0],
72
+ [0, 0, 0]]], dtype=uint8)
73
+ }
74
+
75
+ .. testcode::
76
+ :skipif: True
77
+
78
+ print(rewards)
79
+
80
+ .. testoutput::
81
+
82
+ {
83
+ 'prisoner_1': 0
84
+ }
85
+
86
+ .. testcode::
87
+ :skipif: True
88
+
89
+ print(terminateds)
90
+
91
+ .. testoutput::
92
+
93
+ {
94
+ 'prisoner_1': False, '__all__': False
95
+ }
96
+
97
+ .. testcode::
98
+ :skipif: True
99
+
100
+ print(truncateds)
101
+
102
+ .. testoutput::
103
+
104
+ {
105
+ 'prisoner_1': False, '__all__': False
106
+ }
107
+
108
+ .. testcode::
109
+ :skipif: True
110
+
111
+ print(infos)
112
+
113
+ .. testoutput::
114
+
115
+ {
116
+ 'prisoner_1': {'map_tuple': (1, 0)}
117
+ }
118
+ """
119
+
120
+ def __init__(self, env):
121
+ super().__init__()
122
+ self.env = env
123
+ env.reset()
124
+
125
+ self._agent_ids = set(self.env.agents)
126
+
127
+ self.observation_space = gym.spaces.Dict(
128
+ {aid: self.env.observation_space(aid) for aid in self._agent_ids}
129
+ )
130
+ self.action_space = gym.spaces.Dict(
131
+ {aid: self.env.action_space(aid) for aid in self._agent_ids}
132
+ )
133
+
134
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
135
+ info = self.env.reset(seed=seed, options=options)
136
+ return (
137
+ {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
138
+ info or {},
139
+ )
140
+
141
+ def step(self, action):
142
+ self.env.step(action[self.env.agent_selection])
143
+ obs_d = {}
144
+ rew_d = {}
145
+ terminated_d = {}
146
+ truncated_d = {}
147
+ info_d = {}
148
+ while self.env.agents:
149
+ obs, rew, terminated, truncated, info = self.env.last()
150
+ agent_id = self.env.agent_selection
151
+ obs_d[agent_id] = obs
152
+ rew_d[agent_id] = rew
153
+ terminated_d[agent_id] = terminated
154
+ truncated_d[agent_id] = truncated
155
+ info_d[agent_id] = info
156
+ if (
157
+ self.env.terminations[self.env.agent_selection]
158
+ or self.env.truncations[self.env.agent_selection]
159
+ ):
160
+ self.env.step(None)
161
+ else:
162
+ break
163
+
164
+ all_gone = not self.env.agents
165
+ terminated_d["__all__"] = all_gone and all(terminated_d.values())
166
+ truncated_d["__all__"] = all_gone and all(truncated_d.values())
167
+
168
+ return obs_d, rew_d, terminated_d, truncated_d, info_d
169
+
170
+ def close(self):
171
+ self.env.close()
172
+
173
+ def render(self):
174
+ return self.env.render(self.render_mode)
175
+
176
+ @property
177
+ def get_sub_environments(self):
178
+ return self.env.unwrapped
179
+
180
+
181
+ @PublicAPI
182
+ class ParallelPettingZooEnv(MultiAgentEnv):
183
+ def __init__(self, env):
184
+ super().__init__()
185
+ self.par_env = env
186
+ self.par_env.reset()
187
+ self._agent_ids = set(self.par_env.agents)
188
+
189
+ self.observation_space = gym.spaces.Dict(
190
+ {aid: self.par_env.observation_space(aid) for aid in self._agent_ids}
191
+ )
192
+ self.action_space = gym.spaces.Dict(
193
+ {aid: self.par_env.action_space(aid) for aid in self._agent_ids}
194
+ )
195
+
196
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
197
+ obs, info = self.par_env.reset(seed=seed, options=options)
198
+ return obs, info or {}
199
+
200
+ def step(self, action_dict):
201
+ obss, rews, terminateds, truncateds, infos = self.par_env.step(action_dict)
202
+ terminateds["__all__"] = all(terminateds.values())
203
+ truncateds["__all__"] = all(truncateds.values())
204
+ return obss, rews, terminateds, truncateds, infos
205
+
206
+ def close(self):
207
+ self.par_env.close()
208
+
209
+ def render(self):
210
+ return self.par_env.render(self.render_mode)
211
+
212
+ @property
213
+ def get_sub_environments(self):
214
+ return self.par_env.unwrapped
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace
2
+ import logging
3
+ import numpy as np
4
+ import random
5
+ import time
6
+ from typing import Callable, Optional, Tuple
7
+
8
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
9
+ from ray.rllib.policy.policy import PolicySpec
10
+ from ray.rllib.utils.annotations import PublicAPI
11
+ from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @PublicAPI
17
+ class Unity3DEnv(MultiAgentEnv):
18
+ """A MultiAgentEnv representing a single Unity3D game instance.
19
+
20
+ For an example on how to use this Env with a running Unity3D editor
21
+ or with a compiled game, see:
22
+ `rllib/examples/unity3d_env_local.py`
23
+ For an example on how to use it inside a Unity game client, which
24
+ connects to an RLlib Policy server, see:
25
+ `rllib/examples/envs/external_envs/unity3d_[client|server].py`
26
+
27
+ Supports all Unity3D (MLAgents) examples, multi- or single-agent and
28
+ gets converted automatically into an ExternalMultiAgentEnv, when used
29
+ inside an RLlib PolicyClient for cloud/distributed training of Unity games.
30
+ """
31
+
32
+ # Default base port when connecting directly to the Editor
33
+ _BASE_PORT_EDITOR = 5004
34
+ # Default base port when connecting to a compiled environment
35
+ _BASE_PORT_ENVIRONMENT = 5005
36
+ # The worker_id for each environment instance
37
+ _WORKER_ID = 0
38
+
39
+ def __init__(
40
+ self,
41
+ file_name: str = None,
42
+ port: Optional[int] = None,
43
+ seed: int = 0,
44
+ no_graphics: bool = False,
45
+ timeout_wait: int = 300,
46
+ episode_horizon: int = 1000,
47
+ ):
48
+ """Initializes a Unity3DEnv object.
49
+
50
+ Args:
51
+ file_name (Optional[str]): Name of the Unity game binary.
52
+ If None, will assume a locally running Unity3D editor
53
+ to be used, instead.
54
+ port (Optional[int]): Port number to connect to Unity environment.
55
+ seed: A random seed value to use for the Unity3D game.
56
+ no_graphics: Whether to run the Unity3D simulator in
57
+ no-graphics mode. Default: False.
58
+ timeout_wait: Time (in seconds) to wait for connection from
59
+ the Unity3D instance.
60
+ episode_horizon: A hard horizon to abide to. After at most
61
+ this many steps (per-agent episode `step()` calls), the
62
+ Unity3D game is reset and will start again (finishing the
63
+ multi-agent episode that the game represents).
64
+ Note: The game itself may contain its own episode length
65
+ limits, which are always obeyed (on top of this value here).
66
+ """
67
+ super().__init__()
68
+
69
+ if file_name is None:
70
+ print(
71
+ "No game binary provided, will use a running Unity editor "
72
+ "instead.\nMake sure you are pressing the Play (|>) button in "
73
+ "your editor to start."
74
+ )
75
+
76
+ import mlagents_envs
77
+ from mlagents_envs.environment import UnityEnvironment
78
+
79
+ # Try connecting to the Unity3D game instance. If a port is blocked
80
+ port_ = None
81
+ while True:
82
+ # Sleep for random time to allow for concurrent startup of many
83
+ # environments (num_env_runners >> 1). Otherwise, would lead to port
84
+ # conflicts sometimes.
85
+ if port_ is not None:
86
+ time.sleep(random.randint(1, 10))
87
+ port_ = port or (
88
+ self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR
89
+ )
90
+ # cache the worker_id and
91
+ # increase it for the next environment
92
+ worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
93
+ Unity3DEnv._WORKER_ID += 1
94
+ try:
95
+ self.unity_env = UnityEnvironment(
96
+ file_name=file_name,
97
+ worker_id=worker_id_,
98
+ base_port=port_,
99
+ seed=seed,
100
+ no_graphics=no_graphics,
101
+ timeout_wait=timeout_wait,
102
+ )
103
+ print("Created UnityEnvironment for port {}".format(port_ + worker_id_))
104
+ except mlagents_envs.exception.UnityWorkerInUseException:
105
+ pass
106
+ else:
107
+ break
108
+
109
+ # ML-Agents API version.
110
+ self.api_version = self.unity_env.API_VERSION.split(".")
111
+ self.api_version = [int(s) for s in self.api_version]
112
+
113
+ # Reset entire env every this number of step calls.
114
+ self.episode_horizon = episode_horizon
115
+ # Keep track of how many times we have called `step` so far.
116
+ self.episode_timesteps = 0
117
+
118
+ def step(
119
+ self, action_dict: MultiAgentDict
120
+ ) -> Tuple[
121
+ MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
122
+ ]:
123
+ """Performs one multi-agent step through the game.
124
+
125
+ Args:
126
+ action_dict: Multi-agent action dict with:
127
+ keys=agent identifier consisting of
128
+ [MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
129
+ [Agent index, a unique MLAgent-assigned index per single agent]
130
+
131
+ Returns:
132
+ tuple:
133
+ - obs: Multi-agent observation dict.
134
+ Only those observations for which to get new actions are
135
+ returned.
136
+ - rewards: Rewards dict matching `obs`.
137
+ - dones: Done dict with only an __all__ multi-agent entry in
138
+ it. __all__=True, if episode is done for all agents.
139
+ - infos: An (empty) info dict.
140
+ """
141
+ from mlagents_envs.base_env import ActionTuple
142
+
143
+ # Set only the required actions (from the DecisionSteps) in Unity3D.
144
+ all_agents = []
145
+ for behavior_name in self.unity_env.behavior_specs:
146
+ # New ML-Agents API: Set all agents actions at the same time
147
+ # via an ActionTuple. Since API v1.4.0.
148
+ if self.api_version[0] > 1 or (
149
+ self.api_version[0] == 1 and self.api_version[1] >= 4
150
+ ):
151
+ actions = []
152
+ for agent_id in self.unity_env.get_steps(behavior_name)[0].agent_id:
153
+ key = behavior_name + "_{}".format(agent_id)
154
+ all_agents.append(key)
155
+ actions.append(action_dict[key])
156
+ if actions:
157
+ if actions[0].dtype == np.float32:
158
+ action_tuple = ActionTuple(continuous=np.array(actions))
159
+ else:
160
+ action_tuple = ActionTuple(discrete=np.array(actions))
161
+ self.unity_env.set_actions(behavior_name, action_tuple)
162
+ # Old behavior: Do not use an ActionTuple and set each agent's
163
+ # action individually.
164
+ else:
165
+ for agent_id in self.unity_env.get_steps(behavior_name)[
166
+ 0
167
+ ].agent_id_to_index.keys():
168
+ key = behavior_name + "_{}".format(agent_id)
169
+ all_agents.append(key)
170
+ self.unity_env.set_action_for_agent(
171
+ behavior_name, agent_id, action_dict[key]
172
+ )
173
+ # Do the step.
174
+ self.unity_env.step()
175
+
176
+ obs, rewards, terminateds, truncateds, infos = self._get_step_results()
177
+
178
+ # Global horizon reached? -> Return __all__ truncated=True, so user
179
+ # can reset. Set all agents' individual `truncated` to True as well.
180
+ self.episode_timesteps += 1
181
+ if self.episode_timesteps > self.episode_horizon:
182
+ return (
183
+ obs,
184
+ rewards,
185
+ terminateds,
186
+ dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}),
187
+ infos,
188
+ )
189
+
190
+ return obs, rewards, terminateds, truncateds, infos
191
+
192
+ def reset(
193
+ self, *, seed=None, options=None
194
+ ) -> Tuple[MultiAgentDict, MultiAgentDict]:
195
+ """Resets the entire Unity3D scene (a single multi-agent episode)."""
196
+ self.episode_timesteps = 0
197
+ self.unity_env.reset()
198
+ obs, _, _, _, infos = self._get_step_results()
199
+ return obs, infos
200
+
201
+ def _get_step_results(self):
202
+ """Collects those agents' obs/rewards that have to act in next `step`.
203
+
204
+ Returns:
205
+ Tuple:
206
+ obs: Multi-agent observation dict.
207
+ Only those observations for which to get new actions are
208
+ returned.
209
+ rewards: Rewards dict matching `obs`.
210
+ dones: Done dict with only an __all__ multi-agent entry in it.
211
+ __all__=True, if episode is done for all agents.
212
+ infos: An (empty) info dict.
213
+ """
214
+ obs = {}
215
+ rewards = {}
216
+ infos = {}
217
+ for behavior_name in self.unity_env.behavior_specs:
218
+ decision_steps, terminal_steps = self.unity_env.get_steps(behavior_name)
219
+ # Important: Only update those sub-envs that are currently
220
+ # available within _env_state.
221
+ # Loop through all envs ("agents") and fill in, whatever
222
+ # information we have.
223
+ for agent_id, idx in decision_steps.agent_id_to_index.items():
224
+ key = behavior_name + "_{}".format(agent_id)
225
+ os = tuple(o[idx] for o in decision_steps.obs)
226
+ os = os[0] if len(os) == 1 else os
227
+ obs[key] = os
228
+ rewards[key] = (
229
+ decision_steps.reward[idx] + decision_steps.group_reward[idx]
230
+ )
231
+ for agent_id, idx in terminal_steps.agent_id_to_index.items():
232
+ key = behavior_name + "_{}".format(agent_id)
233
+ # Only overwrite rewards (last reward in episode), b/c obs
234
+ # here is the last obs (which doesn't matter anyways).
235
+ # Unless key does not exist in obs.
236
+ if key not in obs:
237
+ os = tuple(o[idx] for o in terminal_steps.obs)
238
+ obs[key] = os = os[0] if len(os) == 1 else os
239
+ rewards[key] = (
240
+ terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
241
+ )
242
+
243
+ # Only use dones if all agents are done, then we should do a reset.
244
+ return obs, rewards, {"__all__": False}, {"__all__": False}, infos
245
+
246
+ @staticmethod
247
+ def get_policy_configs_for_game(
248
+ game_name: str,
249
+ ) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
250
+
251
+ # The RLlib server must know about the Spaces that the Client will be
252
+ # using inside Unity3D, up-front.
253
+ obs_spaces = {
254
+ # 3DBall.
255
+ "3DBall": Box(float("-inf"), float("inf"), (8,)),
256
+ # 3DBallHard.
257
+ "3DBallHard": Box(float("-inf"), float("inf"), (45,)),
258
+ # GridFoodCollector
259
+ "GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
260
+ # Pyramids.
261
+ "Pyramids": TupleSpace(
262
+ [
263
+ Box(float("-inf"), float("inf"), (56,)),
264
+ Box(float("-inf"), float("inf"), (56,)),
265
+ Box(float("-inf"), float("inf"), (56,)),
266
+ Box(float("-inf"), float("inf"), (4,)),
267
+ ]
268
+ ),
269
+ # SoccerTwos.
270
+ "SoccerPlayer": TupleSpace(
271
+ [
272
+ Box(-1.0, 1.0, (264,)),
273
+ Box(-1.0, 1.0, (72,)),
274
+ ]
275
+ ),
276
+ # SoccerStrikersVsGoalie.
277
+ "Goalie": Box(float("-inf"), float("inf"), (738,)),
278
+ "Striker": TupleSpace(
279
+ [
280
+ Box(float("-inf"), float("inf"), (231,)),
281
+ Box(float("-inf"), float("inf"), (63,)),
282
+ ]
283
+ ),
284
+ # Sorter.
285
+ "Sorter": TupleSpace(
286
+ [
287
+ Box(
288
+ float("-inf"),
289
+ float("inf"),
290
+ (
291
+ 20,
292
+ 23,
293
+ ),
294
+ ),
295
+ Box(float("-inf"), float("inf"), (10,)),
296
+ Box(float("-inf"), float("inf"), (8,)),
297
+ ]
298
+ ),
299
+ # Tennis.
300
+ "Tennis": Box(float("-inf"), float("inf"), (27,)),
301
+ # VisualHallway.
302
+ "VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)),
303
+ # Walker.
304
+ "Walker": Box(float("-inf"), float("inf"), (212,)),
305
+ # FoodCollector.
306
+ "FoodCollector": TupleSpace(
307
+ [
308
+ Box(float("-inf"), float("inf"), (49,)),
309
+ Box(float("-inf"), float("inf"), (4,)),
310
+ ]
311
+ ),
312
+ }
313
+ action_spaces = {
314
+ # 3DBall.
315
+ "3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
316
+ # 3DBallHard.
317
+ "3DBallHard": Box(-1.0, 1.0, (2,), dtype=np.float32),
318
+ # GridFoodCollector.
319
+ "GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
320
+ # Pyramids.
321
+ "Pyramids": MultiDiscrete([5]),
322
+ # SoccerStrikersVsGoalie.
323
+ "Goalie": MultiDiscrete([3, 3, 3]),
324
+ "Striker": MultiDiscrete([3, 3, 3]),
325
+ # SoccerTwos.
326
+ "SoccerPlayer": MultiDiscrete([3, 3, 3]),
327
+ # Sorter.
328
+ "Sorter": MultiDiscrete([3, 3, 3]),
329
+ # Tennis.
330
+ "Tennis": Box(-1.0, 1.0, (3,)),
331
+ # VisualHallway.
332
+ "VisualHallway": MultiDiscrete([5]),
333
+ # Walker.
334
+ "Walker": Box(-1.0, 1.0, (39,)),
335
+ # FoodCollector.
336
+ "FoodCollector": MultiDiscrete([3, 3, 3, 2]),
337
+ }
338
+
339
+ # Policies (Unity: "behaviors") and agent-to-policy mapping fns.
340
+ if game_name == "SoccerStrikersVsGoalie":
341
+ policies = {
342
+ "Goalie": PolicySpec(
343
+ observation_space=obs_spaces["Goalie"],
344
+ action_space=action_spaces["Goalie"],
345
+ ),
346
+ "Striker": PolicySpec(
347
+ observation_space=obs_spaces["Striker"],
348
+ action_space=action_spaces["Striker"],
349
+ ),
350
+ }
351
+
352
+ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
353
+ return "Striker" if "Striker" in agent_id else "Goalie"
354
+
355
+ elif game_name == "SoccerTwos":
356
+ policies = {
357
+ "PurplePlayer": PolicySpec(
358
+ observation_space=obs_spaces["SoccerPlayer"],
359
+ action_space=action_spaces["SoccerPlayer"],
360
+ ),
361
+ "BluePlayer": PolicySpec(
362
+ observation_space=obs_spaces["SoccerPlayer"],
363
+ action_space=action_spaces["SoccerPlayer"],
364
+ ),
365
+ }
366
+
367
+ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
368
+ return "BluePlayer" if "1_" in agent_id else "PurplePlayer"
369
+
370
+ else:
371
+ policies = {
372
+ game_name: PolicySpec(
373
+ observation_space=obs_spaces[game_name],
374
+ action_space=action_spaces[game_name],
375
+ ),
376
+ }
377
+
378
+ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
379
+ return game_name
380
+
381
+ return policies, policy_mapping_fn
.venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc ADDED
Binary file (3.77 kB). View file