koichi12 commited on
Commit
1f700b9
·
verified ·
1 Parent(s): 1d3adaf

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/rllib/connectors/__init__.py +0 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py +478 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py +394 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py +1017 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py +40 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py +55 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py +208 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py +6 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py +253 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py +80 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py +168 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py +131 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py +30 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py +91 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py +82 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py +7 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py +146 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py +70 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py +92 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py +46 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py +170 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py +8 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py +1795 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py +1030 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__init__.py +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py +357 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__init__.py +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py +664 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py +59 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py +53 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file defines base types and common structures for RLlib connectors.
2
+ """
3
+
4
+ import abc
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
7
+
8
+ import gymnasium as gym
9
+
10
+ from ray.rllib.policy.view_requirement import ViewRequirement
11
+ from ray.rllib.utils.typing import (
12
+ ActionConnectorDataType,
13
+ AgentConnectorDataType,
14
+ AlgorithmConfigDict,
15
+ TensorType,
16
+ )
17
+ from ray.rllib.utils.annotations import OldAPIStack
18
+
19
+ if TYPE_CHECKING:
20
+ from ray.rllib.policy.policy import Policy
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @OldAPIStack
26
+ class ConnectorContext:
27
+ """Data bits that may be needed for running connectors.
28
+
29
+ Note(jungong) : we need to be really careful with the data fields here.
30
+ E.g., everything needs to be serializable, in case we need to fetch them
31
+ in a remote setting.
32
+ """
33
+
34
+ # TODO(jungong) : figure out how to fetch these in a remote setting.
35
+ # Probably from a policy server when initializing a policy client.
36
+
37
+ def __init__(
38
+ self,
39
+ config: AlgorithmConfigDict = None,
40
+ initial_states: List[TensorType] = None,
41
+ observation_space: gym.Space = None,
42
+ action_space: gym.Space = None,
43
+ view_requirements: Dict[str, ViewRequirement] = None,
44
+ is_policy_recurrent: bool = False,
45
+ ):
46
+ """Construct a ConnectorContext instance.
47
+
48
+ Args:
49
+ initial_states: States that are used for constructing
50
+ the initial input dict for RNN models. [] if a model is not recurrent.
51
+ action_space_struct: a policy's action space, in python
52
+ data format. E.g., python dict instead of DictSpace, python tuple
53
+ instead of TupleSpace.
54
+ """
55
+ self.config = config or {}
56
+ self.initial_states = initial_states or []
57
+ self.observation_space = observation_space
58
+ self.action_space = action_space
59
+ self.view_requirements = view_requirements
60
+ self.is_policy_recurrent = is_policy_recurrent
61
+
62
+ @staticmethod
63
+ def from_policy(policy: "Policy") -> "ConnectorContext":
64
+ """Build ConnectorContext from a given policy.
65
+
66
+ Args:
67
+ policy: Policy
68
+
69
+ Returns:
70
+ A ConnectorContext instance.
71
+ """
72
+ return ConnectorContext(
73
+ config=policy.config,
74
+ initial_states=policy.get_initial_state(),
75
+ observation_space=policy.observation_space,
76
+ action_space=policy.action_space,
77
+ view_requirements=policy.view_requirements,
78
+ is_policy_recurrent=policy.is_recurrent(),
79
+ )
80
+
81
+
82
+ @OldAPIStack
83
+ class Connector(abc.ABC):
84
+ """Connector base class.
85
+
86
+ A connector is a step of transformation, of either envrionment data before they
87
+ get to a policy, or policy output before it is sent back to the environment.
88
+
89
+ Connectors may be training-aware, for example, behave slightly differently
90
+ during training and inference.
91
+
92
+ All connectors are required to be serializable and implement to_state().
93
+ """
94
+
95
+ def __init__(self, ctx: ConnectorContext):
96
+ # Default is training mode.
97
+ self._is_training = True
98
+
99
+ def in_training(self):
100
+ self._is_training = True
101
+
102
+ def in_eval(self):
103
+ self._is_training = False
104
+
105
+ def __str__(self, indentation: int = 0):
106
+ return " " * indentation + self.__class__.__name__
107
+
108
+ def to_state(self) -> Tuple[str, Any]:
109
+ """Serialize a connector into a JSON serializable Tuple.
110
+
111
+ to_state is required, so that all Connectors are serializable.
112
+
113
+ Returns:
114
+ A tuple of connector's name and its serialized states.
115
+ String should match the name used to register the connector,
116
+ while state can be any single data structure that contains the
117
+ serialized state of the connector. If a connector is stateless,
118
+ state can simply be None.
119
+ """
120
+ # Must implement by each connector.
121
+ return NotImplementedError
122
+
123
+ @staticmethod
124
+ def from_state(self, ctx: ConnectorContext, params: Any) -> "Connector":
125
+ """De-serialize a JSON params back into a Connector.
126
+
127
+ from_state is required, so that all Connectors are serializable.
128
+
129
+ Args:
130
+ ctx: Context for constructing this connector.
131
+ params: Serialized states of the connector to be recovered.
132
+
133
+ Returns:
134
+ De-serialized connector.
135
+ """
136
+ # Must implement by each connector.
137
+ return NotImplementedError
138
+
139
+
140
+ @OldAPIStack
141
+ class AgentConnector(Connector):
142
+ """Connector connecting user environments to RLlib policies.
143
+
144
+ An agent connector transforms a list of agent data in AgentConnectorDataType
145
+ format into a new list in the same AgentConnectorDataTypes format.
146
+ The input API is designed so agent connectors can have access to all the
147
+ agents assigned to a particular policy.
148
+
149
+ AgentConnectorDataTypes can be used to specify arbitrary type of env data,
150
+
151
+ Example:
152
+
153
+ Represent a list of agent data from one env step() call.
154
+
155
+ .. testcode::
156
+
157
+ import numpy as np
158
+ ac = AgentConnectorDataType(
159
+ env_id="env_1",
160
+ agent_id=None,
161
+ data={
162
+ "agent_1": np.array([1, 2, 3]),
163
+ "agent_2": np.array([4, 5, 6]),
164
+ }
165
+ )
166
+
167
+ Or a single agent data ready to be preprocessed.
168
+
169
+ .. testcode::
170
+
171
+ ac = AgentConnectorDataType(
172
+ env_id="env_1",
173
+ agent_id="agent_1",
174
+ data=np.array([1, 2, 3]),
175
+ )
176
+
177
+ We can also adapt a simple stateless function into an agent connector by
178
+ using register_lambda_agent_connector:
179
+
180
+ .. testcode::
181
+
182
+ import numpy as np
183
+ from ray.rllib.connectors.agent.lambdas import (
184
+ register_lambda_agent_connector
185
+ )
186
+ TimesTwoAgentConnector = register_lambda_agent_connector(
187
+ "TimesTwoAgentConnector", lambda data: data * 2
188
+ )
189
+
190
+ # More complicated agent connectors can be implemented by extending this
191
+ # AgentConnector class:
192
+
193
+ class FrameSkippingAgentConnector(AgentConnector):
194
+ def __init__(self, n):
195
+ self._n = n
196
+ self._frame_count = default_dict(str, default_dict(str, int))
197
+
198
+ def reset(self, env_id: str):
199
+ del self._frame_count[env_id]
200
+
201
+ def __call__(
202
+ self, ac_data: List[AgentConnectorDataType]
203
+ ) -> List[AgentConnectorDataType]:
204
+ ret = []
205
+ for d in ac_data:
206
+ assert d.env_id and d.agent_id, "Skipping works per agent!"
207
+
208
+ count = self._frame_count[ac_data.env_id][ac_data.agent_id]
209
+ self._frame_count[ac_data.env_id][ac_data.agent_id] = (
210
+ count + 1
211
+ )
212
+
213
+ if count % self._n == 0:
214
+ ret.append(d)
215
+ return ret
216
+
217
+ As shown, an agent connector may choose to emit an empty list to stop input
218
+ observations from being further prosessed.
219
+ """
220
+
221
+ def reset(self, env_id: str):
222
+ """Reset connector state for a specific environment.
223
+
224
+ For example, at the end of an episode.
225
+
226
+ Args:
227
+ env_id: required. ID of a user environment. Required.
228
+ """
229
+ pass
230
+
231
+ def on_policy_output(self, output: ActionConnectorDataType):
232
+ """Callback on agent connector of policy output.
233
+
234
+ This is useful for certain connectors, for example RNN state buffering,
235
+ where the agent connect needs to be aware of the output of a policy
236
+ forward pass.
237
+
238
+ Args:
239
+ ctx: Context for running this connector call.
240
+ output: Env and agent IDs, plus data output from policy forward pass.
241
+ """
242
+ pass
243
+
244
+ def __call__(
245
+ self, acd_list: List[AgentConnectorDataType]
246
+ ) -> List[AgentConnectorDataType]:
247
+ """Transform a list of data items from env before they reach policy.
248
+
249
+ Args:
250
+ ac_data: List of env and agent IDs, plus arbitrary data items from
251
+ an environment or upstream agent connectors.
252
+
253
+ Returns:
254
+ A list of transformed data items in AgentConnectorDataType format.
255
+ The shape of a returned list does not have to match that of the input list.
256
+ An AgentConnector may choose to derive multiple outputs for a single piece
257
+ of input data, for example multi-agent obs -> multiple single agent obs.
258
+ Agent connectors may also choose to skip emitting certain inputs,
259
+ useful for connectors such as frame skipping.
260
+ """
261
+ assert isinstance(
262
+ acd_list, (list, tuple)
263
+ ), "Input to agent connectors are list of AgentConnectorDataType."
264
+ # Default implementation. Simply call transform on each agent connector data.
265
+ return [self.transform(d) for d in acd_list]
266
+
267
+ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
268
+ """Transform a single agent connector data item.
269
+
270
+ Args:
271
+ data: Env and agent IDs, plus arbitrary data item from a single agent
272
+ of an environment.
273
+
274
+ Returns:
275
+ A transformed piece of agent connector data.
276
+ """
277
+ raise NotImplementedError
278
+
279
+
280
+ @OldAPIStack
281
+ class ActionConnector(Connector):
282
+ """Action connector connects policy outputs including actions,
283
+ to user environments.
284
+
285
+ An action connector transforms a single piece of policy output in
286
+ ActionConnectorDataType format, which is basically PolicyOutputType plus env and
287
+ agent IDs.
288
+
289
+ Any functions that operate directly on PolicyOutputType can be easily adapted
290
+ into an ActionConnector by using register_lambda_action_connector.
291
+
292
+ Example:
293
+
294
+ .. testcode::
295
+
296
+ from ray.rllib.connectors.action.lambdas import (
297
+ register_lambda_action_connector
298
+ )
299
+ ZeroActionConnector = register_lambda_action_connector(
300
+ "ZeroActionsConnector",
301
+ lambda actions, states, fetches: (
302
+ np.zeros_like(actions), states, fetches
303
+ )
304
+ )
305
+
306
+ More complicated action connectors can also be implemented by sub-classing
307
+ this ActionConnector class.
308
+ """
309
+
310
+ def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
311
+ """Transform policy output before they are sent to a user environment.
312
+
313
+ Args:
314
+ ac_data: Env and agent IDs, plus policy output.
315
+
316
+ Returns:
317
+ The processed action connector data.
318
+ """
319
+ return self.transform(ac_data)
320
+
321
+ def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
322
+ """Implementation of the actual transform.
323
+
324
+ Users should override transform instead of __call__ directly.
325
+
326
+ Args:
327
+ ac_data: Env and agent IDs, plus policy output.
328
+
329
+ Returns:
330
+ The processed action connector data.
331
+ """
332
+ raise NotImplementedError
333
+
334
+
335
+ @OldAPIStack
336
+ class ConnectorPipeline(abc.ABC):
337
+ """Utility class for quick manipulation of a connector pipeline."""
338
+
339
+ def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
340
+ self.connectors = connectors
341
+
342
+ def in_training(self):
343
+ for c in self.connectors:
344
+ c.in_training()
345
+
346
+ def in_eval(self):
347
+ for c in self.connectors:
348
+ c.in_eval()
349
+
350
+ def remove(self, name: str):
351
+ """Remove a connector by <name>
352
+
353
+ Args:
354
+ name: name of the connector to be removed.
355
+ """
356
+ idx = -1
357
+ for i, c in enumerate(self.connectors):
358
+ if c.__class__.__name__ == name:
359
+ idx = i
360
+ break
361
+ if idx >= 0:
362
+ del self.connectors[idx]
363
+ logger.info(f"Removed connector {name} from {self.__class__.__name__}.")
364
+ else:
365
+ logger.warning(f"Trying to remove a non-existent connector {name}.")
366
+
367
+ def insert_before(self, name: str, connector: Connector):
368
+ """Insert a new connector before connector <name>
369
+
370
+ Args:
371
+ name: name of the connector before which a new connector
372
+ will get inserted.
373
+ connector: a new connector to be inserted.
374
+ """
375
+ idx = -1
376
+ for idx, c in enumerate(self.connectors):
377
+ if c.__class__.__name__ == name:
378
+ break
379
+ if idx < 0:
380
+ raise ValueError(f"Can not find connector {name}")
381
+ self.connectors.insert(idx, connector)
382
+
383
+ logger.info(
384
+ f"Inserted {connector.__class__.__name__} before {name} "
385
+ f"to {self.__class__.__name__}."
386
+ )
387
+
388
+ def insert_after(self, name: str, connector: Connector):
389
+ """Insert a new connector after connector <name>
390
+
391
+ Args:
392
+ name: name of the connector after which a new connector
393
+ will get inserted.
394
+ connector: a new connector to be inserted.
395
+ """
396
+ idx = -1
397
+ for idx, c in enumerate(self.connectors):
398
+ if c.__class__.__name__ == name:
399
+ break
400
+ if idx < 0:
401
+ raise ValueError(f"Can not find connector {name}")
402
+ self.connectors.insert(idx + 1, connector)
403
+
404
+ logger.info(
405
+ f"Inserted {connector.__class__.__name__} after {name} "
406
+ f"to {self.__class__.__name__}."
407
+ )
408
+
409
+ def prepend(self, connector: Connector):
410
+ """Append a new connector at the beginning of a connector pipeline.
411
+
412
+ Args:
413
+ connector: a new connector to be appended.
414
+ """
415
+ self.connectors.insert(0, connector)
416
+
417
+ logger.info(
418
+ f"Added {connector.__class__.__name__} to the beginning of "
419
+ f"{self.__class__.__name__}."
420
+ )
421
+
422
+ def append(self, connector: Connector):
423
+ """Append a new connector at the end of a connector pipeline.
424
+
425
+ Args:
426
+ connector: a new connector to be appended.
427
+ """
428
+ self.connectors.append(connector)
429
+
430
+ logger.info(
431
+ f"Added {connector.__class__.__name__} to the end of "
432
+ f"{self.__class__.__name__}."
433
+ )
434
+
435
+ def __str__(self, indentation: int = 0):
436
+ return "\n".join(
437
+ [" " * indentation + self.__class__.__name__]
438
+ + [c.__str__(indentation + 4) for c in self.connectors]
439
+ )
440
+
441
+ def __getitem__(self, key: Union[str, int, type]):
442
+ """Returns a list of connectors that fit 'key'.
443
+
444
+ If key is a number n, we return a list with the nth element of this pipeline.
445
+ If key is a Connector class or a string matching the class name of a
446
+ Connector class, we return a list of all connectors in this pipeline matching
447
+ the specified class.
448
+
449
+ Args:
450
+ key: The key to index by
451
+
452
+ Returns: The Connector at index `key`.
453
+ """
454
+ # In case key is a class
455
+ if not isinstance(key, str):
456
+ if isinstance(key, slice):
457
+ raise NotImplementedError(
458
+ "Slicing of ConnectorPipeline is currently not supported."
459
+ )
460
+ elif isinstance(key, int):
461
+ return [self.connectors[key]]
462
+ elif isinstance(key, type):
463
+ results = []
464
+ for c in self.connectors:
465
+ if issubclass(c.__class__, key):
466
+ results.append(c)
467
+ return results
468
+ else:
469
+ raise NotImplementedError(
470
+ "Indexing by {} is currently not supported.".format(type(key))
471
+ )
472
+
473
+ results = []
474
+ for c in self.connectors:
475
+ if c.__class__.__name__ == key:
476
+ results.append(c)
477
+
478
+ return results
.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
3
+
4
+ import gymnasium as gym
5
+
6
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
7
+ from ray.rllib.core.rl_module.rl_module import RLModule
8
+ from ray.rllib.utils.annotations import override
9
+ from ray.rllib.utils.checkpoints import Checkpointable
10
+ from ray.rllib.utils.metrics import TIMERS, CONNECTOR_TIMERS
11
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
12
+ from ray.rllib.utils.typing import EpisodeType, StateDict
13
+ from ray.util.annotations import PublicAPI
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @PublicAPI(stability="alpha")
19
+ class ConnectorPipelineV2(ConnectorV2):
20
+ """Utility class for quick manipulation of a connector pipeline."""
21
+
22
+ @override(ConnectorV2)
23
+ def recompute_output_observation_space(
24
+ self,
25
+ input_observation_space: gym.Space,
26
+ input_action_space: gym.Space,
27
+ ) -> gym.Space:
28
+ self._fix_spaces(input_observation_space, input_action_space)
29
+ return self.observation_space
30
+
31
+ @override(ConnectorV2)
32
+ def recompute_output_action_space(
33
+ self,
34
+ input_observation_space: gym.Space,
35
+ input_action_space: gym.Space,
36
+ ) -> gym.Space:
37
+ self._fix_spaces(input_observation_space, input_action_space)
38
+ return self.action_space
39
+
40
+ def __init__(
41
+ self,
42
+ input_observation_space: Optional[gym.Space] = None,
43
+ input_action_space: Optional[gym.Space] = None,
44
+ *,
45
+ connectors: Optional[List[ConnectorV2]] = None,
46
+ **kwargs,
47
+ ):
48
+ """Initializes a ConnectorPipelineV2 instance.
49
+
50
+ Args:
51
+ input_observation_space: The (optional) input observation space for this
52
+ connector piece. This is the space coming from a previous connector
53
+ piece in the (env-to-module or learner) pipeline or is directly
54
+ defined within the gym.Env.
55
+ input_action_space: The (optional) input action space for this connector
56
+ piece. This is the space coming from a previous connector piece in the
57
+ (module-to-env) pipeline or is directly defined within the gym.Env.
58
+ connectors: A list of individual ConnectorV2 pieces to be added to this
59
+ pipeline during construction. Note that you can always add (or remove)
60
+ more ConnectorV2 pieces later on the fly.
61
+ """
62
+ self.connectors = []
63
+
64
+ for conn in connectors:
65
+ # If we have a `ConnectorV2` instance just append.
66
+ if isinstance(conn, ConnectorV2):
67
+ self.connectors.append(conn)
68
+ # If, we have a class with `args` and `kwargs`, build the instance.
69
+ # Note that this way of constructing a pipeline should only be
70
+ # used internally when restoring the pipeline state from a
71
+ # checkpoint.
72
+ elif isinstance(conn, tuple) and len(conn) == 3:
73
+ self.connectors.append(conn[0](*conn[1], **conn[2]))
74
+
75
+ super().__init__(input_observation_space, input_action_space, **kwargs)
76
+
77
+ def __len__(self):
78
+ return len(self.connectors)
79
+
80
+ @override(ConnectorV2)
81
+ def __call__(
82
+ self,
83
+ *,
84
+ rl_module: RLModule,
85
+ batch: Dict[str, Any],
86
+ episodes: List[EpisodeType],
87
+ explore: Optional[bool] = None,
88
+ shared_data: Optional[dict] = None,
89
+ metrics: Optional[MetricsLogger] = None,
90
+ **kwargs,
91
+ ) -> Any:
92
+ """In a pipeline, we simply call each of our connector pieces after each other.
93
+
94
+ Each connector piece receives as input the output of the previous connector
95
+ piece in the pipeline.
96
+ """
97
+ shared_data = shared_data if shared_data is not None else {}
98
+ # Loop through connector pieces and call each one with the output of the
99
+ # previous one. Thereby, time each connector piece's call.
100
+ for connector in self.connectors:
101
+ # TODO (sven): Add MetricsLogger to non-Learner components that have a
102
+ # LearnerConnector pipeline.
103
+ stats = None
104
+ if metrics:
105
+ stats = metrics.log_time(
106
+ kwargs.get("metrics_prefix_key", ())
107
+ + (TIMERS, CONNECTOR_TIMERS, connector.__class__.__name__)
108
+ )
109
+ stats.__enter__()
110
+
111
+ batch = connector(
112
+ rl_module=rl_module,
113
+ batch=batch,
114
+ episodes=episodes,
115
+ explore=explore,
116
+ shared_data=shared_data,
117
+ metrics=metrics,
118
+ # Deprecated arg.
119
+ data=batch,
120
+ **kwargs,
121
+ )
122
+
123
+ if metrics:
124
+ stats.__exit__(None, None, None)
125
+
126
+ if not isinstance(batch, dict):
127
+ raise ValueError(
128
+ f"`data` returned by ConnectorV2 {connector} must be a dict! "
129
+ f"You returned {batch}. Check your (custom) connectors' "
130
+ f"`__call__()` method's return value and make sure you return "
131
+ f"the `data` arg passed in (either altered or unchanged)."
132
+ )
133
+
134
+ return batch
135
+
136
+ def remove(self, name_or_class: Union[str, Type]):
137
+ """Remove a single connector piece in this pipeline by its name or class.
138
+
139
+ Args:
140
+ name: The name of the connector piece to be removed from the pipeline.
141
+ """
142
+ idx = -1
143
+ for i, c in enumerate(self.connectors):
144
+ if c.__class__.__name__ == name_or_class:
145
+ idx = i
146
+ break
147
+ if idx >= 0:
148
+ del self.connectors[idx]
149
+ self._fix_spaces(self.input_observation_space, self.input_action_space)
150
+ logger.info(
151
+ f"Removed connector {name_or_class} from {self.__class__.__name__}."
152
+ )
153
+ else:
154
+ logger.warning(
155
+ f"Trying to remove a non-existent connector {name_or_class}."
156
+ )
157
+
158
+ def insert_before(
159
+ self,
160
+ name_or_class: Union[str, type],
161
+ connector: ConnectorV2,
162
+ ) -> ConnectorV2:
163
+ """Insert a new connector piece before an existing piece (by name or class).
164
+
165
+ Args:
166
+ name_or_class: Name or class of the connector piece before which `connector`
167
+ will get inserted.
168
+ connector: The new connector piece to be inserted.
169
+
170
+ Returns:
171
+ The ConnectorV2 before which `connector` has been inserted.
172
+ """
173
+ idx = -1
174
+ for idx, c in enumerate(self.connectors):
175
+ if (
176
+ isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class
177
+ ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class):
178
+ break
179
+ if idx < 0:
180
+ raise ValueError(
181
+ f"Can not find connector with name or type '{name_or_class}'!"
182
+ )
183
+ next_connector = self.connectors[idx]
184
+
185
+ self.connectors.insert(idx, connector)
186
+ self._fix_spaces(self.input_observation_space, self.input_action_space)
187
+
188
+ logger.info(
189
+ f"Inserted {connector.__class__.__name__} before {name_or_class} "
190
+ f"to {self.__class__.__name__}."
191
+ )
192
+ return next_connector
193
+
194
+ def insert_after(
195
+ self,
196
+ name_or_class: Union[str, Type],
197
+ connector: ConnectorV2,
198
+ ) -> ConnectorV2:
199
+ """Insert a new connector piece after an existing piece (by name or class).
200
+
201
+ Args:
202
+ name_or_class: Name or class of the connector piece after which `connector`
203
+ will get inserted.
204
+ connector: The new connector piece to be inserted.
205
+
206
+ Returns:
207
+ The ConnectorV2 after which `connector` has been inserted.
208
+ """
209
+ idx = -1
210
+ for idx, c in enumerate(self.connectors):
211
+ if (
212
+ isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class
213
+ ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class):
214
+ break
215
+ if idx < 0:
216
+ raise ValueError(
217
+ f"Can not find connector with name or type '{name_or_class}'!"
218
+ )
219
+ prev_connector = self.connectors[idx]
220
+
221
+ self.connectors.insert(idx + 1, connector)
222
+ self._fix_spaces(self.input_observation_space, self.input_action_space)
223
+
224
+ logger.info(
225
+ f"Inserted {connector.__class__.__name__} after {name_or_class} "
226
+ f"to {self.__class__.__name__}."
227
+ )
228
+
229
+ return prev_connector
230
+
231
+ def prepend(self, connector: ConnectorV2) -> None:
232
+ """Prepend a new connector at the beginning of a connector pipeline.
233
+
234
+ Args:
235
+ connector: The new connector piece to be prepended to this pipeline.
236
+ """
237
+ self.connectors.insert(0, connector)
238
+ self._fix_spaces(self.input_observation_space, self.input_action_space)
239
+
240
+ logger.info(
241
+ f"Added {connector.__class__.__name__} to the beginning of "
242
+ f"{self.__class__.__name__}."
243
+ )
244
+
245
+ def append(self, connector: ConnectorV2) -> None:
246
+ """Append a new connector at the end of a connector pipeline.
247
+
248
+ Args:
249
+ connector: The new connector piece to be appended to this pipeline.
250
+ """
251
+ self.connectors.append(connector)
252
+ self._fix_spaces(self.input_observation_space, self.input_action_space)
253
+
254
+ logger.info(
255
+ f"Added {connector.__class__.__name__} to the end of "
256
+ f"{self.__class__.__name__}."
257
+ )
258
+
259
+ @override(ConnectorV2)
260
+ def get_state(
261
+ self,
262
+ components: Optional[Union[str, Collection[str]]] = None,
263
+ *,
264
+ not_components: Optional[Union[str, Collection[str]]] = None,
265
+ **kwargs,
266
+ ) -> StateDict:
267
+ state = {}
268
+ for conn in self.connectors:
269
+ conn_name = type(conn).__name__
270
+ if self._check_component(conn_name, components, not_components):
271
+ state[conn_name] = conn.get_state(
272
+ components=self._get_subcomponents(conn_name, components),
273
+ not_components=self._get_subcomponents(conn_name, not_components),
274
+ **kwargs,
275
+ )
276
+ return state
277
+
278
+ @override(ConnectorV2)
279
+ def set_state(self, state: Dict[str, Any]) -> None:
280
+ for conn in self.connectors:
281
+ conn_name = type(conn).__name__
282
+ if conn_name in state:
283
+ conn.set_state(state[conn_name])
284
+
285
+ @override(Checkpointable)
286
+ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
287
+ return [(type(conn).__name__, conn) for conn in self.connectors]
288
+
289
+ # Note that we don't have to override Checkpointable.get_ctor_args_and_kwargs and
290
+ # don't have to return the `connectors` c'tor kwarg from there. This is b/c all
291
+ # connector pieces in this pipeline are themselves Checkpointable components,
292
+ # so they will be properly written into this pipeline's checkpoint.
293
+ @override(Checkpointable)
294
+ def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
295
+ return (
296
+ (self.input_observation_space, self.input_action_space), # *args
297
+ {
298
+ "connectors": [
299
+ (type(conn), *conn.get_ctor_args_and_kwargs())
300
+ for conn in self.connectors
301
+ ]
302
+ },
303
+ )
304
+
305
+ @override(ConnectorV2)
306
+ def reset_state(self) -> None:
307
+ for conn in self.connectors:
308
+ conn.reset_state()
309
+
310
+ @override(ConnectorV2)
311
+ def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
312
+ merged_states = {}
313
+ if not states:
314
+ return merged_states
315
+ for i, (key, item) in enumerate(states[0].items()):
316
+ state_list = [state[key] for state in states]
317
+ conn = self.connectors[i]
318
+ merged_states[key] = conn.merge_states(state_list)
319
+ return merged_states
320
+
321
+ def __repr__(self, indentation: int = 0):
322
+ return "\n".join(
323
+ [" " * indentation + self.__class__.__name__]
324
+ + [c.__str__(indentation + 4) for c in self.connectors]
325
+ )
326
+
327
+ def __getitem__(
328
+ self,
329
+ key: Union[str, int, Type],
330
+ ) -> Union[ConnectorV2, List[ConnectorV2]]:
331
+ """Returns a single ConnectorV2 or list of ConnectorV2s that fit `key`.
332
+
333
+ If key is an int, we return a single ConnectorV2 at that index in this pipeline.
334
+ If key is a ConnectorV2 type or a string matching the class name of a
335
+ ConnectorV2 in this pipeline, we return a list of all ConnectorV2s in this
336
+ pipeline matching the specified class.
337
+
338
+ Args:
339
+ key: The key to find or to index by.
340
+
341
+ Returns:
342
+ A single ConnectorV2 or a list of ConnectorV2s matching `key`.
343
+ """
344
+ # Key is an int -> Index into pipeline and return.
345
+ if isinstance(key, int):
346
+ return self.connectors[key]
347
+ # Key is a class.
348
+ elif isinstance(key, type):
349
+ results = []
350
+ for c in self.connectors:
351
+ if issubclass(c.__class__, key):
352
+ results.append(c)
353
+ return results
354
+ # Key is a string -> Find connector(s) by name.
355
+ elif isinstance(key, str):
356
+ results = []
357
+ for c in self.connectors:
358
+ if c.name == key:
359
+ results.append(c)
360
+ return results
361
+ # Slicing not supported (yet).
362
+ elif isinstance(key, slice):
363
+ raise NotImplementedError(
364
+ "Slicing of ConnectorPipelineV2 is currently not supported!"
365
+ )
366
+ else:
367
+ raise NotImplementedError(
368
+ f"Indexing ConnectorPipelineV2 by {type(key)} is currently not "
369
+ f"supported!"
370
+ )
371
+
372
+ @property
373
+ def observation_space(self):
374
+ if len(self) > 0:
375
+ return self.connectors[-1].observation_space
376
+ return self._observation_space
377
+
378
+ @property
379
+ def action_space(self):
380
+ if len(self) > 0:
381
+ return self.connectors[-1].action_space
382
+ return self._action_space
383
+
384
+ def _fix_spaces(self, input_observation_space, input_action_space):
385
+ if len(self) > 0:
386
+ # Fix each connector's input_observation- and input_action space in
387
+ # the pipeline.
388
+ obs_space = input_observation_space
389
+ act_space = input_action_space
390
+ for con in self.connectors:
391
+ con.input_action_space = act_space
392
+ con.input_observation_space = obs_space
393
+ obs_space = con.observation_space
394
+ act_space = con.action_space
.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from collections import defaultdict
3
+ import inspect
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Collection,
8
+ Dict,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Union,
14
+ )
15
+
16
+ import gymnasium as gym
17
+ import tree
18
+
19
+ from ray.rllib.core.rl_module.rl_module import RLModule
20
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
21
+ from ray.rllib.utils import force_list
22
+ from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic
23
+ from ray.rllib.utils.checkpoints import Checkpointable
24
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
25
+ from ray.rllib.utils.spaces.space_utils import BatchedNdArray
26
+ from ray.rllib.utils.typing import AgentID, EpisodeType, ModuleID, StateDict
27
+ from ray.util.annotations import PublicAPI
28
+
29
+
30
+ @PublicAPI(stability="alpha")
31
+ class ConnectorV2(Checkpointable, abc.ABC):
32
+ """Base class defining the API for an individual "connector piece".
33
+
34
+ A ConnectorV2 ("connector piece") is usually part of a whole series of connector
35
+ pieces within a so-called connector pipeline, which in itself also abides to this
36
+ very API.
37
+ For example, you might have a connector pipeline consisting of two connector pieces,
38
+ A and B, both instances of subclasses of ConnectorV2 and each one performing a
39
+ particular transformation on their input data. The resulting connector pipeline
40
+ (A->B) itself also abides to this very ConnectorV2 API and could thus be part of yet
41
+ another, higher-level connector pipeline, e.g. (A->B)->C->D.
42
+
43
+ Any ConnectorV2 instance (individual pieces or several connector pieces in a
44
+ pipeline) is a callable and users should override the `__call__()` method.
45
+ When called, they take the outputs of a previous connector piece (or an empty dict
46
+ if there are no previous pieces) and all the data collected thus far in the
47
+ ongoing episode(s) (only applies to connectors used in EnvRunners) or retrieved
48
+ from a replay buffer or from an environment sampling step (only applies to
49
+ connectors used in Learner pipelines). From this input data, a ConnectorV2 then
50
+ performs a transformation step.
51
+
52
+ There are 3 types of pipelines any ConnectorV2 piece can belong to:
53
+ 1) EnvToModulePipeline: The connector transforms environment data before it gets to
54
+ the RLModule. This type of pipeline is used by an EnvRunner for transforming
55
+ env output data into RLModule readable data (for the next RLModule forward pass).
56
+ For example, such a pipeline would include observation postprocessors, -filters,
57
+ or any RNN preparation code related to time-sequences and zero-padding.
58
+ 2) ModuleToEnvPipeline: This type of pipeline is used by an
59
+ EnvRunner to transform RLModule output data to env readable actions (for the next
60
+ `env.step()` call). For example, in case the RLModule only outputs action
61
+ distribution parameters (but not actual actions), the ModuleToEnvPipeline would
62
+ take care of sampling the actions to be sent back to the end from the
63
+ resulting distribution (made deterministic if exploration is off).
64
+ 3) LearnerConnectorPipeline: This connector pipeline type transforms data coming
65
+ from an `EnvRunner.sample()` call or a replay buffer and will then be sent into the
66
+ RLModule's `forward_train()` method in order to compute loss function inputs.
67
+ This type of pipeline is used by a Learner worker to transform raw training data
68
+ (a batch or a list of episodes) to RLModule readable training data (for the next
69
+ RLModule `forward_train()` call).
70
+
71
+ Some connectors might be stateful, for example for keeping track of observation
72
+ filtering stats (mean and stddev values). Any Algorithm, which uses connectors is
73
+ responsible for frequently synchronizing the states of all connectors and connector
74
+ pipelines between the EnvRunners (owning the env-to-module and module-to-env
75
+ pipelines) and the Learners (owning the Learner pipelines).
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ input_observation_space: Optional[gym.Space] = None,
81
+ input_action_space: Optional[gym.Space] = None,
82
+ **kwargs,
83
+ ):
84
+ """Initializes a ConnectorV2 instance.
85
+
86
+ Args:
87
+ input_observation_space: The (optional) input observation space for this
88
+ connector piece. This is the space coming from a previous connector
89
+ piece in the (env-to-module or learner) pipeline or is directly
90
+ defined within the gym.Env.
91
+ input_action_space: The (optional) input action space for this connector
92
+ piece. This is the space coming from a previous connector piece in the
93
+ (module-to-env) pipeline or is directly defined within the gym.Env.
94
+ **kwargs: Forward API-compatibility kwargs.
95
+ """
96
+ self._observation_space = None
97
+ self._action_space = None
98
+ self._input_observation_space = None
99
+ self._input_action_space = None
100
+
101
+ self.input_action_space = input_action_space
102
+ self.input_observation_space = input_observation_space
103
+
104
+ # Store child's constructor args and kwargs for the default
105
+ # `get_ctor_args_and_kwargs` implementation (to be able to restore from a
106
+ # checkpoint).
107
+ if self.__class__.__dict__.get("__init__") is not None:
108
+ caller_frame = inspect.stack()[1].frame
109
+ arg_info = inspect.getargvalues(caller_frame)
110
+ # Separate positional arguments and keyword arguments.
111
+ caller_locals = (
112
+ arg_info.locals
113
+ ) # Dictionary of all local variables in the caller
114
+ self._ctor_kwargs = {
115
+ arg: caller_locals[arg] for arg in arg_info.args if arg != "self"
116
+ }
117
+ else:
118
+ self._ctor_kwargs = {
119
+ "input_observation_space": self.input_observation_space,
120
+ "input_action_space": self.input_action_space,
121
+ }
122
+
123
+ @OverrideToImplementCustomLogic
124
+ def recompute_output_observation_space(
125
+ self,
126
+ input_observation_space: gym.Space,
127
+ input_action_space: gym.Space,
128
+ ) -> gym.Space:
129
+ """Re-computes a new (output) observation space based on the input spaces.
130
+
131
+ This method should be overridden by users to make sure a ConnectorPipelineV2
132
+ knows how the input spaces through its individual ConnectorV2 pieces are being
133
+ transformed.
134
+
135
+ .. testcode::
136
+
137
+ from gymnasium.spaces import Box, Discrete
138
+ import numpy as np
139
+
140
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
141
+ from ray.rllib.utils.numpy import one_hot
142
+ from ray.rllib.utils.test_utils import check
143
+
144
+ class OneHotConnector(ConnectorV2):
145
+ def recompute_output_observation_space(
146
+ self,
147
+ input_observation_space,
148
+ input_action_space,
149
+ ):
150
+ return Box(0.0, 1.0, (input_observation_space.n,), np.float32)
151
+
152
+ def __call__(
153
+ self,
154
+ *,
155
+ rl_module,
156
+ batch,
157
+ episodes,
158
+ explore=None,
159
+ shared_data=None,
160
+ metrics=None,
161
+ **kwargs,
162
+ ):
163
+ assert "obs" in batch
164
+ batch["obs"] = one_hot(batch["obs"])
165
+ return batch
166
+
167
+ connector = OneHotConnector(input_observation_space=Discrete(2))
168
+ batch = {"obs": np.array([1, 0, 0], np.int32)}
169
+ output = connector(rl_module=None, batch=batch, episodes=None)
170
+
171
+ check(output, {"obs": np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0]])})
172
+
173
+ If this ConnectorV2 does not change the observation space in any way, leave
174
+ this parent method implementation untouched.
175
+
176
+ Args:
177
+ input_observation_space: The input observation space (either coming from the
178
+ environment if `self` is the first connector piece in the pipeline or
179
+ from the previous connector piece in the pipeline).
180
+ input_action_space: The input action space (either coming from the
181
+ environment if `self is the first connector piece in the pipeline or
182
+ from the previous connector piece in the pipeline).
183
+
184
+ Returns:
185
+ The new observation space (after data has passed through this ConnectorV2
186
+ piece).
187
+ """
188
+ return self.input_observation_space
189
+
190
+ @OverrideToImplementCustomLogic
191
+ def recompute_output_action_space(
192
+ self,
193
+ input_observation_space: gym.Space,
194
+ input_action_space: gym.Space,
195
+ ) -> gym.Space:
196
+ """Re-computes a new (output) action space based on the input space.
197
+
198
+ This method should be overridden by users to make sure a ConnectorPipelineV2
199
+ knows how the input spaces through its individual ConnectorV2 pieces are being
200
+ transformed.
201
+
202
+ If this ConnectorV2 does not change the action space in any way, leave
203
+ this parent method implementation untouched.
204
+
205
+ Args:
206
+ input_observation_space: The input observation space (either coming from the
207
+ environment if `self` is the first connector piece in the pipeline or
208
+ from the previous connector piece in the pipeline).
209
+ input_action_space: The input action space (either coming from the
210
+ environment if `self is the first connector piece in the pipeline or
211
+ from the previous connector piece in the pipeline).
212
+
213
+ Returns:
214
+ The new action space (after data has passed through this ConenctorV2
215
+ piece).
216
+ """
217
+ return self.input_action_space
218
+
219
+ @abc.abstractmethod
220
+ def __call__(
221
+ self,
222
+ *,
223
+ rl_module: RLModule,
224
+ batch: Dict[str, Any],
225
+ episodes: List[EpisodeType],
226
+ explore: Optional[bool] = None,
227
+ shared_data: Optional[dict] = None,
228
+ metrics: Optional[MetricsLogger] = None,
229
+ **kwargs,
230
+ ) -> Any:
231
+ """Method for transforming an input `batch` into an output `batch`.
232
+
233
+ Args:
234
+ rl_module: The RLModule object that the connector connects to or from.
235
+ batch: The input data to be transformed by this connector. Transformations
236
+ might either be done in-place or a new structure may be returned.
237
+ Note that the information in `batch` will eventually either become the
238
+ forward batch for the RLModule (env-to-module and learner connectors)
239
+ or the input to the `env.step()` call (module-to-env connectors). Note
240
+ that in the first case (`batch` is a forward batch for RLModule), the
241
+ information in `batch` will be discarded after that RLModule forward
242
+ pass. Any transformation of information (e.g. observation preprocessing)
243
+ that you have only done inside `batch` will be lost, unless you have
244
+ written it back into the corresponding `episodes` during the connector
245
+ pass.
246
+ episodes: The list of SingleAgentEpisode or MultiAgentEpisode objects,
247
+ each corresponding to one slot in the vector env. Note that episodes
248
+ can be read from (e.g. to place information into `batch`), but also
249
+ written to. You should only write back (changed, transformed)
250
+ information into the episodes, if you want these changes to be
251
+ "permanent". For example if you sample from an environment, pick up
252
+ observations from the episodes and place them into `batch`, then
253
+ transform these observations, and would like to make these
254
+ transformations permanent (note that `batch` gets discarded after the
255
+ RLModule forward pass), then you have to write the transformed
256
+ observations back into the episode to make sure you do not have to
257
+ perform the same transformation again on the learner (or replay buffer)
258
+ side. The Learner will hence work on the already changed episodes (and
259
+ compile the train batch using the Learner connector).
260
+ explore: Whether `explore` is currently on. Per convention, if True, the
261
+ RLModule's `forward_exploration` method should be called, if False, the
262
+ EnvRunner should call `forward_inference` instead.
263
+ shared_data: Optional additional context data that needs to be exchanged
264
+ between different ConnectorV2 pieces (in the same pipeline) or across
265
+ ConnectorV2 pipelines (meaning between env-to-module and module-to-env).
266
+ metrics: Optional MetricsLogger instance to log custom metrics to.
267
+ kwargs: Forward API-compatibility kwargs.
268
+
269
+ Returns:
270
+ The transformed connector output.
271
+ """
272
+
273
+ @staticmethod
274
+ def single_agent_episode_iterator(
275
+ episodes: List[EpisodeType],
276
+ agents_that_stepped_only: bool = True,
277
+ zip_with_batch_column: Optional[Union[List[Any], Dict[Tuple, Any]]] = None,
278
+ ) -> Iterator[SingleAgentEpisode]:
279
+ """An iterator over a list of episodes yielding always SingleAgentEpisodes.
280
+
281
+ In case items in the list are MultiAgentEpisodes, these are broken down
282
+ into their individual agents' SingleAgentEpisodes and those are then yielded
283
+ one after the other.
284
+
285
+ Useful for connectors that operate on both single-agent and multi-agent
286
+ episodes.
287
+
288
+ Args:
289
+ episodes: The list of SingleAgent- or MultiAgentEpisode objects.
290
+ agents_that_stepped_only: If True (and multi-agent setup), will only place
291
+ items of those agents into the batch that have just stepped in the
292
+ actual MultiAgentEpisode (this is checked via a
293
+ `MultiAgentEpside.episode.get_agents_to_act()`). Note that this setting
294
+ is ignored in a single-agent setups b/c the agent steps at each timestep
295
+ regardless.
296
+ zip_with_batch_column: If provided, must be a list of batch items
297
+ corresponding to the given `episodes` (single agent case) or a dict
298
+ mapping (AgentID, ModuleID) tuples to lists of individual batch items
299
+ corresponding to this agent/module combination. The iterator will then
300
+ yield tuples of SingleAgentEpisode objects (1st item) along with the
301
+ data item (2nd item) that this episode was responsible for generating
302
+ originally.
303
+
304
+ Yields:
305
+ All SingleAgentEpisodes in the input list, whereby MultiAgentEpisodes will
306
+ be broken down into their individual SingleAgentEpisode components.
307
+ """
308
+ list_indices = defaultdict(int)
309
+
310
+ # Single-agent case.
311
+ if episodes and isinstance(episodes[0], SingleAgentEpisode):
312
+ if zip_with_batch_column is not None:
313
+ if len(zip_with_batch_column) != len(episodes):
314
+ raise ValueError(
315
+ "Invalid `zip_with_batch_column` data: Must have the same "
316
+ f"length as the list of episodes ({len(episodes)}), but has "
317
+ f"length {len(zip_with_batch_column)}!"
318
+ )
319
+ # Simple case: Items are stored in lists directly under the column (str)
320
+ # key.
321
+ if isinstance(zip_with_batch_column, list):
322
+ for episode, data in zip(episodes, zip_with_batch_column):
323
+ yield episode, data
324
+ # Normal single-agent case: Items are stored in dicts under the column
325
+ # (str) key. These dicts map (eps_id,)-tuples to lists of individual
326
+ # items.
327
+ else:
328
+ for episode, (eps_id_tuple, data) in zip(
329
+ episodes,
330
+ zip_with_batch_column.items(),
331
+ ):
332
+ assert episode.id_ == eps_id_tuple[0]
333
+ d = data[list_indices[eps_id_tuple]]
334
+ list_indices[eps_id_tuple] += 1
335
+ yield episode, d
336
+ else:
337
+ for episode in episodes:
338
+ yield episode
339
+ return
340
+
341
+ # Multi-agent case.
342
+ for episode in episodes:
343
+ for agent_id in (
344
+ episode.get_agents_that_stepped()
345
+ if agents_that_stepped_only
346
+ else episode.agent_ids
347
+ ):
348
+ sa_episode = episode.agent_episodes[agent_id]
349
+ # for sa_episode in episode.agent_episodes.values():
350
+ if zip_with_batch_column is not None:
351
+ key = (
352
+ sa_episode.multi_agent_episode_id,
353
+ sa_episode.agent_id,
354
+ sa_episode.module_id,
355
+ )
356
+ if len(zip_with_batch_column[key]) <= list_indices[key]:
357
+ raise ValueError(
358
+ "Invalid `zip_with_batch_column` data: Must structurally "
359
+ "match the single-agent contents in the given list of "
360
+ "(multi-agent) episodes!"
361
+ )
362
+ d = zip_with_batch_column[key][list_indices[key]]
363
+ list_indices[key] += 1
364
+ yield sa_episode, d
365
+ else:
366
+ yield sa_episode
367
+
368
+ @staticmethod
369
+ def add_batch_item(
370
+ batch: Dict[str, Any],
371
+ column: str,
372
+ item_to_add: Any,
373
+ single_agent_episode: Optional[SingleAgentEpisode] = None,
374
+ ) -> None:
375
+ """Adds a data item under `column` to the given `batch`.
376
+
377
+ The `item_to_add` is stored in the `batch` in the following manner:
378
+ 1) If `single_agent_episode` is not provided (None), will store the item in a
379
+ list directly under `column`:
380
+ `column` -> [item, item, ...]
381
+ 2) If `single_agent_episode`'s `agent_id` and `module_id` properties are None
382
+ (`single_agent_episode` is not part of a multi-agent episode), will append
383
+ `item_to_add` to a list under a `(<episodeID>,)` key under `column`:
384
+ `column` -> `(<episodeID>,)` -> [item, item, ...]
385
+ 3) If `single_agent_episode`'s `agent_id` and `module_id` are NOT None
386
+ (`single_agent_episode` is part of a multi-agent episode), will append
387
+ `item_to_add` to a list under a `(<episodeID>,<AgentID>,<ModuleID>)` key
388
+ under `column`:
389
+ `column` -> `(<episodeID>,<AgentID>,<ModuleID>)` -> [item, item, ...]
390
+
391
+ See the these examples here for clarification of these three cases:
392
+
393
+ .. testcode::
394
+
395
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
396
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
397
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
398
+ from ray.rllib.utils.test_utils import check
399
+
400
+ # 1) Simple case (no episodes provided) -> Store data in a list directly
401
+ # under `column`:
402
+ batch = {}
403
+ ConnectorV2.add_batch_item(batch, "test_col", item_to_add=5)
404
+ ConnectorV2.add_batch_item(batch, "test_col", item_to_add=6)
405
+ check(batch, {"test_col": [5, 6]})
406
+ ConnectorV2.add_batch_item(batch, "test_col_2", item_to_add=-10)
407
+ check(batch, {
408
+ "test_col": [5, 6],
409
+ "test_col_2": [-10],
410
+ })
411
+
412
+ # 2) Single-agent case (SingleAgentEpisode provided) -> Store data in a list
413
+ # under the keys: `column` -> `(<eps_id>,)` -> [...]:
414
+ batch = {}
415
+ episode = SingleAgentEpisode(
416
+ id_="SA-EPS0",
417
+ observations=[0, 1, 2, 3],
418
+ actions=[1, 2, 3],
419
+ rewards=[1.0, 2.0, 3.0],
420
+ )
421
+ ConnectorV2.add_batch_item(batch, "test_col", 5, episode)
422
+ ConnectorV2.add_batch_item(batch, "test_col", 6, episode)
423
+ ConnectorV2.add_batch_item(batch, "test_col_2", -10, episode)
424
+ check(batch, {
425
+ "test_col": {("SA-EPS0",): [5, 6]},
426
+ "test_col_2": {("SA-EPS0",): [-10]},
427
+ })
428
+
429
+ # 3) Multi-agent case (SingleAgentEpisode provided that has `agent_id` and
430
+ # `module_id` information) -> Store data in a list under the keys:
431
+ # `column` -> `(<episodeID>,<AgentID>,<ModuleID>)` -> [...]:
432
+ batch = {}
433
+ ma_episode = MultiAgentEpisode(
434
+ id_="MA-EPS1",
435
+ observations=[
436
+ {"ag0": 0, "ag1": 1}, {"ag0": 2, "ag1": 4}
437
+ ],
438
+ actions=[{"ag0": 0, "ag1": 1}],
439
+ rewards=[{"ag0": -0.1, "ag1": -0.2}],
440
+ # ag0 maps to mod0, ag1 maps to mod1, etc..
441
+ agent_to_module_mapping_fn=lambda aid, eps: f"mod{aid[2:]}",
442
+ )
443
+ ConnectorV2.add_batch_item(
444
+ batch,
445
+ "test_col",
446
+ item_to_add=5,
447
+ single_agent_episode=ma_episode.agent_episodes["ag0"],
448
+ )
449
+ ConnectorV2.add_batch_item(
450
+ batch,
451
+ "test_col",
452
+ item_to_add=6,
453
+ single_agent_episode=ma_episode.agent_episodes["ag0"],
454
+ )
455
+ ConnectorV2.add_batch_item(
456
+ batch,
457
+ "test_col_2",
458
+ item_to_add=10,
459
+ single_agent_episode=ma_episode.agent_episodes["ag1"],
460
+ )
461
+ check(
462
+ batch,
463
+ {
464
+ "test_col": {("MA-EPS1", "ag0", "mod0"): [5, 6]},
465
+ "test_col_2": {("MA-EPS1", "ag1", "mod1"): [10]},
466
+ },
467
+ )
468
+
469
+ Args:
470
+ batch: The batch to store `item_to_add` in.
471
+ column: The column name (str) within the `batch` to store `item_to_add`
472
+ under.
473
+ item_to_add: The data item to store in the batch.
474
+ single_agent_episode: An optional SingleAgentEpisode.
475
+ If provided and its `agent_id` and `module_id` properties are None,
476
+ creates a further sub dictionary under `column`, mapping from
477
+ `(<episodeID>,)` to a list of data items (to which `item_to_add` will
478
+ be appended in this call).
479
+ If provided and its `agent_id` and `module_id` properties are NOT None,
480
+ creates a further sub dictionary under `column`, mapping from
481
+ `(<episodeID>,,<AgentID>,<ModuleID>)` to a list of data items (to which
482
+ `item_to_add` will be appended in this call).
483
+ If not provided, will append `item_to_add` to a list directly under
484
+ `column`.
485
+ """
486
+ sub_key = None
487
+ # SAEpisode is provided ...
488
+ if single_agent_episode is not None:
489
+ module_id = single_agent_episode.module_id
490
+ # ... and has `module_id` AND that `module_id` is already a top-level key in
491
+ # `batch` (`batch` is already in module-major form, mapping ModuleID to
492
+ # columns mapping to data).
493
+ if module_id is not None and module_id in batch:
494
+ raise ValueError(
495
+ "Can't call `add_batch_item` on a `batch` that is already "
496
+ "module-major (meaning ModuleID is top-level with column names on "
497
+ "the level thereunder)! Make sure to only call `add_batch_items` "
498
+ "before the `AgentToModuleMapping` ConnectorV2 piece is applied."
499
+ )
500
+
501
+ # ... and has `agent_id` -> Use `single_agent_episode`'s agent ID and
502
+ # module ID.
503
+ elif single_agent_episode.agent_id is not None:
504
+ sub_key = (
505
+ single_agent_episode.multi_agent_episode_id,
506
+ single_agent_episode.agent_id,
507
+ single_agent_episode.module_id,
508
+ )
509
+ # Otherwise, just use episode's ID.
510
+ else:
511
+ sub_key = (single_agent_episode.id_,)
512
+
513
+ if column not in batch:
514
+ batch[column] = [] if sub_key is None else {sub_key: []}
515
+ if sub_key is not None:
516
+ if sub_key not in batch[column]:
517
+ batch[column][sub_key] = []
518
+ batch[column][sub_key].append(item_to_add)
519
+ else:
520
+ batch[column].append(item_to_add)
521
+
522
+ @staticmethod
523
+ def add_n_batch_items(
524
+ batch: Dict[str, Any],
525
+ column: str,
526
+ items_to_add: Any,
527
+ num_items: int,
528
+ single_agent_episode: Optional[SingleAgentEpisode] = None,
529
+ ) -> None:
530
+ """Adds a list of items (or batched item) under `column` to the given `batch`.
531
+
532
+ If `items_to_add` is not a list, but an already batched struct (of np.ndarray
533
+ leafs), the `items_to_add` will be appended to possibly existing data under the
534
+ same `column` as-is. A subsequent `BatchIndividualItems` ConnectorV2 piece will
535
+ recognize this and batch the data properly into a single (batched) item.
536
+ This is much faster than first splitting up `items_to_add` and then adding each
537
+ item individually.
538
+
539
+ If `single_agent_episode` is provided and its `agent_id` and `module_id`
540
+ properties are None, creates a further sub dictionary under `column`, mapping
541
+ from `(<episodeID>,)` to a list of data items (to which `items_to_add` will
542
+ be appended in this call).
543
+ If `single_agent_episode` is provided and its `agent_id` and `module_id`
544
+ properties are NOT None, creates a further sub dictionary under `column`,
545
+ mapping from `(<episodeID>,,<AgentID>,<ModuleID>)` to a list of data items (to
546
+ which `items_to_add` will be appended in this call).
547
+ If `single_agent_episode` is not provided, will append `items_to_add` to a list
548
+ directly under `column`.
549
+
550
+ .. testcode::
551
+
552
+ import numpy as np
553
+
554
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
555
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
556
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
557
+ from ray.rllib.utils.test_utils import check
558
+
559
+ # Simple case (no episodes provided) -> Store data in a list directly under
560
+ # `column`:
561
+ batch = {}
562
+ ConnectorV2.add_n_batch_items(
563
+ batch,
564
+ "test_col",
565
+ # List of (complex) structs.
566
+ [{"a": np.array(3), "b": 4}, {"a": np.array(5), "b": 6}],
567
+ num_items=2,
568
+ )
569
+ check(
570
+ batch["test_col"],
571
+ [{"a": np.array(3), "b": 4}, {"a": np.array(5), "b": 6}],
572
+ )
573
+ # In a new column (test_col_2), store some already batched items.
574
+ # This way, you may avoid having to disassemble an already batched item
575
+ # (e.g. a numpy array of shape (10, 2)) into its individual items (e.g.
576
+ # split the array into a list of len=10) and then adding these individually.
577
+ # The performance gains may be quite large when providing already batched
578
+ # items (such as numpy arrays with a batch dim):
579
+ ConnectorV2.add_n_batch_items(
580
+ batch,
581
+ "test_col_2",
582
+ # One (complex) already batched struct.
583
+ {"a": np.array([3, 5]), "b": np.array([4, 6])},
584
+ num_items=2,
585
+ )
586
+ # Add more already batched items (this time with a different batch size)
587
+ ConnectorV2.add_n_batch_items(
588
+ batch,
589
+ "test_col_2",
590
+ {"a": np.array([7, 7, 7]), "b": np.array([8, 8, 8])},
591
+ num_items=3, # <- in this case, this must be the batch size
592
+ )
593
+ check(
594
+ batch["test_col_2"],
595
+ [
596
+ {"a": np.array([3, 5]), "b": np.array([4, 6])},
597
+ {"a": np.array([7, 7, 7]), "b": np.array([8, 8, 8])},
598
+ ],
599
+ )
600
+
601
+ # Single-agent case (SingleAgentEpisode provided) -> Store data in a list
602
+ # under the keys: `column` -> `(<eps_id>,)`:
603
+ batch = {}
604
+ episode = SingleAgentEpisode(
605
+ id_="SA-EPS0",
606
+ observations=[0, 1, 2, 3],
607
+ actions=[1, 2, 3],
608
+ rewards=[1.0, 2.0, 3.0],
609
+ )
610
+ ConnectorV2.add_n_batch_items(
611
+ batch=batch,
612
+ column="test_col",
613
+ items_to_add=[5, 6, 7],
614
+ num_items=3,
615
+ single_agent_episode=episode,
616
+ )
617
+ check(batch, {
618
+ "test_col": {("SA-EPS0",): [5, 6, 7]},
619
+ })
620
+
621
+ # Multi-agent case (SingleAgentEpisode provided that has `agent_id` and
622
+ # `module_id` information) -> Store data in a list under the keys:
623
+ # `column` -> `(<episodeID>,<AgentID>,<ModuleID>)`:
624
+ batch = {}
625
+ ma_episode = MultiAgentEpisode(
626
+ id_="MA-EPS1",
627
+ observations=[
628
+ {"ag0": 0, "ag1": 1}, {"ag0": 2, "ag1": 4}
629
+ ],
630
+ actions=[{"ag0": 0, "ag1": 1}],
631
+ rewards=[{"ag0": -0.1, "ag1": -0.2}],
632
+ # ag0 maps to mod0, ag1 maps to mod1, etc..
633
+ agent_to_module_mapping_fn=lambda aid, eps: f"mod{aid[2:]}",
634
+ )
635
+ ConnectorV2.add_batch_item(
636
+ batch,
637
+ "test_col",
638
+ item_to_add=5,
639
+ single_agent_episode=ma_episode.agent_episodes["ag0"],
640
+ )
641
+ ConnectorV2.add_batch_item(
642
+ batch,
643
+ "test_col",
644
+ item_to_add=6,
645
+ single_agent_episode=ma_episode.agent_episodes["ag0"],
646
+ )
647
+ ConnectorV2.add_batch_item(
648
+ batch,
649
+ "test_col_2",
650
+ item_to_add=10,
651
+ single_agent_episode=ma_episode.agent_episodes["ag1"],
652
+ )
653
+ check(
654
+ batch,
655
+ {
656
+ "test_col": {("MA-EPS1", "ag0", "mod0"): [5, 6]},
657
+ "test_col_2": {("MA-EPS1", "ag1", "mod1"): [10]},
658
+ },
659
+ )
660
+
661
+ Args:
662
+ batch: The batch to store n `items_to_add` in.
663
+ column: The column name (str) within the `batch` to store `item_to_add`
664
+ under.
665
+ items_to_add: The list of data items to store in the batch OR an already
666
+ batched (possibly nested) struct. In the latter case, the `items_to_add`
667
+ will be appended to possibly existing data under the same `column`
668
+ as-is. A subsequent `BatchIndividualItems` ConnectorV2 piece will
669
+ recognize this and batch the data properly into a single (batched) item.
670
+ This is much faster than first splitting up `items_to_add` and then
671
+ adding each item individually.
672
+ num_items: The number of items in `items_to_add`. This arg is mostly for
673
+ asserting the correct usage of this method by checking, whether the
674
+ given data in `items_to_add` really has the right amount of individual
675
+ items.
676
+ single_agent_episode: An optional SingleAgentEpisode.
677
+ If provided and its `agent_id` and `module_id` properties are None,
678
+ creates a further sub dictionary under `column`, mapping from
679
+ `(<episodeID>,)` to a list of data items (to which `items_to_add` will
680
+ be appended in this call).
681
+ If provided and its `agent_id` and `module_id` properties are NOT None,
682
+ creates a further sub dictionary under `column`, mapping from
683
+ `(<episodeID>,,<AgentID>,<ModuleID>)` to a list of data items (to which
684
+ `items_to_add` will be appended in this call).
685
+ If not provided, will append `items_to_add` to a list directly under
686
+ `column`.
687
+ """
688
+ # Process n list items by calling `add_batch_item` on each of them individually.
689
+ if isinstance(items_to_add, list):
690
+ if len(items_to_add) != num_items:
691
+ raise ValueError(
692
+ f"Mismatch between `num_items` ({num_items}) and the length "
693
+ f"of the provided list ({len(items_to_add)}) in "
694
+ f"{ConnectorV2.__name__}.add_n_batch_items()!"
695
+ )
696
+ for item in items_to_add:
697
+ ConnectorV2.add_batch_item(
698
+ batch=batch,
699
+ column=column,
700
+ item_to_add=item,
701
+ single_agent_episode=single_agent_episode,
702
+ )
703
+ return
704
+
705
+ # Process a batched (possibly complex) struct.
706
+ # We could just unbatch the item (split it into a list) and then add each
707
+ # individual item to our `batch`. However, this comes with a heavy performance
708
+ # penalty. Instead, we tag the thus added array(s) here as "_has_batch_dim=True"
709
+ # and then know that when batching the entire list under the respective
710
+ # (eps_id, agent_id, module_id)-tuple key, we need to concatenate, not stack
711
+ # the items in there.
712
+ def _tag(s):
713
+ return BatchedNdArray(s)
714
+
715
+ ConnectorV2.add_batch_item(
716
+ batch=batch,
717
+ column=column,
718
+ # Convert given input into BatchedNdArray(s) such that the `batch` utility
719
+ # knows that it'll have to concat, not stack.
720
+ item_to_add=tree.map_structure(_tag, items_to_add),
721
+ single_agent_episode=single_agent_episode,
722
+ )
723
+
724
+ @staticmethod
725
+ def foreach_batch_item_change_in_place(
726
+ batch: Dict[str, Any],
727
+ column: Union[str, List[str], Tuple[str]],
728
+ func: Callable[
729
+ [Any, Optional[int], Optional[AgentID], Optional[ModuleID]], Any
730
+ ],
731
+ ) -> None:
732
+ """Runs the provided `func` on all items under one or more columns in the batch.
733
+
734
+ Use this method to conveniently loop through all items in a batch
735
+ and transform them in place.
736
+
737
+ `func` takes the following as arguments:
738
+ - The item itself. If column is a list of column names, this argument is a tuple
739
+ of items.
740
+ - The EpisodeID. This value might be None.
741
+ - The AgentID. This value might be None in the single-agent case.
742
+ - The ModuleID. This value might be None in the single-agent case.
743
+
744
+ The return value(s) of `func` are used to directly override the values in the
745
+ given `batch`.
746
+
747
+ Args:
748
+ batch: The batch to process in-place.
749
+ column: A single column name (str) or a list thereof. If a list is provided,
750
+ the first argument to `func` is a tuple of items. If a single
751
+ str is provided, the first argument to `func` is an individual
752
+ item.
753
+ func: The function to call on each item or tuple of item(s).
754
+
755
+ .. testcode::
756
+
757
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
758
+ from ray.rllib.utils.test_utils import check
759
+
760
+ # Simple case: Batch items are in lists directly under their column names.
761
+ batch = {
762
+ "col1": [0, 1, 2, 3],
763
+ "col2": [0, -1, -2, -3],
764
+ }
765
+ # Increase all ints by 1.
766
+ ConnectorV2.foreach_batch_item_change_in_place(
767
+ batch=batch,
768
+ column="col1",
769
+ func=lambda item, *args: item + 1,
770
+ )
771
+ check(batch["col1"], [1, 2, 3, 4])
772
+
773
+ # Further increase all ints by 1 in col1 and flip sign in col2.
774
+ ConnectorV2.foreach_batch_item_change_in_place(
775
+ batch=batch,
776
+ column=["col1", "col2"],
777
+ func=(lambda items, *args: (items[0] + 1, -items[1])),
778
+ )
779
+ check(batch["col1"], [2, 3, 4, 5])
780
+ check(batch["col2"], [0, 1, 2, 3])
781
+
782
+ # Single-agent case: Batch items are in lists under (eps_id,)-keys in a dict
783
+ # under their column names.
784
+ batch = {
785
+ "col1": {
786
+ ("eps1",): [0, 1, 2, 3],
787
+ ("eps2",): [400, 500, 600],
788
+ },
789
+ }
790
+ # Increase all ints of eps1 by 1 and divide all ints of eps2 by 100.
791
+ ConnectorV2.foreach_batch_item_change_in_place(
792
+ batch=batch,
793
+ column="col1",
794
+ func=lambda item, eps_id, *args: (
795
+ item + 1 if eps_id == "eps1" else item / 100
796
+ ),
797
+ )
798
+ check(batch["col1"], {
799
+ ("eps1",): [1, 2, 3, 4],
800
+ ("eps2",): [4, 5, 6],
801
+ })
802
+
803
+ # Multi-agent case: Batch items are in lists under
804
+ # (eps_id, agent_id, module_id)-keys in a dict
805
+ # under their column names.
806
+ batch = {
807
+ "col1": {
808
+ ("eps1", "ag1", "mod1"): [1, 2, 3, 4],
809
+ ("eps2", "ag1", "mod2"): [400, 500, 600],
810
+ ("eps2", "ag2", "mod3"): [-1, -2, -3, -4, -5],
811
+ },
812
+ }
813
+ # Decrease all ints of "eps1" by 1, divide all ints of "mod2" by 100, and
814
+ # flip sign of all ints of "ag2".
815
+ ConnectorV2.foreach_batch_item_change_in_place(
816
+ batch=batch,
817
+ column="col1",
818
+ func=lambda item, eps_id, ag_id, mod_id: (
819
+ item - 1
820
+ if eps_id == "eps1"
821
+ else item / 100
822
+ if mod_id == "mod2"
823
+ else -item
824
+ ),
825
+ )
826
+ check(batch["col1"], {
827
+ ("eps1", "ag1", "mod1"): [0, 1, 2, 3],
828
+ ("eps2", "ag1", "mod2"): [4, 5, 6],
829
+ ("eps2", "ag2", "mod3"): [1, 2, 3, 4, 5],
830
+ })
831
+ """
832
+ data_to_process = [batch.get(c) for c in force_list(column)]
833
+ single_col = isinstance(column, str)
834
+ if any(d is None for d in data_to_process):
835
+ raise ValueError(
836
+ f"Invalid column name(s) ({column})! One or more not found in "
837
+ f"given batch. Found columns {list(batch.keys())}."
838
+ )
839
+
840
+ # Simple case: Data items are stored in a list directly under the column
841
+ # name(s).
842
+ if isinstance(data_to_process[0], list):
843
+ for list_pos, data_tuple in enumerate(zip(*data_to_process)):
844
+ results = func(
845
+ data_tuple[0] if single_col else data_tuple,
846
+ None, # episode_id
847
+ None, # agent_id
848
+ None, # module_id
849
+ )
850
+ # Tuple'ize results if single_col.
851
+ results = (results,) if single_col else results
852
+ for col_slot, result in enumerate(force_list(results)):
853
+ data_to_process[col_slot][list_pos] = result
854
+ # Single-agent/multi-agent cases.
855
+ else:
856
+ for key, d0_list in data_to_process[0].items():
857
+ # Multi-agent case: There is a dict mapping from a
858
+ # (eps id, AgentID, ModuleID)-tuples to lists of individual data items.
859
+ if len(key) == 3:
860
+ eps_id, agent_id, module_id = key
861
+ # Single-agent case: There is a dict mapping from a (eps_id,)-tuple
862
+ # to lists of individual data items.
863
+ # AgentID and ModuleID are both None.
864
+ else:
865
+ eps_id = key[0]
866
+ agent_id = module_id = None
867
+ other_lists = [d[key] for d in data_to_process[1:]]
868
+ for list_pos, data_tuple in enumerate(zip(d0_list, *other_lists)):
869
+ results = func(
870
+ data_tuple[0] if single_col else data_tuple,
871
+ eps_id,
872
+ agent_id,
873
+ module_id,
874
+ )
875
+ # Tuple'ize results if single_col.
876
+ results = (results,) if single_col else results
877
+ for col_slot, result in enumerate(results):
878
+ data_to_process[col_slot][key][list_pos] = result
879
+
880
+ @staticmethod
881
+ def switch_batch_from_column_to_module_ids(
882
+ batch: Dict[str, Dict[ModuleID, Any]]
883
+ ) -> Dict[ModuleID, Dict[str, Any]]:
884
+ """Switches the first two levels of a `col_name -> ModuleID -> data` type batch.
885
+
886
+ Assuming that the top level consists of column names as keys and the second
887
+ level (under these columns) consists of ModuleID keys, the resulting batch
888
+ will have these two reversed and thus map ModuleIDs to dicts mapping column
889
+ names to data items.
890
+
891
+ .. testcode::
892
+
893
+ from ray.rllib.utils.test_utils import check
894
+
895
+ batch = {
896
+ "obs": {"module_0": [1, 2, 3]},
897
+ "actions": {"module_0": [4, 5, 6], "module_1": [7]},
898
+ }
899
+ switched_batch = ConnectorV2.switch_batch_from_column_to_module_ids(batch)
900
+ check(
901
+ switched_batch,
902
+ {
903
+ "module_0": {"obs": [1, 2, 3], "actions": [4, 5, 6]},
904
+ "module_1": {"actions": [7]},
905
+ },
906
+ )
907
+
908
+ Args:
909
+ batch: The batch to switch from being column name based (then ModuleIDs)
910
+ to being ModuleID based (then column names).
911
+
912
+ Returns:
913
+ A new batch dict mapping ModuleIDs to dicts mapping column names (e.g.
914
+ "obs") to data.
915
+ """
916
+ module_data = defaultdict(dict)
917
+ for column, column_data in batch.items():
918
+ for module_id, data in column_data.items():
919
+ module_data[module_id][column] = data
920
+ return dict(module_data)
921
+
922
+ @override(Checkpointable)
923
+ def get_state(
924
+ self,
925
+ components: Optional[Union[str, Collection[str]]] = None,
926
+ *,
927
+ not_components: Optional[Union[str, Collection[str]]] = None,
928
+ **kwargs,
929
+ ) -> StateDict:
930
+ return {}
931
+
932
+ @override(Checkpointable)
933
+ def set_state(self, state: StateDict) -> None:
934
+ pass
935
+
936
+ @override(Checkpointable)
937
+ def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
938
+ return (
939
+ (), # *args
940
+ self._ctor_kwargs, # **kwargs
941
+ )
942
+
943
+ def reset_state(self) -> None:
944
+ """Resets the state of this ConnectorV2 to some initial value.
945
+
946
+ Note that this may NOT be the exact state that this ConnectorV2 was originally
947
+ constructed with.
948
+ """
949
+ return
950
+
951
+ def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
952
+ """Computes a resulting state given self's state and a list of other states.
953
+
954
+ Algorithms should use this method for merging states between connectors
955
+ running on parallel EnvRunner workers. For example, to synchronize the connector
956
+ states of n remote workers and a local worker, one could:
957
+ - Gather all remote worker connector states in a list.
958
+ - Call `self.merge_states()` on the local worker passing it the states list.
959
+ - Broadcast the resulting local worker's connector state back to all remote
960
+ workers. After this, all workers (including the local one) hold a
961
+ merged/synchronized new connecto state.
962
+
963
+ Args:
964
+ states: The list of n other ConnectorV2 states to merge with self's state
965
+ into a single resulting state.
966
+
967
+ Returns:
968
+ The resulting state dict.
969
+ """
970
+ return {}
971
+
972
+ @property
973
+ def observation_space(self):
974
+ """Getter for our (output) observation space.
975
+
976
+ Logic: Use user provided space (if set via `observation_space` setter)
977
+ otherwise, use the same as the input space, assuming this connector piece
978
+ does not alter the space.
979
+ """
980
+ return self._observation_space
981
+
982
+ @property
983
+ def action_space(self):
984
+ """Getter for our (output) action space.
985
+
986
+ Logic: Use user provided space (if set via `action_space` setter)
987
+ otherwise, use the same as the input space, assuming this connector piece
988
+ does not alter the space.
989
+ """
990
+ return self._action_space
991
+
992
+ @property
993
+ def input_observation_space(self):
994
+ return self._input_observation_space
995
+
996
+ @input_observation_space.setter
997
+ def input_observation_space(self, value):
998
+ self._input_observation_space = value
999
+ if value is not None:
1000
+ self._observation_space = self.recompute_output_observation_space(
1001
+ value, self.input_action_space
1002
+ )
1003
+
1004
+ @property
1005
+ def input_action_space(self):
1006
+ return self._input_action_space
1007
+
1008
+ @input_action_space.setter
1009
+ def input_action_space(self, value):
1010
+ self._input_action_space = value
1011
+ if value is not None:
1012
+ self._action_space = self.recompute_output_action_space(
1013
+ self.input_observation_space, value
1014
+ )
1015
+
1016
+ def __str__(self, indentation: int = 0):
1017
+ return " " * indentation + self.__class__.__name__
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
2
+ AddObservationsFromEpisodesToBatch,
3
+ )
4
+ from ray.rllib.connectors.common.add_states_from_episodes_to_batch import (
5
+ AddStatesFromEpisodesToBatch,
6
+ )
7
+ from ray.rllib.connectors.common.add_time_dim_to_batch_and_zero_pad import (
8
+ AddTimeDimToBatchAndZeroPad,
9
+ )
10
+ from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping
11
+ from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems
12
+ from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor
13
+ from ray.rllib.connectors.env_to_module.env_to_module_pipeline import (
14
+ EnvToModulePipeline,
15
+ )
16
+ from ray.rllib.connectors.env_to_module.flatten_observations import (
17
+ FlattenObservations,
18
+ )
19
+ from ray.rllib.connectors.env_to_module.mean_std_filter import MeanStdFilter
20
+ from ray.rllib.connectors.env_to_module.prev_actions_prev_rewards import (
21
+ PrevActionsPrevRewards,
22
+ )
23
+ from ray.rllib.connectors.env_to_module.write_observations_to_episodes import (
24
+ WriteObservationsToEpisodes,
25
+ )
26
+
27
+
28
+ __all__ = [
29
+ "AddObservationsFromEpisodesToBatch",
30
+ "AddStatesFromEpisodesToBatch",
31
+ "AddTimeDimToBatchAndZeroPad",
32
+ "AgentToModuleMapping",
33
+ "BatchIndividualItems",
34
+ "EnvToModulePipeline",
35
+ "FlattenObservations",
36
+ "MeanStdFilter",
37
+ "NumpyToTensor",
38
+ "PrevActionsPrevRewards",
39
+ "WriteObservationsToEpisodes",
40
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc ADDED
Binary file (7.42 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2
4
+ from ray.rllib.core.rl_module.rl_module import RLModule
5
+ from ray.rllib.utils.annotations import override
6
+ from ray.rllib.utils.metrics import (
7
+ ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN,
8
+ ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT,
9
+ )
10
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
11
+ from ray.rllib.utils.typing import EpisodeType
12
+ from ray.util.annotations import PublicAPI
13
+
14
+
15
+ @PublicAPI(stability="alpha")
16
+ class EnvToModulePipeline(ConnectorPipelineV2):
17
+ @override(ConnectorPipelineV2)
18
+ def __call__(
19
+ self,
20
+ *,
21
+ rl_module: RLModule,
22
+ batch: Optional[Dict[str, Any]] = None,
23
+ episodes: List[EpisodeType],
24
+ explore: bool,
25
+ shared_data: Optional[dict] = None,
26
+ metrics: Optional[MetricsLogger] = None,
27
+ **kwargs,
28
+ ):
29
+ # Log the sum of lengths of all episodes incoming.
30
+ if metrics:
31
+ metrics.log_value(
32
+ ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN,
33
+ sum(map(len, episodes)),
34
+ )
35
+
36
+ # Make sure user does not necessarily send initial input into this pipeline.
37
+ # Might just be empty and to be populated from `episodes`.
38
+ ret = super().__call__(
39
+ rl_module=rl_module,
40
+ batch=batch if batch is not None else {},
41
+ episodes=episodes,
42
+ explore=explore,
43
+ shared_data=shared_data if shared_data is not None else {},
44
+ metrics=metrics,
45
+ **kwargs,
46
+ )
47
+
48
+ # Log the sum of lengths of all episodes outgoing.
49
+ if metrics:
50
+ metrics.log_value(
51
+ ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT,
52
+ sum(map(len, episodes)),
53
+ )
54
+
55
+ return ret
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Collection, Dict, List, Optional
2
+
3
+ import gymnasium as gym
4
+ from gymnasium.spaces import Box
5
+ import numpy as np
6
+ import tree # pip install dm_tree
7
+
8
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
9
+ from ray.rllib.core.rl_module.rl_module import RLModule
10
+ from ray.rllib.utils.annotations import override
11
+ from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor
12
+ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
13
+ from ray.rllib.utils.typing import AgentID, EpisodeType
14
+ from ray.util.annotations import PublicAPI
15
+
16
+
17
+ @PublicAPI(stability="alpha")
18
+ class FlattenObservations(ConnectorV2):
19
+ """A connector piece that flattens all observation components into a 1D array.
20
+
21
+ - Should be used only in env-to-module pipelines.
22
+ - Works directly on the incoming episodes list and changes the last observation
23
+ in-place (write the flattened observation back into the episode).
24
+ - This connector does NOT alter the incoming batch (`data`) when called.
25
+ - This connector does NOT work in a `LearnerConnectorPipeline` because it requires
26
+ the incoming episodes to still be ongoing (in progress) as it only alters the
27
+ latest observation, not all observations in an episode.
28
+
29
+ .. testcode::
30
+
31
+ import gymnasium as gym
32
+ import numpy as np
33
+
34
+ from ray.rllib.connectors.env_to_module import FlattenObservations
35
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
36
+ from ray.rllib.utils.test_utils import check
37
+
38
+ # Some arbitrarily nested, complex observation space.
39
+ obs_space = gym.spaces.Dict({
40
+ "a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
41
+ "b": gym.spaces.Tuple([
42
+ gym.spaces.Discrete(2),
43
+ gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
44
+ ]),
45
+ "c": gym.spaces.MultiDiscrete([2, 3]),
46
+ })
47
+ act_space = gym.spaces.Discrete(2)
48
+
49
+ # Two example episodes, both with initial (reset) observations coming from the
50
+ # above defined observation space.
51
+ episode_1 = SingleAgentEpisode(
52
+ observations=[
53
+ {
54
+ "a": np.array(-10.0, np.float32),
55
+ "b": (1, np.array([[-1.0], [-1.0]], np.float32)),
56
+ "c": np.array([0, 2]),
57
+ },
58
+ ],
59
+ )
60
+ episode_2 = SingleAgentEpisode(
61
+ observations=[
62
+ {
63
+ "a": np.array(10.0, np.float32),
64
+ "b": (0, np.array([[1.0], [1.0]], np.float32)),
65
+ "c": np.array([1, 1]),
66
+ },
67
+ ],
68
+ )
69
+
70
+ # Construct our connector piece.
71
+ connector = FlattenObservations(obs_space, act_space)
72
+
73
+ # Call our connector piece with the example data.
74
+ output_batch = connector(
75
+ rl_module=None, # This connector works without an RLModule.
76
+ batch={}, # This connector does not alter the input batch.
77
+ episodes=[episode_1, episode_2],
78
+ explore=True,
79
+ shared_data={},
80
+ )
81
+
82
+ # The connector does not alter the data and acts as pure pass-through.
83
+ check(output_batch, {})
84
+
85
+ # The connector has flattened each item in the episodes to a 1D tensor.
86
+ check(
87
+ episode_1.get_observations(0),
88
+ # box() disc(2). box(2, 1). multidisc(2, 3)........
89
+ np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]),
90
+ )
91
+ check(
92
+ episode_2.get_observations(0),
93
+ # box() disc(2). box(2, 1). multidisc(2, 3)........
94
+ np.array([10.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]),
95
+ )
96
+ """
97
+
98
+ @override(ConnectorV2)
99
+ def recompute_output_observation_space(
100
+ self,
101
+ input_observation_space,
102
+ input_action_space,
103
+ ) -> gym.Space:
104
+ self._input_obs_base_struct = get_base_struct_from_space(
105
+ self.input_observation_space
106
+ )
107
+ if self._multi_agent:
108
+ spaces = {}
109
+ for agent_id, space in self._input_obs_base_struct.items():
110
+ if self._agent_ids and agent_id not in self._agent_ids:
111
+ spaces[agent_id] = self._input_obs_base_struct[agent_id]
112
+ else:
113
+ sample = flatten_inputs_to_1d_tensor(
114
+ tree.map_structure(
115
+ lambda s: s.sample(),
116
+ self._input_obs_base_struct[agent_id],
117
+ ),
118
+ self._input_obs_base_struct[agent_id],
119
+ batch_axis=False,
120
+ )
121
+ spaces[agent_id] = Box(
122
+ float("-inf"), float("inf"), (len(sample),), np.float32
123
+ )
124
+ return gym.spaces.Dict(spaces)
125
+ else:
126
+ sample = flatten_inputs_to_1d_tensor(
127
+ tree.map_structure(
128
+ lambda s: s.sample(),
129
+ self._input_obs_base_struct,
130
+ ),
131
+ self._input_obs_base_struct,
132
+ batch_axis=False,
133
+ )
134
+ return Box(float("-inf"), float("inf"), (len(sample),), np.float32)
135
+
136
+ def __init__(
137
+ self,
138
+ input_observation_space: Optional[gym.Space] = None,
139
+ input_action_space: Optional[gym.Space] = None,
140
+ *,
141
+ multi_agent: bool = False,
142
+ agent_ids: Optional[Collection[AgentID]] = None,
143
+ **kwargs,
144
+ ):
145
+ """Initializes a FlattenObservations instance.
146
+
147
+ Args:
148
+ multi_agent: Whether this connector operates on multi-agent observations,
149
+ in which case, the top-level of the Dict space (where agent IDs are
150
+ mapped to individual agents' observation spaces) is left as-is.
151
+ agent_ids: If multi_agent is True, this argument defines a collection of
152
+ AgentIDs for which to flatten. AgentIDs not in this collection are
153
+ ignored.
154
+ If None, flatten observations for all AgentIDs. None is the default.
155
+ """
156
+ self._input_obs_base_struct = None
157
+ self._multi_agent = multi_agent
158
+ self._agent_ids = agent_ids
159
+
160
+ super().__init__(input_observation_space, input_action_space, **kwargs)
161
+
162
+ @override(ConnectorV2)
163
+ def __call__(
164
+ self,
165
+ *,
166
+ rl_module: RLModule,
167
+ batch: Dict[str, Any],
168
+ episodes: List[EpisodeType],
169
+ explore: Optional[bool] = None,
170
+ shared_data: Optional[dict] = None,
171
+ **kwargs,
172
+ ) -> Any:
173
+ for sa_episode in self.single_agent_episode_iterator(
174
+ episodes, agents_that_stepped_only=True
175
+ ):
176
+ last_obs = sa_episode.get_observations(-1)
177
+
178
+ if self._multi_agent:
179
+ if (
180
+ self._agent_ids is not None
181
+ and sa_episode.agent_id not in self._agent_ids
182
+ ):
183
+ flattened_obs = last_obs
184
+ else:
185
+ flattened_obs = flatten_inputs_to_1d_tensor(
186
+ inputs=last_obs,
187
+ # In the multi-agent case, we need to use the specific agent's
188
+ # space struct, not the multi-agent observation space dict.
189
+ spaces_struct=self._input_obs_base_struct[sa_episode.agent_id],
190
+ # Our items are individual observations (no batch axis present).
191
+ batch_axis=False,
192
+ )
193
+ else:
194
+ flattened_obs = flatten_inputs_to_1d_tensor(
195
+ inputs=last_obs,
196
+ spaces_struct=self._input_obs_base_struct,
197
+ # Our items are individual observations (no batch axis present).
198
+ batch_axis=False,
199
+ )
200
+
201
+ # Write new observation directly back into the episode.
202
+ sa_episode.set_observations(at_indices=-1, new_data=flattened_obs)
203
+ # We set the Episode's observation space to ours so that we can safely
204
+ # set the last obs to the new value (without causing a space mismatch
205
+ # error).
206
+ sa_episode.observation_space = self.observation_space
207
+
208
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from ray.rllib.connectors.common.frame_stacking import _FrameStacking
4
+
5
+
6
+ FrameStackingEnvToModule = partial(_FrameStacking, as_learner_connector=False)
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Collection, Dict, List, Optional, Union
2
+
3
+ import gymnasium as gym
4
+ from gymnasium.spaces import Discrete, MultiDiscrete
5
+ import numpy as np
6
+ import tree
7
+
8
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
9
+ from ray.rllib.core.rl_module.rl_module import RLModule
10
+ from ray.rllib.utils.annotations import override
11
+ from ray.rllib.utils.filter import MeanStdFilter as _MeanStdFilter, RunningStat
12
+ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
13
+ from ray.rllib.utils.typing import AgentID, EpisodeType, StateDict
14
+ from ray.util.annotations import PublicAPI
15
+
16
+
17
+ @PublicAPI(stability="alpha")
18
+ class MeanStdFilter(ConnectorV2):
19
+ """A connector used to mean-std-filter observations.
20
+
21
+ Incoming observations are filtered such that the output of this filter is on
22
+ average 0.0 and has a standard deviation of 1.0. If the observation space is
23
+ a (possibly nested) dict, this filtering is applied separately per element of
24
+ the observation space (except for discrete- and multi-discrete elements, which
25
+ are left as-is).
26
+
27
+ This connector is stateful as it continues to update its internal stats on mean
28
+ and std values as new data is pushed through it (unless `update_stats` is False).
29
+ """
30
+
31
+ @override(ConnectorV2)
32
+ def recompute_output_observation_space(
33
+ self,
34
+ input_observation_space: gym.Space,
35
+ input_action_space: gym.Space,
36
+ ) -> gym.Space:
37
+ _input_observation_space_struct = get_base_struct_from_space(
38
+ input_observation_space
39
+ )
40
+
41
+ # Adjust our observation space's Boxes (only if clipping is active).
42
+ _observation_space_struct = tree.map_structure(
43
+ lambda s: (
44
+ s
45
+ if not isinstance(s, gym.spaces.Box)
46
+ else gym.spaces.Box(
47
+ low=-self.clip_by_value,
48
+ high=self.clip_by_value,
49
+ shape=s.shape,
50
+ dtype=s.dtype,
51
+ )
52
+ ),
53
+ _input_observation_space_struct,
54
+ )
55
+ if isinstance(input_observation_space, (gym.spaces.Dict, gym.spaces.Tuple)):
56
+ return type(input_observation_space)(_observation_space_struct)
57
+ else:
58
+ return _observation_space_struct
59
+
60
+ def __init__(
61
+ self,
62
+ *,
63
+ multi_agent: bool = False,
64
+ de_mean_to_zero: bool = True,
65
+ de_std_to_one: bool = True,
66
+ clip_by_value: Optional[float] = 10.0,
67
+ update_stats: bool = True,
68
+ **kwargs,
69
+ ):
70
+ """Initializes a MeanStdFilter instance.
71
+
72
+ Args:
73
+ multi_agent: Whether this is a connector operating on a multi-agent
74
+ observation space mapping AgentIDs to individual agents' observations.
75
+ de_mean_to_zero: Whether to transform the mean values of the output data to
76
+ 0.0. This is done by subtracting the incoming data by the currently
77
+ stored mean value.
78
+ de_std_to_one: Whether to transform the standard deviation values of the
79
+ output data to 1.0. This is done by dividing the incoming data by the
80
+ currently stored std value.
81
+ clip_by_value: If not None, clip the incoming data within the interval:
82
+ [-clip_by_value, +clip_by_value].
83
+ update_stats: Whether to update the internal mean and std stats with each
84
+ incoming sample (with each `__call__()`) or not. You should set this to
85
+ False if you would like to perform inference in a production
86
+ environment, without continuing to "learn" stats from new data.
87
+ """
88
+ super().__init__(**kwargs)
89
+
90
+ self._multi_agent = multi_agent
91
+
92
+ # We simply use the old MeanStdFilter until non-connector env_runner is fully
93
+ # deprecated to avoid duplicate code
94
+ self.de_mean_to_zero = de_mean_to_zero
95
+ self.de_std_to_one = de_std_to_one
96
+ self.clip_by_value = clip_by_value
97
+ self._update_stats = update_stats
98
+
99
+ self._filters: Optional[Dict[AgentID, _MeanStdFilter]] = None
100
+
101
+ @override(ConnectorV2)
102
+ def __call__(
103
+ self,
104
+ *,
105
+ rl_module: RLModule,
106
+ batch: Dict[str, Any],
107
+ episodes: List[EpisodeType],
108
+ explore: Optional[bool] = None,
109
+ persistent_data: Optional[dict] = None,
110
+ **kwargs,
111
+ ) -> Any:
112
+ if self._filters is None:
113
+ self._init_new_filters()
114
+
115
+ # This connector acts as a classic preprocessor. We process and then replace
116
+ # observations inside the episodes directly. Thus, all following connectors
117
+ # will only see and operate on the already normalized data (w/o having access
118
+ # anymore to the original observations).
119
+ for sa_episode in self.single_agent_episode_iterator(episodes):
120
+ sa_obs = sa_episode.get_observations(indices=-1)
121
+ try:
122
+ normalized_sa_obs = self._filters[sa_episode.agent_id](
123
+ sa_obs, update=self._update_stats
124
+ )
125
+ except KeyError:
126
+ raise KeyError(
127
+ "KeyError trying to access a filter by agent ID "
128
+ f"`{sa_episode.agent_id}`! You probably did NOT pass the "
129
+ f"`multi_agent=True` flag into the `MeanStdFilter()` constructor. "
130
+ )
131
+ sa_episode.set_observations(at_indices=-1, new_data=normalized_sa_obs)
132
+ # We set the Episode's observation space to ours so that we can safely
133
+ # set the last obs to the new value (without causing a space mismatch
134
+ # error).
135
+ sa_episode.observation_space = self.observation_space
136
+
137
+ # Leave `batch` as is. RLlib's default connector will automatically
138
+ # populate the OBS column therein from the episodes' now transformed
139
+ # observations.
140
+ return batch
141
+
142
+ @override(ConnectorV2)
143
+ def get_state(
144
+ self,
145
+ components: Optional[Union[str, Collection[str]]] = None,
146
+ *,
147
+ not_components: Optional[Union[str, Collection[str]]] = None,
148
+ **kwargs,
149
+ ) -> StateDict:
150
+ if self._filters is None:
151
+ self._init_new_filters()
152
+ return self._get_state_from_filters(self._filters)
153
+
154
+ @override(ConnectorV2)
155
+ def set_state(self, state: StateDict) -> None:
156
+ if self._filters is None:
157
+ self._init_new_filters()
158
+ for agent_id, agent_state in state.items():
159
+ filter = self._filters[agent_id]
160
+ filter.shape = agent_state["shape"]
161
+ filter.demean = agent_state["de_mean_to_zero"]
162
+ filter.destd = agent_state["de_std_to_one"]
163
+ filter.clip = agent_state["clip_by_value"]
164
+ filter.running_stats = tree.unflatten_as(
165
+ filter.shape,
166
+ [RunningStat.from_state(s) for s in agent_state["running_stats"]],
167
+ )
168
+ # Do not update the buffer.
169
+
170
+ @override(ConnectorV2)
171
+ def reset_state(self) -> None:
172
+ """Creates copy of current state and resets accumulated state"""
173
+ if not self._update_stats:
174
+ raise ValueError(
175
+ f"State of {type(self).__name__} can only be changed when "
176
+ f"`update_stats` was set to False."
177
+ )
178
+ self._init_new_filters()
179
+
180
+ @override(ConnectorV2)
181
+ def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
182
+ if self._filters is None:
183
+ self._init_new_filters()
184
+
185
+ # Make sure data is uniform across given states.
186
+ ref = next(iter(states[0].values()))
187
+
188
+ for state in states:
189
+ for agent_id, agent_state in state.items():
190
+ assert (
191
+ agent_state["shape"] == ref["shape"]
192
+ and agent_state["de_mean_to_zero"] == ref["de_mean_to_zero"]
193
+ and agent_state["de_std_to_one"] == ref["de_std_to_one"]
194
+ and agent_state["clip_by_value"] == ref["clip_by_value"]
195
+ )
196
+
197
+ _filter = _MeanStdFilter(
198
+ ref["shape"],
199
+ demean=ref["de_mean_to_zero"],
200
+ destd=ref["de_std_to_one"],
201
+ clip=ref["clip_by_value"],
202
+ )
203
+ # Override running stats of the filter with the ones stored in
204
+ # `agent_state`.
205
+ _filter.buffer = tree.unflatten_as(
206
+ agent_state["shape"],
207
+ [
208
+ RunningStat.from_state(stats)
209
+ for stats in agent_state["running_stats"]
210
+ ],
211
+ )
212
+
213
+ # Leave the buffers as-is, since they should always only reflect
214
+ # what has happened on the particular env runner.
215
+ self._filters[agent_id].apply_changes(_filter, with_buffer=False)
216
+
217
+ return MeanStdFilter._get_state_from_filters(self._filters)
218
+
219
+ def _init_new_filters(self):
220
+ filter_shape = tree.map_structure(
221
+ lambda s: (
222
+ None if isinstance(s, (Discrete, MultiDiscrete)) else np.array(s.shape)
223
+ ),
224
+ get_base_struct_from_space(self.input_observation_space),
225
+ )
226
+ if not self._multi_agent:
227
+ filter_shape = {None: filter_shape}
228
+
229
+ del self._filters
230
+ self._filters = {
231
+ agent_id: _MeanStdFilter(
232
+ agent_filter_shape,
233
+ demean=self.de_mean_to_zero,
234
+ destd=self.de_std_to_one,
235
+ clip=self.clip_by_value,
236
+ )
237
+ for agent_id, agent_filter_shape in filter_shape.items()
238
+ }
239
+
240
+ @staticmethod
241
+ def _get_state_from_filters(filters: Dict[AgentID, Dict[str, Any]]):
242
+ ret = {}
243
+ for agent_id, agent_filter in filters.items():
244
+ ret[agent_id] = {
245
+ "shape": agent_filter.shape,
246
+ "de_mean_to_zero": agent_filter.demean,
247
+ "de_std_to_one": agent_filter.destd,
248
+ "clip_by_value": agent_filter.clip,
249
+ "running_stats": [
250
+ s.to_state() for s in tree.flatten(agent_filter.running_stats)
251
+ ],
252
+ }
253
+ return ret
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import gymnasium as gym
5
+
6
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
7
+ from ray.rllib.core.rl_module.rl_module import RLModule
8
+ from ray.rllib.utils.annotations import override
9
+ from ray.rllib.utils.typing import EpisodeType
10
+ from ray.util.annotations import PublicAPI
11
+
12
+
13
+ @PublicAPI(stability="alpha")
14
+ class ObservationPreprocessor(ConnectorV2, abc.ABC):
15
+ """Env-to-module connector performing one preprocessor step on the last observation.
16
+
17
+ This is a convenience class that simplifies the writing of few-step preprocessor
18
+ connectors.
19
+
20
+ Users must implement the `preprocess()` method, which simplifies the usual procedure
21
+ of extracting some data from a list of episodes and adding it to the batch to a mere
22
+ "old-observation --transform--> return new-observation" step.
23
+ """
24
+
25
+ @override(ConnectorV2)
26
+ def recompute_output_observation_space(
27
+ self,
28
+ input_observation_space: gym.Space,
29
+ input_action_space: gym.Space,
30
+ ) -> gym.Space:
31
+ # Users should override this method only in case the `ObservationPreprocessor`
32
+ # changes the observation space of the pipeline. In this case, return the new
33
+ # observation space based on the incoming one (`input_observation_space`).
34
+ return super().recompute_output_observation_space(
35
+ input_observation_space, input_action_space
36
+ )
37
+
38
+ @abc.abstractmethod
39
+ def preprocess(self, observation):
40
+ """Override to implement the preprocessing logic.
41
+
42
+ Args:
43
+ observation: A single (non-batched) observation item for a single agent to
44
+ be processed by this connector.
45
+
46
+ Returns:
47
+ The new observation after `observation` has been preprocessed.
48
+ """
49
+
50
+ @override(ConnectorV2)
51
+ def __call__(
52
+ self,
53
+ *,
54
+ rl_module: RLModule,
55
+ batch: Dict[str, Any],
56
+ episodes: List[EpisodeType],
57
+ explore: Optional[bool] = None,
58
+ persistent_data: Optional[dict] = None,
59
+ **kwargs,
60
+ ) -> Any:
61
+ # We process and then replace observations inside the episodes directly.
62
+ # Thus, all following connectors will only see and operate on the already
63
+ # processed observation (w/o having access anymore to the original
64
+ # observations).
65
+ for sa_episode in self.single_agent_episode_iterator(episodes):
66
+ observation = sa_episode.get_observations(-1)
67
+
68
+ # Process the observation and write the new observation back into the
69
+ # episode.
70
+ new_observation = self.preprocess(observation=observation)
71
+ sa_episode.set_observations(at_indices=-1, new_data=new_observation)
72
+ # We set the Episode's observation space to ours so that we can safely
73
+ # set the last obs to the new value (without causing a space mismatch
74
+ # error).
75
+ sa_episode.observation_space = self.observation_space
76
+
77
+ # Leave `batch` as is. RLlib's default connector will automatically
78
+ # populate the OBS column therein from the episodes' now transformed
79
+ # observations.
80
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import gymnasium as gym
4
+ from gymnasium.spaces import Box
5
+ import numpy as np
6
+
7
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
8
+ from ray.rllib.core.rl_module.rl_module import RLModule
9
+ from ray.rllib.utils.annotations import override
10
+ from ray.rllib.utils.spaces.space_utils import (
11
+ batch as batch_fn,
12
+ flatten_to_single_ndarray,
13
+ )
14
+ from ray.rllib.utils.typing import EpisodeType
15
+ from ray.util.annotations import PublicAPI
16
+
17
+
18
+ @PublicAPI(stability="alpha")
19
+ class PrevActionsPrevRewards(ConnectorV2):
20
+ """A connector piece that adds previous rewards and actions to the input obs.
21
+
22
+ - Requires Columns.OBS to be already a part of the batch.
23
+ - This connector makes the assumption that under the Columns.OBS key in batch,
24
+ there is either a list of individual env observations to be flattened (single-agent
25
+ case) or a dict mapping (AgentID, ModuleID)-tuples to lists of data items to be
26
+ flattened (multi-agent case).
27
+ - Converts Columns.OBS data into a dict (or creates a sub-dict if obs are
28
+ already a dict), and adds "prev_rewards" and "prev_actions"
29
+ to this dict. The original observations are stored under the self.ORIG_OBS_KEY in
30
+ that dict.
31
+ - If your RLModule does not handle dict inputs, you will have to plug in an
32
+ `FlattenObservations` connector piece after this one.
33
+ - Does NOT work in a Learner pipeline as it operates on individual observation
34
+ items (as opposed to batched/time-ranked data).
35
+ - Therefore, assumes that the altered (flattened) observations will be written
36
+ back into the episode by a later connector piece in the env-to-module pipeline
37
+ (which this piece is part of as well).
38
+ - Only reads reward- and action information from the given list of Episode objects.
39
+ - Does NOT write any observations (or other data) to the given Episode objects.
40
+ """
41
+
42
+ ORIG_OBS_KEY = "_orig_obs"
43
+ PREV_ACTIONS_KEY = "prev_n_actions"
44
+ PREV_REWARDS_KEY = "prev_n_rewards"
45
+
46
+ @override(ConnectorV2)
47
+ def recompute_output_observation_space(
48
+ self,
49
+ input_observation_space: gym.Space,
50
+ input_action_space: gym.Space,
51
+ ) -> gym.Space:
52
+ if self._multi_agent:
53
+ ret = {}
54
+ for agent_id, obs_space in input_observation_space.spaces.items():
55
+ act_space = input_action_space[agent_id]
56
+ ret[agent_id] = self._convert_individual_space(obs_space, act_space)
57
+ return gym.spaces.Dict(ret)
58
+ else:
59
+ return self._convert_individual_space(
60
+ input_observation_space, input_action_space
61
+ )
62
+
63
+ def __init__(
64
+ self,
65
+ input_observation_space: Optional[gym.Space] = None,
66
+ input_action_space: Optional[gym.Space] = None,
67
+ *,
68
+ multi_agent: bool = False,
69
+ n_prev_actions: int = 1,
70
+ n_prev_rewards: int = 1,
71
+ **kwargs,
72
+ ):
73
+ """Initializes a PrevActionsPrevRewards instance.
74
+
75
+ Args:
76
+ multi_agent: Whether this is a connector operating on a multi-agent
77
+ observation space mapping AgentIDs to individual agents' observations.
78
+ n_prev_actions: The number of previous actions to include in the output
79
+ data. Discrete actions are ont-hot'd. If > 1, will concatenate the
80
+ individual action tensors.
81
+ n_prev_rewards: The number of previous rewards to include in the output
82
+ data.
83
+ """
84
+ super().__init__(
85
+ input_observation_space=input_observation_space,
86
+ input_action_space=input_action_space,
87
+ **kwargs,
88
+ )
89
+
90
+ self._multi_agent = multi_agent
91
+ self.n_prev_actions = n_prev_actions
92
+ self.n_prev_rewards = n_prev_rewards
93
+
94
+ # TODO: Move into input_observation_space setter
95
+ # Thus far, this connector piece only operates on discrete action spaces.
96
+ # act_spaces = [self.input_action_space]
97
+ # if self._multi_agent:
98
+ # act_spaces = self.input_action_space.spaces.values()
99
+ # if not all(isinstance(s, gym.spaces.Discrete) for s in act_spaces):
100
+ # raise ValueError(
101
+ # f"{type(self).__name__} only works on Discrete action spaces "
102
+ # f"thus far (or, for multi-agent, on Dict spaces mapping AgentIDs to "
103
+ # f"the individual agents' Discrete action spaces)!"
104
+ # )
105
+
106
+ @override(ConnectorV2)
107
+ def __call__(
108
+ self,
109
+ *,
110
+ rl_module: RLModule,
111
+ batch: Optional[Dict[str, Any]],
112
+ episodes: List[EpisodeType],
113
+ explore: Optional[bool] = None,
114
+ shared_data: Optional[dict] = None,
115
+ **kwargs,
116
+ ) -> Any:
117
+ for sa_episode in self.single_agent_episode_iterator(
118
+ episodes, agents_that_stepped_only=True
119
+ ):
120
+ # Episode is not numpy'ized yet and thus still operates on lists of items.
121
+ assert not sa_episode.is_numpy
122
+
123
+ augmented_obs = {self.ORIG_OBS_KEY: sa_episode.get_observations(-1)}
124
+
125
+ if self.n_prev_actions:
126
+ augmented_obs[self.PREV_ACTIONS_KEY] = flatten_to_single_ndarray(
127
+ batch_fn(
128
+ sa_episode.get_actions(
129
+ indices=slice(-self.n_prev_actions, None),
130
+ fill=0.0,
131
+ one_hot_discrete=True,
132
+ )
133
+ )
134
+ )
135
+
136
+ if self.n_prev_rewards:
137
+ augmented_obs[self.PREV_REWARDS_KEY] = np.array(
138
+ sa_episode.get_rewards(
139
+ indices=slice(-self.n_prev_rewards, None),
140
+ fill=0.0,
141
+ )
142
+ )
143
+
144
+ # Write new observation directly back into the episode.
145
+ sa_episode.set_observations(at_indices=-1, new_data=augmented_obs)
146
+ # We set the Episode's observation space to ours so that we can safely
147
+ # set the last obs to the new value (without causing a space mismatch
148
+ # error).
149
+ sa_episode.observation_space = self.observation_space
150
+
151
+ return batch
152
+
153
+ def _convert_individual_space(self, obs_space, act_space):
154
+ return gym.spaces.Dict(
155
+ {
156
+ self.ORIG_OBS_KEY: obs_space,
157
+ # Currently only works for Discrete action spaces.
158
+ self.PREV_ACTIONS_KEY: Box(
159
+ 0.0, 1.0, (act_space.n * self.n_prev_actions,), np.float32
160
+ ),
161
+ self.PREV_REWARDS_KEY: Box(
162
+ float("-inf"),
163
+ float("inf"),
164
+ (self.n_prev_rewards,),
165
+ np.float32,
166
+ ),
167
+ }
168
+ )
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
4
+ from ray.rllib.core.columns import Columns
5
+ from ray.rllib.core.rl_module.rl_module import RLModule
6
+ from ray.rllib.utils.annotations import override
7
+ from ray.rllib.utils.typing import EpisodeType
8
+ from ray.util.annotations import PublicAPI
9
+
10
+
11
+ @PublicAPI(stability="alpha")
12
+ class WriteObservationsToEpisodes(ConnectorV2):
13
+ """Writes the observations from the batch into the running episodes.
14
+
15
+ Note: This is one of the default env-to-module ConnectorV2 pieces that are added
16
+ automatically by RLlib into every env-to-module connector pipelines, unless
17
+ `config.add_default_connectors_to_env_to_module_pipeline` is set to False.
18
+
19
+ The default env-to-module connector pipeline is:
20
+ [
21
+ [0 or more user defined ConnectorV2 pieces],
22
+ AddObservationsFromEpisodesToBatch,
23
+ AddStatesFromEpisodesToBatch,
24
+ AgentToModuleMapping, # only in multi-agent setups!
25
+ BatchIndividualItems,
26
+ NumpyToTensor,
27
+ ]
28
+
29
+ This ConnectorV2:
30
+ - Operates on a batch that already has observations in it and a list of Episode
31
+ objects.
32
+ - Writes the observation(s) from the batch to all the given episodes. Thereby
33
+ the number of observations in the batch must match the length of the list of
34
+ episodes given.
35
+ - Does NOT alter any observations (or other data) in the batch.
36
+ - Can only be used in an EnvToModule pipeline (writing into Episode objects in a
37
+ Learner pipeline does not make a lot of sense as - after the learner update - the
38
+ list of episodes is discarded).
39
+
40
+ .. testcode::
41
+
42
+ import gymnasium as gym
43
+ import numpy as np
44
+
45
+ from ray.rllib.connectors.env_to_module import WriteObservationsToEpisodes
46
+ from ray.rllib.env.single_agent_episode import SingleAgentEpisode
47
+ from ray.rllib.utils.test_utils import check
48
+
49
+ # Assume we have two episodes (vectorized), then our forward batch will carry
50
+ # two observation records (batch size = 2).
51
+ # The connector in this example will write these two (possibly transformed)
52
+ # observations back into the two respective SingleAgentEpisode objects.
53
+ batch = {
54
+ "obs": [np.array([0.0, 1.0], np.float32), np.array([2.0, 3.0], np.float32)],
55
+ }
56
+
57
+ # Our two episodes have one observation each (i.e. the reset one). This is the
58
+ # one that will be overwritten by the connector in this example.
59
+ obs_space = gym.spaces.Box(-10.0, 10.0, (2,), np.float32)
60
+ act_space = gym.spaces.Discrete(2)
61
+ episodes = [
62
+ SingleAgentEpisode(
63
+ observation_space=obs_space,
64
+ observations=[np.array([-10, -20], np.float32)],
65
+ len_lookback_buffer=0,
66
+ ) for _ in range(2)
67
+ ]
68
+ # Make sure everything is setup correctly.
69
+ check(episodes[0].get_observations(0), [-10.0, -20.0])
70
+ check(episodes[1].get_observations(-1), [-10.0, -20.0])
71
+
72
+ # Create our connector piece.
73
+ connector = WriteObservationsToEpisodes(obs_space, act_space)
74
+
75
+ # Call the connector (and thereby write the transformed observations back
76
+ # into the episodes).
77
+ output_batch = connector(
78
+ rl_module=None, # This particular connector works without an RLModule.
79
+ batch=batch,
80
+ episodes=episodes,
81
+ explore=True,
82
+ shared_data={},
83
+ )
84
+
85
+ # The connector does NOT change the data batch being passed through.
86
+ check(output_batch, batch)
87
+
88
+ # However, the connector has overwritten the last observations in the episodes.
89
+ check(episodes[0].get_observations(-1), [0.0, 1.0])
90
+ check(episodes[1].get_observations(0), [2.0, 3.0])
91
+ """
92
+
93
+ @override(ConnectorV2)
94
+ def __call__(
95
+ self,
96
+ *,
97
+ rl_module: RLModule,
98
+ batch: Optional[Dict[str, Any]],
99
+ episodes: List[EpisodeType],
100
+ explore: Optional[bool] = None,
101
+ shared_data: Optional[dict] = None,
102
+ **kwargs,
103
+ ) -> Any:
104
+ observations = batch.get(Columns.OBS)
105
+
106
+ if observations is None:
107
+ raise ValueError(
108
+ f"`batch` must already have a column named {Columns.OBS} in it "
109
+ f"for this connector to work!"
110
+ )
111
+
112
+ # Note that the following loop works with multi-agent as well as with
113
+ # single-agent episode, as long as the following conditions are met (these
114
+ # will be validated by `self.single_agent_episode_iterator()`):
115
+ # - Per single agent episode, one observation item is expected to exist in
116
+ # `data`, either in a list directly under the "obs" key OR for multi-agent:
117
+ # in a list sitting under a key `(agent_id, module_id)` of a dict sitting
118
+ # under the "obs" key.
119
+ for sa_episode, obs in self.single_agent_episode_iterator(
120
+ episodes=episodes, zip_with_batch_column=observations
121
+ ):
122
+ # Make sure episodes are NOT numpy'ized yet (we are expecting to run in an
123
+ # env-to-module pipeline).
124
+ assert not sa_episode.is_numpy
125
+ # Write new information into the episode.
126
+ sa_episode.set_observations(at_indices=-1, new_data=obs)
127
+ # Change the observation space of the sa_episode.
128
+ sa_episode.observation_space = self.observation_space
129
+
130
+ # Return the unchanged `batch`.
131
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.connectors.common.tensor_to_numpy import TensorToNumpy
2
+ from ray.rllib.connectors.common.module_to_agent_unmapping import ModuleToAgentUnmapping
3
+ from ray.rllib.connectors.module_to_env.get_actions import GetActions
4
+ from ray.rllib.connectors.module_to_env.listify_data_for_vector_env import (
5
+ ListifyDataForVectorEnv,
6
+ )
7
+ from ray.rllib.connectors.module_to_env.module_to_env_pipeline import (
8
+ ModuleToEnvPipeline,
9
+ )
10
+ from ray.rllib.connectors.module_to_env.normalize_and_clip_actions import (
11
+ NormalizeAndClipActions,
12
+ )
13
+ from ray.rllib.connectors.module_to_env.remove_single_ts_time_rank_from_batch import (
14
+ RemoveSingleTsTimeRankFromBatch,
15
+ )
16
+ from ray.rllib.connectors.module_to_env.unbatch_to_individual_items import (
17
+ UnBatchToIndividualItems,
18
+ )
19
+
20
+
21
+ __all__ = [
22
+ "GetActions",
23
+ "ListifyDataForVectorEnv",
24
+ "ModuleToAgentUnmapping",
25
+ "ModuleToEnvPipeline",
26
+ "NormalizeAndClipActions",
27
+ "RemoveSingleTsTimeRankFromBatch",
28
+ "TensorToNumpy",
29
+ "UnBatchToIndividualItems",
30
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc ADDED
Binary file (4.19 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
4
+ from ray.rllib.core.columns import Columns
5
+ from ray.rllib.core.rl_module.rl_module import RLModule
6
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
7
+ from ray.rllib.utils.annotations import override
8
+ from ray.rllib.utils.typing import EpisodeType
9
+ from ray.util.annotations import PublicAPI
10
+
11
+
12
+ @PublicAPI(stability="alpha")
13
+ class GetActions(ConnectorV2):
14
+ """Connector piece sampling actions from ACTION_DIST_INPUTS from an RLModule.
15
+
16
+ Note: This is one of the default module-to-env ConnectorV2 pieces that
17
+ are added automatically by RLlib into every module-to-env connector pipeline,
18
+ unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
19
+ False.
20
+
21
+ The default module-to-env connector pipeline is:
22
+ [
23
+ GetActions,
24
+ TensorToNumpy,
25
+ UnBatchToIndividualItems,
26
+ ModuleToAgentUnmapping, # only in multi-agent setups!
27
+ RemoveSingleTsTimeRankFromBatch,
28
+
29
+ [0 or more user defined ConnectorV2 pieces],
30
+
31
+ NormalizeAndClipActions,
32
+ ListifyDataForVectorEnv,
33
+ ]
34
+
35
+ If necessary, this connector samples actions, given action dist. inputs and a
36
+ dist. class.
37
+ The connector will only sample from the action distribution, if the
38
+ Columns.ACTIONS key cannot be found in `data`. Otherwise, it'll behave
39
+ as pass-through. If Columns.ACTIONS is NOT present in `data`, but
40
+ Columns.ACTION_DIST_INPUTS is, this connector will create a new action
41
+ distribution using the given RLModule and sample from its distribution class
42
+ (deterministically, if we are not exploring, stochastically, if we are).
43
+ """
44
+
45
+ @override(ConnectorV2)
46
+ def __call__(
47
+ self,
48
+ *,
49
+ rl_module: RLModule,
50
+ batch: Dict[str, Any],
51
+ episodes: List[EpisodeType],
52
+ explore: Optional[bool] = None,
53
+ shared_data: Optional[dict] = None,
54
+ **kwargs,
55
+ ) -> Any:
56
+ is_multi_agent = isinstance(episodes[0], MultiAgentEpisode)
57
+
58
+ if is_multi_agent:
59
+ for module_id, module_data in batch.copy().items():
60
+ self._get_actions(module_data, rl_module[module_id], explore)
61
+ else:
62
+ self._get_actions(batch, rl_module, explore)
63
+
64
+ return batch
65
+
66
+ def _get_actions(self, batch, sa_rl_module, explore):
67
+ # Action have already been sampled -> Early out.
68
+ if Columns.ACTIONS in batch:
69
+ return
70
+
71
+ # ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` ->
72
+ # Create a new action distribution object.
73
+ if Columns.ACTION_DIST_INPUTS in batch:
74
+ if explore:
75
+ action_dist_class = sa_rl_module.get_exploration_action_dist_cls()
76
+ else:
77
+ action_dist_class = sa_rl_module.get_inference_action_dist_cls()
78
+ action_dist = action_dist_class.from_logits(
79
+ batch[Columns.ACTION_DIST_INPUTS],
80
+ )
81
+ if not explore:
82
+ action_dist = action_dist.to_deterministic()
83
+
84
+ # Sample actions from the distribution.
85
+ actions = action_dist.sample()
86
+ batch[Columns.ACTIONS] = actions
87
+
88
+ # For convenience and if possible, compute action logp from distribution
89
+ # and add to output.
90
+ if Columns.ACTION_LOGP not in batch:
91
+ batch[Columns.ACTION_LOGP] = action_dist.logp(actions)
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
4
+ from ray.rllib.core.columns import Columns
5
+ from ray.rllib.core.rl_module.rl_module import RLModule
6
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
7
+ from ray.rllib.utils.annotations import override
8
+ from ray.rllib.utils.spaces.space_utils import batch as batch_fn
9
+ from ray.rllib.utils.typing import EpisodeType
10
+ from ray.util.annotations import PublicAPI
11
+
12
+
13
+ @PublicAPI(stability="alpha")
14
+ class ListifyDataForVectorEnv(ConnectorV2):
15
+ """Performs conversion from ConnectorV2-style format to env/episode insertion.
16
+
17
+ Note: This is one of the default module-to-env ConnectorV2 pieces that
18
+ are added automatically by RLlib into every module-to-env connector pipeline,
19
+ unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
20
+ False.
21
+
22
+ The default module-to-env connector pipeline is:
23
+ [
24
+ GetActions,
25
+ TensorToNumpy,
26
+ UnBatchToIndividualItems,
27
+ ModuleToAgentUnmapping, # only in multi-agent setups!
28
+ RemoveSingleTsTimeRankFromBatch,
29
+
30
+ [0 or more user defined ConnectorV2 pieces],
31
+
32
+ NormalizeAndClipActions,
33
+ ListifyDataForVectorEnv,
34
+ ]
35
+
36
+ Single agent case:
37
+ Convert from:
38
+ [col] -> [(episode_id,)] -> [list of items].
39
+ To:
40
+ [col] -> [list of items].
41
+
42
+ Multi-agent case:
43
+ Convert from:
44
+ [col] -> [(episode_id, agent_id, module_id)] -> list of items.
45
+ To:
46
+ [col] -> [list of multi-agent dicts].
47
+ """
48
+
49
+ @override(ConnectorV2)
50
+ def __call__(
51
+ self,
52
+ *,
53
+ rl_module: RLModule,
54
+ batch: Dict[str, Any],
55
+ episodes: List[EpisodeType],
56
+ explore: Optional[bool] = None,
57
+ shared_data: Optional[dict] = None,
58
+ **kwargs,
59
+ ) -> Any:
60
+ for column, column_data in batch.copy().items():
61
+ # Multi-agent case: Create lists of multi-agent dicts under each column.
62
+ if isinstance(episodes[0], MultiAgentEpisode):
63
+ # TODO (sven): Support vectorized MultiAgentEnv
64
+ assert len(episodes) == 1
65
+ new_column_data = [{}]
66
+
67
+ for key, value in batch[column].items():
68
+ assert len(value) == 1
69
+ eps_id, agent_id, module_id = key
70
+ new_column_data[0][agent_id] = value[0]
71
+ batch[column] = new_column_data
72
+ # Single-agent case: Create simple lists under each column.
73
+ else:
74
+ batch[column] = [
75
+ d for key in batch[column].keys() for d in batch[column][key]
76
+ ]
77
+ # Batch actions for (single-agent) gym.vector.Env.
78
+ # All other columns, leave listify'ed.
79
+ if column in [Columns.ACTIONS_FOR_ENV, Columns.ACTIONS]:
80
+ batch[column] = batch_fn(batch[column])
81
+
82
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2
2
+ from ray.util.annotations import PublicAPI
3
+
4
+
5
+ @PublicAPI(stability="alpha")
6
+ class ModuleToEnvPipeline(ConnectorPipelineV2):
7
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import gymnasium as gym
5
+
6
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
7
+ from ray.rllib.core.columns import Columns
8
+ from ray.rllib.core.rl_module.rl_module import RLModule
9
+ from ray.rllib.utils.annotations import override
10
+ from ray.rllib.utils.spaces.space_utils import (
11
+ clip_action,
12
+ get_base_struct_from_space,
13
+ unsquash_action,
14
+ )
15
+ from ray.rllib.utils.typing import EpisodeType
16
+ from ray.util.annotations import PublicAPI
17
+
18
+
19
+ @PublicAPI(stability="alpha")
20
+ class NormalizeAndClipActions(ConnectorV2):
21
+ """Normalizes or clips actions in the input data (coming from the RLModule).
22
+
23
+ Note: This is one of the default module-to-env ConnectorV2 pieces that
24
+ are added automatically by RLlib into every module-to-env connector pipeline,
25
+ unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
26
+ False.
27
+
28
+ The default module-to-env connector pipeline is:
29
+ [
30
+ GetActions,
31
+ TensorToNumpy,
32
+ UnBatchToIndividualItems,
33
+ ModuleToAgentUnmapping, # only in multi-agent setups!
34
+ RemoveSingleTsTimeRankFromBatch,
35
+
36
+ [0 or more user defined ConnectorV2 pieces],
37
+
38
+ NormalizeAndClipActions,
39
+ ListifyDataForVectorEnv,
40
+ ]
41
+
42
+ This ConnectorV2:
43
+ - Deep copies the Columns.ACTIONS in the incoming `data` into a new column:
44
+ Columns.ACTIONS_FOR_ENV.
45
+ - Loops through the Columns.ACTIONS in the incoming `data` and normalizes or clips
46
+ these depending on the c'tor settings in `config.normalize_actions` and
47
+ `config.clip_actions`.
48
+ - Only applies to envs with Box action spaces.
49
+
50
+ Normalizing is the process of mapping NN-outputs (which are usually small
51
+ numbers, e.g. between -1.0 and 1.0) to the bounds defined by the action-space.
52
+ Normalizing helps the NN to learn faster in environments with large ranges between
53
+ `low` and `high` bounds or skewed action bounds (e.g. Box(-3000.0, 1.0, ...)).
54
+
55
+ Clipping clips the actions computed by the NN (and sampled from a distribution)
56
+ between the bounds defined by the action-space. Note that clipping is only performed
57
+ if `normalize_actions` is False.
58
+ """
59
+
60
+ @override(ConnectorV2)
61
+ def recompute_output_action_space(
62
+ self,
63
+ input_observation_space: gym.Space,
64
+ input_action_space: gym.Space,
65
+ ) -> gym.Space:
66
+ self._action_space_struct = get_base_struct_from_space(input_action_space)
67
+ return input_action_space
68
+
69
+ def __init__(
70
+ self,
71
+ input_observation_space: Optional[gym.Space] = None,
72
+ input_action_space: Optional[gym.Space] = None,
73
+ *,
74
+ normalize_actions: bool,
75
+ clip_actions: bool,
76
+ **kwargs,
77
+ ):
78
+ """Initializes a DefaultModuleToEnv (connector piece) instance.
79
+
80
+ Args:
81
+ normalize_actions: If True, actions coming from the RLModule's distribution
82
+ (or are directly computed by the RLModule w/o sampling) will
83
+ be assumed 0.0 centered with a small stddev (only affecting Box
84
+ components) and thus be unsquashed (and clipped, just in case) to the
85
+ bounds of the env's action space. For example, if the action space of
86
+ the environment is `Box(-2.0, -0.5, (1,))`, the model outputs
87
+ mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9
88
+ from the resulting distribution, then this 0.9 will be unsquashed into
89
+ the [-2.0 -0.5] interval. If - after unsquashing - the action still
90
+ breaches the action space, it will simply be clipped.
91
+ clip_actions: If True, actions coming from the RLModule's distribution
92
+ (or are directly computed by the RLModule w/o sampling) will be clipped
93
+ such that they fit into the env's action space's bounds.
94
+ For example, if the action space of the environment is
95
+ `Box(-0.5, 0.5, (1,))`, the model outputs
96
+ mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9
97
+ from the resulting distribution, then this 0.9 will be clipped to 0.5
98
+ to fit into the [-0.5 0.5] interval.
99
+ """
100
+ self._action_space_struct = None
101
+
102
+ super().__init__(input_observation_space, input_action_space, **kwargs)
103
+
104
+ self.normalize_actions = normalize_actions
105
+ self.clip_actions = clip_actions
106
+
107
+ @override(ConnectorV2)
108
+ def __call__(
109
+ self,
110
+ *,
111
+ rl_module: RLModule,
112
+ batch: Optional[Dict[str, Any]],
113
+ episodes: List[EpisodeType],
114
+ explore: Optional[bool] = None,
115
+ shared_data: Optional[dict] = None,
116
+ **kwargs,
117
+ ) -> Any:
118
+ """Based on settings, will normalize (unsquash) and/or clip computed actions.
119
+
120
+ This is such that the final actions (to be sent to the env) match the
121
+ environment's action space and thus don't lead to an error.
122
+ """
123
+
124
+ def _unsquash_or_clip(action_for_env, env_id, agent_id, module_id):
125
+ if agent_id is not None:
126
+ struct = self._action_space_struct[agent_id]
127
+ else:
128
+ struct = self._action_space_struct
129
+
130
+ if self.normalize_actions:
131
+ return unsquash_action(action_for_env, struct)
132
+ else:
133
+ return clip_action(action_for_env, struct)
134
+
135
+ # Normalize or clip this new actions_for_env column, leaving the originally
136
+ # computed/sampled actions intact.
137
+ if self.normalize_actions or self.clip_actions:
138
+ # Copy actions into separate column, just to go to the env.
139
+ batch[Columns.ACTIONS_FOR_ENV] = copy.deepcopy(batch[Columns.ACTIONS])
140
+ self.foreach_batch_item_change_in_place(
141
+ batch=batch,
142
+ column=Columns.ACTIONS_FOR_ENV,
143
+ func=_unsquash_or_clip,
144
+ )
145
+
146
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import numpy as np
4
+ import tree # pip install dm_tree
5
+
6
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
7
+ from ray.rllib.core.columns import Columns
8
+ from ray.rllib.core.rl_module.rl_module import RLModule
9
+ from ray.rllib.utils.annotations import override
10
+ from ray.rllib.utils.typing import EpisodeType
11
+ from ray.util.annotations import PublicAPI
12
+
13
+
14
+ @PublicAPI(stability="alpha")
15
+ class RemoveSingleTsTimeRankFromBatch(ConnectorV2):
16
+ """
17
+ Note: This is one of the default module-to-env ConnectorV2 pieces that
18
+ are added automatically by RLlib into every module-to-env connector pipeline,
19
+ unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
20
+ False.
21
+
22
+ The default module-to-env connector pipeline is:
23
+ [
24
+ GetActions,
25
+ TensorToNumpy,
26
+ UnBatchToIndividualItems,
27
+ ModuleToAgentUnmapping, # only in multi-agent setups!
28
+ RemoveSingleTsTimeRankFromBatch,
29
+
30
+ [0 or more user defined ConnectorV2 pieces],
31
+
32
+ NormalizeAndClipActions,
33
+ ListifyDataForVectorEnv,
34
+ ]
35
+
36
+ """
37
+
38
+ @override(ConnectorV2)
39
+ def __call__(
40
+ self,
41
+ *,
42
+ rl_module: RLModule,
43
+ batch: Optional[Dict[str, Any]],
44
+ episodes: List[EpisodeType],
45
+ explore: Optional[bool] = None,
46
+ shared_data: Optional[dict] = None,
47
+ **kwargs,
48
+ ) -> Any:
49
+ # If single ts time-rank had not been added, early out.
50
+ if shared_data is None or not shared_data.get("_added_single_ts_time_rank"):
51
+ return batch
52
+
53
+ def _remove_single_ts(item, eps_id, aid, mid):
54
+ # Only remove time-rank for modules that are statefule (only for those has
55
+ # a timerank been added).
56
+ if mid is None or rl_module[mid].is_stateful():
57
+ return tree.map_structure(lambda s: np.squeeze(s, axis=0), item)
58
+ return item
59
+
60
+ for column, column_data in batch.copy().items():
61
+ # Skip state_out (doesn't have a time rank).
62
+ if column == Columns.STATE_OUT:
63
+ continue
64
+ self.foreach_batch_item_change_in_place(
65
+ batch,
66
+ column=column,
67
+ func=_remove_single_ts,
68
+ )
69
+
70
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import tree # pip install dm_tree
5
+
6
+ from ray.rllib.connectors.connector_v2 import ConnectorV2
7
+ from ray.rllib.core.rl_module.rl_module import RLModule
8
+ from ray.rllib.utils.annotations import override
9
+ from ray.rllib.utils.spaces.space_utils import unbatch as unbatch_fn
10
+ from ray.rllib.utils.typing import EpisodeType
11
+ from ray.util.annotations import PublicAPI
12
+
13
+
14
+ @PublicAPI(stability="alpha")
15
+ class UnBatchToIndividualItems(ConnectorV2):
16
+ """Unbatches the given `data` back into the individual-batch-items format.
17
+
18
+ Note: This is one of the default module-to-env ConnectorV2 pieces that
19
+ are added automatically by RLlib into every module-to-env connector pipeline,
20
+ unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
21
+ False.
22
+
23
+ The default module-to-env connector pipeline is:
24
+ [
25
+ GetActions,
26
+ TensorToNumpy,
27
+ UnBatchToIndividualItems,
28
+ ModuleToAgentUnmapping, # only in multi-agent setups!
29
+ RemoveSingleTsTimeRankFromBatch,
30
+
31
+ [0 or more user defined ConnectorV2 pieces],
32
+
33
+ NormalizeAndClipActions,
34
+ ListifyDataForVectorEnv,
35
+ ]
36
+ """
37
+
38
+ @override(ConnectorV2)
39
+ def __call__(
40
+ self,
41
+ *,
42
+ rl_module: RLModule,
43
+ batch: Dict[str, Any],
44
+ episodes: List[EpisodeType],
45
+ explore: Optional[bool] = None,
46
+ shared_data: Optional[dict] = None,
47
+ **kwargs,
48
+ ) -> Any:
49
+ memorized_map_structure = shared_data.get("memorized_map_structure")
50
+
51
+ # Simple case (no structure stored): Just unbatch.
52
+ if memorized_map_structure is None:
53
+ return tree.map_structure(lambda s: unbatch_fn(s), batch)
54
+ # Single agent case: Memorized structure is a list, whose indices map to
55
+ # eps_id values.
56
+ elif isinstance(memorized_map_structure, list):
57
+ for column, column_data in batch.copy().items():
58
+ column_data = unbatch_fn(column_data)
59
+ new_column_data = defaultdict(list)
60
+ for i, eps_id in enumerate(memorized_map_structure):
61
+ # Keys are always tuples to resemble multi-agent keys, which
62
+ # have the structure (eps_id, agent_id, module_id).
63
+ key = (eps_id,)
64
+ new_column_data[key].append(column_data[i])
65
+ batch[column] = dict(new_column_data)
66
+ # Multi-agent case: Memorized structure is dict mapping module_ids to lists of
67
+ # (eps_id, agent_id)-tuples, such that the original individual-items-based form
68
+ # can be constructed.
69
+ else:
70
+ for module_id, module_data in batch.copy().items():
71
+ if module_id not in memorized_map_structure:
72
+ raise KeyError(
73
+ f"ModuleID={module_id} not found in `memorized_map_structure`!"
74
+ )
75
+ for column, column_data in module_data.items():
76
+ column_data = unbatch_fn(column_data)
77
+ new_column_data = defaultdict(list)
78
+ for i, (eps_id, agent_id) in enumerate(
79
+ memorized_map_structure[module_id]
80
+ ):
81
+ key = (eps_id, agent_id, module_id)
82
+ # TODO (sven): Support vectorization for MultiAgentEnvRunner.
83
+ # AgentIDs whose SingleAgentEpisodes are already done, should
84
+ # not send any data back to the EnvRunner for further
85
+ # processing.
86
+ if episodes[0].agent_episodes[agent_id].is_done:
87
+ continue
88
+
89
+ new_column_data[key].append(column_data[i])
90
+ module_data[column] = dict(new_column_data)
91
+
92
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Registry of connector names for global access."""
2
+ from typing import Any
3
+
4
+ from ray.rllib.utils.annotations import OldAPIStack
5
+ from ray.rllib.connectors.connector import Connector, ConnectorContext
6
+
7
+
8
+ ALL_CONNECTORS = dict()
9
+
10
+
11
+ @OldAPIStack
12
+ def register_connector(name: str, cls: Connector):
13
+ """Register a connector for use with RLlib.
14
+
15
+ Args:
16
+ name: Name to register.
17
+ cls: Callable that creates an env.
18
+ """
19
+ if name in ALL_CONNECTORS:
20
+ return
21
+
22
+ if not issubclass(cls, Connector):
23
+ raise TypeError("Can only register Connector type.", cls)
24
+
25
+ # Record it in local registry in case we need to register everything
26
+ # again in the global registry, for example in the event of cluster
27
+ # restarts.
28
+ ALL_CONNECTORS[name] = cls
29
+
30
+
31
+ @OldAPIStack
32
+ def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
33
+ # TODO(jungong) : switch the order of parameters man!!
34
+ """Get a connector by its name and serialized config.
35
+
36
+ Args:
37
+ name: name of the connector.
38
+ ctx: Connector context.
39
+ params: serialized parameters of the connector.
40
+
41
+ Returns:
42
+ Constructed connector.
43
+ """
44
+ if name not in ALL_CONNECTORS:
45
+ raise NameError("connector not found.", name)
46
+ return ALL_CONNECTORS[name].from_state(ctx, params)
.venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Tuple, TYPE_CHECKING
3
+
4
+ from ray.rllib.connectors.action.clip import ClipActionsConnector
5
+ from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
6
+ from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
7
+ from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
8
+ from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
9
+ from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
10
+ from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
11
+ from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
12
+ from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
13
+ from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
14
+ from ray.rllib.connectors.connector import Connector, ConnectorContext
15
+ from ray.rllib.connectors.registry import get_connector
16
+ from ray.rllib.connectors.agent.mean_std_filter import (
17
+ MeanStdObservationFilterAgentConnector,
18
+ ConcurrentMeanStdObservationFilterAgentConnector,
19
+ )
20
+ from ray.rllib.utils.annotations import OldAPIStack
21
+ from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
22
+
23
+ if TYPE_CHECKING:
24
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
25
+ from ray.rllib.policy.policy import Policy
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def __preprocessing_enabled(config: "AlgorithmConfig"):
31
+ if config._disable_preprocessor_api:
32
+ return False
33
+ # Same conditions as in RolloutWorker.__init__.
34
+ if config.is_atari and config.preprocessor_pref == "deepmind":
35
+ return False
36
+ if config.preprocessor_pref is None:
37
+ return False
38
+ return True
39
+
40
+
41
+ def __clip_rewards(config: "AlgorithmConfig"):
42
+ # Same logic as in RolloutWorker.__init__.
43
+ # We always clip rewards for Atari games.
44
+ return config.clip_rewards or config.is_atari
45
+
46
+
47
+ @OldAPIStack
48
+ def get_agent_connectors_from_config(
49
+ ctx: ConnectorContext,
50
+ config: "AlgorithmConfig",
51
+ ) -> AgentConnectorPipeline:
52
+ connectors = []
53
+
54
+ clip_rewards = __clip_rewards(config)
55
+ if clip_rewards is True:
56
+ connectors.append(ClipRewardAgentConnector(ctx, sign=True))
57
+ elif type(clip_rewards) is float:
58
+ connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))
59
+
60
+ if __preprocessing_enabled(config):
61
+ connectors.append(ObsPreprocessorConnector(ctx))
62
+
63
+ # Filters should be after observation preprocessing
64
+ filter_connector = get_synced_filter_connector(
65
+ ctx,
66
+ )
67
+ # Configuration option "NoFilter" results in `filter_connector==None`.
68
+ if filter_connector:
69
+ connectors.append(filter_connector)
70
+
71
+ connectors.extend(
72
+ [
73
+ StateBufferConnector(ctx),
74
+ ViewRequirementAgentConnector(ctx),
75
+ ]
76
+ )
77
+
78
+ return AgentConnectorPipeline(ctx, connectors)
79
+
80
+
81
+ @OldAPIStack
82
+ def get_action_connectors_from_config(
83
+ ctx: ConnectorContext,
84
+ config: "AlgorithmConfig",
85
+ ) -> ActionConnectorPipeline:
86
+ """Default list of action connectors to use for a new policy.
87
+
88
+ Args:
89
+ ctx: context used to create connectors.
90
+ config: The AlgorithmConfig object.
91
+ """
92
+ connectors = [ConvertToNumpyConnector(ctx)]
93
+ if config.get("normalize_actions", False):
94
+ connectors.append(NormalizeActionsConnector(ctx))
95
+ if config.get("clip_actions", False):
96
+ connectors.append(ClipActionsConnector(ctx))
97
+ connectors.append(ImmutableActionsConnector(ctx))
98
+ return ActionConnectorPipeline(ctx, connectors)
99
+
100
+
101
+ @OldAPIStack
102
+ def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
103
+ """Util to create agent and action connectors for a Policy.
104
+
105
+ Args:
106
+ policy: Policy instance.
107
+ config: Algorithm config dict.
108
+ """
109
+ ctx: ConnectorContext = ConnectorContext.from_policy(policy)
110
+
111
+ assert (
112
+ policy.agent_connectors is None and policy.action_connectors is None
113
+ ), "Can not create connectors for a policy that already has connectors."
114
+
115
+ policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
116
+ policy.action_connectors = get_action_connectors_from_config(ctx, config)
117
+
118
+ logger.info("Using connectors:")
119
+ logger.info(policy.agent_connectors.__str__(indentation=4))
120
+ logger.info(policy.action_connectors.__str__(indentation=4))
121
+
122
+
123
+ @OldAPIStack
124
+ def restore_connectors_for_policy(
125
+ policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
126
+ ) -> Connector:
127
+ """Util to create connector for a Policy based on serialized config.
128
+
129
+ Args:
130
+ policy: Policy instance.
131
+ connector_config: Serialized connector config.
132
+ """
133
+ ctx: ConnectorContext = ConnectorContext.from_policy(policy)
134
+ name, params = connector_config
135
+ return get_connector(name, ctx, params)
136
+
137
+
138
+ # We need this filter selection mechanism temporarily to remain compatible to old API
139
+ @OldAPIStack
140
+ def get_synced_filter_connector(ctx: ConnectorContext):
141
+ filter_specifier = ctx.config.get("observation_filter")
142
+ if filter_specifier == "MeanStdFilter":
143
+ return MeanStdObservationFilterAgentConnector(ctx, clip=None)
144
+ elif filter_specifier == "ConcurrentMeanStdFilter":
145
+ return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None)
146
+ elif filter_specifier == "NoFilter":
147
+ return None
148
+ else:
149
+ raise Exception("Unknown observation_filter: " + str(filter_specifier))
150
+
151
+
152
+ @OldAPIStack
153
+ def maybe_get_filters_for_syncing(rollout_worker, policy_id):
154
+ # As long as the historic filter synchronization mechanism is in
155
+ # place, we need to put filters into self.filters so that they get
156
+ # synchronized
157
+ policy = rollout_worker.policy_map[policy_id]
158
+ if not policy.agent_connectors:
159
+ return
160
+
161
+ filter_connectors = policy.agent_connectors[SyncedFilterAgentConnector]
162
+ # There can only be one filter at a time
163
+ if not filter_connectors:
164
+ return
165
+
166
+ assert len(filter_connectors) == 1, (
167
+ "ConnectorPipeline has multiple connectors of type "
168
+ "SyncedFilterAgentConnector but can only have one."
169
+ )
170
+ rollout_worker.filters[policy_id] = filter_connectors[0].filter
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.core.learner.learner import Learner
2
+ from ray.rllib.core.learner.learner_group import LearnerGroup
3
+
4
+
5
+ __all__ = [
6
+ "Learner",
7
+ "LearnerGroup",
8
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (400 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc ADDED
Binary file (75.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc ADDED
Binary file (44.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.59 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py ADDED
@@ -0,0 +1,1795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from collections import defaultdict
3
+ import copy
4
+ import logging
5
+ import numpy
6
+ import platform
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Collection,
11
+ Dict,
12
+ List,
13
+ Hashable,
14
+ Optional,
15
+ Sequence,
16
+ Tuple,
17
+ TYPE_CHECKING,
18
+ Union,
19
+ )
20
+
21
+ import tree # pip install dm_tree
22
+
23
+ import ray
24
+ from ray.data.iterator import DataIterator
25
+ from ray.rllib.connectors.learner.learner_connector_pipeline import (
26
+ LearnerConnectorPipeline,
27
+ )
28
+ from ray.rllib.core import (
29
+ COMPONENT_METRICS_LOGGER,
30
+ COMPONENT_OPTIMIZER,
31
+ COMPONENT_RL_MODULE,
32
+ DEFAULT_MODULE_ID,
33
+ )
34
+ from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
35
+ from ray.rllib.core.rl_module import validate_module_id
36
+ from ray.rllib.core.rl_module.multi_rl_module import (
37
+ MultiRLModule,
38
+ MultiRLModuleSpec,
39
+ )
40
+ from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
41
+ from ray.rllib.policy.policy import PolicySpec
42
+ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
43
+ from ray.rllib.utils.annotations import (
44
+ override,
45
+ OverrideToImplementCustomLogic,
46
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
47
+ )
48
+ from ray.rllib.utils.checkpoints import Checkpointable
49
+ from ray.rllib.utils.debug import update_global_seed_if_necessary
50
+ from ray.rllib.utils.deprecation import (
51
+ Deprecated,
52
+ DEPRECATED_VALUE,
53
+ deprecation_warning,
54
+ )
55
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
56
+ from ray.rllib.utils.metrics import (
57
+ ALL_MODULES,
58
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
59
+ NUM_ENV_STEPS_TRAINED,
60
+ NUM_ENV_STEPS_TRAINED_LIFETIME,
61
+ NUM_MODULE_STEPS_TRAINED,
62
+ NUM_MODULE_STEPS_TRAINED_LIFETIME,
63
+ MODULE_TRAIN_BATCH_SIZE_MEAN,
64
+ WEIGHTS_SEQ_NO,
65
+ )
66
+ from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
67
+ from ray.rllib.utils.minibatch_utils import (
68
+ MiniBatchDummyIterator,
69
+ MiniBatchCyclicIterator,
70
+ )
71
+ from ray.rllib.utils.numpy import convert_to_numpy
72
+ from ray.rllib.utils.schedules.scheduler import Scheduler
73
+ from ray.rllib.utils.typing import (
74
+ EpisodeType,
75
+ LearningRateOrSchedule,
76
+ ModuleID,
77
+ Optimizer,
78
+ Param,
79
+ ParamRef,
80
+ ParamDict,
81
+ ResultDict,
82
+ ShouldModuleBeUpdatedFn,
83
+ StateDict,
84
+ TensorType,
85
+ )
86
+ from ray.util.annotations import PublicAPI
87
+
88
+ if TYPE_CHECKING:
89
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
90
+
91
+
92
+ torch, _ = try_import_torch()
93
+ tf1, tf, tfv = try_import_tf()
94
+
95
+ logger = logging.getLogger(__name__)
96
+
97
+ DEFAULT_OPTIMIZER = "default_optimizer"
98
+
99
+ # COMMON LEARNER LOSS_KEYS
100
+ POLICY_LOSS_KEY = "policy_loss"
101
+ VF_LOSS_KEY = "vf_loss"
102
+ ENTROPY_KEY = "entropy"
103
+
104
+ # Additional update keys
105
+ LR_KEY = "learning_rate"
106
+
107
+
108
+ @PublicAPI(stability="alpha")
109
+ class Learner(Checkpointable):
110
+ """Base class for Learners.
111
+
112
+ This class will be used to train RLModules. It is responsible for defining the loss
113
+ function, and updating the neural network weights that it owns. It also provides a
114
+ way to add/remove modules to/from RLModules in a multi-agent scenario, in the
115
+ middle of training (This is useful for league based training).
116
+
117
+ TF and Torch specific implementation of this class fills in the framework-specific
118
+ implementation details for distributed training, and for computing and applying
119
+ gradients. User should not need to sub-class this class, but instead inherit from
120
+ the TF or Torch specific sub-classes to implement their algorithm-specific update
121
+ logic.
122
+
123
+ Args:
124
+ config: The AlgorithmConfig object from which to derive most of the settings
125
+ needed to build the Learner.
126
+ module_spec: The module specification for the RLModule that is being trained.
127
+ If the module is a single agent module, after building the module it will
128
+ be converted to a multi-agent module with a default key. Can be none if the
129
+ module is provided directly via the `module` argument. Refer to
130
+ ray.rllib.core.rl_module.RLModuleSpec
131
+ or ray.rllib.core.rl_module.MultiRLModuleSpec for more info.
132
+ module: If learner is being used stand-alone, the RLModule can be optionally
133
+ passed in directly instead of the through the `module_spec`.
134
+
135
+ Note: We use PPO and torch as an example here because many of the showcased
136
+ components need implementations to come together. However, the same
137
+ pattern is generally applicable.
138
+
139
+ .. testcode::
140
+
141
+ import gymnasium as gym
142
+
143
+ from ray.rllib.algorithms.ppo.ppo import PPOConfig
144
+ from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
145
+ from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
146
+ PPOTorchRLModule
147
+ )
148
+ from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
149
+ from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
150
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
151
+
152
+ env = gym.make("CartPole-v1")
153
+
154
+ # Create a PPO config object first.
155
+ config = (
156
+ PPOConfig()
157
+ .framework("torch")
158
+ .training(model={"fcnet_hiddens": [128, 128]})
159
+ )
160
+
161
+ # Create a learner instance directly from our config. All we need as
162
+ # extra information here is the env to be able to extract space information
163
+ # (needed to construct the RLModule inside the Learner).
164
+ learner = config.build_learner(env=env)
165
+
166
+ # Take one gradient update on the module and report the results.
167
+ # results = learner.update(...)
168
+
169
+ # Add a new module, perhaps for league based training.
170
+ learner.add_module(
171
+ module_id="new_player",
172
+ module_spec=RLModuleSpec(
173
+ module_class=PPOTorchRLModule,
174
+ observation_space=env.observation_space,
175
+ action_space=env.action_space,
176
+ model_config=DefaultModelConfig(fcnet_hiddens=[64, 64]),
177
+ catalog_class=PPOCatalog,
178
+ )
179
+ )
180
+
181
+ # Take another gradient update with both previous and new modules.
182
+ # results = learner.update(...)
183
+
184
+ # Remove a module.
185
+ learner.remove_module("new_player")
186
+
187
+ # Will train previous modules only.
188
+ # results = learner.update(...)
189
+
190
+ # Get the state of the learner.
191
+ state = learner.get_state()
192
+
193
+ # Set the state of the learner.
194
+ learner.set_state(state)
195
+
196
+ # Get the weights of the underlying MultiRLModule.
197
+ weights = learner.get_state(components=COMPONENT_RL_MODULE)
198
+
199
+ # Set the weights of the underlying MultiRLModule.
200
+ learner.set_state({COMPONENT_RL_MODULE: weights})
201
+
202
+
203
+ Extension pattern:
204
+
205
+ .. testcode::
206
+
207
+ from ray.rllib.core.learner.torch.torch_learner import TorchLearner
208
+
209
+ class MyLearner(TorchLearner):
210
+
211
+ def compute_losses(self, fwd_out, batch):
212
+ # Compute the losses per module based on `batch` and output of the
213
+ # forward pass (`fwd_out`). To access the (algorithm) config for a
214
+ # specific RLModule, do:
215
+ # `self.config.get_config_for_module([moduleID])`.
216
+ return {DEFAULT_MODULE_ID: module_loss}
217
+ """
218
+
219
+ framework: str = None
220
+ TOTAL_LOSS_KEY: str = "total_loss"
221
+
222
+ def __init__(
223
+ self,
224
+ *,
225
+ config: "AlgorithmConfig",
226
+ module_spec: Optional[Union[RLModuleSpec, MultiRLModuleSpec]] = None,
227
+ module: Optional[RLModule] = None,
228
+ ):
229
+ # TODO (sven): Figure out how to do this
230
+ self.config = config.copy(copy_frozen=False)
231
+ self._module_spec: Optional[MultiRLModuleSpec] = module_spec
232
+ self._module_obj: Optional[MultiRLModule] = module
233
+
234
+ # Make node and device of this Learner available.
235
+ self._node = platform.node()
236
+ self._device = None
237
+
238
+ # Set a seed, if necessary.
239
+ if self.config.seed is not None:
240
+ update_global_seed_if_necessary(self.framework, self.config.seed)
241
+
242
+ # Whether self.build has already been called.
243
+ self._is_built = False
244
+
245
+ # These are the attributes that are set during build.
246
+
247
+ # The actual MultiRLModule used by this Learner.
248
+ self._module: Optional[MultiRLModule] = None
249
+ self._weights_seq_no = 0
250
+ # Our Learner connector pipeline.
251
+ self._learner_connector: Optional[LearnerConnectorPipeline] = None
252
+ # These are set for properly applying optimizers and adding or removing modules.
253
+ self._optimizer_parameters: Dict[Optimizer, List[ParamRef]] = {}
254
+ self._named_optimizers: Dict[str, Optimizer] = {}
255
+ self._params: ParamDict = {}
256
+ # Dict mapping ModuleID to a list of optimizer names. Note that the optimizer
257
+ # name includes the ModuleID as a prefix: optimizer_name=`[ModuleID]_[.. rest]`.
258
+ self._module_optimizers: Dict[ModuleID, List[str]] = defaultdict(list)
259
+ self._optimizer_name_to_module: Dict[str, ModuleID] = {}
260
+
261
+ # Only manage optimizer's learning rate if user has NOT overridden
262
+ # the `configure_optimizers_for_module` method. Otherwise, leave responsibility
263
+ # to handle lr-updates entirely in user's hands.
264
+ self._optimizer_lr_schedules: Dict[Optimizer, Scheduler] = {}
265
+
266
+ # The Learner's own MetricsLogger to be used to log RLlib's built-in metrics or
267
+ # custom user-defined ones (e.g. custom loss values). When returning from an
268
+ # `update_from_...()` method call, the Learner will do a `self.metrics.reduce()`
269
+ # and return the resulting (reduced) dict.
270
+ self.metrics = MetricsLogger()
271
+
272
+ # In case of offline learning and multiple learners, each learner receives a
273
+ # repeatable iterator that iterates over a split of the streamed data.
274
+ self.iterator: DataIterator = None
275
+
276
+ # TODO (sven): Do we really need this API? It seems like LearnerGroup constructs
277
+ # all Learner workers and then immediately builds them any ways? Unless there is
278
+ # a reason related to Train worker group setup.
279
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
280
+ def build(self) -> None:
281
+ """Builds the Learner.
282
+
283
+ This method should be called before the learner is used. It is responsible for
284
+ setting up the LearnerConnectorPipeline, the RLModule, optimizer(s), and
285
+ (optionally) the optimizers' learning rate schedulers.
286
+ """
287
+ if self._is_built:
288
+ logger.debug("Learner already built. Skipping build.")
289
+ return
290
+
291
+ # Build learner connector pipeline used on this Learner worker.
292
+ self._learner_connector = None
293
+ # If the Algorithm uses aggregation actors to run episodes through the learner
294
+ # connector, its Learners don't need a connector pipelines and instead learn
295
+ # directly from pre-loaded batches already on the GPU.
296
+ if self.config.num_aggregator_actors_per_learner == 0:
297
+ # TODO (sven): Figure out which space to provide here. For now,
298
+ # it doesn't matter, as the default connector piece doesn't use
299
+ # this information anyway.
300
+ # module_spec = self._module_spec.as_multi_rl_module_spec()
301
+ self._learner_connector = self.config.build_learner_connector(
302
+ input_observation_space=None,
303
+ input_action_space=None,
304
+ device=self._device,
305
+ )
306
+
307
+ # Build the module to be trained by this learner.
308
+ self._module = self._make_module()
309
+
310
+ # Configure, construct, and register all optimizers needed to train
311
+ # `self.module`.
312
+ self.configure_optimizers()
313
+
314
+ # Log the number of trainable/non-trainable parameters.
315
+ self._log_trainable_parameters()
316
+
317
+ self._is_built = True
318
+
319
+ @property
320
+ def distributed(self) -> bool:
321
+ """Whether the learner is running in distributed mode."""
322
+ return self.config.num_learners > 1
323
+
324
+ @property
325
+ def module(self) -> MultiRLModule:
326
+ """The MultiRLModule that is being trained."""
327
+ return self._module
328
+
329
+ @property
330
+ def node(self) -> Any:
331
+ return self._node
332
+
333
+ @property
334
+ def device(self) -> Any:
335
+ return self._device
336
+
337
+ def register_optimizer(
338
+ self,
339
+ *,
340
+ module_id: ModuleID = ALL_MODULES,
341
+ optimizer_name: str = DEFAULT_OPTIMIZER,
342
+ optimizer: Optimizer,
343
+ params: Sequence[Param],
344
+ lr_or_lr_schedule: Optional[LearningRateOrSchedule] = None,
345
+ ) -> None:
346
+ """Registers an optimizer with a ModuleID, name, param list and lr-scheduler.
347
+
348
+ Use this method in your custom implementations of either
349
+ `self.configure_optimizers()` or `self.configure_optimzers_for_module()` (you
350
+ should only override one of these!). If you register a learning rate Scheduler
351
+ setting together with an optimizer, RLlib will automatically keep this
352
+ optimizer's learning rate updated throughout the training process.
353
+ Alternatively, you can construct your optimizers directly with a learning rate
354
+ and manage learning rate scheduling or updating yourself.
355
+
356
+ Args:
357
+ module_id: The `module_id` under which to register the optimizer. If not
358
+ provided, will assume ALL_MODULES.
359
+ optimizer_name: The name (str) of the optimizer. If not provided, will
360
+ assume DEFAULT_OPTIMIZER.
361
+ optimizer: The already instantiated optimizer object to register.
362
+ params: A list of parameters (framework-specific variables) that will be
363
+ trained/updated
364
+ lr_or_lr_schedule: An optional fixed learning rate or learning rate schedule
365
+ setup. If provided, RLlib will automatically keep the optimizer's
366
+ learning rate updated.
367
+ """
368
+ # Validate optimizer instance and its param list.
369
+ self._check_registered_optimizer(optimizer, params)
370
+
371
+ full_registration_name = module_id + "_" + optimizer_name
372
+
373
+ # Store the given optimizer under the given `module_id`.
374
+ self._module_optimizers[module_id].append(full_registration_name)
375
+ self._optimizer_name_to_module[full_registration_name] = module_id
376
+
377
+ # Store the optimizer instance under its full `module_id`_`optimizer_name`
378
+ # key.
379
+ self._named_optimizers[full_registration_name] = optimizer
380
+
381
+ # Store all given parameters under the given optimizer.
382
+ self._optimizer_parameters[optimizer] = []
383
+ for param in params:
384
+ param_ref = self.get_param_ref(param)
385
+ self._optimizer_parameters[optimizer].append(param_ref)
386
+ self._params[param_ref] = param
387
+
388
+ # Optionally, store a scheduler object along with this optimizer. If such a
389
+ # setting is provided, RLlib will handle updating the optimizer's learning rate
390
+ # over time.
391
+ if lr_or_lr_schedule is not None:
392
+ # Validate the given setting.
393
+ Scheduler.validate(
394
+ fixed_value_or_schedule=lr_or_lr_schedule,
395
+ setting_name="lr_or_lr_schedule",
396
+ description="learning rate or schedule",
397
+ )
398
+ # Create the scheduler object for this optimizer.
399
+ scheduler = Scheduler(
400
+ fixed_value_or_schedule=lr_or_lr_schedule,
401
+ framework=self.framework,
402
+ device=self._device,
403
+ )
404
+ self._optimizer_lr_schedules[optimizer] = scheduler
405
+ # Set the optimizer to the current (first) learning rate.
406
+ self._set_optimizer_lr(
407
+ optimizer=optimizer,
408
+ lr=scheduler.get_current_value(),
409
+ )
410
+
411
+ @OverrideToImplementCustomLogic
412
+ def configure_optimizers(self) -> None:
413
+ """Configures, creates, and registers the optimizers for this Learner.
414
+
415
+ Optimizers are responsible for updating the model's parameters during training,
416
+ based on the computed gradients.
417
+
418
+ Normally, you should not override this method for your custom algorithms
419
+ (which require certain optimizers), but rather override the
420
+ `self.configure_optimizers_for_module(module_id=..)` method and register those
421
+ optimizers in there that you need for the given `module_id`.
422
+
423
+ You can register an optimizer for any RLModule within `self.module` (or for
424
+ the ALL_MODULES ID) by calling `self.register_optimizer()` and passing the
425
+ module_id, optimizer_name (only in case you would like to register more than
426
+ one optimizer for a given module), the optimizer instane itself, a list
427
+ of all the optimizer's parameters (to be updated by the optimizer), and
428
+ an optional learning rate or learning rate schedule setting.
429
+
430
+ This method is called once during building (`self.build()`).
431
+ """
432
+ # The default implementation simply calls `self.configure_optimizers_for_module`
433
+ # on each RLModule within `self.module`.
434
+ for module_id in self.module.keys():
435
+ if self.rl_module_is_compatible(self.module[module_id]):
436
+ config = self.config.get_config_for_module(module_id)
437
+ self.configure_optimizers_for_module(module_id=module_id, config=config)
438
+
439
+ @OverrideToImplementCustomLogic
440
+ @abc.abstractmethod
441
+ def configure_optimizers_for_module(
442
+ self, module_id: ModuleID, config: "AlgorithmConfig" = None
443
+ ) -> None:
444
+ """Configures an optimizer for the given module_id.
445
+
446
+ This method is called for each RLModule in the MultiRLModule being
447
+ trained by the Learner, as well as any new module added during training via
448
+ `self.add_module()`. It should configure and construct one or more optimizers
449
+ and register them via calls to `self.register_optimizer()` along with the
450
+ `module_id`, an optional optimizer name (str), a list of the optimizer's
451
+ framework specific parameters (variables), and an optional learning rate value
452
+ or -schedule.
453
+
454
+ Args:
455
+ module_id: The module_id of the RLModule that is being configured.
456
+ config: The AlgorithmConfig specific to the given `module_id`.
457
+ """
458
+
459
+ @OverrideToImplementCustomLogic
460
+ @abc.abstractmethod
461
+ def compute_gradients(
462
+ self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
463
+ ) -> ParamDict:
464
+ """Computes the gradients based on the given losses.
465
+
466
+ Args:
467
+ loss_per_module: Dict mapping module IDs to their individual total loss
468
+ terms, computed by the individual `compute_loss_for_module()` calls.
469
+ The overall total loss (sum of loss terms over all modules) is stored
470
+ under `loss_per_module[ALL_MODULES]`.
471
+ **kwargs: Forward compatibility kwargs.
472
+
473
+ Returns:
474
+ The gradients in the same (flat) format as self._params. Note that all
475
+ top-level structures, such as module IDs, will not be present anymore in
476
+ the returned dict. It will merely map parameter tensor references to their
477
+ respective gradient tensors.
478
+ """
479
+
480
+ @OverrideToImplementCustomLogic
481
+ def postprocess_gradients(self, gradients_dict: ParamDict) -> ParamDict:
482
+ """Applies potential postprocessing operations on the gradients.
483
+
484
+ This method is called after gradients have been computed and modifies them
485
+ before they are applied to the respective module(s) by the optimizer(s).
486
+ This might include grad clipping by value, norm, or global-norm, or other
487
+ algorithm specific gradient postprocessing steps.
488
+
489
+ This default implementation calls `self.postprocess_gradients_for_module()`
490
+ on each of the sub-modules in our MultiRLModule: `self.module` and
491
+ returns the accumulated gradients dicts.
492
+
493
+ Args:
494
+ gradients_dict: A dictionary of gradients in the same (flat) format as
495
+ self._params. Note that top-level structures, such as module IDs,
496
+ will not be present anymore in this dict. It will merely map gradient
497
+ tensor references to gradient tensors.
498
+
499
+ Returns:
500
+ A dictionary with the updated gradients and the exact same (flat) structure
501
+ as the incoming `gradients_dict` arg.
502
+ """
503
+
504
+ # The flat gradients dict (mapping param refs to params), returned by this
505
+ # method.
506
+ postprocessed_gradients = {}
507
+
508
+ for module_id in self.module.keys():
509
+ # Send a gradients dict for only this `module_id` to the
510
+ # `self.postprocess_gradients_for_module()` method.
511
+ module_grads_dict = {}
512
+ for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
513
+ module_grads_dict.update(
514
+ self.filter_param_dict_for_optimizer(gradients_dict, optimizer)
515
+ )
516
+
517
+ module_grads_dict = self.postprocess_gradients_for_module(
518
+ module_id=module_id,
519
+ config=self.config.get_config_for_module(module_id),
520
+ module_gradients_dict=module_grads_dict,
521
+ )
522
+ assert isinstance(module_grads_dict, dict)
523
+
524
+ # Update our return dict.
525
+ postprocessed_gradients.update(module_grads_dict)
526
+
527
+ return postprocessed_gradients
528
+
529
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
530
+ def postprocess_gradients_for_module(
531
+ self,
532
+ *,
533
+ module_id: ModuleID,
534
+ config: Optional["AlgorithmConfig"] = None,
535
+ module_gradients_dict: ParamDict,
536
+ ) -> ParamDict:
537
+ """Applies postprocessing operations on the gradients of the given module.
538
+
539
+ Args:
540
+ module_id: The module ID for which we will postprocess computed gradients.
541
+ Note that `module_gradients_dict` already only carries those gradient
542
+ tensors that belong to this `module_id`. Other `module_id`'s gradients
543
+ are not available in this call.
544
+ config: The AlgorithmConfig specific to the given `module_id`.
545
+ module_gradients_dict: A dictionary of gradients in the same (flat) format
546
+ as self._params, mapping gradient refs to gradient tensors, which are to
547
+ be postprocessed. You may alter these tensors in place or create new
548
+ ones and return these in a new dict.
549
+
550
+ Returns:
551
+ A dictionary with the updated gradients and the exact same (flat) structure
552
+ as the incoming `module_gradients_dict` arg.
553
+ """
554
+ postprocessed_grads = {}
555
+
556
+ if config.grad_clip is None and not config.log_gradients:
557
+ postprocessed_grads.update(module_gradients_dict)
558
+ return postprocessed_grads
559
+
560
+ for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
561
+ grad_dict_to_clip = self.filter_param_dict_for_optimizer(
562
+ param_dict=module_gradients_dict,
563
+ optimizer=optimizer,
564
+ )
565
+ if config.grad_clip:
566
+ # Perform gradient clipping, if configured.
567
+ global_norm = self._get_clip_function()(
568
+ grad_dict_to_clip,
569
+ grad_clip=config.grad_clip,
570
+ grad_clip_by=config.grad_clip_by,
571
+ )
572
+ if config.grad_clip_by == "global_norm" or config.log_gradients:
573
+ # If we want to log gradients, but do not use the global norm
574
+ # for clipping compute it here.
575
+ if config.log_gradients and config.grad_clip_by != "global_norm":
576
+ # Compute the global norm of gradients.
577
+ global_norm = self._get_global_norm_function()(
578
+ # Note, `tf.linalg.global_norm` needs a list of tensors.
579
+ list(grad_dict_to_clip.values()),
580
+ )
581
+ self.metrics.log_value(
582
+ key=(module_id, f"gradients_{optimizer_name}_global_norm"),
583
+ value=global_norm,
584
+ window=1,
585
+ )
586
+ postprocessed_grads.update(grad_dict_to_clip)
587
+ # In the other case check, if we want to log gradients only.
588
+ elif config.log_gradients:
589
+ # Compute the global norm of gradients and log it.
590
+ global_norm = self._get_global_norm_function()(
591
+ # Note, `tf.linalg.global_norm` needs a list of tensors.
592
+ list(grad_dict_to_clip.values()),
593
+ )
594
+ self.metrics.log_value(
595
+ key=(module_id, f"gradients_{optimizer_name}_global_norm"),
596
+ value=global_norm,
597
+ window=1,
598
+ )
599
+
600
+ return postprocessed_grads
601
+
602
+ @OverrideToImplementCustomLogic
603
+ @abc.abstractmethod
604
+ def apply_gradients(self, gradients_dict: ParamDict) -> None:
605
+ """Applies the gradients to the MultiRLModule parameters.
606
+
607
+ Args:
608
+ gradients_dict: A dictionary of gradients in the same (flat) format as
609
+ self._params. Note that top-level structures, such as module IDs,
610
+ will not be present anymore in this dict. It will merely map gradient
611
+ tensor references to gradient tensors.
612
+ """
613
+
614
+ def get_optimizer(
615
+ self,
616
+ module_id: ModuleID = DEFAULT_MODULE_ID,
617
+ optimizer_name: str = DEFAULT_OPTIMIZER,
618
+ ) -> Optimizer:
619
+ """Returns the optimizer object, configured under the given module_id and name.
620
+
621
+ If only one optimizer was registered under `module_id` (or ALL_MODULES)
622
+ via the `self.register_optimizer` method, `optimizer_name` is assumed to be
623
+ DEFAULT_OPTIMIZER.
624
+
625
+ Args:
626
+ module_id: The ModuleID for which to return the configured optimizer.
627
+ If not provided, will assume DEFAULT_MODULE_ID.
628
+ optimizer_name: The name of the optimizer (registered under `module_id` via
629
+ `self.register_optimizer()`) to return. If not provided, will assume
630
+ DEFAULT_OPTIMIZER.
631
+
632
+ Returns:
633
+ The optimizer object, configured under the given `module_id` and
634
+ `optimizer_name`.
635
+ """
636
+ # `optimizer_name` could possibly be the full optimizer name (including the
637
+ # module_id under which it is registered).
638
+ if optimizer_name in self._named_optimizers:
639
+ return self._named_optimizers[optimizer_name]
640
+
641
+ # Normally, `optimizer_name` is just the optimizer's name, not including the
642
+ # `module_id`.
643
+ full_registration_name = module_id + "_" + optimizer_name
644
+ if full_registration_name in self._named_optimizers:
645
+ return self._named_optimizers[full_registration_name]
646
+
647
+ # No optimizer found.
648
+ raise KeyError(
649
+ f"Optimizer not found! module_id={module_id} "
650
+ f"optimizer_name={optimizer_name}"
651
+ )
652
+
653
+ def get_optimizers_for_module(
654
+ self, module_id: ModuleID = ALL_MODULES
655
+ ) -> List[Tuple[str, Optimizer]]:
656
+ """Returns a list of (optimizer_name, optimizer instance)-tuples for module_id.
657
+
658
+ Args:
659
+ module_id: The ModuleID for which to return the configured
660
+ (optimizer name, optimizer)-pairs. If not provided, will return
661
+ optimizers registered under ALL_MODULES.
662
+
663
+ Returns:
664
+ A list of tuples of the format: ([optimizer_name], [optimizer object]),
665
+ where optimizer_name is the name under which the optimizer was registered
666
+ in `self.register_optimizer`. If only a single optimizer was
667
+ configured for `module_id`, [optimizer_name] will be DEFAULT_OPTIMIZER.
668
+ """
669
+ named_optimizers = []
670
+ for full_registration_name in self._module_optimizers[module_id]:
671
+ optimizer = self._named_optimizers[full_registration_name]
672
+ # TODO (sven): How can we avoid registering optimziers under this
673
+ # constructed `[module_id]_[optim_name]` format?
674
+ optim_name = full_registration_name[len(module_id) + 1 :]
675
+ named_optimizers.append((optim_name, optimizer))
676
+ return named_optimizers
677
+
678
+ def filter_param_dict_for_optimizer(
679
+ self, param_dict: ParamDict, optimizer: Optimizer
680
+ ) -> ParamDict:
681
+ """Reduces the given ParamDict to contain only parameters for given optimizer.
682
+
683
+ Args:
684
+ param_dict: The ParamDict to reduce/filter down to the given `optimizer`.
685
+ The returned dict will be a subset of `param_dict` only containing keys
686
+ (param refs) that were registered together with `optimizer` (and thus
687
+ that `optimizer` is responsible for applying gradients to).
688
+ optimizer: The optimizer object to whose parameter refs the given
689
+ `param_dict` should be reduced.
690
+
691
+ Returns:
692
+ A new ParamDict only containing param ref keys that belong to `optimizer`.
693
+ """
694
+ # Return a sub-dict only containing those param_ref keys (and their values)
695
+ # that belong to the `optimizer`.
696
+ return {
697
+ ref: param_dict[ref]
698
+ for ref in self._optimizer_parameters[optimizer]
699
+ if ref in param_dict and param_dict[ref] is not None
700
+ }
701
+
702
+ @abc.abstractmethod
703
+ def get_param_ref(self, param: Param) -> Hashable:
704
+ """Returns a hashable reference to a trainable parameter.
705
+
706
+ This should be overridden in framework specific specialization. For example in
707
+ torch it will return the parameter itself, while in tf it returns the .ref() of
708
+ the variable. The purpose is to retrieve a unique reference to the parameters.
709
+
710
+ Args:
711
+ param: The parameter to get the reference to.
712
+
713
+ Returns:
714
+ A reference to the parameter.
715
+ """
716
+
717
+ @abc.abstractmethod
718
+ def get_parameters(self, module: RLModule) -> Sequence[Param]:
719
+ """Returns the list of parameters of a module.
720
+
721
+ This should be overridden in framework specific learner. For example in torch it
722
+ will return .parameters(), while in tf it returns .trainable_variables.
723
+
724
+ Args:
725
+ module: The module to get the parameters from.
726
+
727
+ Returns:
728
+ The parameters of the module.
729
+ """
730
+
731
+ @abc.abstractmethod
732
+ def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
733
+ """Converts the elements of a MultiAgentBatch to Tensors on the correct device.
734
+
735
+ Args:
736
+ batch: The MultiAgentBatch object to convert.
737
+
738
+ Returns:
739
+ The resulting MultiAgentBatch with framework-specific tensor values placed
740
+ on the correct device.
741
+ """
742
+
743
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
744
+ def add_module(
745
+ self,
746
+ *,
747
+ module_id: ModuleID,
748
+ module_spec: RLModuleSpec,
749
+ config_overrides: Optional[Dict] = None,
750
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
751
+ ) -> MultiRLModuleSpec:
752
+ """Adds a module to the underlying MultiRLModule.
753
+
754
+ Changes this Learner's config in order to make this architectural change
755
+ permanent wrt. to checkpointing.
756
+
757
+ Args:
758
+ module_id: The ModuleID of the module to be added.
759
+ module_spec: The ModuleSpec of the module to be added.
760
+ config_overrides: The `AlgorithmConfig` overrides that should apply to
761
+ the new Module, if any.
762
+ new_should_module_be_updated: An optional sequence of ModuleIDs or a
763
+ callable taking ModuleID and SampleBatchType and returning whether the
764
+ ModuleID should be updated (trained).
765
+ If None, will keep the existing setup in place. RLModules,
766
+ whose IDs are not in the list (or for which the callable
767
+ returns False) will not be updated.
768
+
769
+ Returns:
770
+ The new MultiRLModuleSpec (after the RLModule has been added).
771
+ """
772
+ validate_module_id(module_id, error=True)
773
+ self._check_is_built()
774
+
775
+ # Force-set inference-only = False.
776
+ module_spec = copy.deepcopy(module_spec)
777
+ module_spec.inference_only = False
778
+
779
+ # Build the new RLModule and add it to self.module.
780
+ module = module_spec.build()
781
+ self.module.add_module(module_id, module)
782
+
783
+ # Change our config (AlgorithmConfig) to contain the new Module.
784
+ # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
785
+ # but we'll deprecate config.policies soon anyway.
786
+ self.config.policies[module_id] = PolicySpec()
787
+ if config_overrides is not None:
788
+ self.config.multi_agent(
789
+ algorithm_config_overrides_per_module={module_id: config_overrides}
790
+ )
791
+ self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module))
792
+ self._module_spec = self.config.rl_module_spec
793
+ if new_should_module_be_updated is not None:
794
+ self.config.multi_agent(policies_to_train=new_should_module_be_updated)
795
+
796
+ # Allow the user to configure one or more optimizers for this new module.
797
+ self.configure_optimizers_for_module(
798
+ module_id=module_id,
799
+ config=self.config.get_config_for_module(module_id),
800
+ )
801
+ return self.config.rl_module_spec
802
+
803
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
804
+ def remove_module(
805
+ self,
806
+ module_id: ModuleID,
807
+ *,
808
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
809
+ ) -> MultiRLModuleSpec:
810
+ """Removes a module from the Learner.
811
+
812
+ Args:
813
+ module_id: The ModuleID of the module to be removed.
814
+ new_should_module_be_updated: An optional sequence of ModuleIDs or a
815
+ callable taking ModuleID and SampleBatchType and returning whether the
816
+ ModuleID should be updated (trained).
817
+ If None, will keep the existing setup in place. RLModules,
818
+ whose IDs are not in the list (or for which the callable
819
+ returns False) will not be updated.
820
+
821
+ Returns:
822
+ The new MultiRLModuleSpec (after the RLModule has been removed).
823
+ """
824
+ self._check_is_built()
825
+ module = self.module[module_id]
826
+
827
+ # Delete the removed module's parameters and optimizers.
828
+ if self.rl_module_is_compatible(module):
829
+ parameters = self.get_parameters(module)
830
+ for param in parameters:
831
+ param_ref = self.get_param_ref(param)
832
+ if param_ref in self._params:
833
+ del self._params[param_ref]
834
+ for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
835
+ del self._optimizer_parameters[optimizer]
836
+ name = module_id + "_" + optimizer_name
837
+ del self._named_optimizers[name]
838
+ if optimizer in self._optimizer_lr_schedules:
839
+ del self._optimizer_lr_schedules[optimizer]
840
+ del self._module_optimizers[module_id]
841
+
842
+ # Remove the module from the MultiRLModule.
843
+ self.module.remove_module(module_id)
844
+
845
+ # Change self.config to reflect the new architecture.
846
+ # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
847
+ # but we'll deprecate config.policies soon anyway.
848
+ del self.config.policies[module_id]
849
+ self.config.algorithm_config_overrides_per_module.pop(module_id, None)
850
+ if new_should_module_be_updated is not None:
851
+ self.config.multi_agent(policies_to_train=new_should_module_be_updated)
852
+ self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module))
853
+
854
+ # Remove all stats from the module from our metrics logger, so we don't report
855
+ # results from this module again.
856
+ if module_id in self.metrics.stats:
857
+ del self.metrics.stats[module_id]
858
+
859
+ return self.config.rl_module_spec
860
+
861
+ @OverrideToImplementCustomLogic
862
+ def should_module_be_updated(self, module_id, multi_agent_batch=None):
863
+ """Returns whether a module should be updated or not based on `self.config`.
864
+
865
+ Args:
866
+ module_id: The ModuleID that we want to query on whether this module
867
+ should be updated or not.
868
+ multi_agent_batch: An optional MultiAgentBatch to possibly provide further
869
+ information on the decision on whether the RLModule should be updated
870
+ or not.
871
+ """
872
+ should_module_be_updated_fn = self.config.policies_to_train
873
+ # If None, return True (by default, all modules should be updated).
874
+ if should_module_be_updated_fn is None:
875
+ return True
876
+ # If collection given, return whether `module_id` is in that container.
877
+ elif not callable(should_module_be_updated_fn):
878
+ return module_id in set(should_module_be_updated_fn)
879
+
880
+ return should_module_be_updated_fn(module_id, multi_agent_batch)
881
+
882
+ @OverrideToImplementCustomLogic
883
+ def compute_losses(
884
+ self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
885
+ ) -> Dict[str, Any]:
886
+ """Computes the loss(es) for the module being optimized.
887
+
888
+ This method must be overridden by MultiRLModule-specific Learners in order to
889
+ define the specific loss computation logic. If the algorithm is single-agent,
890
+ only `compute_loss_for_module()` should be overridden instead. If the algorithm
891
+ uses independent multi-agent learning (default behavior for RLlib's multi-agent
892
+ setups), also only `compute_loss_for_module()` should be overridden, but it will
893
+ be called for each individual RLModule inside the MultiRLModule.
894
+ It is recommended to not compute any forward passes within this method, and to
895
+ use the `forward_train()` outputs of the RLModule(s) to compute the required
896
+ loss tensors.
897
+ See here for a custom loss function example script:
898
+ https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py # noqa
899
+
900
+ Args:
901
+ fwd_out: Output from a call to the `forward_train()` method of the
902
+ underlying MultiRLModule (`self.module`) during training
903
+ (`self.update()`).
904
+ batch: The train batch that was used to compute `fwd_out`.
905
+
906
+ Returns:
907
+ A dictionary mapping module IDs to individual loss terms.
908
+ """
909
+ loss_per_module = {}
910
+ for module_id in fwd_out:
911
+ module_batch = batch[module_id]
912
+ module_fwd_out = fwd_out[module_id]
913
+
914
+ module = self.module[module_id].unwrapped()
915
+ if isinstance(module, SelfSupervisedLossAPI):
916
+ loss = module.compute_self_supervised_loss(
917
+ learner=self,
918
+ module_id=module_id,
919
+ config=self.config.get_config_for_module(module_id),
920
+ batch=module_batch,
921
+ fwd_out=module_fwd_out,
922
+ )
923
+ else:
924
+ loss = self.compute_loss_for_module(
925
+ module_id=module_id,
926
+ config=self.config.get_config_for_module(module_id),
927
+ batch=module_batch,
928
+ fwd_out=module_fwd_out,
929
+ )
930
+ loss_per_module[module_id] = loss
931
+
932
+ return loss_per_module
933
+
934
+ @OverrideToImplementCustomLogic
935
+ @abc.abstractmethod
936
+ def compute_loss_for_module(
937
+ self,
938
+ *,
939
+ module_id: ModuleID,
940
+ config: "AlgorithmConfig",
941
+ batch: Dict[str, Any],
942
+ fwd_out: Dict[str, TensorType],
943
+ ) -> TensorType:
944
+ """Computes the loss for a single module.
945
+
946
+ Think of this as computing loss for a single agent. For multi-agent use-cases
947
+ that require more complicated computation for loss, consider overriding the
948
+ `compute_losses` method instead.
949
+
950
+ Args:
951
+ module_id: The id of the module.
952
+ config: The AlgorithmConfig specific to the given `module_id`.
953
+ batch: The train batch for this particular module.
954
+ fwd_out: The output of the forward pass for this particular module.
955
+
956
+ Returns:
957
+ A single total loss tensor. If you have more than one optimizer on the
958
+ provided `module_id` and would like to compute gradients separately using
959
+ these different optimizers, simply add up the individual loss terms for
960
+ each optimizer and return the sum. Also, for recording/logging any
961
+ individual loss terms, you can use the `Learner.metrics.log_value(
962
+ key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See:
963
+ :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more
964
+ information.
965
+ """
966
+
967
+ def update_from_batch(
968
+ self,
969
+ batch: MultiAgentBatch,
970
+ *,
971
+ # TODO (sven): Make this a more formal structure with its own type.
972
+ timesteps: Optional[Dict[str, Any]] = None,
973
+ num_epochs: int = 1,
974
+ minibatch_size: Optional[int] = None,
975
+ shuffle_batch_per_epoch: bool = False,
976
+ # Deprecated args.
977
+ num_iters=DEPRECATED_VALUE,
978
+ **kwargs,
979
+ ) -> ResultDict:
980
+ """Run `num_epochs` epochs over the given train batch.
981
+
982
+ You can use this method to take more than one backward pass on the batch.
983
+ The same `minibatch_size` and `num_epochs` will be used for all module ids in
984
+ MultiRLModule.
985
+
986
+ Args:
987
+ batch: A batch of training data to update from.
988
+ timesteps: Timesteps dict, which must have the key
989
+ `NUM_ENV_STEPS_SAMPLED_LIFETIME`.
990
+ # TODO (sven): Make this a more formal structure with its own type.
991
+ num_epochs: The number of complete passes over the entire train batch. Each
992
+ pass might be further split into n minibatches (if `minibatch_size`
993
+ provided).
994
+ minibatch_size: The size of minibatches to use to further split the train
995
+ `batch` into sub-batches. The `batch` is then iterated over n times
996
+ where n is `len(batch) // minibatch_size`.
997
+ shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
998
+ If the train batch has a time rank (axis=1), shuffling will only take
999
+ place along the batch axis to not disturb any intact (episode)
1000
+ trajectories. Also, shuffling is always skipped if `minibatch_size` is
1001
+ None, meaning the entire train batch is processed each epoch, making it
1002
+ unnecessary to shuffle.
1003
+
1004
+ Returns:
1005
+ A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
1006
+ returned dict may be arbitrarily nested and must have `Stats` objects at
1007
+ all its leafs, allowing components further downstream (i.e. a user of this
1008
+ Learner) to further reduce these results (for example over n parallel
1009
+ Learners).
1010
+ """
1011
+ if num_iters != DEPRECATED_VALUE:
1012
+ deprecation_warning(
1013
+ old="Learner.update_from_episodes(num_iters=...)",
1014
+ new="Learner.update_from_episodes(num_epochs=...)",
1015
+ error=True,
1016
+ )
1017
+ self._update_from_batch_or_episodes(
1018
+ batch=batch,
1019
+ timesteps=timesteps,
1020
+ num_epochs=num_epochs,
1021
+ minibatch_size=minibatch_size,
1022
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch,
1023
+ )
1024
+ return self.metrics.reduce()
1025
+
1026
+ def update_from_episodes(
1027
+ self,
1028
+ episodes: List[EpisodeType],
1029
+ *,
1030
+ # TODO (sven): Make this a more formal structure with its own type.
1031
+ timesteps: Optional[Dict[str, Any]] = None,
1032
+ num_epochs: int = 1,
1033
+ minibatch_size: Optional[int] = None,
1034
+ shuffle_batch_per_epoch: bool = False,
1035
+ num_total_minibatches: int = 0,
1036
+ # Deprecated args.
1037
+ num_iters=DEPRECATED_VALUE,
1038
+ ) -> ResultDict:
1039
+ """Run `num_epochs` epochs over the train batch generated from `episodes`.
1040
+
1041
+ You can use this method to take more than one backward pass on the batch.
1042
+ The same `minibatch_size` and `num_epochs` will be used for all module ids in
1043
+ MultiRLModule.
1044
+
1045
+ Args:
1046
+ episodes: An list of episode objects to update from.
1047
+ timesteps: Timesteps dict, which must have the key
1048
+ `NUM_ENV_STEPS_SAMPLED_LIFETIME`.
1049
+ # TODO (sven): Make this a more formal structure with its own type.
1050
+ num_epochs: The number of complete passes over the entire train batch. Each
1051
+ pass might be further split into n minibatches (if `minibatch_size`
1052
+ provided). The train batch is generated from the given `episodes`
1053
+ through the Learner connector pipeline.
1054
+ minibatch_size: The size of minibatches to use to further split the train
1055
+ `batch` into sub-batches. The `batch` is then iterated over n times
1056
+ where n is `len(batch) // minibatch_size`. The train batch is generated
1057
+ from the given `episodes` through the Learner connector pipeline.
1058
+ shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
1059
+ If the train batch has a time rank (axis=1), shuffling will only take
1060
+ place along the batch axis to not disturb any intact (episode)
1061
+ trajectories. Also, shuffling is always skipped if `minibatch_size` is
1062
+ None, meaning the entire train batch is processed each epoch, making it
1063
+ unnecessary to shuffle. The train batch is generated from the given
1064
+ `episodes` through the Learner connector pipeline.
1065
+ num_total_minibatches: The total number of minibatches to loop through
1066
+ (over all `num_epochs` epochs). It's only required to set this to != 0
1067
+ in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes
1068
+ themselves are roughly sharded equally, however, they might contain
1069
+ SingleAgentEpisodes with very lopsided length distributions. Thus,
1070
+ without this fixed, pre-computed value, one Learner might go through a
1071
+ different number of minibatche passes than others causing a deadlock.
1072
+
1073
+ Returns:
1074
+ A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
1075
+ returned dict may be arbitrarily nested and must have `Stats` objects at
1076
+ all its leafs, allowing components further downstream (i.e. a user of this
1077
+ Learner) to further reduce these results (for example over n parallel
1078
+ Learners).
1079
+ """
1080
+ if num_iters != DEPRECATED_VALUE:
1081
+ deprecation_warning(
1082
+ old="Learner.update_from_episodes(num_iters=...)",
1083
+ new="Learner.update_from_episodes(num_epochs=...)",
1084
+ error=True,
1085
+ )
1086
+ self._update_from_batch_or_episodes(
1087
+ episodes=episodes,
1088
+ timesteps=timesteps,
1089
+ num_epochs=num_epochs,
1090
+ minibatch_size=minibatch_size,
1091
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch,
1092
+ num_total_minibatches=num_total_minibatches,
1093
+ )
1094
+ return self.metrics.reduce()
1095
+
1096
+ def update_from_iterator(
1097
+ self,
1098
+ iterator,
1099
+ *,
1100
+ timesteps: Optional[Dict[str, Any]] = None,
1101
+ minibatch_size: Optional[int] = None,
1102
+ num_iters: int = None,
1103
+ **kwargs,
1104
+ ):
1105
+ if "num_epochs" in kwargs:
1106
+ raise ValueError(
1107
+ "`num_epochs` arg NOT supported by Learner.update_from_iterator! Use "
1108
+ "`num_iters` instead."
1109
+ )
1110
+
1111
+ if not self.iterator:
1112
+ self.iterator = iterator
1113
+
1114
+ self._check_is_built()
1115
+
1116
+ # Call `before_gradient_based_update` to allow for non-gradient based
1117
+ # preparations-, logging-, and update logic to happen.
1118
+ self.before_gradient_based_update(timesteps=timesteps or {})
1119
+
1120
+ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
1121
+ # Note, the incoming batch is a dictionary with a numpy array
1122
+ # holding the `MultiAgentBatch`.
1123
+ batch = self._convert_batch_type(batch["batch"][0])
1124
+ return {"batch": self._set_slicing_by_batch_id(batch, value=True)}
1125
+
1126
+ i = 0
1127
+ logger.debug(f"===> [Learner {id(self)}]: Looping through batches ... ")
1128
+ for batch in self.iterator.iter_batches(
1129
+ # Note, this needs to be one b/c data is already mapped to
1130
+ # `MultiAgentBatch`es of `minibatch_size`.
1131
+ batch_size=1,
1132
+ _finalize_fn=_finalize_fn,
1133
+ **kwargs,
1134
+ ):
1135
+ # Update the iteration counter.
1136
+ i += 1
1137
+
1138
+ # Note, `_finalize_fn` must return a dictionary.
1139
+ batch = batch["batch"]
1140
+ logger.debug(
1141
+ f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows."
1142
+ )
1143
+ # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
1144
+ # found in this batch. If not, throw an error.
1145
+ unknown_module_ids = set(batch.policy_batches.keys()) - set(
1146
+ self.module.keys()
1147
+ )
1148
+ if len(unknown_module_ids) > 0:
1149
+ raise ValueError(
1150
+ "Batch contains one or more ModuleIDs that are not in this "
1151
+ f"Learner! Found IDs: {unknown_module_ids}"
1152
+ )
1153
+
1154
+ # Log metrics.
1155
+ self._log_steps_trained_metrics(batch)
1156
+
1157
+ # Make the actual in-graph/traced `_update` call. This should return
1158
+ # all tensor values (no numpy).
1159
+ fwd_out, loss_per_module, tensor_metrics = self._update(
1160
+ batch.policy_batches
1161
+ )
1162
+ # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
1163
+ # to actual (numpy) values.
1164
+ self.metrics.tensors_to_numpy(tensor_metrics)
1165
+
1166
+ self._set_slicing_by_batch_id(batch, value=False)
1167
+ # If `num_iters` is reached break and return.
1168
+ if num_iters and i == num_iters:
1169
+ break
1170
+
1171
+ logger.debug(
1172
+ f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}"
1173
+ )
1174
+
1175
+ # Log all individual RLModules' loss terms and its registered optimizers'
1176
+ # current learning rates.
1177
+ for mid, loss in convert_to_numpy(loss_per_module).items():
1178
+ self.metrics.log_value(
1179
+ key=(mid, self.TOTAL_LOSS_KEY),
1180
+ value=loss,
1181
+ window=1,
1182
+ )
1183
+ # Call `after_gradient_based_update` to allow for non-gradient based
1184
+ # cleanups-, logging-, and update logic to happen.
1185
+ # TODO (simon): Check, if this should stay here, when running multiple
1186
+ # gradient steps inside the iterator loop above (could be a complete epoch)
1187
+ # the target networks might need to be updated earlier.
1188
+ self.after_gradient_based_update(timesteps=timesteps or {})
1189
+
1190
+ # Reduce results across all minibatch update steps.
1191
+ return self.metrics.reduce()
1192
+
1193
+ @OverrideToImplementCustomLogic
1194
+ @abc.abstractmethod
1195
+ def _update(
1196
+ self,
1197
+ batch: Dict[str, Any],
1198
+ **kwargs,
1199
+ ) -> Tuple[Any, Any, Any]:
1200
+ """Contains all logic for an in-graph/traceable update step.
1201
+
1202
+ Framework specific subclasses must implement this method. This should include
1203
+ calls to the RLModule's `forward_train`, `compute_loss`, compute_gradients`,
1204
+ `postprocess_gradients`, and `apply_gradients` methods and return a tuple
1205
+ with all the individual results.
1206
+
1207
+ Args:
1208
+ batch: The train batch already converted to a Dict mapping str to (possibly
1209
+ nested) tensors.
1210
+ kwargs: Forward compatibility kwargs.
1211
+
1212
+ Returns:
1213
+ A tuple consisting of:
1214
+ 1) The `forward_train()` output of the RLModule,
1215
+ 2) the loss_per_module dictionary mapping module IDs to individual loss
1216
+ tensors
1217
+ 3) a metrics dict mapping module IDs to metrics key/value pairs.
1218
+
1219
+ """
1220
+
1221
+ @override(Checkpointable)
1222
+ def get_state(
1223
+ self,
1224
+ components: Optional[Union[str, Collection[str]]] = None,
1225
+ *,
1226
+ not_components: Optional[Union[str, Collection[str]]] = None,
1227
+ **kwargs,
1228
+ ) -> StateDict:
1229
+ self._check_is_built()
1230
+
1231
+ state = {
1232
+ "should_module_be_updated": self.config.policies_to_train,
1233
+ }
1234
+
1235
+ if self._check_component(COMPONENT_RL_MODULE, components, not_components):
1236
+ state[COMPONENT_RL_MODULE] = self.module.get_state(
1237
+ components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
1238
+ not_components=self._get_subcomponents(
1239
+ COMPONENT_RL_MODULE, not_components
1240
+ ),
1241
+ **kwargs,
1242
+ )
1243
+ state[WEIGHTS_SEQ_NO] = self._weights_seq_no
1244
+ if self._check_component(COMPONENT_OPTIMIZER, components, not_components):
1245
+ state[COMPONENT_OPTIMIZER] = self._get_optimizer_state()
1246
+
1247
+ if self._check_component(COMPONENT_METRICS_LOGGER, components, not_components):
1248
+ # TODO (sven): Make `MetricsLogger` a Checkpointable.
1249
+ state[COMPONENT_METRICS_LOGGER] = self.metrics.get_state()
1250
+
1251
+ return state
1252
+
1253
+ @override(Checkpointable)
1254
+ def set_state(self, state: StateDict) -> None:
1255
+ self._check_is_built()
1256
+
1257
+ weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
1258
+
1259
+ if COMPONENT_RL_MODULE in state:
1260
+ if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
1261
+ self.module.set_state(state[COMPONENT_RL_MODULE])
1262
+
1263
+ if COMPONENT_OPTIMIZER in state:
1264
+ self._set_optimizer_state(state[COMPONENT_OPTIMIZER])
1265
+
1266
+ # Update our weights_seq_no, if the new one is > 0.
1267
+ if weights_seq_no > 0:
1268
+ self._weights_seq_no = weights_seq_no
1269
+
1270
+ # Update our trainable Modules information/function via our config.
1271
+ # If not provided in state (None), all Modules will be trained by default.
1272
+ if "should_module_be_updated" in state:
1273
+ self.config.multi_agent(policies_to_train=state["should_module_be_updated"])
1274
+
1275
+ # TODO (sven): Make `MetricsLogger` a Checkpointable.
1276
+ if COMPONENT_METRICS_LOGGER in state:
1277
+ self.metrics.set_state(state[COMPONENT_METRICS_LOGGER])
1278
+
1279
+ @override(Checkpointable)
1280
+ def get_ctor_args_and_kwargs(self):
1281
+ return (
1282
+ (), # *args,
1283
+ {
1284
+ "config": self.config,
1285
+ "module_spec": self._module_spec,
1286
+ "module": self._module_obj,
1287
+ }, # **kwargs
1288
+ )
1289
+
1290
+ @override(Checkpointable)
1291
+ def get_checkpointable_components(self):
1292
+ if not self._check_is_built(error=False):
1293
+ self.build()
1294
+ return [
1295
+ (COMPONENT_RL_MODULE, self.module),
1296
+ ]
1297
+
1298
+ def _get_optimizer_state(self) -> StateDict:
1299
+ """Returns the state of all optimizers currently registered in this Learner.
1300
+
1301
+ Returns:
1302
+ The current state of all optimizers currently registered in this Learner.
1303
+ """
1304
+ raise NotImplementedError
1305
+
1306
+ def _set_optimizer_state(self, state: StateDict) -> None:
1307
+ """Sets the state of all optimizers currently registered in this Learner.
1308
+
1309
+ Args:
1310
+ state: The state of the optimizers.
1311
+ """
1312
+ raise NotImplementedError
1313
+
1314
+ def _update_from_batch_or_episodes(
1315
+ self,
1316
+ *,
1317
+ # TODO (sven): We should allow passing in a single agent batch here
1318
+ # as well for simplicity.
1319
+ batch: Optional[MultiAgentBatch] = None,
1320
+ episodes: Optional[List[EpisodeType]] = None,
1321
+ # TODO (sven): Make this a more formal structure with its own type.
1322
+ timesteps: Optional[Dict[str, Any]] = None,
1323
+ # TODO (sven): Deprecate these in favor of config attributes for only those
1324
+ # algos that actually need (and know how) to do minibatching.
1325
+ num_epochs: int = 1,
1326
+ minibatch_size: Optional[int] = None,
1327
+ shuffle_batch_per_epoch: bool = False,
1328
+ num_total_minibatches: int = 0,
1329
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
1330
+
1331
+ self._check_is_built()
1332
+
1333
+ # Call `before_gradient_based_update` to allow for non-gradient based
1334
+ # preparations-, logging-, and update logic to happen.
1335
+ self.before_gradient_based_update(timesteps=timesteps or {})
1336
+
1337
+ # Resolve batch/episodes being ray object refs (instead of
1338
+ # actual batch/episodes objects).
1339
+ if isinstance(batch, ray.ObjectRef):
1340
+ batch = ray.get(batch)
1341
+ if isinstance(episodes, ray.ObjectRef):
1342
+ episodes = ray.get(episodes)
1343
+ elif isinstance(episodes, list) and isinstance(episodes[0], ray.ObjectRef):
1344
+ # It's possible that individual refs are invalid due to the EnvRunner
1345
+ # that produced the ref has crashed or had its entire node go down.
1346
+ # In this case, try each ref individually and collect only valid results.
1347
+ try:
1348
+ episodes = tree.flatten(ray.get(episodes))
1349
+ except ray.exceptions.OwnerDiedError:
1350
+ episode_refs = episodes
1351
+ episodes = []
1352
+ for ref in episode_refs:
1353
+ try:
1354
+ episodes.extend(ray.get(ref))
1355
+ except ray.exceptions.OwnerDiedError:
1356
+ pass
1357
+
1358
+ # Call the learner connector on the given `episodes` (if we have one).
1359
+ if episodes is not None and self._learner_connector is not None:
1360
+ # Call the learner connector pipeline.
1361
+ shared_data = {}
1362
+ batch = self._learner_connector(
1363
+ rl_module=self.module,
1364
+ batch=batch if batch is not None else {},
1365
+ episodes=episodes,
1366
+ shared_data=shared_data,
1367
+ metrics=self.metrics,
1368
+ )
1369
+ # Convert to a batch.
1370
+ # TODO (sven): Try to not require MultiAgentBatch anymore.
1371
+ batch = MultiAgentBatch(
1372
+ {
1373
+ module_id: (
1374
+ SampleBatch(module_data, _zero_padded=True)
1375
+ if shared_data.get(f"_zero_padded_for_mid={module_id}")
1376
+ else SampleBatch(module_data)
1377
+ )
1378
+ for module_id, module_data in batch.items()
1379
+ },
1380
+ env_steps=sum(len(e) for e in episodes),
1381
+ )
1382
+ # Single-agent SampleBatch: Have to convert to MultiAgentBatch.
1383
+ elif isinstance(batch, SampleBatch):
1384
+ assert len(self.module) == 1
1385
+ batch = MultiAgentBatch(
1386
+ {next(iter(self.module.keys())): batch}, env_steps=len(batch)
1387
+ )
1388
+
1389
+ # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
1390
+ # found in this batch. If not, throw an error.
1391
+ unknown_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys())
1392
+ if len(unknown_module_ids) > 0:
1393
+ raise ValueError(
1394
+ "Batch contains one or more ModuleIDs that are not in this Learner! "
1395
+ f"Found IDs: {unknown_module_ids}"
1396
+ )
1397
+
1398
+ # TODO: Move this into LearnerConnector pipeline?
1399
+ # Filter out those RLModules from the final train batch that should not be
1400
+ # updated.
1401
+ for module_id in list(batch.policy_batches.keys()):
1402
+ if not self.should_module_be_updated(module_id, batch):
1403
+ del batch.policy_batches[module_id]
1404
+
1405
+ # Log all timesteps (env, agent, modules) based on given episodes/batch.
1406
+ self._log_steps_trained_metrics(batch)
1407
+
1408
+ if minibatch_size:
1409
+ batch_iter = MiniBatchCyclicIterator
1410
+ elif num_epochs > 1:
1411
+ # `minibatch_size` was not set but `num_epochs` > 1.
1412
+ # Under the old training stack, users could do multiple epochs
1413
+ # over a batch without specifying a minibatch size. We enable
1414
+ # this behavior here by setting the minibatch size to be the size
1415
+ # of the batch (e.g. 1 minibatch of size batch.count)
1416
+ minibatch_size = batch.count
1417
+ # Note that there is no need to shuffle here, b/c we don't have minibatches.
1418
+ batch_iter = MiniBatchCyclicIterator
1419
+ else:
1420
+ # `minibatch_size` and `num_epochs` are not set by the user.
1421
+ batch_iter = MiniBatchDummyIterator
1422
+
1423
+ batch = self._set_slicing_by_batch_id(batch, value=True)
1424
+
1425
+ for tensor_minibatch in batch_iter(
1426
+ batch,
1427
+ num_epochs=num_epochs,
1428
+ minibatch_size=minibatch_size,
1429
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch and (num_epochs > 1),
1430
+ num_total_minibatches=num_total_minibatches,
1431
+ ):
1432
+ # Make the actual in-graph/traced `_update` call. This should return
1433
+ # all tensor values (no numpy).
1434
+ fwd_out, loss_per_module, tensor_metrics = self._update(
1435
+ tensor_minibatch.policy_batches
1436
+ )
1437
+
1438
+ # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
1439
+ # to actual (numpy) values.
1440
+ self.metrics.tensors_to_numpy(tensor_metrics)
1441
+
1442
+ # Log all individual RLModules' loss terms and its registered optimizers'
1443
+ # current learning rates.
1444
+ for mid, loss in convert_to_numpy(loss_per_module).items():
1445
+ self.metrics.log_value(
1446
+ key=(mid, self.TOTAL_LOSS_KEY),
1447
+ value=loss,
1448
+ window=1,
1449
+ )
1450
+
1451
+ self._weights_seq_no += 1
1452
+ self.metrics.log_dict(
1453
+ {
1454
+ (mid, WEIGHTS_SEQ_NO): self._weights_seq_no
1455
+ for mid in batch.policy_batches.keys()
1456
+ },
1457
+ window=1,
1458
+ )
1459
+
1460
+ self._set_slicing_by_batch_id(batch, value=False)
1461
+
1462
+ # Call `after_gradient_based_update` to allow for non-gradient based
1463
+ # cleanups-, logging-, and update logic to happen.
1464
+ self.after_gradient_based_update(timesteps=timesteps or {})
1465
+
1466
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
1467
+ def before_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
1468
+ """Called before gradient-based updates are completed.
1469
+
1470
+ Should be overridden to implement custom preparation-, logging-, or
1471
+ non-gradient-based Learner/RLModule update logic before(!) gradient-based
1472
+ updates are performed.
1473
+
1474
+ Args:
1475
+ timesteps: Timesteps dict, which must have the key
1476
+ `NUM_ENV_STEPS_SAMPLED_LIFETIME`.
1477
+ # TODO (sven): Make this a more formal structure with its own type.
1478
+ """
1479
+
1480
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
1481
+ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
1482
+ """Called after gradient-based updates are completed.
1483
+
1484
+ Should be overridden to implement custom cleanup-, logging-, or non-gradient-
1485
+ based Learner/RLModule update logic after(!) gradient-based updates have been
1486
+ completed.
1487
+
1488
+ Args:
1489
+ timesteps: Timesteps dict, which must have the key
1490
+ `NUM_ENV_STEPS_SAMPLED_LIFETIME`.
1491
+ # TODO (sven): Make this a more formal structure with its own type.
1492
+ """
1493
+ # Only update this optimizer's lr, if a scheduler has been registered
1494
+ # along with it.
1495
+ for module_id, optimizer_names in self._module_optimizers.items():
1496
+ for optimizer_name in optimizer_names:
1497
+ optimizer = self._named_optimizers[optimizer_name]
1498
+ # Update and log learning rate of this optimizer.
1499
+ lr_schedule = self._optimizer_lr_schedules.get(optimizer)
1500
+ if lr_schedule is not None:
1501
+ new_lr = lr_schedule.update(
1502
+ timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)
1503
+ )
1504
+ self._set_optimizer_lr(optimizer, lr=new_lr)
1505
+ self.metrics.log_value(
1506
+ # Cut out the module ID from the beginning since it's already part
1507
+ # of the key sequence: (ModuleID, "[optim name]_lr").
1508
+ key=(module_id, f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}"),
1509
+ value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
1510
+ window=1,
1511
+ )
1512
+
1513
+ def _set_slicing_by_batch_id(
1514
+ self, batch: MultiAgentBatch, *, value: bool
1515
+ ) -> MultiAgentBatch:
1516
+ """Enables slicing by batch id in the given batch.
1517
+
1518
+ If the input batch contains batches of sequences we need to make sure when
1519
+ slicing happens it is sliced via batch id and not timestamp. Calling this
1520
+ method enables the same flag on each SampleBatch within the input
1521
+ MultiAgentBatch.
1522
+
1523
+ Args:
1524
+ batch: The MultiAgentBatch to enable slicing by batch id on.
1525
+ value: The value to set the flag to.
1526
+
1527
+ Returns:
1528
+ The input MultiAgentBatch with the indexing flag is enabled / disabled on.
1529
+ """
1530
+
1531
+ for pid, policy_batch in batch.policy_batches.items():
1532
+ # We assume that arriving batches for recurrent modules OR batches that
1533
+ # have a SEQ_LENS column are already zero-padded to the max sequence length
1534
+ # and have tensors of shape [B, T, ...]. Therefore, we slice sequence
1535
+ # lengths in B. See SampleBatch for more information.
1536
+ if (
1537
+ self.module[pid].is_stateful()
1538
+ or policy_batch.get("seq_lens") is not None
1539
+ ):
1540
+ if value:
1541
+ policy_batch.enable_slicing_by_batch_id()
1542
+ else:
1543
+ policy_batch.disable_slicing_by_batch_id()
1544
+
1545
+ return batch
1546
+
1547
+ def _make_module(self) -> MultiRLModule:
1548
+ """Construct the multi-agent RL module for the learner.
1549
+
1550
+ This method uses `self._module_specs` or `self._module_obj` to construct the
1551
+ module. If the module_class is a single agent RL module it will be wrapped to a
1552
+ multi-agent RL module. Override this method if there are other things that
1553
+ need to happen for instantiation of the module.
1554
+
1555
+ Returns:
1556
+ A constructed MultiRLModule.
1557
+ """
1558
+ # Module was provided directly through constructor -> Use as-is.
1559
+ if self._module_obj is not None:
1560
+ module = self._module_obj
1561
+ self._module_spec = MultiRLModuleSpec.from_module(module)
1562
+ # RLModuleSpec was provided directly through constructor -> Use it to build the
1563
+ # RLModule.
1564
+ elif self._module_spec is not None:
1565
+ module = self._module_spec.build()
1566
+ # Try using our config object. Note that this would only work if the config
1567
+ # object has all the necessary space information already in it.
1568
+ else:
1569
+ module = self.config.get_multi_rl_module_spec().build()
1570
+
1571
+ # If not already, convert to MultiRLModule.
1572
+ module = module.as_multi_rl_module()
1573
+
1574
+ return module
1575
+
1576
+ def rl_module_is_compatible(self, module: RLModule) -> bool:
1577
+ """Check whether the given `module` is compatible with this Learner.
1578
+
1579
+ The default implementation checks the Learner-required APIs and whether the
1580
+ given `module` implements all of them (if not, returns False).
1581
+
1582
+ Args:
1583
+ module: The RLModule to check.
1584
+
1585
+ Returns:
1586
+ True if the module is compatible with this Learner.
1587
+ """
1588
+ return all(isinstance(module, api) for api in self.rl_module_required_apis())
1589
+
1590
+ @classmethod
1591
+ def rl_module_required_apis(cls) -> list[type]:
1592
+ """Returns the required APIs for an RLModule to be compatible with this Learner.
1593
+
1594
+ The returned values may or may not be used inside the `rl_module_is_compatible`
1595
+ method.
1596
+
1597
+ Args:
1598
+ module: The RLModule to check.
1599
+
1600
+ Returns:
1601
+ A list of RLModule API classes that an RLModule must implement in order
1602
+ to be compatible with this Learner.
1603
+ """
1604
+ return []
1605
+
1606
+ def _check_registered_optimizer(
1607
+ self,
1608
+ optimizer: Optimizer,
1609
+ params: Sequence[Param],
1610
+ ) -> None:
1611
+ """Checks that the given optimizer and parameters are valid for the framework.
1612
+
1613
+ Args:
1614
+ optimizer: The optimizer object to check.
1615
+ params: The list of parameters to check.
1616
+ """
1617
+ if not isinstance(params, list):
1618
+ raise ValueError(
1619
+ f"`params` ({params}) must be a list of framework-specific parameters "
1620
+ "(variables)!"
1621
+ )
1622
+
1623
+ def _log_trainable_parameters(self) -> None:
1624
+ """Logs the number of trainable and non-trainable parameters to self.metrics.
1625
+
1626
+ Use MetricsLogger (self.metrics) tuple-keys:
1627
+ (ALL_MODULES, NUM_TRAINABLE_PARAMETERS) and
1628
+ (ALL_MODULES, NUM_NON_TRAINABLE_PARAMETERS) with EMA.
1629
+ """
1630
+ pass
1631
+
1632
+ def _check_is_built(self, error: bool = True) -> bool:
1633
+ if self.module is None:
1634
+ if error:
1635
+ raise ValueError(
1636
+ "Learner.build() must be called after constructing a "
1637
+ "Learner and before calling any methods on it."
1638
+ )
1639
+ return False
1640
+ return True
1641
+
1642
+ def _reset(self):
1643
+ self._params = {}
1644
+ self._optimizer_parameters = {}
1645
+ self._named_optimizers = {}
1646
+ self._module_optimizers = defaultdict(list)
1647
+ self._optimizer_lr_schedules = {}
1648
+ self.metrics = MetricsLogger()
1649
+ self._is_built = False
1650
+
1651
+ def apply(self, func, *_args, **_kwargs):
1652
+ return func(self, *_args, **_kwargs)
1653
+
1654
+ @abc.abstractmethod
1655
+ def _get_tensor_variable(
1656
+ self,
1657
+ value: Any,
1658
+ dtype: Any = None,
1659
+ trainable: bool = False,
1660
+ ) -> TensorType:
1661
+ """Returns a framework-specific tensor variable with the initial given value.
1662
+
1663
+ This is a framework specific method that should be implemented by the
1664
+ framework specific sub-classes.
1665
+
1666
+ Args:
1667
+ value: The initial value for the tensor variable variable.
1668
+
1669
+ Returns:
1670
+ The framework specific tensor variable of the given initial value,
1671
+ dtype and trainable/requires_grad property.
1672
+ """
1673
+
1674
+ @staticmethod
1675
+ @abc.abstractmethod
1676
+ def _get_optimizer_lr(optimizer: Optimizer) -> float:
1677
+ """Returns the current learning rate of the given local optimizer.
1678
+
1679
+ Args:
1680
+ optimizer: The local optimizer to get the current learning rate for.
1681
+
1682
+ Returns:
1683
+ The learning rate value (float) of the given optimizer.
1684
+ """
1685
+
1686
+ @staticmethod
1687
+ @abc.abstractmethod
1688
+ def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None:
1689
+ """Updates the learning rate of the given local optimizer.
1690
+
1691
+ Args:
1692
+ optimizer: The local optimizer to update the learning rate for.
1693
+ lr: The new learning rate.
1694
+ """
1695
+
1696
+ @staticmethod
1697
+ @abc.abstractmethod
1698
+ def _get_clip_function() -> Callable:
1699
+ """Returns the gradient clipping function to use, given the framework."""
1700
+
1701
+ @staticmethod
1702
+ @abc.abstractmethod
1703
+ def _get_global_norm_function() -> Callable:
1704
+ """Returns the global norm function to use, given the framework."""
1705
+
1706
+ def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
1707
+ """Logs this iteration's steps trained, based on given `batch`."""
1708
+ for mid, module_batch in batch.policy_batches.items():
1709
+ module_batch_size = len(module_batch)
1710
+ # Log average batch size (for each module).
1711
+ self.metrics.log_value(
1712
+ key=(mid, MODULE_TRAIN_BATCH_SIZE_MEAN),
1713
+ value=module_batch_size,
1714
+ )
1715
+ # Log module steps (for each module).
1716
+ self.metrics.log_value(
1717
+ key=(mid, NUM_MODULE_STEPS_TRAINED),
1718
+ value=module_batch_size,
1719
+ reduce="sum",
1720
+ clear_on_reduce=True,
1721
+ )
1722
+ self.metrics.log_value(
1723
+ key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
1724
+ value=module_batch_size,
1725
+ reduce="sum",
1726
+ )
1727
+ # Log module steps (sum of all modules).
1728
+ self.metrics.log_value(
1729
+ key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
1730
+ value=module_batch_size,
1731
+ reduce="sum",
1732
+ clear_on_reduce=True,
1733
+ )
1734
+ self.metrics.log_value(
1735
+ key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
1736
+ value=module_batch_size,
1737
+ reduce="sum",
1738
+ )
1739
+ # Log env steps (all modules).
1740
+ self.metrics.log_value(
1741
+ (ALL_MODULES, NUM_ENV_STEPS_TRAINED),
1742
+ batch.env_steps(),
1743
+ reduce="sum",
1744
+ clear_on_reduce=True,
1745
+ )
1746
+ self.metrics.log_value(
1747
+ (ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME),
1748
+ batch.env_steps(),
1749
+ reduce="sum",
1750
+ with_throughput=True,
1751
+ )
1752
+
1753
+ @Deprecated(
1754
+ new="Learner.before_gradient_based_update("
1755
+ "timesteps={'num_env_steps_sampled_lifetime': ...}) and/or "
1756
+ "Learner.after_gradient_based_update("
1757
+ "timesteps={'num_env_steps_sampled_lifetime': ...})",
1758
+ error=True,
1759
+ )
1760
+ def additional_update_for_module(self, *args, **kwargs):
1761
+ pass
1762
+
1763
+ @Deprecated(new="Learner.save_to_path(...)", error=True)
1764
+ def save_state(self, *args, **kwargs):
1765
+ pass
1766
+
1767
+ @Deprecated(new="Learner.restore_from_path(...)", error=True)
1768
+ def load_state(self, *args, **kwargs):
1769
+ pass
1770
+
1771
+ @Deprecated(new="Learner.module.get_state()", error=True)
1772
+ def get_module_state(self, *args, **kwargs):
1773
+ pass
1774
+
1775
+ @Deprecated(new="Learner.module.set_state()", error=True)
1776
+ def set_module_state(self, *args, **kwargs):
1777
+ pass
1778
+
1779
+ @Deprecated(new="Learner._get_optimizer_state()", error=True)
1780
+ def get_optimizer_state(self, *args, **kwargs):
1781
+ pass
1782
+
1783
+ @Deprecated(new="Learner._set_optimizer_state()", error=True)
1784
+ def set_optimizer_state(self, *args, **kwargs):
1785
+ pass
1786
+
1787
+ @Deprecated(new="Learner.compute_losses(...)", error=False)
1788
+ def compute_loss(self, *args, **kwargs):
1789
+ losses_per_module = self.compute_losses(*args, **kwargs)
1790
+ # To continue supporting the old `compute_loss` behavior (instead of
1791
+ # the new `compute_losses`, add the ALL_MODULES key here holding the sum
1792
+ # of all individual loss terms.
1793
+ if ALL_MODULES not in losses_per_module:
1794
+ losses_per_module[ALL_MODULES] = sum(losses_per_module.values())
1795
+ return losses_per_module
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from collections import defaultdict, Counter
3
+ import copy
4
+ from functools import partial
5
+ import itertools
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Collection,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Set,
14
+ Type,
15
+ TYPE_CHECKING,
16
+ Union,
17
+ )
18
+
19
+ import ray
20
+ from ray import ObjectRef
21
+ from ray.rllib.core import (
22
+ COMPONENT_LEARNER,
23
+ COMPONENT_RL_MODULE,
24
+ )
25
+ from ray.rllib.core.learner.learner import Learner
26
+ from ray.rllib.core.rl_module import validate_module_id
27
+ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
28
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
29
+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
30
+ from ray.rllib.policy.policy import PolicySpec
31
+ from ray.rllib.policy.sample_batch import MultiAgentBatch
32
+ from ray.rllib.utils.actor_manager import (
33
+ FaultTolerantActorManager,
34
+ RemoteCallResults,
35
+ ResultOrError,
36
+ )
37
+ from ray.rllib.utils.annotations import override
38
+ from ray.rllib.utils.checkpoints import Checkpointable
39
+ from ray.rllib.utils.deprecation import Deprecated
40
+ from ray.rllib.utils.metrics import ALL_MODULES
41
+ from ray.rllib.utils.minibatch_utils import (
42
+ ShardBatchIterator,
43
+ ShardEpisodesIterator,
44
+ ShardObjectRefIterator,
45
+ )
46
+ from ray.rllib.utils.typing import (
47
+ EpisodeType,
48
+ ModuleID,
49
+ RLModuleSpecType,
50
+ ShouldModuleBeUpdatedFn,
51
+ StateDict,
52
+ T,
53
+ )
54
+ from ray.train._internal.backend_executor import BackendExecutor
55
+ from ray.util.annotations import PublicAPI
56
+
57
+ if TYPE_CHECKING:
58
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
59
+
60
+
61
+ def _get_backend_config(learner_class: Type[Learner]) -> str:
62
+ if learner_class.framework == "torch":
63
+ from ray.train.torch import TorchConfig
64
+
65
+ backend_config = TorchConfig()
66
+ elif learner_class.framework == "tf2":
67
+ from ray.train.tensorflow import TensorflowConfig
68
+
69
+ backend_config = TensorflowConfig()
70
+ else:
71
+ raise ValueError(
72
+ "`learner_class.framework` must be either 'torch' or 'tf2' (but is "
73
+ f"{learner_class.framework}!"
74
+ )
75
+
76
+ return backend_config
77
+
78
+
79
+ @PublicAPI(stability="alpha")
80
+ class LearnerGroup(Checkpointable):
81
+ """Coordinator of n (possibly remote) Learner workers.
82
+
83
+ Each Learner worker has a copy of the RLModule, the loss function(s), and
84
+ one or more optimizers.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ *,
90
+ config: "AlgorithmConfig",
91
+ # TODO (sven): Rename into `rl_module_spec`.
92
+ module_spec: Optional[RLModuleSpecType] = None,
93
+ ):
94
+ """Initializes a LearnerGroup instance.
95
+
96
+ Args:
97
+ config: The AlgorithmConfig object to use to configure this LearnerGroup.
98
+ Call the `learners(num_learners=...)` method on your config to
99
+ specify the number of learner workers to use.
100
+ Call the same method with arguments `num_cpus_per_learner` and/or
101
+ `num_gpus_per_learner` to configure the compute used by each
102
+ Learner worker in this LearnerGroup.
103
+ Call the `training(learner_class=...)` method on your config to specify,
104
+ which exact Learner class to use.
105
+ Call the `rl_module(rl_module_spec=...)` method on your config to set up
106
+ the specifics for your RLModule to be used in each Learner.
107
+ module_spec: If not already specified in `config`, a separate overriding
108
+ RLModuleSpec may be provided via this argument.
109
+ """
110
+ self.config = config.copy(copy_frozen=False)
111
+ self._module_spec = module_spec
112
+
113
+ learner_class = self.config.learner_class
114
+ module_spec = module_spec or self.config.get_multi_rl_module_spec()
115
+
116
+ self._learner = None
117
+ self._workers = None
118
+ # If a user calls self.shutdown() on their own then this flag is set to true.
119
+ # When del is called the backend executor isn't shutdown twice if this flag is
120
+ # true. the backend executor would otherwise log a warning to the console from
121
+ # ray train.
122
+ self._is_shut_down = False
123
+
124
+ # How many timesteps had to be dropped due to a full input queue?
125
+ self._ts_dropped = 0
126
+
127
+ # A single local Learner.
128
+ if not self.is_remote:
129
+ self._learner = learner_class(config=config, module_spec=module_spec)
130
+ self._learner.build()
131
+ self._worker_manager = None
132
+ # N remote Learner workers.
133
+ else:
134
+ backend_config = _get_backend_config(learner_class)
135
+
136
+ # TODO (sven): Can't set both `num_cpus_per_learner`>1 and
137
+ # `num_gpus_per_learner`>0! Users must set one or the other due
138
+ # to issues with placement group fragmentation. See
139
+ # https://github.com/ray-project/ray/issues/35409 for more details.
140
+ num_cpus_per_learner = (
141
+ self.config.num_cpus_per_learner
142
+ if not self.config.num_gpus_per_learner
143
+ else 0
144
+ )
145
+ num_gpus_per_learner = max(
146
+ 0,
147
+ self.config.num_gpus_per_learner
148
+ - (0.01 * self.config.num_aggregator_actors_per_learner),
149
+ )
150
+ resources_per_learner = {
151
+ "CPU": num_cpus_per_learner,
152
+ "GPU": num_gpus_per_learner,
153
+ }
154
+
155
+ backend_executor = BackendExecutor(
156
+ backend_config=backend_config,
157
+ num_workers=self.config.num_learners,
158
+ resources_per_worker=resources_per_learner,
159
+ max_retries=0,
160
+ )
161
+ backend_executor.start(
162
+ train_cls=learner_class,
163
+ train_cls_kwargs={
164
+ "config": config,
165
+ "module_spec": module_spec,
166
+ },
167
+ )
168
+ self._backend_executor = backend_executor
169
+
170
+ self._workers = [w.actor for w in backend_executor.worker_group.workers]
171
+
172
+ # Run the neural network building code on remote workers.
173
+ ray.get([w.build.remote() for w in self._workers])
174
+
175
+ self._worker_manager = FaultTolerantActorManager(
176
+ self._workers,
177
+ max_remote_requests_in_flight_per_actor=(
178
+ self.config.max_requests_in_flight_per_learner
179
+ ),
180
+ )
181
+ # Counters for the tags for asynchronous update requests that are
182
+ # in-flight. Used for keeping trakc of and grouping together the results of
183
+ # requests that were sent to the workers at the same time.
184
+ self._update_request_tags = Counter()
185
+ self._update_request_tag = 0
186
+ self._update_request_results = {}
187
+
188
+ # TODO (sven): Replace this with call to `self.metrics.peek()`?
189
+ # Currently LearnerGroup does not have a metrics object.
190
+ def get_stats(self) -> Dict[str, Any]:
191
+ """Returns the current stats for the input queue for this learner group."""
192
+ return {
193
+ "learner_group_ts_dropped": self._ts_dropped,
194
+ "actor_manager_num_outstanding_async_reqs": (
195
+ 0
196
+ if self.is_local
197
+ else self._worker_manager.num_outstanding_async_reqs()
198
+ ),
199
+ }
200
+
201
+ @property
202
+ def is_remote(self) -> bool:
203
+ return self.config.num_learners > 0
204
+
205
+ @property
206
+ def is_local(self) -> bool:
207
+ return not self.is_remote
208
+
209
+ def update_from_batch(
210
+ self,
211
+ batch: MultiAgentBatch,
212
+ *,
213
+ timesteps: Optional[Dict[str, Any]] = None,
214
+ async_update: bool = False,
215
+ return_state: bool = False,
216
+ num_epochs: int = 1,
217
+ minibatch_size: Optional[int] = None,
218
+ shuffle_batch_per_epoch: bool = False,
219
+ # User kwargs.
220
+ **kwargs,
221
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
222
+ """Performs gradient based update(s) on the Learner(s), based on given batch.
223
+
224
+ Args:
225
+ batch: A data batch to use for the update. If there are more
226
+ than one Learner workers, the batch is split amongst these and one
227
+ shard is sent to each Learner.
228
+ async_update: Whether the update request(s) to the Learner workers should be
229
+ sent asynchronously. If True, will return NOT the results from the
230
+ update on the given data, but all results from prior asynchronous update
231
+ requests that have not been returned thus far.
232
+ return_state: Whether to include one of the Learner worker's state from
233
+ after the update step in the returned results dict (under the
234
+ `_rl_module_state_after_update` key). Note that after an update, all
235
+ Learner workers' states should be identical, so we use the first
236
+ Learner's state here. Useful for avoiding an extra `get_weights()` call,
237
+ e.g. for synchronizing EnvRunner weights.
238
+ num_epochs: The number of complete passes over the entire train batch. Each
239
+ pass might be further split into n minibatches (if `minibatch_size`
240
+ provided).
241
+ minibatch_size: The size of minibatches to use to further split the train
242
+ `batch` into sub-batches. The `batch` is then iterated over n times
243
+ where n is `len(batch) // minibatch_size`.
244
+ shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
245
+ If the train batch has a time rank (axis=1), shuffling will only take
246
+ place along the batch axis to not disturb any intact (episode)
247
+ trajectories. Also, shuffling is always skipped if `minibatch_size` is
248
+ None, meaning the entire train batch is processed each epoch, making it
249
+ unnecessary to shuffle.
250
+
251
+ Returns:
252
+ If `async_update` is False, a dictionary with the reduced results of the
253
+ updates from the Learner(s) or a list of dictionaries of results from the
254
+ updates from the Learner(s).
255
+ If `async_update` is True, a list of list of dictionaries of results, where
256
+ the outer list corresponds to separate previous calls to this method, and
257
+ the inner list corresponds to the results from each Learner(s). Or if the
258
+ results are reduced, a list of dictionaries of the reduced results from each
259
+ call to async_update that is ready.
260
+ """
261
+ return self._update(
262
+ batch=batch,
263
+ timesteps=timesteps,
264
+ async_update=async_update,
265
+ return_state=return_state,
266
+ num_epochs=num_epochs,
267
+ minibatch_size=minibatch_size,
268
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch,
269
+ **kwargs,
270
+ )
271
+
272
+ def update_from_episodes(
273
+ self,
274
+ episodes: List[EpisodeType],
275
+ *,
276
+ timesteps: Optional[Dict[str, Any]] = None,
277
+ async_update: bool = False,
278
+ return_state: bool = False,
279
+ num_epochs: int = 1,
280
+ minibatch_size: Optional[int] = None,
281
+ shuffle_batch_per_epoch: bool = False,
282
+ # User kwargs.
283
+ **kwargs,
284
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
285
+ """Performs gradient based update(s) on the Learner(s), based on given episodes.
286
+
287
+ Args:
288
+ episodes: A list of Episodes to process and perform the update
289
+ for. If there are more than one Learner workers, the list of episodes
290
+ is split amongst these and one list shard is sent to each Learner.
291
+ async_update: Whether the update request(s) to the Learner workers should be
292
+ sent asynchronously. If True, will return NOT the results from the
293
+ update on the given data, but all results from prior asynchronous update
294
+ requests that have not been returned thus far.
295
+ return_state: Whether to include one of the Learner worker's state from
296
+ after the update step in the returned results dict (under the
297
+ `_rl_module_state_after_update` key). Note that after an update, all
298
+ Learner workers' states should be identical, so we use the first
299
+ Learner's state here. Useful for avoiding an extra `get_weights()` call,
300
+ e.g. for synchronizing EnvRunner weights.
301
+ num_epochs: The number of complete passes over the entire train batch. Each
302
+ pass might be further split into n minibatches (if `minibatch_size`
303
+ provided). The train batch is generated from the given `episodes`
304
+ through the Learner connector pipeline.
305
+ minibatch_size: The size of minibatches to use to further split the train
306
+ `batch` into sub-batches. The `batch` is then iterated over n times
307
+ where n is `len(batch) // minibatch_size`. The train batch is generated
308
+ from the given `episodes` through the Learner connector pipeline.
309
+ shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
310
+ If the train batch has a time rank (axis=1), shuffling will only take
311
+ place along the batch axis to not disturb any intact (episode)
312
+ trajectories. Also, shuffling is always skipped if `minibatch_size` is
313
+ None, meaning the entire train batch is processed each epoch, making it
314
+ unnecessary to shuffle. The train batch is generated from the given
315
+ `episodes` through the Learner connector pipeline.
316
+
317
+ Returns:
318
+ If async_update is False, a dictionary with the reduced results of the
319
+ updates from the Learner(s) or a list of dictionaries of results from the
320
+ updates from the Learner(s).
321
+ If async_update is True, a list of list of dictionaries of results, where
322
+ the outer list corresponds to separate previous calls to this method, and
323
+ the inner list corresponds to the results from each Learner(s). Or if the
324
+ results are reduced, a list of dictionaries of the reduced results from each
325
+ call to async_update that is ready.
326
+ """
327
+ return self._update(
328
+ episodes=episodes,
329
+ timesteps=timesteps,
330
+ async_update=async_update,
331
+ return_state=return_state,
332
+ num_epochs=num_epochs,
333
+ minibatch_size=minibatch_size,
334
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch,
335
+ **kwargs,
336
+ )
337
+
338
+ def _update(
339
+ self,
340
+ *,
341
+ batch: Optional[MultiAgentBatch] = None,
342
+ episodes: Optional[List[EpisodeType]] = None,
343
+ timesteps: Optional[Dict[str, Any]] = None,
344
+ async_update: bool = False,
345
+ return_state: bool = False,
346
+ num_epochs: int = 1,
347
+ num_iters: int = 1,
348
+ minibatch_size: Optional[int] = None,
349
+ shuffle_batch_per_epoch: bool = False,
350
+ **kwargs,
351
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
352
+
353
+ # Define function to be called on all Learner actors (or the local learner).
354
+ def _learner_update(
355
+ _learner: Learner,
356
+ *,
357
+ _batch_shard=None,
358
+ _episodes_shard=None,
359
+ _timesteps=None,
360
+ _return_state=False,
361
+ _num_total_minibatches=0,
362
+ **_kwargs,
363
+ ):
364
+ # If the batch shard is an `DataIterator` we have an offline
365
+ # multi-learner setup and `update_from_iterator` needs to
366
+ # handle updating.
367
+ if isinstance(_batch_shard, ray.data.DataIterator):
368
+ result = _learner.update_from_iterator(
369
+ iterator=_batch_shard,
370
+ timesteps=_timesteps,
371
+ minibatch_size=minibatch_size,
372
+ num_iters=num_iters,
373
+ **_kwargs,
374
+ )
375
+ elif _batch_shard is not None:
376
+ result = _learner.update_from_batch(
377
+ batch=_batch_shard,
378
+ timesteps=_timesteps,
379
+ num_epochs=num_epochs,
380
+ minibatch_size=minibatch_size,
381
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch,
382
+ **_kwargs,
383
+ )
384
+ else:
385
+ result = _learner.update_from_episodes(
386
+ episodes=_episodes_shard,
387
+ timesteps=_timesteps,
388
+ num_epochs=num_epochs,
389
+ minibatch_size=minibatch_size,
390
+ shuffle_batch_per_epoch=shuffle_batch_per_epoch,
391
+ num_total_minibatches=_num_total_minibatches,
392
+ **_kwargs,
393
+ )
394
+ if _return_state and result:
395
+ result["_rl_module_state_after_update"] = _learner.get_state(
396
+ # Only return the state of those RLModules that actually returned
397
+ # results and thus got probably updated.
398
+ components=[
399
+ COMPONENT_RL_MODULE + "/" + mid
400
+ for mid in result
401
+ if mid != ALL_MODULES
402
+ ],
403
+ inference_only=True,
404
+ )
405
+
406
+ return result
407
+
408
+ # Local Learner worker: Don't shard batch/episodes, just run data as-is through
409
+ # this Learner.
410
+ if self.is_local:
411
+ if async_update:
412
+ raise ValueError(
413
+ "Cannot call `update_from_batch(async_update=True)` when running in"
414
+ " local mode! Try setting `config.num_learners > 0`."
415
+ )
416
+
417
+ if isinstance(batch, list) and isinstance(batch[0], ray.ObjectRef):
418
+ assert len(batch) == 1
419
+ batch = ray.get(batch[0])
420
+
421
+ results = [
422
+ _learner_update(
423
+ _learner=self._learner,
424
+ _batch_shard=batch,
425
+ _episodes_shard=episodes,
426
+ _timesteps=timesteps,
427
+ _return_state=return_state,
428
+ **kwargs,
429
+ )
430
+ ]
431
+ # One or more remote Learners: Shard batch/episodes into equal pieces (roughly
432
+ # equal if multi-agent AND episodes) and send each Learner worker one of these
433
+ # shards.
434
+ else:
435
+ # MultiAgentBatch: Shard into equal pieces.
436
+ # TODO (sven): The sharder used here destroys - for multi-agent only -
437
+ # the relationship of the different agents' timesteps to each other.
438
+ # Thus, in case the algorithm requires agent-synchronized data (aka.
439
+ # "lockstep"), the `ShardBatchIterator` should not be used.
440
+ # Then again, we might move into a world where Learner always
441
+ # receives Episodes, never batches.
442
+ if isinstance(batch, list) and isinstance(batch[0], ray.data.DataIterator):
443
+ partials = [
444
+ partial(
445
+ _learner_update,
446
+ _batch_shard=iterator,
447
+ _return_state=(return_state and i == 0),
448
+ _timesteps=timesteps,
449
+ **kwargs,
450
+ )
451
+ # Note, `OfflineData` defines exactly as many iterators as there
452
+ # are learners.
453
+ for i, iterator in enumerate(batch)
454
+ ]
455
+ elif isinstance(batch, list) and isinstance(batch[0], ObjectRef):
456
+ assert len(batch) == len(self._workers)
457
+ partials = [
458
+ partial(
459
+ _learner_update,
460
+ _batch_shard=batch_shard,
461
+ _timesteps=timesteps,
462
+ _return_state=(return_state and i == 0),
463
+ **kwargs,
464
+ )
465
+ for i, batch_shard in enumerate(batch)
466
+ ]
467
+ elif batch is not None:
468
+ partials = [
469
+ partial(
470
+ _learner_update,
471
+ _batch_shard=batch_shard,
472
+ _return_state=(return_state and i == 0),
473
+ _timesteps=timesteps,
474
+ **kwargs,
475
+ )
476
+ for i, batch_shard in enumerate(
477
+ ShardBatchIterator(batch, len(self._workers))
478
+ )
479
+ ]
480
+ elif isinstance(episodes, list) and isinstance(episodes[0], ObjectRef):
481
+ partials = [
482
+ partial(
483
+ _learner_update,
484
+ _episodes_shard=episodes_shard,
485
+ _timesteps=timesteps,
486
+ _return_state=(return_state and i == 0),
487
+ **kwargs,
488
+ )
489
+ for i, episodes_shard in enumerate(
490
+ ShardObjectRefIterator(episodes, len(self._workers))
491
+ )
492
+ ]
493
+ # Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal
494
+ # in case of multi-agent).
495
+ else:
496
+ from ray.data.iterator import DataIterator
497
+
498
+ if isinstance(episodes[0], DataIterator):
499
+ num_total_minibatches = 0
500
+ partials = [
501
+ partial(
502
+ _learner_update,
503
+ _episodes_shard=episodes_shard,
504
+ _timesteps=timesteps,
505
+ _num_total_minibatches=num_total_minibatches,
506
+ )
507
+ for episodes_shard in episodes
508
+ ]
509
+ else:
510
+ eps_shards = list(
511
+ ShardEpisodesIterator(
512
+ episodes,
513
+ len(self._workers),
514
+ len_lookback_buffer=self.config.episode_lookback_horizon,
515
+ )
516
+ )
517
+ # In the multi-agent case AND `minibatch_size` AND num_workers
518
+ # > 1, we compute a max iteration counter such that the different
519
+ # Learners will not go through a different number of iterations.
520
+ num_total_minibatches = 0
521
+ if minibatch_size and len(self._workers) > 1:
522
+ num_total_minibatches = self._compute_num_total_minibatches(
523
+ episodes,
524
+ len(self._workers),
525
+ minibatch_size,
526
+ num_epochs,
527
+ )
528
+ partials = [
529
+ partial(
530
+ _learner_update,
531
+ _episodes_shard=eps_shard,
532
+ _timesteps=timesteps,
533
+ _num_total_minibatches=num_total_minibatches,
534
+ )
535
+ for eps_shard in eps_shards
536
+ ]
537
+
538
+ if async_update:
539
+ # Retrieve all ready results (kicked off by prior calls to this method).
540
+ tags_to_get = []
541
+ for tag in self._update_request_tags.keys():
542
+ result = self._worker_manager.fetch_ready_async_reqs(
543
+ tags=[str(tag)], timeout_seconds=0.0
544
+ )
545
+ if tag not in self._update_request_results:
546
+ self._update_request_results[tag] = result
547
+ else:
548
+ for r in result:
549
+ self._update_request_results[tag].add_result(
550
+ r.actor_id, r.result_or_error, tag
551
+ )
552
+
553
+ # Still not done with this `tag` -> skip out early.
554
+ if (
555
+ self._update_request_tags[tag]
556
+ > len(self._update_request_results[tag].result_or_errors)
557
+ > 0
558
+ ):
559
+ break
560
+ tags_to_get.append(tag)
561
+
562
+ # Send out new request(s), if there is still capacity on the actors
563
+ # (each actor is allowed only some number of max in-flight requests
564
+ # at the same time).
565
+ update_tag = self._update_request_tag
566
+ self._update_request_tag += 1
567
+ num_sent_requests = self._worker_manager.foreach_actor_async(
568
+ partials, tag=str(update_tag)
569
+ )
570
+ if num_sent_requests:
571
+ self._update_request_tags[update_tag] = num_sent_requests
572
+
573
+ # Some requests were dropped, record lost ts/data.
574
+ if num_sent_requests != len(self._workers):
575
+ factor = 1 - (num_sent_requests / len(self._workers))
576
+ # Batch: Measure its length.
577
+ if episodes is None:
578
+ dropped = len(batch)
579
+ # List of Ray ObjectRefs (each object ref is a list of episodes of
580
+ # total len=`rollout_fragment_length * num_envs_per_env_runner`)
581
+ elif isinstance(episodes[0], ObjectRef):
582
+ dropped = (
583
+ len(episodes)
584
+ * self.config.get_rollout_fragment_length()
585
+ * self.config.num_envs_per_env_runner
586
+ )
587
+ else:
588
+ dropped = sum(len(e) for e in episodes)
589
+
590
+ self._ts_dropped += factor * dropped
591
+
592
+ # NOTE: There is a strong assumption here that the requests launched to
593
+ # learner workers will return at the same time, since they have a
594
+ # barrier inside for gradient aggregation. Therefore, results should be
595
+ # a list of lists where each inner list should be the length of the
596
+ # number of learner workers, if results from an non-blocking update are
597
+ # ready.
598
+ results = self._get_async_results(tags_to_get)
599
+
600
+ else:
601
+ results = self._get_results(
602
+ self._worker_manager.foreach_actor(partials)
603
+ )
604
+
605
+ return results
606
+
607
+ # TODO (sven): Move this into FaultTolerantActorManager?
608
+ def _get_results(self, results):
609
+ processed_results = []
610
+ for result in results:
611
+ result_or_error = result.get()
612
+ if result.ok:
613
+ processed_results.append(result_or_error)
614
+ else:
615
+ raise result_or_error
616
+ return processed_results
617
+
618
+ def _get_async_results(self, tags_to_get):
619
+ """Get results from the worker manager and group them by tag.
620
+
621
+ Returns:
622
+ A list of lists of results, where each inner list contains all results
623
+ for same tags.
624
+
625
+ """
626
+ unprocessed_results = defaultdict(list)
627
+ for tag in tags_to_get:
628
+ results = self._update_request_results[tag]
629
+ for result in results:
630
+ result_or_error = result.get()
631
+ if result.ok:
632
+ if result.tag is None:
633
+ raise RuntimeError(
634
+ "Cannot call `LearnerGroup._get_async_results()` on "
635
+ "untagged async requests!"
636
+ )
637
+ tag = int(result.tag)
638
+ unprocessed_results[tag].append(result_or_error)
639
+
640
+ if tag in self._update_request_tags:
641
+ self._update_request_tags[tag] -= 1
642
+ if self._update_request_tags[tag] == 0:
643
+ del self._update_request_tags[tag]
644
+ del self._update_request_results[tag]
645
+ else:
646
+ assert False
647
+
648
+ else:
649
+ raise result_or_error
650
+
651
+ return list(unprocessed_results.values())
652
+
653
+ def add_module(
654
+ self,
655
+ *,
656
+ module_id: ModuleID,
657
+ module_spec: RLModuleSpec,
658
+ config_overrides: Optional[Dict] = None,
659
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
660
+ ) -> MultiRLModuleSpec:
661
+ """Adds a module to the underlying MultiRLModule.
662
+
663
+ Changes this Learner's config in order to make this architectural change
664
+ permanent wrt. to checkpointing.
665
+
666
+ Args:
667
+ module_id: The ModuleID of the module to be added.
668
+ module_spec: The ModuleSpec of the module to be added.
669
+ config_overrides: The `AlgorithmConfig` overrides that should apply to
670
+ the new Module, if any.
671
+ new_should_module_be_updated: An optional sequence of ModuleIDs or a
672
+ callable taking ModuleID and SampleBatchType and returning whether the
673
+ ModuleID should be updated (trained).
674
+ If None, will keep the existing setup in place. RLModules,
675
+ whose IDs are not in the list (or for which the callable
676
+ returns False) will not be updated.
677
+
678
+ Returns:
679
+ The new MultiRLModuleSpec (after the change has been performed).
680
+ """
681
+ validate_module_id(module_id, error=True)
682
+
683
+ # Force-set inference-only = False.
684
+ module_spec = copy.deepcopy(module_spec)
685
+ module_spec.inference_only = False
686
+
687
+ results = self.foreach_learner(
688
+ func=lambda _learner: _learner.add_module(
689
+ module_id=module_id,
690
+ module_spec=module_spec,
691
+ config_overrides=config_overrides,
692
+ new_should_module_be_updated=new_should_module_be_updated,
693
+ ),
694
+ )
695
+ marl_spec = self._get_results(results)[0]
696
+
697
+ # Change our config (AlgorithmConfig) to contain the new Module.
698
+ # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
699
+ # but we'll deprecate config.policies soon anyway.
700
+ self.config.policies[module_id] = PolicySpec()
701
+ if config_overrides is not None:
702
+ self.config.multi_agent(
703
+ algorithm_config_overrides_per_module={module_id: config_overrides}
704
+ )
705
+ self.config.rl_module(rl_module_spec=marl_spec)
706
+ if new_should_module_be_updated is not None:
707
+ self.config.multi_agent(policies_to_train=new_should_module_be_updated)
708
+
709
+ return marl_spec
710
+
711
+ def remove_module(
712
+ self,
713
+ module_id: ModuleID,
714
+ *,
715
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
716
+ ) -> MultiRLModuleSpec:
717
+ """Removes a module from the Learner.
718
+
719
+ Args:
720
+ module_id: The ModuleID of the module to be removed.
721
+ new_should_module_be_updated: An optional sequence of ModuleIDs or a
722
+ callable taking ModuleID and SampleBatchType and returning whether the
723
+ ModuleID should be updated (trained).
724
+ If None, will keep the existing setup in place. RLModules,
725
+ whose IDs are not in the list (or for which the callable
726
+ returns False) will not be updated.
727
+
728
+ Returns:
729
+ The new MultiRLModuleSpec (after the change has been performed).
730
+ """
731
+ results = self.foreach_learner(
732
+ func=lambda _learner: _learner.remove_module(
733
+ module_id=module_id,
734
+ new_should_module_be_updated=new_should_module_be_updated,
735
+ ),
736
+ )
737
+ marl_spec = self._get_results(results)[0]
738
+
739
+ # Change self.config to reflect the new architecture.
740
+ # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
741
+ # but we'll deprecate config.policies soon anyway.
742
+ del self.config.policies[module_id]
743
+ self.config.algorithm_config_overrides_per_module.pop(module_id, None)
744
+ if new_should_module_be_updated is not None:
745
+ self.config.multi_agent(policies_to_train=new_should_module_be_updated)
746
+ self.config.rl_module(rl_module_spec=marl_spec)
747
+
748
+ return marl_spec
749
+
750
+ @override(Checkpointable)
751
+ def get_state(
752
+ self,
753
+ components: Optional[Union[str, Collection[str]]] = None,
754
+ *,
755
+ not_components: Optional[Union[str, Collection[str]]] = None,
756
+ **kwargs,
757
+ ) -> StateDict:
758
+ state = {}
759
+
760
+ if self._check_component(COMPONENT_LEARNER, components, not_components):
761
+ if self.is_local:
762
+ state[COMPONENT_LEARNER] = self._learner.get_state(
763
+ components=self._get_subcomponents(COMPONENT_LEARNER, components),
764
+ not_components=self._get_subcomponents(
765
+ COMPONENT_LEARNER, not_components
766
+ ),
767
+ **kwargs,
768
+ )
769
+ else:
770
+ worker = self._worker_manager.healthy_actor_ids()[0]
771
+ assert len(self._workers) == self._worker_manager.num_healthy_actors()
772
+ _comps = self._get_subcomponents(COMPONENT_LEARNER, components)
773
+ _not_comps = self._get_subcomponents(COMPONENT_LEARNER, not_components)
774
+ results = self._worker_manager.foreach_actor(
775
+ lambda w: w.get_state(_comps, not_components=_not_comps, **kwargs),
776
+ remote_actor_ids=[worker],
777
+ )
778
+ state[COMPONENT_LEARNER] = self._get_results(results)[0]
779
+
780
+ return state
781
+
782
+ @override(Checkpointable)
783
+ def set_state(self, state: StateDict) -> None:
784
+ if COMPONENT_LEARNER in state:
785
+ if self.is_local:
786
+ self._learner.set_state(state[COMPONENT_LEARNER])
787
+ else:
788
+ state_ref = ray.put(state[COMPONENT_LEARNER])
789
+ self.foreach_learner(
790
+ lambda _learner, _ref=state_ref: _learner.set_state(ray.get(_ref))
791
+ )
792
+
793
+ def get_weights(
794
+ self, module_ids: Optional[Collection[ModuleID]] = None
795
+ ) -> StateDict:
796
+ """Convenience method instead of self.get_state(components=...).
797
+
798
+ Args:
799
+ module_ids: An optional collection of ModuleIDs for which to return weights.
800
+ If None (default), return weights of all RLModules.
801
+
802
+ Returns:
803
+ The results of
804
+ `self.get_state(components='learner/rl_module')['learner']['rl_module']`.
805
+ """
806
+ # Return the entire RLModule state (all possible single-agent RLModules).
807
+ if module_ids is None:
808
+ components = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
809
+ # Return a subset of the single-agent RLModules.
810
+ else:
811
+ components = [
812
+ "".join(tup)
813
+ for tup in itertools.product(
814
+ [COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/"],
815
+ list(module_ids),
816
+ )
817
+ ]
818
+ state = self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
819
+ return state
820
+
821
+ def set_weights(self, weights) -> None:
822
+ """Convenience method instead of self.set_state({'learner': {'rl_module': ..}}).
823
+
824
+ Args:
825
+ weights: The weights dict of the MultiRLModule of a Learner inside this
826
+ LearnerGroup.
827
+ """
828
+ self.set_state({COMPONENT_LEARNER: {COMPONENT_RL_MODULE: weights}})
829
+
830
+ @override(Checkpointable)
831
+ def get_ctor_args_and_kwargs(self):
832
+ return (
833
+ (), # *args
834
+ {
835
+ "config": self.config,
836
+ "module_spec": self._module_spec,
837
+ }, # **kwargs
838
+ )
839
+
840
+ @override(Checkpointable)
841
+ def get_checkpointable_components(self):
842
+ # Return the entire ActorManager, if remote. Otherwise, return the
843
+ # local worker. Also, don't give the component (Learner) a name ("")
844
+ # as it's the only component in this LearnerGroup to be saved.
845
+ return [
846
+ (
847
+ COMPONENT_LEARNER,
848
+ self._learner if self.is_local else self._worker_manager,
849
+ )
850
+ ]
851
+
852
+ def foreach_learner(
853
+ self,
854
+ func: Callable[[Learner, Optional[Any]], T],
855
+ *,
856
+ healthy_only: bool = True,
857
+ remote_actor_ids: List[int] = None,
858
+ timeout_seconds: Optional[float] = None,
859
+ return_obj_refs: bool = False,
860
+ mark_healthy: bool = False,
861
+ **kwargs,
862
+ ) -> RemoteCallResults:
863
+ """Calls the given function on each Learner L with the args: (L, \*\*kwargs).
864
+
865
+ Args:
866
+ func: The function to call on each Learner L with args: (L, \*\*kwargs).
867
+ healthy_only: If True, applies `func` only to Learner actors currently
868
+ tagged "healthy", otherwise to all actors. If `healthy_only=False` and
869
+ `mark_healthy=True`, will send `func` to all actors and mark those
870
+ actors "healthy" that respond to the request within `timeout_seconds`
871
+ and are currently tagged as "unhealthy".
872
+ remote_actor_ids: Apply func on a selected set of remote actors. Use None
873
+ (default) for all actors.
874
+ timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
875
+ fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
876
+ synchronous execution).
877
+ return_obj_refs: whether to return ObjectRef instead of actual results.
878
+ Note, for fault tolerance reasons, these returned ObjectRefs should
879
+ never be resolved with ray.get() outside of the context of this manager.
880
+ mark_healthy: Whether to mark all those actors healthy again that are
881
+ currently marked unhealthy AND that returned results from the remote
882
+ call (within the given `timeout_seconds`).
883
+ Note that actors are NOT set unhealthy, if they simply time out
884
+ (only if they return a RayActorError).
885
+ Also not that this setting is ignored if `healthy_only=True` (b/c this
886
+ setting only affects actors that are currently tagged as unhealthy).
887
+
888
+ Returns:
889
+ A list of size len(Learners) with the return values of all calls to `func`.
890
+ """
891
+ if self.is_local:
892
+ results = RemoteCallResults()
893
+ results.add_result(
894
+ None,
895
+ ResultOrError(result=func(self._learner, **kwargs)),
896
+ None,
897
+ )
898
+ return results
899
+
900
+ return self._worker_manager.foreach_actor(
901
+ func=partial(func, **kwargs),
902
+ healthy_only=healthy_only,
903
+ remote_actor_ids=remote_actor_ids,
904
+ timeout_seconds=timeout_seconds,
905
+ return_obj_refs=return_obj_refs,
906
+ mark_healthy=mark_healthy,
907
+ )
908
+
909
+ def shutdown(self):
910
+ """Shuts down the LearnerGroup."""
911
+ if self.is_remote and hasattr(self, "_backend_executor"):
912
+ self._backend_executor.shutdown()
913
+ self._is_shut_down = True
914
+
915
+ def __del__(self):
916
+ if not self._is_shut_down:
917
+ self.shutdown()
918
+
919
+ @staticmethod
920
+ def _compute_num_total_minibatches(
921
+ episodes,
922
+ num_shards,
923
+ minibatch_size,
924
+ num_epochs,
925
+ ):
926
+ # Count total number of timesteps per module ID.
927
+ if isinstance(episodes[0], MultiAgentEpisode):
928
+ per_mod_ts = defaultdict(int)
929
+ for ma_episode in episodes:
930
+ for sa_episode in ma_episode.agent_episodes.values():
931
+ per_mod_ts[sa_episode.module_id] += len(sa_episode)
932
+ max_ts = max(per_mod_ts.values())
933
+ else:
934
+ max_ts = sum(map(len, episodes))
935
+
936
+ return int((num_epochs * max_ts) / (num_shards * minibatch_size))
937
+
938
+ @Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False)
939
+ def update(self, *args, **kwargs):
940
+ # Just in case, we would like to revert this API retirement, we can do so
941
+ # easily.
942
+ return self._update(*args, **kwargs, async_update=False)
943
+
944
+ @Deprecated(new="LearnerGroup.update_from_batch(async=True)", error=False)
945
+ def async_update(self, *args, **kwargs):
946
+ # Just in case, we would like to revert this API retirement, we can do so
947
+ # easily.
948
+ return self._update(*args, **kwargs, async_update=True)
949
+
950
+ @Deprecated(new="LearnerGroup.save_to_path(...)", error=True)
951
+ def save_state(self, *args, **kwargs):
952
+ pass
953
+
954
+ @Deprecated(new="LearnerGroup.restore_from_path(...)", error=True)
955
+ def load_state(self, *args, **kwargs):
956
+ pass
957
+
958
+ @Deprecated(new="LearnerGroup.load_from_path(path=..., component=...)", error=False)
959
+ def load_module_state(
960
+ self,
961
+ *,
962
+ multi_rl_module_ckpt_dir: Optional[str] = None,
963
+ modules_to_load: Optional[Set[str]] = None,
964
+ rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None,
965
+ ) -> None:
966
+ """Load the checkpoints of the modules being trained by this LearnerGroup.
967
+
968
+ `load_module_state` can be used 3 ways:
969
+ 1. Load a checkpoint for the MultiRLModule being trained by this
970
+ LearnerGroup. Limit the modules that are loaded from the checkpoint
971
+ by specifying the `modules_to_load` argument.
972
+ 2. Load the checkpoint(s) for single agent RLModules that
973
+ are in the MultiRLModule being trained by this LearnerGroup.
974
+ 3. Load a checkpoint for the MultiRLModule being trained by this
975
+ LearnerGroup and load the checkpoint(s) for single agent RLModules
976
+ that are in the MultiRLModule. The checkpoints for the single
977
+ agent RLModules take precedence over the module states in the
978
+ MultiRLModule checkpoint.
979
+
980
+ NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is
981
+ must be specified. modules_to_load can only be specified if
982
+ multi_rl_module_ckpt_dir is specified.
983
+
984
+ Args:
985
+ multi_rl_module_ckpt_dir: The path to the checkpoint for the
986
+ MultiRLModule.
987
+ modules_to_load: A set of module ids to load from the checkpoint.
988
+ rl_module_ckpt_dirs: A mapping from module ids to the path to a
989
+ checkpoint for a single agent RLModule.
990
+ """
991
+ if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs):
992
+ raise ValueError(
993
+ "At least one of `multi_rl_module_ckpt_dir` or "
994
+ "`rl_module_ckpt_dirs` must be provided!"
995
+ )
996
+ if multi_rl_module_ckpt_dir:
997
+ multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir)
998
+ if rl_module_ckpt_dirs:
999
+ for module_id, path in rl_module_ckpt_dirs.items():
1000
+ rl_module_ckpt_dirs[module_id] = pathlib.Path(path)
1001
+
1002
+ # MultiRLModule checkpoint is provided.
1003
+ if multi_rl_module_ckpt_dir:
1004
+ # Restore the entire MultiRLModule state.
1005
+ if modules_to_load is None:
1006
+ self.restore_from_path(
1007
+ multi_rl_module_ckpt_dir,
1008
+ component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
1009
+ )
1010
+ # Restore individual module IDs.
1011
+ else:
1012
+ for module_id in modules_to_load:
1013
+ self.restore_from_path(
1014
+ multi_rl_module_ckpt_dir / module_id,
1015
+ component=(
1016
+ COMPONENT_LEARNER
1017
+ + "/"
1018
+ + COMPONENT_RL_MODULE
1019
+ + "/"
1020
+ + module_id
1021
+ ),
1022
+ )
1023
+ if rl_module_ckpt_dirs:
1024
+ for module_id, path in rl_module_ckpt_dirs.items():
1025
+ self.restore_from_path(
1026
+ path,
1027
+ component=(
1028
+ COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id
1029
+ ),
1030
+ )
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (198 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pathlib
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Hashable,
8
+ Sequence,
9
+ Tuple,
10
+ TYPE_CHECKING,
11
+ Union,
12
+ )
13
+
14
+ from ray.rllib.core.learner.learner import Learner
15
+ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
16
+ from ray.rllib.core.rl_module.rl_module import (
17
+ RLModule,
18
+ RLModuleSpec,
19
+ )
20
+ from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
21
+ from ray.rllib.policy.eager_tf_policy import _convert_to_tf
22
+ from ray.rllib.policy.sample_batch import MultiAgentBatch
23
+ from ray.rllib.utils.annotations import (
24
+ override,
25
+ OverrideToImplementCustomLogic,
26
+ )
27
+ from ray.rllib.utils.framework import try_import_tf
28
+ from ray.rllib.utils.typing import (
29
+ ModuleID,
30
+ Optimizer,
31
+ Param,
32
+ ParamDict,
33
+ StateDict,
34
+ TensorType,
35
+ )
36
+
37
+ if TYPE_CHECKING:
38
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
39
+
40
+ tf1, tf, tfv = try_import_tf()
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class TfLearner(Learner):
46
+
47
+ framework: str = "tf2"
48
+
49
+ def __init__(self, **kwargs):
50
+ # by default in rllib we disable tf2 behavior
51
+ # This call re-enables it as it is needed for using
52
+ # this class.
53
+ try:
54
+ tf1.enable_v2_behavior()
55
+ except ValueError:
56
+ # This is a hack to avoid the error that happens when calling
57
+ # enable_v2_behavior after variables have already been created.
58
+ pass
59
+
60
+ super().__init__(**kwargs)
61
+
62
+ self._enable_tf_function = self.config.eager_tracing
63
+
64
+ # This is a placeholder which will be filled by
65
+ # `_make_distributed_strategy_if_necessary`.
66
+ self._strategy: tf.distribute.Strategy = None
67
+
68
+ @OverrideToImplementCustomLogic
69
+ @override(Learner)
70
+ def configure_optimizers_for_module(
71
+ self, module_id: ModuleID, config: "AlgorithmConfig" = None
72
+ ) -> None:
73
+ module = self._module[module_id]
74
+
75
+ # For this default implementation, the learning rate is handled by the
76
+ # attached lr Scheduler (controlled by self.config.lr, which can be a
77
+ # fixed value or a schedule setting).
78
+ optimizer = tf.keras.optimizers.Adam()
79
+ params = self.get_parameters(module)
80
+
81
+ # This isn't strictly necessary, but makes it so that if a checkpoint is
82
+ # computed before training actually starts, then it will be the same in
83
+ # shape / size as a checkpoint after training starts.
84
+ optimizer.build(module.trainable_variables)
85
+
86
+ # Register the created optimizer (under the default optimizer name).
87
+ self.register_optimizer(
88
+ module_id=module_id,
89
+ optimizer=optimizer,
90
+ params=params,
91
+ lr_or_lr_schedule=config.lr,
92
+ )
93
+
94
+ @override(Learner)
95
+ def compute_gradients(
96
+ self,
97
+ loss_per_module: Dict[str, TensorType],
98
+ gradient_tape: "tf.GradientTape",
99
+ **kwargs,
100
+ ) -> ParamDict:
101
+ total_loss = sum(loss_per_module.values())
102
+ grads = gradient_tape.gradient(total_loss, self._params)
103
+ return grads
104
+
105
+ @override(Learner)
106
+ def apply_gradients(self, gradients_dict: ParamDict) -> None:
107
+ # TODO (Avnishn, kourosh): apply gradients doesn't work in cases where
108
+ # only some agents have a sample batch that is passed but not others.
109
+ # This is probably because of the way that we are iterating over the
110
+ # parameters in the optim_to_param_dictionary.
111
+ for optimizer in self._optimizer_parameters:
112
+ optim_grad_dict = self.filter_param_dict_for_optimizer(
113
+ optimizer=optimizer, param_dict=gradients_dict
114
+ )
115
+ variable_list = []
116
+ gradient_list = []
117
+ for param_ref, grad in optim_grad_dict.items():
118
+ if grad is not None:
119
+ variable_list.append(self._params[param_ref])
120
+ gradient_list.append(grad)
121
+ optimizer.apply_gradients(zip(gradient_list, variable_list))
122
+
123
+ @override(Learner)
124
+ def restore_from_path(self, path: Union[str, pathlib.Path]) -> None:
125
+ # This operation is potentially very costly because a MultiRLModule is created
126
+ # at build time, destroyed, and then a new one is created from a checkpoint.
127
+ # However, it is necessary due to complications with the way that Ray Tune
128
+ # restores failed trials. When Tune restores a failed trial, it reconstructs the
129
+ # entire experiment from the initial config. Therefore, to reflect any changes
130
+ # made to the learner's modules, the module created by Tune is destroyed and
131
+ # then rebuilt from the checkpoint.
132
+ with self._strategy.scope():
133
+ super().restore_from_path(path)
134
+
135
+ @override(Learner)
136
+ def _get_optimizer_state(self) -> StateDict:
137
+ optim_state = {}
138
+ with tf.init_scope():
139
+ for name, optim in self._named_optimizers.items():
140
+ optim_state[name] = [var.numpy() for var in optim.variables()]
141
+ return optim_state
142
+
143
+ @override(Learner)
144
+ def _set_optimizer_state(self, state: StateDict) -> None:
145
+ for name, state_array in state.items():
146
+ if name not in self._named_optimizers:
147
+ raise ValueError(
148
+ f"Optimizer {name} in `state` is not known! "
149
+ f"Known optimizers are {self._named_optimizers.keys()}"
150
+ )
151
+ optim = self._named_optimizers[name]
152
+ optim.set_weights(state_array)
153
+
154
+ @override(Learner)
155
+ def get_param_ref(self, param: Param) -> Hashable:
156
+ return param.ref()
157
+
158
+ @override(Learner)
159
+ def get_parameters(self, module: RLModule) -> Sequence[Param]:
160
+ return list(module.trainable_variables)
161
+
162
+ @override(Learner)
163
+ def rl_module_is_compatible(self, module: RLModule) -> bool:
164
+ return isinstance(module, TfRLModule)
165
+
166
+ @override(Learner)
167
+ def _check_registered_optimizer(
168
+ self,
169
+ optimizer: Optimizer,
170
+ params: Sequence[Param],
171
+ ) -> None:
172
+ super()._check_registered_optimizer(optimizer, params)
173
+ if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
174
+ raise ValueError(
175
+ f"The optimizer ({optimizer}) is not a tf keras optimizer! "
176
+ "Only use tf.keras.optimizers.Optimizer subclasses for TfLearner."
177
+ )
178
+ for param in params:
179
+ if not isinstance(param, tf.Variable):
180
+ raise ValueError(
181
+ f"One of the parameters ({param}) in the registered optimizer "
182
+ "is not a tf.Variable!"
183
+ )
184
+
185
+ @override(Learner)
186
+ def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
187
+ batch = _convert_to_tf(batch.policy_batches)
188
+ length = max(len(b) for b in batch.values())
189
+ batch = MultiAgentBatch(batch, env_steps=length)
190
+ return batch
191
+
192
+ @override(Learner)
193
+ def add_module(
194
+ self,
195
+ *,
196
+ module_id: ModuleID,
197
+ module_spec: RLModuleSpec,
198
+ ) -> None:
199
+ # TODO(Avnishn):
200
+ # WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead
201
+ # currently. We will be working on improving this in the future, but for now
202
+ # please wrap `call_for_each_replica` or `experimental_run` or `run` inside a
203
+ # tf.function to get the best performance.
204
+ # I get this warning any time I add a new module. I see the warning a few times
205
+ # and then it disappears. I think that I will need to open an issue with the TF
206
+ # team.
207
+ with self._strategy.scope():
208
+ super().add_module(
209
+ module_id=module_id,
210
+ module_spec=module_spec,
211
+ )
212
+ if self._enable_tf_function:
213
+ self._possibly_traced_update = tf.function(
214
+ self._untraced_update, reduce_retracing=True
215
+ )
216
+
217
+ @override(Learner)
218
+ def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec:
219
+ with self._strategy.scope():
220
+ marl_spec = super().remove_module(module_id, **kwargs)
221
+
222
+ if self._enable_tf_function:
223
+ self._possibly_traced_update = tf.function(
224
+ self._untraced_update, reduce_retracing=True
225
+ )
226
+
227
+ return marl_spec
228
+
229
+ def _make_distributed_strategy_if_necessary(self) -> "tf.distribute.Strategy":
230
+ """Create a distributed strategy for the learner.
231
+
232
+ A stratgey is a tensorflow object that is used for distributing training and
233
+ gradient computation across multiple devices. By default, a no-op strategy is
234
+ used that is not distributed.
235
+
236
+ Returns:
237
+ A strategy for the learner to use for distributed training.
238
+
239
+ """
240
+ if self.config.num_learners > 1:
241
+ strategy = tf.distribute.MultiWorkerMirroredStrategy()
242
+ elif self.config.num_gpus_per_learner > 0:
243
+ # mirrored strategy is typically used for multi-gpu training
244
+ # on a single machine, however we can use it for single-gpu
245
+ devices = tf.config.list_logical_devices("GPU")
246
+ assert self.config.local_gpu_idx < len(devices), (
247
+ f"local_gpu_idx {self.config.local_gpu_idx} is not a valid GPU id or "
248
+ "is not available."
249
+ )
250
+ local_gpu = [devices[self.config.local_gpu_idx].name]
251
+ strategy = tf.distribute.MirroredStrategy(devices=local_gpu)
252
+ else:
253
+ # the default strategy is a no-op that can be used in the local mode
254
+ # cpu only case, build will override this if needed.
255
+ strategy = tf.distribute.get_strategy()
256
+ return strategy
257
+
258
+ @override(Learner)
259
+ def build(self) -> None:
260
+ """Build the TfLearner.
261
+
262
+ This method is specific TfLearner. Before running super() it sets the correct
263
+ distributing strategy with the right device, so that computational graph is
264
+ placed on the correct device. After running super(), depending on eager_tracing
265
+ flag it will decide whether to wrap the update function with tf.function or not.
266
+ """
267
+
268
+ # we call build anytime we make a learner, or load a learner from a checkpoint.
269
+ # we can't make a new strategy every time we build, so we only make one the
270
+ # first time build is called.
271
+ if not self._strategy:
272
+ self._strategy = self._make_distributed_strategy_if_necessary()
273
+
274
+ with self._strategy.scope():
275
+ super().build()
276
+
277
+ if self._enable_tf_function:
278
+ self._possibly_traced_update = tf.function(
279
+ self._untraced_update, reduce_retracing=True
280
+ )
281
+ else:
282
+ self._possibly_traced_update = self._untraced_update
283
+
284
+ @override(Learner)
285
+ def _update(self, batch: Dict) -> Tuple[Any, Any, Any]:
286
+ return self._possibly_traced_update(batch)
287
+
288
+ def _untraced_update(
289
+ self,
290
+ batch: Dict,
291
+ # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
292
+ # eager_tracing=True mode.
293
+ # It seems there may be a clash between the traced-by-tf function and the
294
+ # traced-by-ray functions (for making the TfLearner class a ray actor).
295
+ _ray_trace_ctx=None,
296
+ ):
297
+ # Activate tensor-mode on our MetricsLogger.
298
+ self.metrics.activate_tensor_mode()
299
+
300
+ def helper(_batch):
301
+ with tf.GradientTape(persistent=True) as tape:
302
+ fwd_out = self._module.forward_train(_batch)
303
+ loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch)
304
+ gradients = self.compute_gradients(loss_per_module, gradient_tape=tape)
305
+ del tape
306
+ postprocessed_gradients = self.postprocess_gradients(gradients)
307
+ self.apply_gradients(postprocessed_gradients)
308
+
309
+ # Deactivate tensor-mode on our MetricsLogger and collect the (tensor)
310
+ # results.
311
+ return fwd_out, loss_per_module, self.metrics.deactivate_tensor_mode()
312
+
313
+ return self._strategy.run(helper, args=(batch,))
314
+
315
+ @override(Learner)
316
+ def _get_tensor_variable(self, value, dtype=None, trainable=False) -> "tf.Tensor":
317
+ return tf.Variable(
318
+ value,
319
+ trainable=trainable,
320
+ dtype=(
321
+ dtype
322
+ or (
323
+ tf.float32
324
+ if isinstance(value, float)
325
+ else tf.int32
326
+ if isinstance(value, int)
327
+ else None
328
+ )
329
+ ),
330
+ )
331
+
332
+ @staticmethod
333
+ @override(Learner)
334
+ def _get_optimizer_lr(optimizer: "tf.Optimizer") -> float:
335
+ return optimizer.lr
336
+
337
+ @staticmethod
338
+ @override(Learner)
339
+ def _set_optimizer_lr(optimizer: "tf.Optimizer", lr: float) -> None:
340
+ # When tf creates the optimizer, it seems to detach the optimizer's lr value
341
+ # from the given tf variable.
342
+ # Thus, updating this variable is NOT sufficient to update the actual
343
+ # optimizer's learning rate, so we have to explicitly set it here inside the
344
+ # optimizer object.
345
+ optimizer.lr = lr
346
+
347
+ @staticmethod
348
+ @override(Learner)
349
+ def _get_clip_function() -> Callable:
350
+ from ray.rllib.utils.tf_utils import clip_gradients
351
+
352
+ return clip_gradients
353
+
354
+ @staticmethod
355
+ @override(Learner)
356
+ def _get_global_norm_function() -> Callable:
357
+ return tf.linalg.global_norm
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc ADDED
Binary file (31.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import logging
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Hashable,
8
+ Optional,
9
+ Sequence,
10
+ Tuple,
11
+ )
12
+
13
+ from ray.rllib.algorithms.algorithm_config import (
14
+ AlgorithmConfig,
15
+ TorchCompileWhatToCompile,
16
+ )
17
+ from ray.rllib.core.columns import Columns
18
+ from ray.rllib.core.learner.learner import Learner, LR_KEY
19
+ from ray.rllib.core.rl_module.multi_rl_module import (
20
+ MultiRLModule,
21
+ MultiRLModuleSpec,
22
+ )
23
+ from ray.rllib.core.rl_module.rl_module import (
24
+ RLModule,
25
+ RLModuleSpec,
26
+ )
27
+ from ray.rllib.core.rl_module.torch.torch_rl_module import (
28
+ TorchCompileConfig,
29
+ TorchDDPRLModule,
30
+ TorchRLModule,
31
+ )
32
+ from ray.rllib.policy.sample_batch import MultiAgentBatch
33
+ from ray.rllib.utils.annotations import (
34
+ override,
35
+ OverrideToImplementCustomLogic,
36
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
37
+ )
38
+ from ray.rllib.utils.framework import get_device, try_import_torch
39
+ from ray.rllib.utils.metrics import (
40
+ ALL_MODULES,
41
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
42
+ NUM_TRAINABLE_PARAMETERS,
43
+ NUM_NON_TRAINABLE_PARAMETERS,
44
+ WEIGHTS_SEQ_NO,
45
+ )
46
+ from ray.rllib.utils.numpy import convert_to_numpy
47
+ from ray.rllib.utils.torch_utils import convert_to_torch_tensor
48
+ from ray.rllib.utils.typing import (
49
+ ModuleID,
50
+ Optimizer,
51
+ Param,
52
+ ParamDict,
53
+ ShouldModuleBeUpdatedFn,
54
+ StateDict,
55
+ TensorType,
56
+ )
57
+
58
+ torch, nn = try_import_torch()
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ class TorchLearner(Learner):
63
+
64
+ framework: str = "torch"
65
+
66
+ def __init__(self, **kwargs):
67
+ super().__init__(**kwargs)
68
+
69
+ # Whether to compile the RL Module of this learner. This implies that the.
70
+ # forward_train method of the RL Module will be compiled. Further more,
71
+ # other forward methods of the RL Module will be compiled on demand.
72
+ # This is assumed to not happen, since other forwrad methods are not expected
73
+ # to be used during training.
74
+ self._torch_compile_forward_train = False
75
+ self._torch_compile_cfg = None
76
+ # Whether to compile the `_uncompiled_update` method of this learner. This
77
+ # implies that everything within `_uncompiled_update` will be compiled,
78
+ # not only the forward_train method of the RL Module.
79
+ # Note that this is experimental.
80
+ # Note that this requires recompiling the forward methods once we add/remove
81
+ # RL Modules.
82
+ self._torch_compile_complete_update = False
83
+ if self.config.torch_compile_learner:
84
+ if (
85
+ self.config.torch_compile_learner_what_to_compile
86
+ == TorchCompileWhatToCompile.COMPLETE_UPDATE
87
+ ):
88
+ self._torch_compile_complete_update = True
89
+ self._compiled_update_initialized = False
90
+ else:
91
+ self._torch_compile_forward_train = True
92
+
93
+ self._torch_compile_cfg = TorchCompileConfig(
94
+ torch_dynamo_backend=self.config.torch_compile_learner_dynamo_backend,
95
+ torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode,
96
+ )
97
+
98
+ # Loss scalers for mixed precision training. Map optimizer names to
99
+ # associated torch GradScaler objects.
100
+ self._grad_scalers = None
101
+ if self.config._torch_grad_scaler_class:
102
+ self._grad_scalers = defaultdict(
103
+ lambda: self.config._torch_grad_scaler_class()
104
+ )
105
+ self._lr_schedulers = {}
106
+ self._lr_scheduler_classes = None
107
+ if self.config._torch_lr_scheduler_classes:
108
+ self._lr_scheduler_classes = self.config._torch_lr_scheduler_classes
109
+
110
+ @OverrideToImplementCustomLogic
111
+ @override(Learner)
112
+ def configure_optimizers_for_module(
113
+ self,
114
+ module_id: ModuleID,
115
+ config: "AlgorithmConfig" = None,
116
+ ) -> None:
117
+ module = self._module[module_id]
118
+
119
+ # For this default implementation, the learning rate is handled by the
120
+ # attached lr Scheduler (controlled by self.config.lr, which can be a
121
+ # fixed value or a schedule setting).
122
+ params = self.get_parameters(module)
123
+ optimizer = torch.optim.Adam(params)
124
+
125
+ # Register the created optimizer (under the default optimizer name).
126
+ self.register_optimizer(
127
+ module_id=module_id,
128
+ optimizer=optimizer,
129
+ params=params,
130
+ lr_or_lr_schedule=config.lr,
131
+ )
132
+
133
+ def _uncompiled_update(
134
+ self,
135
+ batch: Dict,
136
+ **kwargs,
137
+ ):
138
+ """Performs a single update given a batch of data."""
139
+ # Activate tensor-mode on our MetricsLogger.
140
+ self.metrics.activate_tensor_mode()
141
+
142
+ # TODO (sven): Causes weird cuda error when WandB is used.
143
+ # Diagnosis thus far:
144
+ # - All peek values during metrics.reduce are non-tensors.
145
+ # - However, in impala.py::training_step(), a tensor does arrive after learner
146
+ # group.update_from_episodes(), so somehow, there is still a race condition
147
+ # possible (learner, which performs the reduce() and learner thread, which
148
+ # performs the logging of tensors into metrics logger).
149
+ self._compute_off_policyness(batch)
150
+
151
+ fwd_out = self.module.forward_train(batch)
152
+ loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
153
+
154
+ gradients = self.compute_gradients(loss_per_module)
155
+ postprocessed_gradients = self.postprocess_gradients(gradients)
156
+ self.apply_gradients(postprocessed_gradients)
157
+
158
+ # Deactivate tensor-mode on our MetricsLogger and collect the (tensor)
159
+ # results.
160
+ return fwd_out, loss_per_module, self.metrics.deactivate_tensor_mode()
161
+
162
+ @override(Learner)
163
+ def compute_gradients(
164
+ self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
165
+ ) -> ParamDict:
166
+ for optim in self._optimizer_parameters:
167
+ # `set_to_none=True` is a faster way to zero out the gradients.
168
+ optim.zero_grad(set_to_none=True)
169
+
170
+ if self._grad_scalers is not None:
171
+ total_loss = sum(
172
+ self._grad_scalers[mid].scale(loss)
173
+ for mid, loss in loss_per_module.items()
174
+ )
175
+ else:
176
+ total_loss = sum(loss_per_module.values())
177
+
178
+ total_loss.backward()
179
+ grads = {pid: p.grad for pid, p in self._params.items()}
180
+
181
+ return grads
182
+
183
+ @override(Learner)
184
+ def apply_gradients(self, gradients_dict: ParamDict) -> None:
185
+ # Set the gradient of the parameters.
186
+ for pid, grad in gradients_dict.items():
187
+ # If updates should not be skipped turn `nan` and `inf` gradients to zero.
188
+ if (
189
+ not torch.isfinite(grad).all()
190
+ and not self.config.torch_skip_nan_gradients
191
+ ):
192
+ # Warn the user about `nan` gradients.
193
+ logger.warning(f"Gradients {pid} contain `nan/inf` values.")
194
+ # If updates should be skipped, do not step the optimizer and return.
195
+ if not self.config.torch_skip_nan_gradients:
196
+ logger.warning(
197
+ "Setting `nan/inf` gradients to zero. If updates with "
198
+ "`nan/inf` gradients should not be set to zero and instead "
199
+ "the update be skipped entirely set `torch_skip_nan_gradients` "
200
+ "to `True`."
201
+ )
202
+ # If necessary turn `nan` gradients to zero. Note this can corrupt the
203
+ # internal state of the optimizer, if many `nan` gradients occur.
204
+ self._params[pid].grad = torch.nan_to_num(grad)
205
+ # Otherwise, use the gradient as is.
206
+ else:
207
+ self._params[pid].grad = grad
208
+
209
+ # For each optimizer call its step function.
210
+ for module_id, optimizer_names in self._module_optimizers.items():
211
+ for optimizer_name in optimizer_names:
212
+ optim = self.get_optimizer(module_id, optimizer_name)
213
+ # If we have learning rate schedulers for a module add them, if
214
+ # necessary.
215
+ if self._lr_scheduler_classes is not None:
216
+ if (
217
+ module_id not in self._lr_schedulers
218
+ or optimizer_name not in self._lr_schedulers[module_id]
219
+ ):
220
+ # Set for each module and optimizer a scheduler.
221
+ self._lr_schedulers[module_id] = {optimizer_name: []}
222
+ # If the classes are in a dictionary each module might have
223
+ # a different set of schedulers.
224
+ if isinstance(self._lr_scheduler_classes, dict):
225
+ scheduler_classes = self._lr_scheduler_classes[module_id]
226
+ # Else, each module has the same learning rate schedulers.
227
+ else:
228
+ scheduler_classes = self._lr_scheduler_classes
229
+ # Initialize and add the schedulers.
230
+ for scheduler_class in scheduler_classes:
231
+ self._lr_schedulers[module_id][optimizer_name].append(
232
+ scheduler_class(optim)
233
+ )
234
+
235
+ # Step through the scaler (unscales gradients, if applicable).
236
+ if self._grad_scalers is not None:
237
+ scaler = self._grad_scalers[module_id]
238
+ scaler.step(optim)
239
+ self.metrics.log_value(
240
+ (module_id, "_torch_grad_scaler_current_scale"),
241
+ scaler.get_scale(),
242
+ window=1, # snapshot in time, no EMA/mean.
243
+ )
244
+ # Update the scaler.
245
+ scaler.update()
246
+ # `step` the optimizer (default), but only if all gradients are finite.
247
+ elif all(
248
+ param.grad is None or torch.isfinite(param.grad).all()
249
+ for group in optim.param_groups
250
+ for param in group["params"]
251
+ ):
252
+ optim.step()
253
+ # If gradients are not all finite warn the user that the update will be
254
+ # skipped.
255
+ elif not all(
256
+ torch.isfinite(param.grad).all()
257
+ for group in optim.param_groups
258
+ for param in group["params"]
259
+ ):
260
+ logger.warning(
261
+ "Skipping this update. If updates with `nan/inf` gradients "
262
+ "should not be skipped entirely and instead `nan/inf` "
263
+ "gradients set to `zero` set `torch_skip_nan_gradients` to "
264
+ "`False`."
265
+ )
266
+
267
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
268
+ @override(Learner)
269
+ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
270
+ """Called after gradient-based updates are completed.
271
+
272
+ Should be overridden to implement custom cleanup-, logging-, or non-gradient-
273
+ based Learner/RLModule update logic after(!) gradient-based updates have been
274
+ completed.
275
+
276
+ Note, for `framework="torch"` users can register
277
+ `torch.optim.lr_scheduler.LRScheduler` via
278
+ `AlgorithmConfig._torch_lr_scheduler_classes`. These schedulers need to be
279
+ stepped here after gradient updates and reported.
280
+
281
+ Args:
282
+ timesteps: Timesteps dict, which must have the key
283
+ `NUM_ENV_STEPS_SAMPLED_LIFETIME`.
284
+ # TODO (sven): Make this a more formal structure with its own type.
285
+ """
286
+
287
+ # If we have no `torch.optim.lr_scheduler.LRScheduler` registered call the
288
+ # `super()`'s method to update RLlib's learning rate schedules.
289
+ if not self._lr_schedulers:
290
+ return super().after_gradient_based_update(timesteps=timesteps)
291
+
292
+ # Only update this optimizer's lr, if a scheduler has been registered
293
+ # along with it.
294
+ for module_id, optimizer_names in self._module_optimizers.items():
295
+ for optimizer_name in optimizer_names:
296
+ # If learning rate schedulers are provided step them here. Note,
297
+ # stepping them in `TorchLearner.apply_gradients` updates the
298
+ # learning rates during minibatch updates; we want to update
299
+ # between whole batch updates.
300
+ if (
301
+ module_id in self._lr_schedulers
302
+ and optimizer_name in self._lr_schedulers[module_id]
303
+ ):
304
+ for scheduler in self._lr_schedulers[module_id][optimizer_name]:
305
+ scheduler.step()
306
+ optimizer = self.get_optimizer(module_id, optimizer_name)
307
+ self.metrics.log_value(
308
+ # Cut out the module ID from the beginning since it's already
309
+ # part of the key sequence: (ModuleID, "[optim name]_lr").
310
+ key=(
311
+ module_id,
312
+ f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}",
313
+ ),
314
+ value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
315
+ window=1,
316
+ )
317
+
318
+ @override(Learner)
319
+ def _get_optimizer_state(self) -> StateDict:
320
+ ret = {}
321
+ for name, optim in self._named_optimizers.items():
322
+ ret[name] = {
323
+ "module_id": self._optimizer_name_to_module[name],
324
+ "state": convert_to_numpy(optim.state_dict()),
325
+ }
326
+ return ret
327
+
328
+ @override(Learner)
329
+ def _set_optimizer_state(self, state: StateDict) -> None:
330
+ for name, state_dict in state.items():
331
+ # Ignore updating optimizers matching to submodules not present in this
332
+ # Learner's MultiRLModule.
333
+ module_id = state_dict["module_id"]
334
+ if name not in self._named_optimizers and module_id in self.module:
335
+ self.configure_optimizers_for_module(
336
+ module_id=module_id,
337
+ config=self.config.get_config_for_module(module_id=module_id),
338
+ )
339
+ if name in self._named_optimizers:
340
+ self._named_optimizers[name].load_state_dict(
341
+ convert_to_torch_tensor(state_dict["state"], device=self._device)
342
+ )
343
+
344
+ @override(Learner)
345
+ def get_param_ref(self, param: Param) -> Hashable:
346
+ return param
347
+
348
+ @override(Learner)
349
+ def get_parameters(self, module: RLModule) -> Sequence[Param]:
350
+ return list(module.parameters())
351
+
352
+ @override(Learner)
353
+ def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
354
+ batch = convert_to_torch_tensor(batch.policy_batches, device=self._device)
355
+ # TODO (sven): This computation of `env_steps` is not accurate!
356
+ length = max(len(b) for b in batch.values())
357
+ batch = MultiAgentBatch(batch, env_steps=length)
358
+ return batch
359
+
360
+ @override(Learner)
361
+ def add_module(
362
+ self,
363
+ *,
364
+ module_id: ModuleID,
365
+ # TODO (sven): Rename to `rl_module_spec`.
366
+ module_spec: RLModuleSpec,
367
+ config_overrides: Optional[Dict] = None,
368
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
369
+ ) -> MultiRLModuleSpec:
370
+ # Call super's add_module method.
371
+ marl_spec = super().add_module(
372
+ module_id=module_id,
373
+ module_spec=module_spec,
374
+ config_overrides=config_overrides,
375
+ new_should_module_be_updated=new_should_module_be_updated,
376
+ )
377
+
378
+ # we need to ddpify the module that was just added to the pool
379
+ module = self._module[module_id]
380
+
381
+ if self._torch_compile_forward_train:
382
+ module.compile(self._torch_compile_cfg)
383
+ elif self._torch_compile_complete_update:
384
+ # When compiling the update, we need to reset and recompile
385
+ # _uncompiled_update every time we add/remove a module anew.
386
+ torch._dynamo.reset()
387
+ self._compiled_update_initialized = False
388
+ self._possibly_compiled_update = torch.compile(
389
+ self._uncompiled_update,
390
+ backend=self._torch_compile_cfg.torch_dynamo_backend,
391
+ mode=self._torch_compile_cfg.torch_dynamo_mode,
392
+ **self._torch_compile_cfg.kwargs,
393
+ )
394
+
395
+ if isinstance(module, TorchRLModule):
396
+ self._module[module_id].to(self._device)
397
+ if self.distributed:
398
+ if (
399
+ self._torch_compile_complete_update
400
+ or self._torch_compile_forward_train
401
+ ):
402
+ raise ValueError(
403
+ "Using torch distributed and torch compile "
404
+ "together tested for now. Please disable "
405
+ "torch compile."
406
+ )
407
+ self._module.add_module(
408
+ module_id,
409
+ TorchDDPRLModule(module, **self.config.torch_ddp_kwargs),
410
+ override=True,
411
+ )
412
+
413
+ self._log_trainable_parameters()
414
+
415
+ return marl_spec
416
+
417
+ @override(Learner)
418
+ def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec:
419
+ marl_spec = super().remove_module(module_id, **kwargs)
420
+
421
+ if self._torch_compile_complete_update:
422
+ # When compiling the update, we need to reset and recompile
423
+ # _uncompiled_update every time we add/remove a module anew.
424
+ torch._dynamo.reset()
425
+ self._compiled_update_initialized = False
426
+ self._possibly_compiled_update = torch.compile(
427
+ self._uncompiled_update,
428
+ backend=self._torch_compile_cfg.torch_dynamo_backend,
429
+ mode=self._torch_compile_cfg.torch_dynamo_mode,
430
+ **self._torch_compile_cfg.kwargs,
431
+ )
432
+
433
+ self._log_trainable_parameters()
434
+
435
+ return marl_spec
436
+
437
+ @override(Learner)
438
+ def build(self) -> None:
439
+ """Builds the TorchLearner.
440
+
441
+ This method is specific to TorchLearner. Before running super() it will
442
+ initialize the device properly based on `self.config`, so that `_make_module()`
443
+ can place the created module on the correct device. After running super() it
444
+ wraps the module in a TorchDDPRLModule if `config.num_learners > 0`.
445
+ Note, in inherited classes it is advisable to call the parent's `build()`
446
+ after setting up all variables because `configure_optimizer_for_module` is
447
+ called in this `Learner.build()`.
448
+ """
449
+ self._device = get_device(self.config, self.config.num_gpus_per_learner)
450
+
451
+ super().build()
452
+
453
+ if self._torch_compile_complete_update:
454
+ torch._dynamo.reset()
455
+ self._compiled_update_initialized = False
456
+ self._possibly_compiled_update = torch.compile(
457
+ self._uncompiled_update,
458
+ backend=self._torch_compile_cfg.torch_dynamo_backend,
459
+ mode=self._torch_compile_cfg.torch_dynamo_mode,
460
+ **self._torch_compile_cfg.kwargs,
461
+ )
462
+ else:
463
+ if self._torch_compile_forward_train:
464
+ if isinstance(self._module, TorchRLModule):
465
+ self._module.compile(self._torch_compile_cfg)
466
+ elif isinstance(self._module, MultiRLModule):
467
+ for module in self._module._rl_modules.values():
468
+ # Compile only TorchRLModules, e.g. we don't want to compile
469
+ # a RandomRLModule.
470
+ if isinstance(self._module, TorchRLModule):
471
+ module.compile(self._torch_compile_cfg)
472
+ else:
473
+ raise ValueError(
474
+ "Torch compile is only supported for TorchRLModule and "
475
+ "MultiRLModule."
476
+ )
477
+
478
+ self._possibly_compiled_update = self._uncompiled_update
479
+
480
+ self._make_modules_ddp_if_necessary()
481
+
482
+ @override(Learner)
483
+ def _update(self, batch: Dict[str, Any]) -> Tuple[Any, Any, Any]:
484
+ # The first time we call _update after building the learner or
485
+ # adding/removing models, we update with the uncompiled update method.
486
+ # This makes it so that any variables that may be created during the first
487
+ # update step are already there when compiling. More specifically,
488
+ # this avoids errors that occur around using defaultdicts with
489
+ # torch.compile().
490
+ if (
491
+ self._torch_compile_complete_update
492
+ and not self._compiled_update_initialized
493
+ ):
494
+ self._compiled_update_initialized = True
495
+ return self._uncompiled_update(batch)
496
+ else:
497
+ return self._possibly_compiled_update(batch)
498
+
499
+ @OverrideToImplementCustomLogic
500
+ def _make_modules_ddp_if_necessary(self) -> None:
501
+ """Default logic for (maybe) making all Modules within self._module DDP."""
502
+
503
+ # If the module is a MultiRLModule and nn.Module we can simply assume
504
+ # all the submodules are registered. Otherwise, we need to loop through
505
+ # each submodule and move it to the correct device.
506
+ # TODO (Kourosh): This can result in missing modules if the user does not
507
+ # register them in the MultiRLModule. We should find a better way to
508
+ # handle this.
509
+ if self.config.num_learners > 1:
510
+ # Single agent module: Convert to `TorchDDPRLModule`.
511
+ if isinstance(self._module, TorchRLModule):
512
+ self._module = TorchDDPRLModule(
513
+ self._module, **self.config.torch_ddp_kwargs
514
+ )
515
+ # Multi agent module: Convert each submodule to `TorchDDPRLModule`.
516
+ else:
517
+ assert isinstance(self._module, MultiRLModule)
518
+ for key in self._module.keys():
519
+ sub_module = self._module[key]
520
+ if isinstance(sub_module, TorchRLModule):
521
+ # Wrap and override the module ID key in self._module.
522
+ self._module.add_module(
523
+ key,
524
+ TorchDDPRLModule(
525
+ sub_module, **self.config.torch_ddp_kwargs
526
+ ),
527
+ override=True,
528
+ )
529
+
530
+ def rl_module_is_compatible(self, module: RLModule) -> bool:
531
+ return isinstance(module, nn.Module)
532
+
533
+ @override(Learner)
534
+ def _check_registered_optimizer(
535
+ self,
536
+ optimizer: Optimizer,
537
+ params: Sequence[Param],
538
+ ) -> None:
539
+ super()._check_registered_optimizer(optimizer, params)
540
+ if not isinstance(optimizer, torch.optim.Optimizer):
541
+ raise ValueError(
542
+ f"The optimizer ({optimizer}) is not a torch.optim.Optimizer! "
543
+ "Only use torch.optim.Optimizer subclasses for TorchLearner."
544
+ )
545
+ for param in params:
546
+ if not isinstance(param, torch.Tensor):
547
+ raise ValueError(
548
+ f"One of the parameters ({param}) in the registered optimizer "
549
+ "is not a torch.Tensor!"
550
+ )
551
+
552
+ @override(Learner)
553
+ def _make_module(self) -> MultiRLModule:
554
+ module = super()._make_module()
555
+ self._map_module_to_device(module)
556
+ return module
557
+
558
+ def _map_module_to_device(self, module: MultiRLModule) -> None:
559
+ """Moves the module to the correct device."""
560
+ if isinstance(module, torch.nn.Module):
561
+ module.to(self._device)
562
+ else:
563
+ for key in module.keys():
564
+ if isinstance(module[key], torch.nn.Module):
565
+ module[key].to(self._device)
566
+
567
+ @override(Learner)
568
+ def _log_trainable_parameters(self) -> None:
569
+ # Log number of non-trainable and trainable parameters of our RLModule.
570
+ num_trainable_params = {
571
+ (mid, NUM_TRAINABLE_PARAMETERS): sum(
572
+ p.numel() for p in rlm.parameters() if p.requires_grad
573
+ )
574
+ for mid, rlm in self.module._rl_modules.items()
575
+ if isinstance(rlm, TorchRLModule)
576
+ }
577
+ num_non_trainable_params = {
578
+ (mid, NUM_NON_TRAINABLE_PARAMETERS): sum(
579
+ p.numel() for p in rlm.parameters() if not p.requires_grad
580
+ )
581
+ for mid, rlm in self.module._rl_modules.items()
582
+ if isinstance(rlm, TorchRLModule)
583
+ }
584
+
585
+ self.metrics.log_dict(
586
+ {
587
+ **{
588
+ (ALL_MODULES, NUM_TRAINABLE_PARAMETERS): sum(
589
+ num_trainable_params.values()
590
+ ),
591
+ (ALL_MODULES, NUM_NON_TRAINABLE_PARAMETERS): sum(
592
+ num_non_trainable_params.values()
593
+ ),
594
+ },
595
+ **num_trainable_params,
596
+ **num_non_trainable_params,
597
+ }
598
+ )
599
+
600
+ def _compute_off_policyness(self, batch):
601
+ # Log off-policy'ness of this batch wrt the current weights.
602
+ off_policyness = {
603
+ (mid, DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY): (
604
+ (self._weights_seq_no - module_batch[WEIGHTS_SEQ_NO]).float()
605
+ )
606
+ for mid, module_batch in batch.items()
607
+ if WEIGHTS_SEQ_NO in module_batch
608
+ }
609
+ for key in off_policyness.keys():
610
+ mid = key[0]
611
+ if Columns.LOSS_MASK not in batch[mid]:
612
+ off_policyness[key] = torch.mean(off_policyness[key])
613
+ else:
614
+ mask = batch[mid][Columns.LOSS_MASK]
615
+ num_valid = torch.sum(mask)
616
+ off_policyness[key] = torch.sum(off_policyness[key][mask]) / num_valid
617
+ self.metrics.log_dict(off_policyness, window=1)
618
+
619
+ @override(Learner)
620
+ def _get_tensor_variable(
621
+ self, value, dtype=None, trainable=False
622
+ ) -> "torch.Tensor":
623
+ tensor = torch.tensor(
624
+ value,
625
+ requires_grad=trainable,
626
+ device=self._device,
627
+ dtype=(
628
+ dtype
629
+ or (
630
+ torch.float32
631
+ if isinstance(value, float)
632
+ else torch.int32
633
+ if isinstance(value, int)
634
+ else None
635
+ )
636
+ ),
637
+ )
638
+ return nn.Parameter(tensor) if trainable else tensor
639
+
640
+ @staticmethod
641
+ @override(Learner)
642
+ def _get_optimizer_lr(optimizer: "torch.optim.Optimizer") -> float:
643
+ for g in optimizer.param_groups:
644
+ return g["lr"]
645
+
646
+ @staticmethod
647
+ @override(Learner)
648
+ def _set_optimizer_lr(optimizer: "torch.optim.Optimizer", lr: float) -> None:
649
+ for g in optimizer.param_groups:
650
+ g["lr"] = lr
651
+
652
+ @staticmethod
653
+ @override(Learner)
654
+ def _get_clip_function() -> Callable:
655
+ from ray.rllib.utils.torch_utils import clip_gradients
656
+
657
+ return clip_gradients
658
+
659
+ @staticmethod
660
+ @override(Learner)
661
+ def _get_global_norm_function() -> Callable:
662
+ from ray.rllib.utils.torch_utils import compute_global_norm
663
+
664
+ return compute_global_norm
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from ray.rllib.utils.framework import try_import_torch
4
+ from ray.rllib.utils.typing import NetworkType
5
+ from ray.util import PublicAPI
6
+
7
+
8
+ torch, _ = try_import_torch()
9
+
10
+
11
+ def make_target_network(main_net: NetworkType) -> NetworkType:
12
+ """Creates a (deep) copy of `main_net` (including synched weights) and returns it.
13
+
14
+ Args:
15
+ main_net: The main network to return a target network for
16
+
17
+ Returns:
18
+ The copy of `main_net` that can be used as a target net. Note that the weights
19
+ of the returned net are already synched (identical) with `main_net`.
20
+ """
21
+ # Deepcopy the main net (this should already take care of synching all weights).
22
+ target_net = copy.deepcopy(main_net)
23
+ # Make the target net not trainable.
24
+ if isinstance(main_net, torch.nn.Module):
25
+ target_net.requires_grad_(False)
26
+ else:
27
+ raise ValueError(f"Unsupported framework for given `main_net` {main_net}!")
28
+
29
+ return target_net
30
+
31
+
32
+ @PublicAPI(stability="beta")
33
+ def update_target_network(
34
+ *,
35
+ main_net: NetworkType,
36
+ target_net: NetworkType,
37
+ tau: float,
38
+ ) -> None:
39
+ """Updates a target network (from a "main" network) using Polyak averaging.
40
+
41
+ Thereby:
42
+ new_target_net_weight = (
43
+ tau * main_net_weight + (1.0 - tau) * current_target_net_weight
44
+ )
45
+
46
+ Args:
47
+ main_net: The nn.Module to update from.
48
+ target_net: The target network to update.
49
+ tau: The tau value to use in the Polyak averaging formula. Use 1.0 for a
50
+ complete sync of the weights (target and main net will be the exact same
51
+ after updating).
52
+ """
53
+ if isinstance(main_net, torch.nn.Module):
54
+ from ray.rllib.utils.torch_utils import update_target_network as _update_target
55
+
56
+ else:
57
+ raise ValueError(f"Unsupported framework for given `main_net` {main_net}!")
58
+
59
+ _update_target(main_net=main_net, target_net=target_net, tau=tau)
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (200 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc ADDED
Binary file (9.66 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.04 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+
4
+ from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
5
+ from ray.rllib.core.rl_module.multi_rl_module import (
6
+ MultiRLModule,
7
+ MultiRLModuleSpec,
8
+ )
9
+ from ray.util import log_once
10
+ from ray.util.annotations import DeveloperAPI
11
+
12
+ logger = logging.getLogger("ray.rllib")
13
+
14
+
15
+ @DeveloperAPI
16
+ def validate_module_id(policy_id: str, error: bool = False) -> None:
17
+ """Makes sure the given `policy_id` is valid.
18
+
19
+ Args:
20
+ policy_id: The Policy ID to check.
21
+ IMPORTANT: Must not contain characters that
22
+ are also not allowed in Unix/Win filesystems, such as: `<>:"/\\|?*`
23
+ or a dot `.` or space ` ` at the end of the ID.
24
+ error: Whether to raise an error (ValueError) or a warning in case of an
25
+ invalid `policy_id`.
26
+
27
+ Raises:
28
+ ValueError: If the given `policy_id` is not a valid one and `error` is True.
29
+ """
30
+ if (
31
+ not isinstance(policy_id, str)
32
+ or len(policy_id) == 0
33
+ or re.search('[<>:"/\\\\|?]', policy_id)
34
+ or policy_id[-1] in (" ", ".")
35
+ ):
36
+ msg = (
37
+ f"PolicyID `{policy_id}` not valid! IDs must be a non-empty string, "
38
+ "must not contain characters that are also disallowed file- or directory "
39
+ "names on Unix/Windows and must not end with a dot `.` or a space ` `."
40
+ )
41
+ if error:
42
+ raise ValueError(msg)
43
+ elif log_once("invalid_policy_id"):
44
+ logger.warning(msg)
45
+
46
+
47
+ __all__ = [
48
+ "MultiRLModule",
49
+ "MultiRLModuleSpec",
50
+ "RLModule",
51
+ "RLModuleSpec",
52
+ "validate_module_id",
53
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.31 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc ADDED
Binary file (5.21 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc ADDED
Binary file (43.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc ADDED
Binary file (36.3 kB). View file