diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bff33528c9af02fa036eb82c6c6833ceb59bff08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py @@ -0,0 +1,35 @@ +from ray.rllib.core.columns import Columns + + +DEFAULT_AGENT_ID = "default_agent" +DEFAULT_POLICY_ID = "default_policy" +# TODO (sven): Change this to "default_module" +DEFAULT_MODULE_ID = DEFAULT_POLICY_ID +ALL_MODULES = "__all_modules__" + +COMPONENT_ENV_RUNNER = "env_runner" +COMPONENT_ENV_TO_MODULE_CONNECTOR = "env_to_module_connector" +COMPONENT_EVAL_ENV_RUNNER = "eval_env_runner" +COMPONENT_LEARNER = "learner" +COMPONENT_LEARNER_GROUP = "learner_group" +COMPONENT_METRICS_LOGGER = "metrics_logger" +COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector" +COMPONENT_OPTIMIZER = "optimizer" +COMPONENT_RL_MODULE = "rl_module" + + +__all__ = [ + "Columns", + "COMPONENT_ENV_RUNNER", + "COMPONENT_ENV_TO_MODULE_CONNECTOR", + "COMPONENT_EVAL_ENV_RUNNER", + "COMPONENT_LEARNER", + "COMPONENT_LEARNER_GROUP", + "COMPONENT_METRICS_LOGGER", + "COMPONENT_MODULE_TO_ENV_CONNECTOR", + "COMPONENT_OPTIMIZER", + "COMPONENT_RL_MODULE", + "DEFAULT_AGENT_ID", + "DEFAULT_MODULE_ID", + "DEFAULT_POLICY_ID", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..623eb2f7ed4c915ac2e6412dd52665ea635633d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py @@ -0,0 +1,37 @@ +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.external_env import ExternalEnv +from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.policy_client import PolicyClient +from ray.rllib.env.policy_server_input import PolicyServerInput +from ray.rllib.env.remote_base_env import RemoteBaseEnv +from ray.rllib.env.vector_env import VectorEnv + +from ray.rllib.env.wrappers.dm_env_wrapper import DMEnv +from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv +from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper +from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv +from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv + +INPUT_ENV_SPACES = "__env__" + +__all__ = [ + "BaseEnv", + "DMEnv", + "DMCEnv", + "EnvContext", + "ExternalEnv", + "ExternalMultiAgentEnv", + "GroupAgentsWrapper", + "MultiAgentEnv", + "PettingZooEnv", + "ParallelPettingZooEnv", + "PolicyClient", + "PolicyServerInput", + "RemoteBaseEnv", + "Unity3DEnv", + "VectorEnv", + "INPUT_ENV_SPACES", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb8fefeb8ca9e929df58153b7ea672419b08f85 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98e6b38da63061b7f649381e8af9b3638aa35d34 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e16bb25a18306d76037ced49322faf33e7364d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d08a2c3df976cb4feab415eb2df3ddf8602343c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25bfee9f8a15084b40b02b911fcaba0f6807e2fb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aced12690ea3f2c149c922c4ada0c3ade447e5d1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a1672ea29e4abf5748aa58555343eb23ab96b25 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d988d0db6b7abc8f22fa1a5567da981d2b55e4bd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a6cd6cf3890ea67fa1eac1af2d53ab8460b2ef4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..694d077e9eb7d01144d199bfaa6c73e34b84876b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5fa41828738d100d52ed83899658e7e7a01c61a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f0b2f5244ec22fb732735eb97f05abc53cb9c39 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93f126822b358f398cdd19ada3239ed88a785864 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a828b7985e5fc7fa9a2cc554eca74173adccf42a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eda13c5e3e46b22e4712a5a9446ce0f1af22f765 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4851ce29508a926de2252be06dea52e82c25153f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c67c642e4763f819e1b6143f362cab8ad1aa8472 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py @@ -0,0 +1,428 @@ +import logging +from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING, Union, Set + +import gymnasium as gym +import ray +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict + +if TYPE_CHECKING: + from ray.rllib.evaluation.rollout_worker import RolloutWorker + +ASYNC_RESET_RETURN = "async_reset_return" + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class BaseEnv: + """The lowest-level env interface used by RLlib for sampling. + + BaseEnv models multiple agents executing asynchronously in multiple + vectorized sub-environments. A call to `poll()` returns observations from + ready agents keyed by their sub-environment ID and agent IDs, and + actions for those agents can be sent back via `send_actions()`. + + All other RLlib supported env types can be converted to BaseEnv. + RLlib handles these conversions internally in RolloutWorker, for example: + + gym.Env => rllib.VectorEnv => rllib.BaseEnv + rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv + rllib.ExternalEnv => rllib.BaseEnv + + .. testcode:: + :skipif: True + + MyBaseEnv = ... + env = MyBaseEnv() + obs, rewards, terminateds, truncateds, infos, off_policy_actions = ( + env.poll() + ) + print(obs) + + env.send_actions({ + "env_0": { + "car_0": 0, + "car_1": 1, + }, ... + }) + obs, rewards, terminateds, truncateds, infos, off_policy_actions = ( + env.poll() + ) + print(obs) + + print(terminateds) + + .. testoutput:: + + { + "env_0": { + "car_0": [2.4, 1.6], + "car_1": [3.4, -3.2], + }, + "env_1": { + "car_0": [8.0, 4.1], + }, + "env_2": { + "car_0": [2.3, 3.3], + "car_1": [1.4, -0.2], + "car_3": [1.2, 0.1], + }, + } + { + "env_0": { + "car_0": [4.1, 1.7], + "car_1": [3.2, -4.2], + }, ... + } + { + "env_0": { + "__all__": False, + "car_0": False, + "car_1": True, + }, ... + } + + """ + + def to_base_env( + self, + make_env: Optional[Callable[[int], EnvType]] = None, + num_envs: int = 1, + remote_envs: bool = False, + remote_env_batch_wait_ms: int = 0, + restart_failed_sub_environments: bool = False, + ) -> "BaseEnv": + """Converts an RLlib-supported env into a BaseEnv object. + + Supported types for the `env` arg are gym.Env, BaseEnv, + VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv. + + The resulting BaseEnv is always vectorized (contains n + sub-environments) to support batched forward passes, where n may also + be 1. BaseEnv also supports async execution via the `poll` and + `send_actions` methods and thus supports external simulators. + + TODO: Support gym3 environments, which are already vectorized. + + Args: + env: An already existing environment of any supported env type + to convert/wrap into a BaseEnv. Supported types are gym.Env, + BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and + ExternalMultiAgentEnv. + make_env: A callable taking an int as input (which indicates the + number of individual sub-environments within the final + vectorized BaseEnv) and returning one individual + sub-environment. + num_envs: The number of sub-environments to create in the + resulting (vectorized) BaseEnv. The already existing `env` + will be one of the `num_envs`. + remote_envs: Whether each sub-env should be a @ray.remote actor. + You can set this behavior in your config via the + `remote_worker_envs=True` option. + remote_env_batch_wait_ms: The wait time (in ms) to poll remote + sub-environments for, if applicable. Only used if + `remote_envs` is True. + policy_config: Optional policy config dict. + + Returns: + The resulting BaseEnv object. + """ + return self + + def poll( + self, + ) -> Tuple[ + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + ]: + """Returns observations from ready agents. + + All return values are two-level dicts mapping from EnvID to dicts + mapping from AgentIDs to (observation/reward/etc..) values. + The number of agents and sub-environments may vary over time. + + Returns: + Tuple consisting of: + New observations for each ready agent. + Reward values for each ready agent. If the episode is just started, + the value will be None. + Terminated values for each ready agent. The special key "__all__" is used to + indicate episode termination. + Truncated values for each ready agent. The special key "__all__" + is used to indicate episode truncation. + Info values for each ready agent. + Agents may take off-policy actions, in which case, there will be an entry + in this dict that contains the taken action. There is no need to + `send_actions()` for agents that have already chosen off-policy actions. + """ + raise NotImplementedError + + def send_actions(self, action_dict: MultiEnvDict) -> None: + """Called to send actions back to running agents in this env. + + Actions should be sent for each ready agent that returned observations + in the previous poll() call. + + Args: + action_dict: Actions values keyed by env_id and agent_id. + """ + raise NotImplementedError + + def try_reset( + self, + env_id: Optional[EnvID] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[Optional[MultiEnvDict], Optional[MultiEnvDict]]: + """Attempt to reset the sub-env with the given id or all sub-envs. + + If the environment does not support synchronous reset, a tuple of + (ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) can be returned here. + + Note: A MultiAgentDict is returned when using the deprecated wrapper + classes such as `ray.rllib.env.base_env._MultiAgentEnvToBaseEnv`, + however for consistency with the poll() method, a `MultiEnvDict` is + returned from the new wrapper classes, such as + `ray.rllib.env.multi_agent_env.MultiAgentEnvWrapper`. + + Args: + env_id: The sub-environment's ID if applicable. If None, reset + the entire Env (i.e. all sub-environments). + seed: The seed to be passed to the sub-environment(s) when + resetting it. If None, will not reset any existing PRNG. If you pass an + integer, the PRNG will be reset even if it already exists. + options: An options dict to be passed to the sub-environment(s) when + resetting it. + + Returns: + A tuple consisting of a) the reset (multi-env/multi-agent) observation + dict and b) the reset (multi-env/multi-agent) infos dict. Returns the + (ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) tuple, if not supported. + """ + return None, None + + def try_restart(self, env_id: Optional[EnvID] = None) -> None: + """Attempt to restart the sub-env with the given id or all sub-envs. + + This could result in the sub-env being completely removed (gc'd) and recreated. + + Args: + env_id: The sub-environment's ID, if applicable. If None, restart + the entire Env (i.e. all sub-environments). + """ + return None + + def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]: + """Return a reference to the underlying sub environments, if any. + + Args: + as_dict: If True, return a dict mapping from env_id to env. + + Returns: + List or dictionary of the underlying sub environments or [] / {}. + """ + if as_dict: + return {} + return [] + + def get_agent_ids(self) -> Set[AgentID]: + """Return the agent ids for the sub_environment. + + Returns: + All agent ids for each the environment. + """ + return {} + + def try_render(self, env_id: Optional[EnvID] = None) -> None: + """Tries to render the sub-environment with the given id or all. + + Args: + env_id: The sub-environment's ID, if applicable. + If None, renders the entire Env (i.e. all sub-environments). + """ + + # By default, do nothing. + pass + + def stop(self) -> None: + """Releases all resources used.""" + + # Try calling `close` on all sub-environments. + for env in self.get_sub_environments(): + if hasattr(env, "close"): + env.close() + + @property + def observation_space(self) -> gym.Space: + """Returns the observation space for each agent. + + Note: samples from the observation space need to be preprocessed into a + `MultiEnvDict` before being used by a policy. + + Returns: + The observation space for each environment. + """ + raise NotImplementedError + + @property + def action_space(self) -> gym.Space: + """Returns the action space for each agent. + + Note: samples from the action space need to be preprocessed into a + `MultiEnvDict` before being passed to `send_actions`. + + Returns: + The observation space for each environment. + """ + raise NotImplementedError + + def last( + self, + ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]: + """Returns the last observations, rewards, done- truncated flags and infos ... + + that were returned by the environment. + + Returns: + The last observations, rewards, done- and truncated flags, and infos + for each sub-environment. + """ + logger.warning("last has not been implemented for this environment.") + return {}, {}, {}, {}, {} + + +# Fixed agent identifier when there is only the single agent in the env +_DUMMY_AGENT_ID = "agent0" + + +@OldAPIStack +def with_dummy_agent_id( + env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID +) -> MultiEnvDict: + ret = {} + for (env_id, value) in env_id_to_values.items(): + # If the value (e.g. the observation) is an Exception, publish this error + # under the env ID so the caller of `poll()` knows that the entire episode + # (sub-environment) has crashed. + ret[env_id] = value if isinstance(value, Exception) else {dummy_id: value} + return ret + + +@OldAPIStack +def convert_to_base_env( + env: EnvType, + make_env: Callable[[int], EnvType] = None, + num_envs: int = 1, + remote_envs: bool = False, + remote_env_batch_wait_ms: int = 0, + worker: Optional["RolloutWorker"] = None, + restart_failed_sub_environments: bool = False, +) -> "BaseEnv": + """Converts an RLlib-supported env into a BaseEnv object. + + Supported types for the `env` arg are gym.Env, BaseEnv, + VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv. + + The resulting BaseEnv is always vectorized (contains n + sub-environments) to support batched forward passes, where n may also + be 1. BaseEnv also supports async execution via the `poll` and + `send_actions` methods and thus supports external simulators. + + TODO: Support gym3 environments, which are already vectorized. + + Args: + env: An already existing environment of any supported env type + to convert/wrap into a BaseEnv. Supported types are gym.Env, + BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and + ExternalMultiAgentEnv. + make_env: A callable taking an int as input (which indicates the + number of individual sub-environments within the final + vectorized BaseEnv) and returning one individual + sub-environment. + num_envs: The number of sub-environments to create in the + resulting (vectorized) BaseEnv. The already existing `env` + will be one of the `num_envs`. + remote_envs: Whether each sub-env should be a @ray.remote actor. + You can set this behavior in your config via the + `remote_worker_envs=True` option. + remote_env_batch_wait_ms: The wait time (in ms) to poll remote + sub-environments for, if applicable. Only used if + `remote_envs` is True. + worker: An optional RolloutWorker that owns the env. This is only + used if `remote_worker_envs` is True in your config and the + `on_sub_environment_created` custom callback needs to be called + on each created actor. + restart_failed_sub_environments: If True and any sub-environment (within + a vectorized env) throws any error during env stepping, the + Sampler will try to restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environment and without + the RolloutWorker crashing. + + Returns: + The resulting BaseEnv object. + """ + + from ray.rllib.env.remote_base_env import RemoteBaseEnv + from ray.rllib.env.external_env import ExternalEnv + from ray.rllib.env.multi_agent_env import MultiAgentEnv + from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper + + if remote_envs and num_envs == 1: + raise ValueError( + "Remote envs only make sense to use if num_envs > 1 " + "(i.e. environment vectorization is enabled)." + ) + + # Given `env` has a `to_base_env` method -> Call that to convert to a BaseEnv type. + if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)): + return env.to_base_env( + make_env=make_env, + num_envs=num_envs, + remote_envs=remote_envs, + remote_env_batch_wait_ms=remote_env_batch_wait_ms, + restart_failed_sub_environments=restart_failed_sub_environments, + ) + # `env` is not a BaseEnv yet -> Need to convert/vectorize. + else: + # Sub-environments are ray.remote actors: + if remote_envs: + # Determine, whether the already existing sub-env (could + # be a ray.actor) is multi-agent or not. + multiagent = ( + ray.get(env._is_multi_agent.remote()) + if hasattr(env, "_is_multi_agent") + else False + ) + env = RemoteBaseEnv( + make_env, + num_envs, + multiagent=multiagent, + remote_env_batch_wait_ms=remote_env_batch_wait_ms, + existing_envs=[env], + worker=worker, + restart_failed_sub_environments=restart_failed_sub_environments, + ) + # Sub-environments are not ray.remote actors. + else: + # Convert gym.Env to VectorEnv ... + env = VectorEnv.vectorize_gym_envs( + make_env=make_env, + existing_envs=[env], + num_envs=num_envs, + action_space=env.action_space, + observation_space=env.observation_space, + restart_failed_sub_environments=restart_failed_sub_environments, + ) + # ... then the resulting VectorEnv to a BaseEnv. + env = VectorEnvWrapper(env) + + # Make sure conversion went well. + assert isinstance(env, BaseEnv), env + + return env diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py new file mode 100644 index 0000000000000000000000000000000000000000..296246fe638c1915205f569247e96d79975b25a1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py @@ -0,0 +1,128 @@ +import copy +from typing import Optional + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.typing import EnvConfigDict + + +@OldAPIStack +class EnvContext(dict): + """Wraps env configurations to include extra rllib metadata. + + These attributes can be used to parameterize environments per process. + For example, one might use `worker_index` to control which data file an + environment reads in on initialization. + + RLlib auto-sets these attributes when constructing registered envs. + """ + + def __init__( + self, + env_config: EnvConfigDict, + worker_index: int, + vector_index: int = 0, + remote: bool = False, + num_workers: Optional[int] = None, + recreated_worker: bool = False, + ): + """Initializes an EnvContext instance. + + Args: + env_config: The env's configuration defined under the + "env_config" key in the Algorithm's config. + worker_index: When there are multiple workers created, this + uniquely identifies the worker the env is created in. + 0 for local worker, >0 for remote workers. + vector_index: When there are multiple envs per worker, this + uniquely identifies the env index within the worker. + Starts from 0. + remote: Whether individual sub-environments (in a vectorized + env) should be @ray.remote actors or not. + num_workers: The total number of (remote) workers in the set. + 0 if only a local worker exists. + recreated_worker: Whether the worker that holds this env is a recreated one. + This means that it replaced a previous (failed) worker when + `restart_failed_env_runners=True` in the Algorithm's config. + """ + # Store the env_config in the (super) dict. + dict.__init__(self, env_config) + + # Set some metadata attributes. + self.worker_index = worker_index + self.vector_index = vector_index + self.remote = remote + self.num_workers = num_workers + self.recreated_worker = recreated_worker + + def copy_with_overrides( + self, + env_config: Optional[EnvConfigDict] = None, + worker_index: Optional[int] = None, + vector_index: Optional[int] = None, + remote: Optional[bool] = None, + num_workers: Optional[int] = None, + recreated_worker: Optional[bool] = None, + ) -> "EnvContext": + """Returns a copy of this EnvContext with some attributes overridden. + + Args: + env_config: Optional env config to use. None for not overriding + the one from the source (self). + worker_index: Optional worker index to use. None for not + overriding the one from the source (self). + vector_index: Optional vector index to use. None for not + overriding the one from the source (self). + remote: Optional remote setting to use. None for not overriding + the one from the source (self). + num_workers: Optional num_workers to use. None for not overriding + the one from the source (self). + recreated_worker: Optional flag, indicating, whether the worker that holds + the env is a recreated one. This means that it replaced a previous + (failed) worker when `restart_failed_env_runners=True` in the + Algorithm's config. + + Returns: + A new EnvContext object as a copy of self plus the provided + overrides. + """ + return EnvContext( + copy.deepcopy(env_config) if env_config is not None else self, + worker_index if worker_index is not None else self.worker_index, + vector_index if vector_index is not None else self.vector_index, + remote if remote is not None else self.remote, + num_workers if num_workers is not None else self.num_workers, + recreated_worker if recreated_worker is not None else self.recreated_worker, + ) + + def set_defaults(self, defaults: dict) -> None: + """Sets missing keys of self to the values given in `defaults`. + + If `defaults` contains keys that already exist in self, don't override + the values with these defaults. + + Args: + defaults: The key/value pairs to add to self, but only for those + keys in `defaults` that don't exist yet in self. + + .. testcode:: + :skipif: True + + from ray.rllib.env.env_context import EnvContext + env_ctx = EnvContext({"a": 1, "b": 2}, worker_index=0) + env_ctx.set_defaults({"a": -42, "c": 3}) + print(env_ctx) + + .. testoutput:: + + {"a": 1, "b": 2, "c": 3} + """ + for key, value in defaults.items(): + if key not in self: + self[key] = value + + def __str__(self): + return ( + super().__str__()[:-1] + + f", worker={self.worker_index}/{self.num_workers}, " + f"vector_idx={self.vector_index}, remote={self.remote}" + "}" + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..6129ed40cf19797fe1adb6f4464f724edd19d0e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py @@ -0,0 +1,187 @@ +import abc +import logging +from typing import Any, Dict, Tuple, TYPE_CHECKING + +import gymnasium as gym +import tree # pip install dm_tree + +from ray.rllib.utils.actor_manager import FaultAwareApply +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + +logger = logging.getLogger("ray.rllib") + +tf1, tf, _ = try_import_tf() + +ENV_RESET_FAILURE = "env_reset_failure" +ENV_STEP_FAILURE = "env_step_failure" + + +# TODO (sven): As soon as RolloutWorker is no longer supported, make this base class +# a Checkpointable. Currently, only some of its subclasses are Checkpointables. +@PublicAPI(stability="alpha") +class EnvRunner(FaultAwareApply, metaclass=abc.ABCMeta): + """Base class for distributed RL-style data collection from an environment. + + The EnvRunner API's core functionalities can be summarized as: + - Gets configured via passing a AlgorithmConfig object to the constructor. + Normally, subclasses of EnvRunner then construct their own environment (possibly + vectorized) copies and RLModules/Policies and use the latter to step through the + environment in order to collect training data. + - Clients of EnvRunner can use the `sample()` method to collect data for training + from the environment(s). + - EnvRunner offers parallelism via creating n remote Ray Actors based on this class. + Use `ray.remote([resources])(EnvRunner)` method to create the corresponding Ray + remote class. Then instantiate n Actors using the Ray `[ctor].remote(...)` syntax. + - EnvRunner clients can get information about the server/node on which the + individual Actors are running. + """ + + def __init__(self, *, config: "AlgorithmConfig", **kwargs): + """Initializes an EnvRunner instance. + + Args: + config: The AlgorithmConfig to use to setup this EnvRunner. + **kwargs: Forward compatibility kwargs. + """ + self.config = config.copy(copy_frozen=False) + self.env = None + + super().__init__(**kwargs) + + # This eager check is necessary for certain all-framework tests + # that use tf's eager_mode() context generator. + if ( + tf1 + and (self.config.framework_str == "tf2" or config.enable_tf1_exec_eagerly) + and not tf1.executing_eagerly() + ): + tf1.enable_eager_execution() + + @abc.abstractmethod + def assert_healthy(self): + """Checks that self.__init__() has been completed properly. + + Useful in case an `EnvRunner` is run as @ray.remote (Actor) and the owner + would like to make sure the Ray Actor has been properly initialized. + + Raises: + AssertionError: If the EnvRunner Actor has NOT been properly initialized. + """ + + # TODO: Make this an abstract method that must be implemented. + def make_env(self): + """Creates the RL environment for this EnvRunner and assigns it to `self.env`. + + Note that users should be able to change the EnvRunner's config (e.g. change + `self.config.env_config`) and then call this method to create new environments + with the updated configuration. + It should also be called after a failure of an earlier env in order to clean up + the existing env (for example `close()` it), re-create a new one, and then + continue sampling with that new env. + """ + pass + + # TODO: Make this an abstract method that must be implemented. + def make_module(self): + """Creates the RLModule for this EnvRunner and assigns it to `self.module`. + + Note that users should be able to change the EnvRunner's config (e.g. change + `self.config.rl_module_spec`) and then call this method to create a new RLModule + with the updated configuration. + """ + pass + + @abc.abstractmethod + def sample(self, **kwargs) -> Any: + """Returns experiences (of any form) sampled from this EnvRunner. + + The exact nature and size of collected data are defined via the EnvRunner's + config and may be overridden by the given arguments. + + Args: + **kwargs: Forward compatibility kwargs. + + Returns: + The collected experience in any form. + """ + + # TODO (sven): Make this an abstract method that must be overridden. + def get_metrics(self) -> Any: + """Returns metrics (in any form) of the thus far collected, completed episodes. + + Returns: + Metrics of any form. + """ + pass + + @abc.abstractmethod + def get_spaces(self) -> Dict[str, Tuple[gym.Space, gym.Space]]: + """Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space.""" + + def stop(self) -> None: + """Releases all resources used by this EnvRunner. + + For example, when using a gym.Env in this EnvRunner, you should make sure + that its `close()` method is called. + """ + pass + + def __del__(self) -> None: + """If this Actor is deleted, clears all resources used by it.""" + pass + + def _try_env_reset(self): + """Tries resetting the env and - if an error orrurs - handles it gracefully.""" + # Try to reset. + try: + obs, infos = self.env.reset() + # Everything ok -> return. + return obs, infos + # Error. + except Exception as e: + # If user wants to simply restart the env -> recreate env and try again + # (calling this method recursively until success). + if self.config.restart_failed_sub_environments: + logger.exception( + "Resetting the env resulted in an error! The original error " + f"is: {e.args[0]}" + ) + # Recreate the env and simply try again. + self.make_env() + return self._try_env_reset() + else: + raise e + + def _try_env_step(self, actions): + """Tries stepping the env and - if an error orrurs - handles it gracefully.""" + try: + results = self.env.step(actions) + return results + except Exception as e: + if self.config.restart_failed_sub_environments: + logger.exception( + "Stepping the env resulted in an error! The original error " + f"is: {e.args[0]}" + ) + # Recreate the env. + self.make_env() + # And return that the stepping failed. The caller will then handle + # specific cleanup operations (for example discarding thus-far collected + # data and repeating the step attempt). + return ENV_STEP_FAILURE + else: + raise e + + def _convert_to_tensor(self, struct) -> TensorType: + """Converts structs to a framework-specific tensor.""" + + if self.config.framework_str == "torch": + return convert_to_torch_tensor(struct) + else: + return tree.map_structure(tf.convert_to_tensor, struct) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1e53e534dfdd153874196f29ee8c24badee891 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py @@ -0,0 +1,1262 @@ +import functools +import gymnasium as gym +import logging +import importlib.util +import os +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +import ray +from ray.actor import ActorHandle +from ray.exceptions import RayActorError +from ray.rllib.core import ( + COMPONENT_ENV_TO_MODULE_CONNECTOR, + COMPONENT_LEARNER, + COMPONENT_MODULE_TO_ENV_CONNECTOR, + COMPONENT_RL_MODULE, +) +from ray.rllib.core.learner import LearnerGroup +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.offline import get_dataset_and_shards +from ray.rllib.policy.policy import Policy, PolicyState +from ray.rllib.utils.actor_manager import FaultTolerantActorManager +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.deprecation import ( + Deprecated, + deprecation_warning, + DEPRECATED_VALUE, +) +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO +from ray.rllib.utils.typing import ( + AgentID, + EnvCreator, + EnvType, + EpisodeID, + PartialAlgorithmConfigDict, + PolicyID, + SampleBatchType, + TensorType, +) +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + +# Generic type var for foreach_* methods. +T = TypeVar("T") + + +@DeveloperAPI +class EnvRunnerGroup: + """Set of EnvRunners with n @ray.remote workers and zero or one local worker. + + Where: n >= 0. + """ + + def __init__( + self, + *, + env_creator: Optional[EnvCreator] = None, + validate_env: Optional[Callable[[EnvType], None]] = None, + default_policy_class: Optional[Type[Policy]] = None, + config: Optional["AlgorithmConfig"] = None, + local_env_runner: bool = True, + logdir: Optional[str] = None, + _setup: bool = True, + tune_trial_id: Optional[str] = None, + # Deprecated args. + num_env_runners: Optional[int] = None, + num_workers=DEPRECATED_VALUE, + local_worker=DEPRECATED_VALUE, + ): + """Initializes a EnvRunnerGroup instance. + + Args: + env_creator: Function that returns env given env config. + validate_env: Optional callable to validate the generated + environment (only on worker=0). This callable should raise + an exception if the environment is invalid. + default_policy_class: An optional default Policy class to use inside + the (multi-agent) `policies` dict. In case the PolicySpecs in there + have no class defined, use this `default_policy_class`. + If None, PolicySpecs will be using the Algorithm's default Policy + class. + config: Optional AlgorithmConfig (or config dict). + local_env_runner: Whether to create a local (non @ray.remote) EnvRunner + in the returned set as well (default: True). If `num_env_runners` + is 0, always create a local EnvRunner. + logdir: Optional logging directory for workers. + _setup: Whether to actually set up workers. This is only for testing. + tune_trial_id: The Ray Tune trial ID, if this EnvRunnerGroup is part of + an Algorithm run as a Tune trial. None, otherwise. + """ + if num_workers != DEPRECATED_VALUE or local_worker != DEPRECATED_VALUE: + deprecation_warning( + old="WorkerSet(num_workers=..., local_worker=...)", + new="EnvRunnerGroup(num_env_runners=..., local_env_runner=...)", + error=True, + ) + + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + # Make sure `config` is an AlgorithmConfig object. + if not config: + config = AlgorithmConfig() + elif isinstance(config, dict): + config = AlgorithmConfig.from_dict(config) + + self._env_creator = env_creator + self._policy_class = default_policy_class + self._remote_config = config + self._remote_args = { + "num_cpus": self._remote_config.num_cpus_per_env_runner, + "num_gpus": self._remote_config.num_gpus_per_env_runner, + "resources": self._remote_config.custom_resources_per_env_runner, + "max_restarts": ( + config.max_num_env_runner_restarts + if config.restart_failed_env_runners + else 0 + ), + } + self._tune_trial_id = tune_trial_id + + # Set the EnvRunner subclass to be used as "workers". Default: RolloutWorker. + self.env_runner_cls = config.env_runner_cls + if self.env_runner_cls is None: + if config.enable_env_runner_and_connector_v2: + # If experiences should be recorded, use the ` + # OfflineSingleAgentEnvRunner`. + if config.output: + # No multi-agent support. + if config.is_multi_agent: + raise ValueError("Multi-agent recording is not supported, yet.") + # Otherwise, load the single-agent env runner for + # recording. + else: + from ray.rllib.offline.offline_env_runner import ( + OfflineSingleAgentEnvRunner, + ) + + self.env_runner_cls = OfflineSingleAgentEnvRunner + else: + if config.is_multi_agent: + from ray.rllib.env.multi_agent_env_runner import ( + MultiAgentEnvRunner, + ) + + self.env_runner_cls = MultiAgentEnvRunner + else: + from ray.rllib.env.single_agent_env_runner import ( + SingleAgentEnvRunner, + ) + + self.env_runner_cls = SingleAgentEnvRunner + else: + self.env_runner_cls = RolloutWorker + self._cls = ray.remote(**self._remote_args)(self.env_runner_cls).remote + + self._logdir = logdir + self._ignore_ray_errors_on_env_runners = ( + config.ignore_env_runner_failures or config.restart_failed_env_runners + ) + + # Create remote worker manager. + # ID=0 is used by the local worker. + # Starting remote workers from ID=1 to avoid conflicts. + self._worker_manager = FaultTolerantActorManager( + max_remote_requests_in_flight_per_actor=( + config.max_requests_in_flight_per_env_runner + ), + init_id=1, + ) + + if _setup: + try: + self._setup( + validate_env=validate_env, + config=config, + num_env_runners=( + num_env_runners + if num_env_runners is not None + else config.num_env_runners + ), + local_env_runner=local_env_runner, + ) + # EnvRunnerGroup creation possibly fails, if some (remote) workers cannot + # be initialized properly (due to some errors in the EnvRunners's + # constructor). + except RayActorError as e: + # In case of an actor (remote worker) init failure, the remote worker + # may still exist and will be accessible, however, e.g. calling + # its `sample.remote()` would result in strange "property not found" + # errors. + if e.actor_init_failed: + # Raise the original error here that the EnvRunners raised + # during its construction process. This is to enforce transparency + # for the user (better to understand the real reason behind the + # failure). + # - e.args[0]: The RayTaskError (inside the caught RayActorError). + # - e.args[0].args[2]: The original Exception (e.g. a ValueError due + # to a config mismatch) thrown inside the actor. + raise e.args[0].args[2] + # In any other case, raise the RayActorError as-is. + else: + raise e + + def _setup( + self, + *, + validate_env: Optional[Callable[[EnvType], None]] = None, + config: Optional["AlgorithmConfig"] = None, + num_env_runners: int = 0, + local_env_runner: bool = True, + ): + """Sets up an EnvRunnerGroup instance. + Args: + validate_env: Optional callable to validate the generated + environment (only on worker=0). + config: Optional dict that extends the common config of + the Algorithm class. + num_env_runners: Number of remote EnvRunner workers to create. + local_env_runner: Whether to create a local (non @ray.remote) EnvRunner + in the returned set as well (default: True). If `num_env_runners` + is 0, always create a local EnvRunner. + """ + # Force a local worker if num_env_runners == 0 (no remote workers). + # Otherwise, this EnvRunnerGroup would be empty. + self._local_env_runner = None + if num_env_runners == 0: + local_env_runner = True + # Create a local (learner) version of the config for the local worker. + # The only difference is the tf_session_args, which - for the local worker - + # will be `config.tf_session_args` updated/overridden with + # `config.local_tf_session_args`. + local_tf_session_args = config.tf_session_args.copy() + local_tf_session_args.update(config.local_tf_session_args) + self._local_config = config.copy(copy_frozen=False).framework( + tf_session_args=local_tf_session_args + ) + + if config.input_ == "dataset": + # Create the set of dataset readers to be shared by all the + # rollout workers. + self._ds, self._ds_shards = get_dataset_and_shards(config, num_env_runners) + else: + self._ds = None + self._ds_shards = None + + # Create a number of @ray.remote workers. + self.add_workers( + num_env_runners, + validate=config.validate_env_runners_after_construction, + ) + + # If num_env_runners > 0 and we don't have an env on the local worker, + # get the observation- and action spaces for each policy from + # the first remote worker (which does have an env). + if ( + local_env_runner + and self._worker_manager.num_actors() > 0 + and not config.enable_env_runner_and_connector_v2 + and not config.create_env_on_local_worker + and (not config.observation_space or not config.action_space) + ): + spaces = self.get_spaces() + else: + spaces = None + + # Create a local worker, if needed. + if local_env_runner: + self._local_env_runner = self._make_worker( + cls=self.env_runner_cls, + env_creator=self._env_creator, + validate_env=validate_env, + worker_index=0, + num_workers=num_env_runners, + config=self._local_config, + spaces=spaces, + ) + + def get_spaces(self): + """Infer observation and action spaces from one (local or remote) EnvRunner. + + Returns: + A dict mapping from ModuleID to a 2-tuple containing obs- and action-space. + """ + # Get ID of the first remote worker. + remote_worker_ids = ( + [self._worker_manager.actor_ids()[0]] + if self._worker_manager.actor_ids() + else [] + ) + + spaces = self.foreach_env_runner( + lambda env_runner: env_runner.get_spaces(), + remote_worker_ids=remote_worker_ids, + local_env_runner=not remote_worker_ids, + )[0] + + logger.info( + "Inferred observation/action spaces from remote " + f"worker (local worker has no env): {spaces}" + ) + + return spaces + + @property + def local_env_runner(self) -> EnvRunner: + """Returns the local EnvRunner.""" + return self._local_env_runner + + def healthy_env_runner_ids(self) -> List[int]: + """Returns the list of remote worker IDs.""" + return self._worker_manager.healthy_actor_ids() + + def healthy_worker_ids(self) -> List[int]: + """Returns the list of remote worker IDs.""" + return self.healthy_env_runner_ids() + + def num_remote_env_runners(self) -> int: + """Returns the number of remote EnvRunners.""" + return self._worker_manager.num_actors() + + def num_remote_workers(self) -> int: + """Returns the number of remote EnvRunners.""" + return self.num_remote_env_runners() + + def num_healthy_remote_env_runners(self) -> int: + """Returns the number of healthy remote workers.""" + return self._worker_manager.num_healthy_actors() + + def num_healthy_remote_workers(self) -> int: + """Returns the number of healthy remote workers.""" + return self.num_healthy_remote_env_runners() + + def num_healthy_env_runners(self) -> int: + """Returns the number of all healthy workers, including the local worker.""" + return int(bool(self._local_env_runner)) + self.num_healthy_remote_workers() + + def num_healthy_workers(self) -> int: + """Returns the number of all healthy workers, including the local worker.""" + return self.num_healthy_env_runners() + + def num_in_flight_async_reqs(self) -> int: + """Returns the number of in-flight async requests.""" + return self._worker_manager.num_outstanding_async_reqs() + + def num_remote_worker_restarts(self) -> int: + """Total number of times managed remote workers have been restarted.""" + return self._worker_manager.total_num_restarts() + + def sync_env_runner_states( + self, + *, + config: "AlgorithmConfig", + from_worker: Optional[EnvRunner] = None, + env_steps_sampled: Optional[int] = None, + connector_states: Optional[List[Dict[str, Any]]] = None, + rl_module_state: Optional[Dict[str, Any]] = None, + env_runner_indices_to_update: Optional[List[int]] = None, + ) -> None: + """Synchronizes the connectors of this EnvRunnerGroup's EnvRunners. + + The exact procedure works as follows: + - If `from_worker` is None, set `from_worker=self.local_env_runner`. + - If `config.use_worker_filter_stats` is True, gather all remote EnvRunners' + ConnectorV2 states. Otherwise, only use the ConnectorV2 states of `from_worker`. + - Merge all gathered states into one resulting state. + - Broadcast the resulting state back to all remote EnvRunners AND the local + EnvRunner. + + Args: + config: The AlgorithmConfig object to use to determine, in which + direction(s) we need to synch and what the timeouts are. + from_worker: The EnvRunner from which to synch. If None, will use the local + worker of this EnvRunnerGroup. + env_steps_sampled: The total number of env steps taken thus far by all + workers combined. Used to broadcast this number to all remote workers + if `update_worker_filter_stats` is True in `config`. + env_runner_indices_to_update: The indices of those EnvRunners to update + with the merged state. Use None (default) to update all remote + EnvRunners. + """ + from_worker = from_worker or self.local_env_runner + + # Early out if the number of (healthy) remote workers is 0. In this case, the + # local worker is the only operating worker and thus of course always holds + # the reference connector state. + if self.num_healthy_remote_workers() == 0: + self.local_env_runner.set_state( + { + **( + {NUM_ENV_STEPS_SAMPLED_LIFETIME: env_steps_sampled} + if env_steps_sampled is not None + else {} + ), + **(rl_module_state if rl_module_state is not None else {}), + } + ) + return + + # Also early out, if we a) don't use the remote states AND b) don't want to + # broadcast back from `from_worker` to all remote workers. + # TODO (sven): Rename these to proper "..env_runner_states.." containing names. + if not config.update_worker_filter_stats and not config.use_worker_filter_stats: + return + + # Use states from all remote EnvRunners. + if config.use_worker_filter_stats: + if connector_states == []: + env_runner_states = {} + else: + if connector_states is None: + connector_states = self.foreach_env_runner( + lambda w: w.get_state( + components=[ + COMPONENT_ENV_TO_MODULE_CONNECTOR, + COMPONENT_MODULE_TO_ENV_CONNECTOR, + ] + ), + local_env_runner=False, + timeout_seconds=( + config.sync_filters_on_rollout_workers_timeout_s + ), + ) + env_to_module_states = [ + s[COMPONENT_ENV_TO_MODULE_CONNECTOR] + for s in connector_states + if COMPONENT_ENV_TO_MODULE_CONNECTOR in s + ] + module_to_env_states = [ + s[COMPONENT_MODULE_TO_ENV_CONNECTOR] + for s in connector_states + if COMPONENT_MODULE_TO_ENV_CONNECTOR in s + ] + + env_runner_states = {} + if env_to_module_states: + env_runner_states.update( + { + COMPONENT_ENV_TO_MODULE_CONNECTOR: ( + self.local_env_runner._env_to_module.merge_states( + env_to_module_states + ) + ), + } + ) + if module_to_env_states: + env_runner_states.update( + { + COMPONENT_MODULE_TO_ENV_CONNECTOR: ( + self.local_env_runner._module_to_env.merge_states( + module_to_env_states + ) + ), + } + ) + # Ignore states from remote EnvRunners (use the current `from_worker` states + # only). + else: + env_runner_states = from_worker.get_state( + components=[ + COMPONENT_ENV_TO_MODULE_CONNECTOR, + COMPONENT_MODULE_TO_ENV_CONNECTOR, + ] + ) + + # Update the global number of environment steps, if necessary. + # Make sure to divide by the number of env runners (such that each EnvRunner + # knows (roughly) its own(!) lifetime count and can infer the global lifetime + # count from it). + if env_steps_sampled is not None: + env_runner_states[NUM_ENV_STEPS_SAMPLED_LIFETIME] = env_steps_sampled // ( + config.num_env_runners or 1 + ) + + # Update the rl_module component of the EnvRunner states, if necessary: + if rl_module_state: + env_runner_states.update(rl_module_state) + + # If we do NOT want remote EnvRunners to get their Connector states updated, + # only update the local worker here (with all state components) and then remove + # the connector components. + if not config.update_worker_filter_stats: + self.local_env_runner.set_state(env_runner_states) + env_runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None) + env_runner_states.pop(COMPONENT_MODULE_TO_ENV_CONNECTOR, None) + + # If there are components in the state left -> Update remote workers with these + # state components (and maybe the local worker, if it hasn't been updated yet). + if env_runner_states: + # Put the state dictionary into Ray's object store to avoid having to make n + # pickled copies of the state dict. + ref_env_runner_states = ray.put(env_runner_states) + + def _update(_env_runner: EnvRunner) -> None: + _env_runner.set_state(ray.get(ref_env_runner_states)) + + # Broadcast updated states back to all workers. + self.foreach_env_runner( + _update, + remote_worker_ids=env_runner_indices_to_update, + local_env_runner=config.update_worker_filter_stats, + timeout_seconds=0.0, # This is a state update -> Fire-and-forget. + ) + + def sync_weights( + self, + policies: Optional[List[PolicyID]] = None, + from_worker_or_learner_group: Optional[Union[EnvRunner, "LearnerGroup"]] = None, + to_worker_indices: Optional[List[int]] = None, + global_vars: Optional[Dict[str, TensorType]] = None, + timeout_seconds: Optional[float] = 0.0, + inference_only: Optional[bool] = False, + ) -> None: + """Syncs model weights from the given weight source to all remote workers. + + Weight source can be either a (local) rollout worker or a learner_group. It + should just implement a `get_weights` method. + + Args: + policies: Optional list of PolicyIDs to sync weights for. + If None (default), sync weights to/from all policies. + from_worker_or_learner_group: Optional (local) EnvRunner instance or + LearnerGroup instance to sync from. If None (default), + sync from this EnvRunnerGroup's local worker. + to_worker_indices: Optional list of worker indices to sync the + weights to. If None (default), sync to all remote workers. + global_vars: An optional global vars dict to set this + worker to. If None, do not update the global_vars. + timeout_seconds: Timeout in seconds to wait for the sync weights + calls to complete. Default is 0.0 (fire-and-forget, do not wait + for any sync calls to finish). Setting this to 0.0 might significantly + improve algorithm performance, depending on the algo's `training_step` + logic. + inference_only: Sync weights with workers that keep inference-only + modules. This is needed for algorithms in the new stack that + use inference-only modules. In this case only a part of the + parameters are synced to the workers. Default is False. + """ + if self.local_env_runner is None and from_worker_or_learner_group is None: + raise TypeError( + "No `local_env_runner` in EnvRunnerGroup! Must provide " + "`from_worker_or_learner_group` arg in `sync_weights()`!" + ) + + # Only sync if we have remote workers or `from_worker_or_trainer` is provided. + rl_module_state = None + if self.num_remote_workers() or from_worker_or_learner_group is not None: + weights_src = from_worker_or_learner_group or self.local_env_runner + + if weights_src is None: + raise ValueError( + "`from_worker_or_trainer` is None. In this case, EnvRunnerGroup " + "should have local_env_runner. But local_env_runner is also None." + ) + + modules = ( + [COMPONENT_RL_MODULE + "/" + p for p in policies] + if policies is not None + else [COMPONENT_RL_MODULE] + ) + # LearnerGroup has-a Learner has-a RLModule. + if isinstance(weights_src, LearnerGroup): + rl_module_state = weights_src.get_state( + components=[COMPONENT_LEARNER + "/" + m for m in modules], + inference_only=inference_only, + )[COMPONENT_LEARNER] + # EnvRunner has-a RLModule. + elif self._remote_config.enable_env_runner_and_connector_v2: + rl_module_state = weights_src.get_state( + components=modules, + inference_only=inference_only, + ) + else: + rl_module_state = weights_src.get_weights( + policies=policies, + inference_only=inference_only, + ) + + if self._remote_config.enable_env_runner_and_connector_v2: + + # Make sure `rl_module_state` only contains the weights and the + # weight seq no, nothing else. + rl_module_state = { + k: v + for k, v in rl_module_state.items() + if k in [COMPONENT_RL_MODULE, WEIGHTS_SEQ_NO] + } + + # Move weights to the object store to avoid having to make n pickled + # copies of the weights dict for each worker. + rl_module_state_ref = ray.put(rl_module_state) + + def _set_weights(env_runner): + env_runner.set_state(ray.get(rl_module_state_ref)) + + else: + rl_module_state_ref = ray.put(rl_module_state) + + def _set_weights(env_runner): + env_runner.set_weights(ray.get(rl_module_state_ref), global_vars) + + # Sync to specified remote workers in this EnvRunnerGroup. + self.foreach_env_runner( + func=_set_weights, + local_env_runner=False, # Do not sync back to local worker. + remote_worker_ids=to_worker_indices, + timeout_seconds=timeout_seconds, + ) + + # If `from_worker_or_learner_group` is provided, also sync to this + # EnvRunnerGroup's local worker. + if self.local_env_runner is not None: + if from_worker_or_learner_group is not None: + if self._remote_config.enable_env_runner_and_connector_v2: + self.local_env_runner.set_state(rl_module_state) + else: + self.local_env_runner.set_weights(rl_module_state) + # If `global_vars` is provided and local worker exists -> Update its + # global_vars. + if global_vars is not None: + self.local_env_runner.set_global_vars(global_vars) + + @OldAPIStack + def add_policy( + self, + policy_id: PolicyID, + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, + *, + observation_space: Optional[gym.spaces.Space] = None, + action_space: Optional[gym.spaces.Space] = None, + config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None, + policy_state: Optional[PolicyState] = None, + policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Collection[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + module_spec: Optional[RLModuleSpec] = None, + # Deprecated. + workers: Optional[List[Union[EnvRunner, ActorHandle]]] = DEPRECATED_VALUE, + ) -> None: + """Adds a policy to this EnvRunnerGroup's workers or a specific list of workers. + + Args: + policy_id: ID of the policy to add. + policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this EnvRunnerGroup. If not None, the + given Policy object will be directly inserted into the + local worker and clones of that Policy will be created on all remote + workers. + Note: Only one of `policy_cls` or `policy` must be provided. + observation_space: The observation space of the policy to add. + If None, try to infer this space from the environment. + action_space: The action space of the policy to add. + If None, try to infer this space from the environment. + config: The config object or overrides for the policy to add. + policy_state: Optional state dict to apply to the new + policy instance, right after its construction. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + module_spec: In the new RLModule API we need to pass in the module_spec for + the new module that is supposed to be added. Knowing the policy spec is + not sufficient. + workers: A list of EnvRunner/ActorHandles (remote + EnvRunners) to add this policy to. If defined, will only + add the given policy to these workers. + + Raises: + KeyError: If the given `policy_id` already exists in this EnvRunnerGroup. + """ + if self.local_env_runner and policy_id in self.local_env_runner.policy_map: + raise KeyError( + f"Policy ID '{policy_id}' already exists in policy map! " + "Make sure you use a Policy ID that has not been taken yet." + " Policy IDs that are already in your policy map: " + f"{list(self.local_env_runner.policy_map.keys())}" + ) + + if workers is not DEPRECATED_VALUE: + deprecation_warning( + old="EnvRunnerGroup.add_policy(.., workers=..)", + help=( + "The `workers` argument to `EnvRunnerGroup.add_policy()` is " + "deprecated! Please do not use it anymore." + ), + error=True, + ) + + if (policy_cls is None) == (policy is None): + raise ValueError( + "Only one of `policy_cls` or `policy` must be provided to " + "staticmethod: `EnvRunnerGroup.add_policy()`!" + ) + validate_module_id(policy_id, error=False) + + # Policy instance not provided: Use the information given here. + if policy_cls is not None: + new_policy_instance_kwargs = dict( + policy_id=policy_id, + policy_cls=policy_cls, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=list(policies_to_train) + if policies_to_train + else None, + module_spec=module_spec, + ) + # Policy instance provided: Create clones of this very policy on the different + # workers (copy all its properties here for the calls to add_policy on the + # remote workers). + else: + new_policy_instance_kwargs = dict( + policy_id=policy_id, + policy_cls=type(policy), + observation_space=policy.observation_space, + action_space=policy.action_space, + config=policy.config, + policy_state=policy.get_state(), + policy_mapping_fn=policy_mapping_fn, + policies_to_train=list(policies_to_train) + if policies_to_train + else None, + module_spec=module_spec, + ) + + def _create_new_policy_fn(worker): + # `foreach_env_runner` function: Adds the policy the the worker (and + # maybe changes its policy_mapping_fn - if provided here). + worker.add_policy(**new_policy_instance_kwargs) + + if self.local_env_runner is not None: + # Add policy directly by (already instantiated) object. + if policy is not None: + self.local_env_runner.add_policy( + policy_id=policy_id, + policy=policy, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + module_spec=module_spec, + ) + # Add policy by constructor kwargs. + else: + self.local_env_runner.add_policy(**new_policy_instance_kwargs) + + # Add the policy to all remote workers. + self.foreach_env_runner(_create_new_policy_fn, local_env_runner=False) + + def add_workers(self, num_workers: int, validate: bool = False) -> None: + """Creates and adds a number of remote workers to this worker set. + + Can be called several times on the same EnvRunnerGroup to add more + EnvRunners to the set. + + Args: + num_workers: The number of remote Workers to add to this + EnvRunnerGroup. + validate: Whether to validate remote workers after their construction + process. + + Raises: + RayError: If any of the constructed remote workers is not up and running + properly. + """ + old_num_workers = self._worker_manager.num_actors() + new_workers = [ + self._make_worker( + cls=self._cls, + env_creator=self._env_creator, + validate_env=None, + worker_index=old_num_workers + i + 1, + num_workers=old_num_workers + num_workers, + config=self._remote_config, + ) + for i in range(num_workers) + ] + self._worker_manager.add_actors(new_workers) + + # Validate here, whether all remote workers have been constructed properly + # and are "up and running". Establish initial states. + if validate: + for result in self._worker_manager.foreach_actor( + lambda w: w.assert_healthy() + ): + # Simiply raise the error, which will get handled by the try-except + # clause around the _setup(). + if not result.ok: + e = result.get() + if self._ignore_ray_errors_on_env_runners: + logger.error(f"Validation of EnvRunner failed! Error={str(e)}") + else: + raise e + + def reset(self, new_remote_workers: List[ActorHandle]) -> None: + """Hard overrides the remote EnvRunners in this set with the provided ones. + + Args: + new_remote_workers: A list of new EnvRunners (as `ActorHandles`) to use as + new remote workers. + """ + self._worker_manager.clear() + self._worker_manager.add_actors(new_remote_workers) + + def stop(self) -> None: + """Calls `stop` on all EnvRunners (including the local one).""" + try: + # Make sure we stop all EnvRunners, include the ones that were just + # restarted / recovered or that are tagged unhealthy (at least, we should + # try). + self.foreach_env_runner( + lambda w: w.stop(), healthy_only=False, local_env_runner=True + ) + except Exception: + logger.exception("Failed to stop workers!") + finally: + self._worker_manager.clear() + + def is_policy_to_train( + self, policy_id: PolicyID, batch: Optional[SampleBatchType] = None + ) -> bool: + """Whether given PolicyID (optionally inside some batch) is trainable.""" + if self.local_env_runner: + if self.local_env_runner.is_policy_to_train is None: + return True + return self.local_env_runner.is_policy_to_train(policy_id, batch) + else: + raise NotImplementedError + + def foreach_env_runner( + self, + func: Callable[[EnvRunner], T], + *, + local_env_runner: bool = True, + healthy_only: bool = True, + remote_worker_ids: List[int] = None, + timeout_seconds: Optional[float] = None, + return_obj_refs: bool = False, + mark_healthy: bool = False, + ) -> List[T]: + """Calls the given function with each EnvRunner as its argument. + + Args: + func: The function to call for each EnvRunners. The only call argument is + the respective EnvRunner instance. + local_env_runner: Whether to apply `func` to local EnvRunner, too. + Default is True. + healthy_only: Apply `func` on known-to-be healthy EnvRunners only. + remote_worker_ids: Apply `func` on a selected set of remote EnvRunners. + Use None (default) for all remote EnvRunners. + timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for + fire-and-forget. Set this to None (default) to wait infinitely (i.e. for + synchronous execution). + return_obj_refs: Whether to return ObjectRef instead of actual results. + Note, for fault tolerance reasons, these returned ObjectRefs should + never be resolved with ray.get() outside of this EnvRunnerGroup. + mark_healthy: Whether to mark all those EnvRunners healthy again that are + currently marked unhealthy AND that returned results from the remote + call (within the given `timeout_seconds`). + Note that EnvRunners are NOT set unhealthy, if they simply time out + (only if they return a RayActorError). + Also note that this setting is ignored if `healthy_only=True` (b/c + `mark_healthy` only affects EnvRunners that are currently tagged as + unhealthy). + + Returns: + The list of return values of all calls to `func([worker])`. + """ + assert ( + not return_obj_refs or not local_env_runner + ), "Can not return ObjectRef from local worker." + + local_result = [] + if local_env_runner and self.local_env_runner is not None: + local_result = [func(self.local_env_runner)] + + if not self._worker_manager.actor_ids(): + return local_result + + remote_results = self._worker_manager.foreach_actor( + func, + healthy_only=healthy_only, + remote_actor_ids=remote_worker_ids, + timeout_seconds=timeout_seconds, + return_obj_refs=return_obj_refs, + mark_healthy=mark_healthy, + ) + + FaultTolerantActorManager.handle_remote_call_result_errors( + remote_results, ignore_ray_errors=self._ignore_ray_errors_on_env_runners + ) + + # With application errors handled, return good results. + remote_results = [r.get() for r in remote_results.ignore_errors()] + + return local_result + remote_results + + def foreach_env_runner_with_id( + self, + func: Callable[[int, EnvRunner], T], + *, + local_env_runner: bool = True, + healthy_only: bool = True, + remote_worker_ids: List[int] = None, + timeout_seconds: Optional[float] = None, + return_obj_refs: bool = False, + mark_healthy: bool = False, + # Deprecated args. + local_worker=DEPRECATED_VALUE, + ) -> List[T]: + """Calls the given function with each EnvRunner and its ID as its arguments. + + Args: + func: The function to call for each EnvRunners. The call arguments are + the EnvRunner's index (int) and the respective EnvRunner instance + itself. + local_env_runner: Whether to apply `func` to the local EnvRunner, too. + Default is True. + healthy_only: Apply `func` on known-to-be healthy EnvRunners only. + remote_worker_ids: Apply `func` on a selected set of remote EnvRunners. + timeout_seconds: Time to wait for results. Default is None. + return_obj_refs: Whether to return ObjectRef instead of actual results. + Note, for fault tolerance reasons, these returned ObjectRefs should + never be resolved with ray.get() outside of this EnvRunnerGroup. + mark_healthy: Whether to mark all those EnvRunners healthy again that are + currently marked unhealthy AND that returned results from the remote + call (within the given `timeout_seconds`). + Note that workers are NOT set unhealthy, if they simply time out + (only if they return a RayActorError). + Also note that this setting is ignored if `healthy_only=True` (b/c + `mark_healthy` only affects EnvRunners that are currently tagged as + unhealthy). + + Returns: + The list of return values of all calls to `func([worker, id])`. + """ + local_result = [] + if local_env_runner and self.local_env_runner is not None: + local_result = [func(0, self.local_env_runner)] + + if not remote_worker_ids: + remote_worker_ids = self._worker_manager.actor_ids() + + funcs = [functools.partial(func, i) for i in remote_worker_ids] + + remote_results = self._worker_manager.foreach_actor( + funcs, + healthy_only=healthy_only, + remote_actor_ids=remote_worker_ids, + timeout_seconds=timeout_seconds, + return_obj_refs=return_obj_refs, + mark_healthy=mark_healthy, + ) + + FaultTolerantActorManager.handle_remote_call_result_errors( + remote_results, + ignore_ray_errors=self._ignore_ray_errors_on_env_runners, + ) + + remote_results = [r.get() for r in remote_results.ignore_errors()] + + return local_result + remote_results + + def foreach_env_runner_async( + self, + func: Callable[[EnvRunner], T], + *, + healthy_only: bool = True, + remote_worker_ids: List[int] = None, + ) -> int: + """Calls the given function asynchronously with each EnvRunner as the argument. + + Does not return results directly. Instead, `fetch_ready_async_reqs()` can be + used to pull results in an async manner whenever they are available. + + Args: + func: The function to call for each EnvRunners. The only call argument is + the respective EnvRunner instance. + healthy_only: Apply `func` on known-to-be healthy EnvRunners only. + remote_worker_ids: Apply `func` on a selected set of remote EnvRunners. + + Returns: + The number of async requests that have actually been made. This is the + length of `remote_worker_ids` (or self.num_remote_workers()` if + `remote_worker_ids` is None) minus the number of requests that were NOT + made b/c a remote EnvRunner already had its + `max_remote_requests_in_flight_per_actor` counter reached. + """ + return self._worker_manager.foreach_actor_async( + func, + healthy_only=healthy_only, + remote_actor_ids=remote_worker_ids, + ) + + def fetch_ready_async_reqs( + self, + *, + timeout_seconds: Optional[float] = 0.0, + return_obj_refs: bool = False, + mark_healthy: bool = False, + ) -> List[Tuple[int, T]]: + """Get esults from outstanding asynchronous requests that are ready. + + Args: + timeout_seconds: Time to wait for results. Default is 0, meaning + those requests that are already ready. + return_obj_refs: Whether to return ObjectRef instead of actual results. + mark_healthy: Whether to mark all those workers healthy again that are + currently marked unhealthy AND that returned results from the remote + call (within the given `timeout_seconds`). + Note that workers are NOT set unhealthy, if they simply time out + (only if they return a RayActorError). + Also note that this setting is ignored if `healthy_only=True` (b/c + `mark_healthy` only affects workers that are currently tagged as + unhealthy). + + Returns: + A list of results successfully returned from outstanding remote calls, + paired with the indices of the callee workers. + """ + remote_results = self._worker_manager.fetch_ready_async_reqs( + timeout_seconds=timeout_seconds, + return_obj_refs=return_obj_refs, + mark_healthy=mark_healthy, + ) + + FaultTolerantActorManager.handle_remote_call_result_errors( + remote_results, + ignore_ray_errors=self._ignore_ray_errors_on_env_runners, + ) + + return [(r.actor_id, r.get()) for r in remote_results.ignore_errors()] + + def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]: + """Calls `func` with each worker's (policy, PolicyID) tuple. + + Note that in the multi-agent case, each worker may have more than one + policy. + + Args: + func: A function - taking a Policy and its ID - that is + called on all workers' Policies. + + Returns: + The list of return values of func over all workers' policies. The + length of this list is: + (num_workers + 1 (local-worker)) * + [num policies in the multi-agent config dict]. + The local workers' results are first, followed by all remote + workers' results + """ + results = [] + for r in self.foreach_env_runner( + lambda w: w.foreach_policy(func), local_env_runner=True + ): + results.extend(r) + return results + + def foreach_policy_to_train(self, func: Callable[[Policy, PolicyID], T]) -> List[T]: + """Apply `func` to all workers' Policies iff in `policies_to_train`. + + Args: + func: A function - taking a Policy and its ID - that is + called on all workers' Policies, for which + `worker.is_policy_to_train()` returns True. + + Returns: + List[any]: The list of n return values of all + `func([trainable policy], [ID])`-calls. + """ + results = [] + for r in self.foreach_env_runner( + lambda w: w.foreach_policy_to_train(func), local_env_runner=True + ): + results.extend(r) + return results + + def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]: + """Calls `func` with all workers' sub-environments as args. + + An "underlying sub environment" is a single clone of an env within + a vectorized environment. + `func` takes a single underlying sub environment as arg, e.g. a + gym.Env object. + + Args: + func: A function - taking an EnvType (normally a gym.Env object) + as arg and returning a list of lists of return values, one + value per underlying sub-environment per each worker. + + Returns: + The list (workers) of lists (sub environments) of results. + """ + return list( + self.foreach_env_runner( + lambda w: w.foreach_env(func), + local_env_runner=True, + ) + ) + + def foreach_env_with_context( + self, func: Callable[[BaseEnv, EnvContext], List[T]] + ) -> List[List[T]]: + """Calls `func` with all workers' sub-environments and env_ctx as args. + + An "underlying sub environment" is a single clone of an env within + a vectorized environment. + `func` takes a single underlying sub environment and the env_context + as args. + + Args: + func: A function - taking a BaseEnv object and an EnvContext as + arg - and returning a list of lists of return values over envs + of the worker. + + Returns: + The list (1 item per workers) of lists (1 item per sub-environment) + of results. + """ + return list( + self.foreach_env_runner( + lambda w: w.foreach_env_with_context(func), + local_env_runner=True, + ) + ) + + def probe_unhealthy_env_runners(self) -> List[int]: + """Checks for unhealthy workers and tries restoring their states. + + Returns: + List of IDs of the workers that were restored. + """ + return self._worker_manager.probe_unhealthy_actors( + timeout_seconds=self._remote_config.env_runner_health_probe_timeout_s, + mark_healthy=True, + ) + + def _make_worker( + self, + *, + cls: Callable, + env_creator: EnvCreator, + validate_env: Optional[Callable[[EnvType], None]], + worker_index: int, + num_workers: int, + recreated_worker: bool = False, + config: "AlgorithmConfig", + spaces: Optional[ + Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]] + ] = None, + ) -> Union[EnvRunner, ActorHandle]: + worker = cls( + env_creator=env_creator, + validate_env=validate_env, + default_policy_class=self._policy_class, + config=config, + worker_index=worker_index, + num_workers=num_workers, + recreated_worker=recreated_worker, + log_dir=self._logdir, + spaces=spaces, + dataset_shards=self._ds_shards, + tune_trial_id=self._tune_trial_id, + ) + + return worker + + @classmethod + def _valid_module(cls, class_path): + del cls + if ( + isinstance(class_path, str) + and not os.path.isfile(class_path) + and "." in class_path + ): + module_path, class_name = class_path.rsplit(".", 1) + try: + spec = importlib.util.find_spec(module_path) + if spec is not None: + return True + except (ModuleNotFoundError, ValueError): + print( + f"module {module_path} not found while trying to get " + f"input {class_path}" + ) + return False + + @Deprecated(new="EnvRunnerGroup.probe_unhealthy_env_runners", error=False) + def probe_unhealthy_workers(self, *args, **kwargs): + return self.probe_unhealthy_env_runners(*args, **kwargs) + + @Deprecated(new="EnvRunnerGroup.foreach_env_runner", error=False) + def foreach_worker(self, *args, **kwargs): + return self.foreach_env_runner(*args, **kwargs) + + @Deprecated(new="EnvRunnerGroup.foreach_env_runner_with_id", error=False) + def foreach_worker_with_id(self, *args, **kwargs): + return self.foreach_env_runner_with_id(*args, **kwargs) + + @Deprecated(new="EnvRunnerGroup.foreach_env_runner_async", error=False) + def foreach_worker_async(self, *args, **kwargs): + return self.foreach_env_runner_async(*args, **kwargs) + + @Deprecated(new="EnvRunnerGroup.local_env_runner", error=True) + def local_worker(self) -> EnvRunner: + pass + + @property + @Deprecated( + old="_remote_workers", + new="Use either the `foreach_env_runner()`, `foreach_env_runner_with_id()`, or " + "`foreach_env_runner_async()` APIs of `EnvRunnerGroup`, which all handle fault " + "tolerance.", + error=True, + ) + def _remote_workers(self): + pass + + @Deprecated( + old="remote_workers()", + new="Use either the `foreach_env_runner()`, `foreach_env_runner_with_id()`, or " + "`foreach_env_runner_async()` APIs of `EnvRunnerGroup`, which all handle fault " + "tolerance.", + error=True, + ) + def remote_workers(self): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py new file mode 100644 index 0000000000000000000000000000000000000000..41eb89d6c471571beca90b9659d262a5283e4519 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py @@ -0,0 +1,481 @@ +import gymnasium as gym +import queue +import threading +import uuid +from typing import Callable, Tuple, Optional, TYPE_CHECKING + +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.utils.annotations import override, OldAPIStack +from ray.rllib.utils.typing import ( + EnvActionType, + EnvInfoDict, + EnvObsType, + EnvType, + MultiEnvDict, +) +from ray.rllib.utils.deprecation import deprecation_warning + +if TYPE_CHECKING: + from ray.rllib.models.preprocessors import Preprocessor + + +@OldAPIStack +class ExternalEnv(threading.Thread): + """An environment that interfaces with external agents. + + Unlike simulator envs, control is inverted: The environment queries the + policy to obtain actions and in return logs observations and rewards for + training. This is in contrast to gym.Env, where the algorithm drives the + simulation through env.step() calls. + + You can use ExternalEnv as the backend for policy serving (by serving HTTP + requests in the run loop), for ingesting offline logs data (by reading + offline transitions in the run loop), or other custom use cases not easily + expressed through gym.Env. + + ExternalEnv supports both on-policy actions (through self.get_action()), + and off-policy actions (through self.log_action()). + + This env is thread-safe, but individual episodes must be executed serially. + + .. testcode:: + :skipif: True + + from ray.tune import register_env + from ray.rllib.algorithms.dqn import DQN + YourExternalEnv = ... + register_env("my_env", lambda config: YourExternalEnv(config)) + algo = DQN(env="my_env") + while True: + print(algo.train()) + """ + + def __init__( + self, + action_space: gym.Space, + observation_space: gym.Space, + max_concurrent: int = None, + ): + """Initializes an ExternalEnv instance. + + Args: + action_space: Action space of the env. + observation_space: Observation space of the env. + """ + + threading.Thread.__init__(self) + + self.daemon = True + self.action_space = action_space + self.observation_space = observation_space + self._episodes = {} + self._finished = set() + self._results_avail_condition = threading.Condition() + if max_concurrent is not None: + deprecation_warning( + "The `max_concurrent` argument has been deprecated. Please configure" + "the number of episodes using the `rollout_fragment_length` and" + "`batch_mode` arguments. Please raise an issue on the Ray Github if " + "these arguments do not support your expected use case for ExternalEnv", + error=True, + ) + + def run(self): + """Override this to implement the run loop. + + Your loop should continuously: + 1. Call self.start_episode(episode_id) + 2. Call self.[get|log]_action(episode_id, obs, [action]?) + 3. Call self.log_returns(episode_id, reward) + 4. Call self.end_episode(episode_id, obs) + 5. Wait if nothing to do. + + Multiple episodes may be started at the same time. + """ + raise NotImplementedError + + def start_episode( + self, episode_id: Optional[str] = None, training_enabled: bool = True + ) -> str: + """Record the start of an episode. + + Args: + episode_id: Unique string id for the episode or + None for it to be auto-assigned and returned. + training_enabled: Whether to use experiences for this + episode to improve the policy. + + Returns: + Unique string id for the episode. + """ + + if episode_id is None: + episode_id = uuid.uuid4().hex + + if episode_id in self._finished: + raise ValueError("Episode {} has already completed.".format(episode_id)) + + if episode_id in self._episodes: + raise ValueError("Episode {} is already started".format(episode_id)) + + self._episodes[episode_id] = _ExternalEnvEpisode( + episode_id, self._results_avail_condition, training_enabled + ) + + return episode_id + + def get_action(self, episode_id: str, observation: EnvObsType) -> EnvActionType: + """Record an observation and get the on-policy action. + + Args: + episode_id: Episode id returned from start_episode(). + observation: Current environment observation. + + Returns: + Action from the env action space. + """ + + episode = self._get(episode_id) + return episode.wait_for_action(observation) + + def log_action( + self, episode_id: str, observation: EnvObsType, action: EnvActionType + ) -> None: + """Record an observation and (off-policy) action taken. + + Args: + episode_id: Episode id returned from start_episode(). + observation: Current environment observation. + action: Action for the observation. + """ + + episode = self._get(episode_id) + episode.log_action(observation, action) + + def log_returns( + self, episode_id: str, reward: float, info: Optional[EnvInfoDict] = None + ) -> None: + """Records returns (rewards and infos) from the environment. + + The reward will be attributed to the previous action taken by the + episode. Rewards accumulate until the next action. If no reward is + logged before the next action, a reward of 0.0 is assumed. + + Args: + episode_id: Episode id returned from start_episode(). + reward: Reward from the environment. + info: Optional info dict. + """ + + episode = self._get(episode_id) + episode.cur_reward += reward + + if info: + episode.cur_info = info or {} + + def end_episode(self, episode_id: str, observation: EnvObsType) -> None: + """Records the end of an episode. + + Args: + episode_id: Episode id returned from start_episode(). + observation: Current environment observation. + """ + + episode = self._get(episode_id) + self._finished.add(episode.episode_id) + episode.done(observation) + + def _get(self, episode_id: str) -> "_ExternalEnvEpisode": + """Get a started episode by its ID or raise an error.""" + + if episode_id in self._finished: + raise ValueError("Episode {} has already completed.".format(episode_id)) + + if episode_id not in self._episodes: + raise ValueError("Episode {} not found.".format(episode_id)) + + return self._episodes[episode_id] + + def to_base_env( + self, + make_env: Optional[Callable[[int], EnvType]] = None, + num_envs: int = 1, + remote_envs: bool = False, + remote_env_batch_wait_ms: int = 0, + restart_failed_sub_environments: bool = False, + ) -> "BaseEnv": + """Converts an RLlib MultiAgentEnv into a BaseEnv object. + + The resulting BaseEnv is always vectorized (contains n + sub-environments) to support batched forward passes, where n may + also be 1. BaseEnv also supports async execution via the `poll` and + `send_actions` methods and thus supports external simulators. + + Args: + make_env: A callable taking an int as input (which indicates + the number of individual sub-environments within the final + vectorized BaseEnv) and returning one individual + sub-environment. + num_envs: The number of sub-environments to create in the + resulting (vectorized) BaseEnv. The already existing `env` + will be one of the `num_envs`. + remote_envs: Whether each sub-env should be a @ray.remote + actor. You can set this behavior in your config via the + `remote_worker_envs=True` option. + remote_env_batch_wait_ms: The wait time (in ms) to poll remote + sub-environments for, if applicable. Only used if + `remote_envs` is True. + + Returns: + The resulting BaseEnv object. + """ + if num_envs != 1: + raise ValueError( + "External(MultiAgent)Env does not currently support " + "num_envs > 1. One way of solving this would be to " + "treat your Env as a MultiAgentEnv hosting only one " + "type of agent but with several copies." + ) + env = ExternalEnvWrapper(self) + + return env + + +@OldAPIStack +class _ExternalEnvEpisode: + """Tracked state for each active episode.""" + + def __init__( + self, + episode_id: str, + results_avail_condition: threading.Condition, + training_enabled: bool, + multiagent: bool = False, + ): + self.episode_id = episode_id + self.results_avail_condition = results_avail_condition + self.training_enabled = training_enabled + self.multiagent = multiagent + self.data_queue = queue.Queue() + self.action_queue = queue.Queue() + if multiagent: + self.new_observation_dict = None + self.new_action_dict = None + self.cur_reward_dict = {} + self.cur_terminated_dict = {"__all__": False} + self.cur_truncated_dict = {"__all__": False} + self.cur_info_dict = {} + else: + self.new_observation = None + self.new_action = None + self.cur_reward = 0.0 + self.cur_terminated = False + self.cur_truncated = False + self.cur_info = {} + + def get_data(self): + if self.data_queue.empty(): + return None + return self.data_queue.get_nowait() + + def log_action(self, observation, action): + if self.multiagent: + self.new_observation_dict = observation + self.new_action_dict = action + else: + self.new_observation = observation + self.new_action = action + self._send() + self.action_queue.get(True, timeout=60.0) + + def wait_for_action(self, observation): + if self.multiagent: + self.new_observation_dict = observation + else: + self.new_observation = observation + self._send() + return self.action_queue.get(True, timeout=300.0) + + def done(self, observation): + if self.multiagent: + self.new_observation_dict = observation + self.cur_terminated_dict = {"__all__": True} + # TODO(sven): External env API does not currently support truncated, + # but we should deprecate external Env anyways in favor of a client-only + # approach. + self.cur_truncated_dict = {"__all__": False} + else: + self.new_observation = observation + self.cur_terminated = True + self.cur_truncated = False + self._send() + + def _send(self): + if self.multiagent: + if not self.training_enabled: + for agent_id in self.cur_info_dict: + self.cur_info_dict[agent_id]["training_enabled"] = False + item = { + "obs": self.new_observation_dict, + "reward": self.cur_reward_dict, + "terminated": self.cur_terminated_dict, + "truncated": self.cur_truncated_dict, + "info": self.cur_info_dict, + } + if self.new_action_dict is not None: + item["off_policy_action"] = self.new_action_dict + self.new_observation_dict = None + self.new_action_dict = None + self.cur_reward_dict = {} + else: + item = { + "obs": self.new_observation, + "reward": self.cur_reward, + "terminated": self.cur_terminated, + "truncated": self.cur_truncated, + "info": self.cur_info, + } + if self.new_action is not None: + item["off_policy_action"] = self.new_action + self.new_observation = None + self.new_action = None + self.cur_reward = 0.0 + if not self.training_enabled: + item["info"]["training_enabled"] = False + + with self.results_avail_condition: + self.data_queue.put_nowait(item) + self.results_avail_condition.notify() + + +@OldAPIStack +class ExternalEnvWrapper(BaseEnv): + """Internal adapter of ExternalEnv to BaseEnv.""" + + def __init__( + self, external_env: "ExternalEnv", preprocessor: "Preprocessor" = None + ): + from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv + + self.external_env = external_env + self.prep = preprocessor + self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv) + self._action_space = external_env.action_space + if preprocessor: + self._observation_space = preprocessor.observation_space + else: + self._observation_space = external_env.observation_space + external_env.start() + + @override(BaseEnv) + def poll( + self, + ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]: + with self.external_env._results_avail_condition: + results = self._poll() + while len(results[0]) == 0: + self.external_env._results_avail_condition.wait() + results = self._poll() + if not self.external_env.is_alive(): + raise Exception("Serving thread has stopped.") + return results + + @override(BaseEnv) + def send_actions(self, action_dict: MultiEnvDict) -> None: + from ray.rllib.env.base_env import _DUMMY_AGENT_ID + + if self.multiagent: + for env_id, actions in action_dict.items(): + self.external_env._episodes[env_id].action_queue.put(actions) + else: + for env_id, action in action_dict.items(): + self.external_env._episodes[env_id].action_queue.put( + action[_DUMMY_AGENT_ID] + ) + + def _poll( + self, + ) -> Tuple[ + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + ]: + from ray.rllib.env.base_env import with_dummy_agent_id + + all_obs, all_rewards, all_terminateds, all_truncateds, all_infos = ( + {}, + {}, + {}, + {}, + {}, + ) + off_policy_actions = {} + for eid, episode in self.external_env._episodes.copy().items(): + data = episode.get_data() + cur_terminated = ( + episode.cur_terminated_dict["__all__"] + if self.multiagent + else episode.cur_terminated + ) + cur_truncated = ( + episode.cur_truncated_dict["__all__"] + if self.multiagent + else episode.cur_truncated + ) + if cur_terminated or cur_truncated: + del self.external_env._episodes[eid] + if data: + if self.prep: + all_obs[eid] = self.prep.transform(data["obs"]) + else: + all_obs[eid] = data["obs"] + all_rewards[eid] = data["reward"] + all_terminateds[eid] = data["terminated"] + all_truncateds[eid] = data["truncated"] + all_infos[eid] = data["info"] + if "off_policy_action" in data: + off_policy_actions[eid] = data["off_policy_action"] + if self.multiagent: + # Ensure a consistent set of keys + # rely on all_obs having all possible keys for now. + for eid, eid_dict in all_obs.items(): + for agent_id in eid_dict.keys(): + + def fix(d, zero_val): + if agent_id not in d[eid]: + d[eid][agent_id] = zero_val + + fix(all_rewards, 0.0) + fix(all_terminateds, False) + fix(all_truncateds, False) + fix(all_infos, {}) + return ( + all_obs, + all_rewards, + all_terminateds, + all_truncateds, + all_infos, + off_policy_actions, + ) + else: + return ( + with_dummy_agent_id(all_obs), + with_dummy_agent_id(all_rewards), + with_dummy_agent_id(all_terminateds, "__all__"), + with_dummy_agent_id(all_truncateds, "__all__"), + with_dummy_agent_id(all_infos), + with_dummy_agent_id(off_policy_actions), + ) + + @property + @override(BaseEnv) + def observation_space(self) -> gym.spaces.Dict: + return self._observation_space + + @property + @override(BaseEnv) + def action_space(self) -> gym.Space: + return self._action_space diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py new file mode 100644 index 0000000000000000000000000000000000000000..1350d5c7c3563f52a8c2c8a7951369d702a0307f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py @@ -0,0 +1,161 @@ +import uuid +import gymnasium as gym +from typing import Optional + +from ray.rllib.utils.annotations import override, OldAPIStack +from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode +from ray.rllib.utils.typing import MultiAgentDict + + +@OldAPIStack +class ExternalMultiAgentEnv(ExternalEnv): + """This is the multi-agent version of ExternalEnv.""" + + def __init__( + self, + action_space: gym.Space, + observation_space: gym.Space, + ): + """Initializes an ExternalMultiAgentEnv instance. + + Args: + action_space: Action space of the env. + observation_space: Observation space of the env. + """ + ExternalEnv.__init__(self, action_space, observation_space) + + # We require to know all agents' spaces. + if isinstance(self.action_space, dict) or isinstance( + self.observation_space, dict + ): + if not (self.action_space.keys() == self.observation_space.keys()): + raise ValueError( + "Agent ids disagree for action space and obs " + "space dict: {} {}".format( + self.action_space.keys(), self.observation_space.keys() + ) + ) + + def run(self): + """Override this to implement the multi-agent run loop. + + Your loop should continuously: + 1. Call self.start_episode(episode_id) + 2. Call self.get_action(episode_id, obs_dict) + -or- + self.log_action(episode_id, obs_dict, action_dict) + 3. Call self.log_returns(episode_id, reward_dict) + 4. Call self.end_episode(episode_id, obs_dict) + 5. Wait if nothing to do. + + Multiple episodes may be started at the same time. + """ + raise NotImplementedError + + @override(ExternalEnv) + def start_episode( + self, episode_id: Optional[str] = None, training_enabled: bool = True + ) -> str: + if episode_id is None: + episode_id = uuid.uuid4().hex + + if episode_id in self._finished: + raise ValueError("Episode {} has already completed.".format(episode_id)) + + if episode_id in self._episodes: + raise ValueError("Episode {} is already started".format(episode_id)) + + self._episodes[episode_id] = _ExternalEnvEpisode( + episode_id, self._results_avail_condition, training_enabled, multiagent=True + ) + + return episode_id + + @override(ExternalEnv) + def get_action( + self, episode_id: str, observation_dict: MultiAgentDict + ) -> MultiAgentDict: + """Record an observation and get the on-policy action. + + Thereby, observation_dict is expected to contain the observation + of all agents acting in this episode step. + + Args: + episode_id: Episode id returned from start_episode(). + observation_dict: Current environment observation. + + Returns: + action: Action from the env action space. + """ + + episode = self._get(episode_id) + return episode.wait_for_action(observation_dict) + + @override(ExternalEnv) + def log_action( + self, + episode_id: str, + observation_dict: MultiAgentDict, + action_dict: MultiAgentDict, + ) -> None: + """Record an observation and (off-policy) action taken. + + Args: + episode_id: Episode id returned from start_episode(). + observation_dict: Current environment observation. + action_dict: Action for the observation. + """ + + episode = self._get(episode_id) + episode.log_action(observation_dict, action_dict) + + @override(ExternalEnv) + def log_returns( + self, + episode_id: str, + reward_dict: MultiAgentDict, + info_dict: MultiAgentDict = None, + multiagent_done_dict: MultiAgentDict = None, + ) -> None: + """Record returns from the environment. + + The reward will be attributed to the previous action taken by the + episode. Rewards accumulate until the next action. If no reward is + logged before the next action, a reward of 0.0 is assumed. + + Args: + episode_id: Episode id returned from start_episode(). + reward_dict: Reward from the environment agents. + info_dict: Optional info dict. + multiagent_done_dict: Optional done dict for agents. + """ + + episode = self._get(episode_id) + + # Accumulate reward by agent. + # For existing agents, we want to add the reward up. + for agent, rew in reward_dict.items(): + if agent in episode.cur_reward_dict: + episode.cur_reward_dict[agent] += rew + else: + episode.cur_reward_dict[agent] = rew + + if multiagent_done_dict: + for agent, done in multiagent_done_dict.items(): + episode.cur_done_dict[agent] = done + + if info_dict: + episode.cur_info_dict = info_dict or {} + + @override(ExternalEnv) + def end_episode(self, episode_id: str, observation_dict: MultiAgentDict) -> None: + """Record the end of an episode. + + Args: + episode_id: Episode id returned from start_episode(). + observation_dict: Current environment observation. + """ + + episode = self._get(episode_id) + self._finished.add(episode.episode_id) + episode.done(observation_dict) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c21acec528c2a9527374254702e9f9c00b1a7131 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py @@ -0,0 +1,799 @@ +import gymnasium as gym +import logging +from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type + +import numpy as np + +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.env.env_context import EnvContext +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.typing import ( + AgentID, + EnvCreator, + EnvID, + EnvType, + MultiAgentDict, + MultiEnvDict, +) +from ray.util import log_once +from ray.util.annotations import DeveloperAPI, PublicAPI + +# If the obs space is Dict type, look for the global state under this key. +ENV_STATE = "state" + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="beta") +class MultiAgentEnv(gym.Env): + """An environment that hosts multiple independent agents. + + Agents are identified by AgentIDs (string). + """ + + # Optional mappings from AgentID to individual agents' spaces. + # Set this to an "exhaustive" dictionary, mapping all possible AgentIDs to + # individual agents' spaces. Alternatively, override + # `get_observation_space(agent_id=...)` and `get_action_space(agent_id=...)`, which + # is the API that RLlib uses to get individual spaces and whose default + # implementation is to simply look up `agent_id` in these dicts. + observation_spaces: Optional[Dict[AgentID, gym.Space]] = None + action_spaces: Optional[Dict[AgentID, gym.Space]] = None + + # All agents currently active in the environment. This attribute may change during + # the lifetime of the env or even during an individual episode. + agents: List[AgentID] = [] + # All agents that may appear in the environment, ever. + # This attribute should not be changed during the lifetime of this env. + possible_agents: List[AgentID] = [] + + # @OldAPIStack, use `observation_spaces` and `action_spaces`, instead. + observation_space: Optional[gym.Space] = None + action_space: Optional[gym.Space] = None + + def __init__(self): + super().__init__() + + # @OldAPIStack + if not hasattr(self, "_agent_ids"): + self._agent_ids = set() + + # If these important attributes are not set, try to infer them. + if not self.agents: + self.agents = list(self._agent_ids) + if not self.possible_agents: + self.possible_agents = self.agents.copy() + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[MultiAgentDict, MultiAgentDict]: # type: ignore + """Resets the env and returns observations from ready agents. + + Args: + seed: An optional seed to use for the new episode. + + Returns: + New observations for each ready agent. + + .. testcode:: + :skipif: True + + from ray.rllib.env.multi_agent_env import MultiAgentEnv + class MyMultiAgentEnv(MultiAgentEnv): + # Define your env here. + env = MyMultiAgentEnv() + obs, infos = env.reset(seed=42, options={}) + print(obs) + + .. testoutput:: + + { + "car_0": [2.4, 1.6], + "car_1": [3.4, -3.2], + "traffic_light_1": [0, 3, 5, 1], + } + """ + # Call super's `reset()` method to (maybe) set the given `seed`. + super().reset(seed=seed, options=options) + + def step( + self, action_dict: MultiAgentDict + ) -> Tuple[ + MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict + ]: + """Returns observations from ready agents. + + The returns are dicts mapping from agent_id strings to values. The + number of agents in the env can vary over time. + + Returns: + Tuple containing 1) new observations for + each ready agent, 2) reward values for each ready agent. If + the episode is just started, the value will be None. + 3) Terminated values for each ready agent. The special key + "__all__" (required) is used to indicate env termination. + 4) Truncated values for each ready agent. + 5) Info values for each agent id (may be empty dicts). + + .. testcode:: + :skipif: True + + env = ... + obs, rewards, terminateds, truncateds, infos = env.step(action_dict={ + "car_0": 1, "car_1": 0, "traffic_light_1": 2, + }) + print(rewards) + + print(terminateds) + + print(infos) + + .. testoutput:: + + { + "car_0": 3, + "car_1": -1, + "traffic_light_1": 0, + } + { + "car_0": False, # car_0 is still running + "car_1": True, # car_1 is terminated + "__all__": False, # the env is not terminated + } + { + "car_0": {}, # info for car_0 + "car_1": {}, # info for car_1 + } + + """ + raise NotImplementedError + + def render(self) -> None: + """Tries to render the environment.""" + + # By default, do nothing. + pass + + def get_observation_space(self, agent_id: AgentID) -> gym.Space: + if self.observation_spaces is not None: + return self.observation_spaces[agent_id] + + # @OldAPIStack behavior. + # `self.observation_space` is a `gym.spaces.Dict` AND contains `agent_id`. + if ( + isinstance(self.observation_space, gym.spaces.Dict) + and agent_id in self.observation_space.spaces + ): + return self.observation_space[agent_id] + # `self.observation_space` is not a `gym.spaces.Dict` OR doesn't contain + # `agent_id` -> The defined space is most likely meant to be the space + # for all agents. + else: + return self.observation_space + + def get_action_space(self, agent_id: AgentID) -> gym.Space: + if self.action_spaces is not None: + return self.action_spaces[agent_id] + + # @OldAPIStack behavior. + # `self.action_space` is a `gym.spaces.Dict` AND contains `agent_id`. + if ( + isinstance(self.action_space, gym.spaces.Dict) + and agent_id in self.action_space.spaces + ): + return self.action_space[agent_id] + # `self.action_space` is not a `gym.spaces.Dict` OR doesn't contain + # `agent_id` -> The defined space is most likely meant to be the space + # for all agents. + else: + return self.action_space + + @property + def num_agents(self) -> int: + return len(self.agents) + + @property + def max_num_agents(self) -> int: + return len(self.possible_agents) + + # fmt: off + # __grouping_doc_begin__ + def with_agent_groups( + self, + groups: Dict[str, List[AgentID]], + obs_space: gym.Space = None, + act_space: gym.Space = None, + ) -> "MultiAgentEnv": + """Convenience method for grouping together agents in this env. + + An agent group is a list of agent IDs that are mapped to a single + logical agent. All agents of the group must act at the same time in the + environment. The grouped agent exposes Tuple action and observation + spaces that are the concatenated action and obs spaces of the + individual agents. + + The rewards of all the agents in a group are summed. The individual + agent rewards are available under the "individual_rewards" key of the + group info return. + + Agent grouping is required to leverage algorithms such as Q-Mix. + + Args: + groups: Mapping from group id to a list of the agent ids + of group members. If an agent id is not present in any group + value, it will be left ungrouped. The group id becomes a new agent ID + in the final environment. + obs_space: Optional observation space for the grouped + env. Must be a tuple space. If not provided, will infer this to be a + Tuple of n individual agents spaces (n=num agents in a group). + act_space: Optional action space for the grouped env. + Must be a tuple space. If not provided, will infer this to be a Tuple + of n individual agents spaces (n=num agents in a group). + + .. testcode:: + :skipif: True + + from ray.rllib.env.multi_agent_env import MultiAgentEnv + class MyMultiAgentEnv(MultiAgentEnv): + # define your env here + ... + env = MyMultiAgentEnv(...) + grouped_env = env.with_agent_groups(env, { + "group1": ["agent1", "agent2", "agent3"], + "group2": ["agent4", "agent5"], + }) + + """ + + from ray.rllib.env.wrappers.group_agents_wrapper import \ + GroupAgentsWrapper + return GroupAgentsWrapper(self, groups, obs_space, act_space) + + # __grouping_doc_end__ + # fmt: on + + @OldAPIStack + @Deprecated(new="MultiAgentEnv.possible_agents", error=False) + def get_agent_ids(self) -> Set[AgentID]: + if not hasattr(self, "_agent_ids"): + self._agent_ids = set() + if not isinstance(self._agent_ids, set): + self._agent_ids = set(self._agent_ids) + # Make this backward compatible as much as possible. + return self._agent_ids if self._agent_ids else set(self.agents) + + @OldAPIStack + def to_base_env( + self, + make_env: Optional[Callable[[int], EnvType]] = None, + num_envs: int = 1, + remote_envs: bool = False, + remote_env_batch_wait_ms: int = 0, + restart_failed_sub_environments: bool = False, + ) -> "BaseEnv": + """Converts an RLlib MultiAgentEnv into a BaseEnv object. + + The resulting BaseEnv is always vectorized (contains n + sub-environments) to support batched forward passes, where n may + also be 1. BaseEnv also supports async execution via the `poll` and + `send_actions` methods and thus supports external simulators. + + Args: + make_env: A callable taking an int as input (which indicates + the number of individual sub-environments within the final + vectorized BaseEnv) and returning one individual + sub-environment. + num_envs: The number of sub-environments to create in the + resulting (vectorized) BaseEnv. The already existing `env` + will be one of the `num_envs`. + remote_envs: Whether each sub-env should be a @ray.remote + actor. You can set this behavior in your config via the + `remote_worker_envs=True` option. + remote_env_batch_wait_ms: The wait time (in ms) to poll remote + sub-environments for, if applicable. Only used if + `remote_envs` is True. + restart_failed_sub_environments: If True and any sub-environment (within + a vectorized env) throws any error during env stepping, we will try to + restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environments. + + Returns: + The resulting BaseEnv object. + """ + from ray.rllib.env.remote_base_env import RemoteBaseEnv + + if remote_envs: + env = RemoteBaseEnv( + make_env, + num_envs, + multiagent=True, + remote_env_batch_wait_ms=remote_env_batch_wait_ms, + restart_failed_sub_environments=restart_failed_sub_environments, + ) + # Sub-environments are not ray.remote actors. + else: + env = MultiAgentEnvWrapper( + make_env=make_env, + existing_envs=[self], + num_envs=num_envs, + restart_failed_sub_environments=restart_failed_sub_environments, + ) + + return env + + +@DeveloperAPI +def make_multi_agent( + env_name_or_creator: Union[str, EnvCreator], +) -> Type["MultiAgentEnv"]: + """Convenience wrapper for any single-agent env to be converted into MA. + + Allows you to convert a simple (single-agent) `gym.Env` class + into a `MultiAgentEnv` class. This function simply stacks n instances + of the given ```gym.Env``` class into one unified ``MultiAgentEnv`` class + and returns this class, thus pretending the agents act together in the + same environment, whereas - under the hood - they live separately from + each other in n parallel single-agent envs. + + Agent IDs in the resulting and are int numbers starting from 0 + (first agent). + + Args: + env_name_or_creator: String specifier or env_maker function taking + an EnvContext object as only arg and returning a gym.Env. + + Returns: + New MultiAgentEnv class to be used as env. + The constructor takes a config dict with `num_agents` key + (default=1). The rest of the config dict will be passed on to the + underlying single-agent env's constructor. + + .. testcode:: + :skipif: True + + from ray.rllib.env.multi_agent_env import make_multi_agent + # By gym string: + ma_cartpole_cls = make_multi_agent("CartPole-v1") + # Create a 2 agent multi-agent cartpole. + ma_cartpole = ma_cartpole_cls({"num_agents": 2}) + obs = ma_cartpole.reset() + print(obs) + + # By env-maker callable: + from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole + ma_stateless_cartpole_cls = make_multi_agent( + lambda config: StatelessCartPole(config)) + # Create a 3 agent multi-agent stateless cartpole. + ma_stateless_cartpole = ma_stateless_cartpole_cls( + {"num_agents": 3}) + print(obs) + + .. testoutput:: + + {0: [...], 1: [...]} + {0: [...], 1: [...], 2: [...]} + """ + + class MultiEnv(MultiAgentEnv): + def __init__(self, config: EnvContext = None): + super().__init__() + + # Note: Explicitly check for None here, because config + # can have an empty dict but meaningful data fields (worker_index, + # vector_index) etc. + # TODO (sven): Clean this up, so we are not mixing up dict fields + # with data fields. + if config is None: + config = {} + num = config.pop("num_agents", 1) + if isinstance(env_name_or_creator, str): + self.envs = [gym.make(env_name_or_creator) for _ in range(num)] + else: + self.envs = [env_name_or_creator(config) for _ in range(num)] + self.terminateds = set() + self.truncateds = set() + self.observation_spaces = { + i: self.envs[i].observation_space for i in range(num) + } + self.action_spaces = {i: self.envs[i].action_space for i in range(num)} + self.agents = list(range(num)) + self.possible_agents = self.agents.copy() + + @override(MultiAgentEnv) + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + self.terminateds = set() + self.truncateds = set() + obs, infos = {}, {} + for i, env in enumerate(self.envs): + obs[i], infos[i] = env.reset(seed=seed, options=options) + + return obs, infos + + @override(MultiAgentEnv) + def step(self, action_dict): + obs, rew, terminated, truncated, info = {}, {}, {}, {}, {} + + # The environment is expecting an action for at least one agent. + if len(action_dict) == 0: + raise ValueError( + "The environment is expecting an action for at least one agent." + ) + + for i, action in action_dict.items(): + obs[i], rew[i], terminated[i], truncated[i], info[i] = self.envs[ + i + ].step(action) + if terminated[i]: + self.terminateds.add(i) + if truncated[i]: + self.truncateds.add(i) + # TODO: Flaw in our MultiAgentEnv API wrt. new gymnasium: Need to return + # an additional episode_done bool that covers cases where all agents are + # either terminated or truncated, but not all are truncated and not all are + # terminated. We can then get rid of the aweful `__all__` special keys! + terminated["__all__"] = len(self.terminateds) + len(self.truncateds) == len( + self.envs + ) + truncated["__all__"] = len(self.truncateds) == len(self.envs) + return obs, rew, terminated, truncated, info + + @override(MultiAgentEnv) + def render(self): + # This render method simply renders all n underlying individual single-agent + # envs and concatenates their images (on top of each other if the returned + # images have dims where [width] > [height], otherwise next to each other). + render_images = [e.render() for e in self.envs] + if render_images[0].shape[1] > render_images[0].shape[0]: + concat_dim = 0 + else: + concat_dim = 1 + return np.concatenate(render_images, axis=concat_dim) + + return MultiEnv + + +@OldAPIStack +class MultiAgentEnvWrapper(BaseEnv): + """Internal adapter of MultiAgentEnv to BaseEnv. + + This also supports vectorization if num_envs > 1. + """ + + def __init__( + self, + make_env: Callable[[int], EnvType], + existing_envs: List["MultiAgentEnv"], + num_envs: int, + restart_failed_sub_environments: bool = False, + ): + """Wraps MultiAgentEnv(s) into the BaseEnv API. + + Args: + make_env: Factory that produces a new MultiAgentEnv instance taking the + vector index as only call argument. + Must be defined, if the number of existing envs is less than num_envs. + existing_envs: List of already existing multi-agent envs. + num_envs: Desired num multiagent envs to have at the end in + total. This will include the given (already created) + `existing_envs`. + restart_failed_sub_environments: If True and any sub-environment (within + this vectorized env) throws any error during env stepping, we will try + to restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environments. + """ + self.make_env = make_env + self.envs = existing_envs + self.num_envs = num_envs + self.restart_failed_sub_environments = restart_failed_sub_environments + + self.terminateds = set() + self.truncateds = set() + while len(self.envs) < self.num_envs: + self.envs.append(self.make_env(len(self.envs))) + for env in self.envs: + assert isinstance(env, MultiAgentEnv) + self._init_env_state(idx=None) + self._unwrapped_env = self.envs[0].unwrapped + + @override(BaseEnv) + def poll( + self, + ) -> Tuple[ + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + ]: + obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {} + for i, env_state in enumerate(self.env_states): + ( + obs[i], + rewards[i], + terminateds[i], + truncateds[i], + infos[i], + ) = env_state.poll() + return obs, rewards, terminateds, truncateds, infos, {} + + @override(BaseEnv) + def send_actions(self, action_dict: MultiEnvDict) -> None: + for env_id, agent_dict in action_dict.items(): + if env_id in self.terminateds or env_id in self.truncateds: + raise ValueError( + f"Env {env_id} is already done and cannot accept new actions" + ) + env = self.envs[env_id] + try: + obs, rewards, terminateds, truncateds, infos = env.step(agent_dict) + except Exception as e: + if self.restart_failed_sub_environments: + logger.exception(e.args[0]) + self.try_restart(env_id=env_id) + obs = e + rewards = {} + terminateds = {"__all__": True} + truncateds = {"__all__": False} + infos = {} + else: + raise e + + assert isinstance( + obs, (dict, Exception) + ), "Not a multi-agent obs dict or an Exception!" + assert isinstance(rewards, dict), "Not a multi-agent reward dict!" + assert isinstance(terminateds, dict), "Not a multi-agent terminateds dict!" + assert isinstance(truncateds, dict), "Not a multi-agent truncateds dict!" + assert isinstance(infos, dict), "Not a multi-agent info dict!" + if isinstance(obs, dict): + info_diff = set(infos).difference(set(obs)) + if info_diff and info_diff != {"__common__"}: + raise ValueError( + "Key set for infos must be a subset of obs (plus optionally " + "the '__common__' key for infos concerning all/no agents): " + "{} vs {}".format(infos.keys(), obs.keys()) + ) + if "__all__" not in terminateds: + raise ValueError( + "In multi-agent environments, '__all__': True|False must " + "be included in the 'terminateds' dict: got {}.".format(terminateds) + ) + elif "__all__" not in truncateds: + raise ValueError( + "In multi-agent environments, '__all__': True|False must " + "be included in the 'truncateds' dict: got {}.".format(truncateds) + ) + + if terminateds["__all__"]: + self.terminateds.add(env_id) + if truncateds["__all__"]: + self.truncateds.add(env_id) + self.env_states[env_id].observe( + obs, rewards, terminateds, truncateds, infos + ) + + @override(BaseEnv) + def try_reset( + self, + env_id: Optional[EnvID] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Optional[Tuple[MultiEnvDict, MultiEnvDict]]: + ret_obs = {} + ret_infos = {} + if isinstance(env_id, int): + env_id = [env_id] + if env_id is None: + env_id = list(range(len(self.envs))) + for idx in env_id: + obs, infos = self.env_states[idx].reset(seed=seed, options=options) + + if isinstance(obs, Exception): + if self.restart_failed_sub_environments: + self.env_states[idx].env = self.envs[idx] = self.make_env(idx) + else: + raise obs + else: + assert isinstance(obs, dict), "Not a multi-agent obs dict!" + if obs is not None: + if idx in self.terminateds: + self.terminateds.remove(idx) + if idx in self.truncateds: + self.truncateds.remove(idx) + ret_obs[idx] = obs + ret_infos[idx] = infos + return ret_obs, ret_infos + + @override(BaseEnv) + def try_restart(self, env_id: Optional[EnvID] = None) -> None: + if isinstance(env_id, int): + env_id = [env_id] + if env_id is None: + env_id = list(range(len(self.envs))) + for idx in env_id: + # Try closing down the old (possibly faulty) sub-env, but ignore errors. + try: + self.envs[idx].close() + except Exception as e: + if log_once("close_sub_env"): + logger.warning( + "Trying to close old and replaced sub-environment (at vector " + f"index={idx}), but closing resulted in error:\n{e}" + ) + # Try recreating the sub-env. + logger.warning(f"Trying to restart sub-environment at index {idx}.") + self.env_states[idx].env = self.envs[idx] = self.make_env(idx) + logger.warning(f"Sub-environment at index {idx} restarted successfully.") + + @override(BaseEnv) + def get_sub_environments( + self, as_dict: bool = False + ) -> Union[Dict[str, EnvType], List[EnvType]]: + if as_dict: + return {_id: env_state.env for _id, env_state in enumerate(self.env_states)} + return [state.env for state in self.env_states] + + @override(BaseEnv) + def try_render(self, env_id: Optional[EnvID] = None) -> None: + if env_id is None: + env_id = 0 + assert isinstance(env_id, int) + return self.envs[env_id].render() + + @property + @override(BaseEnv) + def observation_space(self) -> gym.spaces.Dict: + return self.envs[0].observation_space + + @property + @override(BaseEnv) + def action_space(self) -> gym.Space: + return self.envs[0].action_space + + @override(BaseEnv) + def get_agent_ids(self) -> Set[AgentID]: + return self.envs[0].get_agent_ids() + + def _init_env_state(self, idx: Optional[int] = None) -> None: + """Resets all or one particular sub-environment's state (by index). + + Args: + idx: The index to reset at. If None, reset all the sub-environments' states. + """ + # If index is None, reset all sub-envs' states: + if idx is None: + self.env_states = [ + _MultiAgentEnvState(env, self.restart_failed_sub_environments) + for env in self.envs + ] + # Index provided, reset only the sub-env's state at the given index. + else: + assert isinstance(idx, int) + self.env_states[idx] = _MultiAgentEnvState( + self.envs[idx], self.restart_failed_sub_environments + ) + + +@OldAPIStack +class _MultiAgentEnvState: + def __init__(self, env: MultiAgentEnv, return_error_as_obs: bool = False): + assert isinstance(env, MultiAgentEnv) + self.env = env + self.return_error_as_obs = return_error_as_obs + + self.initialized = False + self.last_obs = {} + self.last_rewards = {} + self.last_terminateds = {"__all__": False} + self.last_truncateds = {"__all__": False} + self.last_infos = {} + + def poll( + self, + ) -> Tuple[ + MultiAgentDict, + MultiAgentDict, + MultiAgentDict, + MultiAgentDict, + MultiAgentDict, + ]: + if not self.initialized: + # TODO(sven): Should we make it possible to pass in a seed here? + self.reset() + self.initialized = True + + observations = self.last_obs + rewards = {} + terminateds = {"__all__": self.last_terminateds["__all__"]} + truncateds = {"__all__": self.last_truncateds["__all__"]} + infos = self.last_infos + + # If episode is done or we have an error, release everything we have. + if ( + terminateds["__all__"] + or truncateds["__all__"] + or isinstance(observations, Exception) + ): + rewards = self.last_rewards + self.last_rewards = {} + terminateds = self.last_terminateds + if isinstance(observations, Exception): + terminateds["__all__"] = True + truncateds["__all__"] = False + self.last_terminateds = {} + truncateds = self.last_truncateds + self.last_truncateds = {} + self.last_obs = {} + infos = self.last_infos + self.last_infos = {} + # Only release those agents' rewards/terminateds/truncateds/infos, whose + # observations we have. + else: + for ag in observations.keys(): + if ag in self.last_rewards: + rewards[ag] = self.last_rewards[ag] + del self.last_rewards[ag] + if ag in self.last_terminateds: + terminateds[ag] = self.last_terminateds[ag] + del self.last_terminateds[ag] + if ag in self.last_truncateds: + truncateds[ag] = self.last_truncateds[ag] + del self.last_truncateds[ag] + + self.last_terminateds["__all__"] = False + self.last_truncateds["__all__"] = False + return observations, rewards, terminateds, truncateds, infos + + def observe( + self, + obs: MultiAgentDict, + rewards: MultiAgentDict, + terminateds: MultiAgentDict, + truncateds: MultiAgentDict, + infos: MultiAgentDict, + ): + self.last_obs = obs + for ag, r in rewards.items(): + if ag in self.last_rewards: + self.last_rewards[ag] += r + else: + self.last_rewards[ag] = r + for ag, d in terminateds.items(): + if ag in self.last_terminateds: + self.last_terminateds[ag] = self.last_terminateds[ag] or d + else: + self.last_terminateds[ag] = d + for ag, t in truncateds.items(): + if ag in self.last_truncateds: + self.last_truncateds[ag] = self.last_truncateds[ag] or t + else: + self.last_truncateds[ag] = t + self.last_infos = infos + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[MultiAgentDict, MultiAgentDict]: + try: + obs_and_infos = self.env.reset(seed=seed, options=options) + except Exception as e: + if self.return_error_as_obs: + logger.exception(e.args[0]) + obs_and_infos = e, e + else: + raise e + + self.last_obs, self.last_infos = obs_and_infos + self.last_rewards = {} + self.last_terminateds = {"__all__": False} + self.last_truncateds = {"__all__": False} + + return self.last_obs, self.last_infos diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2c8d32ec46f9b755f746fb68208bb0a5710f65 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py @@ -0,0 +1,1107 @@ +from collections import defaultdict +from functools import partial +import logging +import time +from typing import Collection, DefaultDict, Dict, List, Optional, Union + +import gymnasium as gym + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.callbacks.utils import make_callback +from ray.rllib.core import ( + COMPONENT_ENV_TO_MODULE_CONNECTOR, + COMPONENT_MODULE_TO_ENV_CONNECTOR, + COMPONENT_RL_MODULE, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.env.utils import _gym_env_creator +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import get_device, try_import_torch +from ray.rllib.utils.metrics import ( + EPISODE_DURATION_SEC_MEAN, + EPISODE_LEN_MAX, + EPISODE_LEN_MEAN, + EPISODE_LEN_MIN, + EPISODE_RETURN_MAX, + EPISODE_RETURN_MEAN, + EPISODE_RETURN_MIN, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_EPISODES, + NUM_EPISODES_LIFETIME, + NUM_MODULE_STEPS_SAMPLED, + NUM_MODULE_STEPS_SAMPLED_LIFETIME, + TIME_BETWEEN_SAMPLING, + WEIGHTS_SEQ_NO, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.pre_checks.env import check_multiagent_environments +from ray.rllib.utils.typing import EpisodeID, ModelWeights, ResultDict, StateDict +from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.util.annotations import PublicAPI + +torch, _ = try_import_torch() +logger = logging.getLogger("ray.rllib") + + +# TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself +# a Checkpointable. Currently, only some of its subclasses are Checkpointables. +@PublicAPI(stability="alpha") +class MultiAgentEnvRunner(EnvRunner, Checkpointable): + """The genetic environment runner for the multi-agent case.""" + + @override(EnvRunner) + def __init__(self, config: AlgorithmConfig, **kwargs): + """Initializes a MultiAgentEnvRunner instance. + + Args: + config: An `AlgorithmConfig` object containing all settings needed to + build this `EnvRunner` class. + """ + super().__init__(config=config) + + # Raise an Error, if the provided config is not a multi-agent one. + if not self.config.is_multi_agent: + raise ValueError( + f"Cannot use this EnvRunner class ({type(self).__name__}), if your " + "setup is not multi-agent! Try adding multi-agent information to your " + "AlgorithmConfig via calling the `config.multi_agent(policies=..., " + "policy_mapping_fn=...)`." + ) + + # Get the worker index on which this instance is running. + self.worker_index: int = kwargs.get("worker_index") + self.tune_trial_id: str = kwargs.get("tune_trial_id") + + # Set up all metrics-related structures and counters. + self.metrics: Optional[MetricsLogger] = None + self._setup_metrics() + + # Create our callbacks object. + self._callbacks = [cls() for cls in force_list(self.config.callbacks_class)] + + # Set device. + self._device = get_device( + self.config, + 0 if not self.worker_index else self.config.num_gpus_per_env_runner, + ) + + # Create the vectorized gymnasium env. + self.env: Optional[gym.Wrapper] = None + self.num_envs: int = 0 + self.make_env() + + # Create the env-to-module connector pipeline. + self._env_to_module = self.config.build_env_to_module_connector( + self.env.unwrapped, device=self._device + ) + # Cached env-to-module results taken at the end of a `_sample_timesteps()` + # call to make sure the final observation (before an episode cut) gets properly + # processed (and maybe postprocessed and re-stored into the episode). + # For example, if we had a connector that normalizes observations and directly + # re-inserts these new obs back into the episode, the last observation in each + # sample call would NOT be processed, which could be very harmful in cases, + # in which value function bootstrapping of those (truncation) observations is + # required in the learning step. + self._cached_to_module = None + + # Construct the MultiRLModule. + self.module: Optional[MultiRLModule] = None + self.make_module() + + # Create the module-to-env connector pipeline. + self._module_to_env = self.config.build_module_to_env_connector( + self.env.unwrapped + ) + + self._needs_initial_reset: bool = True + self._episode: Optional[MultiAgentEpisode] = None + self._shared_data = None + + self._weights_seq_no: int = 0 + + # Measures the time passed between returning from `sample()` + # and receiving the next `sample()` request from the user. + self._time_after_sampling = None + + @override(EnvRunner) + def sample( + self, + *, + num_timesteps: int = None, + num_episodes: int = None, + explore: bool = None, + random_actions: bool = False, + force_reset: bool = False, + ) -> List[MultiAgentEpisode]: + """Runs and returns a sample (n timesteps or m episodes) on the env(s). + + Args: + num_timesteps: The number of timesteps to sample during this call. + Note that only one of `num_timetseps` or `num_episodes` may be provided. + num_episodes: The number of episodes to sample during this call. + Note that only one of `num_timetseps` or `num_episodes` may be provided. + explore: If True, will use the RLModule's `forward_exploration()` + method to compute actions. If False, will use the RLModule's + `forward_inference()` method. If None (default), will use the `explore` + boolean setting from `self.config` passed into this EnvRunner's + constructor. You can change this setting in your config via + `config.env_runners(explore=True|False)`. + random_actions: If True, actions will be sampled randomly (from the action + space of the environment). If False (default), actions or action + distribution parameters are computed by the RLModule. + force_reset: Whether to force-reset all (vector) environments before + sampling. Useful if you would like to collect a clean slate of new + episodes via this call. Note that when sampling n episodes + (`num_episodes != None`), this is fixed to True. + + Returns: + A list of `MultiAgentEpisode` instances, carrying the sampled data. + """ + assert not (num_timesteps is not None and num_episodes is not None) + + # Log time between `sample()` requests. + if self._time_after_sampling is not None: + self.metrics.log_value( + key=TIME_BETWEEN_SAMPLING, + value=time.perf_counter() - self._time_after_sampling, + ) + + # If no execution details are provided, use the config to try to infer the + # desired timesteps/episodes to sample and the exploration behavior. + if explore is None: + explore = self.config.explore + if num_timesteps is None and num_episodes is None: + if self.config.batch_mode == "truncate_episodes": + num_timesteps = self.config.get_rollout_fragment_length( + worker_index=self.worker_index, + ) + else: + num_episodes = 1 + + # Sample n timesteps. + if num_timesteps is not None: + samples = self._sample_timesteps( + num_timesteps=num_timesteps, + explore=explore, + random_actions=random_actions, + force_reset=force_reset, + ) + # Sample m episodes. + else: + samples = self._sample_episodes( + num_episodes=num_episodes, + explore=explore, + random_actions=random_actions, + ) + + # Make the `on_sample_end` callback. + make_callback( + "on_sample_end", + callbacks_objects=self._callbacks, + callbacks_functions=self.config.callbacks_on_sample_end, + kwargs=dict( + env_runner=self, + metrics_logger=self.metrics, + samples=samples, + ), + ) + + self._time_after_sampling = time.perf_counter() + + return samples + + def _sample_timesteps( + self, + num_timesteps: int, + explore: bool, + random_actions: bool = False, + force_reset: bool = False, + ) -> List[MultiAgentEpisode]: + """Helper method to sample n timesteps. + + Args: + num_timesteps: int. Number of timesteps to sample during rollout. + explore: boolean. If in exploration or inference mode. Exploration + mode might for some algorithms provide extza model outputs that + are redundant in inference mode. + random_actions: boolean. If actions should be sampled from the action + space. In default mode (i.e. `False`) we sample actions frokm the + policy. + + Returns: + `Lists of `MultiAgentEpisode` instances, carrying the collected sample data. + """ + done_episodes_to_return: List[MultiAgentEpisode] = [] + + # Have to reset the env. + if force_reset or self._needs_initial_reset: + # Create n new episodes and make the `on_episode_created` callbacks. + self._episode = self._new_episode() + self._make_on_episode_callback("on_episode_created") + + # Erase all cached ongoing episodes (these will never be completed and + # would thus never be returned/cleaned by `get_metrics` and cause a memory + # leak). + self._ongoing_episodes_for_metrics.clear() + + # Try resetting the environment. + # TODO (simon): Check, if we need here the seed from the config. + obs, infos = self._try_env_reset() + + self._cached_to_module = None + + # Call `on_episode_start()` callbacks. + self._make_on_episode_callback("on_episode_start") + + # We just reset the env. Don't have to force this again in the next + # call to `self._sample_timesteps()`. + self._needs_initial_reset = False + + # Set the initial observations in the episodes. + self._episode.add_env_reset(observations=obs, infos=infos) + + self._shared_data = { + "agent_to_module_mapping_fn": self.config.policy_mapping_fn, + } + + # Loop through timesteps. + ts = 0 + + while ts < num_timesteps: + # Act randomly. + if random_actions: + # Only act (randomly) for those agents that had an observation. + to_env = { + Columns.ACTIONS: [ + { + aid: self.env.unwrapped.get_action_space(aid).sample() + for aid in self._episode.get_agents_to_act() + } + ] + } + # Compute an action using the RLModule. + else: + # Env-to-module connector. + to_module = self._cached_to_module or self._env_to_module( + rl_module=self.module, + episodes=[self._episode], + explore=explore, + shared_data=self._shared_data, + metrics=self.metrics, + ) + self._cached_to_module = None + + # MultiRLModule forward pass: Explore or not. + if explore: + env_steps_lifetime = ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0) + ) * (self.config.num_env_runners or 1) + to_env = self.module.forward_exploration( + to_module, t=env_steps_lifetime + ) + else: + to_env = self.module.forward_inference(to_module) + + # Module-to-env connector. + to_env = self._module_to_env( + rl_module=self.module, + batch=to_env, + episodes=[self._episode], + explore=explore, + shared_data=self._shared_data, + metrics=self.metrics, + ) + + # Extract the (vectorized) actions (to be sent to the env) from the + # module/connector output. Note that these actions are fully ready (e.g. + # already unsquashed/clipped) to be sent to the environment) and might not + # be identical to the actions produced by the RLModule/distribution, which + # are the ones stored permanently in the episode objects. + actions = to_env.pop(Columns.ACTIONS) + actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions) + + # Try stepping the environment. + # TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env. + # Support vectorized multi-agent envs. + results = self._try_env_step(actions_for_env[0]) + # If any failure occurs during stepping -> Throw away all data collected + # thus far and restart sampling procedure. + if results == ENV_STEP_FAILURE: + return self._sample_timesteps( + num_timesteps=num_timesteps, + explore=explore, + random_actions=random_actions, + force_reset=True, + ) + obs, rewards, terminateds, truncateds, infos = results + + # TODO (sven): This simple approach to re-map `to_env` from a + # dict[col, List[MADict]] to a dict[agentID, MADict] would not work for + # a vectorized env. + extra_model_outputs = defaultdict(dict) + for col, ma_dict_list in to_env.items(): + # TODO (sven): Support vectorized MA env. + ma_dict = ma_dict_list[0] + for agent_id, val in ma_dict.items(): + extra_model_outputs[agent_id][col] = val + extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no + extra_model_outputs = dict(extra_model_outputs) + + # Record the timestep in the episode instance. + self._episode.add_env_step( + obs, + actions[0], + rewards, + infos=infos, + terminateds=terminateds, + truncateds=truncateds, + extra_model_outputs=extra_model_outputs, + ) + + ts += self._increase_sampled_metrics(self.num_envs, obs, self._episode) + + # Make the `on_episode_step` callback (before finalizing the episode + # object). + self._make_on_episode_callback("on_episode_step") + + # Episode is done for all agents. Wrap up the old one and create a new + # one (and reset it) to continue. + if self._episode.is_done: + # We have to perform an extra env-to-module pass here, just in case + # the user's connector pipeline performs (permanent) transforms + # on each observation (including this final one here). Without such + # a call and in case the structure of the observations change + # sufficiently, the following `to_numpy()` call on the episode will + # fail. + if self.module is not None: + self._env_to_module( + episodes=[self._episode], + explore=explore, + rl_module=self.module, + shared_data=self._shared_data, + metrics=self.metrics, + ) + + # Make the `on_episode_end` callback (before finalizing the episode, + # but after(!) the last env-to-module connector call has been made. + # -> All obs (even the terminal one) should have been processed now (by + # the connector, if applicable). + self._make_on_episode_callback("on_episode_end") + + self._prune_zero_len_sa_episodes(self._episode) + + # Numpy'ize the episode. + if self.config.episodes_to_numpy: + done_episodes_to_return.append(self._episode.to_numpy()) + # Leave episode as lists of individual (obs, action, etc..) items. + else: + done_episodes_to_return.append(self._episode) + + # Create a new episode instance. + self._episode = self._new_episode() + self._make_on_episode_callback("on_episode_created") + + # Reset the environment. + obs, infos = self._try_env_reset() + # Add initial observations and infos. + self._episode.add_env_reset(observations=obs, infos=infos) + + # Make the `on_episode_start` callback. + self._make_on_episode_callback("on_episode_start") + + # Already perform env-to-module connector call for next call to + # `_sample_timesteps()`. See comment in c'tor for `self._cached_to_module`. + if self.module is not None: + self._cached_to_module = self._env_to_module( + rl_module=self.module, + episodes=[self._episode], + explore=explore, + shared_data=self._shared_data, + metrics=self.metrics, + ) + + # Store done episodes for metrics. + self._done_episodes_for_metrics.extend(done_episodes_to_return) + + # Also, make sure we start new episode chunks (continuing the ongoing episodes + # from the to-be-returned chunks). + ongoing_episode_continuation = self._episode.cut( + len_lookback_buffer=self.config.episode_lookback_horizon + ) + + ongoing_episodes_to_return = [] + # Just started Episodes do not have to be returned. There is no data + # in them anyway. + if self._episode.env_t > 0: + self._episode.validate() + self._ongoing_episodes_for_metrics[self._episode.id_].append(self._episode) + + self._prune_zero_len_sa_episodes(self._episode) + + # Numpy'ize the episode. + if self.config.episodes_to_numpy: + ongoing_episodes_to_return.append(self._episode.to_numpy()) + # Leave episode as lists of individual (obs, action, etc..) items. + else: + ongoing_episodes_to_return.append(self._episode) + + # Continue collecting into the cut Episode chunk. + self._episode = ongoing_episode_continuation + + # Return collected episode data. + return done_episodes_to_return + ongoing_episodes_to_return + + def _sample_episodes( + self, + num_episodes: int, + explore: bool, + random_actions: bool = False, + ) -> List[MultiAgentEpisode]: + """Helper method to run n episodes. + + See docstring of `self.sample()` for more details. + """ + # If user calls sample(num_timesteps=..) after this, we must reset again + # at the beginning. + self._needs_initial_reset = True + + done_episodes_to_return: List[MultiAgentEpisode] = [] + + # Create a new multi-agent episode. + _episode = self._new_episode() + self._make_on_episode_callback("on_episode_created", _episode) + _shared_data = { + "agent_to_module_mapping_fn": self.config.policy_mapping_fn, + } + + # Try resetting the environment. + # TODO (simon): Check, if we need here the seed from the config. + obs, infos = self._try_env_reset() + # Set initial obs and infos in the episodes. + _episode.add_env_reset(observations=obs, infos=infos) + self._make_on_episode_callback("on_episode_start", _episode) + + # Loop over episodes. + eps = 0 + ts = 0 + while eps < num_episodes: + # Act randomly. + if random_actions: + # Only act (randomly) for those agents that had an observation. + to_env = { + Columns.ACTIONS: [ + { + aid: self.env.unwrapped.get_action_space(aid).sample() + for aid in self._episode.get_agents_to_act() + } + ] + } + # Compute an action using the RLModule. + else: + # Env-to-module connector. + to_module = self._env_to_module( + rl_module=self.module, + episodes=[_episode], + explore=explore, + shared_data=_shared_data, + metrics=self.metrics, + ) + + # MultiRLModule forward pass: Explore or not. + if explore: + env_steps_lifetime = ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0) + ) * (self.config.num_env_runners or 1) + to_env = self.module.forward_exploration( + to_module, t=env_steps_lifetime + ) + else: + to_env = self.module.forward_inference(to_module) + + # Module-to-env connector. + to_env = self._module_to_env( + rl_module=self.module, + batch=to_env, + episodes=[_episode], + explore=explore, + shared_data=_shared_data, + metrics=self.metrics, + ) + + # Extract the (vectorized) actions (to be sent to the env) from the + # module/connector output. Note that these actions are fully ready (e.g. + # already unsquashed/clipped) to be sent to the environment) and might not + # be identical to the actions produced by the RLModule/distribution, which + # are the ones stored permanently in the episode objects. + actions = to_env.pop(Columns.ACTIONS) + actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions) + + # Try stepping the environment. + # TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env. + # Support vectorized multi-agent envs. + results = self._try_env_step(actions_for_env[0]) + # If any failure occurs during stepping -> Throw away all data collected + # thus far and restart sampling procedure. + if results == ENV_STEP_FAILURE: + return self._sample_episodes( + num_episodes=num_episodes, + explore=explore, + random_actions=random_actions, + ) + obs, rewards, terminateds, truncateds, infos = results + + # TODO (sven): This simple approach to re-map `to_env` from a + # dict[col, List[MADict]] to a dict[agentID, MADict] would not work for + # a vectorized env. + extra_model_outputs = defaultdict(dict) + for col, ma_dict_list in to_env.items(): + # TODO (sven): Support vectorized MA env. + ma_dict = ma_dict_list[0] + for agent_id, val in ma_dict.items(): + extra_model_outputs[agent_id][col] = val + extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no + extra_model_outputs = dict(extra_model_outputs) + + # Record the timestep in the episode instance. + _episode.add_env_step( + obs, + actions[0], + rewards, + infos=infos, + terminateds=terminateds, + truncateds=truncateds, + extra_model_outputs=extra_model_outputs, + ) + + ts += self._increase_sampled_metrics(self.num_envs, obs, _episode) + + # Make `on_episode_step` callback before finalizing the episode. + self._make_on_episode_callback("on_episode_step", _episode) + + # TODO (sven, simon): We have to check, if we need this elaborate + # function here or if the `MultiAgentEnv` defines the cases that + # can happen. + # Right now we have: + # 1. Most times only agents that step get `terminated`, `truncated` + # i.e. the rest we have to check in the episode. + # 2. There are edge cases like, some agents terminated, all others + # truncated and vice versa. + # See also `MultiAgentEpisode` for handling the `__all__`. + if _episode.is_done: + # Increase episode count. + eps += 1 + + # We have to perform an extra env-to-module pass here, just in case + # the user's connector pipeline performs (permanent) transforms + # on each observation (including this final one here). Without such + # a call and in case the structure of the observations change + # sufficiently, the following `to_numpy()` call on the episode will + # fail. + if self.module is not None: + self._env_to_module( + episodes=[_episode], + explore=explore, + rl_module=self.module, + shared_data=_shared_data, + metrics=self.metrics, + ) + + # Make the `on_episode_end` callback (before finalizing the episode, + # but after(!) the last env-to-module connector call has been made. + # -> All obs (even the terminal one) should have been processed now (by + # the connector, if applicable). + self._make_on_episode_callback("on_episode_end", _episode) + + self._prune_zero_len_sa_episodes(_episode) + + # Numpy'ize the episode. + if self.config.episodes_to_numpy: + done_episodes_to_return.append(_episode.to_numpy()) + # Leave episode as lists of individual (obs, action, etc..) items. + else: + done_episodes_to_return.append(_episode) + + # Also early-out if we reach the number of episodes within this + # for-loop. + if eps == num_episodes: + break + + # Create a new episode instance. + _episode = self._new_episode() + self._make_on_episode_callback("on_episode_created", _episode) + + # Try resetting the environment. + obs, infos = self._try_env_reset() + # Add initial observations and infos. + _episode.add_env_reset(observations=obs, infos=infos) + + # Make `on_episode_start` callback. + self._make_on_episode_callback("on_episode_start", _episode) + + self._done_episodes_for_metrics.extend(done_episodes_to_return) + + return done_episodes_to_return + + @override(EnvRunner) + def get_spaces(self): + # Return the already agent-to-module translated spaces from our connector + # pipeline. + return { + **{ + mid: (o, self._env_to_module.action_space[mid]) + for mid, o in self._env_to_module.observation_space.spaces.items() + }, + } + + @override(EnvRunner) + def get_metrics(self) -> ResultDict: + # Compute per-episode metrics (only on already completed episodes). + for eps in self._done_episodes_for_metrics: + assert eps.is_done + episode_length = len(eps) + agent_steps = defaultdict( + int, + {str(aid): len(sa_eps) for aid, sa_eps in eps.agent_episodes.items()}, + ) + episode_return = eps.get_return() + episode_duration_s = eps.get_duration_s() + + agent_episode_returns = defaultdict( + float, + { + str(sa_eps.agent_id): sa_eps.get_return() + for sa_eps in eps.agent_episodes.values() + }, + ) + module_episode_returns = defaultdict( + float, + { + sa_eps.module_id: sa_eps.get_return() + for sa_eps in eps.agent_episodes.values() + }, + ) + + # Don't forget about the already returned chunks of this episode. + if eps.id_ in self._ongoing_episodes_for_metrics: + for eps2 in self._ongoing_episodes_for_metrics[eps.id_]: + return_eps2 = eps2.get_return() + episode_length += len(eps2) + episode_return += return_eps2 + episode_duration_s += eps2.get_duration_s() + + for sa_eps in eps2.agent_episodes.values(): + return_sa = sa_eps.get_return() + agent_steps[str(sa_eps.agent_id)] += len(sa_eps) + agent_episode_returns[str(sa_eps.agent_id)] += return_sa + module_episode_returns[sa_eps.module_id] += return_sa + + del self._ongoing_episodes_for_metrics[eps.id_] + + self._log_episode_metrics( + episode_length, + episode_return, + episode_duration_s, + agent_episode_returns, + module_episode_returns, + dict(agent_steps), + ) + + # Now that we have logged everything, clear cache of done episodes. + self._done_episodes_for_metrics.clear() + + # Return reduced metrics. + return self.metrics.reduce() + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + # Basic state dict. + state = { + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + ), + } + + # RLModule (MultiRLModule) component. + if self._check_component(COMPONENT_RL_MODULE, components, not_components): + state[COMPONENT_RL_MODULE] = self.module.get_state( + components=self._get_subcomponents(COMPONENT_RL_MODULE, components), + not_components=self._get_subcomponents( + COMPONENT_RL_MODULE, not_components + ), + **kwargs, + ) + state[WEIGHTS_SEQ_NO] = self._weights_seq_no + + # Env-to-module connector. + if self._check_component( + COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components + ): + state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state() + # Module-to-env connector. + if self._check_component( + COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components + ): + state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state() + + return state + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + if COMPONENT_ENV_TO_MODULE_CONNECTOR in state: + self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR]) + if COMPONENT_MODULE_TO_ENV_CONNECTOR in state: + self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR]) + + # Update RLModule state. + if COMPONENT_RL_MODULE in state: + # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the + # update. + weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0) + + # Only update the weigths, if this is the first synchronization or + # if the weights of this `EnvRunner` lacks behind the actual ones. + if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no: + self.module.set_state(state[COMPONENT_RL_MODULE]) + + # Update weights_seq_no, if the new one is > 0. + if weights_seq_no > 0: + self._weights_seq_no = weights_seq_no + + # Update lifetime counters. + if NUM_ENV_STEPS_SAMPLED_LIFETIME in state: + self.metrics.set_value( + key=NUM_ENV_STEPS_SAMPLED_LIFETIME, + value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME], + reduce="sum", + with_throughput=True, + ) + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args + {"config": self.config}, # **kwargs + ) + + @override(Checkpointable) + def get_metadata(self): + metadata = Checkpointable.get_metadata(self) + metadata.update( + { + # TODO (sven): Maybe add serialized (JSON-writable) config here? + } + ) + return metadata + + @override(Checkpointable) + def get_checkpointable_components(self): + return [ + (COMPONENT_RL_MODULE, self.module), + (COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module), + (COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env), + ] + + @override(EnvRunner) + def assert_healthy(self): + """Checks that self.__init__() has been completed properly. + + Ensures that the instances has a `MultiRLModule` and an + environment defined. + + Raises: + AssertionError: If the EnvRunner Actor has NOT been properly initialized. + """ + # Make sure, we have built our gym.vector.Env and RLModule properly. + assert self.env and self.module + + @override(EnvRunner) + def make_env(self): + # If an env already exists, try closing it first (to allow it to properly + # cleanup). + if self.env is not None: + try: + self.env.close() + except Exception as e: + logger.warning( + "Tried closing the existing env (multi-agent), but failed with " + f"error: {e.args[0]}" + ) + del self.env + + env_ctx = self.config.env_config + if not isinstance(env_ctx, EnvContext): + env_ctx = EnvContext( + env_ctx, + worker_index=self.worker_index, + num_workers=self.config.num_env_runners, + remote=self.config.remote_worker_envs, + ) + + # No env provided -> Error. + if not self.config.env: + raise ValueError( + "`config.env` is not provided! You should provide a valid environment " + "to your config through `config.environment([env descriptor e.g. " + "'CartPole-v1'])`." + ) + # Register env for the local context. + # Note, `gym.register` has to be called on each worker. + elif isinstance(self.config.env, str) and _global_registry.contains( + ENV_CREATOR, self.config.env + ): + entry_point = partial( + _global_registry.get(ENV_CREATOR, self.config.env), + env_ctx, + ) + else: + entry_point = partial( + _gym_env_creator, + env_descriptor=self.config.env, + env_context=env_ctx, + ) + gym.register( + "rllib-multi-agent-env-v0", + entry_point=entry_point, + disable_env_checker=True, + ) + + # Perform actual gym.make call. + self.env: MultiAgentEnv = gym.make("rllib-multi-agent-env-v0") + self.num_envs = 1 + # If required, check the created MultiAgentEnv. + if not self.config.disable_env_checking: + try: + check_multiagent_environments(self.env.unwrapped) + except Exception as e: + logger.exception(e.args[0]) + # If not required, still check the type (must be MultiAgentEnv). + else: + assert isinstance(self.env.unwrapped, MultiAgentEnv), ( + "ERROR: When using the `MultiAgentEnvRunner` the environment needs " + "to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`." + ) + + # Set the flag to reset all envs upon the next `sample()` call. + self._needs_initial_reset = True + + # Call the `on_environment_created` callback. + make_callback( + "on_environment_created", + callbacks_objects=self._callbacks, + callbacks_functions=self.config.callbacks_on_environment_created, + kwargs=dict( + env_runner=self, + metrics_logger=self.metrics, + env=self.env.unwrapped, + env_context=env_ctx, + ), + ) + + @override(EnvRunner) + def make_module(self): + try: + module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( + env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True + ) + # Build the module from its spec. + self.module = module_spec.build() + # Move the RLModule to our device. + # TODO (sven): In order to make this framework-agnostic, we should maybe + # make the MultiRLModule.build() method accept a device OR create an + # additional `(Multi)RLModule.to()` override. + if torch: + self.module.foreach_module( + lambda mid, mod: ( + mod.to(self._device) + if isinstance(mod, torch.nn.Module) + else mod + ) + ) + + # If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner + # will not have an RLModule, but might still be usable with random actions. + except NotImplementedError: + self.module = None + + @override(EnvRunner) + def stop(self): + # Note, `MultiAgentEnv` inherits `close()`-method from `gym.Env`. + self.env.close() + + def _setup_metrics(self): + self.metrics = MetricsLogger() + + self._done_episodes_for_metrics: List[MultiAgentEpisode] = [] + self._ongoing_episodes_for_metrics: DefaultDict[ + EpisodeID, List[MultiAgentEpisode] + ] = defaultdict(list) + + def _new_episode(self): + return MultiAgentEpisode( + observation_space={ + aid: self.env.unwrapped.get_observation_space(aid) + for aid in self.env.unwrapped.possible_agents + }, + action_space={ + aid: self.env.unwrapped.get_action_space(aid) + for aid in self.env.unwrapped.possible_agents + }, + agent_to_module_mapping_fn=self.config.policy_mapping_fn, + ) + + def _make_on_episode_callback(self, which: str, episode=None): + episode = episode if episode is not None else self._episode + make_callback( + which, + callbacks_objects=self._callbacks, + callbacks_functions=getattr(self.config, f"callbacks_{which}"), + kwargs=dict( + episode=episode, + env_runner=self, + metrics_logger=self.metrics, + env=self.env.unwrapped, + rl_module=self.module, + env_index=0, + ), + ) + + def _increase_sampled_metrics(self, num_steps, next_obs, episode): + # Env steps. + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True + ) + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + num_steps, + reduce="sum", + with_throughput=True, + ) + # Completed episodes. + if episode.is_done: + self.metrics.log_value(NUM_EPISODES, 1, reduce="sum", clear_on_reduce=True) + self.metrics.log_value(NUM_EPISODES_LIFETIME, 1, reduce="sum") + + # TODO (sven): obs is not-vectorized. Support vectorized MA envs. + for aid in next_obs: + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED, str(aid)), + 1, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED_LIFETIME, str(aid)), + 1, + reduce="sum", + ) + self.metrics.log_value( + (NUM_MODULE_STEPS_SAMPLED, episode.module_for(aid)), + 1, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_MODULE_STEPS_SAMPLED_LIFETIME, episode.module_for(aid)), + 1, + reduce="sum", + ) + return num_steps + + def _log_episode_metrics( + self, + length, + ret, + sec, + agents=None, + modules=None, + agent_steps=None, + ): + # Log general episode metrics. + self.metrics.log_dict( + { + EPISODE_LEN_MEAN: length, + EPISODE_RETURN_MEAN: ret, + EPISODE_DURATION_SEC_MEAN: sec, + **( + { + # Per-agent returns. + "agent_episode_returns_mean": agents, + # Per-RLModule returns. + "module_episode_returns_mean": modules, + "agent_steps": agent_steps, + } + if agents is not None + else {} + ), + }, + # To mimick the old API stack behavior, we'll use `window` here for + # these particular stats (instead of the default EMA). + window=self.config.metrics_num_episodes_for_smoothing, + ) + # For some metrics, log min/max as well. + self.metrics.log_dict( + { + EPISODE_LEN_MIN: length, + EPISODE_RETURN_MIN: ret, + }, + reduce="min", + window=self.config.metrics_num_episodes_for_smoothing, + ) + self.metrics.log_dict( + { + EPISODE_LEN_MAX: length, + EPISODE_RETURN_MAX: ret, + }, + reduce="max", + window=self.config.metrics_num_episodes_for_smoothing, + ) + + @staticmethod + def _prune_zero_len_sa_episodes(episode: MultiAgentEpisode): + for agent_id, agent_eps in episode.agent_episodes.copy().items(): + if len(agent_eps) == 0: + del episode.agent_episodes[agent_id] + + @Deprecated( + new="MultiAgentEnvRunner.get_state(components='rl_module')", + error=False, + ) + def get_weights(self, modules=None): + rl_module_state = self.get_state(components=COMPONENT_RL_MODULE)[ + COMPONENT_RL_MODULE + ] + return rl_module_state + + @Deprecated(new="MultiAgentEnvRunner.set_state()", error=False) + def set_weights( + self, + weights: ModelWeights, + global_vars: Optional[Dict] = None, + weights_seq_no: int = 0, + ) -> None: + assert global_vars is None + return self.set_state( + { + COMPONENT_RL_MODULE: weights, + WEIGHTS_SEQ_NO: weights_seq_no, + } + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py new file mode 100644 index 0000000000000000000000000000000000000000..d827b1c55cd9d26682e1b099a741df74810cb3c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py @@ -0,0 +1,2611 @@ +from collections import defaultdict +import copy +import time +from typing import ( + Any, + Callable, + Collection, + DefaultDict, + Dict, + List, + Optional, + Set, + Union, +) +import uuid + +import gymnasium as gym + +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils import force_list +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.error import MultiAgentEnvError +from ray.rllib.utils.spaces.space_utils import batch +from ray.rllib.utils.typing import AgentID, ModuleID, MultiAgentDict +from ray.util.annotations import PublicAPI + + +# TODO (simon): Include cases in which the number of agents in an +# episode are shrinking or growing during the episode itself. +@PublicAPI(stability="alpha") +class MultiAgentEpisode: + """Stores multi-agent episode data. + + The central attribute of the class is the timestep mapping + `self.env_t_to_agent_t` that maps AgentIDs to their specific environment steps to + the agent's own scale/timesteps. + + Each AgentID in the `MultiAgentEpisode` has its own `SingleAgentEpisode` object + in which this agent's data is stored. Together with the env_t_to_agent_t mapping, + we can extract information either on any individual agent's time scale or from + the (global) multi-agent environment time scale. + + Extraction of data from a MultiAgentEpisode happens via the getter APIs, e.g. + `get_observations()`, which work analogous to the ones implemented in the + `SingleAgentEpisode` class. + + Note that recorded `terminateds`/`truncateds` come as simple + `MultiAgentDict`s mapping AgentID to bools and thus have no assignment to a + certain timestep (analogous to a SingleAgentEpisode's single `terminated/truncated` + boolean flag). Instead we assign it to the last observation recorded. + Theoretically, there could occur edge cases in some environments + where an agent receives partial rewards and then terminates without + a last observation. In these cases, we duplicate the last observation. + + Also, if no initial observation has been received yet for an agent, but + some rewards for this same agent already occurred, we delete the agent's data + up to here, b/c there is nothing to learn from these "premature" rewards. + """ + + __slots__ = ( + "id_", + "agent_to_module_mapping_fn", + "_agent_to_module_mapping", + "observation_space", + "action_space", + "env_t_started", + "env_t", + "agent_t_started", + "env_t_to_agent_t", + "_hanging_actions_end", + "_hanging_extra_model_outputs_end", + "_hanging_rewards_end", + "_hanging_rewards_begin", + "is_terminated", + "is_truncated", + "agent_episodes", + "_last_step_time", + "_len_lookback_buffers", + "_start_time", + "_temporary_timestep_data", + ) + + SKIP_ENV_TS_TAG = "S" + + def __init__( + self, + id_: Optional[str] = None, + *, + observations: Optional[List[MultiAgentDict]] = None, + observation_space: Optional[gym.Space] = None, + infos: Optional[List[MultiAgentDict]] = None, + actions: Optional[List[MultiAgentDict]] = None, + action_space: Optional[gym.Space] = None, + rewards: Optional[List[MultiAgentDict]] = None, + terminateds: Union[MultiAgentDict, bool] = False, + truncateds: Union[MultiAgentDict, bool] = False, + extra_model_outputs: Optional[List[MultiAgentDict]] = None, + env_t_started: Optional[int] = None, + agent_t_started: Optional[Dict[AgentID, int]] = None, + len_lookback_buffer: Union[int, str] = "auto", + agent_episode_ids: Optional[Dict[AgentID, str]] = None, + agent_module_ids: Optional[Dict[AgentID, ModuleID]] = None, + agent_to_module_mapping_fn: Optional[ + Callable[[AgentID, "MultiAgentEpisode"], ModuleID] + ] = None, + ): + """Initializes a `MultiAgentEpisode`. + + Args: + id_: Optional. Either a string to identify an episode or None. + If None, a hexadecimal id is created. In case of providing + a string, make sure that it is unique, as episodes get + concatenated via this string. + observations: A list of dictionaries mapping agent IDs to observations. + Can be None. If provided, should match all other episode data + (actions, rewards, etc.) in terms of list lengths and agent IDs. + observation_space: An optional gym.spaces.Dict mapping agent IDs to + individual agents' spaces, which all (individual agents') observations + should abide to. If not None and this MultiAgentEpisode is numpy'ized + (via the `self.to_numpy()` method), and data is appended or set, the new + data will be checked for correctness. + infos: A list of dictionaries mapping agent IDs to info dicts. + Can be None. If provided, should match all other episode data + (observations, rewards, etc.) in terms of list lengths and agent IDs. + actions: A list of dictionaries mapping agent IDs to actions. + Can be None. If provided, should match all other episode data + (observations, rewards, etc.) in terms of list lengths and agent IDs. + action_space: An optional gym.spaces.Dict mapping agent IDs to + individual agents' spaces, which all (individual agents') actions + should abide to. If not None and this MultiAgentEpisode is numpy'ized + (via the `self.to_numpy()` method), and data is appended or set, the new + data will be checked for correctness. + rewards: A list of dictionaries mapping agent IDs to rewards. + Can be None. If provided, should match all other episode data + (actions, rewards, etc.) in terms of list lengths and agent IDs. + terminateds: A boolean defining if an environment has + terminated OR a MultiAgentDict mapping individual agent ids + to boolean flags indicating whether individual agents have terminated. + A special __all__ key in these dicts indicates, whether the episode + is terminated for all agents. + The default is `False`, i.e. the episode has not been terminated. + truncateds: A boolean defining if the environment has been + truncated OR a MultiAgentDict mapping individual agent ids + to boolean flags indicating whether individual agents have been + truncated. A special __all__ key in these dicts indicates, whether the + episode is truncated for all agents. + The default is `False`, i.e. the episode has not been truncated. + extra_model_outputs: A list of dictionaries mapping agent IDs to their + corresponding extra model outputs. Each of these "outputs" is a dict + mapping keys (str) to model output values, for example for + `key=STATE_OUT`, the values would be the internal state outputs for + that agent. + env_t_started: The env timestep (int) that defines the starting point + of the episode. This is only larger zero, if an already ongoing episode + chunk is being created, for example by slicing an ongoing episode or + by calling the `cut()` method on an ongoing episode. + agent_t_started: A dict mapping AgentIDs to the respective agent's (local) + timestep at which its SingleAgentEpisode chunk started. + len_lookback_buffer: The size of the lookback buffers to keep in + front of this Episode for each type of data (observations, actions, + etc..). If larger 0, will interpret the first `len_lookback_buffer` + items in each type of data as NOT part of this actual + episode chunk, but instead serve as "historical" record that may be + viewed and used to derive new data from. For example, it might be + necessary to have a lookback buffer of four if you would like to do + observation frame stacking and your episode has been cut and you are now + operating on a new chunk (continuing from the cut one). Then, for the + first 3 items, you would have to be able to look back into the old + chunk's data. + If `len_lookback_buffer` is "auto" (default), will interpret all + provided data in the constructor as part of the lookback buffers. + agent_episode_ids: An optional dict mapping AgentIDs + to their corresponding `SingleAgentEpisode`. If None, each + `SingleAgentEpisode` in `MultiAgentEpisode.agent_episodes` + will generate a hexadecimal code. If a dictionary is provided, + make sure that IDs are unique, because the agents' `SingleAgentEpisode` + instances are concatenated or recreated by it. + agent_module_ids: An optional dict mapping AgentIDs to their respective + ModuleIDs (these mapping are always valid for an entire episode and + thus won't change during the course of this episode). If a mapping from + agent to module has already been provided via this dict, the (optional) + `agent_to_module_mapping_fn` will NOT be used again to map the same + agent (agents do not change their assigned module in the course of + one episode). + agent_to_module_mapping_fn: A callable taking an AgentID and a + MultiAgentEpisode as args and returning a ModuleID. Used to map agents + that have not been mapped yet (because they just entered this episode) + to a ModuleID. The resulting ModuleID is only stored inside the agent's + SingleAgentEpisode object. + """ + self.id_: str = id_ or uuid.uuid4().hex + if agent_to_module_mapping_fn is None: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + agent_to_module_mapping_fn = ( + AlgorithmConfig.DEFAULT_AGENT_TO_MODULE_MAPPING_FN + ) + self.agent_to_module_mapping_fn = agent_to_module_mapping_fn + # In case a user - e.g. via callbacks - already forces a mapping to happen + # via the `module_for()` API even before the agent has entered the episode + # (and has its SingleAgentEpisode created), we store all aldeary done mappings + # in this dict here. + self._agent_to_module_mapping: Dict[AgentID, ModuleID] = agent_module_ids or {} + + # Lookback buffer length is not provided. Interpret all provided data as + # lookback buffer. + if len_lookback_buffer == "auto": + len_lookback_buffer = len(rewards or []) + self._len_lookback_buffers = len_lookback_buffer + + self.observation_space = observation_space or {} + self.action_space = action_space or {} + + terminateds = terminateds or {} + truncateds = truncateds or {} + + # The global last timestep of the episode and the timesteps when this chunk + # started (excluding a possible lookback buffer). + self.env_t_started = env_t_started or 0 + self.env_t = ( + (len(rewards) if rewards is not None else 0) + - self._len_lookback_buffers + + self.env_t_started + ) + self.agent_t_started = defaultdict(int, agent_t_started or {}) + + # Keeps track of the correspondence between agent steps and environment steps. + # Under each AgentID as key is a InfiniteLookbackBuffer with the following + # data in it: + # The indices of the items in the data represent environment timesteps, + # starting from index=0 for the `env.reset()` and with each `env.step()` call + # increase by 1. + # The values behind these (env timestep) indices represent the agent timesteps + # happening at these env timesteps and the special value of + # `self.SKIP_ENV_TS_TAG` means that the agent did NOT step at the given env + # timestep. + # Thus, agents that are part of the reset obs, will start their mapping data + # with a [0 ...], all other agents will start their mapping data with: + # [self.SKIP_ENV_TS_TAG, ...]. + self.env_t_to_agent_t: DefaultDict[ + AgentID, InfiniteLookbackBuffer + ] = defaultdict(InfiniteLookbackBuffer) + + # Create caches for hanging actions/rewards/extra_model_outputs. + # When an agent gets an observation (and then sends an action), but does not + # receive immediately a next observation, we store the "hanging" action (and + # related rewards and extra model outputs) in the caches postfixed w/ `_end` + # until the next observation is received. + self._hanging_actions_end = {} + self._hanging_extra_model_outputs_end = defaultdict(dict) + self._hanging_rewards_end = defaultdict(float) + + # In case of a `cut()` or `slice()`, we also need to store the hanging actions, + # rewards, and extra model outputs that were already "hanging" in preceeding + # episode slice. + self._hanging_rewards_begin = defaultdict(float) + + # If this is an ongoing episode than the last `__all__` should be `False` + self.is_terminated: bool = ( + terminateds + if isinstance(terminateds, bool) + else terminateds.get("__all__", False) + ) + + # If this is an ongoing episode than the last `__all__` should be `False` + self.is_truncated: bool = ( + truncateds + if isinstance(truncateds, bool) + else truncateds.get("__all__", False) + ) + + # The individual agent SingleAgentEpisode objects. + self.agent_episodes: Dict[AgentID, SingleAgentEpisode] = {} + self._init_single_agent_episodes( + agent_module_ids=agent_module_ids, + agent_episode_ids=agent_episode_ids, + observations=observations, + infos=infos, + actions=actions, + rewards=rewards, + terminateds=terminateds, + truncateds=truncateds, + extra_model_outputs=extra_model_outputs, + ) + + # Caches for temporary per-timestep data. May be used to store custom metrics + # from within a callback for the ongoing episode (e.g. render images). + self._temporary_timestep_data = defaultdict(list) + + # Keep timer stats on deltas between steps. + self._start_time = None + self._last_step_time = None + + # Validate ourselves. + self.validate() + + def add_env_reset( + self, + *, + observations: MultiAgentDict, + infos: Optional[MultiAgentDict] = None, + ) -> None: + """Stores initial observation. + + Args: + observations: A dictionary mapping agent IDs to initial observations. + Note that some agents may not have an initial observation. + infos: A dictionary mapping agent IDs to initial info dicts. + Note that some agents may not have an initial info dict. If not None, + the agent IDs in `infos` must be a subset of those in `observations` + meaning it would not be allowed to have an agent with an info dict, + but not with an observation. + """ + assert not self.is_done + # Assume that this episode is completely empty and has not stepped yet. + # Leave self.env_t (and self.env_t_started) at 0. + assert self.env_t == self.env_t_started == 0 + infos = infos or {} + + # Note, all agents will have an initial observation, some may have an initial + # info dict as well. + for agent_id, agent_obs in observations.items(): + # Update env_t_to_agent_t mapping (all agents that are part of the reset + # obs have their first mapping 0 (env_t) -> 0 (agent_t)). + self.env_t_to_agent_t[agent_id].append(0) + # Create SingleAgentEpisode, if necessary. + if agent_id not in self.agent_episodes: + self.agent_episodes[agent_id] = SingleAgentEpisode( + agent_id=agent_id, + module_id=self.module_for(agent_id), + multi_agent_episode_id=self.id_, + observation_space=self.observation_space.get(agent_id), + action_space=self.action_space.get(agent_id), + ) + # Add initial observations (and infos) to the agent's episode. + self.agent_episodes[agent_id].add_env_reset( + observation=agent_obs, + infos=infos.get(agent_id), + ) + + # Validate our data. + self.validate() + + # Start the timer for this episode. + self._start_time = time.perf_counter() + + def add_env_step( + self, + observations: MultiAgentDict, + actions: MultiAgentDict, + rewards: MultiAgentDict, + infos: Optional[MultiAgentDict] = None, + *, + terminateds: Optional[MultiAgentDict] = None, + truncateds: Optional[MultiAgentDict] = None, + extra_model_outputs: Optional[MultiAgentDict] = None, + ) -> None: + """Adds a timestep to the episode. + + Args: + observations: A dictionary mapping agent IDs to their corresponding + next observations. Note that some agents may not have stepped at this + timestep. + actions: Mandatory. A dictionary mapping agent IDs to their + corresponding actions. Note that some agents may not have stepped at + this timestep. + rewards: Mandatory. A dictionary mapping agent IDs to their + corresponding observations. Note that some agents may not have stepped + at this timestep. + infos: A dictionary mapping agent IDs to their + corresponding info. Note that some agents may not have stepped at this + timestep. + terminateds: A dictionary mapping agent IDs to their `terminated` flags, + indicating, whether the environment has been terminated for them. + A special `__all__` key indicates that the episode is terminated for + all agent IDs. + terminateds: A dictionary mapping agent IDs to their `truncated` flags, + indicating, whether the environment has been truncated for them. + A special `__all__` key indicates that the episode is `truncated` for + all agent IDs. + extra_model_outputs: A dictionary mapping agent IDs to their + corresponding specific model outputs (also in a dictionary; e.g. + `vf_preds` for PPO). + """ + # Cannot add data to an already done episode. + if self.is_done: + raise MultiAgentEnvError( + "Cannot call `add_env_step` on a MultiAgentEpisode that is already " + "done!" + ) + + infos = infos or {} + terminateds = terminateds or {} + truncateds = truncateds or {} + extra_model_outputs = extra_model_outputs or {} + + # Increase (global) env step by one. + self.env_t += 1 + + # Find out, whether this episode is terminated/truncated (for all agents). + # Case 1: all agents are terminated or all are truncated. + self.is_terminated = terminateds.get("__all__", False) + self.is_truncated = truncateds.get("__all__", False) + # Find all agents that were done at prior timesteps and add the agents that are + # done at the present timestep. + agents_done = set( + [aid for aid, sa_eps in self.agent_episodes.items() if sa_eps.is_done] + + [aid for aid in terminateds if terminateds[aid]] + + [aid for aid in truncateds if truncateds[aid]] + ) + # Case 2: Some agents are truncated and the others are terminated -> Declare + # this episode as terminated. + if all(aid in set(agents_done) for aid in self.agent_ids): + self.is_terminated = True + + # For all agents that are not stepping in this env step, but that are not done + # yet -> Add a skip tag to their env- to agent-step mappings. + stepped_agent_ids = set(observations.keys()) + for agent_id, env_t_to_agent_t in self.env_t_to_agent_t.items(): + if agent_id not in stepped_agent_ids: + env_t_to_agent_t.append(self.SKIP_ENV_TS_TAG) + + # Loop through all agent IDs that we received data for in this step: + # Those found in observations, actions, and rewards. + agent_ids_with_data = ( + set(observations.keys()) + | set(actions.keys()) + | set(rewards.keys()) + | set(terminateds.keys()) + | set(truncateds.keys()) + | set( + self.agent_episodes.keys() + if terminateds.get("__all__") or truncateds.get("__all__") + else set() + ) + ) - {"__all__"} + for agent_id in agent_ids_with_data: + if agent_id not in self.agent_episodes: + sa_episode = SingleAgentEpisode( + agent_id=agent_id, + module_id=self.module_for(agent_id), + multi_agent_episode_id=self.id_, + observation_space=self.observation_space.get(agent_id), + action_space=self.action_space.get(agent_id), + ) + else: + sa_episode = self.agent_episodes[agent_id] + + # Collect value to be passed (at end of for-loop) into `add_env_step()` + # call. + _observation = observations.get(agent_id) + _action = actions.get(agent_id) + _reward = rewards.get(agent_id) + _infos = infos.get(agent_id) + _terminated = terminateds.get(agent_id, False) or self.is_terminated + _truncated = truncateds.get(agent_id, False) or self.is_truncated + _extra_model_outputs = extra_model_outputs.get(agent_id) + + # The value to place into the env- to agent-step map for this agent ID. + # _agent_step = self.SKIP_ENV_TS_TAG + + # Agents, whose SingleAgentEpisode had already been done before this + # step should NOT have received any data in this step. + if sa_episode.is_done and any( + v is not None + for v in [_observation, _action, _reward, _infos, _extra_model_outputs] + ): + raise MultiAgentEnvError( + f"Agent {agent_id} already had its `SingleAgentEpisode.is_done` " + f"set to True, but still received data in a following step! " + f"obs={_observation} act={_action} rew={_reward} info={_infos} " + f"extra_model_outputs={_extra_model_outputs}." + ) + _reward = _reward or 0.0 + + # CASE 1: A complete agent step is available (in one env step). + # ------------------------------------------------------------- + # We have an observation and an action for this agent -> + # Add the agent step to the single agent episode. + # ... action -> next obs + reward ... + if _observation is not None and _action is not None: + if agent_id not in rewards: + raise MultiAgentEnvError( + f"Agent {agent_id} acted (and received next obs), but did NOT " + f"receive any reward from the env!" + ) + + # CASE 2: Step gets completed with a hanging action OR first observation. + # ------------------------------------------------------------------------ + # We have an observation, but no action -> + # a) Action (and extra model outputs) must be hanging already. Also use + # collected hanging rewards and extra_model_outputs. + # b) The observation is the first observation for this agent ID. + elif _observation is not None and _action is None: + _action = self._hanging_actions_end.pop(agent_id, None) + + # We have a hanging action (the agent had acted after the previous + # observation, but the env had not responded - until now - with another + # observation). + # ...[hanging action] ... ... -> next obs + (reward)? ... + if _action is not None: + # Get the extra model output if available. + _extra_model_outputs = self._hanging_extra_model_outputs_end.pop( + agent_id, None + ) + _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward + # First observation for this agent, we have no hanging action. + # ... [done]? ... -> [1st obs for agent ID] + else: + # The agent is already done -> The agent thus has never stepped once + # and we do not have to create a SingleAgentEpisode for it. + if _terminated or _truncated: + self._del_hanging(agent_id) + continue + # This must be the agent's initial observation. + else: + # Prepend n skip tags to this agent's mapping + the initial [0]. + assert agent_id not in self.env_t_to_agent_t + self.env_t_to_agent_t[agent_id].extend( + [self.SKIP_ENV_TS_TAG] * self.env_t + [0] + ) + self.env_t_to_agent_t[ + agent_id + ].lookback = self._len_lookback_buffers + # Make `add_env_reset` call and continue with next agent. + sa_episode.add_env_reset(observation=_observation, infos=_infos) + # Add possible reward to begin cache. + self._hanging_rewards_begin[agent_id] += _reward + # Now that the SAEps is valid, add it to our dict. + self.agent_episodes[agent_id] = sa_episode + continue + + # CASE 3: Step is started (by an action), but not completed (no next obs). + # ------------------------------------------------------------------------ + # We have no observation, but we have a hanging action (used when we receive + # the next obs for this agent in the future). + elif agent_id not in observations and agent_id in actions: + # Agent got truncated -> Error b/c we would need a last (truncation) + # observation for this (otherwise, e.g. bootstrapping would not work). + # [previous obs] [action] (hanging) ... ... [truncated] + if _truncated: + raise MultiAgentEnvError( + f"Agent {agent_id} acted and then got truncated, but did NOT " + "receive a last (truncation) observation, required for e.g. " + "value function bootstrapping!" + ) + # Agent got terminated. + # [previous obs] [action] (hanging) ... ... [terminated] + elif _terminated: + # If the agent was terminated and no observation is provided, + # duplicate the previous one (this is a technical "fix" to properly + # complete the single agent episode; this last observation is never + # used for learning anyway). + _observation = sa_episode._last_added_observation + _infos = sa_episode._last_added_infos + # Agent is still alive. + # [previous obs] [action] (hanging) ... + else: + # Hanging action, reward, and extra_model_outputs. + assert agent_id not in self._hanging_actions_end + self._hanging_actions_end[agent_id] = _action + self._hanging_rewards_end[agent_id] = _reward + self._hanging_extra_model_outputs_end[ + agent_id + ] = _extra_model_outputs + + # CASE 4: Step has started in the past and is still ongoing (no observation, + # no action). + # -------------------------------------------------------------------------- + # Record reward and terminated/truncated flags. + else: + _action = self._hanging_actions_end.get(agent_id) + + # Agent is done. + if _terminated or _truncated: + # If the agent has NOT stepped, we treat it as not being + # part of this episode. + # ... ... [other agents doing stuff] ... ... [agent done] + if _action is None: + self._del_hanging(agent_id) + continue + + # Agent got truncated -> Error b/c we would need a last (truncation) + # observation for this (otherwise, e.g. bootstrapping would not + # work). + if _truncated: + raise MultiAgentEnvError( + f"Agent {agent_id} acted and then got truncated, but did " + "NOT receive a last (truncation) observation, required " + "for e.g. value function bootstrapping!" + ) + + # [obs] ... ... [hanging action] ... ... [done] + # If the agent was terminated and no observation is provided, + # duplicate the previous one (this is a technical "fix" to properly + # complete the single agent episode; this last observation is never + # used for learning anyway). + _observation = sa_episode._last_added_observation + _infos = sa_episode._last_added_infos + # `_action` is already `get` above. We don't need to pop out from + # the cache as it gets wiped out anyway below b/c the agent is + # done. + _extra_model_outputs = self._hanging_extra_model_outputs_end.pop( + agent_id, None + ) + _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward + # The agent is still alive, just add current reward to cache. + else: + # But has never stepped in this episode -> add to begin cache. + if agent_id not in self.agent_episodes: + self._hanging_rewards_begin[agent_id] += _reward + # Otherwise, add to end cache. + else: + self._hanging_rewards_end[agent_id] += _reward + + # If agent is stepping, add timestep to `SingleAgentEpisode`. + if _observation is not None: + sa_episode.add_env_step( + observation=_observation, + action=_action, + reward=_reward, + infos=_infos, + terminated=_terminated, + truncated=_truncated, + extra_model_outputs=_extra_model_outputs, + ) + # Update the env- to agent-step mapping. + self.env_t_to_agent_t[agent_id].append( + len(sa_episode) + sa_episode.observations.lookback + ) + + # Agent is also done. -> Erase all hanging values for this agent + # (they should be empty at this point anyways). + if _terminated or _truncated: + self._del_hanging(agent_id) + + # Validate our data. + self.validate() + + # Step time stats. + self._last_step_time = time.perf_counter() + if self._start_time is None: + self._start_time = self._last_step_time + + def validate(self) -> None: + """Validates the episode's data. + + This function ensures that the data stored to a `MultiAgentEpisode` is + in order (e.g. that the correct number of observations, actions, rewards + are there). + """ + for eps in self.agent_episodes.values(): + eps.validate() + + # TODO (sven): Validate MultiAgentEpisode specifics, like the timestep mappings, + # action/reward caches, etc.. + + @property + def is_reset(self) -> bool: + """Returns True if `self.add_env_reset()` has already been called.""" + return any( + len(sa_episode.observations) > 0 + for sa_episode in self.agent_episodes.values() + ) + + @property + def is_numpy(self) -> bool: + """True, if the data in this episode is already stored as numpy arrays.""" + is_numpy = next(iter(self.agent_episodes.values())).is_numpy + # Make sure that all single agent's episodes' `is_numpy` flags are the same. + if not all(eps.is_numpy is is_numpy for eps in self.agent_episodes.values()): + raise RuntimeError( + f"Only some SingleAgentEpisode objects in {self} are converted to " + f"numpy, others are not!" + ) + return is_numpy + + @property + def is_done(self): + """Whether the episode is actually done (terminated or truncated). + + A done episode cannot be continued via `self.add_env_step()` or being + concatenated on its right-side with another episode chunk or being + succeeded via `self.cut()`. + + Note that in a multi-agent environment this does not necessarily + correspond to single agents having terminated or being truncated. + + `self.is_terminated` should be `True`, if all agents are terminated and + `self.is_truncated` should be `True`, if all agents are truncated. If + only one or more (but not all!) agents are `terminated/truncated the + `MultiAgentEpisode.is_terminated/is_truncated` should be `False`. This + information about single agent's terminated/truncated states can always + be retrieved from the `SingleAgentEpisode`s inside the 'MultiAgentEpisode` + one. + + If all agents are either terminated or truncated, but in a mixed fashion, + i.e. some are terminated and others are truncated: This is currently + undefined and could potentially be a problem (if a user really implemented + such a multi-agent env that behaves this way). + + Returns: + Boolean defining if an episode has either terminated or truncated. + """ + return self.is_terminated or self.is_truncated + + def to_numpy(self) -> "MultiAgentEpisode": + """Converts this Episode's list attributes to numpy arrays. + + This means in particular that this episodes' lists (per single agent) of + (possibly complex) data (e.g. an agent having a dict obs space) will be + converted to (possibly complex) structs, whose leafs are now numpy arrays. + Each of these leaf numpy arrays will have the same length (batch dimension) + as the length of the original lists. + + Note that Columns.INFOS are NEVER numpy'ized and will remain a list + (normally, a list of the original, env-returned dicts). This is due to the + heterogeneous nature of INFOS returned by envs, which would make it unwieldy to + convert this information to numpy arrays. + + After calling this method, no further data may be added to this episode via + the `self.add_env_step()` method. + + Examples: + + .. testcode:: + + import numpy as np + + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.env.tests.test_multi_agent_episode import ( + TestMultiAgentEpisode + ) + + # Create some multi-agent episode data. + ( + observations, + actions, + rewards, + terminateds, + truncateds, + infos, + ) = TestMultiAgentEpisode._mock_multi_agent_records() + # Define the agent ids. + agent_ids = ["agent_1", "agent_2", "agent_3", "agent_4", "agent_5"] + + episode = MultiAgentEpisode( + observations=observations, + infos=infos, + actions=actions, + rewards=rewards, + # Note: terminated/truncated have nothing to do with an episode + # being converted `to_numpy` or not (via the `self.to_numpy()` method)! + terminateds=terminateds, + truncateds=truncateds, + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + ) + + # Episode has not been numpy'ized yet. + assert not episode.is_numpy + # We are still operating on lists. + assert ( + episode.get_observations( + indices=[1], + agent_ids="agent_1", + ) == {"agent_1": [1]} + ) + + # Numpy'ized the episode. + episode.to_numpy() + assert episode.is_numpy + + # Everything is now numpy arrays (with 0-axis of size + # B=[len of requested slice]). + assert ( + isinstance(episode.get_observations( + indices=[1], + agent_ids="agent_1", + )["agent_1"], np.ndarray) + ) + + Returns: + This `MultiAgentEpisode` object with the converted numpy data. + """ + + for agent_id, agent_eps in self.agent_episodes.copy().items(): + agent_eps.to_numpy() + + return self + + def concat_episode(self, other: "MultiAgentEpisode") -> None: + """Adds the given `other` MultiAgentEpisode to the right side of self. + + In order for this to work, both chunks (`self` and `other`) must fit + together. This is checked by the IDs (must be identical), the time step counters + (`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the + observations/infos of the individual agents at the concatenation boundaries. + Also, `self.is_done` must not be True, meaning `self.is_terminated` and + `self.is_truncated` are both False. + + Args: + other: The other `MultiAgentEpisode` to be concatenated to this one. + + Returns: A `MultiAgentEpisode` instance containing the concatenated data + from both episodes (`self` and `other`). + """ + # Make sure the IDs match. + assert other.id_ == self.id_ + # NOTE (sven): This is what we agreed on. As the replay buffers must be + # able to concatenate. + assert not self.is_done + # Make sure the timesteps match. + assert self.env_t == other.env_t_started + # Validate `other`. + other.validate() + + # Concatenate the individual SingleAgentEpisodes from both chunks. + all_agent_ids = set(self.agent_ids) | set(other.agent_ids) + for agent_id in all_agent_ids: + sa_episode = self.agent_episodes.get(agent_id) + + # If agent is only in the new episode chunk -> Store all the data of `other` + # wrt agent in `self`. + if sa_episode is None: + self.agent_episodes[agent_id] = other.agent_episodes[agent_id] + self.env_t_to_agent_t[agent_id] = other.env_t_to_agent_t[agent_id] + self.agent_t_started[agent_id] = other.agent_t_started[agent_id] + self._copy_hanging(agent_id, other) + + # If the agent was done in `self`, ignore and continue. There should not be + # any data of that agent in `other`. + elif sa_episode.is_done: + continue + + # If the agent has data in both chunks, concatenate on the single-agent + # level, thereby making sure the hanging values (begin and end) match. + elif agent_id in other.agent_episodes: + # If `other` has hanging (end) values -> Add these to `self`'s agent + # SingleAgentEpisode (as a new timestep) and only then concatenate. + # Otherwise, the concatentaion would fail b/c of missing data. + if agent_id in self._hanging_actions_end: + assert agent_id in self._hanging_extra_model_outputs_end + sa_episode.add_env_step( + observation=other.agent_episodes[agent_id].get_observations(0), + infos=other.agent_episodes[agent_id].get_infos(0), + action=self._hanging_actions_end[agent_id], + reward=( + self._hanging_rewards_end[agent_id] + + other._hanging_rewards_begin[agent_id] + ), + extra_model_outputs=( + self._hanging_extra_model_outputs_end[agent_id] + ), + ) + sa_episode.concat_episode(other.agent_episodes[agent_id]) + # Override `self`'s hanging (end) values with `other`'s hanging (end). + if agent_id in other._hanging_actions_end: + self._hanging_actions_end[agent_id] = copy.deepcopy( + other._hanging_actions_end[agent_id] + ) + self._hanging_rewards_end[agent_id] = other._hanging_rewards_end[ + agent_id + ] + self._hanging_extra_model_outputs_end[agent_id] = copy.deepcopy( + other._hanging_extra_model_outputs_end[agent_id] + ) + + # Concatenate the env- to agent-timestep mappings. + j = self.env_t + for i, val in enumerate(other.env_t_to_agent_t[agent_id][1:]): + if val == self.SKIP_ENV_TS_TAG: + self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG) + else: + self.env_t_to_agent_t[agent_id].append(i + 1 + j) + + # Otherwise, the agent is only in `self` and not done. All data is stored + # already -> skip + # else: pass + + # Update all timestep counters. + self.env_t = other.env_t + # Check, if the episode is terminated or truncated. + if other.is_terminated: + self.is_terminated = True + elif other.is_truncated: + self.is_truncated = True + + # Erase all temporary timestep data caches. + self._temporary_timestep_data.clear() + + # Validate. + self.validate() + + def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode": + """Returns a successor episode chunk (of len=0) continuing from this Episode. + + The successor will have the same ID as `self`. + If no lookback buffer is requested (len_lookback_buffer=0), the successor's + observations will be the last observation(s) of `self` and its length will + therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0, + the returned successor will have `len_lookback_buffer` observations (and + actions, rewards, etc..) taken from the right side (end) of `self`. For example + if `len_lookback_buffer=2`, the returned successor's lookback buffer actions + will be identical to teh results of `self.get_actions([-2, -1])`. + + This method is useful if you would like to discontinue building an episode + chunk (b/c you have to return it from somewhere), but would like to have a new + episode instance to continue building the actual gym.Env episode at a later + time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor) + will still be able to "look back" into this predecessor episode's data (at + least to some extend, depending on the value of `len_lookback_buffer`). + + Args: + len_lookback_buffer: The number of environment timesteps to take along into + the new chunk as "lookback buffer". A lookback buffer is additional data + on the left side of the actual episode data for visibility purposes + (but without actually being part of the new chunk). For example, if + `self` ends in actions: agent_1=5,6,7 and agent_2=6,7, and we call + `self.cut(len_lookback_buffer=2)`, the returned chunk will have + actions 6 and 7 for both agents already in it, but still + `t_started`==t==8 (not 7!) and a length of 0. If there is not enough + data in `self` yet to fulfil the `len_lookback_buffer` request, the + value of `len_lookback_buffer` is automatically adjusted (lowered). + + Returns: + The successor Episode chunk of this one with the same ID and state and the + only observation being the last observation in self. + """ + assert len_lookback_buffer >= 0 + if self.is_done: + raise RuntimeError( + "Can't call `MultiAgentEpisode.cut()` when the episode is already done!" + ) + + # If there is hanging data (e.g. actions) in the agents' caches, we might have + # to re-adjust the lookback len further into the past to make sure that these + # agents have at least one observation to look back to. Otherwise, the timestep + # that got cut into will be "lost" for learning from it. + orig_len_lb = len_lookback_buffer + for agent_id, agent_actions in self._hanging_actions_end.items(): + assert self.env_t_to_agent_t[agent_id].get(-1) == self.SKIP_ENV_TS_TAG + for i in range(orig_len_lb, len(self.env_t_to_agent_t[agent_id].data) + 1): + if self.env_t_to_agent_t[agent_id].get(-i) != self.SKIP_ENV_TS_TAG: + len_lookback_buffer = max(len_lookback_buffer, i - 1) + break + + # Initialize this episode chunk with the most recent observations + # and infos (even if lookback is zero). Similar to an initial `env.reset()` + indices_obs_and_infos = slice(-len_lookback_buffer - 1, None) + indices_rest = ( + slice(-len_lookback_buffer, None) + if len_lookback_buffer > 0 + else slice(None, 0) # -> empty slice + ) + + observations = self.get_observations( + indices=indices_obs_and_infos, return_list=True + ) + infos = self.get_infos(indices=indices_obs_and_infos, return_list=True) + actions = self.get_actions(indices=indices_rest, return_list=True) + rewards = self.get_rewards(indices=indices_rest, return_list=True) + extra_model_outputs = self.get_extra_model_outputs( + key=None, # all keys + indices=indices_rest, + return_list=True, + ) + successor = MultiAgentEpisode( + # Same ID. + id_=self.id_, + observations=observations, + observation_space=self.observation_space, + infos=infos, + actions=actions, + action_space=self.action_space, + rewards=rewards, + # List of MADicts, mapping agent IDs to their respective extra model output + # dicts. + extra_model_outputs=extra_model_outputs, + terminateds=self.get_terminateds(), + truncateds=self.get_truncateds(), + # Continue with `self`'s current timesteps. + env_t_started=self.env_t, + agent_t_started={ + aid: self.agent_episodes[aid].t + for aid in self.agent_ids + if not self.agent_episodes[aid].is_done + }, + # Same AgentIDs and SingleAgentEpisode IDs. + agent_episode_ids=self.agent_episode_ids, + agent_module_ids={ + aid: self.agent_episodes[aid].module_id for aid in self.agent_ids + }, + agent_to_module_mapping_fn=self.agent_to_module_mapping_fn, + # All data we provided to the c'tor goes into the lookback buffer. + len_lookback_buffer="auto", + ) + + # Copy over the hanging (end) values into the hanging (begin) chaches of the + # successor. + successor._hanging_rewards_begin = self._hanging_rewards_end.copy() + + return successor + + @property + def agent_ids(self) -> Set[AgentID]: + """Returns the agent ids.""" + return set(self.agent_episodes.keys()) + + @property + def agent_episode_ids(self) -> MultiAgentDict: + """Returns ids from each agent's `SingleAgentEpisode`.""" + + return { + agent_id: agent_eps.id_ + for agent_id, agent_eps in self.agent_episodes.items() + } + + def module_for(self, agent_id: AgentID) -> Optional[ModuleID]: + """Returns the ModuleID for a given AgentID. + + Forces the agent-to-module mapping to be performed (via + `self.agent_to_module_mapping_fn`), if this has not been done yet. + Note that all such mappings are stored in the `self._agent_to_module_mapping` + property. + + Args: + agent_id: The AgentID to get a mapped ModuleID for. + + Returns: + The ModuleID mapped to from the given `agent_id`. + """ + if agent_id not in self._agent_to_module_mapping: + module_id = self._agent_to_module_mapping[ + agent_id + ] = self.agent_to_module_mapping_fn(agent_id, self) + return module_id + else: + return self._agent_to_module_mapping[agent_id] + + def get_observations( + self, + indices: Optional[Union[int, List[int], slice]] = None, + agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None, + *, + env_steps: bool = True, + # global_indices: bool = False, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + one_hot_discrete: bool = False, + return_list: bool = False, + ) -> Union[MultiAgentDict, List[MultiAgentDict]]: + """Returns agents' observations or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual observation stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual observations in a batch of size len(indices). + A slice object is interpreted as a range of observations to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all observations (from ts=0 to the end). + agent_ids: An optional collection of AgentIDs or a single AgentID to get + observations for. If None, will return observations for all agents in + this episode. + env_steps: Whether `indices` should be interpreted as environment time steps + (True) or per-agent timesteps (False). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with agent A's observations + [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range + (ts=0 item is 7), will respond to `get_observations(-1, agent_ids=[A], + neg_index_as_lookback=True)` with {A: `6`} and to + `get_observations(slice(-2, 1), agent_ids=[A], + neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with agent A' observations [10, 11, 12, 13, 14] + and lookback buffer size of 2 (meaning observations `10` and `11` are + part of the lookback buffer) will respond to + `get_observations(slice(-7, -2), agent_ids=[A], fill=0.0)` with + `{A: [0.0, 0.0, 10, 11, 12]}`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) observation + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + return_list: Whether to return a list of multi-agent dicts (instead of + a single multi-agent dict of lists/structs). False by default. This + option can only be used when `env_steps` is True due to the fact the + such a list can only be interpreted as one env step per list item + (would not work with agent steps). + + Returns: + A dictionary mapping agent IDs to observations (at the given + `indices`). If `env_steps` is True, only agents that have stepped + (were ready) at the given env step `indices` are returned (i.e. not all + agent IDs are necessarily in the keys). + If `return_list` is True, returns a list of MultiAgentDicts (mapping agent + IDs to observations) instead. + """ + return self._get( + what="observations", + indices=indices, + agent_ids=agent_ids, + env_steps=env_steps, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + one_hot_discrete=one_hot_discrete, + return_list=return_list, + ) + + def get_infos( + self, + indices: Optional[Union[int, List[int], slice]] = None, + agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None, + *, + env_steps: bool = True, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + return_list: bool = False, + ) -> Union[MultiAgentDict, List[MultiAgentDict]]: + """Returns agents' info dicts or list (ranges) thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual info dict stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual info dicts in a list of size len(indices). + A slice object is interpreted as a range of info dicts to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all infos (from ts=0 to the end). + agent_ids: An optional collection of AgentIDs or a single AgentID to get + info dicts for. If None, will return info dicts for all agents in + this episode. + env_steps: Whether `indices` should be interpreted as environment time steps + (True) or per-agent timesteps (False). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with agent A's info dicts + [{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the + first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will + respond to `get_infos(-1, agent_ids=A, neg_index_as_lookback=True)` + with `{A: {"l":6}}` and to + `get_infos(slice(-2, 1), agent_ids=A, neg_index_as_lookback=True)` + with `{A: [{"l":5}, {"l":6}, {"a":7}]}`. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to + auto-fill. For example, an episode with agent A's infos being + [{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer + size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback + buffer) will respond to `get_infos(slice(-7, -2), agent_ids=A, + fill={"o": 0.0})` with + `{A: [{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]}`. + return_list: Whether to return a list of multi-agent dicts (instead of + a single multi-agent dict of lists/structs). False by default. This + option can only be used when `env_steps` is True due to the fact the + such a list can only be interpreted as one env step per list item + (would not work with agent steps). + + Returns: + A dictionary mapping agent IDs to observations (at the given + `indices`). If `env_steps` is True, only agents that have stepped + (were ready) at the given env step `indices` are returned (i.e. not all + agent IDs are necessarily in the keys). + If `return_list` is True, returns a list of MultiAgentDicts (mapping agent + IDs to infos) instead. + """ + return self._get( + what="infos", + indices=indices, + agent_ids=agent_ids, + env_steps=env_steps, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + return_list=return_list, + ) + + def get_actions( + self, + indices: Optional[Union[int, List[int], slice]] = None, + agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None, + *, + env_steps: bool = True, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + one_hot_discrete: bool = False, + return_list: bool = False, + ) -> Union[MultiAgentDict, List[MultiAgentDict]]: + """Returns agents' actions or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual actions stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual actions in a batch of size len(indices). + A slice object is interpreted as a range of actions to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all actions (from ts=0 to the end). + agent_ids: An optional collection of AgentIDs or a single AgentID to get + actions for. If None, will return actions for all agents in + this episode. + env_steps: Whether `indices` should be interpreted as environment time steps + (True) or per-agent timesteps (False). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with agent A's actions + [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range + (ts=0 item is 7), will respond to `get_actions(-1, agent_ids=[A], + neg_index_as_lookback=True)` with {A: `6`} and to + `get_actions(slice(-2, 1), agent_ids=[A], + neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with agent A' actions [10, 11, 12, 13, 14] + and lookback buffer size of 2 (meaning actions `10` and `11` are + part of the lookback buffer) will respond to + `get_actions(slice(-7, -2), agent_ids=[A], fill=0.0)` with + `{A: [0.0, 0.0, 10, 11, 12]}`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) observation + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + return_list: Whether to return a list of multi-agent dicts (instead of + a single multi-agent dict of lists/structs). False by default. This + option can only be used when `env_steps` is True due to the fact the + such a list can only be interpreted as one env step per list item + (would not work with agent steps). + + Returns: + A dictionary mapping agent IDs to actions (at the given + `indices`). If `env_steps` is True, only agents that have stepped + (were ready) at the given env step `indices` are returned (i.e. not all + agent IDs are necessarily in the keys). + If `return_list` is True, returns a list of MultiAgentDicts (mapping agent + IDs to actions) instead. + """ + return self._get( + what="actions", + indices=indices, + agent_ids=agent_ids, + env_steps=env_steps, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + one_hot_discrete=one_hot_discrete, + return_list=return_list, + ) + + def get_rewards( + self, + indices: Optional[Union[int, List[int], slice]] = None, + agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None, + *, + env_steps: bool = True, + neg_index_as_lookback: bool = False, + fill: Optional[float] = None, + return_list: bool = False, + ) -> Union[MultiAgentDict, List[MultiAgentDict]]: + """Returns agents' rewards or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual rewards stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual rewards in a batch of size len(indices). + A slice object is interpreted as a range of rewards to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all rewards (from ts=0 to the end). + agent_ids: An optional collection of AgentIDs or a single AgentID to get + rewards for. If None, will return rewards for all agents in + this episode. + env_steps: Whether `indices` should be interpreted as environment time steps + (True) or per-agent timesteps (False). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with agent A's rewards + [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range + (ts=0 item is 7), will respond to `get_rewards(-1, agent_ids=[A], + neg_index_as_lookback=True)` with {A: `6`} and to + `get_rewards(slice(-2, 1), agent_ids=[A], + neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with agent A' rewards [10, 11, 12, 13, 14] + and lookback buffer size of 2 (meaning rewards `10` and `11` are + part of the lookback buffer) will respond to + `get_rewards(slice(-7, -2), agent_ids=[A], fill=0.0)` with + `{A: [0.0, 0.0, 10, 11, 12]}`. + return_list: Whether to return a list of multi-agent dicts (instead of + a single multi-agent dict of lists/structs). False by default. This + option can only be used when `env_steps` is True due to the fact the + such a list can only be interpreted as one env step per list item + (would not work with agent steps). + + Returns: + A dictionary mapping agent IDs to rewards (at the given + `indices`). If `env_steps` is True, only agents that have stepped + (were ready) at the given env step `indices` are returned (i.e. not all + agent IDs are necessarily in the keys). + If `return_list` is True, returns a list of MultiAgentDicts (mapping agent + IDs to rewards) instead. + """ + return self._get( + what="rewards", + indices=indices, + agent_ids=agent_ids, + env_steps=env_steps, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + return_list=return_list, + ) + + def get_extra_model_outputs( + self, + key: Optional[str] = None, + indices: Optional[Union[int, List[int], slice]] = None, + agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None, + *, + env_steps: bool = True, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + return_list: bool = False, + ) -> Union[MultiAgentDict, List[MultiAgentDict]]: + """Returns agents' actions or batched ranges thereof from this episode. + + Args: + key: The `key` within each agents' extra_model_outputs dict to extract + data for. If None, return data of all extra model output keys. + indices: A single int is interpreted as an index, from which to return the + individual extra model outputs stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual extra model outputs in a batch of size len(indices). + A slice object is interpreted as a range of extra model outputs to be + returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all extra model outputs (from ts=0 to the end). + agent_ids: An optional collection of AgentIDs or a single AgentID to get + extra model outputs for. If None, will return extra model outputs for + all agents in this episode. + env_steps: Whether `indices` should be interpreted as environment time steps + (True) or per-agent timesteps (False). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with agent A's actions + [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range + (ts=0 item is 7), will respond to `get_actions(-1, agent_ids=[A], + neg_index_as_lookback=True)` with {A: `6`} and to + `get_actions(slice(-2, 1), agent_ids=[A], + neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with agent A' actions [10, 11, 12, 13, 14] + and lookback buffer size of 2 (meaning actions `10` and `11` are + part of the lookback buffer) will respond to + `get_actions(slice(-7, -2), agent_ids=[A], fill=0.0)` with + `{A: [0.0, 0.0, 10, 11, 12]}`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) observation + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + return_list: Whether to return a list of multi-agent dicts (instead of + a single multi-agent dict of lists/structs). False by default. This + option can only be used when `env_steps` is True due to the fact the + such a list can only be interpreted as one env step per list item + (would not work with agent steps). + + Returns: + A dictionary mapping agent IDs to actions (at the given + `indices`). If `env_steps` is True, only agents that have stepped + (were ready) at the given env step `indices` are returned (i.e. not all + agent IDs are necessarily in the keys). + If `return_list` is True, returns a list of MultiAgentDicts (mapping agent + IDs to extra_model_outputs) instead. + """ + return self._get( + what="extra_model_outputs", + extra_model_outputs_key=key, + indices=indices, + agent_ids=agent_ids, + env_steps=env_steps, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + return_list=return_list, + ) + + def get_terminateds(self) -> MultiAgentDict: + """Gets the terminateds at given indices.""" + terminateds = { + agent_id: self.agent_episodes[agent_id].is_terminated + for agent_id in self.agent_ids + } + terminateds.update({"__all__": self.is_terminated}) + return terminateds + + def get_truncateds(self) -> MultiAgentDict: + truncateds = { + agent_id: self.agent_episodes[agent_id].is_truncated + for agent_id in self.agent_ids + } + truncateds.update({"__all__": self.is_terminated}) + return truncateds + + def add_temporary_timestep_data(self, key: str, data: Any) -> None: + """Temporarily adds (until `to_numpy()` called) per-timestep data to self. + + The given `data` is appended to a list (`self._temporary_timestep_data`), which + is cleared upon calling `self.to_numpy()`. To get the thus-far accumulated + temporary timestep data for a certain key, use the `get_temporary_timestep_data` + API. + Note that the size of the per timestep list is NOT checked or validated against + the other, non-temporary data in this episode (like observations). + + Args: + key: The key under which to find the list to append `data` to. If `data` is + the first data to be added for this key, start a new list. + data: The data item (representing a single timestep) to be stored. + """ + if self.is_numpy: + raise ValueError( + "Cannot use the `add_temporary_timestep_data` API on an already " + f"numpy'ized {type(self).__name__}!" + ) + self._temporary_timestep_data[key].append(data) + + def get_temporary_timestep_data(self, key: str) -> List[Any]: + """Returns all temporarily stored data items (list) under the given key. + + Note that all temporary timestep data is erased/cleared when calling + `self.to_numpy()`. + + Returns: + The current list storing temporary timestep data under `key`. + """ + if self.is_numpy: + raise ValueError( + "Cannot use the `get_temporary_timestep_data` API on an already " + f"numpy'ized {type(self).__name__}! All temporary data has been erased " + f"upon `{type(self).__name__}.to_numpy()`." + ) + try: + return self._temporary_timestep_data[key] + except KeyError: + raise KeyError(f"Key {key} not found in temporary timestep data!") + + def slice( + self, + slice_: slice, + *, + len_lookback_buffer: Optional[int] = None, + ) -> "MultiAgentEpisode": + """Returns a slice of this episode with the given slice object. + + Works analogous to + :py:meth:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode.slice` + + However, the important differences are: + - `slice_` is provided in (global) env steps, not agent steps. + - In case `slice_` ends - for a certain agent - in an env step, where that + particular agent does not have an observation, the previous observation will + be included, but the next action and sum of rewards until this point will + be stored in the agent's hanging values caches for the returned + MultiAgentEpisode slice. + + .. testcode:: + + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.utils.test_utils import check + + # Generate a simple multi-agent episode. + observations = [ + {"a0": 0, "a1": 0}, # 0 + { "a1": 1}, # 1 + { "a1": 2}, # 2 + {"a0": 3, "a1": 3}, # 3 + {"a0": 4}, # 4 + ] + # Actions are the same as observations (except for last obs, which doesn't + # have an action). + actions = observations[:-1] + # Make up a reward for each action. + rewards = [ + {aid: r / 10 + 0.1 for aid, r in o.items()} + for o in observations + ] + episode = MultiAgentEpisode( + observations=observations, + actions=actions, + rewards=rewards, + len_lookback_buffer=0, + ) + + # Slice the episode and check results. + slice = episode[1:3] + a0 = slice.agent_episodes["a0"] + a1 = slice.agent_episodes["a1"] + check((a0.observations, a1.observations), ([3], [1, 2, 3])) + check((a0.actions, a1.actions), ([], [1, 2])) + check((a0.rewards, a1.rewards), ([], [0.2, 0.3])) + check((a0.is_done, a1.is_done), (False, False)) + + # If a slice ends in a "gap" for an agent, expect actions and rewards to be + # cached for this agent. + slice = episode[:2] + a0 = slice.agent_episodes["a0"] + check(a0.observations, [0]) + check(a0.actions, []) + check(a0.rewards, []) + check(slice._hanging_actions_end["a0"], 0) + check(slice._hanging_rewards_end["a0"], 0.1) + + Args: + slice_: The slice object to use for slicing. This should exclude the + lookback buffer, which will be prepended automatically to the returned + slice. + len_lookback_buffer: If not None, forces the returned slice to try to have + this number of timesteps in its lookback buffer (if available). If None + (default), tries to make the returned slice's lookback as large as the + current lookback buffer of this episode (`self`). + + Returns: + The new MultiAgentEpisode representing the requested slice. + """ + if slice_.step not in [1, None]: + raise NotImplementedError( + "Slicing MultiAgentEnv with a step other than 1 (you used" + f" {slice_.step}) is not supported!" + ) + + # Translate `slice_` into one that only contains 0-or-positive ints and will + # NOT contain any None. + start = slice_.start + stop = slice_.stop + + # Start is None -> 0. + if start is None: + start = 0 + # Start is negative -> Interpret index as counting "from end". + elif start < 0: + start = max(len(self) + start, 0) + # Start is larger than len(self) -> Clip to len(self). + elif start > len(self): + start = len(self) + + # Stop is None -> Set stop to our len (one ts past last valid index). + if stop is None: + stop = len(self) + # Stop is negative -> Interpret index as counting "from end". + elif stop < 0: + stop = max(len(self) + stop, 0) + # Stop is larger than len(self) -> Clip to len(self). + elif stop > len(self): + stop = len(self) + + ref_lookback = None + try: + for aid, sa_episode in self.agent_episodes.items(): + if ref_lookback is None: + ref_lookback = sa_episode.observations.lookback + assert sa_episode.observations.lookback == ref_lookback + assert sa_episode.actions.lookback == ref_lookback + assert sa_episode.rewards.lookback == ref_lookback + assert all( + ilb.lookback == ref_lookback + for ilb in sa_episode.extra_model_outputs.values() + ) + except AssertionError: + raise ValueError( + "Can only slice a MultiAgentEpisode if all lookback buffers in this " + "episode have the exact same size!" + ) + + # Determine terminateds/truncateds and when (in agent timesteps) the + # single-agent episode slices start. + terminateds = {} + truncateds = {} + agent_t_started = {} + for aid, sa_episode in self.agent_episodes.items(): + mapping = self.env_t_to_agent_t[aid] + # If the (agent) timestep directly at the slice stop boundary is equal to + # the length of the single-agent episode of this agent -> Use the + # single-agent episode's terminated/truncated flags. + # If `stop` is already beyond this agent's single-agent episode, then we + # don't have to keep track of this: The MultiAgentEpisode initializer will + # automatically determine that this agent must be done (b/c it has no action + # following its final observation). + if ( + stop < len(mapping) + and mapping[stop] != self.SKIP_ENV_TS_TAG + and len(sa_episode) == mapping[stop] + ): + terminateds[aid] = sa_episode.is_terminated + truncateds[aid] = sa_episode.is_truncated + # Determine this agent's t_started. + if start < len(mapping): + for i in range(start, len(mapping)): + if mapping[i] != self.SKIP_ENV_TS_TAG: + agent_t_started[aid] = sa_episode.t_started + mapping[i] + break + terminateds["__all__"] = all( + terminateds.get(aid) for aid in self.agent_episodes + ) + truncateds["__all__"] = all(truncateds.get(aid) for aid in self.agent_episodes) + + # Determine all other slice contents. + _lb = len_lookback_buffer if len_lookback_buffer is not None else ref_lookback + if start - _lb < 0 and ref_lookback < (_lb - start): + _lb = ref_lookback + start + observations = self.get_observations( + slice(start - _lb, stop + 1), + neg_index_as_lookback=True, + return_list=True, + ) + actions = self.get_actions( + slice(start - _lb, stop), + neg_index_as_lookback=True, + return_list=True, + ) + rewards = self.get_rewards( + slice(start - _lb, stop), + neg_index_as_lookback=True, + return_list=True, + ) + extra_model_outputs = self.get_extra_model_outputs( + indices=slice(start - _lb, stop), + neg_index_as_lookback=True, + return_list=True, + ) + + # Create the actual slice to be returned. + ma_episode = MultiAgentEpisode( + id_=self.id_, + # In the following, offset `start`s automatically by lookbacks. + observations=observations, + observation_space=self.observation_space, + actions=actions, + action_space=self.action_space, + rewards=rewards, + extra_model_outputs=extra_model_outputs, + terminateds=terminateds, + truncateds=truncateds, + len_lookback_buffer=_lb, + env_t_started=self.env_t_started + start, + agent_episode_ids={ + aid: eid.id_ for aid, eid in self.agent_episodes.items() + }, + agent_t_started=agent_t_started, + agent_module_ids=self._agent_to_module_mapping, + agent_to_module_mapping_fn=self.agent_to_module_mapping_fn, + ) + + # Numpy'ize slice if `self` is also finalized. + if self.is_numpy: + ma_episode.to_numpy() + + return ma_episode + + def __len__(self): + """Returns the length of an `MultiAgentEpisode`. + + Note that the length of an episode is defined by the difference + between its actual timestep and the starting point. + + Returns: An integer defining the length of the episode or an + error if the episode has not yet started. + """ + assert ( + sum(len(agent_map) for agent_map in self.env_t_to_agent_t.values()) > 0 + ), ( + "ERROR: Cannot determine length of episode that hasn't started, yet!" + "Call `MultiAgentEpisode.add_env_reset(observations=)` " + "first (after which `len(MultiAgentEpisode)` will be 0)." + ) + return self.env_t - self.env_t_started + + def __repr__(self): + sa_eps_returns = { + aid: sa_eps.get_return() for aid, sa_eps in self.agent_episodes.items() + } + return ( + f"MAEps(len={len(self)} done={self.is_done} " + f"Rs={sa_eps_returns} id_={self.id_})" + ) + + def print(self) -> None: + """Prints this MultiAgentEpisode as a table of observations for the agents.""" + + # Find the maximum timestep across all agents to determine the grid width. + max_ts = max(ts.len_incl_lookback() for ts in self.env_t_to_agent_t.values()) + lookback = next(iter(self.env_t_to_agent_t.values())).lookback + longest_agent = max(len(aid) for aid in self.agent_ids) + # Construct the header. + header = ( + "ts" + + (" " * longest_agent) + + " ".join(str(i) for i in range(-lookback, max_ts - lookback)) + + "\n" + ) + # Construct each agent's row. + rows = [] + for agent, inf_buffer in self.env_t_to_agent_t.items(): + row = f"{agent} " + (" " * (longest_agent - len(agent))) + for t in inf_buffer.data: + # Two spaces for alignment. + if t == "S": + row += " " + # Mark the step with an x. + else: + row += " x " + # Remove trailing space for alignment. + rows.append(row.rstrip()) + + # Join all components into a final string + print(header + "\n".join(rows)) + + def get_state(self) -> Dict[str, Any]: + """Returns the state of a multi-agent episode. + + Note that from an episode's state the episode itself can + be recreated. + + Returns: A dicitonary containing pickable data for a + `MultiAgentEpisode`. + """ + return { + "id_": self.id_, + "agent_to_module_mapping_fn": self.agent_to_module_mapping_fn, + "_agent_to_module_mapping": self._agent_to_module_mapping, + "observation_space": self.observation_space, + "action_space": self.action_space, + "env_t_started": self.env_t_started, + "env_t": self.env_t, + "agent_t_started": self.agent_t_started, + # TODO (simon): Check, if we can store the `InfiniteLookbackBuffer` + "env_t_to_agent_t": self.env_t_to_agent_t, + "_hanging_actions_end": self._hanging_actions_end, + "_hanging_extra_model_outputs_end": self._hanging_extra_model_outputs_end, + "_hanging_rewards_end": self._hanging_rewards_end, + "_hanging_rewards_begin": self._hanging_rewards_begin, + "is_terminated": self.is_terminated, + "is_truncated": self.is_truncated, + "agent_episodes": list( + { + agent_id: agent_eps.get_state() + for agent_id, agent_eps in self.agent_episodes.items() + }.items() + ), + "_start_time": self._start_time, + "_last_step_time": self._last_step_time, + } + + @staticmethod + def from_state(state: Dict[str, Any]) -> "MultiAgentEpisode": + """Creates a multi-agent episode from a state dictionary. + + See `MultiAgentEpisode.get_state()` for creating a state for + a `MultiAgentEpisode` pickable state. For recreating a + `MultiAgentEpisode` from a state, this state has to be complete, + i.e. all data must have been stored in the state. + + Args: + state: A dict containing all data required to recreate a MultiAgentEpisode`. + See `MultiAgentEpisode.get_state()`. + + Returns: + A `MultiAgentEpisode` instance created from the state data. + """ + # Create an empty `MultiAgentEpisode` instance. + episode = MultiAgentEpisode(id_=state["id_"]) + # Fill the instance with the state data. + episode.agent_to_module_mapping_fn = state["agent_to_module_mapping_fn"] + episode._agent_to_module_mapping = state["_agent_to_module_mapping"] + episode.observation_space = state["observation_space"] + episode.action_space = state["action_space"] + episode.env_t_started = state["env_t_started"] + episode.env_t = state["env_t"] + episode.agent_t_started = state["agent_t_started"] + episode.env_t_to_agent_t = state["env_t_to_agent_t"] + episode._hanging_actions_end = state["_hanging_actions_end"] + episode._hanging_extra_model_outputs_end = state[ + "_hanging_extra_model_outputs_end" + ] + episode._hanging_rewards_end = state["_hanging_rewards_end"] + episode._hanging_rewards_begin = state["_hanging_rewards_begin"] + episode.is_terminated = state["is_terminated"] + episode.is_truncated = state["is_truncated"] + episode.agent_episodes = { + agent_id: SingleAgentEpisode.from_state(agent_state) + for agent_id, agent_state in state["agent_episodes"] + } + episode._start_time = state["_start_time"] + episode._last_step_time = state["_last_step_time"] + + # Validate the episode. + episode.validate() + + return episode + + def get_sample_batch(self) -> MultiAgentBatch: + """Converts this `MultiAgentEpisode` into a `MultiAgentBatch`. + + Each `SingleAgentEpisode` instances in `MultiAgentEpisode.agent_epiosdes` + will be converted into a `SampleBatch` and the environment timestep will be + passed as the returned MultiAgentBatch's `env_steps`. + + Returns: + A MultiAgentBatch containing all of this episode's data. + """ + # TODO (simon): Check, if timesteps should be converted into global + # timesteps instead of agent steps. + # Note, only agents that have stepped are included into the batch. + return MultiAgentBatch( + policy_batches={ + agent_id: agent_eps.get_sample_batch() + for agent_id, agent_eps in self.agent_episodes.items() + if agent_eps.t - agent_eps.t_started > 0 + }, + env_steps=self.env_t - self.env_t_started, + ) + + def get_return( + self, + include_hanging_rewards: bool = False, + ) -> float: + """Returns all-agent return. + + Args: + include_hanging_rewards: Whether we should also consider + hanging rewards wehn calculating the overall return. Agents might + have received partial rewards, i.e. rewards without an + observation. These are stored in the "hanging" caches (begin and end) + for each agent and added up until the next observation is received by + that agent. + + Returns: + The sum of all single-agents' returns (maybe including the hanging + rewards per agent). + """ + env_return = sum( + agent_eps.get_return() for agent_eps in self.agent_episodes.values() + ) + if include_hanging_rewards: + for hanging_r in self._hanging_rewards_begin.values(): + env_return += hanging_r + for hanging_r in self._hanging_rewards_end.values(): + env_return += hanging_r + + return env_return + + def get_agents_to_act(self) -> Set[AgentID]: + """Returns a set of agent IDs required to send an action to `env.step()` next. + + Those are generally the agents that received an observation in the most recent + `env.step()` call. + + Returns: + A set of AgentIDs that are supposed to send actions to the next `env.step()` + call. + """ + return { + aid + for aid in self.get_observations(-1).keys() + if not self.agent_episodes[aid].is_done + } + + def get_agents_that_stepped(self) -> Set[AgentID]: + """Returns a set of agent IDs of those agents that just finished stepping. + + These are all the agents that have an observation logged at the last env + timestep, which may include agents, whose single agent episode just terminated + or truncated. + + Returns: + A set of AgentIDs of those agents that just finished stepping (that have a + most recent observation on the env timestep scale), regardless of whether + their single agent episodes are done or not. + """ + return set(self.get_observations(-1).keys()) + + def get_duration_s(self) -> float: + """Returns the duration of this Episode (chunk) in seconds.""" + if self._last_step_time is None: + return 0.0 + return self._last_step_time - self._start_time + + def env_steps(self) -> int: + """Returns the number of environment steps. + + Note, this episode instance could be a chunk of an actual episode. + + Returns: + An integer that counts the number of environment steps this episode instance + has seen. + """ + return len(self) + + def agent_steps(self) -> int: + """Number of agent steps. + + Note, there are >= 1 agent steps per environment step. + + Returns: + An integer counting the number of agent steps executed during the time this + episode instance records. + """ + return sum(len(eps) for eps in self.agent_episodes.values()) + + def __getitem__(self, item: slice) -> "MultiAgentEpisode": + """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:].""" + if isinstance(item, slice): + return self.slice(slice_=item) + else: + raise NotImplementedError( + f"MultiAgentEpisode does not support getting item '{item}'! " + "Only slice objects allowed with the syntax: `episode[a:b]`." + ) + + def _init_single_agent_episodes( + self, + *, + agent_module_ids: Optional[Dict[AgentID, ModuleID]] = None, + agent_episode_ids: Optional[Dict[AgentID, str]] = None, + observations: Optional[List[MultiAgentDict]] = None, + actions: Optional[List[MultiAgentDict]] = None, + rewards: Optional[List[MultiAgentDict]] = None, + infos: Optional[List[MultiAgentDict]] = None, + terminateds: Union[MultiAgentDict, bool] = False, + truncateds: Union[MultiAgentDict, bool] = False, + extra_model_outputs: Optional[List[MultiAgentDict]] = None, + ): + if observations is None: + return + if actions is None: + assert not rewards + assert not extra_model_outputs + actions = [] + rewards = [] + extra_model_outputs = [] + + # Infos and `extra_model_outputs` are allowed to be None -> Fill them with + # proper dummy values, if so. + if infos is None: + infos = [{} for _ in range(len(observations))] + if extra_model_outputs is None: + extra_model_outputs = [{} for _ in range(len(actions))] + + observations_per_agent = defaultdict(list) + infos_per_agent = defaultdict(list) + actions_per_agent = defaultdict(list) + rewards_per_agent = defaultdict(list) + extra_model_outputs_per_agent = defaultdict(list) + done_per_agent = defaultdict(bool) + len_lookback_buffer_per_agent = defaultdict(lambda: self._len_lookback_buffers) + + all_agent_ids = set( + agent_episode_ids.keys() if agent_episode_ids is not None else [] + ) + agent_module_ids = agent_module_ids or {} + + # Step through all observations and interpret these as the (global) env steps. + for data_idx, (obs, inf) in enumerate(zip(observations, infos)): + # If we do have actions/extra outs/rewards for this timestep, use the data. + # It may be that these lists have the same length as the observations list, + # in which case the data will be cached (agent did step/send an action, + # but the step has not been concluded yet by the env). + act = actions[data_idx] if len(actions) > data_idx else {} + extra_outs = ( + extra_model_outputs[data_idx] + if len(extra_model_outputs) > data_idx + else {} + ) + rew = rewards[data_idx] if len(rewards) > data_idx else {} + + for agent_id, agent_obs in obs.items(): + all_agent_ids.add(agent_id) + + observations_per_agent[agent_id].append(agent_obs) + infos_per_agent[agent_id].append(inf.get(agent_id, {})) + + # Pull out hanging action (if not first obs for this agent) and + # complete step for agent. + if len(observations_per_agent[agent_id]) > 1: + actions_per_agent[agent_id].append( + self._hanging_actions_end.pop(agent_id) + ) + extra_model_outputs_per_agent[agent_id].append( + self._hanging_extra_model_outputs_end.pop(agent_id) + ) + rewards_per_agent[agent_id].append( + self._hanging_rewards_end.pop(agent_id) + ) + # First obs for this agent. Make sure the agent's mapping is + # appropriately prepended with self.SKIP_ENV_TS_TAG tags. + else: + if agent_id not in self.env_t_to_agent_t: + self.env_t_to_agent_t[agent_id].extend( + [self.SKIP_ENV_TS_TAG] * data_idx + ) + len_lookback_buffer_per_agent[agent_id] -= data_idx + + # Agent is still continuing (has an action for the next step). + if agent_id in act: + # Always push actions/extra outputs into cache, then remove them + # from there, once the next observation comes in. Same for rewards. + self._hanging_actions_end[agent_id] = act[agent_id] + self._hanging_extra_model_outputs_end[agent_id] = extra_outs.get( + agent_id, {} + ) + self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0) + # Agent is done (has no action for the next step). + elif terminateds.get(agent_id) or truncateds.get(agent_id): + done_per_agent[agent_id] = True + # There is more (global) action/reward data. This agent must therefore + # be done. Automatically add it to `done_per_agent` and `terminateds`. + elif data_idx < len(observations) - 1: + done_per_agent[agent_id] = terminateds[agent_id] = True + + # Update env_t_to_agent_t mapping. + self.env_t_to_agent_t[agent_id].append( + len(observations_per_agent[agent_id]) - 1 + ) + + # Those agents that did NOT step: + # - Get self.SKIP_ENV_TS_TAG added to their env_t_to_agent_t mapping. + # - Get their reward (if any) added up. + for agent_id in all_agent_ids: + if agent_id not in obs and agent_id not in done_per_agent: + self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG) + # If we are still in the global lookback buffer segment, deduct 1 + # from this agents' lookback buffer, b/c we don't want the agent + # to use this (missing) obs/data in its single-agent lookback. + if ( + len(self.env_t_to_agent_t[agent_id]) + - self._len_lookback_buffers + <= 0 + ): + len_lookback_buffer_per_agent[agent_id] -= 1 + self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0) + + # - Validate per-agent data. + # - Fix lookback buffers of env_t_to_agent_t mappings. + for agent_id in list(self.env_t_to_agent_t.keys()): + # Skip agent if it doesn't seem to have any data. + if agent_id not in observations_per_agent: + del self.env_t_to_agent_t[agent_id] + continue + assert ( + len(observations_per_agent[agent_id]) + == len(infos_per_agent[agent_id]) + == len(actions_per_agent[agent_id]) + 1 + == len(extra_model_outputs_per_agent[agent_id]) + 1 + == len(rewards_per_agent[agent_id]) + 1 + ) + self.env_t_to_agent_t[agent_id].lookback = self._len_lookback_buffers + + # Now create the individual episodes from the collected per-agent data. + for agent_id, agent_obs in observations_per_agent.items(): + # If agent only has a single obs AND is already done, remove all its traces + # from this MultiAgentEpisode. + if len(agent_obs) == 1 and done_per_agent.get(agent_id): + self._del_agent(agent_id) + continue + + # Try to figure out the module ID for this agent. + # If not provided explicitly by the user that initializes this episode + # object, try our mapping function. + module_id = agent_module_ids.get( + agent_id, self.agent_to_module_mapping_fn(agent_id, self) + ) + # Create this agent's SingleAgentEpisode. + sa_episode = SingleAgentEpisode( + id_=( + agent_episode_ids.get(agent_id) + if agent_episode_ids is not None + else None + ), + agent_id=agent_id, + module_id=module_id, + multi_agent_episode_id=self.id_, + observations=agent_obs, + observation_space=self.observation_space.get(agent_id), + infos=infos_per_agent[agent_id], + actions=actions_per_agent[agent_id], + action_space=self.action_space.get(agent_id), + rewards=rewards_per_agent[agent_id], + extra_model_outputs=( + { + k: [i[k] for i in extra_model_outputs_per_agent[agent_id]] + for k in extra_model_outputs_per_agent[agent_id][0].keys() + } + if extra_model_outputs_per_agent[agent_id] + else None + ), + terminated=terminateds.get(agent_id, False), + truncated=truncateds.get(agent_id, False), + t_started=self.agent_t_started[agent_id], + len_lookback_buffer=max(len_lookback_buffer_per_agent[agent_id], 0), + ) + # .. and store it. + self.agent_episodes[agent_id] = sa_episode + + def _get( + self, + *, + what, + indices, + agent_ids=None, + env_steps=True, + neg_index_as_lookback=False, + fill=None, + one_hot_discrete=False, + return_list=False, + extra_model_outputs_key=None, + ): + agent_ids = set(force_list(agent_ids)) or self.agent_ids + + kwargs = dict( + what=what, + indices=indices, + agent_ids=agent_ids, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + # Rewards and infos do not support one_hot_discrete option. + one_hot_discrete=dict( + {} if not one_hot_discrete else {"one_hot_discrete": one_hot_discrete} + ), + extra_model_outputs_key=extra_model_outputs_key, + ) + + # User specified agent timesteps (indices) -> Simply delegate everything + # to the individual agents' SingleAgentEpisodes. + if env_steps is False: + if return_list: + raise ValueError( + f"`MultiAgentEpisode.get_{what}()` can't be called with both " + "`env_steps=False` and `return_list=True`!" + ) + return self._get_data_by_agent_steps(**kwargs) + # User specified env timesteps (indices) -> We need to translate them for each + # agent into agent-timesteps. + # Return a list of individual per-env-timestep multi-agent dicts. + elif return_list: + return self._get_data_by_env_steps_as_list(**kwargs) + # Return a single multi-agent dict with lists/arrays as leafs. + else: + return self._get_data_by_env_steps(**kwargs) + + def _get_data_by_agent_steps( + self, + *, + what, + indices, + agent_ids, + neg_index_as_lookback, + fill, + one_hot_discrete, + extra_model_outputs_key, + ): + # Return requested data by agent-steps. + ret = {} + # For each agent, we retrieve the data through passing the given indices into + # the SingleAgentEpisode of that agent. + for agent_id, sa_episode in self.agent_episodes.items(): + if agent_id not in agent_ids: + continue + inf_lookback_buffer = getattr(sa_episode, what) + hanging_val = self._get_hanging_value(what, agent_id) + # User wants a specific `extra_model_outputs` key. + if extra_model_outputs_key is not None: + inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key] + hanging_val = hanging_val[extra_model_outputs_key] + agent_value = inf_lookback_buffer.get( + indices=indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + _add_last_ts_value=hanging_val, + **one_hot_discrete, + ) + if agent_value is None or agent_value == []: + continue + ret[agent_id] = agent_value + return ret + + def _get_data_by_env_steps_as_list( + self, + *, + what: str, + indices: Union[int, slice, List[int]], + agent_ids: Collection[AgentID], + neg_index_as_lookback: bool, + fill: Any, + one_hot_discrete, + extra_model_outputs_key: str, + ) -> List[MultiAgentDict]: + # Collect indices for each agent first, so we can construct the list in + # the next step. + agent_indices = {} + for agent_id in self.agent_episodes.keys(): + if agent_id not in agent_ids: + continue + agent_indices[agent_id] = self.env_t_to_agent_t[agent_id].get( + indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=self.SKIP_ENV_TS_TAG, + # For those records where there is no "hanging" last timestep (all + # other than obs and infos), we have to ignore the last entry in + # the env_t_to_agent_t mappings. + _ignore_last_ts=what not in ["observations", "infos"], + ) + if not agent_indices: + return [] + ret = [] + for i in range(len(next(iter(agent_indices.values())))): + ret2 = {} + for agent_id, idxes in agent_indices.items(): + hanging_val = self._get_hanging_value(what, agent_id) + ( + inf_lookback_buffer, + indices_to_use, + ) = self._get_inf_lookback_buffer_or_dict( + agent_id, + what, + extra_model_outputs_key, + hanging_val, + filter_for_skip_indices=idxes[i], + ) + if ( + what == "extra_model_outputs" + and not inf_lookback_buffer + and not hanging_val + ): + continue + agent_value = self._get_single_agent_data_by_index( + what=what, + inf_lookback_buffer=inf_lookback_buffer, + agent_id=agent_id, + index_incl_lookback=indices_to_use, + fill=fill, + one_hot_discrete=one_hot_discrete, + extra_model_outputs_key=extra_model_outputs_key, + hanging_val=hanging_val, + ) + if agent_value is not None: + ret2[agent_id] = agent_value + ret.append(ret2) + return ret + + def _get_data_by_env_steps( + self, + *, + what: str, + indices: Union[int, slice, List[int]], + agent_ids: Collection[AgentID], + neg_index_as_lookback: bool, + fill: Any, + one_hot_discrete: bool, + extra_model_outputs_key: str, + ) -> MultiAgentDict: + ignore_last_ts = what not in ["observations", "infos"] + ret = {} + for agent_id, sa_episode in self.agent_episodes.items(): + if agent_id not in agent_ids: + continue + hanging_val = self._get_hanging_value(what, agent_id) + agent_indices = self.env_t_to_agent_t[agent_id].get( + indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=self.SKIP_ENV_TS_TAG if fill is not None else None, + # For those records where there is no "hanging" last timestep (all + # other than obs and infos), we have to ignore the last entry in + # the env_t_to_agent_t mappings. + _ignore_last_ts=ignore_last_ts, + ) + inf_lookback_buffer, agent_indices = self._get_inf_lookback_buffer_or_dict( + agent_id, + what, + extra_model_outputs_key, + hanging_val, + filter_for_skip_indices=agent_indices, + ) + if isinstance(agent_indices, list): + agent_values = self._get_single_agent_data_by_env_step_indices( + what=what, + agent_id=agent_id, + indices_incl_lookback=agent_indices, + fill=fill, + one_hot_discrete=one_hot_discrete, + hanging_val=hanging_val, + extra_model_outputs_key=extra_model_outputs_key, + ) + if len(agent_values) > 0: + ret[agent_id] = agent_values + else: + agent_values = self._get_single_agent_data_by_index( + what=what, + inf_lookback_buffer=inf_lookback_buffer, + agent_id=agent_id, + index_incl_lookback=agent_indices, + fill=fill, + one_hot_discrete=one_hot_discrete, + extra_model_outputs_key=extra_model_outputs_key, + hanging_val=hanging_val, + ) + if agent_values is not None: + ret[agent_id] = agent_values + return ret + + def _get_single_agent_data_by_index( + self, + *, + what: str, + inf_lookback_buffer: InfiniteLookbackBuffer, + agent_id: AgentID, + index_incl_lookback: Union[int, str], + fill: Any, + one_hot_discrete: dict, + extra_model_outputs_key: str, + hanging_val: Any, + ) -> Any: + sa_episode = self.agent_episodes[agent_id] + + if index_incl_lookback == self.SKIP_ENV_TS_TAG: + # We don't want to fill -> Skip this agent. + if fill is None: + return + # Provide filled value for this agent. + return getattr(sa_episode, f"get_{what}")( + indices=1000000000000, + neg_index_as_lookback=False, + fill=fill, + **dict( + {} + if extra_model_outputs_key is None + else {"key": extra_model_outputs_key} + ), + **one_hot_discrete, + ) + + # No skip timestep -> Provide value at given index for this agent. + + # Special case: extra_model_outputs and key=None (return all keys as + # a dict). Note that `inf_lookback_buffer` is NOT an infinite lookback + # buffer, but a dict mapping keys to individual infinite lookback + # buffers. + elif what == "extra_model_outputs" and extra_model_outputs_key is None: + assert hanging_val is None or isinstance(hanging_val, dict) + ret = {} + if inf_lookback_buffer: + for key, sub_buffer in inf_lookback_buffer.items(): + ret[key] = sub_buffer.get( + indices=index_incl_lookback - sub_buffer.lookback, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=( + None if hanging_val is None else hanging_val[key] + ), + **one_hot_discrete, + ) + else: + for key in hanging_val.keys(): + ret[key] = InfiniteLookbackBuffer().get( + indices=index_incl_lookback, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=hanging_val[key], + **one_hot_discrete, + ) + return ret + + # Extract data directly from the infinite lookback buffer object. + else: + return inf_lookback_buffer.get( + indices=index_incl_lookback - inf_lookback_buffer.lookback, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=hanging_val, + **one_hot_discrete, + ) + + def _get_single_agent_data_by_env_step_indices( + self, + *, + what: str, + agent_id: AgentID, + indices_incl_lookback: Union[int, str], + fill: Optional[Any] = None, + one_hot_discrete: bool = False, + extra_model_outputs_key: Optional[str] = None, + hanging_val: Optional[Any] = None, + ) -> Any: + """Returns single data item from the episode based on given (env step) indices. + + The returned data item will have a batch size that matches the env timesteps + defined via `indices_incl_lookback`. + + Args: + what: A (str) descriptor of what data to collect. Must be one of + "observations", "infos", "actions", "rewards", or "extra_model_outputs". + indices_incl_lookback: A list of ints specifying, which indices + to pull from the InfiniteLookbackBuffer defined by `agent_id` and `what` + (and maybe `extra_model_outputs_key`). Note that these indices + disregard the special logic of the lookback buffer. Meaning if one + index in `indices_incl_lookback` is 0, then the first value in the + lookback buffer should be returned, not the first value after the + lookback buffer (which would be normal behavior for pulling items from + an `InfiniteLookbackBuffer` object). + agent_id: The individual agent ID to pull data for. Used to lookup the + `SingleAgentEpisode` object for this agent in `self`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the buffer's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, a buffer with data [10, 11, 12, 13, 14] and lookback + buffer size of 2 (meaning `10` and `11` are part of the lookback buffer) + will respond to `indices_incl_lookback=[-1, -2, 0]` and `fill=0.0` + with `[0.0, 0.0, 10]`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) space + that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices_incl_lookback` are out of the range of our data, the + returned one-hot vectors will actually be zero-hot (all slots zero). + extra_model_outputs_key: Only if what is "extra_model_outputs", this + specifies the sub-key (str) inside the extra_model_outputs dict, e.g. + STATE_OUT or ACTION_DIST_INPUTS. + hanging_val: In case we are pulling actions, rewards, or extra_model_outputs + data, there might be information "hanging" (cached). For example, + if an agent receives an observation o0 and then immediately sends an + action a0 back, but then does NOT immediately reveive a next + observation, a0 is now cached (not fully logged yet with this + episode). The currently cached value must be provided here to be able + to return it in case the index is -1 (most recent timestep). + + Returns: + A data item corresponding to the provided args. + """ + sa_episode = self.agent_episodes[agent_id] + + inf_lookback_buffer = getattr(sa_episode, what) + if extra_model_outputs_key is not None: + inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key] + + # If there are self.SKIP_ENV_TS_TAG items in `indices_incl_lookback` and user + # wants to fill these (together with outside-episode-bounds indices) -> + # Provide these skipped timesteps as filled values. + if self.SKIP_ENV_TS_TAG in indices_incl_lookback and fill is not None: + single_fill_value = inf_lookback_buffer.get( + indices=1000000000000, + neg_index_as_lookback=False, + fill=fill, + **one_hot_discrete, + ) + ret = [] + for i in indices_incl_lookback: + if i == self.SKIP_ENV_TS_TAG: + ret.append(single_fill_value) + else: + ret.append( + inf_lookback_buffer.get( + indices=i - getattr(sa_episode, what).lookback, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=hanging_val, + **one_hot_discrete, + ) + ) + if self.is_numpy: + ret = batch(ret) + else: + # Filter these indices out up front. + indices = [ + i - inf_lookback_buffer.lookback + for i in indices_incl_lookback + if i != self.SKIP_ENV_TS_TAG + ] + ret = inf_lookback_buffer.get( + indices=indices, + neg_index_as_lookback=True, + fill=fill, + _add_last_ts_value=hanging_val, + **one_hot_discrete, + ) + return ret + + def _get_hanging_value(self, what: str, agent_id: AgentID) -> Any: + """Returns the hanging action/reward/extra_model_outputs for given agent.""" + if what == "actions": + return self._hanging_actions_end.get(agent_id) + elif what == "extra_model_outputs": + return self._hanging_extra_model_outputs_end.get(agent_id) + elif what == "rewards": + return self._hanging_rewards_end.get(agent_id) + + def _copy_hanging(self, agent_id: AgentID, other: "MultiAgentEpisode") -> None: + """Copies hanging action, reward, extra_model_outputs from `other` to `self.""" + if agent_id in other._hanging_rewards_begin: + self._hanging_rewards_begin[agent_id] = other._hanging_rewards_begin[ + agent_id + ] + if agent_id in other._hanging_rewards_end: + self._hanging_actions_end[agent_id] = copy.deepcopy( + other._hanging_actions_end[agent_id] + ) + self._hanging_rewards_end[agent_id] = other._hanging_rewards_end[agent_id] + self._hanging_extra_model_outputs_end[agent_id] = copy.deepcopy( + other._hanging_extra_model_outputs_end[agent_id] + ) + + def _del_hanging(self, agent_id: AgentID) -> None: + """Deletes all hanging action, reward, extra_model_outputs of given agent.""" + self._hanging_rewards_begin.pop(agent_id, None) + + self._hanging_actions_end.pop(agent_id, None) + self._hanging_extra_model_outputs_end.pop(agent_id, None) + self._hanging_rewards_end.pop(agent_id, None) + + def _del_agent(self, agent_id: AgentID) -> None: + """Deletes all data of given agent from this episode.""" + self._del_hanging(agent_id) + self.agent_episodes.pop(agent_id, None) + self.agent_ids.discard(agent_id) + self.env_t_to_agent_t.pop(agent_id, None) + self._agent_to_module_mapping.pop(agent_id, None) + self.agent_t_started.pop(agent_id, None) + + def _get_inf_lookback_buffer_or_dict( + self, + agent_id: AgentID, + what: str, + extra_model_outputs_key: Optional[str] = None, + hanging_val: Optional[Any] = None, + filter_for_skip_indices=None, + ): + """Returns a single InfiniteLookbackBuffer or a dict of such. + + In case `what` is "extra_model_outputs" AND `extra_model_outputs_key` is None, + a dict is returned. In all other cases, a single InfiniteLookbackBuffer is + returned. + """ + inf_lookback_buffer_or_dict = inf_lookback_buffer = getattr( + self.agent_episodes[agent_id], what + ) + if what == "extra_model_outputs": + if extra_model_outputs_key is not None: + inf_lookback_buffer = inf_lookback_buffer_or_dict[ + extra_model_outputs_key + ] + elif inf_lookback_buffer_or_dict: + inf_lookback_buffer = next(iter(inf_lookback_buffer_or_dict.values())) + elif filter_for_skip_indices is not None: + return inf_lookback_buffer_or_dict, filter_for_skip_indices + else: + return inf_lookback_buffer_or_dict + + if filter_for_skip_indices is not None: + inf_lookback_buffer_len = ( + len(inf_lookback_buffer) + + inf_lookback_buffer.lookback + + (hanging_val is not None) + ) + ignore_last_ts = what not in ["observations", "infos"] + if isinstance(filter_for_skip_indices, list): + filter_for_skip_indices = [ + "S" if ignore_last_ts and i == inf_lookback_buffer_len else i + for i in filter_for_skip_indices + ] + elif ignore_last_ts and filter_for_skip_indices == inf_lookback_buffer_len: + filter_for_skip_indices = "S" + return inf_lookback_buffer_or_dict, filter_for_skip_indices + else: + return inf_lookback_buffer_or_dict + + @Deprecated(new="MultiAgentEpisode.is_numpy()", error=True) + def is_finalized(self): + pass + + @Deprecated(new="MultiAgentEpisode.to_numpy()", error=True) + def finalize(self): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3791226077a1a2bf6b497ceb8bd297126a72f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py @@ -0,0 +1,403 @@ +"""REST client to interact with a policy server. + +This client supports both local and remote policy inference modes. Local +inference is faster but causes more compute to be done on the client. +""" + +import logging +import threading +import time +from typing import Union, Optional + +import ray.cloudpickle as pickle +from ray.rllib.env.external_env import ExternalEnv +from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.typing import ( + MultiAgentDict, + EnvInfoDict, + EnvObsType, + EnvActionType, +) + +# Backward compatibility. +from ray.rllib.env.utils.external_env_protocol import RLlink as Commands + +logger = logging.getLogger(__name__) + +try: + import requests # `requests` is not part of stdlib. +except ImportError: + requests = None + logger.warning( + "Couldn't import `requests` library. Be sure to install it on" + " the client side." + ) + + +@OldAPIStack +class PolicyClient: + """REST client to interact with an RLlib policy server.""" + + def __init__( + self, + address: str, + inference_mode: str = "local", + update_interval: float = 10.0, + session: Optional[requests.Session] = None, + ): + """Create a PolicyClient instance. + + Args: + address: Server to connect to (e.g., "localhost:9090"). + inference_mode: Whether to use 'local' or 'remote' policy + inference for computing actions. + update_interval (float or None): If using 'local' inference mode, + the policy is refreshed after this many seconds have passed, + or None for manual control via client. + session (requests.Session or None): If available the session object + is used to communicate with the policy server. Using a session + can lead to speedups as connections are reused. It is the + responsibility of the creator of the session to close it. + """ + self.address = address + self.session = session + self.env: ExternalEnv = None + if inference_mode == "local": + self.local = True + self._setup_local_rollout_worker(update_interval) + elif inference_mode == "remote": + self.local = False + else: + raise ValueError("inference_mode must be either 'local' or 'remote'") + + def start_episode( + self, episode_id: Optional[str] = None, training_enabled: bool = True + ) -> str: + """Record the start of one or more episode(s). + + Args: + episode_id (Optional[str]): Unique string id for the episode or + None for it to be auto-assigned. + training_enabled: Whether to use experiences for this + episode to improve the policy. + + Returns: + episode_id: Unique string id for the episode. + """ + + if self.local: + self._update_local_policy() + return self.env.start_episode(episode_id, training_enabled) + + return self._send( + { + "episode_id": episode_id, + "command": Commands.START_EPISODE, + "training_enabled": training_enabled, + } + )["episode_id"] + + def get_action( + self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict] + ) -> Union[EnvActionType, MultiAgentDict]: + """Record an observation and get the on-policy action. + + Args: + episode_id: Episode id returned from start_episode(). + observation: Current environment observation. + + Returns: + action: Action from the env action space. + """ + + if self.local: + self._update_local_policy() + if isinstance(episode_id, (list, tuple)): + actions = { + eid: self.env.get_action(eid, observation[eid]) + for eid in episode_id + } + return actions + else: + return self.env.get_action(episode_id, observation) + else: + return self._send( + { + "command": Commands.GET_ACTION, + "observation": observation, + "episode_id": episode_id, + } + )["action"] + + def log_action( + self, + episode_id: str, + observation: Union[EnvObsType, MultiAgentDict], + action: Union[EnvActionType, MultiAgentDict], + ) -> None: + """Record an observation and (off-policy) action taken. + + Args: + episode_id: Episode id returned from start_episode(). + observation: Current environment observation. + action: Action for the observation. + """ + + if self.local: + self._update_local_policy() + return self.env.log_action(episode_id, observation, action) + + self._send( + { + "command": Commands.LOG_ACTION, + "observation": observation, + "action": action, + "episode_id": episode_id, + } + ) + + def log_returns( + self, + episode_id: str, + reward: float, + info: Union[EnvInfoDict, MultiAgentDict] = None, + multiagent_done_dict: Optional[MultiAgentDict] = None, + ) -> None: + """Record returns from the environment. + + The reward will be attributed to the previous action taken by the + episode. Rewards accumulate until the next action. If no reward is + logged before the next action, a reward of 0.0 is assumed. + + Args: + episode_id: Episode id returned from start_episode(). + reward: Reward from the environment. + info: Extra info dict. + multiagent_done_dict: Multi-agent done information. + """ + + if self.local: + self._update_local_policy() + if multiagent_done_dict is not None: + assert isinstance(reward, dict) + return self.env.log_returns( + episode_id, reward, info, multiagent_done_dict + ) + return self.env.log_returns(episode_id, reward, info) + + self._send( + { + "command": Commands.LOG_RETURNS, + "reward": reward, + "info": info, + "episode_id": episode_id, + "done": multiagent_done_dict, + } + ) + + def end_episode( + self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict] + ) -> None: + """Record the end of an episode. + + Args: + episode_id: Episode id returned from start_episode(). + observation: Current environment observation. + """ + + if self.local: + self._update_local_policy() + return self.env.end_episode(episode_id, observation) + + self._send( + { + "command": Commands.END_EPISODE, + "observation": observation, + "episode_id": episode_id, + } + ) + + def update_policy_weights(self) -> None: + """Query the server for new policy weights, if local inference is enabled.""" + self._update_local_policy(force=True) + + def _send(self, data): + payload = pickle.dumps(data) + + if self.session is None: + response = requests.post(self.address, data=payload) + else: + response = self.session.post(self.address, data=payload) + + if response.status_code != 200: + logger.error("Request failed {}: {}".format(response.text, data)) + response.raise_for_status() + parsed = pickle.loads(response.content) + return parsed + + def _setup_local_rollout_worker(self, update_interval): + self.update_interval = update_interval + self.last_updated = 0 + + logger.info("Querying server for rollout worker settings.") + kwargs = self._send( + { + "command": Commands.GET_WORKER_ARGS, + } + )["worker_args"] + (self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker( + kwargs, self._send + ) + self.env = self.rollout_worker.env + + def _update_local_policy(self, force=False): + assert self.inference_thread.is_alive() + if ( + self.update_interval + and time.time() - self.last_updated > self.update_interval + ) or force: + logger.info("Querying server for new policy weights.") + resp = self._send( + { + "command": Commands.GET_WEIGHTS, + } + ) + weights = resp["weights"] + global_vars = resp["global_vars"] + logger.info( + "Updating rollout worker weights and global vars {}.".format( + global_vars + ) + ) + self.rollout_worker.set_weights(weights, global_vars) + self.last_updated = time.time() + + +class _LocalInferenceThread(threading.Thread): + """Thread that handles experience generation (worker.sample() loop).""" + + def __init__(self, rollout_worker, send_fn): + super().__init__() + self.daemon = True + self.rollout_worker = rollout_worker + self.send_fn = send_fn + + def run(self): + try: + while True: + logger.info("Generating new batch of experiences.") + samples = self.rollout_worker.sample() + metrics = self.rollout_worker.get_metrics() + if isinstance(samples, MultiAgentBatch): + logger.info( + "Sending batch of {} env steps ({} agent steps) to " + "server.".format(samples.env_steps(), samples.agent_steps()) + ) + else: + logger.info( + "Sending batch of {} steps back to server.".format( + samples.count + ) + ) + self.send_fn( + { + "command": Commands.REPORT_SAMPLES, + "samples": samples, + "metrics": metrics, + } + ) + except Exception as e: + logger.error("Error: inference worker thread died!", e) + + +def _auto_wrap_external(real_env_creator): + """Wrap an environment in the ExternalEnv interface if needed. + + Args: + real_env_creator: Create an env given the env_config. + """ + + def wrapped_creator(env_config): + real_env = real_env_creator(env_config) + if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)): + logger.info( + "The env you specified is not a supported (sub-)type of " + "ExternalEnv. Attempting to convert it automatically to " + "ExternalEnv." + ) + + if isinstance(real_env, MultiAgentEnv): + external_cls = ExternalMultiAgentEnv + else: + external_cls = ExternalEnv + + class _ExternalEnvWrapper(external_cls): + def __init__(self, real_env): + super().__init__( + observation_space=real_env.observation_space, + action_space=real_env.action_space, + ) + + def run(self): + # Since we are calling methods on this class in the + # client, run doesn't need to do anything. + time.sleep(999999) + + return _ExternalEnvWrapper(real_env) + return real_env + + return wrapped_creator + + +def _create_embedded_rollout_worker(kwargs, send_fn): + """Create a local rollout worker and a thread that samples from it. + + Args: + kwargs: Args for the RolloutWorker constructor. + send_fn: Function to send a JSON request to the server. + """ + + # Since the server acts as an input datasource, we have to reset the + # input config to the default, which runs env rollouts. + kwargs = kwargs.copy() + kwargs["config"] = kwargs["config"].copy(copy_frozen=False) + config = kwargs["config"] + config.output = None + config.input_ = "sampler" + config.input_config = {} + + # If server has no env (which is the expected case): + # Generate a dummy ExternalEnv here using RandomEnv and the + # given observation/action spaces. + if config.env is None: + from ray.rllib.examples.envs.classes.random_env import ( + RandomEnv, + RandomMultiAgentEnv, + ) + + env_config = { + "action_space": config.action_space, + "observation_space": config.observation_space, + } + is_ma = config.is_multi_agent + kwargs["env_creator"] = _auto_wrap_external( + lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(env_config) + ) + # kwargs["config"].env = True + # Otherwise, use the env specified by the server args. + else: + real_env_creator = kwargs["env_creator"] + kwargs["env_creator"] = _auto_wrap_external(real_env_creator) + + logger.info("Creating rollout worker with kwargs={}".format(kwargs)) + from ray.rllib.evaluation.rollout_worker import RolloutWorker + + rollout_worker = RolloutWorker(**kwargs) + + inference_thread = _LocalInferenceThread(rollout_worker, send_fn) + inference_thread.start() + + return rollout_worker, inference_thread diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py new file mode 100644 index 0000000000000000000000000000000000000000..eedbe224e631408d0b4384dcceeb4c263919afb7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py @@ -0,0 +1,341 @@ +from collections import deque +from http.server import HTTPServer, SimpleHTTPRequestHandler +import logging +import queue +from socketserver import ThreadingMixIn +import threading +import time +import traceback + +from typing import List +import ray.cloudpickle as pickle +from ray.rllib.env.policy_client import ( + _create_embedded_rollout_worker, + Commands, +) +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.evaluation.metrics import RolloutMetrics +from ray.rllib.evaluation.sampler import SamplerInput +from ray.rllib.utils.typing import SampleBatchType + +logger = logging.getLogger(__name__) + + +@PublicAPI +class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader): + """REST policy server that acts as an offline data source. + + This launches a multi-threaded server that listens on the specified host + and port to serve policy requests and forward experiences to RLlib. For + high performance experience collection, it implements InputReader. + + For an example, run `examples/envs/external_envs/cartpole_server.py` along + with `examples/envs/external_envs/cartpole_client.py --inference-mode=local|remote`. + + WARNING: This class is not meant to be publicly exposed. Anyone that can + communicate with this server can execute arbitary code on the machine. Use + this with caution, in isolated environments, and at your own risk. + + .. testcode:: + :skipif: True + + import gymnasium as gym + from ray.rllib.algorithms.ppo import PPOConfig + from ray.rllib.env.policy_client import PolicyClient + from ray.rllib.env.policy_server_input import PolicyServerInput + addr, port = ... + config = ( + PPOConfig() + .api_stack( + enable_rl_module_and_learner=False, + enable_env_runner_and_connector_v2=False, + ) + .environment("CartPole-v1") + .offline_data( + input_=lambda ioctx: PolicyServerInput(ioctx, addr, port) + ) + # Run just 1 server (in the Algorithm's EnvRunnerGroup). + .env_runners(num_env_runners=0) + ) + algo = config.build() + while True: + algo.train() + client = PolicyClient( + "localhost:9900", inference_mode="local") + eps_id = client.start_episode() + env = gym.make("CartPole-v1") + obs, info = env.reset() + action = client.get_action(eps_id, obs) + _, reward, _, _, _ = env.step(action) + client.log_returns(eps_id, reward) + client.log_returns(eps_id, reward) + algo.stop() + """ + + @PublicAPI + def __init__( + self, + ioctx: IOContext, + address: str, + port: int, + idle_timeout: float = 3.0, + max_sample_queue_size: int = 20, + ): + """Create a PolicyServerInput. + + This class implements rllib.offline.InputReader, and can be used with + any Algorithm by configuring + + [AlgorithmConfig object] + .env_runners(num_env_runners=0) + .offline_data(input_=lambda ioctx: PolicyServerInput(ioctx, addr, port)) + + Note that by setting num_env_runners: 0, the algorithm will only create one + rollout worker / PolicyServerInput. Clients can connect to the launched + server using rllib.env.PolicyClient. You can increase the number of available + connections (ports) by setting num_env_runners to a larger number. The ports + used will then be `port` + the worker's index. + + Args: + ioctx: IOContext provided by RLlib. + address: Server addr (e.g., "localhost"). + port: Server port (e.g., 9900). + max_queue_size: The maximum size for the sample queue. Once full, will + purge (throw away) 50% of all samples, oldest first, and continue. + """ + + self.rollout_worker = ioctx.worker + # Protect ourselves from having a bottleneck on the server (learning) side. + # Once the queue (deque) is full, we throw away 50% (oldest + # samples first) of the samples, warn, and continue. + self.samples_queue = deque(maxlen=max_sample_queue_size) + self.metrics_queue = queue.Queue() + self.idle_timeout = idle_timeout + + # Forwards client-reported metrics directly into the local rollout + # worker. + if self.rollout_worker.sampler is not None: + # This is a bit of a hack since it is patching the get_metrics + # function of the sampler. + + def get_metrics(): + completed = [] + while True: + try: + completed.append(self.metrics_queue.get_nowait()) + except queue.Empty: + break + + return completed + + self.rollout_worker.sampler.get_metrics = get_metrics + else: + # If there is no sampler, act like if there would be one to collect + # metrics from + class MetricsDummySampler(SamplerInput): + """This sampler only maintains a queue to get metrics from.""" + + def __init__(self, metrics_queue): + """Initializes a MetricsDummySampler instance. + + Args: + metrics_queue: A queue of metrics + """ + self.metrics_queue = metrics_queue + + def get_data(self) -> SampleBatchType: + raise NotImplementedError + + def get_extra_batches(self) -> List[SampleBatchType]: + raise NotImplementedError + + def get_metrics(self) -> List[RolloutMetrics]: + """Returns metrics computed on a policy client rollout worker.""" + completed = [] + while True: + try: + completed.append(self.metrics_queue.get_nowait()) + except queue.Empty: + break + return completed + + self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue) + + # Create a request handler that receives commands from the clients + # and sends data and metrics into the queues. + handler = _make_handler( + self.rollout_worker, self.samples_queue, self.metrics_queue + ) + try: + import time + + time.sleep(1) + HTTPServer.__init__(self, (address, port), handler) + except OSError: + print(f"Creating a PolicyServer on {address}:{port} failed!") + import time + + time.sleep(1) + raise + + logger.info( + "Starting connector server at " f"{self.server_name}:{self.server_port}" + ) + + # Start the serving thread, listening on socket and handling commands. + serving_thread = threading.Thread(name="server", target=self.serve_forever) + serving_thread.daemon = True + serving_thread.start() + + # Start a dummy thread that puts empty SampleBatches on the queue, just + # in case we don't receive anything from clients (or there aren't + # any). The latter would block sample collection entirely otherwise, + # even if other workers' PolicyServerInput receive incoming data from + # actual clients. + heart_beat_thread = threading.Thread( + name="heart-beat", target=self._put_empty_sample_batch_every_n_sec + ) + heart_beat_thread.daemon = True + heart_beat_thread.start() + + @override(InputReader) + def next(self): + # Blocking wait until there is something in the deque. + while len(self.samples_queue) == 0: + time.sleep(0.1) + # Utilize last items first in order to remain as closely as possible + # to operating on-policy. + return self.samples_queue.pop() + + def _put_empty_sample_batch_every_n_sec(self): + # Places an empty SampleBatch every `idle_timeout` seconds onto the + # `samples_queue`. This avoids hanging of all RolloutWorkers parallel + # to this one in case this PolicyServerInput does not have incoming + # data (e.g. no client connected) and the driver algorithm uses parallel + # synchronous sampling (e.g. PPO). + while True: + time.sleep(self.idle_timeout) + self.samples_queue.append(SampleBatch()) + + +def _make_handler(rollout_worker, samples_queue, metrics_queue): + # Only used in remote inference mode. We must create a new rollout worker + # then since the original worker doesn't have the env properly wrapped in + # an ExternalEnv interface. + child_rollout_worker = None + inference_thread = None + lock = threading.Lock() + + def setup_child_rollout_worker(): + nonlocal lock + + with lock: + nonlocal child_rollout_worker + nonlocal inference_thread + + if child_rollout_worker is None: + ( + child_rollout_worker, + inference_thread, + ) = _create_embedded_rollout_worker( + rollout_worker.creation_args(), report_data + ) + child_rollout_worker.set_weights(rollout_worker.get_weights()) + + def report_data(data): + nonlocal child_rollout_worker + + batch = data["samples"] + batch.decompress_if_needed() + samples_queue.append(batch) + # Deque is full -> purge 50% (oldest samples) + if len(samples_queue) == samples_queue.maxlen: + logger.warning( + "PolicyServerInput queue is full! Purging half of the samples (oldest)." + ) + for _ in range(samples_queue.maxlen // 2): + samples_queue.popleft() + for rollout_metric in data["metrics"]: + metrics_queue.put(rollout_metric) + + if child_rollout_worker is not None: + child_rollout_worker.set_weights( + rollout_worker.get_weights(), rollout_worker.get_global_vars() + ) + + class Handler(SimpleHTTPRequestHandler): + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + + def do_POST(self): + content_len = int(self.headers.get("Content-Length"), 0) + raw_body = self.rfile.read(content_len) + parsed_input = pickle.loads(raw_body) + try: + response = self.execute_command(parsed_input) + self.send_response(200) + self.end_headers() + self.wfile.write(pickle.dumps(response)) + except Exception: + self.send_error(500, traceback.format_exc()) + + def execute_command(self, args): + command = args["command"] + response = {} + + # Local inference commands: + if command == Commands.GET_WORKER_ARGS: + logger.info("Sending worker creation args to client.") + response["worker_args"] = rollout_worker.creation_args() + elif command == Commands.GET_WEIGHTS: + logger.info("Sending worker weights to client.") + response["weights"] = rollout_worker.get_weights() + response["global_vars"] = rollout_worker.get_global_vars() + elif command == Commands.REPORT_SAMPLES: + logger.info( + "Got sample batch of size {} from client.".format( + args["samples"].count + ) + ) + report_data(args) + + # Remote inference commands: + elif command == Commands.START_EPISODE: + setup_child_rollout_worker() + assert inference_thread.is_alive() + response["episode_id"] = child_rollout_worker.env.start_episode( + args["episode_id"], args["training_enabled"] + ) + elif command == Commands.GET_ACTION: + assert inference_thread.is_alive() + response["action"] = child_rollout_worker.env.get_action( + args["episode_id"], args["observation"] + ) + elif command == Commands.LOG_ACTION: + assert inference_thread.is_alive() + child_rollout_worker.env.log_action( + args["episode_id"], args["observation"], args["action"] + ) + elif command == Commands.LOG_RETURNS: + assert inference_thread.is_alive() + if args["done"]: + child_rollout_worker.env.log_returns( + args["episode_id"], args["reward"], args["info"], args["done"] + ) + else: + child_rollout_worker.env.log_returns( + args["episode_id"], args["reward"], args["info"] + ) + elif command == Commands.END_EPISODE: + assert inference_thread.is_alive() + child_rollout_worker.env.end_episode( + args["episode_id"], args["observation"] + ) + else: + raise ValueError("Unknown command: {}".format(command)) + return response + + return Handler diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e388d50bcfbbb663f22e670b4707849400b6f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py @@ -0,0 +1,462 @@ +import gymnasium as gym +import logging +from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING + +import ray +from ray.util import log_once +from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN +from ray.rllib.utils.annotations import override, OldAPIStack +from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict + +if TYPE_CHECKING: + from ray.rllib.evaluation.rollout_worker import RolloutWorker + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class RemoteBaseEnv(BaseEnv): + """BaseEnv that executes its sub environments as @ray.remote actors. + + This provides dynamic batching of inference as observations are returned + from the remote simulator actors. Both single and multi-agent child envs + are supported, and envs can be stepped synchronously or asynchronously. + + NOTE: This class implicitly assumes that the remote envs are gym.Env's + + You shouldn't need to instantiate this class directly. It's automatically + inserted when you use the `remote_worker_envs=True` option in your + Algorithm's config. + """ + + def __init__( + self, + make_env: Callable[[int], EnvType], + num_envs: int, + multiagent: bool, + remote_env_batch_wait_ms: int, + existing_envs: Optional[List[ray.actor.ActorHandle]] = None, + worker: Optional["RolloutWorker"] = None, + restart_failed_sub_environments: bool = False, + ): + """Initializes a RemoteVectorEnv instance. + + Args: + make_env: Callable that produces a single (non-vectorized) env, + given the vector env index as only arg. + num_envs: The number of sub-environments to create for the + vectorization. + multiagent: Whether this is a multiagent env or not. + remote_env_batch_wait_ms: Time to wait for (ray.remote) + sub-environments to have new observations available when + polled. Only when none of the sub-environments is ready, + repeat the `ray.wait()` call until at least one sub-env + is ready. Then return only the observations of the ready + sub-environment(s). + existing_envs: Optional list of already created sub-environments. + These will be used as-is and only as many new sub-envs as + necessary (`num_envs - len(existing_envs)`) will be created. + worker: An optional RolloutWorker that owns the env. This is only + used if `remote_worker_envs` is True in your config and the + `on_sub_environment_created` custom callback needs to be + called on each created actor. + restart_failed_sub_environments: If True and any sub-environment (within + a vectorized env) throws any error during env stepping, the + Sampler will try to restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environment and without + the RolloutWorker crashing. + """ + + # Could be creating local or remote envs. + self.make_env = make_env + self.num_envs = num_envs + self.multiagent = multiagent + self.poll_timeout = remote_env_batch_wait_ms / 1000 + self.worker = worker + self.restart_failed_sub_environments = restart_failed_sub_environments + + # Already existing env objects (generated by the RolloutWorker). + existing_envs = existing_envs or [] + + # Whether the given `make_env` callable already returns ActorHandles + # (@ray.remote class instances) or not. + self.make_env_creates_actors = False + + self._observation_space = None + self._action_space = None + + # List of ray actor handles (each handle points to one @ray.remote + # sub-environment). + self.actors: Optional[List[ray.actor.ActorHandle]] = None + + # `self.make_env` already produces Actors: Use it directly. + if len(existing_envs) > 0 and isinstance( + existing_envs[0], ray.actor.ActorHandle + ): + self.make_env_creates_actors = True + self.actors = existing_envs + while len(self.actors) < self.num_envs: + self.actors.append(self._make_sub_env(len(self.actors))) + + # `self.make_env` produces gym.Envs (or children thereof, such + # as MultiAgentEnv): Need to auto-wrap it here. The problem with + # this is that custom methods wil get lost. If you would like to + # keep your custom methods in your envs, you should provide the + # env class directly in your config (w/o tune.register_env()), + # such that your class can directly be made a @ray.remote + # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`). + # Also, if `len(existing_envs) > 0`, we have to throw those away + # as we need to create ray actors here. + else: + self.actors = [self._make_sub_env(i) for i in range(self.num_envs)] + # Utilize existing envs for inferring observation/action spaces. + if len(existing_envs) > 0: + self._observation_space = existing_envs[0].observation_space + self._action_space = existing_envs[0].action_space + # Have to call actors' remote methods to get observation/action spaces. + else: + self._observation_space, self._action_space = ray.get( + [ + self.actors[0].observation_space.remote(), + self.actors[0].action_space.remote(), + ] + ) + + # Dict mapping object refs (return values of @ray.remote calls), + # whose actual values we are waiting for (via ray.wait in + # `self.poll()`) to their corresponding actor handles (the actors + # that created these return values). + # Call `reset()` on all @ray.remote sub-environment actors. + self.pending: Dict[ray.actor.ActorHandle] = { + a.reset.remote(): a for a in self.actors + } + + @override(BaseEnv) + def poll( + self, + ) -> Tuple[ + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + ]: + + # each keyed by env_id in [0, num_remote_envs) + obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {} + ready = [] + + # Wait for at least 1 env to be ready here. + while not ready: + ready, _ = ray.wait( + list(self.pending), + num_returns=len(self.pending), + timeout=self.poll_timeout, + ) + + # Get and return observations for each of the ready envs + env_ids = set() + for obj_ref in ready: + # Get the corresponding actor handle from our dict and remove the + # object ref (we will call `ray.get()` on it and it will no longer + # be "pending"). + actor = self.pending.pop(obj_ref) + env_id = self.actors.index(actor) + env_ids.add(env_id) + # Get the ready object ref (this may be return value(s) of + # `reset()` or `step()`). + try: + ret = ray.get(obj_ref) + except Exception as e: + # Something happened on the actor during stepping/resetting. + # Restart sub-environment (create new actor; close old one). + if self.restart_failed_sub_environments: + logger.exception(e.args[0]) + self.try_restart(env_id) + # Always return multi-agent data. + # Set the observation to the exception, no rewards, + # terminated[__all__]=True (episode will be discarded anyways), + # no infos. + ret = ( + e, + {}, + {"__all__": True}, + {"__all__": False}, + {}, + ) + # Do not try to restart. Just raise the error. + else: + raise e + + # Our sub-envs are simple Actor-turned gym.Envs or MultiAgentEnvs. + if self.make_env_creates_actors: + rew, terminated, truncated, info = None, None, None, None + if self.multiagent: + if isinstance(ret, tuple): + # Gym >= 0.26: `step()` result: Obs, reward, terminated, + # truncated, info. + if len(ret) == 5: + ob, rew, terminated, truncated, info = ret + # Gym >= 0.26: `reset()` result: Obs and infos. + elif len(ret) == 2: + ob = ret[0] + info = ret[1] + # Gym < 0.26? Something went wrong. + else: + raise AssertionError( + "Your gymnasium.Env seems to NOT return the correct " + "number of return values for `step()` (needs to return" + " 5 values: obs, reward, terminated, truncated and " + "info) or `reset()` (needs to return 2 values: obs and " + "info)!" + ) + # Gym < 0.26: `reset()` result: Only obs. + else: + raise AssertionError( + "Your gymnasium.Env seems to only return a single value " + "upon `reset()`! Must return 2 (obs AND infos)." + ) + else: + if isinstance(ret, tuple): + # `step()` result: Obs, reward, terminated, truncated, info. + if len(ret) == 5: + ob = {_DUMMY_AGENT_ID: ret[0]} + rew = {_DUMMY_AGENT_ID: ret[1]} + terminated = {_DUMMY_AGENT_ID: ret[2], "__all__": ret[2]} + truncated = {_DUMMY_AGENT_ID: ret[3], "__all__": ret[3]} + info = {_DUMMY_AGENT_ID: ret[4]} + # `reset()` result: Obs and infos. + elif len(ret) == 2: + ob = {_DUMMY_AGENT_ID: ret[0]} + info = {_DUMMY_AGENT_ID: ret[1]} + # Gym < 0.26? Something went wrong. + else: + raise AssertionError( + "Your gymnasium.Env seems to NOT return the correct " + "number of return values for `step()` (needs to return" + " 5 values: obs, reward, terminated, truncated and " + "info) or `reset()` (needs to return 2 values: obs and " + "info)!" + ) + # Gym < 0.26? + else: + raise AssertionError( + "Your gymnasium.Env seems to only return a single value " + "upon `reset()`! Must return 2 (obs and infos)." + ) + + # If this is a `reset()` return value, we only have the initial + # observations and infos: Set rewards, terminateds, and truncateds to + # dummy values. + if rew is None: + rew = {agent_id: 0 for agent_id in ob.keys()} + terminated = {"__all__": False} + truncated = {"__all__": False} + + # Our sub-envs are auto-wrapped (by `_RemoteSingleAgentEnv` or + # `_RemoteMultiAgentEnv`) and already behave like multi-agent + # envs. + else: + ob, rew, terminated, truncated, info = ret + obs[env_id] = ob + rewards[env_id] = rew + terminateds[env_id] = terminated + truncateds[env_id] = truncated + infos[env_id] = info + + logger.debug(f"Got obs batch for actors {env_ids}") + return obs, rewards, terminateds, truncateds, infos, {} + + @override(BaseEnv) + def send_actions(self, action_dict: MultiEnvDict) -> None: + for env_id, actions in action_dict.items(): + actor = self.actors[env_id] + # `actor` is a simple single-agent (remote) env, e.g. a gym.Env + # that was made a @ray.remote. + if not self.multiagent and self.make_env_creates_actors: + obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID]) + # `actor` is already a _RemoteSingleAgentEnv or + # _RemoteMultiAgentEnv wrapper + # (handles the multi-agent action_dict automatically). + else: + obj_ref = actor.step.remote(actions) + self.pending[obj_ref] = actor + + @override(BaseEnv) + def try_reset( + self, + env_id: Optional[EnvID] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[MultiEnvDict, MultiEnvDict]: + actor = self.actors[env_id] + obj_ref = actor.reset.remote(seed=seed, options=options) + + self.pending[obj_ref] = actor + # Because this env type does not support synchronous reset requests (with + # immediate return value), we return ASYNC_RESET_RETURN here to indicate + # that the reset results will be available via the next `poll()` call. + return ASYNC_RESET_RETURN, ASYNC_RESET_RETURN + + @override(BaseEnv) + def try_restart(self, env_id: Optional[EnvID] = None) -> None: + # Try closing down the old (possibly faulty) sub-env, but ignore errors. + try: + # Close the env on the remote side. + self.actors[env_id].close.remote() + except Exception as e: + if log_once("close_sub_env"): + logger.warning( + "Trying to close old and replaced sub-environment (at vector " + f"index={env_id}), but closing resulted in error:\n{e}" + ) + + # Terminate the actor itself to free up its resources. + self.actors[env_id].__ray_terminate__.remote() + + # Re-create a new sub-environment. + self.actors[env_id] = self._make_sub_env(env_id) + + @override(BaseEnv) + def stop(self) -> None: + if self.actors is not None: + for actor in self.actors: + actor.__ray_terminate__.remote() + + @override(BaseEnv) + def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]: + if as_dict: + return {env_id: actor for env_id, actor in enumerate(self.actors)} + return self.actors + + @property + @override(BaseEnv) + def observation_space(self) -> gym.spaces.Dict: + return self._observation_space + + @property + @override(BaseEnv) + def action_space(self) -> gym.Space: + return self._action_space + + def _make_sub_env(self, idx: Optional[int] = None): + """Re-creates a sub-environment at the new index.""" + + # Our `make_env` creates ray actors directly. + if self.make_env_creates_actors: + sub_env = self.make_env(idx) + if self.worker is not None: + self.worker.callbacks.on_sub_environment_created( + worker=self.worker, + sub_environment=self.actors[idx], + env_context=self.worker.env_context.copy_with_overrides( + vector_index=idx + ), + ) + + # Our `make_env` returns actual envs -> Have to convert them into actors + # using our utility wrapper classes. + else: + + def make_remote_env(i): + logger.info("Launching env {} in remote actor".format(i)) + if self.multiagent: + sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i) + else: + sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i) + + if self.worker is not None: + self.worker.callbacks.on_sub_environment_created( + worker=self.worker, + sub_environment=sub_env, + env_context=self.worker.env_context.copy_with_overrides( + vector_index=i + ), + ) + + return sub_env + + sub_env = make_remote_env(idx) + + return sub_env + + @override(BaseEnv) + def get_agent_ids(self) -> Set[AgentID]: + if self.multiagent: + return ray.get(self.actors[0].get_agent_ids.remote()) + else: + return {_DUMMY_AGENT_ID} + + +@ray.remote(num_cpus=0) +class _RemoteMultiAgentEnv: + """Wrapper class for making a multi-agent env a remote actor.""" + + def __init__(self, make_env, i): + self.env = make_env(i) + self.agent_ids = set() + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + obs, info = self.env.reset(seed=seed, options=options) + + # each keyed by agent_id in the env + rew = {} + for agent_id in obs.keys(): + self.agent_ids.add(agent_id) + rew[agent_id] = 0.0 + terminated = {"__all__": False} + truncated = {"__all__": False} + return obs, rew, terminated, truncated, info + + def step(self, action_dict): + return self.env.step(action_dict) + + # Defining these 2 functions that way this information can be queried + # with a call to ray.get(). + def observation_space(self): + return self.env.observation_space + + def action_space(self): + return self.env.action_space + + def get_agent_ids(self) -> Set[AgentID]: + return self.agent_ids + + +@ray.remote(num_cpus=0) +class _RemoteSingleAgentEnv: + """Wrapper class for making a gym env a remote actor.""" + + def __init__(self, make_env, i): + self.env = make_env(i) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + obs_and_info = self.env.reset(seed=seed, options=options) + + obs = {_DUMMY_AGENT_ID: obs_and_info[0]} + info = {_DUMMY_AGENT_ID: obs_and_info[1]} + + rew = {_DUMMY_AGENT_ID: 0.0} + terminated = {"__all__": False} + truncated = {"__all__": False} + return obs, rew, terminated, truncated, info + + def step(self, action): + results = self.env.step(action[_DUMMY_AGENT_ID]) + + obs, rew, terminated, truncated, info = [{_DUMMY_AGENT_ID: x} for x in results] + + terminated["__all__"] = terminated[_DUMMY_AGENT_ID] + truncated["__all__"] = truncated[_DUMMY_AGENT_ID] + + return obs, rew, terminated, truncated, info + + # Defining these 2 functions that way this information can be queried + # with a call to ray.get(). + def observation_space(self): + return self.env.observation_space + + def action_space(self): + return self.env.action_space diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c8fc75f9d77d8c818b1df355e36f707fd72108 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py @@ -0,0 +1,853 @@ +from collections import defaultdict +from functools import partial +import logging +import time +from typing import Collection, DefaultDict, List, Optional, Union + +import gymnasium as gym +from gymnasium.wrappers.vector import DictInfoToList + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.callbacks.callbacks import RLlibCallback +from ray.rllib.callbacks.utils import make_callback +from ray.rllib.core import ( + COMPONENT_ENV_TO_MODULE_CONNECTOR, + COMPONENT_MODULE_TO_ENV_CONNECTOR, + COMPONENT_RL_MODULE, + DEFAULT_AGENT_ID, + DEFAULT_MODULE_ID, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec +from ray.rllib.env import INPUT_ENV_SPACES +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.env.utils import _gym_env_creator +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import override +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import get_device +from ray.rllib.utils.metrics import ( + EPISODE_DURATION_SEC_MEAN, + EPISODE_LEN_MAX, + EPISODE_LEN_MEAN, + EPISODE_LEN_MIN, + EPISODE_RETURN_MAX, + EPISODE_RETURN_MEAN, + EPISODE_RETURN_MIN, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_EPISODES, + NUM_EPISODES_LIFETIME, + NUM_MODULE_STEPS_SAMPLED, + NUM_MODULE_STEPS_SAMPLED_LIFETIME, + SAMPLE_TIMER, + TIME_BETWEEN_SAMPLING, + WEIGHTS_SEQ_NO, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.spaces.space_utils import unbatch +from ray.rllib.utils.typing import EpisodeID, ResultDict, StateDict +from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.util.annotations import PublicAPI + +logger = logging.getLogger("ray.rllib") + + +# TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself +# a Checkpointable. Currently, only some of its subclasses are Checkpointables. +@PublicAPI(stability="alpha") +class SingleAgentEnvRunner(EnvRunner, Checkpointable): + """The generic environment runner for the single agent case.""" + + @override(EnvRunner) + def __init__(self, *, config: AlgorithmConfig, **kwargs): + """Initializes a SingleAgentEnvRunner instance. + + Args: + config: An `AlgorithmConfig` object containing all settings needed to + build this `EnvRunner` class. + """ + super().__init__(config=config) + + self.worker_index: int = kwargs.get("worker_index") + self.num_workers: int = kwargs.get("num_workers", self.config.num_env_runners) + self.tune_trial_id: str = kwargs.get("tune_trial_id") + + # Create a MetricsLogger object for logging custom stats. + self.metrics = MetricsLogger() + + # Create our callbacks object. + self._callbacks: List[RLlibCallback] = [ + cls() for cls in force_list(self.config.callbacks_class) + ] + + # Set device. + self._device = get_device( + self.config, + 0 if not self.worker_index else self.config.num_gpus_per_env_runner, + ) + + # Create the vectorized gymnasium env. + self.env: Optional[gym.vector.VectorEnvWrapper] = None + self.num_envs: int = 0 + self.make_env() + + # Create the env-to-module connector pipeline. + self._env_to_module = self.config.build_env_to_module_connector( + self.env, device=self._device + ) + # Cached env-to-module results taken at the end of a `_sample_timesteps()` + # call to make sure the final observation (before an episode cut) gets properly + # processed (and maybe postprocessed and re-stored into the episode). + # For example, if we had a connector that normalizes observations and directly + # re-inserts these new obs back into the episode, the last observation in each + # sample call would NOT be processed, which could be very harmful in cases, + # in which value function bootstrapping of those (truncation) observations is + # required in the learning step. + self._cached_to_module = None + + # Create the RLModule. + self.module: Optional[RLModule] = None + self.make_module() + + # Create the module-to-env connector pipeline. + self._module_to_env = self.config.build_module_to_env_connector(self.env) + + # This should be the default. + self._needs_initial_reset: bool = True + self._episodes: List[Optional[SingleAgentEpisode]] = [ + None for _ in range(self.num_envs) + ] + self._shared_data = None + + self._done_episodes_for_metrics: List[SingleAgentEpisode] = [] + self._ongoing_episodes_for_metrics: DefaultDict[ + EpisodeID, List[SingleAgentEpisode] + ] = defaultdict(list) + self._weights_seq_no: int = 0 + + # Measures the time passed between returning from `sample()` + # and receiving the next `sample()` request from the user. + self._time_after_sampling = None + + @override(EnvRunner) + def sample( + self, + *, + num_timesteps: int = None, + num_episodes: int = None, + explore: bool = None, + random_actions: bool = False, + force_reset: bool = False, + ) -> List[SingleAgentEpisode]: + """Runs and returns a sample (n timesteps or m episodes) on the env(s). + + Args: + num_timesteps: The number of timesteps to sample during this call. + Note that only one of `num_timetseps` or `num_episodes` may be provided. + num_episodes: The number of episodes to sample during this call. + Note that only one of `num_timetseps` or `num_episodes` may be provided. + explore: If True, will use the RLModule's `forward_exploration()` + method to compute actions. If False, will use the RLModule's + `forward_inference()` method. If None (default), will use the `explore` + boolean setting from `self.config` passed into this EnvRunner's + constructor. You can change this setting in your config via + `config.env_runners(explore=True|False)`. + random_actions: If True, actions will be sampled randomly (from the action + space of the environment). If False (default), actions or action + distribution parameters are computed by the RLModule. + force_reset: Whether to force-reset all (vector) environments before + sampling. Useful if you would like to collect a clean slate of new + episodes via this call. Note that when sampling n episodes + (`num_episodes != None`), this is fixed to True. + + Returns: + A list of `SingleAgentEpisode` instances, carrying the sampled data. + """ + assert not (num_timesteps is not None and num_episodes is not None) + + # Log time between `sample()` requests. + if self._time_after_sampling is not None: + self.metrics.log_value( + key=TIME_BETWEEN_SAMPLING, + value=time.perf_counter() - self._time_after_sampling, + ) + + # Log current weight seq no. + self.metrics.log_value( + key=WEIGHTS_SEQ_NO, + value=self._weights_seq_no, + window=1, + ) + + with self.metrics.log_time(SAMPLE_TIMER): + # If no execution details are provided, use the config to try to infer the + # desired timesteps/episodes to sample and exploration behavior. + if explore is None: + explore = self.config.explore + if ( + num_timesteps is None + and num_episodes is None + and self.config.batch_mode == "truncate_episodes" + ): + num_timesteps = ( + self.config.get_rollout_fragment_length(self.worker_index) + * self.num_envs + ) + + # Sample n timesteps. + if num_timesteps is not None: + samples = self._sample( + num_timesteps=num_timesteps, + explore=explore, + random_actions=random_actions, + force_reset=force_reset, + ) + # Sample m episodes. + elif num_episodes is not None: + samples = self._sample( + num_episodes=num_episodes, + explore=explore, + random_actions=random_actions, + ) + # For complete episodes mode, sample as long as the number of timesteps + # done is smaller than the `train_batch_size`. + else: + samples = self._sample( + num_episodes=self.num_envs, + explore=explore, + random_actions=random_actions, + ) + + # Make the `on_sample_end` callback. + make_callback( + "on_sample_end", + callbacks_objects=self._callbacks, + callbacks_functions=self.config.callbacks_on_sample_end, + kwargs=dict( + env_runner=self, + metrics_logger=self.metrics, + samples=samples, + ), + ) + + self._time_after_sampling = time.perf_counter() + + return samples + + def _sample( + self, + *, + num_timesteps: Optional[int] = None, + num_episodes: Optional[int] = None, + explore: bool, + random_actions: bool = False, + force_reset: bool = False, + ) -> List[SingleAgentEpisode]: + """Helper method to sample n timesteps or m episodes.""" + + done_episodes_to_return: List[SingleAgentEpisode] = [] + + # Have to reset the env (on all vector sub_envs). + if force_reset or num_episodes is not None or self._needs_initial_reset: + episodes = self._episodes = [None for _ in range(self.num_envs)] + shared_data = self._shared_data = {} + self._reset_envs(episodes, shared_data, explore) + # We just reset the env. Don't have to force this again in the next + # call to `self._sample_timesteps()`. + self._needs_initial_reset = False + else: + episodes = self._episodes + shared_data = self._shared_data + + if num_episodes is not None: + self._needs_initial_reset = True + + # Loop through `num_timesteps` timesteps or `num_episodes` episodes. + ts = 0 + eps = 0 + while ( + (ts < num_timesteps) if num_timesteps is not None else (eps < num_episodes) + ): + # Act randomly. + if random_actions: + to_env = { + Columns.ACTIONS: self.env.action_space.sample(), + } + # Compute an action using the RLModule. + else: + # Env-to-module connector (already cached). + to_module = self._cached_to_module + assert to_module is not None + self._cached_to_module = None + + # RLModule forward pass: Explore or not. + if explore: + # Global env steps sampled are (roughly) this EnvRunner's lifetime + # count times the number of env runners in the algo. + global_env_steps_lifetime = ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + + ts + ) * (self.config.num_env_runners or 1) + to_env = self.module.forward_exploration( + to_module, t=global_env_steps_lifetime + ) + else: + to_env = self.module.forward_inference(to_module) + + # Module-to-env connector. + to_env = self._module_to_env( + rl_module=self.module, + batch=to_env, + episodes=episodes, + explore=explore, + shared_data=shared_data, + metrics=self.metrics, + ) + + # Extract the (vectorized) actions (to be sent to the env) from the + # module/connector output. Note that these actions are fully ready (e.g. + # already unsquashed/clipped) to be sent to the environment) and might not + # be identical to the actions produced by the RLModule/distribution, which + # are the ones stored permanently in the episode objects. + actions = to_env.pop(Columns.ACTIONS) + actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions) + # Try stepping the environment. + results = self._try_env_step(actions_for_env) + if results == ENV_STEP_FAILURE: + return self._sample( + num_timesteps=num_timesteps, + num_episodes=num_episodes, + explore=explore, + random_actions=random_actions, + force_reset=True, + ) + observations, rewards, terminateds, truncateds, infos = results + observations, actions = unbatch(observations), unbatch(actions) + + call_on_episode_start = set() + for env_index in range(self.num_envs): + extra_model_output = {k: v[env_index] for k, v in to_env.items()} + extra_model_output[WEIGHTS_SEQ_NO] = self._weights_seq_no + + # Episode has no data in it yet -> Was just reset and needs to be called + # with its `add_env_reset()` method. + if not self._episodes[env_index].is_reset: + episodes[env_index].add_env_reset( + observation=observations[env_index], + infos=infos[env_index], + ) + call_on_episode_start.add(env_index) + + # Call `add_env_step()` method on episode. + else: + # Only increase ts when we actually stepped (not reset'd as a reset + # does not count as a timestep). + ts += 1 + episodes[env_index].add_env_step( + observation=observations[env_index], + action=actions[env_index], + reward=rewards[env_index], + infos=infos[env_index], + terminated=terminateds[env_index], + truncated=truncateds[env_index], + extra_model_outputs=extra_model_output, + ) + + # Env-to-module connector pass (cache results as we will do the RLModule + # forward pass only in the next `while`-iteration. + if self.module is not None: + self._cached_to_module = self._env_to_module( + episodes=episodes, + explore=explore, + rl_module=self.module, + shared_data=shared_data, + metrics=self.metrics, + ) + + for env_index in range(self.num_envs): + # Call `on_episode_start()` callback (always after reset). + if env_index in call_on_episode_start: + self._make_on_episode_callback( + "on_episode_start", env_index, episodes + ) + # Make the `on_episode_step` callbacks. + else: + self._make_on_episode_callback( + "on_episode_step", env_index, episodes + ) + + # Episode is done. + if episodes[env_index].is_done: + eps += 1 + + # Make the `on_episode_end` callbacks (before finalizing the episode + # object). + self._make_on_episode_callback( + "on_episode_end", env_index, episodes + ) + + # Numpy'ize the episode. + if self.config.episodes_to_numpy: + # Any possibly compress observations. + done_episodes_to_return.append(episodes[env_index].to_numpy()) + # Leave episode as lists of individual (obs, action, etc..) items. + else: + done_episodes_to_return.append(episodes[env_index]) + + # Also early-out if we reach the number of episodes within this + # for-loop. + if eps == num_episodes: + break + + # Create a new episode object with no data in it and execute + # `on_episode_created` callback (before the `env.reset()` call). + episodes[env_index] = SingleAgentEpisode( + observation_space=self.env.single_observation_space, + action_space=self.env.single_action_space, + ) + self._make_on_episode_callback( + "on_episode_created", + env_index, + episodes, + ) + + # Return done episodes ... + self._done_episodes_for_metrics.extend(done_episodes_to_return) + # ... and all ongoing episode chunks. + + # Also, make sure we start new episode chunks (continuing the ongoing episodes + # from the to-be-returned chunks). + ongoing_episodes_to_return = [] + # Only if we are doing individual timesteps: We have to maybe cut an ongoing + # episode and continue building it on the next call to `sample()`. + if num_timesteps is not None: + ongoing_episodes_continuations = [ + eps.cut(len_lookback_buffer=self.config.episode_lookback_horizon) + for eps in self._episodes + ] + + for eps in self._episodes: + # Just started Episodes do not have to be returned. There is no data + # in them anyway. + if eps.t == 0: + continue + eps.validate() + self._ongoing_episodes_for_metrics[eps.id_].append(eps) + + # Numpy'ize the episode. + if self.config.episodes_to_numpy: + # Any possibly compress observations. + ongoing_episodes_to_return.append(eps.to_numpy()) + # Leave episode as lists of individual (obs, action, etc..) items. + else: + ongoing_episodes_to_return.append(eps) + + # Continue collecting into the cut Episode chunks. + self._episodes = ongoing_episodes_continuations + + self._increase_sampled_metrics(ts, len(done_episodes_to_return)) + + # Return collected episode data. + return done_episodes_to_return + ongoing_episodes_to_return + + @override(EnvRunner) + def get_spaces(self): + return { + INPUT_ENV_SPACES: (self.env.observation_space, self.env.action_space), + DEFAULT_MODULE_ID: ( + self._env_to_module.observation_space, + self.env.single_action_space, + ), + } + + @override(EnvRunner) + def get_metrics(self) -> ResultDict: + # Compute per-episode metrics (only on already completed episodes). + for eps in self._done_episodes_for_metrics: + assert eps.is_done + episode_length = len(eps) + episode_return = eps.get_return() + episode_duration_s = eps.get_duration_s() + # Don't forget about the already returned chunks of this episode. + if eps.id_ in self._ongoing_episodes_for_metrics: + for eps2 in self._ongoing_episodes_for_metrics[eps.id_]: + episode_length += len(eps2) + episode_return += eps2.get_return() + episode_duration_s += eps2.get_duration_s() + del self._ongoing_episodes_for_metrics[eps.id_] + + self._log_episode_metrics( + episode_length, episode_return, episode_duration_s + ) + + # Now that we have logged everything, clear cache of done episodes. + self._done_episodes_for_metrics.clear() + + # Return reduced metrics. + return self.metrics.reduce() + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + state = { + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + ), + } + + if self._check_component(COMPONENT_RL_MODULE, components, not_components): + state[COMPONENT_RL_MODULE] = self.module.get_state( + components=self._get_subcomponents(COMPONENT_RL_MODULE, components), + not_components=self._get_subcomponents( + COMPONENT_RL_MODULE, not_components + ), + **kwargs, + ) + state[WEIGHTS_SEQ_NO] = self._weights_seq_no + if self._check_component( + COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components + ): + state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state() + if self._check_component( + COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components + ): + state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state() + + return state + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + if COMPONENT_ENV_TO_MODULE_CONNECTOR in state: + self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR]) + if COMPONENT_MODULE_TO_ENV_CONNECTOR in state: + self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR]) + + # Update the RLModule state. + if COMPONENT_RL_MODULE in state: + # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the + # update. + weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0) + + # Only update the weigths, if this is the first synchronization or + # if the weights of this `EnvRunner` lacks behind the actual ones. + if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no: + rl_module_state = state[COMPONENT_RL_MODULE] + if ( + isinstance(rl_module_state, dict) + and DEFAULT_MODULE_ID in rl_module_state + ): + rl_module_state = rl_module_state[DEFAULT_MODULE_ID] + self.module.set_state(rl_module_state) + + # Update our weights_seq_no, if the new one is > 0. + if weights_seq_no > 0: + self._weights_seq_no = weights_seq_no + + # Update our lifetime counters. + if NUM_ENV_STEPS_SAMPLED_LIFETIME in state: + self.metrics.set_value( + key=NUM_ENV_STEPS_SAMPLED_LIFETIME, + value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME], + reduce="sum", + with_throughput=True, + ) + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args + {"config": self.config}, # **kwargs + ) + + @override(Checkpointable) + def get_metadata(self): + metadata = Checkpointable.get_metadata(self) + metadata.update( + { + # TODO (sven): Maybe add serialized (JSON-writable) config here? + } + ) + return metadata + + @override(Checkpointable) + def get_checkpointable_components(self): + return [ + (COMPONENT_RL_MODULE, self.module), + (COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module), + (COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env), + ] + + @override(EnvRunner) + def assert_healthy(self): + """Checks that self.__init__() has been completed properly. + + Ensures that the instances has a `MultiRLModule` and an + environment defined. + + Raises: + AssertionError: If the EnvRunner Actor has NOT been properly initialized. + """ + # Make sure, we have built our gym.vector.Env and RLModule properly. + assert self.env and hasattr(self, "module") + + @override(EnvRunner) + def make_env(self) -> None: + """Creates a vectorized gymnasium env and stores it in `self.env`. + + Note that users can change the EnvRunner's config (e.g. change + `self.config.env_config`) and then call this method to create new environments + with the updated configuration. + """ + # If an env already exists, try closing it first (to allow it to properly + # cleanup). + if self.env is not None: + try: + self.env.close() + except Exception as e: + logger.warning( + "Tried closing the existing env, but failed with error: " + f"{e.args[0]}" + ) + + env_ctx = self.config.env_config + if not isinstance(env_ctx, EnvContext): + env_ctx = EnvContext( + env_ctx, + worker_index=self.worker_index, + num_workers=self.num_workers, + remote=self.config.remote_worker_envs, + ) + + # No env provided -> Error. + if not self.config.env: + raise ValueError( + "`config.env` is not provided! You should provide a valid environment " + "to your config through `config.environment([env descriptor e.g. " + "'CartPole-v1'])`." + ) + # Register env for the local context. + # Note, `gym.register` has to be called on each worker. + elif isinstance(self.config.env, str) and _global_registry.contains( + ENV_CREATOR, self.config.env + ): + entry_point = partial( + _global_registry.get(ENV_CREATOR, self.config.env), + env_ctx, + ) + else: + entry_point = partial( + _gym_env_creator, + env_descriptor=self.config.env, + env_context=env_ctx, + ) + gym.register("rllib-single-agent-env-v0", entry_point=entry_point) + vectorize_mode = self.config.gym_env_vectorize_mode + + self.env = DictInfoToList( + gym.make_vec( + "rllib-single-agent-env-v0", + num_envs=self.config.num_envs_per_env_runner, + vectorization_mode=( + vectorize_mode + if isinstance(vectorize_mode, gym.envs.registration.VectorizeMode) + else gym.envs.registration.VectorizeMode(vectorize_mode.lower()) + ), + ) + ) + + self.num_envs: int = self.env.num_envs + assert self.num_envs == self.config.num_envs_per_env_runner + + # Set the flag to reset all envs upon the next `sample()` call. + self._needs_initial_reset = True + + # Call the `on_environment_created` callback. + make_callback( + "on_environment_created", + callbacks_objects=self._callbacks, + callbacks_functions=self.config.callbacks_on_environment_created, + kwargs=dict( + env_runner=self, + metrics_logger=self.metrics, + env=self.env.unwrapped, + env_context=env_ctx, + ), + ) + + @override(EnvRunner) + def make_module(self): + try: + module_spec: RLModuleSpec = self.config.get_rl_module_spec( + env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True + ) + # Build the module from its spec. + self.module = module_spec.build() + + # Move the RLModule to our device. + # TODO (sven): In order to make this framework-agnostic, we should maybe + # make the RLModule.build() method accept a device OR create an additional + # `RLModule.to()` override. + self.module.to(self._device) + + # If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner + # will not have an RLModule, but might still be usable with random actions. + except NotImplementedError: + self.module = None + + @override(EnvRunner) + def stop(self): + # Close our env object via gymnasium's API. + self.env.close() + + def _reset_envs(self, episodes, shared_data, explore): + # Create n new episodes and make the `on_episode_created` callbacks. + for env_index in range(self.num_envs): + self._new_episode(env_index, episodes) + + # Erase all cached ongoing episodes (these will never be completed and + # would thus never be returned/cleaned by `get_metrics` and cause a memory + # leak). + self._ongoing_episodes_for_metrics.clear() + + # Try resetting the environment. + # TODO (simon): Check, if we need here the seed from the config. + observations, infos = self._try_env_reset() + observations = unbatch(observations) + + # Set initial obs and infos in the episodes. + for env_index in range(self.num_envs): + episodes[env_index].add_env_reset( + observation=observations[env_index], + infos=infos[env_index], + ) + + # Run the env-to-module connector to make sure the reset-obs/infos have + # properly been processed (if applicable). + self._cached_to_module = None + if self.module: + self._cached_to_module = self._env_to_module( + rl_module=self.module, + episodes=episodes, + explore=explore, + shared_data=shared_data, + metrics=self.metrics, + ) + + # Call `on_episode_start()` callbacks (always after reset). + for env_index in range(self.num_envs): + self._make_on_episode_callback("on_episode_start", env_index, episodes) + + def _new_episode(self, env_index, episodes=None): + episodes = episodes if episodes is not None else self._episodes + episodes[env_index] = SingleAgentEpisode( + observation_space=self.env.single_observation_space, + action_space=self.env.single_action_space, + ) + self._make_on_episode_callback("on_episode_created", env_index, episodes) + + def _make_on_episode_callback(self, which: str, idx: int, episodes): + make_callback( + which, + callbacks_objects=self._callbacks, + callbacks_functions=getattr(self.config, f"callbacks_{which}"), + kwargs=dict( + episode=episodes[idx], + env_runner=self, + metrics_logger=self.metrics, + env=self.env.unwrapped, + rl_module=self.module, + env_index=idx, + ), + ) + + def _increase_sampled_metrics(self, num_steps, num_episodes_completed): + # Per sample cycle stats. + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID), + num_steps, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_MODULE_STEPS_SAMPLED, DEFAULT_MODULE_ID), + num_steps, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_EPISODES, + num_episodes_completed, + reduce="sum", + clear_on_reduce=True, + ) + # Lifetime stats. + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + num_steps, + reduce="sum", + with_throughput=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID), + num_steps, + reduce="sum", + ) + self.metrics.log_value( + (NUM_MODULE_STEPS_SAMPLED_LIFETIME, DEFAULT_MODULE_ID), + num_steps, + reduce="sum", + ) + self.metrics.log_value( + NUM_EPISODES_LIFETIME, + num_episodes_completed, + reduce="sum", + ) + return num_steps + + def _log_episode_metrics(self, length, ret, sec): + # Log general episode metrics. + # To mimic the old API stack behavior, we'll use `window` here for + # these particular stats (instead of the default EMA). + win = self.config.metrics_num_episodes_for_smoothing + self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win) + self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win) + self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win) + # Per-agent returns. + self.metrics.log_value( + ("agent_episode_returns_mean", DEFAULT_AGENT_ID), ret, window=win + ) + # Per-RLModule returns. + self.metrics.log_value( + ("module_episode_returns_mean", DEFAULT_MODULE_ID), ret, window=win + ) + + # For some metrics, log min/max as well. + self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min", window=win) + self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min", window=win) + self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max", window=win) + self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max", window=win) + + @Deprecated( + new="SingleAgentEnvRunner.get_state(components='rl_module')", + error=True, + ) + def get_weights(self, *args, **kwargs): + pass + + @Deprecated(new="SingleAgentEnvRunner.set_state()", error=True) + def set_weights(self, *args, **kwargs): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py new file mode 100644 index 0000000000000000000000000000000000000000..76556b67d95053f07a4d81bcacbba4b598d8f293 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py @@ -0,0 +1,1862 @@ +import functools +from collections import defaultdict +import numpy as np +import time +import uuid + +import gymnasium as gym +from gymnasium.core import ActType, ObsType +from typing import Any, Dict, List, Optional, SupportsFloat, Union + +from ray.rllib.core.columns import Columns +from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.typing import AgentID, ModuleID +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class SingleAgentEpisode: + """A class representing RL environment episodes for individual agents. + + SingleAgentEpisode stores observations, info dicts, actions, rewards, and all + module outputs (e.g. state outs, action logp, etc..) for an individual agent within + some single-agent or multi-agent environment. + The two main APIs to add data to an ongoing episode are the `add_env_reset()` + and `add_env_step()` methods, which should be called passing the outputs of the + respective gym.Env API calls: `env.reset()` and `env.step()`. + + A SingleAgentEpisode might also only represent a chunk of an episode, which is + useful for cases, in which partial (non-complete episode) sampling is performed + and collected episode data has to be returned before the actual gym.Env episode has + finished (see `SingleAgentEpisode.cut()`). In order to still maintain visibility + onto past experiences within such a "cut" episode, SingleAgentEpisode instances + can have a "lookback buffer" of n timesteps at their beginning (left side), which + solely exists for the purpose of compiling extra data (e.g. "prev. reward"), but + is not considered part of the finished/packaged episode (b/c the data in the + lookback buffer is already part of a previous episode chunk). + + Powerful getter methods, such as `get_observations()` help collect different types + of data from the episode at individual time indices or time ranges, including the + "lookback buffer" range described above. For example, to extract the last 4 rewards + of an ongoing episode, one can call `self.get_rewards(slice(-4, None))` or + `self.rewards[-4:]`. This would work, even if the ongoing SingleAgentEpisode is + a continuation chunk from a much earlier started episode, as long as it has a + lookback buffer size of sufficient size. + + Examples: + + .. testcode:: + + import gymnasium as gym + import numpy as np + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + # Construct a new episode (without any data in it yet). + episode = SingleAgentEpisode() + assert len(episode) == 0 + + # Fill the episode with some data (10 timesteps). + env = gym.make("CartPole-v1") + obs, infos = env.reset() + episode.add_env_reset(obs, infos) + + # Even with the initial obs/infos, the episode is still considered len=0. + assert len(episode) == 0 + for _ in range(5): + action = env.action_space.sample() + obs, reward, term, trunc, infos = env.step(action) + episode.add_env_step( + observation=obs, + action=action, + reward=reward, + terminated=term, + truncated=trunc, + infos=infos, + ) + assert len(episode) == 5 + + # We can now access information from the episode via the getter APIs. + + # Get the last 3 rewards (in a batch of size 3). + episode.get_rewards(slice(-3, None)) # same as `episode.rewards[-3:]` + + # Get the most recent action (single item, not batched). + # This works regardless of the action space or whether the episode has + # been numpy'ized or not (see below). + episode.get_actions(-1) # same as episode.actions[-1] + + # Looking back from ts=1, get the previous 4 rewards AND fill with 0.0 + # in case we go over the beginning (ts=0). So we would expect + # [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received + # reward in the episode: + episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0) + + # Note the use of fill=0.0 here (fill everything that's out of range with this + # value) AND the argument `neg_index_as_lookback=True`, which interprets + # negative indices as being left of ts=0 (e.g. -1 being the timestep before + # ts=0). + + # Assuming we had a complex action space (nested gym.spaces.Dict) with one or + # more elements being Discrete or MultiDiscrete spaces: + # 1) The `fill=...` argument would still work, filling all spaces (Boxes, + # Discrete) with that provided value. + # 2) Setting the flag `one_hot_discrete=True` would convert those discrete + # sub-components automatically into one-hot (or multi-one-hot) tensors. + # This simplifies the task of having to provide the previous 4 (nested and + # partially discrete/multi-discrete) actions for each timestep within a training + # batch, thereby filling timesteps before the episode started with 0.0s and + # one-hot'ing the discrete/multi-discrete components in these actions: + episode = SingleAgentEpisode(action_space=gym.spaces.Dict({ + "a": gym.spaces.Discrete(3), + "b": gym.spaces.MultiDiscrete([2, 3]), + "c": gym.spaces.Box(-1.0, 1.0, (2,)), + })) + + # ... fill episode with data ... + episode.add_env_reset(observation=0) + # ... from a few steps. + episode.add_env_step( + observation=1, + action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)}, + reward=1.0, + ) + + # In your connector + prev_4_a = [] + # Note here that len(episode) does NOT include the lookback buffer. + for ts in range(len(episode)): + prev_4_a.append( + episode.get_actions( + indices=slice(ts - 4, ts), + # Make sure negative indices are interpreted as + # "into lookback buffer" + neg_index_as_lookback=True, + # Zero-out everything even further before the lookback buffer. + fill=0.0, + # Take care of discrete components (get ready as NN input). + one_hot_discrete=True, + ) + ) + + # Finally, convert from list of batch items to a struct (same as action space) + # of batched (numpy) arrays, in which all leafs have B==len(prev_4_a). + from ray.rllib.utils.spaces.space_utils import batch + + prev_4_actions_col = batch(prev_4_a) + """ + + __slots__ = ( + "actions", + "agent_id", + "extra_model_outputs", + "id_", + "infos", + "is_terminated", + "is_truncated", + "module_id", + "multi_agent_episode_id", + "observations", + "rewards", + "t", + "t_started", + "_action_space", + "_last_added_observation", + "_last_added_infos", + "_last_step_time", + "_observation_space", + "_start_time", + "_temporary_timestep_data", + ) + + def __init__( + self, + id_: Optional[str] = None, + *, + observations: Optional[Union[List[ObsType], InfiniteLookbackBuffer]] = None, + observation_space: Optional[gym.Space] = None, + infos: Optional[Union[List[Dict], InfiniteLookbackBuffer]] = None, + actions: Optional[Union[List[ActType], InfiniteLookbackBuffer]] = None, + action_space: Optional[gym.Space] = None, + rewards: Optional[Union[List[SupportsFloat], InfiniteLookbackBuffer]] = None, + terminated: bool = False, + truncated: bool = False, + extra_model_outputs: Optional[Dict[str, Any]] = None, + t_started: Optional[int] = None, + len_lookback_buffer: Union[int, str] = "auto", + agent_id: Optional[AgentID] = None, + module_id: Optional[ModuleID] = None, + multi_agent_episode_id: Optional[int] = None, + ): + """Initializes a SingleAgentEpisode instance. + + This constructor can be called with or without already sampled data, part of + which might then go into the lookback buffer. + + Args: + id_: Unique identifier for this episode. If no ID is provided the + constructor generates a unique hexadecimal code for the id. + observations: Either a list of individual observations from a sampling or + an already instantiated `InfiniteLookbackBuffer` object (possibly + with observation data in it). If a list, will construct the buffer + automatically (given the data and the `len_lookback_buffer` argument). + observation_space: An optional gym.Space, which all individual observations + should abide to. If not None and this SingleAgentEpisode is numpy'ized + (via the `self.to_numpy()` method), and data is appended or set, the new + data will be checked for correctness. + infos: Either a list of individual info dicts from a sampling or + an already instantiated `InfiniteLookbackBuffer` object (possibly + with info dicts in it). If a list, will construct the buffer + automatically (given the data and the `len_lookback_buffer` argument). + actions: Either a list of individual info dicts from a sampling or + an already instantiated `InfiniteLookbackBuffer` object (possibly + with info dict] data in it). If a list, will construct the buffer + automatically (given the data and the `len_lookback_buffer` argument). + action_space: An optional gym.Space, which all individual actions + should abide to. If not None and this SingleAgentEpisode is numpy'ized + (via the `self.to_numpy()` method), and data is appended or set, the new + data will be checked for correctness. + rewards: Either a list of individual rewards from a sampling or + an already instantiated `InfiniteLookbackBuffer` object (possibly + with reward data in it). If a list, will construct the buffer + automatically (given the data and the `len_lookback_buffer` argument). + extra_model_outputs: A dict mapping string keys to either lists of + individual extra model output tensors (e.g. `action_logp` or + `state_outs`) from a sampling or to already instantiated + `InfiniteLookbackBuffer` object (possibly with extra model output data + in it). If mapping is to lists, will construct the buffers automatically + (given the data and the `len_lookback_buffer` argument). + terminated: A boolean indicating, if the episode is already terminated. + truncated: A boolean indicating, if the episode has been truncated. + t_started: Optional. The starting timestep of the episode. The default + is zero. If data is provided, the starting point is from the last + observation onwards (i.e. `t_started = len(observations) - 1`). If + this parameter is provided the episode starts at the provided value. + len_lookback_buffer: The size of the (optional) lookback buffers to keep in + front of this Episode for each type of data (observations, actions, + etc..). If larger 0, will interpret the first `len_lookback_buffer` + items in each type of data as NOT part of this actual + episode chunk, but instead serve as "historical" record that may be + viewed and used to derive new data from. For example, it might be + necessary to have a lookback buffer of four if you would like to do + observation frame stacking and your episode has been cut and you are now + operating on a new chunk (continuing from the cut one). Then, for the + first 3 items, you would have to be able to look back into the old + chunk's data. + If `len_lookback_buffer` is "auto" (default), will interpret all + provided data in the constructor as part of the lookback buffers. + agent_id: An optional AgentID indicating which agent this episode belongs + to. This information is stored under `self.agent_id` and only serves + reference purposes. + module_id: An optional ModuleID indicating which RLModule this episode + belongs to. Normally, this information is obtained by querying an + `agent_to_module_mapping_fn` with a given agent ID. This information + is stored under `self.module_id` and only serves reference purposes. + multi_agent_episode_id: An optional EpisodeID of the encapsulating + `MultiAgentEpisode` that this `SingleAgentEpisode` belongs to. + """ + self.id_ = id_ or uuid.uuid4().hex + + self.agent_id = agent_id + self.module_id = module_id + self.multi_agent_episode_id = multi_agent_episode_id + + # Lookback buffer length is not provided. Interpret already given data as + # lookback buffer lengths for all data types. + len_rewards = len(rewards) if rewards is not None else 0 + if len_lookback_buffer == "auto" or len_lookback_buffer > len_rewards: + len_lookback_buffer = len_rewards + + infos = infos or [{} for _ in range(len(observations or []))] + + # Observations: t0 (initial obs) to T. + self._observation_space = None + if isinstance(observations, InfiniteLookbackBuffer): + self.observations = observations + else: + self.observations = InfiniteLookbackBuffer( + data=observations, + lookback=len_lookback_buffer, + ) + self.observation_space = observation_space + # Infos: t0 (initial info) to T. + if isinstance(infos, InfiniteLookbackBuffer): + self.infos = infos + else: + self.infos = InfiniteLookbackBuffer( + data=infos, + lookback=len_lookback_buffer, + ) + # Actions: t1 to T. + self._action_space = None + if isinstance(actions, InfiniteLookbackBuffer): + self.actions = actions + else: + self.actions = InfiniteLookbackBuffer( + data=actions, + lookback=len_lookback_buffer, + ) + self.action_space = action_space + # Rewards: t1 to T. + if isinstance(rewards, InfiniteLookbackBuffer): + self.rewards = rewards + else: + self.rewards = InfiniteLookbackBuffer( + data=rewards, + lookback=len_lookback_buffer, + space=gym.spaces.Box(float("-inf"), float("inf"), (), np.float32), + ) + + # obs[-1] is the final observation in the episode. + self.is_terminated = terminated + # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more + # observations in following chunks for this episode). + self.is_truncated = truncated + + # Extra model outputs, e.g. `action_dist_input` needed in the batch. + self.extra_model_outputs = {} + for k, v in (extra_model_outputs or {}).items(): + if isinstance(v, InfiniteLookbackBuffer): + self.extra_model_outputs[k] = v + else: + # We cannot use the defaultdict's own constructor here as this would + # auto-set the lookback buffer to 0 (there is no data passed to that + # constructor). Then, when we manually have to set the data property, + # the lookback buffer would still be (incorrectly) 0. + self.extra_model_outputs[k] = InfiniteLookbackBuffer( + data=v, lookback=len_lookback_buffer + ) + + # The (global) timestep when this episode (possibly an episode chunk) started, + # excluding a possible lookback buffer. + self.t_started = t_started or 0 + # The current (global) timestep in the episode (possibly an episode chunk). + self.t = len(self.rewards) + self.t_started + + # Caches for temporary per-timestep data. May be used to store custom metrics + # from within a callback for the ongoing episode (e.g. render images). + self._temporary_timestep_data = defaultdict(list) + + # Keep timer stats on deltas between steps. + self._start_time = None + self._last_step_time = None + + self._last_added_observation = None + self._last_added_infos = None + + # Validate the episode data thus far. + self.validate() + + def add_env_reset( + self, + observation: ObsType, + infos: Optional[Dict] = None, + ) -> None: + """Adds the initial data (after an `env.reset()`) to the episode. + + This data consists of initial observations and initial infos. + + Args: + observation: The initial observation returned by `env.reset()`. + infos: An (optional) info dict returned by `env.reset()`. + """ + assert not self.is_reset + assert not self.is_done + assert len(self.observations) == 0 + # Assume that this episode is completely empty and has not stepped yet. + # Leave self.t (and self.t_started) at 0. + assert self.t == self.t_started == 0 + + infos = infos or {} + + if self.observation_space is not None: + assert self.observation_space.contains(observation), ( + f"`observation` {observation} does NOT fit SingleAgentEpisode's " + f"observation_space: {self.observation_space}!" + ) + + self.observations.append(observation) + self.infos.append(infos) + + self._last_added_observation = observation + self._last_added_infos = infos + + # Validate our data. + self.validate() + + # Start the timer for this episode. + self._start_time = time.perf_counter() + + def add_env_step( + self, + observation: ObsType, + action: ActType, + reward: SupportsFloat, + infos: Optional[Dict[str, Any]] = None, + *, + terminated: bool = False, + truncated: bool = False, + extra_model_outputs: Optional[Dict[str, Any]] = None, + ) -> None: + """Adds results of an `env.step()` call (including the action) to this episode. + + This data consists of an observation and info dict, an action, a reward, + terminated/truncated flags, and extra model outputs (e.g. action probabilities + or RNN internal state outputs). + + Args: + observation: The next observation received from the environment after(!) + taking `action`. + action: The last action used by the agent during the call to `env.step()`. + reward: The last reward received by the agent after taking `action`. + infos: The last info received from the environment after taking `action`. + terminated: A boolean indicating, if the environment has been + terminated (after taking `action`). + truncated: A boolean indicating, if the environment has been + truncated (after taking `action`). + extra_model_outputs: The last timestep's specific model outputs. + These are normally outputs of an RLModule that were computed along with + `action`, e.g. `action_logp` or `action_dist_inputs`. + """ + # Cannot add data to an already done episode. + assert ( + not self.is_done + ), "The agent is already done: no data can be added to its episode." + + self.observations.append(observation) + self.actions.append(action) + self.rewards.append(reward) + infos = infos or {} + self.infos.append(infos) + self.t += 1 + if extra_model_outputs is not None: + for k, v in extra_model_outputs.items(): + if k not in self.extra_model_outputs: + self.extra_model_outputs[k] = InfiniteLookbackBuffer([v]) + else: + self.extra_model_outputs[k].append(v) + self.is_terminated = terminated + self.is_truncated = truncated + + self._last_added_observation = observation + self._last_added_infos = infos + + # Only check spaces if numpy'ized AND every n timesteps. + if self.is_numpy and self.t % 100: + if self.observation_space is not None: + assert self.observation_space.contains(observation), ( + f"`observation` {observation} does NOT fit SingleAgentEpisode's " + f"observation_space: {self.observation_space}!" + ) + if self.action_space is not None: + assert self.action_space.contains(action), ( + f"`action` {action} does NOT fit SingleAgentEpisode's " + f"action_space: {self.action_space}!" + ) + + # Validate our data. + self.validate() + + # Step time stats. + self._last_step_time = time.perf_counter() + if self._start_time is None: + self._start_time = self._last_step_time + + def validate(self) -> None: + """Validates the episode's data. + + This function ensures that the data stored to a `SingleAgentEpisode` is + in order (e.g. that the correct number of observations, actions, rewards + are there). + """ + assert len(self.observations) == len(self.infos) + if len(self.observations) == 0: + assert len(self.infos) == len(self.rewards) == len(self.actions) == 0 + for k, v in self.extra_model_outputs.items(): + assert len(v) == 0, (k, v, v.data, len(v)) + # Make sure we always have one more obs stored than rewards (and actions) + # due to the reset/last-obs logic of an MDP. + else: + assert ( + len(self.observations) + == len(self.infos) + == len(self.rewards) + 1 + == len(self.actions) + 1 + ), ( + len(self.observations), + len(self.infos), + len(self.rewards), + len(self.actions), + ) + for k, v in self.extra_model_outputs.items(): + assert len(v) == len(self.observations) - 1 + + @property + def is_reset(self) -> bool: + """Returns True if `self.add_env_reset()` has already been called.""" + return len(self.observations) > 0 + + @property + def is_numpy(self) -> bool: + """True, if the data in this episode is already stored as numpy arrays.""" + # If rewards are still a list, return False. + # Otherwise, rewards should already be a (1D) numpy array. + return self.rewards.finalized + + @property + def is_done(self) -> bool: + """Whether the episode is actually done (terminated or truncated). + + A done episode cannot be continued via `self.add_timestep()` or being + concatenated on its right-side with another episode chunk or being + succeeded via `self.create_successor()`. + """ + return self.is_terminated or self.is_truncated + + def to_numpy(self) -> "SingleAgentEpisode": + """Converts this Episode's list attributes to numpy arrays. + + This means in particular that this episodes' lists of (possibly complex) + data (e.g. if we have a dict obs space) will be converted to (possibly complex) + structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will + have the same length (batch dimension) as the length of the original lists. + + Note that the data under the Columns.INFOS are NEVER numpy'ized and will remain + a list (normally, a list of the original, env-returned dicts). This is due to + the herterogenous nature of INFOS returned by envs, which would make it unwieldy + to convert this information to numpy arrays. + + After calling this method, no further data may be added to this episode via + the `self.add_env_step()` method. + + Examples: + + .. testcode:: + + import numpy as np + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + observations=[0, 1, 2, 3], + actions=[1, 2, 3], + rewards=[1, 2, 3], + # Note: terminated/truncated have nothing to do with an episode + # being numpy'ized or not (via the `self.to_numpy()` method)! + terminated=False, + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + ) + # Episode has not been numpy'ized yet. + assert not episode.is_numpy + # We are still operating on lists. + assert episode.get_observations([1]) == [1] + assert episode.get_observations(slice(None, 2)) == [0, 1] + # We can still add data (and even add the terminated=True flag). + episode.add_env_step( + observation=4, + action=4, + reward=4, + terminated=True, + ) + # Still NOT numpy'ized. + assert not episode.is_numpy + + # Numpy'ized the episode. + episode.to_numpy() + assert episode.is_numpy + + # We cannot add data anymore. The following would crash. + # episode.add_env_step(observation=5, action=5, reward=5) + + # Everything is now numpy arrays (with 0-axis of size + # B=[len of requested slice]). + assert isinstance(episode.get_observations([1]), np.ndarray) # B=1 + assert isinstance(episode.actions[0:2], np.ndarray) # B=2 + assert isinstance(episode.rewards[1:4], np.ndarray) # B=3 + + Returns: + This `SingleAgentEpisode` object with the converted numpy data. + """ + + self.observations.finalize() + if len(self) > 0: + self.actions.finalize() + self.rewards.finalize() + for k, v in self.extra_model_outputs.items(): + self.extra_model_outputs[k].finalize() + + return self + + def concat_episode(self, other: "SingleAgentEpisode") -> None: + """Adds the given `other` SingleAgentEpisode to the right side of self. + + In order for this to work, both chunks (`self` and `other`) must fit + together. This is checked by the IDs (must be identical), the time step counters + (`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the + observations/infos at the concatenation boundaries. Also, `self.is_done` must + not be True, meaning `self.is_terminated` and `self.is_truncated` are both + False. + + Args: + other: The other `SingleAgentEpisode` to be concatenated to this one. + + Returns: A `SingleAgentEpisode` instance containing the concatenated data + from both episodes (`self` and `other`). + """ + assert other.id_ == self.id_ + # NOTE (sven): This is what we agreed on. As the replay buffers must be + # able to concatenate. + assert not self.is_done + # Make sure the timesteps match. + assert self.t == other.t_started + # Validate `other`. + other.validate() + + # Make sure, end matches other episode chunk's beginning. + assert np.all(other.observations[0] == self.observations[-1]) + # Pop out our last observations and infos (as these are identical + # to the first obs and infos in the next episode). + self.observations.pop() + self.infos.pop() + + # Extend ourselves. In case, episode_chunk is already terminated and numpy'ized + # we need to convert to lists (as we are ourselves still filling up lists). + self.observations.extend(other.get_observations()) + self.actions.extend(other.get_actions()) + self.rewards.extend(other.get_rewards()) + self.infos.extend(other.get_infos()) + self.t = other.t + + if other.is_terminated: + self.is_terminated = True + elif other.is_truncated: + self.is_truncated = True + + for key in other.extra_model_outputs.keys(): + assert key in self.extra_model_outputs + self.extra_model_outputs[key].extend(other.get_extra_model_outputs(key)) + + # Validate. + self.validate() + + def cut(self, len_lookback_buffer: int = 0) -> "SingleAgentEpisode": + """Returns a successor episode chunk (of len=0) continuing from this Episode. + + The successor will have the same ID as `self`. + If no lookback buffer is requested (len_lookback_buffer=0), the successor's + observations will be the last observation(s) of `self` and its length will + therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0, + the returned successor will have `len_lookback_buffer` observations (and + actions, rewards, etc..) taken from the right side (end) of `self`. For example + if `len_lookback_buffer=2`, the returned successor's lookback buffer actions + will be identical to `self.actions[-2:]`. + + This method is useful if you would like to discontinue building an episode + chunk (b/c you have to return it from somewhere), but would like to have a new + episode instance to continue building the actual gym.Env episode at a later + time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor) + will still be able to "look back" into this predecessor episode's data (at + least to some extend, depending on the value of `len_lookback_buffer`). + + Args: + len_lookback_buffer: The number of timesteps to take along into the new + chunk as "lookback buffer". A lookback buffer is additional data on + the left side of the actual episode data for visibility purposes + (but without actually being part of the new chunk). For example, if + `self` ends in actions 5, 6, 7, and 8, and we call + `self.cut(len_lookback_buffer=2)`, the returned chunk will have + actions 7 and 8 already in it, but still `t_started`==t==8 (not 7!) and + a length of 0. If there is not enough data in `self` yet to fulfil + the `len_lookback_buffer` request, the value of `len_lookback_buffer` + is automatically adjusted (lowered). + + Returns: + The successor Episode chunk of this one with the same ID and state and the + only observation being the last observation in self. + """ + assert not self.is_done and len_lookback_buffer >= 0 + + # Initialize this chunk with the most recent obs and infos (even if lookback is + # 0). Similar to an initial `env.reset()`. + indices_obs_and_infos = slice(-len_lookback_buffer - 1, None) + indices_rest = ( + slice(-len_lookback_buffer, None) + if len_lookback_buffer > 0 + else slice(None, 0) + ) + + # Erase all temporary timestep data caches in `self`. + self._temporary_timestep_data.clear() + + return SingleAgentEpisode( + # Same ID. + id_=self.id_, + observations=self.get_observations(indices=indices_obs_and_infos), + observation_space=self.observation_space, + infos=self.get_infos(indices=indices_obs_and_infos), + actions=self.get_actions(indices=indices_rest), + action_space=self.action_space, + rewards=self.get_rewards(indices=indices_rest), + extra_model_outputs={ + k: self.get_extra_model_outputs(k, indices_rest) + for k in self.extra_model_outputs.keys() + }, + # Continue with self's current timestep. + t_started=self.t, + # Use the length of the provided data as lookback buffer. + len_lookback_buffer="auto", + ) + + # TODO (sven): Distinguish between: + # - global index: This is the absolute, global timestep whose values always + # start from 0 (at the env reset). So doing get_observations(0, global_ts=True) + # should always return the exact 1st observation (reset obs), no matter what. In + # case we are in an episode chunk and `fill` or a sufficient lookback buffer is + # provided, this should yield a result. Otherwise, error. + # - global index=False -> indices are relative to the chunk start. If a chunk has + # t_started=6 and we ask for index=0, then return observation at timestep 6 + # (t_started). + def get_observations( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + one_hot_discrete: bool = False, + ) -> Any: + """Returns individual observations or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual observation stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual observations in a batch of size len(indices). + A slice object is interpreted as a range of observations to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all observations (from ts=0 to the end). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9], + where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will + respond to `get_observations(-1, neg_index_as_lookback=True)` + with `6` and to + `get_observations(slice(-2, 1), neg_index_as_lookback=True)` with + `[5, 6, 7]`. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with observations [10, 11, 12, 13, 14] and + lookback buffer size of 2 (meaning observations `10` and `11` are part + of the lookback buffer) will respond to + `get_observations(slice(-7, -2), fill=0.0)` with + `[0.0, 0.0, 10, 11, 12]`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) observation + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + + Examples: + + .. testcode:: + + import gymnasium as gym + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + episode = SingleAgentEpisode( + # Discrete(4) observations (ints between 0 and 4 (excl.)) + observation_space=gym.spaces.Discrete(4), + observations=[0, 1, 2, 3], + actions=[1, 2, 3], rewards=[1, 2, 3], # <- not relevant for this demo + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + ) + # Plain usage (`indices` arg only). + check(episode.get_observations(-1), 3) + check(episode.get_observations(0), 0) + check(episode.get_observations([0, 2]), [0, 2]) + check(episode.get_observations([-1, 0]), [3, 0]) + check(episode.get_observations(slice(None, 2)), [0, 1]) + check(episode.get_observations(slice(-2, None)), [2, 3]) + # Using `fill=...` (requesting slices beyond the boundaries). + check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1]) + check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7]) + # Using `one_hot_discrete=True`. + check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0]) + check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1]) + check(episode.get_observations( + slice(0, 3), + one_hot_discrete=True, + ), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]) + # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. + check(episode.get_observations( + -1, + neg_index_as_lookback=True, # -1 means one left of ts=0 + fill=0.0, + one_hot_discrete=True, + ), [0, 0, 0, 0]) # <- all 0s one-hot tensor (note difference to [1 0 0 0]!) + + Returns: + The collected observations. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + return self.observations.get( + indices=indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + one_hot_discrete=one_hot_discrete, + ) + + def get_infos( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + ) -> Any: + """Returns individual info dicts or list (ranges) thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual info dict stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual info dicts in a list of size len(indices). + A slice object is interpreted as a range of info dicts to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all infos (from ts=0 to the end). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with infos + [{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the + first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will + respond to `get_infos(-1, neg_index_as_lookback=True)` with + `{"l":6}` and to + `get_infos(slice(-2, 1), neg_index_as_lookback=True)` with + `[{"l":5}, {"l":6}, {"a":7}]`. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to + auto-fill. For example, an episode with infos + [{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer + size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback + buffer) will respond to `get_infos(slice(-7, -2), fill={"o": 0.0})` + with `[{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]`. + + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}], + # The following is needed, but not relevant for this demo. + observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + ) + # Plain usage (`indices` arg only). + episode.get_infos(-1) # {"d":3} + episode.get_infos(0) # {"a":0} + episode.get_infos([0, 2]) # [{"a":0},{"c":2}] + episode.get_infos([-1, 0]) # [{"d":3},{"a":0}] + episode.get_infos(slice(None, 2)) # [{"a":0},{"b":1}] + episode.get_infos(slice(-2, None)) # [{"c":2},{"d":3}] + # Using `fill=...` (requesting slices beyond the boundaries). + # TODO (sven): This would require a space being provided. Maybe we can + # skip this check for infos, which don't have a space anyways. + # episode.get_infos(slice(-5, -3), fill={"o":-1}) # [{"o":-1},{"a":0}] + # episode.get_infos(slice(3, 5), fill={"o":-2}) # [{"d":3},{"o":-2}] + + Returns: + The collected info dicts. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + return self.infos.get( + indices=indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + ) + + def get_actions( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + one_hot_discrete: bool = False, + ) -> Any: + """Returns individual actions or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual action stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual actions in a batch of size len(indices). + A slice object is interpreted as a range of actions to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all actions (from ts=0 to the end). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where + [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond + to `get_actions(-1, neg_index_as_lookback=True)` with `6` and + to `get_actions(slice(-2, 1), neg_index_as_lookback=True)` with + `[5, 6, 7]`. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with actions [10, 11, 12, 13, 14] and + lookback buffer size of 2 (meaning actions `10` and `11` are part + of the lookback buffer) will respond to + `get_actions(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) action + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + + Examples: + + .. testcode:: + + import gymnasium as gym + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + # Discrete(4) actions (ints between 0 and 4 (excl.)) + action_space=gym.spaces.Discrete(4), + actions=[1, 2, 3], + observations=[0, 1, 2, 3], rewards=[1, 2, 3], # <- not relevant here + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + ) + # Plain usage (`indices` arg only). + episode.get_actions(-1) # 3 + episode.get_actions(0) # 1 + episode.get_actions([0, 2]) # [1, 3] + episode.get_actions([-1, 0]) # [3, 1] + episode.get_actions(slice(None, 2)) # [1, 2] + episode.get_actions(slice(-2, None)) # [2, 3] + # Using `fill=...` (requesting slices beyond the boundaries). + episode.get_actions(slice(-5, -2), fill=-9) # [-9, -9, 1, 2] + episode.get_actions(slice(1, 5), fill=-7) # [2, 3, -7, -7] + # Using `one_hot_discrete=True`. + episode.get_actions(1, one_hot_discrete=True) # [0 0 1 0] (action=2) + episode.get_actions(2, one_hot_discrete=True) # [0 0 0 1] (action=3) + episode.get_actions( + slice(0, 2), + one_hot_discrete=True, + ) # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3) + # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. + episode.get_actions( + -1, + neg_index_as_lookback=True, # -1 means one left of ts=0 + fill=0.0, + one_hot_discrete=True, + ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!) + + Returns: + The collected actions. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + return self.actions.get( + indices=indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + one_hot_discrete=one_hot_discrete, + ) + + def get_rewards( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_index_as_lookback: bool = False, + fill: Optional[float] = None, + ) -> Any: + """Returns individual rewards or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual reward stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual rewards in a batch of size len(indices). + A slice object is interpreted as a range of rewards to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + If None, will return all rewards (from ts=0 to the end). + neg_index_as_lookback: Negative values in `indices` are interpreted as + as "before ts=0", meaning going back into the lookback buffer. + For example, an episode with rewards [4, 5, 6, 7, 8, 9], where + [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond + to `get_rewards(-1, neg_index_as_lookback=True)` with `6` and + to `get_rewards(slice(-2, 1), neg_index_as_lookback=True)` with + `[5, 6, 7]`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with rewards [10, 11, 12, 13, 14] and + lookback buffer size of 2 (meaning rewards `10` and `11` are part + of the lookback buffer) will respond to + `get_rewards(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`. + + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + rewards=[1.0, 2.0, 3.0], + observations=[0, 1, 2, 3], actions=[1, 2, 3], # <- not relevant here + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + ) + # Plain usage (`indices` arg only). + episode.get_rewards(-1) # 3.0 + episode.get_rewards(0) # 1.0 + episode.get_rewards([0, 2]) # [1.0, 3.0] + episode.get_rewards([-1, 0]) # [3.0, 1.0] + episode.get_rewards(slice(None, 2)) # [1.0, 2.0] + episode.get_rewards(slice(-2, None)) # [2.0, 3.0] + # Using `fill=...` (requesting slices beyond the boundaries). + episode.get_rewards(slice(-5, -2), fill=0.0) # [0.0, 0.0, 1.0, 2.0] + episode.get_rewards(slice(1, 5), fill=0.0) # [2.0, 3.0, 0.0, 0.0] + + Returns: + The collected rewards. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + return self.rewards.get( + indices=indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + ) + + def get_extra_model_outputs( + self, + key: str, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_index_as_lookback: bool = False, + fill: Optional[Any] = None, + ) -> Any: + """Returns extra model outputs (under given key) from this episode. + + Args: + key: The `key` within `self.extra_model_outputs` to extract data for. + indices: A single int is interpreted as an index, from which to return an + individual extra model output stored under `key` at index. + A list of ints is interpreted as a list of indices from which to gather + individual actions in a batch of size len(indices). + A slice object is interpreted as a range of extra model outputs to be + returned. Thereby, negative indices by default are interpreted as + "before the end" unless the `neg_index_as_lookback=True` option is + used, in which case negative indices are interpreted as "before ts=0", + meaning going back into the lookback buffer. + If None, will return all extra model outputs (from ts=0 to the end). + neg_index_as_lookback: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with + extra_model_outputs['a'] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the + lookback buffer range (ts=0 item is 7), will respond to + `get_extra_model_outputs("a", -1, neg_index_as_lookback=True)` with + `6` and to `get_extra_model_outputs("a", slice(-2, 1), + neg_index_as_lookback=True)` with `[5, 6, 7]`. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with + extra_model_outputs["b"] = [10, 11, 12, 13, 14] and lookback buffer + size of 2 (meaning `10` and `11` are part of the lookback buffer) will + respond to + `get_extra_model_outputs("b", slice(-7, -2), fill=0.0)` with + `[0.0, 0.0, 10, 11, 12]`. + TODO (sven): This would require a space being provided. Maybe we can + automatically infer the space from existing data? + + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + extra_model_outputs={"mo": [1, 2, 3]}, + len_lookback_buffer=0, # no lookback; all data is actually "in" episode + # The following is needed, but not relevant for this demo. + observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], + ) + + # Plain usage (`indices` arg only). + episode.get_extra_model_outputs("mo", -1) # 3 + episode.get_extra_model_outputs("mo", 1) # 0 + episode.get_extra_model_outputs("mo", [0, 2]) # [1, 3] + episode.get_extra_model_outputs("mo", [-1, 0]) # [3, 1] + episode.get_extra_model_outputs("mo", slice(None, 2)) # [1, 2] + episode.get_extra_model_outputs("mo", slice(-2, None)) # [2, 3] + # Using `fill=...` (requesting slices beyond the boundaries). + # TODO (sven): This would require a space being provided. Maybe we can + # automatically infer the space from existing data? + # episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0) # [0, 0, 1] + # episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1) # [3, -1, -1] + + Returns: + The collected extra_model_outputs[`key`]. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + value = self.extra_model_outputs[key] + # The expected case is: `value` is a `InfiniteLookbackBuffer`. + if isinstance(value, InfiniteLookbackBuffer): + return value.get( + indices=indices, + neg_index_as_lookback=neg_index_as_lookback, + fill=fill, + ) + # TODO (sven): This does not seem to be solid yet. Users should NOT be able + # to just write directly into our buffers. Instead, use: + # `self.set_extra_model_outputs(key, new_data, at_indices=...)` and if key + # is not known, add a new buffer to the `extra_model_outputs` dict. + assert False + # It might be that the user has added new key/value pairs in their custom + # postprocessing/connector logic. The values are then most likely numpy + # arrays. We convert them automatically to buffers and get the requested + # indices (with the given options) from there. + return InfiniteLookbackBuffer(value).get( + indices, fill=fill, neg_index_as_lookback=neg_index_as_lookback + ) + + def set_observations( + self, + *, + new_data, + at_indices: Optional[Union[int, List[int], slice]] = None, + neg_index_as_lookback: bool = False, + ) -> None: + """Overwrites all or some of this Episode's observations with the provided data. + + Note that an episode's observation data cannot be written to directly as it is + managed by a `InfiniteLookbackBuffer` object. Normally, individual, current + observations are added to the episode either by calling `self.add_env_step` or + more directly (and manually) via `self.observations.append|extend()`. + However, for certain postprocessing steps, the entirety (or a slice) of an + episode's observations might have to be rewritten, which is when + `self.set_observations()` should be used. + + Args: + new_data: The new observation data to overwrite existing data with. + This may be a list of individual observation(s) in case this episode + is still not numpy'ized yet. In case this episode has already been + numpy'ized, this should be (possibly complex) struct matching the + observation space and with a batch size of its leafs exactly the size + of the to-be-overwritten slice or segment (provided by `at_indices`). + at_indices: A single int is interpreted as one index, which to overwrite + with `new_data` (which is expected to be a single observation). + A list of ints is interpreted as a list of indices, all of which to + overwrite with `new_data` (which is expected to be of the same size + as `len(at_indices)`). + A slice object is interpreted as a range of indices to be overwritten + with `new_data` (which is expected to be of the same size as the + provided slice). + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_index_as_lookback: If True, negative values in `at_indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with + observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the + lookback buffer range (ts=0 item is 7), will handle a call to + `set_observations(individual_observation, -1, + neg_index_as_lookback=True)` by overwriting the value of 6 in our + observations buffer with the provided "individual_observation". + + Raises: + IndexError: If the provided `at_indices` do not match the size of + `new_data`. + """ + self.observations.set( + new_data=new_data, + at_indices=at_indices, + neg_index_as_lookback=neg_index_as_lookback, + ) + + def set_actions( + self, + *, + new_data, + at_indices: Optional[Union[int, List[int], slice]] = None, + neg_index_as_lookback: bool = False, + ) -> None: + """Overwrites all or some of this Episode's actions with the provided data. + + Note that an episode's action data cannot be written to directly as it is + managed by a `InfiniteLookbackBuffer` object. Normally, individual, current + actions are added to the episode either by calling `self.add_env_step` or + more directly (and manually) via `self.actions.append|extend()`. + However, for certain postprocessing steps, the entirety (or a slice) of an + episode's actions might have to be rewritten, which is when + `self.set_actions()` should be used. + + Args: + new_data: The new action data to overwrite existing data with. + This may be a list of individual action(s) in case this episode + is still not numpy'ized yet. In case this episode has already been + numpy'ized, this should be (possibly complex) struct matching the + action space and with a batch size of its leafs exactly the size + of the to-be-overwritten slice or segment (provided by `at_indices`). + at_indices: A single int is interpreted as one index, which to overwrite + with `new_data` (which is expected to be a single action). + A list of ints is interpreted as a list of indices, all of which to + overwrite with `new_data` (which is expected to be of the same size + as `len(at_indices)`). + A slice object is interpreted as a range of indices to be overwritten + with `new_data` (which is expected to be of the same size as the + provided slice). + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_index_as_lookback: If True, negative values in `at_indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with + actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the + lookback buffer range (ts=0 item is 7), will handle a call to + `set_actions(individual_action, -1, + neg_index_as_lookback=True)` by overwriting the value of 6 in our + actions buffer with the provided "individual_action". + + Raises: + IndexError: If the provided `at_indices` do not match the size of + `new_data`. + """ + self.actions.set( + new_data=new_data, + at_indices=at_indices, + neg_index_as_lookback=neg_index_as_lookback, + ) + + def set_rewards( + self, + *, + new_data, + at_indices: Optional[Union[int, List[int], slice]] = None, + neg_index_as_lookback: bool = False, + ) -> None: + """Overwrites all or some of this Episode's rewards with the provided data. + + Note that an episode's reward data cannot be written to directly as it is + managed by a `InfiniteLookbackBuffer` object. Normally, individual, current + rewards are added to the episode either by calling `self.add_env_step` or + more directly (and manually) via `self.rewards.append|extend()`. + However, for certain postprocessing steps, the entirety (or a slice) of an + episode's rewards might have to be rewritten, which is when + `self.set_rewards()` should be used. + + Args: + new_data: The new reward data to overwrite existing data with. + This may be a list of individual reward(s) in case this episode + is still not numpy'ized yet. In case this episode has already been + numpy'ized, this should be a np.ndarray with a length exactly + the size of the to-be-overwritten slice or segment (provided by + `at_indices`). + at_indices: A single int is interpreted as one index, which to overwrite + with `new_data` (which is expected to be a single reward). + A list of ints is interpreted as a list of indices, all of which to + overwrite with `new_data` (which is expected to be of the same size + as `len(at_indices)`). + A slice object is interpreted as a range of indices to be overwritten + with `new_data` (which is expected to be of the same size as the + provided slice). + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_index_as_lookback: If True, negative values in `at_indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with + rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the + lookback buffer range (ts=0 item is 7), will handle a call to + `set_rewards(individual_reward, -1, + neg_index_as_lookback=True)` by overwriting the value of 6 in our + rewards buffer with the provided "individual_reward". + + Raises: + IndexError: If the provided `at_indices` do not match the size of + `new_data`. + """ + self.rewards.set( + new_data=new_data, + at_indices=at_indices, + neg_index_as_lookback=neg_index_as_lookback, + ) + + def set_extra_model_outputs( + self, + *, + key, + new_data, + at_indices: Optional[Union[int, List[int], slice]] = None, + neg_index_as_lookback: bool = False, + ) -> None: + """Overwrites all or some of this Episode's extra model outputs with `new_data`. + + Note that an episode's `extra_model_outputs` data cannot be written to directly + as it is managed by a `InfiniteLookbackBuffer` object. Normally, individual, + current `extra_model_output` values are added to the episode either by calling + `self.add_env_step` or more directly (and manually) via + `self.extra_model_outputs[key].append|extend()`. However, for certain + postprocessing steps, the entirety (or a slice) of an episode's + `extra_model_outputs` might have to be rewritten or a new key (a new type of + `extra_model_outputs`) must be inserted, which is when + `self.set_extra_model_outputs()` should be used. + + Args: + key: The `key` within `self.extra_model_outputs` to override data on or + to insert as a new key into `self.extra_model_outputs`. + new_data: The new data to overwrite existing data with. + This may be a list of individual reward(s) in case this episode + is still not numpy'ized yet. In case this episode has already been + numpy'ized, this should be a np.ndarray with a length exactly + the size of the to-be-overwritten slice or segment (provided by + `at_indices`). + at_indices: A single int is interpreted as one index, which to overwrite + with `new_data` (which is expected to be a single reward). + A list of ints is interpreted as a list of indices, all of which to + overwrite with `new_data` (which is expected to be of the same size + as `len(at_indices)`). + A slice object is interpreted as a range of indices to be overwritten + with `new_data` (which is expected to be of the same size as the + provided slice). + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_index_as_lookback=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_index_as_lookback: If True, negative values in `at_indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with + rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the + lookback buffer range (ts=0 item is 7), will handle a call to + `set_rewards(individual_reward, -1, + neg_index_as_lookback=True)` by overwriting the value of 6 in our + rewards buffer with the provided "individual_reward". + + Raises: + IndexError: If the provided `at_indices` do not match the size of + `new_data`. + """ + # Record already exists -> Set existing record's data to new values. + assert key in self.extra_model_outputs + self.extra_model_outputs[key].set( + new_data=new_data, + at_indices=at_indices, + neg_index_as_lookback=neg_index_as_lookback, + ) + + def add_temporary_timestep_data(self, key: str, data: Any) -> None: + """Temporarily adds (until `to_numpy()` called) per-timestep data to self. + + The given `data` is appended to a list (`self._temporary_timestep_data`), which + is cleared upon calling `self.to_numpy()`. To get the thus-far accumulated + temporary timestep data for a certain key, use the `get_temporary_timestep_data` + API. + Note that the size of the per timestep list is NOT checked or validated against + the other, non-temporary data in this episode (like observations). + + Args: + key: The key under which to find the list to append `data` to. If `data` is + the first data to be added for this key, start a new list. + data: The data item (representing a single timestep) to be stored. + """ + if self.is_numpy: + raise ValueError( + "Cannot use the `add_temporary_timestep_data` API on an already " + f"numpy'ized {type(self).__name__}!" + ) + self._temporary_timestep_data[key].append(data) + + def get_temporary_timestep_data(self, key: str) -> List[Any]: + """Returns all temporarily stored data items (list) under the given key. + + Note that all temporary timestep data is erased/cleared when calling + `self.to_numpy()`. + + Returns: + The current list storing temporary timestep data under `key`. + """ + if self.is_numpy: + raise ValueError( + "Cannot use the `get_temporary_timestep_data` API on an already " + f"numpy'ized {type(self).__name__}! All temporary data has been erased " + f"upon `{type(self).__name__}.to_numpy()`." + ) + try: + return self._temporary_timestep_data[key] + except KeyError: + raise KeyError(f"Key {key} not found in temporary timestep data!") + + def slice( + self, + slice_: slice, + *, + len_lookback_buffer: Optional[int] = None, + ) -> "SingleAgentEpisode": + """Returns a slice of this episode with the given slice object. + + For example, if `self` contains o0 (the reset observation), o1, o2, o3, and o4 + and the actions a1, a2, a3, and a4 (len of `self` is 4), then a call to + `self.slice(slice(1, 3))` would return a new SingleAgentEpisode with + observations o1, o2, and o3, and actions a2 and a3. Note here that there is + always one observation more in an episode than there are actions (and rewards + and extra model outputs) due to the initial observation received after an env + reset. + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + from ray.rllib.utils.test_utils import check + + # Generate a simple multi-agent episode. + observations = [0, 1, 2, 3, 4, 5] + actions = [1, 2, 3, 4, 5] + rewards = [0.1, 0.2, 0.3, 0.4, 0.5] + episode = SingleAgentEpisode( + observations=observations, + actions=actions, + rewards=rewards, + len_lookback_buffer=0, # all given data is part of the episode + ) + slice_1 = episode[:1] + check(slice_1.observations, [0, 1]) + check(slice_1.actions, [1]) + check(slice_1.rewards, [0.1]) + + slice_2 = episode[-2:] + check(slice_2.observations, [3, 4, 5]) + check(slice_2.actions, [4, 5]) + check(slice_2.rewards, [0.4, 0.5]) + + Args: + slice_: The slice object to use for slicing. This should exclude the + lookback buffer, which will be prepended automatically to the returned + slice. + len_lookback_buffer: If not None, forces the returned slice to try to have + this number of timesteps in its lookback buffer (if available). If None + (default), tries to make the returned slice's lookback as large as the + current lookback buffer of this episode (`self`). + + Returns: + The new SingleAgentEpisode representing the requested slice. + """ + # Translate `slice_` into one that only contains 0-or-positive ints and will + # NOT contain any None. + start = slice_.start + stop = slice_.stop + + # Start is None -> 0. + if start is None: + start = 0 + # Start is negative -> Interpret index as counting "from end". + elif start < 0: + start = len(self) + start + + # Stop is None -> Set stop to our len (one ts past last valid index). + if stop is None: + stop = len(self) + # Stop is negative -> Interpret index as counting "from end". + elif stop < 0: + stop = len(self) + stop + + step = slice_.step if slice_.step is not None else 1 + + # Figure out, whether slicing stops at the very end of this episode to know + # whether `self.is_terminated/is_truncated` should be kept as-is. + keep_done = stop == len(self) + # Provide correct timestep- and pre-buffer information. + t_started = self.t_started + start + + _lb = ( + len_lookback_buffer + if len_lookback_buffer is not None + else self.observations.lookback + ) + if ( + start >= 0 + and start - _lb < 0 + and self.observations.lookback < (_lb - start) + ): + _lb = self.observations.lookback + start + observations = InfiniteLookbackBuffer( + data=self.get_observations( + slice(start - _lb, stop + 1, step), + neg_index_as_lookback=True, + ), + lookback=_lb, + space=self.observation_space, + ) + + _lb = ( + len_lookback_buffer + if len_lookback_buffer is not None + else self.infos.lookback + ) + if start >= 0 and start - _lb < 0 and self.infos.lookback < (_lb - start): + _lb = self.infos.lookback + start + infos = InfiniteLookbackBuffer( + data=self.get_infos( + slice(start - _lb, stop + 1, step), + neg_index_as_lookback=True, + ), + lookback=_lb, + ) + + _lb = ( + len_lookback_buffer + if len_lookback_buffer is not None + else self.actions.lookback + ) + if start >= 0 and start - _lb < 0 and self.actions.lookback < (_lb - start): + _lb = self.actions.lookback + start + actions = InfiniteLookbackBuffer( + data=self.get_actions( + slice(start - _lb, stop, step), + neg_index_as_lookback=True, + ), + lookback=_lb, + space=self.action_space, + ) + + _lb = ( + len_lookback_buffer + if len_lookback_buffer is not None + else self.rewards.lookback + ) + if start >= 0 and start - _lb < 0 and self.rewards.lookback < (_lb - start): + _lb = self.rewards.lookback + start + rewards = InfiniteLookbackBuffer( + data=self.get_rewards( + slice(start - _lb, stop, step), + neg_index_as_lookback=True, + ), + lookback=_lb, + ) + + extra_model_outputs = {} + for k, v in self.extra_model_outputs.items(): + _lb = len_lookback_buffer if len_lookback_buffer is not None else v.lookback + if start >= 0 and start - _lb < 0 and v.lookback < (_lb - start): + _lb = v.lookback + start + extra_model_outputs[k] = InfiniteLookbackBuffer( + data=self.get_extra_model_outputs( + key=k, + indices=slice(start - _lb, stop, step), + neg_index_as_lookback=True, + ), + lookback=_lb, + ) + + return SingleAgentEpisode( + id_=self.id_, + # In the following, offset `start`s automatically by lookbacks. + observations=observations, + observation_space=self.observation_space, + infos=infos, + actions=actions, + action_space=self.action_space, + rewards=rewards, + extra_model_outputs=extra_model_outputs, + terminated=(self.is_terminated if keep_done else False), + truncated=(self.is_truncated if keep_done else False), + t_started=t_started, + ) + + def get_data_dict(self): + """Converts a SingleAgentEpisode into a data dict mapping str keys to data. + + The keys used are: + Columns.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS, + and those in `self.extra_model_outputs`. + + Returns: + A data dict mapping str keys to data records. + """ + t = list(range(self.t_started, self.t)) + terminateds = [False] * (len(self) - 1) + [self.is_terminated] + truncateds = [False] * (len(self) - 1) + [self.is_truncated] + eps_id = [self.id_] * len(self) + + if self.is_numpy: + t = np.array(t) + terminateds = np.array(terminateds) + truncateds = np.array(truncateds) + eps_id = np.array(eps_id) + + return dict( + { + # Trivial 1D data (compiled above). + Columns.TERMINATEDS: terminateds, + Columns.TRUNCATEDS: truncateds, + Columns.T: t, + Columns.EPS_ID: eps_id, + # Retrieve obs, infos, actions, rewards using our get_... APIs, + # which return all relevant timesteps (excluding the lookback + # buffer!). Slice off last obs and infos to have the same number + # of them as we have actions and rewards. + Columns.OBS: self.get_observations(slice(None, -1)), + Columns.INFOS: self.get_infos(slice(None, -1)), + Columns.ACTIONS: self.get_actions(), + Columns.REWARDS: self.get_rewards(), + }, + # All `extra_model_outs`: Same as obs: Use get_... API. + **{ + k: self.get_extra_model_outputs(k) + for k in self.extra_model_outputs.keys() + }, + ) + + def get_sample_batch(self) -> SampleBatch: + """Converts this `SingleAgentEpisode` into a `SampleBatch`. + + Returns: + A SampleBatch containing all of this episode's data. + """ + return SampleBatch(self.get_data_dict()) + + def get_return(self) -> float: + """Calculates an episode's return, excluding the lookback buffer's rewards. + + The return is computed by a simple sum, neglecting the discount factor. + Note that if `self` is a continuation chunk (resulting from a call to + `self.cut()`), the previous chunk's rewards are NOT counted and thus NOT + part of the returned reward sum. + + Returns: + The sum of rewards collected during this episode, excluding possible data + inside the lookback buffer and excluding possible data in a predecessor + chunk. + """ + return sum(self.get_rewards()) + + def get_duration_s(self) -> float: + """Returns the duration of this Episode (chunk) in seconds.""" + if self._last_step_time is None: + return 0.0 + return self._last_step_time - self._start_time + + def env_steps(self) -> int: + """Returns the number of environment steps. + + Note, this episode instance could be a chunk of an actual episode. + + Returns: + An integer that counts the number of environment steps this episode instance + has seen. + """ + return len(self) + + def agent_steps(self) -> int: + """Returns the number of agent steps. + + Note, these are identical to the environment steps for a single-agent episode. + + Returns: + An integer counting the number of agent steps executed during the time this + episode instance records. + """ + return self.env_steps() + + def get_state(self) -> Dict[str, Any]: + """Returns the pickable state of an episode. + + The data in the episode is stored into a dictionary. Note that episodes + can also be generated from states (see `SingleAgentEpisode.from_state()`). + + Returns: + A dict containing all the data from the episode. + """ + infos = self.infos.get_state() + infos["data"] = np.array([info if info else None for info in infos["data"]]) + return { + "id_": self.id_, + "agent_id": self.agent_id, + "module_id": self.module_id, + "multi_agent_episode_id": self.multi_agent_episode_id, + # Note, all data is stored in `InfiniteLookbackBuffer`s. + "observations": self.observations.get_state(), + "actions": self.actions.get_state(), + "rewards": self.rewards.get_state(), + "infos": self.infos.get_state(), + "extra_model_outputs": { + k: v.get_state() if v else v + for k, v in self.extra_model_outputs.items() + } + if len(self.extra_model_outputs) > 0 + else None, + "is_terminated": self.is_terminated, + "is_truncated": self.is_truncated, + "t_started": self.t_started, + "t": self.t, + "_observation_space": gym_space_to_dict(self._observation_space) + if self._observation_space + else None, + "_action_space": gym_space_to_dict(self._action_space) + if self._action_space + else None, + "_start_time": self._start_time, + "_last_step_time": self._last_step_time, + "_temporary_timestep_data": dict(self._temporary_timestep_data) + if len(self._temporary_timestep_data) > 0 + else None, + } + + @staticmethod + def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode": + """Creates a new `SingleAgentEpisode` instance from a state dict. + + Args: + state: The state dict, as returned by `self.get_state()`. + + Returns: + A new `SingleAgentEpisode` instance with the data from the state dict. + """ + # Create an empy episode instance. + episode = SingleAgentEpisode(id_=state["id_"]) + # Load all the data from the state dict into the episode. + episode.agent_id = state["agent_id"] + episode.module_id = state["module_id"] + episode.multi_agent_episode_id = state["multi_agent_episode_id"] + # Convert data back to `InfiniteLookbackBuffer`s. + episode.observations = InfiniteLookbackBuffer.from_state(state["observations"]) + episode.actions = InfiniteLookbackBuffer.from_state(state["actions"]) + episode.rewards = InfiniteLookbackBuffer.from_state(state["rewards"]) + episode.infos = InfiniteLookbackBuffer.from_state(state["infos"]) + episode.extra_model_outputs = ( + defaultdict( + functools.partial( + InfiniteLookbackBuffer, lookback=episode.observations.lookback + ), + { + k: InfiniteLookbackBuffer.from_state(v) + for k, v in state["extra_model_outputs"].items() + }, + ) + if state["extra_model_outputs"] + else defaultdict( + functools.partial( + InfiniteLookbackBuffer, lookback=episode.observations.lookback + ), + ) + ) + episode.is_terminated = state["is_terminated"] + episode.is_truncated = state["is_truncated"] + episode.t_started = state["t_started"] + episode.t = state["t"] + # We need to convert the spaces to dictionaries for serialization. + episode._observation_space = ( + gym_space_from_dict(state["_observation_space"]) + if state["_observation_space"] + else None + ) + episode._action_space = ( + gym_space_from_dict(state["_action_space"]) + if state["_action_space"] + else None + ) + episode._start_time = state["_start_time"] + episode._last_step_time = state["_last_step_time"] + episode._temporary_timestep_data = defaultdict( + list, state["_temporary_timestep_data"] or {} + ) + # Validate the episode. + episode.validate() + + return episode + + @property + def observation_space(self): + return self._observation_space + + @observation_space.setter + def observation_space(self, value): + self._observation_space = self.observations.space = value + + @property + def action_space(self): + return self._action_space + + @action_space.setter + def action_space(self, value): + self._action_space = self.actions.space = value + + def __len__(self) -> int: + """Returning the length of an episode. + + The length of an episode is defined by the length of its data, excluding + the lookback buffer data. The length is the number of timesteps an agent has + stepped through an environment thus far. + + The length is 0 in case of an episode whose env has NOT been reset yet, but + also 0 right after the `env.reset()` data has been added via + `self.add_env_reset()`. Only after the first call to `env.step()` (and + `self.add_env_step()`, the length will be 1. + + Returns: + An integer, defining the length of an episode. + """ + return self.t - self.t_started + + def __repr__(self): + return ( + f"SAEps(len={len(self)} done={self.is_done} " + f"R={self.get_return()} id_={self.id_})" + ) + + def __getitem__(self, item: slice) -> "SingleAgentEpisode": + """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:].""" + if isinstance(item, slice): + return self.slice(slice_=item) + else: + raise NotImplementedError( + f"SingleAgentEpisode does not support getting item '{item}'! " + "Only slice objects allowed with the syntax: `episode[a:b]`." + ) + + @Deprecated(new="SingleAgentEpisode.is_numpy()", error=True) + def is_finalized(self): + pass + + @Deprecated(new="SingleAgentEpisode.to_numpy()", error=True) + def finalize(self): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3161ed417bde6620c8b080b8f4892a20221506 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py @@ -0,0 +1,589 @@ +import base64 +from collections import defaultdict +import gzip +import json +import pathlib +import socket +import tempfile +import threading +import time +from typing import Collection, DefaultDict, List, Optional, Union + +import gymnasium as gym +import numpy as np +import onnxruntime + +from ray.rllib.core import ( + Columns, + COMPONENT_RL_MODULE, + DEFAULT_AGENT_ID, + DEFAULT_MODULE_ID, +) +from ray.rllib.env import INPUT_ENV_SPACES +from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.env.utils.external_env_protocol import RLlink as rllink +from ray.rllib.utils.annotations import ExperimentalAPI, override +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ( + EPISODE_DURATION_SEC_MEAN, + EPISODE_LEN_MAX, + EPISODE_LEN_MEAN, + EPISODE_LEN_MIN, + EPISODE_RETURN_MAX, + EPISODE_RETURN_MEAN, + EPISODE_RETURN_MIN, + WEIGHTS_SEQ_NO, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.numpy import softmax +from ray.rllib.utils.typing import EpisodeID, StateDict + +torch, _ = try_import_torch() + + +@ExperimentalAPI +class TcpClientInferenceEnvRunner(EnvRunner, Checkpointable): + """An EnvRunner communicating with an external env through a TCP socket. + + This implementation assumes: + - Only one external client ever connects to this env runner. + - The external client performs inference locally through an ONNX model. Thus, + samples are sent in bulk once a certain number of timesteps has been executed on the + client's side (no individual action requests). + - A copy of the RLModule is kept at all times on the env runner, but never used + for inference, only as a data (weights) container. + TODO (sven): The above might be inefficient as we have to store basically two + models, one in this EnvRunner, one in the env (as ONNX). + - There is no environment and no connectors on this env runner. The external env + is responsible for generating all the data to create episodes. + """ + + @override(EnvRunner) + def __init__(self, *, config, **kwargs): + """ + Initializes a TcpClientInferenceEnvRunner instance. + + Args: + config: The AlgorithmConfig to use for setup. + + Keyword Args: + port: The base port number. The server socket is then actually bound to + `port` + self.worker_index. + """ + super().__init__(config=config) + + self.worker_index: int = kwargs.get("worker_index", 0) + + self._weights_seq_no = 0 + + # Build the module from its spec. + module_spec = self.config.get_rl_module_spec( + spaces=self.get_spaces(), inference_only=True + ) + self.module = module_spec.build() + + self.host = "localhost" + self.port = int(self.config.env_config.get("port", 5555)) + self.worker_index + self.server_socket = None + self.client_socket = None + self.address = None + + self.metrics = MetricsLogger() + + self._episode_chunks_to_return: Optional[List[SingleAgentEpisode]] = None + self._done_episodes_for_metrics: List[SingleAgentEpisode] = [] + self._ongoing_episodes_for_metrics: DefaultDict[ + EpisodeID, List[SingleAgentEpisode] + ] = defaultdict(list) + + self._sample_lock = threading.Lock() + self._on_policy_lock = threading.Lock() + self._blocked_on_state = False + + # Start a background thread for client communication. + self.thread = threading.Thread( + target=self._client_message_listener, daemon=True + ) + self.thread.start() + + @override(EnvRunner) + def assert_healthy(self): + """Checks that the server socket is open and listening.""" + assert ( + self.server_socket is not None + ), "Server socket is None (not connected, not listening)." + + @override(EnvRunner) + def sample(self, **kwargs): + """Waits for the client to send episodes.""" + while True: + with self._sample_lock: + if self._episode_chunks_to_return is not None: + num_env_steps = 0 + num_episodes_completed = 0 + for eps in self._episode_chunks_to_return: + if eps.is_done: + self._done_episodes_for_metrics.append(eps) + num_episodes_completed += 1 + else: + self._ongoing_episodes_for_metrics[eps.id_].append(eps) + num_env_steps += len(eps) + + ret = self._episode_chunks_to_return + self._episode_chunks_to_return = None + + SingleAgentEnvRunner._increase_sampled_metrics( + self, num_env_steps, num_episodes_completed + ) + + return ret + time.sleep(0.01) + + @override(EnvRunner) + def get_metrics(self): + # TODO (sven): We should probably make this a utility function to be called + # from within Single/MultiAgentEnvRunner and other EnvRunner subclasses, as + # needed. + # Compute per-episode metrics (only on already completed episodes). + for eps in self._done_episodes_for_metrics: + assert eps.is_done + episode_length = len(eps) + episode_return = eps.get_return() + episode_duration_s = eps.get_duration_s() + # Don't forget about the already returned chunks of this episode. + if eps.id_ in self._ongoing_episodes_for_metrics: + for eps2 in self._ongoing_episodes_for_metrics[eps.id_]: + episode_length += len(eps2) + episode_return += eps2.get_return() + episode_duration_s += eps2.get_duration_s() + del self._ongoing_episodes_for_metrics[eps.id_] + + self._log_episode_metrics( + episode_length, episode_return, episode_duration_s + ) + + # Now that we have logged everything, clear cache of done episodes. + self._done_episodes_for_metrics.clear() + + # Return reduced metrics. + return self.metrics.reduce() + + def get_spaces(self): + return { + INPUT_ENV_SPACES: (self.config.observation_space, self.config.action_space), + DEFAULT_MODULE_ID: ( + self.config.observation_space, + self.config.action_space, + ), + } + + @override(EnvRunner) + def stop(self): + """Closes the client and server sockets.""" + self._close_sockets_if_necessary() + + @override(Checkpointable) + def get_ctor_args_and_kwargs(self): + return ( + (), # *args + {"config": self.config}, # **kwargs + ) + + @override(Checkpointable) + def get_checkpointable_components(self): + return [ + (COMPONENT_RL_MODULE, self.module), + ] + + @override(Checkpointable) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs, + ) -> StateDict: + return {} + + @override(Checkpointable) + def set_state(self, state: StateDict) -> None: + # Update the RLModule state. + if COMPONENT_RL_MODULE in state: + # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the + # update. + weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0) + + # Only update the weigths, if this is the first synchronization or + # if the weights of this `EnvRunner` lacks behind the actual ones. + if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no: + rl_module_state = state[COMPONENT_RL_MODULE] + if ( + isinstance(rl_module_state, dict) + and DEFAULT_MODULE_ID in rl_module_state + ): + rl_module_state = rl_module_state[DEFAULT_MODULE_ID] + self.module.set_state(rl_module_state) + + # Update our weights_seq_no, if the new one is > 0. + if weights_seq_no > 0: + self._weights_seq_no = weights_seq_no + + if self._blocked_on_state is True: + self._send_set_state_message() + self._blocked_on_state = False + + def _client_message_listener(self): + """Entry point for the listener thread.""" + + # Set up the server socket and bind to the specified host and port. + self._recycle_sockets() + + # Enter an endless message receival- and processing loop. + while True: + # As long as we are blocked on a new state, sleep a bit and continue. + # Do NOT process any incoming messages (until we send out the new state + # back to the client). + if self._blocked_on_state is True: + time.sleep(0.01) + continue + + try: + # Blocking call to get next message. + msg_type, msg_body = _get_message(self.client_socket) + + # Process the message received based on its type. + # Initial handshake. + if msg_type == rllink.PING: + self._send_pong_message() + + # Episode data from the client. + elif msg_type in [ + rllink.EPISODES, + rllink.EPISODES_AND_GET_STATE, + ]: + self._process_episodes_message(msg_type, msg_body) + + # Client requests the state (model weights). + elif msg_type == rllink.GET_STATE: + self._send_set_state_message() + + # Clients requests some (relevant) config information. + elif msg_type == rllink.GET_CONFIG: + self._send_set_config_message() + + except ConnectionError as e: + print(f"Messaging/connection error {e}! Recycling sockets ...") + self._recycle_sockets(5.0) + continue + + def _recycle_sockets(self, sleep: float = 0.0): + # Close all old sockets, if they exist. + self._close_sockets_if_necessary() + + time.sleep(sleep) + + # Start listening on the configured port. + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Allow reuse of the address. + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server_socket.bind((self.host, self.port)) + # Listen for a single connection. + self.server_socket.listen(1) + print(f"Waiting for client to connect to port {self.port}...") + + self.client_socket, self.address = self.server_socket.accept() + print(f"Connected to client at {self.address}") + + def _close_sockets_if_necessary(self): + if self.client_socket: + self.client_socket.close() + if self.server_socket: + self.server_socket.close() + + def _send_pong_message(self): + _send_message(self.client_socket, {"type": rllink.PONG.name}) + + def _process_episodes_message(self, msg_type, msg_body): + # On-policy training -> we have to block until we get a new `set_state` call + # (b/c the learning step is done and we can sent new weights back to all + # clients). + if msg_type == rllink.EPISODES_AND_GET_STATE: + self._blocked_on_state = True + + episodes = [] + for episode_data in msg_body["episodes"]: + episode = SingleAgentEpisode( + observation_space=self.config.observation_space, + observations=[np.array(o) for o in episode_data[Columns.OBS]], + action_space=self.config.action_space, + actions=episode_data[Columns.ACTIONS], + rewards=episode_data[Columns.REWARDS], + extra_model_outputs={ + Columns.ACTION_DIST_INPUTS: [ + np.array(a) for a in episode_data[Columns.ACTION_DIST_INPUTS] + ], + Columns.ACTION_LOGP: episode_data[Columns.ACTION_LOGP], + }, + terminated=episode_data["is_terminated"], + truncated=episode_data["is_truncated"], + len_lookback_buffer=0, + ) + episodes.append(episode.to_numpy()) + + # Push episodes into the to-be-returned list (for `sample()` requests). + with self._sample_lock: + if isinstance(self._episode_chunks_to_return, list): + self._episode_chunks_to_return.extend(episodes) + else: + self._episode_chunks_to_return = episodes + + def _send_set_state_message(self): + with tempfile.TemporaryDirectory() as dir: + onnx_file = pathlib.Path(dir) / "_temp_model.onnx" + torch.onnx.export( + self.module, + { + "batch": { + "obs": torch.randn(1, *self.config.observation_space.shape) + } + }, + onnx_file, + export_params=True, + ) + with open(onnx_file, "rb") as f: + compressed = gzip.compress(f.read()) + onnx_binary = base64.b64encode(compressed).decode("utf-8") + _send_message( + self.client_socket, + { + "type": rllink.SET_STATE.name, + "onnx_file": onnx_binary, + WEIGHTS_SEQ_NO: self._weights_seq_no, + }, + ) + + def _send_set_config_message(self): + _send_message( + self.client_socket, + { + "type": rllink.SET_CONFIG.name, + "env_steps_per_sample": self.config.get_rollout_fragment_length( + worker_index=self.worker_index + ), + "force_on_policy": True, + }, + ) + + def _log_episode_metrics(self, length, ret, sec): + # Log general episode metrics. + # To mimic the old API stack behavior, we'll use `window` here for + # these particular stats (instead of the default EMA). + win = self.config.metrics_num_episodes_for_smoothing + self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win) + self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win) + self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win) + # Per-agent returns. + self.metrics.log_value( + ("agent_episode_returns_mean", DEFAULT_AGENT_ID), ret, window=win + ) + # Per-RLModule returns. + self.metrics.log_value( + ("module_episode_returns_mean", DEFAULT_MODULE_ID), ret, window=win + ) + + # For some metrics, log min/max as well. + self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min", window=win) + self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min", window=win) + self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max", window=win) + self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max", window=win) + + +def _send_message(sock_, message: dict): + """Sends a message to the client with a length header.""" + body = json.dumps(message).encode("utf-8") + header = str(len(body)).zfill(8).encode("utf-8") + try: + sock_.sendall(header + body) + except Exception as e: + raise ConnectionError( + f"Error sending message {message} to server on socket {sock_}! " + f"Original error was: {e}" + ) + + +def _get_message(sock_): + """Receives a message from the client following the length-header protocol.""" + try: + # Read the length header (8 bytes) + header = _get_num_bytes(sock_, 8) + msg_length = int(header.decode("utf-8")) + # Read the message body + body = _get_num_bytes(sock_, msg_length) + # Decode JSON. + message = json.loads(body.decode("utf-8")) + # Check for proper protocol. + if "type" not in message: + raise ConnectionError( + "Protocol Error! Message from peer does not contain `type` " "field." + ) + return rllink(message.pop("type")), message + except Exception as e: + raise ConnectionError( + f"Error receiving message from peer on socket {sock_}! " + f"Original error was: {e}" + ) + + +def _get_num_bytes(sock_, num_bytes): + """Helper function to receive a specific number of bytes.""" + data = b"" + while len(data) < num_bytes: + packet = sock_.recv(num_bytes - len(data)) + if not packet: + raise ConnectionError(f"No data received from socket {sock_}!") + data += packet + return data + + +def _dummy_client(port: int = 5556): + """A dummy client that runs CartPole and acts as a testing external env.""" + + def _set_state(msg_body): + with tempfile.TemporaryDirectory(): + with open("_temp_onnx", "wb") as f: + f.write( + gzip.decompress( + base64.b64decode(msg_body["onnx_file"].encode("utf-8")) + ) + ) + onnx_session = onnxruntime.InferenceSession("_temp_onnx") + output_names = [o.name for o in onnx_session.get_outputs()] + return onnx_session, output_names + + # Connect to server. + while True: + try: + print(f"Trying to connect to localhost:{port} ...") + sock_ = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_.connect(("localhost", port)) + break + except ConnectionRefusedError: + time.sleep(5) + + # Send ping-pong. + _send_message(sock_, {"type": rllink.PING.name}) + msg_type, msg_body = _get_message(sock_) + assert msg_type == rllink.PONG + + # Request config. + _send_message(sock_, {"type": rllink.GET_CONFIG.name}) + msg_type, msg_body = _get_message(sock_) + assert msg_type == rllink.SET_CONFIG + env_steps_per_sample = msg_body["env_steps_per_sample"] + force_on_policy = msg_body["force_on_policy"] + + # Request ONNX weights. + _send_message(sock_, {"type": rllink.GET_STATE.name}) + msg_type, msg_body = _get_message(sock_) + assert msg_type == rllink.SET_STATE + onnx_session, output_names = _set_state(msg_body) + + # Episode collection buckets. + episodes = [] + observations = [] + actions = [] + action_dist_inputs = [] + action_logps = [] + rewards = [] + + timesteps = 0 + episode_return = 0.0 + + # Start actual env loop. + env = gym.make("CartPole-v1") + obs, info = env.reset() + observations.append(obs.tolist()) + + while True: + timesteps += 1 + # Perform action inference using the ONNX model. + logits = onnx_session.run( + output_names, + {"onnx::Gemm_0": np.array([obs], np.float32)}, + )[0][ + 0 + ] # [0]=first return item, [0]=batch size 1 + + # Stochastic sample. + action_probs = softmax(logits) + action = int(np.random.choice(list(range(env.action_space.n)), p=action_probs)) + logp = float(np.log(action_probs[action])) + + # Perform the env step. + obs, reward, terminated, truncated, info = env.step(action) + + # Collect step data. + observations.append(obs.tolist()) + actions.append(action) + action_dist_inputs.append(logits.tolist()) + action_logps.append(logp) + rewards.append(reward) + episode_return += reward + + # We have to create a new episode record. + if timesteps == env_steps_per_sample or terminated or truncated: + episodes.append( + { + Columns.OBS: observations, + Columns.ACTIONS: actions, + Columns.ACTION_DIST_INPUTS: action_dist_inputs, + Columns.ACTION_LOGP: action_logps, + Columns.REWARDS: rewards, + "is_terminated": terminated, + "is_truncated": truncated, + } + ) + # We collected enough samples -> Send them to server. + if timesteps == env_steps_per_sample: + # Make sure the amount of data we collected is correct. + assert sum(len(e["actions"]) for e in episodes) == env_steps_per_sample + + # Send the data to the server. + if force_on_policy: + _send_message( + sock_, + { + "type": rllink.EPISODES_AND_GET_STATE.name, + "episodes": episodes, + "timesteps": timesteps, + }, + ) + # We are forced to sample on-policy. Have to wait for a response + # with the state (weights) in it. + msg_type, msg_body = _get_message(sock_) + assert msg_type == rllink.SET_STATE + onnx_session, output_names = _set_state(msg_body) + + # Sampling doesn't have to be on-policy -> continue collecting + # samples. + else: + raise NotImplementedError + + episodes = [] + timesteps = 0 + + # Set new buckets to empty lists (for next episode). + observations = [observations[-1]] + actions = [] + action_dist_inputs = [] + action_logps = [] + rewards = [] + + # The episode is done -> Reset. + if terminated or truncated: + obs, _ = env.reset() + observations = [obs.tolist()] + episode_return = 0.0 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beafdd901c058a37f28099ef6919067a5a53d773 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e0896ba05e78e559c487ae0bdf858fba230718 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py @@ -0,0 +1,544 @@ +import logging +import gymnasium as gym +import numpy as np +from typing import Callable, List, Optional, Tuple, Union, Set + +from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID +from ray.rllib.utils.annotations import Deprecated, OldAPIStack, override +from ray.rllib.utils.typing import ( + EnvActionType, + EnvID, + EnvInfoDict, + EnvObsType, + EnvType, + MultiEnvDict, + AgentID, +) +from ray.util import log_once + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class VectorEnv: + """An environment that supports batch evaluation using clones of sub-envs.""" + + def __init__( + self, observation_space: gym.Space, action_space: gym.Space, num_envs: int + ): + """Initializes a VectorEnv instance. + + Args: + observation_space: The observation Space of a single + sub-env. + action_space: The action Space of a single sub-env. + num_envs: The number of clones to make of the given sub-env. + """ + self.observation_space = observation_space + self.action_space = action_space + self.num_envs = num_envs + + @staticmethod + def vectorize_gym_envs( + make_env: Optional[Callable[[int], EnvType]] = None, + existing_envs: Optional[List[gym.Env]] = None, + num_envs: int = 1, + action_space: Optional[gym.Space] = None, + observation_space: Optional[gym.Space] = None, + restart_failed_sub_environments: bool = False, + # Deprecated. These seem to have never been used. + env_config=None, + policy_config=None, + ) -> "_VectorizedGymEnv": + """Translates any given gym.Env(s) into a VectorizedEnv object. + + Args: + make_env: Factory that produces a new gym.Env taking the sub-env's + vector index as only arg. Must be defined if the + number of `existing_envs` is less than `num_envs`. + existing_envs: Optional list of already instantiated sub + environments. + num_envs: Total number of sub environments in this VectorEnv. + action_space: The action space. If None, use existing_envs[0]'s + action space. + observation_space: The observation space. If None, use + existing_envs[0]'s observation space. + restart_failed_sub_environments: If True and any sub-environment (within + a vectorized env) throws any error during env stepping, the + Sampler will try to restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environment and without + the RolloutWorker crashing. + + Returns: + The resulting _VectorizedGymEnv object (subclass of VectorEnv). + """ + return _VectorizedGymEnv( + make_env=make_env, + existing_envs=existing_envs or [], + num_envs=num_envs, + observation_space=observation_space, + action_space=action_space, + restart_failed_sub_environments=restart_failed_sub_environments, + ) + + def vector_reset( + self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None + ) -> Tuple[List[EnvObsType], List[EnvInfoDict]]: + """Resets all sub-environments. + + Args: + seed: The list of seeds to be passed to the sub-environments' when resetting + them. If None, will not reset any existing PRNGs. If you pass + integers, the PRNGs will be reset even if they already exists. + options: The list of options dicts to be passed to the sub-environments' + when resetting them. + + Returns: + Tuple consitsing of a list of observations from each environment and + a list of info dicts from each environment. + """ + raise NotImplementedError + + def reset_at( + self, + index: Optional[int] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Union[Tuple[EnvObsType, EnvInfoDict], Exception]: + """Resets a single sub-environment. + + Args: + index: An optional sub-env index to reset. + seed: The seed to be passed to the sub-environment at index `index` when + resetting it. If None, will not reset any existing PRNG. If you pass an + integer, the PRNG will be reset even if it already exists. + options: An options dict to be passed to the sub-environment at index + `index` when resetting it. + + Returns: + Tuple consisting of observations from the reset sub environment and + an info dict of the reset sub environment. Alternatively an Exception + can be returned, indicating that the reset operation on the sub environment + has failed (and why it failed). + """ + raise NotImplementedError + + def restart_at(self, index: Optional[int] = None) -> None: + """Restarts a single sub-environment. + + Args: + index: An optional sub-env index to restart. + """ + raise NotImplementedError + + def vector_step( + self, actions: List[EnvActionType] + ) -> Tuple[ + List[EnvObsType], List[float], List[bool], List[bool], List[EnvInfoDict] + ]: + """Performs a vectorized step on all sub environments using `actions`. + + Args: + actions: List of actions (one for each sub-env). + + Returns: + A tuple consisting of + 1) New observations for each sub-env. + 2) Reward values for each sub-env. + 3) Terminated values for each sub-env. + 4) Truncated values for each sub-env. + 5) Info values for each sub-env. + """ + raise NotImplementedError + + def get_sub_environments(self) -> List[EnvType]: + """Returns the underlying sub environments. + + Returns: + List of all underlying sub environments. + """ + return [] + + # TODO: (sven) Experimental method. Make @PublicAPI at some point. + def try_render_at(self, index: Optional[int] = None) -> Optional[np.ndarray]: + """Renders a single environment. + + Args: + index: An optional sub-env index to render. + + Returns: + Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or + None in case rendering is handled directly by this method. + """ + pass + + def to_base_env( + self, + make_env: Optional[Callable[[int], EnvType]] = None, + num_envs: int = 1, + remote_envs: bool = False, + remote_env_batch_wait_ms: int = 0, + restart_failed_sub_environments: bool = False, + ) -> "BaseEnv": + """Converts an RLlib MultiAgentEnv into a BaseEnv object. + + The resulting BaseEnv is always vectorized (contains n + sub-environments) to support batched forward passes, where n may + also be 1. BaseEnv also supports async execution via the `poll` and + `send_actions` methods and thus supports external simulators. + + Args: + make_env: A callable taking an int as input (which indicates + the number of individual sub-environments within the final + vectorized BaseEnv) and returning one individual + sub-environment. + num_envs: The number of sub-environments to create in the + resulting (vectorized) BaseEnv. The already existing `env` + will be one of the `num_envs`. + remote_envs: Whether each sub-env should be a @ray.remote + actor. You can set this behavior in your config via the + `remote_worker_envs=True` option. + remote_env_batch_wait_ms: The wait time (in ms) to poll remote + sub-environments for, if applicable. Only used if + `remote_envs` is True. + + Returns: + The resulting BaseEnv object. + """ + env = VectorEnvWrapper(self) + return env + + @Deprecated(new="vectorize_gym_envs", error=True) + def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv": + pass + + @Deprecated(new="get_sub_environments", error=True) + def get_unwrapped(self) -> List[EnvType]: + pass + + +@OldAPIStack +class _VectorizedGymEnv(VectorEnv): + """Internal wrapper to translate any gym.Envs into a VectorEnv object.""" + + def __init__( + self, + make_env: Optional[Callable[[int], EnvType]] = None, + existing_envs: Optional[List[gym.Env]] = None, + num_envs: int = 1, + *, + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, + restart_failed_sub_environments: bool = False, + # Deprecated. These seem to have never been used. + env_config=None, + policy_config=None, + ): + """Initializes a _VectorizedGymEnv object. + + Args: + make_env: Factory that produces a new gym.Env taking the sub-env's + vector index as only arg. Must be defined if the + number of `existing_envs` is less than `num_envs`. + existing_envs: Optional list of already instantiated sub + environments. + num_envs: Total number of sub environments in this VectorEnv. + action_space: The action space. If None, use existing_envs[0]'s + action space. + observation_space: The observation space. If None, use + existing_envs[0]'s observation space. + restart_failed_sub_environments: If True and any sub-environment (within + a vectorized env) throws any error during env stepping, we will try to + restart the faulty sub-environment. This is done + without disturbing the other (still intact) sub-environments. + """ + self.envs = existing_envs + self.make_env = make_env + self.restart_failed_sub_environments = restart_failed_sub_environments + + # Fill up missing envs (so we have exactly num_envs sub-envs in this + # VectorEnv. + while len(self.envs) < num_envs: + self.envs.append(make_env(len(self.envs))) + + super().__init__( + observation_space=observation_space or self.envs[0].observation_space, + action_space=action_space or self.envs[0].action_space, + num_envs=num_envs, + ) + + @override(VectorEnv) + def vector_reset( + self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None + ) -> Tuple[List[EnvObsType], List[EnvInfoDict]]: + seeds = seeds or [None] * self.num_envs + options = options or [None] * self.num_envs + # Use reset_at(index) to restart and retry until + # we successfully create a new env. + resetted_obs = [] + resetted_infos = [] + for i in range(len(self.envs)): + while True: + obs, infos = self.reset_at(i, seed=seeds[i], options=options[i]) + if not isinstance(obs, Exception): + break + resetted_obs.append(obs) + resetted_infos.append(infos) + return resetted_obs, resetted_infos + + @override(VectorEnv) + def reset_at( + self, + index: Optional[int] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[Union[EnvObsType, Exception], Union[EnvInfoDict, Exception]]: + if index is None: + index = 0 + try: + obs_and_infos = self.envs[index].reset(seed=seed, options=options) + + except Exception as e: + if self.restart_failed_sub_environments: + logger.exception(e.args[0]) + self.restart_at(index) + obs_and_infos = e, {} + else: + raise e + + return obs_and_infos + + @override(VectorEnv) + def restart_at(self, index: Optional[int] = None) -> None: + if index is None: + index = 0 + + # Try closing down the old (possibly faulty) sub-env, but ignore errors. + try: + self.envs[index].close() + except Exception as e: + if log_once("close_sub_env"): + logger.warning( + "Trying to close old and replaced sub-environment (at vector " + f"index={index}), but closing resulted in error:\n{e}" + ) + env_to_del = self.envs[index] + self.envs[index] = None + del env_to_del + + # Re-create the sub-env at the new index. + logger.warning(f"Trying to restart sub-environment at index {index}.") + self.envs[index] = self.make_env(index) + logger.warning(f"Sub-environment at index {index} restarted successfully.") + + @override(VectorEnv) + def vector_step( + self, actions: List[EnvActionType] + ) -> Tuple[ + List[EnvObsType], List[float], List[bool], List[bool], List[EnvInfoDict] + ]: + obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = ( + [], + [], + [], + [], + [], + ) + for i in range(self.num_envs): + try: + results = self.envs[i].step(actions[i]) + except Exception as e: + if self.restart_failed_sub_environments: + logger.exception(e.args[0]) + self.restart_at(i) + results = e, 0.0, True, True, {} + else: + raise e + + obs, reward, terminated, truncated, info = results + + if not isinstance(info, dict): + raise ValueError( + "Info should be a dict, got {} ({})".format(info, type(info)) + ) + obs_batch.append(obs) + reward_batch.append(reward) + terminated_batch.append(terminated) + truncated_batch.append(truncated) + info_batch.append(info) + return obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch + + @override(VectorEnv) + def get_sub_environments(self) -> List[EnvType]: + return self.envs + + @override(VectorEnv) + def try_render_at(self, index: Optional[int] = None): + if index is None: + index = 0 + return self.envs[index].render() + + +@OldAPIStack +class VectorEnvWrapper(BaseEnv): + """Internal adapter of VectorEnv to BaseEnv. + + We assume the caller will always send the full vector of actions in each + call to send_actions(), and that they call reset_at() on all completed + environments before calling send_actions(). + """ + + def __init__(self, vector_env: VectorEnv): + self.vector_env = vector_env + self.num_envs = vector_env.num_envs + self._observation_space = vector_env.observation_space + self._action_space = vector_env.action_space + + # Sub-environments' states. + self.new_obs = None + self.cur_rewards = None + self.cur_terminateds = None + self.cur_truncateds = None + self.cur_infos = None + # At first `poll()`, reset everything (all sub-environments). + self.first_reset_done = False + # Initialize sub-environments' state. + self._init_env_state(idx=None) + + @override(BaseEnv) + def poll( + self, + ) -> Tuple[ + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + MultiEnvDict, + ]: + from ray.rllib.env.base_env import with_dummy_agent_id + + if not self.first_reset_done: + self.first_reset_done = True + # TODO(sven): We probably would like to seed this call here as well. + self.new_obs, self.cur_infos = self.vector_env.vector_reset() + new_obs = dict(enumerate(self.new_obs)) + rewards = dict(enumerate(self.cur_rewards)) + terminateds = dict(enumerate(self.cur_terminateds)) + truncateds = dict(enumerate(self.cur_truncateds)) + infos = dict(enumerate(self.cur_infos)) + + # Empty all states (in case `poll()` gets called again). + self.new_obs = [] + self.cur_rewards = [] + self.cur_terminateds = [] + self.cur_truncateds = [] + self.cur_infos = [] + + return ( + with_dummy_agent_id(new_obs), + with_dummy_agent_id(rewards), + with_dummy_agent_id(terminateds, "__all__"), + with_dummy_agent_id(truncateds, "__all__"), + with_dummy_agent_id(infos), + {}, + ) + + @override(BaseEnv) + def send_actions(self, action_dict: MultiEnvDict) -> None: + from ray.rllib.env.base_env import _DUMMY_AGENT_ID + + action_vector = [None] * self.num_envs + for i in range(self.num_envs): + action_vector[i] = action_dict[i][_DUMMY_AGENT_ID] + ( + self.new_obs, + self.cur_rewards, + self.cur_terminateds, + self.cur_truncateds, + self.cur_infos, + ) = self.vector_env.vector_step(action_vector) + + @override(BaseEnv) + def try_reset( + self, + env_id: Optional[EnvID] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[MultiEnvDict, MultiEnvDict]: + from ray.rllib.env.base_env import _DUMMY_AGENT_ID + + if env_id is None: + env_id = 0 + assert isinstance(env_id, int) + obs, infos = self.vector_env.reset_at(env_id, seed=seed, options=options) + + # If exceptions were returned, return MultiEnvDict mapping env indices to + # these exceptions (for obs and infos). + if isinstance(obs, Exception): + return {env_id: obs}, {env_id: infos} + # Otherwise, return a MultiEnvDict (with single agent ID) and the actual + # obs and info dicts. + else: + return {env_id: {_DUMMY_AGENT_ID: obs}}, {env_id: {_DUMMY_AGENT_ID: infos}} + + @override(BaseEnv) + def try_restart(self, env_id: Optional[EnvID] = None) -> None: + assert env_id is None or isinstance(env_id, int) + # Restart the sub-env at the index. + self.vector_env.restart_at(env_id) + # Auto-reset (get ready for next `poll()`). + self._init_env_state(env_id) + + @override(BaseEnv) + def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]: + if not as_dict: + return self.vector_env.get_sub_environments() + else: + return { + _id: env + for _id, env in enumerate(self.vector_env.get_sub_environments()) + } + + @override(BaseEnv) + def try_render(self, env_id: Optional[EnvID] = None) -> None: + assert env_id is None or isinstance(env_id, int) + return self.vector_env.try_render_at(env_id) + + @property + @override(BaseEnv) + def observation_space(self) -> gym.Space: + return self._observation_space + + @property + @override(BaseEnv) + def action_space(self) -> gym.Space: + return self._action_space + + @override(BaseEnv) + def get_agent_ids(self) -> Set[AgentID]: + return {_DUMMY_AGENT_ID} + + def _init_env_state(self, idx: Optional[int] = None) -> None: + """Resets all or one particular sub-environment's state (by index). + + Args: + idx: The index to reset at. If None, reset all the sub-environments' states. + """ + # If index is None, reset all sub-envs' states: + if idx is None: + self.new_obs = [None for _ in range(self.num_envs)] + self.cur_rewards = [0.0 for _ in range(self.num_envs)] + self.cur_terminateds = [False for _ in range(self.num_envs)] + self.cur_truncateds = [False for _ in range(self.num_envs)] + self.cur_infos = [{} for _ in range(self.num_envs)] + # Index provided, reset only the sub-env's state at the given index. + else: + self.new_obs[idx], self.cur_infos[idx] = self.vector_env.reset_at(idx) + # Reset all other states to null values. + self.cur_rewards[idx] = 0.0 + self.cur_terminateds[idx] = False + self.cur_truncateds[idx] = False diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..255e56ea7758d2316ad8c6cde05fecabcf3ed6c0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e58e69ba71a7ed285785f103c82f0713032cba60 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ef5f40294139d72564d31b7301d3ede1dd5c910 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a0a9f5bed2d52734899db71291f07ae27fe460f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d53544e5d59ca82f3335213755c4cbe45851779a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb0f3ff771969c04b60f4d5002bece338668df0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py @@ -0,0 +1,400 @@ +from collections import deque +import gymnasium as gym +from gymnasium import spaces +import numpy as np +from typing import Optional, Union + +from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.images import rgb2gray, resize + + +@PublicAPI +def is_atari(env: Union[gym.Env, str]) -> bool: + """Returns, whether a given env object or env descriptor (str) is an Atari env. + + Args: + env: The gym.Env object or a string descriptor of the env (for example, + "ale_py:ALE/Pong-v5"). + + Returns: + Whether `env` is an Atari environment. + """ + # If a gym.Env, check proper spaces as well as occurrence of the "Atari 0: + # for Qbert sometimes we stay in lives == 0 condtion for a few fr + # so its important to keep lives > 0, so that we only reset once + # the environment advertises `terminated`. + terminated = True + self.lives = lives + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """Reset only when lives are exhausted. + This way all states are still reachable even though lives are episodic, + and the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_terminated: + obs, info = self.env.reset(**kwargs) + else: + # no-op step to advance from terminal/lost life state + obs, _, _, _, info = self.env.step(0) + self.lives = self.env.unwrapped.ale.lives() + return obs, info + + +@PublicAPI +class FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset. + + For environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == "FIRE" + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self, **kwargs): + self.env.reset(**kwargs) + obs, _, terminated, truncated, _ = self.env.step(1) + if terminated or truncated: + self.env.reset(**kwargs) + obs, _, terminated, truncated, info = self.env.step(2) + if terminated or truncated: + self.env.reset(**kwargs) + return obs, info + + def step(self, ac): + return self.env.step(ac) + + +@PublicAPI +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + """Stack k last frames.""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = spaces.Box( + low=np.repeat(env.observation_space.low, repeats=k, axis=-1), + high=np.repeat(env.observation_space.high, repeats=k, axis=-1), + shape=(shp[0], shp[1], shp[2] * k), + dtype=env.observation_space.dtype, + ) + + def reset(self, *, seed=None, options=None): + ob, infos = self.env.reset(seed=seed, options=options) + for _ in range(self.k): + self.frames.append(ob) + return self._get_ob(), infos + + def step(self, action): + ob, reward, terminated, truncated, info = self.env.step(action) + self.frames.append(ob) + return self._get_ob(), reward, terminated, truncated, info + + def _get_ob(self): + assert len(self.frames) == self.k + return np.concatenate(self.frames, axis=2) + + +@PublicAPI +class FrameStackTrajectoryView(gym.ObservationWrapper): + def __init__(self, env): + """No stacking. Trajectory View API takes care of this.""" + gym.Wrapper.__init__(self, env) + shp = env.observation_space.shape + assert shp[2] == 1 + self.observation_space = spaces.Box( + low=0, high=255, shape=(shp[0], shp[1]), dtype=env.observation_space.dtype + ) + + def observation(self, observation): + return np.squeeze(observation, axis=-1) + + +@PublicAPI +class MaxAndSkipEnv(gym.Wrapper): + def __init__(self, env, skip=4): + """Return only every `skip`-th frame""" + gym.Wrapper.__init__(self, env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = np.zeros( + (2,) + env.observation_space.shape, dtype=env.observation_space.dtype + ) + self._skip = skip + + def step(self, action): + """Repeat action, sum reward, and max over last observations.""" + total_reward = 0.0 + terminated = truncated = info = None + for i in range(self._skip): + obs, reward, terminated, truncated, info = self.env.step(action) + if i == self._skip - 2: + self._obs_buffer[0] = obs + if i == self._skip - 1: + self._obs_buffer[1] = obs + total_reward += reward + if terminated or truncated: + break + # Note that the observation on the terminated|truncated=True frame + # doesn't matter + max_frame = self._obs_buffer.max(axis=0) + + return max_frame, total_reward, terminated, truncated, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +@PublicAPI +class MonitorEnv(gym.Wrapper): + def __init__(self, env=None): + """Record episodes stats prior to EpisodicLifeEnv, etc.""" + gym.Wrapper.__init__(self, env) + self._current_reward = None + self._num_steps = None + self._total_steps = None + self._episode_rewards = [] + self._episode_lengths = [] + self._num_episodes = 0 + self._num_returned = 0 + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + + if self._total_steps is None: + self._total_steps = sum(self._episode_lengths) + + if self._current_reward is not None: + self._episode_rewards.append(self._current_reward) + self._episode_lengths.append(self._num_steps) + self._num_episodes += 1 + + self._current_reward = 0 + self._num_steps = 0 + + return obs, info + + def step(self, action): + obs, rew, terminated, truncated, info = self.env.step(action) + self._current_reward += rew + self._num_steps += 1 + self._total_steps += 1 + return obs, rew, terminated, truncated, info + + def get_episode_rewards(self): + return self._episode_rewards + + def get_episode_lengths(self): + return self._episode_lengths + + def get_total_steps(self): + return self._total_steps + + def next_episode_results(self): + for i in range(self._num_returned, len(self._episode_rewards)): + yield (self._episode_rewards[i], self._episode_lengths[i]) + self._num_returned = len(self._episode_rewards) + + +@PublicAPI +class NoopResetEnv(gym.Wrapper): + def __init__(self, env, noop_max=30): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + gym.Wrapper.__init__(self, env) + self.noop_max = noop_max + self.override_num_noops = None + self.noop_action = 0 + assert env.unwrapped.get_action_meanings()[0] == "NOOP" + + def reset(self, **kwargs): + """Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset(**kwargs) + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + # This environment now uses the pcg64 random number generator which + # does not have `randint` as an attribute only has `integers`. + try: + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) + # Also still support older versions. + except AttributeError: + noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, terminated, truncated, info = self.env.step(self.noop_action) + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + return obs, info + + def step(self, ac): + return self.env.step(ac) + + +@PublicAPI +class NormalizedImageEnv(gym.ObservationWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.observation_space = gym.spaces.Box( + -1.0, + 1.0, + shape=self.observation_space.shape, + dtype=np.float32, + ) + + # Divide by scale and center around 0.0, such that observations are in the range + # of -1.0 and 1.0. + def observation(self, observation): + return (observation.astype(np.float32) / 128.0) - 1.0 + + +@PublicAPI +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env, dim): + """Warp frames to the specified size (dim x dim).""" + gym.ObservationWrapper.__init__(self, env) + self.width = dim + self.height = dim + self.observation_space = spaces.Box( + low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8 + ) + + def observation(self, frame): + frame = rgb2gray(frame) + frame = resize(frame, height=self.height, width=self.width) + return frame[:, :, None] + + +@PublicAPI +def wrap_atari_for_new_api_stack( + env: gym.Env, + dim: int = 64, + frameskip: int = 4, + framestack: Optional[int] = None, + # TODO (sven): Add option to NOT grayscale, in which case framestack must be None + # (b/c we are using the 3 color channels already as stacking frames). +) -> gym.Env: + """Wraps `env` for new-API-stack-friendly RLlib Atari experiments. + + Note that we assume reward clipping is done outside the wrapper. + + Args: + env: The env object to wrap. + dim: Dimension to resize observations to (dim x dim). + frameskip: Whether to skip n frames and max over them (keep brightest pixels). + framestack: Whether to stack the last n (grayscaled) frames. Note that this + step happens after(!) a possible frameskip step, meaning that if + frameskip=4 and framestack=2, we would perform the following over this + trajectory: + actual env timesteps: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> ... + frameskip: ( max ) ( max ) ( max ) ( max ) + framestack: ( stack ) (stack ) + + Returns: + The wrapped gym.Env. + """ + # Time limit. + env = gym.wrappers.TimeLimit(env, max_episode_steps=108000) + # Grayscale + resize. + env = WarpFrame(env, dim=dim) + # Normalize the image. + env = NormalizedImageEnv(env) + # Frameskip: Take max over these n frames. + if frameskip > 1: + assert env.spec is not None + env = MaxAndSkipEnv(env, skip=frameskip) + # Send n noop actions into env after reset to increase variance in the + # "start states" of the trajectories. These dummy steps are NOT included in the + # sampled data used for learning. + env = NoopResetEnv(env, noop_max=30) + # Each life is one episode. + env = EpisodicLifeEnv(env) + # Some envs only start playing after pressing fire. Unblock those. + if "FIRE" in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + # Framestack. + if framestack: + env = FrameStack(env, k=framestack) + return env + + +@PublicAPI +def wrap_deepmind(env, dim=84, framestack=True, noframeskip=False): + """Configure environment for DeepMind-style Atari. + + Note that we assume reward clipping is done outside the wrapper. + + Args: + env: The env object to wrap. + dim: Dimension to resize observations to (dim x dim). + framestack: Whether to framestack observations. + """ + env = MonitorEnv(env) + env = NoopResetEnv(env, noop_max=30) + if env.spec is not None and noframeskip is True: + env = MaxAndSkipEnv(env, skip=4) + env = EpisodicLifeEnv(env) + if "FIRE" in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = WarpFrame(env, dim) + # env = ClipRewardEnv(env) # reward clipping is handled by policy eval + # 4x image framestacking. + if framestack is True: + env = FrameStack(env, 4) + return env diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8408bbf552ac2396a7332fe653c8f0f4fa669a10 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py @@ -0,0 +1,220 @@ +""" +DeepMind Control Suite Wrapper directly sourced from: +https://github.com/denisyarats/dmc2gym + +MIT License + +Copyright (c) 2020 Denis Yarats + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from gymnasium import core, spaces + +try: + from dm_env import specs +except ImportError: + specs = None +try: + # Suppress MuJoCo warning (dm_control uses absl logging). + import absl.logging + + absl.logging.set_verbosity("error") + from dm_control import suite +except (ImportError, OSError): + suite = None +import numpy as np + +from ray.rllib.utils.annotations import PublicAPI + + +def _spec_to_box(spec): + def extract_min_max(s): + assert s.dtype == np.float64 or s.dtype == np.float32 + dim = np.int_(np.prod(s.shape)) + if type(s) is specs.Array: + bound = np.inf * np.ones(dim, dtype=np.float32) + return -bound, bound + elif type(s) is specs.BoundedArray: + zeros = np.zeros(dim, dtype=np.float32) + return s.minimum + zeros, s.maximum + zeros + + mins, maxs = [], [] + for s in spec: + mn, mx = extract_min_max(s) + mins.append(mn) + maxs.append(mx) + low = np.concatenate(mins, axis=0) + high = np.concatenate(maxs, axis=0) + assert low.shape == high.shape + return spaces.Box(low, high, dtype=np.float32) + + +def _flatten_obs(obs): + obs_pieces = [] + for v in obs.values(): + flat = np.array([v]) if np.isscalar(v) else v.ravel() + obs_pieces.append(flat) + return np.concatenate(obs_pieces, axis=0) + + +@PublicAPI +class DMCEnv(core.Env): + def __init__( + self, + domain_name, + task_name, + task_kwargs=None, + visualize_reward=False, + from_pixels=False, + height=64, + width=64, + camera_id=0, + frame_skip=2, + environment_kwargs=None, + channels_first=True, + preprocess=True, + ): + self._from_pixels = from_pixels + self._height = height + self._width = width + self._camera_id = camera_id + self._frame_skip = frame_skip + self._channels_first = channels_first + self.preprocess = preprocess + + if specs is None: + raise RuntimeError( + ( + "The `specs` module from `dm_env` was not imported. Make sure " + "`dm_env` is installed and visible in the current python " + "environment." + ) + ) + if suite is None: + raise RuntimeError( + ( + "The `suite` module from `dm_control` was not imported. Make " + "sure `dm_control` is installed and visible in the current " + "python enviornment." + ) + ) + + # create task + self._env = suite.load( + domain_name=domain_name, + task_name=task_name, + task_kwargs=task_kwargs, + visualize_reward=visualize_reward, + environment_kwargs=environment_kwargs, + ) + + # true and normalized action spaces + self._true_action_space = _spec_to_box([self._env.action_spec()]) + self._norm_action_space = spaces.Box( + low=-1.0, high=1.0, shape=self._true_action_space.shape, dtype=np.float32 + ) + + # create observation space + if from_pixels: + shape = [3, height, width] if channels_first else [height, width, 3] + self._observation_space = spaces.Box( + low=0, high=255, shape=shape, dtype=np.uint8 + ) + if preprocess: + self._observation_space = spaces.Box( + low=-0.5, high=0.5, shape=shape, dtype=np.float32 + ) + else: + self._observation_space = _spec_to_box( + self._env.observation_spec().values() + ) + + self._state_space = _spec_to_box(self._env.observation_spec().values()) + + self.current_state = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def _get_obs(self, time_step): + if self._from_pixels: + obs = self.render( + height=self._height, width=self._width, camera_id=self._camera_id + ) + if self._channels_first: + obs = obs.transpose(2, 0, 1).copy() + if self.preprocess: + obs = obs / 255.0 - 0.5 + else: + obs = _flatten_obs(time_step.observation) + return obs.astype(np.float32) + + def _convert_action(self, action): + action = action.astype(np.float64) + true_delta = self._true_action_space.high - self._true_action_space.low + norm_delta = self._norm_action_space.high - self._norm_action_space.low + action = (action - self._norm_action_space.low) / norm_delta + action = action * true_delta + self._true_action_space.low + action = action.astype(np.float32) + return action + + @property + def observation_space(self): + return self._observation_space + + @property + def state_space(self): + return self._state_space + + @property + def action_space(self): + return self._norm_action_space + + def step(self, action): + assert self._norm_action_space.contains(action) + action = self._convert_action(action) + assert self._true_action_space.contains(action) + reward = 0.0 + extra = {"internal_state": self._env.physics.get_state().copy()} + + terminated = truncated = False + for _ in range(self._frame_skip): + time_step = self._env.step(action) + reward += time_step.reward or 0.0 + terminated = False + truncated = time_step.last() + if terminated or truncated: + break + obs = self._get_obs(time_step) + self.current_state = _flatten_obs(time_step.observation) + extra["discount"] = time_step.discount + return obs, reward, terminated, truncated, extra + + def reset(self, *, seed=None, options=None): + time_step = self._env.reset() + self.current_state = _flatten_obs(time_step.observation) + obs = self._get_obs(time_step) + return obs, {} + + def render(self, mode="rgb_array", height=None, width=None, camera_id=0): + assert mode == "rgb_array", "only support for rgb_array mode" + height = height or self._height + width = width or self._width + camera_id = camera_id or self._camera_id + return self._env.physics.render(height=height, width=width, camera_id=camera_id) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..435251df216b2159a27ad2f28981c5d85b5ab0d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py @@ -0,0 +1,98 @@ +import gymnasium as gym +from gymnasium import spaces + +import numpy as np + +try: + from dm_env import specs +except ImportError: + specs = None + +from ray.rllib.utils.annotations import PublicAPI + + +def _convert_spec_to_space(spec): + if isinstance(spec, dict): + return spaces.Dict({k: _convert_spec_to_space(v) for k, v in spec.items()}) + if isinstance(spec, specs.DiscreteArray): + return spaces.Discrete(spec.num_values) + elif isinstance(spec, specs.BoundedArray): + return spaces.Box( + low=np.asscalar(spec.minimum), + high=np.asscalar(spec.maximum), + shape=spec.shape, + dtype=spec.dtype, + ) + elif isinstance(spec, specs.Array): + return spaces.Box( + low=-float("inf"), high=float("inf"), shape=spec.shape, dtype=spec.dtype + ) + + raise NotImplementedError( + ( + "Could not convert `Array` spec of type {} to Gym space. " + "Attempted to convert: {}" + ).format(type(spec), spec) + ) + + +@PublicAPI +class DMEnv(gym.Env): + """A `gym.Env` wrapper for the `dm_env` API.""" + + metadata = {"render.modes": ["rgb_array"]} + + def __init__(self, dm_env): + super(DMEnv, self).__init__() + self._env = dm_env + self._prev_obs = None + + if specs is None: + raise RuntimeError( + ( + "The `specs` module from `dm_env` was not imported. Make sure " + "`dm_env` is installed and visible in the current python " + "environment." + ) + ) + + def step(self, action): + ts = self._env.step(action) + + reward = ts.reward + if reward is None: + reward = 0.0 + + return ts.observation, reward, ts.last(), False, {"discount": ts.discount} + + def reset(self, *, seed=None, options=None): + ts = self._env.reset() + return ts.observation, {} + + def render(self, mode="rgb_array"): + if self._prev_obs is None: + raise ValueError( + "Environment not started. Make sure to reset before rendering." + ) + + if mode == "rgb_array": + return self._prev_obs + else: + raise NotImplementedError("Render mode '{}' is not supported.".format(mode)) + + @property + def action_space(self): + spec = self._env.action_spec() + return _convert_spec_to_space(spec) + + @property + def observation_space(self): + spec = self._env.observation_spec() + return _convert_spec_to_space(spec) + + @property + def reward_range(self): + spec = self._env.reward_spec() + if isinstance(spec, specs.BoundedArray): + return spec.minimum, spec.maximum + return -float("inf"), float("inf") diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..c9bb592a79d0ae39870926dac2d6406cfbb7e78f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py @@ -0,0 +1,157 @@ +from collections import OrderedDict +import gymnasium as gym +from typing import Dict, List, Optional + +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import AgentID + +# info key for the individual rewards of an agent, for example: +# info: { +# group_1: { +# _group_rewards: [5, -1, 1], # 3 agents in this group +# } +# } +GROUP_REWARDS = "_group_rewards" + +# info key for the individual infos of an agent, for example: +# info: { +# group_1: { +# _group_infos: [{"foo": ...}, {}], # 2 agents in this group +# } +# } +GROUP_INFO = "_group_info" + + +@DeveloperAPI +class GroupAgentsWrapper(MultiAgentEnv): + """Wraps a MultiAgentEnv environment with agents grouped as specified. + + See multi_agent_env.py for the specification of groups. + + This API is experimental. + """ + + def __init__( + self, + env: MultiAgentEnv, + groups: Dict[str, List[AgentID]], + obs_space: Optional[gym.Space] = None, + act_space: Optional[gym.Space] = None, + ): + """Wrap an existing MultiAgentEnv to group agent ID together. + + See `MultiAgentEnv.with_agent_groups()` for more detailed usage info. + + Args: + env: The env to wrap and whose agent IDs to group into new agents. + groups: Mapping from group id to a list of the agent ids + of group members. If an agent id is not present in any group + value, it will be left ungrouped. The group id becomes a new agent ID + in the final environment. + obs_space: Optional observation space for the grouped + env. Must be a tuple space. If not provided, will infer this to be a + Tuple of n individual agents spaces (n=num agents in a group). + act_space: Optional action space for the grouped env. + Must be a tuple space. If not provided, will infer this to be a Tuple + of n individual agents spaces (n=num agents in a group). + """ + super().__init__() + self.env = env + self.groups = groups + self.agent_id_to_group = {} + for group_id, agent_ids in groups.items(): + for agent_id in agent_ids: + if agent_id in self.agent_id_to_group: + raise ValueError( + "Agent id {} is in multiple groups".format(agent_id) + ) + self.agent_id_to_group[agent_id] = group_id + if obs_space is not None: + self.observation_space = obs_space + if act_space is not None: + self.action_space = act_space + for group_id in groups.keys(): + self._agent_ids.add(group_id) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + obs, info = self.env.reset(seed=seed, options=options) + + return ( + self._group_items(obs), + self._group_items( + info, + agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())}, + ), + ) + + def step(self, action_dict): + # Ungroup and send actions. + action_dict = self._ungroup_items(action_dict) + obs, rewards, terminateds, truncateds, infos = self.env.step(action_dict) + + # Apply grouping transforms to the env outputs + obs = self._group_items(obs) + rewards = self._group_items(rewards, agg_fn=lambda gvals: list(gvals.values())) + # Only if all of the agents are terminated, the group is terminated as well. + terminateds = self._group_items( + terminateds, agg_fn=lambda gvals: all(gvals.values()) + ) + # If all of the agents are truncated, the group is truncated as well. + truncateds = self._group_items( + truncateds, + agg_fn=lambda gvals: all(gvals.values()), + ) + infos = self._group_items( + infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())} + ) + + # Aggregate rewards, but preserve the original values in infos. + for agent_id, rew in rewards.items(): + if isinstance(rew, list): + rewards[agent_id] = sum(rew) + if agent_id not in infos: + infos[agent_id] = {} + infos[agent_id][GROUP_REWARDS] = rew + + return obs, rewards, terminateds, truncateds, infos + + def _ungroup_items(self, items): + out = {} + for agent_id, value in items.items(): + if agent_id in self.groups: + assert len(value) == len(self.groups[agent_id]), ( + agent_id, + value, + self.groups, + ) + for a, v in zip(self.groups[agent_id], value): + out[a] = v + else: + out[agent_id] = value + return out + + def _group_items(self, items, agg_fn=None): + if agg_fn is None: + agg_fn = lambda gvals: list(gvals.values()) # noqa: E731 + + grouped_items = {} + for agent_id, item in items.items(): + if agent_id in self.agent_id_to_group: + group_id = self.agent_id_to_group[agent_id] + if group_id in grouped_items: + continue # already added + group_out = OrderedDict() + for a in self.groups[group_id]: + if a in items: + group_out[a] = items[a] + else: + raise ValueError( + "Missing member of group {}: {}: {}".format( + group_id, a, items + ) + ) + grouped_items[group_id] = agg_fn(group_out) + else: + grouped_items[agent_id] = item + return grouped_items diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8efeda08346e5486f5a51b1fc69b614a7fffab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py @@ -0,0 +1,73 @@ +from typing import Optional, Tuple + +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.typing import MultiAgentDict + + +class MultiAgentEnvCompatibility(MultiAgentEnv): + """A wrapper converting MultiAgentEnv from old gym API to the new one. + + "Old API" refers to step() method returning (observation, reward, done, info), + and reset() only retuning the observation. + "New API" refers to step() method returning (observation, reward, terminated, + truncated, info) and reset() returning (observation, info). + + Known limitations: + - Environments that use `self.np_random` might not work as expected. + """ + + def __init__(self, old_env, render_mode: Optional[str] = None): + """A wrapper which converts old-style envs to valid modern envs. + + Some information may be lost in the conversion, so we recommend updating your + environment. + + Args: + old_env: The old MultiAgentEnv to wrap. Implemented with the old API. + render_mode: The render mode to use when rendering the environment, + passed automatically to `env.render()`. + """ + super().__init__() + + self.metadata = getattr(old_env, "metadata", {"render_modes": []}) + self.render_mode = render_mode + self.reward_range = getattr(old_env, "reward_range", None) + self.spec = getattr(old_env, "spec", None) + self.env = old_env + + self.observation_space = old_env.observation_space + self.action_space = old_env.action_space + + def reset( + self, *, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[MultiAgentDict, MultiAgentDict]: + # Use old `seed()` method. + if seed is not None: + self.env.seed(seed) + # Options are ignored + + if self.render_mode == "human": + self.render() + + obs = self.env.reset() + infos = {k: {} for k in obs.keys()} + return obs, infos + + def step( + self, action + ) -> Tuple[ + MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict + ]: + obs, rewards, terminateds, infos = self.env.step(action) + + # Truncated should always be False by default. + truncateds = {k: False for k in terminateds.keys()} + + return obs, rewards, terminateds, truncateds, infos + + def render(self): + # Use the old `render()` API, where we have to pass in the mode to each call. + return self.env.render(mode=self.render_mode) + + def close(self): + self.env.close() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py new file mode 100644 index 0000000000000000000000000000000000000000..c46c7530098800b43a10d95ee348bb40494aa2bd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py @@ -0,0 +1,130 @@ +from typing import Optional + +import numpy as np +import gymnasium as gym + +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.utils import try_import_pyspiel + +pyspiel = try_import_pyspiel(error=True) + + +class OpenSpielEnv(MultiAgentEnv): + def __init__(self, env): + super().__init__() + self.env = env + self.agents = self.possible_agents = list(range(self.env.num_players())) + # Store the open-spiel game type. + self.type = self.env.get_type() + # Stores the current open-spiel game state. + self.state = None + + self.observation_space = gym.spaces.Dict( + { + aid: gym.spaces.Box( + float("-inf"), + float("inf"), + (self.env.observation_tensor_size(),), + dtype=np.float32, + ) + for aid in self.possible_agents + } + ) + self.action_space = gym.spaces.Dict( + { + aid: gym.spaces.Discrete(self.env.num_distinct_actions()) + for aid in self.possible_agents + } + ) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + self.state = self.env.new_initial_state() + return self._get_obs(), {} + + def step(self, action): + # Before applying action(s), there could be chance nodes. + # E.g. if env has to figure out, which agent's action should get + # resolved first in a simultaneous node. + self._solve_chance_nodes() + penalties = {} + + # Sequential game: + if str(self.type.dynamics) == "Dynamics.SEQUENTIAL": + curr_player = self.state.current_player() + assert curr_player in action + try: + self.state.apply_action(action[curr_player]) + # TODO: (sven) resolve this hack by publishing legal actions + # with each step. + except pyspiel.SpielError: + self.state.apply_action(np.random.choice(self.state.legal_actions())) + penalties[curr_player] = -0.1 + + # Compile rewards dict. + rewards = {ag: r for ag, r in enumerate(self.state.returns())} + # Simultaneous game. + else: + assert self.state.current_player() == -2 + # Apparently, this works, even if one or more actions are invalid. + self.state.apply_actions([action[ag] for ag in range(self.num_agents)]) + + # Now that we have applied all actions, get the next obs. + obs = self._get_obs() + + # Compile rewards dict and add the accumulated penalties + # (for taking invalid actions). + rewards = {ag: r for ag, r in enumerate(self.state.returns())} + for ag, penalty in penalties.items(): + rewards[ag] += penalty + + # Are we done? + is_terminated = self.state.is_terminal() + terminateds = dict( + {ag: is_terminated for ag in range(self.num_agents)}, + **{"__all__": is_terminated} + ) + truncateds = dict( + {ag: False for ag in range(self.num_agents)}, **{"__all__": False} + ) + + return obs, rewards, terminateds, truncateds, {} + + def render(self, mode=None) -> None: + if mode == "human": + print(self.state) + + def _get_obs(self): + # Before calculating an observation, there could be chance nodes + # (that may have an effect on the actual observations). + # E.g. After reset, figure out initial (random) positions of the + # agents. + self._solve_chance_nodes() + + if self.state.is_terminal(): + return {} + + # Sequential game: + if str(self.type.dynamics) == "Dynamics.SEQUENTIAL": + curr_player = self.state.current_player() + return { + curr_player: np.reshape(self.state.observation_tensor(), [-1]).astype( + np.float32 + ) + } + # Simultaneous game. + else: + assert self.state.current_player() == -2 + return { + ag: np.reshape(self.state.observation_tensor(ag), [-1]).astype( + np.float32 + ) + for ag in range(self.num_agents) + } + + def _solve_chance_nodes(self): + # Chance node(s): Sample a (non-player) action and apply. + while self.state.is_chance_node(): + assert self.state.current_player() == -1 + actions, probs = zip(*self.state.chance_outcomes()) + action = np.random.choice(actions, p=probs) + self.state.apply_action(action) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ee4cf4d6b2470ba434a5f91537a5e1ff732c53 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py @@ -0,0 +1,214 @@ +from typing import Optional + +import gymnasium as gym + +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.annotations import PublicAPI + + +@PublicAPI +class PettingZooEnv(MultiAgentEnv): + """An interface to the PettingZoo MARL environment library. + + See: https://github.com/Farama-Foundation/PettingZoo + + Inherits from MultiAgentEnv and exposes a given AEC + (actor-environment-cycle) game from the PettingZoo project via the + MultiAgentEnv public API. + + Note that the wrapper has the following important limitation: + + Environments are positive sum games (-> Agents are expected to cooperate + to maximize reward). This isn't a hard restriction, it just that + standard algorithms aren't expected to work well in highly competitive + games. + + Also note that the earlier existing restriction of all agents having the same + observation- and action spaces has been lifted. Different agents can now have + different spaces and the entire environment's e.g. `self.action_space` is a Dict + mapping agent IDs to individual agents' spaces. Same for `self.observation_space`. + + .. testcode:: + :skipif: True + + from pettingzoo.butterfly import prison_v3 + from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv + env = PettingZooEnv(prison_v3.env()) + obs, infos = env.reset() + # only returns the observation for the agent which should be stepping + print(obs) + + .. testoutput:: + + { + 'prisoner_0': array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ..., + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], dtype=uint8) + } + + .. testcode:: + :skipif: True + + obs, rewards, terminateds, truncateds, infos = env.step({ + "prisoner_0": 1 + }) + # only returns the observation, reward, info, etc, for + # the agent who's turn is next. + print(obs) + + .. testoutput:: + + { + 'prisoner_1': array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ..., + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], dtype=uint8) + } + + .. testcode:: + :skipif: True + + print(rewards) + + .. testoutput:: + + { + 'prisoner_1': 0 + } + + .. testcode:: + :skipif: True + + print(terminateds) + + .. testoutput:: + + { + 'prisoner_1': False, '__all__': False + } + + .. testcode:: + :skipif: True + + print(truncateds) + + .. testoutput:: + + { + 'prisoner_1': False, '__all__': False + } + + .. testcode:: + :skipif: True + + print(infos) + + .. testoutput:: + + { + 'prisoner_1': {'map_tuple': (1, 0)} + } + """ + + def __init__(self, env): + super().__init__() + self.env = env + env.reset() + + self._agent_ids = set(self.env.agents) + + self.observation_space = gym.spaces.Dict( + {aid: self.env.observation_space(aid) for aid in self._agent_ids} + ) + self.action_space = gym.spaces.Dict( + {aid: self.env.action_space(aid) for aid in self._agent_ids} + ) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + info = self.env.reset(seed=seed, options=options) + return ( + {self.env.agent_selection: self.env.observe(self.env.agent_selection)}, + info or {}, + ) + + def step(self, action): + self.env.step(action[self.env.agent_selection]) + obs_d = {} + rew_d = {} + terminated_d = {} + truncated_d = {} + info_d = {} + while self.env.agents: + obs, rew, terminated, truncated, info = self.env.last() + agent_id = self.env.agent_selection + obs_d[agent_id] = obs + rew_d[agent_id] = rew + terminated_d[agent_id] = terminated + truncated_d[agent_id] = truncated + info_d[agent_id] = info + if ( + self.env.terminations[self.env.agent_selection] + or self.env.truncations[self.env.agent_selection] + ): + self.env.step(None) + else: + break + + all_gone = not self.env.agents + terminated_d["__all__"] = all_gone and all(terminated_d.values()) + truncated_d["__all__"] = all_gone and all(truncated_d.values()) + + return obs_d, rew_d, terminated_d, truncated_d, info_d + + def close(self): + self.env.close() + + def render(self): + return self.env.render(self.render_mode) + + @property + def get_sub_environments(self): + return self.env.unwrapped + + +@PublicAPI +class ParallelPettingZooEnv(MultiAgentEnv): + def __init__(self, env): + super().__init__() + self.par_env = env + self.par_env.reset() + self._agent_ids = set(self.par_env.agents) + + self.observation_space = gym.spaces.Dict( + {aid: self.par_env.observation_space(aid) for aid in self._agent_ids} + ) + self.action_space = gym.spaces.Dict( + {aid: self.par_env.action_space(aid) for aid in self._agent_ids} + ) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + obs, info = self.par_env.reset(seed=seed, options=options) + return obs, info or {} + + def step(self, action_dict): + obss, rews, terminateds, truncateds, infos = self.par_env.step(action_dict) + terminateds["__all__"] = all(terminateds.values()) + truncateds["__all__"] = all(truncateds.values()) + return obss, rews, terminateds, truncateds, infos + + def close(self): + self.par_env.close() + + def render(self): + return self.par_env.render(self.render_mode) + + @property + def get_sub_environments(self): + return self.par_env.unwrapped diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py new file mode 100644 index 0000000000000000000000000000000000000000..45f0f910af923e2dc44598bdbee31b9605fc9bd4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py @@ -0,0 +1,381 @@ +from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace +import logging +import numpy as np +import random +import time +from typing import Callable, Optional, Tuple + +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID + +logger = logging.getLogger(__name__) + + +@PublicAPI +class Unity3DEnv(MultiAgentEnv): + """A MultiAgentEnv representing a single Unity3D game instance. + + For an example on how to use this Env with a running Unity3D editor + or with a compiled game, see: + `rllib/examples/unity3d_env_local.py` + For an example on how to use it inside a Unity game client, which + connects to an RLlib Policy server, see: + `rllib/examples/envs/external_envs/unity3d_[client|server].py` + + Supports all Unity3D (MLAgents) examples, multi- or single-agent and + gets converted automatically into an ExternalMultiAgentEnv, when used + inside an RLlib PolicyClient for cloud/distributed training of Unity games. + """ + + # Default base port when connecting directly to the Editor + _BASE_PORT_EDITOR = 5004 + # Default base port when connecting to a compiled environment + _BASE_PORT_ENVIRONMENT = 5005 + # The worker_id for each environment instance + _WORKER_ID = 0 + + def __init__( + self, + file_name: str = None, + port: Optional[int] = None, + seed: int = 0, + no_graphics: bool = False, + timeout_wait: int = 300, + episode_horizon: int = 1000, + ): + """Initializes a Unity3DEnv object. + + Args: + file_name (Optional[str]): Name of the Unity game binary. + If None, will assume a locally running Unity3D editor + to be used, instead. + port (Optional[int]): Port number to connect to Unity environment. + seed: A random seed value to use for the Unity3D game. + no_graphics: Whether to run the Unity3D simulator in + no-graphics mode. Default: False. + timeout_wait: Time (in seconds) to wait for connection from + the Unity3D instance. + episode_horizon: A hard horizon to abide to. After at most + this many steps (per-agent episode `step()` calls), the + Unity3D game is reset and will start again (finishing the + multi-agent episode that the game represents). + Note: The game itself may contain its own episode length + limits, which are always obeyed (on top of this value here). + """ + super().__init__() + + if file_name is None: + print( + "No game binary provided, will use a running Unity editor " + "instead.\nMake sure you are pressing the Play (|>) button in " + "your editor to start." + ) + + import mlagents_envs + from mlagents_envs.environment import UnityEnvironment + + # Try connecting to the Unity3D game instance. If a port is blocked + port_ = None + while True: + # Sleep for random time to allow for concurrent startup of many + # environments (num_env_runners >> 1). Otherwise, would lead to port + # conflicts sometimes. + if port_ is not None: + time.sleep(random.randint(1, 10)) + port_ = port or ( + self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR + ) + # cache the worker_id and + # increase it for the next environment + worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0 + Unity3DEnv._WORKER_ID += 1 + try: + self.unity_env = UnityEnvironment( + file_name=file_name, + worker_id=worker_id_, + base_port=port_, + seed=seed, + no_graphics=no_graphics, + timeout_wait=timeout_wait, + ) + print("Created UnityEnvironment for port {}".format(port_ + worker_id_)) + except mlagents_envs.exception.UnityWorkerInUseException: + pass + else: + break + + # ML-Agents API version. + self.api_version = self.unity_env.API_VERSION.split(".") + self.api_version = [int(s) for s in self.api_version] + + # Reset entire env every this number of step calls. + self.episode_horizon = episode_horizon + # Keep track of how many times we have called `step` so far. + self.episode_timesteps = 0 + + def step( + self, action_dict: MultiAgentDict + ) -> Tuple[ + MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict + ]: + """Performs one multi-agent step through the game. + + Args: + action_dict: Multi-agent action dict with: + keys=agent identifier consisting of + [MLagents behavior name, e.g. "Goalie?team=1"] + "_" + + [Agent index, a unique MLAgent-assigned index per single agent] + + Returns: + tuple: + - obs: Multi-agent observation dict. + Only those observations for which to get new actions are + returned. + - rewards: Rewards dict matching `obs`. + - dones: Done dict with only an __all__ multi-agent entry in + it. __all__=True, if episode is done for all agents. + - infos: An (empty) info dict. + """ + from mlagents_envs.base_env import ActionTuple + + # Set only the required actions (from the DecisionSteps) in Unity3D. + all_agents = [] + for behavior_name in self.unity_env.behavior_specs: + # New ML-Agents API: Set all agents actions at the same time + # via an ActionTuple. Since API v1.4.0. + if self.api_version[0] > 1 or ( + self.api_version[0] == 1 and self.api_version[1] >= 4 + ): + actions = [] + for agent_id in self.unity_env.get_steps(behavior_name)[0].agent_id: + key = behavior_name + "_{}".format(agent_id) + all_agents.append(key) + actions.append(action_dict[key]) + if actions: + if actions[0].dtype == np.float32: + action_tuple = ActionTuple(continuous=np.array(actions)) + else: + action_tuple = ActionTuple(discrete=np.array(actions)) + self.unity_env.set_actions(behavior_name, action_tuple) + # Old behavior: Do not use an ActionTuple and set each agent's + # action individually. + else: + for agent_id in self.unity_env.get_steps(behavior_name)[ + 0 + ].agent_id_to_index.keys(): + key = behavior_name + "_{}".format(agent_id) + all_agents.append(key) + self.unity_env.set_action_for_agent( + behavior_name, agent_id, action_dict[key] + ) + # Do the step. + self.unity_env.step() + + obs, rewards, terminateds, truncateds, infos = self._get_step_results() + + # Global horizon reached? -> Return __all__ truncated=True, so user + # can reset. Set all agents' individual `truncated` to True as well. + self.episode_timesteps += 1 + if self.episode_timesteps > self.episode_horizon: + return ( + obs, + rewards, + terminateds, + dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}), + infos, + ) + + return obs, rewards, terminateds, truncateds, infos + + def reset( + self, *, seed=None, options=None + ) -> Tuple[MultiAgentDict, MultiAgentDict]: + """Resets the entire Unity3D scene (a single multi-agent episode).""" + self.episode_timesteps = 0 + self.unity_env.reset() + obs, _, _, _, infos = self._get_step_results() + return obs, infos + + def _get_step_results(self): + """Collects those agents' obs/rewards that have to act in next `step`. + + Returns: + Tuple: + obs: Multi-agent observation dict. + Only those observations for which to get new actions are + returned. + rewards: Rewards dict matching `obs`. + dones: Done dict with only an __all__ multi-agent entry in it. + __all__=True, if episode is done for all agents. + infos: An (empty) info dict. + """ + obs = {} + rewards = {} + infos = {} + for behavior_name in self.unity_env.behavior_specs: + decision_steps, terminal_steps = self.unity_env.get_steps(behavior_name) + # Important: Only update those sub-envs that are currently + # available within _env_state. + # Loop through all envs ("agents") and fill in, whatever + # information we have. + for agent_id, idx in decision_steps.agent_id_to_index.items(): + key = behavior_name + "_{}".format(agent_id) + os = tuple(o[idx] for o in decision_steps.obs) + os = os[0] if len(os) == 1 else os + obs[key] = os + rewards[key] = ( + decision_steps.reward[idx] + decision_steps.group_reward[idx] + ) + for agent_id, idx in terminal_steps.agent_id_to_index.items(): + key = behavior_name + "_{}".format(agent_id) + # Only overwrite rewards (last reward in episode), b/c obs + # here is the last obs (which doesn't matter anyways). + # Unless key does not exist in obs. + if key not in obs: + os = tuple(o[idx] for o in terminal_steps.obs) + obs[key] = os = os[0] if len(os) == 1 else os + rewards[key] = ( + terminal_steps.reward[idx] + terminal_steps.group_reward[idx] + ) + + # Only use dones if all agents are done, then we should do a reset. + return obs, rewards, {"__all__": False}, {"__all__": False}, infos + + @staticmethod + def get_policy_configs_for_game( + game_name: str, + ) -> Tuple[dict, Callable[[AgentID], PolicyID]]: + + # The RLlib server must know about the Spaces that the Client will be + # using inside Unity3D, up-front. + obs_spaces = { + # 3DBall. + "3DBall": Box(float("-inf"), float("inf"), (8,)), + # 3DBallHard. + "3DBallHard": Box(float("-inf"), float("inf"), (45,)), + # GridFoodCollector + "GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)), + # Pyramids. + "Pyramids": TupleSpace( + [ + Box(float("-inf"), float("inf"), (56,)), + Box(float("-inf"), float("inf"), (56,)), + Box(float("-inf"), float("inf"), (56,)), + Box(float("-inf"), float("inf"), (4,)), + ] + ), + # SoccerTwos. + "SoccerPlayer": TupleSpace( + [ + Box(-1.0, 1.0, (264,)), + Box(-1.0, 1.0, (72,)), + ] + ), + # SoccerStrikersVsGoalie. + "Goalie": Box(float("-inf"), float("inf"), (738,)), + "Striker": TupleSpace( + [ + Box(float("-inf"), float("inf"), (231,)), + Box(float("-inf"), float("inf"), (63,)), + ] + ), + # Sorter. + "Sorter": TupleSpace( + [ + Box( + float("-inf"), + float("inf"), + ( + 20, + 23, + ), + ), + Box(float("-inf"), float("inf"), (10,)), + Box(float("-inf"), float("inf"), (8,)), + ] + ), + # Tennis. + "Tennis": Box(float("-inf"), float("inf"), (27,)), + # VisualHallway. + "VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)), + # Walker. + "Walker": Box(float("-inf"), float("inf"), (212,)), + # FoodCollector. + "FoodCollector": TupleSpace( + [ + Box(float("-inf"), float("inf"), (49,)), + Box(float("-inf"), float("inf"), (4,)), + ] + ), + } + action_spaces = { + # 3DBall. + "3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32), + # 3DBallHard. + "3DBallHard": Box(-1.0, 1.0, (2,), dtype=np.float32), + # GridFoodCollector. + "GridFoodCollector": MultiDiscrete([3, 3, 3, 2]), + # Pyramids. + "Pyramids": MultiDiscrete([5]), + # SoccerStrikersVsGoalie. + "Goalie": MultiDiscrete([3, 3, 3]), + "Striker": MultiDiscrete([3, 3, 3]), + # SoccerTwos. + "SoccerPlayer": MultiDiscrete([3, 3, 3]), + # Sorter. + "Sorter": MultiDiscrete([3, 3, 3]), + # Tennis. + "Tennis": Box(-1.0, 1.0, (3,)), + # VisualHallway. + "VisualHallway": MultiDiscrete([5]), + # Walker. + "Walker": Box(-1.0, 1.0, (39,)), + # FoodCollector. + "FoodCollector": MultiDiscrete([3, 3, 3, 2]), + } + + # Policies (Unity: "behaviors") and agent-to-policy mapping fns. + if game_name == "SoccerStrikersVsGoalie": + policies = { + "Goalie": PolicySpec( + observation_space=obs_spaces["Goalie"], + action_space=action_spaces["Goalie"], + ), + "Striker": PolicySpec( + observation_space=obs_spaces["Striker"], + action_space=action_spaces["Striker"], + ), + } + + def policy_mapping_fn(agent_id, episode, worker, **kwargs): + return "Striker" if "Striker" in agent_id else "Goalie" + + elif game_name == "SoccerTwos": + policies = { + "PurplePlayer": PolicySpec( + observation_space=obs_spaces["SoccerPlayer"], + action_space=action_spaces["SoccerPlayer"], + ), + "BluePlayer": PolicySpec( + observation_space=obs_spaces["SoccerPlayer"], + action_space=action_spaces["SoccerPlayer"], + ), + } + + def policy_mapping_fn(agent_id, episode, worker, **kwargs): + return "BluePlayer" if "1_" in agent_id else "PurplePlayer" + + else: + policies = { + game_name: PolicySpec( + observation_space=obs_spaces[game_name], + action_space=action_spaces[game_name], + ), + } + + def policy_mapping_fn(agent_id, episode, worker, **kwargs): + return game_name + + return policies, policy_mapping_fn diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8108cb3d50fbe83382ece7a891b44d1d2faef52a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2c9edb18150e801a0b0d6bf1e5e7d82fdaa2592 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__pycache__/deterministic_training.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__pycache__/deterministic_training.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..174815f6fbb0361d87a3d1f9a4867454eea11678 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/__pycache__/deterministic_training.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/deterministic_training.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/deterministic_training.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7a8960c56e4cabdfefdafe77e8b8218fb2f05b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/examples/debugging/deterministic_training.py @@ -0,0 +1,111 @@ +# @OldAPIStack + +""" +Example of a fully deterministic, repeatable RLlib train run using +the "seed" config key. +""" +import argparse + +import ray +from ray import air, tune +from ray.air.constants import TRAINING_ITERATION +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.examples.envs.classes.env_using_remote_actor import ( + CartPoleWithRemoteParamServer, + ParameterStorage, +) +from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO +from ray.rllib.utils.test_utils import check +from ray.tune.registry import get_trainable_cls + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--framework", choices=["tf2", "tf", "torch"], default="torch") +parser.add_argument("--seed", type=int, default=42) +parser.add_argument("--as-test", action="store_true") +parser.add_argument("--stop-iters", type=int, default=2) +parser.add_argument("--num-gpus", type=float, default=0) +parser.add_argument("--num-gpus-per-env-runner", type=float, default=0) + +if __name__ == "__main__": + args = parser.parse_args() + + param_storage = ParameterStorage.options(name="param-server").remote() + + config = ( + get_trainable_cls(args.run) + .get_default_config() + .api_stack( + enable_rl_module_and_learner=False, + enable_env_runner_and_connector_v2=False, + ) + .environment( + CartPoleWithRemoteParamServer, + env_config={"param_server": "param-server"}, + ) + .framework(args.framework) + .env_runners( + num_env_runners=1, + num_envs_per_env_runner=2, + rollout_fragment_length=50, + num_gpus_per_env_runner=args.num_gpus_per_env_runner, + ) + # The new Learner API. + .learners( + num_learners=int(args.num_gpus), + num_gpus_per_learner=int(args.num_gpus > 0), + ) + # Old gpu-training API. + .resources( + num_gpus=args.num_gpus, + ) + # Make sure every environment gets a fixed seed. + .debugging(seed=args.seed) + .training( + train_batch_size=100, + ) + ) + + if args.run == "PPO": + # Simplify to run this example script faster. + config.training(minibatch_size=10, num_epochs=5) + + stop = {TRAINING_ITERATION: args.stop_iters} + + results1 = tune.Tuner( + args.run, + param_space=config.to_dict(), + run_config=air.RunConfig( + stop=stop, verbose=1, failure_config=air.FailureConfig(fail_fast="raise") + ), + ).fit() + results2 = tune.Tuner( + args.run, + param_space=config.to_dict(), + run_config=air.RunConfig( + stop=stop, verbose=1, failure_config=air.FailureConfig(fail_fast="raise") + ), + ).fit() + + if args.as_test: + results1 = results1.get_best_result().metrics + results2 = results2.get_best_result().metrics + # Test rollout behavior. + check( + results1[ENV_RUNNER_RESULTS]["hist_stats"], + results2[ENV_RUNNER_RESULTS]["hist_stats"], + ) + # As well as training behavior (minibatch sequence during SGD + # iterations). + if config.enable_rl_module_and_learner: + check( + results1["info"][LEARNER_INFO][DEFAULT_MODULE_ID], + results2["info"][LEARNER_INFO][DEFAULT_MODULE_ID], + ) + else: + check( + results1["info"][LEARNER_INFO][DEFAULT_MODULE_ID]["learner_stats"], + results2["info"][LEARNER_INFO][DEFAULT_MODULE_ID]["learner_stats"], + ) + ray.shutdown() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cf30b3954529fabc48b6362684620015233003d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/custom_evaluation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/custom_evaluation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60612b070b4ffe352a08bfc596993687906f7e64 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/custom_evaluation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/evaluation_parallel_to_training.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/evaluation_parallel_to_training.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15bc5a1f9a61b97a36701e1db9e3c0812d228521 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/__pycache__/evaluation_parallel_to_training.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/evaluation_parallel_to_training.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/evaluation_parallel_to_training.py new file mode 100644 index 0000000000000000000000000000000000000000..841496c2aca45a455cd2989f45dfde5071943755 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/examples/evaluation/evaluation_parallel_to_training.py @@ -0,0 +1,246 @@ +"""Example showing how one can set up evaluation running in parallel to training. + +Such a setup saves a considerable amount of time during RL Algorithm training, b/c +the next training step does NOT have to wait for the previous evaluation procedure to +finish, but can already start running (in parallel). + +See RLlib's documentation for more details on the effect of the different supported +evaluation configuration options: +https://docs.ray.io/en/latest/rllib/rllib-advanced-api.html#customized-evaluation-during-training # noqa + +For an example of how to write a fully customized evaluation function (which normally +is not necessary as the config options are sufficient and offer maximum flexibility), +see this example script here: + +https://github.com/ray-project/ray/blob/master/rllib/examples/evaluation/custom_evaluation.py # noqa + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +Use the `--evaluation-num-workers` option to scale up the evaluation workers. Note +that the requested evaluation duration (`--evaluation-duration` measured in +`--evaluation-duration-unit`, which is either "timesteps" (default) or "episodes") is +shared between all configured evaluation workers. For example, if the evaluation +duration is 10 and the unit is "episodes" and you configured 5 workers, then each of the +evaluation workers will run exactly 2 episodes. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should see the following output (at the end of the experiment) in your console when +running with a fixed number of 100k training timesteps +(`--enable-new-api-stack --evaluation-duration=auto --stop-timesteps=100000 +--stop-reward=100000`): ++-----------------------------+------------+-----------------+--------+ +| Trial name | status | loc | iter | +|-----------------------------+------------+-----------------+--------+ +| PPO_CartPole-v1_1377a_00000 | TERMINATED | 127.0.0.1:73330 | 25 | ++-----------------------------+------------+-----------------+--------+ ++------------------+--------+----------+--------------------+ +| total time (s) | ts | reward | episode_len_mean | +|------------------+--------+----------+--------------------| +| 71.7485 | 100000 | 476.51 | 476.51 | ++------------------+--------+----------+--------------------+ + +When running without parallel evaluation (no `--evaluation-parallel-to-training` flag), +the experiment takes considerably longer (~70sec vs ~80sec): ++-----------------------------+------------+-----------------+--------+ +| Trial name | status | loc | iter | +|-----------------------------+------------+-----------------+--------+ +| PPO_CartPole-v1_f1788_00000 | TERMINATED | 127.0.0.1:75135 | 25 | ++-----------------------------+------------+-----------------+--------+ ++------------------+--------+----------+--------------------+ +| total time (s) | ts | reward | episode_len_mean | +|------------------+--------+----------+--------------------| +| 81.7371 | 100000 | 494.68 | 494.68 | ++------------------+--------+----------+--------------------+ +""" +from typing import Optional + +from ray.air.constants import TRAINING_ITERATION +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.callbacks.callbacks import RLlibCallback +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + EVALUATION_RESULTS, + NUM_EPISODES, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.rllib.utils.typing import ResultDict +from ray.tune.registry import get_trainable_cls, register_env + +parser = add_rllib_example_script_args(default_reward=500.0) +parser.set_defaults( + evaluation_num_env_runners=2, + evaluation_interval=1, + evaluation_duration_unit="timesteps", +) + + +class AssertEvalCallback(RLlibCallback): + def on_train_result( + self, + *, + algorithm: Algorithm, + metrics_logger: Optional[MetricsLogger] = None, + result: ResultDict, + **kwargs, + ): + # The eval results can be found inside the main `result` dict + # (old API stack: "evaluation"). + eval_results = result.get(EVALUATION_RESULTS, {}) + # In there, there is a sub-key: ENV_RUNNER_RESULTS. + eval_env_runner_results = eval_results.get(ENV_RUNNER_RESULTS) + # Make sure we always run exactly the given evaluation duration, + # no matter what the other settings are (such as + # `evaluation_num_env_runners` or `evaluation_parallel_to_training`). + if eval_env_runner_results and NUM_EPISODES in eval_env_runner_results: + num_episodes_done = eval_env_runner_results[NUM_EPISODES] + if algorithm.config.enable_env_runner_and_connector_v2: + num_timesteps_reported = eval_env_runner_results[NUM_ENV_STEPS_SAMPLED] + else: + num_timesteps_reported = eval_results["timesteps_this_iter"] + + # We run for automatic duration (as long as training takes). + if algorithm.config.evaluation_duration == "auto": + # If duration=auto: Expect at least as many timesteps as workers + # (each worker's `sample()` is at least called once). + # UNLESS: All eval workers were completely busy during the auto-time + # with older (async) requests and did NOT return anything from the async + # fetch. + assert ( + num_timesteps_reported == 0 + or num_timesteps_reported + >= algorithm.config.evaluation_num_env_runners + ) + # We count in episodes. + elif algorithm.config.evaluation_duration_unit == "episodes": + # Compare number of entries in episode_lengths (this is the + # number of episodes actually run) with desired number of + # episodes from the config. + assert ( + algorithm.iteration + 1 % algorithm.config.evaluation_interval != 0 + or num_episodes_done == algorithm.config.evaluation_duration + ), (num_episodes_done, algorithm.config.evaluation_duration) + print( + "Number of run evaluation episodes: " f"{num_episodes_done} (ok)!" + ) + # We count in timesteps. + else: + # TODO (sven): This assertion works perfectly fine locally, but breaks + # the CI for no reason. The observed collected timesteps is +500 more + # than desired (~2500 instead of 2011 and ~1250 vs 1011). + # num_timesteps_wanted = algorithm.config.evaluation_duration + # delta = num_timesteps_wanted - num_timesteps_reported + # Expect roughly the same (desired // num-eval-workers). + # assert abs(delta) < 20, ( + # delta, + # num_timesteps_wanted, + # num_timesteps_reported, + # ) + print( + "Number of run evaluation timesteps: " + f"{num_timesteps_reported} (ok?)!" + ) + + +if __name__ == "__main__": + args = parser.parse_args() + + # Register our environment with tune. + if args.num_agents > 0: + register_env( + "env", + lambda _: MultiAgentCartPole(config={"num_agents": args.num_agents}), + ) + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment("env" if args.num_agents > 0 else "CartPole-v1") + # Use a custom callback that asserts that we are running the + # configured exact number of episodes per evaluation OR - in auto + # mode - run at least as many episodes as we have eval workers. + .callbacks(AssertEvalCallback) + .evaluation( + # Parallel evaluation+training config. + # Switch on evaluation in parallel with training. + evaluation_parallel_to_training=args.evaluation_parallel_to_training, + # Use two evaluation workers. Must be >0, otherwise, + # evaluation will run on a local worker and block (no parallelism). + evaluation_num_env_runners=args.evaluation_num_env_runners, + # Evaluate every other training iteration (together + # with every other call to Algorithm.train()). + evaluation_interval=args.evaluation_interval, + # Run for n episodes/timesteps (properly distribute load amongst + # all eval workers). The longer it takes to evaluate, the more sense + # it makes to use `evaluation_parallel_to_training=True`. + # Use "auto" to run evaluation for roughly as long as the training + # step takes. + evaluation_duration=args.evaluation_duration, + # "episodes" or "timesteps". + evaluation_duration_unit=args.evaluation_duration_unit, + # Switch off exploratory behavior for better (greedy) results. + evaluation_config={ + "explore": False, + # TODO (sven): Add support for window=float(inf) and reduce=mean for + # evaluation episode_return_mean reductions (identical to old stack + # behavior, which does NOT use a window (100 by default) to reduce + # eval episode returns. + "metrics_num_episodes_for_smoothing": 5, + }, + ) + ) + + # Add a simple multi-agent setup. + if args.num_agents > 0: + base_config.multi_agent( + policies={f"p{i}" for i in range(args.num_agents)}, + policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", + ) + # Set some PPO-specific tuning settings to learn better in the env (assumed to be + # CartPole-v1). + if args.algo == "PPO": + base_config.training( + lr=0.0003, + num_epochs=6, + vf_loss_coeff=0.01, + ) + + stop = { + TRAINING_ITERATION: args.stop_iters, + f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": ( + args.stop_reward + ), + NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, + } + + run_rllib_example_script_experiment( + base_config, + args, + stop=stop, + success_metric={ + f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": ( + args.stop_reward + ), + }, + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..851acbee99cb0fb022c72c2a91f75a2306de8cd4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__pycache__/hierarchical_training.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__pycache__/hierarchical_training.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25ca31301498307b08016e8a413b2b7f897901e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/__pycache__/hierarchical_training.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/hierarchical_training.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/hierarchical_training.py new file mode 100644 index 0000000000000000000000000000000000000000..d7401e262ea612bdc81d41a4747f141ef2d78a34 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/examples/hierarchical/hierarchical_training.py @@ -0,0 +1,182 @@ +"""Example of running a hierarchichal training setup in RLlib using its multi-agent API. + +This example is very loosely based on this paper: +[1] Hierarchical RL Based on Subgoal Discovery and Subpolicy Specialization - +B. Bakker & J. Schmidhuber - 2003 + +The approach features one high level policy, which picks the next target state to be +reached by one of three low level policies as well as the actual low level policy to +take over control. +A low level policy - once chosen by the high level one - has up to 10 primitive +timesteps to reach the given target state. If it reaches it, both high level and low +level policy are rewarded and the high level policy takes another action (choses a new +target state and a new low level policy). +A global goal state must be reached to deem the overall task to be solved. Once one +of the lower level policies reaches that goal state, the high level policy receives +a large reward and the episode ends. +The approach utilizes the possibility for low level policies to specialize in reaching +certain sub-goals and the high level policy to know, which sub goals to pick next and +which "expert" (low level policy) to allow to reach the subgoal. + +This example: + - demonstrates how to write a relatively simple custom multi-agent environment and + have it behave, such that it mimics a hierarchical RL setup with higher- and lower + level agents acting on different abstract time axes (the higher level policy + only acts occasionally, picking a new lower level policy and the lower level + policies have each n primitive timesteps to reach the given target state, after + which control is handed back to the high level policy for the next pick). + - shows how to setup a plain multi-agent RL algo (here: PPO) to learn in this + hierarchical setup and solve tasks that are otherwise very difficult to solve + only with a single, primitive-action picking low level policy. + +We use the `SixRoomEnv` and `HierarchicalSixRoomEnv`, both sharing the same built-in +maps. The envs are similar to the FrozenLake-v1 env, but support walls (inner and outer) +through which the agent cannot walk. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --map=large --time-limit=50` + +Use the `--flat` option to disable the hierarchical setup and learn the simple (flat) +SixRoomEnv with only one policy. You should observe that it's much harder for the algo +to reach the global goal state in this setting. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +In the console output, you can see that only a PPO algorithm that uses hierarchical +training (`--flat` flag is NOT set) can actually learn with the command line options +`--map=large --time-limit=500 --max-steps-low-level=40 --num-low-level-agents=3`. + +4 policies in a hierarchical setup (1 high level "manager", 3 low level "experts"): ++---------------------+----------+--------+------------------+ +| Trial name | status | iter | total time (s) | +| | | | | +|---------------------+----------+--------+------------------+ +| PPO_env_58b78_00000 | RUNNING | 100 | 278.23 | ++---------------------+----------+--------+------------------+ ++-------------------+--------------------------+---------------------------+ ... +| combined return | return high_level_policy | return low_level_policy_0 | +|-------------------+--------------------------+---------------------------+ ... +| -8.4 | -5.2 | -1.19 | ++-------------------+--------------------------+---------------------------+ ... +""" +from ray import tune +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations +from ray.rllib.examples.envs.classes.six_room_env import ( + HierarchicalSixRoomEnv, + SixRoomEnv, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args( + default_reward=7.0, + default_timesteps=4000000, + default_iters=800, +) +parser.add_argument( + "--flat", + action="store_true", + help="Use the non-hierarchical, single-agent flat `SixRoomEnv` instead.", +) +parser.add_argument( + "--map", + type=str, + choices=["small", "medium", "large"], + default="medium", + help="The built-in map to use.", +) +parser.add_argument( + "--time-limit", + type=int, + default=100, + help="The max. number of (primitive) timesteps per episode.", +) +parser.add_argument( + "--max-steps-low-level", + type=int, + default=15, + help="The max. number of steps a low-level policy can take after having been " + "picked by the high level policy. After this number of timesteps, control is " + "handed back to the high-level policy (to pick a next goal position plus the next " + "low level policy).", +) +parser.add_argument( + "--num-low-level-agents", + type=int, + default=3, + help="The number of low-level agents/policies to use.", +) +parser.set_defaults(enable_new_api_stack=True) + + +if __name__ == "__main__": + args = parser.parse_args() + + # Run the flat (non-hierarchical env). + if args.flat: + cls = SixRoomEnv + # Run in hierarchical mode. + else: + cls = HierarchicalSixRoomEnv + + tune.register_env("env", lambda cfg: cls(config=cfg)) + + base_config = ( + PPOConfig() + .environment( + "env", + env_config={ + "map": args.map, + "max_steps_low_level": args.max_steps_low_level, + "time_limit": args.time_limit, + "num_low_level_agents": args.num_low_level_agents, + }, + ) + .env_runners( + # num_envs_per_env_runner=10, + env_to_module_connector=( + lambda env: FlattenObservations(multi_agent=not args.flat) + ), + ) + .training( + train_batch_size_per_learner=4000, + minibatch_size=512, + lr=0.0003, + num_epochs=20, + entropy_coeff=0.025, + ) + ) + + # Configure a proper multi-agent setup for the hierarchical env. + if not args.flat: + + def policy_mapping_fn(agent_id, episode, **kwargs): + # Map each low level agent to its respective (low-level) policy. + if agent_id.startswith("low_level_"): + return f"low_level_policy_{agent_id[-1]}" + # Map the high level agent to the high level policy. + else: + return "high_level_policy" + + base_config.multi_agent( + policy_mapping_fn=policy_mapping_fn, + policies={"high_level_policy"} + | {f"low_level_policy_{i}" for i in range(args.num_low_level_agents)}, + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2c73973866ff00b05ef3f7e9df04d597ced2191 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/ppo_with_custom_loss_fn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/ppo_with_custom_loss_fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bf727cc54224841654ea285da65af08aa8e4abd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/ppo_with_custom_loss_fn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/ppo_with_torch_lr_schedulers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/ppo_with_torch_lr_schedulers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee849dea926690eb968e8021f1bb612ab31e840d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/ppo_with_torch_lr_schedulers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/separate_vf_lr_and_optimizer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/separate_vf_lr_and_optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b72d255b74255f6e2c4a6f9cb67defc159c0d745 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/__pycache__/separate_vf_lr_and_optimizer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/classes/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/classes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6c1be3a39253d8457854ad4800a50fa7c67fc8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/classes/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/classes/__pycache__/intrinsic_curiosity_learners.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/classes/__pycache__/intrinsic_curiosity_learners.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..746e26028da99b6a46be763735c372fe7c9c3236 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/learners/classes/__pycache__/intrinsic_curiosity_learners.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56fc94fa424e2d436c8daf86c85a5f793ad7d66d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/custom_metrics_in_algorithm_training_step.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/custom_metrics_in_algorithm_training_step.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4dc8c0b0df977e31333588a8e9769ce76a0efdd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/custom_metrics_in_algorithm_training_step.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/custom_metrics_in_env_runners.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/custom_metrics_in_env_runners.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e068a25508e84cd5f724cb7e16f644cb7d8c2f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/__pycache__/custom_metrics_in_env_runners.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/custom_metrics_in_algorithm_training_step.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/custom_metrics_in_algorithm_training_step.py new file mode 100644 index 0000000000000000000000000000000000000000..357f37a0e3d15b9e5c504d990bdce9a87310961e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/custom_metrics_in_algorithm_training_step.py @@ -0,0 +1,105 @@ +"""Example of logging custom metrics inside the `Algorithm.training_step()` method. + +RLlib provides a MetricsLogger instance inside most components of an Algorithm, +including the Algorithm itself. + +This example: +- Shows how to subclass a custom Algorithm class (VPG) and override its +`training_step()` method. +- Shows how to use the MetricsLogger instance of Algorithm to log the ratio between +the time spent on sampling over the time spent on the learning update. For on-policy +algorithms, this ratio indicates, where scaling would yield the largest speedups, on +the EnvRunner side or on the Learner side. +- Shows how to access the logged metrics at the end of an iteration through the returned +result dict and how to assert, the new metrics has been properly logged. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --wandb-key [your WandB key] +--wandb-project [some project name]` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should see something similar to the below on your terminal when running +this script: + ++-----------------------------------------------+------------+--------+ +| Trial name | status | iter | +| | | | +|-----------------------------------------------+------------+--------+ +| MyVPGWithExtraMetrics_CartPole-v1_d2c5c_00000 | TERMINATED | 100 | ++-----------------------------------------------+------------+--------+ ++------------------+-----------------------+------------------------+ +| total time (s) | episode_return_mean | ratio_time_sampling_ | +| | | over_learning | +|------------------+-----------------------+------------------------| +| 10.0308 | 50.91 | 4.84769 | ++------------------+-----------------------+------------------------+ + +Found logged `ratio_time_sampling_over_learning` in result dict. +""" +from ray.rllib.algorithms import AlgorithmConfig +from ray.rllib.examples.algorithms.classes.vpg import VPG, VPGConfig +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics import ( + ENV_RUNNER_SAMPLING_TIMER, + LEARNER_UPDATE_TIMER, + TIMERS, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + + +class MyVPGWithExtraMetrics(VPG): + @override(VPG) + def training_step(self) -> None: + # Call the actual VPG training_step. + super().training_step() + + # Look up some already logged metrics through the `peek()` method. + time_spent_on_sampling = self.metrics.peek((TIMERS, ENV_RUNNER_SAMPLING_TIMER)) + time_spent_on_learning = self.metrics.peek((TIMERS, LEARNER_UPDATE_TIMER)) + + # Log extra metrics, still for this training step. + self.metrics.log_value( + "ratio_time_sampling_over_learning", + time_spent_on_sampling / time_spent_on_learning, + ) + + @classmethod + @override(VPG) + def get_default_config(cls) -> AlgorithmConfig: + return VPGConfig(algo_class=cls) + + +parser = add_rllib_example_script_args(default_reward=50.0) +parser.set_defaults(enable_new_api_stack=True) + + +if __name__ == "__main__": + args = parser.parse_args() + + base_config = MyVPGWithExtraMetrics.get_default_config().environment("CartPole-v1") + + results = run_rllib_example_script_experiment(base_config, args) + + # Check, whether the logged metrics are present. + if args.no_tune: + assert "ratio_time_sampling_over_learning" in results + else: + assert "ratio_time_sampling_over_learning" in results[0].metrics + + print("Found logged `ratio_time_sampling_over_learning` in result dict.") diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/custom_metrics_in_env_runners.py b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/custom_metrics_in_env_runners.py new file mode 100644 index 0000000000000000000000000000000000000000..ba7160c7d65515d874aa2aae38dc6cc60fd4c212 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/examples/metrics/custom_metrics_in_env_runners.py @@ -0,0 +1,340 @@ +"""Example of adding custom metrics to the results returned by `EnvRunner.sample()`. + +We use the `MetricsLogger` class, which RLlib provides inside all its components (only +when using the new API stack through +`config.api_stack(enable_rl_module_and_learner=True, +enable_env_runner_and_connector_v2=True)`), +and which offers a unified API to log individual values per iteration, per episode +timestep, per episode (as a whole), per loss call, etc.. +`MetricsLogger` objects are available in all custom API code, for example inside your +custom `Algorithm.training_step()` methods, custom loss functions, custom callbacks, +and custom EnvRunners. + +This example: + - demonstrates how to write a custom RLlibCallback subclass, which overrides some + EnvRunner-bound methods, such as `on_episode_start`, `on_episode_step`, and + `on_episode_end`. + - shows how to temporarily store per-timestep data inside the currently running + episode within the EnvRunner (and the callback methods). + - shows how to extract this temporary data again when the episode is done in order + to further process the data into a single, reportable metric. + - explains how to use the `MetricsLogger` API to create and log different metrics + to the final Algorithm's iteration output. These include - but are not limited to - + a 2D heatmap (image) per episode, an average per-episode metric (over a sliding + window of 200 episodes), a maximum per-episode metric (over a sliding window of 100 + episodes), and an EMA-smoothed metric. + +In this script, we define a custom `RLlibCallback` class and then override some of +its methods in order to define custom behavior during episode sampling. In particular, +we add custom metrics to the Algorithm's published result dict (once per +iteration) before it is sent back to Ray Tune (and possibly a WandB logger). + +For demonstration purposes only, we log the following custom metrics: +- A 2D heatmap showing the frequency of all accumulated y/x-locations of Ms Pacman +during an episode. We create and log a separate heatmap per episode and limit the number +of heatmaps reported back to the algorithm by each EnvRunner to 10 (`window=10`). +- The maximum per-episode distance travelled by Ms Pacman over a sliding window of 100 +episodes. +- The average per-episode distance travelled by Ms Pacman over a sliding window of 200 +episodes. +- The EMA-smoothed number of lives of Ms Pacman at each timestep (across all episodes). + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --wandb-key [your WandB key] +--wandb-project [some project name]` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +This script has not been finetuned to actually learn the environment. Its purpose +is to show how you can create and log custom metrics during episode sampling and +have these stats be sent to WandB for further analysis. + +However, you should see training proceeding over time like this: ++---------------------+----------+----------------+--------+------------------+ +| Trial name | status | loc | iter | total time (s) | +| | | | | | +|---------------------+----------+----------------+--------+------------------+ +| PPO_env_efd16_00000 | RUNNING | 127.0.0.1:6181 | 4 | 72.4725 | ++---------------------+----------+----------------+--------+------------------+ ++------------------------+------------------------+------------------------+ +| episode_return_mean | num_episodes_lifetim | num_env_steps_traine | +| | e | d_lifetime | +|------------------------+------------------------+------------------------| +| 76.4 | 45 | 8053 | ++------------------------+------------------------+------------------------+ +""" +from typing import Optional, Sequence + +import gymnasium as gym +import matplotlib.pyplot as plt +from matplotlib.colors import Normalize +import numpy as np + +from ray.rllib.callbacks.callbacks import RLlibCallback +from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack +from ray.rllib.utils.images import resize +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls, register_env + + +class MsPacmanHeatmapCallback(RLlibCallback): + """A custom callback to extract information from MsPacman and log these. + + This callback logs: + - the positions of MsPacman over an episode to produce heatmaps from this data. + At each episode timestep, the current pacman (y/x)-position is determined and added + to the episode's temporary storage. At the end of an episode, a simple 2D heatmap + is created from this data and the heatmap is logged to the MetricsLogger (to be + viewed in WandB). + - the max distance travelled by MsPacman per episode, then averaging these max + values over a window of size=100. + - the mean distance travelled by MsPacman per episode (over an infinite window). + - the number of lifes of MsPacman EMA-smoothed over time. + + This callback can be setup to only log stats on certain EnvRunner indices through + the `env_runner_indices` c'tor arg. + """ + + def __init__(self, env_runner_indices: Optional[Sequence[int]] = None): + """Initializes an MsPacmanHeatmapCallback instance. + + Args: + env_runner_indices: The (optional) EnvRunner indices, for this callback + should be active. If None, activates the heatmap for all EnvRunners. + If a Sequence type, only logs/heatmaps, if the EnvRunner index is found + in `env_runner_indices`. + """ + super().__init__() + # Only create heatmap on certain EnvRunner indices? + self._env_runner_indices = env_runner_indices + + # Mapping from episode ID to max distance travelled thus far. + self._episode_start_position = {} + + def on_episode_start( + self, + *, + episode, + env_runner, + metrics_logger, + env, + env_index, + rl_module, + **kwargs, + ) -> None: + # Skip, if this EnvRunner's index is not in `self._env_runner_indices`. + if ( + self._env_runner_indices is not None + and env_runner.worker_index not in self._env_runner_indices + ): + return + + yx_pos = self._get_pacman_yx_pos(env) + self._episode_start_position[episode.id_] = yx_pos + + def on_episode_step( + self, + *, + episode, + env_runner, + metrics_logger, + env, + env_index, + rl_module, + **kwargs, + ) -> None: + """Adds current pacman y/x-position to episode's temporary data.""" + + # Skip, if this EnvRunner's index is not in `self._env_runner_indices`. + if ( + self._env_runner_indices is not None + and env_runner.worker_index not in self._env_runner_indices + ): + return + + yx_pos = self._get_pacman_yx_pos(env) + episode.add_temporary_timestep_data("pacman_yx_pos", yx_pos) + + # Compute distance to start position. + dist_travelled = np.sqrt( + np.sum( + np.square( + np.array(self._episode_start_position[episode.id_]) + - np.array(yx_pos) + ) + ) + ) + episode.add_temporary_timestep_data("pacman_dist_travelled", dist_travelled) + + def on_episode_end( + self, + *, + episode, + env_runner, + metrics_logger, + env, + env_index, + rl_module, + **kwargs, + ) -> None: + # Skip, if this EnvRunner's index is not in `self._env_runner_indices`. + if ( + self._env_runner_indices is not None + and env_runner.worker_index not in self._env_runner_indices + ): + return + + # Erase the start position record. + del self._episode_start_position[episode.id_] + + # Get all pacman y/x-positions from the episode. + yx_positions = episode.get_temporary_timestep_data("pacman_yx_pos") + # h x w + heatmap = np.zeros((80, 100), dtype=np.int32) + for yx_pos in yx_positions: + if yx_pos != (-1, -1): + heatmap[yx_pos[0], yx_pos[1]] += 1 + + # Create the actual heatmap image. + # Normalize the heatmap to values between 0 and 1 + norm = Normalize(vmin=heatmap.min(), vmax=heatmap.max()) + # Use a colormap (e.g., 'hot') to map normalized values to RGB + colormap = plt.get_cmap("coolwarm") # try "hot" and "viridis" as well? + # Returns a (64, 64, 4) array (RGBA). + heatmap_rgb = colormap(norm(heatmap)) + # Convert RGBA to RGB by dropping the alpha channel and converting to uint8. + heatmap_rgb = (heatmap_rgb[:, :, :3] * 255).astype(np.uint8) + # Log the image. + metrics_logger.log_value( + "pacman_heatmap", + heatmap_rgb, + reduce=None, + window=10, # Log 10 images at most per EnvRunner/training iteration. + ) + + # Get the max distance travelled for this episode. + dist_travelled = np.max( + episode.get_temporary_timestep_data("pacman_dist_travelled") + ) + + # Log the max. dist travelled in this episode (window=100). + metrics_logger.log_value( + "pacman_max_dist_travelled", + dist_travelled, + # For future reductions (e.g. over n different episodes and all the + # data coming from other env runners), reduce by max. + reduce="max", + # Always keep the last 100 values and max over this window. + # Note that this means that over time, if the values drop to lower + # numbers again, the reported `pacman_max_dist_travelled` might also + # decrease again (meaning `window=100` makes this not a "lifetime max"). + window=100, + ) + + # Log the average dist travelled per episode (window=200). + metrics_logger.log_value( + "pacman_mean_dist_travelled", + dist_travelled, + reduce="mean", # <- default + # Always keep the last 200 values and average over this window. + window=200, + ) + + # Log the number of lifes (as EMA-smoothed; no window). + metrics_logger.log_value( + "pacman_lifes", + episode.get_infos(-1)["lives"], + reduce="mean", # <- default (must be "mean" for EMA smothing) + ema_coeff=0.01, # <- default EMA coefficient (`window` must be None) + ) + + def _get_pacman_yx_pos(self, env): + # If we have a vector env, only render the sub-env at index 0. + if isinstance(env.unwrapped, gym.vector.VectorEnv): + image = env.envs[0].render() + else: + image = env.render() + # Downsize to 100x100 for our utility function to work with. + image = resize(image, 100, 100) + # Crop image at bottom 20% (where lives are shown, which may confuse the pacman + # detector). + image = image[:80] + # Define the yellow color range in RGB (Ms. Pac-Man is yellowish). + # We allow some range around yellow to account for variation. + yellow_lower = np.array([200, 130, 65], dtype=np.uint8) + yellow_upper = np.array([220, 175, 105], dtype=np.uint8) + # Create a mask that highlights the yellow pixels + mask = np.all((image >= yellow_lower) & (image <= yellow_upper), axis=-1) + # Find the coordinates of the yellow pixels + yellow_pixels = np.argwhere(mask) + if yellow_pixels.size == 0: + return (-1, -1) + + # Calculate the centroid of the yellow pixels to get Ms. Pac-Man's position + y, x = yellow_pixels.mean(axis=0).astype(int) + return y, x + + +parser = add_rllib_example_script_args(default_reward=450.0) +parser.set_defaults(enable_new_api_stack=True) + + +if __name__ == "__main__": + args = parser.parse_args() + + # Register our environment with tune. + register_env( + "env", + lambda cfg: wrap_atari_for_new_api_stack( + gym.make("ale_py:ALE/MsPacman-v5", **cfg, **{"render_mode": "rgb_array"}), + framestack=4, + ), + ) + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment( + "env", + env_config={ + # Make analogous to old v4 + NoFrameskip. + "frameskip": 1, + "full_action_space": False, + "repeat_action_probability": 0.0, + }, + ) + .callbacks(MsPacmanHeatmapCallback) + .training( + # Make learning time fast, but note that this example may not + # necessarily learn well (its purpose is to demo the + # functionality of callbacks and the MetricsLogger). + train_batch_size_per_learner=2000, + minibatch_size=512, + num_epochs=6, + ) + .rl_module( + model_config_dict={ + "vf_share_layers": True, + "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], + "conv_activation": "relu", + "post_fcnet_hiddens": [256], + } + ) + ) + + run_rllib_example_script_experiment(base_config, args)