diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2ff5f54fe854f475d62a765cdafff7835cbde7e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6451f2907598efc078e255530e32a5b311a384fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_pipeline_v2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py new file mode 100644 index 0000000000000000000000000000000000000000..80c5003085b17bd586e22e5d39e4dbec34a50c79 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector.py @@ -0,0 +1,478 @@ +"""This file defines base types and common structures for RLlib connectors. +""" + +import abc +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +import gymnasium as gym + +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + AgentConnectorDataType, + AlgorithmConfigDict, + TensorType, +) +from ray.rllib.utils.annotations import OldAPIStack + +if TYPE_CHECKING: + from ray.rllib.policy.policy import Policy + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class ConnectorContext: + """Data bits that may be needed for running connectors. + + Note(jungong) : we need to be really careful with the data fields here. + E.g., everything needs to be serializable, in case we need to fetch them + in a remote setting. + """ + + # TODO(jungong) : figure out how to fetch these in a remote setting. + # Probably from a policy server when initializing a policy client. + + def __init__( + self, + config: AlgorithmConfigDict = None, + initial_states: List[TensorType] = None, + observation_space: gym.Space = None, + action_space: gym.Space = None, + view_requirements: Dict[str, ViewRequirement] = None, + is_policy_recurrent: bool = False, + ): + """Construct a ConnectorContext instance. + + Args: + initial_states: States that are used for constructing + the initial input dict for RNN models. [] if a model is not recurrent. + action_space_struct: a policy's action space, in python + data format. E.g., python dict instead of DictSpace, python tuple + instead of TupleSpace. + """ + self.config = config or {} + self.initial_states = initial_states or [] + self.observation_space = observation_space + self.action_space = action_space + self.view_requirements = view_requirements + self.is_policy_recurrent = is_policy_recurrent + + @staticmethod + def from_policy(policy: "Policy") -> "ConnectorContext": + """Build ConnectorContext from a given policy. + + Args: + policy: Policy + + Returns: + A ConnectorContext instance. + """ + return ConnectorContext( + config=policy.config, + initial_states=policy.get_initial_state(), + observation_space=policy.observation_space, + action_space=policy.action_space, + view_requirements=policy.view_requirements, + is_policy_recurrent=policy.is_recurrent(), + ) + + +@OldAPIStack +class Connector(abc.ABC): + """Connector base class. + + A connector is a step of transformation, of either envrionment data before they + get to a policy, or policy output before it is sent back to the environment. + + Connectors may be training-aware, for example, behave slightly differently + during training and inference. + + All connectors are required to be serializable and implement to_state(). + """ + + def __init__(self, ctx: ConnectorContext): + # Default is training mode. + self._is_training = True + + def in_training(self): + self._is_training = True + + def in_eval(self): + self._is_training = False + + def __str__(self, indentation: int = 0): + return " " * indentation + self.__class__.__name__ + + def to_state(self) -> Tuple[str, Any]: + """Serialize a connector into a JSON serializable Tuple. + + to_state is required, so that all Connectors are serializable. + + Returns: + A tuple of connector's name and its serialized states. + String should match the name used to register the connector, + while state can be any single data structure that contains the + serialized state of the connector. If a connector is stateless, + state can simply be None. + """ + # Must implement by each connector. + return NotImplementedError + + @staticmethod + def from_state(self, ctx: ConnectorContext, params: Any) -> "Connector": + """De-serialize a JSON params back into a Connector. + + from_state is required, so that all Connectors are serializable. + + Args: + ctx: Context for constructing this connector. + params: Serialized states of the connector to be recovered. + + Returns: + De-serialized connector. + """ + # Must implement by each connector. + return NotImplementedError + + +@OldAPIStack +class AgentConnector(Connector): + """Connector connecting user environments to RLlib policies. + + An agent connector transforms a list of agent data in AgentConnectorDataType + format into a new list in the same AgentConnectorDataTypes format. + The input API is designed so agent connectors can have access to all the + agents assigned to a particular policy. + + AgentConnectorDataTypes can be used to specify arbitrary type of env data, + + Example: + + Represent a list of agent data from one env step() call. + + .. testcode:: + + import numpy as np + ac = AgentConnectorDataType( + env_id="env_1", + agent_id=None, + data={ + "agent_1": np.array([1, 2, 3]), + "agent_2": np.array([4, 5, 6]), + } + ) + + Or a single agent data ready to be preprocessed. + + .. testcode:: + + ac = AgentConnectorDataType( + env_id="env_1", + agent_id="agent_1", + data=np.array([1, 2, 3]), + ) + + We can also adapt a simple stateless function into an agent connector by + using register_lambda_agent_connector: + + .. testcode:: + + import numpy as np + from ray.rllib.connectors.agent.lambdas import ( + register_lambda_agent_connector + ) + TimesTwoAgentConnector = register_lambda_agent_connector( + "TimesTwoAgentConnector", lambda data: data * 2 + ) + + # More complicated agent connectors can be implemented by extending this + # AgentConnector class: + + class FrameSkippingAgentConnector(AgentConnector): + def __init__(self, n): + self._n = n + self._frame_count = default_dict(str, default_dict(str, int)) + + def reset(self, env_id: str): + del self._frame_count[env_id] + + def __call__( + self, ac_data: List[AgentConnectorDataType] + ) -> List[AgentConnectorDataType]: + ret = [] + for d in ac_data: + assert d.env_id and d.agent_id, "Skipping works per agent!" + + count = self._frame_count[ac_data.env_id][ac_data.agent_id] + self._frame_count[ac_data.env_id][ac_data.agent_id] = ( + count + 1 + ) + + if count % self._n == 0: + ret.append(d) + return ret + + As shown, an agent connector may choose to emit an empty list to stop input + observations from being further prosessed. + """ + + def reset(self, env_id: str): + """Reset connector state for a specific environment. + + For example, at the end of an episode. + + Args: + env_id: required. ID of a user environment. Required. + """ + pass + + def on_policy_output(self, output: ActionConnectorDataType): + """Callback on agent connector of policy output. + + This is useful for certain connectors, for example RNN state buffering, + where the agent connect needs to be aware of the output of a policy + forward pass. + + Args: + ctx: Context for running this connector call. + output: Env and agent IDs, plus data output from policy forward pass. + """ + pass + + def __call__( + self, acd_list: List[AgentConnectorDataType] + ) -> List[AgentConnectorDataType]: + """Transform a list of data items from env before they reach policy. + + Args: + ac_data: List of env and agent IDs, plus arbitrary data items from + an environment or upstream agent connectors. + + Returns: + A list of transformed data items in AgentConnectorDataType format. + The shape of a returned list does not have to match that of the input list. + An AgentConnector may choose to derive multiple outputs for a single piece + of input data, for example multi-agent obs -> multiple single agent obs. + Agent connectors may also choose to skip emitting certain inputs, + useful for connectors such as frame skipping. + """ + assert isinstance( + acd_list, (list, tuple) + ), "Input to agent connectors are list of AgentConnectorDataType." + # Default implementation. Simply call transform on each agent connector data. + return [self.transform(d) for d in acd_list] + + def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: + """Transform a single agent connector data item. + + Args: + data: Env and agent IDs, plus arbitrary data item from a single agent + of an environment. + + Returns: + A transformed piece of agent connector data. + """ + raise NotImplementedError + + +@OldAPIStack +class ActionConnector(Connector): + """Action connector connects policy outputs including actions, + to user environments. + + An action connector transforms a single piece of policy output in + ActionConnectorDataType format, which is basically PolicyOutputType plus env and + agent IDs. + + Any functions that operate directly on PolicyOutputType can be easily adapted + into an ActionConnector by using register_lambda_action_connector. + + Example: + + .. testcode:: + + from ray.rllib.connectors.action.lambdas import ( + register_lambda_action_connector + ) + ZeroActionConnector = register_lambda_action_connector( + "ZeroActionsConnector", + lambda actions, states, fetches: ( + np.zeros_like(actions), states, fetches + ) + ) + + More complicated action connectors can also be implemented by sub-classing + this ActionConnector class. + """ + + def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + """Transform policy output before they are sent to a user environment. + + Args: + ac_data: Env and agent IDs, plus policy output. + + Returns: + The processed action connector data. + """ + return self.transform(ac_data) + + def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + """Implementation of the actual transform. + + Users should override transform instead of __call__ directly. + + Args: + ac_data: Env and agent IDs, plus policy output. + + Returns: + The processed action connector data. + """ + raise NotImplementedError + + +@OldAPIStack +class ConnectorPipeline(abc.ABC): + """Utility class for quick manipulation of a connector pipeline.""" + + def __init__(self, ctx: ConnectorContext, connectors: List[Connector]): + self.connectors = connectors + + def in_training(self): + for c in self.connectors: + c.in_training() + + def in_eval(self): + for c in self.connectors: + c.in_eval() + + def remove(self, name: str): + """Remove a connector by + + Args: + name: name of the connector to be removed. + """ + idx = -1 + for i, c in enumerate(self.connectors): + if c.__class__.__name__ == name: + idx = i + break + if idx >= 0: + del self.connectors[idx] + logger.info(f"Removed connector {name} from {self.__class__.__name__}.") + else: + logger.warning(f"Trying to remove a non-existent connector {name}.") + + def insert_before(self, name: str, connector: Connector): + """Insert a new connector before connector + + Args: + name: name of the connector before which a new connector + will get inserted. + connector: a new connector to be inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if c.__class__.__name__ == name: + break + if idx < 0: + raise ValueError(f"Can not find connector {name}") + self.connectors.insert(idx, connector) + + logger.info( + f"Inserted {connector.__class__.__name__} before {name} " + f"to {self.__class__.__name__}." + ) + + def insert_after(self, name: str, connector: Connector): + """Insert a new connector after connector + + Args: + name: name of the connector after which a new connector + will get inserted. + connector: a new connector to be inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if c.__class__.__name__ == name: + break + if idx < 0: + raise ValueError(f"Can not find connector {name}") + self.connectors.insert(idx + 1, connector) + + logger.info( + f"Inserted {connector.__class__.__name__} after {name} " + f"to {self.__class__.__name__}." + ) + + def prepend(self, connector: Connector): + """Append a new connector at the beginning of a connector pipeline. + + Args: + connector: a new connector to be appended. + """ + self.connectors.insert(0, connector) + + logger.info( + f"Added {connector.__class__.__name__} to the beginning of " + f"{self.__class__.__name__}." + ) + + def append(self, connector: Connector): + """Append a new connector at the end of a connector pipeline. + + Args: + connector: a new connector to be appended. + """ + self.connectors.append(connector) + + logger.info( + f"Added {connector.__class__.__name__} to the end of " + f"{self.__class__.__name__}." + ) + + def __str__(self, indentation: int = 0): + return "\n".join( + [" " * indentation + self.__class__.__name__] + + [c.__str__(indentation + 4) for c in self.connectors] + ) + + def __getitem__(self, key: Union[str, int, type]): + """Returns a list of connectors that fit 'key'. + + If key is a number n, we return a list with the nth element of this pipeline. + If key is a Connector class or a string matching the class name of a + Connector class, we return a list of all connectors in this pipeline matching + the specified class. + + Args: + key: The key to index by + + Returns: The Connector at index `key`. + """ + # In case key is a class + if not isinstance(key, str): + if isinstance(key, slice): + raise NotImplementedError( + "Slicing of ConnectorPipeline is currently not supported." + ) + elif isinstance(key, int): + return [self.connectors[key]] + elif isinstance(key, type): + results = [] + for c in self.connectors: + if issubclass(c.__class__, key): + results.append(c) + return results + else: + raise NotImplementedError( + "Indexing by {} is currently not supported.".format(type(key)) + ) + + results = [] + for c in self.connectors: + if c.__class__.__name__ == key: + results.append(c) + + return results diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b26c4d0542c52fcc41b0a11e461c8e5bf4860124 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_pipeline_v2.py @@ -0,0 +1,394 @@ +import logging +from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union + +import gymnasium as gym + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.metrics import TIMERS, CONNECTOR_TIMERS +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.typing import EpisodeType, StateDict +from ray.util.annotations import PublicAPI + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class ConnectorPipelineV2(ConnectorV2): + """Utility class for quick manipulation of a connector pipeline.""" + + @override(ConnectorV2) + def recompute_output_observation_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + self._fix_spaces(input_observation_space, input_action_space) + return self.observation_space + + @override(ConnectorV2) + def recompute_output_action_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + self._fix_spaces(input_observation_space, input_action_space) + return self.action_space + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + connectors: Optional[List[ConnectorV2]] = None, + **kwargs, + ): + """Initializes a ConnectorPipelineV2 instance. + + Args: + input_observation_space: The (optional) input observation space for this + connector piece. This is the space coming from a previous connector + piece in the (env-to-module or learner) pipeline or is directly + defined within the gym.Env. + input_action_space: The (optional) input action space for this connector + piece. This is the space coming from a previous connector piece in the + (module-to-env) pipeline or is directly defined within the gym.Env. + connectors: A list of individual ConnectorV2 pieces to be added to this + pipeline during construction. Note that you can always add (or remove) + more ConnectorV2 pieces later on the fly. + """ + self.connectors = [] + + for conn in connectors: + # If we have a `ConnectorV2` instance just append. + if isinstance(conn, ConnectorV2): + self.connectors.append(conn) + # If, we have a class with `args` and `kwargs`, build the instance. + # Note that this way of constructing a pipeline should only be + # used internally when restoring the pipeline state from a + # checkpoint. + elif isinstance(conn, tuple) and len(conn) == 3: + self.connectors.append(conn[0](*conn[1], **conn[2])) + + super().__init__(input_observation_space, input_action_space, **kwargs) + + def __len__(self): + return len(self.connectors) + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + metrics: Optional[MetricsLogger] = None, + **kwargs, + ) -> Any: + """In a pipeline, we simply call each of our connector pieces after each other. + + Each connector piece receives as input the output of the previous connector + piece in the pipeline. + """ + shared_data = shared_data if shared_data is not None else {} + # Loop through connector pieces and call each one with the output of the + # previous one. Thereby, time each connector piece's call. + for connector in self.connectors: + # TODO (sven): Add MetricsLogger to non-Learner components that have a + # LearnerConnector pipeline. + stats = None + if metrics: + stats = metrics.log_time( + kwargs.get("metrics_prefix_key", ()) + + (TIMERS, CONNECTOR_TIMERS, connector.__class__.__name__) + ) + stats.__enter__() + + batch = connector( + rl_module=rl_module, + batch=batch, + episodes=episodes, + explore=explore, + shared_data=shared_data, + metrics=metrics, + # Deprecated arg. + data=batch, + **kwargs, + ) + + if metrics: + stats.__exit__(None, None, None) + + if not isinstance(batch, dict): + raise ValueError( + f"`data` returned by ConnectorV2 {connector} must be a dict! " + f"You returned {batch}. Check your (custom) connectors' " + f"`__call__()` method's return value and make sure you return " + f"the `data` arg passed in (either altered or unchanged)." + ) + + return batch + + def remove(self, name_or_class: Union[str, Type]): + """Remove a single connector piece in this pipeline by its name or class. + + Args: + name: The name of the connector piece to be removed from the pipeline. + """ + idx = -1 + for i, c in enumerate(self.connectors): + if c.__class__.__name__ == name_or_class: + idx = i + break + if idx >= 0: + del self.connectors[idx] + self._fix_spaces(self.input_observation_space, self.input_action_space) + logger.info( + f"Removed connector {name_or_class} from {self.__class__.__name__}." + ) + else: + logger.warning( + f"Trying to remove a non-existent connector {name_or_class}." + ) + + def insert_before( + self, + name_or_class: Union[str, type], + connector: ConnectorV2, + ) -> ConnectorV2: + """Insert a new connector piece before an existing piece (by name or class). + + Args: + name_or_class: Name or class of the connector piece before which `connector` + will get inserted. + connector: The new connector piece to be inserted. + + Returns: + The ConnectorV2 before which `connector` has been inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if ( + isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class + ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class): + break + if idx < 0: + raise ValueError( + f"Can not find connector with name or type '{name_or_class}'!" + ) + next_connector = self.connectors[idx] + + self.connectors.insert(idx, connector) + self._fix_spaces(self.input_observation_space, self.input_action_space) + + logger.info( + f"Inserted {connector.__class__.__name__} before {name_or_class} " + f"to {self.__class__.__name__}." + ) + return next_connector + + def insert_after( + self, + name_or_class: Union[str, Type], + connector: ConnectorV2, + ) -> ConnectorV2: + """Insert a new connector piece after an existing piece (by name or class). + + Args: + name_or_class: Name or class of the connector piece after which `connector` + will get inserted. + connector: The new connector piece to be inserted. + + Returns: + The ConnectorV2 after which `connector` has been inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if ( + isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class + ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class): + break + if idx < 0: + raise ValueError( + f"Can not find connector with name or type '{name_or_class}'!" + ) + prev_connector = self.connectors[idx] + + self.connectors.insert(idx + 1, connector) + self._fix_spaces(self.input_observation_space, self.input_action_space) + + logger.info( + f"Inserted {connector.__class__.__name__} after {name_or_class} " + f"to {self.__class__.__name__}." + ) + + return prev_connector + + def prepend(self, connector: ConnectorV2) -> None: + """Prepend a new connector at the beginning of a connector pipeline. + + Args: + connector: The new connector piece to be prepended to this pipeline. + """ + self.connectors.insert(0, connector) + self._fix_spaces(self.input_observation_space, self.input_action_space) + + logger.info( + f"Added {connector.__class__.__name__} to the beginning of " + f"{self.__class__.__name__}." + ) + + def append(self, connector: ConnectorV2) -> None: + """Append a new connector at the end of a connector pipeline. + + Args: + connector: The new connector piece to be appended to this pipeline. + """ + self.connectors.append(connector) + self._fix_spaces(self.input_observation_space, self.input_action_space) + + logger.info( + f"Added {connector.__class__.__name__} to the end of " + f"{self.__class__.__name__}." + ) + + @override(ConnectorV2) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + state = {} + for conn in self.connectors: + conn_name = type(conn).__name__ + if self._check_component(conn_name, components, not_components): + state[conn_name] = conn.get_state( + components=self._get_subcomponents(conn_name, components), + not_components=self._get_subcomponents(conn_name, not_components), + **kwargs, + ) + return state + + @override(ConnectorV2) + def set_state(self, state: Dict[str, Any]) -> None: + for conn in self.connectors: + conn_name = type(conn).__name__ + if conn_name in state: + conn.set_state(state[conn_name]) + + @override(Checkpointable) + def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]: + return [(type(conn).__name__, conn) for conn in self.connectors] + + # Note that we don't have to override Checkpointable.get_ctor_args_and_kwargs and + # don't have to return the `connectors` c'tor kwarg from there. This is b/c all + # connector pieces in this pipeline are themselves Checkpointable components, + # so they will be properly written into this pipeline's checkpoint. + @override(Checkpointable) + def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: + return ( + (self.input_observation_space, self.input_action_space), # *args + { + "connectors": [ + (type(conn), *conn.get_ctor_args_and_kwargs()) + for conn in self.connectors + ] + }, + ) + + @override(ConnectorV2) + def reset_state(self) -> None: + for conn in self.connectors: + conn.reset_state() + + @override(ConnectorV2) + def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]: + merged_states = {} + if not states: + return merged_states + for i, (key, item) in enumerate(states[0].items()): + state_list = [state[key] for state in states] + conn = self.connectors[i] + merged_states[key] = conn.merge_states(state_list) + return merged_states + + def __repr__(self, indentation: int = 0): + return "\n".join( + [" " * indentation + self.__class__.__name__] + + [c.__str__(indentation + 4) for c in self.connectors] + ) + + def __getitem__( + self, + key: Union[str, int, Type], + ) -> Union[ConnectorV2, List[ConnectorV2]]: + """Returns a single ConnectorV2 or list of ConnectorV2s that fit `key`. + + If key is an int, we return a single ConnectorV2 at that index in this pipeline. + If key is a ConnectorV2 type or a string matching the class name of a + ConnectorV2 in this pipeline, we return a list of all ConnectorV2s in this + pipeline matching the specified class. + + Args: + key: The key to find or to index by. + + Returns: + A single ConnectorV2 or a list of ConnectorV2s matching `key`. + """ + # Key is an int -> Index into pipeline and return. + if isinstance(key, int): + return self.connectors[key] + # Key is a class. + elif isinstance(key, type): + results = [] + for c in self.connectors: + if issubclass(c.__class__, key): + results.append(c) + return results + # Key is a string -> Find connector(s) by name. + elif isinstance(key, str): + results = [] + for c in self.connectors: + if c.name == key: + results.append(c) + return results + # Slicing not supported (yet). + elif isinstance(key, slice): + raise NotImplementedError( + "Slicing of ConnectorPipelineV2 is currently not supported!" + ) + else: + raise NotImplementedError( + f"Indexing ConnectorPipelineV2 by {type(key)} is currently not " + f"supported!" + ) + + @property + def observation_space(self): + if len(self) > 0: + return self.connectors[-1].observation_space + return self._observation_space + + @property + def action_space(self): + if len(self) > 0: + return self.connectors[-1].action_space + return self._action_space + + def _fix_spaces(self, input_observation_space, input_action_space): + if len(self) > 0: + # Fix each connector's input_observation- and input_action space in + # the pipeline. + obs_space = input_observation_space + act_space = input_action_space + for con in self.connectors: + con.input_action_space = act_space + con.input_observation_space = obs_space + obs_space = con.observation_space + act_space = con.action_space diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4b2b86bdc803ca78d35403128a63ba2e9dedd3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/connector_v2.py @@ -0,0 +1,1017 @@ +import abc +from collections import defaultdict +import inspect +from typing import ( + Any, + Callable, + Collection, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, +) + +import gymnasium as gym +import tree + +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.spaces.space_utils import BatchedNdArray +from ray.rllib.utils.typing import AgentID, EpisodeType, ModuleID, StateDict +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ConnectorV2(Checkpointable, abc.ABC): + """Base class defining the API for an individual "connector piece". + + A ConnectorV2 ("connector piece") is usually part of a whole series of connector + pieces within a so-called connector pipeline, which in itself also abides to this + very API. + For example, you might have a connector pipeline consisting of two connector pieces, + A and B, both instances of subclasses of ConnectorV2 and each one performing a + particular transformation on their input data. The resulting connector pipeline + (A->B) itself also abides to this very ConnectorV2 API and could thus be part of yet + another, higher-level connector pipeline, e.g. (A->B)->C->D. + + Any ConnectorV2 instance (individual pieces or several connector pieces in a + pipeline) is a callable and users should override the `__call__()` method. + When called, they take the outputs of a previous connector piece (or an empty dict + if there are no previous pieces) and all the data collected thus far in the + ongoing episode(s) (only applies to connectors used in EnvRunners) or retrieved + from a replay buffer or from an environment sampling step (only applies to + connectors used in Learner pipelines). From this input data, a ConnectorV2 then + performs a transformation step. + + There are 3 types of pipelines any ConnectorV2 piece can belong to: + 1) EnvToModulePipeline: The connector transforms environment data before it gets to + the RLModule. This type of pipeline is used by an EnvRunner for transforming + env output data into RLModule readable data (for the next RLModule forward pass). + For example, such a pipeline would include observation postprocessors, -filters, + or any RNN preparation code related to time-sequences and zero-padding. + 2) ModuleToEnvPipeline: This type of pipeline is used by an + EnvRunner to transform RLModule output data to env readable actions (for the next + `env.step()` call). For example, in case the RLModule only outputs action + distribution parameters (but not actual actions), the ModuleToEnvPipeline would + take care of sampling the actions to be sent back to the end from the + resulting distribution (made deterministic if exploration is off). + 3) LearnerConnectorPipeline: This connector pipeline type transforms data coming + from an `EnvRunner.sample()` call or a replay buffer and will then be sent into the + RLModule's `forward_train()` method in order to compute loss function inputs. + This type of pipeline is used by a Learner worker to transform raw training data + (a batch or a list of episodes) to RLModule readable training data (for the next + RLModule `forward_train()` call). + + Some connectors might be stateful, for example for keeping track of observation + filtering stats (mean and stddev values). Any Algorithm, which uses connectors is + responsible for frequently synchronizing the states of all connectors and connector + pipelines between the EnvRunners (owning the env-to-module and module-to-env + pipelines) and the Learners (owning the Learner pipelines). + """ + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + **kwargs, + ): + """Initializes a ConnectorV2 instance. + + Args: + input_observation_space: The (optional) input observation space for this + connector piece. This is the space coming from a previous connector + piece in the (env-to-module or learner) pipeline or is directly + defined within the gym.Env. + input_action_space: The (optional) input action space for this connector + piece. This is the space coming from a previous connector piece in the + (module-to-env) pipeline or is directly defined within the gym.Env. + **kwargs: Forward API-compatibility kwargs. + """ + self._observation_space = None + self._action_space = None + self._input_observation_space = None + self._input_action_space = None + + self.input_action_space = input_action_space + self.input_observation_space = input_observation_space + + # Store child's constructor args and kwargs for the default + # `get_ctor_args_and_kwargs` implementation (to be able to restore from a + # checkpoint). + if self.__class__.__dict__.get("__init__") is not None: + caller_frame = inspect.stack()[1].frame + arg_info = inspect.getargvalues(caller_frame) + # Separate positional arguments and keyword arguments. + caller_locals = ( + arg_info.locals + ) # Dictionary of all local variables in the caller + self._ctor_kwargs = { + arg: caller_locals[arg] for arg in arg_info.args if arg != "self" + } + else: + self._ctor_kwargs = { + "input_observation_space": self.input_observation_space, + "input_action_space": self.input_action_space, + } + + @OverrideToImplementCustomLogic + def recompute_output_observation_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + """Re-computes a new (output) observation space based on the input spaces. + + This method should be overridden by users to make sure a ConnectorPipelineV2 + knows how the input spaces through its individual ConnectorV2 pieces are being + transformed. + + .. testcode:: + + from gymnasium.spaces import Box, Discrete + import numpy as np + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + from ray.rllib.utils.numpy import one_hot + from ray.rllib.utils.test_utils import check + + class OneHotConnector(ConnectorV2): + def recompute_output_observation_space( + self, + input_observation_space, + input_action_space, + ): + return Box(0.0, 1.0, (input_observation_space.n,), np.float32) + + def __call__( + self, + *, + rl_module, + batch, + episodes, + explore=None, + shared_data=None, + metrics=None, + **kwargs, + ): + assert "obs" in batch + batch["obs"] = one_hot(batch["obs"]) + return batch + + connector = OneHotConnector(input_observation_space=Discrete(2)) + batch = {"obs": np.array([1, 0, 0], np.int32)} + output = connector(rl_module=None, batch=batch, episodes=None) + + check(output, {"obs": np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0]])}) + + If this ConnectorV2 does not change the observation space in any way, leave + this parent method implementation untouched. + + Args: + input_observation_space: The input observation space (either coming from the + environment if `self` is the first connector piece in the pipeline or + from the previous connector piece in the pipeline). + input_action_space: The input action space (either coming from the + environment if `self is the first connector piece in the pipeline or + from the previous connector piece in the pipeline). + + Returns: + The new observation space (after data has passed through this ConnectorV2 + piece). + """ + return self.input_observation_space + + @OverrideToImplementCustomLogic + def recompute_output_action_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + """Re-computes a new (output) action space based on the input space. + + This method should be overridden by users to make sure a ConnectorPipelineV2 + knows how the input spaces through its individual ConnectorV2 pieces are being + transformed. + + If this ConnectorV2 does not change the action space in any way, leave + this parent method implementation untouched. + + Args: + input_observation_space: The input observation space (either coming from the + environment if `self` is the first connector piece in the pipeline or + from the previous connector piece in the pipeline). + input_action_space: The input action space (either coming from the + environment if `self is the first connector piece in the pipeline or + from the previous connector piece in the pipeline). + + Returns: + The new action space (after data has passed through this ConenctorV2 + piece). + """ + return self.input_action_space + + @abc.abstractmethod + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + metrics: Optional[MetricsLogger] = None, + **kwargs, + ) -> Any: + """Method for transforming an input `batch` into an output `batch`. + + Args: + rl_module: The RLModule object that the connector connects to or from. + batch: The input data to be transformed by this connector. Transformations + might either be done in-place or a new structure may be returned. + Note that the information in `batch` will eventually either become the + forward batch for the RLModule (env-to-module and learner connectors) + or the input to the `env.step()` call (module-to-env connectors). Note + that in the first case (`batch` is a forward batch for RLModule), the + information in `batch` will be discarded after that RLModule forward + pass. Any transformation of information (e.g. observation preprocessing) + that you have only done inside `batch` will be lost, unless you have + written it back into the corresponding `episodes` during the connector + pass. + episodes: The list of SingleAgentEpisode or MultiAgentEpisode objects, + each corresponding to one slot in the vector env. Note that episodes + can be read from (e.g. to place information into `batch`), but also + written to. You should only write back (changed, transformed) + information into the episodes, if you want these changes to be + "permanent". For example if you sample from an environment, pick up + observations from the episodes and place them into `batch`, then + transform these observations, and would like to make these + transformations permanent (note that `batch` gets discarded after the + RLModule forward pass), then you have to write the transformed + observations back into the episode to make sure you do not have to + perform the same transformation again on the learner (or replay buffer) + side. The Learner will hence work on the already changed episodes (and + compile the train batch using the Learner connector). + explore: Whether `explore` is currently on. Per convention, if True, the + RLModule's `forward_exploration` method should be called, if False, the + EnvRunner should call `forward_inference` instead. + shared_data: Optional additional context data that needs to be exchanged + between different ConnectorV2 pieces (in the same pipeline) or across + ConnectorV2 pipelines (meaning between env-to-module and module-to-env). + metrics: Optional MetricsLogger instance to log custom metrics to. + kwargs: Forward API-compatibility kwargs. + + Returns: + The transformed connector output. + """ + + @staticmethod + def single_agent_episode_iterator( + episodes: List[EpisodeType], + agents_that_stepped_only: bool = True, + zip_with_batch_column: Optional[Union[List[Any], Dict[Tuple, Any]]] = None, + ) -> Iterator[SingleAgentEpisode]: + """An iterator over a list of episodes yielding always SingleAgentEpisodes. + + In case items in the list are MultiAgentEpisodes, these are broken down + into their individual agents' SingleAgentEpisodes and those are then yielded + one after the other. + + Useful for connectors that operate on both single-agent and multi-agent + episodes. + + Args: + episodes: The list of SingleAgent- or MultiAgentEpisode objects. + agents_that_stepped_only: If True (and multi-agent setup), will only place + items of those agents into the batch that have just stepped in the + actual MultiAgentEpisode (this is checked via a + `MultiAgentEpside.episode.get_agents_to_act()`). Note that this setting + is ignored in a single-agent setups b/c the agent steps at each timestep + regardless. + zip_with_batch_column: If provided, must be a list of batch items + corresponding to the given `episodes` (single agent case) or a dict + mapping (AgentID, ModuleID) tuples to lists of individual batch items + corresponding to this agent/module combination. The iterator will then + yield tuples of SingleAgentEpisode objects (1st item) along with the + data item (2nd item) that this episode was responsible for generating + originally. + + Yields: + All SingleAgentEpisodes in the input list, whereby MultiAgentEpisodes will + be broken down into their individual SingleAgentEpisode components. + """ + list_indices = defaultdict(int) + + # Single-agent case. + if episodes and isinstance(episodes[0], SingleAgentEpisode): + if zip_with_batch_column is not None: + if len(zip_with_batch_column) != len(episodes): + raise ValueError( + "Invalid `zip_with_batch_column` data: Must have the same " + f"length as the list of episodes ({len(episodes)}), but has " + f"length {len(zip_with_batch_column)}!" + ) + # Simple case: Items are stored in lists directly under the column (str) + # key. + if isinstance(zip_with_batch_column, list): + for episode, data in zip(episodes, zip_with_batch_column): + yield episode, data + # Normal single-agent case: Items are stored in dicts under the column + # (str) key. These dicts map (eps_id,)-tuples to lists of individual + # items. + else: + for episode, (eps_id_tuple, data) in zip( + episodes, + zip_with_batch_column.items(), + ): + assert episode.id_ == eps_id_tuple[0] + d = data[list_indices[eps_id_tuple]] + list_indices[eps_id_tuple] += 1 + yield episode, d + else: + for episode in episodes: + yield episode + return + + # Multi-agent case. + for episode in episodes: + for agent_id in ( + episode.get_agents_that_stepped() + if agents_that_stepped_only + else episode.agent_ids + ): + sa_episode = episode.agent_episodes[agent_id] + # for sa_episode in episode.agent_episodes.values(): + if zip_with_batch_column is not None: + key = ( + sa_episode.multi_agent_episode_id, + sa_episode.agent_id, + sa_episode.module_id, + ) + if len(zip_with_batch_column[key]) <= list_indices[key]: + raise ValueError( + "Invalid `zip_with_batch_column` data: Must structurally " + "match the single-agent contents in the given list of " + "(multi-agent) episodes!" + ) + d = zip_with_batch_column[key][list_indices[key]] + list_indices[key] += 1 + yield sa_episode, d + else: + yield sa_episode + + @staticmethod + def add_batch_item( + batch: Dict[str, Any], + column: str, + item_to_add: Any, + single_agent_episode: Optional[SingleAgentEpisode] = None, + ) -> None: + """Adds a data item under `column` to the given `batch`. + + The `item_to_add` is stored in the `batch` in the following manner: + 1) If `single_agent_episode` is not provided (None), will store the item in a + list directly under `column`: + `column` -> [item, item, ...] + 2) If `single_agent_episode`'s `agent_id` and `module_id` properties are None + (`single_agent_episode` is not part of a multi-agent episode), will append + `item_to_add` to a list under a `(,)` key under `column`: + `column` -> `(,)` -> [item, item, ...] + 3) If `single_agent_episode`'s `agent_id` and `module_id` are NOT None + (`single_agent_episode` is part of a multi-agent episode), will append + `item_to_add` to a list under a `(,,)` key + under `column`: + `column` -> `(,,)` -> [item, item, ...] + + See the these examples here for clarification of these three cases: + + .. testcode:: + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + # 1) Simple case (no episodes provided) -> Store data in a list directly + # under `column`: + batch = {} + ConnectorV2.add_batch_item(batch, "test_col", item_to_add=5) + ConnectorV2.add_batch_item(batch, "test_col", item_to_add=6) + check(batch, {"test_col": [5, 6]}) + ConnectorV2.add_batch_item(batch, "test_col_2", item_to_add=-10) + check(batch, { + "test_col": [5, 6], + "test_col_2": [-10], + }) + + # 2) Single-agent case (SingleAgentEpisode provided) -> Store data in a list + # under the keys: `column` -> `(,)` -> [...]: + batch = {} + episode = SingleAgentEpisode( + id_="SA-EPS0", + observations=[0, 1, 2, 3], + actions=[1, 2, 3], + rewards=[1.0, 2.0, 3.0], + ) + ConnectorV2.add_batch_item(batch, "test_col", 5, episode) + ConnectorV2.add_batch_item(batch, "test_col", 6, episode) + ConnectorV2.add_batch_item(batch, "test_col_2", -10, episode) + check(batch, { + "test_col": {("SA-EPS0",): [5, 6]}, + "test_col_2": {("SA-EPS0",): [-10]}, + }) + + # 3) Multi-agent case (SingleAgentEpisode provided that has `agent_id` and + # `module_id` information) -> Store data in a list under the keys: + # `column` -> `(,,)` -> [...]: + batch = {} + ma_episode = MultiAgentEpisode( + id_="MA-EPS1", + observations=[ + {"ag0": 0, "ag1": 1}, {"ag0": 2, "ag1": 4} + ], + actions=[{"ag0": 0, "ag1": 1}], + rewards=[{"ag0": -0.1, "ag1": -0.2}], + # ag0 maps to mod0, ag1 maps to mod1, etc.. + agent_to_module_mapping_fn=lambda aid, eps: f"mod{aid[2:]}", + ) + ConnectorV2.add_batch_item( + batch, + "test_col", + item_to_add=5, + single_agent_episode=ma_episode.agent_episodes["ag0"], + ) + ConnectorV2.add_batch_item( + batch, + "test_col", + item_to_add=6, + single_agent_episode=ma_episode.agent_episodes["ag0"], + ) + ConnectorV2.add_batch_item( + batch, + "test_col_2", + item_to_add=10, + single_agent_episode=ma_episode.agent_episodes["ag1"], + ) + check( + batch, + { + "test_col": {("MA-EPS1", "ag0", "mod0"): [5, 6]}, + "test_col_2": {("MA-EPS1", "ag1", "mod1"): [10]}, + }, + ) + + Args: + batch: The batch to store `item_to_add` in. + column: The column name (str) within the `batch` to store `item_to_add` + under. + item_to_add: The data item to store in the batch. + single_agent_episode: An optional SingleAgentEpisode. + If provided and its `agent_id` and `module_id` properties are None, + creates a further sub dictionary under `column`, mapping from + `(,)` to a list of data items (to which `item_to_add` will + be appended in this call). + If provided and its `agent_id` and `module_id` properties are NOT None, + creates a further sub dictionary under `column`, mapping from + `(,,,)` to a list of data items (to which + `item_to_add` will be appended in this call). + If not provided, will append `item_to_add` to a list directly under + `column`. + """ + sub_key = None + # SAEpisode is provided ... + if single_agent_episode is not None: + module_id = single_agent_episode.module_id + # ... and has `module_id` AND that `module_id` is already a top-level key in + # `batch` (`batch` is already in module-major form, mapping ModuleID to + # columns mapping to data). + if module_id is not None and module_id in batch: + raise ValueError( + "Can't call `add_batch_item` on a `batch` that is already " + "module-major (meaning ModuleID is top-level with column names on " + "the level thereunder)! Make sure to only call `add_batch_items` " + "before the `AgentToModuleMapping` ConnectorV2 piece is applied." + ) + + # ... and has `agent_id` -> Use `single_agent_episode`'s agent ID and + # module ID. + elif single_agent_episode.agent_id is not None: + sub_key = ( + single_agent_episode.multi_agent_episode_id, + single_agent_episode.agent_id, + single_agent_episode.module_id, + ) + # Otherwise, just use episode's ID. + else: + sub_key = (single_agent_episode.id_,) + + if column not in batch: + batch[column] = [] if sub_key is None else {sub_key: []} + if sub_key is not None: + if sub_key not in batch[column]: + batch[column][sub_key] = [] + batch[column][sub_key].append(item_to_add) + else: + batch[column].append(item_to_add) + + @staticmethod + def add_n_batch_items( + batch: Dict[str, Any], + column: str, + items_to_add: Any, + num_items: int, + single_agent_episode: Optional[SingleAgentEpisode] = None, + ) -> None: + """Adds a list of items (or batched item) under `column` to the given `batch`. + + If `items_to_add` is not a list, but an already batched struct (of np.ndarray + leafs), the `items_to_add` will be appended to possibly existing data under the + same `column` as-is. A subsequent `BatchIndividualItems` ConnectorV2 piece will + recognize this and batch the data properly into a single (batched) item. + This is much faster than first splitting up `items_to_add` and then adding each + item individually. + + If `single_agent_episode` is provided and its `agent_id` and `module_id` + properties are None, creates a further sub dictionary under `column`, mapping + from `(,)` to a list of data items (to which `items_to_add` will + be appended in this call). + If `single_agent_episode` is provided and its `agent_id` and `module_id` + properties are NOT None, creates a further sub dictionary under `column`, + mapping from `(,,,)` to a list of data items (to + which `items_to_add` will be appended in this call). + If `single_agent_episode` is not provided, will append `items_to_add` to a list + directly under `column`. + + .. testcode:: + + import numpy as np + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + # Simple case (no episodes provided) -> Store data in a list directly under + # `column`: + batch = {} + ConnectorV2.add_n_batch_items( + batch, + "test_col", + # List of (complex) structs. + [{"a": np.array(3), "b": 4}, {"a": np.array(5), "b": 6}], + num_items=2, + ) + check( + batch["test_col"], + [{"a": np.array(3), "b": 4}, {"a": np.array(5), "b": 6}], + ) + # In a new column (test_col_2), store some already batched items. + # This way, you may avoid having to disassemble an already batched item + # (e.g. a numpy array of shape (10, 2)) into its individual items (e.g. + # split the array into a list of len=10) and then adding these individually. + # The performance gains may be quite large when providing already batched + # items (such as numpy arrays with a batch dim): + ConnectorV2.add_n_batch_items( + batch, + "test_col_2", + # One (complex) already batched struct. + {"a": np.array([3, 5]), "b": np.array([4, 6])}, + num_items=2, + ) + # Add more already batched items (this time with a different batch size) + ConnectorV2.add_n_batch_items( + batch, + "test_col_2", + {"a": np.array([7, 7, 7]), "b": np.array([8, 8, 8])}, + num_items=3, # <- in this case, this must be the batch size + ) + check( + batch["test_col_2"], + [ + {"a": np.array([3, 5]), "b": np.array([4, 6])}, + {"a": np.array([7, 7, 7]), "b": np.array([8, 8, 8])}, + ], + ) + + # Single-agent case (SingleAgentEpisode provided) -> Store data in a list + # under the keys: `column` -> `(,)`: + batch = {} + episode = SingleAgentEpisode( + id_="SA-EPS0", + observations=[0, 1, 2, 3], + actions=[1, 2, 3], + rewards=[1.0, 2.0, 3.0], + ) + ConnectorV2.add_n_batch_items( + batch=batch, + column="test_col", + items_to_add=[5, 6, 7], + num_items=3, + single_agent_episode=episode, + ) + check(batch, { + "test_col": {("SA-EPS0",): [5, 6, 7]}, + }) + + # Multi-agent case (SingleAgentEpisode provided that has `agent_id` and + # `module_id` information) -> Store data in a list under the keys: + # `column` -> `(,,)`: + batch = {} + ma_episode = MultiAgentEpisode( + id_="MA-EPS1", + observations=[ + {"ag0": 0, "ag1": 1}, {"ag0": 2, "ag1": 4} + ], + actions=[{"ag0": 0, "ag1": 1}], + rewards=[{"ag0": -0.1, "ag1": -0.2}], + # ag0 maps to mod0, ag1 maps to mod1, etc.. + agent_to_module_mapping_fn=lambda aid, eps: f"mod{aid[2:]}", + ) + ConnectorV2.add_batch_item( + batch, + "test_col", + item_to_add=5, + single_agent_episode=ma_episode.agent_episodes["ag0"], + ) + ConnectorV2.add_batch_item( + batch, + "test_col", + item_to_add=6, + single_agent_episode=ma_episode.agent_episodes["ag0"], + ) + ConnectorV2.add_batch_item( + batch, + "test_col_2", + item_to_add=10, + single_agent_episode=ma_episode.agent_episodes["ag1"], + ) + check( + batch, + { + "test_col": {("MA-EPS1", "ag0", "mod0"): [5, 6]}, + "test_col_2": {("MA-EPS1", "ag1", "mod1"): [10]}, + }, + ) + + Args: + batch: The batch to store n `items_to_add` in. + column: The column name (str) within the `batch` to store `item_to_add` + under. + items_to_add: The list of data items to store in the batch OR an already + batched (possibly nested) struct. In the latter case, the `items_to_add` + will be appended to possibly existing data under the same `column` + as-is. A subsequent `BatchIndividualItems` ConnectorV2 piece will + recognize this and batch the data properly into a single (batched) item. + This is much faster than first splitting up `items_to_add` and then + adding each item individually. + num_items: The number of items in `items_to_add`. This arg is mostly for + asserting the correct usage of this method by checking, whether the + given data in `items_to_add` really has the right amount of individual + items. + single_agent_episode: An optional SingleAgentEpisode. + If provided and its `agent_id` and `module_id` properties are None, + creates a further sub dictionary under `column`, mapping from + `(,)` to a list of data items (to which `items_to_add` will + be appended in this call). + If provided and its `agent_id` and `module_id` properties are NOT None, + creates a further sub dictionary under `column`, mapping from + `(,,,)` to a list of data items (to which + `items_to_add` will be appended in this call). + If not provided, will append `items_to_add` to a list directly under + `column`. + """ + # Process n list items by calling `add_batch_item` on each of them individually. + if isinstance(items_to_add, list): + if len(items_to_add) != num_items: + raise ValueError( + f"Mismatch between `num_items` ({num_items}) and the length " + f"of the provided list ({len(items_to_add)}) in " + f"{ConnectorV2.__name__}.add_n_batch_items()!" + ) + for item in items_to_add: + ConnectorV2.add_batch_item( + batch=batch, + column=column, + item_to_add=item, + single_agent_episode=single_agent_episode, + ) + return + + # Process a batched (possibly complex) struct. + # We could just unbatch the item (split it into a list) and then add each + # individual item to our `batch`. However, this comes with a heavy performance + # penalty. Instead, we tag the thus added array(s) here as "_has_batch_dim=True" + # and then know that when batching the entire list under the respective + # (eps_id, agent_id, module_id)-tuple key, we need to concatenate, not stack + # the items in there. + def _tag(s): + return BatchedNdArray(s) + + ConnectorV2.add_batch_item( + batch=batch, + column=column, + # Convert given input into BatchedNdArray(s) such that the `batch` utility + # knows that it'll have to concat, not stack. + item_to_add=tree.map_structure(_tag, items_to_add), + single_agent_episode=single_agent_episode, + ) + + @staticmethod + def foreach_batch_item_change_in_place( + batch: Dict[str, Any], + column: Union[str, List[str], Tuple[str]], + func: Callable[ + [Any, Optional[int], Optional[AgentID], Optional[ModuleID]], Any + ], + ) -> None: + """Runs the provided `func` on all items under one or more columns in the batch. + + Use this method to conveniently loop through all items in a batch + and transform them in place. + + `func` takes the following as arguments: + - The item itself. If column is a list of column names, this argument is a tuple + of items. + - The EpisodeID. This value might be None. + - The AgentID. This value might be None in the single-agent case. + - The ModuleID. This value might be None in the single-agent case. + + The return value(s) of `func` are used to directly override the values in the + given `batch`. + + Args: + batch: The batch to process in-place. + column: A single column name (str) or a list thereof. If a list is provided, + the first argument to `func` is a tuple of items. If a single + str is provided, the first argument to `func` is an individual + item. + func: The function to call on each item or tuple of item(s). + + .. testcode:: + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + from ray.rllib.utils.test_utils import check + + # Simple case: Batch items are in lists directly under their column names. + batch = { + "col1": [0, 1, 2, 3], + "col2": [0, -1, -2, -3], + } + # Increase all ints by 1. + ConnectorV2.foreach_batch_item_change_in_place( + batch=batch, + column="col1", + func=lambda item, *args: item + 1, + ) + check(batch["col1"], [1, 2, 3, 4]) + + # Further increase all ints by 1 in col1 and flip sign in col2. + ConnectorV2.foreach_batch_item_change_in_place( + batch=batch, + column=["col1", "col2"], + func=(lambda items, *args: (items[0] + 1, -items[1])), + ) + check(batch["col1"], [2, 3, 4, 5]) + check(batch["col2"], [0, 1, 2, 3]) + + # Single-agent case: Batch items are in lists under (eps_id,)-keys in a dict + # under their column names. + batch = { + "col1": { + ("eps1",): [0, 1, 2, 3], + ("eps2",): [400, 500, 600], + }, + } + # Increase all ints of eps1 by 1 and divide all ints of eps2 by 100. + ConnectorV2.foreach_batch_item_change_in_place( + batch=batch, + column="col1", + func=lambda item, eps_id, *args: ( + item + 1 if eps_id == "eps1" else item / 100 + ), + ) + check(batch["col1"], { + ("eps1",): [1, 2, 3, 4], + ("eps2",): [4, 5, 6], + }) + + # Multi-agent case: Batch items are in lists under + # (eps_id, agent_id, module_id)-keys in a dict + # under their column names. + batch = { + "col1": { + ("eps1", "ag1", "mod1"): [1, 2, 3, 4], + ("eps2", "ag1", "mod2"): [400, 500, 600], + ("eps2", "ag2", "mod3"): [-1, -2, -3, -4, -5], + }, + } + # Decrease all ints of "eps1" by 1, divide all ints of "mod2" by 100, and + # flip sign of all ints of "ag2". + ConnectorV2.foreach_batch_item_change_in_place( + batch=batch, + column="col1", + func=lambda item, eps_id, ag_id, mod_id: ( + item - 1 + if eps_id == "eps1" + else item / 100 + if mod_id == "mod2" + else -item + ), + ) + check(batch["col1"], { + ("eps1", "ag1", "mod1"): [0, 1, 2, 3], + ("eps2", "ag1", "mod2"): [4, 5, 6], + ("eps2", "ag2", "mod3"): [1, 2, 3, 4, 5], + }) + """ + data_to_process = [batch.get(c) for c in force_list(column)] + single_col = isinstance(column, str) + if any(d is None for d in data_to_process): + raise ValueError( + f"Invalid column name(s) ({column})! One or more not found in " + f"given batch. Found columns {list(batch.keys())}." + ) + + # Simple case: Data items are stored in a list directly under the column + # name(s). + if isinstance(data_to_process[0], list): + for list_pos, data_tuple in enumerate(zip(*data_to_process)): + results = func( + data_tuple[0] if single_col else data_tuple, + None, # episode_id + None, # agent_id + None, # module_id + ) + # Tuple'ize results if single_col. + results = (results,) if single_col else results + for col_slot, result in enumerate(force_list(results)): + data_to_process[col_slot][list_pos] = result + # Single-agent/multi-agent cases. + else: + for key, d0_list in data_to_process[0].items(): + # Multi-agent case: There is a dict mapping from a + # (eps id, AgentID, ModuleID)-tuples to lists of individual data items. + if len(key) == 3: + eps_id, agent_id, module_id = key + # Single-agent case: There is a dict mapping from a (eps_id,)-tuple + # to lists of individual data items. + # AgentID and ModuleID are both None. + else: + eps_id = key[0] + agent_id = module_id = None + other_lists = [d[key] for d in data_to_process[1:]] + for list_pos, data_tuple in enumerate(zip(d0_list, *other_lists)): + results = func( + data_tuple[0] if single_col else data_tuple, + eps_id, + agent_id, + module_id, + ) + # Tuple'ize results if single_col. + results = (results,) if single_col else results + for col_slot, result in enumerate(results): + data_to_process[col_slot][key][list_pos] = result + + @staticmethod + def switch_batch_from_column_to_module_ids( + batch: Dict[str, Dict[ModuleID, Any]] + ) -> Dict[ModuleID, Dict[str, Any]]: + """Switches the first two levels of a `col_name -> ModuleID -> data` type batch. + + Assuming that the top level consists of column names as keys and the second + level (under these columns) consists of ModuleID keys, the resulting batch + will have these two reversed and thus map ModuleIDs to dicts mapping column + names to data items. + + .. testcode:: + + from ray.rllib.utils.test_utils import check + + batch = { + "obs": {"module_0": [1, 2, 3]}, + "actions": {"module_0": [4, 5, 6], "module_1": [7]}, + } + switched_batch = ConnectorV2.switch_batch_from_column_to_module_ids(batch) + check( + switched_batch, + { + "module_0": {"obs": [1, 2, 3], "actions": [4, 5, 6]}, + "module_1": {"actions": [7]}, + }, + ) + + Args: + batch: The batch to switch from being column name based (then ModuleIDs) + to being ModuleID based (then column names). + + Returns: + A new batch dict mapping ModuleIDs to dicts mapping column names (e.g. + "obs") to data. + """ + module_data = defaultdict(dict) + for column, column_data in batch.items(): + for module_id, data in column_data.items(): + module_data[module_id][column] = data + return dict(module_data) + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + return {} + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + pass + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: + return ( + (), # *args + self._ctor_kwargs, # **kwargs + ) + + def reset_state(self) -> None: + """Resets the state of this ConnectorV2 to some initial value. + + Note that this may NOT be the exact state that this ConnectorV2 was originally + constructed with. + """ + return + + def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]: + """Computes a resulting state given self's state and a list of other states. + + Algorithms should use this method for merging states between connectors + running on parallel EnvRunner workers. For example, to synchronize the connector + states of n remote workers and a local worker, one could: + - Gather all remote worker connector states in a list. + - Call `self.merge_states()` on the local worker passing it the states list. + - Broadcast the resulting local worker's connector state back to all remote + workers. After this, all workers (including the local one) hold a + merged/synchronized new connecto state. + + Args: + states: The list of n other ConnectorV2 states to merge with self's state + into a single resulting state. + + Returns: + The resulting state dict. + """ + return {} + + @property + def observation_space(self): + """Getter for our (output) observation space. + + Logic: Use user provided space (if set via `observation_space` setter) + otherwise, use the same as the input space, assuming this connector piece + does not alter the space. + """ + return self._observation_space + + @property + def action_space(self): + """Getter for our (output) action space. + + Logic: Use user provided space (if set via `action_space` setter) + otherwise, use the same as the input space, assuming this connector piece + does not alter the space. + """ + return self._action_space + + @property + def input_observation_space(self): + return self._input_observation_space + + @input_observation_space.setter + def input_observation_space(self, value): + self._input_observation_space = value + if value is not None: + self._observation_space = self.recompute_output_observation_space( + value, self.input_action_space + ) + + @property + def input_action_space(self): + return self._input_action_space + + @input_action_space.setter + def input_action_space(self, value): + self._input_action_space = value + if value is not None: + self._action_space = self.recompute_output_action_space( + self.input_observation_space, value + ) + + def __str__(self, indentation: int = 0): + return " " * indentation + self.__class__.__name__ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d513e596446c3d51ee0b0a065ca4cc434caf65ed --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__init__.py @@ -0,0 +1,40 @@ +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.common.add_states_from_episodes_to_batch import ( + AddStatesFromEpisodesToBatch, +) +from ray.rllib.connectors.common.add_time_dim_to_batch_and_zero_pad import ( + AddTimeDimToBatchAndZeroPad, +) +from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping +from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems +from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor +from ray.rllib.connectors.env_to_module.env_to_module_pipeline import ( + EnvToModulePipeline, +) +from ray.rllib.connectors.env_to_module.flatten_observations import ( + FlattenObservations, +) +from ray.rllib.connectors.env_to_module.mean_std_filter import MeanStdFilter +from ray.rllib.connectors.env_to_module.prev_actions_prev_rewards import ( + PrevActionsPrevRewards, +) +from ray.rllib.connectors.env_to_module.write_observations_to_episodes import ( + WriteObservationsToEpisodes, +) + + +__all__ = [ + "AddObservationsFromEpisodesToBatch", + "AddStatesFromEpisodesToBatch", + "AddTimeDimToBatchAndZeroPad", + "AgentToModuleMapping", + "BatchIndividualItems", + "EnvToModulePipeline", + "FlattenObservations", + "MeanStdFilter", + "NumpyToTensor", + "PrevActionsPrevRewards", + "WriteObservationsToEpisodes", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b79fd6ecd958b73af34b978b72f08e5a6e84a932 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/mean_std_filter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b035056a6ffa604ef61a0fa2879cb1c44c979737 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/__pycache__/prev_actions_prev_rewards.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..04a5389737e1ffded64793b77b48207687ad6af0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, List, Optional + +from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics import ( + ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN, + ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class EnvToModulePipeline(ConnectorPipelineV2): + @override(ConnectorPipelineV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Optional[Dict[str, Any]] = None, + episodes: List[EpisodeType], + explore: bool, + shared_data: Optional[dict] = None, + metrics: Optional[MetricsLogger] = None, + **kwargs, + ): + # Log the sum of lengths of all episodes incoming. + if metrics: + metrics.log_value( + ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN, + sum(map(len, episodes)), + ) + + # Make sure user does not necessarily send initial input into this pipeline. + # Might just be empty and to be populated from `episodes`. + ret = super().__call__( + rl_module=rl_module, + batch=batch if batch is not None else {}, + episodes=episodes, + explore=explore, + shared_data=shared_data if shared_data is not None else {}, + metrics=metrics, + **kwargs, + ) + + # Log the sum of lengths of all episodes outgoing. + if metrics: + metrics.log_value( + ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT, + sum(map(len, episodes)), + ) + + return ret diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py new file mode 100644 index 0000000000000000000000000000000000000000..986341685d5e4dc8ff7bce163e49c5d1fe22308e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/flatten_observations.py @@ -0,0 +1,208 @@ +from typing import Any, Collection, Dict, List, Optional + +import gymnasium as gym +from gymnasium.spaces import Box +import numpy as np +import tree # pip install dm_tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import AgentID, EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class FlattenObservations(ConnectorV2): + """A connector piece that flattens all observation components into a 1D array. + + - Should be used only in env-to-module pipelines. + - Works directly on the incoming episodes list and changes the last observation + in-place (write the flattened observation back into the episode). + - This connector does NOT alter the incoming batch (`data`) when called. + - This connector does NOT work in a `LearnerConnectorPipeline` because it requires + the incoming episodes to still be ongoing (in progress) as it only alters the + latest observation, not all observations in an episode. + + .. testcode:: + + import gymnasium as gym + import numpy as np + + from ray.rllib.connectors.env_to_module import FlattenObservations + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + # Some arbitrarily nested, complex observation space. + obs_space = gym.spaces.Dict({ + "a": gym.spaces.Box(-10.0, 10.0, (), np.float32), + "b": gym.spaces.Tuple([ + gym.spaces.Discrete(2), + gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32), + ]), + "c": gym.spaces.MultiDiscrete([2, 3]), + }) + act_space = gym.spaces.Discrete(2) + + # Two example episodes, both with initial (reset) observations coming from the + # above defined observation space. + episode_1 = SingleAgentEpisode( + observations=[ + { + "a": np.array(-10.0, np.float32), + "b": (1, np.array([[-1.0], [-1.0]], np.float32)), + "c": np.array([0, 2]), + }, + ], + ) + episode_2 = SingleAgentEpisode( + observations=[ + { + "a": np.array(10.0, np.float32), + "b": (0, np.array([[1.0], [1.0]], np.float32)), + "c": np.array([1, 1]), + }, + ], + ) + + # Construct our connector piece. + connector = FlattenObservations(obs_space, act_space) + + # Call our connector piece with the example data. + output_batch = connector( + rl_module=None, # This connector works without an RLModule. + batch={}, # This connector does not alter the input batch. + episodes=[episode_1, episode_2], + explore=True, + shared_data={}, + ) + + # The connector does not alter the data and acts as pure pass-through. + check(output_batch, {}) + + # The connector has flattened each item in the episodes to a 1D tensor. + check( + episode_1.get_observations(0), + # box() disc(2). box(2, 1). multidisc(2, 3)........ + np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]), + ) + check( + episode_2.get_observations(0), + # box() disc(2). box(2, 1). multidisc(2, 3)........ + np.array([10.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]), + ) + """ + + @override(ConnectorV2) + def recompute_output_observation_space( + self, + input_observation_space, + input_action_space, + ) -> gym.Space: + self._input_obs_base_struct = get_base_struct_from_space( + self.input_observation_space + ) + if self._multi_agent: + spaces = {} + for agent_id, space in self._input_obs_base_struct.items(): + if self._agent_ids and agent_id not in self._agent_ids: + spaces[agent_id] = self._input_obs_base_struct[agent_id] + else: + sample = flatten_inputs_to_1d_tensor( + tree.map_structure( + lambda s: s.sample(), + self._input_obs_base_struct[agent_id], + ), + self._input_obs_base_struct[agent_id], + batch_axis=False, + ) + spaces[agent_id] = Box( + float("-inf"), float("inf"), (len(sample),), np.float32 + ) + return gym.spaces.Dict(spaces) + else: + sample = flatten_inputs_to_1d_tensor( + tree.map_structure( + lambda s: s.sample(), + self._input_obs_base_struct, + ), + self._input_obs_base_struct, + batch_axis=False, + ) + return Box(float("-inf"), float("inf"), (len(sample),), np.float32) + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + multi_agent: bool = False, + agent_ids: Optional[Collection[AgentID]] = None, + **kwargs, + ): + """Initializes a FlattenObservations instance. + + Args: + multi_agent: Whether this connector operates on multi-agent observations, + in which case, the top-level of the Dict space (where agent IDs are + mapped to individual agents' observation spaces) is left as-is. + agent_ids: If multi_agent is True, this argument defines a collection of + AgentIDs for which to flatten. AgentIDs not in this collection are + ignored. + If None, flatten observations for all AgentIDs. None is the default. + """ + self._input_obs_base_struct = None + self._multi_agent = multi_agent + self._agent_ids = agent_ids + + super().__init__(input_observation_space, input_action_space, **kwargs) + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + for sa_episode in self.single_agent_episode_iterator( + episodes, agents_that_stepped_only=True + ): + last_obs = sa_episode.get_observations(-1) + + if self._multi_agent: + if ( + self._agent_ids is not None + and sa_episode.agent_id not in self._agent_ids + ): + flattened_obs = last_obs + else: + flattened_obs = flatten_inputs_to_1d_tensor( + inputs=last_obs, + # In the multi-agent case, we need to use the specific agent's + # space struct, not the multi-agent observation space dict. + spaces_struct=self._input_obs_base_struct[sa_episode.agent_id], + # Our items are individual observations (no batch axis present). + batch_axis=False, + ) + else: + flattened_obs = flatten_inputs_to_1d_tensor( + inputs=last_obs, + spaces_struct=self._input_obs_base_struct, + # Our items are individual observations (no batch axis present). + batch_axis=False, + ) + + # Write new observation directly back into the episode. + sa_episode.set_observations(at_indices=-1, new_data=flattened_obs) + # We set the Episode's observation space to ours so that we can safely + # set the last obs to the new value (without causing a space mismatch + # error). + sa_episode.observation_space = self.observation_space + + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py new file mode 100644 index 0000000000000000000000000000000000000000..25c12fa4526a8cfb090360ab704e31ab889467da --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/frame_stacking.py @@ -0,0 +1,6 @@ +from functools import partial + +from ray.rllib.connectors.common.frame_stacking import _FrameStacking + + +FrameStackingEnvToModule = partial(_FrameStacking, as_learner_connector=False) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..39a452657f5dc6a153de71ee75463b85abcd79f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/mean_std_filter.py @@ -0,0 +1,253 @@ +from typing import Any, Collection, Dict, List, Optional, Union + +import gymnasium as gym +from gymnasium.spaces import Discrete, MultiDiscrete +import numpy as np +import tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.filter import MeanStdFilter as _MeanStdFilter, RunningStat +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import AgentID, EpisodeType, StateDict +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class MeanStdFilter(ConnectorV2): + """A connector used to mean-std-filter observations. + + Incoming observations are filtered such that the output of this filter is on + average 0.0 and has a standard deviation of 1.0. If the observation space is + a (possibly nested) dict, this filtering is applied separately per element of + the observation space (except for discrete- and multi-discrete elements, which + are left as-is). + + This connector is stateful as it continues to update its internal stats on mean + and std values as new data is pushed through it (unless `update_stats` is False). + """ + + @override(ConnectorV2) + def recompute_output_observation_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + _input_observation_space_struct = get_base_struct_from_space( + input_observation_space + ) + + # Adjust our observation space's Boxes (only if clipping is active). + _observation_space_struct = tree.map_structure( + lambda s: ( + s + if not isinstance(s, gym.spaces.Box) + else gym.spaces.Box( + low=-self.clip_by_value, + high=self.clip_by_value, + shape=s.shape, + dtype=s.dtype, + ) + ), + _input_observation_space_struct, + ) + if isinstance(input_observation_space, (gym.spaces.Dict, gym.spaces.Tuple)): + return type(input_observation_space)(_observation_space_struct) + else: + return _observation_space_struct + + def __init__( + self, + *, + multi_agent: bool = False, + de_mean_to_zero: bool = True, + de_std_to_one: bool = True, + clip_by_value: Optional[float] = 10.0, + update_stats: bool = True, + **kwargs, + ): + """Initializes a MeanStdFilter instance. + + Args: + multi_agent: Whether this is a connector operating on a multi-agent + observation space mapping AgentIDs to individual agents' observations. + de_mean_to_zero: Whether to transform the mean values of the output data to + 0.0. This is done by subtracting the incoming data by the currently + stored mean value. + de_std_to_one: Whether to transform the standard deviation values of the + output data to 1.0. This is done by dividing the incoming data by the + currently stored std value. + clip_by_value: If not None, clip the incoming data within the interval: + [-clip_by_value, +clip_by_value]. + update_stats: Whether to update the internal mean and std stats with each + incoming sample (with each `__call__()`) or not. You should set this to + False if you would like to perform inference in a production + environment, without continuing to "learn" stats from new data. + """ + super().__init__(**kwargs) + + self._multi_agent = multi_agent + + # We simply use the old MeanStdFilter until non-connector env_runner is fully + # deprecated to avoid duplicate code + self.de_mean_to_zero = de_mean_to_zero + self.de_std_to_one = de_std_to_one + self.clip_by_value = clip_by_value + self._update_stats = update_stats + + self._filters: Optional[Dict[AgentID, _MeanStdFilter]] = None + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + persistent_data: Optional[dict] = None, + **kwargs, + ) -> Any: + if self._filters is None: + self._init_new_filters() + + # This connector acts as a classic preprocessor. We process and then replace + # observations inside the episodes directly. Thus, all following connectors + # will only see and operate on the already normalized data (w/o having access + # anymore to the original observations). + for sa_episode in self.single_agent_episode_iterator(episodes): + sa_obs = sa_episode.get_observations(indices=-1) + try: + normalized_sa_obs = self._filters[sa_episode.agent_id]( + sa_obs, update=self._update_stats + ) + except KeyError: + raise KeyError( + "KeyError trying to access a filter by agent ID " + f"`{sa_episode.agent_id}`! You probably did NOT pass the " + f"`multi_agent=True` flag into the `MeanStdFilter()` constructor. " + ) + sa_episode.set_observations(at_indices=-1, new_data=normalized_sa_obs) + # We set the Episode's observation space to ours so that we can safely + # set the last obs to the new value (without causing a space mismatch + # error). + sa_episode.observation_space = self.observation_space + + # Leave `batch` as is. RLlib's default connector will automatically + # populate the OBS column therein from the episodes' now transformed + # observations. + return batch + + @override(ConnectorV2) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + if self._filters is None: + self._init_new_filters() + return self._get_state_from_filters(self._filters) + + @override(ConnectorV2) + def set_state(self, state: StateDict) -> None: + if self._filters is None: + self._init_new_filters() + for agent_id, agent_state in state.items(): + filter = self._filters[agent_id] + filter.shape = agent_state["shape"] + filter.demean = agent_state["de_mean_to_zero"] + filter.destd = agent_state["de_std_to_one"] + filter.clip = agent_state["clip_by_value"] + filter.running_stats = tree.unflatten_as( + filter.shape, + [RunningStat.from_state(s) for s in agent_state["running_stats"]], + ) + # Do not update the buffer. + + @override(ConnectorV2) + def reset_state(self) -> None: + """Creates copy of current state and resets accumulated state""" + if not self._update_stats: + raise ValueError( + f"State of {type(self).__name__} can only be changed when " + f"`update_stats` was set to False." + ) + self._init_new_filters() + + @override(ConnectorV2) + def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]: + if self._filters is None: + self._init_new_filters() + + # Make sure data is uniform across given states. + ref = next(iter(states[0].values())) + + for state in states: + for agent_id, agent_state in state.items(): + assert ( + agent_state["shape"] == ref["shape"] + and agent_state["de_mean_to_zero"] == ref["de_mean_to_zero"] + and agent_state["de_std_to_one"] == ref["de_std_to_one"] + and agent_state["clip_by_value"] == ref["clip_by_value"] + ) + + _filter = _MeanStdFilter( + ref["shape"], + demean=ref["de_mean_to_zero"], + destd=ref["de_std_to_one"], + clip=ref["clip_by_value"], + ) + # Override running stats of the filter with the ones stored in + # `agent_state`. + _filter.buffer = tree.unflatten_as( + agent_state["shape"], + [ + RunningStat.from_state(stats) + for stats in agent_state["running_stats"] + ], + ) + + # Leave the buffers as-is, since they should always only reflect + # what has happened on the particular env runner. + self._filters[agent_id].apply_changes(_filter, with_buffer=False) + + return MeanStdFilter._get_state_from_filters(self._filters) + + def _init_new_filters(self): + filter_shape = tree.map_structure( + lambda s: ( + None if isinstance(s, (Discrete, MultiDiscrete)) else np.array(s.shape) + ), + get_base_struct_from_space(self.input_observation_space), + ) + if not self._multi_agent: + filter_shape = {None: filter_shape} + + del self._filters + self._filters = { + agent_id: _MeanStdFilter( + agent_filter_shape, + demean=self.de_mean_to_zero, + destd=self.de_std_to_one, + clip=self.clip_by_value, + ) + for agent_id, agent_filter_shape in filter_shape.items() + } + + @staticmethod + def _get_state_from_filters(filters: Dict[AgentID, Dict[str, Any]]): + ret = {} + for agent_id, agent_filter in filters.items(): + ret[agent_id] = { + "shape": agent_filter.shape, + "de_mean_to_zero": agent_filter.demean, + "de_std_to_one": agent_filter.destd, + "clip_by_value": agent_filter.clip, + "running_stats": [ + s.to_state() for s in tree.flatten(agent_filter.running_stats) + ], + } + return ret diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..120099ffe50b82f74490b71102b942dce1ad7e78 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/observation_preprocessor.py @@ -0,0 +1,80 @@ +import abc +from typing import Any, Dict, List, Optional + +import gymnasium as gym + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ObservationPreprocessor(ConnectorV2, abc.ABC): + """Env-to-module connector performing one preprocessor step on the last observation. + + This is a convenience class that simplifies the writing of few-step preprocessor + connectors. + + Users must implement the `preprocess()` method, which simplifies the usual procedure + of extracting some data from a list of episodes and adding it to the batch to a mere + "old-observation --transform--> return new-observation" step. + """ + + @override(ConnectorV2) + def recompute_output_observation_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + # Users should override this method only in case the `ObservationPreprocessor` + # changes the observation space of the pipeline. In this case, return the new + # observation space based on the incoming one (`input_observation_space`). + return super().recompute_output_observation_space( + input_observation_space, input_action_space + ) + + @abc.abstractmethod + def preprocess(self, observation): + """Override to implement the preprocessing logic. + + Args: + observation: A single (non-batched) observation item for a single agent to + be processed by this connector. + + Returns: + The new observation after `observation` has been preprocessed. + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + persistent_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # We process and then replace observations inside the episodes directly. + # Thus, all following connectors will only see and operate on the already + # processed observation (w/o having access anymore to the original + # observations). + for sa_episode in self.single_agent_episode_iterator(episodes): + observation = sa_episode.get_observations(-1) + + # Process the observation and write the new observation back into the + # episode. + new_observation = self.preprocess(observation=observation) + sa_episode.set_observations(at_indices=-1, new_data=new_observation) + # We set the Episode's observation space to ours so that we can safely + # set the last obs to the new value (without causing a space mismatch + # error). + sa_episode.observation_space = self.observation_space + + # Leave `batch` as is. RLlib's default connector will automatically + # populate the OBS column therein from the episodes' now transformed + # observations. + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..35e29d02a5214938ee175129eb49787a829b0df4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/prev_actions_prev_rewards.py @@ -0,0 +1,168 @@ +from typing import Any, Dict, List, Optional + +import gymnasium as gym +from gymnasium.spaces import Box +import numpy as np + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import ( + batch as batch_fn, + flatten_to_single_ndarray, +) +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class PrevActionsPrevRewards(ConnectorV2): + """A connector piece that adds previous rewards and actions to the input obs. + + - Requires Columns.OBS to be already a part of the batch. + - This connector makes the assumption that under the Columns.OBS key in batch, + there is either a list of individual env observations to be flattened (single-agent + case) or a dict mapping (AgentID, ModuleID)-tuples to lists of data items to be + flattened (multi-agent case). + - Converts Columns.OBS data into a dict (or creates a sub-dict if obs are + already a dict), and adds "prev_rewards" and "prev_actions" + to this dict. The original observations are stored under the self.ORIG_OBS_KEY in + that dict. + - If your RLModule does not handle dict inputs, you will have to plug in an + `FlattenObservations` connector piece after this one. + - Does NOT work in a Learner pipeline as it operates on individual observation + items (as opposed to batched/time-ranked data). + - Therefore, assumes that the altered (flattened) observations will be written + back into the episode by a later connector piece in the env-to-module pipeline + (which this piece is part of as well). + - Only reads reward- and action information from the given list of Episode objects. + - Does NOT write any observations (or other data) to the given Episode objects. + """ + + ORIG_OBS_KEY = "_orig_obs" + PREV_ACTIONS_KEY = "prev_n_actions" + PREV_REWARDS_KEY = "prev_n_rewards" + + @override(ConnectorV2) + def recompute_output_observation_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + if self._multi_agent: + ret = {} + for agent_id, obs_space in input_observation_space.spaces.items(): + act_space = input_action_space[agent_id] + ret[agent_id] = self._convert_individual_space(obs_space, act_space) + return gym.spaces.Dict(ret) + else: + return self._convert_individual_space( + input_observation_space, input_action_space + ) + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + multi_agent: bool = False, + n_prev_actions: int = 1, + n_prev_rewards: int = 1, + **kwargs, + ): + """Initializes a PrevActionsPrevRewards instance. + + Args: + multi_agent: Whether this is a connector operating on a multi-agent + observation space mapping AgentIDs to individual agents' observations. + n_prev_actions: The number of previous actions to include in the output + data. Discrete actions are ont-hot'd. If > 1, will concatenate the + individual action tensors. + n_prev_rewards: The number of previous rewards to include in the output + data. + """ + super().__init__( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + **kwargs, + ) + + self._multi_agent = multi_agent + self.n_prev_actions = n_prev_actions + self.n_prev_rewards = n_prev_rewards + + # TODO: Move into input_observation_space setter + # Thus far, this connector piece only operates on discrete action spaces. + # act_spaces = [self.input_action_space] + # if self._multi_agent: + # act_spaces = self.input_action_space.spaces.values() + # if not all(isinstance(s, gym.spaces.Discrete) for s in act_spaces): + # raise ValueError( + # f"{type(self).__name__} only works on Discrete action spaces " + # f"thus far (or, for multi-agent, on Dict spaces mapping AgentIDs to " + # f"the individual agents' Discrete action spaces)!" + # ) + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Optional[Dict[str, Any]], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + for sa_episode in self.single_agent_episode_iterator( + episodes, agents_that_stepped_only=True + ): + # Episode is not numpy'ized yet and thus still operates on lists of items. + assert not sa_episode.is_numpy + + augmented_obs = {self.ORIG_OBS_KEY: sa_episode.get_observations(-1)} + + if self.n_prev_actions: + augmented_obs[self.PREV_ACTIONS_KEY] = flatten_to_single_ndarray( + batch_fn( + sa_episode.get_actions( + indices=slice(-self.n_prev_actions, None), + fill=0.0, + one_hot_discrete=True, + ) + ) + ) + + if self.n_prev_rewards: + augmented_obs[self.PREV_REWARDS_KEY] = np.array( + sa_episode.get_rewards( + indices=slice(-self.n_prev_rewards, None), + fill=0.0, + ) + ) + + # Write new observation directly back into the episode. + sa_episode.set_observations(at_indices=-1, new_data=augmented_obs) + # We set the Episode's observation space to ours so that we can safely + # set the last obs to the new value (without causing a space mismatch + # error). + sa_episode.observation_space = self.observation_space + + return batch + + def _convert_individual_space(self, obs_space, act_space): + return gym.spaces.Dict( + { + self.ORIG_OBS_KEY: obs_space, + # Currently only works for Discrete action spaces. + self.PREV_ACTIONS_KEY: Box( + 0.0, 1.0, (act_space.n * self.n_prev_actions,), np.float32 + ), + self.PREV_REWARDS_KEY: Box( + float("-inf"), + float("inf"), + (self.n_prev_rewards,), + np.float32, + ), + } + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py new file mode 100644 index 0000000000000000000000000000000000000000..9b92da4984fd20cb8fe704692e1473eab4b925a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/env_to_module/write_observations_to_episodes.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, List, Optional + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class WriteObservationsToEpisodes(ConnectorV2): + """Writes the observations from the batch into the running episodes. + + Note: This is one of the default env-to-module ConnectorV2 pieces that are added + automatically by RLlib into every env-to-module connector pipelines, unless + `config.add_default_connectors_to_env_to_module_pipeline` is set to False. + + The default env-to-module connector pipeline is: + [ + [0 or more user defined ConnectorV2 pieces], + AddObservationsFromEpisodesToBatch, + AddStatesFromEpisodesToBatch, + AgentToModuleMapping, # only in multi-agent setups! + BatchIndividualItems, + NumpyToTensor, + ] + + This ConnectorV2: + - Operates on a batch that already has observations in it and a list of Episode + objects. + - Writes the observation(s) from the batch to all the given episodes. Thereby + the number of observations in the batch must match the length of the list of + episodes given. + - Does NOT alter any observations (or other data) in the batch. + - Can only be used in an EnvToModule pipeline (writing into Episode objects in a + Learner pipeline does not make a lot of sense as - after the learner update - the + list of episodes is discarded). + + .. testcode:: + + import gymnasium as gym + import numpy as np + + from ray.rllib.connectors.env_to_module import WriteObservationsToEpisodes + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + # Assume we have two episodes (vectorized), then our forward batch will carry + # two observation records (batch size = 2). + # The connector in this example will write these two (possibly transformed) + # observations back into the two respective SingleAgentEpisode objects. + batch = { + "obs": [np.array([0.0, 1.0], np.float32), np.array([2.0, 3.0], np.float32)], + } + + # Our two episodes have one observation each (i.e. the reset one). This is the + # one that will be overwritten by the connector in this example. + obs_space = gym.spaces.Box(-10.0, 10.0, (2,), np.float32) + act_space = gym.spaces.Discrete(2) + episodes = [ + SingleAgentEpisode( + observation_space=obs_space, + observations=[np.array([-10, -20], np.float32)], + len_lookback_buffer=0, + ) for _ in range(2) + ] + # Make sure everything is setup correctly. + check(episodes[0].get_observations(0), [-10.0, -20.0]) + check(episodes[1].get_observations(-1), [-10.0, -20.0]) + + # Create our connector piece. + connector = WriteObservationsToEpisodes(obs_space, act_space) + + # Call the connector (and thereby write the transformed observations back + # into the episodes). + output_batch = connector( + rl_module=None, # This particular connector works without an RLModule. + batch=batch, + episodes=episodes, + explore=True, + shared_data={}, + ) + + # The connector does NOT change the data batch being passed through. + check(output_batch, batch) + + # However, the connector has overwritten the last observations in the episodes. + check(episodes[0].get_observations(-1), [0.0, 1.0]) + check(episodes[1].get_observations(0), [2.0, 3.0]) + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Optional[Dict[str, Any]], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + observations = batch.get(Columns.OBS) + + if observations is None: + raise ValueError( + f"`batch` must already have a column named {Columns.OBS} in it " + f"for this connector to work!" + ) + + # Note that the following loop works with multi-agent as well as with + # single-agent episode, as long as the following conditions are met (these + # will be validated by `self.single_agent_episode_iterator()`): + # - Per single agent episode, one observation item is expected to exist in + # `data`, either in a list directly under the "obs" key OR for multi-agent: + # in a list sitting under a key `(agent_id, module_id)` of a dict sitting + # under the "obs" key. + for sa_episode, obs in self.single_agent_episode_iterator( + episodes=episodes, zip_with_batch_column=observations + ): + # Make sure episodes are NOT numpy'ized yet (we are expecting to run in an + # env-to-module pipeline). + assert not sa_episode.is_numpy + # Write new information into the episode. + sa_episode.set_observations(at_indices=-1, new_data=obs) + # Change the observation space of the sa_episode. + sa_episode.observation_space = self.observation_space + + # Return the unchanged `batch`. + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc76488be7b7550e3fdbf7177731c758194fd5b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__init__.py @@ -0,0 +1,30 @@ +from ray.rllib.connectors.common.tensor_to_numpy import TensorToNumpy +from ray.rllib.connectors.common.module_to_agent_unmapping import ModuleToAgentUnmapping +from ray.rllib.connectors.module_to_env.get_actions import GetActions +from ray.rllib.connectors.module_to_env.listify_data_for_vector_env import ( + ListifyDataForVectorEnv, +) +from ray.rllib.connectors.module_to_env.module_to_env_pipeline import ( + ModuleToEnvPipeline, +) +from ray.rllib.connectors.module_to_env.normalize_and_clip_actions import ( + NormalizeAndClipActions, +) +from ray.rllib.connectors.module_to_env.remove_single_ts_time_rank_from_batch import ( + RemoveSingleTsTimeRankFromBatch, +) +from ray.rllib.connectors.module_to_env.unbatch_to_individual_items import ( + UnBatchToIndividualItems, +) + + +__all__ = [ + "GetActions", + "ListifyDataForVectorEnv", + "ModuleToAgentUnmapping", + "ModuleToEnvPipeline", + "NormalizeAndClipActions", + "RemoveSingleTsTimeRankFromBatch", + "TensorToNumpy", + "UnBatchToIndividualItems", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f25cc5df44d3f85f65a0bd07fd9febef7c4c88e1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/listify_data_for_vector_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py new file mode 100644 index 0000000000000000000000000000000000000000..1e862231b4fbd10500ffee1f4b03f604a8af2472 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/get_actions.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, List, Optional + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class GetActions(ConnectorV2): + """Connector piece sampling actions from ACTION_DIST_INPUTS from an RLModule. + + Note: This is one of the default module-to-env ConnectorV2 pieces that + are added automatically by RLlib into every module-to-env connector pipeline, + unless `config.add_default_connectors_to_module_to_env_pipeline` is set to + False. + + The default module-to-env connector pipeline is: + [ + GetActions, + TensorToNumpy, + UnBatchToIndividualItems, + ModuleToAgentUnmapping, # only in multi-agent setups! + RemoveSingleTsTimeRankFromBatch, + + [0 or more user defined ConnectorV2 pieces], + + NormalizeAndClipActions, + ListifyDataForVectorEnv, + ] + + If necessary, this connector samples actions, given action dist. inputs and a + dist. class. + The connector will only sample from the action distribution, if the + Columns.ACTIONS key cannot be found in `data`. Otherwise, it'll behave + as pass-through. If Columns.ACTIONS is NOT present in `data`, but + Columns.ACTION_DIST_INPUTS is, this connector will create a new action + distribution using the given RLModule and sample from its distribution class + (deterministically, if we are not exploring, stochastically, if we are). + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + is_multi_agent = isinstance(episodes[0], MultiAgentEpisode) + + if is_multi_agent: + for module_id, module_data in batch.copy().items(): + self._get_actions(module_data, rl_module[module_id], explore) + else: + self._get_actions(batch, rl_module, explore) + + return batch + + def _get_actions(self, batch, sa_rl_module, explore): + # Action have already been sampled -> Early out. + if Columns.ACTIONS in batch: + return + + # ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` -> + # Create a new action distribution object. + if Columns.ACTION_DIST_INPUTS in batch: + if explore: + action_dist_class = sa_rl_module.get_exploration_action_dist_cls() + else: + action_dist_class = sa_rl_module.get_inference_action_dist_cls() + action_dist = action_dist_class.from_logits( + batch[Columns.ACTION_DIST_INPUTS], + ) + if not explore: + action_dist = action_dist.to_deterministic() + + # Sample actions from the distribution. + actions = action_dist.sample() + batch[Columns.ACTIONS] = actions + + # For convenience and if possible, compute action logp from distribution + # and add to output. + if Columns.ACTION_LOGP not in batch: + batch[Columns.ACTION_LOGP] = action_dist.logp(actions) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py new file mode 100644 index 0000000000000000000000000000000000000000..af30a0d31e7f5b745342c2705613b838cf9d2c70 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/listify_data_for_vector_env.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, List, Optional + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import batch as batch_fn +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ListifyDataForVectorEnv(ConnectorV2): + """Performs conversion from ConnectorV2-style format to env/episode insertion. + + Note: This is one of the default module-to-env ConnectorV2 pieces that + are added automatically by RLlib into every module-to-env connector pipeline, + unless `config.add_default_connectors_to_module_to_env_pipeline` is set to + False. + + The default module-to-env connector pipeline is: + [ + GetActions, + TensorToNumpy, + UnBatchToIndividualItems, + ModuleToAgentUnmapping, # only in multi-agent setups! + RemoveSingleTsTimeRankFromBatch, + + [0 or more user defined ConnectorV2 pieces], + + NormalizeAndClipActions, + ListifyDataForVectorEnv, + ] + + Single agent case: + Convert from: + [col] -> [(episode_id,)] -> [list of items]. + To: + [col] -> [list of items]. + + Multi-agent case: + Convert from: + [col] -> [(episode_id, agent_id, module_id)] -> list of items. + To: + [col] -> [list of multi-agent dicts]. + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + for column, column_data in batch.copy().items(): + # Multi-agent case: Create lists of multi-agent dicts under each column. + if isinstance(episodes[0], MultiAgentEpisode): + # TODO (sven): Support vectorized MultiAgentEnv + assert len(episodes) == 1 + new_column_data = [{}] + + for key, value in batch[column].items(): + assert len(value) == 1 + eps_id, agent_id, module_id = key + new_column_data[0][agent_id] = value[0] + batch[column] = new_column_data + # Single-agent case: Create simple lists under each column. + else: + batch[column] = [ + d for key in batch[column].keys() for d in batch[column][key] + ] + # Batch actions for (single-agent) gym.vector.Env. + # All other columns, leave listify'ed. + if column in [Columns.ACTIONS_FOR_ENV, Columns.ACTIONS]: + batch[column] = batch_fn(batch[column]) + + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..63d5f00bfa4cda7a59e1cd22b3476e2091241090 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/module_to_env_pipeline.py @@ -0,0 +1,7 @@ +from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2 +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ModuleToEnvPipeline(ConnectorPipelineV2): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py new file mode 100644 index 0000000000000000000000000000000000000000..1698b1d0fe1e1aa5a0a3b4bfd32e65b6efc3b60b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/normalize_and_clip_actions.py @@ -0,0 +1,146 @@ +import copy +from typing import Any, Dict, List, Optional + +import gymnasium as gym + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import ( + clip_action, + get_base_struct_from_space, + unsquash_action, +) +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class NormalizeAndClipActions(ConnectorV2): + """Normalizes or clips actions in the input data (coming from the RLModule). + + Note: This is one of the default module-to-env ConnectorV2 pieces that + are added automatically by RLlib into every module-to-env connector pipeline, + unless `config.add_default_connectors_to_module_to_env_pipeline` is set to + False. + + The default module-to-env connector pipeline is: + [ + GetActions, + TensorToNumpy, + UnBatchToIndividualItems, + ModuleToAgentUnmapping, # only in multi-agent setups! + RemoveSingleTsTimeRankFromBatch, + + [0 or more user defined ConnectorV2 pieces], + + NormalizeAndClipActions, + ListifyDataForVectorEnv, + ] + + This ConnectorV2: + - Deep copies the Columns.ACTIONS in the incoming `data` into a new column: + Columns.ACTIONS_FOR_ENV. + - Loops through the Columns.ACTIONS in the incoming `data` and normalizes or clips + these depending on the c'tor settings in `config.normalize_actions` and + `config.clip_actions`. + - Only applies to envs with Box action spaces. + + Normalizing is the process of mapping NN-outputs (which are usually small + numbers, e.g. between -1.0 and 1.0) to the bounds defined by the action-space. + Normalizing helps the NN to learn faster in environments with large ranges between + `low` and `high` bounds or skewed action bounds (e.g. Box(-3000.0, 1.0, ...)). + + Clipping clips the actions computed by the NN (and sampled from a distribution) + between the bounds defined by the action-space. Note that clipping is only performed + if `normalize_actions` is False. + """ + + @override(ConnectorV2) + def recompute_output_action_space( + self, + input_observation_space: gym.Space, + input_action_space: gym.Space, + ) -> gym.Space: + self._action_space_struct = get_base_struct_from_space(input_action_space) + return input_action_space + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + normalize_actions: bool, + clip_actions: bool, + **kwargs, + ): + """Initializes a DefaultModuleToEnv (connector piece) instance. + + Args: + normalize_actions: If True, actions coming from the RLModule's distribution + (or are directly computed by the RLModule w/o sampling) will + be assumed 0.0 centered with a small stddev (only affecting Box + components) and thus be unsquashed (and clipped, just in case) to the + bounds of the env's action space. For example, if the action space of + the environment is `Box(-2.0, -0.5, (1,))`, the model outputs + mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9 + from the resulting distribution, then this 0.9 will be unsquashed into + the [-2.0 -0.5] interval. If - after unsquashing - the action still + breaches the action space, it will simply be clipped. + clip_actions: If True, actions coming from the RLModule's distribution + (or are directly computed by the RLModule w/o sampling) will be clipped + such that they fit into the env's action space's bounds. + For example, if the action space of the environment is + `Box(-0.5, 0.5, (1,))`, the model outputs + mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9 + from the resulting distribution, then this 0.9 will be clipped to 0.5 + to fit into the [-0.5 0.5] interval. + """ + self._action_space_struct = None + + super().__init__(input_observation_space, input_action_space, **kwargs) + + self.normalize_actions = normalize_actions + self.clip_actions = clip_actions + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Optional[Dict[str, Any]], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + """Based on settings, will normalize (unsquash) and/or clip computed actions. + + This is such that the final actions (to be sent to the env) match the + environment's action space and thus don't lead to an error. + """ + + def _unsquash_or_clip(action_for_env, env_id, agent_id, module_id): + if agent_id is not None: + struct = self._action_space_struct[agent_id] + else: + struct = self._action_space_struct + + if self.normalize_actions: + return unsquash_action(action_for_env, struct) + else: + return clip_action(action_for_env, struct) + + # Normalize or clip this new actions_for_env column, leaving the originally + # computed/sampled actions intact. + if self.normalize_actions or self.clip_actions: + # Copy actions into separate column, just to go to the env. + batch[Columns.ACTIONS_FOR_ENV] = copy.deepcopy(batch[Columns.ACTIONS]) + self.foreach_batch_item_change_in_place( + batch=batch, + column=Columns.ACTIONS_FOR_ENV, + func=_unsquash_or_clip, + ) + + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..7297080595ad11b5069953a36dd9c269b1006912 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, List, Optional + +import numpy as np +import tree # pip install dm_tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class RemoveSingleTsTimeRankFromBatch(ConnectorV2): + """ + Note: This is one of the default module-to-env ConnectorV2 pieces that + are added automatically by RLlib into every module-to-env connector pipeline, + unless `config.add_default_connectors_to_module_to_env_pipeline` is set to + False. + + The default module-to-env connector pipeline is: + [ + GetActions, + TensorToNumpy, + UnBatchToIndividualItems, + ModuleToAgentUnmapping, # only in multi-agent setups! + RemoveSingleTsTimeRankFromBatch, + + [0 or more user defined ConnectorV2 pieces], + + NormalizeAndClipActions, + ListifyDataForVectorEnv, + ] + + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Optional[Dict[str, Any]], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # If single ts time-rank had not been added, early out. + if shared_data is None or not shared_data.get("_added_single_ts_time_rank"): + return batch + + def _remove_single_ts(item, eps_id, aid, mid): + # Only remove time-rank for modules that are statefule (only for those has + # a timerank been added). + if mid is None or rl_module[mid].is_stateful(): + return tree.map_structure(lambda s: np.squeeze(s, axis=0), item) + return item + + for column, column_data in batch.copy().items(): + # Skip state_out (doesn't have a time rank). + if column == Columns.STATE_OUT: + continue + self.foreach_batch_item_change_in_place( + batch, + column=column, + func=_remove_single_ts, + ) + + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py new file mode 100644 index 0000000000000000000000000000000000000000..bfac443ff4c9208f72a8f8223e43857a999dbe46 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/unbatch_to_individual_items.py @@ -0,0 +1,92 @@ +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import tree # pip install dm_tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import unbatch as unbatch_fn +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class UnBatchToIndividualItems(ConnectorV2): + """Unbatches the given `data` back into the individual-batch-items format. + + Note: This is one of the default module-to-env ConnectorV2 pieces that + are added automatically by RLlib into every module-to-env connector pipeline, + unless `config.add_default_connectors_to_module_to_env_pipeline` is set to + False. + + The default module-to-env connector pipeline is: + [ + GetActions, + TensorToNumpy, + UnBatchToIndividualItems, + ModuleToAgentUnmapping, # only in multi-agent setups! + RemoveSingleTsTimeRankFromBatch, + + [0 or more user defined ConnectorV2 pieces], + + NormalizeAndClipActions, + ListifyDataForVectorEnv, + ] + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + batch: Dict[str, Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + memorized_map_structure = shared_data.get("memorized_map_structure") + + # Simple case (no structure stored): Just unbatch. + if memorized_map_structure is None: + return tree.map_structure(lambda s: unbatch_fn(s), batch) + # Single agent case: Memorized structure is a list, whose indices map to + # eps_id values. + elif isinstance(memorized_map_structure, list): + for column, column_data in batch.copy().items(): + column_data = unbatch_fn(column_data) + new_column_data = defaultdict(list) + for i, eps_id in enumerate(memorized_map_structure): + # Keys are always tuples to resemble multi-agent keys, which + # have the structure (eps_id, agent_id, module_id). + key = (eps_id,) + new_column_data[key].append(column_data[i]) + batch[column] = dict(new_column_data) + # Multi-agent case: Memorized structure is dict mapping module_ids to lists of + # (eps_id, agent_id)-tuples, such that the original individual-items-based form + # can be constructed. + else: + for module_id, module_data in batch.copy().items(): + if module_id not in memorized_map_structure: + raise KeyError( + f"ModuleID={module_id} not found in `memorized_map_structure`!" + ) + for column, column_data in module_data.items(): + column_data = unbatch_fn(column_data) + new_column_data = defaultdict(list) + for i, (eps_id, agent_id) in enumerate( + memorized_map_structure[module_id] + ): + key = (eps_id, agent_id, module_id) + # TODO (sven): Support vectorization for MultiAgentEnvRunner. + # AgentIDs whose SingleAgentEpisodes are already done, should + # not send any data back to the EnvRunner for further + # processing. + if episodes[0].agent_episodes[agent_id].is_done: + continue + + new_column_data[key].append(column_data[i]) + module_data[column] = dict(new_column_data) + + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8efe64515eea485a13799330260758a68a11e21b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/registry.py @@ -0,0 +1,46 @@ +"""Registry of connector names for global access.""" +from typing import Any + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.connectors.connector import Connector, ConnectorContext + + +ALL_CONNECTORS = dict() + + +@OldAPIStack +def register_connector(name: str, cls: Connector): + """Register a connector for use with RLlib. + + Args: + name: Name to register. + cls: Callable that creates an env. + """ + if name in ALL_CONNECTORS: + return + + if not issubclass(cls, Connector): + raise TypeError("Can only register Connector type.", cls) + + # Record it in local registry in case we need to register everything + # again in the global registry, for example in the event of cluster + # restarts. + ALL_CONNECTORS[name] = cls + + +@OldAPIStack +def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector: + # TODO(jungong) : switch the order of parameters man!! + """Get a connector by its name and serialized config. + + Args: + name: name of the connector. + ctx: Connector context. + params: serialized parameters of the connector. + + Returns: + Constructed connector. + """ + if name not in ALL_CONNECTORS: + raise NameError("connector not found.", name) + return ALL_CONNECTORS[name].from_state(ctx, params) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ff00b6d49dfec12ef766c166c9fe7d02ca4a33a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/connectors/util.py @@ -0,0 +1,170 @@ +import logging +from typing import Any, Tuple, TYPE_CHECKING + +from ray.rllib.connectors.action.clip import ClipActionsConnector +from ray.rllib.connectors.action.immutable import ImmutableActionsConnector +from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector +from ray.rllib.connectors.action.normalize import NormalizeActionsConnector +from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline +from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector +from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector +from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline +from ray.rllib.connectors.agent.state_buffer import StateBufferConnector +from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector +from ray.rllib.connectors.connector import Connector, ConnectorContext +from ray.rllib.connectors.registry import get_connector +from ray.rllib.connectors.agent.mean_std_filter import ( + MeanStdObservationFilterAgentConnector, + ConcurrentMeanStdObservationFilterAgentConnector, +) +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.policy.policy import Policy + +logger = logging.getLogger(__name__) + + +def __preprocessing_enabled(config: "AlgorithmConfig"): + if config._disable_preprocessor_api: + return False + # Same conditions as in RolloutWorker.__init__. + if config.is_atari and config.preprocessor_pref == "deepmind": + return False + if config.preprocessor_pref is None: + return False + return True + + +def __clip_rewards(config: "AlgorithmConfig"): + # Same logic as in RolloutWorker.__init__. + # We always clip rewards for Atari games. + return config.clip_rewards or config.is_atari + + +@OldAPIStack +def get_agent_connectors_from_config( + ctx: ConnectorContext, + config: "AlgorithmConfig", +) -> AgentConnectorPipeline: + connectors = [] + + clip_rewards = __clip_rewards(config) + if clip_rewards is True: + connectors.append(ClipRewardAgentConnector(ctx, sign=True)) + elif type(clip_rewards) is float: + connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards))) + + if __preprocessing_enabled(config): + connectors.append(ObsPreprocessorConnector(ctx)) + + # Filters should be after observation preprocessing + filter_connector = get_synced_filter_connector( + ctx, + ) + # Configuration option "NoFilter" results in `filter_connector==None`. + if filter_connector: + connectors.append(filter_connector) + + connectors.extend( + [ + StateBufferConnector(ctx), + ViewRequirementAgentConnector(ctx), + ] + ) + + return AgentConnectorPipeline(ctx, connectors) + + +@OldAPIStack +def get_action_connectors_from_config( + ctx: ConnectorContext, + config: "AlgorithmConfig", +) -> ActionConnectorPipeline: + """Default list of action connectors to use for a new policy. + + Args: + ctx: context used to create connectors. + config: The AlgorithmConfig object. + """ + connectors = [ConvertToNumpyConnector(ctx)] + if config.get("normalize_actions", False): + connectors.append(NormalizeActionsConnector(ctx)) + if config.get("clip_actions", False): + connectors.append(ClipActionsConnector(ctx)) + connectors.append(ImmutableActionsConnector(ctx)) + return ActionConnectorPipeline(ctx, connectors) + + +@OldAPIStack +def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"): + """Util to create agent and action connectors for a Policy. + + Args: + policy: Policy instance. + config: Algorithm config dict. + """ + ctx: ConnectorContext = ConnectorContext.from_policy(policy) + + assert ( + policy.agent_connectors is None and policy.action_connectors is None + ), "Can not create connectors for a policy that already has connectors." + + policy.agent_connectors = get_agent_connectors_from_config(ctx, config) + policy.action_connectors = get_action_connectors_from_config(ctx, config) + + logger.info("Using connectors:") + logger.info(policy.agent_connectors.__str__(indentation=4)) + logger.info(policy.action_connectors.__str__(indentation=4)) + + +@OldAPIStack +def restore_connectors_for_policy( + policy: "Policy", connector_config: Tuple[str, Tuple[Any]] +) -> Connector: + """Util to create connector for a Policy based on serialized config. + + Args: + policy: Policy instance. + connector_config: Serialized connector config. + """ + ctx: ConnectorContext = ConnectorContext.from_policy(policy) + name, params = connector_config + return get_connector(name, ctx, params) + + +# We need this filter selection mechanism temporarily to remain compatible to old API +@OldAPIStack +def get_synced_filter_connector(ctx: ConnectorContext): + filter_specifier = ctx.config.get("observation_filter") + if filter_specifier == "MeanStdFilter": + return MeanStdObservationFilterAgentConnector(ctx, clip=None) + elif filter_specifier == "ConcurrentMeanStdFilter": + return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None) + elif filter_specifier == "NoFilter": + return None + else: + raise Exception("Unknown observation_filter: " + str(filter_specifier)) + + +@OldAPIStack +def maybe_get_filters_for_syncing(rollout_worker, policy_id): + # As long as the historic filter synchronization mechanism is in + # place, we need to put filters into self.filters so that they get + # synchronized + policy = rollout_worker.policy_map[policy_id] + if not policy.agent_connectors: + return + + filter_connectors = policy.agent_connectors[SyncedFilterAgentConnector] + # There can only be one filter at a time + if not filter_connectors: + return + + assert len(filter_connectors) == 1, ( + "ConnectorPipeline has multiple connectors of type " + "SyncedFilterAgentConnector but can only have one." + ) + rollout_worker.filters[policy_id] = filter_connectors[0].filter diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1265532aa05f4cf40792ccec9b2309465c7fae54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__init__.py @@ -0,0 +1,8 @@ +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.learner.learner_group import LearnerGroup + + +__all__ = [ + "Learner", + "LearnerGroup", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bb401ff1cfe3ca6b17544d0a8e2e68e95256027 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..119e9d09f37c0f0a833692f919598016f6705481 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce0eea9870e541fa61b9277cb940c759c7127e74 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/learner_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0229595246c0858e98ab629f937ed8561f4d7ef4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py new file mode 100644 index 0000000000000000000000000000000000000000..497a71fa2923436dfef59a49c697e8269dc6b75f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner.py @@ -0,0 +1,1795 @@ +import abc +from collections import defaultdict +import copy +import logging +import numpy +import platform +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Hashable, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +import tree # pip install dm_tree + +import ray +from ray.data.iterator import DataIterator +from ray.rllib.connectors.learner.learner_connector_pipeline import ( + LearnerConnectorPipeline, +) +from ray.rllib.core import ( + COMPONENT_METRICS_LOGGER, + COMPONENT_OPTIMIZER, + COMPONENT_RL_MODULE, + DEFAULT_MODULE_ID, +) +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, +) +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.debug import update_global_seed_if_necessary +from ray.rllib.utils.deprecation import ( + Deprecated, + DEPRECATED_VALUE, + deprecation_warning, +) +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.metrics import ( + ALL_MODULES, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_TRAINED, + NUM_ENV_STEPS_TRAINED_LIFETIME, + NUM_MODULE_STEPS_TRAINED, + NUM_MODULE_STEPS_TRAINED_LIFETIME, + MODULE_TRAIN_BATCH_SIZE_MEAN, + WEIGHTS_SEQ_NO, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.minibatch_utils import ( + MiniBatchDummyIterator, + MiniBatchCyclicIterator, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ( + EpisodeType, + LearningRateOrSchedule, + ModuleID, + Optimizer, + Param, + ParamRef, + ParamDict, + ResultDict, + ShouldModuleBeUpdatedFn, + StateDict, + TensorType, +) +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + +torch, _ = try_import_torch() +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + +DEFAULT_OPTIMIZER = "default_optimizer" + +# COMMON LEARNER LOSS_KEYS +POLICY_LOSS_KEY = "policy_loss" +VF_LOSS_KEY = "vf_loss" +ENTROPY_KEY = "entropy" + +# Additional update keys +LR_KEY = "learning_rate" + + +@PublicAPI(stability="alpha") +class Learner(Checkpointable): + """Base class for Learners. + + This class will be used to train RLModules. It is responsible for defining the loss + function, and updating the neural network weights that it owns. It also provides a + way to add/remove modules to/from RLModules in a multi-agent scenario, in the + middle of training (This is useful for league based training). + + TF and Torch specific implementation of this class fills in the framework-specific + implementation details for distributed training, and for computing and applying + gradients. User should not need to sub-class this class, but instead inherit from + the TF or Torch specific sub-classes to implement their algorithm-specific update + logic. + + Args: + config: The AlgorithmConfig object from which to derive most of the settings + needed to build the Learner. + module_spec: The module specification for the RLModule that is being trained. + If the module is a single agent module, after building the module it will + be converted to a multi-agent module with a default key. Can be none if the + module is provided directly via the `module` argument. Refer to + ray.rllib.core.rl_module.RLModuleSpec + or ray.rllib.core.rl_module.MultiRLModuleSpec for more info. + module: If learner is being used stand-alone, the RLModule can be optionally + passed in directly instead of the through the `module_spec`. + + Note: We use PPO and torch as an example here because many of the showcased + components need implementations to come together. However, the same + pattern is generally applicable. + + .. testcode:: + + import gymnasium as gym + + from ray.rllib.algorithms.ppo.ppo import PPOConfig + from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog + from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( + PPOTorchRLModule + ) + from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID + from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig + from ray.rllib.core.rl_module.rl_module import RLModuleSpec + + env = gym.make("CartPole-v1") + + # Create a PPO config object first. + config = ( + PPOConfig() + .framework("torch") + .training(model={"fcnet_hiddens": [128, 128]}) + ) + + # Create a learner instance directly from our config. All we need as + # extra information here is the env to be able to extract space information + # (needed to construct the RLModule inside the Learner). + learner = config.build_learner(env=env) + + # Take one gradient update on the module and report the results. + # results = learner.update(...) + + # Add a new module, perhaps for league based training. + learner.add_module( + module_id="new_player", + module_spec=RLModuleSpec( + module_class=PPOTorchRLModule, + observation_space=env.observation_space, + action_space=env.action_space, + model_config=DefaultModelConfig(fcnet_hiddens=[64, 64]), + catalog_class=PPOCatalog, + ) + ) + + # Take another gradient update with both previous and new modules. + # results = learner.update(...) + + # Remove a module. + learner.remove_module("new_player") + + # Will train previous modules only. + # results = learner.update(...) + + # Get the state of the learner. + state = learner.get_state() + + # Set the state of the learner. + learner.set_state(state) + + # Get the weights of the underlying MultiRLModule. + weights = learner.get_state(components=COMPONENT_RL_MODULE) + + # Set the weights of the underlying MultiRLModule. + learner.set_state({COMPONENT_RL_MODULE: weights}) + + + Extension pattern: + + .. testcode:: + + from ray.rllib.core.learner.torch.torch_learner import TorchLearner + + class MyLearner(TorchLearner): + + def compute_losses(self, fwd_out, batch): + # Compute the losses per module based on `batch` and output of the + # forward pass (`fwd_out`). To access the (algorithm) config for a + # specific RLModule, do: + # `self.config.get_config_for_module([moduleID])`. + return {DEFAULT_MODULE_ID: module_loss} + """ + + framework: str = None + TOTAL_LOSS_KEY: str = "total_loss" + + def __init__( + self, + *, + config: "AlgorithmConfig", + module_spec: Optional[Union[RLModuleSpec, MultiRLModuleSpec]] = None, + module: Optional[RLModule] = None, + ): + # TODO (sven): Figure out how to do this + self.config = config.copy(copy_frozen=False) + self._module_spec: Optional[MultiRLModuleSpec] = module_spec + self._module_obj: Optional[MultiRLModule] = module + + # Make node and device of this Learner available. + self._node = platform.node() + self._device = None + + # Set a seed, if necessary. + if self.config.seed is not None: + update_global_seed_if_necessary(self.framework, self.config.seed) + + # Whether self.build has already been called. + self._is_built = False + + # These are the attributes that are set during build. + + # The actual MultiRLModule used by this Learner. + self._module: Optional[MultiRLModule] = None + self._weights_seq_no = 0 + # Our Learner connector pipeline. + self._learner_connector: Optional[LearnerConnectorPipeline] = None + # These are set for properly applying optimizers and adding or removing modules. + self._optimizer_parameters: Dict[Optimizer, List[ParamRef]] = {} + self._named_optimizers: Dict[str, Optimizer] = {} + self._params: ParamDict = {} + # Dict mapping ModuleID to a list of optimizer names. Note that the optimizer + # name includes the ModuleID as a prefix: optimizer_name=`[ModuleID]_[.. rest]`. + self._module_optimizers: Dict[ModuleID, List[str]] = defaultdict(list) + self._optimizer_name_to_module: Dict[str, ModuleID] = {} + + # Only manage optimizer's learning rate if user has NOT overridden + # the `configure_optimizers_for_module` method. Otherwise, leave responsibility + # to handle lr-updates entirely in user's hands. + self._optimizer_lr_schedules: Dict[Optimizer, Scheduler] = {} + + # The Learner's own MetricsLogger to be used to log RLlib's built-in metrics or + # custom user-defined ones (e.g. custom loss values). When returning from an + # `update_from_...()` method call, the Learner will do a `self.metrics.reduce()` + # and return the resulting (reduced) dict. + self.metrics = MetricsLogger() + + # In case of offline learning and multiple learners, each learner receives a + # repeatable iterator that iterates over a split of the streamed data. + self.iterator: DataIterator = None + + # TODO (sven): Do we really need this API? It seems like LearnerGroup constructs + # all Learner workers and then immediately builds them any ways? Unless there is + # a reason related to Train worker group setup. + @OverrideToImplementCustomLogic_CallToSuperRecommended + def build(self) -> None: + """Builds the Learner. + + This method should be called before the learner is used. It is responsible for + setting up the LearnerConnectorPipeline, the RLModule, optimizer(s), and + (optionally) the optimizers' learning rate schedulers. + """ + if self._is_built: + logger.debug("Learner already built. Skipping build.") + return + + # Build learner connector pipeline used on this Learner worker. + self._learner_connector = None + # If the Algorithm uses aggregation actors to run episodes through the learner + # connector, its Learners don't need a connector pipelines and instead learn + # directly from pre-loaded batches already on the GPU. + if self.config.num_aggregator_actors_per_learner == 0: + # TODO (sven): Figure out which space to provide here. For now, + # it doesn't matter, as the default connector piece doesn't use + # this information anyway. + # module_spec = self._module_spec.as_multi_rl_module_spec() + self._learner_connector = self.config.build_learner_connector( + input_observation_space=None, + input_action_space=None, + device=self._device, + ) + + # Build the module to be trained by this learner. + self._module = self._make_module() + + # Configure, construct, and register all optimizers needed to train + # `self.module`. + self.configure_optimizers() + + # Log the number of trainable/non-trainable parameters. + self._log_trainable_parameters() + + self._is_built = True + + @property + def distributed(self) -> bool: + """Whether the learner is running in distributed mode.""" + return self.config.num_learners > 1 + + @property + def module(self) -> MultiRLModule: + """The MultiRLModule that is being trained.""" + return self._module + + @property + def node(self) -> Any: + return self._node + + @property + def device(self) -> Any: + return self._device + + def register_optimizer( + self, + *, + module_id: ModuleID = ALL_MODULES, + optimizer_name: str = DEFAULT_OPTIMIZER, + optimizer: Optimizer, + params: Sequence[Param], + lr_or_lr_schedule: Optional[LearningRateOrSchedule] = None, + ) -> None: + """Registers an optimizer with a ModuleID, name, param list and lr-scheduler. + + Use this method in your custom implementations of either + `self.configure_optimizers()` or `self.configure_optimzers_for_module()` (you + should only override one of these!). If you register a learning rate Scheduler + setting together with an optimizer, RLlib will automatically keep this + optimizer's learning rate updated throughout the training process. + Alternatively, you can construct your optimizers directly with a learning rate + and manage learning rate scheduling or updating yourself. + + Args: + module_id: The `module_id` under which to register the optimizer. If not + provided, will assume ALL_MODULES. + optimizer_name: The name (str) of the optimizer. If not provided, will + assume DEFAULT_OPTIMIZER. + optimizer: The already instantiated optimizer object to register. + params: A list of parameters (framework-specific variables) that will be + trained/updated + lr_or_lr_schedule: An optional fixed learning rate or learning rate schedule + setup. If provided, RLlib will automatically keep the optimizer's + learning rate updated. + """ + # Validate optimizer instance and its param list. + self._check_registered_optimizer(optimizer, params) + + full_registration_name = module_id + "_" + optimizer_name + + # Store the given optimizer under the given `module_id`. + self._module_optimizers[module_id].append(full_registration_name) + self._optimizer_name_to_module[full_registration_name] = module_id + + # Store the optimizer instance under its full `module_id`_`optimizer_name` + # key. + self._named_optimizers[full_registration_name] = optimizer + + # Store all given parameters under the given optimizer. + self._optimizer_parameters[optimizer] = [] + for param in params: + param_ref = self.get_param_ref(param) + self._optimizer_parameters[optimizer].append(param_ref) + self._params[param_ref] = param + + # Optionally, store a scheduler object along with this optimizer. If such a + # setting is provided, RLlib will handle updating the optimizer's learning rate + # over time. + if lr_or_lr_schedule is not None: + # Validate the given setting. + Scheduler.validate( + fixed_value_or_schedule=lr_or_lr_schedule, + setting_name="lr_or_lr_schedule", + description="learning rate or schedule", + ) + # Create the scheduler object for this optimizer. + scheduler = Scheduler( + fixed_value_or_schedule=lr_or_lr_schedule, + framework=self.framework, + device=self._device, + ) + self._optimizer_lr_schedules[optimizer] = scheduler + # Set the optimizer to the current (first) learning rate. + self._set_optimizer_lr( + optimizer=optimizer, + lr=scheduler.get_current_value(), + ) + + @OverrideToImplementCustomLogic + def configure_optimizers(self) -> None: + """Configures, creates, and registers the optimizers for this Learner. + + Optimizers are responsible for updating the model's parameters during training, + based on the computed gradients. + + Normally, you should not override this method for your custom algorithms + (which require certain optimizers), but rather override the + `self.configure_optimizers_for_module(module_id=..)` method and register those + optimizers in there that you need for the given `module_id`. + + You can register an optimizer for any RLModule within `self.module` (or for + the ALL_MODULES ID) by calling `self.register_optimizer()` and passing the + module_id, optimizer_name (only in case you would like to register more than + one optimizer for a given module), the optimizer instane itself, a list + of all the optimizer's parameters (to be updated by the optimizer), and + an optional learning rate or learning rate schedule setting. + + This method is called once during building (`self.build()`). + """ + # The default implementation simply calls `self.configure_optimizers_for_module` + # on each RLModule within `self.module`. + for module_id in self.module.keys(): + if self.rl_module_is_compatible(self.module[module_id]): + config = self.config.get_config_for_module(module_id) + self.configure_optimizers_for_module(module_id=module_id, config=config) + + @OverrideToImplementCustomLogic + @abc.abstractmethod + def configure_optimizers_for_module( + self, module_id: ModuleID, config: "AlgorithmConfig" = None + ) -> None: + """Configures an optimizer for the given module_id. + + This method is called for each RLModule in the MultiRLModule being + trained by the Learner, as well as any new module added during training via + `self.add_module()`. It should configure and construct one or more optimizers + and register them via calls to `self.register_optimizer()` along with the + `module_id`, an optional optimizer name (str), a list of the optimizer's + framework specific parameters (variables), and an optional learning rate value + or -schedule. + + Args: + module_id: The module_id of the RLModule that is being configured. + config: The AlgorithmConfig specific to the given `module_id`. + """ + + @OverrideToImplementCustomLogic + @abc.abstractmethod + def compute_gradients( + self, loss_per_module: Dict[ModuleID, TensorType], **kwargs + ) -> ParamDict: + """Computes the gradients based on the given losses. + + Args: + loss_per_module: Dict mapping module IDs to their individual total loss + terms, computed by the individual `compute_loss_for_module()` calls. + The overall total loss (sum of loss terms over all modules) is stored + under `loss_per_module[ALL_MODULES]`. + **kwargs: Forward compatibility kwargs. + + Returns: + The gradients in the same (flat) format as self._params. Note that all + top-level structures, such as module IDs, will not be present anymore in + the returned dict. It will merely map parameter tensor references to their + respective gradient tensors. + """ + + @OverrideToImplementCustomLogic + def postprocess_gradients(self, gradients_dict: ParamDict) -> ParamDict: + """Applies potential postprocessing operations on the gradients. + + This method is called after gradients have been computed and modifies them + before they are applied to the respective module(s) by the optimizer(s). + This might include grad clipping by value, norm, or global-norm, or other + algorithm specific gradient postprocessing steps. + + This default implementation calls `self.postprocess_gradients_for_module()` + on each of the sub-modules in our MultiRLModule: `self.module` and + returns the accumulated gradients dicts. + + Args: + gradients_dict: A dictionary of gradients in the same (flat) format as + self._params. Note that top-level structures, such as module IDs, + will not be present anymore in this dict. It will merely map gradient + tensor references to gradient tensors. + + Returns: + A dictionary with the updated gradients and the exact same (flat) structure + as the incoming `gradients_dict` arg. + """ + + # The flat gradients dict (mapping param refs to params), returned by this + # method. + postprocessed_gradients = {} + + for module_id in self.module.keys(): + # Send a gradients dict for only this `module_id` to the + # `self.postprocess_gradients_for_module()` method. + module_grads_dict = {} + for optimizer_name, optimizer in self.get_optimizers_for_module(module_id): + module_grads_dict.update( + self.filter_param_dict_for_optimizer(gradients_dict, optimizer) + ) + + module_grads_dict = self.postprocess_gradients_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id), + module_gradients_dict=module_grads_dict, + ) + assert isinstance(module_grads_dict, dict) + + # Update our return dict. + postprocessed_gradients.update(module_grads_dict) + + return postprocessed_gradients + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def postprocess_gradients_for_module( + self, + *, + module_id: ModuleID, + config: Optional["AlgorithmConfig"] = None, + module_gradients_dict: ParamDict, + ) -> ParamDict: + """Applies postprocessing operations on the gradients of the given module. + + Args: + module_id: The module ID for which we will postprocess computed gradients. + Note that `module_gradients_dict` already only carries those gradient + tensors that belong to this `module_id`. Other `module_id`'s gradients + are not available in this call. + config: The AlgorithmConfig specific to the given `module_id`. + module_gradients_dict: A dictionary of gradients in the same (flat) format + as self._params, mapping gradient refs to gradient tensors, which are to + be postprocessed. You may alter these tensors in place or create new + ones and return these in a new dict. + + Returns: + A dictionary with the updated gradients and the exact same (flat) structure + as the incoming `module_gradients_dict` arg. + """ + postprocessed_grads = {} + + if config.grad_clip is None and not config.log_gradients: + postprocessed_grads.update(module_gradients_dict) + return postprocessed_grads + + for optimizer_name, optimizer in self.get_optimizers_for_module(module_id): + grad_dict_to_clip = self.filter_param_dict_for_optimizer( + param_dict=module_gradients_dict, + optimizer=optimizer, + ) + if config.grad_clip: + # Perform gradient clipping, if configured. + global_norm = self._get_clip_function()( + grad_dict_to_clip, + grad_clip=config.grad_clip, + grad_clip_by=config.grad_clip_by, + ) + if config.grad_clip_by == "global_norm" or config.log_gradients: + # If we want to log gradients, but do not use the global norm + # for clipping compute it here. + if config.log_gradients and config.grad_clip_by != "global_norm": + # Compute the global norm of gradients. + global_norm = self._get_global_norm_function()( + # Note, `tf.linalg.global_norm` needs a list of tensors. + list(grad_dict_to_clip.values()), + ) + self.metrics.log_value( + key=(module_id, f"gradients_{optimizer_name}_global_norm"), + value=global_norm, + window=1, + ) + postprocessed_grads.update(grad_dict_to_clip) + # In the other case check, if we want to log gradients only. + elif config.log_gradients: + # Compute the global norm of gradients and log it. + global_norm = self._get_global_norm_function()( + # Note, `tf.linalg.global_norm` needs a list of tensors. + list(grad_dict_to_clip.values()), + ) + self.metrics.log_value( + key=(module_id, f"gradients_{optimizer_name}_global_norm"), + value=global_norm, + window=1, + ) + + return postprocessed_grads + + @OverrideToImplementCustomLogic + @abc.abstractmethod + def apply_gradients(self, gradients_dict: ParamDict) -> None: + """Applies the gradients to the MultiRLModule parameters. + + Args: + gradients_dict: A dictionary of gradients in the same (flat) format as + self._params. Note that top-level structures, such as module IDs, + will not be present anymore in this dict. It will merely map gradient + tensor references to gradient tensors. + """ + + def get_optimizer( + self, + module_id: ModuleID = DEFAULT_MODULE_ID, + optimizer_name: str = DEFAULT_OPTIMIZER, + ) -> Optimizer: + """Returns the optimizer object, configured under the given module_id and name. + + If only one optimizer was registered under `module_id` (or ALL_MODULES) + via the `self.register_optimizer` method, `optimizer_name` is assumed to be + DEFAULT_OPTIMIZER. + + Args: + module_id: The ModuleID for which to return the configured optimizer. + If not provided, will assume DEFAULT_MODULE_ID. + optimizer_name: The name of the optimizer (registered under `module_id` via + `self.register_optimizer()`) to return. If not provided, will assume + DEFAULT_OPTIMIZER. + + Returns: + The optimizer object, configured under the given `module_id` and + `optimizer_name`. + """ + # `optimizer_name` could possibly be the full optimizer name (including the + # module_id under which it is registered). + if optimizer_name in self._named_optimizers: + return self._named_optimizers[optimizer_name] + + # Normally, `optimizer_name` is just the optimizer's name, not including the + # `module_id`. + full_registration_name = module_id + "_" + optimizer_name + if full_registration_name in self._named_optimizers: + return self._named_optimizers[full_registration_name] + + # No optimizer found. + raise KeyError( + f"Optimizer not found! module_id={module_id} " + f"optimizer_name={optimizer_name}" + ) + + def get_optimizers_for_module( + self, module_id: ModuleID = ALL_MODULES + ) -> List[Tuple[str, Optimizer]]: + """Returns a list of (optimizer_name, optimizer instance)-tuples for module_id. + + Args: + module_id: The ModuleID for which to return the configured + (optimizer name, optimizer)-pairs. If not provided, will return + optimizers registered under ALL_MODULES. + + Returns: + A list of tuples of the format: ([optimizer_name], [optimizer object]), + where optimizer_name is the name under which the optimizer was registered + in `self.register_optimizer`. If only a single optimizer was + configured for `module_id`, [optimizer_name] will be DEFAULT_OPTIMIZER. + """ + named_optimizers = [] + for full_registration_name in self._module_optimizers[module_id]: + optimizer = self._named_optimizers[full_registration_name] + # TODO (sven): How can we avoid registering optimziers under this + # constructed `[module_id]_[optim_name]` format? + optim_name = full_registration_name[len(module_id) + 1 :] + named_optimizers.append((optim_name, optimizer)) + return named_optimizers + + def filter_param_dict_for_optimizer( + self, param_dict: ParamDict, optimizer: Optimizer + ) -> ParamDict: + """Reduces the given ParamDict to contain only parameters for given optimizer. + + Args: + param_dict: The ParamDict to reduce/filter down to the given `optimizer`. + The returned dict will be a subset of `param_dict` only containing keys + (param refs) that were registered together with `optimizer` (and thus + that `optimizer` is responsible for applying gradients to). + optimizer: The optimizer object to whose parameter refs the given + `param_dict` should be reduced. + + Returns: + A new ParamDict only containing param ref keys that belong to `optimizer`. + """ + # Return a sub-dict only containing those param_ref keys (and their values) + # that belong to the `optimizer`. + return { + ref: param_dict[ref] + for ref in self._optimizer_parameters[optimizer] + if ref in param_dict and param_dict[ref] is not None + } + + @abc.abstractmethod + def get_param_ref(self, param: Param) -> Hashable: + """Returns a hashable reference to a trainable parameter. + + This should be overridden in framework specific specialization. For example in + torch it will return the parameter itself, while in tf it returns the .ref() of + the variable. The purpose is to retrieve a unique reference to the parameters. + + Args: + param: The parameter to get the reference to. + + Returns: + A reference to the parameter. + """ + + @abc.abstractmethod + def get_parameters(self, module: RLModule) -> Sequence[Param]: + """Returns the list of parameters of a module. + + This should be overridden in framework specific learner. For example in torch it + will return .parameters(), while in tf it returns .trainable_variables. + + Args: + module: The module to get the parameters from. + + Returns: + The parameters of the module. + """ + + @abc.abstractmethod + def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch: + """Converts the elements of a MultiAgentBatch to Tensors on the correct device. + + Args: + batch: The MultiAgentBatch object to convert. + + Returns: + The resulting MultiAgentBatch with framework-specific tensor values placed + on the correct device. + """ + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def add_module( + self, + *, + module_id: ModuleID, + module_spec: RLModuleSpec, + config_overrides: Optional[Dict] = None, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> MultiRLModuleSpec: + """Adds a module to the underlying MultiRLModule. + + Changes this Learner's config in order to make this architectural change + permanent wrt. to checkpointing. + + Args: + module_id: The ModuleID of the module to be added. + module_spec: The ModuleSpec of the module to be added. + config_overrides: The `AlgorithmConfig` overrides that should apply to + the new Module, if any. + new_should_module_be_updated: An optional sequence of ModuleIDs or a + callable taking ModuleID and SampleBatchType and returning whether the + ModuleID should be updated (trained). + If None, will keep the existing setup in place. RLModules, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + + Returns: + The new MultiRLModuleSpec (after the RLModule has been added). + """ + validate_module_id(module_id, error=True) + self._check_is_built() + + # Force-set inference-only = False. + module_spec = copy.deepcopy(module_spec) + module_spec.inference_only = False + + # Build the new RLModule and add it to self.module. + module = module_spec.build() + self.module.add_module(module_id, module) + + # Change our config (AlgorithmConfig) to contain the new Module. + # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, + # but we'll deprecate config.policies soon anyway. + self.config.policies[module_id] = PolicySpec() + if config_overrides is not None: + self.config.multi_agent( + algorithm_config_overrides_per_module={module_id: config_overrides} + ) + self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module)) + self._module_spec = self.config.rl_module_spec + if new_should_module_be_updated is not None: + self.config.multi_agent(policies_to_train=new_should_module_be_updated) + + # Allow the user to configure one or more optimizers for this new module. + self.configure_optimizers_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id), + ) + return self.config.rl_module_spec + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def remove_module( + self, + module_id: ModuleID, + *, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> MultiRLModuleSpec: + """Removes a module from the Learner. + + Args: + module_id: The ModuleID of the module to be removed. + new_should_module_be_updated: An optional sequence of ModuleIDs or a + callable taking ModuleID and SampleBatchType and returning whether the + ModuleID should be updated (trained). + If None, will keep the existing setup in place. RLModules, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + + Returns: + The new MultiRLModuleSpec (after the RLModule has been removed). + """ + self._check_is_built() + module = self.module[module_id] + + # Delete the removed module's parameters and optimizers. + if self.rl_module_is_compatible(module): + parameters = self.get_parameters(module) + for param in parameters: + param_ref = self.get_param_ref(param) + if param_ref in self._params: + del self._params[param_ref] + for optimizer_name, optimizer in self.get_optimizers_for_module(module_id): + del self._optimizer_parameters[optimizer] + name = module_id + "_" + optimizer_name + del self._named_optimizers[name] + if optimizer in self._optimizer_lr_schedules: + del self._optimizer_lr_schedules[optimizer] + del self._module_optimizers[module_id] + + # Remove the module from the MultiRLModule. + self.module.remove_module(module_id) + + # Change self.config to reflect the new architecture. + # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, + # but we'll deprecate config.policies soon anyway. + del self.config.policies[module_id] + self.config.algorithm_config_overrides_per_module.pop(module_id, None) + if new_should_module_be_updated is not None: + self.config.multi_agent(policies_to_train=new_should_module_be_updated) + self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module)) + + # Remove all stats from the module from our metrics logger, so we don't report + # results from this module again. + if module_id in self.metrics.stats: + del self.metrics.stats[module_id] + + return self.config.rl_module_spec + + @OverrideToImplementCustomLogic + def should_module_be_updated(self, module_id, multi_agent_batch=None): + """Returns whether a module should be updated or not based on `self.config`. + + Args: + module_id: The ModuleID that we want to query on whether this module + should be updated or not. + multi_agent_batch: An optional MultiAgentBatch to possibly provide further + information on the decision on whether the RLModule should be updated + or not. + """ + should_module_be_updated_fn = self.config.policies_to_train + # If None, return True (by default, all modules should be updated). + if should_module_be_updated_fn is None: + return True + # If collection given, return whether `module_id` is in that container. + elif not callable(should_module_be_updated_fn): + return module_id in set(should_module_be_updated_fn) + + return should_module_be_updated_fn(module_id, multi_agent_batch) + + @OverrideToImplementCustomLogic + def compute_losses( + self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any] + ) -> Dict[str, Any]: + """Computes the loss(es) for the module being optimized. + + This method must be overridden by MultiRLModule-specific Learners in order to + define the specific loss computation logic. If the algorithm is single-agent, + only `compute_loss_for_module()` should be overridden instead. If the algorithm + uses independent multi-agent learning (default behavior for RLlib's multi-agent + setups), also only `compute_loss_for_module()` should be overridden, but it will + be called for each individual RLModule inside the MultiRLModule. + It is recommended to not compute any forward passes within this method, and to + use the `forward_train()` outputs of the RLModule(s) to compute the required + loss tensors. + See here for a custom loss function example script: + https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py # noqa + + Args: + fwd_out: Output from a call to the `forward_train()` method of the + underlying MultiRLModule (`self.module`) during training + (`self.update()`). + batch: The train batch that was used to compute `fwd_out`. + + Returns: + A dictionary mapping module IDs to individual loss terms. + """ + loss_per_module = {} + for module_id in fwd_out: + module_batch = batch[module_id] + module_fwd_out = fwd_out[module_id] + + module = self.module[module_id].unwrapped() + if isinstance(module, SelfSupervisedLossAPI): + loss = module.compute_self_supervised_loss( + learner=self, + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + else: + loss = self.compute_loss_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + loss_per_module[module_id] = loss + + return loss_per_module + + @OverrideToImplementCustomLogic + @abc.abstractmethod + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: "AlgorithmConfig", + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + """Computes the loss for a single module. + + Think of this as computing loss for a single agent. For multi-agent use-cases + that require more complicated computation for loss, consider overriding the + `compute_losses` method instead. + + Args: + module_id: The id of the module. + config: The AlgorithmConfig specific to the given `module_id`. + batch: The train batch for this particular module. + fwd_out: The output of the forward pass for this particular module. + + Returns: + A single total loss tensor. If you have more than one optimizer on the + provided `module_id` and would like to compute gradients separately using + these different optimizers, simply add up the individual loss terms for + each optimizer and return the sum. Also, for recording/logging any + individual loss terms, you can use the `Learner.metrics.log_value( + key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See: + :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more + information. + """ + + def update_from_batch( + self, + batch: MultiAgentBatch, + *, + # TODO (sven): Make this a more formal structure with its own type. + timesteps: Optional[Dict[str, Any]] = None, + num_epochs: int = 1, + minibatch_size: Optional[int] = None, + shuffle_batch_per_epoch: bool = False, + # Deprecated args. + num_iters=DEPRECATED_VALUE, + **kwargs, + ) -> ResultDict: + """Run `num_epochs` epochs over the given train batch. + + You can use this method to take more than one backward pass on the batch. + The same `minibatch_size` and `num_epochs` will be used for all module ids in + MultiRLModule. + + Args: + batch: A batch of training data to update from. + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + num_epochs: The number of complete passes over the entire train batch. Each + pass might be further split into n minibatches (if `minibatch_size` + provided). + minibatch_size: The size of minibatches to use to further split the train + `batch` into sub-batches. The `batch` is then iterated over n times + where n is `len(batch) // minibatch_size`. + shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch. + If the train batch has a time rank (axis=1), shuffling will only take + place along the batch axis to not disturb any intact (episode) + trajectories. Also, shuffling is always skipped if `minibatch_size` is + None, meaning the entire train batch is processed each epoch, making it + unnecessary to shuffle. + + Returns: + A `ResultDict` object produced by a call to `self.metrics.reduce()`. The + returned dict may be arbitrarily nested and must have `Stats` objects at + all its leafs, allowing components further downstream (i.e. a user of this + Learner) to further reduce these results (for example over n parallel + Learners). + """ + if num_iters != DEPRECATED_VALUE: + deprecation_warning( + old="Learner.update_from_episodes(num_iters=...)", + new="Learner.update_from_episodes(num_epochs=...)", + error=True, + ) + self._update_from_batch_or_episodes( + batch=batch, + timesteps=timesteps, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + ) + return self.metrics.reduce() + + def update_from_episodes( + self, + episodes: List[EpisodeType], + *, + # TODO (sven): Make this a more formal structure with its own type. + timesteps: Optional[Dict[str, Any]] = None, + num_epochs: int = 1, + minibatch_size: Optional[int] = None, + shuffle_batch_per_epoch: bool = False, + num_total_minibatches: int = 0, + # Deprecated args. + num_iters=DEPRECATED_VALUE, + ) -> ResultDict: + """Run `num_epochs` epochs over the train batch generated from `episodes`. + + You can use this method to take more than one backward pass on the batch. + The same `minibatch_size` and `num_epochs` will be used for all module ids in + MultiRLModule. + + Args: + episodes: An list of episode objects to update from. + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + num_epochs: The number of complete passes over the entire train batch. Each + pass might be further split into n minibatches (if `minibatch_size` + provided). The train batch is generated from the given `episodes` + through the Learner connector pipeline. + minibatch_size: The size of minibatches to use to further split the train + `batch` into sub-batches. The `batch` is then iterated over n times + where n is `len(batch) // minibatch_size`. The train batch is generated + from the given `episodes` through the Learner connector pipeline. + shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch. + If the train batch has a time rank (axis=1), shuffling will only take + place along the batch axis to not disturb any intact (episode) + trajectories. Also, shuffling is always skipped if `minibatch_size` is + None, meaning the entire train batch is processed each epoch, making it + unnecessary to shuffle. The train batch is generated from the given + `episodes` through the Learner connector pipeline. + num_total_minibatches: The total number of minibatches to loop through + (over all `num_epochs` epochs). It's only required to set this to != 0 + in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes + themselves are roughly sharded equally, however, they might contain + SingleAgentEpisodes with very lopsided length distributions. Thus, + without this fixed, pre-computed value, one Learner might go through a + different number of minibatche passes than others causing a deadlock. + + Returns: + A `ResultDict` object produced by a call to `self.metrics.reduce()`. The + returned dict may be arbitrarily nested and must have `Stats` objects at + all its leafs, allowing components further downstream (i.e. a user of this + Learner) to further reduce these results (for example over n parallel + Learners). + """ + if num_iters != DEPRECATED_VALUE: + deprecation_warning( + old="Learner.update_from_episodes(num_iters=...)", + new="Learner.update_from_episodes(num_epochs=...)", + error=True, + ) + self._update_from_batch_or_episodes( + episodes=episodes, + timesteps=timesteps, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + num_total_minibatches=num_total_minibatches, + ) + return self.metrics.reduce() + + def update_from_iterator( + self, + iterator, + *, + timesteps: Optional[Dict[str, Any]] = None, + minibatch_size: Optional[int] = None, + num_iters: int = None, + **kwargs, + ): + if "num_epochs" in kwargs: + raise ValueError( + "`num_epochs` arg NOT supported by Learner.update_from_iterator! Use " + "`num_iters` instead." + ) + + if not self.iterator: + self.iterator = iterator + + self._check_is_built() + + # Call `before_gradient_based_update` to allow for non-gradient based + # preparations-, logging-, and update logic to happen. + self.before_gradient_based_update(timesteps=timesteps or {}) + + def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: + # Note, the incoming batch is a dictionary with a numpy array + # holding the `MultiAgentBatch`. + batch = self._convert_batch_type(batch["batch"][0]) + return {"batch": self._set_slicing_by_batch_id(batch, value=True)} + + i = 0 + logger.debug(f"===> [Learner {id(self)}]: Looping through batches ... ") + for batch in self.iterator.iter_batches( + # Note, this needs to be one b/c data is already mapped to + # `MultiAgentBatch`es of `minibatch_size`. + batch_size=1, + _finalize_fn=_finalize_fn, + **kwargs, + ): + # Update the iteration counter. + i += 1 + + # Note, `_finalize_fn` must return a dictionary. + batch = batch["batch"] + logger.debug( + f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows." + ) + # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs + # found in this batch. If not, throw an error. + unknown_module_ids = set(batch.policy_batches.keys()) - set( + self.module.keys() + ) + if len(unknown_module_ids) > 0: + raise ValueError( + "Batch contains one or more ModuleIDs that are not in this " + f"Learner! Found IDs: {unknown_module_ids}" + ) + + # Log metrics. + self._log_steps_trained_metrics(batch) + + # Make the actual in-graph/traced `_update` call. This should return + # all tensor values (no numpy). + fwd_out, loss_per_module, tensor_metrics = self._update( + batch.policy_batches + ) + # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger) + # to actual (numpy) values. + self.metrics.tensors_to_numpy(tensor_metrics) + + self._set_slicing_by_batch_id(batch, value=False) + # If `num_iters` is reached break and return. + if num_iters and i == num_iters: + break + + logger.debug( + f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}" + ) + + # Log all individual RLModules' loss terms and its registered optimizers' + # current learning rates. + for mid, loss in convert_to_numpy(loss_per_module).items(): + self.metrics.log_value( + key=(mid, self.TOTAL_LOSS_KEY), + value=loss, + window=1, + ) + # Call `after_gradient_based_update` to allow for non-gradient based + # cleanups-, logging-, and update logic to happen. + # TODO (simon): Check, if this should stay here, when running multiple + # gradient steps inside the iterator loop above (could be a complete epoch) + # the target networks might need to be updated earlier. + self.after_gradient_based_update(timesteps=timesteps or {}) + + # Reduce results across all minibatch update steps. + return self.metrics.reduce() + + @OverrideToImplementCustomLogic + @abc.abstractmethod + def _update( + self, + batch: Dict[str, Any], + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Contains all logic for an in-graph/traceable update step. + + Framework specific subclasses must implement this method. This should include + calls to the RLModule's `forward_train`, `compute_loss`, compute_gradients`, + `postprocess_gradients`, and `apply_gradients` methods and return a tuple + with all the individual results. + + Args: + batch: The train batch already converted to a Dict mapping str to (possibly + nested) tensors. + kwargs: Forward compatibility kwargs. + + Returns: + A tuple consisting of: + 1) The `forward_train()` output of the RLModule, + 2) the loss_per_module dictionary mapping module IDs to individual loss + tensors + 3) a metrics dict mapping module IDs to metrics key/value pairs. + + """ + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + self._check_is_built() + + state = { + "should_module_be_updated": self.config.policies_to_train, + } + + if self._check_component(COMPONENT_RL_MODULE, components, not_components): + state[COMPONENT_RL_MODULE] = self.module.get_state( + components=self._get_subcomponents(COMPONENT_RL_MODULE, components), + not_components=self._get_subcomponents( + COMPONENT_RL_MODULE, not_components + ), + **kwargs, + ) + state[WEIGHTS_SEQ_NO] = self._weights_seq_no + if self._check_component(COMPONENT_OPTIMIZER, components, not_components): + state[COMPONENT_OPTIMIZER] = self._get_optimizer_state() + + if self._check_component(COMPONENT_METRICS_LOGGER, components, not_components): + # TODO (sven): Make `MetricsLogger` a Checkpointable. + state[COMPONENT_METRICS_LOGGER] = self.metrics.get_state() + + return state + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + self._check_is_built() + + weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0) + + if COMPONENT_RL_MODULE in state: + if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no: + self.module.set_state(state[COMPONENT_RL_MODULE]) + + if COMPONENT_OPTIMIZER in state: + self._set_optimizer_state(state[COMPONENT_OPTIMIZER]) + + # Update our weights_seq_no, if the new one is > 0. + if weights_seq_no > 0: + self._weights_seq_no = weights_seq_no + + # Update our trainable Modules information/function via our config. + # If not provided in state (None), all Modules will be trained by default. + if "should_module_be_updated" in state: + self.config.multi_agent(policies_to_train=state["should_module_be_updated"]) + + # TODO (sven): Make `MetricsLogger` a Checkpointable. + if COMPONENT_METRICS_LOGGER in state: + self.metrics.set_state(state[COMPONENT_METRICS_LOGGER]) + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args, + { + "config": self.config, + "module_spec": self._module_spec, + "module": self._module_obj, + }, # **kwargs + ) + + @override(Checkpointable) + def get_checkpointable_components(self): + if not self._check_is_built(error=False): + self.build() + return [ + (COMPONENT_RL_MODULE, self.module), + ] + + def _get_optimizer_state(self) -> StateDict: + """Returns the state of all optimizers currently registered in this Learner. + + Returns: + The current state of all optimizers currently registered in this Learner. + """ + raise NotImplementedError + + def _set_optimizer_state(self, state: StateDict) -> None: + """Sets the state of all optimizers currently registered in this Learner. + + Args: + state: The state of the optimizers. + """ + raise NotImplementedError + + def _update_from_batch_or_episodes( + self, + *, + # TODO (sven): We should allow passing in a single agent batch here + # as well for simplicity. + batch: Optional[MultiAgentBatch] = None, + episodes: Optional[List[EpisodeType]] = None, + # TODO (sven): Make this a more formal structure with its own type. + timesteps: Optional[Dict[str, Any]] = None, + # TODO (sven): Deprecate these in favor of config attributes for only those + # algos that actually need (and know how) to do minibatching. + num_epochs: int = 1, + minibatch_size: Optional[int] = None, + shuffle_batch_per_epoch: bool = False, + num_total_minibatches: int = 0, + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + + self._check_is_built() + + # Call `before_gradient_based_update` to allow for non-gradient based + # preparations-, logging-, and update logic to happen. + self.before_gradient_based_update(timesteps=timesteps or {}) + + # Resolve batch/episodes being ray object refs (instead of + # actual batch/episodes objects). + if isinstance(batch, ray.ObjectRef): + batch = ray.get(batch) + if isinstance(episodes, ray.ObjectRef): + episodes = ray.get(episodes) + elif isinstance(episodes, list) and isinstance(episodes[0], ray.ObjectRef): + # It's possible that individual refs are invalid due to the EnvRunner + # that produced the ref has crashed or had its entire node go down. + # In this case, try each ref individually and collect only valid results. + try: + episodes = tree.flatten(ray.get(episodes)) + except ray.exceptions.OwnerDiedError: + episode_refs = episodes + episodes = [] + for ref in episode_refs: + try: + episodes.extend(ray.get(ref)) + except ray.exceptions.OwnerDiedError: + pass + + # Call the learner connector on the given `episodes` (if we have one). + if episodes is not None and self._learner_connector is not None: + # Call the learner connector pipeline. + shared_data = {} + batch = self._learner_connector( + rl_module=self.module, + batch=batch if batch is not None else {}, + episodes=episodes, + shared_data=shared_data, + metrics=self.metrics, + ) + # Convert to a batch. + # TODO (sven): Try to not require MultiAgentBatch anymore. + batch = MultiAgentBatch( + { + module_id: ( + SampleBatch(module_data, _zero_padded=True) + if shared_data.get(f"_zero_padded_for_mid={module_id}") + else SampleBatch(module_data) + ) + for module_id, module_data in batch.items() + }, + env_steps=sum(len(e) for e in episodes), + ) + # Single-agent SampleBatch: Have to convert to MultiAgentBatch. + elif isinstance(batch, SampleBatch): + assert len(self.module) == 1 + batch = MultiAgentBatch( + {next(iter(self.module.keys())): batch}, env_steps=len(batch) + ) + + # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs + # found in this batch. If not, throw an error. + unknown_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys()) + if len(unknown_module_ids) > 0: + raise ValueError( + "Batch contains one or more ModuleIDs that are not in this Learner! " + f"Found IDs: {unknown_module_ids}" + ) + + # TODO: Move this into LearnerConnector pipeline? + # Filter out those RLModules from the final train batch that should not be + # updated. + for module_id in list(batch.policy_batches.keys()): + if not self.should_module_be_updated(module_id, batch): + del batch.policy_batches[module_id] + + # Log all timesteps (env, agent, modules) based on given episodes/batch. + self._log_steps_trained_metrics(batch) + + if minibatch_size: + batch_iter = MiniBatchCyclicIterator + elif num_epochs > 1: + # `minibatch_size` was not set but `num_epochs` > 1. + # Under the old training stack, users could do multiple epochs + # over a batch without specifying a minibatch size. We enable + # this behavior here by setting the minibatch size to be the size + # of the batch (e.g. 1 minibatch of size batch.count) + minibatch_size = batch.count + # Note that there is no need to shuffle here, b/c we don't have minibatches. + batch_iter = MiniBatchCyclicIterator + else: + # `minibatch_size` and `num_epochs` are not set by the user. + batch_iter = MiniBatchDummyIterator + + batch = self._set_slicing_by_batch_id(batch, value=True) + + for tensor_minibatch in batch_iter( + batch, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch and (num_epochs > 1), + num_total_minibatches=num_total_minibatches, + ): + # Make the actual in-graph/traced `_update` call. This should return + # all tensor values (no numpy). + fwd_out, loss_per_module, tensor_metrics = self._update( + tensor_minibatch.policy_batches + ) + + # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger) + # to actual (numpy) values. + self.metrics.tensors_to_numpy(tensor_metrics) + + # Log all individual RLModules' loss terms and its registered optimizers' + # current learning rates. + for mid, loss in convert_to_numpy(loss_per_module).items(): + self.metrics.log_value( + key=(mid, self.TOTAL_LOSS_KEY), + value=loss, + window=1, + ) + + self._weights_seq_no += 1 + self.metrics.log_dict( + { + (mid, WEIGHTS_SEQ_NO): self._weights_seq_no + for mid in batch.policy_batches.keys() + }, + window=1, + ) + + self._set_slicing_by_batch_id(batch, value=False) + + # Call `after_gradient_based_update` to allow for non-gradient based + # cleanups-, logging-, and update logic to happen. + self.after_gradient_based_update(timesteps=timesteps or {}) + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def before_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: + """Called before gradient-based updates are completed. + + Should be overridden to implement custom preparation-, logging-, or + non-gradient-based Learner/RLModule update logic before(!) gradient-based + updates are performed. + + Args: + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + """ + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: + """Called after gradient-based updates are completed. + + Should be overridden to implement custom cleanup-, logging-, or non-gradient- + based Learner/RLModule update logic after(!) gradient-based updates have been + completed. + + Args: + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + """ + # Only update this optimizer's lr, if a scheduler has been registered + # along with it. + for module_id, optimizer_names in self._module_optimizers.items(): + for optimizer_name in optimizer_names: + optimizer = self._named_optimizers[optimizer_name] + # Update and log learning rate of this optimizer. + lr_schedule = self._optimizer_lr_schedules.get(optimizer) + if lr_schedule is not None: + new_lr = lr_schedule.update( + timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0) + ) + self._set_optimizer_lr(optimizer, lr=new_lr) + self.metrics.log_value( + # Cut out the module ID from the beginning since it's already part + # of the key sequence: (ModuleID, "[optim name]_lr"). + key=(module_id, f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}"), + value=convert_to_numpy(self._get_optimizer_lr(optimizer)), + window=1, + ) + + def _set_slicing_by_batch_id( + self, batch: MultiAgentBatch, *, value: bool + ) -> MultiAgentBatch: + """Enables slicing by batch id in the given batch. + + If the input batch contains batches of sequences we need to make sure when + slicing happens it is sliced via batch id and not timestamp. Calling this + method enables the same flag on each SampleBatch within the input + MultiAgentBatch. + + Args: + batch: The MultiAgentBatch to enable slicing by batch id on. + value: The value to set the flag to. + + Returns: + The input MultiAgentBatch with the indexing flag is enabled / disabled on. + """ + + for pid, policy_batch in batch.policy_batches.items(): + # We assume that arriving batches for recurrent modules OR batches that + # have a SEQ_LENS column are already zero-padded to the max sequence length + # and have tensors of shape [B, T, ...]. Therefore, we slice sequence + # lengths in B. See SampleBatch for more information. + if ( + self.module[pid].is_stateful() + or policy_batch.get("seq_lens") is not None + ): + if value: + policy_batch.enable_slicing_by_batch_id() + else: + policy_batch.disable_slicing_by_batch_id() + + return batch + + def _make_module(self) -> MultiRLModule: + """Construct the multi-agent RL module for the learner. + + This method uses `self._module_specs` or `self._module_obj` to construct the + module. If the module_class is a single agent RL module it will be wrapped to a + multi-agent RL module. Override this method if there are other things that + need to happen for instantiation of the module. + + Returns: + A constructed MultiRLModule. + """ + # Module was provided directly through constructor -> Use as-is. + if self._module_obj is not None: + module = self._module_obj + self._module_spec = MultiRLModuleSpec.from_module(module) + # RLModuleSpec was provided directly through constructor -> Use it to build the + # RLModule. + elif self._module_spec is not None: + module = self._module_spec.build() + # Try using our config object. Note that this would only work if the config + # object has all the necessary space information already in it. + else: + module = self.config.get_multi_rl_module_spec().build() + + # If not already, convert to MultiRLModule. + module = module.as_multi_rl_module() + + return module + + def rl_module_is_compatible(self, module: RLModule) -> bool: + """Check whether the given `module` is compatible with this Learner. + + The default implementation checks the Learner-required APIs and whether the + given `module` implements all of them (if not, returns False). + + Args: + module: The RLModule to check. + + Returns: + True if the module is compatible with this Learner. + """ + return all(isinstance(module, api) for api in self.rl_module_required_apis()) + + @classmethod + def rl_module_required_apis(cls) -> list[type]: + """Returns the required APIs for an RLModule to be compatible with this Learner. + + The returned values may or may not be used inside the `rl_module_is_compatible` + method. + + Args: + module: The RLModule to check. + + Returns: + A list of RLModule API classes that an RLModule must implement in order + to be compatible with this Learner. + """ + return [] + + def _check_registered_optimizer( + self, + optimizer: Optimizer, + params: Sequence[Param], + ) -> None: + """Checks that the given optimizer and parameters are valid for the framework. + + Args: + optimizer: The optimizer object to check. + params: The list of parameters to check. + """ + if not isinstance(params, list): + raise ValueError( + f"`params` ({params}) must be a list of framework-specific parameters " + "(variables)!" + ) + + def _log_trainable_parameters(self) -> None: + """Logs the number of trainable and non-trainable parameters to self.metrics. + + Use MetricsLogger (self.metrics) tuple-keys: + (ALL_MODULES, NUM_TRAINABLE_PARAMETERS) and + (ALL_MODULES, NUM_NON_TRAINABLE_PARAMETERS) with EMA. + """ + pass + + def _check_is_built(self, error: bool = True) -> bool: + if self.module is None: + if error: + raise ValueError( + "Learner.build() must be called after constructing a " + "Learner and before calling any methods on it." + ) + return False + return True + + def _reset(self): + self._params = {} + self._optimizer_parameters = {} + self._named_optimizers = {} + self._module_optimizers = defaultdict(list) + self._optimizer_lr_schedules = {} + self.metrics = MetricsLogger() + self._is_built = False + + def apply(self, func, *_args, **_kwargs): + return func(self, *_args, **_kwargs) + + @abc.abstractmethod + def _get_tensor_variable( + self, + value: Any, + dtype: Any = None, + trainable: bool = False, + ) -> TensorType: + """Returns a framework-specific tensor variable with the initial given value. + + This is a framework specific method that should be implemented by the + framework specific sub-classes. + + Args: + value: The initial value for the tensor variable variable. + + Returns: + The framework specific tensor variable of the given initial value, + dtype and trainable/requires_grad property. + """ + + @staticmethod + @abc.abstractmethod + def _get_optimizer_lr(optimizer: Optimizer) -> float: + """Returns the current learning rate of the given local optimizer. + + Args: + optimizer: The local optimizer to get the current learning rate for. + + Returns: + The learning rate value (float) of the given optimizer. + """ + + @staticmethod + @abc.abstractmethod + def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None: + """Updates the learning rate of the given local optimizer. + + Args: + optimizer: The local optimizer to update the learning rate for. + lr: The new learning rate. + """ + + @staticmethod + @abc.abstractmethod + def _get_clip_function() -> Callable: + """Returns the gradient clipping function to use, given the framework.""" + + @staticmethod + @abc.abstractmethod + def _get_global_norm_function() -> Callable: + """Returns the global norm function to use, given the framework.""" + + def _log_steps_trained_metrics(self, batch: MultiAgentBatch): + """Logs this iteration's steps trained, based on given `batch`.""" + for mid, module_batch in batch.policy_batches.items(): + module_batch_size = len(module_batch) + # Log average batch size (for each module). + self.metrics.log_value( + key=(mid, MODULE_TRAIN_BATCH_SIZE_MEAN), + value=module_batch_size, + ) + # Log module steps (for each module). + self.metrics.log_value( + key=(mid, NUM_MODULE_STEPS_TRAINED), + value=module_batch_size, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME), + value=module_batch_size, + reduce="sum", + ) + # Log module steps (sum of all modules). + self.metrics.log_value( + key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED), + value=module_batch_size, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME), + value=module_batch_size, + reduce="sum", + ) + # Log env steps (all modules). + self.metrics.log_value( + (ALL_MODULES, NUM_ENV_STEPS_TRAINED), + batch.env_steps(), + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME), + batch.env_steps(), + reduce="sum", + with_throughput=True, + ) + + @Deprecated( + new="Learner.before_gradient_based_update(" + "timesteps={'num_env_steps_sampled_lifetime': ...}) and/or " + "Learner.after_gradient_based_update(" + "timesteps={'num_env_steps_sampled_lifetime': ...})", + error=True, + ) + def additional_update_for_module(self, *args, **kwargs): + pass + + @Deprecated(new="Learner.save_to_path(...)", error=True) + def save_state(self, *args, **kwargs): + pass + + @Deprecated(new="Learner.restore_from_path(...)", error=True) + def load_state(self, *args, **kwargs): + pass + + @Deprecated(new="Learner.module.get_state()", error=True) + def get_module_state(self, *args, **kwargs): + pass + + @Deprecated(new="Learner.module.set_state()", error=True) + def set_module_state(self, *args, **kwargs): + pass + + @Deprecated(new="Learner._get_optimizer_state()", error=True) + def get_optimizer_state(self, *args, **kwargs): + pass + + @Deprecated(new="Learner._set_optimizer_state()", error=True) + def set_optimizer_state(self, *args, **kwargs): + pass + + @Deprecated(new="Learner.compute_losses(...)", error=False) + def compute_loss(self, *args, **kwargs): + losses_per_module = self.compute_losses(*args, **kwargs) + # To continue supporting the old `compute_loss` behavior (instead of + # the new `compute_losses`, add the ALL_MODULES key here holding the sum + # of all individual loss terms. + if ALL_MODULES not in losses_per_module: + losses_per_module[ALL_MODULES] = sum(losses_per_module.values()) + return losses_per_module diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef6abb3748d1c6e70bd544365ac837748b81c97 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/learner_group.py @@ -0,0 +1,1030 @@ +import pathlib +from collections import defaultdict, Counter +import copy +from functools import partial +import itertools +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Set, + Type, + TYPE_CHECKING, + Union, +) + +import ray +from ray import ObjectRef +from ray.rllib.core import ( + COMPONENT_LEARNER, + COMPONENT_RL_MODULE, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.actor_manager import ( + FaultTolerantActorManager, + RemoteCallResults, + ResultOrError, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.metrics import ALL_MODULES +from ray.rllib.utils.minibatch_utils import ( + ShardBatchIterator, + ShardEpisodesIterator, + ShardObjectRefIterator, +) +from ray.rllib.utils.typing import ( + EpisodeType, + ModuleID, + RLModuleSpecType, + ShouldModuleBeUpdatedFn, + StateDict, + T, +) +from ray.train._internal.backend_executor import BackendExecutor +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + +def _get_backend_config(learner_class: Type[Learner]) -> str: + if learner_class.framework == "torch": + from ray.train.torch import TorchConfig + + backend_config = TorchConfig() + elif learner_class.framework == "tf2": + from ray.train.tensorflow import TensorflowConfig + + backend_config = TensorflowConfig() + else: + raise ValueError( + "`learner_class.framework` must be either 'torch' or 'tf2' (but is " + f"{learner_class.framework}!" + ) + + return backend_config + + +@PublicAPI(stability="alpha") +class LearnerGroup(Checkpointable): + """Coordinator of n (possibly remote) Learner workers. + + Each Learner worker has a copy of the RLModule, the loss function(s), and + one or more optimizers. + """ + + def __init__( + self, + *, + config: "AlgorithmConfig", + # TODO (sven): Rename into `rl_module_spec`. + module_spec: Optional[RLModuleSpecType] = None, + ): + """Initializes a LearnerGroup instance. + + Args: + config: The AlgorithmConfig object to use to configure this LearnerGroup. + Call the `learners(num_learners=...)` method on your config to + specify the number of learner workers to use. + Call the same method with arguments `num_cpus_per_learner` and/or + `num_gpus_per_learner` to configure the compute used by each + Learner worker in this LearnerGroup. + Call the `training(learner_class=...)` method on your config to specify, + which exact Learner class to use. + Call the `rl_module(rl_module_spec=...)` method on your config to set up + the specifics for your RLModule to be used in each Learner. + module_spec: If not already specified in `config`, a separate overriding + RLModuleSpec may be provided via this argument. + """ + self.config = config.copy(copy_frozen=False) + self._module_spec = module_spec + + learner_class = self.config.learner_class + module_spec = module_spec or self.config.get_multi_rl_module_spec() + + self._learner = None + self._workers = None + # If a user calls self.shutdown() on their own then this flag is set to true. + # When del is called the backend executor isn't shutdown twice if this flag is + # true. the backend executor would otherwise log a warning to the console from + # ray train. + self._is_shut_down = False + + # How many timesteps had to be dropped due to a full input queue? + self._ts_dropped = 0 + + # A single local Learner. + if not self.is_remote: + self._learner = learner_class(config=config, module_spec=module_spec) + self._learner.build() + self._worker_manager = None + # N remote Learner workers. + else: + backend_config = _get_backend_config(learner_class) + + # TODO (sven): Can't set both `num_cpus_per_learner`>1 and + # `num_gpus_per_learner`>0! Users must set one or the other due + # to issues with placement group fragmentation. See + # https://github.com/ray-project/ray/issues/35409 for more details. + num_cpus_per_learner = ( + self.config.num_cpus_per_learner + if not self.config.num_gpus_per_learner + else 0 + ) + num_gpus_per_learner = max( + 0, + self.config.num_gpus_per_learner + - (0.01 * self.config.num_aggregator_actors_per_learner), + ) + resources_per_learner = { + "CPU": num_cpus_per_learner, + "GPU": num_gpus_per_learner, + } + + backend_executor = BackendExecutor( + backend_config=backend_config, + num_workers=self.config.num_learners, + resources_per_worker=resources_per_learner, + max_retries=0, + ) + backend_executor.start( + train_cls=learner_class, + train_cls_kwargs={ + "config": config, + "module_spec": module_spec, + }, + ) + self._backend_executor = backend_executor + + self._workers = [w.actor for w in backend_executor.worker_group.workers] + + # Run the neural network building code on remote workers. + ray.get([w.build.remote() for w in self._workers]) + + self._worker_manager = FaultTolerantActorManager( + self._workers, + max_remote_requests_in_flight_per_actor=( + self.config.max_requests_in_flight_per_learner + ), + ) + # Counters for the tags for asynchronous update requests that are + # in-flight. Used for keeping trakc of and grouping together the results of + # requests that were sent to the workers at the same time. + self._update_request_tags = Counter() + self._update_request_tag = 0 + self._update_request_results = {} + + # TODO (sven): Replace this with call to `self.metrics.peek()`? + # Currently LearnerGroup does not have a metrics object. + def get_stats(self) -> Dict[str, Any]: + """Returns the current stats for the input queue for this learner group.""" + return { + "learner_group_ts_dropped": self._ts_dropped, + "actor_manager_num_outstanding_async_reqs": ( + 0 + if self.is_local + else self._worker_manager.num_outstanding_async_reqs() + ), + } + + @property + def is_remote(self) -> bool: + return self.config.num_learners > 0 + + @property + def is_local(self) -> bool: + return not self.is_remote + + def update_from_batch( + self, + batch: MultiAgentBatch, + *, + timesteps: Optional[Dict[str, Any]] = None, + async_update: bool = False, + return_state: bool = False, + num_epochs: int = 1, + minibatch_size: Optional[int] = None, + shuffle_batch_per_epoch: bool = False, + # User kwargs. + **kwargs, + ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + """Performs gradient based update(s) on the Learner(s), based on given batch. + + Args: + batch: A data batch to use for the update. If there are more + than one Learner workers, the batch is split amongst these and one + shard is sent to each Learner. + async_update: Whether the update request(s) to the Learner workers should be + sent asynchronously. If True, will return NOT the results from the + update on the given data, but all results from prior asynchronous update + requests that have not been returned thus far. + return_state: Whether to include one of the Learner worker's state from + after the update step in the returned results dict (under the + `_rl_module_state_after_update` key). Note that after an update, all + Learner workers' states should be identical, so we use the first + Learner's state here. Useful for avoiding an extra `get_weights()` call, + e.g. for synchronizing EnvRunner weights. + num_epochs: The number of complete passes over the entire train batch. Each + pass might be further split into n minibatches (if `minibatch_size` + provided). + minibatch_size: The size of minibatches to use to further split the train + `batch` into sub-batches. The `batch` is then iterated over n times + where n is `len(batch) // minibatch_size`. + shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch. + If the train batch has a time rank (axis=1), shuffling will only take + place along the batch axis to not disturb any intact (episode) + trajectories. Also, shuffling is always skipped if `minibatch_size` is + None, meaning the entire train batch is processed each epoch, making it + unnecessary to shuffle. + + Returns: + If `async_update` is False, a dictionary with the reduced results of the + updates from the Learner(s) or a list of dictionaries of results from the + updates from the Learner(s). + If `async_update` is True, a list of list of dictionaries of results, where + the outer list corresponds to separate previous calls to this method, and + the inner list corresponds to the results from each Learner(s). Or if the + results are reduced, a list of dictionaries of the reduced results from each + call to async_update that is ready. + """ + return self._update( + batch=batch, + timesteps=timesteps, + async_update=async_update, + return_state=return_state, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + **kwargs, + ) + + def update_from_episodes( + self, + episodes: List[EpisodeType], + *, + timesteps: Optional[Dict[str, Any]] = None, + async_update: bool = False, + return_state: bool = False, + num_epochs: int = 1, + minibatch_size: Optional[int] = None, + shuffle_batch_per_epoch: bool = False, + # User kwargs. + **kwargs, + ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + """Performs gradient based update(s) on the Learner(s), based on given episodes. + + Args: + episodes: A list of Episodes to process and perform the update + for. If there are more than one Learner workers, the list of episodes + is split amongst these and one list shard is sent to each Learner. + async_update: Whether the update request(s) to the Learner workers should be + sent asynchronously. If True, will return NOT the results from the + update on the given data, but all results from prior asynchronous update + requests that have not been returned thus far. + return_state: Whether to include one of the Learner worker's state from + after the update step in the returned results dict (under the + `_rl_module_state_after_update` key). Note that after an update, all + Learner workers' states should be identical, so we use the first + Learner's state here. Useful for avoiding an extra `get_weights()` call, + e.g. for synchronizing EnvRunner weights. + num_epochs: The number of complete passes over the entire train batch. Each + pass might be further split into n minibatches (if `minibatch_size` + provided). The train batch is generated from the given `episodes` + through the Learner connector pipeline. + minibatch_size: The size of minibatches to use to further split the train + `batch` into sub-batches. The `batch` is then iterated over n times + where n is `len(batch) // minibatch_size`. The train batch is generated + from the given `episodes` through the Learner connector pipeline. + shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch. + If the train batch has a time rank (axis=1), shuffling will only take + place along the batch axis to not disturb any intact (episode) + trajectories. Also, shuffling is always skipped if `minibatch_size` is + None, meaning the entire train batch is processed each epoch, making it + unnecessary to shuffle. The train batch is generated from the given + `episodes` through the Learner connector pipeline. + + Returns: + If async_update is False, a dictionary with the reduced results of the + updates from the Learner(s) or a list of dictionaries of results from the + updates from the Learner(s). + If async_update is True, a list of list of dictionaries of results, where + the outer list corresponds to separate previous calls to this method, and + the inner list corresponds to the results from each Learner(s). Or if the + results are reduced, a list of dictionaries of the reduced results from each + call to async_update that is ready. + """ + return self._update( + episodes=episodes, + timesteps=timesteps, + async_update=async_update, + return_state=return_state, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + **kwargs, + ) + + def _update( + self, + *, + batch: Optional[MultiAgentBatch] = None, + episodes: Optional[List[EpisodeType]] = None, + timesteps: Optional[Dict[str, Any]] = None, + async_update: bool = False, + return_state: bool = False, + num_epochs: int = 1, + num_iters: int = 1, + minibatch_size: Optional[int] = None, + shuffle_batch_per_epoch: bool = False, + **kwargs, + ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + + # Define function to be called on all Learner actors (or the local learner). + def _learner_update( + _learner: Learner, + *, + _batch_shard=None, + _episodes_shard=None, + _timesteps=None, + _return_state=False, + _num_total_minibatches=0, + **_kwargs, + ): + # If the batch shard is an `DataIterator` we have an offline + # multi-learner setup and `update_from_iterator` needs to + # handle updating. + if isinstance(_batch_shard, ray.data.DataIterator): + result = _learner.update_from_iterator( + iterator=_batch_shard, + timesteps=_timesteps, + minibatch_size=minibatch_size, + num_iters=num_iters, + **_kwargs, + ) + elif _batch_shard is not None: + result = _learner.update_from_batch( + batch=_batch_shard, + timesteps=_timesteps, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + **_kwargs, + ) + else: + result = _learner.update_from_episodes( + episodes=_episodes_shard, + timesteps=_timesteps, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + num_total_minibatches=_num_total_minibatches, + **_kwargs, + ) + if _return_state and result: + result["_rl_module_state_after_update"] = _learner.get_state( + # Only return the state of those RLModules that actually returned + # results and thus got probably updated. + components=[ + COMPONENT_RL_MODULE + "/" + mid + for mid in result + if mid != ALL_MODULES + ], + inference_only=True, + ) + + return result + + # Local Learner worker: Don't shard batch/episodes, just run data as-is through + # this Learner. + if self.is_local: + if async_update: + raise ValueError( + "Cannot call `update_from_batch(async_update=True)` when running in" + " local mode! Try setting `config.num_learners > 0`." + ) + + if isinstance(batch, list) and isinstance(batch[0], ray.ObjectRef): + assert len(batch) == 1 + batch = ray.get(batch[0]) + + results = [ + _learner_update( + _learner=self._learner, + _batch_shard=batch, + _episodes_shard=episodes, + _timesteps=timesteps, + _return_state=return_state, + **kwargs, + ) + ] + # One or more remote Learners: Shard batch/episodes into equal pieces (roughly + # equal if multi-agent AND episodes) and send each Learner worker one of these + # shards. + else: + # MultiAgentBatch: Shard into equal pieces. + # TODO (sven): The sharder used here destroys - for multi-agent only - + # the relationship of the different agents' timesteps to each other. + # Thus, in case the algorithm requires agent-synchronized data (aka. + # "lockstep"), the `ShardBatchIterator` should not be used. + # Then again, we might move into a world where Learner always + # receives Episodes, never batches. + if isinstance(batch, list) and isinstance(batch[0], ray.data.DataIterator): + partials = [ + partial( + _learner_update, + _batch_shard=iterator, + _return_state=(return_state and i == 0), + _timesteps=timesteps, + **kwargs, + ) + # Note, `OfflineData` defines exactly as many iterators as there + # are learners. + for i, iterator in enumerate(batch) + ] + elif isinstance(batch, list) and isinstance(batch[0], ObjectRef): + assert len(batch) == len(self._workers) + partials = [ + partial( + _learner_update, + _batch_shard=batch_shard, + _timesteps=timesteps, + _return_state=(return_state and i == 0), + **kwargs, + ) + for i, batch_shard in enumerate(batch) + ] + elif batch is not None: + partials = [ + partial( + _learner_update, + _batch_shard=batch_shard, + _return_state=(return_state and i == 0), + _timesteps=timesteps, + **kwargs, + ) + for i, batch_shard in enumerate( + ShardBatchIterator(batch, len(self._workers)) + ) + ] + elif isinstance(episodes, list) and isinstance(episodes[0], ObjectRef): + partials = [ + partial( + _learner_update, + _episodes_shard=episodes_shard, + _timesteps=timesteps, + _return_state=(return_state and i == 0), + **kwargs, + ) + for i, episodes_shard in enumerate( + ShardObjectRefIterator(episodes, len(self._workers)) + ) + ] + # Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal + # in case of multi-agent). + else: + from ray.data.iterator import DataIterator + + if isinstance(episodes[0], DataIterator): + num_total_minibatches = 0 + partials = [ + partial( + _learner_update, + _episodes_shard=episodes_shard, + _timesteps=timesteps, + _num_total_minibatches=num_total_minibatches, + ) + for episodes_shard in episodes + ] + else: + eps_shards = list( + ShardEpisodesIterator( + episodes, + len(self._workers), + len_lookback_buffer=self.config.episode_lookback_horizon, + ) + ) + # In the multi-agent case AND `minibatch_size` AND num_workers + # > 1, we compute a max iteration counter such that the different + # Learners will not go through a different number of iterations. + num_total_minibatches = 0 + if minibatch_size and len(self._workers) > 1: + num_total_minibatches = self._compute_num_total_minibatches( + episodes, + len(self._workers), + minibatch_size, + num_epochs, + ) + partials = [ + partial( + _learner_update, + _episodes_shard=eps_shard, + _timesteps=timesteps, + _num_total_minibatches=num_total_minibatches, + ) + for eps_shard in eps_shards + ] + + if async_update: + # Retrieve all ready results (kicked off by prior calls to this method). + tags_to_get = [] + for tag in self._update_request_tags.keys(): + result = self._worker_manager.fetch_ready_async_reqs( + tags=[str(tag)], timeout_seconds=0.0 + ) + if tag not in self._update_request_results: + self._update_request_results[tag] = result + else: + for r in result: + self._update_request_results[tag].add_result( + r.actor_id, r.result_or_error, tag + ) + + # Still not done with this `tag` -> skip out early. + if ( + self._update_request_tags[tag] + > len(self._update_request_results[tag].result_or_errors) + > 0 + ): + break + tags_to_get.append(tag) + + # Send out new request(s), if there is still capacity on the actors + # (each actor is allowed only some number of max in-flight requests + # at the same time). + update_tag = self._update_request_tag + self._update_request_tag += 1 + num_sent_requests = self._worker_manager.foreach_actor_async( + partials, tag=str(update_tag) + ) + if num_sent_requests: + self._update_request_tags[update_tag] = num_sent_requests + + # Some requests were dropped, record lost ts/data. + if num_sent_requests != len(self._workers): + factor = 1 - (num_sent_requests / len(self._workers)) + # Batch: Measure its length. + if episodes is None: + dropped = len(batch) + # List of Ray ObjectRefs (each object ref is a list of episodes of + # total len=`rollout_fragment_length * num_envs_per_env_runner`) + elif isinstance(episodes[0], ObjectRef): + dropped = ( + len(episodes) + * self.config.get_rollout_fragment_length() + * self.config.num_envs_per_env_runner + ) + else: + dropped = sum(len(e) for e in episodes) + + self._ts_dropped += factor * dropped + + # NOTE: There is a strong assumption here that the requests launched to + # learner workers will return at the same time, since they have a + # barrier inside for gradient aggregation. Therefore, results should be + # a list of lists where each inner list should be the length of the + # number of learner workers, if results from an non-blocking update are + # ready. + results = self._get_async_results(tags_to_get) + + else: + results = self._get_results( + self._worker_manager.foreach_actor(partials) + ) + + return results + + # TODO (sven): Move this into FaultTolerantActorManager? + def _get_results(self, results): + processed_results = [] + for result in results: + result_or_error = result.get() + if result.ok: + processed_results.append(result_or_error) + else: + raise result_or_error + return processed_results + + def _get_async_results(self, tags_to_get): + """Get results from the worker manager and group them by tag. + + Returns: + A list of lists of results, where each inner list contains all results + for same tags. + + """ + unprocessed_results = defaultdict(list) + for tag in tags_to_get: + results = self._update_request_results[tag] + for result in results: + result_or_error = result.get() + if result.ok: + if result.tag is None: + raise RuntimeError( + "Cannot call `LearnerGroup._get_async_results()` on " + "untagged async requests!" + ) + tag = int(result.tag) + unprocessed_results[tag].append(result_or_error) + + if tag in self._update_request_tags: + self._update_request_tags[tag] -= 1 + if self._update_request_tags[tag] == 0: + del self._update_request_tags[tag] + del self._update_request_results[tag] + else: + assert False + + else: + raise result_or_error + + return list(unprocessed_results.values()) + + def add_module( + self, + *, + module_id: ModuleID, + module_spec: RLModuleSpec, + config_overrides: Optional[Dict] = None, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> MultiRLModuleSpec: + """Adds a module to the underlying MultiRLModule. + + Changes this Learner's config in order to make this architectural change + permanent wrt. to checkpointing. + + Args: + module_id: The ModuleID of the module to be added. + module_spec: The ModuleSpec of the module to be added. + config_overrides: The `AlgorithmConfig` overrides that should apply to + the new Module, if any. + new_should_module_be_updated: An optional sequence of ModuleIDs or a + callable taking ModuleID and SampleBatchType and returning whether the + ModuleID should be updated (trained). + If None, will keep the existing setup in place. RLModules, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + + Returns: + The new MultiRLModuleSpec (after the change has been performed). + """ + validate_module_id(module_id, error=True) + + # Force-set inference-only = False. + module_spec = copy.deepcopy(module_spec) + module_spec.inference_only = False + + results = self.foreach_learner( + func=lambda _learner: _learner.add_module( + module_id=module_id, + module_spec=module_spec, + config_overrides=config_overrides, + new_should_module_be_updated=new_should_module_be_updated, + ), + ) + marl_spec = self._get_results(results)[0] + + # Change our config (AlgorithmConfig) to contain the new Module. + # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, + # but we'll deprecate config.policies soon anyway. + self.config.policies[module_id] = PolicySpec() + if config_overrides is not None: + self.config.multi_agent( + algorithm_config_overrides_per_module={module_id: config_overrides} + ) + self.config.rl_module(rl_module_spec=marl_spec) + if new_should_module_be_updated is not None: + self.config.multi_agent(policies_to_train=new_should_module_be_updated) + + return marl_spec + + def remove_module( + self, + module_id: ModuleID, + *, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> MultiRLModuleSpec: + """Removes a module from the Learner. + + Args: + module_id: The ModuleID of the module to be removed. + new_should_module_be_updated: An optional sequence of ModuleIDs or a + callable taking ModuleID and SampleBatchType and returning whether the + ModuleID should be updated (trained). + If None, will keep the existing setup in place. RLModules, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + + Returns: + The new MultiRLModuleSpec (after the change has been performed). + """ + results = self.foreach_learner( + func=lambda _learner: _learner.remove_module( + module_id=module_id, + new_should_module_be_updated=new_should_module_be_updated, + ), + ) + marl_spec = self._get_results(results)[0] + + # Change self.config to reflect the new architecture. + # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, + # but we'll deprecate config.policies soon anyway. + del self.config.policies[module_id] + self.config.algorithm_config_overrides_per_module.pop(module_id, None) + if new_should_module_be_updated is not None: + self.config.multi_agent(policies_to_train=new_should_module_be_updated) + self.config.rl_module(rl_module_spec=marl_spec) + + return marl_spec + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + state = {} + + if self._check_component(COMPONENT_LEARNER, components, not_components): + if self.is_local: + state[COMPONENT_LEARNER] = self._learner.get_state( + components=self._get_subcomponents(COMPONENT_LEARNER, components), + not_components=self._get_subcomponents( + COMPONENT_LEARNER, not_components + ), + **kwargs, + ) + else: + worker = self._worker_manager.healthy_actor_ids()[0] + assert len(self._workers) == self._worker_manager.num_healthy_actors() + _comps = self._get_subcomponents(COMPONENT_LEARNER, components) + _not_comps = self._get_subcomponents(COMPONENT_LEARNER, not_components) + results = self._worker_manager.foreach_actor( + lambda w: w.get_state(_comps, not_components=_not_comps, **kwargs), + remote_actor_ids=[worker], + ) + state[COMPONENT_LEARNER] = self._get_results(results)[0] + + return state + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + if COMPONENT_LEARNER in state: + if self.is_local: + self._learner.set_state(state[COMPONENT_LEARNER]) + else: + state_ref = ray.put(state[COMPONENT_LEARNER]) + self.foreach_learner( + lambda _learner, _ref=state_ref: _learner.set_state(ray.get(_ref)) + ) + + def get_weights( + self, module_ids: Optional[Collection[ModuleID]] = None + ) -> StateDict: + """Convenience method instead of self.get_state(components=...). + + Args: + module_ids: An optional collection of ModuleIDs for which to return weights. + If None (default), return weights of all RLModules. + + Returns: + The results of + `self.get_state(components='learner/rl_module')['learner']['rl_module']`. + """ + # Return the entire RLModule state (all possible single-agent RLModules). + if module_ids is None: + components = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + # Return a subset of the single-agent RLModules. + else: + components = [ + "".join(tup) + for tup in itertools.product( + [COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/"], + list(module_ids), + ) + ] + state = self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE] + return state + + def set_weights(self, weights) -> None: + """Convenience method instead of self.set_state({'learner': {'rl_module': ..}}). + + Args: + weights: The weights dict of the MultiRLModule of a Learner inside this + LearnerGroup. + """ + self.set_state({COMPONENT_LEARNER: {COMPONENT_RL_MODULE: weights}}) + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args + { + "config": self.config, + "module_spec": self._module_spec, + }, # **kwargs + ) + + @override(Checkpointable) + def get_checkpointable_components(self): + # Return the entire ActorManager, if remote. Otherwise, return the + # local worker. Also, don't give the component (Learner) a name ("") + # as it's the only component in this LearnerGroup to be saved. + return [ + ( + COMPONENT_LEARNER, + self._learner if self.is_local else self._worker_manager, + ) + ] + + def foreach_learner( + self, + func: Callable[[Learner, Optional[Any]], T], + *, + healthy_only: bool = True, + remote_actor_ids: List[int] = None, + timeout_seconds: Optional[float] = None, + return_obj_refs: bool = False, + mark_healthy: bool = False, + **kwargs, + ) -> RemoteCallResults: + """Calls the given function on each Learner L with the args: (L, \*\*kwargs). + + Args: + func: The function to call on each Learner L with args: (L, \*\*kwargs). + healthy_only: If True, applies `func` only to Learner actors currently + tagged "healthy", otherwise to all actors. If `healthy_only=False` and + `mark_healthy=True`, will send `func` to all actors and mark those + actors "healthy" that respond to the request within `timeout_seconds` + and are currently tagged as "unhealthy". + remote_actor_ids: Apply func on a selected set of remote actors. Use None + (default) for all actors. + timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for + fire-and-forget. Set this to None (default) to wait infinitely (i.e. for + synchronous execution). + return_obj_refs: whether to return ObjectRef instead of actual results. + Note, for fault tolerance reasons, these returned ObjectRefs should + never be resolved with ray.get() outside of the context of this manager. + mark_healthy: Whether to mark all those actors healthy again that are + currently marked unhealthy AND that returned results from the remote + call (within the given `timeout_seconds`). + Note that actors are NOT set unhealthy, if they simply time out + (only if they return a RayActorError). + Also not that this setting is ignored if `healthy_only=True` (b/c this + setting only affects actors that are currently tagged as unhealthy). + + Returns: + A list of size len(Learners) with the return values of all calls to `func`. + """ + if self.is_local: + results = RemoteCallResults() + results.add_result( + None, + ResultOrError(result=func(self._learner, **kwargs)), + None, + ) + return results + + return self._worker_manager.foreach_actor( + func=partial(func, **kwargs), + healthy_only=healthy_only, + remote_actor_ids=remote_actor_ids, + timeout_seconds=timeout_seconds, + return_obj_refs=return_obj_refs, + mark_healthy=mark_healthy, + ) + + def shutdown(self): + """Shuts down the LearnerGroup.""" + if self.is_remote and hasattr(self, "_backend_executor"): + self._backend_executor.shutdown() + self._is_shut_down = True + + def __del__(self): + if not self._is_shut_down: + self.shutdown() + + @staticmethod + def _compute_num_total_minibatches( + episodes, + num_shards, + minibatch_size, + num_epochs, + ): + # Count total number of timesteps per module ID. + if isinstance(episodes[0], MultiAgentEpisode): + per_mod_ts = defaultdict(int) + for ma_episode in episodes: + for sa_episode in ma_episode.agent_episodes.values(): + per_mod_ts[sa_episode.module_id] += len(sa_episode) + max_ts = max(per_mod_ts.values()) + else: + max_ts = sum(map(len, episodes)) + + return int((num_epochs * max_ts) / (num_shards * minibatch_size)) + + @Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False) + def update(self, *args, **kwargs): + # Just in case, we would like to revert this API retirement, we can do so + # easily. + return self._update(*args, **kwargs, async_update=False) + + @Deprecated(new="LearnerGroup.update_from_batch(async=True)", error=False) + def async_update(self, *args, **kwargs): + # Just in case, we would like to revert this API retirement, we can do so + # easily. + return self._update(*args, **kwargs, async_update=True) + + @Deprecated(new="LearnerGroup.save_to_path(...)", error=True) + def save_state(self, *args, **kwargs): + pass + + @Deprecated(new="LearnerGroup.restore_from_path(...)", error=True) + def load_state(self, *args, **kwargs): + pass + + @Deprecated(new="LearnerGroup.load_from_path(path=..., component=...)", error=False) + def load_module_state( + self, + *, + multi_rl_module_ckpt_dir: Optional[str] = None, + modules_to_load: Optional[Set[str]] = None, + rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None, + ) -> None: + """Load the checkpoints of the modules being trained by this LearnerGroup. + + `load_module_state` can be used 3 ways: + 1. Load a checkpoint for the MultiRLModule being trained by this + LearnerGroup. Limit the modules that are loaded from the checkpoint + by specifying the `modules_to_load` argument. + 2. Load the checkpoint(s) for single agent RLModules that + are in the MultiRLModule being trained by this LearnerGroup. + 3. Load a checkpoint for the MultiRLModule being trained by this + LearnerGroup and load the checkpoint(s) for single agent RLModules + that are in the MultiRLModule. The checkpoints for the single + agent RLModules take precedence over the module states in the + MultiRLModule checkpoint. + + NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is + must be specified. modules_to_load can only be specified if + multi_rl_module_ckpt_dir is specified. + + Args: + multi_rl_module_ckpt_dir: The path to the checkpoint for the + MultiRLModule. + modules_to_load: A set of module ids to load from the checkpoint. + rl_module_ckpt_dirs: A mapping from module ids to the path to a + checkpoint for a single agent RLModule. + """ + if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs): + raise ValueError( + "At least one of `multi_rl_module_ckpt_dir` or " + "`rl_module_ckpt_dirs` must be provided!" + ) + if multi_rl_module_ckpt_dir: + multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir) + if rl_module_ckpt_dirs: + for module_id, path in rl_module_ckpt_dirs.items(): + rl_module_ckpt_dirs[module_id] = pathlib.Path(path) + + # MultiRLModule checkpoint is provided. + if multi_rl_module_ckpt_dir: + # Restore the entire MultiRLModule state. + if modules_to_load is None: + self.restore_from_path( + multi_rl_module_ckpt_dir, + component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, + ) + # Restore individual module IDs. + else: + for module_id in modules_to_load: + self.restore_from_path( + multi_rl_module_ckpt_dir / module_id, + component=( + COMPONENT_LEARNER + + "/" + + COMPONENT_RL_MODULE + + "/" + + module_id + ), + ) + if rl_module_ckpt_dirs: + for module_id, path in rl_module_ckpt_dirs.items(): + self.restore_from_path( + path, + component=( + COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id + ), + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca07a6b45144c8a8caa48fff807c73a7aee235c3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c87da11cd7ba4f75d65693d6a132e24f2a9dd8e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/__pycache__/tf_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..4c8fd29cada89032ed7087a05615e8f6a05aa839 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/tf/tf_learner.py @@ -0,0 +1,357 @@ +import logging +import pathlib +from typing import ( + Any, + Callable, + Dict, + Hashable, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import ( + RLModule, + RLModuleSpec, +) +from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule +from ray.rllib.policy.eager_tf_policy import _convert_to_tf +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import ( + ModuleID, + Optimizer, + Param, + ParamDict, + StateDict, + TensorType, +) + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +class TfLearner(Learner): + + framework: str = "tf2" + + def __init__(self, **kwargs): + # by default in rllib we disable tf2 behavior + # This call re-enables it as it is needed for using + # this class. + try: + tf1.enable_v2_behavior() + except ValueError: + # This is a hack to avoid the error that happens when calling + # enable_v2_behavior after variables have already been created. + pass + + super().__init__(**kwargs) + + self._enable_tf_function = self.config.eager_tracing + + # This is a placeholder which will be filled by + # `_make_distributed_strategy_if_necessary`. + self._strategy: tf.distribute.Strategy = None + + @OverrideToImplementCustomLogic + @override(Learner) + def configure_optimizers_for_module( + self, module_id: ModuleID, config: "AlgorithmConfig" = None + ) -> None: + module = self._module[module_id] + + # For this default implementation, the learning rate is handled by the + # attached lr Scheduler (controlled by self.config.lr, which can be a + # fixed value or a schedule setting). + optimizer = tf.keras.optimizers.Adam() + params = self.get_parameters(module) + + # This isn't strictly necessary, but makes it so that if a checkpoint is + # computed before training actually starts, then it will be the same in + # shape / size as a checkpoint after training starts. + optimizer.build(module.trainable_variables) + + # Register the created optimizer (under the default optimizer name). + self.register_optimizer( + module_id=module_id, + optimizer=optimizer, + params=params, + lr_or_lr_schedule=config.lr, + ) + + @override(Learner) + def compute_gradients( + self, + loss_per_module: Dict[str, TensorType], + gradient_tape: "tf.GradientTape", + **kwargs, + ) -> ParamDict: + total_loss = sum(loss_per_module.values()) + grads = gradient_tape.gradient(total_loss, self._params) + return grads + + @override(Learner) + def apply_gradients(self, gradients_dict: ParamDict) -> None: + # TODO (Avnishn, kourosh): apply gradients doesn't work in cases where + # only some agents have a sample batch that is passed but not others. + # This is probably because of the way that we are iterating over the + # parameters in the optim_to_param_dictionary. + for optimizer in self._optimizer_parameters: + optim_grad_dict = self.filter_param_dict_for_optimizer( + optimizer=optimizer, param_dict=gradients_dict + ) + variable_list = [] + gradient_list = [] + for param_ref, grad in optim_grad_dict.items(): + if grad is not None: + variable_list.append(self._params[param_ref]) + gradient_list.append(grad) + optimizer.apply_gradients(zip(gradient_list, variable_list)) + + @override(Learner) + def restore_from_path(self, path: Union[str, pathlib.Path]) -> None: + # This operation is potentially very costly because a MultiRLModule is created + # at build time, destroyed, and then a new one is created from a checkpoint. + # However, it is necessary due to complications with the way that Ray Tune + # restores failed trials. When Tune restores a failed trial, it reconstructs the + # entire experiment from the initial config. Therefore, to reflect any changes + # made to the learner's modules, the module created by Tune is destroyed and + # then rebuilt from the checkpoint. + with self._strategy.scope(): + super().restore_from_path(path) + + @override(Learner) + def _get_optimizer_state(self) -> StateDict: + optim_state = {} + with tf.init_scope(): + for name, optim in self._named_optimizers.items(): + optim_state[name] = [var.numpy() for var in optim.variables()] + return optim_state + + @override(Learner) + def _set_optimizer_state(self, state: StateDict) -> None: + for name, state_array in state.items(): + if name not in self._named_optimizers: + raise ValueError( + f"Optimizer {name} in `state` is not known! " + f"Known optimizers are {self._named_optimizers.keys()}" + ) + optim = self._named_optimizers[name] + optim.set_weights(state_array) + + @override(Learner) + def get_param_ref(self, param: Param) -> Hashable: + return param.ref() + + @override(Learner) + def get_parameters(self, module: RLModule) -> Sequence[Param]: + return list(module.trainable_variables) + + @override(Learner) + def rl_module_is_compatible(self, module: RLModule) -> bool: + return isinstance(module, TfRLModule) + + @override(Learner) + def _check_registered_optimizer( + self, + optimizer: Optimizer, + params: Sequence[Param], + ) -> None: + super()._check_registered_optimizer(optimizer, params) + if not isinstance(optimizer, tf.keras.optimizers.Optimizer): + raise ValueError( + f"The optimizer ({optimizer}) is not a tf keras optimizer! " + "Only use tf.keras.optimizers.Optimizer subclasses for TfLearner." + ) + for param in params: + if not isinstance(param, tf.Variable): + raise ValueError( + f"One of the parameters ({param}) in the registered optimizer " + "is not a tf.Variable!" + ) + + @override(Learner) + def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch: + batch = _convert_to_tf(batch.policy_batches) + length = max(len(b) for b in batch.values()) + batch = MultiAgentBatch(batch, env_steps=length) + return batch + + @override(Learner) + def add_module( + self, + *, + module_id: ModuleID, + module_spec: RLModuleSpec, + ) -> None: + # TODO(Avnishn): + # WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead + # currently. We will be working on improving this in the future, but for now + # please wrap `call_for_each_replica` or `experimental_run` or `run` inside a + # tf.function to get the best performance. + # I get this warning any time I add a new module. I see the warning a few times + # and then it disappears. I think that I will need to open an issue with the TF + # team. + with self._strategy.scope(): + super().add_module( + module_id=module_id, + module_spec=module_spec, + ) + if self._enable_tf_function: + self._possibly_traced_update = tf.function( + self._untraced_update, reduce_retracing=True + ) + + @override(Learner) + def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec: + with self._strategy.scope(): + marl_spec = super().remove_module(module_id, **kwargs) + + if self._enable_tf_function: + self._possibly_traced_update = tf.function( + self._untraced_update, reduce_retracing=True + ) + + return marl_spec + + def _make_distributed_strategy_if_necessary(self) -> "tf.distribute.Strategy": + """Create a distributed strategy for the learner. + + A stratgey is a tensorflow object that is used for distributing training and + gradient computation across multiple devices. By default, a no-op strategy is + used that is not distributed. + + Returns: + A strategy for the learner to use for distributed training. + + """ + if self.config.num_learners > 1: + strategy = tf.distribute.MultiWorkerMirroredStrategy() + elif self.config.num_gpus_per_learner > 0: + # mirrored strategy is typically used for multi-gpu training + # on a single machine, however we can use it for single-gpu + devices = tf.config.list_logical_devices("GPU") + assert self.config.local_gpu_idx < len(devices), ( + f"local_gpu_idx {self.config.local_gpu_idx} is not a valid GPU id or " + "is not available." + ) + local_gpu = [devices[self.config.local_gpu_idx].name] + strategy = tf.distribute.MirroredStrategy(devices=local_gpu) + else: + # the default strategy is a no-op that can be used in the local mode + # cpu only case, build will override this if needed. + strategy = tf.distribute.get_strategy() + return strategy + + @override(Learner) + def build(self) -> None: + """Build the TfLearner. + + This method is specific TfLearner. Before running super() it sets the correct + distributing strategy with the right device, so that computational graph is + placed on the correct device. After running super(), depending on eager_tracing + flag it will decide whether to wrap the update function with tf.function or not. + """ + + # we call build anytime we make a learner, or load a learner from a checkpoint. + # we can't make a new strategy every time we build, so we only make one the + # first time build is called. + if not self._strategy: + self._strategy = self._make_distributed_strategy_if_necessary() + + with self._strategy.scope(): + super().build() + + if self._enable_tf_function: + self._possibly_traced_update = tf.function( + self._untraced_update, reduce_retracing=True + ) + else: + self._possibly_traced_update = self._untraced_update + + @override(Learner) + def _update(self, batch: Dict) -> Tuple[Any, Any, Any]: + return self._possibly_traced_update(batch) + + def _untraced_update( + self, + batch: Dict, + # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in + # eager_tracing=True mode. + # It seems there may be a clash between the traced-by-tf function and the + # traced-by-ray functions (for making the TfLearner class a ray actor). + _ray_trace_ctx=None, + ): + # Activate tensor-mode on our MetricsLogger. + self.metrics.activate_tensor_mode() + + def helper(_batch): + with tf.GradientTape(persistent=True) as tape: + fwd_out = self._module.forward_train(_batch) + loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=_batch) + gradients = self.compute_gradients(loss_per_module, gradient_tape=tape) + del tape + postprocessed_gradients = self.postprocess_gradients(gradients) + self.apply_gradients(postprocessed_gradients) + + # Deactivate tensor-mode on our MetricsLogger and collect the (tensor) + # results. + return fwd_out, loss_per_module, self.metrics.deactivate_tensor_mode() + + return self._strategy.run(helper, args=(batch,)) + + @override(Learner) + def _get_tensor_variable(self, value, dtype=None, trainable=False) -> "tf.Tensor": + return tf.Variable( + value, + trainable=trainable, + dtype=( + dtype + or ( + tf.float32 + if isinstance(value, float) + else tf.int32 + if isinstance(value, int) + else None + ) + ), + ) + + @staticmethod + @override(Learner) + def _get_optimizer_lr(optimizer: "tf.Optimizer") -> float: + return optimizer.lr + + @staticmethod + @override(Learner) + def _set_optimizer_lr(optimizer: "tf.Optimizer", lr: float) -> None: + # When tf creates the optimizer, it seems to detach the optimizer's lr value + # from the given tf variable. + # Thus, updating this variable is NOT sufficient to update the actual + # optimizer's learning rate, so we have to explicitly set it here inside the + # optimizer object. + optimizer.lr = lr + + @staticmethod + @override(Learner) + def _get_clip_function() -> Callable: + from ray.rllib.utils.tf_utils import clip_gradients + + return clip_gradients + + @staticmethod + @override(Learner) + def _get_global_norm_function() -> Callable: + return tf.linalg.global_norm diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c02bbe64cbcbaca26bc2e4e7cc46511fcea5df48 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c10acf07e5338da5631d939e062d7a038556628 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/__pycache__/torch_learner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c10787d55d7954fdc4403fdab92f38c9f639c7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/torch/torch_learner.py @@ -0,0 +1,664 @@ +from collections import defaultdict +import logging +from typing import ( + Any, + Callable, + Dict, + Hashable, + Optional, + Sequence, + Tuple, +) + +from ray.rllib.algorithms.algorithm_config import ( + AlgorithmConfig, + TorchCompileWhatToCompile, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import Learner, LR_KEY +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, +) +from ray.rllib.core.rl_module.rl_module import ( + RLModule, + RLModuleSpec, +) +from ray.rllib.core.rl_module.torch.torch_rl_module import ( + TorchCompileConfig, + TorchDDPRLModule, + TorchRLModule, +) +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.framework import get_device, try_import_torch +from ray.rllib.utils.metrics import ( + ALL_MODULES, + DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY, + NUM_TRAINABLE_PARAMETERS, + NUM_NON_TRAINABLE_PARAMETERS, + WEIGHTS_SEQ_NO, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.utils.typing import ( + ModuleID, + Optimizer, + Param, + ParamDict, + ShouldModuleBeUpdatedFn, + StateDict, + TensorType, +) + +torch, nn = try_import_torch() +logger = logging.getLogger(__name__) + + +class TorchLearner(Learner): + + framework: str = "torch" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Whether to compile the RL Module of this learner. This implies that the. + # forward_train method of the RL Module will be compiled. Further more, + # other forward methods of the RL Module will be compiled on demand. + # This is assumed to not happen, since other forwrad methods are not expected + # to be used during training. + self._torch_compile_forward_train = False + self._torch_compile_cfg = None + # Whether to compile the `_uncompiled_update` method of this learner. This + # implies that everything within `_uncompiled_update` will be compiled, + # not only the forward_train method of the RL Module. + # Note that this is experimental. + # Note that this requires recompiling the forward methods once we add/remove + # RL Modules. + self._torch_compile_complete_update = False + if self.config.torch_compile_learner: + if ( + self.config.torch_compile_learner_what_to_compile + == TorchCompileWhatToCompile.COMPLETE_UPDATE + ): + self._torch_compile_complete_update = True + self._compiled_update_initialized = False + else: + self._torch_compile_forward_train = True + + self._torch_compile_cfg = TorchCompileConfig( + torch_dynamo_backend=self.config.torch_compile_learner_dynamo_backend, + torch_dynamo_mode=self.config.torch_compile_learner_dynamo_mode, + ) + + # Loss scalers for mixed precision training. Map optimizer names to + # associated torch GradScaler objects. + self._grad_scalers = None + if self.config._torch_grad_scaler_class: + self._grad_scalers = defaultdict( + lambda: self.config._torch_grad_scaler_class() + ) + self._lr_schedulers = {} + self._lr_scheduler_classes = None + if self.config._torch_lr_scheduler_classes: + self._lr_scheduler_classes = self.config._torch_lr_scheduler_classes + + @OverrideToImplementCustomLogic + @override(Learner) + def configure_optimizers_for_module( + self, + module_id: ModuleID, + config: "AlgorithmConfig" = None, + ) -> None: + module = self._module[module_id] + + # For this default implementation, the learning rate is handled by the + # attached lr Scheduler (controlled by self.config.lr, which can be a + # fixed value or a schedule setting). + params = self.get_parameters(module) + optimizer = torch.optim.Adam(params) + + # Register the created optimizer (under the default optimizer name). + self.register_optimizer( + module_id=module_id, + optimizer=optimizer, + params=params, + lr_or_lr_schedule=config.lr, + ) + + def _uncompiled_update( + self, + batch: Dict, + **kwargs, + ): + """Performs a single update given a batch of data.""" + # Activate tensor-mode on our MetricsLogger. + self.metrics.activate_tensor_mode() + + # TODO (sven): Causes weird cuda error when WandB is used. + # Diagnosis thus far: + # - All peek values during metrics.reduce are non-tensors. + # - However, in impala.py::training_step(), a tensor does arrive after learner + # group.update_from_episodes(), so somehow, there is still a race condition + # possible (learner, which performs the reduce() and learner thread, which + # performs the logging of tensors into metrics logger). + self._compute_off_policyness(batch) + + fwd_out = self.module.forward_train(batch) + loss_per_module = self.compute_losses(fwd_out=fwd_out, batch=batch) + + gradients = self.compute_gradients(loss_per_module) + postprocessed_gradients = self.postprocess_gradients(gradients) + self.apply_gradients(postprocessed_gradients) + + # Deactivate tensor-mode on our MetricsLogger and collect the (tensor) + # results. + return fwd_out, loss_per_module, self.metrics.deactivate_tensor_mode() + + @override(Learner) + def compute_gradients( + self, loss_per_module: Dict[ModuleID, TensorType], **kwargs + ) -> ParamDict: + for optim in self._optimizer_parameters: + # `set_to_none=True` is a faster way to zero out the gradients. + optim.zero_grad(set_to_none=True) + + if self._grad_scalers is not None: + total_loss = sum( + self._grad_scalers[mid].scale(loss) + for mid, loss in loss_per_module.items() + ) + else: + total_loss = sum(loss_per_module.values()) + + total_loss.backward() + grads = {pid: p.grad for pid, p in self._params.items()} + + return grads + + @override(Learner) + def apply_gradients(self, gradients_dict: ParamDict) -> None: + # Set the gradient of the parameters. + for pid, grad in gradients_dict.items(): + # If updates should not be skipped turn `nan` and `inf` gradients to zero. + if ( + not torch.isfinite(grad).all() + and not self.config.torch_skip_nan_gradients + ): + # Warn the user about `nan` gradients. + logger.warning(f"Gradients {pid} contain `nan/inf` values.") + # If updates should be skipped, do not step the optimizer and return. + if not self.config.torch_skip_nan_gradients: + logger.warning( + "Setting `nan/inf` gradients to zero. If updates with " + "`nan/inf` gradients should not be set to zero and instead " + "the update be skipped entirely set `torch_skip_nan_gradients` " + "to `True`." + ) + # If necessary turn `nan` gradients to zero. Note this can corrupt the + # internal state of the optimizer, if many `nan` gradients occur. + self._params[pid].grad = torch.nan_to_num(grad) + # Otherwise, use the gradient as is. + else: + self._params[pid].grad = grad + + # For each optimizer call its step function. + for module_id, optimizer_names in self._module_optimizers.items(): + for optimizer_name in optimizer_names: + optim = self.get_optimizer(module_id, optimizer_name) + # If we have learning rate schedulers for a module add them, if + # necessary. + if self._lr_scheduler_classes is not None: + if ( + module_id not in self._lr_schedulers + or optimizer_name not in self._lr_schedulers[module_id] + ): + # Set for each module and optimizer a scheduler. + self._lr_schedulers[module_id] = {optimizer_name: []} + # If the classes are in a dictionary each module might have + # a different set of schedulers. + if isinstance(self._lr_scheduler_classes, dict): + scheduler_classes = self._lr_scheduler_classes[module_id] + # Else, each module has the same learning rate schedulers. + else: + scheduler_classes = self._lr_scheduler_classes + # Initialize and add the schedulers. + for scheduler_class in scheduler_classes: + self._lr_schedulers[module_id][optimizer_name].append( + scheduler_class(optim) + ) + + # Step through the scaler (unscales gradients, if applicable). + if self._grad_scalers is not None: + scaler = self._grad_scalers[module_id] + scaler.step(optim) + self.metrics.log_value( + (module_id, "_torch_grad_scaler_current_scale"), + scaler.get_scale(), + window=1, # snapshot in time, no EMA/mean. + ) + # Update the scaler. + scaler.update() + # `step` the optimizer (default), but only if all gradients are finite. + elif all( + param.grad is None or torch.isfinite(param.grad).all() + for group in optim.param_groups + for param in group["params"] + ): + optim.step() + # If gradients are not all finite warn the user that the update will be + # skipped. + elif not all( + torch.isfinite(param.grad).all() + for group in optim.param_groups + for param in group["params"] + ): + logger.warning( + "Skipping this update. If updates with `nan/inf` gradients " + "should not be skipped entirely and instead `nan/inf` " + "gradients set to `zero` set `torch_skip_nan_gradients` to " + "`False`." + ) + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: + """Called after gradient-based updates are completed. + + Should be overridden to implement custom cleanup-, logging-, or non-gradient- + based Learner/RLModule update logic after(!) gradient-based updates have been + completed. + + Note, for `framework="torch"` users can register + `torch.optim.lr_scheduler.LRScheduler` via + `AlgorithmConfig._torch_lr_scheduler_classes`. These schedulers need to be + stepped here after gradient updates and reported. + + Args: + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + """ + + # If we have no `torch.optim.lr_scheduler.LRScheduler` registered call the + # `super()`'s method to update RLlib's learning rate schedules. + if not self._lr_schedulers: + return super().after_gradient_based_update(timesteps=timesteps) + + # Only update this optimizer's lr, if a scheduler has been registered + # along with it. + for module_id, optimizer_names in self._module_optimizers.items(): + for optimizer_name in optimizer_names: + # If learning rate schedulers are provided step them here. Note, + # stepping them in `TorchLearner.apply_gradients` updates the + # learning rates during minibatch updates; we want to update + # between whole batch updates. + if ( + module_id in self._lr_schedulers + and optimizer_name in self._lr_schedulers[module_id] + ): + for scheduler in self._lr_schedulers[module_id][optimizer_name]: + scheduler.step() + optimizer = self.get_optimizer(module_id, optimizer_name) + self.metrics.log_value( + # Cut out the module ID from the beginning since it's already + # part of the key sequence: (ModuleID, "[optim name]_lr"). + key=( + module_id, + f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}", + ), + value=convert_to_numpy(self._get_optimizer_lr(optimizer)), + window=1, + ) + + @override(Learner) + def _get_optimizer_state(self) -> StateDict: + ret = {} + for name, optim in self._named_optimizers.items(): + ret[name] = { + "module_id": self._optimizer_name_to_module[name], + "state": convert_to_numpy(optim.state_dict()), + } + return ret + + @override(Learner) + def _set_optimizer_state(self, state: StateDict) -> None: + for name, state_dict in state.items(): + # Ignore updating optimizers matching to submodules not present in this + # Learner's MultiRLModule. + module_id = state_dict["module_id"] + if name not in self._named_optimizers and module_id in self.module: + self.configure_optimizers_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id=module_id), + ) + if name in self._named_optimizers: + self._named_optimizers[name].load_state_dict( + convert_to_torch_tensor(state_dict["state"], device=self._device) + ) + + @override(Learner) + def get_param_ref(self, param: Param) -> Hashable: + return param + + @override(Learner) + def get_parameters(self, module: RLModule) -> Sequence[Param]: + return list(module.parameters()) + + @override(Learner) + def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch: + batch = convert_to_torch_tensor(batch.policy_batches, device=self._device) + # TODO (sven): This computation of `env_steps` is not accurate! + length = max(len(b) for b in batch.values()) + batch = MultiAgentBatch(batch, env_steps=length) + return batch + + @override(Learner) + def add_module( + self, + *, + module_id: ModuleID, + # TODO (sven): Rename to `rl_module_spec`. + module_spec: RLModuleSpec, + config_overrides: Optional[Dict] = None, + new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, + ) -> MultiRLModuleSpec: + # Call super's add_module method. + marl_spec = super().add_module( + module_id=module_id, + module_spec=module_spec, + config_overrides=config_overrides, + new_should_module_be_updated=new_should_module_be_updated, + ) + + # we need to ddpify the module that was just added to the pool + module = self._module[module_id] + + if self._torch_compile_forward_train: + module.compile(self._torch_compile_cfg) + elif self._torch_compile_complete_update: + # When compiling the update, we need to reset and recompile + # _uncompiled_update every time we add/remove a module anew. + torch._dynamo.reset() + self._compiled_update_initialized = False + self._possibly_compiled_update = torch.compile( + self._uncompiled_update, + backend=self._torch_compile_cfg.torch_dynamo_backend, + mode=self._torch_compile_cfg.torch_dynamo_mode, + **self._torch_compile_cfg.kwargs, + ) + + if isinstance(module, TorchRLModule): + self._module[module_id].to(self._device) + if self.distributed: + if ( + self._torch_compile_complete_update + or self._torch_compile_forward_train + ): + raise ValueError( + "Using torch distributed and torch compile " + "together tested for now. Please disable " + "torch compile." + ) + self._module.add_module( + module_id, + TorchDDPRLModule(module, **self.config.torch_ddp_kwargs), + override=True, + ) + + self._log_trainable_parameters() + + return marl_spec + + @override(Learner) + def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec: + marl_spec = super().remove_module(module_id, **kwargs) + + if self._torch_compile_complete_update: + # When compiling the update, we need to reset and recompile + # _uncompiled_update every time we add/remove a module anew. + torch._dynamo.reset() + self._compiled_update_initialized = False + self._possibly_compiled_update = torch.compile( + self._uncompiled_update, + backend=self._torch_compile_cfg.torch_dynamo_backend, + mode=self._torch_compile_cfg.torch_dynamo_mode, + **self._torch_compile_cfg.kwargs, + ) + + self._log_trainable_parameters() + + return marl_spec + + @override(Learner) + def build(self) -> None: + """Builds the TorchLearner. + + This method is specific to TorchLearner. Before running super() it will + initialize the device properly based on `self.config`, so that `_make_module()` + can place the created module on the correct device. After running super() it + wraps the module in a TorchDDPRLModule if `config.num_learners > 0`. + Note, in inherited classes it is advisable to call the parent's `build()` + after setting up all variables because `configure_optimizer_for_module` is + called in this `Learner.build()`. + """ + self._device = get_device(self.config, self.config.num_gpus_per_learner) + + super().build() + + if self._torch_compile_complete_update: + torch._dynamo.reset() + self._compiled_update_initialized = False + self._possibly_compiled_update = torch.compile( + self._uncompiled_update, + backend=self._torch_compile_cfg.torch_dynamo_backend, + mode=self._torch_compile_cfg.torch_dynamo_mode, + **self._torch_compile_cfg.kwargs, + ) + else: + if self._torch_compile_forward_train: + if isinstance(self._module, TorchRLModule): + self._module.compile(self._torch_compile_cfg) + elif isinstance(self._module, MultiRLModule): + for module in self._module._rl_modules.values(): + # Compile only TorchRLModules, e.g. we don't want to compile + # a RandomRLModule. + if isinstance(self._module, TorchRLModule): + module.compile(self._torch_compile_cfg) + else: + raise ValueError( + "Torch compile is only supported for TorchRLModule and " + "MultiRLModule." + ) + + self._possibly_compiled_update = self._uncompiled_update + + self._make_modules_ddp_if_necessary() + + @override(Learner) + def _update(self, batch: Dict[str, Any]) -> Tuple[Any, Any, Any]: + # The first time we call _update after building the learner or + # adding/removing models, we update with the uncompiled update method. + # This makes it so that any variables that may be created during the first + # update step are already there when compiling. More specifically, + # this avoids errors that occur around using defaultdicts with + # torch.compile(). + if ( + self._torch_compile_complete_update + and not self._compiled_update_initialized + ): + self._compiled_update_initialized = True + return self._uncompiled_update(batch) + else: + return self._possibly_compiled_update(batch) + + @OverrideToImplementCustomLogic + def _make_modules_ddp_if_necessary(self) -> None: + """Default logic for (maybe) making all Modules within self._module DDP.""" + + # If the module is a MultiRLModule and nn.Module we can simply assume + # all the submodules are registered. Otherwise, we need to loop through + # each submodule and move it to the correct device. + # TODO (Kourosh): This can result in missing modules if the user does not + # register them in the MultiRLModule. We should find a better way to + # handle this. + if self.config.num_learners > 1: + # Single agent module: Convert to `TorchDDPRLModule`. + if isinstance(self._module, TorchRLModule): + self._module = TorchDDPRLModule( + self._module, **self.config.torch_ddp_kwargs + ) + # Multi agent module: Convert each submodule to `TorchDDPRLModule`. + else: + assert isinstance(self._module, MultiRLModule) + for key in self._module.keys(): + sub_module = self._module[key] + if isinstance(sub_module, TorchRLModule): + # Wrap and override the module ID key in self._module. + self._module.add_module( + key, + TorchDDPRLModule( + sub_module, **self.config.torch_ddp_kwargs + ), + override=True, + ) + + def rl_module_is_compatible(self, module: RLModule) -> bool: + return isinstance(module, nn.Module) + + @override(Learner) + def _check_registered_optimizer( + self, + optimizer: Optimizer, + params: Sequence[Param], + ) -> None: + super()._check_registered_optimizer(optimizer, params) + if not isinstance(optimizer, torch.optim.Optimizer): + raise ValueError( + f"The optimizer ({optimizer}) is not a torch.optim.Optimizer! " + "Only use torch.optim.Optimizer subclasses for TorchLearner." + ) + for param in params: + if not isinstance(param, torch.Tensor): + raise ValueError( + f"One of the parameters ({param}) in the registered optimizer " + "is not a torch.Tensor!" + ) + + @override(Learner) + def _make_module(self) -> MultiRLModule: + module = super()._make_module() + self._map_module_to_device(module) + return module + + def _map_module_to_device(self, module: MultiRLModule) -> None: + """Moves the module to the correct device.""" + if isinstance(module, torch.nn.Module): + module.to(self._device) + else: + for key in module.keys(): + if isinstance(module[key], torch.nn.Module): + module[key].to(self._device) + + @override(Learner) + def _log_trainable_parameters(self) -> None: + # Log number of non-trainable and trainable parameters of our RLModule. + num_trainable_params = { + (mid, NUM_TRAINABLE_PARAMETERS): sum( + p.numel() for p in rlm.parameters() if p.requires_grad + ) + for mid, rlm in self.module._rl_modules.items() + if isinstance(rlm, TorchRLModule) + } + num_non_trainable_params = { + (mid, NUM_NON_TRAINABLE_PARAMETERS): sum( + p.numel() for p in rlm.parameters() if not p.requires_grad + ) + for mid, rlm in self.module._rl_modules.items() + if isinstance(rlm, TorchRLModule) + } + + self.metrics.log_dict( + { + **{ + (ALL_MODULES, NUM_TRAINABLE_PARAMETERS): sum( + num_trainable_params.values() + ), + (ALL_MODULES, NUM_NON_TRAINABLE_PARAMETERS): sum( + num_non_trainable_params.values() + ), + }, + **num_trainable_params, + **num_non_trainable_params, + } + ) + + def _compute_off_policyness(self, batch): + # Log off-policy'ness of this batch wrt the current weights. + off_policyness = { + (mid, DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY): ( + (self._weights_seq_no - module_batch[WEIGHTS_SEQ_NO]).float() + ) + for mid, module_batch in batch.items() + if WEIGHTS_SEQ_NO in module_batch + } + for key in off_policyness.keys(): + mid = key[0] + if Columns.LOSS_MASK not in batch[mid]: + off_policyness[key] = torch.mean(off_policyness[key]) + else: + mask = batch[mid][Columns.LOSS_MASK] + num_valid = torch.sum(mask) + off_policyness[key] = torch.sum(off_policyness[key][mask]) / num_valid + self.metrics.log_dict(off_policyness, window=1) + + @override(Learner) + def _get_tensor_variable( + self, value, dtype=None, trainable=False + ) -> "torch.Tensor": + tensor = torch.tensor( + value, + requires_grad=trainable, + device=self._device, + dtype=( + dtype + or ( + torch.float32 + if isinstance(value, float) + else torch.int32 + if isinstance(value, int) + else None + ) + ), + ) + return nn.Parameter(tensor) if trainable else tensor + + @staticmethod + @override(Learner) + def _get_optimizer_lr(optimizer: "torch.optim.Optimizer") -> float: + for g in optimizer.param_groups: + return g["lr"] + + @staticmethod + @override(Learner) + def _set_optimizer_lr(optimizer: "torch.optim.Optimizer", lr: float) -> None: + for g in optimizer.param_groups: + g["lr"] = lr + + @staticmethod + @override(Learner) + def _get_clip_function() -> Callable: + from ray.rllib.utils.torch_utils import clip_gradients + + return clip_gradients + + @staticmethod + @override(Learner) + def _get_global_norm_function() -> Callable: + from ray.rllib.utils.torch_utils import compute_global_norm + + return compute_global_norm diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7682725cf9a222617968b1b3e578fd36cc2f2595 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/learner/utils.py @@ -0,0 +1,59 @@ +import copy + +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import NetworkType +from ray.util import PublicAPI + + +torch, _ = try_import_torch() + + +def make_target_network(main_net: NetworkType) -> NetworkType: + """Creates a (deep) copy of `main_net` (including synched weights) and returns it. + + Args: + main_net: The main network to return a target network for + + Returns: + The copy of `main_net` that can be used as a target net. Note that the weights + of the returned net are already synched (identical) with `main_net`. + """ + # Deepcopy the main net (this should already take care of synching all weights). + target_net = copy.deepcopy(main_net) + # Make the target net not trainable. + if isinstance(main_net, torch.nn.Module): + target_net.requires_grad_(False) + else: + raise ValueError(f"Unsupported framework for given `main_net` {main_net}!") + + return target_net + + +@PublicAPI(stability="beta") +def update_target_network( + *, + main_net: NetworkType, + target_net: NetworkType, + tau: float, +) -> None: + """Updates a target network (from a "main" network) using Polyak averaging. + + Thereby: + new_target_net_weight = ( + tau * main_net_weight + (1.0 - tau) * current_target_net_weight + ) + + Args: + main_net: The nn.Module to update from. + target_net: The target network to update. + tau: The tau value to use in the Polyak averaging formula. Use 1.0 for a + complete sync of the weights (target and main net will be the exact same + after updating). + """ + if isinstance(main_net, torch.nn.Module): + from ray.rllib.utils.torch_utils import update_target_network as _update_target + + else: + raise ValueError(f"Unsupported framework for given `main_net` {main_net}!") + + _update_target(main_net=main_net, target_net=target_net, tau=tau) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a62a14504bd132eef48dc0c933aff99635f3f04f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6afe7bb4994f3cd6140c85f77aea7e0b4ba0f1ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/heads.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..992bf08a2b9328c587682c27907c5d70afc14768 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/models/torch/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..490cd7942947d8cc14625fe069893f8347697b12 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__init__.py @@ -0,0 +1,53 @@ +import logging +import re + +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, +) +from ray.util import log_once +from ray.util.annotations import DeveloperAPI + +logger = logging.getLogger("ray.rllib") + + +@DeveloperAPI +def validate_module_id(policy_id: str, error: bool = False) -> None: + """Makes sure the given `policy_id` is valid. + + Args: + policy_id: The Policy ID to check. + IMPORTANT: Must not contain characters that + are also not allowed in Unix/Win filesystems, such as: `<>:"/\\|?*` + or a dot `.` or space ` ` at the end of the ID. + error: Whether to raise an error (ValueError) or a warning in case of an + invalid `policy_id`. + + Raises: + ValueError: If the given `policy_id` is not a valid one and `error` is True. + """ + if ( + not isinstance(policy_id, str) + or len(policy_id) == 0 + or re.search('[<>:"/\\\\|?]', policy_id) + or policy_id[-1] in (" ", ".") + ): + msg = ( + f"PolicyID `{policy_id}` not valid! IDs must be a non-empty string, " + "must not contain characters that are also disallowed file- or directory " + "names on Unix/Windows and must not end with a dot `.` or a space ` `." + ) + if error: + raise ValueError(msg) + elif log_once("invalid_policy_id"): + logger.warning(msg) + + +__all__ = [ + "MultiRLModule", + "MultiRLModuleSpec", + "RLModule", + "RLModuleSpec", + "validate_module_id", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfae3cd81c0f67197dbedbad97e35e77f012199d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a61abc90d208c90e79a5950f7f832946e77425f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/default_model_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6eb6ff2d294d58ddf5831c74ce34a2cc1130e93 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/multi_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..539321cb6c2f22fbcf47e4d4aec52f14001bc455 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/__pycache__/rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e51e91a1b11e2c0d37a66d8cd7ff904ef2fcf6e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__init__.py @@ -0,0 +1,18 @@ +from ray.rllib.core.rl_module.apis.inference_only_api import InferenceOnlyAPI +from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI +from ray.rllib.core.rl_module.apis.self_supervised_loss_api import SelfSupervisedLossAPI +from ray.rllib.core.rl_module.apis.target_network_api import ( + TargetNetworkAPI, + TARGET_NETWORK_ACTION_DIST_INPUTS, +) +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI + + +__all__ = [ + "InferenceOnlyAPI", + "QNetAPI", + "SelfSupervisedLossAPI", + "TargetNetworkAPI", + "TARGET_NETWORK_ACTION_DIST_INPUTS", + "ValueFunctionAPI", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7eb56c7730163bf4bdc0cb73e029deddb4f7f2c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/inference_only_api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/inference_only_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95e3db1112a48a04fdf05cd3f3405cd274df7e4c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/inference_only_api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/q_net_api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/q_net_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c6eaf123f38d23c1ceffc4214d0f9d79d156964 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/q_net_api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/self_supervised_loss_api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/self_supervised_loss_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90f2224cfd4f766a5d7b3a644a8dfcaebeee3dc1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/self_supervised_loss_api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/target_network_api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/target_network_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20d14072b6130d540c4a5c5583d1b2317fa388af Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/target_network_api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/value_function_api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/value_function_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4325e766df33e1bb53c07fe7df0bac8ddd6f6933 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/__pycache__/value_function_api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/inference_only_api.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/inference_only_api.py new file mode 100644 index 0000000000000000000000000000000000000000..34bed76781d25bca38cce7802409f93c7e62f2f2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/inference_only_api.py @@ -0,0 +1,65 @@ +import abc +from typing import List + +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class InferenceOnlyAPI(abc.ABC): + """An API to be implemented by RLModules that have an inference-only mode. + + Only the `get_non_inference_attributes` method needs to get implemented for + an RLModule to have the following functionality: + - On EnvRunners (or when self.inference_only=True), RLlib will remove + those parts of the model not required for action computation. + - An RLModule on a Learner (where `self.inference_only=False`) will + return only those weights from `get_state()` that are part of its inference-only + version, thus possibly saving network traffic/time. + """ + + @abc.abstractmethod + def get_non_inference_attributes(self) -> List[str]: + """Returns a list of attribute names (str) of components NOT used for inference. + + RLlib will use this information to remove those attributes/components from an + RLModule, whose `config.inference_only` is set to True. This so-called + "inference-only setup" is activated. Normally, all RLModules located on + EnvRunners are constructed this way (because they are only used for computing + actions). Similarly, when deployed into a production environment, users should + consider building their RLModules with this flag set to True as well. + + For example: + + .. testcode:: + :skipif: True + + from ray.rllib.core.rl_module.rl_module import RLModuleSpec + + spec = RLModuleSpec(module_class=..., inference_only=True) + + If an RLModule has the following `setup()` implementation: + + .. testcode:: + :skipif: True + + class MyRLModule(RLModule): + + def setup(self): + self._policy_head = [some NN component] + self._value_function_head = [some NN component] + + self._encoder = [some NN component with attributes: `pol` and `vf` + (policy- and value func. encoder)] + + Then its `get_non_inference_attributes()` should return: + `["_value_function_head", "_encoder.vf"]` + + Note the "." notation to separate attributes and their sub-attributes in case + you need more fine-grained control over which exact sub-attributes to exclude in + the inference-only setup. + + Returns: + A list of names (str) of those attributes (or sub-attributes) that should be + excluded (deleted) from this RLModule in case it's setup in + `inference_only` mode. + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/q_net_api.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/q_net_api.py new file mode 100644 index 0000000000000000000000000000000000000000..88fb1a00c0fe7c27e4ca72b594d6f548940fe56f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/q_net_api.py @@ -0,0 +1,56 @@ +import abc +from typing import Dict + +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class QNetAPI(abc.ABC): + """An API to be implemented by RLModules used for (distributional) Q-learning. + + RLModules implementing this API must override the `compute_q_values` and the + `compute_advantage_distribution` methods. + """ + + @abc.abstractmethod + def compute_q_values( + self, + batch: Dict[str, TensorType], + ) -> Dict[str, TensorType]: + """Computes Q-values, given encoder, q-net and (optionally), advantage net. + + Note, these can be accompanied by logits and probabilities + in case of distributional Q-learning, i.e. `self.num_atoms > 1`. + + Args: + batch: The batch received in the forward pass. + + Results: + A dictionary containing the Q-value predictions ("qf_preds") + and in case of distributional Q-learning - in addition to the Q-value + predictions ("qf_preds") - the support atoms ("atoms"), the Q-logits + ("qf_logits"), and the probabilities ("qf_probs"). + """ + + def compute_advantage_distribution( + self, + batch: Dict[str, TensorType], + ) -> Dict[str, TensorType]: + """Computes the advantage distribution. + + Note this distribution is identical to the Q-distribution in case no dueling + architecture is used. + + Args: + batch: A dictionary containing a tensor with the outputs of the + forward pass of the Q-head or advantage stream head. + + Returns: + A `dict` containing the support of the discrete distribution for + either Q-values or advantages (in case of a dueling architecture), + ("atoms"), the logits per action and atom and the probabilities + of the discrete distribution (per action and atom of the support). + """ + # Return the Q-distribution by default. + return self.compute_q_values(batch) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/self_supervised_loss_api.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/self_supervised_loss_api.py new file mode 100644 index 0000000000000000000000000000000000000000..713dc68929ad75fffa8d7343f7979571fdab2bfe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/self_supervised_loss_api.py @@ -0,0 +1,54 @@ +import abc +from typing import Any, Dict, TYPE_CHECKING + +from ray.rllib.utils.typing import ModuleID, TensorType +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.core.learner.learner import Learner + + +@PublicAPI(stability="alpha") +class SelfSupervisedLossAPI(abc.ABC): + """An API to be implemented by RLModules that bring their own self-supervised loss. + + Learners will call these model's `compute_self_supervised_loss()` method instead of + the Learner's own `compute_loss_for_module()` method. + The call signature is identical to the Learner's `compute_loss_for_module()` method + except of an additional mandatory `learner` kwarg. + """ + + @abc.abstractmethod + def compute_self_supervised_loss( + self, + *, + learner: "Learner", + module_id: ModuleID, + config: "AlgorithmConfig", + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + """Computes the loss for a single module. + + Think of this as computing loss for a single agent. For multi-agent use-cases + that require more complicated computation for loss, consider overriding the + `compute_losses` method instead. + + Args: + learner: The Learner calling this loss method on the RLModule. + module_id: The ID of the RLModule (within a MultiRLModule). + config: The AlgorithmConfig specific to the given `module_id`. + batch: The sample batch for this particular RLModule. + fwd_out: The output of the forward pass for this particular RLModule. + + Returns: + A single total loss tensor. If you have more than one optimizer on the + provided `module_id` and would like to compute gradients separately using + these different optimizers, simply add up the individual loss terms for + each optimizer and return the sum. Also, for recording/logging any + individual loss terms, you can use the `Learner.metrics.log_value( + key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See: + :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more + information. + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/target_network_api.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/target_network_api.py new file mode 100644 index 0000000000000000000000000000000000000000..d1615edff1e039236a67bb841e495289fbbdbeb1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/target_network_api.py @@ -0,0 +1,57 @@ +import abc +from typing import Any, Dict, List, Tuple + +from ray.rllib.utils.typing import NetworkType +from ray.util.annotations import PublicAPI + + +TARGET_NETWORK_ACTION_DIST_INPUTS = "target_network_action_dist_inputs" + + +@PublicAPI(stability="alpha") +class TargetNetworkAPI(abc.ABC): + """An API to be implemented by RLModules for handling target networks. + + RLModules implementing this API must override the `make_target_networks`, + `get_target_network_pairs`, and the `forward_target` methods. + + Note that the respective Learner that owns the implementing RLModule handles all + target syncing logic. + """ + + @abc.abstractmethod + def make_target_networks(self) -> None: + """Creates the required target nets for this RLModule. + + Use the convenience `ray.rllib.core.learner.utils.make_target_network()` utility + when implementing this method. Pass in an already existing, corresponding "main" + net (for which you need a target net). + This function already takes care of initialization (from the "main" net). + """ + + @abc.abstractmethod + def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: + """Returns a list of 2-tuples of (main_net, target_net). + + For example, if your RLModule has a property: `self.q_net` and this network + has a corresponding target net `self.target_q_net`, return from this + (overridden) method: [(self.q_net, self.target_q_net)]. + + Note that you need to create all target nets in your overridden + `make_target_networks` method and store the target nets in any properly of your + choice. + + Returns: + A list of 2-tuples of (main_net, target_net) + """ + + @abc.abstractmethod + def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Performs the forward pass through the target net(s). + + Args: + batch: The batch to use for the forward pass. + + Returns: + The results from the forward pass(es) through the target net(s). + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/value_function_api.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/value_function_api.py new file mode 100644 index 0000000000000000000000000000000000000000..46e33d2fd315fe75e555e1629d4a1e4a0e44324b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/apis/value_function_api.py @@ -0,0 +1,35 @@ +import abc +from typing import Any, Dict, Optional + +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ValueFunctionAPI(abc.ABC): + """An API to be implemented by RLModules for handling value function-based learning. + + RLModules implementing this API must override the `compute_values` method. + """ + + @abc.abstractmethod + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + """Computes the value estimates given `batch`. + + Args: + batch: The batch to compute value function estimates for. + embeddings: Optional embeddings already computed from the `batch` (by + another forward pass through the model's encoder (or other subcomponent + that computes an embedding). For example, the caller of thie method + should provide `embeddings` - if available - to avoid duplicate passes + through a shared encoder. + + Returns: + A tensor of shape (B,) or (B, T) (in case the input `batch` has a + time dimension. Note that the last value dimension should already be + squeezed out (not 1!). + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/default_model_config.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/default_model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e53b852a4e7b5ec3c25a2b713e8779d8cdec834e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/default_model_config.py @@ -0,0 +1,204 @@ +from dataclasses import dataclass, field +from typing import Callable, List, Optional, Union + +from ray.rllib.utils.typing import ConvFilterSpec +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +@dataclass +class DefaultModelConfig: + """Dataclass to configure all default RLlib RLModules. + + Users should NOT use this class for configuring their own custom RLModules, but + use a custom `model_config` dict with arbitrary (str) keys passed into the + `RLModuleSpec` used to define the custom RLModule. + For example: + + .. testcode:: + + import gymnasium as gym + import numpy as np + from ray.rllib.core.rl_module.rl_module import RLModuleSpec + from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import ( + TinyAtariCNN + ) + + my_rl_module = RLModuleSpec( + module_class=TinyAtariCNN, + observation_space=gym.spaces.Box(-1.0, 1.0, (64, 64, 4), np.float32), + action_space=gym.spaces.Discrete(7), + # DreamerV3-style stack working on a 64x64, color or 4x-grayscale-stacked, + # normalized image. + model_config={ + "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], + }, + ).build() + + Only RLlib's default RLModules (defined by the various algorithms) should use + this dataclass. Pass an instance of it into your algorithm config like so: + + .. testcode:: + + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig + + config = ( + PPOConfig() + .rl_module( + model_config=DefaultModelConfig(fcnet_hiddens=[32, 32]), + ) + ) + """ + + # ==================================================== + # MLP stacks + # ==================================================== + # __sphinx_doc_default_model_config_fcnet_begin__ + #: List containing the sizes (number of nodes) of a fully connected (MLP) stack. + #: Note that in an encoder-based default architecture with a policy head (and + #: possible value head), this setting only affects the encoder component. To set the + #: policy (and value) head sizes, use `post_fcnet_hiddens`, instead. For example, + #: if you set `fcnet_hiddens=[32, 32]` and `post_fcnet_hiddens=[64]`, you would get + #: an RLModule with a [32, 32] encoder, a [64, act-dim] policy head, and a [64, 1] + #: value head (if applicable). + fcnet_hiddens: List[int] = field(default_factory=lambda: [256, 256]) + #: Activation function descriptor for the stack configured by `fcnet_hiddens`. + #: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same), + #: and 'linear' (or None). + fcnet_activation: str = "tanh" + #: Initializer function or class descriptor for the weight/kernel matrices in the + #: stack configured by `fcnet_hiddens`. Supported values are the initializer names + #: (str), classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + fcnet_kernel_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `fcnet_kernel_initializer`. + fcnet_kernel_initializer_kwargs: Optional[dict] = None + #: Initializer function or class descriptor for the bias vectors in the stack + #: configured by `fcnet_hiddens`. Supported values are the initializer names (str), + #: classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + fcnet_bias_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `fcnet_bias_initializer`. + fcnet_bias_initializer_kwargs: Optional[dict] = None + # __sphinx_doc_default_model_config_fcnet_end__ + + # ==================================================== + # Conv2D stacks + # ==================================================== + # __sphinx_doc_default_model_config_conv_begin__ + #: List of lists of format [num_out_channels, kernel, stride] defining a Conv2D + #: stack if the input space is 2D. Each item in the outer list represents one Conv2D + #: layer. `kernel` and `stride` may be single ints (width and height have same + #: value) or 2-tuples (int, int) specifying width and height dimensions separately. + #: If None (default) and the input space is 2D, RLlib tries to find a default filter + #: setup given the exact input dimensions. + conv_filters: Optional[ConvFilterSpec] = None + #: Activation function descriptor for the stack configured by `conv_filters`. + #: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same), and + #: 'linear' (or None). + conv_activation: str = "relu" + #: Initializer function or class descriptor for the weight/kernel matrices in the + #: stack configured by `conv_filters`. Supported values are the initializer names + #: (str), classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + conv_kernel_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `conv_kernel_initializer`. + conv_kernel_initializer_kwargs: Optional[dict] = None + #: Initializer function or class descriptor for the bias vectors in the stack + #: configured by `conv_filters`. Supported values are the initializer names (str), + #: classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + conv_bias_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `conv_bias_initializer`. + conv_bias_initializer_kwargs: Optional[dict] = None + # __sphinx_doc_default_model_config_conv_end__ + + # ==================================================== + # Head configs (e.g. policy- or value function heads) + # ==================================================== + #: List containing the sizes (number of nodes) of a fully connected (MLP) head (ex. + #: policy-, value-, or Q-head). Note that in order to configure the encoder + #: architecture, use `fcnet_hiddens`, instead. + head_fcnet_hiddens: List[int] = field(default_factory=lambda: []) + #: Activation function descriptor for the stack configured by `head_fcnet_hiddens`. + #: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same), + #: and 'linear' (or None). + head_fcnet_activation: str = "relu" + #: Initializer function or class descriptor for the weight/kernel matrices in the + #: stack configured by `head_fcnet_hiddens`. Supported values are the initializer + #: names (str), classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + head_fcnet_kernel_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `head_fcnet_kernel_initializer`. + head_fcnet_kernel_initializer_kwargs: Optional[dict] = None + #: Initializer function or class descriptor for the bias vectors in the stack + #: configured by `head_fcnet_hiddens`. Supported values are the initializer names + #: (str), classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + head_fcnet_bias_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `head_fcnet_bias_initializer`. + head_fcnet_bias_initializer_kwargs: Optional[dict] = None + + # ==================================================== + # Continuous action settings + # ==================================================== + #: If True, for DiagGaussian action distributions (or any other continuous control + #: distribution), make the second half of the policy's outputs a "free" bias + #: parameter, rather than state-/NN-dependent nodes. In this case, the number of + #: nodes of the policy head have the same dimension as the action space as no slots + #: for log(stddev) are required (only for the mean values). + free_log_std: bool = False + #: Whether to clip the log(stddev) when using a DiagGaussian action distribution + #: (or any other continuous control distribution). This can stabilize training and + #: avoid very small or large log(stddev) values leading to numerical instabilities + #: turning outputs to `nan`. The default is to clamp the log(stddev) in between + #: -20 and 20. Set to float("inf") for no clamping. + log_std_clip_param: float = 20.0 + #: Whether encoder layers (defined by `fcnet_hiddens` or `conv_filters`) should be + #: shared between policy- and value function. + vf_share_layers: bool = True + + # ==================================================== + # LSTM settings + # ==================================================== + #: Whether to wrap the encoder component (defined by `fcnet_hiddens` or + #: `conv_filters`) with an LSTM. + use_lstm: bool = False + #: The maximum seq len for building the train batch for an LSTM model. + #: Defaults to 20. + max_seq_len: int = 20 + #: The size of the LSTM cell. + lstm_cell_size: int = 256 + lstm_use_prev_action: bool = False + lstm_use_prev_reward: bool = False + #: Initializer function or class descriptor for the weight/kernel matrices in the + #: LSTM layer. Supported values are the initializer names (str), classes or + #: functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + lstm_kernel_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `lstm_kernel_initializer`. + lstm_kernel_initializer_kwargs: Optional[dict] = None + #: Initializer function or class descriptor for the bias vectors in the stack + #: configured by the LSTM layer. Supported values are the initializer names (str), + #: classes or functions listed by the frameworks (`torch`). See + #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default), + #: the default initializer defined by `torch` is used. + lstm_bias_initializer: Optional[Union[str, Callable]] = None + #: Kwargs passed into the initializer function defined through + #: `lstm_bias_initializer`. + lstm_bias_initializer_kwargs: Optional[dict] = None diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/multi_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/multi_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..49b8097675f5487ecf9aa45716d091ccef2b4080 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/multi_rl_module.py @@ -0,0 +1,850 @@ +import copy +import dataclasses +import logging +import pprint +from typing import ( + Any, + Callable, + Collection, + Dict, + ItemsView, + KeysView, + List, + Optional, + Set, + Tuple, + Type, + Union, + ValuesView, +) + +import gymnasium as gym + +from ray.rllib.core.models.specs.typing import SpecType +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.deprecation import ( + Deprecated, + DEPRECATED_VALUE, + deprecation_warning, +) +from ray.rllib.utils.serialization import ( + gym_space_from_dict, + gym_space_to_dict, + serialize_type, + deserialize_type, +) +from ray.rllib.utils.typing import ModuleID, StateDict, T +from ray.util.annotations import PublicAPI + +logger = logging.getLogger("ray.rllib") + + +@PublicAPI(stability="alpha") +class MultiRLModule(RLModule): + """Base class for an RLModule that contains n sub-RLModules. + + This class holds a mapping from ModuleID to underlying RLModules. It provides + a convenient way of accessing each individual module, as well as accessing all of + them with only one API call. Whether a given module is trainable is + determined by the caller of this class (not the instance of this class itself). + + The extension of this class can include any arbitrary neural networks as part of + the MultiRLModule. For example, a MultiRLModule can include a shared encoder network + that is used by all the individual (single-agent) RLModules. It is up to the user + to decide how to implement this class. + + The default implementation assumes the data communicated as input and output of + the APIs in this class are `Dict[ModuleID, Dict[str, Any]]` types. The + `MultiRLModule` by default loops through each `module_id`, and runs the forward pass + of the corresponding `RLModule` object with the associated `batch` within the + input. + It also assumes that the underlying RLModules do not share any parameters or + communication with one another. The behavior of modules with such advanced + communication would be undefined by default. To share parameters or communication + between the underlying RLModules, you should implement your own + `MultiRLModule` subclass. + """ + + def __init__( + self, + config=DEPRECATED_VALUE, + *, + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, + inference_only: Optional[bool] = None, + # TODO (sven): Ignore learner_only setting for now on MultiRLModule. + learner_only: Optional[bool] = None, + model_config: Optional[dict] = None, + rl_module_specs: Optional[Dict[ModuleID, RLModuleSpec]] = None, + **kwargs, + ) -> None: + """Initializes a MultiRLModule instance. + + Args: + observation_space: The MultiRLModule's observation space. + action_space: The MultiRLModule's action space. + inference_only: The MultiRLModule's `inference_only` setting. If True, force + sets all inference_only flags inside `rl_module_specs` also to True. + If None, infers the value for `inference_only` by setting it to True, + iff all `inference_only` flags inside `rl_module_specs`, otherwise to + False. + model_config: The MultiRLModule's `model_config` dict. + rl_module_specs: A dict mapping ModuleIDs to `RLModuleSpec` instances used + to create the submodules. + """ + if config != DEPRECATED_VALUE and isinstance(config, MultiRLModuleConfig): + deprecation_warning( + old="MultiRLModule(config=..)", + new="MultiRLModule(*, observation_space=.., action_space=.., " + "inference_only=.., model_config=.., rl_module_specs=..)", + error=True, + ) + + # Make sure we don't alter incoming module specs in this c'tor. + rl_module_specs = copy.deepcopy(rl_module_specs or {}) + # Figure out global inference_only setting. + # If not provided (None), only if all submodules are + # inference_only, this MultiRLModule will be inference_only. + inference_only = ( + inference_only + if inference_only is not None + else all(spec.inference_only for spec in rl_module_specs.values()) + ) + # If given inference_only=True, make all submodules also inference_only (before + # creating them). + if inference_only is True: + for rl_module_spec in rl_module_specs.values(): + rl_module_spec.inference_only = True + self._check_module_specs(rl_module_specs) + self.rl_module_specs = rl_module_specs + + super().__init__( + observation_space=observation_space, + action_space=action_space, + inference_only=inference_only, + learner_only=None, + catalog_class=None, + model_config=model_config, + **kwargs, + ) + + @OverrideToImplementCustomLogic + @override(RLModule) + def setup(self): + """Sets up the underlying, individual RLModules.""" + self._rl_modules = {} + # Make sure all individual RLModules have the same framework OR framework=None. + framework = None + for module_id, rl_module_spec in self.rl_module_specs.items(): + self._rl_modules[module_id] = rl_module_spec.build() + if framework is None: + framework = self._rl_modules[module_id].framework + else: + assert self._rl_modules[module_id].framework in [None, framework] + self.framework = framework + + @override(RLModule) + def _forward( + self, + batch: Dict[ModuleID, Any], + **kwargs, + ) -> Dict[ModuleID, Dict[str, Any]]: + """Generic forward pass method, used in all phases of training and evaluation. + + If you need a more nuanced distinction between forward passes in the different + phases of training and evaluation, override the following methods instead: + For distinct action computation logic w/o exploration, override the + `self._forward_inference()` method. + For distinct action computation logic with exploration, override the + `self._forward_exploration()` method. + For distinct forward pass logic before loss computation, override the + `self._forward_train()` method. + + Args: + batch: The input batch, a dict mapping from ModuleID to individual modules' + batches. + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. + """ + return { + mid: self._rl_modules[mid]._forward(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @override(RLModule) + def _forward_inference( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used for action computation without exploration behavior. + + Override this method only, if you need specific behavior for non-exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_inference(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @override(RLModule) + def _forward_exploration( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used for action computation with exploration behavior. + + Override this method only, if you need specific behavior for exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_exploration(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @override(RLModule) + def _forward_train( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used before the loss computation (training). + + Override this method only, if you need specific behavior and outputs for your + loss computations. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_train(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_initial_state(self) -> Any: + # TODO (sven): Replace by call to `self.foreach_module`, but only if this method + # supports returning dicts. + ret = {} + for module_id, module in self._rl_modules.items(): + ret[module_id] = module.get_initial_state() + return ret + + @OverrideToImplementCustomLogic + @override(RLModule) + def is_stateful(self) -> bool: + initial_state = self.get_initial_state() + assert isinstance(initial_state, dict), ( + "The initial state of an RLModule must be a dict, but is " + f"{type(initial_state)} instead." + ) + return bool(any(sa_init_state for sa_init_state in initial_state.values())) + + def add_module( + self, + module_id: ModuleID, + module: RLModule, + *, + override: bool = False, + ) -> None: + """Adds a module at run time to the multi-agent module. + + Args: + module_id: The module ID to add. If the module ID already exists and + override is False, an error is raised. If override is True, the module + is replaced. + module: The module to add. + override: Whether to override the module if it already exists. + + Raises: + ValueError: If the module ID already exists and override is False. + Warnings are raised if the module id is not valid according to the + logic of ``validate_module_id()``. + """ + from ray.rllib.core.rl_module import validate_module_id + + validate_module_id(module_id) + + if module_id in self._rl_modules and not override: + raise ValueError( + f"Module ID {module_id} already exists. If your intention is to " + "override, set override=True." + ) + # Set our own inference_only flag to False as soon as any added Module + # has `inference_only=False`. + if not module.inference_only: + self.inference_only = False + + # Check framework of incoming RLModule against `self.framework`. + if module.framework is not None: + if self.framework is None: + self.framework = module.framework + elif module.framework != self.framework: + raise ValueError( + f"Framework ({module.framework}) of incoming RLModule does NOT " + f"match framework ({self.framework}) of MultiRLModule! If the " + f"added module should not be trained, try setting its framework " + f"to None." + ) + + self._rl_modules[module_id] = module + # Update our RLModuleSpecs dict, such that - if written to disk - + # it'll allow for proper restoring this instance through `.from_checkpoint()`. + self.rl_module_specs[module_id] = RLModuleSpec.from_module(module) + + def remove_module( + self, module_id: ModuleID, *, raise_err_if_not_found: bool = True + ) -> None: + """Removes a module at runtime from the multi-agent module. + + Args: + module_id: The module ID to remove. + raise_err_if_not_found: Whether to raise an error if the module ID is not + found. + Raises: + ValueError: If the module ID does not exist and raise_err_if_not_found is + True. + """ + if raise_err_if_not_found: + self._check_module_exists(module_id) + del self._rl_modules[module_id] + del self.rl_module_specs[module_id] + + def foreach_module( + self, + func: Callable[[ModuleID, RLModule, Optional[Any]], T], + *, + return_dict: bool = False, + **kwargs, + ) -> Union[List[T], Dict[ModuleID, T]]: + """Calls the given function with each (module_id, module). + + Args: + func: The function to call with each (module_id, module) tuple. + return_dict: Whether to return a dict mapping ModuleID to the individual + module's return values of calling `func`. If False (default), return + a list. + + Returns: + The list of return values of all calls to + `func([module_id, module, **kwargs])` or a dict (if `return_dict=True`) + mapping ModuleIDs to the respective models' return values. + """ + ret_dict = { + module_id: func(module_id, module.unwrapped(), **kwargs) + for module_id, module in self._rl_modules.items() + } + if return_dict: + return ret_dict + return list(ret_dict.values()) + + def __contains__(self, item) -> bool: + """Returns whether the given `item` (ModuleID) is present in self.""" + return item in self._rl_modules + + def __getitem__(self, module_id: ModuleID) -> RLModule: + """Returns the RLModule with the given module ID. + + Args: + module_id: The module ID to get. + + Returns: + The RLModule with the given module ID. + + Raises: + KeyError: If `module_id` cannot be found in self. + """ + self._check_module_exists(module_id) + return self._rl_modules[module_id] + + def get( + self, + module_id: ModuleID, + default: Optional[RLModule] = None, + ) -> Optional[RLModule]: + """Returns the module with the given module ID or default if not found in self. + + Args: + module_id: The module ID to get. + + Returns: + The RLModule with the given module ID or `default` if `module_id` not found + in `self`. + """ + if module_id not in self._rl_modules: + return default + return self._rl_modules[module_id] + + def items(self) -> ItemsView[ModuleID, RLModule]: + """Returns an ItemsView over the module IDs in this MultiRLModule.""" + return self._rl_modules.items() + + def keys(self) -> KeysView[ModuleID]: + """Returns a KeysView over the module IDs in this MultiRLModule.""" + return self._rl_modules.keys() + + def values(self) -> ValuesView[ModuleID]: + """Returns a ValuesView over the module IDs in this MultiRLModule.""" + return self._rl_modules.values() + + def __len__(self) -> int: + """Returns the number of RLModules within this MultiRLModule.""" + return len(self._rl_modules) + + def __repr__(self) -> str: + return f"MARL({pprint.pformat(self._rl_modules)})" + + @override(RLModule) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + inference_only: bool = False, + **kwargs, + ) -> StateDict: + state = {} + + for module_id, rl_module in self.get_checkpointable_components(): + if self._check_component(module_id, components, not_components): + state[module_id] = rl_module.get_state( + components=self._get_subcomponents(module_id, components), + not_components=self._get_subcomponents(module_id, not_components), + inference_only=inference_only, + ) + return state + + @override(RLModule) + def set_state(self, state: StateDict) -> None: + """Sets the state of the multi-agent module. + + It is assumed that the state_dict is a mapping from module IDs to the + corresponding module's state. This method sets the state of each module by + calling their set_state method. If you want to set the state of some of the + RLModules within this MultiRLModule your state_dict can only include the + state of those RLModules. Override this method to customize the state_dict for + custom more advanced multi-agent use cases. + + Args: + state: The state dict to set. + """ + # Now, set the individual states + for module_id, module_state in state.items(): + if module_id in self: + self._rl_modules[module_id].set_state(module_state) + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args + { + "observation_space": self.observation_space, + "action_space": self.action_space, + "inference_only": self.inference_only, + "learner_only": self.learner_only, + "model_config": self.model_config, + "rl_module_specs": self.rl_module_specs, + }, # **kwargs + ) + + @override(Checkpointable) + def get_checkpointable_components(self) -> List[Tuple[str, Checkpointable]]: + return list(self._rl_modules.items()) + + @override(RLModule) + def output_specs_train(self) -> SpecType: + return [] + + @override(RLModule) + def output_specs_inference(self) -> SpecType: + return [] + + @override(RLModule) + def output_specs_exploration(self) -> SpecType: + return [] + + @override(RLModule) + def _default_input_specs(self) -> SpecType: + """MultiRLModule should not check the input specs. + + The underlying single-agent RLModules will check the input specs. + """ + return [] + + @override(RLModule) + def as_multi_rl_module(self) -> "MultiRLModule": + """Returns self in order to match `RLModule.as_multi_rl_module()` behavior. + + This method is overridden to avoid double wrapping. + + Returns: + The instance itself. + """ + return self + + @classmethod + def _check_module_specs(cls, rl_module_specs: Dict[ModuleID, RLModuleSpec]): + """Checks the individual RLModuleSpecs for validity. + + Args: + rl_module_specs: Dict mapping ModuleIDs to the respective RLModuleSpec. + + Raises: + ValueError: If any RLModuleSpec is invalid. + """ + for module_id, rl_module_spec in rl_module_specs.items(): + if not isinstance(rl_module_spec, RLModuleSpec): + raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") + + def _check_module_exists(self, module_id: ModuleID) -> None: + if module_id not in self._rl_modules: + raise KeyError( + f"Module with module_id {module_id} not found. " + f"Available modules: {set(self.keys())}" + ) + + +@PublicAPI(stability="alpha") +@dataclasses.dataclass +class MultiRLModuleSpec: + """A utility spec class to make it constructing MultiRLModules easier. + + Users can extend this class to modify the behavior of base class. For example to + share neural networks across the modules, the build method can be overridden to + create the shared module first and then pass it to custom module classes that would + then use it as a shared module. + + Args: + multi_rl_module_class: The class of the MultiRLModule to construct. By + default, this is the base `MultiRLModule` class. + observation_space: Optional global observation space for the MultiRLModule. + Useful for shared network components that live only inside the MultiRLModule + and don't have their own ModuleID and own RLModule within + `self._rl_modules`. + action_space: Optional global action space for the MultiRLModule. + Useful for shared network components that live only inside the MultiRLModule + and don't have their own ModuleID and own RLModule within + `self._rl_modules`. + inference_only: An optional global inference_only flag. If not set (None by + default), considers the MultiRLModule to be inference_only=True, only + if all submodules also have their own inference_only flags set to True. + model_config: An optional global model_config dict. Useful to configure shared + network components that only live inside the MultiRLModule and don't have + their own ModuleID and own RLModule within `self._rl_modules`. + rl_module_specs: The module specs for each individual module. It can be either a + RLModuleSpec used for all module_ids or a dictionary mapping + from module IDs to RLModuleSpecs for each individual module. + load_state_path: The path to the module state to load from. NOTE: This must be + an absolute path. NOTE: If the load_state_path of this spec is set, and + the load_state_path of one of the RLModuleSpecs' is also set, + the weights of that RL Module will be loaded from the path specified in + the RLModuleSpec. This is useful if you want to load the weights + of a MultiRLModule and also manually load the weights of some of the RL + modules within that MultiRLModule from other checkpoints. + modules_to_load: A set of module ids to load from the checkpoint. This is + only used if load_state_path is set. If this is None, all modules are + loaded. + """ + + multi_rl_module_class: Type[MultiRLModule] = MultiRLModule + observation_space: Optional[gym.Space] = None + action_space: Optional[gym.Space] = None + inference_only: Optional[bool] = None + # TODO (sven): Once we support MultiRLModules inside other MultiRLModules, we would + # need this flag in here as well, but for now, we'll leave it out for simplicity. + # learner_only: bool = False + model_config: Optional[dict] = None + rl_module_specs: Union[RLModuleSpec, Dict[ModuleID, RLModuleSpec]] = None + + # TODO (sven): Deprecate these in favor of using the pure Checkpointable APIs for + # loading and saving state. + load_state_path: Optional[str] = None + modules_to_load: Optional[Set[ModuleID]] = None + + # Deprecated: Do not use anymore. + module_specs: Optional[Union[RLModuleSpec, Dict[ModuleID, RLModuleSpec]]] = None + + def __post_init__(self): + if self.module_specs is not None: + deprecation_warning( + old="MultiRLModuleSpec(module_specs=..)", + new="MultiRLModuleSpec(rl_module_specs=..)", + error=True, + ) + if self.rl_module_specs is None: + raise ValueError( + "Module_specs cannot be None. It should be either a " + "RLModuleSpec or a dictionary mapping from module IDs to " + "RLModuleSpecs for each individual module." + ) + self.module_specs = self.rl_module_specs + # Figure out global inference_only setting. + # If not provided (None), only if all submodules are + # inference_only, this MultiRLModule will be inference_only. + self.inference_only = ( + self.inference_only + if self.inference_only is not None + else all(spec.inference_only for spec in self.rl_module_specs.values()) + ) + + @OverrideToImplementCustomLogic + def build(self, module_id: Optional[ModuleID] = None) -> RLModule: + """Builds either the MultiRLModule or a (single) sub-RLModule under `module_id`. + + Args: + module_id: Optional ModuleID of a single RLModule to be built. If None + (default), builds the MultiRLModule. + + Returns: + The built RLModule if `module_id` is provided, otherwise the built + MultiRLModule. + """ + self._check_before_build() + + # ModuleID provided, return single-agent RLModule. + if module_id: + return self.rl_module_specs[module_id].build() + + # Return MultiRLModule. + try: + module = self.multi_rl_module_class( + observation_space=self.observation_space, + action_space=self.action_space, + inference_only=self.inference_only, + model_config=( + dataclasses.asdict(self.model_config) + if dataclasses.is_dataclass(self.model_config) + else self.model_config + ), + rl_module_specs=self.rl_module_specs, + ) + # Older custom model might still require the old `MultiRLModuleConfig` under + # the `config` arg. + except AttributeError as e: + if self.multi_rl_module_class is not MultiRLModule: + multi_rl_module_config = self.get_rl_module_config() + module = self.multi_rl_module_class(multi_rl_module_config) + else: + raise e + + return module + + def add_modules( + self, + module_specs: Dict[ModuleID, RLModuleSpec], + override: bool = True, + ) -> None: + """Add new module specs to the spec or updates existing ones. + + Args: + module_specs: The mapping for the module_id to the single-agent module + specs to be added to this multi-agent module spec. + override: Whether to override the existing module specs if they already + exist. If False, they are only updated. + """ + if self.rl_module_specs is None: + self.rl_module_specs = {} + for module_id, module_spec in module_specs.items(): + if override or module_id not in self.rl_module_specs: + # Disable our `inference_only` as soon as any single-agent module has + # `inference_only=False`. + if not module_spec.inference_only: + self.inference_only = False + self.rl_module_specs[module_id] = module_spec + else: + self.rl_module_specs[module_id].update(module_spec) + + def remove_modules(self, module_ids: Union[ModuleID, Collection[ModuleID]]) -> None: + """Removes the provided ModuleIDs from this MultiRLModuleSpec. + + Args: + module_ids: Collection of the ModuleIDs to remove from this spec. + """ + for module_id in force_list(module_ids): + self.rl_module_specs.pop(module_id, None) + + @classmethod + def from_module(self, module: MultiRLModule) -> "MultiRLModuleSpec": + """Creates a MultiRLModuleSpec from a MultiRLModule. + + Args: + module: The MultiRLModule to create the spec from. + + Returns: + The MultiRLModuleSpec. + """ + # we want to get the spec of the underlying unwrapped module that way we can + # easily reconstruct it. The only wrappers that we expect to support today are + # wrappers that allow us to do distributed training. Those will be added back + # by the learner if necessary. + rl_module_specs = { + module_id: RLModuleSpec.from_module(rl_module.unwrapped()) + for module_id, rl_module in module._rl_modules.items() + } + multi_rl_module_class = module.__class__ + return MultiRLModuleSpec( + multi_rl_module_class=multi_rl_module_class, + observation_space=module.observation_space, + action_space=module.action_space, + inference_only=module.inference_only, + model_config=module.model_config, + rl_module_specs=rl_module_specs, + ) + + def _check_before_build(self): + if not isinstance(self.rl_module_specs, dict): + raise ValueError( + f"When build() is called on {self.__class__}, the `rl_module_specs` " + "attribute should be a dictionary mapping ModuleIDs to " + "RLModuleSpecs for each individual RLModule." + ) + + def to_dict(self) -> Dict[str, Any]: + """Converts the MultiRLModuleSpec to a dictionary.""" + return { + "multi_rl_module_class": serialize_type(self.multi_rl_module_class), + "observation_space": gym_space_to_dict(self.observation_space), + "action_space": gym_space_to_dict(self.action_space), + "inference_only": self.inference_only, + "model_config": self.model_config, + "rl_module_specs": { + module_id: rl_module_spec.to_dict() + for module_id, rl_module_spec in self.rl_module_specs.items() + }, + } + + @classmethod + def from_dict(cls, d) -> "MultiRLModuleSpec": + """Creates a MultiRLModuleSpec from a dictionary.""" + return MultiRLModuleSpec( + multi_rl_module_class=deserialize_type(d["multi_rl_module_class"]), + observation_space=gym_space_from_dict(d.get("observation_space")), + action_space=gym_space_from_dict(d.get("action_space")), + model_config=d.get("model_config"), + inference_only=d["inference_only"], + rl_module_specs={ + module_id: RLModuleSpec.from_dict(rl_module_spec) + for module_id, rl_module_spec in ( + d.get("rl_module_specs", d.get("module_specs")).items() + ) + }, + ) + + def update( + self, + other: Union["MultiRLModuleSpec", RLModuleSpec], + override: bool = False, + ) -> None: + """Updates this spec with the other spec. + + Traverses this MultiRLModuleSpec's module_specs and updates them with + the module specs from the `other` (Multi)RLModuleSpec. + + Args: + other: The other spec to update this spec with. + override: Whether to override the existing module specs if they already + exist. If False, they are only updated. + """ + if isinstance(other, RLModuleSpec): + # Disable our `inference_only` as soon as any single-agent module has + # `inference_only=False`. + if not other.inference_only: + self.inference_only = False + for mid, spec in self.rl_module_specs.items(): + self.rl_module_specs[mid].update(other, override=False) + elif isinstance(other.module_specs, dict): + self.add_modules(other.rl_module_specs, override=override) + else: + assert isinstance(other, MultiRLModuleSpec) + if not self.rl_module_specs: + self.inference_only = other.inference_only + self.rl_module_specs = other.rl_module_specs + else: + if not other.inference_only: + self.inference_only = False + self.rl_module_specs.update(other.rl_module_specs) + + def as_multi_rl_module_spec(self) -> "MultiRLModuleSpec": + """Returns self in order to match `RLModuleSpec.as_multi_rl_module_spec()`.""" + return self + + def __contains__(self, item) -> bool: + """Returns whether the given `item` (ModuleID) is present in self.""" + return item in self.rl_module_specs + + def __getitem__(self, item) -> RLModuleSpec: + """Returns the RLModuleSpec under the ModuleID.""" + return self.rl_module_specs[item] + + @Deprecated( + new="MultiRLModule(*, module_specs={module1: [RLModuleSpec], " + "module2: [RLModuleSpec], ..}, inference_only=..)", + error=True, + ) + def get_multi_rl_module_config(self): + pass + + @Deprecated(new="MultiRLModuleSpec.as_multi_rl_module_spec()", error=True) + def as_multi_agent(self): + pass + + @Deprecated(new="MultiRLModuleSpec.get_multi_rl_module_config", error=True) + def get_marl_config(self, *args, **kwargs): + pass + + @Deprecated( + new="MultiRLModule(*, observation_space=.., action_space=.., ....)", + error=False, + ) + def get_rl_module_config(self): + return MultiRLModuleConfig( + inference_only=self.inference_only, + modules=self.rl_module_specs, + ) + + +@Deprecated( + new="MultiRLModule(*, rl_module_specs={module1: [RLModuleSpec], " + "module2: [RLModuleSpec], ..}, inference_only=..)", + error=False, +) +@dataclasses.dataclass +class MultiRLModuleConfig: + inference_only: bool = False + modules: Dict[ModuleID, RLModuleSpec] = dataclasses.field(default_factory=dict) + + def to_dict(self): + return { + "inference_only": self.inference_only, + "modules": { + module_id: module_spec.to_dict() + for module_id, module_spec in self.modules.items() + }, + } + + @classmethod + def from_dict(cls, d) -> "MultiRLModuleConfig": + return cls( + inference_only=d["inference_only"], + modules={ + module_id: RLModuleSpec.from_dict(module_spec) + for module_id, module_spec in d["modules"].items() + }, + ) + + def get_catalog(self) -> None: + return None diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..27f0861f365ba9f5c5ef8c0acca9ebe6c53c63bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/rl_module.py @@ -0,0 +1,801 @@ +import abc +import dataclasses +from dataclasses import dataclass, field +import logging +from typing import Any, Collection, Dict, Optional, Type, TYPE_CHECKING, Union + +import gymnasium as gym + +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.specs.typing import SpecType +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.models.distributions import Distribution +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.deprecation import ( + Deprecated, + DEPRECATED_VALUE, + deprecation_warning, +) +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.serialization import ( + gym_space_from_dict, + gym_space_to_dict, + serialize_type, + deserialize_type, +) +from ray.rllib.utils.typing import StateDict +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, + ) + from ray.rllib.core.models.catalog import Catalog + +logger = logging.getLogger("ray.rllib") +torch, _ = try_import_torch() + + +@PublicAPI(stability="beta") +@dataclass +class RLModuleSpec: + """Utility spec class to make constructing RLModules (in single-agent case) easier. + + Args: + module_class: The RLModule class to use. + observation_space: The observation space of the RLModule. This may differ + from the observation space of the environment. For example, a discrete + observation space of an environment, would usually correspond to a + one-hot encoded observation space of the RLModule because of preprocessing. + action_space: The action space of the RLModule. + inference_only: Whether the RLModule should be configured in its inference-only + state, in which those components not needed for action computing (for + example a value function or a target network) might be missing. + Note that `inference_only=True` AND `learner_only=True` is not allowed. + learner_only: Whether this RLModule should only be built on Learner workers, but + NOT on EnvRunners. Useful for RLModules inside a MultiRLModule that are only + used for training, for example a shared value function in a multi-agent + setup or a world model in a curiosity-learning setup. + Note that `inference_only=True` AND `learner_only=True` is not allowed. + model_config: The model config dict or default RLlib dataclass to use. + catalog_class: The Catalog class to use. + load_state_path: The path to the module state to load from. NOTE: This must be + an absolute path. + """ + + module_class: Optional[Type["RLModule"]] = None + observation_space: Optional[gym.Space] = None + action_space: Optional[gym.Space] = None + inference_only: bool = False + learner_only: bool = False + model_config: Optional[Union[Dict[str, Any], DefaultModelConfig]] = None + catalog_class: Optional[Type["Catalog"]] = None + load_state_path: Optional[str] = None + + # Deprecated field. + model_config_dict: Optional[Union[dict, int]] = None + + def __post_init__(self): + if self.model_config_dict is not None: + deprecation_warning( + old="RLModuleSpec(model_config_dict=..)", + new="RLModuleSpec(model_config=..)", + error=True, + ) + + def build(self) -> "RLModule": + """Builds the RLModule from this spec.""" + if self.module_class is None: + raise ValueError("RLModule class is not set.") + if self.observation_space is None: + raise ValueError("Observation space is not set.") + if self.action_space is None: + raise ValueError("Action space is not set.") + + try: + module = self.module_class( + observation_space=self.observation_space, + action_space=self.action_space, + inference_only=self.inference_only, + model_config=self._get_model_config(), + catalog_class=self.catalog_class, + ) + # Older custom model might still require the old `RLModuleConfig` under + # the `config` arg. + except AttributeError: + module_config = self.get_rl_module_config() + module = self.module_class(module_config) + return module + + @classmethod + def from_module(cls, module: "RLModule") -> "RLModuleSpec": + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule + + if isinstance(module, MultiRLModule): + raise ValueError("MultiRLModule cannot be converted to RLModuleSpec.") + + # Try instantiating a new RLModule from the spec using the new c'tor args. + try: + rl_module_spec = RLModuleSpec( + module_class=type(module), + observation_space=module.observation_space, + action_space=module.action_space, + inference_only=module.inference_only, + learner_only=module.learner_only, + model_config=module.model_config, + catalog_class=( + type(module.catalog) if module.catalog is not None else None + ), + ) + + # Old path through deprecated `RLModuleConfig` class. Used only if `module` + # still has a valid `config` attribute. + except AttributeError: + rl_module_spec = RLModuleSpec( + module_class=type(module), + observation_space=module.config.observation_space, + action_space=module.config.action_space, + inference_only=module.config.inference_only, + learner_only=module.config.learner_only, + model_config=module.config.model_config_dict, + catalog_class=module.config.catalog_class, + ) + return rl_module_spec + + def to_dict(self): + """Returns a serialized representation of the spec.""" + return { + "module_class": serialize_type(self.module_class), + "observation_space": gym_space_to_dict(self.observation_space), + "action_space": gym_space_to_dict(self.action_space), + "inference_only": self.inference_only, + "learner_only": self.learner_only, + "model_config": self._get_model_config(), + "catalog_class": serialize_type(self.catalog_class) + if self.catalog_class is not None + else None, + } + + @classmethod + def from_dict(cls, d): + """Returns a single agent RLModule spec from a serialized representation.""" + module_class = deserialize_type(d["module_class"]) + try: + spec = RLModuleSpec( + module_class=module_class, + observation_space=gym_space_from_dict(d["observation_space"]), + action_space=gym_space_from_dict(d["action_space"]), + inference_only=d["inference_only"], + learner_only=d["learner_only"], + model_config=d["model_config"], + catalog_class=deserialize_type(d["catalog_class"]) + if d["catalog_class"] is not None + else None, + ) + + # Old path through deprecated `RLModuleConfig` class. + except KeyError: + module_config = RLModuleConfig.from_dict(d["module_config"]) + spec = RLModuleSpec( + module_class=module_class, + observation_space=module_config.observation_space, + action_space=module_config.action_space, + inference_only=module_config.inference_only, + learner_only=module_config.learner_only, + model_config=module_config.model_config_dict, + catalog_class=module_config.catalog_class, + ) + return spec + + def update(self, other, override: bool = True) -> None: + """Updates this spec with the given other spec. Works like dict.update(). + + Args: + other: The other SingleAgentRLModule spec to update this one from. + override: Whether to update all properties in `self` with those of `other. + If False, only update those properties in `self` that are not None. + """ + if not isinstance(other, RLModuleSpec): + raise ValueError("Can only update with another RLModuleSpec.") + + # If the field is None in the other, keep the current field, otherwise update + # with the new value. + if override: + self.module_class = other.module_class or self.module_class + self.observation_space = other.observation_space or self.observation_space + self.action_space = other.action_space or self.action_space + self.inference_only = other.inference_only or self.inference_only + self.learner_only = other.learner_only and self.learner_only + self.model_config = other.model_config or self.model_config + self.catalog_class = other.catalog_class or self.catalog_class + self.load_state_path = other.load_state_path or self.load_state_path + # Only override, if the field is None in `self`. + # Do NOT override the boolean settings: `inference_only` and `learner_only`. + else: + self.module_class = self.module_class or other.module_class + self.observation_space = self.observation_space or other.observation_space + self.action_space = self.action_space or other.action_space + self.model_config = self.model_config or other.model_config + self.catalog_class = self.catalog_class or other.catalog_class + self.load_state_path = self.load_state_path or other.load_state_path + + def as_multi_rl_module_spec(self) -> "MultiRLModuleSpec": + """Returns a MultiRLModuleSpec (`self` under DEFAULT_MODULE_ID key).""" + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec + + return MultiRLModuleSpec( + rl_module_specs={DEFAULT_MODULE_ID: self}, + load_state_path=self.load_state_path, + ) + + def _get_model_config(self): + return ( + dataclasses.asdict(self.model_config) + if dataclasses.is_dataclass(self.model_config) + else (self.model_config or {}) + ) + + @Deprecated( + new="RLModule(*, observation_space=.., action_space=.., ....)", + error=False, + ) + def get_rl_module_config(self): + return RLModuleConfig( + observation_space=self.observation_space, + action_space=self.action_space, + inference_only=self.inference_only, + learner_only=self.learner_only, + model_config_dict=self._get_model_config(), + catalog_class=self.catalog_class, + ) + + +@PublicAPI(stability="beta") +class RLModule(Checkpointable, abc.ABC): + """Base class for RLlib modules. + + Subclasses should call `super().__init__(observation_space=.., action_space=.., + inference_only=.., learner_only=.., model_config={..})` in their __init__ methods. + + Here is the pseudocode for how the forward methods are called: + + Example for creating a (inference-only) sampling loop: + + .. testcode:: + + from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( + DefaultPPOTorchRLModule + ) + from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog + import gymnasium as gym + import torch + + env = gym.make("CartPole-v1") + + # Create an instance of the default RLModule used by PPO. + module = DefaultPPOTorchRLModule( + observation_space=env.observation_space, + action_space=env.action_space, + model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), + catalog_class=PPOCatalog, + ) + action_dist_class = module.get_inference_action_dist_cls() + obs, info = env.reset() + terminated = False + + while not terminated: + fwd_ins = {"obs": torch.Tensor([obs])} + fwd_outputs = module.forward_exploration(fwd_ins) + # This can be either deterministic or stochastic distribution. + action_dist = action_dist_class.from_logits( + fwd_outputs["action_dist_inputs"] + ) + action = action_dist.sample()[0].numpy() + obs, reward, terminated, truncated, info = env.step(action) + + + Example for training: + + .. testcode:: + + import gymnasium as gym + import torch + + from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( + DefaultPPOTorchRLModule + ) + from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog + + env = gym.make("CartPole-v1") + + # Create an instance of the default RLModule used by PPO. + module = DefaultPPOTorchRLModule( + observation_space=env.observation_space, + action_space=env.action_space, + model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), + catalog_class=PPOCatalog, + ) + + fwd_ins = {"obs": torch.Tensor([obs])} + fwd_outputs = module.forward_train(fwd_ins) + # loss = compute_loss(fwd_outputs, fwd_ins) + # update_params(module, loss) + + Example for inference: + + .. testcode:: + + import gymnasium as gym + import torch + + from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( + DefaultPPOTorchRLModule + ) + from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog + + env = gym.make("CartPole-v1") + + # Create an instance of the default RLModule used by PPO. + module = DefaultPPOTorchRLModule( + observation_space=env.observation_space, + action_space=env.action_space, + model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), + catalog_class=PPOCatalog, + ) + + while not terminated: + fwd_ins = {"obs": torch.Tensor([obs])} + fwd_outputs = module.forward_inference(fwd_ins) + # this can be either deterministic or stochastic distribution + action_dist = action_dist_class.from_logits( + fwd_outputs["action_dist_inputs"] + ) + action = action_dist.sample()[0].numpy() + obs, reward, terminated, truncated, info = env.step(action) + + + Args: + config: The config for the RLModule. + + Abstract Methods: + ``~_forward_train``: Forward pass during training. + + ``~_forward_exploration``: Forward pass during training for exploration. + + ``~_forward_inference``: Forward pass during inference. + """ + + framework: str = None + + STATE_FILE_NAME = "module_state" + + def __init__( + self, + config=DEPRECATED_VALUE, + *, + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, + inference_only: Optional[bool] = None, + learner_only: bool = False, + model_config: Optional[Union[dict, DefaultModelConfig]] = None, + catalog_class=None, + **kwargs, + ): + # TODO (sven): Deprecate Catalog and replace with utility functions to create + # primitive components based on obs- and action spaces. + self.catalog = None + self._catalog_ctor_error = None + + # Deprecated + self.config = config + if self.config != DEPRECATED_VALUE: + deprecation_warning( + old="RLModule(config=[RLModuleConfig])", + new="RLModule(observation_space=.., action_space=.., inference_only=..," + " learner_only=.., model_config=..)", + help="See https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/custom_cnn_rl_module.py " # noqa + "for how to write a custom RLModule.", + error=True, + ) + else: + self.observation_space = observation_space + self.action_space = action_space + self.inference_only = inference_only + self.learner_only = learner_only + self.model_config = model_config + try: + self.catalog = catalog_class( + observation_space=self.observation_space, + action_space=self.action_space, + model_config_dict=self.model_config, + ) + except Exception as e: + logger.warning( + "Could not create a Catalog object for your RLModule! If you are " + "not using the new API stack yet, make sure to switch it off in " + "your config: `config.api_stack(enable_rl_module_and_learner=False" + ", enable_env_runner_and_connector_v2=False)`. All algos " + "use the new stack by default. Ignore this message, if your " + "RLModule does not use a Catalog to build its sub-components." + ) + self._catalog_ctor_error = e + + # TODO (sven): Deprecate this. We keep it here for now in case users + # still have custom models (or subclasses of RLlib default models) + # into which they pass in a `config` argument. + self.config = RLModuleConfig( + observation_space=self.observation_space, + action_space=self.action_space, + inference_only=self.inference_only, + learner_only=self.learner_only, + model_config_dict=self.model_config, + catalog_class=catalog_class, + ) + + self.action_dist_cls = None + if self.catalog is not None: + self.action_dist_cls = self.catalog.get_action_dist_cls( + framework=self.framework + ) + + # Make sure, `setup()` is only called once, no matter what. + if hasattr(self, "_is_setup") and self._is_setup: + raise RuntimeError( + "`RLModule.setup()` called twice within your RLModule implementation " + f"{self}! Make sure you are using the proper inheritance order " + "(TorchRLModule before [Algo]RLModule) or (TfRLModule before " + "[Algo]RLModule) and that you are NOT overriding the constructor, but " + "only the `setup()` method of your subclass." + ) + self.setup() + self._is_setup = True + + @OverrideToImplementCustomLogic + def setup(self): + """Sets up the components of the module. + + This is called automatically during the __init__ method of this class, + therefore, the subclass should call super.__init__() in its constructor. This + abstraction can be used to create any components (e.g. NN layers) that your + RLModule needs. + """ + return None + + @OverrideToImplementCustomLogic + def get_exploration_action_dist_cls(self) -> Type[Distribution]: + """Returns the action distribution class for this RLModule used for exploration. + + This class is used to create action distributions from outputs of the + forward_exploration method. If the case that no action distribution class is + needed, this method can return None. + + Note that RLlib's distribution classes all implement the `Distribution` + interface. This requires two special methods: `Distribution.from_logits()` and + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic + def get_inference_action_dist_cls(self) -> Type[Distribution]: + """Returns the action distribution class for this RLModule used for inference. + + This class is used to create action distributions from outputs of the forward + inference method. If the case that no action distribution class is needed, + this method can return None. + + Note that RLlib's distribution classes all implement the `Distribution` + interface. This requires two special methods: `Distribution.from_logits()` and + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic + def get_train_action_dist_cls(self) -> Type[Distribution]: + """Returns the action distribution class for this RLModule used for training. + + This class is used to get the correct action distribution class to be used by + the training components. In case that no action distribution class is needed, + this method can return None. + + Note that RLlib's distribution classes all implement the `Distribution` + interface. This requires two special methods: `Distribution.from_logits()` and + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. + """ + raise NotImplementedError + + @OverrideToImplementCustomLogic + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Generic forward pass method, used in all phases of training and evaluation. + + If you need a more nuanced distinction between forward passes in the different + phases of training and evaluation, override the following methods instead: + + For distinct action computation logic w/o exploration, override the + `self._forward_inference()` method. + For distinct action computation logic with exploration, override the + `self._forward_exploration()` method. + For distinct forward pass logic before loss computation, override the + `self._forward_train()` method. + + Args: + batch: The input batch. + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. + """ + return {} + + def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler. + + This method should not be overridden. Override the `self._forward_inference()` + method instead. + + Args: + batch: The input batch. This input batch should comply with + input_specs_inference(). + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. This output should comply with the + ouptut_specs_inference(). + """ + return self._forward_inference(batch, **kwargs) + + @OverrideToImplementCustomLogic + def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Forward-pass used for action computation without exploration behavior. + + Override this method only, if you need specific behavior for non-exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + with torch.no_grad(): + return self._forward(batch, **kwargs) + + def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler. + + This method should not be overridden. Override the `self._forward_exploration()` + method instead. + + Args: + batch: The input batch. This input batch should comply with + input_specs_exploration(). + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. This output should comply with the + output_specs_exploration(). + """ + return self._forward_exploration(batch, **kwargs) + + @OverrideToImplementCustomLogic + def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Forward-pass used for action computation with exploration behavior. + + Override this method only, if you need specific behavior for exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + with torch.no_grad(): + return self._forward(batch, **kwargs) + + def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during training called from the learner. + + This method should not be overridden. Override the `self._forward_train()` + method instead. + + Args: + batch: The input batch. This input batch should comply with + input_specs_train(). + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. This output should comply with the + output_specs_train(). + """ + if self.inference_only: + raise RuntimeError( + "Calling `forward_train` on an inference_only module is not allowed! " + "Set the `inference_only=False` flag in the RLModuleSpec (or the " + "RLModule's constructor)." + ) + return self._forward_train(batch, **kwargs) + + @OverrideToImplementCustomLogic + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Forward-pass used before the loss computation (training). + + Override this method only, if you need specific behavior and outputs for your + loss computations. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + def get_initial_state(self) -> Any: + """Returns the initial state of the RLModule, in case this is a stateful module. + + Returns: + A tensor or any nested struct of tensors, representing an initial state for + this (stateful) RLModule. + """ + return {} + + @OverrideToImplementCustomLogic + def is_stateful(self) -> bool: + """By default, returns False if the initial state is an empty dict (or None). + + By default, RLlib assumes that the module is non-recurrent, if the initial + state is an empty dict and recurrent otherwise. + This behavior can be customized by overriding this method. + """ + initial_state = self.get_initial_state() + assert isinstance(initial_state, dict), ( + "The initial state of an RLModule must be a dict, but is " + f"{type(initial_state)} instead." + ) + return bool(initial_state) + + @OverrideToImplementCustomLogic + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + inference_only: bool = False, + **kwargs, + ) -> StateDict: + """Returns the state dict of the module. + + Args: + inference_only: Whether the returned state should be an inference-only + state (w/o those model components that are not needed for action + computations, such as a value function or a target network). + Note that setting this to `False` might raise an error if + `self.inference_only` is True. + + Returns: + This RLModule's state dict. + """ + return {} + + @OverrideToImplementCustomLogic + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + pass + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args + { + "observation_space": self.observation_space, + "action_space": self.action_space, + "inference_only": self.inference_only, + "learner_only": self.learner_only, + "model_config": self.model_config, + "catalog_class": ( + type(self.catalog) if self.catalog is not None else None + ), + }, # **kwargs + ) + + def as_multi_rl_module(self) -> "MultiRLModule": + """Returns a multi-agent wrapper around this module.""" + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule + + multi_rl_module = MultiRLModule( + rl_module_specs={DEFAULT_MODULE_ID: RLModuleSpec.from_module(self)} + ) + return multi_rl_module + + def unwrapped(self) -> "RLModule": + """Returns the underlying module if this module is a wrapper. + + An example of a wrapped is the TorchDDPRLModule class, which wraps + a TorchRLModule. + + Returns: + The underlying module. + """ + return self + + def output_specs_inference(self) -> SpecType: + return [Columns.ACTION_DIST_INPUTS] + + def output_specs_exploration(self) -> SpecType: + return [Columns.ACTION_DIST_INPUTS] + + def output_specs_train(self) -> SpecType: + """Returns the output specs of the forward_train method.""" + return {} + + def input_specs_inference(self) -> SpecType: + """Returns the input specs of the forward_inference method.""" + return self._default_input_specs() + + def input_specs_exploration(self) -> SpecType: + """Returns the input specs of the forward_exploration method.""" + return self._default_input_specs() + + def input_specs_train(self) -> SpecType: + """Returns the input specs of the forward_train method.""" + return self._default_input_specs() + + def _default_input_specs(self) -> SpecType: + """Returns the default input specs.""" + return [Columns.OBS] + + +@Deprecated( + old="RLModule(config=[RLModuleConfig object])", + new="RLModule(observation_space=.., action_space=.., inference_only=.., " + "model_config=.., catalog_class=..)", + error=False, +) +@dataclass +class RLModuleConfig: + observation_space: gym.Space = None + action_space: gym.Space = None + inference_only: bool = False + learner_only: bool = False + model_config_dict: Dict[str, Any] = field(default_factory=dict) + catalog_class: Type["Catalog"] = None + + def get_catalog(self) -> Optional["Catalog"]: + if self.catalog_class is not None: + return self.catalog_class( + observation_space=self.observation_space, + action_space=self.action_space, + model_config_dict=self.model_config_dict, + ) + return None + + def to_dict(self): + catalog_class_path = ( + serialize_type(self.catalog_class) if self.catalog_class else "" + ) + return { + "observation_space": gym_space_to_dict(self.observation_space), + "action_space": gym_space_to_dict(self.action_space), + "inference_only": self.inference_only, + "learner_only": self.learner_only, + "model_config_dict": self.model_config_dict, + "catalog_class_path": catalog_class_path, + } + + @classmethod + def from_dict(cls, d: Dict[str, Any]): + catalog_class = ( + None + if d["catalog_class_path"] == "" + else deserialize_type(d["catalog_class_path"]) + ) + return cls( + observation_space=gym_space_from_dict(d["observation_space"]), + action_space=gym_space_from_dict(d["action_space"]), + inference_only=d["inference_only"], + learner_only=d["learner_only"], + model_config_dict=d["model_config_dict"], + catalog_class=catalog_class, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee44f4967a204cb03397e3859dcde9fe82d514c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__pycache__/tf_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__pycache__/tf_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98ff3bcdedb4e4c62b2f0a8ffe8e299598c1b207 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/__pycache__/tf_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/tf_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/tf_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..144ba00953e62b23c3de5772034a1e0c22929bcc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/tf/tf_rl_module.py @@ -0,0 +1,91 @@ +from typing import Any, Collection, Dict, Optional, Type, Union + +import gymnasium as gym + +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.models.tf.tf_distributions import ( + TfCategorical, + TfDiagGaussian, + TfDistribution, +) +from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import StateDict + +_, tf, _ = try_import_tf() + + +class TfRLModule(tf.keras.Model, RLModule): + """Base class for RLlib TensorFlow RLModules.""" + + framework = "tf2" + + def __init__(self, *args, **kwargs) -> None: + tf.keras.Model.__init__(self) + RLModule.__init__(self, *args, **kwargs) + + def call(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Forward pass of the module. + + Note: + This is aliased to forward_train to follow the Keras Model API. + + Args: + batch: The input batch. This input batch should comply with + input_specs_train(). + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. This output should comply with the + ouptut_specs_train(). + + """ + return self.forward_train(batch) + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + inference_only: bool = False, + **kwargs, + ) -> StateDict: + return self.get_weights() + + @OverrideToImplementCustomLogic + @override(RLModule) + def set_state(self, state: StateDict) -> None: + self.set_weights(state) + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_inference_action_dist_cls(self) -> Type[TfDistribution]: + if self.action_dist_cls is not None: + return self.action_dist_cls + elif isinstance(self.action_space, gym.spaces.Discrete): + return TfCategorical + elif isinstance(self.action_space, gym.spaces.Box): + return TfDiagGaussian + else: + raise ValueError( + f"Default action distribution for action space " + f"{self.action_space} not supported! Either set the " + f"`self.action_dist_cls` property in your RLModule's `setup()` method " + f"to a subclass of `ray.rllib.models.tf.tf_distributions." + f"TfDistribution` or - if you need different distributions for " + f"inference and training - override the three methods: " + f"`get_inference_action_dist_cls`, `get_exploration_action_dist_cls`, " + f"and `get_train_action_dist_cls` in your RLModule." + ) + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_exploration_action_dist_cls(self) -> Type[TfDistribution]: + return self.get_inference_action_dist_cls() + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_train_action_dist_cls(self) -> Type[TfDistribution]: + return self.get_inference_action_dist_cls() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..217c5a5abc8aa4ec29708e81d0b2884256d8fae7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__init__.py @@ -0,0 +1,3 @@ +from .torch_rl_module import TorchRLModule + +__all__ = ["TorchRLModule"] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e822d03008199a37092fb1e3cc7b7504540151db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/torch_compile_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/torch_compile_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f5a175546ad4809df7d3051b60d9f00d563a6d8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/torch_compile_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/torch_rl_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/torch_rl_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1033f3403318c82968253ccbd7551de29d41ce45 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/__pycache__/torch_rl_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/torch_compile_config.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/torch_compile_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7c95c62fb52acab3c054c5cce7cc429046f1c84e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/torch_compile_config.py @@ -0,0 +1,39 @@ +import sys +from dataclasses import dataclass, field + + +@dataclass +class TorchCompileConfig: + """Configuration options for RLlib's usage of torch.compile in RLModules. + + # On `torch.compile` in Torch RLModules + `torch.compile` invokes torch's dynamo JIT compiler that can potentially bring + speedups to RL Module's forward methods. + This is a performance optimization that should be disabled for debugging. + + General usage: + - Usually, you only want to `RLModule._forward_train` to be compiled on + instances of RLModule used for learning. (e.g. the learner) + - In some cases, it can bring speedups to also compile + `RLModule._forward_exploration` on instances used for exploration. (e.g. + RolloutWorker) + - In some cases, it can bring speedups to also compile + `RLModule._forward_inference` on instances used for inference. (e.g. + RolloutWorker) + + Note that different backends are available on different platforms. + Also note that the default backend for torch dynamo is "aot_eager" on macOS. + This is a debugging backend that is expected not to improve performance because + the inductor backend is not supported on OSX so far. + + Args: + torch_dynamo_backend: The torch.dynamo backend to use. + torch_dynamo_mode: The torch.dynamo mode to use. + kwargs: Additional keyword arguments to pass to `torch.compile()` + """ + + torch_dynamo_backend: str = ( + "aot_eager" if sys.platform == "darwin" else "cudagraphs" + ) + torch_dynamo_mode: str = None + kwargs: dict = field(default_factory=lambda: dict()) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/torch_rl_module.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/torch_rl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e479b156889ae3feaa75a850cf854844083ac7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/rl_module/torch/torch_rl_module.py @@ -0,0 +1,305 @@ +from typing import Any, Collection, Dict, Optional, Union, Type + +import gymnasium as gym +from packaging import version + +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig +from ray.rllib.models.torch.torch_distributions import ( + TorchCategorical, + TorchDiagGaussian, + TorchDistribution, +) +from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import ( + convert_to_torch_tensor, + TORCH_COMPILE_REQUIRED_VERSION, +) +from ray.rllib.utils.typing import StateDict + +torch, nn = try_import_torch() + + +class TorchRLModule(nn.Module, RLModule): + """A base class for RLlib PyTorch RLModules. + + Note that the `_forward` methods of this class can be 'torch.compiled' individually: + - `TorchRLModule._forward_train()` + - `TorchRLModule._forward_inference()` + - `TorchRLModule._forward_exploration()` + + As a rule of thumb, they should only contain torch-native tensor manipulations, + or otherwise they may yield wrong outputs. In particular, the creation of RLlib + distributions inside these methods should be avoided when using `torch.compile`. + When in doubt, you can use `torch.dynamo.explain()` to check whether a compiled + method has broken up into multiple sub-graphs. + + Compiling these methods can bring speedups under certain conditions. + """ + + framework: str = "torch" + + # Stick with torch default. + STATE_FILE_NAME = "module_state" + + def __init__(self, *args, **kwargs) -> None: + nn.Module.__init__(self) + RLModule.__init__(self, *args, **kwargs) + + # If an inference-only class AND self.inference_only is True, + # remove all attributes that are returned by + # `self.get_non_inference_attributes()`. + if self.inference_only and isinstance(self, InferenceOnlyAPI): + for attr in self.get_non_inference_attributes(): + parts = attr.split(".") + if not hasattr(self, parts[0]): + continue + target_name = parts[0] + target_obj = getattr(self, target_name) + # Traverse from the next part on (if nested). + for part in parts[1:]: + if not hasattr(target_obj, part): + target_obj = None + break + target_name = part + target_obj = getattr(target_obj, target_name) + # Delete, if target is valid. + if target_obj is not None: + delattr(self, target_name) + + def compile(self, compile_config: TorchCompileConfig): + """Compile the forward methods of this module. + + This is a convenience method that calls `compile_wrapper` with the given + compile_config. + + Args: + compile_config: The compile config to use. + """ + return compile_wrapper(self, compile_config) + + @OverrideToImplementCustomLogic + def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + # By default, calls the generic `_forward()` method, but with a no-grad context + # for performance reasons. + with torch.no_grad(): + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + # By default, calls the generic `_forward()` method, but with a no-grad context + # for performance reasons. + with torch.no_grad(): + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + inference_only: bool = False, + **kwargs, + ) -> StateDict: + state_dict = self.state_dict() + # Filter out `inference_only` keys from the state dict if `inference_only` and + # this RLModule is NOT `inference_only` (but does implement the + # InferenceOnlyAPI). + if ( + inference_only + and not self.inference_only + and isinstance(self, InferenceOnlyAPI) + ): + attr = self.get_non_inference_attributes() + for key in list(state_dict.keys()): + if any( + key.startswith(a) and (len(key) == len(a) or key[len(a)] == ".") + for a in attr + ): + del state_dict[key] + return convert_to_numpy(state_dict) + + @OverrideToImplementCustomLogic + @override(RLModule) + def set_state(self, state: StateDict) -> None: + # If state contains more keys than `self.state_dict()`, then we simply ignore + # these keys (strict=False). This is most likely due to `state` coming from + # an `inference_only=False` RLModule, while `self` is an `inference_only=True` + # RLModule. + self.load_state_dict(convert_to_torch_tensor(state), strict=False) + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_inference_action_dist_cls(self) -> Type[TorchDistribution]: + if self.action_dist_cls is not None: + return self.action_dist_cls + elif isinstance(self.action_space, gym.spaces.Discrete): + return TorchCategorical + elif isinstance(self.action_space, gym.spaces.Box): + return TorchDiagGaussian + else: + raise ValueError( + f"Default action distribution for action space " + f"{self.action_space} not supported! Either set the " + f"`self.action_dist_cls` property in your RLModule's `setup()` method " + f"to a subclass of `ray.rllib.models.torch.torch_distributions." + f"TorchDistribution` or - if you need different distributions for " + f"inference and training - override the three methods: " + f"`get_inference_action_dist_cls`, `get_exploration_action_dist_cls`, " + f"and `get_train_action_dist_cls` in your RLModule." + ) + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]: + return self.get_inference_action_dist_cls() + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_train_action_dist_cls(self) -> Type[TorchDistribution]: + return self.get_inference_action_dist_cls() + + def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! + + This is aliased to `self.forward_train` because Torch DDP requires a forward + method to be implemented for backpropagation to work. + + Instead, override: + `_forward()` to define a generic forward pass for all phases (exploration, + inference, training) + `_forward_inference()` to define the forward pass for action inference in + deployment/production (no exploration). + `_forward_exploration()` to define the forward pass for action inference during + training sample collection (w/ exploration behavior). + `_forward_train()` to define the forward pass prior to loss computation. + """ + # TODO (sven): Experimental to make ONNX exported models work. + if self.config.inference_only: + return self.forward_exploration(batch, **kwargs) + else: + return self.forward_train(batch, **kwargs) + + +class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel): + def __init__(self, *args, **kwargs) -> None: + nn.parallel.DistributedDataParallel.__init__(self, *args, **kwargs) + # We do not want to call RLModule.__init__ here because all we need is + # the interface of that base-class not the actual implementation. + # RLModule.__init__(self, *args, **kwargs) + self.observation_space = self.unwrapped().observation_space + self.action_space = self.unwrapped().action_space + self.inference_only = self.unwrapped().inference_only + self.learner_only = self.unwrapped().learner_only + self.model_config = self.unwrapped().model_config + self.catalog = self.unwrapped().catalog + + # Deprecated. + self.config = self.unwrapped().config + + @override(RLModule) + def get_inference_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: + return self.unwrapped().get_inference_action_dist_cls(*args, **kwargs) + + @override(RLModule) + def get_exploration_action_dist_cls( + self, *args, **kwargs + ) -> Type[TorchDistribution]: + return self.unwrapped().get_exploration_action_dist_cls(*args, **kwargs) + + @override(RLModule) + def get_train_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: + return self.unwrapped().get_train_action_dist_cls(*args, **kwargs) + + @override(RLModule) + def get_initial_state(self) -> Any: + return self.unwrapped().get_initial_state() + + @override(RLModule) + def is_stateful(self) -> bool: + return self.unwrapped().is_stateful() + + @override(RLModule) + def _forward(self, *args, **kwargs): + return self.unwrapped()._forward(*args, **kwargs) + + @override(RLModule) + def _forward_inference(self, *args, **kwargs) -> Dict[str, Any]: + return self.unwrapped()._forward_inference(*args, **kwargs) + + @override(RLModule) + def _forward_exploration(self, *args, **kwargs) -> Dict[str, Any]: + return self.unwrapped()._forward_exploration(*args, **kwargs) + + @override(RLModule) + def _forward_train(self, *args, **kwargs): + return self(*args, **kwargs) + + @override(RLModule) + def get_state(self, *args, **kwargs): + return self.unwrapped().get_state(*args, **kwargs) + + @override(RLModule) + def set_state(self, *args, **kwargs): + self.unwrapped().set_state(*args, **kwargs) + + @override(RLModule) + def save_to_path(self, *args, **kwargs): + self.unwrapped().save_to_path(*args, **kwargs) + + @override(RLModule) + def restore_from_path(self, *args, **kwargs): + self.unwrapped().restore_from_path(*args, **kwargs) + + @override(RLModule) + def get_metadata(self, *args, **kwargs): + self.unwrapped().get_metadata(*args, **kwargs) + + @override(RLModule) + def unwrapped(self) -> "RLModule": + return self.module + + +def compile_wrapper(rl_module: "TorchRLModule", compile_config: TorchCompileConfig): + """A wrapper that compiles the forward methods of a TorchRLModule.""" + + # TODO(Artur): Remove this once our requirements enforce torch >= 2.0.0 + # Check if torch framework supports torch.compile. + if ( + torch is not None + and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION + ): + raise ValueError("torch.compile is only supported from torch 2.0.0") + + compiled_forward_train = torch.compile( + rl_module._forward_train, + backend=compile_config.torch_dynamo_backend, + mode=compile_config.torch_dynamo_mode, + **compile_config.kwargs, + ) + + rl_module._forward_train = compiled_forward_train + + compiled_forward_inference = torch.compile( + rl_module._forward_inference, + backend=compile_config.torch_dynamo_backend, + mode=compile_config.torch_dynamo_mode, + **compile_config.kwargs, + ) + + rl_module._forward_inference = compiled_forward_inference + + compiled_forward_exploration = torch.compile( + rl_module._forward_exploration, + backend=compile_config.torch_dynamo_backend, + mode=compile_config.torch_dynamo_mode, + **compile_config.kwargs, + ) + + rl_module._forward_exploration = compiled_forward_exploration + + return rl_module