Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py +478 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py +394 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py +1017 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py +40 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py +55 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py +208 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py +6 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py +253 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py +80 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py +168 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py +131 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py +30 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py +91 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py +82 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py +7 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py +146 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py +70 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py +92 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py +46 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py +170 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py +8 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py +1795 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py +1030 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py +357 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py +664 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py +59 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py +53 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file defines base types and common structures for RLlib connectors.
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import abc
|
| 5 |
+
import logging
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import gymnasium as gym
|
| 9 |
+
|
| 10 |
+
from ray.rllib.policy.view_requirement import ViewRequirement
|
| 11 |
+
from ray.rllib.utils.typing import (
|
| 12 |
+
ActionConnectorDataType,
|
| 13 |
+
AgentConnectorDataType,
|
| 14 |
+
AlgorithmConfigDict,
|
| 15 |
+
TensorType,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from ray.rllib.policy.policy import Policy
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@OldAPIStack
|
| 26 |
+
class ConnectorContext:
|
| 27 |
+
"""Data bits that may be needed for running connectors.
|
| 28 |
+
|
| 29 |
+
Note(jungong) : we need to be really careful with the data fields here.
|
| 30 |
+
E.g., everything needs to be serializable, in case we need to fetch them
|
| 31 |
+
in a remote setting.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# TODO(jungong) : figure out how to fetch these in a remote setting.
|
| 35 |
+
# Probably from a policy server when initializing a policy client.
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
config: AlgorithmConfigDict = None,
|
| 40 |
+
initial_states: List[TensorType] = None,
|
| 41 |
+
observation_space: gym.Space = None,
|
| 42 |
+
action_space: gym.Space = None,
|
| 43 |
+
view_requirements: Dict[str, ViewRequirement] = None,
|
| 44 |
+
is_policy_recurrent: bool = False,
|
| 45 |
+
):
|
| 46 |
+
"""Construct a ConnectorContext instance.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
initial_states: States that are used for constructing
|
| 50 |
+
the initial input dict for RNN models. [] if a model is not recurrent.
|
| 51 |
+
action_space_struct: a policy's action space, in python
|
| 52 |
+
data format. E.g., python dict instead of DictSpace, python tuple
|
| 53 |
+
instead of TupleSpace.
|
| 54 |
+
"""
|
| 55 |
+
self.config = config or {}
|
| 56 |
+
self.initial_states = initial_states or []
|
| 57 |
+
self.observation_space = observation_space
|
| 58 |
+
self.action_space = action_space
|
| 59 |
+
self.view_requirements = view_requirements
|
| 60 |
+
self.is_policy_recurrent = is_policy_recurrent
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def from_policy(policy: "Policy") -> "ConnectorContext":
|
| 64 |
+
"""Build ConnectorContext from a given policy.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
policy: Policy
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
A ConnectorContext instance.
|
| 71 |
+
"""
|
| 72 |
+
return ConnectorContext(
|
| 73 |
+
config=policy.config,
|
| 74 |
+
initial_states=policy.get_initial_state(),
|
| 75 |
+
observation_space=policy.observation_space,
|
| 76 |
+
action_space=policy.action_space,
|
| 77 |
+
view_requirements=policy.view_requirements,
|
| 78 |
+
is_policy_recurrent=policy.is_recurrent(),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@OldAPIStack
|
| 83 |
+
class Connector(abc.ABC):
|
| 84 |
+
"""Connector base class.
|
| 85 |
+
|
| 86 |
+
A connector is a step of transformation, of either envrionment data before they
|
| 87 |
+
get to a policy, or policy output before it is sent back to the environment.
|
| 88 |
+
|
| 89 |
+
Connectors may be training-aware, for example, behave slightly differently
|
| 90 |
+
during training and inference.
|
| 91 |
+
|
| 92 |
+
All connectors are required to be serializable and implement to_state().
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, ctx: ConnectorContext):
|
| 96 |
+
# Default is training mode.
|
| 97 |
+
self._is_training = True
|
| 98 |
+
|
| 99 |
+
def in_training(self):
|
| 100 |
+
self._is_training = True
|
| 101 |
+
|
| 102 |
+
def in_eval(self):
|
| 103 |
+
self._is_training = False
|
| 104 |
+
|
| 105 |
+
def __str__(self, indentation: int = 0):
|
| 106 |
+
return " " * indentation + self.__class__.__name__
|
| 107 |
+
|
| 108 |
+
def to_state(self) -> Tuple[str, Any]:
|
| 109 |
+
"""Serialize a connector into a JSON serializable Tuple.
|
| 110 |
+
|
| 111 |
+
to_state is required, so that all Connectors are serializable.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
A tuple of connector's name and its serialized states.
|
| 115 |
+
String should match the name used to register the connector,
|
| 116 |
+
while state can be any single data structure that contains the
|
| 117 |
+
serialized state of the connector. If a connector is stateless,
|
| 118 |
+
state can simply be None.
|
| 119 |
+
"""
|
| 120 |
+
# Must implement by each connector.
|
| 121 |
+
return NotImplementedError
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def from_state(self, ctx: ConnectorContext, params: Any) -> "Connector":
|
| 125 |
+
"""De-serialize a JSON params back into a Connector.
|
| 126 |
+
|
| 127 |
+
from_state is required, so that all Connectors are serializable.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
ctx: Context for constructing this connector.
|
| 131 |
+
params: Serialized states of the connector to be recovered.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
De-serialized connector.
|
| 135 |
+
"""
|
| 136 |
+
# Must implement by each connector.
|
| 137 |
+
return NotImplementedError
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@OldAPIStack
|
| 141 |
+
class AgentConnector(Connector):
|
| 142 |
+
"""Connector connecting user environments to RLlib policies.
|
| 143 |
+
|
| 144 |
+
An agent connector transforms a list of agent data in AgentConnectorDataType
|
| 145 |
+
format into a new list in the same AgentConnectorDataTypes format.
|
| 146 |
+
The input API is designed so agent connectors can have access to all the
|
| 147 |
+
agents assigned to a particular policy.
|
| 148 |
+
|
| 149 |
+
AgentConnectorDataTypes can be used to specify arbitrary type of env data,
|
| 150 |
+
|
| 151 |
+
Example:
|
| 152 |
+
|
| 153 |
+
Represent a list of agent data from one env step() call.
|
| 154 |
+
|
| 155 |
+
.. testcode::
|
| 156 |
+
|
| 157 |
+
import numpy as np
|
| 158 |
+
ac = AgentConnectorDataType(
|
| 159 |
+
env_id="env_1",
|
| 160 |
+
agent_id=None,
|
| 161 |
+
data={
|
| 162 |
+
"agent_1": np.array([1, 2, 3]),
|
| 163 |
+
"agent_2": np.array([4, 5, 6]),
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
Or a single agent data ready to be preprocessed.
|
| 168 |
+
|
| 169 |
+
.. testcode::
|
| 170 |
+
|
| 171 |
+
ac = AgentConnectorDataType(
|
| 172 |
+
env_id="env_1",
|
| 173 |
+
agent_id="agent_1",
|
| 174 |
+
data=np.array([1, 2, 3]),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
We can also adapt a simple stateless function into an agent connector by
|
| 178 |
+
using register_lambda_agent_connector:
|
| 179 |
+
|
| 180 |
+
.. testcode::
|
| 181 |
+
|
| 182 |
+
import numpy as np
|
| 183 |
+
from ray.rllib.connectors.agent.lambdas import (
|
| 184 |
+
register_lambda_agent_connector
|
| 185 |
+
)
|
| 186 |
+
TimesTwoAgentConnector = register_lambda_agent_connector(
|
| 187 |
+
"TimesTwoAgentConnector", lambda data: data * 2
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# More complicated agent connectors can be implemented by extending this
|
| 191 |
+
# AgentConnector class:
|
| 192 |
+
|
| 193 |
+
class FrameSkippingAgentConnector(AgentConnector):
|
| 194 |
+
def __init__(self, n):
|
| 195 |
+
self._n = n
|
| 196 |
+
self._frame_count = default_dict(str, default_dict(str, int))
|
| 197 |
+
|
| 198 |
+
def reset(self, env_id: str):
|
| 199 |
+
del self._frame_count[env_id]
|
| 200 |
+
|
| 201 |
+
def __call__(
|
| 202 |
+
self, ac_data: List[AgentConnectorDataType]
|
| 203 |
+
) -> List[AgentConnectorDataType]:
|
| 204 |
+
ret = []
|
| 205 |
+
for d in ac_data:
|
| 206 |
+
assert d.env_id and d.agent_id, "Skipping works per agent!"
|
| 207 |
+
|
| 208 |
+
count = self._frame_count[ac_data.env_id][ac_data.agent_id]
|
| 209 |
+
self._frame_count[ac_data.env_id][ac_data.agent_id] = (
|
| 210 |
+
count + 1
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if count % self._n == 0:
|
| 214 |
+
ret.append(d)
|
| 215 |
+
return ret
|
| 216 |
+
|
| 217 |
+
As shown, an agent connector may choose to emit an empty list to stop input
|
| 218 |
+
observations from being further prosessed.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def reset(self, env_id: str):
|
| 222 |
+
"""Reset connector state for a specific environment.
|
| 223 |
+
|
| 224 |
+
For example, at the end of an episode.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
env_id: required. ID of a user environment. Required.
|
| 228 |
+
"""
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
def on_policy_output(self, output: ActionConnectorDataType):
|
| 232 |
+
"""Callback on agent connector of policy output.
|
| 233 |
+
|
| 234 |
+
This is useful for certain connectors, for example RNN state buffering,
|
| 235 |
+
where the agent connect needs to be aware of the output of a policy
|
| 236 |
+
forward pass.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
ctx: Context for running this connector call.
|
| 240 |
+
output: Env and agent IDs, plus data output from policy forward pass.
|
| 241 |
+
"""
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
def __call__(
|
| 245 |
+
self, acd_list: List[AgentConnectorDataType]
|
| 246 |
+
) -> List[AgentConnectorDataType]:
|
| 247 |
+
"""Transform a list of data items from env before they reach policy.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
ac_data: List of env and agent IDs, plus arbitrary data items from
|
| 251 |
+
an environment or upstream agent connectors.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
A list of transformed data items in AgentConnectorDataType format.
|
| 255 |
+
The shape of a returned list does not have to match that of the input list.
|
| 256 |
+
An AgentConnector may choose to derive multiple outputs for a single piece
|
| 257 |
+
of input data, for example multi-agent obs -> multiple single agent obs.
|
| 258 |
+
Agent connectors may also choose to skip emitting certain inputs,
|
| 259 |
+
useful for connectors such as frame skipping.
|
| 260 |
+
"""
|
| 261 |
+
assert isinstance(
|
| 262 |
+
acd_list, (list, tuple)
|
| 263 |
+
), "Input to agent connectors are list of AgentConnectorDataType."
|
| 264 |
+
# Default implementation. Simply call transform on each agent connector data.
|
| 265 |
+
return [self.transform(d) for d in acd_list]
|
| 266 |
+
|
| 267 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 268 |
+
"""Transform a single agent connector data item.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
data: Env and agent IDs, plus arbitrary data item from a single agent
|
| 272 |
+
of an environment.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
A transformed piece of agent connector data.
|
| 276 |
+
"""
|
| 277 |
+
raise NotImplementedError
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@OldAPIStack
|
| 281 |
+
class ActionConnector(Connector):
|
| 282 |
+
"""Action connector connects policy outputs including actions,
|
| 283 |
+
to user environments.
|
| 284 |
+
|
| 285 |
+
An action connector transforms a single piece of policy output in
|
| 286 |
+
ActionConnectorDataType format, which is basically PolicyOutputType plus env and
|
| 287 |
+
agent IDs.
|
| 288 |
+
|
| 289 |
+
Any functions that operate directly on PolicyOutputType can be easily adapted
|
| 290 |
+
into an ActionConnector by using register_lambda_action_connector.
|
| 291 |
+
|
| 292 |
+
Example:
|
| 293 |
+
|
| 294 |
+
.. testcode::
|
| 295 |
+
|
| 296 |
+
from ray.rllib.connectors.action.lambdas import (
|
| 297 |
+
register_lambda_action_connector
|
| 298 |
+
)
|
| 299 |
+
ZeroActionConnector = register_lambda_action_connector(
|
| 300 |
+
"ZeroActionsConnector",
|
| 301 |
+
lambda actions, states, fetches: (
|
| 302 |
+
np.zeros_like(actions), states, fetches
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
More complicated action connectors can also be implemented by sub-classing
|
| 307 |
+
this ActionConnector class.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
| 311 |
+
"""Transform policy output before they are sent to a user environment.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
ac_data: Env and agent IDs, plus policy output.
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
The processed action connector data.
|
| 318 |
+
"""
|
| 319 |
+
return self.transform(ac_data)
|
| 320 |
+
|
| 321 |
+
def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
| 322 |
+
"""Implementation of the actual transform.
|
| 323 |
+
|
| 324 |
+
Users should override transform instead of __call__ directly.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
ac_data: Env and agent IDs, plus policy output.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
The processed action connector data.
|
| 331 |
+
"""
|
| 332 |
+
raise NotImplementedError
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
@OldAPIStack
|
| 336 |
+
class ConnectorPipeline(abc.ABC):
|
| 337 |
+
"""Utility class for quick manipulation of a connector pipeline."""
|
| 338 |
+
|
| 339 |
+
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
| 340 |
+
self.connectors = connectors
|
| 341 |
+
|
| 342 |
+
def in_training(self):
|
| 343 |
+
for c in self.connectors:
|
| 344 |
+
c.in_training()
|
| 345 |
+
|
| 346 |
+
def in_eval(self):
|
| 347 |
+
for c in self.connectors:
|
| 348 |
+
c.in_eval()
|
| 349 |
+
|
| 350 |
+
def remove(self, name: str):
|
| 351 |
+
"""Remove a connector by <name>
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
name: name of the connector to be removed.
|
| 355 |
+
"""
|
| 356 |
+
idx = -1
|
| 357 |
+
for i, c in enumerate(self.connectors):
|
| 358 |
+
if c.__class__.__name__ == name:
|
| 359 |
+
idx = i
|
| 360 |
+
break
|
| 361 |
+
if idx >= 0:
|
| 362 |
+
del self.connectors[idx]
|
| 363 |
+
logger.info(f"Removed connector {name} from {self.__class__.__name__}.")
|
| 364 |
+
else:
|
| 365 |
+
logger.warning(f"Trying to remove a non-existent connector {name}.")
|
| 366 |
+
|
| 367 |
+
def insert_before(self, name: str, connector: Connector):
|
| 368 |
+
"""Insert a new connector before connector <name>
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
name: name of the connector before which a new connector
|
| 372 |
+
will get inserted.
|
| 373 |
+
connector: a new connector to be inserted.
|
| 374 |
+
"""
|
| 375 |
+
idx = -1
|
| 376 |
+
for idx, c in enumerate(self.connectors):
|
| 377 |
+
if c.__class__.__name__ == name:
|
| 378 |
+
break
|
| 379 |
+
if idx < 0:
|
| 380 |
+
raise ValueError(f"Can not find connector {name}")
|
| 381 |
+
self.connectors.insert(idx, connector)
|
| 382 |
+
|
| 383 |
+
logger.info(
|
| 384 |
+
f"Inserted {connector.__class__.__name__} before {name} "
|
| 385 |
+
f"to {self.__class__.__name__}."
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def insert_after(self, name: str, connector: Connector):
|
| 389 |
+
"""Insert a new connector after connector <name>
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
name: name of the connector after which a new connector
|
| 393 |
+
will get inserted.
|
| 394 |
+
connector: a new connector to be inserted.
|
| 395 |
+
"""
|
| 396 |
+
idx = -1
|
| 397 |
+
for idx, c in enumerate(self.connectors):
|
| 398 |
+
if c.__class__.__name__ == name:
|
| 399 |
+
break
|
| 400 |
+
if idx < 0:
|
| 401 |
+
raise ValueError(f"Can not find connector {name}")
|
| 402 |
+
self.connectors.insert(idx + 1, connector)
|
| 403 |
+
|
| 404 |
+
logger.info(
|
| 405 |
+
f"Inserted {connector.__class__.__name__} after {name} "
|
| 406 |
+
f"to {self.__class__.__name__}."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def prepend(self, connector: Connector):
|
| 410 |
+
"""Append a new connector at the beginning of a connector pipeline.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
connector: a new connector to be appended.
|
| 414 |
+
"""
|
| 415 |
+
self.connectors.insert(0, connector)
|
| 416 |
+
|
| 417 |
+
logger.info(
|
| 418 |
+
f"Added {connector.__class__.__name__} to the beginning of "
|
| 419 |
+
f"{self.__class__.__name__}."
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def append(self, connector: Connector):
|
| 423 |
+
"""Append a new connector at the end of a connector pipeline.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
connector: a new connector to be appended.
|
| 427 |
+
"""
|
| 428 |
+
self.connectors.append(connector)
|
| 429 |
+
|
| 430 |
+
logger.info(
|
| 431 |
+
f"Added {connector.__class__.__name__} to the end of "
|
| 432 |
+
f"{self.__class__.__name__}."
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
def __str__(self, indentation: int = 0):
|
| 436 |
+
return "\n".join(
|
| 437 |
+
[" " * indentation + self.__class__.__name__]
|
| 438 |
+
+ [c.__str__(indentation + 4) for c in self.connectors]
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def __getitem__(self, key: Union[str, int, type]):
|
| 442 |
+
"""Returns a list of connectors that fit 'key'.
|
| 443 |
+
|
| 444 |
+
If key is a number n, we return a list with the nth element of this pipeline.
|
| 445 |
+
If key is a Connector class or a string matching the class name of a
|
| 446 |
+
Connector class, we return a list of all connectors in this pipeline matching
|
| 447 |
+
the specified class.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
key: The key to index by
|
| 451 |
+
|
| 452 |
+
Returns: The Connector at index `key`.
|
| 453 |
+
"""
|
| 454 |
+
# In case key is a class
|
| 455 |
+
if not isinstance(key, str):
|
| 456 |
+
if isinstance(key, slice):
|
| 457 |
+
raise NotImplementedError(
|
| 458 |
+
"Slicing of ConnectorPipeline is currently not supported."
|
| 459 |
+
)
|
| 460 |
+
elif isinstance(key, int):
|
| 461 |
+
return [self.connectors[key]]
|
| 462 |
+
elif isinstance(key, type):
|
| 463 |
+
results = []
|
| 464 |
+
for c in self.connectors:
|
| 465 |
+
if issubclass(c.__class__, key):
|
| 466 |
+
results.append(c)
|
| 467 |
+
return results
|
| 468 |
+
else:
|
| 469 |
+
raise NotImplementedError(
|
| 470 |
+
"Indexing by {} is currently not supported.".format(type(key))
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
results = []
|
| 474 |
+
for c in self.connectors:
|
| 475 |
+
if c.__class__.__name__ == key:
|
| 476 |
+
results.append(c)
|
| 477 |
+
|
| 478 |
+
return results
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
|
| 3 |
+
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
|
| 6 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 7 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 8 |
+
from ray.rllib.utils.annotations import override
|
| 9 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 10 |
+
from ray.rllib.utils.metrics import TIMERS, CONNECTOR_TIMERS
|
| 11 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 12 |
+
from ray.rllib.utils.typing import EpisodeType, StateDict
|
| 13 |
+
from ray.util.annotations import PublicAPI
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@PublicAPI(stability="alpha")
|
| 19 |
+
class ConnectorPipelineV2(ConnectorV2):
|
| 20 |
+
"""Utility class for quick manipulation of a connector pipeline."""
|
| 21 |
+
|
| 22 |
+
@override(ConnectorV2)
|
| 23 |
+
def recompute_output_observation_space(
|
| 24 |
+
self,
|
| 25 |
+
input_observation_space: gym.Space,
|
| 26 |
+
input_action_space: gym.Space,
|
| 27 |
+
) -> gym.Space:
|
| 28 |
+
self._fix_spaces(input_observation_space, input_action_space)
|
| 29 |
+
return self.observation_space
|
| 30 |
+
|
| 31 |
+
@override(ConnectorV2)
|
| 32 |
+
def recompute_output_action_space(
|
| 33 |
+
self,
|
| 34 |
+
input_observation_space: gym.Space,
|
| 35 |
+
input_action_space: gym.Space,
|
| 36 |
+
) -> gym.Space:
|
| 37 |
+
self._fix_spaces(input_observation_space, input_action_space)
|
| 38 |
+
return self.action_space
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
input_observation_space: Optional[gym.Space] = None,
|
| 43 |
+
input_action_space: Optional[gym.Space] = None,
|
| 44 |
+
*,
|
| 45 |
+
connectors: Optional[List[ConnectorV2]] = None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
"""Initializes a ConnectorPipelineV2 instance.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
input_observation_space: The (optional) input observation space for this
|
| 52 |
+
connector piece. This is the space coming from a previous connector
|
| 53 |
+
piece in the (env-to-module or learner) pipeline or is directly
|
| 54 |
+
defined within the gym.Env.
|
| 55 |
+
input_action_space: The (optional) input action space for this connector
|
| 56 |
+
piece. This is the space coming from a previous connector piece in the
|
| 57 |
+
(module-to-env) pipeline or is directly defined within the gym.Env.
|
| 58 |
+
connectors: A list of individual ConnectorV2 pieces to be added to this
|
| 59 |
+
pipeline during construction. Note that you can always add (or remove)
|
| 60 |
+
more ConnectorV2 pieces later on the fly.
|
| 61 |
+
"""
|
| 62 |
+
self.connectors = []
|
| 63 |
+
|
| 64 |
+
for conn in connectors:
|
| 65 |
+
# If we have a `ConnectorV2` instance just append.
|
| 66 |
+
if isinstance(conn, ConnectorV2):
|
| 67 |
+
self.connectors.append(conn)
|
| 68 |
+
# If, we have a class with `args` and `kwargs`, build the instance.
|
| 69 |
+
# Note that this way of constructing a pipeline should only be
|
| 70 |
+
# used internally when restoring the pipeline state from a
|
| 71 |
+
# checkpoint.
|
| 72 |
+
elif isinstance(conn, tuple) and len(conn) == 3:
|
| 73 |
+
self.connectors.append(conn[0](*conn[1], **conn[2]))
|
| 74 |
+
|
| 75 |
+
super().__init__(input_observation_space, input_action_space, **kwargs)
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
return len(self.connectors)
|
| 79 |
+
|
| 80 |
+
@override(ConnectorV2)
|
| 81 |
+
def __call__(
|
| 82 |
+
self,
|
| 83 |
+
*,
|
| 84 |
+
rl_module: RLModule,
|
| 85 |
+
batch: Dict[str, Any],
|
| 86 |
+
episodes: List[EpisodeType],
|
| 87 |
+
explore: Optional[bool] = None,
|
| 88 |
+
shared_data: Optional[dict] = None,
|
| 89 |
+
metrics: Optional[MetricsLogger] = None,
|
| 90 |
+
**kwargs,
|
| 91 |
+
) -> Any:
|
| 92 |
+
"""In a pipeline, we simply call each of our connector pieces after each other.
|
| 93 |
+
|
| 94 |
+
Each connector piece receives as input the output of the previous connector
|
| 95 |
+
piece in the pipeline.
|
| 96 |
+
"""
|
| 97 |
+
shared_data = shared_data if shared_data is not None else {}
|
| 98 |
+
# Loop through connector pieces and call each one with the output of the
|
| 99 |
+
# previous one. Thereby, time each connector piece's call.
|
| 100 |
+
for connector in self.connectors:
|
| 101 |
+
# TODO (sven): Add MetricsLogger to non-Learner components that have a
|
| 102 |
+
# LearnerConnector pipeline.
|
| 103 |
+
stats = None
|
| 104 |
+
if metrics:
|
| 105 |
+
stats = metrics.log_time(
|
| 106 |
+
kwargs.get("metrics_prefix_key", ())
|
| 107 |
+
+ (TIMERS, CONNECTOR_TIMERS, connector.__class__.__name__)
|
| 108 |
+
)
|
| 109 |
+
stats.__enter__()
|
| 110 |
+
|
| 111 |
+
batch = connector(
|
| 112 |
+
rl_module=rl_module,
|
| 113 |
+
batch=batch,
|
| 114 |
+
episodes=episodes,
|
| 115 |
+
explore=explore,
|
| 116 |
+
shared_data=shared_data,
|
| 117 |
+
metrics=metrics,
|
| 118 |
+
# Deprecated arg.
|
| 119 |
+
data=batch,
|
| 120 |
+
**kwargs,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if metrics:
|
| 124 |
+
stats.__exit__(None, None, None)
|
| 125 |
+
|
| 126 |
+
if not isinstance(batch, dict):
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"`data` returned by ConnectorV2 {connector} must be a dict! "
|
| 129 |
+
f"You returned {batch}. Check your (custom) connectors' "
|
| 130 |
+
f"`__call__()` method's return value and make sure you return "
|
| 131 |
+
f"the `data` arg passed in (either altered or unchanged)."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return batch
|
| 135 |
+
|
| 136 |
+
def remove(self, name_or_class: Union[str, Type]):
|
| 137 |
+
"""Remove a single connector piece in this pipeline by its name or class.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
name: The name of the connector piece to be removed from the pipeline.
|
| 141 |
+
"""
|
| 142 |
+
idx = -1
|
| 143 |
+
for i, c in enumerate(self.connectors):
|
| 144 |
+
if c.__class__.__name__ == name_or_class:
|
| 145 |
+
idx = i
|
| 146 |
+
break
|
| 147 |
+
if idx >= 0:
|
| 148 |
+
del self.connectors[idx]
|
| 149 |
+
self._fix_spaces(self.input_observation_space, self.input_action_space)
|
| 150 |
+
logger.info(
|
| 151 |
+
f"Removed connector {name_or_class} from {self.__class__.__name__}."
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
logger.warning(
|
| 155 |
+
f"Trying to remove a non-existent connector {name_or_class}."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def insert_before(
|
| 159 |
+
self,
|
| 160 |
+
name_or_class: Union[str, type],
|
| 161 |
+
connector: ConnectorV2,
|
| 162 |
+
) -> ConnectorV2:
|
| 163 |
+
"""Insert a new connector piece before an existing piece (by name or class).
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
name_or_class: Name or class of the connector piece before which `connector`
|
| 167 |
+
will get inserted.
|
| 168 |
+
connector: The new connector piece to be inserted.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
The ConnectorV2 before which `connector` has been inserted.
|
| 172 |
+
"""
|
| 173 |
+
idx = -1
|
| 174 |
+
for idx, c in enumerate(self.connectors):
|
| 175 |
+
if (
|
| 176 |
+
isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class
|
| 177 |
+
) or (isinstance(name_or_class, type) and c.__class__ is name_or_class):
|
| 178 |
+
break
|
| 179 |
+
if idx < 0:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"Can not find connector with name or type '{name_or_class}'!"
|
| 182 |
+
)
|
| 183 |
+
next_connector = self.connectors[idx]
|
| 184 |
+
|
| 185 |
+
self.connectors.insert(idx, connector)
|
| 186 |
+
self._fix_spaces(self.input_observation_space, self.input_action_space)
|
| 187 |
+
|
| 188 |
+
logger.info(
|
| 189 |
+
f"Inserted {connector.__class__.__name__} before {name_or_class} "
|
| 190 |
+
f"to {self.__class__.__name__}."
|
| 191 |
+
)
|
| 192 |
+
return next_connector
|
| 193 |
+
|
| 194 |
+
def insert_after(
|
| 195 |
+
self,
|
| 196 |
+
name_or_class: Union[str, Type],
|
| 197 |
+
connector: ConnectorV2,
|
| 198 |
+
) -> ConnectorV2:
|
| 199 |
+
"""Insert a new connector piece after an existing piece (by name or class).
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
name_or_class: Name or class of the connector piece after which `connector`
|
| 203 |
+
will get inserted.
|
| 204 |
+
connector: The new connector piece to be inserted.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
The ConnectorV2 after which `connector` has been inserted.
|
| 208 |
+
"""
|
| 209 |
+
idx = -1
|
| 210 |
+
for idx, c in enumerate(self.connectors):
|
| 211 |
+
if (
|
| 212 |
+
isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class
|
| 213 |
+
) or (isinstance(name_or_class, type) and c.__class__ is name_or_class):
|
| 214 |
+
break
|
| 215 |
+
if idx < 0:
|
| 216 |
+
raise ValueError(
|
| 217 |
+
f"Can not find connector with name or type '{name_or_class}'!"
|
| 218 |
+
)
|
| 219 |
+
prev_connector = self.connectors[idx]
|
| 220 |
+
|
| 221 |
+
self.connectors.insert(idx + 1, connector)
|
| 222 |
+
self._fix_spaces(self.input_observation_space, self.input_action_space)
|
| 223 |
+
|
| 224 |
+
logger.info(
|
| 225 |
+
f"Inserted {connector.__class__.__name__} after {name_or_class} "
|
| 226 |
+
f"to {self.__class__.__name__}."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return prev_connector
|
| 230 |
+
|
| 231 |
+
def prepend(self, connector: ConnectorV2) -> None:
|
| 232 |
+
"""Prepend a new connector at the beginning of a connector pipeline.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
connector: The new connector piece to be prepended to this pipeline.
|
| 236 |
+
"""
|
| 237 |
+
self.connectors.insert(0, connector)
|
| 238 |
+
self._fix_spaces(self.input_observation_space, self.input_action_space)
|
| 239 |
+
|
| 240 |
+
logger.info(
|
| 241 |
+
f"Added {connector.__class__.__name__} to the beginning of "
|
| 242 |
+
f"{self.__class__.__name__}."
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def append(self, connector: ConnectorV2) -> None:
|
| 246 |
+
"""Append a new connector at the end of a connector pipeline.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
connector: The new connector piece to be appended to this pipeline.
|
| 250 |
+
"""
|
| 251 |
+
self.connectors.append(connector)
|
| 252 |
+
self._fix_spaces(self.input_observation_space, self.input_action_space)
|
| 253 |
+
|
| 254 |
+
logger.info(
|
| 255 |
+
f"Added {connector.__class__.__name__} to the end of "
|
| 256 |
+
f"{self.__class__.__name__}."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
@override(ConnectorV2)
|
| 260 |
+
def get_state(
|
| 261 |
+
self,
|
| 262 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 263 |
+
*,
|
| 264 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 265 |
+
**kwargs,
|
| 266 |
+
) -> StateDict:
|
| 267 |
+
state = {}
|
| 268 |
+
for conn in self.connectors:
|
| 269 |
+
conn_name = type(conn).__name__
|
| 270 |
+
if self._check_component(conn_name, components, not_components):
|
| 271 |
+
state[conn_name] = conn.get_state(
|
| 272 |
+
components=self._get_subcomponents(conn_name, components),
|
| 273 |
+
not_components=self._get_subcomponents(conn_name, not_components),
|
| 274 |
+
**kwargs,
|
| 275 |
+
)
|
| 276 |
+
return state
|
| 277 |
+
|
| 278 |
+
@override(ConnectorV2)
|
| 279 |
+
def set_state(self, state: Dict[str, Any]) -> None:
|
| 280 |
+
for conn in self.connectors:
|
| 281 |
+
conn_name = type(conn).__name__
|
| 282 |
+
if conn_name in state:
|
| 283 |
+
conn.set_state(state[conn_name])
|
| 284 |
+
|
| 285 |
+
@override(Checkpointable)
|
| 286 |
+
def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
|
| 287 |
+
return [(type(conn).__name__, conn) for conn in self.connectors]
|
| 288 |
+
|
| 289 |
+
# Note that we don't have to override Checkpointable.get_ctor_args_and_kwargs and
|
| 290 |
+
# don't have to return the `connectors` c'tor kwarg from there. This is b/c all
|
| 291 |
+
# connector pieces in this pipeline are themselves Checkpointable components,
|
| 292 |
+
# so they will be properly written into this pipeline's checkpoint.
|
| 293 |
+
@override(Checkpointable)
|
| 294 |
+
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
|
| 295 |
+
return (
|
| 296 |
+
(self.input_observation_space, self.input_action_space), # *args
|
| 297 |
+
{
|
| 298 |
+
"connectors": [
|
| 299 |
+
(type(conn), *conn.get_ctor_args_and_kwargs())
|
| 300 |
+
for conn in self.connectors
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
@override(ConnectorV2)
|
| 306 |
+
def reset_state(self) -> None:
|
| 307 |
+
for conn in self.connectors:
|
| 308 |
+
conn.reset_state()
|
| 309 |
+
|
| 310 |
+
@override(ConnectorV2)
|
| 311 |
+
def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 312 |
+
merged_states = {}
|
| 313 |
+
if not states:
|
| 314 |
+
return merged_states
|
| 315 |
+
for i, (key, item) in enumerate(states[0].items()):
|
| 316 |
+
state_list = [state[key] for state in states]
|
| 317 |
+
conn = self.connectors[i]
|
| 318 |
+
merged_states[key] = conn.merge_states(state_list)
|
| 319 |
+
return merged_states
|
| 320 |
+
|
| 321 |
+
def __repr__(self, indentation: int = 0):
|
| 322 |
+
return "\n".join(
|
| 323 |
+
[" " * indentation + self.__class__.__name__]
|
| 324 |
+
+ [c.__str__(indentation + 4) for c in self.connectors]
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def __getitem__(
|
| 328 |
+
self,
|
| 329 |
+
key: Union[str, int, Type],
|
| 330 |
+
) -> Union[ConnectorV2, List[ConnectorV2]]:
|
| 331 |
+
"""Returns a single ConnectorV2 or list of ConnectorV2s that fit `key`.
|
| 332 |
+
|
| 333 |
+
If key is an int, we return a single ConnectorV2 at that index in this pipeline.
|
| 334 |
+
If key is a ConnectorV2 type or a string matching the class name of a
|
| 335 |
+
ConnectorV2 in this pipeline, we return a list of all ConnectorV2s in this
|
| 336 |
+
pipeline matching the specified class.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
key: The key to find or to index by.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
A single ConnectorV2 or a list of ConnectorV2s matching `key`.
|
| 343 |
+
"""
|
| 344 |
+
# Key is an int -> Index into pipeline and return.
|
| 345 |
+
if isinstance(key, int):
|
| 346 |
+
return self.connectors[key]
|
| 347 |
+
# Key is a class.
|
| 348 |
+
elif isinstance(key, type):
|
| 349 |
+
results = []
|
| 350 |
+
for c in self.connectors:
|
| 351 |
+
if issubclass(c.__class__, key):
|
| 352 |
+
results.append(c)
|
| 353 |
+
return results
|
| 354 |
+
# Key is a string -> Find connector(s) by name.
|
| 355 |
+
elif isinstance(key, str):
|
| 356 |
+
results = []
|
| 357 |
+
for c in self.connectors:
|
| 358 |
+
if c.name == key:
|
| 359 |
+
results.append(c)
|
| 360 |
+
return results
|
| 361 |
+
# Slicing not supported (yet).
|
| 362 |
+
elif isinstance(key, slice):
|
| 363 |
+
raise NotImplementedError(
|
| 364 |
+
"Slicing of ConnectorPipelineV2 is currently not supported!"
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
raise NotImplementedError(
|
| 368 |
+
f"Indexing ConnectorPipelineV2 by {type(key)} is currently not "
|
| 369 |
+
f"supported!"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def observation_space(self):
|
| 374 |
+
if len(self) > 0:
|
| 375 |
+
return self.connectors[-1].observation_space
|
| 376 |
+
return self._observation_space
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def action_space(self):
|
| 380 |
+
if len(self) > 0:
|
| 381 |
+
return self.connectors[-1].action_space
|
| 382 |
+
return self._action_space
|
| 383 |
+
|
| 384 |
+
def _fix_spaces(self, input_observation_space, input_action_space):
|
| 385 |
+
if len(self) > 0:
|
| 386 |
+
# Fix each connector's input_observation- and input_action space in
|
| 387 |
+
# the pipeline.
|
| 388 |
+
obs_space = input_observation_space
|
| 389 |
+
act_space = input_action_space
|
| 390 |
+
for con in self.connectors:
|
| 391 |
+
con.input_action_space = act_space
|
| 392 |
+
con.input_observation_space = obs_space
|
| 393 |
+
obs_space = con.observation_space
|
| 394 |
+
act_space = con.action_space
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py
ADDED
|
@@ -0,0 +1,1017 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import (
|
| 5 |
+
Any,
|
| 6 |
+
Callable,
|
| 7 |
+
Collection,
|
| 8 |
+
Dict,
|
| 9 |
+
Iterator,
|
| 10 |
+
List,
|
| 11 |
+
Optional,
|
| 12 |
+
Tuple,
|
| 13 |
+
Union,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
import gymnasium as gym
|
| 17 |
+
import tree
|
| 18 |
+
|
| 19 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 20 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 21 |
+
from ray.rllib.utils import force_list
|
| 22 |
+
from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic
|
| 23 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 24 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 25 |
+
from ray.rllib.utils.spaces.space_utils import BatchedNdArray
|
| 26 |
+
from ray.rllib.utils.typing import AgentID, EpisodeType, ModuleID, StateDict
|
| 27 |
+
from ray.util.annotations import PublicAPI
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@PublicAPI(stability="alpha")
|
| 31 |
+
class ConnectorV2(Checkpointable, abc.ABC):
|
| 32 |
+
"""Base class defining the API for an individual "connector piece".
|
| 33 |
+
|
| 34 |
+
A ConnectorV2 ("connector piece") is usually part of a whole series of connector
|
| 35 |
+
pieces within a so-called connector pipeline, which in itself also abides to this
|
| 36 |
+
very API.
|
| 37 |
+
For example, you might have a connector pipeline consisting of two connector pieces,
|
| 38 |
+
A and B, both instances of subclasses of ConnectorV2 and each one performing a
|
| 39 |
+
particular transformation on their input data. The resulting connector pipeline
|
| 40 |
+
(A->B) itself also abides to this very ConnectorV2 API and could thus be part of yet
|
| 41 |
+
another, higher-level connector pipeline, e.g. (A->B)->C->D.
|
| 42 |
+
|
| 43 |
+
Any ConnectorV2 instance (individual pieces or several connector pieces in a
|
| 44 |
+
pipeline) is a callable and users should override the `__call__()` method.
|
| 45 |
+
When called, they take the outputs of a previous connector piece (or an empty dict
|
| 46 |
+
if there are no previous pieces) and all the data collected thus far in the
|
| 47 |
+
ongoing episode(s) (only applies to connectors used in EnvRunners) or retrieved
|
| 48 |
+
from a replay buffer or from an environment sampling step (only applies to
|
| 49 |
+
connectors used in Learner pipelines). From this input data, a ConnectorV2 then
|
| 50 |
+
performs a transformation step.
|
| 51 |
+
|
| 52 |
+
There are 3 types of pipelines any ConnectorV2 piece can belong to:
|
| 53 |
+
1) EnvToModulePipeline: The connector transforms environment data before it gets to
|
| 54 |
+
the RLModule. This type of pipeline is used by an EnvRunner for transforming
|
| 55 |
+
env output data into RLModule readable data (for the next RLModule forward pass).
|
| 56 |
+
For example, such a pipeline would include observation postprocessors, -filters,
|
| 57 |
+
or any RNN preparation code related to time-sequences and zero-padding.
|
| 58 |
+
2) ModuleToEnvPipeline: This type of pipeline is used by an
|
| 59 |
+
EnvRunner to transform RLModule output data to env readable actions (for the next
|
| 60 |
+
`env.step()` call). For example, in case the RLModule only outputs action
|
| 61 |
+
distribution parameters (but not actual actions), the ModuleToEnvPipeline would
|
| 62 |
+
take care of sampling the actions to be sent back to the end from the
|
| 63 |
+
resulting distribution (made deterministic if exploration is off).
|
| 64 |
+
3) LearnerConnectorPipeline: This connector pipeline type transforms data coming
|
| 65 |
+
from an `EnvRunner.sample()` call or a replay buffer and will then be sent into the
|
| 66 |
+
RLModule's `forward_train()` method in order to compute loss function inputs.
|
| 67 |
+
This type of pipeline is used by a Learner worker to transform raw training data
|
| 68 |
+
(a batch or a list of episodes) to RLModule readable training data (for the next
|
| 69 |
+
RLModule `forward_train()` call).
|
| 70 |
+
|
| 71 |
+
Some connectors might be stateful, for example for keeping track of observation
|
| 72 |
+
filtering stats (mean and stddev values). Any Algorithm, which uses connectors is
|
| 73 |
+
responsible for frequently synchronizing the states of all connectors and connector
|
| 74 |
+
pipelines between the EnvRunners (owning the env-to-module and module-to-env
|
| 75 |
+
pipelines) and the Learners (owning the Learner pipelines).
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
input_observation_space: Optional[gym.Space] = None,
|
| 81 |
+
input_action_space: Optional[gym.Space] = None,
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
"""Initializes a ConnectorV2 instance.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
input_observation_space: The (optional) input observation space for this
|
| 88 |
+
connector piece. This is the space coming from a previous connector
|
| 89 |
+
piece in the (env-to-module or learner) pipeline or is directly
|
| 90 |
+
defined within the gym.Env.
|
| 91 |
+
input_action_space: The (optional) input action space for this connector
|
| 92 |
+
piece. This is the space coming from a previous connector piece in the
|
| 93 |
+
(module-to-env) pipeline or is directly defined within the gym.Env.
|
| 94 |
+
**kwargs: Forward API-compatibility kwargs.
|
| 95 |
+
"""
|
| 96 |
+
self._observation_space = None
|
| 97 |
+
self._action_space = None
|
| 98 |
+
self._input_observation_space = None
|
| 99 |
+
self._input_action_space = None
|
| 100 |
+
|
| 101 |
+
self.input_action_space = input_action_space
|
| 102 |
+
self.input_observation_space = input_observation_space
|
| 103 |
+
|
| 104 |
+
# Store child's constructor args and kwargs for the default
|
| 105 |
+
# `get_ctor_args_and_kwargs` implementation (to be able to restore from a
|
| 106 |
+
# checkpoint).
|
| 107 |
+
if self.__class__.__dict__.get("__init__") is not None:
|
| 108 |
+
caller_frame = inspect.stack()[1].frame
|
| 109 |
+
arg_info = inspect.getargvalues(caller_frame)
|
| 110 |
+
# Separate positional arguments and keyword arguments.
|
| 111 |
+
caller_locals = (
|
| 112 |
+
arg_info.locals
|
| 113 |
+
) # Dictionary of all local variables in the caller
|
| 114 |
+
self._ctor_kwargs = {
|
| 115 |
+
arg: caller_locals[arg] for arg in arg_info.args if arg != "self"
|
| 116 |
+
}
|
| 117 |
+
else:
|
| 118 |
+
self._ctor_kwargs = {
|
| 119 |
+
"input_observation_space": self.input_observation_space,
|
| 120 |
+
"input_action_space": self.input_action_space,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
@OverrideToImplementCustomLogic
|
| 124 |
+
def recompute_output_observation_space(
|
| 125 |
+
self,
|
| 126 |
+
input_observation_space: gym.Space,
|
| 127 |
+
input_action_space: gym.Space,
|
| 128 |
+
) -> gym.Space:
|
| 129 |
+
"""Re-computes a new (output) observation space based on the input spaces.
|
| 130 |
+
|
| 131 |
+
This method should be overridden by users to make sure a ConnectorPipelineV2
|
| 132 |
+
knows how the input spaces through its individual ConnectorV2 pieces are being
|
| 133 |
+
transformed.
|
| 134 |
+
|
| 135 |
+
.. testcode::
|
| 136 |
+
|
| 137 |
+
from gymnasium.spaces import Box, Discrete
|
| 138 |
+
import numpy as np
|
| 139 |
+
|
| 140 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 141 |
+
from ray.rllib.utils.numpy import one_hot
|
| 142 |
+
from ray.rllib.utils.test_utils import check
|
| 143 |
+
|
| 144 |
+
class OneHotConnector(ConnectorV2):
|
| 145 |
+
def recompute_output_observation_space(
|
| 146 |
+
self,
|
| 147 |
+
input_observation_space,
|
| 148 |
+
input_action_space,
|
| 149 |
+
):
|
| 150 |
+
return Box(0.0, 1.0, (input_observation_space.n,), np.float32)
|
| 151 |
+
|
| 152 |
+
def __call__(
|
| 153 |
+
self,
|
| 154 |
+
*,
|
| 155 |
+
rl_module,
|
| 156 |
+
batch,
|
| 157 |
+
episodes,
|
| 158 |
+
explore=None,
|
| 159 |
+
shared_data=None,
|
| 160 |
+
metrics=None,
|
| 161 |
+
**kwargs,
|
| 162 |
+
):
|
| 163 |
+
assert "obs" in batch
|
| 164 |
+
batch["obs"] = one_hot(batch["obs"])
|
| 165 |
+
return batch
|
| 166 |
+
|
| 167 |
+
connector = OneHotConnector(input_observation_space=Discrete(2))
|
| 168 |
+
batch = {"obs": np.array([1, 0, 0], np.int32)}
|
| 169 |
+
output = connector(rl_module=None, batch=batch, episodes=None)
|
| 170 |
+
|
| 171 |
+
check(output, {"obs": np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0]])})
|
| 172 |
+
|
| 173 |
+
If this ConnectorV2 does not change the observation space in any way, leave
|
| 174 |
+
this parent method implementation untouched.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
input_observation_space: The input observation space (either coming from the
|
| 178 |
+
environment if `self` is the first connector piece in the pipeline or
|
| 179 |
+
from the previous connector piece in the pipeline).
|
| 180 |
+
input_action_space: The input action space (either coming from the
|
| 181 |
+
environment if `self is the first connector piece in the pipeline or
|
| 182 |
+
from the previous connector piece in the pipeline).
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
The new observation space (after data has passed through this ConnectorV2
|
| 186 |
+
piece).
|
| 187 |
+
"""
|
| 188 |
+
return self.input_observation_space
|
| 189 |
+
|
| 190 |
+
@OverrideToImplementCustomLogic
|
| 191 |
+
def recompute_output_action_space(
|
| 192 |
+
self,
|
| 193 |
+
input_observation_space: gym.Space,
|
| 194 |
+
input_action_space: gym.Space,
|
| 195 |
+
) -> gym.Space:
|
| 196 |
+
"""Re-computes a new (output) action space based on the input space.
|
| 197 |
+
|
| 198 |
+
This method should be overridden by users to make sure a ConnectorPipelineV2
|
| 199 |
+
knows how the input spaces through its individual ConnectorV2 pieces are being
|
| 200 |
+
transformed.
|
| 201 |
+
|
| 202 |
+
If this ConnectorV2 does not change the action space in any way, leave
|
| 203 |
+
this parent method implementation untouched.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
input_observation_space: The input observation space (either coming from the
|
| 207 |
+
environment if `self` is the first connector piece in the pipeline or
|
| 208 |
+
from the previous connector piece in the pipeline).
|
| 209 |
+
input_action_space: The input action space (either coming from the
|
| 210 |
+
environment if `self is the first connector piece in the pipeline or
|
| 211 |
+
from the previous connector piece in the pipeline).
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
The new action space (after data has passed through this ConenctorV2
|
| 215 |
+
piece).
|
| 216 |
+
"""
|
| 217 |
+
return self.input_action_space
|
| 218 |
+
|
| 219 |
+
@abc.abstractmethod
|
| 220 |
+
def __call__(
|
| 221 |
+
self,
|
| 222 |
+
*,
|
| 223 |
+
rl_module: RLModule,
|
| 224 |
+
batch: Dict[str, Any],
|
| 225 |
+
episodes: List[EpisodeType],
|
| 226 |
+
explore: Optional[bool] = None,
|
| 227 |
+
shared_data: Optional[dict] = None,
|
| 228 |
+
metrics: Optional[MetricsLogger] = None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
) -> Any:
|
| 231 |
+
"""Method for transforming an input `batch` into an output `batch`.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
rl_module: The RLModule object that the connector connects to or from.
|
| 235 |
+
batch: The input data to be transformed by this connector. Transformations
|
| 236 |
+
might either be done in-place or a new structure may be returned.
|
| 237 |
+
Note that the information in `batch` will eventually either become the
|
| 238 |
+
forward batch for the RLModule (env-to-module and learner connectors)
|
| 239 |
+
or the input to the `env.step()` call (module-to-env connectors). Note
|
| 240 |
+
that in the first case (`batch` is a forward batch for RLModule), the
|
| 241 |
+
information in `batch` will be discarded after that RLModule forward
|
| 242 |
+
pass. Any transformation of information (e.g. observation preprocessing)
|
| 243 |
+
that you have only done inside `batch` will be lost, unless you have
|
| 244 |
+
written it back into the corresponding `episodes` during the connector
|
| 245 |
+
pass.
|
| 246 |
+
episodes: The list of SingleAgentEpisode or MultiAgentEpisode objects,
|
| 247 |
+
each corresponding to one slot in the vector env. Note that episodes
|
| 248 |
+
can be read from (e.g. to place information into `batch`), but also
|
| 249 |
+
written to. You should only write back (changed, transformed)
|
| 250 |
+
information into the episodes, if you want these changes to be
|
| 251 |
+
"permanent". For example if you sample from an environment, pick up
|
| 252 |
+
observations from the episodes and place them into `batch`, then
|
| 253 |
+
transform these observations, and would like to make these
|
| 254 |
+
transformations permanent (note that `batch` gets discarded after the
|
| 255 |
+
RLModule forward pass), then you have to write the transformed
|
| 256 |
+
observations back into the episode to make sure you do not have to
|
| 257 |
+
perform the same transformation again on the learner (or replay buffer)
|
| 258 |
+
side. The Learner will hence work on the already changed episodes (and
|
| 259 |
+
compile the train batch using the Learner connector).
|
| 260 |
+
explore: Whether `explore` is currently on. Per convention, if True, the
|
| 261 |
+
RLModule's `forward_exploration` method should be called, if False, the
|
| 262 |
+
EnvRunner should call `forward_inference` instead.
|
| 263 |
+
shared_data: Optional additional context data that needs to be exchanged
|
| 264 |
+
between different ConnectorV2 pieces (in the same pipeline) or across
|
| 265 |
+
ConnectorV2 pipelines (meaning between env-to-module and module-to-env).
|
| 266 |
+
metrics: Optional MetricsLogger instance to log custom metrics to.
|
| 267 |
+
kwargs: Forward API-compatibility kwargs.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
The transformed connector output.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
def single_agent_episode_iterator(
|
| 275 |
+
episodes: List[EpisodeType],
|
| 276 |
+
agents_that_stepped_only: bool = True,
|
| 277 |
+
zip_with_batch_column: Optional[Union[List[Any], Dict[Tuple, Any]]] = None,
|
| 278 |
+
) -> Iterator[SingleAgentEpisode]:
|
| 279 |
+
"""An iterator over a list of episodes yielding always SingleAgentEpisodes.
|
| 280 |
+
|
| 281 |
+
In case items in the list are MultiAgentEpisodes, these are broken down
|
| 282 |
+
into their individual agents' SingleAgentEpisodes and those are then yielded
|
| 283 |
+
one after the other.
|
| 284 |
+
|
| 285 |
+
Useful for connectors that operate on both single-agent and multi-agent
|
| 286 |
+
episodes.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
episodes: The list of SingleAgent- or MultiAgentEpisode objects.
|
| 290 |
+
agents_that_stepped_only: If True (and multi-agent setup), will only place
|
| 291 |
+
items of those agents into the batch that have just stepped in the
|
| 292 |
+
actual MultiAgentEpisode (this is checked via a
|
| 293 |
+
`MultiAgentEpside.episode.get_agents_to_act()`). Note that this setting
|
| 294 |
+
is ignored in a single-agent setups b/c the agent steps at each timestep
|
| 295 |
+
regardless.
|
| 296 |
+
zip_with_batch_column: If provided, must be a list of batch items
|
| 297 |
+
corresponding to the given `episodes` (single agent case) or a dict
|
| 298 |
+
mapping (AgentID, ModuleID) tuples to lists of individual batch items
|
| 299 |
+
corresponding to this agent/module combination. The iterator will then
|
| 300 |
+
yield tuples of SingleAgentEpisode objects (1st item) along with the
|
| 301 |
+
data item (2nd item) that this episode was responsible for generating
|
| 302 |
+
originally.
|
| 303 |
+
|
| 304 |
+
Yields:
|
| 305 |
+
All SingleAgentEpisodes in the input list, whereby MultiAgentEpisodes will
|
| 306 |
+
be broken down into their individual SingleAgentEpisode components.
|
| 307 |
+
"""
|
| 308 |
+
list_indices = defaultdict(int)
|
| 309 |
+
|
| 310 |
+
# Single-agent case.
|
| 311 |
+
if episodes and isinstance(episodes[0], SingleAgentEpisode):
|
| 312 |
+
if zip_with_batch_column is not None:
|
| 313 |
+
if len(zip_with_batch_column) != len(episodes):
|
| 314 |
+
raise ValueError(
|
| 315 |
+
"Invalid `zip_with_batch_column` data: Must have the same "
|
| 316 |
+
f"length as the list of episodes ({len(episodes)}), but has "
|
| 317 |
+
f"length {len(zip_with_batch_column)}!"
|
| 318 |
+
)
|
| 319 |
+
# Simple case: Items are stored in lists directly under the column (str)
|
| 320 |
+
# key.
|
| 321 |
+
if isinstance(zip_with_batch_column, list):
|
| 322 |
+
for episode, data in zip(episodes, zip_with_batch_column):
|
| 323 |
+
yield episode, data
|
| 324 |
+
# Normal single-agent case: Items are stored in dicts under the column
|
| 325 |
+
# (str) key. These dicts map (eps_id,)-tuples to lists of individual
|
| 326 |
+
# items.
|
| 327 |
+
else:
|
| 328 |
+
for episode, (eps_id_tuple, data) in zip(
|
| 329 |
+
episodes,
|
| 330 |
+
zip_with_batch_column.items(),
|
| 331 |
+
):
|
| 332 |
+
assert episode.id_ == eps_id_tuple[0]
|
| 333 |
+
d = data[list_indices[eps_id_tuple]]
|
| 334 |
+
list_indices[eps_id_tuple] += 1
|
| 335 |
+
yield episode, d
|
| 336 |
+
else:
|
| 337 |
+
for episode in episodes:
|
| 338 |
+
yield episode
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
# Multi-agent case.
|
| 342 |
+
for episode in episodes:
|
| 343 |
+
for agent_id in (
|
| 344 |
+
episode.get_agents_that_stepped()
|
| 345 |
+
if agents_that_stepped_only
|
| 346 |
+
else episode.agent_ids
|
| 347 |
+
):
|
| 348 |
+
sa_episode = episode.agent_episodes[agent_id]
|
| 349 |
+
# for sa_episode in episode.agent_episodes.values():
|
| 350 |
+
if zip_with_batch_column is not None:
|
| 351 |
+
key = (
|
| 352 |
+
sa_episode.multi_agent_episode_id,
|
| 353 |
+
sa_episode.agent_id,
|
| 354 |
+
sa_episode.module_id,
|
| 355 |
+
)
|
| 356 |
+
if len(zip_with_batch_column[key]) <= list_indices[key]:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
"Invalid `zip_with_batch_column` data: Must structurally "
|
| 359 |
+
"match the single-agent contents in the given list of "
|
| 360 |
+
"(multi-agent) episodes!"
|
| 361 |
+
)
|
| 362 |
+
d = zip_with_batch_column[key][list_indices[key]]
|
| 363 |
+
list_indices[key] += 1
|
| 364 |
+
yield sa_episode, d
|
| 365 |
+
else:
|
| 366 |
+
yield sa_episode
|
| 367 |
+
|
| 368 |
+
@staticmethod
|
| 369 |
+
def add_batch_item(
|
| 370 |
+
batch: Dict[str, Any],
|
| 371 |
+
column: str,
|
| 372 |
+
item_to_add: Any,
|
| 373 |
+
single_agent_episode: Optional[SingleAgentEpisode] = None,
|
| 374 |
+
) -> None:
|
| 375 |
+
"""Adds a data item under `column` to the given `batch`.
|
| 376 |
+
|
| 377 |
+
The `item_to_add` is stored in the `batch` in the following manner:
|
| 378 |
+
1) If `single_agent_episode` is not provided (None), will store the item in a
|
| 379 |
+
list directly under `column`:
|
| 380 |
+
`column` -> [item, item, ...]
|
| 381 |
+
2) If `single_agent_episode`'s `agent_id` and `module_id` properties are None
|
| 382 |
+
(`single_agent_episode` is not part of a multi-agent episode), will append
|
| 383 |
+
`item_to_add` to a list under a `(<episodeID>,)` key under `column`:
|
| 384 |
+
`column` -> `(<episodeID>,)` -> [item, item, ...]
|
| 385 |
+
3) If `single_agent_episode`'s `agent_id` and `module_id` are NOT None
|
| 386 |
+
(`single_agent_episode` is part of a multi-agent episode), will append
|
| 387 |
+
`item_to_add` to a list under a `(<episodeID>,<AgentID>,<ModuleID>)` key
|
| 388 |
+
under `column`:
|
| 389 |
+
`column` -> `(<episodeID>,<AgentID>,<ModuleID>)` -> [item, item, ...]
|
| 390 |
+
|
| 391 |
+
See the these examples here for clarification of these three cases:
|
| 392 |
+
|
| 393 |
+
.. testcode::
|
| 394 |
+
|
| 395 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 396 |
+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
|
| 397 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 398 |
+
from ray.rllib.utils.test_utils import check
|
| 399 |
+
|
| 400 |
+
# 1) Simple case (no episodes provided) -> Store data in a list directly
|
| 401 |
+
# under `column`:
|
| 402 |
+
batch = {}
|
| 403 |
+
ConnectorV2.add_batch_item(batch, "test_col", item_to_add=5)
|
| 404 |
+
ConnectorV2.add_batch_item(batch, "test_col", item_to_add=6)
|
| 405 |
+
check(batch, {"test_col": [5, 6]})
|
| 406 |
+
ConnectorV2.add_batch_item(batch, "test_col_2", item_to_add=-10)
|
| 407 |
+
check(batch, {
|
| 408 |
+
"test_col": [5, 6],
|
| 409 |
+
"test_col_2": [-10],
|
| 410 |
+
})
|
| 411 |
+
|
| 412 |
+
# 2) Single-agent case (SingleAgentEpisode provided) -> Store data in a list
|
| 413 |
+
# under the keys: `column` -> `(<eps_id>,)` -> [...]:
|
| 414 |
+
batch = {}
|
| 415 |
+
episode = SingleAgentEpisode(
|
| 416 |
+
id_="SA-EPS0",
|
| 417 |
+
observations=[0, 1, 2, 3],
|
| 418 |
+
actions=[1, 2, 3],
|
| 419 |
+
rewards=[1.0, 2.0, 3.0],
|
| 420 |
+
)
|
| 421 |
+
ConnectorV2.add_batch_item(batch, "test_col", 5, episode)
|
| 422 |
+
ConnectorV2.add_batch_item(batch, "test_col", 6, episode)
|
| 423 |
+
ConnectorV2.add_batch_item(batch, "test_col_2", -10, episode)
|
| 424 |
+
check(batch, {
|
| 425 |
+
"test_col": {("SA-EPS0",): [5, 6]},
|
| 426 |
+
"test_col_2": {("SA-EPS0",): [-10]},
|
| 427 |
+
})
|
| 428 |
+
|
| 429 |
+
# 3) Multi-agent case (SingleAgentEpisode provided that has `agent_id` and
|
| 430 |
+
# `module_id` information) -> Store data in a list under the keys:
|
| 431 |
+
# `column` -> `(<episodeID>,<AgentID>,<ModuleID>)` -> [...]:
|
| 432 |
+
batch = {}
|
| 433 |
+
ma_episode = MultiAgentEpisode(
|
| 434 |
+
id_="MA-EPS1",
|
| 435 |
+
observations=[
|
| 436 |
+
{"ag0": 0, "ag1": 1}, {"ag0": 2, "ag1": 4}
|
| 437 |
+
],
|
| 438 |
+
actions=[{"ag0": 0, "ag1": 1}],
|
| 439 |
+
rewards=[{"ag0": -0.1, "ag1": -0.2}],
|
| 440 |
+
# ag0 maps to mod0, ag1 maps to mod1, etc..
|
| 441 |
+
agent_to_module_mapping_fn=lambda aid, eps: f"mod{aid[2:]}",
|
| 442 |
+
)
|
| 443 |
+
ConnectorV2.add_batch_item(
|
| 444 |
+
batch,
|
| 445 |
+
"test_col",
|
| 446 |
+
item_to_add=5,
|
| 447 |
+
single_agent_episode=ma_episode.agent_episodes["ag0"],
|
| 448 |
+
)
|
| 449 |
+
ConnectorV2.add_batch_item(
|
| 450 |
+
batch,
|
| 451 |
+
"test_col",
|
| 452 |
+
item_to_add=6,
|
| 453 |
+
single_agent_episode=ma_episode.agent_episodes["ag0"],
|
| 454 |
+
)
|
| 455 |
+
ConnectorV2.add_batch_item(
|
| 456 |
+
batch,
|
| 457 |
+
"test_col_2",
|
| 458 |
+
item_to_add=10,
|
| 459 |
+
single_agent_episode=ma_episode.agent_episodes["ag1"],
|
| 460 |
+
)
|
| 461 |
+
check(
|
| 462 |
+
batch,
|
| 463 |
+
{
|
| 464 |
+
"test_col": {("MA-EPS1", "ag0", "mod0"): [5, 6]},
|
| 465 |
+
"test_col_2": {("MA-EPS1", "ag1", "mod1"): [10]},
|
| 466 |
+
},
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
batch: The batch to store `item_to_add` in.
|
| 471 |
+
column: The column name (str) within the `batch` to store `item_to_add`
|
| 472 |
+
under.
|
| 473 |
+
item_to_add: The data item to store in the batch.
|
| 474 |
+
single_agent_episode: An optional SingleAgentEpisode.
|
| 475 |
+
If provided and its `agent_id` and `module_id` properties are None,
|
| 476 |
+
creates a further sub dictionary under `column`, mapping from
|
| 477 |
+
`(<episodeID>,)` to a list of data items (to which `item_to_add` will
|
| 478 |
+
be appended in this call).
|
| 479 |
+
If provided and its `agent_id` and `module_id` properties are NOT None,
|
| 480 |
+
creates a further sub dictionary under `column`, mapping from
|
| 481 |
+
`(<episodeID>,,<AgentID>,<ModuleID>)` to a list of data items (to which
|
| 482 |
+
`item_to_add` will be appended in this call).
|
| 483 |
+
If not provided, will append `item_to_add` to a list directly under
|
| 484 |
+
`column`.
|
| 485 |
+
"""
|
| 486 |
+
sub_key = None
|
| 487 |
+
# SAEpisode is provided ...
|
| 488 |
+
if single_agent_episode is not None:
|
| 489 |
+
module_id = single_agent_episode.module_id
|
| 490 |
+
# ... and has `module_id` AND that `module_id` is already a top-level key in
|
| 491 |
+
# `batch` (`batch` is already in module-major form, mapping ModuleID to
|
| 492 |
+
# columns mapping to data).
|
| 493 |
+
if module_id is not None and module_id in batch:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Can't call `add_batch_item` on a `batch` that is already "
|
| 496 |
+
"module-major (meaning ModuleID is top-level with column names on "
|
| 497 |
+
"the level thereunder)! Make sure to only call `add_batch_items` "
|
| 498 |
+
"before the `AgentToModuleMapping` ConnectorV2 piece is applied."
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# ... and has `agent_id` -> Use `single_agent_episode`'s agent ID and
|
| 502 |
+
# module ID.
|
| 503 |
+
elif single_agent_episode.agent_id is not None:
|
| 504 |
+
sub_key = (
|
| 505 |
+
single_agent_episode.multi_agent_episode_id,
|
| 506 |
+
single_agent_episode.agent_id,
|
| 507 |
+
single_agent_episode.module_id,
|
| 508 |
+
)
|
| 509 |
+
# Otherwise, just use episode's ID.
|
| 510 |
+
else:
|
| 511 |
+
sub_key = (single_agent_episode.id_,)
|
| 512 |
+
|
| 513 |
+
if column not in batch:
|
| 514 |
+
batch[column] = [] if sub_key is None else {sub_key: []}
|
| 515 |
+
if sub_key is not None:
|
| 516 |
+
if sub_key not in batch[column]:
|
| 517 |
+
batch[column][sub_key] = []
|
| 518 |
+
batch[column][sub_key].append(item_to_add)
|
| 519 |
+
else:
|
| 520 |
+
batch[column].append(item_to_add)
|
| 521 |
+
|
| 522 |
+
@staticmethod
|
| 523 |
+
def add_n_batch_items(
|
| 524 |
+
batch: Dict[str, Any],
|
| 525 |
+
column: str,
|
| 526 |
+
items_to_add: Any,
|
| 527 |
+
num_items: int,
|
| 528 |
+
single_agent_episode: Optional[SingleAgentEpisode] = None,
|
| 529 |
+
) -> None:
|
| 530 |
+
"""Adds a list of items (or batched item) under `column` to the given `batch`.
|
| 531 |
+
|
| 532 |
+
If `items_to_add` is not a list, but an already batched struct (of np.ndarray
|
| 533 |
+
leafs), the `items_to_add` will be appended to possibly existing data under the
|
| 534 |
+
same `column` as-is. A subsequent `BatchIndividualItems` ConnectorV2 piece will
|
| 535 |
+
recognize this and batch the data properly into a single (batched) item.
|
| 536 |
+
This is much faster than first splitting up `items_to_add` and then adding each
|
| 537 |
+
item individually.
|
| 538 |
+
|
| 539 |
+
If `single_agent_episode` is provided and its `agent_id` and `module_id`
|
| 540 |
+
properties are None, creates a further sub dictionary under `column`, mapping
|
| 541 |
+
from `(<episodeID>,)` to a list of data items (to which `items_to_add` will
|
| 542 |
+
be appended in this call).
|
| 543 |
+
If `single_agent_episode` is provided and its `agent_id` and `module_id`
|
| 544 |
+
properties are NOT None, creates a further sub dictionary under `column`,
|
| 545 |
+
mapping from `(<episodeID>,,<AgentID>,<ModuleID>)` to a list of data items (to
|
| 546 |
+
which `items_to_add` will be appended in this call).
|
| 547 |
+
If `single_agent_episode` is not provided, will append `items_to_add` to a list
|
| 548 |
+
directly under `column`.
|
| 549 |
+
|
| 550 |
+
.. testcode::
|
| 551 |
+
|
| 552 |
+
import numpy as np
|
| 553 |
+
|
| 554 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 555 |
+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
|
| 556 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 557 |
+
from ray.rllib.utils.test_utils import check
|
| 558 |
+
|
| 559 |
+
# Simple case (no episodes provided) -> Store data in a list directly under
|
| 560 |
+
# `column`:
|
| 561 |
+
batch = {}
|
| 562 |
+
ConnectorV2.add_n_batch_items(
|
| 563 |
+
batch,
|
| 564 |
+
"test_col",
|
| 565 |
+
# List of (complex) structs.
|
| 566 |
+
[{"a": np.array(3), "b": 4}, {"a": np.array(5), "b": 6}],
|
| 567 |
+
num_items=2,
|
| 568 |
+
)
|
| 569 |
+
check(
|
| 570 |
+
batch["test_col"],
|
| 571 |
+
[{"a": np.array(3), "b": 4}, {"a": np.array(5), "b": 6}],
|
| 572 |
+
)
|
| 573 |
+
# In a new column (test_col_2), store some already batched items.
|
| 574 |
+
# This way, you may avoid having to disassemble an already batched item
|
| 575 |
+
# (e.g. a numpy array of shape (10, 2)) into its individual items (e.g.
|
| 576 |
+
# split the array into a list of len=10) and then adding these individually.
|
| 577 |
+
# The performance gains may be quite large when providing already batched
|
| 578 |
+
# items (such as numpy arrays with a batch dim):
|
| 579 |
+
ConnectorV2.add_n_batch_items(
|
| 580 |
+
batch,
|
| 581 |
+
"test_col_2",
|
| 582 |
+
# One (complex) already batched struct.
|
| 583 |
+
{"a": np.array([3, 5]), "b": np.array([4, 6])},
|
| 584 |
+
num_items=2,
|
| 585 |
+
)
|
| 586 |
+
# Add more already batched items (this time with a different batch size)
|
| 587 |
+
ConnectorV2.add_n_batch_items(
|
| 588 |
+
batch,
|
| 589 |
+
"test_col_2",
|
| 590 |
+
{"a": np.array([7, 7, 7]), "b": np.array([8, 8, 8])},
|
| 591 |
+
num_items=3, # <- in this case, this must be the batch size
|
| 592 |
+
)
|
| 593 |
+
check(
|
| 594 |
+
batch["test_col_2"],
|
| 595 |
+
[
|
| 596 |
+
{"a": np.array([3, 5]), "b": np.array([4, 6])},
|
| 597 |
+
{"a": np.array([7, 7, 7]), "b": np.array([8, 8, 8])},
|
| 598 |
+
],
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Single-agent case (SingleAgentEpisode provided) -> Store data in a list
|
| 602 |
+
# under the keys: `column` -> `(<eps_id>,)`:
|
| 603 |
+
batch = {}
|
| 604 |
+
episode = SingleAgentEpisode(
|
| 605 |
+
id_="SA-EPS0",
|
| 606 |
+
observations=[0, 1, 2, 3],
|
| 607 |
+
actions=[1, 2, 3],
|
| 608 |
+
rewards=[1.0, 2.0, 3.0],
|
| 609 |
+
)
|
| 610 |
+
ConnectorV2.add_n_batch_items(
|
| 611 |
+
batch=batch,
|
| 612 |
+
column="test_col",
|
| 613 |
+
items_to_add=[5, 6, 7],
|
| 614 |
+
num_items=3,
|
| 615 |
+
single_agent_episode=episode,
|
| 616 |
+
)
|
| 617 |
+
check(batch, {
|
| 618 |
+
"test_col": {("SA-EPS0",): [5, 6, 7]},
|
| 619 |
+
})
|
| 620 |
+
|
| 621 |
+
# Multi-agent case (SingleAgentEpisode provided that has `agent_id` and
|
| 622 |
+
# `module_id` information) -> Store data in a list under the keys:
|
| 623 |
+
# `column` -> `(<episodeID>,<AgentID>,<ModuleID>)`:
|
| 624 |
+
batch = {}
|
| 625 |
+
ma_episode = MultiAgentEpisode(
|
| 626 |
+
id_="MA-EPS1",
|
| 627 |
+
observations=[
|
| 628 |
+
{"ag0": 0, "ag1": 1}, {"ag0": 2, "ag1": 4}
|
| 629 |
+
],
|
| 630 |
+
actions=[{"ag0": 0, "ag1": 1}],
|
| 631 |
+
rewards=[{"ag0": -0.1, "ag1": -0.2}],
|
| 632 |
+
# ag0 maps to mod0, ag1 maps to mod1, etc..
|
| 633 |
+
agent_to_module_mapping_fn=lambda aid, eps: f"mod{aid[2:]}",
|
| 634 |
+
)
|
| 635 |
+
ConnectorV2.add_batch_item(
|
| 636 |
+
batch,
|
| 637 |
+
"test_col",
|
| 638 |
+
item_to_add=5,
|
| 639 |
+
single_agent_episode=ma_episode.agent_episodes["ag0"],
|
| 640 |
+
)
|
| 641 |
+
ConnectorV2.add_batch_item(
|
| 642 |
+
batch,
|
| 643 |
+
"test_col",
|
| 644 |
+
item_to_add=6,
|
| 645 |
+
single_agent_episode=ma_episode.agent_episodes["ag0"],
|
| 646 |
+
)
|
| 647 |
+
ConnectorV2.add_batch_item(
|
| 648 |
+
batch,
|
| 649 |
+
"test_col_2",
|
| 650 |
+
item_to_add=10,
|
| 651 |
+
single_agent_episode=ma_episode.agent_episodes["ag1"],
|
| 652 |
+
)
|
| 653 |
+
check(
|
| 654 |
+
batch,
|
| 655 |
+
{
|
| 656 |
+
"test_col": {("MA-EPS1", "ag0", "mod0"): [5, 6]},
|
| 657 |
+
"test_col_2": {("MA-EPS1", "ag1", "mod1"): [10]},
|
| 658 |
+
},
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
Args:
|
| 662 |
+
batch: The batch to store n `items_to_add` in.
|
| 663 |
+
column: The column name (str) within the `batch` to store `item_to_add`
|
| 664 |
+
under.
|
| 665 |
+
items_to_add: The list of data items to store in the batch OR an already
|
| 666 |
+
batched (possibly nested) struct. In the latter case, the `items_to_add`
|
| 667 |
+
will be appended to possibly existing data under the same `column`
|
| 668 |
+
as-is. A subsequent `BatchIndividualItems` ConnectorV2 piece will
|
| 669 |
+
recognize this and batch the data properly into a single (batched) item.
|
| 670 |
+
This is much faster than first splitting up `items_to_add` and then
|
| 671 |
+
adding each item individually.
|
| 672 |
+
num_items: The number of items in `items_to_add`. This arg is mostly for
|
| 673 |
+
asserting the correct usage of this method by checking, whether the
|
| 674 |
+
given data in `items_to_add` really has the right amount of individual
|
| 675 |
+
items.
|
| 676 |
+
single_agent_episode: An optional SingleAgentEpisode.
|
| 677 |
+
If provided and its `agent_id` and `module_id` properties are None,
|
| 678 |
+
creates a further sub dictionary under `column`, mapping from
|
| 679 |
+
`(<episodeID>,)` to a list of data items (to which `items_to_add` will
|
| 680 |
+
be appended in this call).
|
| 681 |
+
If provided and its `agent_id` and `module_id` properties are NOT None,
|
| 682 |
+
creates a further sub dictionary under `column`, mapping from
|
| 683 |
+
`(<episodeID>,,<AgentID>,<ModuleID>)` to a list of data items (to which
|
| 684 |
+
`items_to_add` will be appended in this call).
|
| 685 |
+
If not provided, will append `items_to_add` to a list directly under
|
| 686 |
+
`column`.
|
| 687 |
+
"""
|
| 688 |
+
# Process n list items by calling `add_batch_item` on each of them individually.
|
| 689 |
+
if isinstance(items_to_add, list):
|
| 690 |
+
if len(items_to_add) != num_items:
|
| 691 |
+
raise ValueError(
|
| 692 |
+
f"Mismatch between `num_items` ({num_items}) and the length "
|
| 693 |
+
f"of the provided list ({len(items_to_add)}) in "
|
| 694 |
+
f"{ConnectorV2.__name__}.add_n_batch_items()!"
|
| 695 |
+
)
|
| 696 |
+
for item in items_to_add:
|
| 697 |
+
ConnectorV2.add_batch_item(
|
| 698 |
+
batch=batch,
|
| 699 |
+
column=column,
|
| 700 |
+
item_to_add=item,
|
| 701 |
+
single_agent_episode=single_agent_episode,
|
| 702 |
+
)
|
| 703 |
+
return
|
| 704 |
+
|
| 705 |
+
# Process a batched (possibly complex) struct.
|
| 706 |
+
# We could just unbatch the item (split it into a list) and then add each
|
| 707 |
+
# individual item to our `batch`. However, this comes with a heavy performance
|
| 708 |
+
# penalty. Instead, we tag the thus added array(s) here as "_has_batch_dim=True"
|
| 709 |
+
# and then know that when batching the entire list under the respective
|
| 710 |
+
# (eps_id, agent_id, module_id)-tuple key, we need to concatenate, not stack
|
| 711 |
+
# the items in there.
|
| 712 |
+
def _tag(s):
|
| 713 |
+
return BatchedNdArray(s)
|
| 714 |
+
|
| 715 |
+
ConnectorV2.add_batch_item(
|
| 716 |
+
batch=batch,
|
| 717 |
+
column=column,
|
| 718 |
+
# Convert given input into BatchedNdArray(s) such that the `batch` utility
|
| 719 |
+
# knows that it'll have to concat, not stack.
|
| 720 |
+
item_to_add=tree.map_structure(_tag, items_to_add),
|
| 721 |
+
single_agent_episode=single_agent_episode,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
@staticmethod
|
| 725 |
+
def foreach_batch_item_change_in_place(
|
| 726 |
+
batch: Dict[str, Any],
|
| 727 |
+
column: Union[str, List[str], Tuple[str]],
|
| 728 |
+
func: Callable[
|
| 729 |
+
[Any, Optional[int], Optional[AgentID], Optional[ModuleID]], Any
|
| 730 |
+
],
|
| 731 |
+
) -> None:
|
| 732 |
+
"""Runs the provided `func` on all items under one or more columns in the batch.
|
| 733 |
+
|
| 734 |
+
Use this method to conveniently loop through all items in a batch
|
| 735 |
+
and transform them in place.
|
| 736 |
+
|
| 737 |
+
`func` takes the following as arguments:
|
| 738 |
+
- The item itself. If column is a list of column names, this argument is a tuple
|
| 739 |
+
of items.
|
| 740 |
+
- The EpisodeID. This value might be None.
|
| 741 |
+
- The AgentID. This value might be None in the single-agent case.
|
| 742 |
+
- The ModuleID. This value might be None in the single-agent case.
|
| 743 |
+
|
| 744 |
+
The return value(s) of `func` are used to directly override the values in the
|
| 745 |
+
given `batch`.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
batch: The batch to process in-place.
|
| 749 |
+
column: A single column name (str) or a list thereof. If a list is provided,
|
| 750 |
+
the first argument to `func` is a tuple of items. If a single
|
| 751 |
+
str is provided, the first argument to `func` is an individual
|
| 752 |
+
item.
|
| 753 |
+
func: The function to call on each item or tuple of item(s).
|
| 754 |
+
|
| 755 |
+
.. testcode::
|
| 756 |
+
|
| 757 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 758 |
+
from ray.rllib.utils.test_utils import check
|
| 759 |
+
|
| 760 |
+
# Simple case: Batch items are in lists directly under their column names.
|
| 761 |
+
batch = {
|
| 762 |
+
"col1": [0, 1, 2, 3],
|
| 763 |
+
"col2": [0, -1, -2, -3],
|
| 764 |
+
}
|
| 765 |
+
# Increase all ints by 1.
|
| 766 |
+
ConnectorV2.foreach_batch_item_change_in_place(
|
| 767 |
+
batch=batch,
|
| 768 |
+
column="col1",
|
| 769 |
+
func=lambda item, *args: item + 1,
|
| 770 |
+
)
|
| 771 |
+
check(batch["col1"], [1, 2, 3, 4])
|
| 772 |
+
|
| 773 |
+
# Further increase all ints by 1 in col1 and flip sign in col2.
|
| 774 |
+
ConnectorV2.foreach_batch_item_change_in_place(
|
| 775 |
+
batch=batch,
|
| 776 |
+
column=["col1", "col2"],
|
| 777 |
+
func=(lambda items, *args: (items[0] + 1, -items[1])),
|
| 778 |
+
)
|
| 779 |
+
check(batch["col1"], [2, 3, 4, 5])
|
| 780 |
+
check(batch["col2"], [0, 1, 2, 3])
|
| 781 |
+
|
| 782 |
+
# Single-agent case: Batch items are in lists under (eps_id,)-keys in a dict
|
| 783 |
+
# under their column names.
|
| 784 |
+
batch = {
|
| 785 |
+
"col1": {
|
| 786 |
+
("eps1",): [0, 1, 2, 3],
|
| 787 |
+
("eps2",): [400, 500, 600],
|
| 788 |
+
},
|
| 789 |
+
}
|
| 790 |
+
# Increase all ints of eps1 by 1 and divide all ints of eps2 by 100.
|
| 791 |
+
ConnectorV2.foreach_batch_item_change_in_place(
|
| 792 |
+
batch=batch,
|
| 793 |
+
column="col1",
|
| 794 |
+
func=lambda item, eps_id, *args: (
|
| 795 |
+
item + 1 if eps_id == "eps1" else item / 100
|
| 796 |
+
),
|
| 797 |
+
)
|
| 798 |
+
check(batch["col1"], {
|
| 799 |
+
("eps1",): [1, 2, 3, 4],
|
| 800 |
+
("eps2",): [4, 5, 6],
|
| 801 |
+
})
|
| 802 |
+
|
| 803 |
+
# Multi-agent case: Batch items are in lists under
|
| 804 |
+
# (eps_id, agent_id, module_id)-keys in a dict
|
| 805 |
+
# under their column names.
|
| 806 |
+
batch = {
|
| 807 |
+
"col1": {
|
| 808 |
+
("eps1", "ag1", "mod1"): [1, 2, 3, 4],
|
| 809 |
+
("eps2", "ag1", "mod2"): [400, 500, 600],
|
| 810 |
+
("eps2", "ag2", "mod3"): [-1, -2, -3, -4, -5],
|
| 811 |
+
},
|
| 812 |
+
}
|
| 813 |
+
# Decrease all ints of "eps1" by 1, divide all ints of "mod2" by 100, and
|
| 814 |
+
# flip sign of all ints of "ag2".
|
| 815 |
+
ConnectorV2.foreach_batch_item_change_in_place(
|
| 816 |
+
batch=batch,
|
| 817 |
+
column="col1",
|
| 818 |
+
func=lambda item, eps_id, ag_id, mod_id: (
|
| 819 |
+
item - 1
|
| 820 |
+
if eps_id == "eps1"
|
| 821 |
+
else item / 100
|
| 822 |
+
if mod_id == "mod2"
|
| 823 |
+
else -item
|
| 824 |
+
),
|
| 825 |
+
)
|
| 826 |
+
check(batch["col1"], {
|
| 827 |
+
("eps1", "ag1", "mod1"): [0, 1, 2, 3],
|
| 828 |
+
("eps2", "ag1", "mod2"): [4, 5, 6],
|
| 829 |
+
("eps2", "ag2", "mod3"): [1, 2, 3, 4, 5],
|
| 830 |
+
})
|
| 831 |
+
"""
|
| 832 |
+
data_to_process = [batch.get(c) for c in force_list(column)]
|
| 833 |
+
single_col = isinstance(column, str)
|
| 834 |
+
if any(d is None for d in data_to_process):
|
| 835 |
+
raise ValueError(
|
| 836 |
+
f"Invalid column name(s) ({column})! One or more not found in "
|
| 837 |
+
f"given batch. Found columns {list(batch.keys())}."
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
# Simple case: Data items are stored in a list directly under the column
|
| 841 |
+
# name(s).
|
| 842 |
+
if isinstance(data_to_process[0], list):
|
| 843 |
+
for list_pos, data_tuple in enumerate(zip(*data_to_process)):
|
| 844 |
+
results = func(
|
| 845 |
+
data_tuple[0] if single_col else data_tuple,
|
| 846 |
+
None, # episode_id
|
| 847 |
+
None, # agent_id
|
| 848 |
+
None, # module_id
|
| 849 |
+
)
|
| 850 |
+
# Tuple'ize results if single_col.
|
| 851 |
+
results = (results,) if single_col else results
|
| 852 |
+
for col_slot, result in enumerate(force_list(results)):
|
| 853 |
+
data_to_process[col_slot][list_pos] = result
|
| 854 |
+
# Single-agent/multi-agent cases.
|
| 855 |
+
else:
|
| 856 |
+
for key, d0_list in data_to_process[0].items():
|
| 857 |
+
# Multi-agent case: There is a dict mapping from a
|
| 858 |
+
# (eps id, AgentID, ModuleID)-tuples to lists of individual data items.
|
| 859 |
+
if len(key) == 3:
|
| 860 |
+
eps_id, agent_id, module_id = key
|
| 861 |
+
# Single-agent case: There is a dict mapping from a (eps_id,)-tuple
|
| 862 |
+
# to lists of individual data items.
|
| 863 |
+
# AgentID and ModuleID are both None.
|
| 864 |
+
else:
|
| 865 |
+
eps_id = key[0]
|
| 866 |
+
agent_id = module_id = None
|
| 867 |
+
other_lists = [d[key] for d in data_to_process[1:]]
|
| 868 |
+
for list_pos, data_tuple in enumerate(zip(d0_list, *other_lists)):
|
| 869 |
+
results = func(
|
| 870 |
+
data_tuple[0] if single_col else data_tuple,
|
| 871 |
+
eps_id,
|
| 872 |
+
agent_id,
|
| 873 |
+
module_id,
|
| 874 |
+
)
|
| 875 |
+
# Tuple'ize results if single_col.
|
| 876 |
+
results = (results,) if single_col else results
|
| 877 |
+
for col_slot, result in enumerate(results):
|
| 878 |
+
data_to_process[col_slot][key][list_pos] = result
|
| 879 |
+
|
| 880 |
+
@staticmethod
|
| 881 |
+
def switch_batch_from_column_to_module_ids(
|
| 882 |
+
batch: Dict[str, Dict[ModuleID, Any]]
|
| 883 |
+
) -> Dict[ModuleID, Dict[str, Any]]:
|
| 884 |
+
"""Switches the first two levels of a `col_name -> ModuleID -> data` type batch.
|
| 885 |
+
|
| 886 |
+
Assuming that the top level consists of column names as keys and the second
|
| 887 |
+
level (under these columns) consists of ModuleID keys, the resulting batch
|
| 888 |
+
will have these two reversed and thus map ModuleIDs to dicts mapping column
|
| 889 |
+
names to data items.
|
| 890 |
+
|
| 891 |
+
.. testcode::
|
| 892 |
+
|
| 893 |
+
from ray.rllib.utils.test_utils import check
|
| 894 |
+
|
| 895 |
+
batch = {
|
| 896 |
+
"obs": {"module_0": [1, 2, 3]},
|
| 897 |
+
"actions": {"module_0": [4, 5, 6], "module_1": [7]},
|
| 898 |
+
}
|
| 899 |
+
switched_batch = ConnectorV2.switch_batch_from_column_to_module_ids(batch)
|
| 900 |
+
check(
|
| 901 |
+
switched_batch,
|
| 902 |
+
{
|
| 903 |
+
"module_0": {"obs": [1, 2, 3], "actions": [4, 5, 6]},
|
| 904 |
+
"module_1": {"actions": [7]},
|
| 905 |
+
},
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
Args:
|
| 909 |
+
batch: The batch to switch from being column name based (then ModuleIDs)
|
| 910 |
+
to being ModuleID based (then column names).
|
| 911 |
+
|
| 912 |
+
Returns:
|
| 913 |
+
A new batch dict mapping ModuleIDs to dicts mapping column names (e.g.
|
| 914 |
+
"obs") to data.
|
| 915 |
+
"""
|
| 916 |
+
module_data = defaultdict(dict)
|
| 917 |
+
for column, column_data in batch.items():
|
| 918 |
+
for module_id, data in column_data.items():
|
| 919 |
+
module_data[module_id][column] = data
|
| 920 |
+
return dict(module_data)
|
| 921 |
+
|
| 922 |
+
@override(Checkpointable)
|
| 923 |
+
def get_state(
|
| 924 |
+
self,
|
| 925 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 926 |
+
*,
|
| 927 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 928 |
+
**kwargs,
|
| 929 |
+
) -> StateDict:
|
| 930 |
+
return {}
|
| 931 |
+
|
| 932 |
+
@override(Checkpointable)
|
| 933 |
+
def set_state(self, state: StateDict) -> None:
|
| 934 |
+
pass
|
| 935 |
+
|
| 936 |
+
@override(Checkpointable)
|
| 937 |
+
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
|
| 938 |
+
return (
|
| 939 |
+
(), # *args
|
| 940 |
+
self._ctor_kwargs, # **kwargs
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
def reset_state(self) -> None:
|
| 944 |
+
"""Resets the state of this ConnectorV2 to some initial value.
|
| 945 |
+
|
| 946 |
+
Note that this may NOT be the exact state that this ConnectorV2 was originally
|
| 947 |
+
constructed with.
|
| 948 |
+
"""
|
| 949 |
+
return
|
| 950 |
+
|
| 951 |
+
def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 952 |
+
"""Computes a resulting state given self's state and a list of other states.
|
| 953 |
+
|
| 954 |
+
Algorithms should use this method for merging states between connectors
|
| 955 |
+
running on parallel EnvRunner workers. For example, to synchronize the connector
|
| 956 |
+
states of n remote workers and a local worker, one could:
|
| 957 |
+
- Gather all remote worker connector states in a list.
|
| 958 |
+
- Call `self.merge_states()` on the local worker passing it the states list.
|
| 959 |
+
- Broadcast the resulting local worker's connector state back to all remote
|
| 960 |
+
workers. After this, all workers (including the local one) hold a
|
| 961 |
+
merged/synchronized new connecto state.
|
| 962 |
+
|
| 963 |
+
Args:
|
| 964 |
+
states: The list of n other ConnectorV2 states to merge with self's state
|
| 965 |
+
into a single resulting state.
|
| 966 |
+
|
| 967 |
+
Returns:
|
| 968 |
+
The resulting state dict.
|
| 969 |
+
"""
|
| 970 |
+
return {}
|
| 971 |
+
|
| 972 |
+
@property
|
| 973 |
+
def observation_space(self):
|
| 974 |
+
"""Getter for our (output) observation space.
|
| 975 |
+
|
| 976 |
+
Logic: Use user provided space (if set via `observation_space` setter)
|
| 977 |
+
otherwise, use the same as the input space, assuming this connector piece
|
| 978 |
+
does not alter the space.
|
| 979 |
+
"""
|
| 980 |
+
return self._observation_space
|
| 981 |
+
|
| 982 |
+
@property
|
| 983 |
+
def action_space(self):
|
| 984 |
+
"""Getter for our (output) action space.
|
| 985 |
+
|
| 986 |
+
Logic: Use user provided space (if set via `action_space` setter)
|
| 987 |
+
otherwise, use the same as the input space, assuming this connector piece
|
| 988 |
+
does not alter the space.
|
| 989 |
+
"""
|
| 990 |
+
return self._action_space
|
| 991 |
+
|
| 992 |
+
@property
|
| 993 |
+
def input_observation_space(self):
|
| 994 |
+
return self._input_observation_space
|
| 995 |
+
|
| 996 |
+
@input_observation_space.setter
|
| 997 |
+
def input_observation_space(self, value):
|
| 998 |
+
self._input_observation_space = value
|
| 999 |
+
if value is not None:
|
| 1000 |
+
self._observation_space = self.recompute_output_observation_space(
|
| 1001 |
+
value, self.input_action_space
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
@property
|
| 1005 |
+
def input_action_space(self):
|
| 1006 |
+
return self._input_action_space
|
| 1007 |
+
|
| 1008 |
+
@input_action_space.setter
|
| 1009 |
+
def input_action_space(self, value):
|
| 1010 |
+
self._input_action_space = value
|
| 1011 |
+
if value is not None:
|
| 1012 |
+
self._action_space = self.recompute_output_action_space(
|
| 1013 |
+
self.input_observation_space, value
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
def __str__(self, indentation: int = 0):
|
| 1017 |
+
return " " * indentation + self.__class__.__name__
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
|
| 2 |
+
AddObservationsFromEpisodesToBatch,
|
| 3 |
+
)
|
| 4 |
+
from ray.rllib.connectors.common.add_states_from_episodes_to_batch import (
|
| 5 |
+
AddStatesFromEpisodesToBatch,
|
| 6 |
+
)
|
| 7 |
+
from ray.rllib.connectors.common.add_time_dim_to_batch_and_zero_pad import (
|
| 8 |
+
AddTimeDimToBatchAndZeroPad,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping
|
| 11 |
+
from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems
|
| 12 |
+
from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor
|
| 13 |
+
from ray.rllib.connectors.env_to_module.env_to_module_pipeline import (
|
| 14 |
+
EnvToModulePipeline,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.connectors.env_to_module.flatten_observations import (
|
| 17 |
+
FlattenObservations,
|
| 18 |
+
)
|
| 19 |
+
from ray.rllib.connectors.env_to_module.mean_std_filter import MeanStdFilter
|
| 20 |
+
from ray.rllib.connectors.env_to_module.prev_actions_prev_rewards import (
|
| 21 |
+
PrevActionsPrevRewards,
|
| 22 |
+
)
|
| 23 |
+
from ray.rllib.connectors.env_to_module.write_observations_to_episodes import (
|
| 24 |
+
WriteObservationsToEpisodes,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"AddObservationsFromEpisodesToBatch",
|
| 30 |
+
"AddStatesFromEpisodesToBatch",
|
| 31 |
+
"AddTimeDimToBatchAndZeroPad",
|
| 32 |
+
"AgentToModuleMapping",
|
| 33 |
+
"BatchIndividualItems",
|
| 34 |
+
"EnvToModulePipeline",
|
| 35 |
+
"FlattenObservations",
|
| 36 |
+
"MeanStdFilter",
|
| 37 |
+
"NumpyToTensor",
|
| 38 |
+
"PrevActionsPrevRewards",
|
| 39 |
+
"WriteObservationsToEpisodes",
|
| 40 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc
ADDED
|
Binary file (7.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2
|
| 4 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 5 |
+
from ray.rllib.utils.annotations import override
|
| 6 |
+
from ray.rllib.utils.metrics import (
|
| 7 |
+
ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN,
|
| 8 |
+
ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 11 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 12 |
+
from ray.util.annotations import PublicAPI
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@PublicAPI(stability="alpha")
|
| 16 |
+
class EnvToModulePipeline(ConnectorPipelineV2):
|
| 17 |
+
@override(ConnectorPipelineV2)
|
| 18 |
+
def __call__(
|
| 19 |
+
self,
|
| 20 |
+
*,
|
| 21 |
+
rl_module: RLModule,
|
| 22 |
+
batch: Optional[Dict[str, Any]] = None,
|
| 23 |
+
episodes: List[EpisodeType],
|
| 24 |
+
explore: bool,
|
| 25 |
+
shared_data: Optional[dict] = None,
|
| 26 |
+
metrics: Optional[MetricsLogger] = None,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
# Log the sum of lengths of all episodes incoming.
|
| 30 |
+
if metrics:
|
| 31 |
+
metrics.log_value(
|
| 32 |
+
ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN,
|
| 33 |
+
sum(map(len, episodes)),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Make sure user does not necessarily send initial input into this pipeline.
|
| 37 |
+
# Might just be empty and to be populated from `episodes`.
|
| 38 |
+
ret = super().__call__(
|
| 39 |
+
rl_module=rl_module,
|
| 40 |
+
batch=batch if batch is not None else {},
|
| 41 |
+
episodes=episodes,
|
| 42 |
+
explore=explore,
|
| 43 |
+
shared_data=shared_data if shared_data is not None else {},
|
| 44 |
+
metrics=metrics,
|
| 45 |
+
**kwargs,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Log the sum of lengths of all episodes outgoing.
|
| 49 |
+
if metrics:
|
| 50 |
+
metrics.log_value(
|
| 51 |
+
ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT,
|
| 52 |
+
sum(map(len, episodes)),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
return ret
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Collection, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
from gymnasium.spaces import Box
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tree # pip install dm_tree
|
| 7 |
+
|
| 8 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 9 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 10 |
+
from ray.rllib.utils.annotations import override
|
| 11 |
+
from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor
|
| 12 |
+
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
| 13 |
+
from ray.rllib.utils.typing import AgentID, EpisodeType
|
| 14 |
+
from ray.util.annotations import PublicAPI
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@PublicAPI(stability="alpha")
|
| 18 |
+
class FlattenObservations(ConnectorV2):
|
| 19 |
+
"""A connector piece that flattens all observation components into a 1D array.
|
| 20 |
+
|
| 21 |
+
- Should be used only in env-to-module pipelines.
|
| 22 |
+
- Works directly on the incoming episodes list and changes the last observation
|
| 23 |
+
in-place (write the flattened observation back into the episode).
|
| 24 |
+
- This connector does NOT alter the incoming batch (`data`) when called.
|
| 25 |
+
- This connector does NOT work in a `LearnerConnectorPipeline` because it requires
|
| 26 |
+
the incoming episodes to still be ongoing (in progress) as it only alters the
|
| 27 |
+
latest observation, not all observations in an episode.
|
| 28 |
+
|
| 29 |
+
.. testcode::
|
| 30 |
+
|
| 31 |
+
import gymnasium as gym
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
from ray.rllib.connectors.env_to_module import FlattenObservations
|
| 35 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 36 |
+
from ray.rllib.utils.test_utils import check
|
| 37 |
+
|
| 38 |
+
# Some arbitrarily nested, complex observation space.
|
| 39 |
+
obs_space = gym.spaces.Dict({
|
| 40 |
+
"a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
|
| 41 |
+
"b": gym.spaces.Tuple([
|
| 42 |
+
gym.spaces.Discrete(2),
|
| 43 |
+
gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
|
| 44 |
+
]),
|
| 45 |
+
"c": gym.spaces.MultiDiscrete([2, 3]),
|
| 46 |
+
})
|
| 47 |
+
act_space = gym.spaces.Discrete(2)
|
| 48 |
+
|
| 49 |
+
# Two example episodes, both with initial (reset) observations coming from the
|
| 50 |
+
# above defined observation space.
|
| 51 |
+
episode_1 = SingleAgentEpisode(
|
| 52 |
+
observations=[
|
| 53 |
+
{
|
| 54 |
+
"a": np.array(-10.0, np.float32),
|
| 55 |
+
"b": (1, np.array([[-1.0], [-1.0]], np.float32)),
|
| 56 |
+
"c": np.array([0, 2]),
|
| 57 |
+
},
|
| 58 |
+
],
|
| 59 |
+
)
|
| 60 |
+
episode_2 = SingleAgentEpisode(
|
| 61 |
+
observations=[
|
| 62 |
+
{
|
| 63 |
+
"a": np.array(10.0, np.float32),
|
| 64 |
+
"b": (0, np.array([[1.0], [1.0]], np.float32)),
|
| 65 |
+
"c": np.array([1, 1]),
|
| 66 |
+
},
|
| 67 |
+
],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Construct our connector piece.
|
| 71 |
+
connector = FlattenObservations(obs_space, act_space)
|
| 72 |
+
|
| 73 |
+
# Call our connector piece with the example data.
|
| 74 |
+
output_batch = connector(
|
| 75 |
+
rl_module=None, # This connector works without an RLModule.
|
| 76 |
+
batch={}, # This connector does not alter the input batch.
|
| 77 |
+
episodes=[episode_1, episode_2],
|
| 78 |
+
explore=True,
|
| 79 |
+
shared_data={},
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# The connector does not alter the data and acts as pure pass-through.
|
| 83 |
+
check(output_batch, {})
|
| 84 |
+
|
| 85 |
+
# The connector has flattened each item in the episodes to a 1D tensor.
|
| 86 |
+
check(
|
| 87 |
+
episode_1.get_observations(0),
|
| 88 |
+
# box() disc(2). box(2, 1). multidisc(2, 3)........
|
| 89 |
+
np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]),
|
| 90 |
+
)
|
| 91 |
+
check(
|
| 92 |
+
episode_2.get_observations(0),
|
| 93 |
+
# box() disc(2). box(2, 1). multidisc(2, 3)........
|
| 94 |
+
np.array([10.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]),
|
| 95 |
+
)
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
@override(ConnectorV2)
|
| 99 |
+
def recompute_output_observation_space(
|
| 100 |
+
self,
|
| 101 |
+
input_observation_space,
|
| 102 |
+
input_action_space,
|
| 103 |
+
) -> gym.Space:
|
| 104 |
+
self._input_obs_base_struct = get_base_struct_from_space(
|
| 105 |
+
self.input_observation_space
|
| 106 |
+
)
|
| 107 |
+
if self._multi_agent:
|
| 108 |
+
spaces = {}
|
| 109 |
+
for agent_id, space in self._input_obs_base_struct.items():
|
| 110 |
+
if self._agent_ids and agent_id not in self._agent_ids:
|
| 111 |
+
spaces[agent_id] = self._input_obs_base_struct[agent_id]
|
| 112 |
+
else:
|
| 113 |
+
sample = flatten_inputs_to_1d_tensor(
|
| 114 |
+
tree.map_structure(
|
| 115 |
+
lambda s: s.sample(),
|
| 116 |
+
self._input_obs_base_struct[agent_id],
|
| 117 |
+
),
|
| 118 |
+
self._input_obs_base_struct[agent_id],
|
| 119 |
+
batch_axis=False,
|
| 120 |
+
)
|
| 121 |
+
spaces[agent_id] = Box(
|
| 122 |
+
float("-inf"), float("inf"), (len(sample),), np.float32
|
| 123 |
+
)
|
| 124 |
+
return gym.spaces.Dict(spaces)
|
| 125 |
+
else:
|
| 126 |
+
sample = flatten_inputs_to_1d_tensor(
|
| 127 |
+
tree.map_structure(
|
| 128 |
+
lambda s: s.sample(),
|
| 129 |
+
self._input_obs_base_struct,
|
| 130 |
+
),
|
| 131 |
+
self._input_obs_base_struct,
|
| 132 |
+
batch_axis=False,
|
| 133 |
+
)
|
| 134 |
+
return Box(float("-inf"), float("inf"), (len(sample),), np.float32)
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
input_observation_space: Optional[gym.Space] = None,
|
| 139 |
+
input_action_space: Optional[gym.Space] = None,
|
| 140 |
+
*,
|
| 141 |
+
multi_agent: bool = False,
|
| 142 |
+
agent_ids: Optional[Collection[AgentID]] = None,
|
| 143 |
+
**kwargs,
|
| 144 |
+
):
|
| 145 |
+
"""Initializes a FlattenObservations instance.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
multi_agent: Whether this connector operates on multi-agent observations,
|
| 149 |
+
in which case, the top-level of the Dict space (where agent IDs are
|
| 150 |
+
mapped to individual agents' observation spaces) is left as-is.
|
| 151 |
+
agent_ids: If multi_agent is True, this argument defines a collection of
|
| 152 |
+
AgentIDs for which to flatten. AgentIDs not in this collection are
|
| 153 |
+
ignored.
|
| 154 |
+
If None, flatten observations for all AgentIDs. None is the default.
|
| 155 |
+
"""
|
| 156 |
+
self._input_obs_base_struct = None
|
| 157 |
+
self._multi_agent = multi_agent
|
| 158 |
+
self._agent_ids = agent_ids
|
| 159 |
+
|
| 160 |
+
super().__init__(input_observation_space, input_action_space, **kwargs)
|
| 161 |
+
|
| 162 |
+
@override(ConnectorV2)
|
| 163 |
+
def __call__(
|
| 164 |
+
self,
|
| 165 |
+
*,
|
| 166 |
+
rl_module: RLModule,
|
| 167 |
+
batch: Dict[str, Any],
|
| 168 |
+
episodes: List[EpisodeType],
|
| 169 |
+
explore: Optional[bool] = None,
|
| 170 |
+
shared_data: Optional[dict] = None,
|
| 171 |
+
**kwargs,
|
| 172 |
+
) -> Any:
|
| 173 |
+
for sa_episode in self.single_agent_episode_iterator(
|
| 174 |
+
episodes, agents_that_stepped_only=True
|
| 175 |
+
):
|
| 176 |
+
last_obs = sa_episode.get_observations(-1)
|
| 177 |
+
|
| 178 |
+
if self._multi_agent:
|
| 179 |
+
if (
|
| 180 |
+
self._agent_ids is not None
|
| 181 |
+
and sa_episode.agent_id not in self._agent_ids
|
| 182 |
+
):
|
| 183 |
+
flattened_obs = last_obs
|
| 184 |
+
else:
|
| 185 |
+
flattened_obs = flatten_inputs_to_1d_tensor(
|
| 186 |
+
inputs=last_obs,
|
| 187 |
+
# In the multi-agent case, we need to use the specific agent's
|
| 188 |
+
# space struct, not the multi-agent observation space dict.
|
| 189 |
+
spaces_struct=self._input_obs_base_struct[sa_episode.agent_id],
|
| 190 |
+
# Our items are individual observations (no batch axis present).
|
| 191 |
+
batch_axis=False,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
flattened_obs = flatten_inputs_to_1d_tensor(
|
| 195 |
+
inputs=last_obs,
|
| 196 |
+
spaces_struct=self._input_obs_base_struct,
|
| 197 |
+
# Our items are individual observations (no batch axis present).
|
| 198 |
+
batch_axis=False,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Write new observation directly back into the episode.
|
| 202 |
+
sa_episode.set_observations(at_indices=-1, new_data=flattened_obs)
|
| 203 |
+
# We set the Episode's observation space to ours so that we can safely
|
| 204 |
+
# set the last obs to the new value (without causing a space mismatch
|
| 205 |
+
# error).
|
| 206 |
+
sa_episode.observation_space = self.observation_space
|
| 207 |
+
|
| 208 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.common.frame_stacking import _FrameStacking
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
FrameStackingEnvToModule = partial(_FrameStacking, as_learner_connector=False)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Collection, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
from gymnasium.spaces import Discrete, MultiDiscrete
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tree
|
| 7 |
+
|
| 8 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 9 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 10 |
+
from ray.rllib.utils.annotations import override
|
| 11 |
+
from ray.rllib.utils.filter import MeanStdFilter as _MeanStdFilter, RunningStat
|
| 12 |
+
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
| 13 |
+
from ray.rllib.utils.typing import AgentID, EpisodeType, StateDict
|
| 14 |
+
from ray.util.annotations import PublicAPI
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@PublicAPI(stability="alpha")
|
| 18 |
+
class MeanStdFilter(ConnectorV2):
|
| 19 |
+
"""A connector used to mean-std-filter observations.
|
| 20 |
+
|
| 21 |
+
Incoming observations are filtered such that the output of this filter is on
|
| 22 |
+
average 0.0 and has a standard deviation of 1.0. If the observation space is
|
| 23 |
+
a (possibly nested) dict, this filtering is applied separately per element of
|
| 24 |
+
the observation space (except for discrete- and multi-discrete elements, which
|
| 25 |
+
are left as-is).
|
| 26 |
+
|
| 27 |
+
This connector is stateful as it continues to update its internal stats on mean
|
| 28 |
+
and std values as new data is pushed through it (unless `update_stats` is False).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
@override(ConnectorV2)
|
| 32 |
+
def recompute_output_observation_space(
|
| 33 |
+
self,
|
| 34 |
+
input_observation_space: gym.Space,
|
| 35 |
+
input_action_space: gym.Space,
|
| 36 |
+
) -> gym.Space:
|
| 37 |
+
_input_observation_space_struct = get_base_struct_from_space(
|
| 38 |
+
input_observation_space
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Adjust our observation space's Boxes (only if clipping is active).
|
| 42 |
+
_observation_space_struct = tree.map_structure(
|
| 43 |
+
lambda s: (
|
| 44 |
+
s
|
| 45 |
+
if not isinstance(s, gym.spaces.Box)
|
| 46 |
+
else gym.spaces.Box(
|
| 47 |
+
low=-self.clip_by_value,
|
| 48 |
+
high=self.clip_by_value,
|
| 49 |
+
shape=s.shape,
|
| 50 |
+
dtype=s.dtype,
|
| 51 |
+
)
|
| 52 |
+
),
|
| 53 |
+
_input_observation_space_struct,
|
| 54 |
+
)
|
| 55 |
+
if isinstance(input_observation_space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
| 56 |
+
return type(input_observation_space)(_observation_space_struct)
|
| 57 |
+
else:
|
| 58 |
+
return _observation_space_struct
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
*,
|
| 63 |
+
multi_agent: bool = False,
|
| 64 |
+
de_mean_to_zero: bool = True,
|
| 65 |
+
de_std_to_one: bool = True,
|
| 66 |
+
clip_by_value: Optional[float] = 10.0,
|
| 67 |
+
update_stats: bool = True,
|
| 68 |
+
**kwargs,
|
| 69 |
+
):
|
| 70 |
+
"""Initializes a MeanStdFilter instance.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
multi_agent: Whether this is a connector operating on a multi-agent
|
| 74 |
+
observation space mapping AgentIDs to individual agents' observations.
|
| 75 |
+
de_mean_to_zero: Whether to transform the mean values of the output data to
|
| 76 |
+
0.0. This is done by subtracting the incoming data by the currently
|
| 77 |
+
stored mean value.
|
| 78 |
+
de_std_to_one: Whether to transform the standard deviation values of the
|
| 79 |
+
output data to 1.0. This is done by dividing the incoming data by the
|
| 80 |
+
currently stored std value.
|
| 81 |
+
clip_by_value: If not None, clip the incoming data within the interval:
|
| 82 |
+
[-clip_by_value, +clip_by_value].
|
| 83 |
+
update_stats: Whether to update the internal mean and std stats with each
|
| 84 |
+
incoming sample (with each `__call__()`) or not. You should set this to
|
| 85 |
+
False if you would like to perform inference in a production
|
| 86 |
+
environment, without continuing to "learn" stats from new data.
|
| 87 |
+
"""
|
| 88 |
+
super().__init__(**kwargs)
|
| 89 |
+
|
| 90 |
+
self._multi_agent = multi_agent
|
| 91 |
+
|
| 92 |
+
# We simply use the old MeanStdFilter until non-connector env_runner is fully
|
| 93 |
+
# deprecated to avoid duplicate code
|
| 94 |
+
self.de_mean_to_zero = de_mean_to_zero
|
| 95 |
+
self.de_std_to_one = de_std_to_one
|
| 96 |
+
self.clip_by_value = clip_by_value
|
| 97 |
+
self._update_stats = update_stats
|
| 98 |
+
|
| 99 |
+
self._filters: Optional[Dict[AgentID, _MeanStdFilter]] = None
|
| 100 |
+
|
| 101 |
+
@override(ConnectorV2)
|
| 102 |
+
def __call__(
|
| 103 |
+
self,
|
| 104 |
+
*,
|
| 105 |
+
rl_module: RLModule,
|
| 106 |
+
batch: Dict[str, Any],
|
| 107 |
+
episodes: List[EpisodeType],
|
| 108 |
+
explore: Optional[bool] = None,
|
| 109 |
+
persistent_data: Optional[dict] = None,
|
| 110 |
+
**kwargs,
|
| 111 |
+
) -> Any:
|
| 112 |
+
if self._filters is None:
|
| 113 |
+
self._init_new_filters()
|
| 114 |
+
|
| 115 |
+
# This connector acts as a classic preprocessor. We process and then replace
|
| 116 |
+
# observations inside the episodes directly. Thus, all following connectors
|
| 117 |
+
# will only see and operate on the already normalized data (w/o having access
|
| 118 |
+
# anymore to the original observations).
|
| 119 |
+
for sa_episode in self.single_agent_episode_iterator(episodes):
|
| 120 |
+
sa_obs = sa_episode.get_observations(indices=-1)
|
| 121 |
+
try:
|
| 122 |
+
normalized_sa_obs = self._filters[sa_episode.agent_id](
|
| 123 |
+
sa_obs, update=self._update_stats
|
| 124 |
+
)
|
| 125 |
+
except KeyError:
|
| 126 |
+
raise KeyError(
|
| 127 |
+
"KeyError trying to access a filter by agent ID "
|
| 128 |
+
f"`{sa_episode.agent_id}`! You probably did NOT pass the "
|
| 129 |
+
f"`multi_agent=True` flag into the `MeanStdFilter()` constructor. "
|
| 130 |
+
)
|
| 131 |
+
sa_episode.set_observations(at_indices=-1, new_data=normalized_sa_obs)
|
| 132 |
+
# We set the Episode's observation space to ours so that we can safely
|
| 133 |
+
# set the last obs to the new value (without causing a space mismatch
|
| 134 |
+
# error).
|
| 135 |
+
sa_episode.observation_space = self.observation_space
|
| 136 |
+
|
| 137 |
+
# Leave `batch` as is. RLlib's default connector will automatically
|
| 138 |
+
# populate the OBS column therein from the episodes' now transformed
|
| 139 |
+
# observations.
|
| 140 |
+
return batch
|
| 141 |
+
|
| 142 |
+
@override(ConnectorV2)
|
| 143 |
+
def get_state(
|
| 144 |
+
self,
|
| 145 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 146 |
+
*,
|
| 147 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 148 |
+
**kwargs,
|
| 149 |
+
) -> StateDict:
|
| 150 |
+
if self._filters is None:
|
| 151 |
+
self._init_new_filters()
|
| 152 |
+
return self._get_state_from_filters(self._filters)
|
| 153 |
+
|
| 154 |
+
@override(ConnectorV2)
|
| 155 |
+
def set_state(self, state: StateDict) -> None:
|
| 156 |
+
if self._filters is None:
|
| 157 |
+
self._init_new_filters()
|
| 158 |
+
for agent_id, agent_state in state.items():
|
| 159 |
+
filter = self._filters[agent_id]
|
| 160 |
+
filter.shape = agent_state["shape"]
|
| 161 |
+
filter.demean = agent_state["de_mean_to_zero"]
|
| 162 |
+
filter.destd = agent_state["de_std_to_one"]
|
| 163 |
+
filter.clip = agent_state["clip_by_value"]
|
| 164 |
+
filter.running_stats = tree.unflatten_as(
|
| 165 |
+
filter.shape,
|
| 166 |
+
[RunningStat.from_state(s) for s in agent_state["running_stats"]],
|
| 167 |
+
)
|
| 168 |
+
# Do not update the buffer.
|
| 169 |
+
|
| 170 |
+
@override(ConnectorV2)
|
| 171 |
+
def reset_state(self) -> None:
|
| 172 |
+
"""Creates copy of current state and resets accumulated state"""
|
| 173 |
+
if not self._update_stats:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"State of {type(self).__name__} can only be changed when "
|
| 176 |
+
f"`update_stats` was set to False."
|
| 177 |
+
)
|
| 178 |
+
self._init_new_filters()
|
| 179 |
+
|
| 180 |
+
@override(ConnectorV2)
|
| 181 |
+
def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 182 |
+
if self._filters is None:
|
| 183 |
+
self._init_new_filters()
|
| 184 |
+
|
| 185 |
+
# Make sure data is uniform across given states.
|
| 186 |
+
ref = next(iter(states[0].values()))
|
| 187 |
+
|
| 188 |
+
for state in states:
|
| 189 |
+
for agent_id, agent_state in state.items():
|
| 190 |
+
assert (
|
| 191 |
+
agent_state["shape"] == ref["shape"]
|
| 192 |
+
and agent_state["de_mean_to_zero"] == ref["de_mean_to_zero"]
|
| 193 |
+
and agent_state["de_std_to_one"] == ref["de_std_to_one"]
|
| 194 |
+
and agent_state["clip_by_value"] == ref["clip_by_value"]
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
_filter = _MeanStdFilter(
|
| 198 |
+
ref["shape"],
|
| 199 |
+
demean=ref["de_mean_to_zero"],
|
| 200 |
+
destd=ref["de_std_to_one"],
|
| 201 |
+
clip=ref["clip_by_value"],
|
| 202 |
+
)
|
| 203 |
+
# Override running stats of the filter with the ones stored in
|
| 204 |
+
# `agent_state`.
|
| 205 |
+
_filter.buffer = tree.unflatten_as(
|
| 206 |
+
agent_state["shape"],
|
| 207 |
+
[
|
| 208 |
+
RunningStat.from_state(stats)
|
| 209 |
+
for stats in agent_state["running_stats"]
|
| 210 |
+
],
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Leave the buffers as-is, since they should always only reflect
|
| 214 |
+
# what has happened on the particular env runner.
|
| 215 |
+
self._filters[agent_id].apply_changes(_filter, with_buffer=False)
|
| 216 |
+
|
| 217 |
+
return MeanStdFilter._get_state_from_filters(self._filters)
|
| 218 |
+
|
| 219 |
+
def _init_new_filters(self):
|
| 220 |
+
filter_shape = tree.map_structure(
|
| 221 |
+
lambda s: (
|
| 222 |
+
None if isinstance(s, (Discrete, MultiDiscrete)) else np.array(s.shape)
|
| 223 |
+
),
|
| 224 |
+
get_base_struct_from_space(self.input_observation_space),
|
| 225 |
+
)
|
| 226 |
+
if not self._multi_agent:
|
| 227 |
+
filter_shape = {None: filter_shape}
|
| 228 |
+
|
| 229 |
+
del self._filters
|
| 230 |
+
self._filters = {
|
| 231 |
+
agent_id: _MeanStdFilter(
|
| 232 |
+
agent_filter_shape,
|
| 233 |
+
demean=self.de_mean_to_zero,
|
| 234 |
+
destd=self.de_std_to_one,
|
| 235 |
+
clip=self.clip_by_value,
|
| 236 |
+
)
|
| 237 |
+
for agent_id, agent_filter_shape in filter_shape.items()
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def _get_state_from_filters(filters: Dict[AgentID, Dict[str, Any]]):
|
| 242 |
+
ret = {}
|
| 243 |
+
for agent_id, agent_filter in filters.items():
|
| 244 |
+
ret[agent_id] = {
|
| 245 |
+
"shape": agent_filter.shape,
|
| 246 |
+
"de_mean_to_zero": agent_filter.demean,
|
| 247 |
+
"de_std_to_one": agent_filter.destd,
|
| 248 |
+
"clip_by_value": agent_filter.clip,
|
| 249 |
+
"running_stats": [
|
| 250 |
+
s.to_state() for s in tree.flatten(agent_filter.running_stats)
|
| 251 |
+
],
|
| 252 |
+
}
|
| 253 |
+
return ret
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
|
| 6 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 7 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 8 |
+
from ray.rllib.utils.annotations import override
|
| 9 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 10 |
+
from ray.util.annotations import PublicAPI
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@PublicAPI(stability="alpha")
|
| 14 |
+
class ObservationPreprocessor(ConnectorV2, abc.ABC):
|
| 15 |
+
"""Env-to-module connector performing one preprocessor step on the last observation.
|
| 16 |
+
|
| 17 |
+
This is a convenience class that simplifies the writing of few-step preprocessor
|
| 18 |
+
connectors.
|
| 19 |
+
|
| 20 |
+
Users must implement the `preprocess()` method, which simplifies the usual procedure
|
| 21 |
+
of extracting some data from a list of episodes and adding it to the batch to a mere
|
| 22 |
+
"old-observation --transform--> return new-observation" step.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
@override(ConnectorV2)
|
| 26 |
+
def recompute_output_observation_space(
|
| 27 |
+
self,
|
| 28 |
+
input_observation_space: gym.Space,
|
| 29 |
+
input_action_space: gym.Space,
|
| 30 |
+
) -> gym.Space:
|
| 31 |
+
# Users should override this method only in case the `ObservationPreprocessor`
|
| 32 |
+
# changes the observation space of the pipeline. In this case, return the new
|
| 33 |
+
# observation space based on the incoming one (`input_observation_space`).
|
| 34 |
+
return super().recompute_output_observation_space(
|
| 35 |
+
input_observation_space, input_action_space
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
@abc.abstractmethod
|
| 39 |
+
def preprocess(self, observation):
|
| 40 |
+
"""Override to implement the preprocessing logic.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
observation: A single (non-batched) observation item for a single agent to
|
| 44 |
+
be processed by this connector.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
The new observation after `observation` has been preprocessed.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
@override(ConnectorV2)
|
| 51 |
+
def __call__(
|
| 52 |
+
self,
|
| 53 |
+
*,
|
| 54 |
+
rl_module: RLModule,
|
| 55 |
+
batch: Dict[str, Any],
|
| 56 |
+
episodes: List[EpisodeType],
|
| 57 |
+
explore: Optional[bool] = None,
|
| 58 |
+
persistent_data: Optional[dict] = None,
|
| 59 |
+
**kwargs,
|
| 60 |
+
) -> Any:
|
| 61 |
+
# We process and then replace observations inside the episodes directly.
|
| 62 |
+
# Thus, all following connectors will only see and operate on the already
|
| 63 |
+
# processed observation (w/o having access anymore to the original
|
| 64 |
+
# observations).
|
| 65 |
+
for sa_episode in self.single_agent_episode_iterator(episodes):
|
| 66 |
+
observation = sa_episode.get_observations(-1)
|
| 67 |
+
|
| 68 |
+
# Process the observation and write the new observation back into the
|
| 69 |
+
# episode.
|
| 70 |
+
new_observation = self.preprocess(observation=observation)
|
| 71 |
+
sa_episode.set_observations(at_indices=-1, new_data=new_observation)
|
| 72 |
+
# We set the Episode's observation space to ours so that we can safely
|
| 73 |
+
# set the last obs to the new value (without causing a space mismatch
|
| 74 |
+
# error).
|
| 75 |
+
sa_episode.observation_space = self.observation_space
|
| 76 |
+
|
| 77 |
+
# Leave `batch` as is. RLlib's default connector will automatically
|
| 78 |
+
# populate the OBS column therein from the episodes' now transformed
|
| 79 |
+
# observations.
|
| 80 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
from gymnasium.spaces import Box
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 8 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 9 |
+
from ray.rllib.utils.annotations import override
|
| 10 |
+
from ray.rllib.utils.spaces.space_utils import (
|
| 11 |
+
batch as batch_fn,
|
| 12 |
+
flatten_to_single_ndarray,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 15 |
+
from ray.util.annotations import PublicAPI
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@PublicAPI(stability="alpha")
|
| 19 |
+
class PrevActionsPrevRewards(ConnectorV2):
|
| 20 |
+
"""A connector piece that adds previous rewards and actions to the input obs.
|
| 21 |
+
|
| 22 |
+
- Requires Columns.OBS to be already a part of the batch.
|
| 23 |
+
- This connector makes the assumption that under the Columns.OBS key in batch,
|
| 24 |
+
there is either a list of individual env observations to be flattened (single-agent
|
| 25 |
+
case) or a dict mapping (AgentID, ModuleID)-tuples to lists of data items to be
|
| 26 |
+
flattened (multi-agent case).
|
| 27 |
+
- Converts Columns.OBS data into a dict (or creates a sub-dict if obs are
|
| 28 |
+
already a dict), and adds "prev_rewards" and "prev_actions"
|
| 29 |
+
to this dict. The original observations are stored under the self.ORIG_OBS_KEY in
|
| 30 |
+
that dict.
|
| 31 |
+
- If your RLModule does not handle dict inputs, you will have to plug in an
|
| 32 |
+
`FlattenObservations` connector piece after this one.
|
| 33 |
+
- Does NOT work in a Learner pipeline as it operates on individual observation
|
| 34 |
+
items (as opposed to batched/time-ranked data).
|
| 35 |
+
- Therefore, assumes that the altered (flattened) observations will be written
|
| 36 |
+
back into the episode by a later connector piece in the env-to-module pipeline
|
| 37 |
+
(which this piece is part of as well).
|
| 38 |
+
- Only reads reward- and action information from the given list of Episode objects.
|
| 39 |
+
- Does NOT write any observations (or other data) to the given Episode objects.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
ORIG_OBS_KEY = "_orig_obs"
|
| 43 |
+
PREV_ACTIONS_KEY = "prev_n_actions"
|
| 44 |
+
PREV_REWARDS_KEY = "prev_n_rewards"
|
| 45 |
+
|
| 46 |
+
@override(ConnectorV2)
|
| 47 |
+
def recompute_output_observation_space(
|
| 48 |
+
self,
|
| 49 |
+
input_observation_space: gym.Space,
|
| 50 |
+
input_action_space: gym.Space,
|
| 51 |
+
) -> gym.Space:
|
| 52 |
+
if self._multi_agent:
|
| 53 |
+
ret = {}
|
| 54 |
+
for agent_id, obs_space in input_observation_space.spaces.items():
|
| 55 |
+
act_space = input_action_space[agent_id]
|
| 56 |
+
ret[agent_id] = self._convert_individual_space(obs_space, act_space)
|
| 57 |
+
return gym.spaces.Dict(ret)
|
| 58 |
+
else:
|
| 59 |
+
return self._convert_individual_space(
|
| 60 |
+
input_observation_space, input_action_space
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
input_observation_space: Optional[gym.Space] = None,
|
| 66 |
+
input_action_space: Optional[gym.Space] = None,
|
| 67 |
+
*,
|
| 68 |
+
multi_agent: bool = False,
|
| 69 |
+
n_prev_actions: int = 1,
|
| 70 |
+
n_prev_rewards: int = 1,
|
| 71 |
+
**kwargs,
|
| 72 |
+
):
|
| 73 |
+
"""Initializes a PrevActionsPrevRewards instance.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
multi_agent: Whether this is a connector operating on a multi-agent
|
| 77 |
+
observation space mapping AgentIDs to individual agents' observations.
|
| 78 |
+
n_prev_actions: The number of previous actions to include in the output
|
| 79 |
+
data. Discrete actions are ont-hot'd. If > 1, will concatenate the
|
| 80 |
+
individual action tensors.
|
| 81 |
+
n_prev_rewards: The number of previous rewards to include in the output
|
| 82 |
+
data.
|
| 83 |
+
"""
|
| 84 |
+
super().__init__(
|
| 85 |
+
input_observation_space=input_observation_space,
|
| 86 |
+
input_action_space=input_action_space,
|
| 87 |
+
**kwargs,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self._multi_agent = multi_agent
|
| 91 |
+
self.n_prev_actions = n_prev_actions
|
| 92 |
+
self.n_prev_rewards = n_prev_rewards
|
| 93 |
+
|
| 94 |
+
# TODO: Move into input_observation_space setter
|
| 95 |
+
# Thus far, this connector piece only operates on discrete action spaces.
|
| 96 |
+
# act_spaces = [self.input_action_space]
|
| 97 |
+
# if self._multi_agent:
|
| 98 |
+
# act_spaces = self.input_action_space.spaces.values()
|
| 99 |
+
# if not all(isinstance(s, gym.spaces.Discrete) for s in act_spaces):
|
| 100 |
+
# raise ValueError(
|
| 101 |
+
# f"{type(self).__name__} only works on Discrete action spaces "
|
| 102 |
+
# f"thus far (or, for multi-agent, on Dict spaces mapping AgentIDs to "
|
| 103 |
+
# f"the individual agents' Discrete action spaces)!"
|
| 104 |
+
# )
|
| 105 |
+
|
| 106 |
+
@override(ConnectorV2)
|
| 107 |
+
def __call__(
|
| 108 |
+
self,
|
| 109 |
+
*,
|
| 110 |
+
rl_module: RLModule,
|
| 111 |
+
batch: Optional[Dict[str, Any]],
|
| 112 |
+
episodes: List[EpisodeType],
|
| 113 |
+
explore: Optional[bool] = None,
|
| 114 |
+
shared_data: Optional[dict] = None,
|
| 115 |
+
**kwargs,
|
| 116 |
+
) -> Any:
|
| 117 |
+
for sa_episode in self.single_agent_episode_iterator(
|
| 118 |
+
episodes, agents_that_stepped_only=True
|
| 119 |
+
):
|
| 120 |
+
# Episode is not numpy'ized yet and thus still operates on lists of items.
|
| 121 |
+
assert not sa_episode.is_numpy
|
| 122 |
+
|
| 123 |
+
augmented_obs = {self.ORIG_OBS_KEY: sa_episode.get_observations(-1)}
|
| 124 |
+
|
| 125 |
+
if self.n_prev_actions:
|
| 126 |
+
augmented_obs[self.PREV_ACTIONS_KEY] = flatten_to_single_ndarray(
|
| 127 |
+
batch_fn(
|
| 128 |
+
sa_episode.get_actions(
|
| 129 |
+
indices=slice(-self.n_prev_actions, None),
|
| 130 |
+
fill=0.0,
|
| 131 |
+
one_hot_discrete=True,
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if self.n_prev_rewards:
|
| 137 |
+
augmented_obs[self.PREV_REWARDS_KEY] = np.array(
|
| 138 |
+
sa_episode.get_rewards(
|
| 139 |
+
indices=slice(-self.n_prev_rewards, None),
|
| 140 |
+
fill=0.0,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Write new observation directly back into the episode.
|
| 145 |
+
sa_episode.set_observations(at_indices=-1, new_data=augmented_obs)
|
| 146 |
+
# We set the Episode's observation space to ours so that we can safely
|
| 147 |
+
# set the last obs to the new value (without causing a space mismatch
|
| 148 |
+
# error).
|
| 149 |
+
sa_episode.observation_space = self.observation_space
|
| 150 |
+
|
| 151 |
+
return batch
|
| 152 |
+
|
| 153 |
+
def _convert_individual_space(self, obs_space, act_space):
|
| 154 |
+
return gym.spaces.Dict(
|
| 155 |
+
{
|
| 156 |
+
self.ORIG_OBS_KEY: obs_space,
|
| 157 |
+
# Currently only works for Discrete action spaces.
|
| 158 |
+
self.PREV_ACTIONS_KEY: Box(
|
| 159 |
+
0.0, 1.0, (act_space.n * self.n_prev_actions,), np.float32
|
| 160 |
+
),
|
| 161 |
+
self.PREV_REWARDS_KEY: Box(
|
| 162 |
+
float("-inf"),
|
| 163 |
+
float("inf"),
|
| 164 |
+
(self.n_prev_rewards,),
|
| 165 |
+
np.float32,
|
| 166 |
+
),
|
| 167 |
+
}
|
| 168 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 4 |
+
from ray.rllib.core.columns import Columns
|
| 5 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 6 |
+
from ray.rllib.utils.annotations import override
|
| 7 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 8 |
+
from ray.util.annotations import PublicAPI
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@PublicAPI(stability="alpha")
|
| 12 |
+
class WriteObservationsToEpisodes(ConnectorV2):
|
| 13 |
+
"""Writes the observations from the batch into the running episodes.
|
| 14 |
+
|
| 15 |
+
Note: This is one of the default env-to-module ConnectorV2 pieces that are added
|
| 16 |
+
automatically by RLlib into every env-to-module connector pipelines, unless
|
| 17 |
+
`config.add_default_connectors_to_env_to_module_pipeline` is set to False.
|
| 18 |
+
|
| 19 |
+
The default env-to-module connector pipeline is:
|
| 20 |
+
[
|
| 21 |
+
[0 or more user defined ConnectorV2 pieces],
|
| 22 |
+
AddObservationsFromEpisodesToBatch,
|
| 23 |
+
AddStatesFromEpisodesToBatch,
|
| 24 |
+
AgentToModuleMapping, # only in multi-agent setups!
|
| 25 |
+
BatchIndividualItems,
|
| 26 |
+
NumpyToTensor,
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
This ConnectorV2:
|
| 30 |
+
- Operates on a batch that already has observations in it and a list of Episode
|
| 31 |
+
objects.
|
| 32 |
+
- Writes the observation(s) from the batch to all the given episodes. Thereby
|
| 33 |
+
the number of observations in the batch must match the length of the list of
|
| 34 |
+
episodes given.
|
| 35 |
+
- Does NOT alter any observations (or other data) in the batch.
|
| 36 |
+
- Can only be used in an EnvToModule pipeline (writing into Episode objects in a
|
| 37 |
+
Learner pipeline does not make a lot of sense as - after the learner update - the
|
| 38 |
+
list of episodes is discarded).
|
| 39 |
+
|
| 40 |
+
.. testcode::
|
| 41 |
+
|
| 42 |
+
import gymnasium as gym
|
| 43 |
+
import numpy as np
|
| 44 |
+
|
| 45 |
+
from ray.rllib.connectors.env_to_module import WriteObservationsToEpisodes
|
| 46 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 47 |
+
from ray.rllib.utils.test_utils import check
|
| 48 |
+
|
| 49 |
+
# Assume we have two episodes (vectorized), then our forward batch will carry
|
| 50 |
+
# two observation records (batch size = 2).
|
| 51 |
+
# The connector in this example will write these two (possibly transformed)
|
| 52 |
+
# observations back into the two respective SingleAgentEpisode objects.
|
| 53 |
+
batch = {
|
| 54 |
+
"obs": [np.array([0.0, 1.0], np.float32), np.array([2.0, 3.0], np.float32)],
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Our two episodes have one observation each (i.e. the reset one). This is the
|
| 58 |
+
# one that will be overwritten by the connector in this example.
|
| 59 |
+
obs_space = gym.spaces.Box(-10.0, 10.0, (2,), np.float32)
|
| 60 |
+
act_space = gym.spaces.Discrete(2)
|
| 61 |
+
episodes = [
|
| 62 |
+
SingleAgentEpisode(
|
| 63 |
+
observation_space=obs_space,
|
| 64 |
+
observations=[np.array([-10, -20], np.float32)],
|
| 65 |
+
len_lookback_buffer=0,
|
| 66 |
+
) for _ in range(2)
|
| 67 |
+
]
|
| 68 |
+
# Make sure everything is setup correctly.
|
| 69 |
+
check(episodes[0].get_observations(0), [-10.0, -20.0])
|
| 70 |
+
check(episodes[1].get_observations(-1), [-10.0, -20.0])
|
| 71 |
+
|
| 72 |
+
# Create our connector piece.
|
| 73 |
+
connector = WriteObservationsToEpisodes(obs_space, act_space)
|
| 74 |
+
|
| 75 |
+
# Call the connector (and thereby write the transformed observations back
|
| 76 |
+
# into the episodes).
|
| 77 |
+
output_batch = connector(
|
| 78 |
+
rl_module=None, # This particular connector works without an RLModule.
|
| 79 |
+
batch=batch,
|
| 80 |
+
episodes=episodes,
|
| 81 |
+
explore=True,
|
| 82 |
+
shared_data={},
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# The connector does NOT change the data batch being passed through.
|
| 86 |
+
check(output_batch, batch)
|
| 87 |
+
|
| 88 |
+
# However, the connector has overwritten the last observations in the episodes.
|
| 89 |
+
check(episodes[0].get_observations(-1), [0.0, 1.0])
|
| 90 |
+
check(episodes[1].get_observations(0), [2.0, 3.0])
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
@override(ConnectorV2)
|
| 94 |
+
def __call__(
|
| 95 |
+
self,
|
| 96 |
+
*,
|
| 97 |
+
rl_module: RLModule,
|
| 98 |
+
batch: Optional[Dict[str, Any]],
|
| 99 |
+
episodes: List[EpisodeType],
|
| 100 |
+
explore: Optional[bool] = None,
|
| 101 |
+
shared_data: Optional[dict] = None,
|
| 102 |
+
**kwargs,
|
| 103 |
+
) -> Any:
|
| 104 |
+
observations = batch.get(Columns.OBS)
|
| 105 |
+
|
| 106 |
+
if observations is None:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"`batch` must already have a column named {Columns.OBS} in it "
|
| 109 |
+
f"for this connector to work!"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Note that the following loop works with multi-agent as well as with
|
| 113 |
+
# single-agent episode, as long as the following conditions are met (these
|
| 114 |
+
# will be validated by `self.single_agent_episode_iterator()`):
|
| 115 |
+
# - Per single agent episode, one observation item is expected to exist in
|
| 116 |
+
# `data`, either in a list directly under the "obs" key OR for multi-agent:
|
| 117 |
+
# in a list sitting under a key `(agent_id, module_id)` of a dict sitting
|
| 118 |
+
# under the "obs" key.
|
| 119 |
+
for sa_episode, obs in self.single_agent_episode_iterator(
|
| 120 |
+
episodes=episodes, zip_with_batch_column=observations
|
| 121 |
+
):
|
| 122 |
+
# Make sure episodes are NOT numpy'ized yet (we are expecting to run in an
|
| 123 |
+
# env-to-module pipeline).
|
| 124 |
+
assert not sa_episode.is_numpy
|
| 125 |
+
# Write new information into the episode.
|
| 126 |
+
sa_episode.set_observations(at_indices=-1, new_data=obs)
|
| 127 |
+
# Change the observation space of the sa_episode.
|
| 128 |
+
sa_episode.observation_space = self.observation_space
|
| 129 |
+
|
| 130 |
+
# Return the unchanged `batch`.
|
| 131 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.connectors.common.tensor_to_numpy import TensorToNumpy
|
| 2 |
+
from ray.rllib.connectors.common.module_to_agent_unmapping import ModuleToAgentUnmapping
|
| 3 |
+
from ray.rllib.connectors.module_to_env.get_actions import GetActions
|
| 4 |
+
from ray.rllib.connectors.module_to_env.listify_data_for_vector_env import (
|
| 5 |
+
ListifyDataForVectorEnv,
|
| 6 |
+
)
|
| 7 |
+
from ray.rllib.connectors.module_to_env.module_to_env_pipeline import (
|
| 8 |
+
ModuleToEnvPipeline,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.connectors.module_to_env.normalize_and_clip_actions import (
|
| 11 |
+
NormalizeAndClipActions,
|
| 12 |
+
)
|
| 13 |
+
from ray.rllib.connectors.module_to_env.remove_single_ts_time_rank_from_batch import (
|
| 14 |
+
RemoveSingleTsTimeRankFromBatch,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.connectors.module_to_env.unbatch_to_individual_items import (
|
| 17 |
+
UnBatchToIndividualItems,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"GetActions",
|
| 23 |
+
"ListifyDataForVectorEnv",
|
| 24 |
+
"ModuleToAgentUnmapping",
|
| 25 |
+
"ModuleToEnvPipeline",
|
| 26 |
+
"NormalizeAndClipActions",
|
| 27 |
+
"RemoveSingleTsTimeRankFromBatch",
|
| 28 |
+
"TensorToNumpy",
|
| 29 |
+
"UnBatchToIndividualItems",
|
| 30 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc
ADDED
|
Binary file (4.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 4 |
+
from ray.rllib.core.columns import Columns
|
| 5 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 6 |
+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
|
| 7 |
+
from ray.rllib.utils.annotations import override
|
| 8 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 9 |
+
from ray.util.annotations import PublicAPI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@PublicAPI(stability="alpha")
|
| 13 |
+
class GetActions(ConnectorV2):
|
| 14 |
+
"""Connector piece sampling actions from ACTION_DIST_INPUTS from an RLModule.
|
| 15 |
+
|
| 16 |
+
Note: This is one of the default module-to-env ConnectorV2 pieces that
|
| 17 |
+
are added automatically by RLlib into every module-to-env connector pipeline,
|
| 18 |
+
unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
|
| 19 |
+
False.
|
| 20 |
+
|
| 21 |
+
The default module-to-env connector pipeline is:
|
| 22 |
+
[
|
| 23 |
+
GetActions,
|
| 24 |
+
TensorToNumpy,
|
| 25 |
+
UnBatchToIndividualItems,
|
| 26 |
+
ModuleToAgentUnmapping, # only in multi-agent setups!
|
| 27 |
+
RemoveSingleTsTimeRankFromBatch,
|
| 28 |
+
|
| 29 |
+
[0 or more user defined ConnectorV2 pieces],
|
| 30 |
+
|
| 31 |
+
NormalizeAndClipActions,
|
| 32 |
+
ListifyDataForVectorEnv,
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
If necessary, this connector samples actions, given action dist. inputs and a
|
| 36 |
+
dist. class.
|
| 37 |
+
The connector will only sample from the action distribution, if the
|
| 38 |
+
Columns.ACTIONS key cannot be found in `data`. Otherwise, it'll behave
|
| 39 |
+
as pass-through. If Columns.ACTIONS is NOT present in `data`, but
|
| 40 |
+
Columns.ACTION_DIST_INPUTS is, this connector will create a new action
|
| 41 |
+
distribution using the given RLModule and sample from its distribution class
|
| 42 |
+
(deterministically, if we are not exploring, stochastically, if we are).
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
@override(ConnectorV2)
|
| 46 |
+
def __call__(
|
| 47 |
+
self,
|
| 48 |
+
*,
|
| 49 |
+
rl_module: RLModule,
|
| 50 |
+
batch: Dict[str, Any],
|
| 51 |
+
episodes: List[EpisodeType],
|
| 52 |
+
explore: Optional[bool] = None,
|
| 53 |
+
shared_data: Optional[dict] = None,
|
| 54 |
+
**kwargs,
|
| 55 |
+
) -> Any:
|
| 56 |
+
is_multi_agent = isinstance(episodes[0], MultiAgentEpisode)
|
| 57 |
+
|
| 58 |
+
if is_multi_agent:
|
| 59 |
+
for module_id, module_data in batch.copy().items():
|
| 60 |
+
self._get_actions(module_data, rl_module[module_id], explore)
|
| 61 |
+
else:
|
| 62 |
+
self._get_actions(batch, rl_module, explore)
|
| 63 |
+
|
| 64 |
+
return batch
|
| 65 |
+
|
| 66 |
+
def _get_actions(self, batch, sa_rl_module, explore):
|
| 67 |
+
# Action have already been sampled -> Early out.
|
| 68 |
+
if Columns.ACTIONS in batch:
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
# ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` ->
|
| 72 |
+
# Create a new action distribution object.
|
| 73 |
+
if Columns.ACTION_DIST_INPUTS in batch:
|
| 74 |
+
if explore:
|
| 75 |
+
action_dist_class = sa_rl_module.get_exploration_action_dist_cls()
|
| 76 |
+
else:
|
| 77 |
+
action_dist_class = sa_rl_module.get_inference_action_dist_cls()
|
| 78 |
+
action_dist = action_dist_class.from_logits(
|
| 79 |
+
batch[Columns.ACTION_DIST_INPUTS],
|
| 80 |
+
)
|
| 81 |
+
if not explore:
|
| 82 |
+
action_dist = action_dist.to_deterministic()
|
| 83 |
+
|
| 84 |
+
# Sample actions from the distribution.
|
| 85 |
+
actions = action_dist.sample()
|
| 86 |
+
batch[Columns.ACTIONS] = actions
|
| 87 |
+
|
| 88 |
+
# For convenience and if possible, compute action logp from distribution
|
| 89 |
+
# and add to output.
|
| 90 |
+
if Columns.ACTION_LOGP not in batch:
|
| 91 |
+
batch[Columns.ACTION_LOGP] = action_dist.logp(actions)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 4 |
+
from ray.rllib.core.columns import Columns
|
| 5 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 6 |
+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
|
| 7 |
+
from ray.rllib.utils.annotations import override
|
| 8 |
+
from ray.rllib.utils.spaces.space_utils import batch as batch_fn
|
| 9 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 10 |
+
from ray.util.annotations import PublicAPI
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@PublicAPI(stability="alpha")
|
| 14 |
+
class ListifyDataForVectorEnv(ConnectorV2):
|
| 15 |
+
"""Performs conversion from ConnectorV2-style format to env/episode insertion.
|
| 16 |
+
|
| 17 |
+
Note: This is one of the default module-to-env ConnectorV2 pieces that
|
| 18 |
+
are added automatically by RLlib into every module-to-env connector pipeline,
|
| 19 |
+
unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
|
| 20 |
+
False.
|
| 21 |
+
|
| 22 |
+
The default module-to-env connector pipeline is:
|
| 23 |
+
[
|
| 24 |
+
GetActions,
|
| 25 |
+
TensorToNumpy,
|
| 26 |
+
UnBatchToIndividualItems,
|
| 27 |
+
ModuleToAgentUnmapping, # only in multi-agent setups!
|
| 28 |
+
RemoveSingleTsTimeRankFromBatch,
|
| 29 |
+
|
| 30 |
+
[0 or more user defined ConnectorV2 pieces],
|
| 31 |
+
|
| 32 |
+
NormalizeAndClipActions,
|
| 33 |
+
ListifyDataForVectorEnv,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
Single agent case:
|
| 37 |
+
Convert from:
|
| 38 |
+
[col] -> [(episode_id,)] -> [list of items].
|
| 39 |
+
To:
|
| 40 |
+
[col] -> [list of items].
|
| 41 |
+
|
| 42 |
+
Multi-agent case:
|
| 43 |
+
Convert from:
|
| 44 |
+
[col] -> [(episode_id, agent_id, module_id)] -> list of items.
|
| 45 |
+
To:
|
| 46 |
+
[col] -> [list of multi-agent dicts].
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
@override(ConnectorV2)
|
| 50 |
+
def __call__(
|
| 51 |
+
self,
|
| 52 |
+
*,
|
| 53 |
+
rl_module: RLModule,
|
| 54 |
+
batch: Dict[str, Any],
|
| 55 |
+
episodes: List[EpisodeType],
|
| 56 |
+
explore: Optional[bool] = None,
|
| 57 |
+
shared_data: Optional[dict] = None,
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> Any:
|
| 60 |
+
for column, column_data in batch.copy().items():
|
| 61 |
+
# Multi-agent case: Create lists of multi-agent dicts under each column.
|
| 62 |
+
if isinstance(episodes[0], MultiAgentEpisode):
|
| 63 |
+
# TODO (sven): Support vectorized MultiAgentEnv
|
| 64 |
+
assert len(episodes) == 1
|
| 65 |
+
new_column_data = [{}]
|
| 66 |
+
|
| 67 |
+
for key, value in batch[column].items():
|
| 68 |
+
assert len(value) == 1
|
| 69 |
+
eps_id, agent_id, module_id = key
|
| 70 |
+
new_column_data[0][agent_id] = value[0]
|
| 71 |
+
batch[column] = new_column_data
|
| 72 |
+
# Single-agent case: Create simple lists under each column.
|
| 73 |
+
else:
|
| 74 |
+
batch[column] = [
|
| 75 |
+
d for key in batch[column].keys() for d in batch[column][key]
|
| 76 |
+
]
|
| 77 |
+
# Batch actions for (single-agent) gym.vector.Env.
|
| 78 |
+
# All other columns, leave listify'ed.
|
| 79 |
+
if column in [Columns.ACTIONS_FOR_ENV, Columns.ACTIONS]:
|
| 80 |
+
batch[column] = batch_fn(batch[column])
|
| 81 |
+
|
| 82 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2
|
| 2 |
+
from ray.util.annotations import PublicAPI
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@PublicAPI(stability="alpha")
|
| 6 |
+
class ModuleToEnvPipeline(ConnectorPipelineV2):
|
| 7 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
|
| 6 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 7 |
+
from ray.rllib.core.columns import Columns
|
| 8 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 9 |
+
from ray.rllib.utils.annotations import override
|
| 10 |
+
from ray.rllib.utils.spaces.space_utils import (
|
| 11 |
+
clip_action,
|
| 12 |
+
get_base_struct_from_space,
|
| 13 |
+
unsquash_action,
|
| 14 |
+
)
|
| 15 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 16 |
+
from ray.util.annotations import PublicAPI
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@PublicAPI(stability="alpha")
|
| 20 |
+
class NormalizeAndClipActions(ConnectorV2):
|
| 21 |
+
"""Normalizes or clips actions in the input data (coming from the RLModule).
|
| 22 |
+
|
| 23 |
+
Note: This is one of the default module-to-env ConnectorV2 pieces that
|
| 24 |
+
are added automatically by RLlib into every module-to-env connector pipeline,
|
| 25 |
+
unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
|
| 26 |
+
False.
|
| 27 |
+
|
| 28 |
+
The default module-to-env connector pipeline is:
|
| 29 |
+
[
|
| 30 |
+
GetActions,
|
| 31 |
+
TensorToNumpy,
|
| 32 |
+
UnBatchToIndividualItems,
|
| 33 |
+
ModuleToAgentUnmapping, # only in multi-agent setups!
|
| 34 |
+
RemoveSingleTsTimeRankFromBatch,
|
| 35 |
+
|
| 36 |
+
[0 or more user defined ConnectorV2 pieces],
|
| 37 |
+
|
| 38 |
+
NormalizeAndClipActions,
|
| 39 |
+
ListifyDataForVectorEnv,
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
This ConnectorV2:
|
| 43 |
+
- Deep copies the Columns.ACTIONS in the incoming `data` into a new column:
|
| 44 |
+
Columns.ACTIONS_FOR_ENV.
|
| 45 |
+
- Loops through the Columns.ACTIONS in the incoming `data` and normalizes or clips
|
| 46 |
+
these depending on the c'tor settings in `config.normalize_actions` and
|
| 47 |
+
`config.clip_actions`.
|
| 48 |
+
- Only applies to envs with Box action spaces.
|
| 49 |
+
|
| 50 |
+
Normalizing is the process of mapping NN-outputs (which are usually small
|
| 51 |
+
numbers, e.g. between -1.0 and 1.0) to the bounds defined by the action-space.
|
| 52 |
+
Normalizing helps the NN to learn faster in environments with large ranges between
|
| 53 |
+
`low` and `high` bounds or skewed action bounds (e.g. Box(-3000.0, 1.0, ...)).
|
| 54 |
+
|
| 55 |
+
Clipping clips the actions computed by the NN (and sampled from a distribution)
|
| 56 |
+
between the bounds defined by the action-space. Note that clipping is only performed
|
| 57 |
+
if `normalize_actions` is False.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
@override(ConnectorV2)
|
| 61 |
+
def recompute_output_action_space(
|
| 62 |
+
self,
|
| 63 |
+
input_observation_space: gym.Space,
|
| 64 |
+
input_action_space: gym.Space,
|
| 65 |
+
) -> gym.Space:
|
| 66 |
+
self._action_space_struct = get_base_struct_from_space(input_action_space)
|
| 67 |
+
return input_action_space
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
input_observation_space: Optional[gym.Space] = None,
|
| 72 |
+
input_action_space: Optional[gym.Space] = None,
|
| 73 |
+
*,
|
| 74 |
+
normalize_actions: bool,
|
| 75 |
+
clip_actions: bool,
|
| 76 |
+
**kwargs,
|
| 77 |
+
):
|
| 78 |
+
"""Initializes a DefaultModuleToEnv (connector piece) instance.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
normalize_actions: If True, actions coming from the RLModule's distribution
|
| 82 |
+
(or are directly computed by the RLModule w/o sampling) will
|
| 83 |
+
be assumed 0.0 centered with a small stddev (only affecting Box
|
| 84 |
+
components) and thus be unsquashed (and clipped, just in case) to the
|
| 85 |
+
bounds of the env's action space. For example, if the action space of
|
| 86 |
+
the environment is `Box(-2.0, -0.5, (1,))`, the model outputs
|
| 87 |
+
mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9
|
| 88 |
+
from the resulting distribution, then this 0.9 will be unsquashed into
|
| 89 |
+
the [-2.0 -0.5] interval. If - after unsquashing - the action still
|
| 90 |
+
breaches the action space, it will simply be clipped.
|
| 91 |
+
clip_actions: If True, actions coming from the RLModule's distribution
|
| 92 |
+
(or are directly computed by the RLModule w/o sampling) will be clipped
|
| 93 |
+
such that they fit into the env's action space's bounds.
|
| 94 |
+
For example, if the action space of the environment is
|
| 95 |
+
`Box(-0.5, 0.5, (1,))`, the model outputs
|
| 96 |
+
mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9
|
| 97 |
+
from the resulting distribution, then this 0.9 will be clipped to 0.5
|
| 98 |
+
to fit into the [-0.5 0.5] interval.
|
| 99 |
+
"""
|
| 100 |
+
self._action_space_struct = None
|
| 101 |
+
|
| 102 |
+
super().__init__(input_observation_space, input_action_space, **kwargs)
|
| 103 |
+
|
| 104 |
+
self.normalize_actions = normalize_actions
|
| 105 |
+
self.clip_actions = clip_actions
|
| 106 |
+
|
| 107 |
+
@override(ConnectorV2)
|
| 108 |
+
def __call__(
|
| 109 |
+
self,
|
| 110 |
+
*,
|
| 111 |
+
rl_module: RLModule,
|
| 112 |
+
batch: Optional[Dict[str, Any]],
|
| 113 |
+
episodes: List[EpisodeType],
|
| 114 |
+
explore: Optional[bool] = None,
|
| 115 |
+
shared_data: Optional[dict] = None,
|
| 116 |
+
**kwargs,
|
| 117 |
+
) -> Any:
|
| 118 |
+
"""Based on settings, will normalize (unsquash) and/or clip computed actions.
|
| 119 |
+
|
| 120 |
+
This is such that the final actions (to be sent to the env) match the
|
| 121 |
+
environment's action space and thus don't lead to an error.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def _unsquash_or_clip(action_for_env, env_id, agent_id, module_id):
|
| 125 |
+
if agent_id is not None:
|
| 126 |
+
struct = self._action_space_struct[agent_id]
|
| 127 |
+
else:
|
| 128 |
+
struct = self._action_space_struct
|
| 129 |
+
|
| 130 |
+
if self.normalize_actions:
|
| 131 |
+
return unsquash_action(action_for_env, struct)
|
| 132 |
+
else:
|
| 133 |
+
return clip_action(action_for_env, struct)
|
| 134 |
+
|
| 135 |
+
# Normalize or clip this new actions_for_env column, leaving the originally
|
| 136 |
+
# computed/sampled actions intact.
|
| 137 |
+
if self.normalize_actions or self.clip_actions:
|
| 138 |
+
# Copy actions into separate column, just to go to the env.
|
| 139 |
+
batch[Columns.ACTIONS_FOR_ENV] = copy.deepcopy(batch[Columns.ACTIONS])
|
| 140 |
+
self.foreach_batch_item_change_in_place(
|
| 141 |
+
batch=batch,
|
| 142 |
+
column=Columns.ACTIONS_FOR_ENV,
|
| 143 |
+
func=_unsquash_or_clip,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tree # pip install dm_tree
|
| 5 |
+
|
| 6 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 7 |
+
from ray.rllib.core.columns import Columns
|
| 8 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 9 |
+
from ray.rllib.utils.annotations import override
|
| 10 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 11 |
+
from ray.util.annotations import PublicAPI
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@PublicAPI(stability="alpha")
|
| 15 |
+
class RemoveSingleTsTimeRankFromBatch(ConnectorV2):
|
| 16 |
+
"""
|
| 17 |
+
Note: This is one of the default module-to-env ConnectorV2 pieces that
|
| 18 |
+
are added automatically by RLlib into every module-to-env connector pipeline,
|
| 19 |
+
unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
|
| 20 |
+
False.
|
| 21 |
+
|
| 22 |
+
The default module-to-env connector pipeline is:
|
| 23 |
+
[
|
| 24 |
+
GetActions,
|
| 25 |
+
TensorToNumpy,
|
| 26 |
+
UnBatchToIndividualItems,
|
| 27 |
+
ModuleToAgentUnmapping, # only in multi-agent setups!
|
| 28 |
+
RemoveSingleTsTimeRankFromBatch,
|
| 29 |
+
|
| 30 |
+
[0 or more user defined ConnectorV2 pieces],
|
| 31 |
+
|
| 32 |
+
NormalizeAndClipActions,
|
| 33 |
+
ListifyDataForVectorEnv,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@override(ConnectorV2)
|
| 39 |
+
def __call__(
|
| 40 |
+
self,
|
| 41 |
+
*,
|
| 42 |
+
rl_module: RLModule,
|
| 43 |
+
batch: Optional[Dict[str, Any]],
|
| 44 |
+
episodes: List[EpisodeType],
|
| 45 |
+
explore: Optional[bool] = None,
|
| 46 |
+
shared_data: Optional[dict] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
) -> Any:
|
| 49 |
+
# If single ts time-rank had not been added, early out.
|
| 50 |
+
if shared_data is None or not shared_data.get("_added_single_ts_time_rank"):
|
| 51 |
+
return batch
|
| 52 |
+
|
| 53 |
+
def _remove_single_ts(item, eps_id, aid, mid):
|
| 54 |
+
# Only remove time-rank for modules that are statefule (only for those has
|
| 55 |
+
# a timerank been added).
|
| 56 |
+
if mid is None or rl_module[mid].is_stateful():
|
| 57 |
+
return tree.map_structure(lambda s: np.squeeze(s, axis=0), item)
|
| 58 |
+
return item
|
| 59 |
+
|
| 60 |
+
for column, column_data in batch.copy().items():
|
| 61 |
+
# Skip state_out (doesn't have a time rank).
|
| 62 |
+
if column == Columns.STATE_OUT:
|
| 63 |
+
continue
|
| 64 |
+
self.foreach_batch_item_change_in_place(
|
| 65 |
+
batch,
|
| 66 |
+
column=column,
|
| 67 |
+
func=_remove_single_ts,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import tree # pip install dm_tree
|
| 5 |
+
|
| 6 |
+
from ray.rllib.connectors.connector_v2 import ConnectorV2
|
| 7 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 8 |
+
from ray.rllib.utils.annotations import override
|
| 9 |
+
from ray.rllib.utils.spaces.space_utils import unbatch as unbatch_fn
|
| 10 |
+
from ray.rllib.utils.typing import EpisodeType
|
| 11 |
+
from ray.util.annotations import PublicAPI
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@PublicAPI(stability="alpha")
|
| 15 |
+
class UnBatchToIndividualItems(ConnectorV2):
|
| 16 |
+
"""Unbatches the given `data` back into the individual-batch-items format.
|
| 17 |
+
|
| 18 |
+
Note: This is one of the default module-to-env ConnectorV2 pieces that
|
| 19 |
+
are added automatically by RLlib into every module-to-env connector pipeline,
|
| 20 |
+
unless `config.add_default_connectors_to_module_to_env_pipeline` is set to
|
| 21 |
+
False.
|
| 22 |
+
|
| 23 |
+
The default module-to-env connector pipeline is:
|
| 24 |
+
[
|
| 25 |
+
GetActions,
|
| 26 |
+
TensorToNumpy,
|
| 27 |
+
UnBatchToIndividualItems,
|
| 28 |
+
ModuleToAgentUnmapping, # only in multi-agent setups!
|
| 29 |
+
RemoveSingleTsTimeRankFromBatch,
|
| 30 |
+
|
| 31 |
+
[0 or more user defined ConnectorV2 pieces],
|
| 32 |
+
|
| 33 |
+
NormalizeAndClipActions,
|
| 34 |
+
ListifyDataForVectorEnv,
|
| 35 |
+
]
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@override(ConnectorV2)
|
| 39 |
+
def __call__(
|
| 40 |
+
self,
|
| 41 |
+
*,
|
| 42 |
+
rl_module: RLModule,
|
| 43 |
+
batch: Dict[str, Any],
|
| 44 |
+
episodes: List[EpisodeType],
|
| 45 |
+
explore: Optional[bool] = None,
|
| 46 |
+
shared_data: Optional[dict] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
) -> Any:
|
| 49 |
+
memorized_map_structure = shared_data.get("memorized_map_structure")
|
| 50 |
+
|
| 51 |
+
# Simple case (no structure stored): Just unbatch.
|
| 52 |
+
if memorized_map_structure is None:
|
| 53 |
+
return tree.map_structure(lambda s: unbatch_fn(s), batch)
|
| 54 |
+
# Single agent case: Memorized structure is a list, whose indices map to
|
| 55 |
+
# eps_id values.
|
| 56 |
+
elif isinstance(memorized_map_structure, list):
|
| 57 |
+
for column, column_data in batch.copy().items():
|
| 58 |
+
column_data = unbatch_fn(column_data)
|
| 59 |
+
new_column_data = defaultdict(list)
|
| 60 |
+
for i, eps_id in enumerate(memorized_map_structure):
|
| 61 |
+
# Keys are always tuples to resemble multi-agent keys, which
|
| 62 |
+
# have the structure (eps_id, agent_id, module_id).
|
| 63 |
+
key = (eps_id,)
|
| 64 |
+
new_column_data[key].append(column_data[i])
|
| 65 |
+
batch[column] = dict(new_column_data)
|
| 66 |
+
# Multi-agent case: Memorized structure is dict mapping module_ids to lists of
|
| 67 |
+
# (eps_id, agent_id)-tuples, such that the original individual-items-based form
|
| 68 |
+
# can be constructed.
|
| 69 |
+
else:
|
| 70 |
+
for module_id, module_data in batch.copy().items():
|
| 71 |
+
if module_id not in memorized_map_structure:
|
| 72 |
+
raise KeyError(
|
| 73 |
+
f"ModuleID={module_id} not found in `memorized_map_structure`!"
|
| 74 |
+
)
|
| 75 |
+
for column, column_data in module_data.items():
|
| 76 |
+
column_data = unbatch_fn(column_data)
|
| 77 |
+
new_column_data = defaultdict(list)
|
| 78 |
+
for i, (eps_id, agent_id) in enumerate(
|
| 79 |
+
memorized_map_structure[module_id]
|
| 80 |
+
):
|
| 81 |
+
key = (eps_id, agent_id, module_id)
|
| 82 |
+
# TODO (sven): Support vectorization for MultiAgentEnvRunner.
|
| 83 |
+
# AgentIDs whose SingleAgentEpisodes are already done, should
|
| 84 |
+
# not send any data back to the EnvRunner for further
|
| 85 |
+
# processing.
|
| 86 |
+
if episodes[0].agent_episodes[agent_id].is_done:
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
new_column_data[key].append(column_data[i])
|
| 90 |
+
module_data[column] = dict(new_column_data)
|
| 91 |
+
|
| 92 |
+
return batch
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Registry of connector names for global access."""
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 5 |
+
from ray.rllib.connectors.connector import Connector, ConnectorContext
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ALL_CONNECTORS = dict()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@OldAPIStack
|
| 12 |
+
def register_connector(name: str, cls: Connector):
|
| 13 |
+
"""Register a connector for use with RLlib.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
name: Name to register.
|
| 17 |
+
cls: Callable that creates an env.
|
| 18 |
+
"""
|
| 19 |
+
if name in ALL_CONNECTORS:
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
if not issubclass(cls, Connector):
|
| 23 |
+
raise TypeError("Can only register Connector type.", cls)
|
| 24 |
+
|
| 25 |
+
# Record it in local registry in case we need to register everything
|
| 26 |
+
# again in the global registry, for example in the event of cluster
|
| 27 |
+
# restarts.
|
| 28 |
+
ALL_CONNECTORS[name] = cls
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@OldAPIStack
|
| 32 |
+
def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
|
| 33 |
+
# TODO(jungong) : switch the order of parameters man!!
|
| 34 |
+
"""Get a connector by its name and serialized config.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
name: name of the connector.
|
| 38 |
+
ctx: Connector context.
|
| 39 |
+
params: serialized parameters of the connector.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Constructed connector.
|
| 43 |
+
"""
|
| 44 |
+
if name not in ALL_CONNECTORS:
|
| 45 |
+
raise NameError("connector not found.", name)
|
| 46 |
+
return ALL_CONNECTORS[name].from_state(ctx, params)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Tuple, TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
from ray.rllib.connectors.action.clip import ClipActionsConnector
|
| 5 |
+
from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
|
| 6 |
+
from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
|
| 7 |
+
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
|
| 8 |
+
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
|
| 9 |
+
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
| 10 |
+
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
|
| 11 |
+
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
|
| 12 |
+
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
|
| 13 |
+
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
|
| 14 |
+
from ray.rllib.connectors.connector import Connector, ConnectorContext
|
| 15 |
+
from ray.rllib.connectors.registry import get_connector
|
| 16 |
+
from ray.rllib.connectors.agent.mean_std_filter import (
|
| 17 |
+
MeanStdObservationFilterAgentConnector,
|
| 18 |
+
ConcurrentMeanStdObservationFilterAgentConnector,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 21 |
+
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 25 |
+
from ray.rllib.policy.policy import Policy
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def __preprocessing_enabled(config: "AlgorithmConfig"):
|
| 31 |
+
if config._disable_preprocessor_api:
|
| 32 |
+
return False
|
| 33 |
+
# Same conditions as in RolloutWorker.__init__.
|
| 34 |
+
if config.is_atari and config.preprocessor_pref == "deepmind":
|
| 35 |
+
return False
|
| 36 |
+
if config.preprocessor_pref is None:
|
| 37 |
+
return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def __clip_rewards(config: "AlgorithmConfig"):
|
| 42 |
+
# Same logic as in RolloutWorker.__init__.
|
| 43 |
+
# We always clip rewards for Atari games.
|
| 44 |
+
return config.clip_rewards or config.is_atari
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@OldAPIStack
|
| 48 |
+
def get_agent_connectors_from_config(
|
| 49 |
+
ctx: ConnectorContext,
|
| 50 |
+
config: "AlgorithmConfig",
|
| 51 |
+
) -> AgentConnectorPipeline:
|
| 52 |
+
connectors = []
|
| 53 |
+
|
| 54 |
+
clip_rewards = __clip_rewards(config)
|
| 55 |
+
if clip_rewards is True:
|
| 56 |
+
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
|
| 57 |
+
elif type(clip_rewards) is float:
|
| 58 |
+
connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))
|
| 59 |
+
|
| 60 |
+
if __preprocessing_enabled(config):
|
| 61 |
+
connectors.append(ObsPreprocessorConnector(ctx))
|
| 62 |
+
|
| 63 |
+
# Filters should be after observation preprocessing
|
| 64 |
+
filter_connector = get_synced_filter_connector(
|
| 65 |
+
ctx,
|
| 66 |
+
)
|
| 67 |
+
# Configuration option "NoFilter" results in `filter_connector==None`.
|
| 68 |
+
if filter_connector:
|
| 69 |
+
connectors.append(filter_connector)
|
| 70 |
+
|
| 71 |
+
connectors.extend(
|
| 72 |
+
[
|
| 73 |
+
StateBufferConnector(ctx),
|
| 74 |
+
ViewRequirementAgentConnector(ctx),
|
| 75 |
+
]
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return AgentConnectorPipeline(ctx, connectors)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@OldAPIStack
|
| 82 |
+
def get_action_connectors_from_config(
|
| 83 |
+
ctx: ConnectorContext,
|
| 84 |
+
config: "AlgorithmConfig",
|
| 85 |
+
) -> ActionConnectorPipeline:
|
| 86 |
+
"""Default list of action connectors to use for a new policy.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
ctx: context used to create connectors.
|
| 90 |
+
config: The AlgorithmConfig object.
|
| 91 |
+
"""
|
| 92 |
+
connectors = [ConvertToNumpyConnector(ctx)]
|
| 93 |
+
if config.get("normalize_actions", False):
|
| 94 |
+
connectors.append(NormalizeActionsConnector(ctx))
|
| 95 |
+
if config.get("clip_actions", False):
|
| 96 |
+
connectors.append(ClipActionsConnector(ctx))
|
| 97 |
+
connectors.append(ImmutableActionsConnector(ctx))
|
| 98 |
+
return ActionConnectorPipeline(ctx, connectors)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@OldAPIStack
|
| 102 |
+
def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
|
| 103 |
+
"""Util to create agent and action connectors for a Policy.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
policy: Policy instance.
|
| 107 |
+
config: Algorithm config dict.
|
| 108 |
+
"""
|
| 109 |
+
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
| 110 |
+
|
| 111 |
+
assert (
|
| 112 |
+
policy.agent_connectors is None and policy.action_connectors is None
|
| 113 |
+
), "Can not create connectors for a policy that already has connectors."
|
| 114 |
+
|
| 115 |
+
policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
|
| 116 |
+
policy.action_connectors = get_action_connectors_from_config(ctx, config)
|
| 117 |
+
|
| 118 |
+
logger.info("Using connectors:")
|
| 119 |
+
logger.info(policy.agent_connectors.__str__(indentation=4))
|
| 120 |
+
logger.info(policy.action_connectors.__str__(indentation=4))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@OldAPIStack
|
| 124 |
+
def restore_connectors_for_policy(
|
| 125 |
+
policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
|
| 126 |
+
) -> Connector:
|
| 127 |
+
"""Util to create connector for a Policy based on serialized config.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
policy: Policy instance.
|
| 131 |
+
connector_config: Serialized connector config.
|
| 132 |
+
"""
|
| 133 |
+
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
| 134 |
+
name, params = connector_config
|
| 135 |
+
return get_connector(name, ctx, params)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# We need this filter selection mechanism temporarily to remain compatible to old API
|
| 139 |
+
@OldAPIStack
|
| 140 |
+
def get_synced_filter_connector(ctx: ConnectorContext):
|
| 141 |
+
filter_specifier = ctx.config.get("observation_filter")
|
| 142 |
+
if filter_specifier == "MeanStdFilter":
|
| 143 |
+
return MeanStdObservationFilterAgentConnector(ctx, clip=None)
|
| 144 |
+
elif filter_specifier == "ConcurrentMeanStdFilter":
|
| 145 |
+
return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None)
|
| 146 |
+
elif filter_specifier == "NoFilter":
|
| 147 |
+
return None
|
| 148 |
+
else:
|
| 149 |
+
raise Exception("Unknown observation_filter: " + str(filter_specifier))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@OldAPIStack
|
| 153 |
+
def maybe_get_filters_for_syncing(rollout_worker, policy_id):
|
| 154 |
+
# As long as the historic filter synchronization mechanism is in
|
| 155 |
+
# place, we need to put filters into self.filters so that they get
|
| 156 |
+
# synchronized
|
| 157 |
+
policy = rollout_worker.policy_map[policy_id]
|
| 158 |
+
if not policy.agent_connectors:
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
filter_connectors = policy.agent_connectors[SyncedFilterAgentConnector]
|
| 162 |
+
# There can only be one filter at a time
|
| 163 |
+
if not filter_connectors:
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
assert len(filter_connectors) == 1, (
|
| 167 |
+
"ConnectorPipeline has multiple connectors of type "
|
| 168 |
+
"SyncedFilterAgentConnector but can only have one."
|
| 169 |
+
)
|
| 170 |
+
rollout_worker.filters[policy_id] = filter_connectors[0].filter
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.core.learner.learner import Learner
|
| 2 |
+
from ray.rllib.core.learner.learner_group import LearnerGroup
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"Learner",
|
| 7 |
+
"LearnerGroup",
|
| 8 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (400 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc
ADDED
|
Binary file (75.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc
ADDED
|
Binary file (44.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py
ADDED
|
@@ -0,0 +1,1795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import copy
|
| 4 |
+
import logging
|
| 5 |
+
import numpy
|
| 6 |
+
import platform
|
| 7 |
+
from typing import (
|
| 8 |
+
Any,
|
| 9 |
+
Callable,
|
| 10 |
+
Collection,
|
| 11 |
+
Dict,
|
| 12 |
+
List,
|
| 13 |
+
Hashable,
|
| 14 |
+
Optional,
|
| 15 |
+
Sequence,
|
| 16 |
+
Tuple,
|
| 17 |
+
TYPE_CHECKING,
|
| 18 |
+
Union,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import tree # pip install dm_tree
|
| 22 |
+
|
| 23 |
+
import ray
|
| 24 |
+
from ray.data.iterator import DataIterator
|
| 25 |
+
from ray.rllib.connectors.learner.learner_connector_pipeline import (
|
| 26 |
+
LearnerConnectorPipeline,
|
| 27 |
+
)
|
| 28 |
+
from ray.rllib.core import (
|
| 29 |
+
COMPONENT_METRICS_LOGGER,
|
| 30 |
+
COMPONENT_OPTIMIZER,
|
| 31 |
+
COMPONENT_RL_MODULE,
|
| 32 |
+
DEFAULT_MODULE_ID,
|
| 33 |
+
)
|
| 34 |
+
from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
|
| 35 |
+
from ray.rllib.core.rl_module import validate_module_id
|
| 36 |
+
from ray.rllib.core.rl_module.multi_rl_module import (
|
| 37 |
+
MultiRLModule,
|
| 38 |
+
MultiRLModuleSpec,
|
| 39 |
+
)
|
| 40 |
+
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
|
| 41 |
+
from ray.rllib.policy.policy import PolicySpec
|
| 42 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
| 43 |
+
from ray.rllib.utils.annotations import (
|
| 44 |
+
override,
|
| 45 |
+
OverrideToImplementCustomLogic,
|
| 46 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 47 |
+
)
|
| 48 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 49 |
+
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
| 50 |
+
from ray.rllib.utils.deprecation import (
|
| 51 |
+
Deprecated,
|
| 52 |
+
DEPRECATED_VALUE,
|
| 53 |
+
deprecation_warning,
|
| 54 |
+
)
|
| 55 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 56 |
+
from ray.rllib.utils.metrics import (
|
| 57 |
+
ALL_MODULES,
|
| 58 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 59 |
+
NUM_ENV_STEPS_TRAINED,
|
| 60 |
+
NUM_ENV_STEPS_TRAINED_LIFETIME,
|
| 61 |
+
NUM_MODULE_STEPS_TRAINED,
|
| 62 |
+
NUM_MODULE_STEPS_TRAINED_LIFETIME,
|
| 63 |
+
MODULE_TRAIN_BATCH_SIZE_MEAN,
|
| 64 |
+
WEIGHTS_SEQ_NO,
|
| 65 |
+
)
|
| 66 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 67 |
+
from ray.rllib.utils.minibatch_utils import (
|
| 68 |
+
MiniBatchDummyIterator,
|
| 69 |
+
MiniBatchCyclicIterator,
|
| 70 |
+
)
|
| 71 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 72 |
+
from ray.rllib.utils.schedules.scheduler import Scheduler
|
| 73 |
+
from ray.rllib.utils.typing import (
|
| 74 |
+
EpisodeType,
|
| 75 |
+
LearningRateOrSchedule,
|
| 76 |
+
ModuleID,
|
| 77 |
+
Optimizer,
|
| 78 |
+
Param,
|
| 79 |
+
ParamRef,
|
| 80 |
+
ParamDict,
|
| 81 |
+
ResultDict,
|
| 82 |
+
ShouldModuleBeUpdatedFn,
|
| 83 |
+
StateDict,
|
| 84 |
+
TensorType,
|
| 85 |
+
)
|
| 86 |
+
from ray.util.annotations import PublicAPI
|
| 87 |
+
|
| 88 |
+
if TYPE_CHECKING:
|
| 89 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
torch, _ = try_import_torch()
|
| 93 |
+
tf1, tf, tfv = try_import_tf()
|
| 94 |
+
|
| 95 |
+
logger = logging.getLogger(__name__)
|
| 96 |
+
|
| 97 |
+
DEFAULT_OPTIMIZER = "default_optimizer"
|
| 98 |
+
|
| 99 |
+
# COMMON LEARNER LOSS_KEYS
|
| 100 |
+
POLICY_LOSS_KEY = "policy_loss"
|
| 101 |
+
VF_LOSS_KEY = "vf_loss"
|
| 102 |
+
ENTROPY_KEY = "entropy"
|
| 103 |
+
|
| 104 |
+
# Additional update keys
|
| 105 |
+
LR_KEY = "learning_rate"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@PublicAPI(stability="alpha")
|
| 109 |
+
class Learner(Checkpointable):
|
| 110 |
+
"""Base class for Learners.
|
| 111 |
+
|
| 112 |
+
This class will be used to train RLModules. It is responsible for defining the loss
|
| 113 |
+
function, and updating the neural network weights that it owns. It also provides a
|
| 114 |
+
way to add/remove modules to/from RLModules in a multi-agent scenario, in the
|
| 115 |
+
middle of training (This is useful for league based training).
|
| 116 |
+
|
| 117 |
+
TF and Torch specific implementation of this class fills in the framework-specific
|
| 118 |
+
implementation details for distributed training, and for computing and applying
|
| 119 |
+
gradients. User should not need to sub-class this class, but instead inherit from
|
| 120 |
+
the TF or Torch specific sub-classes to implement their algorithm-specific update
|
| 121 |
+
logic.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
config: The AlgorithmConfig object from which to derive most of the settings
|
| 125 |
+
needed to build the Learner.
|
| 126 |
+
module_spec: The module specification for the RLModule that is being trained.
|
| 127 |
+
If the module is a single agent module, after building the module it will
|
| 128 |
+
be converted to a multi-agent module with a default key. Can be none if the
|
| 129 |
+
module is provided directly via the `module` argument. Refer to
|
| 130 |
+
ray.rllib.core.rl_module.RLModuleSpec
|
| 131 |
+
or ray.rllib.core.rl_module.MultiRLModuleSpec for more info.
|
| 132 |
+
module: If learner is being used stand-alone, the RLModule can be optionally
|
| 133 |
+
passed in directly instead of the through the `module_spec`.
|
| 134 |
+
|
| 135 |
+
Note: We use PPO and torch as an example here because many of the showcased
|
| 136 |
+
components need implementations to come together. However, the same
|
| 137 |
+
pattern is generally applicable.
|
| 138 |
+
|
| 139 |
+
.. testcode::
|
| 140 |
+
|
| 141 |
+
import gymnasium as gym
|
| 142 |
+
|
| 143 |
+
from ray.rllib.algorithms.ppo.ppo import PPOConfig
|
| 144 |
+
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
|
| 145 |
+
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
|
| 146 |
+
PPOTorchRLModule
|
| 147 |
+
)
|
| 148 |
+
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
|
| 149 |
+
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
|
| 150 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 151 |
+
|
| 152 |
+
env = gym.make("CartPole-v1")
|
| 153 |
+
|
| 154 |
+
# Create a PPO config object first.
|
| 155 |
+
config = (
|
| 156 |
+
PPOConfig()
|
| 157 |
+
.framework("torch")
|
| 158 |
+
.training(model={"fcnet_hiddens": [128, 128]})
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Create a learner instance directly from our config. All we need as
|
| 162 |
+
# extra information here is the env to be able to extract space information
|
| 163 |
+
# (needed to construct the RLModule inside the Learner).
|
| 164 |
+
learner = config.build_learner(env=env)
|
| 165 |
+
|
| 166 |
+
# Take one gradient update on the module and report the results.
|
| 167 |
+
# results = learner.update(...)
|
| 168 |
+
|
| 169 |
+
# Add a new module, perhaps for league based training.
|
| 170 |
+
learner.add_module(
|
| 171 |
+
module_id="new_player",
|
| 172 |
+
module_spec=RLModuleSpec(
|
| 173 |
+
module_class=PPOTorchRLModule,
|
| 174 |
+
observation_space=env.observation_space,
|
| 175 |
+
action_space=env.action_space,
|
| 176 |
+
model_config=DefaultModelConfig(fcnet_hiddens=[64, 64]),
|
| 177 |
+
catalog_class=PPOCatalog,
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Take another gradient update with both previous and new modules.
|
| 182 |
+
# results = learner.update(...)
|
| 183 |
+
|
| 184 |
+
# Remove a module.
|
| 185 |
+
learner.remove_module("new_player")
|
| 186 |
+
|
| 187 |
+
# Will train previous modules only.
|
| 188 |
+
# results = learner.update(...)
|
| 189 |
+
|
| 190 |
+
# Get the state of the learner.
|
| 191 |
+
state = learner.get_state()
|
| 192 |
+
|
| 193 |
+
# Set the state of the learner.
|
| 194 |
+
learner.set_state(state)
|
| 195 |
+
|
| 196 |
+
# Get the weights of the underlying MultiRLModule.
|
| 197 |
+
weights = learner.get_state(components=COMPONENT_RL_MODULE)
|
| 198 |
+
|
| 199 |
+
# Set the weights of the underlying MultiRLModule.
|
| 200 |
+
learner.set_state({COMPONENT_RL_MODULE: weights})
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
Extension pattern:
|
| 204 |
+
|
| 205 |
+
.. testcode::
|
| 206 |
+
|
| 207 |
+
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
|
| 208 |
+
|
| 209 |
+
class MyLearner(TorchLearner):
|
| 210 |
+
|
| 211 |
+
def compute_losses(self, fwd_out, batch):
|
| 212 |
+
# Compute the losses per module based on `batch` and output of the
|
| 213 |
+
# forward pass (`fwd_out`). To access the (algorithm) config for a
|
| 214 |
+
# specific RLModule, do:
|
| 215 |
+
# `self.config.get_config_for_module([moduleID])`.
|
| 216 |
+
return {DEFAULT_MODULE_ID: module_loss}
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
framework: str = None
|
| 220 |
+
TOTAL_LOSS_KEY: str = "total_loss"
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
*,
|
| 225 |
+
config: "AlgorithmConfig",
|
| 226 |
+
module_spec: Optional[Union[RLModuleSpec, MultiRLModuleSpec]] = None,
|
| 227 |
+
module: Optional[RLModule] = None,
|
| 228 |
+
):
|
| 229 |
+
# TODO (sven): Figure out how to do this
|
| 230 |
+
self.config = config.copy(copy_frozen=False)
|
| 231 |
+
self._module_spec: Optional[MultiRLModuleSpec] = module_spec
|
| 232 |
+
self._module_obj: Optional[MultiRLModule] = module
|
| 233 |
+
|
| 234 |
+
# Make node and device of this Learner available.
|
| 235 |
+
self._node = platform.node()
|
| 236 |
+
self._device = None
|
| 237 |
+
|
| 238 |
+
# Set a seed, if necessary.
|
| 239 |
+
if self.config.seed is not None:
|
| 240 |
+
update_global_seed_if_necessary(self.framework, self.config.seed)
|
| 241 |
+
|
| 242 |
+
# Whether self.build has already been called.
|
| 243 |
+
self._is_built = False
|
| 244 |
+
|
| 245 |
+
# These are the attributes that are set during build.
|
| 246 |
+
|
| 247 |
+
# The actual MultiRLModule used by this Learner.
|
| 248 |
+
self._module: Optional[MultiRLModule] = None
|
| 249 |
+
self._weights_seq_no = 0
|
| 250 |
+
# Our Learner connector pipeline.
|
| 251 |
+
self._learner_connector: Optional[LearnerConnectorPipeline] = None
|
| 252 |
+
# These are set for properly applying optimizers and adding or removing modules.
|
| 253 |
+
self._optimizer_parameters: Dict[Optimizer, List[ParamRef]] = {}
|
| 254 |
+
self._named_optimizers: Dict[str, Optimizer] = {}
|
| 255 |
+
self._params: ParamDict = {}
|
| 256 |
+
# Dict mapping ModuleID to a list of optimizer names. Note that the optimizer
|
| 257 |
+
# name includes the ModuleID as a prefix: optimizer_name=`[ModuleID]_[.. rest]`.
|
| 258 |
+
self._module_optimizers: Dict[ModuleID, List[str]] = defaultdict(list)
|
| 259 |
+
self._optimizer_name_to_module: Dict[str, ModuleID] = {}
|
| 260 |
+
|
| 261 |
+
# Only manage optimizer's learning rate if user has NOT overridden
|
| 262 |
+
# the `configure_optimizers_for_module` method. Otherwise, leave responsibility
|
| 263 |
+
# to handle lr-updates entirely in user's hands.
|
| 264 |
+
self._optimizer_lr_schedules: Dict[Optimizer, Scheduler] = {}
|
| 265 |
+
|
| 266 |
+
# The Learner's own MetricsLogger to be used to log RLlib's built-in metrics or
|
| 267 |
+
# custom user-defined ones (e.g. custom loss values). When returning from an
|
| 268 |
+
# `update_from_...()` method call, the Learner will do a `self.metrics.reduce()`
|
| 269 |
+
# and return the resulting (reduced) dict.
|
| 270 |
+
self.metrics = MetricsLogger()
|
| 271 |
+
|
| 272 |
+
# In case of offline learning and multiple learners, each learner receives a
|
| 273 |
+
# repeatable iterator that iterates over a split of the streamed data.
|
| 274 |
+
self.iterator: DataIterator = None
|
| 275 |
+
|
| 276 |
+
# TODO (sven): Do we really need this API? It seems like LearnerGroup constructs
|
| 277 |
+
# all Learner workers and then immediately builds them any ways? Unless there is
|
| 278 |
+
# a reason related to Train worker group setup.
|
| 279 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 280 |
+
def build(self) -> None:
|
| 281 |
+
"""Builds the Learner.
|
| 282 |
+
|
| 283 |
+
This method should be called before the learner is used. It is responsible for
|
| 284 |
+
setting up the LearnerConnectorPipeline, the RLModule, optimizer(s), and
|
| 285 |
+
(optionally) the optimizers' learning rate schedulers.
|
| 286 |
+
"""
|
| 287 |
+
if self._is_built:
|
| 288 |
+
logger.debug("Learner already built. Skipping build.")
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
# Build learner connector pipeline used on this Learner worker.
|
| 292 |
+
self._learner_connector = None
|
| 293 |
+
# If the Algorithm uses aggregation actors to run episodes through the learner
|
| 294 |
+
# connector, its Learners don't need a connector pipelines and instead learn
|
| 295 |
+
# directly from pre-loaded batches already on the GPU.
|
| 296 |
+
if self.config.num_aggregator_actors_per_learner == 0:
|
| 297 |
+
# TODO (sven): Figure out which space to provide here. For now,
|
| 298 |
+
# it doesn't matter, as the default connector piece doesn't use
|
| 299 |
+
# this information anyway.
|
| 300 |
+
# module_spec = self._module_spec.as_multi_rl_module_spec()
|
| 301 |
+
self._learner_connector = self.config.build_learner_connector(
|
| 302 |
+
input_observation_space=None,
|
| 303 |
+
input_action_space=None,
|
| 304 |
+
device=self._device,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Build the module to be trained by this learner.
|
| 308 |
+
self._module = self._make_module()
|
| 309 |
+
|
| 310 |
+
# Configure, construct, and register all optimizers needed to train
|
| 311 |
+
# `self.module`.
|
| 312 |
+
self.configure_optimizers()
|
| 313 |
+
|
| 314 |
+
# Log the number of trainable/non-trainable parameters.
|
| 315 |
+
self._log_trainable_parameters()
|
| 316 |
+
|
| 317 |
+
self._is_built = True
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def distributed(self) -> bool:
|
| 321 |
+
"""Whether the learner is running in distributed mode."""
|
| 322 |
+
return self.config.num_learners > 1
|
| 323 |
+
|
| 324 |
+
@property
|
| 325 |
+
def module(self) -> MultiRLModule:
|
| 326 |
+
"""The MultiRLModule that is being trained."""
|
| 327 |
+
return self._module
|
| 328 |
+
|
| 329 |
+
@property
|
| 330 |
+
def node(self) -> Any:
|
| 331 |
+
return self._node
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def device(self) -> Any:
|
| 335 |
+
return self._device
|
| 336 |
+
|
| 337 |
+
def register_optimizer(
|
| 338 |
+
self,
|
| 339 |
+
*,
|
| 340 |
+
module_id: ModuleID = ALL_MODULES,
|
| 341 |
+
optimizer_name: str = DEFAULT_OPTIMIZER,
|
| 342 |
+
optimizer: Optimizer,
|
| 343 |
+
params: Sequence[Param],
|
| 344 |
+
lr_or_lr_schedule: Optional[LearningRateOrSchedule] = None,
|
| 345 |
+
) -> None:
|
| 346 |
+
"""Registers an optimizer with a ModuleID, name, param list and lr-scheduler.
|
| 347 |
+
|
| 348 |
+
Use this method in your custom implementations of either
|
| 349 |
+
`self.configure_optimizers()` or `self.configure_optimzers_for_module()` (you
|
| 350 |
+
should only override one of these!). If you register a learning rate Scheduler
|
| 351 |
+
setting together with an optimizer, RLlib will automatically keep this
|
| 352 |
+
optimizer's learning rate updated throughout the training process.
|
| 353 |
+
Alternatively, you can construct your optimizers directly with a learning rate
|
| 354 |
+
and manage learning rate scheduling or updating yourself.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
module_id: The `module_id` under which to register the optimizer. If not
|
| 358 |
+
provided, will assume ALL_MODULES.
|
| 359 |
+
optimizer_name: The name (str) of the optimizer. If not provided, will
|
| 360 |
+
assume DEFAULT_OPTIMIZER.
|
| 361 |
+
optimizer: The already instantiated optimizer object to register.
|
| 362 |
+
params: A list of parameters (framework-specific variables) that will be
|
| 363 |
+
trained/updated
|
| 364 |
+
lr_or_lr_schedule: An optional fixed learning rate or learning rate schedule
|
| 365 |
+
setup. If provided, RLlib will automatically keep the optimizer's
|
| 366 |
+
learning rate updated.
|
| 367 |
+
"""
|
| 368 |
+
# Validate optimizer instance and its param list.
|
| 369 |
+
self._check_registered_optimizer(optimizer, params)
|
| 370 |
+
|
| 371 |
+
full_registration_name = module_id + "_" + optimizer_name
|
| 372 |
+
|
| 373 |
+
# Store the given optimizer under the given `module_id`.
|
| 374 |
+
self._module_optimizers[module_id].append(full_registration_name)
|
| 375 |
+
self._optimizer_name_to_module[full_registration_name] = module_id
|
| 376 |
+
|
| 377 |
+
# Store the optimizer instance under its full `module_id`_`optimizer_name`
|
| 378 |
+
# key.
|
| 379 |
+
self._named_optimizers[full_registration_name] = optimizer
|
| 380 |
+
|
| 381 |
+
# Store all given parameters under the given optimizer.
|
| 382 |
+
self._optimizer_parameters[optimizer] = []
|
| 383 |
+
for param in params:
|
| 384 |
+
param_ref = self.get_param_ref(param)
|
| 385 |
+
self._optimizer_parameters[optimizer].append(param_ref)
|
| 386 |
+
self._params[param_ref] = param
|
| 387 |
+
|
| 388 |
+
# Optionally, store a scheduler object along with this optimizer. If such a
|
| 389 |
+
# setting is provided, RLlib will handle updating the optimizer's learning rate
|
| 390 |
+
# over time.
|
| 391 |
+
if lr_or_lr_schedule is not None:
|
| 392 |
+
# Validate the given setting.
|
| 393 |
+
Scheduler.validate(
|
| 394 |
+
fixed_value_or_schedule=lr_or_lr_schedule,
|
| 395 |
+
setting_name="lr_or_lr_schedule",
|
| 396 |
+
description="learning rate or schedule",
|
| 397 |
+
)
|
| 398 |
+
# Create the scheduler object for this optimizer.
|
| 399 |
+
scheduler = Scheduler(
|
| 400 |
+
fixed_value_or_schedule=lr_or_lr_schedule,
|
| 401 |
+
framework=self.framework,
|
| 402 |
+
device=self._device,
|
| 403 |
+
)
|
| 404 |
+
self._optimizer_lr_schedules[optimizer] = scheduler
|
| 405 |
+
# Set the optimizer to the current (first) learning rate.
|
| 406 |
+
self._set_optimizer_lr(
|
| 407 |
+
optimizer=optimizer,
|
| 408 |
+
lr=scheduler.get_current_value(),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
@OverrideToImplementCustomLogic
|
| 412 |
+
def configure_optimizers(self) -> None:
|
| 413 |
+
"""Configures, creates, and registers the optimizers for this Learner.
|
| 414 |
+
|
| 415 |
+
Optimizers are responsible for updating the model's parameters during training,
|
| 416 |
+
based on the computed gradients.
|
| 417 |
+
|
| 418 |
+
Normally, you should not override this method for your custom algorithms
|
| 419 |
+
(which require certain optimizers), but rather override the
|
| 420 |
+
`self.configure_optimizers_for_module(module_id=..)` method and register those
|
| 421 |
+
optimizers in there that you need for the given `module_id`.
|
| 422 |
+
|
| 423 |
+
You can register an optimizer for any RLModule within `self.module` (or for
|
| 424 |
+
the ALL_MODULES ID) by calling `self.register_optimizer()` and passing the
|
| 425 |
+
module_id, optimizer_name (only in case you would like to register more than
|
| 426 |
+
one optimizer for a given module), the optimizer instane itself, a list
|
| 427 |
+
of all the optimizer's parameters (to be updated by the optimizer), and
|
| 428 |
+
an optional learning rate or learning rate schedule setting.
|
| 429 |
+
|
| 430 |
+
This method is called once during building (`self.build()`).
|
| 431 |
+
"""
|
| 432 |
+
# The default implementation simply calls `self.configure_optimizers_for_module`
|
| 433 |
+
# on each RLModule within `self.module`.
|
| 434 |
+
for module_id in self.module.keys():
|
| 435 |
+
if self.rl_module_is_compatible(self.module[module_id]):
|
| 436 |
+
config = self.config.get_config_for_module(module_id)
|
| 437 |
+
self.configure_optimizers_for_module(module_id=module_id, config=config)
|
| 438 |
+
|
| 439 |
+
@OverrideToImplementCustomLogic
|
| 440 |
+
@abc.abstractmethod
|
| 441 |
+
def configure_optimizers_for_module(
|
| 442 |
+
self, module_id: ModuleID, config: "AlgorithmConfig" = None
|
| 443 |
+
) -> None:
|
| 444 |
+
"""Configures an optimizer for the given module_id.
|
| 445 |
+
|
| 446 |
+
This method is called for each RLModule in the MultiRLModule being
|
| 447 |
+
trained by the Learner, as well as any new module added during training via
|
| 448 |
+
`self.add_module()`. It should configure and construct one or more optimizers
|
| 449 |
+
and register them via calls to `self.register_optimizer()` along with the
|
| 450 |
+
`module_id`, an optional optimizer name (str), a list of the optimizer's
|
| 451 |
+
framework specific parameters (variables), and an optional learning rate value
|
| 452 |
+
or -schedule.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
module_id: The module_id of the RLModule that is being configured.
|
| 456 |
+
config: The AlgorithmConfig specific to the given `module_id`.
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
@OverrideToImplementCustomLogic
|
| 460 |
+
@abc.abstractmethod
|
| 461 |
+
def compute_gradients(
|
| 462 |
+
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
|
| 463 |
+
) -> ParamDict:
|
| 464 |
+
"""Computes the gradients based on the given losses.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
loss_per_module: Dict mapping module IDs to their individual total loss
|
| 468 |
+
terms, computed by the individual `compute_loss_for_module()` calls.
|
| 469 |
+
The overall total loss (sum of loss terms over all modules) is stored
|
| 470 |
+
under `loss_per_module[ALL_MODULES]`.
|
| 471 |
+
**kwargs: Forward compatibility kwargs.
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
The gradients in the same (flat) format as self._params. Note that all
|
| 475 |
+
top-level structures, such as module IDs, will not be present anymore in
|
| 476 |
+
the returned dict. It will merely map parameter tensor references to their
|
| 477 |
+
respective gradient tensors.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
@OverrideToImplementCustomLogic
|
| 481 |
+
def postprocess_gradients(self, gradients_dict: ParamDict) -> ParamDict:
|
| 482 |
+
"""Applies potential postprocessing operations on the gradients.
|
| 483 |
+
|
| 484 |
+
This method is called after gradients have been computed and modifies them
|
| 485 |
+
before they are applied to the respective module(s) by the optimizer(s).
|
| 486 |
+
This might include grad clipping by value, norm, or global-norm, or other
|
| 487 |
+
algorithm specific gradient postprocessing steps.
|
| 488 |
+
|
| 489 |
+
This default implementation calls `self.postprocess_gradients_for_module()`
|
| 490 |
+
on each of the sub-modules in our MultiRLModule: `self.module` and
|
| 491 |
+
returns the accumulated gradients dicts.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
gradients_dict: A dictionary of gradients in the same (flat) format as
|
| 495 |
+
self._params. Note that top-level structures, such as module IDs,
|
| 496 |
+
will not be present anymore in this dict. It will merely map gradient
|
| 497 |
+
tensor references to gradient tensors.
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
A dictionary with the updated gradients and the exact same (flat) structure
|
| 501 |
+
as the incoming `gradients_dict` arg.
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
# The flat gradients dict (mapping param refs to params), returned by this
|
| 505 |
+
# method.
|
| 506 |
+
postprocessed_gradients = {}
|
| 507 |
+
|
| 508 |
+
for module_id in self.module.keys():
|
| 509 |
+
# Send a gradients dict for only this `module_id` to the
|
| 510 |
+
# `self.postprocess_gradients_for_module()` method.
|
| 511 |
+
module_grads_dict = {}
|
| 512 |
+
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
|
| 513 |
+
module_grads_dict.update(
|
| 514 |
+
self.filter_param_dict_for_optimizer(gradients_dict, optimizer)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
module_grads_dict = self.postprocess_gradients_for_module(
|
| 518 |
+
module_id=module_id,
|
| 519 |
+
config=self.config.get_config_for_module(module_id),
|
| 520 |
+
module_gradients_dict=module_grads_dict,
|
| 521 |
+
)
|
| 522 |
+
assert isinstance(module_grads_dict, dict)
|
| 523 |
+
|
| 524 |
+
# Update our return dict.
|
| 525 |
+
postprocessed_gradients.update(module_grads_dict)
|
| 526 |
+
|
| 527 |
+
return postprocessed_gradients
|
| 528 |
+
|
| 529 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 530 |
+
def postprocess_gradients_for_module(
|
| 531 |
+
self,
|
| 532 |
+
*,
|
| 533 |
+
module_id: ModuleID,
|
| 534 |
+
config: Optional["AlgorithmConfig"] = None,
|
| 535 |
+
module_gradients_dict: ParamDict,
|
| 536 |
+
) -> ParamDict:
|
| 537 |
+
"""Applies postprocessing operations on the gradients of the given module.
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
module_id: The module ID for which we will postprocess computed gradients.
|
| 541 |
+
Note that `module_gradients_dict` already only carries those gradient
|
| 542 |
+
tensors that belong to this `module_id`. Other `module_id`'s gradients
|
| 543 |
+
are not available in this call.
|
| 544 |
+
config: The AlgorithmConfig specific to the given `module_id`.
|
| 545 |
+
module_gradients_dict: A dictionary of gradients in the same (flat) format
|
| 546 |
+
as self._params, mapping gradient refs to gradient tensors, which are to
|
| 547 |
+
be postprocessed. You may alter these tensors in place or create new
|
| 548 |
+
ones and return these in a new dict.
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
A dictionary with the updated gradients and the exact same (flat) structure
|
| 552 |
+
as the incoming `module_gradients_dict` arg.
|
| 553 |
+
"""
|
| 554 |
+
postprocessed_grads = {}
|
| 555 |
+
|
| 556 |
+
if config.grad_clip is None and not config.log_gradients:
|
| 557 |
+
postprocessed_grads.update(module_gradients_dict)
|
| 558 |
+
return postprocessed_grads
|
| 559 |
+
|
| 560 |
+
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
|
| 561 |
+
grad_dict_to_clip = self.filter_param_dict_for_optimizer(
|
| 562 |
+
param_dict=module_gradients_dict,
|
| 563 |
+
optimizer=optimizer,
|
| 564 |
+
)
|
| 565 |
+
if config.grad_clip:
|
| 566 |
+
# Perform gradient clipping, if configured.
|
| 567 |
+
global_norm = self._get_clip_function()(
|
| 568 |
+
grad_dict_to_clip,
|
| 569 |
+
grad_clip=config.grad_clip,
|
| 570 |
+
grad_clip_by=config.grad_clip_by,
|
| 571 |
+
)
|
| 572 |
+
if config.grad_clip_by == "global_norm" or config.log_gradients:
|
| 573 |
+
# If we want to log gradients, but do not use the global norm
|
| 574 |
+
# for clipping compute it here.
|
| 575 |
+
if config.log_gradients and config.grad_clip_by != "global_norm":
|
| 576 |
+
# Compute the global norm of gradients.
|
| 577 |
+
global_norm = self._get_global_norm_function()(
|
| 578 |
+
# Note, `tf.linalg.global_norm` needs a list of tensors.
|
| 579 |
+
list(grad_dict_to_clip.values()),
|
| 580 |
+
)
|
| 581 |
+
self.metrics.log_value(
|
| 582 |
+
key=(module_id, f"gradients_{optimizer_name}_global_norm"),
|
| 583 |
+
value=global_norm,
|
| 584 |
+
window=1,
|
| 585 |
+
)
|
| 586 |
+
postprocessed_grads.update(grad_dict_to_clip)
|
| 587 |
+
# In the other case check, if we want to log gradients only.
|
| 588 |
+
elif config.log_gradients:
|
| 589 |
+
# Compute the global norm of gradients and log it.
|
| 590 |
+
global_norm = self._get_global_norm_function()(
|
| 591 |
+
# Note, `tf.linalg.global_norm` needs a list of tensors.
|
| 592 |
+
list(grad_dict_to_clip.values()),
|
| 593 |
+
)
|
| 594 |
+
self.metrics.log_value(
|
| 595 |
+
key=(module_id, f"gradients_{optimizer_name}_global_norm"),
|
| 596 |
+
value=global_norm,
|
| 597 |
+
window=1,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
return postprocessed_grads
|
| 601 |
+
|
| 602 |
+
@OverrideToImplementCustomLogic
|
| 603 |
+
@abc.abstractmethod
|
| 604 |
+
def apply_gradients(self, gradients_dict: ParamDict) -> None:
|
| 605 |
+
"""Applies the gradients to the MultiRLModule parameters.
|
| 606 |
+
|
| 607 |
+
Args:
|
| 608 |
+
gradients_dict: A dictionary of gradients in the same (flat) format as
|
| 609 |
+
self._params. Note that top-level structures, such as module IDs,
|
| 610 |
+
will not be present anymore in this dict. It will merely map gradient
|
| 611 |
+
tensor references to gradient tensors.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
def get_optimizer(
|
| 615 |
+
self,
|
| 616 |
+
module_id: ModuleID = DEFAULT_MODULE_ID,
|
| 617 |
+
optimizer_name: str = DEFAULT_OPTIMIZER,
|
| 618 |
+
) -> Optimizer:
|
| 619 |
+
"""Returns the optimizer object, configured under the given module_id and name.
|
| 620 |
+
|
| 621 |
+
If only one optimizer was registered under `module_id` (or ALL_MODULES)
|
| 622 |
+
via the `self.register_optimizer` method, `optimizer_name` is assumed to be
|
| 623 |
+
DEFAULT_OPTIMIZER.
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
module_id: The ModuleID for which to return the configured optimizer.
|
| 627 |
+
If not provided, will assume DEFAULT_MODULE_ID.
|
| 628 |
+
optimizer_name: The name of the optimizer (registered under `module_id` via
|
| 629 |
+
`self.register_optimizer()`) to return. If not provided, will assume
|
| 630 |
+
DEFAULT_OPTIMIZER.
|
| 631 |
+
|
| 632 |
+
Returns:
|
| 633 |
+
The optimizer object, configured under the given `module_id` and
|
| 634 |
+
`optimizer_name`.
|
| 635 |
+
"""
|
| 636 |
+
# `optimizer_name` could possibly be the full optimizer name (including the
|
| 637 |
+
# module_id under which it is registered).
|
| 638 |
+
if optimizer_name in self._named_optimizers:
|
| 639 |
+
return self._named_optimizers[optimizer_name]
|
| 640 |
+
|
| 641 |
+
# Normally, `optimizer_name` is just the optimizer's name, not including the
|
| 642 |
+
# `module_id`.
|
| 643 |
+
full_registration_name = module_id + "_" + optimizer_name
|
| 644 |
+
if full_registration_name in self._named_optimizers:
|
| 645 |
+
return self._named_optimizers[full_registration_name]
|
| 646 |
+
|
| 647 |
+
# No optimizer found.
|
| 648 |
+
raise KeyError(
|
| 649 |
+
f"Optimizer not found! module_id={module_id} "
|
| 650 |
+
f"optimizer_name={optimizer_name}"
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
def get_optimizers_for_module(
|
| 654 |
+
self, module_id: ModuleID = ALL_MODULES
|
| 655 |
+
) -> List[Tuple[str, Optimizer]]:
|
| 656 |
+
"""Returns a list of (optimizer_name, optimizer instance)-tuples for module_id.
|
| 657 |
+
|
| 658 |
+
Args:
|
| 659 |
+
module_id: The ModuleID for which to return the configured
|
| 660 |
+
(optimizer name, optimizer)-pairs. If not provided, will return
|
| 661 |
+
optimizers registered under ALL_MODULES.
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
A list of tuples of the format: ([optimizer_name], [optimizer object]),
|
| 665 |
+
where optimizer_name is the name under which the optimizer was registered
|
| 666 |
+
in `self.register_optimizer`. If only a single optimizer was
|
| 667 |
+
configured for `module_id`, [optimizer_name] will be DEFAULT_OPTIMIZER.
|
| 668 |
+
"""
|
| 669 |
+
named_optimizers = []
|
| 670 |
+
for full_registration_name in self._module_optimizers[module_id]:
|
| 671 |
+
optimizer = self._named_optimizers[full_registration_name]
|
| 672 |
+
# TODO (sven): How can we avoid registering optimziers under this
|
| 673 |
+
# constructed `[module_id]_[optim_name]` format?
|
| 674 |
+
optim_name = full_registration_name[len(module_id) + 1 :]
|
| 675 |
+
named_optimizers.append((optim_name, optimizer))
|
| 676 |
+
return named_optimizers
|
| 677 |
+
|
| 678 |
+
def filter_param_dict_for_optimizer(
|
| 679 |
+
self, param_dict: ParamDict, optimizer: Optimizer
|
| 680 |
+
) -> ParamDict:
|
| 681 |
+
"""Reduces the given ParamDict to contain only parameters for given optimizer.
|
| 682 |
+
|
| 683 |
+
Args:
|
| 684 |
+
param_dict: The ParamDict to reduce/filter down to the given `optimizer`.
|
| 685 |
+
The returned dict will be a subset of `param_dict` only containing keys
|
| 686 |
+
(param refs) that were registered together with `optimizer` (and thus
|
| 687 |
+
that `optimizer` is responsible for applying gradients to).
|
| 688 |
+
optimizer: The optimizer object to whose parameter refs the given
|
| 689 |
+
`param_dict` should be reduced.
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
A new ParamDict only containing param ref keys that belong to `optimizer`.
|
| 693 |
+
"""
|
| 694 |
+
# Return a sub-dict only containing those param_ref keys (and their values)
|
| 695 |
+
# that belong to the `optimizer`.
|
| 696 |
+
return {
|
| 697 |
+
ref: param_dict[ref]
|
| 698 |
+
for ref in self._optimizer_parameters[optimizer]
|
| 699 |
+
if ref in param_dict and param_dict[ref] is not None
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
@abc.abstractmethod
|
| 703 |
+
def get_param_ref(self, param: Param) -> Hashable:
|
| 704 |
+
"""Returns a hashable reference to a trainable parameter.
|
| 705 |
+
|
| 706 |
+
This should be overridden in framework specific specialization. For example in
|
| 707 |
+
torch it will return the parameter itself, while in tf it returns the .ref() of
|
| 708 |
+
the variable. The purpose is to retrieve a unique reference to the parameters.
|
| 709 |
+
|
| 710 |
+
Args:
|
| 711 |
+
param: The parameter to get the reference to.
|
| 712 |
+
|
| 713 |
+
Returns:
|
| 714 |
+
A reference to the parameter.
|
| 715 |
+
"""
|
| 716 |
+
|
| 717 |
+
@abc.abstractmethod
|
| 718 |
+
def get_parameters(self, module: RLModule) -> Sequence[Param]:
|
| 719 |
+
"""Returns the list of parameters of a module.
|
| 720 |
+
|
| 721 |
+
This should be overridden in framework specific learner. For example in torch it
|
| 722 |
+
will return .parameters(), while in tf it returns .trainable_variables.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
module: The module to get the parameters from.
|
| 726 |
+
|
| 727 |
+
Returns:
|
| 728 |
+
The parameters of the module.
|
| 729 |
+
"""
|
| 730 |
+
|
| 731 |
+
@abc.abstractmethod
|
| 732 |
+
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
|
| 733 |
+
"""Converts the elements of a MultiAgentBatch to Tensors on the correct device.
|
| 734 |
+
|
| 735 |
+
Args:
|
| 736 |
+
batch: The MultiAgentBatch object to convert.
|
| 737 |
+
|
| 738 |
+
Returns:
|
| 739 |
+
The resulting MultiAgentBatch with framework-specific tensor values placed
|
| 740 |
+
on the correct device.
|
| 741 |
+
"""
|
| 742 |
+
|
| 743 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 744 |
+
def add_module(
|
| 745 |
+
self,
|
| 746 |
+
*,
|
| 747 |
+
module_id: ModuleID,
|
| 748 |
+
module_spec: RLModuleSpec,
|
| 749 |
+
config_overrides: Optional[Dict] = None,
|
| 750 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 751 |
+
) -> MultiRLModuleSpec:
|
| 752 |
+
"""Adds a module to the underlying MultiRLModule.
|
| 753 |
+
|
| 754 |
+
Changes this Learner's config in order to make this architectural change
|
| 755 |
+
permanent wrt. to checkpointing.
|
| 756 |
+
|
| 757 |
+
Args:
|
| 758 |
+
module_id: The ModuleID of the module to be added.
|
| 759 |
+
module_spec: The ModuleSpec of the module to be added.
|
| 760 |
+
config_overrides: The `AlgorithmConfig` overrides that should apply to
|
| 761 |
+
the new Module, if any.
|
| 762 |
+
new_should_module_be_updated: An optional sequence of ModuleIDs or a
|
| 763 |
+
callable taking ModuleID and SampleBatchType and returning whether the
|
| 764 |
+
ModuleID should be updated (trained).
|
| 765 |
+
If None, will keep the existing setup in place. RLModules,
|
| 766 |
+
whose IDs are not in the list (or for which the callable
|
| 767 |
+
returns False) will not be updated.
|
| 768 |
+
|
| 769 |
+
Returns:
|
| 770 |
+
The new MultiRLModuleSpec (after the RLModule has been added).
|
| 771 |
+
"""
|
| 772 |
+
validate_module_id(module_id, error=True)
|
| 773 |
+
self._check_is_built()
|
| 774 |
+
|
| 775 |
+
# Force-set inference-only = False.
|
| 776 |
+
module_spec = copy.deepcopy(module_spec)
|
| 777 |
+
module_spec.inference_only = False
|
| 778 |
+
|
| 779 |
+
# Build the new RLModule and add it to self.module.
|
| 780 |
+
module = module_spec.build()
|
| 781 |
+
self.module.add_module(module_id, module)
|
| 782 |
+
|
| 783 |
+
# Change our config (AlgorithmConfig) to contain the new Module.
|
| 784 |
+
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
|
| 785 |
+
# but we'll deprecate config.policies soon anyway.
|
| 786 |
+
self.config.policies[module_id] = PolicySpec()
|
| 787 |
+
if config_overrides is not None:
|
| 788 |
+
self.config.multi_agent(
|
| 789 |
+
algorithm_config_overrides_per_module={module_id: config_overrides}
|
| 790 |
+
)
|
| 791 |
+
self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module))
|
| 792 |
+
self._module_spec = self.config.rl_module_spec
|
| 793 |
+
if new_should_module_be_updated is not None:
|
| 794 |
+
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
|
| 795 |
+
|
| 796 |
+
# Allow the user to configure one or more optimizers for this new module.
|
| 797 |
+
self.configure_optimizers_for_module(
|
| 798 |
+
module_id=module_id,
|
| 799 |
+
config=self.config.get_config_for_module(module_id),
|
| 800 |
+
)
|
| 801 |
+
return self.config.rl_module_spec
|
| 802 |
+
|
| 803 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 804 |
+
def remove_module(
|
| 805 |
+
self,
|
| 806 |
+
module_id: ModuleID,
|
| 807 |
+
*,
|
| 808 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 809 |
+
) -> MultiRLModuleSpec:
|
| 810 |
+
"""Removes a module from the Learner.
|
| 811 |
+
|
| 812 |
+
Args:
|
| 813 |
+
module_id: The ModuleID of the module to be removed.
|
| 814 |
+
new_should_module_be_updated: An optional sequence of ModuleIDs or a
|
| 815 |
+
callable taking ModuleID and SampleBatchType and returning whether the
|
| 816 |
+
ModuleID should be updated (trained).
|
| 817 |
+
If None, will keep the existing setup in place. RLModules,
|
| 818 |
+
whose IDs are not in the list (or for which the callable
|
| 819 |
+
returns False) will not be updated.
|
| 820 |
+
|
| 821 |
+
Returns:
|
| 822 |
+
The new MultiRLModuleSpec (after the RLModule has been removed).
|
| 823 |
+
"""
|
| 824 |
+
self._check_is_built()
|
| 825 |
+
module = self.module[module_id]
|
| 826 |
+
|
| 827 |
+
# Delete the removed module's parameters and optimizers.
|
| 828 |
+
if self.rl_module_is_compatible(module):
|
| 829 |
+
parameters = self.get_parameters(module)
|
| 830 |
+
for param in parameters:
|
| 831 |
+
param_ref = self.get_param_ref(param)
|
| 832 |
+
if param_ref in self._params:
|
| 833 |
+
del self._params[param_ref]
|
| 834 |
+
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
|
| 835 |
+
del self._optimizer_parameters[optimizer]
|
| 836 |
+
name = module_id + "_" + optimizer_name
|
| 837 |
+
del self._named_optimizers[name]
|
| 838 |
+
if optimizer in self._optimizer_lr_schedules:
|
| 839 |
+
del self._optimizer_lr_schedules[optimizer]
|
| 840 |
+
del self._module_optimizers[module_id]
|
| 841 |
+
|
| 842 |
+
# Remove the module from the MultiRLModule.
|
| 843 |
+
self.module.remove_module(module_id)
|
| 844 |
+
|
| 845 |
+
# Change self.config to reflect the new architecture.
|
| 846 |
+
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
|
| 847 |
+
# but we'll deprecate config.policies soon anyway.
|
| 848 |
+
del self.config.policies[module_id]
|
| 849 |
+
self.config.algorithm_config_overrides_per_module.pop(module_id, None)
|
| 850 |
+
if new_should_module_be_updated is not None:
|
| 851 |
+
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
|
| 852 |
+
self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module))
|
| 853 |
+
|
| 854 |
+
# Remove all stats from the module from our metrics logger, so we don't report
|
| 855 |
+
# results from this module again.
|
| 856 |
+
if module_id in self.metrics.stats:
|
| 857 |
+
del self.metrics.stats[module_id]
|
| 858 |
+
|
| 859 |
+
return self.config.rl_module_spec
|
| 860 |
+
|
| 861 |
+
@OverrideToImplementCustomLogic
|
| 862 |
+
def should_module_be_updated(self, module_id, multi_agent_batch=None):
|
| 863 |
+
"""Returns whether a module should be updated or not based on `self.config`.
|
| 864 |
+
|
| 865 |
+
Args:
|
| 866 |
+
module_id: The ModuleID that we want to query on whether this module
|
| 867 |
+
should be updated or not.
|
| 868 |
+
multi_agent_batch: An optional MultiAgentBatch to possibly provide further
|
| 869 |
+
information on the decision on whether the RLModule should be updated
|
| 870 |
+
or not.
|
| 871 |
+
"""
|
| 872 |
+
should_module_be_updated_fn = self.config.policies_to_train
|
| 873 |
+
# If None, return True (by default, all modules should be updated).
|
| 874 |
+
if should_module_be_updated_fn is None:
|
| 875 |
+
return True
|
| 876 |
+
# If collection given, return whether `module_id` is in that container.
|
| 877 |
+
elif not callable(should_module_be_updated_fn):
|
| 878 |
+
return module_id in set(should_module_be_updated_fn)
|
| 879 |
+
|
| 880 |
+
return should_module_be_updated_fn(module_id, multi_agent_batch)
|
| 881 |
+
|
| 882 |
+
@OverrideToImplementCustomLogic
|
| 883 |
+
def compute_losses(
|
| 884 |
+
self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
|
| 885 |
+
) -> Dict[str, Any]:
|
| 886 |
+
"""Computes the loss(es) for the module being optimized.
|
| 887 |
+
|
| 888 |
+
This method must be overridden by MultiRLModule-specific Learners in order to
|
| 889 |
+
define the specific loss computation logic. If the algorithm is single-agent,
|
| 890 |
+
only `compute_loss_for_module()` should be overridden instead. If the algorithm
|
| 891 |
+
uses independent multi-agent learning (default behavior for RLlib's multi-agent
|
| 892 |
+
setups), also only `compute_loss_for_module()` should be overridden, but it will
|
| 893 |
+
be called for each individual RLModule inside the MultiRLModule.
|
| 894 |
+
It is recommended to not compute any forward passes within this method, and to
|
| 895 |
+
use the `forward_train()` outputs of the RLModule(s) to compute the required
|
| 896 |
+
loss tensors.
|
| 897 |
+
See here for a custom loss function example script:
|
| 898 |
+
https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py # noqa
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
fwd_out: Output from a call to the `forward_train()` method of the
|
| 902 |
+
underlying MultiRLModule (`self.module`) during training
|
| 903 |
+
(`self.update()`).
|
| 904 |
+
batch: The train batch that was used to compute `fwd_out`.
|
| 905 |
+
|
| 906 |
+
Returns:
|
| 907 |
+
A dictionary mapping module IDs to individual loss terms.
|
| 908 |
+
"""
|
| 909 |
+
loss_per_module = {}
|
| 910 |
+
for module_id in fwd_out:
|
| 911 |
+
module_batch = batch[module_id]
|
| 912 |
+
module_fwd_out = fwd_out[module_id]
|
| 913 |
+
|
| 914 |
+
module = self.module[module_id].unwrapped()
|
| 915 |
+
if isinstance(module, SelfSupervisedLossAPI):
|
| 916 |
+
loss = module.compute_self_supervised_loss(
|
| 917 |
+
learner=self,
|
| 918 |
+
module_id=module_id,
|
| 919 |
+
config=self.config.get_config_for_module(module_id),
|
| 920 |
+
batch=module_batch,
|
| 921 |
+
fwd_out=module_fwd_out,
|
| 922 |
+
)
|
| 923 |
+
else:
|
| 924 |
+
loss = self.compute_loss_for_module(
|
| 925 |
+
module_id=module_id,
|
| 926 |
+
config=self.config.get_config_for_module(module_id),
|
| 927 |
+
batch=module_batch,
|
| 928 |
+
fwd_out=module_fwd_out,
|
| 929 |
+
)
|
| 930 |
+
loss_per_module[module_id] = loss
|
| 931 |
+
|
| 932 |
+
return loss_per_module
|
| 933 |
+
|
| 934 |
+
@OverrideToImplementCustomLogic
|
| 935 |
+
@abc.abstractmethod
|
| 936 |
+
def compute_loss_for_module(
|
| 937 |
+
self,
|
| 938 |
+
*,
|
| 939 |
+
module_id: ModuleID,
|
| 940 |
+
config: "AlgorithmConfig",
|
| 941 |
+
batch: Dict[str, Any],
|
| 942 |
+
fwd_out: Dict[str, TensorType],
|
| 943 |
+
) -> TensorType:
|
| 944 |
+
"""Computes the loss for a single module.
|
| 945 |
+
|
| 946 |
+
Think of this as computing loss for a single agent. For multi-agent use-cases
|
| 947 |
+
that require more complicated computation for loss, consider overriding the
|
| 948 |
+
`compute_losses` method instead.
|
| 949 |
+
|
| 950 |
+
Args:
|
| 951 |
+
module_id: The id of the module.
|
| 952 |
+
config: The AlgorithmConfig specific to the given `module_id`.
|
| 953 |
+
batch: The train batch for this particular module.
|
| 954 |
+
fwd_out: The output of the forward pass for this particular module.
|
| 955 |
+
|
| 956 |
+
Returns:
|
| 957 |
+
A single total loss tensor. If you have more than one optimizer on the
|
| 958 |
+
provided `module_id` and would like to compute gradients separately using
|
| 959 |
+
these different optimizers, simply add up the individual loss terms for
|
| 960 |
+
each optimizer and return the sum. Also, for recording/logging any
|
| 961 |
+
individual loss terms, you can use the `Learner.metrics.log_value(
|
| 962 |
+
key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See:
|
| 963 |
+
:py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more
|
| 964 |
+
information.
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
def update_from_batch(
|
| 968 |
+
self,
|
| 969 |
+
batch: MultiAgentBatch,
|
| 970 |
+
*,
|
| 971 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 972 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 973 |
+
num_epochs: int = 1,
|
| 974 |
+
minibatch_size: Optional[int] = None,
|
| 975 |
+
shuffle_batch_per_epoch: bool = False,
|
| 976 |
+
# Deprecated args.
|
| 977 |
+
num_iters=DEPRECATED_VALUE,
|
| 978 |
+
**kwargs,
|
| 979 |
+
) -> ResultDict:
|
| 980 |
+
"""Run `num_epochs` epochs over the given train batch.
|
| 981 |
+
|
| 982 |
+
You can use this method to take more than one backward pass on the batch.
|
| 983 |
+
The same `minibatch_size` and `num_epochs` will be used for all module ids in
|
| 984 |
+
MultiRLModule.
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
batch: A batch of training data to update from.
|
| 988 |
+
timesteps: Timesteps dict, which must have the key
|
| 989 |
+
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
|
| 990 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 991 |
+
num_epochs: The number of complete passes over the entire train batch. Each
|
| 992 |
+
pass might be further split into n minibatches (if `minibatch_size`
|
| 993 |
+
provided).
|
| 994 |
+
minibatch_size: The size of minibatches to use to further split the train
|
| 995 |
+
`batch` into sub-batches. The `batch` is then iterated over n times
|
| 996 |
+
where n is `len(batch) // minibatch_size`.
|
| 997 |
+
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
|
| 998 |
+
If the train batch has a time rank (axis=1), shuffling will only take
|
| 999 |
+
place along the batch axis to not disturb any intact (episode)
|
| 1000 |
+
trajectories. Also, shuffling is always skipped if `minibatch_size` is
|
| 1001 |
+
None, meaning the entire train batch is processed each epoch, making it
|
| 1002 |
+
unnecessary to shuffle.
|
| 1003 |
+
|
| 1004 |
+
Returns:
|
| 1005 |
+
A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
|
| 1006 |
+
returned dict may be arbitrarily nested and must have `Stats` objects at
|
| 1007 |
+
all its leafs, allowing components further downstream (i.e. a user of this
|
| 1008 |
+
Learner) to further reduce these results (for example over n parallel
|
| 1009 |
+
Learners).
|
| 1010 |
+
"""
|
| 1011 |
+
if num_iters != DEPRECATED_VALUE:
|
| 1012 |
+
deprecation_warning(
|
| 1013 |
+
old="Learner.update_from_episodes(num_iters=...)",
|
| 1014 |
+
new="Learner.update_from_episodes(num_epochs=...)",
|
| 1015 |
+
error=True,
|
| 1016 |
+
)
|
| 1017 |
+
self._update_from_batch_or_episodes(
|
| 1018 |
+
batch=batch,
|
| 1019 |
+
timesteps=timesteps,
|
| 1020 |
+
num_epochs=num_epochs,
|
| 1021 |
+
minibatch_size=minibatch_size,
|
| 1022 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
|
| 1023 |
+
)
|
| 1024 |
+
return self.metrics.reduce()
|
| 1025 |
+
|
| 1026 |
+
def update_from_episodes(
|
| 1027 |
+
self,
|
| 1028 |
+
episodes: List[EpisodeType],
|
| 1029 |
+
*,
|
| 1030 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 1031 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 1032 |
+
num_epochs: int = 1,
|
| 1033 |
+
minibatch_size: Optional[int] = None,
|
| 1034 |
+
shuffle_batch_per_epoch: bool = False,
|
| 1035 |
+
num_total_minibatches: int = 0,
|
| 1036 |
+
# Deprecated args.
|
| 1037 |
+
num_iters=DEPRECATED_VALUE,
|
| 1038 |
+
) -> ResultDict:
|
| 1039 |
+
"""Run `num_epochs` epochs over the train batch generated from `episodes`.
|
| 1040 |
+
|
| 1041 |
+
You can use this method to take more than one backward pass on the batch.
|
| 1042 |
+
The same `minibatch_size` and `num_epochs` will be used for all module ids in
|
| 1043 |
+
MultiRLModule.
|
| 1044 |
+
|
| 1045 |
+
Args:
|
| 1046 |
+
episodes: An list of episode objects to update from.
|
| 1047 |
+
timesteps: Timesteps dict, which must have the key
|
| 1048 |
+
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
|
| 1049 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 1050 |
+
num_epochs: The number of complete passes over the entire train batch. Each
|
| 1051 |
+
pass might be further split into n minibatches (if `minibatch_size`
|
| 1052 |
+
provided). The train batch is generated from the given `episodes`
|
| 1053 |
+
through the Learner connector pipeline.
|
| 1054 |
+
minibatch_size: The size of minibatches to use to further split the train
|
| 1055 |
+
`batch` into sub-batches. The `batch` is then iterated over n times
|
| 1056 |
+
where n is `len(batch) // minibatch_size`. The train batch is generated
|
| 1057 |
+
from the given `episodes` through the Learner connector pipeline.
|
| 1058 |
+
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
|
| 1059 |
+
If the train batch has a time rank (axis=1), shuffling will only take
|
| 1060 |
+
place along the batch axis to not disturb any intact (episode)
|
| 1061 |
+
trajectories. Also, shuffling is always skipped if `minibatch_size` is
|
| 1062 |
+
None, meaning the entire train batch is processed each epoch, making it
|
| 1063 |
+
unnecessary to shuffle. The train batch is generated from the given
|
| 1064 |
+
`episodes` through the Learner connector pipeline.
|
| 1065 |
+
num_total_minibatches: The total number of minibatches to loop through
|
| 1066 |
+
(over all `num_epochs` epochs). It's only required to set this to != 0
|
| 1067 |
+
in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes
|
| 1068 |
+
themselves are roughly sharded equally, however, they might contain
|
| 1069 |
+
SingleAgentEpisodes with very lopsided length distributions. Thus,
|
| 1070 |
+
without this fixed, pre-computed value, one Learner might go through a
|
| 1071 |
+
different number of minibatche passes than others causing a deadlock.
|
| 1072 |
+
|
| 1073 |
+
Returns:
|
| 1074 |
+
A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
|
| 1075 |
+
returned dict may be arbitrarily nested and must have `Stats` objects at
|
| 1076 |
+
all its leafs, allowing components further downstream (i.e. a user of this
|
| 1077 |
+
Learner) to further reduce these results (for example over n parallel
|
| 1078 |
+
Learners).
|
| 1079 |
+
"""
|
| 1080 |
+
if num_iters != DEPRECATED_VALUE:
|
| 1081 |
+
deprecation_warning(
|
| 1082 |
+
old="Learner.update_from_episodes(num_iters=...)",
|
| 1083 |
+
new="Learner.update_from_episodes(num_epochs=...)",
|
| 1084 |
+
error=True,
|
| 1085 |
+
)
|
| 1086 |
+
self._update_from_batch_or_episodes(
|
| 1087 |
+
episodes=episodes,
|
| 1088 |
+
timesteps=timesteps,
|
| 1089 |
+
num_epochs=num_epochs,
|
| 1090 |
+
minibatch_size=minibatch_size,
|
| 1091 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
|
| 1092 |
+
num_total_minibatches=num_total_minibatches,
|
| 1093 |
+
)
|
| 1094 |
+
return self.metrics.reduce()
|
| 1095 |
+
|
| 1096 |
+
def update_from_iterator(
|
| 1097 |
+
self,
|
| 1098 |
+
iterator,
|
| 1099 |
+
*,
|
| 1100 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 1101 |
+
minibatch_size: Optional[int] = None,
|
| 1102 |
+
num_iters: int = None,
|
| 1103 |
+
**kwargs,
|
| 1104 |
+
):
|
| 1105 |
+
if "num_epochs" in kwargs:
|
| 1106 |
+
raise ValueError(
|
| 1107 |
+
"`num_epochs` arg NOT supported by Learner.update_from_iterator! Use "
|
| 1108 |
+
"`num_iters` instead."
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
if not self.iterator:
|
| 1112 |
+
self.iterator = iterator
|
| 1113 |
+
|
| 1114 |
+
self._check_is_built()
|
| 1115 |
+
|
| 1116 |
+
# Call `before_gradient_based_update` to allow for non-gradient based
|
| 1117 |
+
# preparations-, logging-, and update logic to happen.
|
| 1118 |
+
self.before_gradient_based_update(timesteps=timesteps or {})
|
| 1119 |
+
|
| 1120 |
+
def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
|
| 1121 |
+
# Note, the incoming batch is a dictionary with a numpy array
|
| 1122 |
+
# holding the `MultiAgentBatch`.
|
| 1123 |
+
batch = self._convert_batch_type(batch["batch"][0])
|
| 1124 |
+
return {"batch": self._set_slicing_by_batch_id(batch, value=True)}
|
| 1125 |
+
|
| 1126 |
+
i = 0
|
| 1127 |
+
logger.debug(f"===> [Learner {id(self)}]: Looping through batches ... ")
|
| 1128 |
+
for batch in self.iterator.iter_batches(
|
| 1129 |
+
# Note, this needs to be one b/c data is already mapped to
|
| 1130 |
+
# `MultiAgentBatch`es of `minibatch_size`.
|
| 1131 |
+
batch_size=1,
|
| 1132 |
+
_finalize_fn=_finalize_fn,
|
| 1133 |
+
**kwargs,
|
| 1134 |
+
):
|
| 1135 |
+
# Update the iteration counter.
|
| 1136 |
+
i += 1
|
| 1137 |
+
|
| 1138 |
+
# Note, `_finalize_fn` must return a dictionary.
|
| 1139 |
+
batch = batch["batch"]
|
| 1140 |
+
logger.debug(
|
| 1141 |
+
f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows."
|
| 1142 |
+
)
|
| 1143 |
+
# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
|
| 1144 |
+
# found in this batch. If not, throw an error.
|
| 1145 |
+
unknown_module_ids = set(batch.policy_batches.keys()) - set(
|
| 1146 |
+
self.module.keys()
|
| 1147 |
+
)
|
| 1148 |
+
if len(unknown_module_ids) > 0:
|
| 1149 |
+
raise ValueError(
|
| 1150 |
+
"Batch contains one or more ModuleIDs that are not in this "
|
| 1151 |
+
f"Learner! Found IDs: {unknown_module_ids}"
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
# Log metrics.
|
| 1155 |
+
self._log_steps_trained_metrics(batch)
|
| 1156 |
+
|
| 1157 |
+
# Make the actual in-graph/traced `_update` call. This should return
|
| 1158 |
+
# all tensor values (no numpy).
|
| 1159 |
+
fwd_out, loss_per_module, tensor_metrics = self._update(
|
| 1160 |
+
batch.policy_batches
|
| 1161 |
+
)
|
| 1162 |
+
# Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
|
| 1163 |
+
# to actual (numpy) values.
|
| 1164 |
+
self.metrics.tensors_to_numpy(tensor_metrics)
|
| 1165 |
+
|
| 1166 |
+
self._set_slicing_by_batch_id(batch, value=False)
|
| 1167 |
+
# If `num_iters` is reached break and return.
|
| 1168 |
+
if num_iters and i == num_iters:
|
| 1169 |
+
break
|
| 1170 |
+
|
| 1171 |
+
logger.debug(
|
| 1172 |
+
f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}"
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
# Log all individual RLModules' loss terms and its registered optimizers'
|
| 1176 |
+
# current learning rates.
|
| 1177 |
+
for mid, loss in convert_to_numpy(loss_per_module).items():
|
| 1178 |
+
self.metrics.log_value(
|
| 1179 |
+
key=(mid, self.TOTAL_LOSS_KEY),
|
| 1180 |
+
value=loss,
|
| 1181 |
+
window=1,
|
| 1182 |
+
)
|
| 1183 |
+
# Call `after_gradient_based_update` to allow for non-gradient based
|
| 1184 |
+
# cleanups-, logging-, and update logic to happen.
|
| 1185 |
+
# TODO (simon): Check, if this should stay here, when running multiple
|
| 1186 |
+
# gradient steps inside the iterator loop above (could be a complete epoch)
|
| 1187 |
+
# the target networks might need to be updated earlier.
|
| 1188 |
+
self.after_gradient_based_update(timesteps=timesteps or {})
|
| 1189 |
+
|
| 1190 |
+
# Reduce results across all minibatch update steps.
|
| 1191 |
+
return self.metrics.reduce()
|
| 1192 |
+
|
| 1193 |
+
@OverrideToImplementCustomLogic
|
| 1194 |
+
@abc.abstractmethod
|
| 1195 |
+
def _update(
|
| 1196 |
+
self,
|
| 1197 |
+
batch: Dict[str, Any],
|
| 1198 |
+
**kwargs,
|
| 1199 |
+
) -> Tuple[Any, Any, Any]:
|
| 1200 |
+
"""Contains all logic for an in-graph/traceable update step.
|
| 1201 |
+
|
| 1202 |
+
Framework specific subclasses must implement this method. This should include
|
| 1203 |
+
calls to the RLModule's `forward_train`, `compute_loss`, compute_gradients`,
|
| 1204 |
+
`postprocess_gradients`, and `apply_gradients` methods and return a tuple
|
| 1205 |
+
with all the individual results.
|
| 1206 |
+
|
| 1207 |
+
Args:
|
| 1208 |
+
batch: The train batch already converted to a Dict mapping str to (possibly
|
| 1209 |
+
nested) tensors.
|
| 1210 |
+
kwargs: Forward compatibility kwargs.
|
| 1211 |
+
|
| 1212 |
+
Returns:
|
| 1213 |
+
A tuple consisting of:
|
| 1214 |
+
1) The `forward_train()` output of the RLModule,
|
| 1215 |
+
2) the loss_per_module dictionary mapping module IDs to individual loss
|
| 1216 |
+
tensors
|
| 1217 |
+
3) a metrics dict mapping module IDs to metrics key/value pairs.
|
| 1218 |
+
|
| 1219 |
+
"""
|
| 1220 |
+
|
| 1221 |
+
@override(Checkpointable)
|
| 1222 |
+
def get_state(
|
| 1223 |
+
self,
|
| 1224 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 1225 |
+
*,
|
| 1226 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 1227 |
+
**kwargs,
|
| 1228 |
+
) -> StateDict:
|
| 1229 |
+
self._check_is_built()
|
| 1230 |
+
|
| 1231 |
+
state = {
|
| 1232 |
+
"should_module_be_updated": self.config.policies_to_train,
|
| 1233 |
+
}
|
| 1234 |
+
|
| 1235 |
+
if self._check_component(COMPONENT_RL_MODULE, components, not_components):
|
| 1236 |
+
state[COMPONENT_RL_MODULE] = self.module.get_state(
|
| 1237 |
+
components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
|
| 1238 |
+
not_components=self._get_subcomponents(
|
| 1239 |
+
COMPONENT_RL_MODULE, not_components
|
| 1240 |
+
),
|
| 1241 |
+
**kwargs,
|
| 1242 |
+
)
|
| 1243 |
+
state[WEIGHTS_SEQ_NO] = self._weights_seq_no
|
| 1244 |
+
if self._check_component(COMPONENT_OPTIMIZER, components, not_components):
|
| 1245 |
+
state[COMPONENT_OPTIMIZER] = self._get_optimizer_state()
|
| 1246 |
+
|
| 1247 |
+
if self._check_component(COMPONENT_METRICS_LOGGER, components, not_components):
|
| 1248 |
+
# TODO (sven): Make `MetricsLogger` a Checkpointable.
|
| 1249 |
+
state[COMPONENT_METRICS_LOGGER] = self.metrics.get_state()
|
| 1250 |
+
|
| 1251 |
+
return state
|
| 1252 |
+
|
| 1253 |
+
@override(Checkpointable)
|
| 1254 |
+
def set_state(self, state: StateDict) -> None:
|
| 1255 |
+
self._check_is_built()
|
| 1256 |
+
|
| 1257 |
+
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
|
| 1258 |
+
|
| 1259 |
+
if COMPONENT_RL_MODULE in state:
|
| 1260 |
+
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
|
| 1261 |
+
self.module.set_state(state[COMPONENT_RL_MODULE])
|
| 1262 |
+
|
| 1263 |
+
if COMPONENT_OPTIMIZER in state:
|
| 1264 |
+
self._set_optimizer_state(state[COMPONENT_OPTIMIZER])
|
| 1265 |
+
|
| 1266 |
+
# Update our weights_seq_no, if the new one is > 0.
|
| 1267 |
+
if weights_seq_no > 0:
|
| 1268 |
+
self._weights_seq_no = weights_seq_no
|
| 1269 |
+
|
| 1270 |
+
# Update our trainable Modules information/function via our config.
|
| 1271 |
+
# If not provided in state (None), all Modules will be trained by default.
|
| 1272 |
+
if "should_module_be_updated" in state:
|
| 1273 |
+
self.config.multi_agent(policies_to_train=state["should_module_be_updated"])
|
| 1274 |
+
|
| 1275 |
+
# TODO (sven): Make `MetricsLogger` a Checkpointable.
|
| 1276 |
+
if COMPONENT_METRICS_LOGGER in state:
|
| 1277 |
+
self.metrics.set_state(state[COMPONENT_METRICS_LOGGER])
|
| 1278 |
+
|
| 1279 |
+
@override(Checkpointable)
|
| 1280 |
+
def get_ctor_args_and_kwargs(self):
|
| 1281 |
+
return (
|
| 1282 |
+
(), # *args,
|
| 1283 |
+
{
|
| 1284 |
+
"config": self.config,
|
| 1285 |
+
"module_spec": self._module_spec,
|
| 1286 |
+
"module": self._module_obj,
|
| 1287 |
+
}, # **kwargs
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
@override(Checkpointable)
|
| 1291 |
+
def get_checkpointable_components(self):
|
| 1292 |
+
if not self._check_is_built(error=False):
|
| 1293 |
+
self.build()
|
| 1294 |
+
return [
|
| 1295 |
+
(COMPONENT_RL_MODULE, self.module),
|
| 1296 |
+
]
|
| 1297 |
+
|
| 1298 |
+
def _get_optimizer_state(self) -> StateDict:
|
| 1299 |
+
"""Returns the state of all optimizers currently registered in this Learner.
|
| 1300 |
+
|
| 1301 |
+
Returns:
|
| 1302 |
+
The current state of all optimizers currently registered in this Learner.
|
| 1303 |
+
"""
|
| 1304 |
+
raise NotImplementedError
|
| 1305 |
+
|
| 1306 |
+
def _set_optimizer_state(self, state: StateDict) -> None:
|
| 1307 |
+
"""Sets the state of all optimizers currently registered in this Learner.
|
| 1308 |
+
|
| 1309 |
+
Args:
|
| 1310 |
+
state: The state of the optimizers.
|
| 1311 |
+
"""
|
| 1312 |
+
raise NotImplementedError
|
| 1313 |
+
|
| 1314 |
+
def _update_from_batch_or_episodes(
|
| 1315 |
+
self,
|
| 1316 |
+
*,
|
| 1317 |
+
# TODO (sven): We should allow passing in a single agent batch here
|
| 1318 |
+
# as well for simplicity.
|
| 1319 |
+
batch: Optional[MultiAgentBatch] = None,
|
| 1320 |
+
episodes: Optional[List[EpisodeType]] = None,
|
| 1321 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 1322 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 1323 |
+
# TODO (sven): Deprecate these in favor of config attributes for only those
|
| 1324 |
+
# algos that actually need (and know how) to do minibatching.
|
| 1325 |
+
num_epochs: int = 1,
|
| 1326 |
+
minibatch_size: Optional[int] = None,
|
| 1327 |
+
shuffle_batch_per_epoch: bool = False,
|
| 1328 |
+
num_total_minibatches: int = 0,
|
| 1329 |
+
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
| 1330 |
+
|
| 1331 |
+
self._check_is_built()
|
| 1332 |
+
|
| 1333 |
+
# Call `before_gradient_based_update` to allow for non-gradient based
|
| 1334 |
+
# preparations-, logging-, and update logic to happen.
|
| 1335 |
+
self.before_gradient_based_update(timesteps=timesteps or {})
|
| 1336 |
+
|
| 1337 |
+
# Resolve batch/episodes being ray object refs (instead of
|
| 1338 |
+
# actual batch/episodes objects).
|
| 1339 |
+
if isinstance(batch, ray.ObjectRef):
|
| 1340 |
+
batch = ray.get(batch)
|
| 1341 |
+
if isinstance(episodes, ray.ObjectRef):
|
| 1342 |
+
episodes = ray.get(episodes)
|
| 1343 |
+
elif isinstance(episodes, list) and isinstance(episodes[0], ray.ObjectRef):
|
| 1344 |
+
# It's possible that individual refs are invalid due to the EnvRunner
|
| 1345 |
+
# that produced the ref has crashed or had its entire node go down.
|
| 1346 |
+
# In this case, try each ref individually and collect only valid results.
|
| 1347 |
+
try:
|
| 1348 |
+
episodes = tree.flatten(ray.get(episodes))
|
| 1349 |
+
except ray.exceptions.OwnerDiedError:
|
| 1350 |
+
episode_refs = episodes
|
| 1351 |
+
episodes = []
|
| 1352 |
+
for ref in episode_refs:
|
| 1353 |
+
try:
|
| 1354 |
+
episodes.extend(ray.get(ref))
|
| 1355 |
+
except ray.exceptions.OwnerDiedError:
|
| 1356 |
+
pass
|
| 1357 |
+
|
| 1358 |
+
# Call the learner connector on the given `episodes` (if we have one).
|
| 1359 |
+
if episodes is not None and self._learner_connector is not None:
|
| 1360 |
+
# Call the learner connector pipeline.
|
| 1361 |
+
shared_data = {}
|
| 1362 |
+
batch = self._learner_connector(
|
| 1363 |
+
rl_module=self.module,
|
| 1364 |
+
batch=batch if batch is not None else {},
|
| 1365 |
+
episodes=episodes,
|
| 1366 |
+
shared_data=shared_data,
|
| 1367 |
+
metrics=self.metrics,
|
| 1368 |
+
)
|
| 1369 |
+
# Convert to a batch.
|
| 1370 |
+
# TODO (sven): Try to not require MultiAgentBatch anymore.
|
| 1371 |
+
batch = MultiAgentBatch(
|
| 1372 |
+
{
|
| 1373 |
+
module_id: (
|
| 1374 |
+
SampleBatch(module_data, _zero_padded=True)
|
| 1375 |
+
if shared_data.get(f"_zero_padded_for_mid={module_id}")
|
| 1376 |
+
else SampleBatch(module_data)
|
| 1377 |
+
)
|
| 1378 |
+
for module_id, module_data in batch.items()
|
| 1379 |
+
},
|
| 1380 |
+
env_steps=sum(len(e) for e in episodes),
|
| 1381 |
+
)
|
| 1382 |
+
# Single-agent SampleBatch: Have to convert to MultiAgentBatch.
|
| 1383 |
+
elif isinstance(batch, SampleBatch):
|
| 1384 |
+
assert len(self.module) == 1
|
| 1385 |
+
batch = MultiAgentBatch(
|
| 1386 |
+
{next(iter(self.module.keys())): batch}, env_steps=len(batch)
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
|
| 1390 |
+
# found in this batch. If not, throw an error.
|
| 1391 |
+
unknown_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys())
|
| 1392 |
+
if len(unknown_module_ids) > 0:
|
| 1393 |
+
raise ValueError(
|
| 1394 |
+
"Batch contains one or more ModuleIDs that are not in this Learner! "
|
| 1395 |
+
f"Found IDs: {unknown_module_ids}"
|
| 1396 |
+
)
|
| 1397 |
+
|
| 1398 |
+
# TODO: Move this into LearnerConnector pipeline?
|
| 1399 |
+
# Filter out those RLModules from the final train batch that should not be
|
| 1400 |
+
# updated.
|
| 1401 |
+
for module_id in list(batch.policy_batches.keys()):
|
| 1402 |
+
if not self.should_module_be_updated(module_id, batch):
|
| 1403 |
+
del batch.policy_batches[module_id]
|
| 1404 |
+
|
| 1405 |
+
# Log all timesteps (env, agent, modules) based on given episodes/batch.
|
| 1406 |
+
self._log_steps_trained_metrics(batch)
|
| 1407 |
+
|
| 1408 |
+
if minibatch_size:
|
| 1409 |
+
batch_iter = MiniBatchCyclicIterator
|
| 1410 |
+
elif num_epochs > 1:
|
| 1411 |
+
# `minibatch_size` was not set but `num_epochs` > 1.
|
| 1412 |
+
# Under the old training stack, users could do multiple epochs
|
| 1413 |
+
# over a batch without specifying a minibatch size. We enable
|
| 1414 |
+
# this behavior here by setting the minibatch size to be the size
|
| 1415 |
+
# of the batch (e.g. 1 minibatch of size batch.count)
|
| 1416 |
+
minibatch_size = batch.count
|
| 1417 |
+
# Note that there is no need to shuffle here, b/c we don't have minibatches.
|
| 1418 |
+
batch_iter = MiniBatchCyclicIterator
|
| 1419 |
+
else:
|
| 1420 |
+
# `minibatch_size` and `num_epochs` are not set by the user.
|
| 1421 |
+
batch_iter = MiniBatchDummyIterator
|
| 1422 |
+
|
| 1423 |
+
batch = self._set_slicing_by_batch_id(batch, value=True)
|
| 1424 |
+
|
| 1425 |
+
for tensor_minibatch in batch_iter(
|
| 1426 |
+
batch,
|
| 1427 |
+
num_epochs=num_epochs,
|
| 1428 |
+
minibatch_size=minibatch_size,
|
| 1429 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch and (num_epochs > 1),
|
| 1430 |
+
num_total_minibatches=num_total_minibatches,
|
| 1431 |
+
):
|
| 1432 |
+
# Make the actual in-graph/traced `_update` call. This should return
|
| 1433 |
+
# all tensor values (no numpy).
|
| 1434 |
+
fwd_out, loss_per_module, tensor_metrics = self._update(
|
| 1435 |
+
tensor_minibatch.policy_batches
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
# Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
|
| 1439 |
+
# to actual (numpy) values.
|
| 1440 |
+
self.metrics.tensors_to_numpy(tensor_metrics)
|
| 1441 |
+
|
| 1442 |
+
# Log all individual RLModules' loss terms and its registered optimizers'
|
| 1443 |
+
# current learning rates.
|
| 1444 |
+
for mid, loss in convert_to_numpy(loss_per_module).items():
|
| 1445 |
+
self.metrics.log_value(
|
| 1446 |
+
key=(mid, self.TOTAL_LOSS_KEY),
|
| 1447 |
+
value=loss,
|
| 1448 |
+
window=1,
|
| 1449 |
+
)
|
| 1450 |
+
|
| 1451 |
+
self._weights_seq_no += 1
|
| 1452 |
+
self.metrics.log_dict(
|
| 1453 |
+
{
|
| 1454 |
+
(mid, WEIGHTS_SEQ_NO): self._weights_seq_no
|
| 1455 |
+
for mid in batch.policy_batches.keys()
|
| 1456 |
+
},
|
| 1457 |
+
window=1,
|
| 1458 |
+
)
|
| 1459 |
+
|
| 1460 |
+
self._set_slicing_by_batch_id(batch, value=False)
|
| 1461 |
+
|
| 1462 |
+
# Call `after_gradient_based_update` to allow for non-gradient based
|
| 1463 |
+
# cleanups-, logging-, and update logic to happen.
|
| 1464 |
+
self.after_gradient_based_update(timesteps=timesteps or {})
|
| 1465 |
+
|
| 1466 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 1467 |
+
def before_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
|
| 1468 |
+
"""Called before gradient-based updates are completed.
|
| 1469 |
+
|
| 1470 |
+
Should be overridden to implement custom preparation-, logging-, or
|
| 1471 |
+
non-gradient-based Learner/RLModule update logic before(!) gradient-based
|
| 1472 |
+
updates are performed.
|
| 1473 |
+
|
| 1474 |
+
Args:
|
| 1475 |
+
timesteps: Timesteps dict, which must have the key
|
| 1476 |
+
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
|
| 1477 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 1478 |
+
"""
|
| 1479 |
+
|
| 1480 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 1481 |
+
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
|
| 1482 |
+
"""Called after gradient-based updates are completed.
|
| 1483 |
+
|
| 1484 |
+
Should be overridden to implement custom cleanup-, logging-, or non-gradient-
|
| 1485 |
+
based Learner/RLModule update logic after(!) gradient-based updates have been
|
| 1486 |
+
completed.
|
| 1487 |
+
|
| 1488 |
+
Args:
|
| 1489 |
+
timesteps: Timesteps dict, which must have the key
|
| 1490 |
+
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
|
| 1491 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 1492 |
+
"""
|
| 1493 |
+
# Only update this optimizer's lr, if a scheduler has been registered
|
| 1494 |
+
# along with it.
|
| 1495 |
+
for module_id, optimizer_names in self._module_optimizers.items():
|
| 1496 |
+
for optimizer_name in optimizer_names:
|
| 1497 |
+
optimizer = self._named_optimizers[optimizer_name]
|
| 1498 |
+
# Update and log learning rate of this optimizer.
|
| 1499 |
+
lr_schedule = self._optimizer_lr_schedules.get(optimizer)
|
| 1500 |
+
if lr_schedule is not None:
|
| 1501 |
+
new_lr = lr_schedule.update(
|
| 1502 |
+
timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)
|
| 1503 |
+
)
|
| 1504 |
+
self._set_optimizer_lr(optimizer, lr=new_lr)
|
| 1505 |
+
self.metrics.log_value(
|
| 1506 |
+
# Cut out the module ID from the beginning since it's already part
|
| 1507 |
+
# of the key sequence: (ModuleID, "[optim name]_lr").
|
| 1508 |
+
key=(module_id, f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}"),
|
| 1509 |
+
value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
|
| 1510 |
+
window=1,
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
def _set_slicing_by_batch_id(
|
| 1514 |
+
self, batch: MultiAgentBatch, *, value: bool
|
| 1515 |
+
) -> MultiAgentBatch:
|
| 1516 |
+
"""Enables slicing by batch id in the given batch.
|
| 1517 |
+
|
| 1518 |
+
If the input batch contains batches of sequences we need to make sure when
|
| 1519 |
+
slicing happens it is sliced via batch id and not timestamp. Calling this
|
| 1520 |
+
method enables the same flag on each SampleBatch within the input
|
| 1521 |
+
MultiAgentBatch.
|
| 1522 |
+
|
| 1523 |
+
Args:
|
| 1524 |
+
batch: The MultiAgentBatch to enable slicing by batch id on.
|
| 1525 |
+
value: The value to set the flag to.
|
| 1526 |
+
|
| 1527 |
+
Returns:
|
| 1528 |
+
The input MultiAgentBatch with the indexing flag is enabled / disabled on.
|
| 1529 |
+
"""
|
| 1530 |
+
|
| 1531 |
+
for pid, policy_batch in batch.policy_batches.items():
|
| 1532 |
+
# We assume that arriving batches for recurrent modules OR batches that
|
| 1533 |
+
# have a SEQ_LENS column are already zero-padded to the max sequence length
|
| 1534 |
+
# and have tensors of shape [B, T, ...]. Therefore, we slice sequence
|
| 1535 |
+
# lengths in B. See SampleBatch for more information.
|
| 1536 |
+
if (
|
| 1537 |
+
self.module[pid].is_stateful()
|
| 1538 |
+
or policy_batch.get("seq_lens") is not None
|
| 1539 |
+
):
|
| 1540 |
+
if value:
|
| 1541 |
+
policy_batch.enable_slicing_by_batch_id()
|
| 1542 |
+
else:
|
| 1543 |
+
policy_batch.disable_slicing_by_batch_id()
|
| 1544 |
+
|
| 1545 |
+
return batch
|
| 1546 |
+
|
| 1547 |
+
def _make_module(self) -> MultiRLModule:
|
| 1548 |
+
"""Construct the multi-agent RL module for the learner.
|
| 1549 |
+
|
| 1550 |
+
This method uses `self._module_specs` or `self._module_obj` to construct the
|
| 1551 |
+
module. If the module_class is a single agent RL module it will be wrapped to a
|
| 1552 |
+
multi-agent RL module. Override this method if there are other things that
|
| 1553 |
+
need to happen for instantiation of the module.
|
| 1554 |
+
|
| 1555 |
+
Returns:
|
| 1556 |
+
A constructed MultiRLModule.
|
| 1557 |
+
"""
|
| 1558 |
+
# Module was provided directly through constructor -> Use as-is.
|
| 1559 |
+
if self._module_obj is not None:
|
| 1560 |
+
module = self._module_obj
|
| 1561 |
+
self._module_spec = MultiRLModuleSpec.from_module(module)
|
| 1562 |
+
# RLModuleSpec was provided directly through constructor -> Use it to build the
|
| 1563 |
+
# RLModule.
|
| 1564 |
+
elif self._module_spec is not None:
|
| 1565 |
+
module = self._module_spec.build()
|
| 1566 |
+
# Try using our config object. Note that this would only work if the config
|
| 1567 |
+
# object has all the necessary space information already in it.
|
| 1568 |
+
else:
|
| 1569 |
+
module = self.config.get_multi_rl_module_spec().build()
|
| 1570 |
+
|
| 1571 |
+
# If not already, convert to MultiRLModule.
|
| 1572 |
+
module = module.as_multi_rl_module()
|
| 1573 |
+
|
| 1574 |
+
return module
|
| 1575 |
+
|
| 1576 |
+
def rl_module_is_compatible(self, module: RLModule) -> bool:
|
| 1577 |
+
"""Check whether the given `module` is compatible with this Learner.
|
| 1578 |
+
|
| 1579 |
+
The default implementation checks the Learner-required APIs and whether the
|
| 1580 |
+
given `module` implements all of them (if not, returns False).
|
| 1581 |
+
|
| 1582 |
+
Args:
|
| 1583 |
+
module: The RLModule to check.
|
| 1584 |
+
|
| 1585 |
+
Returns:
|
| 1586 |
+
True if the module is compatible with this Learner.
|
| 1587 |
+
"""
|
| 1588 |
+
return all(isinstance(module, api) for api in self.rl_module_required_apis())
|
| 1589 |
+
|
| 1590 |
+
@classmethod
|
| 1591 |
+
def rl_module_required_apis(cls) -> list[type]:
|
| 1592 |
+
"""Returns the required APIs for an RLModule to be compatible with this Learner.
|
| 1593 |
+
|
| 1594 |
+
The returned values may or may not be used inside the `rl_module_is_compatible`
|
| 1595 |
+
method.
|
| 1596 |
+
|
| 1597 |
+
Args:
|
| 1598 |
+
module: The RLModule to check.
|
| 1599 |
+
|
| 1600 |
+
Returns:
|
| 1601 |
+
A list of RLModule API classes that an RLModule must implement in order
|
| 1602 |
+
to be compatible with this Learner.
|
| 1603 |
+
"""
|
| 1604 |
+
return []
|
| 1605 |
+
|
| 1606 |
+
def _check_registered_optimizer(
|
| 1607 |
+
self,
|
| 1608 |
+
optimizer: Optimizer,
|
| 1609 |
+
params: Sequence[Param],
|
| 1610 |
+
) -> None:
|
| 1611 |
+
"""Checks that the given optimizer and parameters are valid for the framework.
|
| 1612 |
+
|
| 1613 |
+
Args:
|
| 1614 |
+
optimizer: The optimizer object to check.
|
| 1615 |
+
params: The list of parameters to check.
|
| 1616 |
+
"""
|
| 1617 |
+
if not isinstance(params, list):
|
| 1618 |
+
raise ValueError(
|
| 1619 |
+
f"`params` ({params}) must be a list of framework-specific parameters "
|
| 1620 |
+
"(variables)!"
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
def _log_trainable_parameters(self) -> None:
|
| 1624 |
+
"""Logs the number of trainable and non-trainable parameters to self.metrics.
|
| 1625 |
+
|
| 1626 |
+
Use MetricsLogger (self.metrics) tuple-keys:
|
| 1627 |
+
(ALL_MODULES, NUM_TRAINABLE_PARAMETERS) and
|
| 1628 |
+
(ALL_MODULES, NUM_NON_TRAINABLE_PARAMETERS) with EMA.
|
| 1629 |
+
"""
|
| 1630 |
+
pass
|
| 1631 |
+
|
| 1632 |
+
def _check_is_built(self, error: bool = True) -> bool:
|
| 1633 |
+
if self.module is None:
|
| 1634 |
+
if error:
|
| 1635 |
+
raise ValueError(
|
| 1636 |
+
"Learner.build() must be called after constructing a "
|
| 1637 |
+
"Learner and before calling any methods on it."
|
| 1638 |
+
)
|
| 1639 |
+
return False
|
| 1640 |
+
return True
|
| 1641 |
+
|
| 1642 |
+
def _reset(self):
|
| 1643 |
+
self._params = {}
|
| 1644 |
+
self._optimizer_parameters = {}
|
| 1645 |
+
self._named_optimizers = {}
|
| 1646 |
+
self._module_optimizers = defaultdict(list)
|
| 1647 |
+
self._optimizer_lr_schedules = {}
|
| 1648 |
+
self.metrics = MetricsLogger()
|
| 1649 |
+
self._is_built = False
|
| 1650 |
+
|
| 1651 |
+
def apply(self, func, *_args, **_kwargs):
|
| 1652 |
+
return func(self, *_args, **_kwargs)
|
| 1653 |
+
|
| 1654 |
+
@abc.abstractmethod
|
| 1655 |
+
def _get_tensor_variable(
|
| 1656 |
+
self,
|
| 1657 |
+
value: Any,
|
| 1658 |
+
dtype: Any = None,
|
| 1659 |
+
trainable: bool = False,
|
| 1660 |
+
) -> TensorType:
|
| 1661 |
+
"""Returns a framework-specific tensor variable with the initial given value.
|
| 1662 |
+
|
| 1663 |
+
This is a framework specific method that should be implemented by the
|
| 1664 |
+
framework specific sub-classes.
|
| 1665 |
+
|
| 1666 |
+
Args:
|
| 1667 |
+
value: The initial value for the tensor variable variable.
|
| 1668 |
+
|
| 1669 |
+
Returns:
|
| 1670 |
+
The framework specific tensor variable of the given initial value,
|
| 1671 |
+
dtype and trainable/requires_grad property.
|
| 1672 |
+
"""
|
| 1673 |
+
|
| 1674 |
+
@staticmethod
|
| 1675 |
+
@abc.abstractmethod
|
| 1676 |
+
def _get_optimizer_lr(optimizer: Optimizer) -> float:
|
| 1677 |
+
"""Returns the current learning rate of the given local optimizer.
|
| 1678 |
+
|
| 1679 |
+
Args:
|
| 1680 |
+
optimizer: The local optimizer to get the current learning rate for.
|
| 1681 |
+
|
| 1682 |
+
Returns:
|
| 1683 |
+
The learning rate value (float) of the given optimizer.
|
| 1684 |
+
"""
|
| 1685 |
+
|
| 1686 |
+
@staticmethod
|
| 1687 |
+
@abc.abstractmethod
|
| 1688 |
+
def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None:
|
| 1689 |
+
"""Updates the learning rate of the given local optimizer.
|
| 1690 |
+
|
| 1691 |
+
Args:
|
| 1692 |
+
optimizer: The local optimizer to update the learning rate for.
|
| 1693 |
+
lr: The new learning rate.
|
| 1694 |
+
"""
|
| 1695 |
+
|
| 1696 |
+
@staticmethod
|
| 1697 |
+
@abc.abstractmethod
|
| 1698 |
+
def _get_clip_function() -> Callable:
|
| 1699 |
+
"""Returns the gradient clipping function to use, given the framework."""
|
| 1700 |
+
|
| 1701 |
+
@staticmethod
|
| 1702 |
+
@abc.abstractmethod
|
| 1703 |
+
def _get_global_norm_function() -> Callable:
|
| 1704 |
+
"""Returns the global norm function to use, given the framework."""
|
| 1705 |
+
|
| 1706 |
+
def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
|
| 1707 |
+
"""Logs this iteration's steps trained, based on given `batch`."""
|
| 1708 |
+
for mid, module_batch in batch.policy_batches.items():
|
| 1709 |
+
module_batch_size = len(module_batch)
|
| 1710 |
+
# Log average batch size (for each module).
|
| 1711 |
+
self.metrics.log_value(
|
| 1712 |
+
key=(mid, MODULE_TRAIN_BATCH_SIZE_MEAN),
|
| 1713 |
+
value=module_batch_size,
|
| 1714 |
+
)
|
| 1715 |
+
# Log module steps (for each module).
|
| 1716 |
+
self.metrics.log_value(
|
| 1717 |
+
key=(mid, NUM_MODULE_STEPS_TRAINED),
|
| 1718 |
+
value=module_batch_size,
|
| 1719 |
+
reduce="sum",
|
| 1720 |
+
clear_on_reduce=True,
|
| 1721 |
+
)
|
| 1722 |
+
self.metrics.log_value(
|
| 1723 |
+
key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
|
| 1724 |
+
value=module_batch_size,
|
| 1725 |
+
reduce="sum",
|
| 1726 |
+
)
|
| 1727 |
+
# Log module steps (sum of all modules).
|
| 1728 |
+
self.metrics.log_value(
|
| 1729 |
+
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
|
| 1730 |
+
value=module_batch_size,
|
| 1731 |
+
reduce="sum",
|
| 1732 |
+
clear_on_reduce=True,
|
| 1733 |
+
)
|
| 1734 |
+
self.metrics.log_value(
|
| 1735 |
+
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
|
| 1736 |
+
value=module_batch_size,
|
| 1737 |
+
reduce="sum",
|
| 1738 |
+
)
|
| 1739 |
+
# Log env steps (all modules).
|
| 1740 |
+
self.metrics.log_value(
|
| 1741 |
+
(ALL_MODULES, NUM_ENV_STEPS_TRAINED),
|
| 1742 |
+
batch.env_steps(),
|
| 1743 |
+
reduce="sum",
|
| 1744 |
+
clear_on_reduce=True,
|
| 1745 |
+
)
|
| 1746 |
+
self.metrics.log_value(
|
| 1747 |
+
(ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME),
|
| 1748 |
+
batch.env_steps(),
|
| 1749 |
+
reduce="sum",
|
| 1750 |
+
with_throughput=True,
|
| 1751 |
+
)
|
| 1752 |
+
|
| 1753 |
+
@Deprecated(
|
| 1754 |
+
new="Learner.before_gradient_based_update("
|
| 1755 |
+
"timesteps={'num_env_steps_sampled_lifetime': ...}) and/or "
|
| 1756 |
+
"Learner.after_gradient_based_update("
|
| 1757 |
+
"timesteps={'num_env_steps_sampled_lifetime': ...})",
|
| 1758 |
+
error=True,
|
| 1759 |
+
)
|
| 1760 |
+
def additional_update_for_module(self, *args, **kwargs):
|
| 1761 |
+
pass
|
| 1762 |
+
|
| 1763 |
+
@Deprecated(new="Learner.save_to_path(...)", error=True)
|
| 1764 |
+
def save_state(self, *args, **kwargs):
|
| 1765 |
+
pass
|
| 1766 |
+
|
| 1767 |
+
@Deprecated(new="Learner.restore_from_path(...)", error=True)
|
| 1768 |
+
def load_state(self, *args, **kwargs):
|
| 1769 |
+
pass
|
| 1770 |
+
|
| 1771 |
+
@Deprecated(new="Learner.module.get_state()", error=True)
|
| 1772 |
+
def get_module_state(self, *args, **kwargs):
|
| 1773 |
+
pass
|
| 1774 |
+
|
| 1775 |
+
@Deprecated(new="Learner.module.set_state()", error=True)
|
| 1776 |
+
def set_module_state(self, *args, **kwargs):
|
| 1777 |
+
pass
|
| 1778 |
+
|
| 1779 |
+
@Deprecated(new="Learner._get_optimizer_state()", error=True)
|
| 1780 |
+
def get_optimizer_state(self, *args, **kwargs):
|
| 1781 |
+
pass
|
| 1782 |
+
|
| 1783 |
+
@Deprecated(new="Learner._set_optimizer_state()", error=True)
|
| 1784 |
+
def set_optimizer_state(self, *args, **kwargs):
|
| 1785 |
+
pass
|
| 1786 |
+
|
| 1787 |
+
@Deprecated(new="Learner.compute_losses(...)", error=False)
|
| 1788 |
+
def compute_loss(self, *args, **kwargs):
|
| 1789 |
+
losses_per_module = self.compute_losses(*args, **kwargs)
|
| 1790 |
+
# To continue supporting the old `compute_loss` behavior (instead of
|
| 1791 |
+
# the new `compute_losses`, add the ALL_MODULES key here holding the sum
|
| 1792 |
+
# of all individual loss terms.
|
| 1793 |
+
if ALL_MODULES not in losses_per_module:
|
| 1794 |
+
losses_per_module[ALL_MODULES] = sum(losses_per_module.values())
|
| 1795 |
+
return losses_per_module
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py
ADDED
|
@@ -0,0 +1,1030 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from collections import defaultdict, Counter
|
| 3 |
+
import copy
|
| 4 |
+
from functools import partial
|
| 5 |
+
import itertools
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Collection,
|
| 10 |
+
Dict,
|
| 11 |
+
List,
|
| 12 |
+
Optional,
|
| 13 |
+
Set,
|
| 14 |
+
Type,
|
| 15 |
+
TYPE_CHECKING,
|
| 16 |
+
Union,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
from ray import ObjectRef
|
| 21 |
+
from ray.rllib.core import (
|
| 22 |
+
COMPONENT_LEARNER,
|
| 23 |
+
COMPONENT_RL_MODULE,
|
| 24 |
+
)
|
| 25 |
+
from ray.rllib.core.learner.learner import Learner
|
| 26 |
+
from ray.rllib.core.rl_module import validate_module_id
|
| 27 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
|
| 28 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 29 |
+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
|
| 30 |
+
from ray.rllib.policy.policy import PolicySpec
|
| 31 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
| 32 |
+
from ray.rllib.utils.actor_manager import (
|
| 33 |
+
FaultTolerantActorManager,
|
| 34 |
+
RemoteCallResults,
|
| 35 |
+
ResultOrError,
|
| 36 |
+
)
|
| 37 |
+
from ray.rllib.utils.annotations import override
|
| 38 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 39 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 40 |
+
from ray.rllib.utils.metrics import ALL_MODULES
|
| 41 |
+
from ray.rllib.utils.minibatch_utils import (
|
| 42 |
+
ShardBatchIterator,
|
| 43 |
+
ShardEpisodesIterator,
|
| 44 |
+
ShardObjectRefIterator,
|
| 45 |
+
)
|
| 46 |
+
from ray.rllib.utils.typing import (
|
| 47 |
+
EpisodeType,
|
| 48 |
+
ModuleID,
|
| 49 |
+
RLModuleSpecType,
|
| 50 |
+
ShouldModuleBeUpdatedFn,
|
| 51 |
+
StateDict,
|
| 52 |
+
T,
|
| 53 |
+
)
|
| 54 |
+
from ray.train._internal.backend_executor import BackendExecutor
|
| 55 |
+
from ray.util.annotations import PublicAPI
|
| 56 |
+
|
| 57 |
+
if TYPE_CHECKING:
|
| 58 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _get_backend_config(learner_class: Type[Learner]) -> str:
|
| 62 |
+
if learner_class.framework == "torch":
|
| 63 |
+
from ray.train.torch import TorchConfig
|
| 64 |
+
|
| 65 |
+
backend_config = TorchConfig()
|
| 66 |
+
elif learner_class.framework == "tf2":
|
| 67 |
+
from ray.train.tensorflow import TensorflowConfig
|
| 68 |
+
|
| 69 |
+
backend_config = TensorflowConfig()
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"`learner_class.framework` must be either 'torch' or 'tf2' (but is "
|
| 73 |
+
f"{learner_class.framework}!"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return backend_config
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@PublicAPI(stability="alpha")
|
| 80 |
+
class LearnerGroup(Checkpointable):
|
| 81 |
+
"""Coordinator of n (possibly remote) Learner workers.
|
| 82 |
+
|
| 83 |
+
Each Learner worker has a copy of the RLModule, the loss function(s), and
|
| 84 |
+
one or more optimizers.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
*,
|
| 90 |
+
config: "AlgorithmConfig",
|
| 91 |
+
# TODO (sven): Rename into `rl_module_spec`.
|
| 92 |
+
module_spec: Optional[RLModuleSpecType] = None,
|
| 93 |
+
):
|
| 94 |
+
"""Initializes a LearnerGroup instance.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
config: The AlgorithmConfig object to use to configure this LearnerGroup.
|
| 98 |
+
Call the `learners(num_learners=...)` method on your config to
|
| 99 |
+
specify the number of learner workers to use.
|
| 100 |
+
Call the same method with arguments `num_cpus_per_learner` and/or
|
| 101 |
+
`num_gpus_per_learner` to configure the compute used by each
|
| 102 |
+
Learner worker in this LearnerGroup.
|
| 103 |
+
Call the `training(learner_class=...)` method on your config to specify,
|
| 104 |
+
which exact Learner class to use.
|
| 105 |
+
Call the `rl_module(rl_module_spec=...)` method on your config to set up
|
| 106 |
+
the specifics for your RLModule to be used in each Learner.
|
| 107 |
+
module_spec: If not already specified in `config`, a separate overriding
|
| 108 |
+
RLModuleSpec may be provided via this argument.
|
| 109 |
+
"""
|
| 110 |
+
self.config = config.copy(copy_frozen=False)
|
| 111 |
+
self._module_spec = module_spec
|
| 112 |
+
|
| 113 |
+
learner_class = self.config.learner_class
|
| 114 |
+
module_spec = module_spec or self.config.get_multi_rl_module_spec()
|
| 115 |
+
|
| 116 |
+
self._learner = None
|
| 117 |
+
self._workers = None
|
| 118 |
+
# If a user calls self.shutdown() on their own then this flag is set to true.
|
| 119 |
+
# When del is called the backend executor isn't shutdown twice if this flag is
|
| 120 |
+
# true. the backend executor would otherwise log a warning to the console from
|
| 121 |
+
# ray train.
|
| 122 |
+
self._is_shut_down = False
|
| 123 |
+
|
| 124 |
+
# How many timesteps had to be dropped due to a full input queue?
|
| 125 |
+
self._ts_dropped = 0
|
| 126 |
+
|
| 127 |
+
# A single local Learner.
|
| 128 |
+
if not self.is_remote:
|
| 129 |
+
self._learner = learner_class(config=config, module_spec=module_spec)
|
| 130 |
+
self._learner.build()
|
| 131 |
+
self._worker_manager = None
|
| 132 |
+
# N remote Learner workers.
|
| 133 |
+
else:
|
| 134 |
+
backend_config = _get_backend_config(learner_class)
|
| 135 |
+
|
| 136 |
+
# TODO (sven): Can't set both `num_cpus_per_learner`>1 and
|
| 137 |
+
# `num_gpus_per_learner`>0! Users must set one or the other due
|
| 138 |
+
# to issues with placement group fragmentation. See
|
| 139 |
+
# https://github.com/ray-project/ray/issues/35409 for more details.
|
| 140 |
+
num_cpus_per_learner = (
|
| 141 |
+
self.config.num_cpus_per_learner
|
| 142 |
+
if not self.config.num_gpus_per_learner
|
| 143 |
+
else 0
|
| 144 |
+
)
|
| 145 |
+
num_gpus_per_learner = max(
|
| 146 |
+
0,
|
| 147 |
+
self.config.num_gpus_per_learner
|
| 148 |
+
- (0.01 * self.config.num_aggregator_actors_per_learner),
|
| 149 |
+
)
|
| 150 |
+
resources_per_learner = {
|
| 151 |
+
"CPU": num_cpus_per_learner,
|
| 152 |
+
"GPU": num_gpus_per_learner,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
backend_executor = BackendExecutor(
|
| 156 |
+
backend_config=backend_config,
|
| 157 |
+
num_workers=self.config.num_learners,
|
| 158 |
+
resources_per_worker=resources_per_learner,
|
| 159 |
+
max_retries=0,
|
| 160 |
+
)
|
| 161 |
+
backend_executor.start(
|
| 162 |
+
train_cls=learner_class,
|
| 163 |
+
train_cls_kwargs={
|
| 164 |
+
"config": config,
|
| 165 |
+
"module_spec": module_spec,
|
| 166 |
+
},
|
| 167 |
+
)
|
| 168 |
+
self._backend_executor = backend_executor
|
| 169 |
+
|
| 170 |
+
self._workers = [w.actor for w in backend_executor.worker_group.workers]
|
| 171 |
+
|
| 172 |
+
# Run the neural network building code on remote workers.
|
| 173 |
+
ray.get([w.build.remote() for w in self._workers])
|
| 174 |
+
|
| 175 |
+
self._worker_manager = FaultTolerantActorManager(
|
| 176 |
+
self._workers,
|
| 177 |
+
max_remote_requests_in_flight_per_actor=(
|
| 178 |
+
self.config.max_requests_in_flight_per_learner
|
| 179 |
+
),
|
| 180 |
+
)
|
| 181 |
+
# Counters for the tags for asynchronous update requests that are
|
| 182 |
+
# in-flight. Used for keeping trakc of and grouping together the results of
|
| 183 |
+
# requests that were sent to the workers at the same time.
|
| 184 |
+
self._update_request_tags = Counter()
|
| 185 |
+
self._update_request_tag = 0
|
| 186 |
+
self._update_request_results = {}
|
| 187 |
+
|
| 188 |
+
# TODO (sven): Replace this with call to `self.metrics.peek()`?
|
| 189 |
+
# Currently LearnerGroup does not have a metrics object.
|
| 190 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 191 |
+
"""Returns the current stats for the input queue for this learner group."""
|
| 192 |
+
return {
|
| 193 |
+
"learner_group_ts_dropped": self._ts_dropped,
|
| 194 |
+
"actor_manager_num_outstanding_async_reqs": (
|
| 195 |
+
0
|
| 196 |
+
if self.is_local
|
| 197 |
+
else self._worker_manager.num_outstanding_async_reqs()
|
| 198 |
+
),
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def is_remote(self) -> bool:
|
| 203 |
+
return self.config.num_learners > 0
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def is_local(self) -> bool:
|
| 207 |
+
return not self.is_remote
|
| 208 |
+
|
| 209 |
+
def update_from_batch(
|
| 210 |
+
self,
|
| 211 |
+
batch: MultiAgentBatch,
|
| 212 |
+
*,
|
| 213 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 214 |
+
async_update: bool = False,
|
| 215 |
+
return_state: bool = False,
|
| 216 |
+
num_epochs: int = 1,
|
| 217 |
+
minibatch_size: Optional[int] = None,
|
| 218 |
+
shuffle_batch_per_epoch: bool = False,
|
| 219 |
+
# User kwargs.
|
| 220 |
+
**kwargs,
|
| 221 |
+
) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
|
| 222 |
+
"""Performs gradient based update(s) on the Learner(s), based on given batch.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
batch: A data batch to use for the update. If there are more
|
| 226 |
+
than one Learner workers, the batch is split amongst these and one
|
| 227 |
+
shard is sent to each Learner.
|
| 228 |
+
async_update: Whether the update request(s) to the Learner workers should be
|
| 229 |
+
sent asynchronously. If True, will return NOT the results from the
|
| 230 |
+
update on the given data, but all results from prior asynchronous update
|
| 231 |
+
requests that have not been returned thus far.
|
| 232 |
+
return_state: Whether to include one of the Learner worker's state from
|
| 233 |
+
after the update step in the returned results dict (under the
|
| 234 |
+
`_rl_module_state_after_update` key). Note that after an update, all
|
| 235 |
+
Learner workers' states should be identical, so we use the first
|
| 236 |
+
Learner's state here. Useful for avoiding an extra `get_weights()` call,
|
| 237 |
+
e.g. for synchronizing EnvRunner weights.
|
| 238 |
+
num_epochs: The number of complete passes over the entire train batch. Each
|
| 239 |
+
pass might be further split into n minibatches (if `minibatch_size`
|
| 240 |
+
provided).
|
| 241 |
+
minibatch_size: The size of minibatches to use to further split the train
|
| 242 |
+
`batch` into sub-batches. The `batch` is then iterated over n times
|
| 243 |
+
where n is `len(batch) // minibatch_size`.
|
| 244 |
+
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
|
| 245 |
+
If the train batch has a time rank (axis=1), shuffling will only take
|
| 246 |
+
place along the batch axis to not disturb any intact (episode)
|
| 247 |
+
trajectories. Also, shuffling is always skipped if `minibatch_size` is
|
| 248 |
+
None, meaning the entire train batch is processed each epoch, making it
|
| 249 |
+
unnecessary to shuffle.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
If `async_update` is False, a dictionary with the reduced results of the
|
| 253 |
+
updates from the Learner(s) or a list of dictionaries of results from the
|
| 254 |
+
updates from the Learner(s).
|
| 255 |
+
If `async_update` is True, a list of list of dictionaries of results, where
|
| 256 |
+
the outer list corresponds to separate previous calls to this method, and
|
| 257 |
+
the inner list corresponds to the results from each Learner(s). Or if the
|
| 258 |
+
results are reduced, a list of dictionaries of the reduced results from each
|
| 259 |
+
call to async_update that is ready.
|
| 260 |
+
"""
|
| 261 |
+
return self._update(
|
| 262 |
+
batch=batch,
|
| 263 |
+
timesteps=timesteps,
|
| 264 |
+
async_update=async_update,
|
| 265 |
+
return_state=return_state,
|
| 266 |
+
num_epochs=num_epochs,
|
| 267 |
+
minibatch_size=minibatch_size,
|
| 268 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
|
| 269 |
+
**kwargs,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def update_from_episodes(
|
| 273 |
+
self,
|
| 274 |
+
episodes: List[EpisodeType],
|
| 275 |
+
*,
|
| 276 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 277 |
+
async_update: bool = False,
|
| 278 |
+
return_state: bool = False,
|
| 279 |
+
num_epochs: int = 1,
|
| 280 |
+
minibatch_size: Optional[int] = None,
|
| 281 |
+
shuffle_batch_per_epoch: bool = False,
|
| 282 |
+
# User kwargs.
|
| 283 |
+
**kwargs,
|
| 284 |
+
) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
|
| 285 |
+
"""Performs gradient based update(s) on the Learner(s), based on given episodes.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
episodes: A list of Episodes to process and perform the update
|
| 289 |
+
for. If there are more than one Learner workers, the list of episodes
|
| 290 |
+
is split amongst these and one list shard is sent to each Learner.
|
| 291 |
+
async_update: Whether the update request(s) to the Learner workers should be
|
| 292 |
+
sent asynchronously. If True, will return NOT the results from the
|
| 293 |
+
update on the given data, but all results from prior asynchronous update
|
| 294 |
+
requests that have not been returned thus far.
|
| 295 |
+
return_state: Whether to include one of the Learner worker's state from
|
| 296 |
+
after the update step in the returned results dict (under the
|
| 297 |
+
`_rl_module_state_after_update` key). Note that after an update, all
|
| 298 |
+
Learner workers' states should be identical, so we use the first
|
| 299 |
+
Learner's state here. Useful for avoiding an extra `get_weights()` call,
|
| 300 |
+
e.g. for synchronizing EnvRunner weights.
|
| 301 |
+
num_epochs: The number of complete passes over the entire train batch. Each
|
| 302 |
+
pass might be further split into n minibatches (if `minibatch_size`
|
| 303 |
+
provided). The train batch is generated from the given `episodes`
|
| 304 |
+
through the Learner connector pipeline.
|
| 305 |
+
minibatch_size: The size of minibatches to use to further split the train
|
| 306 |
+
`batch` into sub-batches. The `batch` is then iterated over n times
|
| 307 |
+
where n is `len(batch) // minibatch_size`. The train batch is generated
|
| 308 |
+
from the given `episodes` through the Learner connector pipeline.
|
| 309 |
+
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
|
| 310 |
+
If the train batch has a time rank (axis=1), shuffling will only take
|
| 311 |
+
place along the batch axis to not disturb any intact (episode)
|
| 312 |
+
trajectories. Also, shuffling is always skipped if `minibatch_size` is
|
| 313 |
+
None, meaning the entire train batch is processed each epoch, making it
|
| 314 |
+
unnecessary to shuffle. The train batch is generated from the given
|
| 315 |
+
`episodes` through the Learner connector pipeline.
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
If async_update is False, a dictionary with the reduced results of the
|
| 319 |
+
updates from the Learner(s) or a list of dictionaries of results from the
|
| 320 |
+
updates from the Learner(s).
|
| 321 |
+
If async_update is True, a list of list of dictionaries of results, where
|
| 322 |
+
the outer list corresponds to separate previous calls to this method, and
|
| 323 |
+
the inner list corresponds to the results from each Learner(s). Or if the
|
| 324 |
+
results are reduced, a list of dictionaries of the reduced results from each
|
| 325 |
+
call to async_update that is ready.
|
| 326 |
+
"""
|
| 327 |
+
return self._update(
|
| 328 |
+
episodes=episodes,
|
| 329 |
+
timesteps=timesteps,
|
| 330 |
+
async_update=async_update,
|
| 331 |
+
return_state=return_state,
|
| 332 |
+
num_epochs=num_epochs,
|
| 333 |
+
minibatch_size=minibatch_size,
|
| 334 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
|
| 335 |
+
**kwargs,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def _update(
|
| 339 |
+
self,
|
| 340 |
+
*,
|
| 341 |
+
batch: Optional[MultiAgentBatch] = None,
|
| 342 |
+
episodes: Optional[List[EpisodeType]] = None,
|
| 343 |
+
timesteps: Optional[Dict[str, Any]] = None,
|
| 344 |
+
async_update: bool = False,
|
| 345 |
+
return_state: bool = False,
|
| 346 |
+
num_epochs: int = 1,
|
| 347 |
+
num_iters: int = 1,
|
| 348 |
+
minibatch_size: Optional[int] = None,
|
| 349 |
+
shuffle_batch_per_epoch: bool = False,
|
| 350 |
+
**kwargs,
|
| 351 |
+
) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
|
| 352 |
+
|
| 353 |
+
# Define function to be called on all Learner actors (or the local learner).
|
| 354 |
+
def _learner_update(
|
| 355 |
+
_learner: Learner,
|
| 356 |
+
*,
|
| 357 |
+
_batch_shard=None,
|
| 358 |
+
_episodes_shard=None,
|
| 359 |
+
_timesteps=None,
|
| 360 |
+
_return_state=False,
|
| 361 |
+
_num_total_minibatches=0,
|
| 362 |
+
**_kwargs,
|
| 363 |
+
):
|
| 364 |
+
# If the batch shard is an `DataIterator` we have an offline
|
| 365 |
+
# multi-learner setup and `update_from_iterator` needs to
|
| 366 |
+
# handle updating.
|
| 367 |
+
if isinstance(_batch_shard, ray.data.DataIterator):
|
| 368 |
+
result = _learner.update_from_iterator(
|
| 369 |
+
iterator=_batch_shard,
|
| 370 |
+
timesteps=_timesteps,
|
| 371 |
+
minibatch_size=minibatch_size,
|
| 372 |
+
num_iters=num_iters,
|
| 373 |
+
**_kwargs,
|
| 374 |
+
)
|
| 375 |
+
elif _batch_shard is not None:
|
| 376 |
+
result = _learner.update_from_batch(
|
| 377 |
+
batch=_batch_shard,
|
| 378 |
+
timesteps=_timesteps,
|
| 379 |
+
num_epochs=num_epochs,
|
| 380 |
+
minibatch_size=minibatch_size,
|
| 381 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
|
| 382 |
+
**_kwargs,
|
| 383 |
+
)
|
| 384 |
+
else:
|
| 385 |
+
result = _learner.update_from_episodes(
|
| 386 |
+
episodes=_episodes_shard,
|
| 387 |
+
timesteps=_timesteps,
|
| 388 |
+
num_epochs=num_epochs,
|
| 389 |
+
minibatch_size=minibatch_size,
|
| 390 |
+
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
|
| 391 |
+
num_total_minibatches=_num_total_minibatches,
|
| 392 |
+
**_kwargs,
|
| 393 |
+
)
|
| 394 |
+
if _return_state and result:
|
| 395 |
+
result["_rl_module_state_after_update"] = _learner.get_state(
|
| 396 |
+
# Only return the state of those RLModules that actually returned
|
| 397 |
+
# results and thus got probably updated.
|
| 398 |
+
components=[
|
| 399 |
+
COMPONENT_RL_MODULE + "/" + mid
|
| 400 |
+
for mid in result
|
| 401 |
+
if mid != ALL_MODULES
|
| 402 |
+
],
|
| 403 |
+
inference_only=True,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
return result
|
| 407 |
+
|
| 408 |
+
# Local Learner worker: Don't shard batch/episodes, just run data as-is through
|
| 409 |
+
# this Learner.
|
| 410 |
+
if self.is_local:
|
| 411 |
+
if async_update:
|
| 412 |
+
raise ValueError(
|
| 413 |
+
"Cannot call `update_from_batch(async_update=True)` when running in"
|
| 414 |
+
" local mode! Try setting `config.num_learners > 0`."
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
if isinstance(batch, list) and isinstance(batch[0], ray.ObjectRef):
|
| 418 |
+
assert len(batch) == 1
|
| 419 |
+
batch = ray.get(batch[0])
|
| 420 |
+
|
| 421 |
+
results = [
|
| 422 |
+
_learner_update(
|
| 423 |
+
_learner=self._learner,
|
| 424 |
+
_batch_shard=batch,
|
| 425 |
+
_episodes_shard=episodes,
|
| 426 |
+
_timesteps=timesteps,
|
| 427 |
+
_return_state=return_state,
|
| 428 |
+
**kwargs,
|
| 429 |
+
)
|
| 430 |
+
]
|
| 431 |
+
# One or more remote Learners: Shard batch/episodes into equal pieces (roughly
|
| 432 |
+
# equal if multi-agent AND episodes) and send each Learner worker one of these
|
| 433 |
+
# shards.
|
| 434 |
+
else:
|
| 435 |
+
# MultiAgentBatch: Shard into equal pieces.
|
| 436 |
+
# TODO (sven): The sharder used here destroys - for multi-agent only -
|
| 437 |
+
# the relationship of the different agents' timesteps to each other.
|
| 438 |
+
# Thus, in case the algorithm requires agent-synchronized data (aka.
|
| 439 |
+
# "lockstep"), the `ShardBatchIterator` should not be used.
|
| 440 |
+
# Then again, we might move into a world where Learner always
|
| 441 |
+
# receives Episodes, never batches.
|
| 442 |
+
if isinstance(batch, list) and isinstance(batch[0], ray.data.DataIterator):
|
| 443 |
+
partials = [
|
| 444 |
+
partial(
|
| 445 |
+
_learner_update,
|
| 446 |
+
_batch_shard=iterator,
|
| 447 |
+
_return_state=(return_state and i == 0),
|
| 448 |
+
_timesteps=timesteps,
|
| 449 |
+
**kwargs,
|
| 450 |
+
)
|
| 451 |
+
# Note, `OfflineData` defines exactly as many iterators as there
|
| 452 |
+
# are learners.
|
| 453 |
+
for i, iterator in enumerate(batch)
|
| 454 |
+
]
|
| 455 |
+
elif isinstance(batch, list) and isinstance(batch[0], ObjectRef):
|
| 456 |
+
assert len(batch) == len(self._workers)
|
| 457 |
+
partials = [
|
| 458 |
+
partial(
|
| 459 |
+
_learner_update,
|
| 460 |
+
_batch_shard=batch_shard,
|
| 461 |
+
_timesteps=timesteps,
|
| 462 |
+
_return_state=(return_state and i == 0),
|
| 463 |
+
**kwargs,
|
| 464 |
+
)
|
| 465 |
+
for i, batch_shard in enumerate(batch)
|
| 466 |
+
]
|
| 467 |
+
elif batch is not None:
|
| 468 |
+
partials = [
|
| 469 |
+
partial(
|
| 470 |
+
_learner_update,
|
| 471 |
+
_batch_shard=batch_shard,
|
| 472 |
+
_return_state=(return_state and i == 0),
|
| 473 |
+
_timesteps=timesteps,
|
| 474 |
+
**kwargs,
|
| 475 |
+
)
|
| 476 |
+
for i, batch_shard in enumerate(
|
| 477 |
+
ShardBatchIterator(batch, len(self._workers))
|
| 478 |
+
)
|
| 479 |
+
]
|
| 480 |
+
elif isinstance(episodes, list) and isinstance(episodes[0], ObjectRef):
|
| 481 |
+
partials = [
|
| 482 |
+
partial(
|
| 483 |
+
_learner_update,
|
| 484 |
+
_episodes_shard=episodes_shard,
|
| 485 |
+
_timesteps=timesteps,
|
| 486 |
+
_return_state=(return_state and i == 0),
|
| 487 |
+
**kwargs,
|
| 488 |
+
)
|
| 489 |
+
for i, episodes_shard in enumerate(
|
| 490 |
+
ShardObjectRefIterator(episodes, len(self._workers))
|
| 491 |
+
)
|
| 492 |
+
]
|
| 493 |
+
# Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal
|
| 494 |
+
# in case of multi-agent).
|
| 495 |
+
else:
|
| 496 |
+
from ray.data.iterator import DataIterator
|
| 497 |
+
|
| 498 |
+
if isinstance(episodes[0], DataIterator):
|
| 499 |
+
num_total_minibatches = 0
|
| 500 |
+
partials = [
|
| 501 |
+
partial(
|
| 502 |
+
_learner_update,
|
| 503 |
+
_episodes_shard=episodes_shard,
|
| 504 |
+
_timesteps=timesteps,
|
| 505 |
+
_num_total_minibatches=num_total_minibatches,
|
| 506 |
+
)
|
| 507 |
+
for episodes_shard in episodes
|
| 508 |
+
]
|
| 509 |
+
else:
|
| 510 |
+
eps_shards = list(
|
| 511 |
+
ShardEpisodesIterator(
|
| 512 |
+
episodes,
|
| 513 |
+
len(self._workers),
|
| 514 |
+
len_lookback_buffer=self.config.episode_lookback_horizon,
|
| 515 |
+
)
|
| 516 |
+
)
|
| 517 |
+
# In the multi-agent case AND `minibatch_size` AND num_workers
|
| 518 |
+
# > 1, we compute a max iteration counter such that the different
|
| 519 |
+
# Learners will not go through a different number of iterations.
|
| 520 |
+
num_total_minibatches = 0
|
| 521 |
+
if minibatch_size and len(self._workers) > 1:
|
| 522 |
+
num_total_minibatches = self._compute_num_total_minibatches(
|
| 523 |
+
episodes,
|
| 524 |
+
len(self._workers),
|
| 525 |
+
minibatch_size,
|
| 526 |
+
num_epochs,
|
| 527 |
+
)
|
| 528 |
+
partials = [
|
| 529 |
+
partial(
|
| 530 |
+
_learner_update,
|
| 531 |
+
_episodes_shard=eps_shard,
|
| 532 |
+
_timesteps=timesteps,
|
| 533 |
+
_num_total_minibatches=num_total_minibatches,
|
| 534 |
+
)
|
| 535 |
+
for eps_shard in eps_shards
|
| 536 |
+
]
|
| 537 |
+
|
| 538 |
+
if async_update:
|
| 539 |
+
# Retrieve all ready results (kicked off by prior calls to this method).
|
| 540 |
+
tags_to_get = []
|
| 541 |
+
for tag in self._update_request_tags.keys():
|
| 542 |
+
result = self._worker_manager.fetch_ready_async_reqs(
|
| 543 |
+
tags=[str(tag)], timeout_seconds=0.0
|
| 544 |
+
)
|
| 545 |
+
if tag not in self._update_request_results:
|
| 546 |
+
self._update_request_results[tag] = result
|
| 547 |
+
else:
|
| 548 |
+
for r in result:
|
| 549 |
+
self._update_request_results[tag].add_result(
|
| 550 |
+
r.actor_id, r.result_or_error, tag
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# Still not done with this `tag` -> skip out early.
|
| 554 |
+
if (
|
| 555 |
+
self._update_request_tags[tag]
|
| 556 |
+
> len(self._update_request_results[tag].result_or_errors)
|
| 557 |
+
> 0
|
| 558 |
+
):
|
| 559 |
+
break
|
| 560 |
+
tags_to_get.append(tag)
|
| 561 |
+
|
| 562 |
+
# Send out new request(s), if there is still capacity on the actors
|
| 563 |
+
# (each actor is allowed only some number of max in-flight requests
|
| 564 |
+
# at the same time).
|
| 565 |
+
update_tag = self._update_request_tag
|
| 566 |
+
self._update_request_tag += 1
|
| 567 |
+
num_sent_requests = self._worker_manager.foreach_actor_async(
|
| 568 |
+
partials, tag=str(update_tag)
|
| 569 |
+
)
|
| 570 |
+
if num_sent_requests:
|
| 571 |
+
self._update_request_tags[update_tag] = num_sent_requests
|
| 572 |
+
|
| 573 |
+
# Some requests were dropped, record lost ts/data.
|
| 574 |
+
if num_sent_requests != len(self._workers):
|
| 575 |
+
factor = 1 - (num_sent_requests / len(self._workers))
|
| 576 |
+
# Batch: Measure its length.
|
| 577 |
+
if episodes is None:
|
| 578 |
+
dropped = len(batch)
|
| 579 |
+
# List of Ray ObjectRefs (each object ref is a list of episodes of
|
| 580 |
+
# total len=`rollout_fragment_length * num_envs_per_env_runner`)
|
| 581 |
+
elif isinstance(episodes[0], ObjectRef):
|
| 582 |
+
dropped = (
|
| 583 |
+
len(episodes)
|
| 584 |
+
* self.config.get_rollout_fragment_length()
|
| 585 |
+
* self.config.num_envs_per_env_runner
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
dropped = sum(len(e) for e in episodes)
|
| 589 |
+
|
| 590 |
+
self._ts_dropped += factor * dropped
|
| 591 |
+
|
| 592 |
+
# NOTE: There is a strong assumption here that the requests launched to
|
| 593 |
+
# learner workers will return at the same time, since they have a
|
| 594 |
+
# barrier inside for gradient aggregation. Therefore, results should be
|
| 595 |
+
# a list of lists where each inner list should be the length of the
|
| 596 |
+
# number of learner workers, if results from an non-blocking update are
|
| 597 |
+
# ready.
|
| 598 |
+
results = self._get_async_results(tags_to_get)
|
| 599 |
+
|
| 600 |
+
else:
|
| 601 |
+
results = self._get_results(
|
| 602 |
+
self._worker_manager.foreach_actor(partials)
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return results
|
| 606 |
+
|
| 607 |
+
# TODO (sven): Move this into FaultTolerantActorManager?
|
| 608 |
+
def _get_results(self, results):
|
| 609 |
+
processed_results = []
|
| 610 |
+
for result in results:
|
| 611 |
+
result_or_error = result.get()
|
| 612 |
+
if result.ok:
|
| 613 |
+
processed_results.append(result_or_error)
|
| 614 |
+
else:
|
| 615 |
+
raise result_or_error
|
| 616 |
+
return processed_results
|
| 617 |
+
|
| 618 |
+
def _get_async_results(self, tags_to_get):
|
| 619 |
+
"""Get results from the worker manager and group them by tag.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
A list of lists of results, where each inner list contains all results
|
| 623 |
+
for same tags.
|
| 624 |
+
|
| 625 |
+
"""
|
| 626 |
+
unprocessed_results = defaultdict(list)
|
| 627 |
+
for tag in tags_to_get:
|
| 628 |
+
results = self._update_request_results[tag]
|
| 629 |
+
for result in results:
|
| 630 |
+
result_or_error = result.get()
|
| 631 |
+
if result.ok:
|
| 632 |
+
if result.tag is None:
|
| 633 |
+
raise RuntimeError(
|
| 634 |
+
"Cannot call `LearnerGroup._get_async_results()` on "
|
| 635 |
+
"untagged async requests!"
|
| 636 |
+
)
|
| 637 |
+
tag = int(result.tag)
|
| 638 |
+
unprocessed_results[tag].append(result_or_error)
|
| 639 |
+
|
| 640 |
+
if tag in self._update_request_tags:
|
| 641 |
+
self._update_request_tags[tag] -= 1
|
| 642 |
+
if self._update_request_tags[tag] == 0:
|
| 643 |
+
del self._update_request_tags[tag]
|
| 644 |
+
del self._update_request_results[tag]
|
| 645 |
+
else:
|
| 646 |
+
assert False
|
| 647 |
+
|
| 648 |
+
else:
|
| 649 |
+
raise result_or_error
|
| 650 |
+
|
| 651 |
+
return list(unprocessed_results.values())
|
| 652 |
+
|
| 653 |
+
def add_module(
|
| 654 |
+
self,
|
| 655 |
+
*,
|
| 656 |
+
module_id: ModuleID,
|
| 657 |
+
module_spec: RLModuleSpec,
|
| 658 |
+
config_overrides: Optional[Dict] = None,
|
| 659 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 660 |
+
) -> MultiRLModuleSpec:
|
| 661 |
+
"""Adds a module to the underlying MultiRLModule.
|
| 662 |
+
|
| 663 |
+
Changes this Learner's config in order to make this architectural change
|
| 664 |
+
permanent wrt. to checkpointing.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
module_id: The ModuleID of the module to be added.
|
| 668 |
+
module_spec: The ModuleSpec of the module to be added.
|
| 669 |
+
config_overrides: The `AlgorithmConfig` overrides that should apply to
|
| 670 |
+
the new Module, if any.
|
| 671 |
+
new_should_module_be_updated: An optional sequence of ModuleIDs or a
|
| 672 |
+
callable taking ModuleID and SampleBatchType and returning whether the
|
| 673 |
+
ModuleID should be updated (trained).
|
| 674 |
+
If None, will keep the existing setup in place. RLModules,
|
| 675 |
+
whose IDs are not in the list (or for which the callable
|
| 676 |
+
returns False) will not be updated.
|
| 677 |
+
|
| 678 |
+
Returns:
|
| 679 |
+
The new MultiRLModuleSpec (after the change has been performed).
|
| 680 |
+
"""
|
| 681 |
+
validate_module_id(module_id, error=True)
|
| 682 |
+
|
| 683 |
+
# Force-set inference-only = False.
|
| 684 |
+
module_spec = copy.deepcopy(module_spec)
|
| 685 |
+
module_spec.inference_only = False
|
| 686 |
+
|
| 687 |
+
results = self.foreach_learner(
|
| 688 |
+
func=lambda _learner: _learner.add_module(
|
| 689 |
+
module_id=module_id,
|
| 690 |
+
module_spec=module_spec,
|
| 691 |
+
config_overrides=config_overrides,
|
| 692 |
+
new_should_module_be_updated=new_should_module_be_updated,
|
| 693 |
+
),
|
| 694 |
+
)
|
| 695 |
+
marl_spec = self._get_results(results)[0]
|
| 696 |
+
|
| 697 |
+
# Change our config (AlgorithmConfig) to contain the new Module.
|
| 698 |
+
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
|
| 699 |
+
# but we'll deprecate config.policies soon anyway.
|
| 700 |
+
self.config.policies[module_id] = PolicySpec()
|
| 701 |
+
if config_overrides is not None:
|
| 702 |
+
self.config.multi_agent(
|
| 703 |
+
algorithm_config_overrides_per_module={module_id: config_overrides}
|
| 704 |
+
)
|
| 705 |
+
self.config.rl_module(rl_module_spec=marl_spec)
|
| 706 |
+
if new_should_module_be_updated is not None:
|
| 707 |
+
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
|
| 708 |
+
|
| 709 |
+
return marl_spec
|
| 710 |
+
|
| 711 |
+
def remove_module(
|
| 712 |
+
self,
|
| 713 |
+
module_id: ModuleID,
|
| 714 |
+
*,
|
| 715 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 716 |
+
) -> MultiRLModuleSpec:
|
| 717 |
+
"""Removes a module from the Learner.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
module_id: The ModuleID of the module to be removed.
|
| 721 |
+
new_should_module_be_updated: An optional sequence of ModuleIDs or a
|
| 722 |
+
callable taking ModuleID and SampleBatchType and returning whether the
|
| 723 |
+
ModuleID should be updated (trained).
|
| 724 |
+
If None, will keep the existing setup in place. RLModules,
|
| 725 |
+
whose IDs are not in the list (or for which the callable
|
| 726 |
+
returns False) will not be updated.
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
The new MultiRLModuleSpec (after the change has been performed).
|
| 730 |
+
"""
|
| 731 |
+
results = self.foreach_learner(
|
| 732 |
+
func=lambda _learner: _learner.remove_module(
|
| 733 |
+
module_id=module_id,
|
| 734 |
+
new_should_module_be_updated=new_should_module_be_updated,
|
| 735 |
+
),
|
| 736 |
+
)
|
| 737 |
+
marl_spec = self._get_results(results)[0]
|
| 738 |
+
|
| 739 |
+
# Change self.config to reflect the new architecture.
|
| 740 |
+
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
|
| 741 |
+
# but we'll deprecate config.policies soon anyway.
|
| 742 |
+
del self.config.policies[module_id]
|
| 743 |
+
self.config.algorithm_config_overrides_per_module.pop(module_id, None)
|
| 744 |
+
if new_should_module_be_updated is not None:
|
| 745 |
+
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
|
| 746 |
+
self.config.rl_module(rl_module_spec=marl_spec)
|
| 747 |
+
|
| 748 |
+
return marl_spec
|
| 749 |
+
|
| 750 |
+
@override(Checkpointable)
|
| 751 |
+
def get_state(
|
| 752 |
+
self,
|
| 753 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 754 |
+
*,
|
| 755 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 756 |
+
**kwargs,
|
| 757 |
+
) -> StateDict:
|
| 758 |
+
state = {}
|
| 759 |
+
|
| 760 |
+
if self._check_component(COMPONENT_LEARNER, components, not_components):
|
| 761 |
+
if self.is_local:
|
| 762 |
+
state[COMPONENT_LEARNER] = self._learner.get_state(
|
| 763 |
+
components=self._get_subcomponents(COMPONENT_LEARNER, components),
|
| 764 |
+
not_components=self._get_subcomponents(
|
| 765 |
+
COMPONENT_LEARNER, not_components
|
| 766 |
+
),
|
| 767 |
+
**kwargs,
|
| 768 |
+
)
|
| 769 |
+
else:
|
| 770 |
+
worker = self._worker_manager.healthy_actor_ids()[0]
|
| 771 |
+
assert len(self._workers) == self._worker_manager.num_healthy_actors()
|
| 772 |
+
_comps = self._get_subcomponents(COMPONENT_LEARNER, components)
|
| 773 |
+
_not_comps = self._get_subcomponents(COMPONENT_LEARNER, not_components)
|
| 774 |
+
results = self._worker_manager.foreach_actor(
|
| 775 |
+
lambda w: w.get_state(_comps, not_components=_not_comps, **kwargs),
|
| 776 |
+
remote_actor_ids=[worker],
|
| 777 |
+
)
|
| 778 |
+
state[COMPONENT_LEARNER] = self._get_results(results)[0]
|
| 779 |
+
|
| 780 |
+
return state
|
| 781 |
+
|
| 782 |
+
@override(Checkpointable)
|
| 783 |
+
def set_state(self, state: StateDict) -> None:
|
| 784 |
+
if COMPONENT_LEARNER in state:
|
| 785 |
+
if self.is_local:
|
| 786 |
+
self._learner.set_state(state[COMPONENT_LEARNER])
|
| 787 |
+
else:
|
| 788 |
+
state_ref = ray.put(state[COMPONENT_LEARNER])
|
| 789 |
+
self.foreach_learner(
|
| 790 |
+
lambda _learner, _ref=state_ref: _learner.set_state(ray.get(_ref))
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
def get_weights(
|
| 794 |
+
self, module_ids: Optional[Collection[ModuleID]] = None
|
| 795 |
+
) -> StateDict:
|
| 796 |
+
"""Convenience method instead of self.get_state(components=...).
|
| 797 |
+
|
| 798 |
+
Args:
|
| 799 |
+
module_ids: An optional collection of ModuleIDs for which to return weights.
|
| 800 |
+
If None (default), return weights of all RLModules.
|
| 801 |
+
|
| 802 |
+
Returns:
|
| 803 |
+
The results of
|
| 804 |
+
`self.get_state(components='learner/rl_module')['learner']['rl_module']`.
|
| 805 |
+
"""
|
| 806 |
+
# Return the entire RLModule state (all possible single-agent RLModules).
|
| 807 |
+
if module_ids is None:
|
| 808 |
+
components = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
|
| 809 |
+
# Return a subset of the single-agent RLModules.
|
| 810 |
+
else:
|
| 811 |
+
components = [
|
| 812 |
+
"".join(tup)
|
| 813 |
+
for tup in itertools.product(
|
| 814 |
+
[COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/"],
|
| 815 |
+
list(module_ids),
|
| 816 |
+
)
|
| 817 |
+
]
|
| 818 |
+
state = self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
|
| 819 |
+
return state
|
| 820 |
+
|
| 821 |
+
def set_weights(self, weights) -> None:
|
| 822 |
+
"""Convenience method instead of self.set_state({'learner': {'rl_module': ..}}).
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
weights: The weights dict of the MultiRLModule of a Learner inside this
|
| 826 |
+
LearnerGroup.
|
| 827 |
+
"""
|
| 828 |
+
self.set_state({COMPONENT_LEARNER: {COMPONENT_RL_MODULE: weights}})
|
| 829 |
+
|
| 830 |
+
@override(Checkpointable)
|
| 831 |
+
def get_ctor_args_and_kwargs(self):
|
| 832 |
+
return (
|
| 833 |
+
(), # *args
|
| 834 |
+
{
|
| 835 |
+
"config": self.config,
|
| 836 |
+
"module_spec": self._module_spec,
|
| 837 |
+
}, # **kwargs
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
@override(Checkpointable)
|
| 841 |
+
def get_checkpointable_components(self):
|
| 842 |
+
# Return the entire ActorManager, if remote. Otherwise, return the
|
| 843 |
+
# local worker. Also, don't give the component (Learner) a name ("")
|
| 844 |
+
# as it's the only component in this LearnerGroup to be saved.
|
| 845 |
+
return [
|
| 846 |
+
(
|
| 847 |
+
COMPONENT_LEARNER,
|
| 848 |
+
self._learner if self.is_local else self._worker_manager,
|
| 849 |
+
)
|
| 850 |
+
]
|
| 851 |
+
|
| 852 |
+
def foreach_learner(
|
| 853 |
+
self,
|
| 854 |
+
func: Callable[[Learner, Optional[Any]], T],
|
| 855 |
+
*,
|
| 856 |
+
healthy_only: bool = True,
|
| 857 |
+
remote_actor_ids: List[int] = None,
|
| 858 |
+
timeout_seconds: Optional[float] = None,
|
| 859 |
+
return_obj_refs: bool = False,
|
| 860 |
+
mark_healthy: bool = False,
|
| 861 |
+
**kwargs,
|
| 862 |
+
) -> RemoteCallResults:
|
| 863 |
+
"""Calls the given function on each Learner L with the args: (L, \*\*kwargs).
|
| 864 |
+
|
| 865 |
+
Args:
|
| 866 |
+
func: The function to call on each Learner L with args: (L, \*\*kwargs).
|
| 867 |
+
healthy_only: If True, applies `func` only to Learner actors currently
|
| 868 |
+
tagged "healthy", otherwise to all actors. If `healthy_only=False` and
|
| 869 |
+
`mark_healthy=True`, will send `func` to all actors and mark those
|
| 870 |
+
actors "healthy" that respond to the request within `timeout_seconds`
|
| 871 |
+
and are currently tagged as "unhealthy".
|
| 872 |
+
remote_actor_ids: Apply func on a selected set of remote actors. Use None
|
| 873 |
+
(default) for all actors.
|
| 874 |
+
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
|
| 875 |
+
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
|
| 876 |
+
synchronous execution).
|
| 877 |
+
return_obj_refs: whether to return ObjectRef instead of actual results.
|
| 878 |
+
Note, for fault tolerance reasons, these returned ObjectRefs should
|
| 879 |
+
never be resolved with ray.get() outside of the context of this manager.
|
| 880 |
+
mark_healthy: Whether to mark all those actors healthy again that are
|
| 881 |
+
currently marked unhealthy AND that returned results from the remote
|
| 882 |
+
call (within the given `timeout_seconds`).
|
| 883 |
+
Note that actors are NOT set unhealthy, if they simply time out
|
| 884 |
+
(only if they return a RayActorError).
|
| 885 |
+
Also not that this setting is ignored if `healthy_only=True` (b/c this
|
| 886 |
+
setting only affects actors that are currently tagged as unhealthy).
|
| 887 |
+
|
| 888 |
+
Returns:
|
| 889 |
+
A list of size len(Learners) with the return values of all calls to `func`.
|
| 890 |
+
"""
|
| 891 |
+
if self.is_local:
|
| 892 |
+
results = RemoteCallResults()
|
| 893 |
+
results.add_result(
|
| 894 |
+
None,
|
| 895 |
+
ResultOrError(result=func(self._learner, **kwargs)),
|
| 896 |
+
None,
|
| 897 |
+
)
|
| 898 |
+
return results
|
| 899 |
+
|
| 900 |
+
return self._worker_manager.foreach_actor(
|
| 901 |
+
func=partial(func, **kwargs),
|
| 902 |
+
healthy_only=healthy_only,
|
| 903 |
+
remote_actor_ids=remote_actor_ids,
|
| 904 |
+
timeout_seconds=timeout_seconds,
|
| 905 |
+
return_obj_refs=return_obj_refs,
|
| 906 |
+
mark_healthy=mark_healthy,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
def shutdown(self):
|
| 910 |
+
"""Shuts down the LearnerGroup."""
|
| 911 |
+
if self.is_remote and hasattr(self, "_backend_executor"):
|
| 912 |
+
self._backend_executor.shutdown()
|
| 913 |
+
self._is_shut_down = True
|
| 914 |
+
|
| 915 |
+
def __del__(self):
|
| 916 |
+
if not self._is_shut_down:
|
| 917 |
+
self.shutdown()
|
| 918 |
+
|
| 919 |
+
@staticmethod
|
| 920 |
+
def _compute_num_total_minibatches(
|
| 921 |
+
episodes,
|
| 922 |
+
num_shards,
|
| 923 |
+
minibatch_size,
|
| 924 |
+
num_epochs,
|
| 925 |
+
):
|
| 926 |
+
# Count total number of timesteps per module ID.
|
| 927 |
+
if isinstance(episodes[0], MultiAgentEpisode):
|
| 928 |
+
per_mod_ts = defaultdict(int)
|
| 929 |
+
for ma_episode in episodes:
|
| 930 |
+
for sa_episode in ma_episode.agent_episodes.values():
|
| 931 |
+
per_mod_ts[sa_episode.module_id] += len(sa_episode)
|
| 932 |
+
max_ts = max(per_mod_ts.values())
|
| 933 |
+
else:
|
| 934 |
+
max_ts = sum(map(len, episodes))
|
| 935 |
+
|
| 936 |
+
return int((num_epochs * max_ts) / (num_shards * minibatch_size))
|
| 937 |
+
|
| 938 |
+
@Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False)
|
| 939 |
+
def update(self, *args, **kwargs):
|
| 940 |
+
# Just in case, we would like to revert this API retirement, we can do so
|
| 941 |
+
# easily.
|
| 942 |
+
return self._update(*args, **kwargs, async_update=False)
|
| 943 |
+
|
| 944 |
+
@Deprecated(new="LearnerGroup.update_from_batch(async=True)", error=False)
|
| 945 |
+
def async_update(self, *args, **kwargs):
|
| 946 |
+
# Just in case, we would like to revert this API retirement, we can do so
|
| 947 |
+
# easily.
|
| 948 |
+
return self._update(*args, **kwargs, async_update=True)
|
| 949 |
+
|
| 950 |
+
@Deprecated(new="LearnerGroup.save_to_path(...)", error=True)
|
| 951 |
+
def save_state(self, *args, **kwargs):
|
| 952 |
+
pass
|
| 953 |
+
|
| 954 |
+
@Deprecated(new="LearnerGroup.restore_from_path(...)", error=True)
|
| 955 |
+
def load_state(self, *args, **kwargs):
|
| 956 |
+
pass
|
| 957 |
+
|
| 958 |
+
@Deprecated(new="LearnerGroup.load_from_path(path=..., component=...)", error=False)
|
| 959 |
+
def load_module_state(
|
| 960 |
+
self,
|
| 961 |
+
*,
|
| 962 |
+
multi_rl_module_ckpt_dir: Optional[str] = None,
|
| 963 |
+
modules_to_load: Optional[Set[str]] = None,
|
| 964 |
+
rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None,
|
| 965 |
+
) -> None:
|
| 966 |
+
"""Load the checkpoints of the modules being trained by this LearnerGroup.
|
| 967 |
+
|
| 968 |
+
`load_module_state` can be used 3 ways:
|
| 969 |
+
1. Load a checkpoint for the MultiRLModule being trained by this
|
| 970 |
+
LearnerGroup. Limit the modules that are loaded from the checkpoint
|
| 971 |
+
by specifying the `modules_to_load` argument.
|
| 972 |
+
2. Load the checkpoint(s) for single agent RLModules that
|
| 973 |
+
are in the MultiRLModule being trained by this LearnerGroup.
|
| 974 |
+
3. Load a checkpoint for the MultiRLModule being trained by this
|
| 975 |
+
LearnerGroup and load the checkpoint(s) for single agent RLModules
|
| 976 |
+
that are in the MultiRLModule. The checkpoints for the single
|
| 977 |
+
agent RLModules take precedence over the module states in the
|
| 978 |
+
MultiRLModule checkpoint.
|
| 979 |
+
|
| 980 |
+
NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is
|
| 981 |
+
must be specified. modules_to_load can only be specified if
|
| 982 |
+
multi_rl_module_ckpt_dir is specified.
|
| 983 |
+
|
| 984 |
+
Args:
|
| 985 |
+
multi_rl_module_ckpt_dir: The path to the checkpoint for the
|
| 986 |
+
MultiRLModule.
|
| 987 |
+
modules_to_load: A set of module ids to load from the checkpoint.
|
| 988 |
+
rl_module_ckpt_dirs: A mapping from module ids to the path to a
|
| 989 |
+
checkpoint for a single agent RLModule.
|
| 990 |
+
"""
|
| 991 |
+
if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs):
|
| 992 |
+
raise ValueError(
|
| 993 |
+
"At least one of `multi_rl_module_ckpt_dir` or "
|
| 994 |
+
"`rl_module_ckpt_dirs` must be provided!"
|
| 995 |
+
)
|
| 996 |
+
if multi_rl_module_ckpt_dir:
|
| 997 |
+
multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir)
|
| 998 |
+
if rl_module_ckpt_dirs:
|
| 999 |
+
for module_id, path in rl_module_ckpt_dirs.items():
|
| 1000 |
+
rl_module_ckpt_dirs[module_id] = pathlib.Path(path)
|
| 1001 |
+
|
| 1002 |
+
# MultiRLModule checkpoint is provided.
|
| 1003 |
+
if multi_rl_module_ckpt_dir:
|
| 1004 |
+
# Restore the entire MultiRLModule state.
|
| 1005 |
+
if modules_to_load is None:
|
| 1006 |
+
self.restore_from_path(
|
| 1007 |
+
multi_rl_module_ckpt_dir,
|
| 1008 |
+
component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
|
| 1009 |
+
)
|
| 1010 |
+
# Restore individual module IDs.
|
| 1011 |
+
else:
|
| 1012 |
+
for module_id in modules_to_load:
|
| 1013 |
+
self.restore_from_path(
|
| 1014 |
+
multi_rl_module_ckpt_dir / module_id,
|
| 1015 |
+
component=(
|
| 1016 |
+
COMPONENT_LEARNER
|
| 1017 |
+
+ "/"
|
| 1018 |
+
+ COMPONENT_RL_MODULE
|
| 1019 |
+
+ "/"
|
| 1020 |
+
+ module_id
|
| 1021 |
+
),
|
| 1022 |
+
)
|
| 1023 |
+
if rl_module_ckpt_dirs:
|
| 1024 |
+
for module_id, path in rl_module_ckpt_dirs.items():
|
| 1025 |
+
self.restore_from_path(
|
| 1026 |
+
path,
|
| 1027 |
+
component=(
|
| 1028 |
+
COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id
|
| 1029 |
+
),
|
| 1030 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import (
|
| 4 |
+
Any,
|
| 5 |
+
Callable,
|
| 6 |
+
Dict,
|
| 7 |
+
Hashable,
|
| 8 |
+
Sequence,
|
| 9 |
+
Tuple,
|
| 10 |
+
TYPE_CHECKING,
|
| 11 |
+
Union,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from ray.rllib.core.learner.learner import Learner
|
| 15 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
|
| 16 |
+
from ray.rllib.core.rl_module.rl_module import (
|
| 17 |
+
RLModule,
|
| 18 |
+
RLModuleSpec,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
|
| 21 |
+
from ray.rllib.policy.eager_tf_policy import _convert_to_tf
|
| 22 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
| 23 |
+
from ray.rllib.utils.annotations import (
|
| 24 |
+
override,
|
| 25 |
+
OverrideToImplementCustomLogic,
|
| 26 |
+
)
|
| 27 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 28 |
+
from ray.rllib.utils.typing import (
|
| 29 |
+
ModuleID,
|
| 30 |
+
Optimizer,
|
| 31 |
+
Param,
|
| 32 |
+
ParamDict,
|
| 33 |
+
StateDict,
|
| 34 |
+
TensorType,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 39 |
+
|
| 40 |
+
tf1, tf, tfv = try_import_tf()
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TfLearner(Learner):
|
| 46 |
+
|
| 47 |
+
framework: str = "tf2"
|
| 48 |
+
|
| 49 |
+
def __init__(self, **kwargs):
|
| 50 |
+
# by default in rllib we disable tf2 behavior
|
| 51 |
+
# This call re-enables it as it is needed for using
|
| 52 |
+
# this class.
|
| 53 |
+
try:
|
| 54 |
+
tf1.enable_v2_behavior()
|
| 55 |
+
except ValueError:
|
| 56 |
+
# This is a hack to avoid the error that happens when calling
|
| 57 |
+
# enable_v2_behavior after variables have already been created.
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
super().__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
self._enable_tf_function = self.config.eager_tracing
|
| 63 |
+
|
| 64 |
+
# This is a placeholder which will be filled by
|
| 65 |
+
# `_make_distributed_strategy_if_necessary`.
|
| 66 |
+
self._strategy: tf.distribute.Strategy = None
|
| 67 |
+
|
| 68 |
+
@OverrideToImplementCustomLogic
|
| 69 |
+
@override(Learner)
|
| 70 |
+
def configure_optimizers_for_module(
|
| 71 |
+
self, module_id: ModuleID, config: "AlgorithmConfig" = None
|
| 72 |
+
) -> None:
|
| 73 |
+
module = self._module[module_id]
|
| 74 |
+
|
| 75 |
+
# For this default implementation, the learning rate is handled by the
|
| 76 |
+
# attached lr Scheduler (controlled by self.config.lr, which can be a
|
| 77 |
+
# fixed value or a schedule setting).
|
| 78 |
+
optimizer = tf.keras.optimizers.Adam()
|
| 79 |
+
params = self.get_parameters(module)
|
| 80 |
+
|
| 81 |
+
# This isn't strictly necessary, but makes it so that if a checkpoint is
|
| 82 |
+
# computed before training actually starts, then it will be the same in
|
| 83 |
+
# shape / size as a checkpoint after training starts.
|
| 84 |
+
optimizer.build(module.trainable_variables)
|
| 85 |
+
|
| 86 |
+
# Register the created optimizer (under the default optimizer name).
|
| 87 |
+
self.register_optimizer(
|
| 88 |
+
module_id=module_id,
|
| 89 |
+
optimizer=optimizer,
|
| 90 |
+
params=params,
|
| 91 |
+
lr_or_lr_schedule=config.lr,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
@override(Learner)
|
| 95 |
+
def compute_gradients(
|
| 96 |
+
self,
|
| 97 |
+
loss_per_module: Dict[str, TensorType],
|
| 98 |
+
gradient_tape: "tf.GradientTape",
|
| 99 |
+
**kwargs,
|
| 100 |
+
) -> ParamDict:
|
| 101 |
+
total_loss = sum(loss_per_module.values())
|
| 102 |
+
grads = gradient_tape.gradient(total_loss, self._params)
|
| 103 |
+
return grads
|
| 104 |
+
|
| 105 |
+
@override(Learner)
|
| 106 |
+
def apply_gradients(self, gradients_dict: ParamDict) -> None:
|
| 107 |
+
# TODO (Avnishn, kourosh): apply gradients doesn't work in cases where
|
| 108 |
+
# only some agents have a sample batch that is passed but not others.
|
| 109 |
+
# This is probably because of the way that we are iterating over the
|
| 110 |
+
# parameters in the optim_to_param_dictionary.
|
| 111 |
+
for optimizer in self._optimizer_parameters:
|
| 112 |
+
optim_grad_dict = self.filter_param_dict_for_optimizer(
|
| 113 |
+
optimizer=optimizer, param_dict=gradients_dict
|
| 114 |
+
)
|
| 115 |
+
variable_list = []
|
| 116 |
+
gradient_list = []
|
| 117 |
+
for param_ref, grad in optim_grad_dict.items():
|
| 118 |
+
if grad is not None:
|
| 119 |
+
variable_list.append(self._params[param_ref])
|
| 120 |
+
gradient_list.append(grad)
|
| 121 |
+
optimizer.apply_gradients(zip(gradient_list, variable_list))
|
| 122 |
+
|
| 123 |
+
@override(Learner)
|
| 124 |
+
def restore_from_path(self, path: Union[str, pathlib.Path]) -> None:
|
| 125 |
+
# This operation is potentially very costly because a MultiRLModule is created
|
| 126 |
+
# at build time, destroyed, and then a new one is created from a checkpoint.
|
| 127 |
+
# However, it is necessary due to complications with the way that Ray Tune
|
| 128 |
+
# restores failed trials. When Tune restores a failed trial, it reconstructs the
|
| 129 |
+
# entire experiment from the initial config. Therefore, to reflect any changes
|
| 130 |
+
# made to the learner's modules, the module created by Tune is destroyed and
|
| 131 |
+
# then rebuilt from the checkpoint.
|
| 132 |
+
with self._strategy.scope():
|
| 133 |
+
super().restore_from_path(path)
|
| 134 |
+
|
| 135 |
+
@override(Learner)
|
| 136 |
+
def _get_optimizer_state(self) -> StateDict:
|
| 137 |
+
optim_state = {}
|
| 138 |
+
with tf.init_scope():
|
| 139 |
+
for name, optim in self._named_optimizers.items():
|
| 140 |
+
optim_state[name] = [var.numpy() for var in optim.variables()]
|
| 141 |
+
return optim_state
|
| 142 |
+
|
| 143 |
+
@override(Learner)
|
| 144 |
+
def _set_optimizer_state(self, state: StateDict) -> None:
|
| 145 |
+
for name, state_array in state.items():
|
| 146 |
+
if name not in self._named_optimizers:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"Optimizer {name} in `state` is not known! "
|
| 149 |
+
f"Known optimizers are {self._named_optimizers.keys()}"
|
| 150 |
+
)
|
| 151 |
+
optim = self._named_optimizers[name]
|
| 152 |
+
optim.set_weights(state_array)
|
| 153 |
+
|
| 154 |
+
@override(Learner)
|
| 155 |
+
def get_param_ref(self, param: Param) -> Hashable:
|
| 156 |
+
return param.ref()
|
| 157 |
+
|
| 158 |
+
@override(Learner)
|
| 159 |
+
def get_parameters(self, module: RLModule) -> Sequence[Param]:
|
| 160 |
+
return list(module.trainable_variables)
|
| 161 |
+
|
| 162 |
+
@override(Learner)
|
| 163 |
+
def rl_module_is_compatible(self, module: RLModule) -> bool:
|
| 164 |
+
return isinstance(module, TfRLModule)
|
| 165 |
+
|
| 166 |
+
@override(Learner)
|
| 167 |
+
def _check_registered_optimizer(
|
| 168 |
+
self,
|
| 169 |
+
optimizer: Optimizer,
|
| 170 |
+
params: Sequence[Param],
|
| 171 |
+
) -> None:
|
| 172 |
+
super()._check_registered_optimizer(optimizer, params)
|
| 173 |
+
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
|
| 174 |
+
raise ValueError(
|
| 175 |
+
f"The optimizer ({optimizer}) is not a tf keras optimizer! "
|
| 176 |
+
"Only use tf.keras.optimizers.Optimizer subclasses for TfLearner."
|
| 177 |
+
)
|
| 178 |
+
for param in params:
|
| 179 |
+
if not isinstance(param, tf.Variable):
|
| 180 |
+
raise ValueError(
|
| 181 |
+
f"One of the parameters ({param}) in the registered optimizer "
|
| 182 |
+
"is not a tf.Variable!"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
@override(Learner)
|
| 186 |
+
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
|
| 187 |
+
batch = _convert_to_tf(batch.policy_batches)
|
| 188 |
+
length = max(len(b) for b in batch.values())
|
| 189 |
+
batch = MultiAgentBatch(batch, env_steps=length)
|
| 190 |
+
return batch
|
| 191 |
+
|
| 192 |
+
@override(Learner)
|
| 193 |
+
def add_module(
|
| 194 |
+
self,
|
| 195 |
+
*,
|
| 196 |
+
module_id: ModuleID,
|
| 197 |
+
module_spec: RLModuleSpec,
|
| 198 |
+
) -> None:
|
| 199 |
+
# TODO(Avnishn):
|
| 200 |
+
# WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead
|
| 201 |
+
# currently. We will be working on improving this in the future, but for now
|
| 202 |
+
# please wrap `call_for_each_replica` or `experimental_run` or `run` inside a
|
| 203 |
+
# tf.function to get the best performance.
|
| 204 |
+
# I get this warning any time I add a new module. I see the warning a few times
|
| 205 |
+
# and then it disappears. I think that I will need to open an issue with the TF
|
| 206 |
+
# team.
|
| 207 |
+
with self._strategy.scope():
|
| 208 |
+
super().add_module(
|
| 209 |
+
module_id=module_id,
|
| 210 |
+
module_spec=module_spec,
|
| 211 |
+
)
|
| 212 |
+
if self._enable_tf_function:
|
| 213 |
+
self._possibly_traced_update = tf.function(
|
| 214 |
+
self._untraced_update, reduce_retracing=True
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
@override(Learner)
|
| 218 |
+
def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec:
|
| 219 |
+
with self._strategy.scope():
|
| 220 |
+
marl_spec = super().remove_module(module_id, **kwargs)
|
| 221 |
+
|
| 222 |
+
if self._enable_tf_function:
|
| 223 |
+
self._possibly_traced_update = tf.function(
|
| 224 |
+
self._untraced_update, reduce_retracing=True
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return marl_spec
|
| 228 |
+
|
| 229 |
+
def _make_distributed_strategy_if_necessary(self) -> "tf.distribute.Strategy":
|
| 230 |
+
"""Create a distributed strategy for the learner.
|
| 231 |
+
|
| 232 |
+
A stratgey is a tensorflow object that is used for distributing training and
|
| 233 |
+
gradient computation across multiple devices. By default, a no-op strategy is
|
| 234 |
+
used that is not distributed.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
A strategy for the learner to use for distributed training.
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
if self.config.num_learners > 1:
|
| 241 |
+
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
| 242 |
+
elif self.config.num_gpus_per_learner > 0:
|
| 243 |
+
# mirrored strategy is typically used for multi-gpu training
|
| 244 |
+
# on a single machine, however we can use it for single-gpu
|
| 245 |
+
devices = tf.config.list_logical_devices("GPU")
|
| 246 |
+
assert self.config.local_gpu_idx < len(devices), (
|
| 247 |
+
f"local_gpu_idx {self.config.local_gpu_idx} is not a valid GPU id or "
|
| 248 |
+
"is not available."
|
| 249 |
+
)
|
| 250 |
+
local_gpu = [devices[self.config.local_gpu_idx].name]
|
| 251 |
+
strategy = tf.distribute.MirroredStrategy(devices=local_gpu)
|
| 252 |
+
else:
|
| 253 |
+
# the default strategy is a no-op that can be used in the local mode
|
| 254 |
+
# cpu only case, build will override this if needed.
|
| 255 |
+
strategy = tf.distribute.get_strategy()
|
| 256 |
+
return strategy
|
| 257 |
+
|
| 258 |
+
@override(Learner)
|
| 259 |
+
def build(self) -> None:
|
| 260 |
+
"""Build the TfLearner.
|
| 261 |
+
|
| 262 |
+
This method is specific TfLearner. Before running super() it sets the correct
|
| 263 |
+
distributing strategy with the right device, so that computational graph is
|
| 264 |
+
placed on the correct device. After running super(), depending on eager_tracing
|
| 265 |
+
flag it will decide whether to wrap the update function with tf.function or not.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
# we call build anytime we make a learner, or load a learner from a checkpoint.
|
| 269 |
+
# we can't make a new strategy every time we build, so we only make one the
|
| 270 |
+
# first time build is called.
|
| 271 |
+
if not self._strategy:
|
| 272 |
+
self._strategy = self._make_distributed_strategy_if_necessary()
|
| 273 |
+
|
| 274 |
+
with self._strategy.scope():
|
| 275 |
+
super().build()
|
| 276 |
+
|
| 277 |
+
if self._enable_tf_function:
|
| 278 |
+
self._possibly_traced_update = tf.function(
|
| 279 |
+
self._untraced_update, reduce_retracing=True
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
self._possibly_traced_update = self._untraced_update
|
| 283 |
+
|
| 284 |
+
@override(Learner)
|
| 285 |
+
def _update(self, batch: Dict) -> Tuple[Any, Any, Any]:
|
| 286 |
+
return self._possibly_traced_update(batch)
|
| 287 |
+
|
| 288 |
+
def _untraced_update(
|
| 289 |
+
self,
|
| 290 |
+
batch: Dict,
|
| 291 |
+
# TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
|
| 292 |
+
# eager_tracing=True mode.
|
| 293 |
+
# It seems there may be a clash between the traced-by-tf function and the
|
| 294 |
+
# traced-by-ray functions (for making the TfLearner class a ray actor).
|
| 295 |
+
_ray_trace_ctx=None,
|
| 296 |
+
):
|
| 297 |
+
# Activate tensor-mode on our MetricsLogger.
|
| 298 |
+
self.metrics.activate_tensor_mode()
|
| 299 |
+
|
| 300 |
+
def helper(_batch):
|
| 301 |
+
with tf.GradientTape(persistent=True) as tape:
|
| 302 |
+
fwd_out = self._module.forward_train(_batch)
|
| 303 |
+
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch)
|
| 304 |
+
gradients = self.compute_gradients(loss_per_module, gradient_tape=tape)
|
| 305 |
+
del tape
|
| 306 |
+
postprocessed_gradients = self.postprocess_gradients(gradients)
|
| 307 |
+
self.apply_gradients(postprocessed_gradients)
|
| 308 |
+
|
| 309 |
+
# Deactivate tensor-mode on our MetricsLogger and collect the (tensor)
|
| 310 |
+
# results.
|
| 311 |
+
return fwd_out, loss_per_module, self.metrics.deactivate_tensor_mode()
|
| 312 |
+
|
| 313 |
+
return self._strategy.run(helper, args=(batch,))
|
| 314 |
+
|
| 315 |
+
@override(Learner)
|
| 316 |
+
def _get_tensor_variable(self, value, dtype=None, trainable=False) -> "tf.Tensor":
|
| 317 |
+
return tf.Variable(
|
| 318 |
+
value,
|
| 319 |
+
trainable=trainable,
|
| 320 |
+
dtype=(
|
| 321 |
+
dtype
|
| 322 |
+
or (
|
| 323 |
+
tf.float32
|
| 324 |
+
if isinstance(value, float)
|
| 325 |
+
else tf.int32
|
| 326 |
+
if isinstance(value, int)
|
| 327 |
+
else None
|
| 328 |
+
)
|
| 329 |
+
),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
@staticmethod
|
| 333 |
+
@override(Learner)
|
| 334 |
+
def _get_optimizer_lr(optimizer: "tf.Optimizer") -> float:
|
| 335 |
+
return optimizer.lr
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
@override(Learner)
|
| 339 |
+
def _set_optimizer_lr(optimizer: "tf.Optimizer", lr: float) -> None:
|
| 340 |
+
# When tf creates the optimizer, it seems to detach the optimizer's lr value
|
| 341 |
+
# from the given tf variable.
|
| 342 |
+
# Thus, updating this variable is NOT sufficient to update the actual
|
| 343 |
+
# optimizer's learning rate, so we have to explicitly set it here inside the
|
| 344 |
+
# optimizer object.
|
| 345 |
+
optimizer.lr = lr
|
| 346 |
+
|
| 347 |
+
@staticmethod
|
| 348 |
+
@override(Learner)
|
| 349 |
+
def _get_clip_function() -> Callable:
|
| 350 |
+
from ray.rllib.utils.tf_utils import clip_gradients
|
| 351 |
+
|
| 352 |
+
return clip_gradients
|
| 353 |
+
|
| 354 |
+
@staticmethod
|
| 355 |
+
@override(Learner)
|
| 356 |
+
def _get_global_norm_function() -> Callable:
|
| 357 |
+
return tf.linalg.global_norm
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc
ADDED
|
Binary file (31.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py
ADDED
|
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import logging
|
| 3 |
+
from typing import (
|
| 4 |
+
Any,
|
| 5 |
+
Callable,
|
| 6 |
+
Dict,
|
| 7 |
+
Hashable,
|
| 8 |
+
Optional,
|
| 9 |
+
Sequence,
|
| 10 |
+
Tuple,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from ray.rllib.algorithms.algorithm_config import (
|
| 14 |
+
AlgorithmConfig,
|
| 15 |
+
TorchCompileWhatToCompile,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.core.columns import Columns
|
| 18 |
+
from ray.rllib.core.learner.learner import Learner, LR_KEY
|
| 19 |
+
from ray.rllib.core.rl_module.multi_rl_module import (
|
| 20 |
+
MultiRLModule,
|
| 21 |
+
MultiRLModuleSpec,
|
| 22 |
+
)
|
| 23 |
+
from ray.rllib.core.rl_module.rl_module import (
|
| 24 |
+
RLModule,
|
| 25 |
+
RLModuleSpec,
|
| 26 |
+
)
|
| 27 |
+
from ray.rllib.core.rl_module.torch.torch_rl_module import (
|
| 28 |
+
TorchCompileConfig,
|
| 29 |
+
TorchDDPRLModule,
|
| 30 |
+
TorchRLModule,
|
| 31 |
+
)
|
| 32 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
| 33 |
+
from ray.rllib.utils.annotations import (
|
| 34 |
+
override,
|
| 35 |
+
OverrideToImplementCustomLogic,
|
| 36 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 37 |
+
)
|
| 38 |
+
from ray.rllib.utils.framework import get_device, try_import_torch
|
| 39 |
+
from ray.rllib.utils.metrics import (
|
| 40 |
+
ALL_MODULES,
|
| 41 |
+
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
|
| 42 |
+
NUM_TRAINABLE_PARAMETERS,
|
| 43 |
+
NUM_NON_TRAINABLE_PARAMETERS,
|
| 44 |
+
WEIGHTS_SEQ_NO,
|
| 45 |
+
)
|
| 46 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 47 |
+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
| 48 |
+
from ray.rllib.utils.typing import (
|
| 49 |
+
ModuleID,
|
| 50 |
+
Optimizer,
|
| 51 |
+
Param,
|
| 52 |
+
ParamDict,
|
| 53 |
+
ShouldModuleBeUpdatedFn,
|
| 54 |
+
StateDict,
|
| 55 |
+
TensorType,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
torch, nn = try_import_torch()
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TorchLearner(Learner):
|
| 63 |
+
|
| 64 |
+
framework: str = "torch"
|
| 65 |
+
|
| 66 |
+
def __init__(self, **kwargs):
|
| 67 |
+
super().__init__(**kwargs)
|
| 68 |
+
|
| 69 |
+
# Whether to compile the RL Module of this learner. This implies that the.
|
| 70 |
+
# forward_train method of the RL Module will be compiled. Further more,
|
| 71 |
+
# other forward methods of the RL Module will be compiled on demand.
|
| 72 |
+
# This is assumed to not happen, since other forwrad methods are not expected
|
| 73 |
+
# to be used during training.
|
| 74 |
+
self._torch_compile_forward_train = False
|
| 75 |
+
self._torch_compile_cfg = None
|
| 76 |
+
# Whether to compile the `_uncompiled_update` method of this learner. This
|
| 77 |
+
# implies that everything within `_uncompiled_update` will be compiled,
|
| 78 |
+
# not only the forward_train method of the RL Module.
|
| 79 |
+
# Note that this is experimental.
|
| 80 |
+
# Note that this requires recompiling the forward methods once we add/remove
|
| 81 |
+
# RL Modules.
|
| 82 |
+
self._torch_compile_complete_update = False
|
| 83 |
+
if self.config.torch_compile_learner:
|
| 84 |
+
if (
|
| 85 |
+
self.config.torch_compile_learner_what_to_compile
|
| 86 |
+
== TorchCompileWhatToCompile.COMPLETE_UPDATE
|
| 87 |
+
):
|
| 88 |
+
self._torch_compile_complete_update = True
|
| 89 |
+
self._compiled_update_initialized = False
|
| 90 |
+
else:
|
| 91 |
+
self._torch_compile_forward_train = True
|
| 92 |
+
|
| 93 |
+
self._torch_compile_cfg = TorchCompileConfig(
|
| 94 |
+
torch_dynamo_backend=self.config.torch_compile_learner_dynamo_backend,
|
| 95 |
+
torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Loss scalers for mixed precision training. Map optimizer names to
|
| 99 |
+
# associated torch GradScaler objects.
|
| 100 |
+
self._grad_scalers = None
|
| 101 |
+
if self.config._torch_grad_scaler_class:
|
| 102 |
+
self._grad_scalers = defaultdict(
|
| 103 |
+
lambda: self.config._torch_grad_scaler_class()
|
| 104 |
+
)
|
| 105 |
+
self._lr_schedulers = {}
|
| 106 |
+
self._lr_scheduler_classes = None
|
| 107 |
+
if self.config._torch_lr_scheduler_classes:
|
| 108 |
+
self._lr_scheduler_classes = self.config._torch_lr_scheduler_classes
|
| 109 |
+
|
| 110 |
+
@OverrideToImplementCustomLogic
|
| 111 |
+
@override(Learner)
|
| 112 |
+
def configure_optimizers_for_module(
|
| 113 |
+
self,
|
| 114 |
+
module_id: ModuleID,
|
| 115 |
+
config: "AlgorithmConfig" = None,
|
| 116 |
+
) -> None:
|
| 117 |
+
module = self._module[module_id]
|
| 118 |
+
|
| 119 |
+
# For this default implementation, the learning rate is handled by the
|
| 120 |
+
# attached lr Scheduler (controlled by self.config.lr, which can be a
|
| 121 |
+
# fixed value or a schedule setting).
|
| 122 |
+
params = self.get_parameters(module)
|
| 123 |
+
optimizer = torch.optim.Adam(params)
|
| 124 |
+
|
| 125 |
+
# Register the created optimizer (under the default optimizer name).
|
| 126 |
+
self.register_optimizer(
|
| 127 |
+
module_id=module_id,
|
| 128 |
+
optimizer=optimizer,
|
| 129 |
+
params=params,
|
| 130 |
+
lr_or_lr_schedule=config.lr,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def _uncompiled_update(
|
| 134 |
+
self,
|
| 135 |
+
batch: Dict,
|
| 136 |
+
**kwargs,
|
| 137 |
+
):
|
| 138 |
+
"""Performs a single update given a batch of data."""
|
| 139 |
+
# Activate tensor-mode on our MetricsLogger.
|
| 140 |
+
self.metrics.activate_tensor_mode()
|
| 141 |
+
|
| 142 |
+
# TODO (sven): Causes weird cuda error when WandB is used.
|
| 143 |
+
# Diagnosis thus far:
|
| 144 |
+
# - All peek values during metrics.reduce are non-tensors.
|
| 145 |
+
# - However, in impala.py::training_step(), a tensor does arrive after learner
|
| 146 |
+
# group.update_from_episodes(), so somehow, there is still a race condition
|
| 147 |
+
# possible (learner, which performs the reduce() and learner thread, which
|
| 148 |
+
# performs the logging of tensors into metrics logger).
|
| 149 |
+
self._compute_off_policyness(batch)
|
| 150 |
+
|
| 151 |
+
fwd_out = self.module.forward_train(batch)
|
| 152 |
+
loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch)
|
| 153 |
+
|
| 154 |
+
gradients = self.compute_gradients(loss_per_module)
|
| 155 |
+
postprocessed_gradients = self.postprocess_gradients(gradients)
|
| 156 |
+
self.apply_gradients(postprocessed_gradients)
|
| 157 |
+
|
| 158 |
+
# Deactivate tensor-mode on our MetricsLogger and collect the (tensor)
|
| 159 |
+
# results.
|
| 160 |
+
return fwd_out, loss_per_module, self.metrics.deactivate_tensor_mode()
|
| 161 |
+
|
| 162 |
+
@override(Learner)
|
| 163 |
+
def compute_gradients(
|
| 164 |
+
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
|
| 165 |
+
) -> ParamDict:
|
| 166 |
+
for optim in self._optimizer_parameters:
|
| 167 |
+
# `set_to_none=True` is a faster way to zero out the gradients.
|
| 168 |
+
optim.zero_grad(set_to_none=True)
|
| 169 |
+
|
| 170 |
+
if self._grad_scalers is not None:
|
| 171 |
+
total_loss = sum(
|
| 172 |
+
self._grad_scalers[mid].scale(loss)
|
| 173 |
+
for mid, loss in loss_per_module.items()
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
total_loss = sum(loss_per_module.values())
|
| 177 |
+
|
| 178 |
+
total_loss.backward()
|
| 179 |
+
grads = {pid: p.grad for pid, p in self._params.items()}
|
| 180 |
+
|
| 181 |
+
return grads
|
| 182 |
+
|
| 183 |
+
@override(Learner)
|
| 184 |
+
def apply_gradients(self, gradients_dict: ParamDict) -> None:
|
| 185 |
+
# Set the gradient of the parameters.
|
| 186 |
+
for pid, grad in gradients_dict.items():
|
| 187 |
+
# If updates should not be skipped turn `nan` and `inf` gradients to zero.
|
| 188 |
+
if (
|
| 189 |
+
not torch.isfinite(grad).all()
|
| 190 |
+
and not self.config.torch_skip_nan_gradients
|
| 191 |
+
):
|
| 192 |
+
# Warn the user about `nan` gradients.
|
| 193 |
+
logger.warning(f"Gradients {pid} contain `nan/inf` values.")
|
| 194 |
+
# If updates should be skipped, do not step the optimizer and return.
|
| 195 |
+
if not self.config.torch_skip_nan_gradients:
|
| 196 |
+
logger.warning(
|
| 197 |
+
"Setting `nan/inf` gradients to zero. If updates with "
|
| 198 |
+
"`nan/inf` gradients should not be set to zero and instead "
|
| 199 |
+
"the update be skipped entirely set `torch_skip_nan_gradients` "
|
| 200 |
+
"to `True`."
|
| 201 |
+
)
|
| 202 |
+
# If necessary turn `nan` gradients to zero. Note this can corrupt the
|
| 203 |
+
# internal state of the optimizer, if many `nan` gradients occur.
|
| 204 |
+
self._params[pid].grad = torch.nan_to_num(grad)
|
| 205 |
+
# Otherwise, use the gradient as is.
|
| 206 |
+
else:
|
| 207 |
+
self._params[pid].grad = grad
|
| 208 |
+
|
| 209 |
+
# For each optimizer call its step function.
|
| 210 |
+
for module_id, optimizer_names in self._module_optimizers.items():
|
| 211 |
+
for optimizer_name in optimizer_names:
|
| 212 |
+
optim = self.get_optimizer(module_id, optimizer_name)
|
| 213 |
+
# If we have learning rate schedulers for a module add them, if
|
| 214 |
+
# necessary.
|
| 215 |
+
if self._lr_scheduler_classes is not None:
|
| 216 |
+
if (
|
| 217 |
+
module_id not in self._lr_schedulers
|
| 218 |
+
or optimizer_name not in self._lr_schedulers[module_id]
|
| 219 |
+
):
|
| 220 |
+
# Set for each module and optimizer a scheduler.
|
| 221 |
+
self._lr_schedulers[module_id] = {optimizer_name: []}
|
| 222 |
+
# If the classes are in a dictionary each module might have
|
| 223 |
+
# a different set of schedulers.
|
| 224 |
+
if isinstance(self._lr_scheduler_classes, dict):
|
| 225 |
+
scheduler_classes = self._lr_scheduler_classes[module_id]
|
| 226 |
+
# Else, each module has the same learning rate schedulers.
|
| 227 |
+
else:
|
| 228 |
+
scheduler_classes = self._lr_scheduler_classes
|
| 229 |
+
# Initialize and add the schedulers.
|
| 230 |
+
for scheduler_class in scheduler_classes:
|
| 231 |
+
self._lr_schedulers[module_id][optimizer_name].append(
|
| 232 |
+
scheduler_class(optim)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Step through the scaler (unscales gradients, if applicable).
|
| 236 |
+
if self._grad_scalers is not None:
|
| 237 |
+
scaler = self._grad_scalers[module_id]
|
| 238 |
+
scaler.step(optim)
|
| 239 |
+
self.metrics.log_value(
|
| 240 |
+
(module_id, "_torch_grad_scaler_current_scale"),
|
| 241 |
+
scaler.get_scale(),
|
| 242 |
+
window=1, # snapshot in time, no EMA/mean.
|
| 243 |
+
)
|
| 244 |
+
# Update the scaler.
|
| 245 |
+
scaler.update()
|
| 246 |
+
# `step` the optimizer (default), but only if all gradients are finite.
|
| 247 |
+
elif all(
|
| 248 |
+
param.grad is None or torch.isfinite(param.grad).all()
|
| 249 |
+
for group in optim.param_groups
|
| 250 |
+
for param in group["params"]
|
| 251 |
+
):
|
| 252 |
+
optim.step()
|
| 253 |
+
# If gradients are not all finite warn the user that the update will be
|
| 254 |
+
# skipped.
|
| 255 |
+
elif not all(
|
| 256 |
+
torch.isfinite(param.grad).all()
|
| 257 |
+
for group in optim.param_groups
|
| 258 |
+
for param in group["params"]
|
| 259 |
+
):
|
| 260 |
+
logger.warning(
|
| 261 |
+
"Skipping this update. If updates with `nan/inf` gradients "
|
| 262 |
+
"should not be skipped entirely and instead `nan/inf` "
|
| 263 |
+
"gradients set to `zero` set `torch_skip_nan_gradients` to "
|
| 264 |
+
"`False`."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 268 |
+
@override(Learner)
|
| 269 |
+
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
|
| 270 |
+
"""Called after gradient-based updates are completed.
|
| 271 |
+
|
| 272 |
+
Should be overridden to implement custom cleanup-, logging-, or non-gradient-
|
| 273 |
+
based Learner/RLModule update logic after(!) gradient-based updates have been
|
| 274 |
+
completed.
|
| 275 |
+
|
| 276 |
+
Note, for `framework="torch"` users can register
|
| 277 |
+
`torch.optim.lr_scheduler.LRScheduler` via
|
| 278 |
+
`AlgorithmConfig._torch_lr_scheduler_classes`. These schedulers need to be
|
| 279 |
+
stepped here after gradient updates and reported.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
timesteps: Timesteps dict, which must have the key
|
| 283 |
+
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
|
| 284 |
+
# TODO (sven): Make this a more formal structure with its own type.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
# If we have no `torch.optim.lr_scheduler.LRScheduler` registered call the
|
| 288 |
+
# `super()`'s method to update RLlib's learning rate schedules.
|
| 289 |
+
if not self._lr_schedulers:
|
| 290 |
+
return super().after_gradient_based_update(timesteps=timesteps)
|
| 291 |
+
|
| 292 |
+
# Only update this optimizer's lr, if a scheduler has been registered
|
| 293 |
+
# along with it.
|
| 294 |
+
for module_id, optimizer_names in self._module_optimizers.items():
|
| 295 |
+
for optimizer_name in optimizer_names:
|
| 296 |
+
# If learning rate schedulers are provided step them here. Note,
|
| 297 |
+
# stepping them in `TorchLearner.apply_gradients` updates the
|
| 298 |
+
# learning rates during minibatch updates; we want to update
|
| 299 |
+
# between whole batch updates.
|
| 300 |
+
if (
|
| 301 |
+
module_id in self._lr_schedulers
|
| 302 |
+
and optimizer_name in self._lr_schedulers[module_id]
|
| 303 |
+
):
|
| 304 |
+
for scheduler in self._lr_schedulers[module_id][optimizer_name]:
|
| 305 |
+
scheduler.step()
|
| 306 |
+
optimizer = self.get_optimizer(module_id, optimizer_name)
|
| 307 |
+
self.metrics.log_value(
|
| 308 |
+
# Cut out the module ID from the beginning since it's already
|
| 309 |
+
# part of the key sequence: (ModuleID, "[optim name]_lr").
|
| 310 |
+
key=(
|
| 311 |
+
module_id,
|
| 312 |
+
f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}",
|
| 313 |
+
),
|
| 314 |
+
value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
|
| 315 |
+
window=1,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
@override(Learner)
|
| 319 |
+
def _get_optimizer_state(self) -> StateDict:
|
| 320 |
+
ret = {}
|
| 321 |
+
for name, optim in self._named_optimizers.items():
|
| 322 |
+
ret[name] = {
|
| 323 |
+
"module_id": self._optimizer_name_to_module[name],
|
| 324 |
+
"state": convert_to_numpy(optim.state_dict()),
|
| 325 |
+
}
|
| 326 |
+
return ret
|
| 327 |
+
|
| 328 |
+
@override(Learner)
|
| 329 |
+
def _set_optimizer_state(self, state: StateDict) -> None:
|
| 330 |
+
for name, state_dict in state.items():
|
| 331 |
+
# Ignore updating optimizers matching to submodules not present in this
|
| 332 |
+
# Learner's MultiRLModule.
|
| 333 |
+
module_id = state_dict["module_id"]
|
| 334 |
+
if name not in self._named_optimizers and module_id in self.module:
|
| 335 |
+
self.configure_optimizers_for_module(
|
| 336 |
+
module_id=module_id,
|
| 337 |
+
config=self.config.get_config_for_module(module_id=module_id),
|
| 338 |
+
)
|
| 339 |
+
if name in self._named_optimizers:
|
| 340 |
+
self._named_optimizers[name].load_state_dict(
|
| 341 |
+
convert_to_torch_tensor(state_dict["state"], device=self._device)
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
@override(Learner)
|
| 345 |
+
def get_param_ref(self, param: Param) -> Hashable:
|
| 346 |
+
return param
|
| 347 |
+
|
| 348 |
+
@override(Learner)
|
| 349 |
+
def get_parameters(self, module: RLModule) -> Sequence[Param]:
|
| 350 |
+
return list(module.parameters())
|
| 351 |
+
|
| 352 |
+
@override(Learner)
|
| 353 |
+
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
|
| 354 |
+
batch = convert_to_torch_tensor(batch.policy_batches, device=self._device)
|
| 355 |
+
# TODO (sven): This computation of `env_steps` is not accurate!
|
| 356 |
+
length = max(len(b) for b in batch.values())
|
| 357 |
+
batch = MultiAgentBatch(batch, env_steps=length)
|
| 358 |
+
return batch
|
| 359 |
+
|
| 360 |
+
@override(Learner)
|
| 361 |
+
def add_module(
|
| 362 |
+
self,
|
| 363 |
+
*,
|
| 364 |
+
module_id: ModuleID,
|
| 365 |
+
# TODO (sven): Rename to `rl_module_spec`.
|
| 366 |
+
module_spec: RLModuleSpec,
|
| 367 |
+
config_overrides: Optional[Dict] = None,
|
| 368 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 369 |
+
) -> MultiRLModuleSpec:
|
| 370 |
+
# Call super's add_module method.
|
| 371 |
+
marl_spec = super().add_module(
|
| 372 |
+
module_id=module_id,
|
| 373 |
+
module_spec=module_spec,
|
| 374 |
+
config_overrides=config_overrides,
|
| 375 |
+
new_should_module_be_updated=new_should_module_be_updated,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# we need to ddpify the module that was just added to the pool
|
| 379 |
+
module = self._module[module_id]
|
| 380 |
+
|
| 381 |
+
if self._torch_compile_forward_train:
|
| 382 |
+
module.compile(self._torch_compile_cfg)
|
| 383 |
+
elif self._torch_compile_complete_update:
|
| 384 |
+
# When compiling the update, we need to reset and recompile
|
| 385 |
+
# _uncompiled_update every time we add/remove a module anew.
|
| 386 |
+
torch._dynamo.reset()
|
| 387 |
+
self._compiled_update_initialized = False
|
| 388 |
+
self._possibly_compiled_update = torch.compile(
|
| 389 |
+
self._uncompiled_update,
|
| 390 |
+
backend=self._torch_compile_cfg.torch_dynamo_backend,
|
| 391 |
+
mode=self._torch_compile_cfg.torch_dynamo_mode,
|
| 392 |
+
**self._torch_compile_cfg.kwargs,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if isinstance(module, TorchRLModule):
|
| 396 |
+
self._module[module_id].to(self._device)
|
| 397 |
+
if self.distributed:
|
| 398 |
+
if (
|
| 399 |
+
self._torch_compile_complete_update
|
| 400 |
+
or self._torch_compile_forward_train
|
| 401 |
+
):
|
| 402 |
+
raise ValueError(
|
| 403 |
+
"Using torch distributed and torch compile "
|
| 404 |
+
"together tested for now. Please disable "
|
| 405 |
+
"torch compile."
|
| 406 |
+
)
|
| 407 |
+
self._module.add_module(
|
| 408 |
+
module_id,
|
| 409 |
+
TorchDDPRLModule(module, **self.config.torch_ddp_kwargs),
|
| 410 |
+
override=True,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
self._log_trainable_parameters()
|
| 414 |
+
|
| 415 |
+
return marl_spec
|
| 416 |
+
|
| 417 |
+
@override(Learner)
|
| 418 |
+
def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec:
|
| 419 |
+
marl_spec = super().remove_module(module_id, **kwargs)
|
| 420 |
+
|
| 421 |
+
if self._torch_compile_complete_update:
|
| 422 |
+
# When compiling the update, we need to reset and recompile
|
| 423 |
+
# _uncompiled_update every time we add/remove a module anew.
|
| 424 |
+
torch._dynamo.reset()
|
| 425 |
+
self._compiled_update_initialized = False
|
| 426 |
+
self._possibly_compiled_update = torch.compile(
|
| 427 |
+
self._uncompiled_update,
|
| 428 |
+
backend=self._torch_compile_cfg.torch_dynamo_backend,
|
| 429 |
+
mode=self._torch_compile_cfg.torch_dynamo_mode,
|
| 430 |
+
**self._torch_compile_cfg.kwargs,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
self._log_trainable_parameters()
|
| 434 |
+
|
| 435 |
+
return marl_spec
|
| 436 |
+
|
| 437 |
+
@override(Learner)
|
| 438 |
+
def build(self) -> None:
|
| 439 |
+
"""Builds the TorchLearner.
|
| 440 |
+
|
| 441 |
+
This method is specific to TorchLearner. Before running super() it will
|
| 442 |
+
initialize the device properly based on `self.config`, so that `_make_module()`
|
| 443 |
+
can place the created module on the correct device. After running super() it
|
| 444 |
+
wraps the module in a TorchDDPRLModule if `config.num_learners > 0`.
|
| 445 |
+
Note, in inherited classes it is advisable to call the parent's `build()`
|
| 446 |
+
after setting up all variables because `configure_optimizer_for_module` is
|
| 447 |
+
called in this `Learner.build()`.
|
| 448 |
+
"""
|
| 449 |
+
self._device = get_device(self.config, self.config.num_gpus_per_learner)
|
| 450 |
+
|
| 451 |
+
super().build()
|
| 452 |
+
|
| 453 |
+
if self._torch_compile_complete_update:
|
| 454 |
+
torch._dynamo.reset()
|
| 455 |
+
self._compiled_update_initialized = False
|
| 456 |
+
self._possibly_compiled_update = torch.compile(
|
| 457 |
+
self._uncompiled_update,
|
| 458 |
+
backend=self._torch_compile_cfg.torch_dynamo_backend,
|
| 459 |
+
mode=self._torch_compile_cfg.torch_dynamo_mode,
|
| 460 |
+
**self._torch_compile_cfg.kwargs,
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
if self._torch_compile_forward_train:
|
| 464 |
+
if isinstance(self._module, TorchRLModule):
|
| 465 |
+
self._module.compile(self._torch_compile_cfg)
|
| 466 |
+
elif isinstance(self._module, MultiRLModule):
|
| 467 |
+
for module in self._module._rl_modules.values():
|
| 468 |
+
# Compile only TorchRLModules, e.g. we don't want to compile
|
| 469 |
+
# a RandomRLModule.
|
| 470 |
+
if isinstance(self._module, TorchRLModule):
|
| 471 |
+
module.compile(self._torch_compile_cfg)
|
| 472 |
+
else:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
"Torch compile is only supported for TorchRLModule and "
|
| 475 |
+
"MultiRLModule."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
self._possibly_compiled_update = self._uncompiled_update
|
| 479 |
+
|
| 480 |
+
self._make_modules_ddp_if_necessary()
|
| 481 |
+
|
| 482 |
+
@override(Learner)
|
| 483 |
+
def _update(self, batch: Dict[str, Any]) -> Tuple[Any, Any, Any]:
|
| 484 |
+
# The first time we call _update after building the learner or
|
| 485 |
+
# adding/removing models, we update with the uncompiled update method.
|
| 486 |
+
# This makes it so that any variables that may be created during the first
|
| 487 |
+
# update step are already there when compiling. More specifically,
|
| 488 |
+
# this avoids errors that occur around using defaultdicts with
|
| 489 |
+
# torch.compile().
|
| 490 |
+
if (
|
| 491 |
+
self._torch_compile_complete_update
|
| 492 |
+
and not self._compiled_update_initialized
|
| 493 |
+
):
|
| 494 |
+
self._compiled_update_initialized = True
|
| 495 |
+
return self._uncompiled_update(batch)
|
| 496 |
+
else:
|
| 497 |
+
return self._possibly_compiled_update(batch)
|
| 498 |
+
|
| 499 |
+
@OverrideToImplementCustomLogic
|
| 500 |
+
def _make_modules_ddp_if_necessary(self) -> None:
|
| 501 |
+
"""Default logic for (maybe) making all Modules within self._module DDP."""
|
| 502 |
+
|
| 503 |
+
# If the module is a MultiRLModule and nn.Module we can simply assume
|
| 504 |
+
# all the submodules are registered. Otherwise, we need to loop through
|
| 505 |
+
# each submodule and move it to the correct device.
|
| 506 |
+
# TODO (Kourosh): This can result in missing modules if the user does not
|
| 507 |
+
# register them in the MultiRLModule. We should find a better way to
|
| 508 |
+
# handle this.
|
| 509 |
+
if self.config.num_learners > 1:
|
| 510 |
+
# Single agent module: Convert to `TorchDDPRLModule`.
|
| 511 |
+
if isinstance(self._module, TorchRLModule):
|
| 512 |
+
self._module = TorchDDPRLModule(
|
| 513 |
+
self._module, **self.config.torch_ddp_kwargs
|
| 514 |
+
)
|
| 515 |
+
# Multi agent module: Convert each submodule to `TorchDDPRLModule`.
|
| 516 |
+
else:
|
| 517 |
+
assert isinstance(self._module, MultiRLModule)
|
| 518 |
+
for key in self._module.keys():
|
| 519 |
+
sub_module = self._module[key]
|
| 520 |
+
if isinstance(sub_module, TorchRLModule):
|
| 521 |
+
# Wrap and override the module ID key in self._module.
|
| 522 |
+
self._module.add_module(
|
| 523 |
+
key,
|
| 524 |
+
TorchDDPRLModule(
|
| 525 |
+
sub_module, **self.config.torch_ddp_kwargs
|
| 526 |
+
),
|
| 527 |
+
override=True,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
def rl_module_is_compatible(self, module: RLModule) -> bool:
|
| 531 |
+
return isinstance(module, nn.Module)
|
| 532 |
+
|
| 533 |
+
@override(Learner)
|
| 534 |
+
def _check_registered_optimizer(
|
| 535 |
+
self,
|
| 536 |
+
optimizer: Optimizer,
|
| 537 |
+
params: Sequence[Param],
|
| 538 |
+
) -> None:
|
| 539 |
+
super()._check_registered_optimizer(optimizer, params)
|
| 540 |
+
if not isinstance(optimizer, torch.optim.Optimizer):
|
| 541 |
+
raise ValueError(
|
| 542 |
+
f"The optimizer ({optimizer}) is not a torch.optim.Optimizer! "
|
| 543 |
+
"Only use torch.optim.Optimizer subclasses for TorchLearner."
|
| 544 |
+
)
|
| 545 |
+
for param in params:
|
| 546 |
+
if not isinstance(param, torch.Tensor):
|
| 547 |
+
raise ValueError(
|
| 548 |
+
f"One of the parameters ({param}) in the registered optimizer "
|
| 549 |
+
"is not a torch.Tensor!"
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
@override(Learner)
|
| 553 |
+
def _make_module(self) -> MultiRLModule:
|
| 554 |
+
module = super()._make_module()
|
| 555 |
+
self._map_module_to_device(module)
|
| 556 |
+
return module
|
| 557 |
+
|
| 558 |
+
def _map_module_to_device(self, module: MultiRLModule) -> None:
|
| 559 |
+
"""Moves the module to the correct device."""
|
| 560 |
+
if isinstance(module, torch.nn.Module):
|
| 561 |
+
module.to(self._device)
|
| 562 |
+
else:
|
| 563 |
+
for key in module.keys():
|
| 564 |
+
if isinstance(module[key], torch.nn.Module):
|
| 565 |
+
module[key].to(self._device)
|
| 566 |
+
|
| 567 |
+
@override(Learner)
|
| 568 |
+
def _log_trainable_parameters(self) -> None:
|
| 569 |
+
# Log number of non-trainable and trainable parameters of our RLModule.
|
| 570 |
+
num_trainable_params = {
|
| 571 |
+
(mid, NUM_TRAINABLE_PARAMETERS): sum(
|
| 572 |
+
p.numel() for p in rlm.parameters() if p.requires_grad
|
| 573 |
+
)
|
| 574 |
+
for mid, rlm in self.module._rl_modules.items()
|
| 575 |
+
if isinstance(rlm, TorchRLModule)
|
| 576 |
+
}
|
| 577 |
+
num_non_trainable_params = {
|
| 578 |
+
(mid, NUM_NON_TRAINABLE_PARAMETERS): sum(
|
| 579 |
+
p.numel() for p in rlm.parameters() if not p.requires_grad
|
| 580 |
+
)
|
| 581 |
+
for mid, rlm in self.module._rl_modules.items()
|
| 582 |
+
if isinstance(rlm, TorchRLModule)
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
self.metrics.log_dict(
|
| 586 |
+
{
|
| 587 |
+
**{
|
| 588 |
+
(ALL_MODULES, NUM_TRAINABLE_PARAMETERS): sum(
|
| 589 |
+
num_trainable_params.values()
|
| 590 |
+
),
|
| 591 |
+
(ALL_MODULES, NUM_NON_TRAINABLE_PARAMETERS): sum(
|
| 592 |
+
num_non_trainable_params.values()
|
| 593 |
+
),
|
| 594 |
+
},
|
| 595 |
+
**num_trainable_params,
|
| 596 |
+
**num_non_trainable_params,
|
| 597 |
+
}
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
def _compute_off_policyness(self, batch):
|
| 601 |
+
# Log off-policy'ness of this batch wrt the current weights.
|
| 602 |
+
off_policyness = {
|
| 603 |
+
(mid, DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY): (
|
| 604 |
+
(self._weights_seq_no - module_batch[WEIGHTS_SEQ_NO]).float()
|
| 605 |
+
)
|
| 606 |
+
for mid, module_batch in batch.items()
|
| 607 |
+
if WEIGHTS_SEQ_NO in module_batch
|
| 608 |
+
}
|
| 609 |
+
for key in off_policyness.keys():
|
| 610 |
+
mid = key[0]
|
| 611 |
+
if Columns.LOSS_MASK not in batch[mid]:
|
| 612 |
+
off_policyness[key] = torch.mean(off_policyness[key])
|
| 613 |
+
else:
|
| 614 |
+
mask = batch[mid][Columns.LOSS_MASK]
|
| 615 |
+
num_valid = torch.sum(mask)
|
| 616 |
+
off_policyness[key] = torch.sum(off_policyness[key][mask]) / num_valid
|
| 617 |
+
self.metrics.log_dict(off_policyness, window=1)
|
| 618 |
+
|
| 619 |
+
@override(Learner)
|
| 620 |
+
def _get_tensor_variable(
|
| 621 |
+
self, value, dtype=None, trainable=False
|
| 622 |
+
) -> "torch.Tensor":
|
| 623 |
+
tensor = torch.tensor(
|
| 624 |
+
value,
|
| 625 |
+
requires_grad=trainable,
|
| 626 |
+
device=self._device,
|
| 627 |
+
dtype=(
|
| 628 |
+
dtype
|
| 629 |
+
or (
|
| 630 |
+
torch.float32
|
| 631 |
+
if isinstance(value, float)
|
| 632 |
+
else torch.int32
|
| 633 |
+
if isinstance(value, int)
|
| 634 |
+
else None
|
| 635 |
+
)
|
| 636 |
+
),
|
| 637 |
+
)
|
| 638 |
+
return nn.Parameter(tensor) if trainable else tensor
|
| 639 |
+
|
| 640 |
+
@staticmethod
|
| 641 |
+
@override(Learner)
|
| 642 |
+
def _get_optimizer_lr(optimizer: "torch.optim.Optimizer") -> float:
|
| 643 |
+
for g in optimizer.param_groups:
|
| 644 |
+
return g["lr"]
|
| 645 |
+
|
| 646 |
+
@staticmethod
|
| 647 |
+
@override(Learner)
|
| 648 |
+
def _set_optimizer_lr(optimizer: "torch.optim.Optimizer", lr: float) -> None:
|
| 649 |
+
for g in optimizer.param_groups:
|
| 650 |
+
g["lr"] = lr
|
| 651 |
+
|
| 652 |
+
@staticmethod
|
| 653 |
+
@override(Learner)
|
| 654 |
+
def _get_clip_function() -> Callable:
|
| 655 |
+
from ray.rllib.utils.torch_utils import clip_gradients
|
| 656 |
+
|
| 657 |
+
return clip_gradients
|
| 658 |
+
|
| 659 |
+
@staticmethod
|
| 660 |
+
@override(Learner)
|
| 661 |
+
def _get_global_norm_function() -> Callable:
|
| 662 |
+
from ray.rllib.utils.torch_utils import compute_global_norm
|
| 663 |
+
|
| 664 |
+
return compute_global_norm
|
.venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 4 |
+
from ray.rllib.utils.typing import NetworkType
|
| 5 |
+
from ray.util import PublicAPI
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
torch, _ = try_import_torch()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_target_network(main_net: NetworkType) -> NetworkType:
|
| 12 |
+
"""Creates a (deep) copy of `main_net` (including synched weights) and returns it.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
main_net: The main network to return a target network for
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
The copy of `main_net` that can be used as a target net. Note that the weights
|
| 19 |
+
of the returned net are already synched (identical) with `main_net`.
|
| 20 |
+
"""
|
| 21 |
+
# Deepcopy the main net (this should already take care of synching all weights).
|
| 22 |
+
target_net = copy.deepcopy(main_net)
|
| 23 |
+
# Make the target net not trainable.
|
| 24 |
+
if isinstance(main_net, torch.nn.Module):
|
| 25 |
+
target_net.requires_grad_(False)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unsupported framework for given `main_net` {main_net}!")
|
| 28 |
+
|
| 29 |
+
return target_net
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@PublicAPI(stability="beta")
|
| 33 |
+
def update_target_network(
|
| 34 |
+
*,
|
| 35 |
+
main_net: NetworkType,
|
| 36 |
+
target_net: NetworkType,
|
| 37 |
+
tau: float,
|
| 38 |
+
) -> None:
|
| 39 |
+
"""Updates a target network (from a "main" network) using Polyak averaging.
|
| 40 |
+
|
| 41 |
+
Thereby:
|
| 42 |
+
new_target_net_weight = (
|
| 43 |
+
tau * main_net_weight + (1.0 - tau) * current_target_net_weight
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
main_net: The nn.Module to update from.
|
| 48 |
+
target_net: The target network to update.
|
| 49 |
+
tau: The tau value to use in the Polyak averaging formula. Use 1.0 for a
|
| 50 |
+
complete sync of the weights (target and main net will be the exact same
|
| 51 |
+
after updating).
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(main_net, torch.nn.Module):
|
| 54 |
+
from ray.rllib.utils.torch_utils import update_target_network as _update_target
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Unsupported framework for given `main_net` {main_net}!")
|
| 58 |
+
|
| 59 |
+
_update_target(main_net=main_net, target_net=target_net, tau=tau)
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc
ADDED
|
Binary file (9.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
|
| 5 |
+
from ray.rllib.core.rl_module.multi_rl_module import (
|
| 6 |
+
MultiRLModule,
|
| 7 |
+
MultiRLModuleSpec,
|
| 8 |
+
)
|
| 9 |
+
from ray.util import log_once
|
| 10 |
+
from ray.util.annotations import DeveloperAPI
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("ray.rllib")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@DeveloperAPI
|
| 16 |
+
def validate_module_id(policy_id: str, error: bool = False) -> None:
|
| 17 |
+
"""Makes sure the given `policy_id` is valid.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
policy_id: The Policy ID to check.
|
| 21 |
+
IMPORTANT: Must not contain characters that
|
| 22 |
+
are also not allowed in Unix/Win filesystems, such as: `<>:"/\\|?*`
|
| 23 |
+
or a dot `.` or space ` ` at the end of the ID.
|
| 24 |
+
error: Whether to raise an error (ValueError) or a warning in case of an
|
| 25 |
+
invalid `policy_id`.
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
ValueError: If the given `policy_id` is not a valid one and `error` is True.
|
| 29 |
+
"""
|
| 30 |
+
if (
|
| 31 |
+
not isinstance(policy_id, str)
|
| 32 |
+
or len(policy_id) == 0
|
| 33 |
+
or re.search('[<>:"/\\\\|?]', policy_id)
|
| 34 |
+
or policy_id[-1] in (" ", ".")
|
| 35 |
+
):
|
| 36 |
+
msg = (
|
| 37 |
+
f"PolicyID `{policy_id}` not valid! IDs must be a non-empty string, "
|
| 38 |
+
"must not contain characters that are also disallowed file- or directory "
|
| 39 |
+
"names on Unix/Windows and must not end with a dot `.` or a space ` `."
|
| 40 |
+
)
|
| 41 |
+
if error:
|
| 42 |
+
raise ValueError(msg)
|
| 43 |
+
elif log_once("invalid_policy_id"):
|
| 44 |
+
logger.warning(msg)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
"MultiRLModule",
|
| 49 |
+
"MultiRLModuleSpec",
|
| 50 |
+
"RLModule",
|
| 51 |
+
"RLModuleSpec",
|
| 52 |
+
"validate_module_id",
|
| 53 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc
ADDED
|
Binary file (5.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc
ADDED
|
Binary file (43.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc
ADDED
|
Binary file (36.3 kB). View file
|
|
|