Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py +35 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py +37 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py +428 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py +128 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py +187 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py +1262 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py +481 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py +161 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py +799 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py +1107 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py +403 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py +341 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py +462 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py +853 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py +1862 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py +589 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py +544 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py +400 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py +220 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py +98 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py +157 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py +73 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py +130 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py +214 -0
- .venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py +381 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/rllib/core/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.core.columns import Columns
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
DEFAULT_AGENT_ID = "default_agent"
|
| 5 |
+
DEFAULT_POLICY_ID = "default_policy"
|
| 6 |
+
# TODO (sven): Change this to "default_module"
|
| 7 |
+
DEFAULT_MODULE_ID = DEFAULT_POLICY_ID
|
| 8 |
+
ALL_MODULES = "__all_modules__"
|
| 9 |
+
|
| 10 |
+
COMPONENT_ENV_RUNNER = "env_runner"
|
| 11 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR = "env_to_module_connector"
|
| 12 |
+
COMPONENT_EVAL_ENV_RUNNER = "eval_env_runner"
|
| 13 |
+
COMPONENT_LEARNER = "learner"
|
| 14 |
+
COMPONENT_LEARNER_GROUP = "learner_group"
|
| 15 |
+
COMPONENT_METRICS_LOGGER = "metrics_logger"
|
| 16 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR = "module_to_env_connector"
|
| 17 |
+
COMPONENT_OPTIMIZER = "optimizer"
|
| 18 |
+
COMPONENT_RL_MODULE = "rl_module"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"Columns",
|
| 23 |
+
"COMPONENT_ENV_RUNNER",
|
| 24 |
+
"COMPONENT_ENV_TO_MODULE_CONNECTOR",
|
| 25 |
+
"COMPONENT_EVAL_ENV_RUNNER",
|
| 26 |
+
"COMPONENT_LEARNER",
|
| 27 |
+
"COMPONENT_LEARNER_GROUP",
|
| 28 |
+
"COMPONENT_METRICS_LOGGER",
|
| 29 |
+
"COMPONENT_MODULE_TO_ENV_CONNECTOR",
|
| 30 |
+
"COMPONENT_OPTIMIZER",
|
| 31 |
+
"COMPONENT_RL_MODULE",
|
| 32 |
+
"DEFAULT_AGENT_ID",
|
| 33 |
+
"DEFAULT_MODULE_ID",
|
| 34 |
+
"DEFAULT_POLICY_ID",
|
| 35 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.env.base_env import BaseEnv
|
| 2 |
+
from ray.rllib.env.env_context import EnvContext
|
| 3 |
+
from ray.rllib.env.external_env import ExternalEnv
|
| 4 |
+
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
| 5 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 6 |
+
from ray.rllib.env.policy_client import PolicyClient
|
| 7 |
+
from ray.rllib.env.policy_server_input import PolicyServerInput
|
| 8 |
+
from ray.rllib.env.remote_base_env import RemoteBaseEnv
|
| 9 |
+
from ray.rllib.env.vector_env import VectorEnv
|
| 10 |
+
|
| 11 |
+
from ray.rllib.env.wrappers.dm_env_wrapper import DMEnv
|
| 12 |
+
from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv
|
| 13 |
+
from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
|
| 14 |
+
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
| 15 |
+
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
|
| 16 |
+
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
|
| 17 |
+
|
| 18 |
+
INPUT_ENV_SPACES = "__env__"
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"BaseEnv",
|
| 22 |
+
"DMEnv",
|
| 23 |
+
"DMCEnv",
|
| 24 |
+
"EnvContext",
|
| 25 |
+
"ExternalEnv",
|
| 26 |
+
"ExternalMultiAgentEnv",
|
| 27 |
+
"GroupAgentsWrapper",
|
| 28 |
+
"MultiAgentEnv",
|
| 29 |
+
"PettingZooEnv",
|
| 30 |
+
"ParallelPettingZooEnv",
|
| 31 |
+
"PolicyClient",
|
| 32 |
+
"PolicyServerInput",
|
| 33 |
+
"RemoteBaseEnv",
|
| 34 |
+
"Unity3DEnv",
|
| 35 |
+
"VectorEnv",
|
| 36 |
+
"INPUT_ENV_SPACES",
|
| 37 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/base_env.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_context.cpython-311.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner.cpython-311.pyc
ADDED
|
Binary file (9.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/env_runner_group.cpython-311.pyc
ADDED
|
Binary file (54.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_env.cpython-311.pyc
ADDED
|
Binary file (21.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/external_multi_agent_env.cpython-311.pyc
ADDED
|
Binary file (7.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env.cpython-311.pyc
ADDED
|
Binary file (37.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_env_runner.cpython-311.pyc
ADDED
|
Binary file (40 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_client.cpython-311.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/policy_server_input.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/remote_base_env.cpython-311.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_env_runner.cpython-311.pyc
ADDED
|
Binary file (32.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/single_agent_episode.cpython-311.pyc
ADDED
|
Binary file (92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/tcp_client_inference_env_runner.cpython-311.pyc
ADDED
|
Binary file (28.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/vector_env.cpython-311.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/base_env.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING, Union, Set
|
| 3 |
+
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
import ray
|
| 6 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 7 |
+
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 11 |
+
|
| 12 |
+
ASYNC_RESET_RETURN = "async_reset_return"
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@OldAPIStack
|
| 18 |
+
class BaseEnv:
|
| 19 |
+
"""The lowest-level env interface used by RLlib for sampling.
|
| 20 |
+
|
| 21 |
+
BaseEnv models multiple agents executing asynchronously in multiple
|
| 22 |
+
vectorized sub-environments. A call to `poll()` returns observations from
|
| 23 |
+
ready agents keyed by their sub-environment ID and agent IDs, and
|
| 24 |
+
actions for those agents can be sent back via `send_actions()`.
|
| 25 |
+
|
| 26 |
+
All other RLlib supported env types can be converted to BaseEnv.
|
| 27 |
+
RLlib handles these conversions internally in RolloutWorker, for example:
|
| 28 |
+
|
| 29 |
+
gym.Env => rllib.VectorEnv => rllib.BaseEnv
|
| 30 |
+
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
|
| 31 |
+
rllib.ExternalEnv => rllib.BaseEnv
|
| 32 |
+
|
| 33 |
+
.. testcode::
|
| 34 |
+
:skipif: True
|
| 35 |
+
|
| 36 |
+
MyBaseEnv = ...
|
| 37 |
+
env = MyBaseEnv()
|
| 38 |
+
obs, rewards, terminateds, truncateds, infos, off_policy_actions = (
|
| 39 |
+
env.poll()
|
| 40 |
+
)
|
| 41 |
+
print(obs)
|
| 42 |
+
|
| 43 |
+
env.send_actions({
|
| 44 |
+
"env_0": {
|
| 45 |
+
"car_0": 0,
|
| 46 |
+
"car_1": 1,
|
| 47 |
+
}, ...
|
| 48 |
+
})
|
| 49 |
+
obs, rewards, terminateds, truncateds, infos, off_policy_actions = (
|
| 50 |
+
env.poll()
|
| 51 |
+
)
|
| 52 |
+
print(obs)
|
| 53 |
+
|
| 54 |
+
print(terminateds)
|
| 55 |
+
|
| 56 |
+
.. testoutput::
|
| 57 |
+
|
| 58 |
+
{
|
| 59 |
+
"env_0": {
|
| 60 |
+
"car_0": [2.4, 1.6],
|
| 61 |
+
"car_1": [3.4, -3.2],
|
| 62 |
+
},
|
| 63 |
+
"env_1": {
|
| 64 |
+
"car_0": [8.0, 4.1],
|
| 65 |
+
},
|
| 66 |
+
"env_2": {
|
| 67 |
+
"car_0": [2.3, 3.3],
|
| 68 |
+
"car_1": [1.4, -0.2],
|
| 69 |
+
"car_3": [1.2, 0.1],
|
| 70 |
+
},
|
| 71 |
+
}
|
| 72 |
+
{
|
| 73 |
+
"env_0": {
|
| 74 |
+
"car_0": [4.1, 1.7],
|
| 75 |
+
"car_1": [3.2, -4.2],
|
| 76 |
+
}, ...
|
| 77 |
+
}
|
| 78 |
+
{
|
| 79 |
+
"env_0": {
|
| 80 |
+
"__all__": False,
|
| 81 |
+
"car_0": False,
|
| 82 |
+
"car_1": True,
|
| 83 |
+
}, ...
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def to_base_env(
|
| 89 |
+
self,
|
| 90 |
+
make_env: Optional[Callable[[int], EnvType]] = None,
|
| 91 |
+
num_envs: int = 1,
|
| 92 |
+
remote_envs: bool = False,
|
| 93 |
+
remote_env_batch_wait_ms: int = 0,
|
| 94 |
+
restart_failed_sub_environments: bool = False,
|
| 95 |
+
) -> "BaseEnv":
|
| 96 |
+
"""Converts an RLlib-supported env into a BaseEnv object.
|
| 97 |
+
|
| 98 |
+
Supported types for the `env` arg are gym.Env, BaseEnv,
|
| 99 |
+
VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
|
| 100 |
+
|
| 101 |
+
The resulting BaseEnv is always vectorized (contains n
|
| 102 |
+
sub-environments) to support batched forward passes, where n may also
|
| 103 |
+
be 1. BaseEnv also supports async execution via the `poll` and
|
| 104 |
+
`send_actions` methods and thus supports external simulators.
|
| 105 |
+
|
| 106 |
+
TODO: Support gym3 environments, which are already vectorized.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
env: An already existing environment of any supported env type
|
| 110 |
+
to convert/wrap into a BaseEnv. Supported types are gym.Env,
|
| 111 |
+
BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
|
| 112 |
+
ExternalMultiAgentEnv.
|
| 113 |
+
make_env: A callable taking an int as input (which indicates the
|
| 114 |
+
number of individual sub-environments within the final
|
| 115 |
+
vectorized BaseEnv) and returning one individual
|
| 116 |
+
sub-environment.
|
| 117 |
+
num_envs: The number of sub-environments to create in the
|
| 118 |
+
resulting (vectorized) BaseEnv. The already existing `env`
|
| 119 |
+
will be one of the `num_envs`.
|
| 120 |
+
remote_envs: Whether each sub-env should be a @ray.remote actor.
|
| 121 |
+
You can set this behavior in your config via the
|
| 122 |
+
`remote_worker_envs=True` option.
|
| 123 |
+
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
| 124 |
+
sub-environments for, if applicable. Only used if
|
| 125 |
+
`remote_envs` is True.
|
| 126 |
+
policy_config: Optional policy config dict.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
The resulting BaseEnv object.
|
| 130 |
+
"""
|
| 131 |
+
return self
|
| 132 |
+
|
| 133 |
+
def poll(
|
| 134 |
+
self,
|
| 135 |
+
) -> Tuple[
|
| 136 |
+
MultiEnvDict,
|
| 137 |
+
MultiEnvDict,
|
| 138 |
+
MultiEnvDict,
|
| 139 |
+
MultiEnvDict,
|
| 140 |
+
MultiEnvDict,
|
| 141 |
+
MultiEnvDict,
|
| 142 |
+
]:
|
| 143 |
+
"""Returns observations from ready agents.
|
| 144 |
+
|
| 145 |
+
All return values are two-level dicts mapping from EnvID to dicts
|
| 146 |
+
mapping from AgentIDs to (observation/reward/etc..) values.
|
| 147 |
+
The number of agents and sub-environments may vary over time.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Tuple consisting of:
|
| 151 |
+
New observations for each ready agent.
|
| 152 |
+
Reward values for each ready agent. If the episode is just started,
|
| 153 |
+
the value will be None.
|
| 154 |
+
Terminated values for each ready agent. The special key "__all__" is used to
|
| 155 |
+
indicate episode termination.
|
| 156 |
+
Truncated values for each ready agent. The special key "__all__"
|
| 157 |
+
is used to indicate episode truncation.
|
| 158 |
+
Info values for each ready agent.
|
| 159 |
+
Agents may take off-policy actions, in which case, there will be an entry
|
| 160 |
+
in this dict that contains the taken action. There is no need to
|
| 161 |
+
`send_actions()` for agents that have already chosen off-policy actions.
|
| 162 |
+
"""
|
| 163 |
+
raise NotImplementedError
|
| 164 |
+
|
| 165 |
+
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
| 166 |
+
"""Called to send actions back to running agents in this env.
|
| 167 |
+
|
| 168 |
+
Actions should be sent for each ready agent that returned observations
|
| 169 |
+
in the previous poll() call.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
action_dict: Actions values keyed by env_id and agent_id.
|
| 173 |
+
"""
|
| 174 |
+
raise NotImplementedError
|
| 175 |
+
|
| 176 |
+
def try_reset(
|
| 177 |
+
self,
|
| 178 |
+
env_id: Optional[EnvID] = None,
|
| 179 |
+
*,
|
| 180 |
+
seed: Optional[int] = None,
|
| 181 |
+
options: Optional[dict] = None,
|
| 182 |
+
) -> Tuple[Optional[MultiEnvDict], Optional[MultiEnvDict]]:
|
| 183 |
+
"""Attempt to reset the sub-env with the given id or all sub-envs.
|
| 184 |
+
|
| 185 |
+
If the environment does not support synchronous reset, a tuple of
|
| 186 |
+
(ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) can be returned here.
|
| 187 |
+
|
| 188 |
+
Note: A MultiAgentDict is returned when using the deprecated wrapper
|
| 189 |
+
classes such as `ray.rllib.env.base_env._MultiAgentEnvToBaseEnv`,
|
| 190 |
+
however for consistency with the poll() method, a `MultiEnvDict` is
|
| 191 |
+
returned from the new wrapper classes, such as
|
| 192 |
+
`ray.rllib.env.multi_agent_env.MultiAgentEnvWrapper`.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
env_id: The sub-environment's ID if applicable. If None, reset
|
| 196 |
+
the entire Env (i.e. all sub-environments).
|
| 197 |
+
seed: The seed to be passed to the sub-environment(s) when
|
| 198 |
+
resetting it. If None, will not reset any existing PRNG. If you pass an
|
| 199 |
+
integer, the PRNG will be reset even if it already exists.
|
| 200 |
+
options: An options dict to be passed to the sub-environment(s) when
|
| 201 |
+
resetting it.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
A tuple consisting of a) the reset (multi-env/multi-agent) observation
|
| 205 |
+
dict and b) the reset (multi-env/multi-agent) infos dict. Returns the
|
| 206 |
+
(ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) tuple, if not supported.
|
| 207 |
+
"""
|
| 208 |
+
return None, None
|
| 209 |
+
|
| 210 |
+
def try_restart(self, env_id: Optional[EnvID] = None) -> None:
|
| 211 |
+
"""Attempt to restart the sub-env with the given id or all sub-envs.
|
| 212 |
+
|
| 213 |
+
This could result in the sub-env being completely removed (gc'd) and recreated.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
env_id: The sub-environment's ID, if applicable. If None, restart
|
| 217 |
+
the entire Env (i.e. all sub-environments).
|
| 218 |
+
"""
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
|
| 222 |
+
"""Return a reference to the underlying sub environments, if any.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
as_dict: If True, return a dict mapping from env_id to env.
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
List or dictionary of the underlying sub environments or [] / {}.
|
| 229 |
+
"""
|
| 230 |
+
if as_dict:
|
| 231 |
+
return {}
|
| 232 |
+
return []
|
| 233 |
+
|
| 234 |
+
def get_agent_ids(self) -> Set[AgentID]:
|
| 235 |
+
"""Return the agent ids for the sub_environment.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
All agent ids for each the environment.
|
| 239 |
+
"""
|
| 240 |
+
return {}
|
| 241 |
+
|
| 242 |
+
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
| 243 |
+
"""Tries to render the sub-environment with the given id or all.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
env_id: The sub-environment's ID, if applicable.
|
| 247 |
+
If None, renders the entire Env (i.e. all sub-environments).
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
# By default, do nothing.
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
def stop(self) -> None:
|
| 254 |
+
"""Releases all resources used."""
|
| 255 |
+
|
| 256 |
+
# Try calling `close` on all sub-environments.
|
| 257 |
+
for env in self.get_sub_environments():
|
| 258 |
+
if hasattr(env, "close"):
|
| 259 |
+
env.close()
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def observation_space(self) -> gym.Space:
|
| 263 |
+
"""Returns the observation space for each agent.
|
| 264 |
+
|
| 265 |
+
Note: samples from the observation space need to be preprocessed into a
|
| 266 |
+
`MultiEnvDict` before being used by a policy.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
The observation space for each environment.
|
| 270 |
+
"""
|
| 271 |
+
raise NotImplementedError
|
| 272 |
+
|
| 273 |
+
@property
|
| 274 |
+
def action_space(self) -> gym.Space:
|
| 275 |
+
"""Returns the action space for each agent.
|
| 276 |
+
|
| 277 |
+
Note: samples from the action space need to be preprocessed into a
|
| 278 |
+
`MultiEnvDict` before being passed to `send_actions`.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
The observation space for each environment.
|
| 282 |
+
"""
|
| 283 |
+
raise NotImplementedError
|
| 284 |
+
|
| 285 |
+
def last(
|
| 286 |
+
self,
|
| 287 |
+
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
|
| 288 |
+
"""Returns the last observations, rewards, done- truncated flags and infos ...
|
| 289 |
+
|
| 290 |
+
that were returned by the environment.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
The last observations, rewards, done- and truncated flags, and infos
|
| 294 |
+
for each sub-environment.
|
| 295 |
+
"""
|
| 296 |
+
logger.warning("last has not been implemented for this environment.")
|
| 297 |
+
return {}, {}, {}, {}, {}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Fixed agent identifier when there is only the single agent in the env
|
| 301 |
+
_DUMMY_AGENT_ID = "agent0"
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@OldAPIStack
|
| 305 |
+
def with_dummy_agent_id(
|
| 306 |
+
env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID
|
| 307 |
+
) -> MultiEnvDict:
|
| 308 |
+
ret = {}
|
| 309 |
+
for (env_id, value) in env_id_to_values.items():
|
| 310 |
+
# If the value (e.g. the observation) is an Exception, publish this error
|
| 311 |
+
# under the env ID so the caller of `poll()` knows that the entire episode
|
| 312 |
+
# (sub-environment) has crashed.
|
| 313 |
+
ret[env_id] = value if isinstance(value, Exception) else {dummy_id: value}
|
| 314 |
+
return ret
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@OldAPIStack
|
| 318 |
+
def convert_to_base_env(
|
| 319 |
+
env: EnvType,
|
| 320 |
+
make_env: Callable[[int], EnvType] = None,
|
| 321 |
+
num_envs: int = 1,
|
| 322 |
+
remote_envs: bool = False,
|
| 323 |
+
remote_env_batch_wait_ms: int = 0,
|
| 324 |
+
worker: Optional["RolloutWorker"] = None,
|
| 325 |
+
restart_failed_sub_environments: bool = False,
|
| 326 |
+
) -> "BaseEnv":
|
| 327 |
+
"""Converts an RLlib-supported env into a BaseEnv object.
|
| 328 |
+
|
| 329 |
+
Supported types for the `env` arg are gym.Env, BaseEnv,
|
| 330 |
+
VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
|
| 331 |
+
|
| 332 |
+
The resulting BaseEnv is always vectorized (contains n
|
| 333 |
+
sub-environments) to support batched forward passes, where n may also
|
| 334 |
+
be 1. BaseEnv also supports async execution via the `poll` and
|
| 335 |
+
`send_actions` methods and thus supports external simulators.
|
| 336 |
+
|
| 337 |
+
TODO: Support gym3 environments, which are already vectorized.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
env: An already existing environment of any supported env type
|
| 341 |
+
to convert/wrap into a BaseEnv. Supported types are gym.Env,
|
| 342 |
+
BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
|
| 343 |
+
ExternalMultiAgentEnv.
|
| 344 |
+
make_env: A callable taking an int as input (which indicates the
|
| 345 |
+
number of individual sub-environments within the final
|
| 346 |
+
vectorized BaseEnv) and returning one individual
|
| 347 |
+
sub-environment.
|
| 348 |
+
num_envs: The number of sub-environments to create in the
|
| 349 |
+
resulting (vectorized) BaseEnv. The already existing `env`
|
| 350 |
+
will be one of the `num_envs`.
|
| 351 |
+
remote_envs: Whether each sub-env should be a @ray.remote actor.
|
| 352 |
+
You can set this behavior in your config via the
|
| 353 |
+
`remote_worker_envs=True` option.
|
| 354 |
+
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
| 355 |
+
sub-environments for, if applicable. Only used if
|
| 356 |
+
`remote_envs` is True.
|
| 357 |
+
worker: An optional RolloutWorker that owns the env. This is only
|
| 358 |
+
used if `remote_worker_envs` is True in your config and the
|
| 359 |
+
`on_sub_environment_created` custom callback needs to be called
|
| 360 |
+
on each created actor.
|
| 361 |
+
restart_failed_sub_environments: If True and any sub-environment (within
|
| 362 |
+
a vectorized env) throws any error during env stepping, the
|
| 363 |
+
Sampler will try to restart the faulty sub-environment. This is done
|
| 364 |
+
without disturbing the other (still intact) sub-environment and without
|
| 365 |
+
the RolloutWorker crashing.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
The resulting BaseEnv object.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
from ray.rllib.env.remote_base_env import RemoteBaseEnv
|
| 372 |
+
from ray.rllib.env.external_env import ExternalEnv
|
| 373 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 374 |
+
from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper
|
| 375 |
+
|
| 376 |
+
if remote_envs and num_envs == 1:
|
| 377 |
+
raise ValueError(
|
| 378 |
+
"Remote envs only make sense to use if num_envs > 1 "
|
| 379 |
+
"(i.e. environment vectorization is enabled)."
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Given `env` has a `to_base_env` method -> Call that to convert to a BaseEnv type.
|
| 383 |
+
if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
|
| 384 |
+
return env.to_base_env(
|
| 385 |
+
make_env=make_env,
|
| 386 |
+
num_envs=num_envs,
|
| 387 |
+
remote_envs=remote_envs,
|
| 388 |
+
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
| 389 |
+
restart_failed_sub_environments=restart_failed_sub_environments,
|
| 390 |
+
)
|
| 391 |
+
# `env` is not a BaseEnv yet -> Need to convert/vectorize.
|
| 392 |
+
else:
|
| 393 |
+
# Sub-environments are ray.remote actors:
|
| 394 |
+
if remote_envs:
|
| 395 |
+
# Determine, whether the already existing sub-env (could
|
| 396 |
+
# be a ray.actor) is multi-agent or not.
|
| 397 |
+
multiagent = (
|
| 398 |
+
ray.get(env._is_multi_agent.remote())
|
| 399 |
+
if hasattr(env, "_is_multi_agent")
|
| 400 |
+
else False
|
| 401 |
+
)
|
| 402 |
+
env = RemoteBaseEnv(
|
| 403 |
+
make_env,
|
| 404 |
+
num_envs,
|
| 405 |
+
multiagent=multiagent,
|
| 406 |
+
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
| 407 |
+
existing_envs=[env],
|
| 408 |
+
worker=worker,
|
| 409 |
+
restart_failed_sub_environments=restart_failed_sub_environments,
|
| 410 |
+
)
|
| 411 |
+
# Sub-environments are not ray.remote actors.
|
| 412 |
+
else:
|
| 413 |
+
# Convert gym.Env to VectorEnv ...
|
| 414 |
+
env = VectorEnv.vectorize_gym_envs(
|
| 415 |
+
make_env=make_env,
|
| 416 |
+
existing_envs=[env],
|
| 417 |
+
num_envs=num_envs,
|
| 418 |
+
action_space=env.action_space,
|
| 419 |
+
observation_space=env.observation_space,
|
| 420 |
+
restart_failed_sub_environments=restart_failed_sub_environments,
|
| 421 |
+
)
|
| 422 |
+
# ... then the resulting VectorEnv to a BaseEnv.
|
| 423 |
+
env = VectorEnvWrapper(env)
|
| 424 |
+
|
| 425 |
+
# Make sure conversion went well.
|
| 426 |
+
assert isinstance(env, BaseEnv), env
|
| 427 |
+
|
| 428 |
+
return env
|
.venv/lib/python3.11/site-packages/ray/rllib/env/env_context.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 5 |
+
from ray.rllib.utils.typing import EnvConfigDict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@OldAPIStack
|
| 9 |
+
class EnvContext(dict):
|
| 10 |
+
"""Wraps env configurations to include extra rllib metadata.
|
| 11 |
+
|
| 12 |
+
These attributes can be used to parameterize environments per process.
|
| 13 |
+
For example, one might use `worker_index` to control which data file an
|
| 14 |
+
environment reads in on initialization.
|
| 15 |
+
|
| 16 |
+
RLlib auto-sets these attributes when constructing registered envs.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
env_config: EnvConfigDict,
|
| 22 |
+
worker_index: int,
|
| 23 |
+
vector_index: int = 0,
|
| 24 |
+
remote: bool = False,
|
| 25 |
+
num_workers: Optional[int] = None,
|
| 26 |
+
recreated_worker: bool = False,
|
| 27 |
+
):
|
| 28 |
+
"""Initializes an EnvContext instance.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
env_config: The env's configuration defined under the
|
| 32 |
+
"env_config" key in the Algorithm's config.
|
| 33 |
+
worker_index: When there are multiple workers created, this
|
| 34 |
+
uniquely identifies the worker the env is created in.
|
| 35 |
+
0 for local worker, >0 for remote workers.
|
| 36 |
+
vector_index: When there are multiple envs per worker, this
|
| 37 |
+
uniquely identifies the env index within the worker.
|
| 38 |
+
Starts from 0.
|
| 39 |
+
remote: Whether individual sub-environments (in a vectorized
|
| 40 |
+
env) should be @ray.remote actors or not.
|
| 41 |
+
num_workers: The total number of (remote) workers in the set.
|
| 42 |
+
0 if only a local worker exists.
|
| 43 |
+
recreated_worker: Whether the worker that holds this env is a recreated one.
|
| 44 |
+
This means that it replaced a previous (failed) worker when
|
| 45 |
+
`restart_failed_env_runners=True` in the Algorithm's config.
|
| 46 |
+
"""
|
| 47 |
+
# Store the env_config in the (super) dict.
|
| 48 |
+
dict.__init__(self, env_config)
|
| 49 |
+
|
| 50 |
+
# Set some metadata attributes.
|
| 51 |
+
self.worker_index = worker_index
|
| 52 |
+
self.vector_index = vector_index
|
| 53 |
+
self.remote = remote
|
| 54 |
+
self.num_workers = num_workers
|
| 55 |
+
self.recreated_worker = recreated_worker
|
| 56 |
+
|
| 57 |
+
def copy_with_overrides(
|
| 58 |
+
self,
|
| 59 |
+
env_config: Optional[EnvConfigDict] = None,
|
| 60 |
+
worker_index: Optional[int] = None,
|
| 61 |
+
vector_index: Optional[int] = None,
|
| 62 |
+
remote: Optional[bool] = None,
|
| 63 |
+
num_workers: Optional[int] = None,
|
| 64 |
+
recreated_worker: Optional[bool] = None,
|
| 65 |
+
) -> "EnvContext":
|
| 66 |
+
"""Returns a copy of this EnvContext with some attributes overridden.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
env_config: Optional env config to use. None for not overriding
|
| 70 |
+
the one from the source (self).
|
| 71 |
+
worker_index: Optional worker index to use. None for not
|
| 72 |
+
overriding the one from the source (self).
|
| 73 |
+
vector_index: Optional vector index to use. None for not
|
| 74 |
+
overriding the one from the source (self).
|
| 75 |
+
remote: Optional remote setting to use. None for not overriding
|
| 76 |
+
the one from the source (self).
|
| 77 |
+
num_workers: Optional num_workers to use. None for not overriding
|
| 78 |
+
the one from the source (self).
|
| 79 |
+
recreated_worker: Optional flag, indicating, whether the worker that holds
|
| 80 |
+
the env is a recreated one. This means that it replaced a previous
|
| 81 |
+
(failed) worker when `restart_failed_env_runners=True` in the
|
| 82 |
+
Algorithm's config.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
A new EnvContext object as a copy of self plus the provided
|
| 86 |
+
overrides.
|
| 87 |
+
"""
|
| 88 |
+
return EnvContext(
|
| 89 |
+
copy.deepcopy(env_config) if env_config is not None else self,
|
| 90 |
+
worker_index if worker_index is not None else self.worker_index,
|
| 91 |
+
vector_index if vector_index is not None else self.vector_index,
|
| 92 |
+
remote if remote is not None else self.remote,
|
| 93 |
+
num_workers if num_workers is not None else self.num_workers,
|
| 94 |
+
recreated_worker if recreated_worker is not None else self.recreated_worker,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def set_defaults(self, defaults: dict) -> None:
|
| 98 |
+
"""Sets missing keys of self to the values given in `defaults`.
|
| 99 |
+
|
| 100 |
+
If `defaults` contains keys that already exist in self, don't override
|
| 101 |
+
the values with these defaults.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
defaults: The key/value pairs to add to self, but only for those
|
| 105 |
+
keys in `defaults` that don't exist yet in self.
|
| 106 |
+
|
| 107 |
+
.. testcode::
|
| 108 |
+
:skipif: True
|
| 109 |
+
|
| 110 |
+
from ray.rllib.env.env_context import EnvContext
|
| 111 |
+
env_ctx = EnvContext({"a": 1, "b": 2}, worker_index=0)
|
| 112 |
+
env_ctx.set_defaults({"a": -42, "c": 3})
|
| 113 |
+
print(env_ctx)
|
| 114 |
+
|
| 115 |
+
.. testoutput::
|
| 116 |
+
|
| 117 |
+
{"a": 1, "b": 2, "c": 3}
|
| 118 |
+
"""
|
| 119 |
+
for key, value in defaults.items():
|
| 120 |
+
if key not in self:
|
| 121 |
+
self[key] = value
|
| 122 |
+
|
| 123 |
+
def __str__(self):
|
| 124 |
+
return (
|
| 125 |
+
super().__str__()[:-1]
|
| 126 |
+
+ f", worker={self.worker_index}/{self.num_workers}, "
|
| 127 |
+
f"vector_idx={self.vector_index}, remote={self.remote}" + "}"
|
| 128 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Dict, Tuple, TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
import gymnasium as gym
|
| 6 |
+
import tree # pip install dm_tree
|
| 7 |
+
|
| 8 |
+
from ray.rllib.utils.actor_manager import FaultAwareApply
|
| 9 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 10 |
+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
| 11 |
+
from ray.rllib.utils.typing import TensorType
|
| 12 |
+
from ray.util.annotations import PublicAPI
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("ray.rllib")
|
| 18 |
+
|
| 19 |
+
tf1, tf, _ = try_import_tf()
|
| 20 |
+
|
| 21 |
+
ENV_RESET_FAILURE = "env_reset_failure"
|
| 22 |
+
ENV_STEP_FAILURE = "env_step_failure"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# TODO (sven): As soon as RolloutWorker is no longer supported, make this base class
|
| 26 |
+
# a Checkpointable. Currently, only some of its subclasses are Checkpointables.
|
| 27 |
+
@PublicAPI(stability="alpha")
|
| 28 |
+
class EnvRunner(FaultAwareApply, metaclass=abc.ABCMeta):
|
| 29 |
+
"""Base class for distributed RL-style data collection from an environment.
|
| 30 |
+
|
| 31 |
+
The EnvRunner API's core functionalities can be summarized as:
|
| 32 |
+
- Gets configured via passing a AlgorithmConfig object to the constructor.
|
| 33 |
+
Normally, subclasses of EnvRunner then construct their own environment (possibly
|
| 34 |
+
vectorized) copies and RLModules/Policies and use the latter to step through the
|
| 35 |
+
environment in order to collect training data.
|
| 36 |
+
- Clients of EnvRunner can use the `sample()` method to collect data for training
|
| 37 |
+
from the environment(s).
|
| 38 |
+
- EnvRunner offers parallelism via creating n remote Ray Actors based on this class.
|
| 39 |
+
Use `ray.remote([resources])(EnvRunner)` method to create the corresponding Ray
|
| 40 |
+
remote class. Then instantiate n Actors using the Ray `[ctor].remote(...)` syntax.
|
| 41 |
+
- EnvRunner clients can get information about the server/node on which the
|
| 42 |
+
individual Actors are running.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, *, config: "AlgorithmConfig", **kwargs):
|
| 46 |
+
"""Initializes an EnvRunner instance.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
config: The AlgorithmConfig to use to setup this EnvRunner.
|
| 50 |
+
**kwargs: Forward compatibility kwargs.
|
| 51 |
+
"""
|
| 52 |
+
self.config = config.copy(copy_frozen=False)
|
| 53 |
+
self.env = None
|
| 54 |
+
|
| 55 |
+
super().__init__(**kwargs)
|
| 56 |
+
|
| 57 |
+
# This eager check is necessary for certain all-framework tests
|
| 58 |
+
# that use tf's eager_mode() context generator.
|
| 59 |
+
if (
|
| 60 |
+
tf1
|
| 61 |
+
and (self.config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
|
| 62 |
+
and not tf1.executing_eagerly()
|
| 63 |
+
):
|
| 64 |
+
tf1.enable_eager_execution()
|
| 65 |
+
|
| 66 |
+
@abc.abstractmethod
|
| 67 |
+
def assert_healthy(self):
|
| 68 |
+
"""Checks that self.__init__() has been completed properly.
|
| 69 |
+
|
| 70 |
+
Useful in case an `EnvRunner` is run as @ray.remote (Actor) and the owner
|
| 71 |
+
would like to make sure the Ray Actor has been properly initialized.
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
AssertionError: If the EnvRunner Actor has NOT been properly initialized.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
# TODO: Make this an abstract method that must be implemented.
|
| 78 |
+
def make_env(self):
|
| 79 |
+
"""Creates the RL environment for this EnvRunner and assigns it to `self.env`.
|
| 80 |
+
|
| 81 |
+
Note that users should be able to change the EnvRunner's config (e.g. change
|
| 82 |
+
`self.config.env_config`) and then call this method to create new environments
|
| 83 |
+
with the updated configuration.
|
| 84 |
+
It should also be called after a failure of an earlier env in order to clean up
|
| 85 |
+
the existing env (for example `close()` it), re-create a new one, and then
|
| 86 |
+
continue sampling with that new env.
|
| 87 |
+
"""
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
# TODO: Make this an abstract method that must be implemented.
|
| 91 |
+
def make_module(self):
|
| 92 |
+
"""Creates the RLModule for this EnvRunner and assigns it to `self.module`.
|
| 93 |
+
|
| 94 |
+
Note that users should be able to change the EnvRunner's config (e.g. change
|
| 95 |
+
`self.config.rl_module_spec`) and then call this method to create a new RLModule
|
| 96 |
+
with the updated configuration.
|
| 97 |
+
"""
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
@abc.abstractmethod
|
| 101 |
+
def sample(self, **kwargs) -> Any:
|
| 102 |
+
"""Returns experiences (of any form) sampled from this EnvRunner.
|
| 103 |
+
|
| 104 |
+
The exact nature and size of collected data are defined via the EnvRunner's
|
| 105 |
+
config and may be overridden by the given arguments.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
**kwargs: Forward compatibility kwargs.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
The collected experience in any form.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# TODO (sven): Make this an abstract method that must be overridden.
|
| 115 |
+
def get_metrics(self) -> Any:
|
| 116 |
+
"""Returns metrics (in any form) of the thus far collected, completed episodes.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Metrics of any form.
|
| 120 |
+
"""
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
@abc.abstractmethod
|
| 124 |
+
def get_spaces(self) -> Dict[str, Tuple[gym.Space, gym.Space]]:
|
| 125 |
+
"""Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space."""
|
| 126 |
+
|
| 127 |
+
def stop(self) -> None:
|
| 128 |
+
"""Releases all resources used by this EnvRunner.
|
| 129 |
+
|
| 130 |
+
For example, when using a gym.Env in this EnvRunner, you should make sure
|
| 131 |
+
that its `close()` method is called.
|
| 132 |
+
"""
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
def __del__(self) -> None:
|
| 136 |
+
"""If this Actor is deleted, clears all resources used by it."""
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
def _try_env_reset(self):
|
| 140 |
+
"""Tries resetting the env and - if an error orrurs - handles it gracefully."""
|
| 141 |
+
# Try to reset.
|
| 142 |
+
try:
|
| 143 |
+
obs, infos = self.env.reset()
|
| 144 |
+
# Everything ok -> return.
|
| 145 |
+
return obs, infos
|
| 146 |
+
# Error.
|
| 147 |
+
except Exception as e:
|
| 148 |
+
# If user wants to simply restart the env -> recreate env and try again
|
| 149 |
+
# (calling this method recursively until success).
|
| 150 |
+
if self.config.restart_failed_sub_environments:
|
| 151 |
+
logger.exception(
|
| 152 |
+
"Resetting the env resulted in an error! The original error "
|
| 153 |
+
f"is: {e.args[0]}"
|
| 154 |
+
)
|
| 155 |
+
# Recreate the env and simply try again.
|
| 156 |
+
self.make_env()
|
| 157 |
+
return self._try_env_reset()
|
| 158 |
+
else:
|
| 159 |
+
raise e
|
| 160 |
+
|
| 161 |
+
def _try_env_step(self, actions):
|
| 162 |
+
"""Tries stepping the env and - if an error orrurs - handles it gracefully."""
|
| 163 |
+
try:
|
| 164 |
+
results = self.env.step(actions)
|
| 165 |
+
return results
|
| 166 |
+
except Exception as e:
|
| 167 |
+
if self.config.restart_failed_sub_environments:
|
| 168 |
+
logger.exception(
|
| 169 |
+
"Stepping the env resulted in an error! The original error "
|
| 170 |
+
f"is: {e.args[0]}"
|
| 171 |
+
)
|
| 172 |
+
# Recreate the env.
|
| 173 |
+
self.make_env()
|
| 174 |
+
# And return that the stepping failed. The caller will then handle
|
| 175 |
+
# specific cleanup operations (for example discarding thus-far collected
|
| 176 |
+
# data and repeating the step attempt).
|
| 177 |
+
return ENV_STEP_FAILURE
|
| 178 |
+
else:
|
| 179 |
+
raise e
|
| 180 |
+
|
| 181 |
+
def _convert_to_tensor(self, struct) -> TensorType:
|
| 182 |
+
"""Converts structs to a framework-specific tensor."""
|
| 183 |
+
|
| 184 |
+
if self.config.framework_str == "torch":
|
| 185 |
+
return convert_to_torch_tensor(struct)
|
| 186 |
+
else:
|
| 187 |
+
return tree.map_structure(tf.convert_to_tensor, struct)
|
.venv/lib/python3.11/site-packages/ray/rllib/env/env_runner_group.py
ADDED
|
@@ -0,0 +1,1262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
import logging
|
| 4 |
+
import importlib.util
|
| 5 |
+
import os
|
| 6 |
+
from typing import (
|
| 7 |
+
Any,
|
| 8 |
+
Callable,
|
| 9 |
+
Collection,
|
| 10 |
+
Dict,
|
| 11 |
+
List,
|
| 12 |
+
Optional,
|
| 13 |
+
Tuple,
|
| 14 |
+
Type,
|
| 15 |
+
TYPE_CHECKING,
|
| 16 |
+
TypeVar,
|
| 17 |
+
Union,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
import ray
|
| 21 |
+
from ray.actor import ActorHandle
|
| 22 |
+
from ray.exceptions import RayActorError
|
| 23 |
+
from ray.rllib.core import (
|
| 24 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR,
|
| 25 |
+
COMPONENT_LEARNER,
|
| 26 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR,
|
| 27 |
+
COMPONENT_RL_MODULE,
|
| 28 |
+
)
|
| 29 |
+
from ray.rllib.core.learner import LearnerGroup
|
| 30 |
+
from ray.rllib.core.rl_module import validate_module_id
|
| 31 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 32 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 33 |
+
from ray.rllib.env.base_env import BaseEnv
|
| 34 |
+
from ray.rllib.env.env_context import EnvContext
|
| 35 |
+
from ray.rllib.env.env_runner import EnvRunner
|
| 36 |
+
from ray.rllib.offline import get_dataset_and_shards
|
| 37 |
+
from ray.rllib.policy.policy import Policy, PolicyState
|
| 38 |
+
from ray.rllib.utils.actor_manager import FaultTolerantActorManager
|
| 39 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 40 |
+
from ray.rllib.utils.deprecation import (
|
| 41 |
+
Deprecated,
|
| 42 |
+
deprecation_warning,
|
| 43 |
+
DEPRECATED_VALUE,
|
| 44 |
+
)
|
| 45 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 46 |
+
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME, WEIGHTS_SEQ_NO
|
| 47 |
+
from ray.rllib.utils.typing import (
|
| 48 |
+
AgentID,
|
| 49 |
+
EnvCreator,
|
| 50 |
+
EnvType,
|
| 51 |
+
EpisodeID,
|
| 52 |
+
PartialAlgorithmConfigDict,
|
| 53 |
+
PolicyID,
|
| 54 |
+
SampleBatchType,
|
| 55 |
+
TensorType,
|
| 56 |
+
)
|
| 57 |
+
from ray.util.annotations import DeveloperAPI
|
| 58 |
+
|
| 59 |
+
if TYPE_CHECKING:
|
| 60 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 61 |
+
|
| 62 |
+
tf1, tf, tfv = try_import_tf()
|
| 63 |
+
|
| 64 |
+
logger = logging.getLogger(__name__)
|
| 65 |
+
|
| 66 |
+
# Generic type var for foreach_* methods.
|
| 67 |
+
T = TypeVar("T")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@DeveloperAPI
|
| 71 |
+
class EnvRunnerGroup:
|
| 72 |
+
"""Set of EnvRunners with n @ray.remote workers and zero or one local worker.
|
| 73 |
+
|
| 74 |
+
Where: n >= 0.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
*,
|
| 80 |
+
env_creator: Optional[EnvCreator] = None,
|
| 81 |
+
validate_env: Optional[Callable[[EnvType], None]] = None,
|
| 82 |
+
default_policy_class: Optional[Type[Policy]] = None,
|
| 83 |
+
config: Optional["AlgorithmConfig"] = None,
|
| 84 |
+
local_env_runner: bool = True,
|
| 85 |
+
logdir: Optional[str] = None,
|
| 86 |
+
_setup: bool = True,
|
| 87 |
+
tune_trial_id: Optional[str] = None,
|
| 88 |
+
# Deprecated args.
|
| 89 |
+
num_env_runners: Optional[int] = None,
|
| 90 |
+
num_workers=DEPRECATED_VALUE,
|
| 91 |
+
local_worker=DEPRECATED_VALUE,
|
| 92 |
+
):
|
| 93 |
+
"""Initializes a EnvRunnerGroup instance.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
env_creator: Function that returns env given env config.
|
| 97 |
+
validate_env: Optional callable to validate the generated
|
| 98 |
+
environment (only on worker=0). This callable should raise
|
| 99 |
+
an exception if the environment is invalid.
|
| 100 |
+
default_policy_class: An optional default Policy class to use inside
|
| 101 |
+
the (multi-agent) `policies` dict. In case the PolicySpecs in there
|
| 102 |
+
have no class defined, use this `default_policy_class`.
|
| 103 |
+
If None, PolicySpecs will be using the Algorithm's default Policy
|
| 104 |
+
class.
|
| 105 |
+
config: Optional AlgorithmConfig (or config dict).
|
| 106 |
+
local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
|
| 107 |
+
in the returned set as well (default: True). If `num_env_runners`
|
| 108 |
+
is 0, always create a local EnvRunner.
|
| 109 |
+
logdir: Optional logging directory for workers.
|
| 110 |
+
_setup: Whether to actually set up workers. This is only for testing.
|
| 111 |
+
tune_trial_id: The Ray Tune trial ID, if this EnvRunnerGroup is part of
|
| 112 |
+
an Algorithm run as a Tune trial. None, otherwise.
|
| 113 |
+
"""
|
| 114 |
+
if num_workers != DEPRECATED_VALUE or local_worker != DEPRECATED_VALUE:
|
| 115 |
+
deprecation_warning(
|
| 116 |
+
old="WorkerSet(num_workers=..., local_worker=...)",
|
| 117 |
+
new="EnvRunnerGroup(num_env_runners=..., local_env_runner=...)",
|
| 118 |
+
error=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 122 |
+
|
| 123 |
+
# Make sure `config` is an AlgorithmConfig object.
|
| 124 |
+
if not config:
|
| 125 |
+
config = AlgorithmConfig()
|
| 126 |
+
elif isinstance(config, dict):
|
| 127 |
+
config = AlgorithmConfig.from_dict(config)
|
| 128 |
+
|
| 129 |
+
self._env_creator = env_creator
|
| 130 |
+
self._policy_class = default_policy_class
|
| 131 |
+
self._remote_config = config
|
| 132 |
+
self._remote_args = {
|
| 133 |
+
"num_cpus": self._remote_config.num_cpus_per_env_runner,
|
| 134 |
+
"num_gpus": self._remote_config.num_gpus_per_env_runner,
|
| 135 |
+
"resources": self._remote_config.custom_resources_per_env_runner,
|
| 136 |
+
"max_restarts": (
|
| 137 |
+
config.max_num_env_runner_restarts
|
| 138 |
+
if config.restart_failed_env_runners
|
| 139 |
+
else 0
|
| 140 |
+
),
|
| 141 |
+
}
|
| 142 |
+
self._tune_trial_id = tune_trial_id
|
| 143 |
+
|
| 144 |
+
# Set the EnvRunner subclass to be used as "workers". Default: RolloutWorker.
|
| 145 |
+
self.env_runner_cls = config.env_runner_cls
|
| 146 |
+
if self.env_runner_cls is None:
|
| 147 |
+
if config.enable_env_runner_and_connector_v2:
|
| 148 |
+
# If experiences should be recorded, use the `
|
| 149 |
+
# OfflineSingleAgentEnvRunner`.
|
| 150 |
+
if config.output:
|
| 151 |
+
# No multi-agent support.
|
| 152 |
+
if config.is_multi_agent:
|
| 153 |
+
raise ValueError("Multi-agent recording is not supported, yet.")
|
| 154 |
+
# Otherwise, load the single-agent env runner for
|
| 155 |
+
# recording.
|
| 156 |
+
else:
|
| 157 |
+
from ray.rllib.offline.offline_env_runner import (
|
| 158 |
+
OfflineSingleAgentEnvRunner,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.env_runner_cls = OfflineSingleAgentEnvRunner
|
| 162 |
+
else:
|
| 163 |
+
if config.is_multi_agent:
|
| 164 |
+
from ray.rllib.env.multi_agent_env_runner import (
|
| 165 |
+
MultiAgentEnvRunner,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.env_runner_cls = MultiAgentEnvRunner
|
| 169 |
+
else:
|
| 170 |
+
from ray.rllib.env.single_agent_env_runner import (
|
| 171 |
+
SingleAgentEnvRunner,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.env_runner_cls = SingleAgentEnvRunner
|
| 175 |
+
else:
|
| 176 |
+
self.env_runner_cls = RolloutWorker
|
| 177 |
+
self._cls = ray.remote(**self._remote_args)(self.env_runner_cls).remote
|
| 178 |
+
|
| 179 |
+
self._logdir = logdir
|
| 180 |
+
self._ignore_ray_errors_on_env_runners = (
|
| 181 |
+
config.ignore_env_runner_failures or config.restart_failed_env_runners
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Create remote worker manager.
|
| 185 |
+
# ID=0 is used by the local worker.
|
| 186 |
+
# Starting remote workers from ID=1 to avoid conflicts.
|
| 187 |
+
self._worker_manager = FaultTolerantActorManager(
|
| 188 |
+
max_remote_requests_in_flight_per_actor=(
|
| 189 |
+
config.max_requests_in_flight_per_env_runner
|
| 190 |
+
),
|
| 191 |
+
init_id=1,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if _setup:
|
| 195 |
+
try:
|
| 196 |
+
self._setup(
|
| 197 |
+
validate_env=validate_env,
|
| 198 |
+
config=config,
|
| 199 |
+
num_env_runners=(
|
| 200 |
+
num_env_runners
|
| 201 |
+
if num_env_runners is not None
|
| 202 |
+
else config.num_env_runners
|
| 203 |
+
),
|
| 204 |
+
local_env_runner=local_env_runner,
|
| 205 |
+
)
|
| 206 |
+
# EnvRunnerGroup creation possibly fails, if some (remote) workers cannot
|
| 207 |
+
# be initialized properly (due to some errors in the EnvRunners's
|
| 208 |
+
# constructor).
|
| 209 |
+
except RayActorError as e:
|
| 210 |
+
# In case of an actor (remote worker) init failure, the remote worker
|
| 211 |
+
# may still exist and will be accessible, however, e.g. calling
|
| 212 |
+
# its `sample.remote()` would result in strange "property not found"
|
| 213 |
+
# errors.
|
| 214 |
+
if e.actor_init_failed:
|
| 215 |
+
# Raise the original error here that the EnvRunners raised
|
| 216 |
+
# during its construction process. This is to enforce transparency
|
| 217 |
+
# for the user (better to understand the real reason behind the
|
| 218 |
+
# failure).
|
| 219 |
+
# - e.args[0]: The RayTaskError (inside the caught RayActorError).
|
| 220 |
+
# - e.args[0].args[2]: The original Exception (e.g. a ValueError due
|
| 221 |
+
# to a config mismatch) thrown inside the actor.
|
| 222 |
+
raise e.args[0].args[2]
|
| 223 |
+
# In any other case, raise the RayActorError as-is.
|
| 224 |
+
else:
|
| 225 |
+
raise e
|
| 226 |
+
|
| 227 |
+
def _setup(
|
| 228 |
+
self,
|
| 229 |
+
*,
|
| 230 |
+
validate_env: Optional[Callable[[EnvType], None]] = None,
|
| 231 |
+
config: Optional["AlgorithmConfig"] = None,
|
| 232 |
+
num_env_runners: int = 0,
|
| 233 |
+
local_env_runner: bool = True,
|
| 234 |
+
):
|
| 235 |
+
"""Sets up an EnvRunnerGroup instance.
|
| 236 |
+
Args:
|
| 237 |
+
validate_env: Optional callable to validate the generated
|
| 238 |
+
environment (only on worker=0).
|
| 239 |
+
config: Optional dict that extends the common config of
|
| 240 |
+
the Algorithm class.
|
| 241 |
+
num_env_runners: Number of remote EnvRunner workers to create.
|
| 242 |
+
local_env_runner: Whether to create a local (non @ray.remote) EnvRunner
|
| 243 |
+
in the returned set as well (default: True). If `num_env_runners`
|
| 244 |
+
is 0, always create a local EnvRunner.
|
| 245 |
+
"""
|
| 246 |
+
# Force a local worker if num_env_runners == 0 (no remote workers).
|
| 247 |
+
# Otherwise, this EnvRunnerGroup would be empty.
|
| 248 |
+
self._local_env_runner = None
|
| 249 |
+
if num_env_runners == 0:
|
| 250 |
+
local_env_runner = True
|
| 251 |
+
# Create a local (learner) version of the config for the local worker.
|
| 252 |
+
# The only difference is the tf_session_args, which - for the local worker -
|
| 253 |
+
# will be `config.tf_session_args` updated/overridden with
|
| 254 |
+
# `config.local_tf_session_args`.
|
| 255 |
+
local_tf_session_args = config.tf_session_args.copy()
|
| 256 |
+
local_tf_session_args.update(config.local_tf_session_args)
|
| 257 |
+
self._local_config = config.copy(copy_frozen=False).framework(
|
| 258 |
+
tf_session_args=local_tf_session_args
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if config.input_ == "dataset":
|
| 262 |
+
# Create the set of dataset readers to be shared by all the
|
| 263 |
+
# rollout workers.
|
| 264 |
+
self._ds, self._ds_shards = get_dataset_and_shards(config, num_env_runners)
|
| 265 |
+
else:
|
| 266 |
+
self._ds = None
|
| 267 |
+
self._ds_shards = None
|
| 268 |
+
|
| 269 |
+
# Create a number of @ray.remote workers.
|
| 270 |
+
self.add_workers(
|
| 271 |
+
num_env_runners,
|
| 272 |
+
validate=config.validate_env_runners_after_construction,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# If num_env_runners > 0 and we don't have an env on the local worker,
|
| 276 |
+
# get the observation- and action spaces for each policy from
|
| 277 |
+
# the first remote worker (which does have an env).
|
| 278 |
+
if (
|
| 279 |
+
local_env_runner
|
| 280 |
+
and self._worker_manager.num_actors() > 0
|
| 281 |
+
and not config.enable_env_runner_and_connector_v2
|
| 282 |
+
and not config.create_env_on_local_worker
|
| 283 |
+
and (not config.observation_space or not config.action_space)
|
| 284 |
+
):
|
| 285 |
+
spaces = self.get_spaces()
|
| 286 |
+
else:
|
| 287 |
+
spaces = None
|
| 288 |
+
|
| 289 |
+
# Create a local worker, if needed.
|
| 290 |
+
if local_env_runner:
|
| 291 |
+
self._local_env_runner = self._make_worker(
|
| 292 |
+
cls=self.env_runner_cls,
|
| 293 |
+
env_creator=self._env_creator,
|
| 294 |
+
validate_env=validate_env,
|
| 295 |
+
worker_index=0,
|
| 296 |
+
num_workers=num_env_runners,
|
| 297 |
+
config=self._local_config,
|
| 298 |
+
spaces=spaces,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def get_spaces(self):
|
| 302 |
+
"""Infer observation and action spaces from one (local or remote) EnvRunner.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
A dict mapping from ModuleID to a 2-tuple containing obs- and action-space.
|
| 306 |
+
"""
|
| 307 |
+
# Get ID of the first remote worker.
|
| 308 |
+
remote_worker_ids = (
|
| 309 |
+
[self._worker_manager.actor_ids()[0]]
|
| 310 |
+
if self._worker_manager.actor_ids()
|
| 311 |
+
else []
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
spaces = self.foreach_env_runner(
|
| 315 |
+
lambda env_runner: env_runner.get_spaces(),
|
| 316 |
+
remote_worker_ids=remote_worker_ids,
|
| 317 |
+
local_env_runner=not remote_worker_ids,
|
| 318 |
+
)[0]
|
| 319 |
+
|
| 320 |
+
logger.info(
|
| 321 |
+
"Inferred observation/action spaces from remote "
|
| 322 |
+
f"worker (local worker has no env): {spaces}"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return spaces
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def local_env_runner(self) -> EnvRunner:
|
| 329 |
+
"""Returns the local EnvRunner."""
|
| 330 |
+
return self._local_env_runner
|
| 331 |
+
|
| 332 |
+
def healthy_env_runner_ids(self) -> List[int]:
|
| 333 |
+
"""Returns the list of remote worker IDs."""
|
| 334 |
+
return self._worker_manager.healthy_actor_ids()
|
| 335 |
+
|
| 336 |
+
def healthy_worker_ids(self) -> List[int]:
|
| 337 |
+
"""Returns the list of remote worker IDs."""
|
| 338 |
+
return self.healthy_env_runner_ids()
|
| 339 |
+
|
| 340 |
+
def num_remote_env_runners(self) -> int:
|
| 341 |
+
"""Returns the number of remote EnvRunners."""
|
| 342 |
+
return self._worker_manager.num_actors()
|
| 343 |
+
|
| 344 |
+
def num_remote_workers(self) -> int:
|
| 345 |
+
"""Returns the number of remote EnvRunners."""
|
| 346 |
+
return self.num_remote_env_runners()
|
| 347 |
+
|
| 348 |
+
def num_healthy_remote_env_runners(self) -> int:
|
| 349 |
+
"""Returns the number of healthy remote workers."""
|
| 350 |
+
return self._worker_manager.num_healthy_actors()
|
| 351 |
+
|
| 352 |
+
def num_healthy_remote_workers(self) -> int:
|
| 353 |
+
"""Returns the number of healthy remote workers."""
|
| 354 |
+
return self.num_healthy_remote_env_runners()
|
| 355 |
+
|
| 356 |
+
def num_healthy_env_runners(self) -> int:
|
| 357 |
+
"""Returns the number of all healthy workers, including the local worker."""
|
| 358 |
+
return int(bool(self._local_env_runner)) + self.num_healthy_remote_workers()
|
| 359 |
+
|
| 360 |
+
def num_healthy_workers(self) -> int:
|
| 361 |
+
"""Returns the number of all healthy workers, including the local worker."""
|
| 362 |
+
return self.num_healthy_env_runners()
|
| 363 |
+
|
| 364 |
+
def num_in_flight_async_reqs(self) -> int:
|
| 365 |
+
"""Returns the number of in-flight async requests."""
|
| 366 |
+
return self._worker_manager.num_outstanding_async_reqs()
|
| 367 |
+
|
| 368 |
+
def num_remote_worker_restarts(self) -> int:
|
| 369 |
+
"""Total number of times managed remote workers have been restarted."""
|
| 370 |
+
return self._worker_manager.total_num_restarts()
|
| 371 |
+
|
| 372 |
+
def sync_env_runner_states(
|
| 373 |
+
self,
|
| 374 |
+
*,
|
| 375 |
+
config: "AlgorithmConfig",
|
| 376 |
+
from_worker: Optional[EnvRunner] = None,
|
| 377 |
+
env_steps_sampled: Optional[int] = None,
|
| 378 |
+
connector_states: Optional[List[Dict[str, Any]]] = None,
|
| 379 |
+
rl_module_state: Optional[Dict[str, Any]] = None,
|
| 380 |
+
env_runner_indices_to_update: Optional[List[int]] = None,
|
| 381 |
+
) -> None:
|
| 382 |
+
"""Synchronizes the connectors of this EnvRunnerGroup's EnvRunners.
|
| 383 |
+
|
| 384 |
+
The exact procedure works as follows:
|
| 385 |
+
- If `from_worker` is None, set `from_worker=self.local_env_runner`.
|
| 386 |
+
- If `config.use_worker_filter_stats` is True, gather all remote EnvRunners'
|
| 387 |
+
ConnectorV2 states. Otherwise, only use the ConnectorV2 states of `from_worker`.
|
| 388 |
+
- Merge all gathered states into one resulting state.
|
| 389 |
+
- Broadcast the resulting state back to all remote EnvRunners AND the local
|
| 390 |
+
EnvRunner.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
config: The AlgorithmConfig object to use to determine, in which
|
| 394 |
+
direction(s) we need to synch and what the timeouts are.
|
| 395 |
+
from_worker: The EnvRunner from which to synch. If None, will use the local
|
| 396 |
+
worker of this EnvRunnerGroup.
|
| 397 |
+
env_steps_sampled: The total number of env steps taken thus far by all
|
| 398 |
+
workers combined. Used to broadcast this number to all remote workers
|
| 399 |
+
if `update_worker_filter_stats` is True in `config`.
|
| 400 |
+
env_runner_indices_to_update: The indices of those EnvRunners to update
|
| 401 |
+
with the merged state. Use None (default) to update all remote
|
| 402 |
+
EnvRunners.
|
| 403 |
+
"""
|
| 404 |
+
from_worker = from_worker or self.local_env_runner
|
| 405 |
+
|
| 406 |
+
# Early out if the number of (healthy) remote workers is 0. In this case, the
|
| 407 |
+
# local worker is the only operating worker and thus of course always holds
|
| 408 |
+
# the reference connector state.
|
| 409 |
+
if self.num_healthy_remote_workers() == 0:
|
| 410 |
+
self.local_env_runner.set_state(
|
| 411 |
+
{
|
| 412 |
+
**(
|
| 413 |
+
{NUM_ENV_STEPS_SAMPLED_LIFETIME: env_steps_sampled}
|
| 414 |
+
if env_steps_sampled is not None
|
| 415 |
+
else {}
|
| 416 |
+
),
|
| 417 |
+
**(rl_module_state if rl_module_state is not None else {}),
|
| 418 |
+
}
|
| 419 |
+
)
|
| 420 |
+
return
|
| 421 |
+
|
| 422 |
+
# Also early out, if we a) don't use the remote states AND b) don't want to
|
| 423 |
+
# broadcast back from `from_worker` to all remote workers.
|
| 424 |
+
# TODO (sven): Rename these to proper "..env_runner_states.." containing names.
|
| 425 |
+
if not config.update_worker_filter_stats and not config.use_worker_filter_stats:
|
| 426 |
+
return
|
| 427 |
+
|
| 428 |
+
# Use states from all remote EnvRunners.
|
| 429 |
+
if config.use_worker_filter_stats:
|
| 430 |
+
if connector_states == []:
|
| 431 |
+
env_runner_states = {}
|
| 432 |
+
else:
|
| 433 |
+
if connector_states is None:
|
| 434 |
+
connector_states = self.foreach_env_runner(
|
| 435 |
+
lambda w: w.get_state(
|
| 436 |
+
components=[
|
| 437 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR,
|
| 438 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR,
|
| 439 |
+
]
|
| 440 |
+
),
|
| 441 |
+
local_env_runner=False,
|
| 442 |
+
timeout_seconds=(
|
| 443 |
+
config.sync_filters_on_rollout_workers_timeout_s
|
| 444 |
+
),
|
| 445 |
+
)
|
| 446 |
+
env_to_module_states = [
|
| 447 |
+
s[COMPONENT_ENV_TO_MODULE_CONNECTOR]
|
| 448 |
+
for s in connector_states
|
| 449 |
+
if COMPONENT_ENV_TO_MODULE_CONNECTOR in s
|
| 450 |
+
]
|
| 451 |
+
module_to_env_states = [
|
| 452 |
+
s[COMPONENT_MODULE_TO_ENV_CONNECTOR]
|
| 453 |
+
for s in connector_states
|
| 454 |
+
if COMPONENT_MODULE_TO_ENV_CONNECTOR in s
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
env_runner_states = {}
|
| 458 |
+
if env_to_module_states:
|
| 459 |
+
env_runner_states.update(
|
| 460 |
+
{
|
| 461 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR: (
|
| 462 |
+
self.local_env_runner._env_to_module.merge_states(
|
| 463 |
+
env_to_module_states
|
| 464 |
+
)
|
| 465 |
+
),
|
| 466 |
+
}
|
| 467 |
+
)
|
| 468 |
+
if module_to_env_states:
|
| 469 |
+
env_runner_states.update(
|
| 470 |
+
{
|
| 471 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR: (
|
| 472 |
+
self.local_env_runner._module_to_env.merge_states(
|
| 473 |
+
module_to_env_states
|
| 474 |
+
)
|
| 475 |
+
),
|
| 476 |
+
}
|
| 477 |
+
)
|
| 478 |
+
# Ignore states from remote EnvRunners (use the current `from_worker` states
|
| 479 |
+
# only).
|
| 480 |
+
else:
|
| 481 |
+
env_runner_states = from_worker.get_state(
|
| 482 |
+
components=[
|
| 483 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR,
|
| 484 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR,
|
| 485 |
+
]
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Update the global number of environment steps, if necessary.
|
| 489 |
+
# Make sure to divide by the number of env runners (such that each EnvRunner
|
| 490 |
+
# knows (roughly) its own(!) lifetime count and can infer the global lifetime
|
| 491 |
+
# count from it).
|
| 492 |
+
if env_steps_sampled is not None:
|
| 493 |
+
env_runner_states[NUM_ENV_STEPS_SAMPLED_LIFETIME] = env_steps_sampled // (
|
| 494 |
+
config.num_env_runners or 1
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Update the rl_module component of the EnvRunner states, if necessary:
|
| 498 |
+
if rl_module_state:
|
| 499 |
+
env_runner_states.update(rl_module_state)
|
| 500 |
+
|
| 501 |
+
# If we do NOT want remote EnvRunners to get their Connector states updated,
|
| 502 |
+
# only update the local worker here (with all state components) and then remove
|
| 503 |
+
# the connector components.
|
| 504 |
+
if not config.update_worker_filter_stats:
|
| 505 |
+
self.local_env_runner.set_state(env_runner_states)
|
| 506 |
+
env_runner_states.pop(COMPONENT_ENV_TO_MODULE_CONNECTOR, None)
|
| 507 |
+
env_runner_states.pop(COMPONENT_MODULE_TO_ENV_CONNECTOR, None)
|
| 508 |
+
|
| 509 |
+
# If there are components in the state left -> Update remote workers with these
|
| 510 |
+
# state components (and maybe the local worker, if it hasn't been updated yet).
|
| 511 |
+
if env_runner_states:
|
| 512 |
+
# Put the state dictionary into Ray's object store to avoid having to make n
|
| 513 |
+
# pickled copies of the state dict.
|
| 514 |
+
ref_env_runner_states = ray.put(env_runner_states)
|
| 515 |
+
|
| 516 |
+
def _update(_env_runner: EnvRunner) -> None:
|
| 517 |
+
_env_runner.set_state(ray.get(ref_env_runner_states))
|
| 518 |
+
|
| 519 |
+
# Broadcast updated states back to all workers.
|
| 520 |
+
self.foreach_env_runner(
|
| 521 |
+
_update,
|
| 522 |
+
remote_worker_ids=env_runner_indices_to_update,
|
| 523 |
+
local_env_runner=config.update_worker_filter_stats,
|
| 524 |
+
timeout_seconds=0.0, # This is a state update -> Fire-and-forget.
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
def sync_weights(
|
| 528 |
+
self,
|
| 529 |
+
policies: Optional[List[PolicyID]] = None,
|
| 530 |
+
from_worker_or_learner_group: Optional[Union[EnvRunner, "LearnerGroup"]] = None,
|
| 531 |
+
to_worker_indices: Optional[List[int]] = None,
|
| 532 |
+
global_vars: Optional[Dict[str, TensorType]] = None,
|
| 533 |
+
timeout_seconds: Optional[float] = 0.0,
|
| 534 |
+
inference_only: Optional[bool] = False,
|
| 535 |
+
) -> None:
|
| 536 |
+
"""Syncs model weights from the given weight source to all remote workers.
|
| 537 |
+
|
| 538 |
+
Weight source can be either a (local) rollout worker or a learner_group. It
|
| 539 |
+
should just implement a `get_weights` method.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
policies: Optional list of PolicyIDs to sync weights for.
|
| 543 |
+
If None (default), sync weights to/from all policies.
|
| 544 |
+
from_worker_or_learner_group: Optional (local) EnvRunner instance or
|
| 545 |
+
LearnerGroup instance to sync from. If None (default),
|
| 546 |
+
sync from this EnvRunnerGroup's local worker.
|
| 547 |
+
to_worker_indices: Optional list of worker indices to sync the
|
| 548 |
+
weights to. If None (default), sync to all remote workers.
|
| 549 |
+
global_vars: An optional global vars dict to set this
|
| 550 |
+
worker to. If None, do not update the global_vars.
|
| 551 |
+
timeout_seconds: Timeout in seconds to wait for the sync weights
|
| 552 |
+
calls to complete. Default is 0.0 (fire-and-forget, do not wait
|
| 553 |
+
for any sync calls to finish). Setting this to 0.0 might significantly
|
| 554 |
+
improve algorithm performance, depending on the algo's `training_step`
|
| 555 |
+
logic.
|
| 556 |
+
inference_only: Sync weights with workers that keep inference-only
|
| 557 |
+
modules. This is needed for algorithms in the new stack that
|
| 558 |
+
use inference-only modules. In this case only a part of the
|
| 559 |
+
parameters are synced to the workers. Default is False.
|
| 560 |
+
"""
|
| 561 |
+
if self.local_env_runner is None and from_worker_or_learner_group is None:
|
| 562 |
+
raise TypeError(
|
| 563 |
+
"No `local_env_runner` in EnvRunnerGroup! Must provide "
|
| 564 |
+
"`from_worker_or_learner_group` arg in `sync_weights()`!"
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Only sync if we have remote workers or `from_worker_or_trainer` is provided.
|
| 568 |
+
rl_module_state = None
|
| 569 |
+
if self.num_remote_workers() or from_worker_or_learner_group is not None:
|
| 570 |
+
weights_src = from_worker_or_learner_group or self.local_env_runner
|
| 571 |
+
|
| 572 |
+
if weights_src is None:
|
| 573 |
+
raise ValueError(
|
| 574 |
+
"`from_worker_or_trainer` is None. In this case, EnvRunnerGroup "
|
| 575 |
+
"should have local_env_runner. But local_env_runner is also None."
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
modules = (
|
| 579 |
+
[COMPONENT_RL_MODULE + "/" + p for p in policies]
|
| 580 |
+
if policies is not None
|
| 581 |
+
else [COMPONENT_RL_MODULE]
|
| 582 |
+
)
|
| 583 |
+
# LearnerGroup has-a Learner has-a RLModule.
|
| 584 |
+
if isinstance(weights_src, LearnerGroup):
|
| 585 |
+
rl_module_state = weights_src.get_state(
|
| 586 |
+
components=[COMPONENT_LEARNER + "/" + m for m in modules],
|
| 587 |
+
inference_only=inference_only,
|
| 588 |
+
)[COMPONENT_LEARNER]
|
| 589 |
+
# EnvRunner has-a RLModule.
|
| 590 |
+
elif self._remote_config.enable_env_runner_and_connector_v2:
|
| 591 |
+
rl_module_state = weights_src.get_state(
|
| 592 |
+
components=modules,
|
| 593 |
+
inference_only=inference_only,
|
| 594 |
+
)
|
| 595 |
+
else:
|
| 596 |
+
rl_module_state = weights_src.get_weights(
|
| 597 |
+
policies=policies,
|
| 598 |
+
inference_only=inference_only,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
if self._remote_config.enable_env_runner_and_connector_v2:
|
| 602 |
+
|
| 603 |
+
# Make sure `rl_module_state` only contains the weights and the
|
| 604 |
+
# weight seq no, nothing else.
|
| 605 |
+
rl_module_state = {
|
| 606 |
+
k: v
|
| 607 |
+
for k, v in rl_module_state.items()
|
| 608 |
+
if k in [COMPONENT_RL_MODULE, WEIGHTS_SEQ_NO]
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
# Move weights to the object store to avoid having to make n pickled
|
| 612 |
+
# copies of the weights dict for each worker.
|
| 613 |
+
rl_module_state_ref = ray.put(rl_module_state)
|
| 614 |
+
|
| 615 |
+
def _set_weights(env_runner):
|
| 616 |
+
env_runner.set_state(ray.get(rl_module_state_ref))
|
| 617 |
+
|
| 618 |
+
else:
|
| 619 |
+
rl_module_state_ref = ray.put(rl_module_state)
|
| 620 |
+
|
| 621 |
+
def _set_weights(env_runner):
|
| 622 |
+
env_runner.set_weights(ray.get(rl_module_state_ref), global_vars)
|
| 623 |
+
|
| 624 |
+
# Sync to specified remote workers in this EnvRunnerGroup.
|
| 625 |
+
self.foreach_env_runner(
|
| 626 |
+
func=_set_weights,
|
| 627 |
+
local_env_runner=False, # Do not sync back to local worker.
|
| 628 |
+
remote_worker_ids=to_worker_indices,
|
| 629 |
+
timeout_seconds=timeout_seconds,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# If `from_worker_or_learner_group` is provided, also sync to this
|
| 633 |
+
# EnvRunnerGroup's local worker.
|
| 634 |
+
if self.local_env_runner is not None:
|
| 635 |
+
if from_worker_or_learner_group is not None:
|
| 636 |
+
if self._remote_config.enable_env_runner_and_connector_v2:
|
| 637 |
+
self.local_env_runner.set_state(rl_module_state)
|
| 638 |
+
else:
|
| 639 |
+
self.local_env_runner.set_weights(rl_module_state)
|
| 640 |
+
# If `global_vars` is provided and local worker exists -> Update its
|
| 641 |
+
# global_vars.
|
| 642 |
+
if global_vars is not None:
|
| 643 |
+
self.local_env_runner.set_global_vars(global_vars)
|
| 644 |
+
|
| 645 |
+
@OldAPIStack
|
| 646 |
+
def add_policy(
|
| 647 |
+
self,
|
| 648 |
+
policy_id: PolicyID,
|
| 649 |
+
policy_cls: Optional[Type[Policy]] = None,
|
| 650 |
+
policy: Optional[Policy] = None,
|
| 651 |
+
*,
|
| 652 |
+
observation_space: Optional[gym.spaces.Space] = None,
|
| 653 |
+
action_space: Optional[gym.spaces.Space] = None,
|
| 654 |
+
config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None,
|
| 655 |
+
policy_state: Optional[PolicyState] = None,
|
| 656 |
+
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
|
| 657 |
+
policies_to_train: Optional[
|
| 658 |
+
Union[
|
| 659 |
+
Collection[PolicyID],
|
| 660 |
+
Callable[[PolicyID, Optional[SampleBatchType]], bool],
|
| 661 |
+
]
|
| 662 |
+
] = None,
|
| 663 |
+
module_spec: Optional[RLModuleSpec] = None,
|
| 664 |
+
# Deprecated.
|
| 665 |
+
workers: Optional[List[Union[EnvRunner, ActorHandle]]] = DEPRECATED_VALUE,
|
| 666 |
+
) -> None:
|
| 667 |
+
"""Adds a policy to this EnvRunnerGroup's workers or a specific list of workers.
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
policy_id: ID of the policy to add.
|
| 671 |
+
policy_cls: The Policy class to use for constructing the new Policy.
|
| 672 |
+
Note: Only one of `policy_cls` or `policy` must be provided.
|
| 673 |
+
policy: The Policy instance to add to this EnvRunnerGroup. If not None, the
|
| 674 |
+
given Policy object will be directly inserted into the
|
| 675 |
+
local worker and clones of that Policy will be created on all remote
|
| 676 |
+
workers.
|
| 677 |
+
Note: Only one of `policy_cls` or `policy` must be provided.
|
| 678 |
+
observation_space: The observation space of the policy to add.
|
| 679 |
+
If None, try to infer this space from the environment.
|
| 680 |
+
action_space: The action space of the policy to add.
|
| 681 |
+
If None, try to infer this space from the environment.
|
| 682 |
+
config: The config object or overrides for the policy to add.
|
| 683 |
+
policy_state: Optional state dict to apply to the new
|
| 684 |
+
policy instance, right after its construction.
|
| 685 |
+
policy_mapping_fn: An optional (updated) policy mapping function
|
| 686 |
+
to use from here on. Note that already ongoing episodes will
|
| 687 |
+
not change their mapping but will use the old mapping till
|
| 688 |
+
the end of the episode.
|
| 689 |
+
policies_to_train: An optional list of policy IDs to be trained
|
| 690 |
+
or a callable taking PolicyID and SampleBatchType and
|
| 691 |
+
returning a bool (trainable or not?).
|
| 692 |
+
If None, will keep the existing setup in place. Policies,
|
| 693 |
+
whose IDs are not in the list (or for which the callable
|
| 694 |
+
returns False) will not be updated.
|
| 695 |
+
module_spec: In the new RLModule API we need to pass in the module_spec for
|
| 696 |
+
the new module that is supposed to be added. Knowing the policy spec is
|
| 697 |
+
not sufficient.
|
| 698 |
+
workers: A list of EnvRunner/ActorHandles (remote
|
| 699 |
+
EnvRunners) to add this policy to. If defined, will only
|
| 700 |
+
add the given policy to these workers.
|
| 701 |
+
|
| 702 |
+
Raises:
|
| 703 |
+
KeyError: If the given `policy_id` already exists in this EnvRunnerGroup.
|
| 704 |
+
"""
|
| 705 |
+
if self.local_env_runner and policy_id in self.local_env_runner.policy_map:
|
| 706 |
+
raise KeyError(
|
| 707 |
+
f"Policy ID '{policy_id}' already exists in policy map! "
|
| 708 |
+
"Make sure you use a Policy ID that has not been taken yet."
|
| 709 |
+
" Policy IDs that are already in your policy map: "
|
| 710 |
+
f"{list(self.local_env_runner.policy_map.keys())}"
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
if workers is not DEPRECATED_VALUE:
|
| 714 |
+
deprecation_warning(
|
| 715 |
+
old="EnvRunnerGroup.add_policy(.., workers=..)",
|
| 716 |
+
help=(
|
| 717 |
+
"The `workers` argument to `EnvRunnerGroup.add_policy()` is "
|
| 718 |
+
"deprecated! Please do not use it anymore."
|
| 719 |
+
),
|
| 720 |
+
error=True,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if (policy_cls is None) == (policy is None):
|
| 724 |
+
raise ValueError(
|
| 725 |
+
"Only one of `policy_cls` or `policy` must be provided to "
|
| 726 |
+
"staticmethod: `EnvRunnerGroup.add_policy()`!"
|
| 727 |
+
)
|
| 728 |
+
validate_module_id(policy_id, error=False)
|
| 729 |
+
|
| 730 |
+
# Policy instance not provided: Use the information given here.
|
| 731 |
+
if policy_cls is not None:
|
| 732 |
+
new_policy_instance_kwargs = dict(
|
| 733 |
+
policy_id=policy_id,
|
| 734 |
+
policy_cls=policy_cls,
|
| 735 |
+
observation_space=observation_space,
|
| 736 |
+
action_space=action_space,
|
| 737 |
+
config=config,
|
| 738 |
+
policy_state=policy_state,
|
| 739 |
+
policy_mapping_fn=policy_mapping_fn,
|
| 740 |
+
policies_to_train=list(policies_to_train)
|
| 741 |
+
if policies_to_train
|
| 742 |
+
else None,
|
| 743 |
+
module_spec=module_spec,
|
| 744 |
+
)
|
| 745 |
+
# Policy instance provided: Create clones of this very policy on the different
|
| 746 |
+
# workers (copy all its properties here for the calls to add_policy on the
|
| 747 |
+
# remote workers).
|
| 748 |
+
else:
|
| 749 |
+
new_policy_instance_kwargs = dict(
|
| 750 |
+
policy_id=policy_id,
|
| 751 |
+
policy_cls=type(policy),
|
| 752 |
+
observation_space=policy.observation_space,
|
| 753 |
+
action_space=policy.action_space,
|
| 754 |
+
config=policy.config,
|
| 755 |
+
policy_state=policy.get_state(),
|
| 756 |
+
policy_mapping_fn=policy_mapping_fn,
|
| 757 |
+
policies_to_train=list(policies_to_train)
|
| 758 |
+
if policies_to_train
|
| 759 |
+
else None,
|
| 760 |
+
module_spec=module_spec,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
def _create_new_policy_fn(worker):
|
| 764 |
+
# `foreach_env_runner` function: Adds the policy the the worker (and
|
| 765 |
+
# maybe changes its policy_mapping_fn - if provided here).
|
| 766 |
+
worker.add_policy(**new_policy_instance_kwargs)
|
| 767 |
+
|
| 768 |
+
if self.local_env_runner is not None:
|
| 769 |
+
# Add policy directly by (already instantiated) object.
|
| 770 |
+
if policy is not None:
|
| 771 |
+
self.local_env_runner.add_policy(
|
| 772 |
+
policy_id=policy_id,
|
| 773 |
+
policy=policy,
|
| 774 |
+
policy_mapping_fn=policy_mapping_fn,
|
| 775 |
+
policies_to_train=policies_to_train,
|
| 776 |
+
module_spec=module_spec,
|
| 777 |
+
)
|
| 778 |
+
# Add policy by constructor kwargs.
|
| 779 |
+
else:
|
| 780 |
+
self.local_env_runner.add_policy(**new_policy_instance_kwargs)
|
| 781 |
+
|
| 782 |
+
# Add the policy to all remote workers.
|
| 783 |
+
self.foreach_env_runner(_create_new_policy_fn, local_env_runner=False)
|
| 784 |
+
|
| 785 |
+
def add_workers(self, num_workers: int, validate: bool = False) -> None:
|
| 786 |
+
"""Creates and adds a number of remote workers to this worker set.
|
| 787 |
+
|
| 788 |
+
Can be called several times on the same EnvRunnerGroup to add more
|
| 789 |
+
EnvRunners to the set.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
num_workers: The number of remote Workers to add to this
|
| 793 |
+
EnvRunnerGroup.
|
| 794 |
+
validate: Whether to validate remote workers after their construction
|
| 795 |
+
process.
|
| 796 |
+
|
| 797 |
+
Raises:
|
| 798 |
+
RayError: If any of the constructed remote workers is not up and running
|
| 799 |
+
properly.
|
| 800 |
+
"""
|
| 801 |
+
old_num_workers = self._worker_manager.num_actors()
|
| 802 |
+
new_workers = [
|
| 803 |
+
self._make_worker(
|
| 804 |
+
cls=self._cls,
|
| 805 |
+
env_creator=self._env_creator,
|
| 806 |
+
validate_env=None,
|
| 807 |
+
worker_index=old_num_workers + i + 1,
|
| 808 |
+
num_workers=old_num_workers + num_workers,
|
| 809 |
+
config=self._remote_config,
|
| 810 |
+
)
|
| 811 |
+
for i in range(num_workers)
|
| 812 |
+
]
|
| 813 |
+
self._worker_manager.add_actors(new_workers)
|
| 814 |
+
|
| 815 |
+
# Validate here, whether all remote workers have been constructed properly
|
| 816 |
+
# and are "up and running". Establish initial states.
|
| 817 |
+
if validate:
|
| 818 |
+
for result in self._worker_manager.foreach_actor(
|
| 819 |
+
lambda w: w.assert_healthy()
|
| 820 |
+
):
|
| 821 |
+
# Simiply raise the error, which will get handled by the try-except
|
| 822 |
+
# clause around the _setup().
|
| 823 |
+
if not result.ok:
|
| 824 |
+
e = result.get()
|
| 825 |
+
if self._ignore_ray_errors_on_env_runners:
|
| 826 |
+
logger.error(f"Validation of EnvRunner failed! Error={str(e)}")
|
| 827 |
+
else:
|
| 828 |
+
raise e
|
| 829 |
+
|
| 830 |
+
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
|
| 831 |
+
"""Hard overrides the remote EnvRunners in this set with the provided ones.
|
| 832 |
+
|
| 833 |
+
Args:
|
| 834 |
+
new_remote_workers: A list of new EnvRunners (as `ActorHandles`) to use as
|
| 835 |
+
new remote workers.
|
| 836 |
+
"""
|
| 837 |
+
self._worker_manager.clear()
|
| 838 |
+
self._worker_manager.add_actors(new_remote_workers)
|
| 839 |
+
|
| 840 |
+
def stop(self) -> None:
|
| 841 |
+
"""Calls `stop` on all EnvRunners (including the local one)."""
|
| 842 |
+
try:
|
| 843 |
+
# Make sure we stop all EnvRunners, include the ones that were just
|
| 844 |
+
# restarted / recovered or that are tagged unhealthy (at least, we should
|
| 845 |
+
# try).
|
| 846 |
+
self.foreach_env_runner(
|
| 847 |
+
lambda w: w.stop(), healthy_only=False, local_env_runner=True
|
| 848 |
+
)
|
| 849 |
+
except Exception:
|
| 850 |
+
logger.exception("Failed to stop workers!")
|
| 851 |
+
finally:
|
| 852 |
+
self._worker_manager.clear()
|
| 853 |
+
|
| 854 |
+
def is_policy_to_train(
|
| 855 |
+
self, policy_id: PolicyID, batch: Optional[SampleBatchType] = None
|
| 856 |
+
) -> bool:
|
| 857 |
+
"""Whether given PolicyID (optionally inside some batch) is trainable."""
|
| 858 |
+
if self.local_env_runner:
|
| 859 |
+
if self.local_env_runner.is_policy_to_train is None:
|
| 860 |
+
return True
|
| 861 |
+
return self.local_env_runner.is_policy_to_train(policy_id, batch)
|
| 862 |
+
else:
|
| 863 |
+
raise NotImplementedError
|
| 864 |
+
|
| 865 |
+
def foreach_env_runner(
|
| 866 |
+
self,
|
| 867 |
+
func: Callable[[EnvRunner], T],
|
| 868 |
+
*,
|
| 869 |
+
local_env_runner: bool = True,
|
| 870 |
+
healthy_only: bool = True,
|
| 871 |
+
remote_worker_ids: List[int] = None,
|
| 872 |
+
timeout_seconds: Optional[float] = None,
|
| 873 |
+
return_obj_refs: bool = False,
|
| 874 |
+
mark_healthy: bool = False,
|
| 875 |
+
) -> List[T]:
|
| 876 |
+
"""Calls the given function with each EnvRunner as its argument.
|
| 877 |
+
|
| 878 |
+
Args:
|
| 879 |
+
func: The function to call for each EnvRunners. The only call argument is
|
| 880 |
+
the respective EnvRunner instance.
|
| 881 |
+
local_env_runner: Whether to apply `func` to local EnvRunner, too.
|
| 882 |
+
Default is True.
|
| 883 |
+
healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
|
| 884 |
+
remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
|
| 885 |
+
Use None (default) for all remote EnvRunners.
|
| 886 |
+
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
|
| 887 |
+
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
|
| 888 |
+
synchronous execution).
|
| 889 |
+
return_obj_refs: Whether to return ObjectRef instead of actual results.
|
| 890 |
+
Note, for fault tolerance reasons, these returned ObjectRefs should
|
| 891 |
+
never be resolved with ray.get() outside of this EnvRunnerGroup.
|
| 892 |
+
mark_healthy: Whether to mark all those EnvRunners healthy again that are
|
| 893 |
+
currently marked unhealthy AND that returned results from the remote
|
| 894 |
+
call (within the given `timeout_seconds`).
|
| 895 |
+
Note that EnvRunners are NOT set unhealthy, if they simply time out
|
| 896 |
+
(only if they return a RayActorError).
|
| 897 |
+
Also note that this setting is ignored if `healthy_only=True` (b/c
|
| 898 |
+
`mark_healthy` only affects EnvRunners that are currently tagged as
|
| 899 |
+
unhealthy).
|
| 900 |
+
|
| 901 |
+
Returns:
|
| 902 |
+
The list of return values of all calls to `func([worker])`.
|
| 903 |
+
"""
|
| 904 |
+
assert (
|
| 905 |
+
not return_obj_refs or not local_env_runner
|
| 906 |
+
), "Can not return ObjectRef from local worker."
|
| 907 |
+
|
| 908 |
+
local_result = []
|
| 909 |
+
if local_env_runner and self.local_env_runner is not None:
|
| 910 |
+
local_result = [func(self.local_env_runner)]
|
| 911 |
+
|
| 912 |
+
if not self._worker_manager.actor_ids():
|
| 913 |
+
return local_result
|
| 914 |
+
|
| 915 |
+
remote_results = self._worker_manager.foreach_actor(
|
| 916 |
+
func,
|
| 917 |
+
healthy_only=healthy_only,
|
| 918 |
+
remote_actor_ids=remote_worker_ids,
|
| 919 |
+
timeout_seconds=timeout_seconds,
|
| 920 |
+
return_obj_refs=return_obj_refs,
|
| 921 |
+
mark_healthy=mark_healthy,
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
FaultTolerantActorManager.handle_remote_call_result_errors(
|
| 925 |
+
remote_results, ignore_ray_errors=self._ignore_ray_errors_on_env_runners
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
# With application errors handled, return good results.
|
| 929 |
+
remote_results = [r.get() for r in remote_results.ignore_errors()]
|
| 930 |
+
|
| 931 |
+
return local_result + remote_results
|
| 932 |
+
|
| 933 |
+
def foreach_env_runner_with_id(
|
| 934 |
+
self,
|
| 935 |
+
func: Callable[[int, EnvRunner], T],
|
| 936 |
+
*,
|
| 937 |
+
local_env_runner: bool = True,
|
| 938 |
+
healthy_only: bool = True,
|
| 939 |
+
remote_worker_ids: List[int] = None,
|
| 940 |
+
timeout_seconds: Optional[float] = None,
|
| 941 |
+
return_obj_refs: bool = False,
|
| 942 |
+
mark_healthy: bool = False,
|
| 943 |
+
# Deprecated args.
|
| 944 |
+
local_worker=DEPRECATED_VALUE,
|
| 945 |
+
) -> List[T]:
|
| 946 |
+
"""Calls the given function with each EnvRunner and its ID as its arguments.
|
| 947 |
+
|
| 948 |
+
Args:
|
| 949 |
+
func: The function to call for each EnvRunners. The call arguments are
|
| 950 |
+
the EnvRunner's index (int) and the respective EnvRunner instance
|
| 951 |
+
itself.
|
| 952 |
+
local_env_runner: Whether to apply `func` to the local EnvRunner, too.
|
| 953 |
+
Default is True.
|
| 954 |
+
healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
|
| 955 |
+
remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
|
| 956 |
+
timeout_seconds: Time to wait for results. Default is None.
|
| 957 |
+
return_obj_refs: Whether to return ObjectRef instead of actual results.
|
| 958 |
+
Note, for fault tolerance reasons, these returned ObjectRefs should
|
| 959 |
+
never be resolved with ray.get() outside of this EnvRunnerGroup.
|
| 960 |
+
mark_healthy: Whether to mark all those EnvRunners healthy again that are
|
| 961 |
+
currently marked unhealthy AND that returned results from the remote
|
| 962 |
+
call (within the given `timeout_seconds`).
|
| 963 |
+
Note that workers are NOT set unhealthy, if they simply time out
|
| 964 |
+
(only if they return a RayActorError).
|
| 965 |
+
Also note that this setting is ignored if `healthy_only=True` (b/c
|
| 966 |
+
`mark_healthy` only affects EnvRunners that are currently tagged as
|
| 967 |
+
unhealthy).
|
| 968 |
+
|
| 969 |
+
Returns:
|
| 970 |
+
The list of return values of all calls to `func([worker, id])`.
|
| 971 |
+
"""
|
| 972 |
+
local_result = []
|
| 973 |
+
if local_env_runner and self.local_env_runner is not None:
|
| 974 |
+
local_result = [func(0, self.local_env_runner)]
|
| 975 |
+
|
| 976 |
+
if not remote_worker_ids:
|
| 977 |
+
remote_worker_ids = self._worker_manager.actor_ids()
|
| 978 |
+
|
| 979 |
+
funcs = [functools.partial(func, i) for i in remote_worker_ids]
|
| 980 |
+
|
| 981 |
+
remote_results = self._worker_manager.foreach_actor(
|
| 982 |
+
funcs,
|
| 983 |
+
healthy_only=healthy_only,
|
| 984 |
+
remote_actor_ids=remote_worker_ids,
|
| 985 |
+
timeout_seconds=timeout_seconds,
|
| 986 |
+
return_obj_refs=return_obj_refs,
|
| 987 |
+
mark_healthy=mark_healthy,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
FaultTolerantActorManager.handle_remote_call_result_errors(
|
| 991 |
+
remote_results,
|
| 992 |
+
ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
remote_results = [r.get() for r in remote_results.ignore_errors()]
|
| 996 |
+
|
| 997 |
+
return local_result + remote_results
|
| 998 |
+
|
| 999 |
+
def foreach_env_runner_async(
|
| 1000 |
+
self,
|
| 1001 |
+
func: Callable[[EnvRunner], T],
|
| 1002 |
+
*,
|
| 1003 |
+
healthy_only: bool = True,
|
| 1004 |
+
remote_worker_ids: List[int] = None,
|
| 1005 |
+
) -> int:
|
| 1006 |
+
"""Calls the given function asynchronously with each EnvRunner as the argument.
|
| 1007 |
+
|
| 1008 |
+
Does not return results directly. Instead, `fetch_ready_async_reqs()` can be
|
| 1009 |
+
used to pull results in an async manner whenever they are available.
|
| 1010 |
+
|
| 1011 |
+
Args:
|
| 1012 |
+
func: The function to call for each EnvRunners. The only call argument is
|
| 1013 |
+
the respective EnvRunner instance.
|
| 1014 |
+
healthy_only: Apply `func` on known-to-be healthy EnvRunners only.
|
| 1015 |
+
remote_worker_ids: Apply `func` on a selected set of remote EnvRunners.
|
| 1016 |
+
|
| 1017 |
+
Returns:
|
| 1018 |
+
The number of async requests that have actually been made. This is the
|
| 1019 |
+
length of `remote_worker_ids` (or self.num_remote_workers()` if
|
| 1020 |
+
`remote_worker_ids` is None) minus the number of requests that were NOT
|
| 1021 |
+
made b/c a remote EnvRunner already had its
|
| 1022 |
+
`max_remote_requests_in_flight_per_actor` counter reached.
|
| 1023 |
+
"""
|
| 1024 |
+
return self._worker_manager.foreach_actor_async(
|
| 1025 |
+
func,
|
| 1026 |
+
healthy_only=healthy_only,
|
| 1027 |
+
remote_actor_ids=remote_worker_ids,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
def fetch_ready_async_reqs(
|
| 1031 |
+
self,
|
| 1032 |
+
*,
|
| 1033 |
+
timeout_seconds: Optional[float] = 0.0,
|
| 1034 |
+
return_obj_refs: bool = False,
|
| 1035 |
+
mark_healthy: bool = False,
|
| 1036 |
+
) -> List[Tuple[int, T]]:
|
| 1037 |
+
"""Get esults from outstanding asynchronous requests that are ready.
|
| 1038 |
+
|
| 1039 |
+
Args:
|
| 1040 |
+
timeout_seconds: Time to wait for results. Default is 0, meaning
|
| 1041 |
+
those requests that are already ready.
|
| 1042 |
+
return_obj_refs: Whether to return ObjectRef instead of actual results.
|
| 1043 |
+
mark_healthy: Whether to mark all those workers healthy again that are
|
| 1044 |
+
currently marked unhealthy AND that returned results from the remote
|
| 1045 |
+
call (within the given `timeout_seconds`).
|
| 1046 |
+
Note that workers are NOT set unhealthy, if they simply time out
|
| 1047 |
+
(only if they return a RayActorError).
|
| 1048 |
+
Also note that this setting is ignored if `healthy_only=True` (b/c
|
| 1049 |
+
`mark_healthy` only affects workers that are currently tagged as
|
| 1050 |
+
unhealthy).
|
| 1051 |
+
|
| 1052 |
+
Returns:
|
| 1053 |
+
A list of results successfully returned from outstanding remote calls,
|
| 1054 |
+
paired with the indices of the callee workers.
|
| 1055 |
+
"""
|
| 1056 |
+
remote_results = self._worker_manager.fetch_ready_async_reqs(
|
| 1057 |
+
timeout_seconds=timeout_seconds,
|
| 1058 |
+
return_obj_refs=return_obj_refs,
|
| 1059 |
+
mark_healthy=mark_healthy,
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
FaultTolerantActorManager.handle_remote_call_result_errors(
|
| 1063 |
+
remote_results,
|
| 1064 |
+
ignore_ray_errors=self._ignore_ray_errors_on_env_runners,
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
return [(r.actor_id, r.get()) for r in remote_results.ignore_errors()]
|
| 1068 |
+
|
| 1069 |
+
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
| 1070 |
+
"""Calls `func` with each worker's (policy, PolicyID) tuple.
|
| 1071 |
+
|
| 1072 |
+
Note that in the multi-agent case, each worker may have more than one
|
| 1073 |
+
policy.
|
| 1074 |
+
|
| 1075 |
+
Args:
|
| 1076 |
+
func: A function - taking a Policy and its ID - that is
|
| 1077 |
+
called on all workers' Policies.
|
| 1078 |
+
|
| 1079 |
+
Returns:
|
| 1080 |
+
The list of return values of func over all workers' policies. The
|
| 1081 |
+
length of this list is:
|
| 1082 |
+
(num_workers + 1 (local-worker)) *
|
| 1083 |
+
[num policies in the multi-agent config dict].
|
| 1084 |
+
The local workers' results are first, followed by all remote
|
| 1085 |
+
workers' results
|
| 1086 |
+
"""
|
| 1087 |
+
results = []
|
| 1088 |
+
for r in self.foreach_env_runner(
|
| 1089 |
+
lambda w: w.foreach_policy(func), local_env_runner=True
|
| 1090 |
+
):
|
| 1091 |
+
results.extend(r)
|
| 1092 |
+
return results
|
| 1093 |
+
|
| 1094 |
+
def foreach_policy_to_train(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
| 1095 |
+
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
|
| 1096 |
+
|
| 1097 |
+
Args:
|
| 1098 |
+
func: A function - taking a Policy and its ID - that is
|
| 1099 |
+
called on all workers' Policies, for which
|
| 1100 |
+
`worker.is_policy_to_train()` returns True.
|
| 1101 |
+
|
| 1102 |
+
Returns:
|
| 1103 |
+
List[any]: The list of n return values of all
|
| 1104 |
+
`func([trainable policy], [ID])`-calls.
|
| 1105 |
+
"""
|
| 1106 |
+
results = []
|
| 1107 |
+
for r in self.foreach_env_runner(
|
| 1108 |
+
lambda w: w.foreach_policy_to_train(func), local_env_runner=True
|
| 1109 |
+
):
|
| 1110 |
+
results.extend(r)
|
| 1111 |
+
return results
|
| 1112 |
+
|
| 1113 |
+
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
|
| 1114 |
+
"""Calls `func` with all workers' sub-environments as args.
|
| 1115 |
+
|
| 1116 |
+
An "underlying sub environment" is a single clone of an env within
|
| 1117 |
+
a vectorized environment.
|
| 1118 |
+
`func` takes a single underlying sub environment as arg, e.g. a
|
| 1119 |
+
gym.Env object.
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
func: A function - taking an EnvType (normally a gym.Env object)
|
| 1123 |
+
as arg and returning a list of lists of return values, one
|
| 1124 |
+
value per underlying sub-environment per each worker.
|
| 1125 |
+
|
| 1126 |
+
Returns:
|
| 1127 |
+
The list (workers) of lists (sub environments) of results.
|
| 1128 |
+
"""
|
| 1129 |
+
return list(
|
| 1130 |
+
self.foreach_env_runner(
|
| 1131 |
+
lambda w: w.foreach_env(func),
|
| 1132 |
+
local_env_runner=True,
|
| 1133 |
+
)
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
def foreach_env_with_context(
|
| 1137 |
+
self, func: Callable[[BaseEnv, EnvContext], List[T]]
|
| 1138 |
+
) -> List[List[T]]:
|
| 1139 |
+
"""Calls `func` with all workers' sub-environments and env_ctx as args.
|
| 1140 |
+
|
| 1141 |
+
An "underlying sub environment" is a single clone of an env within
|
| 1142 |
+
a vectorized environment.
|
| 1143 |
+
`func` takes a single underlying sub environment and the env_context
|
| 1144 |
+
as args.
|
| 1145 |
+
|
| 1146 |
+
Args:
|
| 1147 |
+
func: A function - taking a BaseEnv object and an EnvContext as
|
| 1148 |
+
arg - and returning a list of lists of return values over envs
|
| 1149 |
+
of the worker.
|
| 1150 |
+
|
| 1151 |
+
Returns:
|
| 1152 |
+
The list (1 item per workers) of lists (1 item per sub-environment)
|
| 1153 |
+
of results.
|
| 1154 |
+
"""
|
| 1155 |
+
return list(
|
| 1156 |
+
self.foreach_env_runner(
|
| 1157 |
+
lambda w: w.foreach_env_with_context(func),
|
| 1158 |
+
local_env_runner=True,
|
| 1159 |
+
)
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
def probe_unhealthy_env_runners(self) -> List[int]:
|
| 1163 |
+
"""Checks for unhealthy workers and tries restoring their states.
|
| 1164 |
+
|
| 1165 |
+
Returns:
|
| 1166 |
+
List of IDs of the workers that were restored.
|
| 1167 |
+
"""
|
| 1168 |
+
return self._worker_manager.probe_unhealthy_actors(
|
| 1169 |
+
timeout_seconds=self._remote_config.env_runner_health_probe_timeout_s,
|
| 1170 |
+
mark_healthy=True,
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
def _make_worker(
|
| 1174 |
+
self,
|
| 1175 |
+
*,
|
| 1176 |
+
cls: Callable,
|
| 1177 |
+
env_creator: EnvCreator,
|
| 1178 |
+
validate_env: Optional[Callable[[EnvType], None]],
|
| 1179 |
+
worker_index: int,
|
| 1180 |
+
num_workers: int,
|
| 1181 |
+
recreated_worker: bool = False,
|
| 1182 |
+
config: "AlgorithmConfig",
|
| 1183 |
+
spaces: Optional[
|
| 1184 |
+
Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]
|
| 1185 |
+
] = None,
|
| 1186 |
+
) -> Union[EnvRunner, ActorHandle]:
|
| 1187 |
+
worker = cls(
|
| 1188 |
+
env_creator=env_creator,
|
| 1189 |
+
validate_env=validate_env,
|
| 1190 |
+
default_policy_class=self._policy_class,
|
| 1191 |
+
config=config,
|
| 1192 |
+
worker_index=worker_index,
|
| 1193 |
+
num_workers=num_workers,
|
| 1194 |
+
recreated_worker=recreated_worker,
|
| 1195 |
+
log_dir=self._logdir,
|
| 1196 |
+
spaces=spaces,
|
| 1197 |
+
dataset_shards=self._ds_shards,
|
| 1198 |
+
tune_trial_id=self._tune_trial_id,
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
return worker
|
| 1202 |
+
|
| 1203 |
+
@classmethod
|
| 1204 |
+
def _valid_module(cls, class_path):
|
| 1205 |
+
del cls
|
| 1206 |
+
if (
|
| 1207 |
+
isinstance(class_path, str)
|
| 1208 |
+
and not os.path.isfile(class_path)
|
| 1209 |
+
and "." in class_path
|
| 1210 |
+
):
|
| 1211 |
+
module_path, class_name = class_path.rsplit(".", 1)
|
| 1212 |
+
try:
|
| 1213 |
+
spec = importlib.util.find_spec(module_path)
|
| 1214 |
+
if spec is not None:
|
| 1215 |
+
return True
|
| 1216 |
+
except (ModuleNotFoundError, ValueError):
|
| 1217 |
+
print(
|
| 1218 |
+
f"module {module_path} not found while trying to get "
|
| 1219 |
+
f"input {class_path}"
|
| 1220 |
+
)
|
| 1221 |
+
return False
|
| 1222 |
+
|
| 1223 |
+
@Deprecated(new="EnvRunnerGroup.probe_unhealthy_env_runners", error=False)
|
| 1224 |
+
def probe_unhealthy_workers(self, *args, **kwargs):
|
| 1225 |
+
return self.probe_unhealthy_env_runners(*args, **kwargs)
|
| 1226 |
+
|
| 1227 |
+
@Deprecated(new="EnvRunnerGroup.foreach_env_runner", error=False)
|
| 1228 |
+
def foreach_worker(self, *args, **kwargs):
|
| 1229 |
+
return self.foreach_env_runner(*args, **kwargs)
|
| 1230 |
+
|
| 1231 |
+
@Deprecated(new="EnvRunnerGroup.foreach_env_runner_with_id", error=False)
|
| 1232 |
+
def foreach_worker_with_id(self, *args, **kwargs):
|
| 1233 |
+
return self.foreach_env_runner_with_id(*args, **kwargs)
|
| 1234 |
+
|
| 1235 |
+
@Deprecated(new="EnvRunnerGroup.foreach_env_runner_async", error=False)
|
| 1236 |
+
def foreach_worker_async(self, *args, **kwargs):
|
| 1237 |
+
return self.foreach_env_runner_async(*args, **kwargs)
|
| 1238 |
+
|
| 1239 |
+
@Deprecated(new="EnvRunnerGroup.local_env_runner", error=True)
|
| 1240 |
+
def local_worker(self) -> EnvRunner:
|
| 1241 |
+
pass
|
| 1242 |
+
|
| 1243 |
+
@property
|
| 1244 |
+
@Deprecated(
|
| 1245 |
+
old="_remote_workers",
|
| 1246 |
+
new="Use either the `foreach_env_runner()`, `foreach_env_runner_with_id()`, or "
|
| 1247 |
+
"`foreach_env_runner_async()` APIs of `EnvRunnerGroup`, which all handle fault "
|
| 1248 |
+
"tolerance.",
|
| 1249 |
+
error=True,
|
| 1250 |
+
)
|
| 1251 |
+
def _remote_workers(self):
|
| 1252 |
+
pass
|
| 1253 |
+
|
| 1254 |
+
@Deprecated(
|
| 1255 |
+
old="remote_workers()",
|
| 1256 |
+
new="Use either the `foreach_env_runner()`, `foreach_env_runner_with_id()`, or "
|
| 1257 |
+
"`foreach_env_runner_async()` APIs of `EnvRunnerGroup`, which all handle fault "
|
| 1258 |
+
"tolerance.",
|
| 1259 |
+
error=True,
|
| 1260 |
+
)
|
| 1261 |
+
def remote_workers(self):
|
| 1262 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/rllib/env/external_env.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import queue
|
| 3 |
+
import threading
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import Callable, Tuple, Optional, TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
from ray.rllib.env.base_env import BaseEnv
|
| 8 |
+
from ray.rllib.utils.annotations import override, OldAPIStack
|
| 9 |
+
from ray.rllib.utils.typing import (
|
| 10 |
+
EnvActionType,
|
| 11 |
+
EnvInfoDict,
|
| 12 |
+
EnvObsType,
|
| 13 |
+
EnvType,
|
| 14 |
+
MultiEnvDict,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from ray.rllib.models.preprocessors import Preprocessor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@OldAPIStack
|
| 23 |
+
class ExternalEnv(threading.Thread):
|
| 24 |
+
"""An environment that interfaces with external agents.
|
| 25 |
+
|
| 26 |
+
Unlike simulator envs, control is inverted: The environment queries the
|
| 27 |
+
policy to obtain actions and in return logs observations and rewards for
|
| 28 |
+
training. This is in contrast to gym.Env, where the algorithm drives the
|
| 29 |
+
simulation through env.step() calls.
|
| 30 |
+
|
| 31 |
+
You can use ExternalEnv as the backend for policy serving (by serving HTTP
|
| 32 |
+
requests in the run loop), for ingesting offline logs data (by reading
|
| 33 |
+
offline transitions in the run loop), or other custom use cases not easily
|
| 34 |
+
expressed through gym.Env.
|
| 35 |
+
|
| 36 |
+
ExternalEnv supports both on-policy actions (through self.get_action()),
|
| 37 |
+
and off-policy actions (through self.log_action()).
|
| 38 |
+
|
| 39 |
+
This env is thread-safe, but individual episodes must be executed serially.
|
| 40 |
+
|
| 41 |
+
.. testcode::
|
| 42 |
+
:skipif: True
|
| 43 |
+
|
| 44 |
+
from ray.tune import register_env
|
| 45 |
+
from ray.rllib.algorithms.dqn import DQN
|
| 46 |
+
YourExternalEnv = ...
|
| 47 |
+
register_env("my_env", lambda config: YourExternalEnv(config))
|
| 48 |
+
algo = DQN(env="my_env")
|
| 49 |
+
while True:
|
| 50 |
+
print(algo.train())
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
action_space: gym.Space,
|
| 56 |
+
observation_space: gym.Space,
|
| 57 |
+
max_concurrent: int = None,
|
| 58 |
+
):
|
| 59 |
+
"""Initializes an ExternalEnv instance.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
action_space: Action space of the env.
|
| 63 |
+
observation_space: Observation space of the env.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
threading.Thread.__init__(self)
|
| 67 |
+
|
| 68 |
+
self.daemon = True
|
| 69 |
+
self.action_space = action_space
|
| 70 |
+
self.observation_space = observation_space
|
| 71 |
+
self._episodes = {}
|
| 72 |
+
self._finished = set()
|
| 73 |
+
self._results_avail_condition = threading.Condition()
|
| 74 |
+
if max_concurrent is not None:
|
| 75 |
+
deprecation_warning(
|
| 76 |
+
"The `max_concurrent` argument has been deprecated. Please configure"
|
| 77 |
+
"the number of episodes using the `rollout_fragment_length` and"
|
| 78 |
+
"`batch_mode` arguments. Please raise an issue on the Ray Github if "
|
| 79 |
+
"these arguments do not support your expected use case for ExternalEnv",
|
| 80 |
+
error=True,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def run(self):
|
| 84 |
+
"""Override this to implement the run loop.
|
| 85 |
+
|
| 86 |
+
Your loop should continuously:
|
| 87 |
+
1. Call self.start_episode(episode_id)
|
| 88 |
+
2. Call self.[get|log]_action(episode_id, obs, [action]?)
|
| 89 |
+
3. Call self.log_returns(episode_id, reward)
|
| 90 |
+
4. Call self.end_episode(episode_id, obs)
|
| 91 |
+
5. Wait if nothing to do.
|
| 92 |
+
|
| 93 |
+
Multiple episodes may be started at the same time.
|
| 94 |
+
"""
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
def start_episode(
|
| 98 |
+
self, episode_id: Optional[str] = None, training_enabled: bool = True
|
| 99 |
+
) -> str:
|
| 100 |
+
"""Record the start of an episode.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
episode_id: Unique string id for the episode or
|
| 104 |
+
None for it to be auto-assigned and returned.
|
| 105 |
+
training_enabled: Whether to use experiences for this
|
| 106 |
+
episode to improve the policy.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Unique string id for the episode.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
if episode_id is None:
|
| 113 |
+
episode_id = uuid.uuid4().hex
|
| 114 |
+
|
| 115 |
+
if episode_id in self._finished:
|
| 116 |
+
raise ValueError("Episode {} has already completed.".format(episode_id))
|
| 117 |
+
|
| 118 |
+
if episode_id in self._episodes:
|
| 119 |
+
raise ValueError("Episode {} is already started".format(episode_id))
|
| 120 |
+
|
| 121 |
+
self._episodes[episode_id] = _ExternalEnvEpisode(
|
| 122 |
+
episode_id, self._results_avail_condition, training_enabled
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return episode_id
|
| 126 |
+
|
| 127 |
+
def get_action(self, episode_id: str, observation: EnvObsType) -> EnvActionType:
|
| 128 |
+
"""Record an observation and get the on-policy action.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
episode_id: Episode id returned from start_episode().
|
| 132 |
+
observation: Current environment observation.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Action from the env action space.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
episode = self._get(episode_id)
|
| 139 |
+
return episode.wait_for_action(observation)
|
| 140 |
+
|
| 141 |
+
def log_action(
|
| 142 |
+
self, episode_id: str, observation: EnvObsType, action: EnvActionType
|
| 143 |
+
) -> None:
|
| 144 |
+
"""Record an observation and (off-policy) action taken.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
episode_id: Episode id returned from start_episode().
|
| 148 |
+
observation: Current environment observation.
|
| 149 |
+
action: Action for the observation.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
episode = self._get(episode_id)
|
| 153 |
+
episode.log_action(observation, action)
|
| 154 |
+
|
| 155 |
+
def log_returns(
|
| 156 |
+
self, episode_id: str, reward: float, info: Optional[EnvInfoDict] = None
|
| 157 |
+
) -> None:
|
| 158 |
+
"""Records returns (rewards and infos) from the environment.
|
| 159 |
+
|
| 160 |
+
The reward will be attributed to the previous action taken by the
|
| 161 |
+
episode. Rewards accumulate until the next action. If no reward is
|
| 162 |
+
logged before the next action, a reward of 0.0 is assumed.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
episode_id: Episode id returned from start_episode().
|
| 166 |
+
reward: Reward from the environment.
|
| 167 |
+
info: Optional info dict.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
episode = self._get(episode_id)
|
| 171 |
+
episode.cur_reward += reward
|
| 172 |
+
|
| 173 |
+
if info:
|
| 174 |
+
episode.cur_info = info or {}
|
| 175 |
+
|
| 176 |
+
def end_episode(self, episode_id: str, observation: EnvObsType) -> None:
|
| 177 |
+
"""Records the end of an episode.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
episode_id: Episode id returned from start_episode().
|
| 181 |
+
observation: Current environment observation.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
episode = self._get(episode_id)
|
| 185 |
+
self._finished.add(episode.episode_id)
|
| 186 |
+
episode.done(observation)
|
| 187 |
+
|
| 188 |
+
def _get(self, episode_id: str) -> "_ExternalEnvEpisode":
|
| 189 |
+
"""Get a started episode by its ID or raise an error."""
|
| 190 |
+
|
| 191 |
+
if episode_id in self._finished:
|
| 192 |
+
raise ValueError("Episode {} has already completed.".format(episode_id))
|
| 193 |
+
|
| 194 |
+
if episode_id not in self._episodes:
|
| 195 |
+
raise ValueError("Episode {} not found.".format(episode_id))
|
| 196 |
+
|
| 197 |
+
return self._episodes[episode_id]
|
| 198 |
+
|
| 199 |
+
def to_base_env(
|
| 200 |
+
self,
|
| 201 |
+
make_env: Optional[Callable[[int], EnvType]] = None,
|
| 202 |
+
num_envs: int = 1,
|
| 203 |
+
remote_envs: bool = False,
|
| 204 |
+
remote_env_batch_wait_ms: int = 0,
|
| 205 |
+
restart_failed_sub_environments: bool = False,
|
| 206 |
+
) -> "BaseEnv":
|
| 207 |
+
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
| 208 |
+
|
| 209 |
+
The resulting BaseEnv is always vectorized (contains n
|
| 210 |
+
sub-environments) to support batched forward passes, where n may
|
| 211 |
+
also be 1. BaseEnv also supports async execution via the `poll` and
|
| 212 |
+
`send_actions` methods and thus supports external simulators.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
make_env: A callable taking an int as input (which indicates
|
| 216 |
+
the number of individual sub-environments within the final
|
| 217 |
+
vectorized BaseEnv) and returning one individual
|
| 218 |
+
sub-environment.
|
| 219 |
+
num_envs: The number of sub-environments to create in the
|
| 220 |
+
resulting (vectorized) BaseEnv. The already existing `env`
|
| 221 |
+
will be one of the `num_envs`.
|
| 222 |
+
remote_envs: Whether each sub-env should be a @ray.remote
|
| 223 |
+
actor. You can set this behavior in your config via the
|
| 224 |
+
`remote_worker_envs=True` option.
|
| 225 |
+
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
| 226 |
+
sub-environments for, if applicable. Only used if
|
| 227 |
+
`remote_envs` is True.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
The resulting BaseEnv object.
|
| 231 |
+
"""
|
| 232 |
+
if num_envs != 1:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
"External(MultiAgent)Env does not currently support "
|
| 235 |
+
"num_envs > 1. One way of solving this would be to "
|
| 236 |
+
"treat your Env as a MultiAgentEnv hosting only one "
|
| 237 |
+
"type of agent but with several copies."
|
| 238 |
+
)
|
| 239 |
+
env = ExternalEnvWrapper(self)
|
| 240 |
+
|
| 241 |
+
return env
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@OldAPIStack
|
| 245 |
+
class _ExternalEnvEpisode:
|
| 246 |
+
"""Tracked state for each active episode."""
|
| 247 |
+
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
episode_id: str,
|
| 251 |
+
results_avail_condition: threading.Condition,
|
| 252 |
+
training_enabled: bool,
|
| 253 |
+
multiagent: bool = False,
|
| 254 |
+
):
|
| 255 |
+
self.episode_id = episode_id
|
| 256 |
+
self.results_avail_condition = results_avail_condition
|
| 257 |
+
self.training_enabled = training_enabled
|
| 258 |
+
self.multiagent = multiagent
|
| 259 |
+
self.data_queue = queue.Queue()
|
| 260 |
+
self.action_queue = queue.Queue()
|
| 261 |
+
if multiagent:
|
| 262 |
+
self.new_observation_dict = None
|
| 263 |
+
self.new_action_dict = None
|
| 264 |
+
self.cur_reward_dict = {}
|
| 265 |
+
self.cur_terminated_dict = {"__all__": False}
|
| 266 |
+
self.cur_truncated_dict = {"__all__": False}
|
| 267 |
+
self.cur_info_dict = {}
|
| 268 |
+
else:
|
| 269 |
+
self.new_observation = None
|
| 270 |
+
self.new_action = None
|
| 271 |
+
self.cur_reward = 0.0
|
| 272 |
+
self.cur_terminated = False
|
| 273 |
+
self.cur_truncated = False
|
| 274 |
+
self.cur_info = {}
|
| 275 |
+
|
| 276 |
+
def get_data(self):
|
| 277 |
+
if self.data_queue.empty():
|
| 278 |
+
return None
|
| 279 |
+
return self.data_queue.get_nowait()
|
| 280 |
+
|
| 281 |
+
def log_action(self, observation, action):
|
| 282 |
+
if self.multiagent:
|
| 283 |
+
self.new_observation_dict = observation
|
| 284 |
+
self.new_action_dict = action
|
| 285 |
+
else:
|
| 286 |
+
self.new_observation = observation
|
| 287 |
+
self.new_action = action
|
| 288 |
+
self._send()
|
| 289 |
+
self.action_queue.get(True, timeout=60.0)
|
| 290 |
+
|
| 291 |
+
def wait_for_action(self, observation):
|
| 292 |
+
if self.multiagent:
|
| 293 |
+
self.new_observation_dict = observation
|
| 294 |
+
else:
|
| 295 |
+
self.new_observation = observation
|
| 296 |
+
self._send()
|
| 297 |
+
return self.action_queue.get(True, timeout=300.0)
|
| 298 |
+
|
| 299 |
+
def done(self, observation):
|
| 300 |
+
if self.multiagent:
|
| 301 |
+
self.new_observation_dict = observation
|
| 302 |
+
self.cur_terminated_dict = {"__all__": True}
|
| 303 |
+
# TODO(sven): External env API does not currently support truncated,
|
| 304 |
+
# but we should deprecate external Env anyways in favor of a client-only
|
| 305 |
+
# approach.
|
| 306 |
+
self.cur_truncated_dict = {"__all__": False}
|
| 307 |
+
else:
|
| 308 |
+
self.new_observation = observation
|
| 309 |
+
self.cur_terminated = True
|
| 310 |
+
self.cur_truncated = False
|
| 311 |
+
self._send()
|
| 312 |
+
|
| 313 |
+
def _send(self):
|
| 314 |
+
if self.multiagent:
|
| 315 |
+
if not self.training_enabled:
|
| 316 |
+
for agent_id in self.cur_info_dict:
|
| 317 |
+
self.cur_info_dict[agent_id]["training_enabled"] = False
|
| 318 |
+
item = {
|
| 319 |
+
"obs": self.new_observation_dict,
|
| 320 |
+
"reward": self.cur_reward_dict,
|
| 321 |
+
"terminated": self.cur_terminated_dict,
|
| 322 |
+
"truncated": self.cur_truncated_dict,
|
| 323 |
+
"info": self.cur_info_dict,
|
| 324 |
+
}
|
| 325 |
+
if self.new_action_dict is not None:
|
| 326 |
+
item["off_policy_action"] = self.new_action_dict
|
| 327 |
+
self.new_observation_dict = None
|
| 328 |
+
self.new_action_dict = None
|
| 329 |
+
self.cur_reward_dict = {}
|
| 330 |
+
else:
|
| 331 |
+
item = {
|
| 332 |
+
"obs": self.new_observation,
|
| 333 |
+
"reward": self.cur_reward,
|
| 334 |
+
"terminated": self.cur_terminated,
|
| 335 |
+
"truncated": self.cur_truncated,
|
| 336 |
+
"info": self.cur_info,
|
| 337 |
+
}
|
| 338 |
+
if self.new_action is not None:
|
| 339 |
+
item["off_policy_action"] = self.new_action
|
| 340 |
+
self.new_observation = None
|
| 341 |
+
self.new_action = None
|
| 342 |
+
self.cur_reward = 0.0
|
| 343 |
+
if not self.training_enabled:
|
| 344 |
+
item["info"]["training_enabled"] = False
|
| 345 |
+
|
| 346 |
+
with self.results_avail_condition:
|
| 347 |
+
self.data_queue.put_nowait(item)
|
| 348 |
+
self.results_avail_condition.notify()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@OldAPIStack
|
| 352 |
+
class ExternalEnvWrapper(BaseEnv):
|
| 353 |
+
"""Internal adapter of ExternalEnv to BaseEnv."""
|
| 354 |
+
|
| 355 |
+
def __init__(
|
| 356 |
+
self, external_env: "ExternalEnv", preprocessor: "Preprocessor" = None
|
| 357 |
+
):
|
| 358 |
+
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
| 359 |
+
|
| 360 |
+
self.external_env = external_env
|
| 361 |
+
self.prep = preprocessor
|
| 362 |
+
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
|
| 363 |
+
self._action_space = external_env.action_space
|
| 364 |
+
if preprocessor:
|
| 365 |
+
self._observation_space = preprocessor.observation_space
|
| 366 |
+
else:
|
| 367 |
+
self._observation_space = external_env.observation_space
|
| 368 |
+
external_env.start()
|
| 369 |
+
|
| 370 |
+
@override(BaseEnv)
|
| 371 |
+
def poll(
|
| 372 |
+
self,
|
| 373 |
+
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
|
| 374 |
+
with self.external_env._results_avail_condition:
|
| 375 |
+
results = self._poll()
|
| 376 |
+
while len(results[0]) == 0:
|
| 377 |
+
self.external_env._results_avail_condition.wait()
|
| 378 |
+
results = self._poll()
|
| 379 |
+
if not self.external_env.is_alive():
|
| 380 |
+
raise Exception("Serving thread has stopped.")
|
| 381 |
+
return results
|
| 382 |
+
|
| 383 |
+
@override(BaseEnv)
|
| 384 |
+
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
| 385 |
+
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
| 386 |
+
|
| 387 |
+
if self.multiagent:
|
| 388 |
+
for env_id, actions in action_dict.items():
|
| 389 |
+
self.external_env._episodes[env_id].action_queue.put(actions)
|
| 390 |
+
else:
|
| 391 |
+
for env_id, action in action_dict.items():
|
| 392 |
+
self.external_env._episodes[env_id].action_queue.put(
|
| 393 |
+
action[_DUMMY_AGENT_ID]
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def _poll(
|
| 397 |
+
self,
|
| 398 |
+
) -> Tuple[
|
| 399 |
+
MultiEnvDict,
|
| 400 |
+
MultiEnvDict,
|
| 401 |
+
MultiEnvDict,
|
| 402 |
+
MultiEnvDict,
|
| 403 |
+
MultiEnvDict,
|
| 404 |
+
MultiEnvDict,
|
| 405 |
+
]:
|
| 406 |
+
from ray.rllib.env.base_env import with_dummy_agent_id
|
| 407 |
+
|
| 408 |
+
all_obs, all_rewards, all_terminateds, all_truncateds, all_infos = (
|
| 409 |
+
{},
|
| 410 |
+
{},
|
| 411 |
+
{},
|
| 412 |
+
{},
|
| 413 |
+
{},
|
| 414 |
+
)
|
| 415 |
+
off_policy_actions = {}
|
| 416 |
+
for eid, episode in self.external_env._episodes.copy().items():
|
| 417 |
+
data = episode.get_data()
|
| 418 |
+
cur_terminated = (
|
| 419 |
+
episode.cur_terminated_dict["__all__"]
|
| 420 |
+
if self.multiagent
|
| 421 |
+
else episode.cur_terminated
|
| 422 |
+
)
|
| 423 |
+
cur_truncated = (
|
| 424 |
+
episode.cur_truncated_dict["__all__"]
|
| 425 |
+
if self.multiagent
|
| 426 |
+
else episode.cur_truncated
|
| 427 |
+
)
|
| 428 |
+
if cur_terminated or cur_truncated:
|
| 429 |
+
del self.external_env._episodes[eid]
|
| 430 |
+
if data:
|
| 431 |
+
if self.prep:
|
| 432 |
+
all_obs[eid] = self.prep.transform(data["obs"])
|
| 433 |
+
else:
|
| 434 |
+
all_obs[eid] = data["obs"]
|
| 435 |
+
all_rewards[eid] = data["reward"]
|
| 436 |
+
all_terminateds[eid] = data["terminated"]
|
| 437 |
+
all_truncateds[eid] = data["truncated"]
|
| 438 |
+
all_infos[eid] = data["info"]
|
| 439 |
+
if "off_policy_action" in data:
|
| 440 |
+
off_policy_actions[eid] = data["off_policy_action"]
|
| 441 |
+
if self.multiagent:
|
| 442 |
+
# Ensure a consistent set of keys
|
| 443 |
+
# rely on all_obs having all possible keys for now.
|
| 444 |
+
for eid, eid_dict in all_obs.items():
|
| 445 |
+
for agent_id in eid_dict.keys():
|
| 446 |
+
|
| 447 |
+
def fix(d, zero_val):
|
| 448 |
+
if agent_id not in d[eid]:
|
| 449 |
+
d[eid][agent_id] = zero_val
|
| 450 |
+
|
| 451 |
+
fix(all_rewards, 0.0)
|
| 452 |
+
fix(all_terminateds, False)
|
| 453 |
+
fix(all_truncateds, False)
|
| 454 |
+
fix(all_infos, {})
|
| 455 |
+
return (
|
| 456 |
+
all_obs,
|
| 457 |
+
all_rewards,
|
| 458 |
+
all_terminateds,
|
| 459 |
+
all_truncateds,
|
| 460 |
+
all_infos,
|
| 461 |
+
off_policy_actions,
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
return (
|
| 465 |
+
with_dummy_agent_id(all_obs),
|
| 466 |
+
with_dummy_agent_id(all_rewards),
|
| 467 |
+
with_dummy_agent_id(all_terminateds, "__all__"),
|
| 468 |
+
with_dummy_agent_id(all_truncateds, "__all__"),
|
| 469 |
+
with_dummy_agent_id(all_infos),
|
| 470 |
+
with_dummy_agent_id(off_policy_actions),
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
@override(BaseEnv)
|
| 475 |
+
def observation_space(self) -> gym.spaces.Dict:
|
| 476 |
+
return self._observation_space
|
| 477 |
+
|
| 478 |
+
@property
|
| 479 |
+
@override(BaseEnv)
|
| 480 |
+
def action_space(self) -> gym.Space:
|
| 481 |
+
return self._action_space
|
.venv/lib/python3.11/site-packages/ray/rllib/env/external_multi_agent_env.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from ray.rllib.utils.annotations import override, OldAPIStack
|
| 6 |
+
from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode
|
| 7 |
+
from ray.rllib.utils.typing import MultiAgentDict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@OldAPIStack
|
| 11 |
+
class ExternalMultiAgentEnv(ExternalEnv):
|
| 12 |
+
"""This is the multi-agent version of ExternalEnv."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
action_space: gym.Space,
|
| 17 |
+
observation_space: gym.Space,
|
| 18 |
+
):
|
| 19 |
+
"""Initializes an ExternalMultiAgentEnv instance.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
action_space: Action space of the env.
|
| 23 |
+
observation_space: Observation space of the env.
|
| 24 |
+
"""
|
| 25 |
+
ExternalEnv.__init__(self, action_space, observation_space)
|
| 26 |
+
|
| 27 |
+
# We require to know all agents' spaces.
|
| 28 |
+
if isinstance(self.action_space, dict) or isinstance(
|
| 29 |
+
self.observation_space, dict
|
| 30 |
+
):
|
| 31 |
+
if not (self.action_space.keys() == self.observation_space.keys()):
|
| 32 |
+
raise ValueError(
|
| 33 |
+
"Agent ids disagree for action space and obs "
|
| 34 |
+
"space dict: {} {}".format(
|
| 35 |
+
self.action_space.keys(), self.observation_space.keys()
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def run(self):
|
| 40 |
+
"""Override this to implement the multi-agent run loop.
|
| 41 |
+
|
| 42 |
+
Your loop should continuously:
|
| 43 |
+
1. Call self.start_episode(episode_id)
|
| 44 |
+
2. Call self.get_action(episode_id, obs_dict)
|
| 45 |
+
-or-
|
| 46 |
+
self.log_action(episode_id, obs_dict, action_dict)
|
| 47 |
+
3. Call self.log_returns(episode_id, reward_dict)
|
| 48 |
+
4. Call self.end_episode(episode_id, obs_dict)
|
| 49 |
+
5. Wait if nothing to do.
|
| 50 |
+
|
| 51 |
+
Multiple episodes may be started at the same time.
|
| 52 |
+
"""
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
@override(ExternalEnv)
|
| 56 |
+
def start_episode(
|
| 57 |
+
self, episode_id: Optional[str] = None, training_enabled: bool = True
|
| 58 |
+
) -> str:
|
| 59 |
+
if episode_id is None:
|
| 60 |
+
episode_id = uuid.uuid4().hex
|
| 61 |
+
|
| 62 |
+
if episode_id in self._finished:
|
| 63 |
+
raise ValueError("Episode {} has already completed.".format(episode_id))
|
| 64 |
+
|
| 65 |
+
if episode_id in self._episodes:
|
| 66 |
+
raise ValueError("Episode {} is already started".format(episode_id))
|
| 67 |
+
|
| 68 |
+
self._episodes[episode_id] = _ExternalEnvEpisode(
|
| 69 |
+
episode_id, self._results_avail_condition, training_enabled, multiagent=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return episode_id
|
| 73 |
+
|
| 74 |
+
@override(ExternalEnv)
|
| 75 |
+
def get_action(
|
| 76 |
+
self, episode_id: str, observation_dict: MultiAgentDict
|
| 77 |
+
) -> MultiAgentDict:
|
| 78 |
+
"""Record an observation and get the on-policy action.
|
| 79 |
+
|
| 80 |
+
Thereby, observation_dict is expected to contain the observation
|
| 81 |
+
of all agents acting in this episode step.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
episode_id: Episode id returned from start_episode().
|
| 85 |
+
observation_dict: Current environment observation.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
action: Action from the env action space.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
episode = self._get(episode_id)
|
| 92 |
+
return episode.wait_for_action(observation_dict)
|
| 93 |
+
|
| 94 |
+
@override(ExternalEnv)
|
| 95 |
+
def log_action(
|
| 96 |
+
self,
|
| 97 |
+
episode_id: str,
|
| 98 |
+
observation_dict: MultiAgentDict,
|
| 99 |
+
action_dict: MultiAgentDict,
|
| 100 |
+
) -> None:
|
| 101 |
+
"""Record an observation and (off-policy) action taken.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
episode_id: Episode id returned from start_episode().
|
| 105 |
+
observation_dict: Current environment observation.
|
| 106 |
+
action_dict: Action for the observation.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
episode = self._get(episode_id)
|
| 110 |
+
episode.log_action(observation_dict, action_dict)
|
| 111 |
+
|
| 112 |
+
@override(ExternalEnv)
|
| 113 |
+
def log_returns(
|
| 114 |
+
self,
|
| 115 |
+
episode_id: str,
|
| 116 |
+
reward_dict: MultiAgentDict,
|
| 117 |
+
info_dict: MultiAgentDict = None,
|
| 118 |
+
multiagent_done_dict: MultiAgentDict = None,
|
| 119 |
+
) -> None:
|
| 120 |
+
"""Record returns from the environment.
|
| 121 |
+
|
| 122 |
+
The reward will be attributed to the previous action taken by the
|
| 123 |
+
episode. Rewards accumulate until the next action. If no reward is
|
| 124 |
+
logged before the next action, a reward of 0.0 is assumed.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
episode_id: Episode id returned from start_episode().
|
| 128 |
+
reward_dict: Reward from the environment agents.
|
| 129 |
+
info_dict: Optional info dict.
|
| 130 |
+
multiagent_done_dict: Optional done dict for agents.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
episode = self._get(episode_id)
|
| 134 |
+
|
| 135 |
+
# Accumulate reward by agent.
|
| 136 |
+
# For existing agents, we want to add the reward up.
|
| 137 |
+
for agent, rew in reward_dict.items():
|
| 138 |
+
if agent in episode.cur_reward_dict:
|
| 139 |
+
episode.cur_reward_dict[agent] += rew
|
| 140 |
+
else:
|
| 141 |
+
episode.cur_reward_dict[agent] = rew
|
| 142 |
+
|
| 143 |
+
if multiagent_done_dict:
|
| 144 |
+
for agent, done in multiagent_done_dict.items():
|
| 145 |
+
episode.cur_done_dict[agent] = done
|
| 146 |
+
|
| 147 |
+
if info_dict:
|
| 148 |
+
episode.cur_info_dict = info_dict or {}
|
| 149 |
+
|
| 150 |
+
@override(ExternalEnv)
|
| 151 |
+
def end_episode(self, episode_id: str, observation_dict: MultiAgentDict) -> None:
|
| 152 |
+
"""Record the end of an episode.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
episode_id: Episode id returned from start_episode().
|
| 156 |
+
observation_dict: Current environment observation.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
episode = self._get(episode_id)
|
| 160 |
+
self._finished.add(episode.episode_id)
|
| 161 |
+
episode.done(observation_dict)
|
.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ray.rllib.env.base_env import BaseEnv
|
| 8 |
+
from ray.rllib.env.env_context import EnvContext
|
| 9 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 10 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 11 |
+
from ray.rllib.utils.typing import (
|
| 12 |
+
AgentID,
|
| 13 |
+
EnvCreator,
|
| 14 |
+
EnvID,
|
| 15 |
+
EnvType,
|
| 16 |
+
MultiAgentDict,
|
| 17 |
+
MultiEnvDict,
|
| 18 |
+
)
|
| 19 |
+
from ray.util import log_once
|
| 20 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 21 |
+
|
| 22 |
+
# If the obs space is Dict type, look for the global state under this key.
|
| 23 |
+
ENV_STATE = "state"
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@PublicAPI(stability="beta")
|
| 29 |
+
class MultiAgentEnv(gym.Env):
|
| 30 |
+
"""An environment that hosts multiple independent agents.
|
| 31 |
+
|
| 32 |
+
Agents are identified by AgentIDs (string).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# Optional mappings from AgentID to individual agents' spaces.
|
| 36 |
+
# Set this to an "exhaustive" dictionary, mapping all possible AgentIDs to
|
| 37 |
+
# individual agents' spaces. Alternatively, override
|
| 38 |
+
# `get_observation_space(agent_id=...)` and `get_action_space(agent_id=...)`, which
|
| 39 |
+
# is the API that RLlib uses to get individual spaces and whose default
|
| 40 |
+
# implementation is to simply look up `agent_id` in these dicts.
|
| 41 |
+
observation_spaces: Optional[Dict[AgentID, gym.Space]] = None
|
| 42 |
+
action_spaces: Optional[Dict[AgentID, gym.Space]] = None
|
| 43 |
+
|
| 44 |
+
# All agents currently active in the environment. This attribute may change during
|
| 45 |
+
# the lifetime of the env or even during an individual episode.
|
| 46 |
+
agents: List[AgentID] = []
|
| 47 |
+
# All agents that may appear in the environment, ever.
|
| 48 |
+
# This attribute should not be changed during the lifetime of this env.
|
| 49 |
+
possible_agents: List[AgentID] = []
|
| 50 |
+
|
| 51 |
+
# @OldAPIStack, use `observation_spaces` and `action_spaces`, instead.
|
| 52 |
+
observation_space: Optional[gym.Space] = None
|
| 53 |
+
action_space: Optional[gym.Space] = None
|
| 54 |
+
|
| 55 |
+
def __init__(self):
|
| 56 |
+
super().__init__()
|
| 57 |
+
|
| 58 |
+
# @OldAPIStack
|
| 59 |
+
if not hasattr(self, "_agent_ids"):
|
| 60 |
+
self._agent_ids = set()
|
| 61 |
+
|
| 62 |
+
# If these important attributes are not set, try to infer them.
|
| 63 |
+
if not self.agents:
|
| 64 |
+
self.agents = list(self._agent_ids)
|
| 65 |
+
if not self.possible_agents:
|
| 66 |
+
self.possible_agents = self.agents.copy()
|
| 67 |
+
|
| 68 |
+
def reset(
|
| 69 |
+
self,
|
| 70 |
+
*,
|
| 71 |
+
seed: Optional[int] = None,
|
| 72 |
+
options: Optional[dict] = None,
|
| 73 |
+
) -> Tuple[MultiAgentDict, MultiAgentDict]: # type: ignore
|
| 74 |
+
"""Resets the env and returns observations from ready agents.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
seed: An optional seed to use for the new episode.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
New observations for each ready agent.
|
| 81 |
+
|
| 82 |
+
.. testcode::
|
| 83 |
+
:skipif: True
|
| 84 |
+
|
| 85 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 86 |
+
class MyMultiAgentEnv(MultiAgentEnv):
|
| 87 |
+
# Define your env here.
|
| 88 |
+
env = MyMultiAgentEnv()
|
| 89 |
+
obs, infos = env.reset(seed=42, options={})
|
| 90 |
+
print(obs)
|
| 91 |
+
|
| 92 |
+
.. testoutput::
|
| 93 |
+
|
| 94 |
+
{
|
| 95 |
+
"car_0": [2.4, 1.6],
|
| 96 |
+
"car_1": [3.4, -3.2],
|
| 97 |
+
"traffic_light_1": [0, 3, 5, 1],
|
| 98 |
+
}
|
| 99 |
+
"""
|
| 100 |
+
# Call super's `reset()` method to (maybe) set the given `seed`.
|
| 101 |
+
super().reset(seed=seed, options=options)
|
| 102 |
+
|
| 103 |
+
def step(
|
| 104 |
+
self, action_dict: MultiAgentDict
|
| 105 |
+
) -> Tuple[
|
| 106 |
+
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
|
| 107 |
+
]:
|
| 108 |
+
"""Returns observations from ready agents.
|
| 109 |
+
|
| 110 |
+
The returns are dicts mapping from agent_id strings to values. The
|
| 111 |
+
number of agents in the env can vary over time.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tuple containing 1) new observations for
|
| 115 |
+
each ready agent, 2) reward values for each ready agent. If
|
| 116 |
+
the episode is just started, the value will be None.
|
| 117 |
+
3) Terminated values for each ready agent. The special key
|
| 118 |
+
"__all__" (required) is used to indicate env termination.
|
| 119 |
+
4) Truncated values for each ready agent.
|
| 120 |
+
5) Info values for each agent id (may be empty dicts).
|
| 121 |
+
|
| 122 |
+
.. testcode::
|
| 123 |
+
:skipif: True
|
| 124 |
+
|
| 125 |
+
env = ...
|
| 126 |
+
obs, rewards, terminateds, truncateds, infos = env.step(action_dict={
|
| 127 |
+
"car_0": 1, "car_1": 0, "traffic_light_1": 2,
|
| 128 |
+
})
|
| 129 |
+
print(rewards)
|
| 130 |
+
|
| 131 |
+
print(terminateds)
|
| 132 |
+
|
| 133 |
+
print(infos)
|
| 134 |
+
|
| 135 |
+
.. testoutput::
|
| 136 |
+
|
| 137 |
+
{
|
| 138 |
+
"car_0": 3,
|
| 139 |
+
"car_1": -1,
|
| 140 |
+
"traffic_light_1": 0,
|
| 141 |
+
}
|
| 142 |
+
{
|
| 143 |
+
"car_0": False, # car_0 is still running
|
| 144 |
+
"car_1": True, # car_1 is terminated
|
| 145 |
+
"__all__": False, # the env is not terminated
|
| 146 |
+
}
|
| 147 |
+
{
|
| 148 |
+
"car_0": {}, # info for car_0
|
| 149 |
+
"car_1": {}, # info for car_1
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
"""
|
| 153 |
+
raise NotImplementedError
|
| 154 |
+
|
| 155 |
+
def render(self) -> None:
|
| 156 |
+
"""Tries to render the environment."""
|
| 157 |
+
|
| 158 |
+
# By default, do nothing.
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
def get_observation_space(self, agent_id: AgentID) -> gym.Space:
|
| 162 |
+
if self.observation_spaces is not None:
|
| 163 |
+
return self.observation_spaces[agent_id]
|
| 164 |
+
|
| 165 |
+
# @OldAPIStack behavior.
|
| 166 |
+
# `self.observation_space` is a `gym.spaces.Dict` AND contains `agent_id`.
|
| 167 |
+
if (
|
| 168 |
+
isinstance(self.observation_space, gym.spaces.Dict)
|
| 169 |
+
and agent_id in self.observation_space.spaces
|
| 170 |
+
):
|
| 171 |
+
return self.observation_space[agent_id]
|
| 172 |
+
# `self.observation_space` is not a `gym.spaces.Dict` OR doesn't contain
|
| 173 |
+
# `agent_id` -> The defined space is most likely meant to be the space
|
| 174 |
+
# for all agents.
|
| 175 |
+
else:
|
| 176 |
+
return self.observation_space
|
| 177 |
+
|
| 178 |
+
def get_action_space(self, agent_id: AgentID) -> gym.Space:
|
| 179 |
+
if self.action_spaces is not None:
|
| 180 |
+
return self.action_spaces[agent_id]
|
| 181 |
+
|
| 182 |
+
# @OldAPIStack behavior.
|
| 183 |
+
# `self.action_space` is a `gym.spaces.Dict` AND contains `agent_id`.
|
| 184 |
+
if (
|
| 185 |
+
isinstance(self.action_space, gym.spaces.Dict)
|
| 186 |
+
and agent_id in self.action_space.spaces
|
| 187 |
+
):
|
| 188 |
+
return self.action_space[agent_id]
|
| 189 |
+
# `self.action_space` is not a `gym.spaces.Dict` OR doesn't contain
|
| 190 |
+
# `agent_id` -> The defined space is most likely meant to be the space
|
| 191 |
+
# for all agents.
|
| 192 |
+
else:
|
| 193 |
+
return self.action_space
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def num_agents(self) -> int:
|
| 197 |
+
return len(self.agents)
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def max_num_agents(self) -> int:
|
| 201 |
+
return len(self.possible_agents)
|
| 202 |
+
|
| 203 |
+
# fmt: off
|
| 204 |
+
# __grouping_doc_begin__
|
| 205 |
+
def with_agent_groups(
|
| 206 |
+
self,
|
| 207 |
+
groups: Dict[str, List[AgentID]],
|
| 208 |
+
obs_space: gym.Space = None,
|
| 209 |
+
act_space: gym.Space = None,
|
| 210 |
+
) -> "MultiAgentEnv":
|
| 211 |
+
"""Convenience method for grouping together agents in this env.
|
| 212 |
+
|
| 213 |
+
An agent group is a list of agent IDs that are mapped to a single
|
| 214 |
+
logical agent. All agents of the group must act at the same time in the
|
| 215 |
+
environment. The grouped agent exposes Tuple action and observation
|
| 216 |
+
spaces that are the concatenated action and obs spaces of the
|
| 217 |
+
individual agents.
|
| 218 |
+
|
| 219 |
+
The rewards of all the agents in a group are summed. The individual
|
| 220 |
+
agent rewards are available under the "individual_rewards" key of the
|
| 221 |
+
group info return.
|
| 222 |
+
|
| 223 |
+
Agent grouping is required to leverage algorithms such as Q-Mix.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
groups: Mapping from group id to a list of the agent ids
|
| 227 |
+
of group members. If an agent id is not present in any group
|
| 228 |
+
value, it will be left ungrouped. The group id becomes a new agent ID
|
| 229 |
+
in the final environment.
|
| 230 |
+
obs_space: Optional observation space for the grouped
|
| 231 |
+
env. Must be a tuple space. If not provided, will infer this to be a
|
| 232 |
+
Tuple of n individual agents spaces (n=num agents in a group).
|
| 233 |
+
act_space: Optional action space for the grouped env.
|
| 234 |
+
Must be a tuple space. If not provided, will infer this to be a Tuple
|
| 235 |
+
of n individual agents spaces (n=num agents in a group).
|
| 236 |
+
|
| 237 |
+
.. testcode::
|
| 238 |
+
:skipif: True
|
| 239 |
+
|
| 240 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 241 |
+
class MyMultiAgentEnv(MultiAgentEnv):
|
| 242 |
+
# define your env here
|
| 243 |
+
...
|
| 244 |
+
env = MyMultiAgentEnv(...)
|
| 245 |
+
grouped_env = env.with_agent_groups(env, {
|
| 246 |
+
"group1": ["agent1", "agent2", "agent3"],
|
| 247 |
+
"group2": ["agent4", "agent5"],
|
| 248 |
+
})
|
| 249 |
+
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
from ray.rllib.env.wrappers.group_agents_wrapper import \
|
| 253 |
+
GroupAgentsWrapper
|
| 254 |
+
return GroupAgentsWrapper(self, groups, obs_space, act_space)
|
| 255 |
+
|
| 256 |
+
# __grouping_doc_end__
|
| 257 |
+
# fmt: on
|
| 258 |
+
|
| 259 |
+
@OldAPIStack
|
| 260 |
+
@Deprecated(new="MultiAgentEnv.possible_agents", error=False)
|
| 261 |
+
def get_agent_ids(self) -> Set[AgentID]:
|
| 262 |
+
if not hasattr(self, "_agent_ids"):
|
| 263 |
+
self._agent_ids = set()
|
| 264 |
+
if not isinstance(self._agent_ids, set):
|
| 265 |
+
self._agent_ids = set(self._agent_ids)
|
| 266 |
+
# Make this backward compatible as much as possible.
|
| 267 |
+
return self._agent_ids if self._agent_ids else set(self.agents)
|
| 268 |
+
|
| 269 |
+
@OldAPIStack
|
| 270 |
+
def to_base_env(
|
| 271 |
+
self,
|
| 272 |
+
make_env: Optional[Callable[[int], EnvType]] = None,
|
| 273 |
+
num_envs: int = 1,
|
| 274 |
+
remote_envs: bool = False,
|
| 275 |
+
remote_env_batch_wait_ms: int = 0,
|
| 276 |
+
restart_failed_sub_environments: bool = False,
|
| 277 |
+
) -> "BaseEnv":
|
| 278 |
+
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
| 279 |
+
|
| 280 |
+
The resulting BaseEnv is always vectorized (contains n
|
| 281 |
+
sub-environments) to support batched forward passes, where n may
|
| 282 |
+
also be 1. BaseEnv also supports async execution via the `poll` and
|
| 283 |
+
`send_actions` methods and thus supports external simulators.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
make_env: A callable taking an int as input (which indicates
|
| 287 |
+
the number of individual sub-environments within the final
|
| 288 |
+
vectorized BaseEnv) and returning one individual
|
| 289 |
+
sub-environment.
|
| 290 |
+
num_envs: The number of sub-environments to create in the
|
| 291 |
+
resulting (vectorized) BaseEnv. The already existing `env`
|
| 292 |
+
will be one of the `num_envs`.
|
| 293 |
+
remote_envs: Whether each sub-env should be a @ray.remote
|
| 294 |
+
actor. You can set this behavior in your config via the
|
| 295 |
+
`remote_worker_envs=True` option.
|
| 296 |
+
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
| 297 |
+
sub-environments for, if applicable. Only used if
|
| 298 |
+
`remote_envs` is True.
|
| 299 |
+
restart_failed_sub_environments: If True and any sub-environment (within
|
| 300 |
+
a vectorized env) throws any error during env stepping, we will try to
|
| 301 |
+
restart the faulty sub-environment. This is done
|
| 302 |
+
without disturbing the other (still intact) sub-environments.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
The resulting BaseEnv object.
|
| 306 |
+
"""
|
| 307 |
+
from ray.rllib.env.remote_base_env import RemoteBaseEnv
|
| 308 |
+
|
| 309 |
+
if remote_envs:
|
| 310 |
+
env = RemoteBaseEnv(
|
| 311 |
+
make_env,
|
| 312 |
+
num_envs,
|
| 313 |
+
multiagent=True,
|
| 314 |
+
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
| 315 |
+
restart_failed_sub_environments=restart_failed_sub_environments,
|
| 316 |
+
)
|
| 317 |
+
# Sub-environments are not ray.remote actors.
|
| 318 |
+
else:
|
| 319 |
+
env = MultiAgentEnvWrapper(
|
| 320 |
+
make_env=make_env,
|
| 321 |
+
existing_envs=[self],
|
| 322 |
+
num_envs=num_envs,
|
| 323 |
+
restart_failed_sub_environments=restart_failed_sub_environments,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
return env
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@DeveloperAPI
|
| 330 |
+
def make_multi_agent(
|
| 331 |
+
env_name_or_creator: Union[str, EnvCreator],
|
| 332 |
+
) -> Type["MultiAgentEnv"]:
|
| 333 |
+
"""Convenience wrapper for any single-agent env to be converted into MA.
|
| 334 |
+
|
| 335 |
+
Allows you to convert a simple (single-agent) `gym.Env` class
|
| 336 |
+
into a `MultiAgentEnv` class. This function simply stacks n instances
|
| 337 |
+
of the given ```gym.Env``` class into one unified ``MultiAgentEnv`` class
|
| 338 |
+
and returns this class, thus pretending the agents act together in the
|
| 339 |
+
same environment, whereas - under the hood - they live separately from
|
| 340 |
+
each other in n parallel single-agent envs.
|
| 341 |
+
|
| 342 |
+
Agent IDs in the resulting and are int numbers starting from 0
|
| 343 |
+
(first agent).
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
env_name_or_creator: String specifier or env_maker function taking
|
| 347 |
+
an EnvContext object as only arg and returning a gym.Env.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
New MultiAgentEnv class to be used as env.
|
| 351 |
+
The constructor takes a config dict with `num_agents` key
|
| 352 |
+
(default=1). The rest of the config dict will be passed on to the
|
| 353 |
+
underlying single-agent env's constructor.
|
| 354 |
+
|
| 355 |
+
.. testcode::
|
| 356 |
+
:skipif: True
|
| 357 |
+
|
| 358 |
+
from ray.rllib.env.multi_agent_env import make_multi_agent
|
| 359 |
+
# By gym string:
|
| 360 |
+
ma_cartpole_cls = make_multi_agent("CartPole-v1")
|
| 361 |
+
# Create a 2 agent multi-agent cartpole.
|
| 362 |
+
ma_cartpole = ma_cartpole_cls({"num_agents": 2})
|
| 363 |
+
obs = ma_cartpole.reset()
|
| 364 |
+
print(obs)
|
| 365 |
+
|
| 366 |
+
# By env-maker callable:
|
| 367 |
+
from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole
|
| 368 |
+
ma_stateless_cartpole_cls = make_multi_agent(
|
| 369 |
+
lambda config: StatelessCartPole(config))
|
| 370 |
+
# Create a 3 agent multi-agent stateless cartpole.
|
| 371 |
+
ma_stateless_cartpole = ma_stateless_cartpole_cls(
|
| 372 |
+
{"num_agents": 3})
|
| 373 |
+
print(obs)
|
| 374 |
+
|
| 375 |
+
.. testoutput::
|
| 376 |
+
|
| 377 |
+
{0: [...], 1: [...]}
|
| 378 |
+
{0: [...], 1: [...], 2: [...]}
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
class MultiEnv(MultiAgentEnv):
|
| 382 |
+
def __init__(self, config: EnvContext = None):
|
| 383 |
+
super().__init__()
|
| 384 |
+
|
| 385 |
+
# Note: Explicitly check for None here, because config
|
| 386 |
+
# can have an empty dict but meaningful data fields (worker_index,
|
| 387 |
+
# vector_index) etc.
|
| 388 |
+
# TODO (sven): Clean this up, so we are not mixing up dict fields
|
| 389 |
+
# with data fields.
|
| 390 |
+
if config is None:
|
| 391 |
+
config = {}
|
| 392 |
+
num = config.pop("num_agents", 1)
|
| 393 |
+
if isinstance(env_name_or_creator, str):
|
| 394 |
+
self.envs = [gym.make(env_name_or_creator) for _ in range(num)]
|
| 395 |
+
else:
|
| 396 |
+
self.envs = [env_name_or_creator(config) for _ in range(num)]
|
| 397 |
+
self.terminateds = set()
|
| 398 |
+
self.truncateds = set()
|
| 399 |
+
self.observation_spaces = {
|
| 400 |
+
i: self.envs[i].observation_space for i in range(num)
|
| 401 |
+
}
|
| 402 |
+
self.action_spaces = {i: self.envs[i].action_space for i in range(num)}
|
| 403 |
+
self.agents = list(range(num))
|
| 404 |
+
self.possible_agents = self.agents.copy()
|
| 405 |
+
|
| 406 |
+
@override(MultiAgentEnv)
|
| 407 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 408 |
+
self.terminateds = set()
|
| 409 |
+
self.truncateds = set()
|
| 410 |
+
obs, infos = {}, {}
|
| 411 |
+
for i, env in enumerate(self.envs):
|
| 412 |
+
obs[i], infos[i] = env.reset(seed=seed, options=options)
|
| 413 |
+
|
| 414 |
+
return obs, infos
|
| 415 |
+
|
| 416 |
+
@override(MultiAgentEnv)
|
| 417 |
+
def step(self, action_dict):
|
| 418 |
+
obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}
|
| 419 |
+
|
| 420 |
+
# The environment is expecting an action for at least one agent.
|
| 421 |
+
if len(action_dict) == 0:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
"The environment is expecting an action for at least one agent."
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
for i, action in action_dict.items():
|
| 427 |
+
obs[i], rew[i], terminated[i], truncated[i], info[i] = self.envs[
|
| 428 |
+
i
|
| 429 |
+
].step(action)
|
| 430 |
+
if terminated[i]:
|
| 431 |
+
self.terminateds.add(i)
|
| 432 |
+
if truncated[i]:
|
| 433 |
+
self.truncateds.add(i)
|
| 434 |
+
# TODO: Flaw in our MultiAgentEnv API wrt. new gymnasium: Need to return
|
| 435 |
+
# an additional episode_done bool that covers cases where all agents are
|
| 436 |
+
# either terminated or truncated, but not all are truncated and not all are
|
| 437 |
+
# terminated. We can then get rid of the aweful `__all__` special keys!
|
| 438 |
+
terminated["__all__"] = len(self.terminateds) + len(self.truncateds) == len(
|
| 439 |
+
self.envs
|
| 440 |
+
)
|
| 441 |
+
truncated["__all__"] = len(self.truncateds) == len(self.envs)
|
| 442 |
+
return obs, rew, terminated, truncated, info
|
| 443 |
+
|
| 444 |
+
@override(MultiAgentEnv)
|
| 445 |
+
def render(self):
|
| 446 |
+
# This render method simply renders all n underlying individual single-agent
|
| 447 |
+
# envs and concatenates their images (on top of each other if the returned
|
| 448 |
+
# images have dims where [width] > [height], otherwise next to each other).
|
| 449 |
+
render_images = [e.render() for e in self.envs]
|
| 450 |
+
if render_images[0].shape[1] > render_images[0].shape[0]:
|
| 451 |
+
concat_dim = 0
|
| 452 |
+
else:
|
| 453 |
+
concat_dim = 1
|
| 454 |
+
return np.concatenate(render_images, axis=concat_dim)
|
| 455 |
+
|
| 456 |
+
return MultiEnv
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@OldAPIStack
|
| 460 |
+
class MultiAgentEnvWrapper(BaseEnv):
|
| 461 |
+
"""Internal adapter of MultiAgentEnv to BaseEnv.
|
| 462 |
+
|
| 463 |
+
This also supports vectorization if num_envs > 1.
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
def __init__(
|
| 467 |
+
self,
|
| 468 |
+
make_env: Callable[[int], EnvType],
|
| 469 |
+
existing_envs: List["MultiAgentEnv"],
|
| 470 |
+
num_envs: int,
|
| 471 |
+
restart_failed_sub_environments: bool = False,
|
| 472 |
+
):
|
| 473 |
+
"""Wraps MultiAgentEnv(s) into the BaseEnv API.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
make_env: Factory that produces a new MultiAgentEnv instance taking the
|
| 477 |
+
vector index as only call argument.
|
| 478 |
+
Must be defined, if the number of existing envs is less than num_envs.
|
| 479 |
+
existing_envs: List of already existing multi-agent envs.
|
| 480 |
+
num_envs: Desired num multiagent envs to have at the end in
|
| 481 |
+
total. This will include the given (already created)
|
| 482 |
+
`existing_envs`.
|
| 483 |
+
restart_failed_sub_environments: If True and any sub-environment (within
|
| 484 |
+
this vectorized env) throws any error during env stepping, we will try
|
| 485 |
+
to restart the faulty sub-environment. This is done
|
| 486 |
+
without disturbing the other (still intact) sub-environments.
|
| 487 |
+
"""
|
| 488 |
+
self.make_env = make_env
|
| 489 |
+
self.envs = existing_envs
|
| 490 |
+
self.num_envs = num_envs
|
| 491 |
+
self.restart_failed_sub_environments = restart_failed_sub_environments
|
| 492 |
+
|
| 493 |
+
self.terminateds = set()
|
| 494 |
+
self.truncateds = set()
|
| 495 |
+
while len(self.envs) < self.num_envs:
|
| 496 |
+
self.envs.append(self.make_env(len(self.envs)))
|
| 497 |
+
for env in self.envs:
|
| 498 |
+
assert isinstance(env, MultiAgentEnv)
|
| 499 |
+
self._init_env_state(idx=None)
|
| 500 |
+
self._unwrapped_env = self.envs[0].unwrapped
|
| 501 |
+
|
| 502 |
+
@override(BaseEnv)
|
| 503 |
+
def poll(
|
| 504 |
+
self,
|
| 505 |
+
) -> Tuple[
|
| 506 |
+
MultiEnvDict,
|
| 507 |
+
MultiEnvDict,
|
| 508 |
+
MultiEnvDict,
|
| 509 |
+
MultiEnvDict,
|
| 510 |
+
MultiEnvDict,
|
| 511 |
+
MultiEnvDict,
|
| 512 |
+
]:
|
| 513 |
+
obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {}
|
| 514 |
+
for i, env_state in enumerate(self.env_states):
|
| 515 |
+
(
|
| 516 |
+
obs[i],
|
| 517 |
+
rewards[i],
|
| 518 |
+
terminateds[i],
|
| 519 |
+
truncateds[i],
|
| 520 |
+
infos[i],
|
| 521 |
+
) = env_state.poll()
|
| 522 |
+
return obs, rewards, terminateds, truncateds, infos, {}
|
| 523 |
+
|
| 524 |
+
@override(BaseEnv)
|
| 525 |
+
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
| 526 |
+
for env_id, agent_dict in action_dict.items():
|
| 527 |
+
if env_id in self.terminateds or env_id in self.truncateds:
|
| 528 |
+
raise ValueError(
|
| 529 |
+
f"Env {env_id} is already done and cannot accept new actions"
|
| 530 |
+
)
|
| 531 |
+
env = self.envs[env_id]
|
| 532 |
+
try:
|
| 533 |
+
obs, rewards, terminateds, truncateds, infos = env.step(agent_dict)
|
| 534 |
+
except Exception as e:
|
| 535 |
+
if self.restart_failed_sub_environments:
|
| 536 |
+
logger.exception(e.args[0])
|
| 537 |
+
self.try_restart(env_id=env_id)
|
| 538 |
+
obs = e
|
| 539 |
+
rewards = {}
|
| 540 |
+
terminateds = {"__all__": True}
|
| 541 |
+
truncateds = {"__all__": False}
|
| 542 |
+
infos = {}
|
| 543 |
+
else:
|
| 544 |
+
raise e
|
| 545 |
+
|
| 546 |
+
assert isinstance(
|
| 547 |
+
obs, (dict, Exception)
|
| 548 |
+
), "Not a multi-agent obs dict or an Exception!"
|
| 549 |
+
assert isinstance(rewards, dict), "Not a multi-agent reward dict!"
|
| 550 |
+
assert isinstance(terminateds, dict), "Not a multi-agent terminateds dict!"
|
| 551 |
+
assert isinstance(truncateds, dict), "Not a multi-agent truncateds dict!"
|
| 552 |
+
assert isinstance(infos, dict), "Not a multi-agent info dict!"
|
| 553 |
+
if isinstance(obs, dict):
|
| 554 |
+
info_diff = set(infos).difference(set(obs))
|
| 555 |
+
if info_diff and info_diff != {"__common__"}:
|
| 556 |
+
raise ValueError(
|
| 557 |
+
"Key set for infos must be a subset of obs (plus optionally "
|
| 558 |
+
"the '__common__' key for infos concerning all/no agents): "
|
| 559 |
+
"{} vs {}".format(infos.keys(), obs.keys())
|
| 560 |
+
)
|
| 561 |
+
if "__all__" not in terminateds:
|
| 562 |
+
raise ValueError(
|
| 563 |
+
"In multi-agent environments, '__all__': True|False must "
|
| 564 |
+
"be included in the 'terminateds' dict: got {}.".format(terminateds)
|
| 565 |
+
)
|
| 566 |
+
elif "__all__" not in truncateds:
|
| 567 |
+
raise ValueError(
|
| 568 |
+
"In multi-agent environments, '__all__': True|False must "
|
| 569 |
+
"be included in the 'truncateds' dict: got {}.".format(truncateds)
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
if terminateds["__all__"]:
|
| 573 |
+
self.terminateds.add(env_id)
|
| 574 |
+
if truncateds["__all__"]:
|
| 575 |
+
self.truncateds.add(env_id)
|
| 576 |
+
self.env_states[env_id].observe(
|
| 577 |
+
obs, rewards, terminateds, truncateds, infos
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
@override(BaseEnv)
|
| 581 |
+
def try_reset(
|
| 582 |
+
self,
|
| 583 |
+
env_id: Optional[EnvID] = None,
|
| 584 |
+
*,
|
| 585 |
+
seed: Optional[int] = None,
|
| 586 |
+
options: Optional[dict] = None,
|
| 587 |
+
) -> Optional[Tuple[MultiEnvDict, MultiEnvDict]]:
|
| 588 |
+
ret_obs = {}
|
| 589 |
+
ret_infos = {}
|
| 590 |
+
if isinstance(env_id, int):
|
| 591 |
+
env_id = [env_id]
|
| 592 |
+
if env_id is None:
|
| 593 |
+
env_id = list(range(len(self.envs)))
|
| 594 |
+
for idx in env_id:
|
| 595 |
+
obs, infos = self.env_states[idx].reset(seed=seed, options=options)
|
| 596 |
+
|
| 597 |
+
if isinstance(obs, Exception):
|
| 598 |
+
if self.restart_failed_sub_environments:
|
| 599 |
+
self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
|
| 600 |
+
else:
|
| 601 |
+
raise obs
|
| 602 |
+
else:
|
| 603 |
+
assert isinstance(obs, dict), "Not a multi-agent obs dict!"
|
| 604 |
+
if obs is not None:
|
| 605 |
+
if idx in self.terminateds:
|
| 606 |
+
self.terminateds.remove(idx)
|
| 607 |
+
if idx in self.truncateds:
|
| 608 |
+
self.truncateds.remove(idx)
|
| 609 |
+
ret_obs[idx] = obs
|
| 610 |
+
ret_infos[idx] = infos
|
| 611 |
+
return ret_obs, ret_infos
|
| 612 |
+
|
| 613 |
+
@override(BaseEnv)
|
| 614 |
+
def try_restart(self, env_id: Optional[EnvID] = None) -> None:
|
| 615 |
+
if isinstance(env_id, int):
|
| 616 |
+
env_id = [env_id]
|
| 617 |
+
if env_id is None:
|
| 618 |
+
env_id = list(range(len(self.envs)))
|
| 619 |
+
for idx in env_id:
|
| 620 |
+
# Try closing down the old (possibly faulty) sub-env, but ignore errors.
|
| 621 |
+
try:
|
| 622 |
+
self.envs[idx].close()
|
| 623 |
+
except Exception as e:
|
| 624 |
+
if log_once("close_sub_env"):
|
| 625 |
+
logger.warning(
|
| 626 |
+
"Trying to close old and replaced sub-environment (at vector "
|
| 627 |
+
f"index={idx}), but closing resulted in error:\n{e}"
|
| 628 |
+
)
|
| 629 |
+
# Try recreating the sub-env.
|
| 630 |
+
logger.warning(f"Trying to restart sub-environment at index {idx}.")
|
| 631 |
+
self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
|
| 632 |
+
logger.warning(f"Sub-environment at index {idx} restarted successfully.")
|
| 633 |
+
|
| 634 |
+
@override(BaseEnv)
|
| 635 |
+
def get_sub_environments(
|
| 636 |
+
self, as_dict: bool = False
|
| 637 |
+
) -> Union[Dict[str, EnvType], List[EnvType]]:
|
| 638 |
+
if as_dict:
|
| 639 |
+
return {_id: env_state.env for _id, env_state in enumerate(self.env_states)}
|
| 640 |
+
return [state.env for state in self.env_states]
|
| 641 |
+
|
| 642 |
+
@override(BaseEnv)
|
| 643 |
+
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
| 644 |
+
if env_id is None:
|
| 645 |
+
env_id = 0
|
| 646 |
+
assert isinstance(env_id, int)
|
| 647 |
+
return self.envs[env_id].render()
|
| 648 |
+
|
| 649 |
+
@property
|
| 650 |
+
@override(BaseEnv)
|
| 651 |
+
def observation_space(self) -> gym.spaces.Dict:
|
| 652 |
+
return self.envs[0].observation_space
|
| 653 |
+
|
| 654 |
+
@property
|
| 655 |
+
@override(BaseEnv)
|
| 656 |
+
def action_space(self) -> gym.Space:
|
| 657 |
+
return self.envs[0].action_space
|
| 658 |
+
|
| 659 |
+
@override(BaseEnv)
|
| 660 |
+
def get_agent_ids(self) -> Set[AgentID]:
|
| 661 |
+
return self.envs[0].get_agent_ids()
|
| 662 |
+
|
| 663 |
+
def _init_env_state(self, idx: Optional[int] = None) -> None:
|
| 664 |
+
"""Resets all or one particular sub-environment's state (by index).
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
idx: The index to reset at. If None, reset all the sub-environments' states.
|
| 668 |
+
"""
|
| 669 |
+
# If index is None, reset all sub-envs' states:
|
| 670 |
+
if idx is None:
|
| 671 |
+
self.env_states = [
|
| 672 |
+
_MultiAgentEnvState(env, self.restart_failed_sub_environments)
|
| 673 |
+
for env in self.envs
|
| 674 |
+
]
|
| 675 |
+
# Index provided, reset only the sub-env's state at the given index.
|
| 676 |
+
else:
|
| 677 |
+
assert isinstance(idx, int)
|
| 678 |
+
self.env_states[idx] = _MultiAgentEnvState(
|
| 679 |
+
self.envs[idx], self.restart_failed_sub_environments
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
@OldAPIStack
|
| 684 |
+
class _MultiAgentEnvState:
|
| 685 |
+
def __init__(self, env: MultiAgentEnv, return_error_as_obs: bool = False):
|
| 686 |
+
assert isinstance(env, MultiAgentEnv)
|
| 687 |
+
self.env = env
|
| 688 |
+
self.return_error_as_obs = return_error_as_obs
|
| 689 |
+
|
| 690 |
+
self.initialized = False
|
| 691 |
+
self.last_obs = {}
|
| 692 |
+
self.last_rewards = {}
|
| 693 |
+
self.last_terminateds = {"__all__": False}
|
| 694 |
+
self.last_truncateds = {"__all__": False}
|
| 695 |
+
self.last_infos = {}
|
| 696 |
+
|
| 697 |
+
def poll(
|
| 698 |
+
self,
|
| 699 |
+
) -> Tuple[
|
| 700 |
+
MultiAgentDict,
|
| 701 |
+
MultiAgentDict,
|
| 702 |
+
MultiAgentDict,
|
| 703 |
+
MultiAgentDict,
|
| 704 |
+
MultiAgentDict,
|
| 705 |
+
]:
|
| 706 |
+
if not self.initialized:
|
| 707 |
+
# TODO(sven): Should we make it possible to pass in a seed here?
|
| 708 |
+
self.reset()
|
| 709 |
+
self.initialized = True
|
| 710 |
+
|
| 711 |
+
observations = self.last_obs
|
| 712 |
+
rewards = {}
|
| 713 |
+
terminateds = {"__all__": self.last_terminateds["__all__"]}
|
| 714 |
+
truncateds = {"__all__": self.last_truncateds["__all__"]}
|
| 715 |
+
infos = self.last_infos
|
| 716 |
+
|
| 717 |
+
# If episode is done or we have an error, release everything we have.
|
| 718 |
+
if (
|
| 719 |
+
terminateds["__all__"]
|
| 720 |
+
or truncateds["__all__"]
|
| 721 |
+
or isinstance(observations, Exception)
|
| 722 |
+
):
|
| 723 |
+
rewards = self.last_rewards
|
| 724 |
+
self.last_rewards = {}
|
| 725 |
+
terminateds = self.last_terminateds
|
| 726 |
+
if isinstance(observations, Exception):
|
| 727 |
+
terminateds["__all__"] = True
|
| 728 |
+
truncateds["__all__"] = False
|
| 729 |
+
self.last_terminateds = {}
|
| 730 |
+
truncateds = self.last_truncateds
|
| 731 |
+
self.last_truncateds = {}
|
| 732 |
+
self.last_obs = {}
|
| 733 |
+
infos = self.last_infos
|
| 734 |
+
self.last_infos = {}
|
| 735 |
+
# Only release those agents' rewards/terminateds/truncateds/infos, whose
|
| 736 |
+
# observations we have.
|
| 737 |
+
else:
|
| 738 |
+
for ag in observations.keys():
|
| 739 |
+
if ag in self.last_rewards:
|
| 740 |
+
rewards[ag] = self.last_rewards[ag]
|
| 741 |
+
del self.last_rewards[ag]
|
| 742 |
+
if ag in self.last_terminateds:
|
| 743 |
+
terminateds[ag] = self.last_terminateds[ag]
|
| 744 |
+
del self.last_terminateds[ag]
|
| 745 |
+
if ag in self.last_truncateds:
|
| 746 |
+
truncateds[ag] = self.last_truncateds[ag]
|
| 747 |
+
del self.last_truncateds[ag]
|
| 748 |
+
|
| 749 |
+
self.last_terminateds["__all__"] = False
|
| 750 |
+
self.last_truncateds["__all__"] = False
|
| 751 |
+
return observations, rewards, terminateds, truncateds, infos
|
| 752 |
+
|
| 753 |
+
def observe(
|
| 754 |
+
self,
|
| 755 |
+
obs: MultiAgentDict,
|
| 756 |
+
rewards: MultiAgentDict,
|
| 757 |
+
terminateds: MultiAgentDict,
|
| 758 |
+
truncateds: MultiAgentDict,
|
| 759 |
+
infos: MultiAgentDict,
|
| 760 |
+
):
|
| 761 |
+
self.last_obs = obs
|
| 762 |
+
for ag, r in rewards.items():
|
| 763 |
+
if ag in self.last_rewards:
|
| 764 |
+
self.last_rewards[ag] += r
|
| 765 |
+
else:
|
| 766 |
+
self.last_rewards[ag] = r
|
| 767 |
+
for ag, d in terminateds.items():
|
| 768 |
+
if ag in self.last_terminateds:
|
| 769 |
+
self.last_terminateds[ag] = self.last_terminateds[ag] or d
|
| 770 |
+
else:
|
| 771 |
+
self.last_terminateds[ag] = d
|
| 772 |
+
for ag, t in truncateds.items():
|
| 773 |
+
if ag in self.last_truncateds:
|
| 774 |
+
self.last_truncateds[ag] = self.last_truncateds[ag] or t
|
| 775 |
+
else:
|
| 776 |
+
self.last_truncateds[ag] = t
|
| 777 |
+
self.last_infos = infos
|
| 778 |
+
|
| 779 |
+
def reset(
|
| 780 |
+
self,
|
| 781 |
+
*,
|
| 782 |
+
seed: Optional[int] = None,
|
| 783 |
+
options: Optional[dict] = None,
|
| 784 |
+
) -> Tuple[MultiAgentDict, MultiAgentDict]:
|
| 785 |
+
try:
|
| 786 |
+
obs_and_infos = self.env.reset(seed=seed, options=options)
|
| 787 |
+
except Exception as e:
|
| 788 |
+
if self.return_error_as_obs:
|
| 789 |
+
logger.exception(e.args[0])
|
| 790 |
+
obs_and_infos = e, e
|
| 791 |
+
else:
|
| 792 |
+
raise e
|
| 793 |
+
|
| 794 |
+
self.last_obs, self.last_infos = obs_and_infos
|
| 795 |
+
self.last_rewards = {}
|
| 796 |
+
self.last_terminateds = {"__all__": False}
|
| 797 |
+
self.last_truncateds = {"__all__": False}
|
| 798 |
+
|
| 799 |
+
return self.last_obs, self.last_infos
|
.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_env_runner.py
ADDED
|
@@ -0,0 +1,1107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from functools import partial
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from typing import Collection, DefaultDict, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import gymnasium as gym
|
| 8 |
+
|
| 9 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 10 |
+
from ray.rllib.callbacks.utils import make_callback
|
| 11 |
+
from ray.rllib.core import (
|
| 12 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR,
|
| 13 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR,
|
| 14 |
+
COMPONENT_RL_MODULE,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.core.columns import Columns
|
| 17 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec
|
| 18 |
+
from ray.rllib.env.env_context import EnvContext
|
| 19 |
+
from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
|
| 20 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 21 |
+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
|
| 22 |
+
from ray.rllib.env.utils import _gym_env_creator
|
| 23 |
+
from ray.rllib.utils import force_list
|
| 24 |
+
from ray.rllib.utils.annotations import override
|
| 25 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 26 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 27 |
+
from ray.rllib.utils.framework import get_device, try_import_torch
|
| 28 |
+
from ray.rllib.utils.metrics import (
|
| 29 |
+
EPISODE_DURATION_SEC_MEAN,
|
| 30 |
+
EPISODE_LEN_MAX,
|
| 31 |
+
EPISODE_LEN_MEAN,
|
| 32 |
+
EPISODE_LEN_MIN,
|
| 33 |
+
EPISODE_RETURN_MAX,
|
| 34 |
+
EPISODE_RETURN_MEAN,
|
| 35 |
+
EPISODE_RETURN_MIN,
|
| 36 |
+
NUM_AGENT_STEPS_SAMPLED,
|
| 37 |
+
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
|
| 38 |
+
NUM_ENV_STEPS_SAMPLED,
|
| 39 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 40 |
+
NUM_EPISODES,
|
| 41 |
+
NUM_EPISODES_LIFETIME,
|
| 42 |
+
NUM_MODULE_STEPS_SAMPLED,
|
| 43 |
+
NUM_MODULE_STEPS_SAMPLED_LIFETIME,
|
| 44 |
+
TIME_BETWEEN_SAMPLING,
|
| 45 |
+
WEIGHTS_SEQ_NO,
|
| 46 |
+
)
|
| 47 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 48 |
+
from ray.rllib.utils.pre_checks.env import check_multiagent_environments
|
| 49 |
+
from ray.rllib.utils.typing import EpisodeID, ModelWeights, ResultDict, StateDict
|
| 50 |
+
from ray.tune.registry import ENV_CREATOR, _global_registry
|
| 51 |
+
from ray.util.annotations import PublicAPI
|
| 52 |
+
|
| 53 |
+
torch, _ = try_import_torch()
|
| 54 |
+
logger = logging.getLogger("ray.rllib")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself
|
| 58 |
+
# a Checkpointable. Currently, only some of its subclasses are Checkpointables.
|
| 59 |
+
@PublicAPI(stability="alpha")
|
| 60 |
+
class MultiAgentEnvRunner(EnvRunner, Checkpointable):
|
| 61 |
+
"""The genetic environment runner for the multi-agent case."""
|
| 62 |
+
|
| 63 |
+
@override(EnvRunner)
|
| 64 |
+
def __init__(self, config: AlgorithmConfig, **kwargs):
|
| 65 |
+
"""Initializes a MultiAgentEnvRunner instance.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
config: An `AlgorithmConfig` object containing all settings needed to
|
| 69 |
+
build this `EnvRunner` class.
|
| 70 |
+
"""
|
| 71 |
+
super().__init__(config=config)
|
| 72 |
+
|
| 73 |
+
# Raise an Error, if the provided config is not a multi-agent one.
|
| 74 |
+
if not self.config.is_multi_agent:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Cannot use this EnvRunner class ({type(self).__name__}), if your "
|
| 77 |
+
"setup is not multi-agent! Try adding multi-agent information to your "
|
| 78 |
+
"AlgorithmConfig via calling the `config.multi_agent(policies=..., "
|
| 79 |
+
"policy_mapping_fn=...)`."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Get the worker index on which this instance is running.
|
| 83 |
+
self.worker_index: int = kwargs.get("worker_index")
|
| 84 |
+
self.tune_trial_id: str = kwargs.get("tune_trial_id")
|
| 85 |
+
|
| 86 |
+
# Set up all metrics-related structures and counters.
|
| 87 |
+
self.metrics: Optional[MetricsLogger] = None
|
| 88 |
+
self._setup_metrics()
|
| 89 |
+
|
| 90 |
+
# Create our callbacks object.
|
| 91 |
+
self._callbacks = [cls() for cls in force_list(self.config.callbacks_class)]
|
| 92 |
+
|
| 93 |
+
# Set device.
|
| 94 |
+
self._device = get_device(
|
| 95 |
+
self.config,
|
| 96 |
+
0 if not self.worker_index else self.config.num_gpus_per_env_runner,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Create the vectorized gymnasium env.
|
| 100 |
+
self.env: Optional[gym.Wrapper] = None
|
| 101 |
+
self.num_envs: int = 0
|
| 102 |
+
self.make_env()
|
| 103 |
+
|
| 104 |
+
# Create the env-to-module connector pipeline.
|
| 105 |
+
self._env_to_module = self.config.build_env_to_module_connector(
|
| 106 |
+
self.env.unwrapped, device=self._device
|
| 107 |
+
)
|
| 108 |
+
# Cached env-to-module results taken at the end of a `_sample_timesteps()`
|
| 109 |
+
# call to make sure the final observation (before an episode cut) gets properly
|
| 110 |
+
# processed (and maybe postprocessed and re-stored into the episode).
|
| 111 |
+
# For example, if we had a connector that normalizes observations and directly
|
| 112 |
+
# re-inserts these new obs back into the episode, the last observation in each
|
| 113 |
+
# sample call would NOT be processed, which could be very harmful in cases,
|
| 114 |
+
# in which value function bootstrapping of those (truncation) observations is
|
| 115 |
+
# required in the learning step.
|
| 116 |
+
self._cached_to_module = None
|
| 117 |
+
|
| 118 |
+
# Construct the MultiRLModule.
|
| 119 |
+
self.module: Optional[MultiRLModule] = None
|
| 120 |
+
self.make_module()
|
| 121 |
+
|
| 122 |
+
# Create the module-to-env connector pipeline.
|
| 123 |
+
self._module_to_env = self.config.build_module_to_env_connector(
|
| 124 |
+
self.env.unwrapped
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
self._needs_initial_reset: bool = True
|
| 128 |
+
self._episode: Optional[MultiAgentEpisode] = None
|
| 129 |
+
self._shared_data = None
|
| 130 |
+
|
| 131 |
+
self._weights_seq_no: int = 0
|
| 132 |
+
|
| 133 |
+
# Measures the time passed between returning from `sample()`
|
| 134 |
+
# and receiving the next `sample()` request from the user.
|
| 135 |
+
self._time_after_sampling = None
|
| 136 |
+
|
| 137 |
+
@override(EnvRunner)
|
| 138 |
+
def sample(
|
| 139 |
+
self,
|
| 140 |
+
*,
|
| 141 |
+
num_timesteps: int = None,
|
| 142 |
+
num_episodes: int = None,
|
| 143 |
+
explore: bool = None,
|
| 144 |
+
random_actions: bool = False,
|
| 145 |
+
force_reset: bool = False,
|
| 146 |
+
) -> List[MultiAgentEpisode]:
|
| 147 |
+
"""Runs and returns a sample (n timesteps or m episodes) on the env(s).
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
num_timesteps: The number of timesteps to sample during this call.
|
| 151 |
+
Note that only one of `num_timetseps` or `num_episodes` may be provided.
|
| 152 |
+
num_episodes: The number of episodes to sample during this call.
|
| 153 |
+
Note that only one of `num_timetseps` or `num_episodes` may be provided.
|
| 154 |
+
explore: If True, will use the RLModule's `forward_exploration()`
|
| 155 |
+
method to compute actions. If False, will use the RLModule's
|
| 156 |
+
`forward_inference()` method. If None (default), will use the `explore`
|
| 157 |
+
boolean setting from `self.config` passed into this EnvRunner's
|
| 158 |
+
constructor. You can change this setting in your config via
|
| 159 |
+
`config.env_runners(explore=True|False)`.
|
| 160 |
+
random_actions: If True, actions will be sampled randomly (from the action
|
| 161 |
+
space of the environment). If False (default), actions or action
|
| 162 |
+
distribution parameters are computed by the RLModule.
|
| 163 |
+
force_reset: Whether to force-reset all (vector) environments before
|
| 164 |
+
sampling. Useful if you would like to collect a clean slate of new
|
| 165 |
+
episodes via this call. Note that when sampling n episodes
|
| 166 |
+
(`num_episodes != None`), this is fixed to True.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
A list of `MultiAgentEpisode` instances, carrying the sampled data.
|
| 170 |
+
"""
|
| 171 |
+
assert not (num_timesteps is not None and num_episodes is not None)
|
| 172 |
+
|
| 173 |
+
# Log time between `sample()` requests.
|
| 174 |
+
if self._time_after_sampling is not None:
|
| 175 |
+
self.metrics.log_value(
|
| 176 |
+
key=TIME_BETWEEN_SAMPLING,
|
| 177 |
+
value=time.perf_counter() - self._time_after_sampling,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# If no execution details are provided, use the config to try to infer the
|
| 181 |
+
# desired timesteps/episodes to sample and the exploration behavior.
|
| 182 |
+
if explore is None:
|
| 183 |
+
explore = self.config.explore
|
| 184 |
+
if num_timesteps is None and num_episodes is None:
|
| 185 |
+
if self.config.batch_mode == "truncate_episodes":
|
| 186 |
+
num_timesteps = self.config.get_rollout_fragment_length(
|
| 187 |
+
worker_index=self.worker_index,
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
num_episodes = 1
|
| 191 |
+
|
| 192 |
+
# Sample n timesteps.
|
| 193 |
+
if num_timesteps is not None:
|
| 194 |
+
samples = self._sample_timesteps(
|
| 195 |
+
num_timesteps=num_timesteps,
|
| 196 |
+
explore=explore,
|
| 197 |
+
random_actions=random_actions,
|
| 198 |
+
force_reset=force_reset,
|
| 199 |
+
)
|
| 200 |
+
# Sample m episodes.
|
| 201 |
+
else:
|
| 202 |
+
samples = self._sample_episodes(
|
| 203 |
+
num_episodes=num_episodes,
|
| 204 |
+
explore=explore,
|
| 205 |
+
random_actions=random_actions,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Make the `on_sample_end` callback.
|
| 209 |
+
make_callback(
|
| 210 |
+
"on_sample_end",
|
| 211 |
+
callbacks_objects=self._callbacks,
|
| 212 |
+
callbacks_functions=self.config.callbacks_on_sample_end,
|
| 213 |
+
kwargs=dict(
|
| 214 |
+
env_runner=self,
|
| 215 |
+
metrics_logger=self.metrics,
|
| 216 |
+
samples=samples,
|
| 217 |
+
),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
self._time_after_sampling = time.perf_counter()
|
| 221 |
+
|
| 222 |
+
return samples
|
| 223 |
+
|
| 224 |
+
def _sample_timesteps(
|
| 225 |
+
self,
|
| 226 |
+
num_timesteps: int,
|
| 227 |
+
explore: bool,
|
| 228 |
+
random_actions: bool = False,
|
| 229 |
+
force_reset: bool = False,
|
| 230 |
+
) -> List[MultiAgentEpisode]:
|
| 231 |
+
"""Helper method to sample n timesteps.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
num_timesteps: int. Number of timesteps to sample during rollout.
|
| 235 |
+
explore: boolean. If in exploration or inference mode. Exploration
|
| 236 |
+
mode might for some algorithms provide extza model outputs that
|
| 237 |
+
are redundant in inference mode.
|
| 238 |
+
random_actions: boolean. If actions should be sampled from the action
|
| 239 |
+
space. In default mode (i.e. `False`) we sample actions frokm the
|
| 240 |
+
policy.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
`Lists of `MultiAgentEpisode` instances, carrying the collected sample data.
|
| 244 |
+
"""
|
| 245 |
+
done_episodes_to_return: List[MultiAgentEpisode] = []
|
| 246 |
+
|
| 247 |
+
# Have to reset the env.
|
| 248 |
+
if force_reset or self._needs_initial_reset:
|
| 249 |
+
# Create n new episodes and make the `on_episode_created` callbacks.
|
| 250 |
+
self._episode = self._new_episode()
|
| 251 |
+
self._make_on_episode_callback("on_episode_created")
|
| 252 |
+
|
| 253 |
+
# Erase all cached ongoing episodes (these will never be completed and
|
| 254 |
+
# would thus never be returned/cleaned by `get_metrics` and cause a memory
|
| 255 |
+
# leak).
|
| 256 |
+
self._ongoing_episodes_for_metrics.clear()
|
| 257 |
+
|
| 258 |
+
# Try resetting the environment.
|
| 259 |
+
# TODO (simon): Check, if we need here the seed from the config.
|
| 260 |
+
obs, infos = self._try_env_reset()
|
| 261 |
+
|
| 262 |
+
self._cached_to_module = None
|
| 263 |
+
|
| 264 |
+
# Call `on_episode_start()` callbacks.
|
| 265 |
+
self._make_on_episode_callback("on_episode_start")
|
| 266 |
+
|
| 267 |
+
# We just reset the env. Don't have to force this again in the next
|
| 268 |
+
# call to `self._sample_timesteps()`.
|
| 269 |
+
self._needs_initial_reset = False
|
| 270 |
+
|
| 271 |
+
# Set the initial observations in the episodes.
|
| 272 |
+
self._episode.add_env_reset(observations=obs, infos=infos)
|
| 273 |
+
|
| 274 |
+
self._shared_data = {
|
| 275 |
+
"agent_to_module_mapping_fn": self.config.policy_mapping_fn,
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
# Loop through timesteps.
|
| 279 |
+
ts = 0
|
| 280 |
+
|
| 281 |
+
while ts < num_timesteps:
|
| 282 |
+
# Act randomly.
|
| 283 |
+
if random_actions:
|
| 284 |
+
# Only act (randomly) for those agents that had an observation.
|
| 285 |
+
to_env = {
|
| 286 |
+
Columns.ACTIONS: [
|
| 287 |
+
{
|
| 288 |
+
aid: self.env.unwrapped.get_action_space(aid).sample()
|
| 289 |
+
for aid in self._episode.get_agents_to_act()
|
| 290 |
+
}
|
| 291 |
+
]
|
| 292 |
+
}
|
| 293 |
+
# Compute an action using the RLModule.
|
| 294 |
+
else:
|
| 295 |
+
# Env-to-module connector.
|
| 296 |
+
to_module = self._cached_to_module or self._env_to_module(
|
| 297 |
+
rl_module=self.module,
|
| 298 |
+
episodes=[self._episode],
|
| 299 |
+
explore=explore,
|
| 300 |
+
shared_data=self._shared_data,
|
| 301 |
+
metrics=self.metrics,
|
| 302 |
+
)
|
| 303 |
+
self._cached_to_module = None
|
| 304 |
+
|
| 305 |
+
# MultiRLModule forward pass: Explore or not.
|
| 306 |
+
if explore:
|
| 307 |
+
env_steps_lifetime = (
|
| 308 |
+
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
|
| 309 |
+
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
|
| 310 |
+
) * (self.config.num_env_runners or 1)
|
| 311 |
+
to_env = self.module.forward_exploration(
|
| 312 |
+
to_module, t=env_steps_lifetime
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
to_env = self.module.forward_inference(to_module)
|
| 316 |
+
|
| 317 |
+
# Module-to-env connector.
|
| 318 |
+
to_env = self._module_to_env(
|
| 319 |
+
rl_module=self.module,
|
| 320 |
+
batch=to_env,
|
| 321 |
+
episodes=[self._episode],
|
| 322 |
+
explore=explore,
|
| 323 |
+
shared_data=self._shared_data,
|
| 324 |
+
metrics=self.metrics,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Extract the (vectorized) actions (to be sent to the env) from the
|
| 328 |
+
# module/connector output. Note that these actions are fully ready (e.g.
|
| 329 |
+
# already unsquashed/clipped) to be sent to the environment) and might not
|
| 330 |
+
# be identical to the actions produced by the RLModule/distribution, which
|
| 331 |
+
# are the ones stored permanently in the episode objects.
|
| 332 |
+
actions = to_env.pop(Columns.ACTIONS)
|
| 333 |
+
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
|
| 334 |
+
|
| 335 |
+
# Try stepping the environment.
|
| 336 |
+
# TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env.
|
| 337 |
+
# Support vectorized multi-agent envs.
|
| 338 |
+
results = self._try_env_step(actions_for_env[0])
|
| 339 |
+
# If any failure occurs during stepping -> Throw away all data collected
|
| 340 |
+
# thus far and restart sampling procedure.
|
| 341 |
+
if results == ENV_STEP_FAILURE:
|
| 342 |
+
return self._sample_timesteps(
|
| 343 |
+
num_timesteps=num_timesteps,
|
| 344 |
+
explore=explore,
|
| 345 |
+
random_actions=random_actions,
|
| 346 |
+
force_reset=True,
|
| 347 |
+
)
|
| 348 |
+
obs, rewards, terminateds, truncateds, infos = results
|
| 349 |
+
|
| 350 |
+
# TODO (sven): This simple approach to re-map `to_env` from a
|
| 351 |
+
# dict[col, List[MADict]] to a dict[agentID, MADict] would not work for
|
| 352 |
+
# a vectorized env.
|
| 353 |
+
extra_model_outputs = defaultdict(dict)
|
| 354 |
+
for col, ma_dict_list in to_env.items():
|
| 355 |
+
# TODO (sven): Support vectorized MA env.
|
| 356 |
+
ma_dict = ma_dict_list[0]
|
| 357 |
+
for agent_id, val in ma_dict.items():
|
| 358 |
+
extra_model_outputs[agent_id][col] = val
|
| 359 |
+
extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no
|
| 360 |
+
extra_model_outputs = dict(extra_model_outputs)
|
| 361 |
+
|
| 362 |
+
# Record the timestep in the episode instance.
|
| 363 |
+
self._episode.add_env_step(
|
| 364 |
+
obs,
|
| 365 |
+
actions[0],
|
| 366 |
+
rewards,
|
| 367 |
+
infos=infos,
|
| 368 |
+
terminateds=terminateds,
|
| 369 |
+
truncateds=truncateds,
|
| 370 |
+
extra_model_outputs=extra_model_outputs,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
ts += self._increase_sampled_metrics(self.num_envs, obs, self._episode)
|
| 374 |
+
|
| 375 |
+
# Make the `on_episode_step` callback (before finalizing the episode
|
| 376 |
+
# object).
|
| 377 |
+
self._make_on_episode_callback("on_episode_step")
|
| 378 |
+
|
| 379 |
+
# Episode is done for all agents. Wrap up the old one and create a new
|
| 380 |
+
# one (and reset it) to continue.
|
| 381 |
+
if self._episode.is_done:
|
| 382 |
+
# We have to perform an extra env-to-module pass here, just in case
|
| 383 |
+
# the user's connector pipeline performs (permanent) transforms
|
| 384 |
+
# on each observation (including this final one here). Without such
|
| 385 |
+
# a call and in case the structure of the observations change
|
| 386 |
+
# sufficiently, the following `to_numpy()` call on the episode will
|
| 387 |
+
# fail.
|
| 388 |
+
if self.module is not None:
|
| 389 |
+
self._env_to_module(
|
| 390 |
+
episodes=[self._episode],
|
| 391 |
+
explore=explore,
|
| 392 |
+
rl_module=self.module,
|
| 393 |
+
shared_data=self._shared_data,
|
| 394 |
+
metrics=self.metrics,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Make the `on_episode_end` callback (before finalizing the episode,
|
| 398 |
+
# but after(!) the last env-to-module connector call has been made.
|
| 399 |
+
# -> All obs (even the terminal one) should have been processed now (by
|
| 400 |
+
# the connector, if applicable).
|
| 401 |
+
self._make_on_episode_callback("on_episode_end")
|
| 402 |
+
|
| 403 |
+
self._prune_zero_len_sa_episodes(self._episode)
|
| 404 |
+
|
| 405 |
+
# Numpy'ize the episode.
|
| 406 |
+
if self.config.episodes_to_numpy:
|
| 407 |
+
done_episodes_to_return.append(self._episode.to_numpy())
|
| 408 |
+
# Leave episode as lists of individual (obs, action, etc..) items.
|
| 409 |
+
else:
|
| 410 |
+
done_episodes_to_return.append(self._episode)
|
| 411 |
+
|
| 412 |
+
# Create a new episode instance.
|
| 413 |
+
self._episode = self._new_episode()
|
| 414 |
+
self._make_on_episode_callback("on_episode_created")
|
| 415 |
+
|
| 416 |
+
# Reset the environment.
|
| 417 |
+
obs, infos = self._try_env_reset()
|
| 418 |
+
# Add initial observations and infos.
|
| 419 |
+
self._episode.add_env_reset(observations=obs, infos=infos)
|
| 420 |
+
|
| 421 |
+
# Make the `on_episode_start` callback.
|
| 422 |
+
self._make_on_episode_callback("on_episode_start")
|
| 423 |
+
|
| 424 |
+
# Already perform env-to-module connector call for next call to
|
| 425 |
+
# `_sample_timesteps()`. See comment in c'tor for `self._cached_to_module`.
|
| 426 |
+
if self.module is not None:
|
| 427 |
+
self._cached_to_module = self._env_to_module(
|
| 428 |
+
rl_module=self.module,
|
| 429 |
+
episodes=[self._episode],
|
| 430 |
+
explore=explore,
|
| 431 |
+
shared_data=self._shared_data,
|
| 432 |
+
metrics=self.metrics,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Store done episodes for metrics.
|
| 436 |
+
self._done_episodes_for_metrics.extend(done_episodes_to_return)
|
| 437 |
+
|
| 438 |
+
# Also, make sure we start new episode chunks (continuing the ongoing episodes
|
| 439 |
+
# from the to-be-returned chunks).
|
| 440 |
+
ongoing_episode_continuation = self._episode.cut(
|
| 441 |
+
len_lookback_buffer=self.config.episode_lookback_horizon
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
ongoing_episodes_to_return = []
|
| 445 |
+
# Just started Episodes do not have to be returned. There is no data
|
| 446 |
+
# in them anyway.
|
| 447 |
+
if self._episode.env_t > 0:
|
| 448 |
+
self._episode.validate()
|
| 449 |
+
self._ongoing_episodes_for_metrics[self._episode.id_].append(self._episode)
|
| 450 |
+
|
| 451 |
+
self._prune_zero_len_sa_episodes(self._episode)
|
| 452 |
+
|
| 453 |
+
# Numpy'ize the episode.
|
| 454 |
+
if self.config.episodes_to_numpy:
|
| 455 |
+
ongoing_episodes_to_return.append(self._episode.to_numpy())
|
| 456 |
+
# Leave episode as lists of individual (obs, action, etc..) items.
|
| 457 |
+
else:
|
| 458 |
+
ongoing_episodes_to_return.append(self._episode)
|
| 459 |
+
|
| 460 |
+
# Continue collecting into the cut Episode chunk.
|
| 461 |
+
self._episode = ongoing_episode_continuation
|
| 462 |
+
|
| 463 |
+
# Return collected episode data.
|
| 464 |
+
return done_episodes_to_return + ongoing_episodes_to_return
|
| 465 |
+
|
| 466 |
+
def _sample_episodes(
|
| 467 |
+
self,
|
| 468 |
+
num_episodes: int,
|
| 469 |
+
explore: bool,
|
| 470 |
+
random_actions: bool = False,
|
| 471 |
+
) -> List[MultiAgentEpisode]:
|
| 472 |
+
"""Helper method to run n episodes.
|
| 473 |
+
|
| 474 |
+
See docstring of `self.sample()` for more details.
|
| 475 |
+
"""
|
| 476 |
+
# If user calls sample(num_timesteps=..) after this, we must reset again
|
| 477 |
+
# at the beginning.
|
| 478 |
+
self._needs_initial_reset = True
|
| 479 |
+
|
| 480 |
+
done_episodes_to_return: List[MultiAgentEpisode] = []
|
| 481 |
+
|
| 482 |
+
# Create a new multi-agent episode.
|
| 483 |
+
_episode = self._new_episode()
|
| 484 |
+
self._make_on_episode_callback("on_episode_created", _episode)
|
| 485 |
+
_shared_data = {
|
| 486 |
+
"agent_to_module_mapping_fn": self.config.policy_mapping_fn,
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
# Try resetting the environment.
|
| 490 |
+
# TODO (simon): Check, if we need here the seed from the config.
|
| 491 |
+
obs, infos = self._try_env_reset()
|
| 492 |
+
# Set initial obs and infos in the episodes.
|
| 493 |
+
_episode.add_env_reset(observations=obs, infos=infos)
|
| 494 |
+
self._make_on_episode_callback("on_episode_start", _episode)
|
| 495 |
+
|
| 496 |
+
# Loop over episodes.
|
| 497 |
+
eps = 0
|
| 498 |
+
ts = 0
|
| 499 |
+
while eps < num_episodes:
|
| 500 |
+
# Act randomly.
|
| 501 |
+
if random_actions:
|
| 502 |
+
# Only act (randomly) for those agents that had an observation.
|
| 503 |
+
to_env = {
|
| 504 |
+
Columns.ACTIONS: [
|
| 505 |
+
{
|
| 506 |
+
aid: self.env.unwrapped.get_action_space(aid).sample()
|
| 507 |
+
for aid in self._episode.get_agents_to_act()
|
| 508 |
+
}
|
| 509 |
+
]
|
| 510 |
+
}
|
| 511 |
+
# Compute an action using the RLModule.
|
| 512 |
+
else:
|
| 513 |
+
# Env-to-module connector.
|
| 514 |
+
to_module = self._env_to_module(
|
| 515 |
+
rl_module=self.module,
|
| 516 |
+
episodes=[_episode],
|
| 517 |
+
explore=explore,
|
| 518 |
+
shared_data=_shared_data,
|
| 519 |
+
metrics=self.metrics,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# MultiRLModule forward pass: Explore or not.
|
| 523 |
+
if explore:
|
| 524 |
+
env_steps_lifetime = (
|
| 525 |
+
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
|
| 526 |
+
+ self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0)
|
| 527 |
+
) * (self.config.num_env_runners or 1)
|
| 528 |
+
to_env = self.module.forward_exploration(
|
| 529 |
+
to_module, t=env_steps_lifetime
|
| 530 |
+
)
|
| 531 |
+
else:
|
| 532 |
+
to_env = self.module.forward_inference(to_module)
|
| 533 |
+
|
| 534 |
+
# Module-to-env connector.
|
| 535 |
+
to_env = self._module_to_env(
|
| 536 |
+
rl_module=self.module,
|
| 537 |
+
batch=to_env,
|
| 538 |
+
episodes=[_episode],
|
| 539 |
+
explore=explore,
|
| 540 |
+
shared_data=_shared_data,
|
| 541 |
+
metrics=self.metrics,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# Extract the (vectorized) actions (to be sent to the env) from the
|
| 545 |
+
# module/connector output. Note that these actions are fully ready (e.g.
|
| 546 |
+
# already unsquashed/clipped) to be sent to the environment) and might not
|
| 547 |
+
# be identical to the actions produced by the RLModule/distribution, which
|
| 548 |
+
# are the ones stored permanently in the episode objects.
|
| 549 |
+
actions = to_env.pop(Columns.ACTIONS)
|
| 550 |
+
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
|
| 551 |
+
|
| 552 |
+
# Try stepping the environment.
|
| 553 |
+
# TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env.
|
| 554 |
+
# Support vectorized multi-agent envs.
|
| 555 |
+
results = self._try_env_step(actions_for_env[0])
|
| 556 |
+
# If any failure occurs during stepping -> Throw away all data collected
|
| 557 |
+
# thus far and restart sampling procedure.
|
| 558 |
+
if results == ENV_STEP_FAILURE:
|
| 559 |
+
return self._sample_episodes(
|
| 560 |
+
num_episodes=num_episodes,
|
| 561 |
+
explore=explore,
|
| 562 |
+
random_actions=random_actions,
|
| 563 |
+
)
|
| 564 |
+
obs, rewards, terminateds, truncateds, infos = results
|
| 565 |
+
|
| 566 |
+
# TODO (sven): This simple approach to re-map `to_env` from a
|
| 567 |
+
# dict[col, List[MADict]] to a dict[agentID, MADict] would not work for
|
| 568 |
+
# a vectorized env.
|
| 569 |
+
extra_model_outputs = defaultdict(dict)
|
| 570 |
+
for col, ma_dict_list in to_env.items():
|
| 571 |
+
# TODO (sven): Support vectorized MA env.
|
| 572 |
+
ma_dict = ma_dict_list[0]
|
| 573 |
+
for agent_id, val in ma_dict.items():
|
| 574 |
+
extra_model_outputs[agent_id][col] = val
|
| 575 |
+
extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no
|
| 576 |
+
extra_model_outputs = dict(extra_model_outputs)
|
| 577 |
+
|
| 578 |
+
# Record the timestep in the episode instance.
|
| 579 |
+
_episode.add_env_step(
|
| 580 |
+
obs,
|
| 581 |
+
actions[0],
|
| 582 |
+
rewards,
|
| 583 |
+
infos=infos,
|
| 584 |
+
terminateds=terminateds,
|
| 585 |
+
truncateds=truncateds,
|
| 586 |
+
extra_model_outputs=extra_model_outputs,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
ts += self._increase_sampled_metrics(self.num_envs, obs, _episode)
|
| 590 |
+
|
| 591 |
+
# Make `on_episode_step` callback before finalizing the episode.
|
| 592 |
+
self._make_on_episode_callback("on_episode_step", _episode)
|
| 593 |
+
|
| 594 |
+
# TODO (sven, simon): We have to check, if we need this elaborate
|
| 595 |
+
# function here or if the `MultiAgentEnv` defines the cases that
|
| 596 |
+
# can happen.
|
| 597 |
+
# Right now we have:
|
| 598 |
+
# 1. Most times only agents that step get `terminated`, `truncated`
|
| 599 |
+
# i.e. the rest we have to check in the episode.
|
| 600 |
+
# 2. There are edge cases like, some agents terminated, all others
|
| 601 |
+
# truncated and vice versa.
|
| 602 |
+
# See also `MultiAgentEpisode` for handling the `__all__`.
|
| 603 |
+
if _episode.is_done:
|
| 604 |
+
# Increase episode count.
|
| 605 |
+
eps += 1
|
| 606 |
+
|
| 607 |
+
# We have to perform an extra env-to-module pass here, just in case
|
| 608 |
+
# the user's connector pipeline performs (permanent) transforms
|
| 609 |
+
# on each observation (including this final one here). Without such
|
| 610 |
+
# a call and in case the structure of the observations change
|
| 611 |
+
# sufficiently, the following `to_numpy()` call on the episode will
|
| 612 |
+
# fail.
|
| 613 |
+
if self.module is not None:
|
| 614 |
+
self._env_to_module(
|
| 615 |
+
episodes=[_episode],
|
| 616 |
+
explore=explore,
|
| 617 |
+
rl_module=self.module,
|
| 618 |
+
shared_data=_shared_data,
|
| 619 |
+
metrics=self.metrics,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Make the `on_episode_end` callback (before finalizing the episode,
|
| 623 |
+
# but after(!) the last env-to-module connector call has been made.
|
| 624 |
+
# -> All obs (even the terminal one) should have been processed now (by
|
| 625 |
+
# the connector, if applicable).
|
| 626 |
+
self._make_on_episode_callback("on_episode_end", _episode)
|
| 627 |
+
|
| 628 |
+
self._prune_zero_len_sa_episodes(_episode)
|
| 629 |
+
|
| 630 |
+
# Numpy'ize the episode.
|
| 631 |
+
if self.config.episodes_to_numpy:
|
| 632 |
+
done_episodes_to_return.append(_episode.to_numpy())
|
| 633 |
+
# Leave episode as lists of individual (obs, action, etc..) items.
|
| 634 |
+
else:
|
| 635 |
+
done_episodes_to_return.append(_episode)
|
| 636 |
+
|
| 637 |
+
# Also early-out if we reach the number of episodes within this
|
| 638 |
+
# for-loop.
|
| 639 |
+
if eps == num_episodes:
|
| 640 |
+
break
|
| 641 |
+
|
| 642 |
+
# Create a new episode instance.
|
| 643 |
+
_episode = self._new_episode()
|
| 644 |
+
self._make_on_episode_callback("on_episode_created", _episode)
|
| 645 |
+
|
| 646 |
+
# Try resetting the environment.
|
| 647 |
+
obs, infos = self._try_env_reset()
|
| 648 |
+
# Add initial observations and infos.
|
| 649 |
+
_episode.add_env_reset(observations=obs, infos=infos)
|
| 650 |
+
|
| 651 |
+
# Make `on_episode_start` callback.
|
| 652 |
+
self._make_on_episode_callback("on_episode_start", _episode)
|
| 653 |
+
|
| 654 |
+
self._done_episodes_for_metrics.extend(done_episodes_to_return)
|
| 655 |
+
|
| 656 |
+
return done_episodes_to_return
|
| 657 |
+
|
| 658 |
+
@override(EnvRunner)
|
| 659 |
+
def get_spaces(self):
|
| 660 |
+
# Return the already agent-to-module translated spaces from our connector
|
| 661 |
+
# pipeline.
|
| 662 |
+
return {
|
| 663 |
+
**{
|
| 664 |
+
mid: (o, self._env_to_module.action_space[mid])
|
| 665 |
+
for mid, o in self._env_to_module.observation_space.spaces.items()
|
| 666 |
+
},
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
@override(EnvRunner)
|
| 670 |
+
def get_metrics(self) -> ResultDict:
|
| 671 |
+
# Compute per-episode metrics (only on already completed episodes).
|
| 672 |
+
for eps in self._done_episodes_for_metrics:
|
| 673 |
+
assert eps.is_done
|
| 674 |
+
episode_length = len(eps)
|
| 675 |
+
agent_steps = defaultdict(
|
| 676 |
+
int,
|
| 677 |
+
{str(aid): len(sa_eps) for aid, sa_eps in eps.agent_episodes.items()},
|
| 678 |
+
)
|
| 679 |
+
episode_return = eps.get_return()
|
| 680 |
+
episode_duration_s = eps.get_duration_s()
|
| 681 |
+
|
| 682 |
+
agent_episode_returns = defaultdict(
|
| 683 |
+
float,
|
| 684 |
+
{
|
| 685 |
+
str(sa_eps.agent_id): sa_eps.get_return()
|
| 686 |
+
for sa_eps in eps.agent_episodes.values()
|
| 687 |
+
},
|
| 688 |
+
)
|
| 689 |
+
module_episode_returns = defaultdict(
|
| 690 |
+
float,
|
| 691 |
+
{
|
| 692 |
+
sa_eps.module_id: sa_eps.get_return()
|
| 693 |
+
for sa_eps in eps.agent_episodes.values()
|
| 694 |
+
},
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# Don't forget about the already returned chunks of this episode.
|
| 698 |
+
if eps.id_ in self._ongoing_episodes_for_metrics:
|
| 699 |
+
for eps2 in self._ongoing_episodes_for_metrics[eps.id_]:
|
| 700 |
+
return_eps2 = eps2.get_return()
|
| 701 |
+
episode_length += len(eps2)
|
| 702 |
+
episode_return += return_eps2
|
| 703 |
+
episode_duration_s += eps2.get_duration_s()
|
| 704 |
+
|
| 705 |
+
for sa_eps in eps2.agent_episodes.values():
|
| 706 |
+
return_sa = sa_eps.get_return()
|
| 707 |
+
agent_steps[str(sa_eps.agent_id)] += len(sa_eps)
|
| 708 |
+
agent_episode_returns[str(sa_eps.agent_id)] += return_sa
|
| 709 |
+
module_episode_returns[sa_eps.module_id] += return_sa
|
| 710 |
+
|
| 711 |
+
del self._ongoing_episodes_for_metrics[eps.id_]
|
| 712 |
+
|
| 713 |
+
self._log_episode_metrics(
|
| 714 |
+
episode_length,
|
| 715 |
+
episode_return,
|
| 716 |
+
episode_duration_s,
|
| 717 |
+
agent_episode_returns,
|
| 718 |
+
module_episode_returns,
|
| 719 |
+
dict(agent_steps),
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Now that we have logged everything, clear cache of done episodes.
|
| 723 |
+
self._done_episodes_for_metrics.clear()
|
| 724 |
+
|
| 725 |
+
# Return reduced metrics.
|
| 726 |
+
return self.metrics.reduce()
|
| 727 |
+
|
| 728 |
+
@override(Checkpointable)
|
| 729 |
+
def get_state(
|
| 730 |
+
self,
|
| 731 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 732 |
+
*,
|
| 733 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 734 |
+
**kwargs,
|
| 735 |
+
) -> StateDict:
|
| 736 |
+
# Basic state dict.
|
| 737 |
+
state = {
|
| 738 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
|
| 739 |
+
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
|
| 740 |
+
),
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
# RLModule (MultiRLModule) component.
|
| 744 |
+
if self._check_component(COMPONENT_RL_MODULE, components, not_components):
|
| 745 |
+
state[COMPONENT_RL_MODULE] = self.module.get_state(
|
| 746 |
+
components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
|
| 747 |
+
not_components=self._get_subcomponents(
|
| 748 |
+
COMPONENT_RL_MODULE, not_components
|
| 749 |
+
),
|
| 750 |
+
**kwargs,
|
| 751 |
+
)
|
| 752 |
+
state[WEIGHTS_SEQ_NO] = self._weights_seq_no
|
| 753 |
+
|
| 754 |
+
# Env-to-module connector.
|
| 755 |
+
if self._check_component(
|
| 756 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components
|
| 757 |
+
):
|
| 758 |
+
state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state()
|
| 759 |
+
# Module-to-env connector.
|
| 760 |
+
if self._check_component(
|
| 761 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components
|
| 762 |
+
):
|
| 763 |
+
state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state()
|
| 764 |
+
|
| 765 |
+
return state
|
| 766 |
+
|
| 767 |
+
@override(Checkpointable)
|
| 768 |
+
def set_state(self, state: StateDict) -> None:
|
| 769 |
+
if COMPONENT_ENV_TO_MODULE_CONNECTOR in state:
|
| 770 |
+
self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR])
|
| 771 |
+
if COMPONENT_MODULE_TO_ENV_CONNECTOR in state:
|
| 772 |
+
self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR])
|
| 773 |
+
|
| 774 |
+
# Update RLModule state.
|
| 775 |
+
if COMPONENT_RL_MODULE in state:
|
| 776 |
+
# A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
|
| 777 |
+
# update.
|
| 778 |
+
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
|
| 779 |
+
|
| 780 |
+
# Only update the weigths, if this is the first synchronization or
|
| 781 |
+
# if the weights of this `EnvRunner` lacks behind the actual ones.
|
| 782 |
+
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
|
| 783 |
+
self.module.set_state(state[COMPONENT_RL_MODULE])
|
| 784 |
+
|
| 785 |
+
# Update weights_seq_no, if the new one is > 0.
|
| 786 |
+
if weights_seq_no > 0:
|
| 787 |
+
self._weights_seq_no = weights_seq_no
|
| 788 |
+
|
| 789 |
+
# Update lifetime counters.
|
| 790 |
+
if NUM_ENV_STEPS_SAMPLED_LIFETIME in state:
|
| 791 |
+
self.metrics.set_value(
|
| 792 |
+
key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 793 |
+
value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME],
|
| 794 |
+
reduce="sum",
|
| 795 |
+
with_throughput=True,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
@override(Checkpointable)
|
| 799 |
+
def get_ctor_args_and_kwargs(self):
|
| 800 |
+
return (
|
| 801 |
+
(), # *args
|
| 802 |
+
{"config": self.config}, # **kwargs
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
@override(Checkpointable)
|
| 806 |
+
def get_metadata(self):
|
| 807 |
+
metadata = Checkpointable.get_metadata(self)
|
| 808 |
+
metadata.update(
|
| 809 |
+
{
|
| 810 |
+
# TODO (sven): Maybe add serialized (JSON-writable) config here?
|
| 811 |
+
}
|
| 812 |
+
)
|
| 813 |
+
return metadata
|
| 814 |
+
|
| 815 |
+
@override(Checkpointable)
|
| 816 |
+
def get_checkpointable_components(self):
|
| 817 |
+
return [
|
| 818 |
+
(COMPONENT_RL_MODULE, self.module),
|
| 819 |
+
(COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module),
|
| 820 |
+
(COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env),
|
| 821 |
+
]
|
| 822 |
+
|
| 823 |
+
@override(EnvRunner)
|
| 824 |
+
def assert_healthy(self):
|
| 825 |
+
"""Checks that self.__init__() has been completed properly.
|
| 826 |
+
|
| 827 |
+
Ensures that the instances has a `MultiRLModule` and an
|
| 828 |
+
environment defined.
|
| 829 |
+
|
| 830 |
+
Raises:
|
| 831 |
+
AssertionError: If the EnvRunner Actor has NOT been properly initialized.
|
| 832 |
+
"""
|
| 833 |
+
# Make sure, we have built our gym.vector.Env and RLModule properly.
|
| 834 |
+
assert self.env and self.module
|
| 835 |
+
|
| 836 |
+
@override(EnvRunner)
|
| 837 |
+
def make_env(self):
|
| 838 |
+
# If an env already exists, try closing it first (to allow it to properly
|
| 839 |
+
# cleanup).
|
| 840 |
+
if self.env is not None:
|
| 841 |
+
try:
|
| 842 |
+
self.env.close()
|
| 843 |
+
except Exception as e:
|
| 844 |
+
logger.warning(
|
| 845 |
+
"Tried closing the existing env (multi-agent), but failed with "
|
| 846 |
+
f"error: {e.args[0]}"
|
| 847 |
+
)
|
| 848 |
+
del self.env
|
| 849 |
+
|
| 850 |
+
env_ctx = self.config.env_config
|
| 851 |
+
if not isinstance(env_ctx, EnvContext):
|
| 852 |
+
env_ctx = EnvContext(
|
| 853 |
+
env_ctx,
|
| 854 |
+
worker_index=self.worker_index,
|
| 855 |
+
num_workers=self.config.num_env_runners,
|
| 856 |
+
remote=self.config.remote_worker_envs,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# No env provided -> Error.
|
| 860 |
+
if not self.config.env:
|
| 861 |
+
raise ValueError(
|
| 862 |
+
"`config.env` is not provided! You should provide a valid environment "
|
| 863 |
+
"to your config through `config.environment([env descriptor e.g. "
|
| 864 |
+
"'CartPole-v1'])`."
|
| 865 |
+
)
|
| 866 |
+
# Register env for the local context.
|
| 867 |
+
# Note, `gym.register` has to be called on each worker.
|
| 868 |
+
elif isinstance(self.config.env, str) and _global_registry.contains(
|
| 869 |
+
ENV_CREATOR, self.config.env
|
| 870 |
+
):
|
| 871 |
+
entry_point = partial(
|
| 872 |
+
_global_registry.get(ENV_CREATOR, self.config.env),
|
| 873 |
+
env_ctx,
|
| 874 |
+
)
|
| 875 |
+
else:
|
| 876 |
+
entry_point = partial(
|
| 877 |
+
_gym_env_creator,
|
| 878 |
+
env_descriptor=self.config.env,
|
| 879 |
+
env_context=env_ctx,
|
| 880 |
+
)
|
| 881 |
+
gym.register(
|
| 882 |
+
"rllib-multi-agent-env-v0",
|
| 883 |
+
entry_point=entry_point,
|
| 884 |
+
disable_env_checker=True,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
# Perform actual gym.make call.
|
| 888 |
+
self.env: MultiAgentEnv = gym.make("rllib-multi-agent-env-v0")
|
| 889 |
+
self.num_envs = 1
|
| 890 |
+
# If required, check the created MultiAgentEnv.
|
| 891 |
+
if not self.config.disable_env_checking:
|
| 892 |
+
try:
|
| 893 |
+
check_multiagent_environments(self.env.unwrapped)
|
| 894 |
+
except Exception as e:
|
| 895 |
+
logger.exception(e.args[0])
|
| 896 |
+
# If not required, still check the type (must be MultiAgentEnv).
|
| 897 |
+
else:
|
| 898 |
+
assert isinstance(self.env.unwrapped, MultiAgentEnv), (
|
| 899 |
+
"ERROR: When using the `MultiAgentEnvRunner` the environment needs "
|
| 900 |
+
"to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`."
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# Set the flag to reset all envs upon the next `sample()` call.
|
| 904 |
+
self._needs_initial_reset = True
|
| 905 |
+
|
| 906 |
+
# Call the `on_environment_created` callback.
|
| 907 |
+
make_callback(
|
| 908 |
+
"on_environment_created",
|
| 909 |
+
callbacks_objects=self._callbacks,
|
| 910 |
+
callbacks_functions=self.config.callbacks_on_environment_created,
|
| 911 |
+
kwargs=dict(
|
| 912 |
+
env_runner=self,
|
| 913 |
+
metrics_logger=self.metrics,
|
| 914 |
+
env=self.env.unwrapped,
|
| 915 |
+
env_context=env_ctx,
|
| 916 |
+
),
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
@override(EnvRunner)
|
| 920 |
+
def make_module(self):
|
| 921 |
+
try:
|
| 922 |
+
module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec(
|
| 923 |
+
env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True
|
| 924 |
+
)
|
| 925 |
+
# Build the module from its spec.
|
| 926 |
+
self.module = module_spec.build()
|
| 927 |
+
# Move the RLModule to our device.
|
| 928 |
+
# TODO (sven): In order to make this framework-agnostic, we should maybe
|
| 929 |
+
# make the MultiRLModule.build() method accept a device OR create an
|
| 930 |
+
# additional `(Multi)RLModule.to()` override.
|
| 931 |
+
if torch:
|
| 932 |
+
self.module.foreach_module(
|
| 933 |
+
lambda mid, mod: (
|
| 934 |
+
mod.to(self._device)
|
| 935 |
+
if isinstance(mod, torch.nn.Module)
|
| 936 |
+
else mod
|
| 937 |
+
)
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner
|
| 941 |
+
# will not have an RLModule, but might still be usable with random actions.
|
| 942 |
+
except NotImplementedError:
|
| 943 |
+
self.module = None
|
| 944 |
+
|
| 945 |
+
@override(EnvRunner)
|
| 946 |
+
def stop(self):
|
| 947 |
+
# Note, `MultiAgentEnv` inherits `close()`-method from `gym.Env`.
|
| 948 |
+
self.env.close()
|
| 949 |
+
|
| 950 |
+
def _setup_metrics(self):
|
| 951 |
+
self.metrics = MetricsLogger()
|
| 952 |
+
|
| 953 |
+
self._done_episodes_for_metrics: List[MultiAgentEpisode] = []
|
| 954 |
+
self._ongoing_episodes_for_metrics: DefaultDict[
|
| 955 |
+
EpisodeID, List[MultiAgentEpisode]
|
| 956 |
+
] = defaultdict(list)
|
| 957 |
+
|
| 958 |
+
def _new_episode(self):
|
| 959 |
+
return MultiAgentEpisode(
|
| 960 |
+
observation_space={
|
| 961 |
+
aid: self.env.unwrapped.get_observation_space(aid)
|
| 962 |
+
for aid in self.env.unwrapped.possible_agents
|
| 963 |
+
},
|
| 964 |
+
action_space={
|
| 965 |
+
aid: self.env.unwrapped.get_action_space(aid)
|
| 966 |
+
for aid in self.env.unwrapped.possible_agents
|
| 967 |
+
},
|
| 968 |
+
agent_to_module_mapping_fn=self.config.policy_mapping_fn,
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
def _make_on_episode_callback(self, which: str, episode=None):
|
| 972 |
+
episode = episode if episode is not None else self._episode
|
| 973 |
+
make_callback(
|
| 974 |
+
which,
|
| 975 |
+
callbacks_objects=self._callbacks,
|
| 976 |
+
callbacks_functions=getattr(self.config, f"callbacks_{which}"),
|
| 977 |
+
kwargs=dict(
|
| 978 |
+
episode=episode,
|
| 979 |
+
env_runner=self,
|
| 980 |
+
metrics_logger=self.metrics,
|
| 981 |
+
env=self.env.unwrapped,
|
| 982 |
+
rl_module=self.module,
|
| 983 |
+
env_index=0,
|
| 984 |
+
),
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
def _increase_sampled_metrics(self, num_steps, next_obs, episode):
|
| 988 |
+
# Env steps.
|
| 989 |
+
self.metrics.log_value(
|
| 990 |
+
NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True
|
| 991 |
+
)
|
| 992 |
+
self.metrics.log_value(
|
| 993 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 994 |
+
num_steps,
|
| 995 |
+
reduce="sum",
|
| 996 |
+
with_throughput=True,
|
| 997 |
+
)
|
| 998 |
+
# Completed episodes.
|
| 999 |
+
if episode.is_done:
|
| 1000 |
+
self.metrics.log_value(NUM_EPISODES, 1, reduce="sum", clear_on_reduce=True)
|
| 1001 |
+
self.metrics.log_value(NUM_EPISODES_LIFETIME, 1, reduce="sum")
|
| 1002 |
+
|
| 1003 |
+
# TODO (sven): obs is not-vectorized. Support vectorized MA envs.
|
| 1004 |
+
for aid in next_obs:
|
| 1005 |
+
self.metrics.log_value(
|
| 1006 |
+
(NUM_AGENT_STEPS_SAMPLED, str(aid)),
|
| 1007 |
+
1,
|
| 1008 |
+
reduce="sum",
|
| 1009 |
+
clear_on_reduce=True,
|
| 1010 |
+
)
|
| 1011 |
+
self.metrics.log_value(
|
| 1012 |
+
(NUM_AGENT_STEPS_SAMPLED_LIFETIME, str(aid)),
|
| 1013 |
+
1,
|
| 1014 |
+
reduce="sum",
|
| 1015 |
+
)
|
| 1016 |
+
self.metrics.log_value(
|
| 1017 |
+
(NUM_MODULE_STEPS_SAMPLED, episode.module_for(aid)),
|
| 1018 |
+
1,
|
| 1019 |
+
reduce="sum",
|
| 1020 |
+
clear_on_reduce=True,
|
| 1021 |
+
)
|
| 1022 |
+
self.metrics.log_value(
|
| 1023 |
+
(NUM_MODULE_STEPS_SAMPLED_LIFETIME, episode.module_for(aid)),
|
| 1024 |
+
1,
|
| 1025 |
+
reduce="sum",
|
| 1026 |
+
)
|
| 1027 |
+
return num_steps
|
| 1028 |
+
|
| 1029 |
+
def _log_episode_metrics(
|
| 1030 |
+
self,
|
| 1031 |
+
length,
|
| 1032 |
+
ret,
|
| 1033 |
+
sec,
|
| 1034 |
+
agents=None,
|
| 1035 |
+
modules=None,
|
| 1036 |
+
agent_steps=None,
|
| 1037 |
+
):
|
| 1038 |
+
# Log general episode metrics.
|
| 1039 |
+
self.metrics.log_dict(
|
| 1040 |
+
{
|
| 1041 |
+
EPISODE_LEN_MEAN: length,
|
| 1042 |
+
EPISODE_RETURN_MEAN: ret,
|
| 1043 |
+
EPISODE_DURATION_SEC_MEAN: sec,
|
| 1044 |
+
**(
|
| 1045 |
+
{
|
| 1046 |
+
# Per-agent returns.
|
| 1047 |
+
"agent_episode_returns_mean": agents,
|
| 1048 |
+
# Per-RLModule returns.
|
| 1049 |
+
"module_episode_returns_mean": modules,
|
| 1050 |
+
"agent_steps": agent_steps,
|
| 1051 |
+
}
|
| 1052 |
+
if agents is not None
|
| 1053 |
+
else {}
|
| 1054 |
+
),
|
| 1055 |
+
},
|
| 1056 |
+
# To mimick the old API stack behavior, we'll use `window` here for
|
| 1057 |
+
# these particular stats (instead of the default EMA).
|
| 1058 |
+
window=self.config.metrics_num_episodes_for_smoothing,
|
| 1059 |
+
)
|
| 1060 |
+
# For some metrics, log min/max as well.
|
| 1061 |
+
self.metrics.log_dict(
|
| 1062 |
+
{
|
| 1063 |
+
EPISODE_LEN_MIN: length,
|
| 1064 |
+
EPISODE_RETURN_MIN: ret,
|
| 1065 |
+
},
|
| 1066 |
+
reduce="min",
|
| 1067 |
+
window=self.config.metrics_num_episodes_for_smoothing,
|
| 1068 |
+
)
|
| 1069 |
+
self.metrics.log_dict(
|
| 1070 |
+
{
|
| 1071 |
+
EPISODE_LEN_MAX: length,
|
| 1072 |
+
EPISODE_RETURN_MAX: ret,
|
| 1073 |
+
},
|
| 1074 |
+
reduce="max",
|
| 1075 |
+
window=self.config.metrics_num_episodes_for_smoothing,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
@staticmethod
|
| 1079 |
+
def _prune_zero_len_sa_episodes(episode: MultiAgentEpisode):
|
| 1080 |
+
for agent_id, agent_eps in episode.agent_episodes.copy().items():
|
| 1081 |
+
if len(agent_eps) == 0:
|
| 1082 |
+
del episode.agent_episodes[agent_id]
|
| 1083 |
+
|
| 1084 |
+
@Deprecated(
|
| 1085 |
+
new="MultiAgentEnvRunner.get_state(components='rl_module')",
|
| 1086 |
+
error=False,
|
| 1087 |
+
)
|
| 1088 |
+
def get_weights(self, modules=None):
|
| 1089 |
+
rl_module_state = self.get_state(components=COMPONENT_RL_MODULE)[
|
| 1090 |
+
COMPONENT_RL_MODULE
|
| 1091 |
+
]
|
| 1092 |
+
return rl_module_state
|
| 1093 |
+
|
| 1094 |
+
@Deprecated(new="MultiAgentEnvRunner.set_state()", error=False)
|
| 1095 |
+
def set_weights(
|
| 1096 |
+
self,
|
| 1097 |
+
weights: ModelWeights,
|
| 1098 |
+
global_vars: Optional[Dict] = None,
|
| 1099 |
+
weights_seq_no: int = 0,
|
| 1100 |
+
) -> None:
|
| 1101 |
+
assert global_vars is None
|
| 1102 |
+
return self.set_state(
|
| 1103 |
+
{
|
| 1104 |
+
COMPONENT_RL_MODULE: weights,
|
| 1105 |
+
WEIGHTS_SEQ_NO: weights_seq_no,
|
| 1106 |
+
}
|
| 1107 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/env/multi_agent_episode.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/policy_client.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""REST client to interact with a policy server.
|
| 2 |
+
|
| 3 |
+
This client supports both local and remote policy inference modes. Local
|
| 4 |
+
inference is faster but causes more compute to be done on the client.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
from typing import Union, Optional
|
| 11 |
+
|
| 12 |
+
import ray.cloudpickle as pickle
|
| 13 |
+
from ray.rllib.env.external_env import ExternalEnv
|
| 14 |
+
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
| 15 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 16 |
+
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
| 17 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 18 |
+
from ray.rllib.utils.typing import (
|
| 19 |
+
MultiAgentDict,
|
| 20 |
+
EnvInfoDict,
|
| 21 |
+
EnvObsType,
|
| 22 |
+
EnvActionType,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Backward compatibility.
|
| 26 |
+
from ray.rllib.env.utils.external_env_protocol import RLlink as Commands
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import requests # `requests` is not part of stdlib.
|
| 32 |
+
except ImportError:
|
| 33 |
+
requests = None
|
| 34 |
+
logger.warning(
|
| 35 |
+
"Couldn't import `requests` library. Be sure to install it on"
|
| 36 |
+
" the client side."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@OldAPIStack
|
| 41 |
+
class PolicyClient:
|
| 42 |
+
"""REST client to interact with an RLlib policy server."""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
address: str,
|
| 47 |
+
inference_mode: str = "local",
|
| 48 |
+
update_interval: float = 10.0,
|
| 49 |
+
session: Optional[requests.Session] = None,
|
| 50 |
+
):
|
| 51 |
+
"""Create a PolicyClient instance.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
address: Server to connect to (e.g., "localhost:9090").
|
| 55 |
+
inference_mode: Whether to use 'local' or 'remote' policy
|
| 56 |
+
inference for computing actions.
|
| 57 |
+
update_interval (float or None): If using 'local' inference mode,
|
| 58 |
+
the policy is refreshed after this many seconds have passed,
|
| 59 |
+
or None for manual control via client.
|
| 60 |
+
session (requests.Session or None): If available the session object
|
| 61 |
+
is used to communicate with the policy server. Using a session
|
| 62 |
+
can lead to speedups as connections are reused. It is the
|
| 63 |
+
responsibility of the creator of the session to close it.
|
| 64 |
+
"""
|
| 65 |
+
self.address = address
|
| 66 |
+
self.session = session
|
| 67 |
+
self.env: ExternalEnv = None
|
| 68 |
+
if inference_mode == "local":
|
| 69 |
+
self.local = True
|
| 70 |
+
self._setup_local_rollout_worker(update_interval)
|
| 71 |
+
elif inference_mode == "remote":
|
| 72 |
+
self.local = False
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("inference_mode must be either 'local' or 'remote'")
|
| 75 |
+
|
| 76 |
+
def start_episode(
|
| 77 |
+
self, episode_id: Optional[str] = None, training_enabled: bool = True
|
| 78 |
+
) -> str:
|
| 79 |
+
"""Record the start of one or more episode(s).
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
episode_id (Optional[str]): Unique string id for the episode or
|
| 83 |
+
None for it to be auto-assigned.
|
| 84 |
+
training_enabled: Whether to use experiences for this
|
| 85 |
+
episode to improve the policy.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
episode_id: Unique string id for the episode.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
if self.local:
|
| 92 |
+
self._update_local_policy()
|
| 93 |
+
return self.env.start_episode(episode_id, training_enabled)
|
| 94 |
+
|
| 95 |
+
return self._send(
|
| 96 |
+
{
|
| 97 |
+
"episode_id": episode_id,
|
| 98 |
+
"command": Commands.START_EPISODE,
|
| 99 |
+
"training_enabled": training_enabled,
|
| 100 |
+
}
|
| 101 |
+
)["episode_id"]
|
| 102 |
+
|
| 103 |
+
def get_action(
|
| 104 |
+
self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
|
| 105 |
+
) -> Union[EnvActionType, MultiAgentDict]:
|
| 106 |
+
"""Record an observation and get the on-policy action.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
episode_id: Episode id returned from start_episode().
|
| 110 |
+
observation: Current environment observation.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
action: Action from the env action space.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
if self.local:
|
| 117 |
+
self._update_local_policy()
|
| 118 |
+
if isinstance(episode_id, (list, tuple)):
|
| 119 |
+
actions = {
|
| 120 |
+
eid: self.env.get_action(eid, observation[eid])
|
| 121 |
+
for eid in episode_id
|
| 122 |
+
}
|
| 123 |
+
return actions
|
| 124 |
+
else:
|
| 125 |
+
return self.env.get_action(episode_id, observation)
|
| 126 |
+
else:
|
| 127 |
+
return self._send(
|
| 128 |
+
{
|
| 129 |
+
"command": Commands.GET_ACTION,
|
| 130 |
+
"observation": observation,
|
| 131 |
+
"episode_id": episode_id,
|
| 132 |
+
}
|
| 133 |
+
)["action"]
|
| 134 |
+
|
| 135 |
+
def log_action(
|
| 136 |
+
self,
|
| 137 |
+
episode_id: str,
|
| 138 |
+
observation: Union[EnvObsType, MultiAgentDict],
|
| 139 |
+
action: Union[EnvActionType, MultiAgentDict],
|
| 140 |
+
) -> None:
|
| 141 |
+
"""Record an observation and (off-policy) action taken.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
episode_id: Episode id returned from start_episode().
|
| 145 |
+
observation: Current environment observation.
|
| 146 |
+
action: Action for the observation.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
if self.local:
|
| 150 |
+
self._update_local_policy()
|
| 151 |
+
return self.env.log_action(episode_id, observation, action)
|
| 152 |
+
|
| 153 |
+
self._send(
|
| 154 |
+
{
|
| 155 |
+
"command": Commands.LOG_ACTION,
|
| 156 |
+
"observation": observation,
|
| 157 |
+
"action": action,
|
| 158 |
+
"episode_id": episode_id,
|
| 159 |
+
}
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def log_returns(
|
| 163 |
+
self,
|
| 164 |
+
episode_id: str,
|
| 165 |
+
reward: float,
|
| 166 |
+
info: Union[EnvInfoDict, MultiAgentDict] = None,
|
| 167 |
+
multiagent_done_dict: Optional[MultiAgentDict] = None,
|
| 168 |
+
) -> None:
|
| 169 |
+
"""Record returns from the environment.
|
| 170 |
+
|
| 171 |
+
The reward will be attributed to the previous action taken by the
|
| 172 |
+
episode. Rewards accumulate until the next action. If no reward is
|
| 173 |
+
logged before the next action, a reward of 0.0 is assumed.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
episode_id: Episode id returned from start_episode().
|
| 177 |
+
reward: Reward from the environment.
|
| 178 |
+
info: Extra info dict.
|
| 179 |
+
multiagent_done_dict: Multi-agent done information.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
if self.local:
|
| 183 |
+
self._update_local_policy()
|
| 184 |
+
if multiagent_done_dict is not None:
|
| 185 |
+
assert isinstance(reward, dict)
|
| 186 |
+
return self.env.log_returns(
|
| 187 |
+
episode_id, reward, info, multiagent_done_dict
|
| 188 |
+
)
|
| 189 |
+
return self.env.log_returns(episode_id, reward, info)
|
| 190 |
+
|
| 191 |
+
self._send(
|
| 192 |
+
{
|
| 193 |
+
"command": Commands.LOG_RETURNS,
|
| 194 |
+
"reward": reward,
|
| 195 |
+
"info": info,
|
| 196 |
+
"episode_id": episode_id,
|
| 197 |
+
"done": multiagent_done_dict,
|
| 198 |
+
}
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def end_episode(
|
| 202 |
+
self, episode_id: str, observation: Union[EnvObsType, MultiAgentDict]
|
| 203 |
+
) -> None:
|
| 204 |
+
"""Record the end of an episode.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
episode_id: Episode id returned from start_episode().
|
| 208 |
+
observation: Current environment observation.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
if self.local:
|
| 212 |
+
self._update_local_policy()
|
| 213 |
+
return self.env.end_episode(episode_id, observation)
|
| 214 |
+
|
| 215 |
+
self._send(
|
| 216 |
+
{
|
| 217 |
+
"command": Commands.END_EPISODE,
|
| 218 |
+
"observation": observation,
|
| 219 |
+
"episode_id": episode_id,
|
| 220 |
+
}
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def update_policy_weights(self) -> None:
|
| 224 |
+
"""Query the server for new policy weights, if local inference is enabled."""
|
| 225 |
+
self._update_local_policy(force=True)
|
| 226 |
+
|
| 227 |
+
def _send(self, data):
|
| 228 |
+
payload = pickle.dumps(data)
|
| 229 |
+
|
| 230 |
+
if self.session is None:
|
| 231 |
+
response = requests.post(self.address, data=payload)
|
| 232 |
+
else:
|
| 233 |
+
response = self.session.post(self.address, data=payload)
|
| 234 |
+
|
| 235 |
+
if response.status_code != 200:
|
| 236 |
+
logger.error("Request failed {}: {}".format(response.text, data))
|
| 237 |
+
response.raise_for_status()
|
| 238 |
+
parsed = pickle.loads(response.content)
|
| 239 |
+
return parsed
|
| 240 |
+
|
| 241 |
+
def _setup_local_rollout_worker(self, update_interval):
|
| 242 |
+
self.update_interval = update_interval
|
| 243 |
+
self.last_updated = 0
|
| 244 |
+
|
| 245 |
+
logger.info("Querying server for rollout worker settings.")
|
| 246 |
+
kwargs = self._send(
|
| 247 |
+
{
|
| 248 |
+
"command": Commands.GET_WORKER_ARGS,
|
| 249 |
+
}
|
| 250 |
+
)["worker_args"]
|
| 251 |
+
(self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
|
| 252 |
+
kwargs, self._send
|
| 253 |
+
)
|
| 254 |
+
self.env = self.rollout_worker.env
|
| 255 |
+
|
| 256 |
+
def _update_local_policy(self, force=False):
|
| 257 |
+
assert self.inference_thread.is_alive()
|
| 258 |
+
if (
|
| 259 |
+
self.update_interval
|
| 260 |
+
and time.time() - self.last_updated > self.update_interval
|
| 261 |
+
) or force:
|
| 262 |
+
logger.info("Querying server for new policy weights.")
|
| 263 |
+
resp = self._send(
|
| 264 |
+
{
|
| 265 |
+
"command": Commands.GET_WEIGHTS,
|
| 266 |
+
}
|
| 267 |
+
)
|
| 268 |
+
weights = resp["weights"]
|
| 269 |
+
global_vars = resp["global_vars"]
|
| 270 |
+
logger.info(
|
| 271 |
+
"Updating rollout worker weights and global vars {}.".format(
|
| 272 |
+
global_vars
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
self.rollout_worker.set_weights(weights, global_vars)
|
| 276 |
+
self.last_updated = time.time()
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class _LocalInferenceThread(threading.Thread):
|
| 280 |
+
"""Thread that handles experience generation (worker.sample() loop)."""
|
| 281 |
+
|
| 282 |
+
def __init__(self, rollout_worker, send_fn):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.daemon = True
|
| 285 |
+
self.rollout_worker = rollout_worker
|
| 286 |
+
self.send_fn = send_fn
|
| 287 |
+
|
| 288 |
+
def run(self):
|
| 289 |
+
try:
|
| 290 |
+
while True:
|
| 291 |
+
logger.info("Generating new batch of experiences.")
|
| 292 |
+
samples = self.rollout_worker.sample()
|
| 293 |
+
metrics = self.rollout_worker.get_metrics()
|
| 294 |
+
if isinstance(samples, MultiAgentBatch):
|
| 295 |
+
logger.info(
|
| 296 |
+
"Sending batch of {} env steps ({} agent steps) to "
|
| 297 |
+
"server.".format(samples.env_steps(), samples.agent_steps())
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
logger.info(
|
| 301 |
+
"Sending batch of {} steps back to server.".format(
|
| 302 |
+
samples.count
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
self.send_fn(
|
| 306 |
+
{
|
| 307 |
+
"command": Commands.REPORT_SAMPLES,
|
| 308 |
+
"samples": samples,
|
| 309 |
+
"metrics": metrics,
|
| 310 |
+
}
|
| 311 |
+
)
|
| 312 |
+
except Exception as e:
|
| 313 |
+
logger.error("Error: inference worker thread died!", e)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _auto_wrap_external(real_env_creator):
|
| 317 |
+
"""Wrap an environment in the ExternalEnv interface if needed.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
real_env_creator: Create an env given the env_config.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
def wrapped_creator(env_config):
|
| 324 |
+
real_env = real_env_creator(env_config)
|
| 325 |
+
if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
|
| 326 |
+
logger.info(
|
| 327 |
+
"The env you specified is not a supported (sub-)type of "
|
| 328 |
+
"ExternalEnv. Attempting to convert it automatically to "
|
| 329 |
+
"ExternalEnv."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if isinstance(real_env, MultiAgentEnv):
|
| 333 |
+
external_cls = ExternalMultiAgentEnv
|
| 334 |
+
else:
|
| 335 |
+
external_cls = ExternalEnv
|
| 336 |
+
|
| 337 |
+
class _ExternalEnvWrapper(external_cls):
|
| 338 |
+
def __init__(self, real_env):
|
| 339 |
+
super().__init__(
|
| 340 |
+
observation_space=real_env.observation_space,
|
| 341 |
+
action_space=real_env.action_space,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def run(self):
|
| 345 |
+
# Since we are calling methods on this class in the
|
| 346 |
+
# client, run doesn't need to do anything.
|
| 347 |
+
time.sleep(999999)
|
| 348 |
+
|
| 349 |
+
return _ExternalEnvWrapper(real_env)
|
| 350 |
+
return real_env
|
| 351 |
+
|
| 352 |
+
return wrapped_creator
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _create_embedded_rollout_worker(kwargs, send_fn):
|
| 356 |
+
"""Create a local rollout worker and a thread that samples from it.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
kwargs: Args for the RolloutWorker constructor.
|
| 360 |
+
send_fn: Function to send a JSON request to the server.
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
# Since the server acts as an input datasource, we have to reset the
|
| 364 |
+
# input config to the default, which runs env rollouts.
|
| 365 |
+
kwargs = kwargs.copy()
|
| 366 |
+
kwargs["config"] = kwargs["config"].copy(copy_frozen=False)
|
| 367 |
+
config = kwargs["config"]
|
| 368 |
+
config.output = None
|
| 369 |
+
config.input_ = "sampler"
|
| 370 |
+
config.input_config = {}
|
| 371 |
+
|
| 372 |
+
# If server has no env (which is the expected case):
|
| 373 |
+
# Generate a dummy ExternalEnv here using RandomEnv and the
|
| 374 |
+
# given observation/action spaces.
|
| 375 |
+
if config.env is None:
|
| 376 |
+
from ray.rllib.examples.envs.classes.random_env import (
|
| 377 |
+
RandomEnv,
|
| 378 |
+
RandomMultiAgentEnv,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
env_config = {
|
| 382 |
+
"action_space": config.action_space,
|
| 383 |
+
"observation_space": config.observation_space,
|
| 384 |
+
}
|
| 385 |
+
is_ma = config.is_multi_agent
|
| 386 |
+
kwargs["env_creator"] = _auto_wrap_external(
|
| 387 |
+
lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(env_config)
|
| 388 |
+
)
|
| 389 |
+
# kwargs["config"].env = True
|
| 390 |
+
# Otherwise, use the env specified by the server args.
|
| 391 |
+
else:
|
| 392 |
+
real_env_creator = kwargs["env_creator"]
|
| 393 |
+
kwargs["env_creator"] = _auto_wrap_external(real_env_creator)
|
| 394 |
+
|
| 395 |
+
logger.info("Creating rollout worker with kwargs={}".format(kwargs))
|
| 396 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 397 |
+
|
| 398 |
+
rollout_worker = RolloutWorker(**kwargs)
|
| 399 |
+
|
| 400 |
+
inference_thread = _LocalInferenceThread(rollout_worker, send_fn)
|
| 401 |
+
inference_thread.start()
|
| 402 |
+
|
| 403 |
+
return rollout_worker, inference_thread
|
.venv/lib/python3.11/site-packages/ray/rllib/env/policy_server_input.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
| 3 |
+
import logging
|
| 4 |
+
import queue
|
| 5 |
+
from socketserver import ThreadingMixIn
|
| 6 |
+
import threading
|
| 7 |
+
import time
|
| 8 |
+
import traceback
|
| 9 |
+
|
| 10 |
+
from typing import List
|
| 11 |
+
import ray.cloudpickle as pickle
|
| 12 |
+
from ray.rllib.env.policy_client import (
|
| 13 |
+
_create_embedded_rollout_worker,
|
| 14 |
+
Commands,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.offline.input_reader import InputReader
|
| 17 |
+
from ray.rllib.offline.io_context import IOContext
|
| 18 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 19 |
+
from ray.rllib.utils.annotations import override, PublicAPI
|
| 20 |
+
from ray.rllib.evaluation.metrics import RolloutMetrics
|
| 21 |
+
from ray.rllib.evaluation.sampler import SamplerInput
|
| 22 |
+
from ray.rllib.utils.typing import SampleBatchType
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@PublicAPI
|
| 28 |
+
class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
|
| 29 |
+
"""REST policy server that acts as an offline data source.
|
| 30 |
+
|
| 31 |
+
This launches a multi-threaded server that listens on the specified host
|
| 32 |
+
and port to serve policy requests and forward experiences to RLlib. For
|
| 33 |
+
high performance experience collection, it implements InputReader.
|
| 34 |
+
|
| 35 |
+
For an example, run `examples/envs/external_envs/cartpole_server.py` along
|
| 36 |
+
with `examples/envs/external_envs/cartpole_client.py --inference-mode=local|remote`.
|
| 37 |
+
|
| 38 |
+
WARNING: This class is not meant to be publicly exposed. Anyone that can
|
| 39 |
+
communicate with this server can execute arbitary code on the machine. Use
|
| 40 |
+
this with caution, in isolated environments, and at your own risk.
|
| 41 |
+
|
| 42 |
+
.. testcode::
|
| 43 |
+
:skipif: True
|
| 44 |
+
|
| 45 |
+
import gymnasium as gym
|
| 46 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 47 |
+
from ray.rllib.env.policy_client import PolicyClient
|
| 48 |
+
from ray.rllib.env.policy_server_input import PolicyServerInput
|
| 49 |
+
addr, port = ...
|
| 50 |
+
config = (
|
| 51 |
+
PPOConfig()
|
| 52 |
+
.api_stack(
|
| 53 |
+
enable_rl_module_and_learner=False,
|
| 54 |
+
enable_env_runner_and_connector_v2=False,
|
| 55 |
+
)
|
| 56 |
+
.environment("CartPole-v1")
|
| 57 |
+
.offline_data(
|
| 58 |
+
input_=lambda ioctx: PolicyServerInput(ioctx, addr, port)
|
| 59 |
+
)
|
| 60 |
+
# Run just 1 server (in the Algorithm's EnvRunnerGroup).
|
| 61 |
+
.env_runners(num_env_runners=0)
|
| 62 |
+
)
|
| 63 |
+
algo = config.build()
|
| 64 |
+
while True:
|
| 65 |
+
algo.train()
|
| 66 |
+
client = PolicyClient(
|
| 67 |
+
"localhost:9900", inference_mode="local")
|
| 68 |
+
eps_id = client.start_episode()
|
| 69 |
+
env = gym.make("CartPole-v1")
|
| 70 |
+
obs, info = env.reset()
|
| 71 |
+
action = client.get_action(eps_id, obs)
|
| 72 |
+
_, reward, _, _, _ = env.step(action)
|
| 73 |
+
client.log_returns(eps_id, reward)
|
| 74 |
+
client.log_returns(eps_id, reward)
|
| 75 |
+
algo.stop()
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
@PublicAPI
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
ioctx: IOContext,
|
| 82 |
+
address: str,
|
| 83 |
+
port: int,
|
| 84 |
+
idle_timeout: float = 3.0,
|
| 85 |
+
max_sample_queue_size: int = 20,
|
| 86 |
+
):
|
| 87 |
+
"""Create a PolicyServerInput.
|
| 88 |
+
|
| 89 |
+
This class implements rllib.offline.InputReader, and can be used with
|
| 90 |
+
any Algorithm by configuring
|
| 91 |
+
|
| 92 |
+
[AlgorithmConfig object]
|
| 93 |
+
.env_runners(num_env_runners=0)
|
| 94 |
+
.offline_data(input_=lambda ioctx: PolicyServerInput(ioctx, addr, port))
|
| 95 |
+
|
| 96 |
+
Note that by setting num_env_runners: 0, the algorithm will only create one
|
| 97 |
+
rollout worker / PolicyServerInput. Clients can connect to the launched
|
| 98 |
+
server using rllib.env.PolicyClient. You can increase the number of available
|
| 99 |
+
connections (ports) by setting num_env_runners to a larger number. The ports
|
| 100 |
+
used will then be `port` + the worker's index.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
ioctx: IOContext provided by RLlib.
|
| 104 |
+
address: Server addr (e.g., "localhost").
|
| 105 |
+
port: Server port (e.g., 9900).
|
| 106 |
+
max_queue_size: The maximum size for the sample queue. Once full, will
|
| 107 |
+
purge (throw away) 50% of all samples, oldest first, and continue.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
self.rollout_worker = ioctx.worker
|
| 111 |
+
# Protect ourselves from having a bottleneck on the server (learning) side.
|
| 112 |
+
# Once the queue (deque) is full, we throw away 50% (oldest
|
| 113 |
+
# samples first) of the samples, warn, and continue.
|
| 114 |
+
self.samples_queue = deque(maxlen=max_sample_queue_size)
|
| 115 |
+
self.metrics_queue = queue.Queue()
|
| 116 |
+
self.idle_timeout = idle_timeout
|
| 117 |
+
|
| 118 |
+
# Forwards client-reported metrics directly into the local rollout
|
| 119 |
+
# worker.
|
| 120 |
+
if self.rollout_worker.sampler is not None:
|
| 121 |
+
# This is a bit of a hack since it is patching the get_metrics
|
| 122 |
+
# function of the sampler.
|
| 123 |
+
|
| 124 |
+
def get_metrics():
|
| 125 |
+
completed = []
|
| 126 |
+
while True:
|
| 127 |
+
try:
|
| 128 |
+
completed.append(self.metrics_queue.get_nowait())
|
| 129 |
+
except queue.Empty:
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
return completed
|
| 133 |
+
|
| 134 |
+
self.rollout_worker.sampler.get_metrics = get_metrics
|
| 135 |
+
else:
|
| 136 |
+
# If there is no sampler, act like if there would be one to collect
|
| 137 |
+
# metrics from
|
| 138 |
+
class MetricsDummySampler(SamplerInput):
|
| 139 |
+
"""This sampler only maintains a queue to get metrics from."""
|
| 140 |
+
|
| 141 |
+
def __init__(self, metrics_queue):
|
| 142 |
+
"""Initializes a MetricsDummySampler instance.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
metrics_queue: A queue of metrics
|
| 146 |
+
"""
|
| 147 |
+
self.metrics_queue = metrics_queue
|
| 148 |
+
|
| 149 |
+
def get_data(self) -> SampleBatchType:
|
| 150 |
+
raise NotImplementedError
|
| 151 |
+
|
| 152 |
+
def get_extra_batches(self) -> List[SampleBatchType]:
|
| 153 |
+
raise NotImplementedError
|
| 154 |
+
|
| 155 |
+
def get_metrics(self) -> List[RolloutMetrics]:
|
| 156 |
+
"""Returns metrics computed on a policy client rollout worker."""
|
| 157 |
+
completed = []
|
| 158 |
+
while True:
|
| 159 |
+
try:
|
| 160 |
+
completed.append(self.metrics_queue.get_nowait())
|
| 161 |
+
except queue.Empty:
|
| 162 |
+
break
|
| 163 |
+
return completed
|
| 164 |
+
|
| 165 |
+
self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue)
|
| 166 |
+
|
| 167 |
+
# Create a request handler that receives commands from the clients
|
| 168 |
+
# and sends data and metrics into the queues.
|
| 169 |
+
handler = _make_handler(
|
| 170 |
+
self.rollout_worker, self.samples_queue, self.metrics_queue
|
| 171 |
+
)
|
| 172 |
+
try:
|
| 173 |
+
import time
|
| 174 |
+
|
| 175 |
+
time.sleep(1)
|
| 176 |
+
HTTPServer.__init__(self, (address, port), handler)
|
| 177 |
+
except OSError:
|
| 178 |
+
print(f"Creating a PolicyServer on {address}:{port} failed!")
|
| 179 |
+
import time
|
| 180 |
+
|
| 181 |
+
time.sleep(1)
|
| 182 |
+
raise
|
| 183 |
+
|
| 184 |
+
logger.info(
|
| 185 |
+
"Starting connector server at " f"{self.server_name}:{self.server_port}"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Start the serving thread, listening on socket and handling commands.
|
| 189 |
+
serving_thread = threading.Thread(name="server", target=self.serve_forever)
|
| 190 |
+
serving_thread.daemon = True
|
| 191 |
+
serving_thread.start()
|
| 192 |
+
|
| 193 |
+
# Start a dummy thread that puts empty SampleBatches on the queue, just
|
| 194 |
+
# in case we don't receive anything from clients (or there aren't
|
| 195 |
+
# any). The latter would block sample collection entirely otherwise,
|
| 196 |
+
# even if other workers' PolicyServerInput receive incoming data from
|
| 197 |
+
# actual clients.
|
| 198 |
+
heart_beat_thread = threading.Thread(
|
| 199 |
+
name="heart-beat", target=self._put_empty_sample_batch_every_n_sec
|
| 200 |
+
)
|
| 201 |
+
heart_beat_thread.daemon = True
|
| 202 |
+
heart_beat_thread.start()
|
| 203 |
+
|
| 204 |
+
@override(InputReader)
|
| 205 |
+
def next(self):
|
| 206 |
+
# Blocking wait until there is something in the deque.
|
| 207 |
+
while len(self.samples_queue) == 0:
|
| 208 |
+
time.sleep(0.1)
|
| 209 |
+
# Utilize last items first in order to remain as closely as possible
|
| 210 |
+
# to operating on-policy.
|
| 211 |
+
return self.samples_queue.pop()
|
| 212 |
+
|
| 213 |
+
def _put_empty_sample_batch_every_n_sec(self):
|
| 214 |
+
# Places an empty SampleBatch every `idle_timeout` seconds onto the
|
| 215 |
+
# `samples_queue`. This avoids hanging of all RolloutWorkers parallel
|
| 216 |
+
# to this one in case this PolicyServerInput does not have incoming
|
| 217 |
+
# data (e.g. no client connected) and the driver algorithm uses parallel
|
| 218 |
+
# synchronous sampling (e.g. PPO).
|
| 219 |
+
while True:
|
| 220 |
+
time.sleep(self.idle_timeout)
|
| 221 |
+
self.samples_queue.append(SampleBatch())
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _make_handler(rollout_worker, samples_queue, metrics_queue):
|
| 225 |
+
# Only used in remote inference mode. We must create a new rollout worker
|
| 226 |
+
# then since the original worker doesn't have the env properly wrapped in
|
| 227 |
+
# an ExternalEnv interface.
|
| 228 |
+
child_rollout_worker = None
|
| 229 |
+
inference_thread = None
|
| 230 |
+
lock = threading.Lock()
|
| 231 |
+
|
| 232 |
+
def setup_child_rollout_worker():
|
| 233 |
+
nonlocal lock
|
| 234 |
+
|
| 235 |
+
with lock:
|
| 236 |
+
nonlocal child_rollout_worker
|
| 237 |
+
nonlocal inference_thread
|
| 238 |
+
|
| 239 |
+
if child_rollout_worker is None:
|
| 240 |
+
(
|
| 241 |
+
child_rollout_worker,
|
| 242 |
+
inference_thread,
|
| 243 |
+
) = _create_embedded_rollout_worker(
|
| 244 |
+
rollout_worker.creation_args(), report_data
|
| 245 |
+
)
|
| 246 |
+
child_rollout_worker.set_weights(rollout_worker.get_weights())
|
| 247 |
+
|
| 248 |
+
def report_data(data):
|
| 249 |
+
nonlocal child_rollout_worker
|
| 250 |
+
|
| 251 |
+
batch = data["samples"]
|
| 252 |
+
batch.decompress_if_needed()
|
| 253 |
+
samples_queue.append(batch)
|
| 254 |
+
# Deque is full -> purge 50% (oldest samples)
|
| 255 |
+
if len(samples_queue) == samples_queue.maxlen:
|
| 256 |
+
logger.warning(
|
| 257 |
+
"PolicyServerInput queue is full! Purging half of the samples (oldest)."
|
| 258 |
+
)
|
| 259 |
+
for _ in range(samples_queue.maxlen // 2):
|
| 260 |
+
samples_queue.popleft()
|
| 261 |
+
for rollout_metric in data["metrics"]:
|
| 262 |
+
metrics_queue.put(rollout_metric)
|
| 263 |
+
|
| 264 |
+
if child_rollout_worker is not None:
|
| 265 |
+
child_rollout_worker.set_weights(
|
| 266 |
+
rollout_worker.get_weights(), rollout_worker.get_global_vars()
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
class Handler(SimpleHTTPRequestHandler):
|
| 270 |
+
def __init__(self, *a, **kw):
|
| 271 |
+
super().__init__(*a, **kw)
|
| 272 |
+
|
| 273 |
+
def do_POST(self):
|
| 274 |
+
content_len = int(self.headers.get("Content-Length"), 0)
|
| 275 |
+
raw_body = self.rfile.read(content_len)
|
| 276 |
+
parsed_input = pickle.loads(raw_body)
|
| 277 |
+
try:
|
| 278 |
+
response = self.execute_command(parsed_input)
|
| 279 |
+
self.send_response(200)
|
| 280 |
+
self.end_headers()
|
| 281 |
+
self.wfile.write(pickle.dumps(response))
|
| 282 |
+
except Exception:
|
| 283 |
+
self.send_error(500, traceback.format_exc())
|
| 284 |
+
|
| 285 |
+
def execute_command(self, args):
|
| 286 |
+
command = args["command"]
|
| 287 |
+
response = {}
|
| 288 |
+
|
| 289 |
+
# Local inference commands:
|
| 290 |
+
if command == Commands.GET_WORKER_ARGS:
|
| 291 |
+
logger.info("Sending worker creation args to client.")
|
| 292 |
+
response["worker_args"] = rollout_worker.creation_args()
|
| 293 |
+
elif command == Commands.GET_WEIGHTS:
|
| 294 |
+
logger.info("Sending worker weights to client.")
|
| 295 |
+
response["weights"] = rollout_worker.get_weights()
|
| 296 |
+
response["global_vars"] = rollout_worker.get_global_vars()
|
| 297 |
+
elif command == Commands.REPORT_SAMPLES:
|
| 298 |
+
logger.info(
|
| 299 |
+
"Got sample batch of size {} from client.".format(
|
| 300 |
+
args["samples"].count
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
report_data(args)
|
| 304 |
+
|
| 305 |
+
# Remote inference commands:
|
| 306 |
+
elif command == Commands.START_EPISODE:
|
| 307 |
+
setup_child_rollout_worker()
|
| 308 |
+
assert inference_thread.is_alive()
|
| 309 |
+
response["episode_id"] = child_rollout_worker.env.start_episode(
|
| 310 |
+
args["episode_id"], args["training_enabled"]
|
| 311 |
+
)
|
| 312 |
+
elif command == Commands.GET_ACTION:
|
| 313 |
+
assert inference_thread.is_alive()
|
| 314 |
+
response["action"] = child_rollout_worker.env.get_action(
|
| 315 |
+
args["episode_id"], args["observation"]
|
| 316 |
+
)
|
| 317 |
+
elif command == Commands.LOG_ACTION:
|
| 318 |
+
assert inference_thread.is_alive()
|
| 319 |
+
child_rollout_worker.env.log_action(
|
| 320 |
+
args["episode_id"], args["observation"], args["action"]
|
| 321 |
+
)
|
| 322 |
+
elif command == Commands.LOG_RETURNS:
|
| 323 |
+
assert inference_thread.is_alive()
|
| 324 |
+
if args["done"]:
|
| 325 |
+
child_rollout_worker.env.log_returns(
|
| 326 |
+
args["episode_id"], args["reward"], args["info"], args["done"]
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
child_rollout_worker.env.log_returns(
|
| 330 |
+
args["episode_id"], args["reward"], args["info"]
|
| 331 |
+
)
|
| 332 |
+
elif command == Commands.END_EPISODE:
|
| 333 |
+
assert inference_thread.is_alive()
|
| 334 |
+
child_rollout_worker.env.end_episode(
|
| 335 |
+
args["episode_id"], args["observation"]
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
raise ValueError("Unknown command: {}".format(command))
|
| 339 |
+
return response
|
| 340 |
+
|
| 341 |
+
return Handler
|
.venv/lib/python3.11/site-packages/ray/rllib/env/remote_base_env.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
import ray
|
| 6 |
+
from ray.util import log_once
|
| 7 |
+
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
| 8 |
+
from ray.rllib.utils.annotations import override, OldAPIStack
|
| 9 |
+
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@OldAPIStack
|
| 18 |
+
class RemoteBaseEnv(BaseEnv):
|
| 19 |
+
"""BaseEnv that executes its sub environments as @ray.remote actors.
|
| 20 |
+
|
| 21 |
+
This provides dynamic batching of inference as observations are returned
|
| 22 |
+
from the remote simulator actors. Both single and multi-agent child envs
|
| 23 |
+
are supported, and envs can be stepped synchronously or asynchronously.
|
| 24 |
+
|
| 25 |
+
NOTE: This class implicitly assumes that the remote envs are gym.Env's
|
| 26 |
+
|
| 27 |
+
You shouldn't need to instantiate this class directly. It's automatically
|
| 28 |
+
inserted when you use the `remote_worker_envs=True` option in your
|
| 29 |
+
Algorithm's config.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
make_env: Callable[[int], EnvType],
|
| 35 |
+
num_envs: int,
|
| 36 |
+
multiagent: bool,
|
| 37 |
+
remote_env_batch_wait_ms: int,
|
| 38 |
+
existing_envs: Optional[List[ray.actor.ActorHandle]] = None,
|
| 39 |
+
worker: Optional["RolloutWorker"] = None,
|
| 40 |
+
restart_failed_sub_environments: bool = False,
|
| 41 |
+
):
|
| 42 |
+
"""Initializes a RemoteVectorEnv instance.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
make_env: Callable that produces a single (non-vectorized) env,
|
| 46 |
+
given the vector env index as only arg.
|
| 47 |
+
num_envs: The number of sub-environments to create for the
|
| 48 |
+
vectorization.
|
| 49 |
+
multiagent: Whether this is a multiagent env or not.
|
| 50 |
+
remote_env_batch_wait_ms: Time to wait for (ray.remote)
|
| 51 |
+
sub-environments to have new observations available when
|
| 52 |
+
polled. Only when none of the sub-environments is ready,
|
| 53 |
+
repeat the `ray.wait()` call until at least one sub-env
|
| 54 |
+
is ready. Then return only the observations of the ready
|
| 55 |
+
sub-environment(s).
|
| 56 |
+
existing_envs: Optional list of already created sub-environments.
|
| 57 |
+
These will be used as-is and only as many new sub-envs as
|
| 58 |
+
necessary (`num_envs - len(existing_envs)`) will be created.
|
| 59 |
+
worker: An optional RolloutWorker that owns the env. This is only
|
| 60 |
+
used if `remote_worker_envs` is True in your config and the
|
| 61 |
+
`on_sub_environment_created` custom callback needs to be
|
| 62 |
+
called on each created actor.
|
| 63 |
+
restart_failed_sub_environments: If True and any sub-environment (within
|
| 64 |
+
a vectorized env) throws any error during env stepping, the
|
| 65 |
+
Sampler will try to restart the faulty sub-environment. This is done
|
| 66 |
+
without disturbing the other (still intact) sub-environment and without
|
| 67 |
+
the RolloutWorker crashing.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Could be creating local or remote envs.
|
| 71 |
+
self.make_env = make_env
|
| 72 |
+
self.num_envs = num_envs
|
| 73 |
+
self.multiagent = multiagent
|
| 74 |
+
self.poll_timeout = remote_env_batch_wait_ms / 1000
|
| 75 |
+
self.worker = worker
|
| 76 |
+
self.restart_failed_sub_environments = restart_failed_sub_environments
|
| 77 |
+
|
| 78 |
+
# Already existing env objects (generated by the RolloutWorker).
|
| 79 |
+
existing_envs = existing_envs or []
|
| 80 |
+
|
| 81 |
+
# Whether the given `make_env` callable already returns ActorHandles
|
| 82 |
+
# (@ray.remote class instances) or not.
|
| 83 |
+
self.make_env_creates_actors = False
|
| 84 |
+
|
| 85 |
+
self._observation_space = None
|
| 86 |
+
self._action_space = None
|
| 87 |
+
|
| 88 |
+
# List of ray actor handles (each handle points to one @ray.remote
|
| 89 |
+
# sub-environment).
|
| 90 |
+
self.actors: Optional[List[ray.actor.ActorHandle]] = None
|
| 91 |
+
|
| 92 |
+
# `self.make_env` already produces Actors: Use it directly.
|
| 93 |
+
if len(existing_envs) > 0 and isinstance(
|
| 94 |
+
existing_envs[0], ray.actor.ActorHandle
|
| 95 |
+
):
|
| 96 |
+
self.make_env_creates_actors = True
|
| 97 |
+
self.actors = existing_envs
|
| 98 |
+
while len(self.actors) < self.num_envs:
|
| 99 |
+
self.actors.append(self._make_sub_env(len(self.actors)))
|
| 100 |
+
|
| 101 |
+
# `self.make_env` produces gym.Envs (or children thereof, such
|
| 102 |
+
# as MultiAgentEnv): Need to auto-wrap it here. The problem with
|
| 103 |
+
# this is that custom methods wil get lost. If you would like to
|
| 104 |
+
# keep your custom methods in your envs, you should provide the
|
| 105 |
+
# env class directly in your config (w/o tune.register_env()),
|
| 106 |
+
# such that your class can directly be made a @ray.remote
|
| 107 |
+
# (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
|
| 108 |
+
# Also, if `len(existing_envs) > 0`, we have to throw those away
|
| 109 |
+
# as we need to create ray actors here.
|
| 110 |
+
else:
|
| 111 |
+
self.actors = [self._make_sub_env(i) for i in range(self.num_envs)]
|
| 112 |
+
# Utilize existing envs for inferring observation/action spaces.
|
| 113 |
+
if len(existing_envs) > 0:
|
| 114 |
+
self._observation_space = existing_envs[0].observation_space
|
| 115 |
+
self._action_space = existing_envs[0].action_space
|
| 116 |
+
# Have to call actors' remote methods to get observation/action spaces.
|
| 117 |
+
else:
|
| 118 |
+
self._observation_space, self._action_space = ray.get(
|
| 119 |
+
[
|
| 120 |
+
self.actors[0].observation_space.remote(),
|
| 121 |
+
self.actors[0].action_space.remote(),
|
| 122 |
+
]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Dict mapping object refs (return values of @ray.remote calls),
|
| 126 |
+
# whose actual values we are waiting for (via ray.wait in
|
| 127 |
+
# `self.poll()`) to their corresponding actor handles (the actors
|
| 128 |
+
# that created these return values).
|
| 129 |
+
# Call `reset()` on all @ray.remote sub-environment actors.
|
| 130 |
+
self.pending: Dict[ray.actor.ActorHandle] = {
|
| 131 |
+
a.reset.remote(): a for a in self.actors
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
@override(BaseEnv)
|
| 135 |
+
def poll(
|
| 136 |
+
self,
|
| 137 |
+
) -> Tuple[
|
| 138 |
+
MultiEnvDict,
|
| 139 |
+
MultiEnvDict,
|
| 140 |
+
MultiEnvDict,
|
| 141 |
+
MultiEnvDict,
|
| 142 |
+
MultiEnvDict,
|
| 143 |
+
MultiEnvDict,
|
| 144 |
+
]:
|
| 145 |
+
|
| 146 |
+
# each keyed by env_id in [0, num_remote_envs)
|
| 147 |
+
obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {}
|
| 148 |
+
ready = []
|
| 149 |
+
|
| 150 |
+
# Wait for at least 1 env to be ready here.
|
| 151 |
+
while not ready:
|
| 152 |
+
ready, _ = ray.wait(
|
| 153 |
+
list(self.pending),
|
| 154 |
+
num_returns=len(self.pending),
|
| 155 |
+
timeout=self.poll_timeout,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Get and return observations for each of the ready envs
|
| 159 |
+
env_ids = set()
|
| 160 |
+
for obj_ref in ready:
|
| 161 |
+
# Get the corresponding actor handle from our dict and remove the
|
| 162 |
+
# object ref (we will call `ray.get()` on it and it will no longer
|
| 163 |
+
# be "pending").
|
| 164 |
+
actor = self.pending.pop(obj_ref)
|
| 165 |
+
env_id = self.actors.index(actor)
|
| 166 |
+
env_ids.add(env_id)
|
| 167 |
+
# Get the ready object ref (this may be return value(s) of
|
| 168 |
+
# `reset()` or `step()`).
|
| 169 |
+
try:
|
| 170 |
+
ret = ray.get(obj_ref)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
# Something happened on the actor during stepping/resetting.
|
| 173 |
+
# Restart sub-environment (create new actor; close old one).
|
| 174 |
+
if self.restart_failed_sub_environments:
|
| 175 |
+
logger.exception(e.args[0])
|
| 176 |
+
self.try_restart(env_id)
|
| 177 |
+
# Always return multi-agent data.
|
| 178 |
+
# Set the observation to the exception, no rewards,
|
| 179 |
+
# terminated[__all__]=True (episode will be discarded anyways),
|
| 180 |
+
# no infos.
|
| 181 |
+
ret = (
|
| 182 |
+
e,
|
| 183 |
+
{},
|
| 184 |
+
{"__all__": True},
|
| 185 |
+
{"__all__": False},
|
| 186 |
+
{},
|
| 187 |
+
)
|
| 188 |
+
# Do not try to restart. Just raise the error.
|
| 189 |
+
else:
|
| 190 |
+
raise e
|
| 191 |
+
|
| 192 |
+
# Our sub-envs are simple Actor-turned gym.Envs or MultiAgentEnvs.
|
| 193 |
+
if self.make_env_creates_actors:
|
| 194 |
+
rew, terminated, truncated, info = None, None, None, None
|
| 195 |
+
if self.multiagent:
|
| 196 |
+
if isinstance(ret, tuple):
|
| 197 |
+
# Gym >= 0.26: `step()` result: Obs, reward, terminated,
|
| 198 |
+
# truncated, info.
|
| 199 |
+
if len(ret) == 5:
|
| 200 |
+
ob, rew, terminated, truncated, info = ret
|
| 201 |
+
# Gym >= 0.26: `reset()` result: Obs and infos.
|
| 202 |
+
elif len(ret) == 2:
|
| 203 |
+
ob = ret[0]
|
| 204 |
+
info = ret[1]
|
| 205 |
+
# Gym < 0.26? Something went wrong.
|
| 206 |
+
else:
|
| 207 |
+
raise AssertionError(
|
| 208 |
+
"Your gymnasium.Env seems to NOT return the correct "
|
| 209 |
+
"number of return values for `step()` (needs to return"
|
| 210 |
+
" 5 values: obs, reward, terminated, truncated and "
|
| 211 |
+
"info) or `reset()` (needs to return 2 values: obs and "
|
| 212 |
+
"info)!"
|
| 213 |
+
)
|
| 214 |
+
# Gym < 0.26: `reset()` result: Only obs.
|
| 215 |
+
else:
|
| 216 |
+
raise AssertionError(
|
| 217 |
+
"Your gymnasium.Env seems to only return a single value "
|
| 218 |
+
"upon `reset()`! Must return 2 (obs AND infos)."
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
if isinstance(ret, tuple):
|
| 222 |
+
# `step()` result: Obs, reward, terminated, truncated, info.
|
| 223 |
+
if len(ret) == 5:
|
| 224 |
+
ob = {_DUMMY_AGENT_ID: ret[0]}
|
| 225 |
+
rew = {_DUMMY_AGENT_ID: ret[1]}
|
| 226 |
+
terminated = {_DUMMY_AGENT_ID: ret[2], "__all__": ret[2]}
|
| 227 |
+
truncated = {_DUMMY_AGENT_ID: ret[3], "__all__": ret[3]}
|
| 228 |
+
info = {_DUMMY_AGENT_ID: ret[4]}
|
| 229 |
+
# `reset()` result: Obs and infos.
|
| 230 |
+
elif len(ret) == 2:
|
| 231 |
+
ob = {_DUMMY_AGENT_ID: ret[0]}
|
| 232 |
+
info = {_DUMMY_AGENT_ID: ret[1]}
|
| 233 |
+
# Gym < 0.26? Something went wrong.
|
| 234 |
+
else:
|
| 235 |
+
raise AssertionError(
|
| 236 |
+
"Your gymnasium.Env seems to NOT return the correct "
|
| 237 |
+
"number of return values for `step()` (needs to return"
|
| 238 |
+
" 5 values: obs, reward, terminated, truncated and "
|
| 239 |
+
"info) or `reset()` (needs to return 2 values: obs and "
|
| 240 |
+
"info)!"
|
| 241 |
+
)
|
| 242 |
+
# Gym < 0.26?
|
| 243 |
+
else:
|
| 244 |
+
raise AssertionError(
|
| 245 |
+
"Your gymnasium.Env seems to only return a single value "
|
| 246 |
+
"upon `reset()`! Must return 2 (obs and infos)."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# If this is a `reset()` return value, we only have the initial
|
| 250 |
+
# observations and infos: Set rewards, terminateds, and truncateds to
|
| 251 |
+
# dummy values.
|
| 252 |
+
if rew is None:
|
| 253 |
+
rew = {agent_id: 0 for agent_id in ob.keys()}
|
| 254 |
+
terminated = {"__all__": False}
|
| 255 |
+
truncated = {"__all__": False}
|
| 256 |
+
|
| 257 |
+
# Our sub-envs are auto-wrapped (by `_RemoteSingleAgentEnv` or
|
| 258 |
+
# `_RemoteMultiAgentEnv`) and already behave like multi-agent
|
| 259 |
+
# envs.
|
| 260 |
+
else:
|
| 261 |
+
ob, rew, terminated, truncated, info = ret
|
| 262 |
+
obs[env_id] = ob
|
| 263 |
+
rewards[env_id] = rew
|
| 264 |
+
terminateds[env_id] = terminated
|
| 265 |
+
truncateds[env_id] = truncated
|
| 266 |
+
infos[env_id] = info
|
| 267 |
+
|
| 268 |
+
logger.debug(f"Got obs batch for actors {env_ids}")
|
| 269 |
+
return obs, rewards, terminateds, truncateds, infos, {}
|
| 270 |
+
|
| 271 |
+
@override(BaseEnv)
|
| 272 |
+
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
| 273 |
+
for env_id, actions in action_dict.items():
|
| 274 |
+
actor = self.actors[env_id]
|
| 275 |
+
# `actor` is a simple single-agent (remote) env, e.g. a gym.Env
|
| 276 |
+
# that was made a @ray.remote.
|
| 277 |
+
if not self.multiagent and self.make_env_creates_actors:
|
| 278 |
+
obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID])
|
| 279 |
+
# `actor` is already a _RemoteSingleAgentEnv or
|
| 280 |
+
# _RemoteMultiAgentEnv wrapper
|
| 281 |
+
# (handles the multi-agent action_dict automatically).
|
| 282 |
+
else:
|
| 283 |
+
obj_ref = actor.step.remote(actions)
|
| 284 |
+
self.pending[obj_ref] = actor
|
| 285 |
+
|
| 286 |
+
@override(BaseEnv)
|
| 287 |
+
def try_reset(
|
| 288 |
+
self,
|
| 289 |
+
env_id: Optional[EnvID] = None,
|
| 290 |
+
*,
|
| 291 |
+
seed: Optional[int] = None,
|
| 292 |
+
options: Optional[dict] = None,
|
| 293 |
+
) -> Tuple[MultiEnvDict, MultiEnvDict]:
|
| 294 |
+
actor = self.actors[env_id]
|
| 295 |
+
obj_ref = actor.reset.remote(seed=seed, options=options)
|
| 296 |
+
|
| 297 |
+
self.pending[obj_ref] = actor
|
| 298 |
+
# Because this env type does not support synchronous reset requests (with
|
| 299 |
+
# immediate return value), we return ASYNC_RESET_RETURN here to indicate
|
| 300 |
+
# that the reset results will be available via the next `poll()` call.
|
| 301 |
+
return ASYNC_RESET_RETURN, ASYNC_RESET_RETURN
|
| 302 |
+
|
| 303 |
+
@override(BaseEnv)
|
| 304 |
+
def try_restart(self, env_id: Optional[EnvID] = None) -> None:
|
| 305 |
+
# Try closing down the old (possibly faulty) sub-env, but ignore errors.
|
| 306 |
+
try:
|
| 307 |
+
# Close the env on the remote side.
|
| 308 |
+
self.actors[env_id].close.remote()
|
| 309 |
+
except Exception as e:
|
| 310 |
+
if log_once("close_sub_env"):
|
| 311 |
+
logger.warning(
|
| 312 |
+
"Trying to close old and replaced sub-environment (at vector "
|
| 313 |
+
f"index={env_id}), but closing resulted in error:\n{e}"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Terminate the actor itself to free up its resources.
|
| 317 |
+
self.actors[env_id].__ray_terminate__.remote()
|
| 318 |
+
|
| 319 |
+
# Re-create a new sub-environment.
|
| 320 |
+
self.actors[env_id] = self._make_sub_env(env_id)
|
| 321 |
+
|
| 322 |
+
@override(BaseEnv)
|
| 323 |
+
def stop(self) -> None:
|
| 324 |
+
if self.actors is not None:
|
| 325 |
+
for actor in self.actors:
|
| 326 |
+
actor.__ray_terminate__.remote()
|
| 327 |
+
|
| 328 |
+
@override(BaseEnv)
|
| 329 |
+
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
|
| 330 |
+
if as_dict:
|
| 331 |
+
return {env_id: actor for env_id, actor in enumerate(self.actors)}
|
| 332 |
+
return self.actors
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
@override(BaseEnv)
|
| 336 |
+
def observation_space(self) -> gym.spaces.Dict:
|
| 337 |
+
return self._observation_space
|
| 338 |
+
|
| 339 |
+
@property
|
| 340 |
+
@override(BaseEnv)
|
| 341 |
+
def action_space(self) -> gym.Space:
|
| 342 |
+
return self._action_space
|
| 343 |
+
|
| 344 |
+
def _make_sub_env(self, idx: Optional[int] = None):
|
| 345 |
+
"""Re-creates a sub-environment at the new index."""
|
| 346 |
+
|
| 347 |
+
# Our `make_env` creates ray actors directly.
|
| 348 |
+
if self.make_env_creates_actors:
|
| 349 |
+
sub_env = self.make_env(idx)
|
| 350 |
+
if self.worker is not None:
|
| 351 |
+
self.worker.callbacks.on_sub_environment_created(
|
| 352 |
+
worker=self.worker,
|
| 353 |
+
sub_environment=self.actors[idx],
|
| 354 |
+
env_context=self.worker.env_context.copy_with_overrides(
|
| 355 |
+
vector_index=idx
|
| 356 |
+
),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Our `make_env` returns actual envs -> Have to convert them into actors
|
| 360 |
+
# using our utility wrapper classes.
|
| 361 |
+
else:
|
| 362 |
+
|
| 363 |
+
def make_remote_env(i):
|
| 364 |
+
logger.info("Launching env {} in remote actor".format(i))
|
| 365 |
+
if self.multiagent:
|
| 366 |
+
sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i)
|
| 367 |
+
else:
|
| 368 |
+
sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i)
|
| 369 |
+
|
| 370 |
+
if self.worker is not None:
|
| 371 |
+
self.worker.callbacks.on_sub_environment_created(
|
| 372 |
+
worker=self.worker,
|
| 373 |
+
sub_environment=sub_env,
|
| 374 |
+
env_context=self.worker.env_context.copy_with_overrides(
|
| 375 |
+
vector_index=i
|
| 376 |
+
),
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
return sub_env
|
| 380 |
+
|
| 381 |
+
sub_env = make_remote_env(idx)
|
| 382 |
+
|
| 383 |
+
return sub_env
|
| 384 |
+
|
| 385 |
+
@override(BaseEnv)
|
| 386 |
+
def get_agent_ids(self) -> Set[AgentID]:
|
| 387 |
+
if self.multiagent:
|
| 388 |
+
return ray.get(self.actors[0].get_agent_ids.remote())
|
| 389 |
+
else:
|
| 390 |
+
return {_DUMMY_AGENT_ID}
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@ray.remote(num_cpus=0)
|
| 394 |
+
class _RemoteMultiAgentEnv:
|
| 395 |
+
"""Wrapper class for making a multi-agent env a remote actor."""
|
| 396 |
+
|
| 397 |
+
def __init__(self, make_env, i):
|
| 398 |
+
self.env = make_env(i)
|
| 399 |
+
self.agent_ids = set()
|
| 400 |
+
|
| 401 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 402 |
+
obs, info = self.env.reset(seed=seed, options=options)
|
| 403 |
+
|
| 404 |
+
# each keyed by agent_id in the env
|
| 405 |
+
rew = {}
|
| 406 |
+
for agent_id in obs.keys():
|
| 407 |
+
self.agent_ids.add(agent_id)
|
| 408 |
+
rew[agent_id] = 0.0
|
| 409 |
+
terminated = {"__all__": False}
|
| 410 |
+
truncated = {"__all__": False}
|
| 411 |
+
return obs, rew, terminated, truncated, info
|
| 412 |
+
|
| 413 |
+
def step(self, action_dict):
|
| 414 |
+
return self.env.step(action_dict)
|
| 415 |
+
|
| 416 |
+
# Defining these 2 functions that way this information can be queried
|
| 417 |
+
# with a call to ray.get().
|
| 418 |
+
def observation_space(self):
|
| 419 |
+
return self.env.observation_space
|
| 420 |
+
|
| 421 |
+
def action_space(self):
|
| 422 |
+
return self.env.action_space
|
| 423 |
+
|
| 424 |
+
def get_agent_ids(self) -> Set[AgentID]:
|
| 425 |
+
return self.agent_ids
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@ray.remote(num_cpus=0)
|
| 429 |
+
class _RemoteSingleAgentEnv:
|
| 430 |
+
"""Wrapper class for making a gym env a remote actor."""
|
| 431 |
+
|
| 432 |
+
def __init__(self, make_env, i):
|
| 433 |
+
self.env = make_env(i)
|
| 434 |
+
|
| 435 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 436 |
+
obs_and_info = self.env.reset(seed=seed, options=options)
|
| 437 |
+
|
| 438 |
+
obs = {_DUMMY_AGENT_ID: obs_and_info[0]}
|
| 439 |
+
info = {_DUMMY_AGENT_ID: obs_and_info[1]}
|
| 440 |
+
|
| 441 |
+
rew = {_DUMMY_AGENT_ID: 0.0}
|
| 442 |
+
terminated = {"__all__": False}
|
| 443 |
+
truncated = {"__all__": False}
|
| 444 |
+
return obs, rew, terminated, truncated, info
|
| 445 |
+
|
| 446 |
+
def step(self, action):
|
| 447 |
+
results = self.env.step(action[_DUMMY_AGENT_ID])
|
| 448 |
+
|
| 449 |
+
obs, rew, terminated, truncated, info = [{_DUMMY_AGENT_ID: x} for x in results]
|
| 450 |
+
|
| 451 |
+
terminated["__all__"] = terminated[_DUMMY_AGENT_ID]
|
| 452 |
+
truncated["__all__"] = truncated[_DUMMY_AGENT_ID]
|
| 453 |
+
|
| 454 |
+
return obs, rew, terminated, truncated, info
|
| 455 |
+
|
| 456 |
+
# Defining these 2 functions that way this information can be queried
|
| 457 |
+
# with a call to ray.get().
|
| 458 |
+
def observation_space(self):
|
| 459 |
+
return self.env.observation_space
|
| 460 |
+
|
| 461 |
+
def action_space(self):
|
| 462 |
+
return self.env.action_space
|
.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_env_runner.py
ADDED
|
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from functools import partial
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from typing import Collection, DefaultDict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import gymnasium as gym
|
| 8 |
+
from gymnasium.wrappers.vector import DictInfoToList
|
| 9 |
+
|
| 10 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
| 11 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 12 |
+
from ray.rllib.callbacks.utils import make_callback
|
| 13 |
+
from ray.rllib.core import (
|
| 14 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR,
|
| 15 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR,
|
| 16 |
+
COMPONENT_RL_MODULE,
|
| 17 |
+
DEFAULT_AGENT_ID,
|
| 18 |
+
DEFAULT_MODULE_ID,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.core.columns import Columns
|
| 21 |
+
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
|
| 22 |
+
from ray.rllib.env import INPUT_ENV_SPACES
|
| 23 |
+
from ray.rllib.env.env_context import EnvContext
|
| 24 |
+
from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
|
| 25 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 26 |
+
from ray.rllib.env.utils import _gym_env_creator
|
| 27 |
+
from ray.rllib.utils import force_list
|
| 28 |
+
from ray.rllib.utils.annotations import override
|
| 29 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 30 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 31 |
+
from ray.rllib.utils.framework import get_device
|
| 32 |
+
from ray.rllib.utils.metrics import (
|
| 33 |
+
EPISODE_DURATION_SEC_MEAN,
|
| 34 |
+
EPISODE_LEN_MAX,
|
| 35 |
+
EPISODE_LEN_MEAN,
|
| 36 |
+
EPISODE_LEN_MIN,
|
| 37 |
+
EPISODE_RETURN_MAX,
|
| 38 |
+
EPISODE_RETURN_MEAN,
|
| 39 |
+
EPISODE_RETURN_MIN,
|
| 40 |
+
NUM_AGENT_STEPS_SAMPLED,
|
| 41 |
+
NUM_AGENT_STEPS_SAMPLED_LIFETIME,
|
| 42 |
+
NUM_ENV_STEPS_SAMPLED,
|
| 43 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 44 |
+
NUM_EPISODES,
|
| 45 |
+
NUM_EPISODES_LIFETIME,
|
| 46 |
+
NUM_MODULE_STEPS_SAMPLED,
|
| 47 |
+
NUM_MODULE_STEPS_SAMPLED_LIFETIME,
|
| 48 |
+
SAMPLE_TIMER,
|
| 49 |
+
TIME_BETWEEN_SAMPLING,
|
| 50 |
+
WEIGHTS_SEQ_NO,
|
| 51 |
+
)
|
| 52 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 53 |
+
from ray.rllib.utils.spaces.space_utils import unbatch
|
| 54 |
+
from ray.rllib.utils.typing import EpisodeID, ResultDict, StateDict
|
| 55 |
+
from ray.tune.registry import ENV_CREATOR, _global_registry
|
| 56 |
+
from ray.util.annotations import PublicAPI
|
| 57 |
+
|
| 58 |
+
logger = logging.getLogger("ray.rllib")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself
|
| 62 |
+
# a Checkpointable. Currently, only some of its subclasses are Checkpointables.
|
| 63 |
+
@PublicAPI(stability="alpha")
|
| 64 |
+
class SingleAgentEnvRunner(EnvRunner, Checkpointable):
|
| 65 |
+
"""The generic environment runner for the single agent case."""
|
| 66 |
+
|
| 67 |
+
@override(EnvRunner)
|
| 68 |
+
def __init__(self, *, config: AlgorithmConfig, **kwargs):
|
| 69 |
+
"""Initializes a SingleAgentEnvRunner instance.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
config: An `AlgorithmConfig` object containing all settings needed to
|
| 73 |
+
build this `EnvRunner` class.
|
| 74 |
+
"""
|
| 75 |
+
super().__init__(config=config)
|
| 76 |
+
|
| 77 |
+
self.worker_index: int = kwargs.get("worker_index")
|
| 78 |
+
self.num_workers: int = kwargs.get("num_workers", self.config.num_env_runners)
|
| 79 |
+
self.tune_trial_id: str = kwargs.get("tune_trial_id")
|
| 80 |
+
|
| 81 |
+
# Create a MetricsLogger object for logging custom stats.
|
| 82 |
+
self.metrics = MetricsLogger()
|
| 83 |
+
|
| 84 |
+
# Create our callbacks object.
|
| 85 |
+
self._callbacks: List[RLlibCallback] = [
|
| 86 |
+
cls() for cls in force_list(self.config.callbacks_class)
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
# Set device.
|
| 90 |
+
self._device = get_device(
|
| 91 |
+
self.config,
|
| 92 |
+
0 if not self.worker_index else self.config.num_gpus_per_env_runner,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Create the vectorized gymnasium env.
|
| 96 |
+
self.env: Optional[gym.vector.VectorEnvWrapper] = None
|
| 97 |
+
self.num_envs: int = 0
|
| 98 |
+
self.make_env()
|
| 99 |
+
|
| 100 |
+
# Create the env-to-module connector pipeline.
|
| 101 |
+
self._env_to_module = self.config.build_env_to_module_connector(
|
| 102 |
+
self.env, device=self._device
|
| 103 |
+
)
|
| 104 |
+
# Cached env-to-module results taken at the end of a `_sample_timesteps()`
|
| 105 |
+
# call to make sure the final observation (before an episode cut) gets properly
|
| 106 |
+
# processed (and maybe postprocessed and re-stored into the episode).
|
| 107 |
+
# For example, if we had a connector that normalizes observations and directly
|
| 108 |
+
# re-inserts these new obs back into the episode, the last observation in each
|
| 109 |
+
# sample call would NOT be processed, which could be very harmful in cases,
|
| 110 |
+
# in which value function bootstrapping of those (truncation) observations is
|
| 111 |
+
# required in the learning step.
|
| 112 |
+
self._cached_to_module = None
|
| 113 |
+
|
| 114 |
+
# Create the RLModule.
|
| 115 |
+
self.module: Optional[RLModule] = None
|
| 116 |
+
self.make_module()
|
| 117 |
+
|
| 118 |
+
# Create the module-to-env connector pipeline.
|
| 119 |
+
self._module_to_env = self.config.build_module_to_env_connector(self.env)
|
| 120 |
+
|
| 121 |
+
# This should be the default.
|
| 122 |
+
self._needs_initial_reset: bool = True
|
| 123 |
+
self._episodes: List[Optional[SingleAgentEpisode]] = [
|
| 124 |
+
None for _ in range(self.num_envs)
|
| 125 |
+
]
|
| 126 |
+
self._shared_data = None
|
| 127 |
+
|
| 128 |
+
self._done_episodes_for_metrics: List[SingleAgentEpisode] = []
|
| 129 |
+
self._ongoing_episodes_for_metrics: DefaultDict[
|
| 130 |
+
EpisodeID, List[SingleAgentEpisode]
|
| 131 |
+
] = defaultdict(list)
|
| 132 |
+
self._weights_seq_no: int = 0
|
| 133 |
+
|
| 134 |
+
# Measures the time passed between returning from `sample()`
|
| 135 |
+
# and receiving the next `sample()` request from the user.
|
| 136 |
+
self._time_after_sampling = None
|
| 137 |
+
|
| 138 |
+
@override(EnvRunner)
|
| 139 |
+
def sample(
|
| 140 |
+
self,
|
| 141 |
+
*,
|
| 142 |
+
num_timesteps: int = None,
|
| 143 |
+
num_episodes: int = None,
|
| 144 |
+
explore: bool = None,
|
| 145 |
+
random_actions: bool = False,
|
| 146 |
+
force_reset: bool = False,
|
| 147 |
+
) -> List[SingleAgentEpisode]:
|
| 148 |
+
"""Runs and returns a sample (n timesteps or m episodes) on the env(s).
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
num_timesteps: The number of timesteps to sample during this call.
|
| 152 |
+
Note that only one of `num_timetseps` or `num_episodes` may be provided.
|
| 153 |
+
num_episodes: The number of episodes to sample during this call.
|
| 154 |
+
Note that only one of `num_timetseps` or `num_episodes` may be provided.
|
| 155 |
+
explore: If True, will use the RLModule's `forward_exploration()`
|
| 156 |
+
method to compute actions. If False, will use the RLModule's
|
| 157 |
+
`forward_inference()` method. If None (default), will use the `explore`
|
| 158 |
+
boolean setting from `self.config` passed into this EnvRunner's
|
| 159 |
+
constructor. You can change this setting in your config via
|
| 160 |
+
`config.env_runners(explore=True|False)`.
|
| 161 |
+
random_actions: If True, actions will be sampled randomly (from the action
|
| 162 |
+
space of the environment). If False (default), actions or action
|
| 163 |
+
distribution parameters are computed by the RLModule.
|
| 164 |
+
force_reset: Whether to force-reset all (vector) environments before
|
| 165 |
+
sampling. Useful if you would like to collect a clean slate of new
|
| 166 |
+
episodes via this call. Note that when sampling n episodes
|
| 167 |
+
(`num_episodes != None`), this is fixed to True.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
A list of `SingleAgentEpisode` instances, carrying the sampled data.
|
| 171 |
+
"""
|
| 172 |
+
assert not (num_timesteps is not None and num_episodes is not None)
|
| 173 |
+
|
| 174 |
+
# Log time between `sample()` requests.
|
| 175 |
+
if self._time_after_sampling is not None:
|
| 176 |
+
self.metrics.log_value(
|
| 177 |
+
key=TIME_BETWEEN_SAMPLING,
|
| 178 |
+
value=time.perf_counter() - self._time_after_sampling,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Log current weight seq no.
|
| 182 |
+
self.metrics.log_value(
|
| 183 |
+
key=WEIGHTS_SEQ_NO,
|
| 184 |
+
value=self._weights_seq_no,
|
| 185 |
+
window=1,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
with self.metrics.log_time(SAMPLE_TIMER):
|
| 189 |
+
# If no execution details are provided, use the config to try to infer the
|
| 190 |
+
# desired timesteps/episodes to sample and exploration behavior.
|
| 191 |
+
if explore is None:
|
| 192 |
+
explore = self.config.explore
|
| 193 |
+
if (
|
| 194 |
+
num_timesteps is None
|
| 195 |
+
and num_episodes is None
|
| 196 |
+
and self.config.batch_mode == "truncate_episodes"
|
| 197 |
+
):
|
| 198 |
+
num_timesteps = (
|
| 199 |
+
self.config.get_rollout_fragment_length(self.worker_index)
|
| 200 |
+
* self.num_envs
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Sample n timesteps.
|
| 204 |
+
if num_timesteps is not None:
|
| 205 |
+
samples = self._sample(
|
| 206 |
+
num_timesteps=num_timesteps,
|
| 207 |
+
explore=explore,
|
| 208 |
+
random_actions=random_actions,
|
| 209 |
+
force_reset=force_reset,
|
| 210 |
+
)
|
| 211 |
+
# Sample m episodes.
|
| 212 |
+
elif num_episodes is not None:
|
| 213 |
+
samples = self._sample(
|
| 214 |
+
num_episodes=num_episodes,
|
| 215 |
+
explore=explore,
|
| 216 |
+
random_actions=random_actions,
|
| 217 |
+
)
|
| 218 |
+
# For complete episodes mode, sample as long as the number of timesteps
|
| 219 |
+
# done is smaller than the `train_batch_size`.
|
| 220 |
+
else:
|
| 221 |
+
samples = self._sample(
|
| 222 |
+
num_episodes=self.num_envs,
|
| 223 |
+
explore=explore,
|
| 224 |
+
random_actions=random_actions,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Make the `on_sample_end` callback.
|
| 228 |
+
make_callback(
|
| 229 |
+
"on_sample_end",
|
| 230 |
+
callbacks_objects=self._callbacks,
|
| 231 |
+
callbacks_functions=self.config.callbacks_on_sample_end,
|
| 232 |
+
kwargs=dict(
|
| 233 |
+
env_runner=self,
|
| 234 |
+
metrics_logger=self.metrics,
|
| 235 |
+
samples=samples,
|
| 236 |
+
),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self._time_after_sampling = time.perf_counter()
|
| 240 |
+
|
| 241 |
+
return samples
|
| 242 |
+
|
| 243 |
+
def _sample(
|
| 244 |
+
self,
|
| 245 |
+
*,
|
| 246 |
+
num_timesteps: Optional[int] = None,
|
| 247 |
+
num_episodes: Optional[int] = None,
|
| 248 |
+
explore: bool,
|
| 249 |
+
random_actions: bool = False,
|
| 250 |
+
force_reset: bool = False,
|
| 251 |
+
) -> List[SingleAgentEpisode]:
|
| 252 |
+
"""Helper method to sample n timesteps or m episodes."""
|
| 253 |
+
|
| 254 |
+
done_episodes_to_return: List[SingleAgentEpisode] = []
|
| 255 |
+
|
| 256 |
+
# Have to reset the env (on all vector sub_envs).
|
| 257 |
+
if force_reset or num_episodes is not None or self._needs_initial_reset:
|
| 258 |
+
episodes = self._episodes = [None for _ in range(self.num_envs)]
|
| 259 |
+
shared_data = self._shared_data = {}
|
| 260 |
+
self._reset_envs(episodes, shared_data, explore)
|
| 261 |
+
# We just reset the env. Don't have to force this again in the next
|
| 262 |
+
# call to `self._sample_timesteps()`.
|
| 263 |
+
self._needs_initial_reset = False
|
| 264 |
+
else:
|
| 265 |
+
episodes = self._episodes
|
| 266 |
+
shared_data = self._shared_data
|
| 267 |
+
|
| 268 |
+
if num_episodes is not None:
|
| 269 |
+
self._needs_initial_reset = True
|
| 270 |
+
|
| 271 |
+
# Loop through `num_timesteps` timesteps or `num_episodes` episodes.
|
| 272 |
+
ts = 0
|
| 273 |
+
eps = 0
|
| 274 |
+
while (
|
| 275 |
+
(ts < num_timesteps) if num_timesteps is not None else (eps < num_episodes)
|
| 276 |
+
):
|
| 277 |
+
# Act randomly.
|
| 278 |
+
if random_actions:
|
| 279 |
+
to_env = {
|
| 280 |
+
Columns.ACTIONS: self.env.action_space.sample(),
|
| 281 |
+
}
|
| 282 |
+
# Compute an action using the RLModule.
|
| 283 |
+
else:
|
| 284 |
+
# Env-to-module connector (already cached).
|
| 285 |
+
to_module = self._cached_to_module
|
| 286 |
+
assert to_module is not None
|
| 287 |
+
self._cached_to_module = None
|
| 288 |
+
|
| 289 |
+
# RLModule forward pass: Explore or not.
|
| 290 |
+
if explore:
|
| 291 |
+
# Global env steps sampled are (roughly) this EnvRunner's lifetime
|
| 292 |
+
# count times the number of env runners in the algo.
|
| 293 |
+
global_env_steps_lifetime = (
|
| 294 |
+
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
|
| 295 |
+
+ ts
|
| 296 |
+
) * (self.config.num_env_runners or 1)
|
| 297 |
+
to_env = self.module.forward_exploration(
|
| 298 |
+
to_module, t=global_env_steps_lifetime
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
to_env = self.module.forward_inference(to_module)
|
| 302 |
+
|
| 303 |
+
# Module-to-env connector.
|
| 304 |
+
to_env = self._module_to_env(
|
| 305 |
+
rl_module=self.module,
|
| 306 |
+
batch=to_env,
|
| 307 |
+
episodes=episodes,
|
| 308 |
+
explore=explore,
|
| 309 |
+
shared_data=shared_data,
|
| 310 |
+
metrics=self.metrics,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Extract the (vectorized) actions (to be sent to the env) from the
|
| 314 |
+
# module/connector output. Note that these actions are fully ready (e.g.
|
| 315 |
+
# already unsquashed/clipped) to be sent to the environment) and might not
|
| 316 |
+
# be identical to the actions produced by the RLModule/distribution, which
|
| 317 |
+
# are the ones stored permanently in the episode objects.
|
| 318 |
+
actions = to_env.pop(Columns.ACTIONS)
|
| 319 |
+
actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions)
|
| 320 |
+
# Try stepping the environment.
|
| 321 |
+
results = self._try_env_step(actions_for_env)
|
| 322 |
+
if results == ENV_STEP_FAILURE:
|
| 323 |
+
return self._sample(
|
| 324 |
+
num_timesteps=num_timesteps,
|
| 325 |
+
num_episodes=num_episodes,
|
| 326 |
+
explore=explore,
|
| 327 |
+
random_actions=random_actions,
|
| 328 |
+
force_reset=True,
|
| 329 |
+
)
|
| 330 |
+
observations, rewards, terminateds, truncateds, infos = results
|
| 331 |
+
observations, actions = unbatch(observations), unbatch(actions)
|
| 332 |
+
|
| 333 |
+
call_on_episode_start = set()
|
| 334 |
+
for env_index in range(self.num_envs):
|
| 335 |
+
extra_model_output = {k: v[env_index] for k, v in to_env.items()}
|
| 336 |
+
extra_model_output[WEIGHTS_SEQ_NO] = self._weights_seq_no
|
| 337 |
+
|
| 338 |
+
# Episode has no data in it yet -> Was just reset and needs to be called
|
| 339 |
+
# with its `add_env_reset()` method.
|
| 340 |
+
if not self._episodes[env_index].is_reset:
|
| 341 |
+
episodes[env_index].add_env_reset(
|
| 342 |
+
observation=observations[env_index],
|
| 343 |
+
infos=infos[env_index],
|
| 344 |
+
)
|
| 345 |
+
call_on_episode_start.add(env_index)
|
| 346 |
+
|
| 347 |
+
# Call `add_env_step()` method on episode.
|
| 348 |
+
else:
|
| 349 |
+
# Only increase ts when we actually stepped (not reset'd as a reset
|
| 350 |
+
# does not count as a timestep).
|
| 351 |
+
ts += 1
|
| 352 |
+
episodes[env_index].add_env_step(
|
| 353 |
+
observation=observations[env_index],
|
| 354 |
+
action=actions[env_index],
|
| 355 |
+
reward=rewards[env_index],
|
| 356 |
+
infos=infos[env_index],
|
| 357 |
+
terminated=terminateds[env_index],
|
| 358 |
+
truncated=truncateds[env_index],
|
| 359 |
+
extra_model_outputs=extra_model_output,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Env-to-module connector pass (cache results as we will do the RLModule
|
| 363 |
+
# forward pass only in the next `while`-iteration.
|
| 364 |
+
if self.module is not None:
|
| 365 |
+
self._cached_to_module = self._env_to_module(
|
| 366 |
+
episodes=episodes,
|
| 367 |
+
explore=explore,
|
| 368 |
+
rl_module=self.module,
|
| 369 |
+
shared_data=shared_data,
|
| 370 |
+
metrics=self.metrics,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
for env_index in range(self.num_envs):
|
| 374 |
+
# Call `on_episode_start()` callback (always after reset).
|
| 375 |
+
if env_index in call_on_episode_start:
|
| 376 |
+
self._make_on_episode_callback(
|
| 377 |
+
"on_episode_start", env_index, episodes
|
| 378 |
+
)
|
| 379 |
+
# Make the `on_episode_step` callbacks.
|
| 380 |
+
else:
|
| 381 |
+
self._make_on_episode_callback(
|
| 382 |
+
"on_episode_step", env_index, episodes
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Episode is done.
|
| 386 |
+
if episodes[env_index].is_done:
|
| 387 |
+
eps += 1
|
| 388 |
+
|
| 389 |
+
# Make the `on_episode_end` callbacks (before finalizing the episode
|
| 390 |
+
# object).
|
| 391 |
+
self._make_on_episode_callback(
|
| 392 |
+
"on_episode_end", env_index, episodes
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Numpy'ize the episode.
|
| 396 |
+
if self.config.episodes_to_numpy:
|
| 397 |
+
# Any possibly compress observations.
|
| 398 |
+
done_episodes_to_return.append(episodes[env_index].to_numpy())
|
| 399 |
+
# Leave episode as lists of individual (obs, action, etc..) items.
|
| 400 |
+
else:
|
| 401 |
+
done_episodes_to_return.append(episodes[env_index])
|
| 402 |
+
|
| 403 |
+
# Also early-out if we reach the number of episodes within this
|
| 404 |
+
# for-loop.
|
| 405 |
+
if eps == num_episodes:
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
# Create a new episode object with no data in it and execute
|
| 409 |
+
# `on_episode_created` callback (before the `env.reset()` call).
|
| 410 |
+
episodes[env_index] = SingleAgentEpisode(
|
| 411 |
+
observation_space=self.env.single_observation_space,
|
| 412 |
+
action_space=self.env.single_action_space,
|
| 413 |
+
)
|
| 414 |
+
self._make_on_episode_callback(
|
| 415 |
+
"on_episode_created",
|
| 416 |
+
env_index,
|
| 417 |
+
episodes,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Return done episodes ...
|
| 421 |
+
self._done_episodes_for_metrics.extend(done_episodes_to_return)
|
| 422 |
+
# ... and all ongoing episode chunks.
|
| 423 |
+
|
| 424 |
+
# Also, make sure we start new episode chunks (continuing the ongoing episodes
|
| 425 |
+
# from the to-be-returned chunks).
|
| 426 |
+
ongoing_episodes_to_return = []
|
| 427 |
+
# Only if we are doing individual timesteps: We have to maybe cut an ongoing
|
| 428 |
+
# episode and continue building it on the next call to `sample()`.
|
| 429 |
+
if num_timesteps is not None:
|
| 430 |
+
ongoing_episodes_continuations = [
|
| 431 |
+
eps.cut(len_lookback_buffer=self.config.episode_lookback_horizon)
|
| 432 |
+
for eps in self._episodes
|
| 433 |
+
]
|
| 434 |
+
|
| 435 |
+
for eps in self._episodes:
|
| 436 |
+
# Just started Episodes do not have to be returned. There is no data
|
| 437 |
+
# in them anyway.
|
| 438 |
+
if eps.t == 0:
|
| 439 |
+
continue
|
| 440 |
+
eps.validate()
|
| 441 |
+
self._ongoing_episodes_for_metrics[eps.id_].append(eps)
|
| 442 |
+
|
| 443 |
+
# Numpy'ize the episode.
|
| 444 |
+
if self.config.episodes_to_numpy:
|
| 445 |
+
# Any possibly compress observations.
|
| 446 |
+
ongoing_episodes_to_return.append(eps.to_numpy())
|
| 447 |
+
# Leave episode as lists of individual (obs, action, etc..) items.
|
| 448 |
+
else:
|
| 449 |
+
ongoing_episodes_to_return.append(eps)
|
| 450 |
+
|
| 451 |
+
# Continue collecting into the cut Episode chunks.
|
| 452 |
+
self._episodes = ongoing_episodes_continuations
|
| 453 |
+
|
| 454 |
+
self._increase_sampled_metrics(ts, len(done_episodes_to_return))
|
| 455 |
+
|
| 456 |
+
# Return collected episode data.
|
| 457 |
+
return done_episodes_to_return + ongoing_episodes_to_return
|
| 458 |
+
|
| 459 |
+
@override(EnvRunner)
|
| 460 |
+
def get_spaces(self):
|
| 461 |
+
return {
|
| 462 |
+
INPUT_ENV_SPACES: (self.env.observation_space, self.env.action_space),
|
| 463 |
+
DEFAULT_MODULE_ID: (
|
| 464 |
+
self._env_to_module.observation_space,
|
| 465 |
+
self.env.single_action_space,
|
| 466 |
+
),
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
@override(EnvRunner)
|
| 470 |
+
def get_metrics(self) -> ResultDict:
|
| 471 |
+
# Compute per-episode metrics (only on already completed episodes).
|
| 472 |
+
for eps in self._done_episodes_for_metrics:
|
| 473 |
+
assert eps.is_done
|
| 474 |
+
episode_length = len(eps)
|
| 475 |
+
episode_return = eps.get_return()
|
| 476 |
+
episode_duration_s = eps.get_duration_s()
|
| 477 |
+
# Don't forget about the already returned chunks of this episode.
|
| 478 |
+
if eps.id_ in self._ongoing_episodes_for_metrics:
|
| 479 |
+
for eps2 in self._ongoing_episodes_for_metrics[eps.id_]:
|
| 480 |
+
episode_length += len(eps2)
|
| 481 |
+
episode_return += eps2.get_return()
|
| 482 |
+
episode_duration_s += eps2.get_duration_s()
|
| 483 |
+
del self._ongoing_episodes_for_metrics[eps.id_]
|
| 484 |
+
|
| 485 |
+
self._log_episode_metrics(
|
| 486 |
+
episode_length, episode_return, episode_duration_s
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Now that we have logged everything, clear cache of done episodes.
|
| 490 |
+
self._done_episodes_for_metrics.clear()
|
| 491 |
+
|
| 492 |
+
# Return reduced metrics.
|
| 493 |
+
return self.metrics.reduce()
|
| 494 |
+
|
| 495 |
+
@override(Checkpointable)
|
| 496 |
+
def get_state(
|
| 497 |
+
self,
|
| 498 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 499 |
+
*,
|
| 500 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 501 |
+
**kwargs,
|
| 502 |
+
) -> StateDict:
|
| 503 |
+
state = {
|
| 504 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
|
| 505 |
+
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0)
|
| 506 |
+
),
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
if self._check_component(COMPONENT_RL_MODULE, components, not_components):
|
| 510 |
+
state[COMPONENT_RL_MODULE] = self.module.get_state(
|
| 511 |
+
components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
|
| 512 |
+
not_components=self._get_subcomponents(
|
| 513 |
+
COMPONENT_RL_MODULE, not_components
|
| 514 |
+
),
|
| 515 |
+
**kwargs,
|
| 516 |
+
)
|
| 517 |
+
state[WEIGHTS_SEQ_NO] = self._weights_seq_no
|
| 518 |
+
if self._check_component(
|
| 519 |
+
COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components
|
| 520 |
+
):
|
| 521 |
+
state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state()
|
| 522 |
+
if self._check_component(
|
| 523 |
+
COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components
|
| 524 |
+
):
|
| 525 |
+
state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state()
|
| 526 |
+
|
| 527 |
+
return state
|
| 528 |
+
|
| 529 |
+
@override(Checkpointable)
|
| 530 |
+
def set_state(self, state: StateDict) -> None:
|
| 531 |
+
if COMPONENT_ENV_TO_MODULE_CONNECTOR in state:
|
| 532 |
+
self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR])
|
| 533 |
+
if COMPONENT_MODULE_TO_ENV_CONNECTOR in state:
|
| 534 |
+
self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR])
|
| 535 |
+
|
| 536 |
+
# Update the RLModule state.
|
| 537 |
+
if COMPONENT_RL_MODULE in state:
|
| 538 |
+
# A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
|
| 539 |
+
# update.
|
| 540 |
+
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
|
| 541 |
+
|
| 542 |
+
# Only update the weigths, if this is the first synchronization or
|
| 543 |
+
# if the weights of this `EnvRunner` lacks behind the actual ones.
|
| 544 |
+
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
|
| 545 |
+
rl_module_state = state[COMPONENT_RL_MODULE]
|
| 546 |
+
if (
|
| 547 |
+
isinstance(rl_module_state, dict)
|
| 548 |
+
and DEFAULT_MODULE_ID in rl_module_state
|
| 549 |
+
):
|
| 550 |
+
rl_module_state = rl_module_state[DEFAULT_MODULE_ID]
|
| 551 |
+
self.module.set_state(rl_module_state)
|
| 552 |
+
|
| 553 |
+
# Update our weights_seq_no, if the new one is > 0.
|
| 554 |
+
if weights_seq_no > 0:
|
| 555 |
+
self._weights_seq_no = weights_seq_no
|
| 556 |
+
|
| 557 |
+
# Update our lifetime counters.
|
| 558 |
+
if NUM_ENV_STEPS_SAMPLED_LIFETIME in state:
|
| 559 |
+
self.metrics.set_value(
|
| 560 |
+
key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 561 |
+
value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME],
|
| 562 |
+
reduce="sum",
|
| 563 |
+
with_throughput=True,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
@override(Checkpointable)
|
| 567 |
+
def get_ctor_args_and_kwargs(self):
|
| 568 |
+
return (
|
| 569 |
+
(), # *args
|
| 570 |
+
{"config": self.config}, # **kwargs
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
@override(Checkpointable)
|
| 574 |
+
def get_metadata(self):
|
| 575 |
+
metadata = Checkpointable.get_metadata(self)
|
| 576 |
+
metadata.update(
|
| 577 |
+
{
|
| 578 |
+
# TODO (sven): Maybe add serialized (JSON-writable) config here?
|
| 579 |
+
}
|
| 580 |
+
)
|
| 581 |
+
return metadata
|
| 582 |
+
|
| 583 |
+
@override(Checkpointable)
|
| 584 |
+
def get_checkpointable_components(self):
|
| 585 |
+
return [
|
| 586 |
+
(COMPONENT_RL_MODULE, self.module),
|
| 587 |
+
(COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module),
|
| 588 |
+
(COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env),
|
| 589 |
+
]
|
| 590 |
+
|
| 591 |
+
@override(EnvRunner)
|
| 592 |
+
def assert_healthy(self):
|
| 593 |
+
"""Checks that self.__init__() has been completed properly.
|
| 594 |
+
|
| 595 |
+
Ensures that the instances has a `MultiRLModule` and an
|
| 596 |
+
environment defined.
|
| 597 |
+
|
| 598 |
+
Raises:
|
| 599 |
+
AssertionError: If the EnvRunner Actor has NOT been properly initialized.
|
| 600 |
+
"""
|
| 601 |
+
# Make sure, we have built our gym.vector.Env and RLModule properly.
|
| 602 |
+
assert self.env and hasattr(self, "module")
|
| 603 |
+
|
| 604 |
+
@override(EnvRunner)
|
| 605 |
+
def make_env(self) -> None:
|
| 606 |
+
"""Creates a vectorized gymnasium env and stores it in `self.env`.
|
| 607 |
+
|
| 608 |
+
Note that users can change the EnvRunner's config (e.g. change
|
| 609 |
+
`self.config.env_config`) and then call this method to create new environments
|
| 610 |
+
with the updated configuration.
|
| 611 |
+
"""
|
| 612 |
+
# If an env already exists, try closing it first (to allow it to properly
|
| 613 |
+
# cleanup).
|
| 614 |
+
if self.env is not None:
|
| 615 |
+
try:
|
| 616 |
+
self.env.close()
|
| 617 |
+
except Exception as e:
|
| 618 |
+
logger.warning(
|
| 619 |
+
"Tried closing the existing env, but failed with error: "
|
| 620 |
+
f"{e.args[0]}"
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
env_ctx = self.config.env_config
|
| 624 |
+
if not isinstance(env_ctx, EnvContext):
|
| 625 |
+
env_ctx = EnvContext(
|
| 626 |
+
env_ctx,
|
| 627 |
+
worker_index=self.worker_index,
|
| 628 |
+
num_workers=self.num_workers,
|
| 629 |
+
remote=self.config.remote_worker_envs,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# No env provided -> Error.
|
| 633 |
+
if not self.config.env:
|
| 634 |
+
raise ValueError(
|
| 635 |
+
"`config.env` is not provided! You should provide a valid environment "
|
| 636 |
+
"to your config through `config.environment([env descriptor e.g. "
|
| 637 |
+
"'CartPole-v1'])`."
|
| 638 |
+
)
|
| 639 |
+
# Register env for the local context.
|
| 640 |
+
# Note, `gym.register` has to be called on each worker.
|
| 641 |
+
elif isinstance(self.config.env, str) and _global_registry.contains(
|
| 642 |
+
ENV_CREATOR, self.config.env
|
| 643 |
+
):
|
| 644 |
+
entry_point = partial(
|
| 645 |
+
_global_registry.get(ENV_CREATOR, self.config.env),
|
| 646 |
+
env_ctx,
|
| 647 |
+
)
|
| 648 |
+
else:
|
| 649 |
+
entry_point = partial(
|
| 650 |
+
_gym_env_creator,
|
| 651 |
+
env_descriptor=self.config.env,
|
| 652 |
+
env_context=env_ctx,
|
| 653 |
+
)
|
| 654 |
+
gym.register("rllib-single-agent-env-v0", entry_point=entry_point)
|
| 655 |
+
vectorize_mode = self.config.gym_env_vectorize_mode
|
| 656 |
+
|
| 657 |
+
self.env = DictInfoToList(
|
| 658 |
+
gym.make_vec(
|
| 659 |
+
"rllib-single-agent-env-v0",
|
| 660 |
+
num_envs=self.config.num_envs_per_env_runner,
|
| 661 |
+
vectorization_mode=(
|
| 662 |
+
vectorize_mode
|
| 663 |
+
if isinstance(vectorize_mode, gym.envs.registration.VectorizeMode)
|
| 664 |
+
else gym.envs.registration.VectorizeMode(vectorize_mode.lower())
|
| 665 |
+
),
|
| 666 |
+
)
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
self.num_envs: int = self.env.num_envs
|
| 670 |
+
assert self.num_envs == self.config.num_envs_per_env_runner
|
| 671 |
+
|
| 672 |
+
# Set the flag to reset all envs upon the next `sample()` call.
|
| 673 |
+
self._needs_initial_reset = True
|
| 674 |
+
|
| 675 |
+
# Call the `on_environment_created` callback.
|
| 676 |
+
make_callback(
|
| 677 |
+
"on_environment_created",
|
| 678 |
+
callbacks_objects=self._callbacks,
|
| 679 |
+
callbacks_functions=self.config.callbacks_on_environment_created,
|
| 680 |
+
kwargs=dict(
|
| 681 |
+
env_runner=self,
|
| 682 |
+
metrics_logger=self.metrics,
|
| 683 |
+
env=self.env.unwrapped,
|
| 684 |
+
env_context=env_ctx,
|
| 685 |
+
),
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
@override(EnvRunner)
|
| 689 |
+
def make_module(self):
|
| 690 |
+
try:
|
| 691 |
+
module_spec: RLModuleSpec = self.config.get_rl_module_spec(
|
| 692 |
+
env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True
|
| 693 |
+
)
|
| 694 |
+
# Build the module from its spec.
|
| 695 |
+
self.module = module_spec.build()
|
| 696 |
+
|
| 697 |
+
# Move the RLModule to our device.
|
| 698 |
+
# TODO (sven): In order to make this framework-agnostic, we should maybe
|
| 699 |
+
# make the RLModule.build() method accept a device OR create an additional
|
| 700 |
+
# `RLModule.to()` override.
|
| 701 |
+
self.module.to(self._device)
|
| 702 |
+
|
| 703 |
+
# If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner
|
| 704 |
+
# will not have an RLModule, but might still be usable with random actions.
|
| 705 |
+
except NotImplementedError:
|
| 706 |
+
self.module = None
|
| 707 |
+
|
| 708 |
+
@override(EnvRunner)
|
| 709 |
+
def stop(self):
|
| 710 |
+
# Close our env object via gymnasium's API.
|
| 711 |
+
self.env.close()
|
| 712 |
+
|
| 713 |
+
def _reset_envs(self, episodes, shared_data, explore):
|
| 714 |
+
# Create n new episodes and make the `on_episode_created` callbacks.
|
| 715 |
+
for env_index in range(self.num_envs):
|
| 716 |
+
self._new_episode(env_index, episodes)
|
| 717 |
+
|
| 718 |
+
# Erase all cached ongoing episodes (these will never be completed and
|
| 719 |
+
# would thus never be returned/cleaned by `get_metrics` and cause a memory
|
| 720 |
+
# leak).
|
| 721 |
+
self._ongoing_episodes_for_metrics.clear()
|
| 722 |
+
|
| 723 |
+
# Try resetting the environment.
|
| 724 |
+
# TODO (simon): Check, if we need here the seed from the config.
|
| 725 |
+
observations, infos = self._try_env_reset()
|
| 726 |
+
observations = unbatch(observations)
|
| 727 |
+
|
| 728 |
+
# Set initial obs and infos in the episodes.
|
| 729 |
+
for env_index in range(self.num_envs):
|
| 730 |
+
episodes[env_index].add_env_reset(
|
| 731 |
+
observation=observations[env_index],
|
| 732 |
+
infos=infos[env_index],
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Run the env-to-module connector to make sure the reset-obs/infos have
|
| 736 |
+
# properly been processed (if applicable).
|
| 737 |
+
self._cached_to_module = None
|
| 738 |
+
if self.module:
|
| 739 |
+
self._cached_to_module = self._env_to_module(
|
| 740 |
+
rl_module=self.module,
|
| 741 |
+
episodes=episodes,
|
| 742 |
+
explore=explore,
|
| 743 |
+
shared_data=shared_data,
|
| 744 |
+
metrics=self.metrics,
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# Call `on_episode_start()` callbacks (always after reset).
|
| 748 |
+
for env_index in range(self.num_envs):
|
| 749 |
+
self._make_on_episode_callback("on_episode_start", env_index, episodes)
|
| 750 |
+
|
| 751 |
+
def _new_episode(self, env_index, episodes=None):
|
| 752 |
+
episodes = episodes if episodes is not None else self._episodes
|
| 753 |
+
episodes[env_index] = SingleAgentEpisode(
|
| 754 |
+
observation_space=self.env.single_observation_space,
|
| 755 |
+
action_space=self.env.single_action_space,
|
| 756 |
+
)
|
| 757 |
+
self._make_on_episode_callback("on_episode_created", env_index, episodes)
|
| 758 |
+
|
| 759 |
+
def _make_on_episode_callback(self, which: str, idx: int, episodes):
|
| 760 |
+
make_callback(
|
| 761 |
+
which,
|
| 762 |
+
callbacks_objects=self._callbacks,
|
| 763 |
+
callbacks_functions=getattr(self.config, f"callbacks_{which}"),
|
| 764 |
+
kwargs=dict(
|
| 765 |
+
episode=episodes[idx],
|
| 766 |
+
env_runner=self,
|
| 767 |
+
metrics_logger=self.metrics,
|
| 768 |
+
env=self.env.unwrapped,
|
| 769 |
+
rl_module=self.module,
|
| 770 |
+
env_index=idx,
|
| 771 |
+
),
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
def _increase_sampled_metrics(self, num_steps, num_episodes_completed):
|
| 775 |
+
# Per sample cycle stats.
|
| 776 |
+
self.metrics.log_value(
|
| 777 |
+
NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True
|
| 778 |
+
)
|
| 779 |
+
self.metrics.log_value(
|
| 780 |
+
(NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID),
|
| 781 |
+
num_steps,
|
| 782 |
+
reduce="sum",
|
| 783 |
+
clear_on_reduce=True,
|
| 784 |
+
)
|
| 785 |
+
self.metrics.log_value(
|
| 786 |
+
(NUM_MODULE_STEPS_SAMPLED, DEFAULT_MODULE_ID),
|
| 787 |
+
num_steps,
|
| 788 |
+
reduce="sum",
|
| 789 |
+
clear_on_reduce=True,
|
| 790 |
+
)
|
| 791 |
+
self.metrics.log_value(
|
| 792 |
+
NUM_EPISODES,
|
| 793 |
+
num_episodes_completed,
|
| 794 |
+
reduce="sum",
|
| 795 |
+
clear_on_reduce=True,
|
| 796 |
+
)
|
| 797 |
+
# Lifetime stats.
|
| 798 |
+
self.metrics.log_value(
|
| 799 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 800 |
+
num_steps,
|
| 801 |
+
reduce="sum",
|
| 802 |
+
with_throughput=True,
|
| 803 |
+
)
|
| 804 |
+
self.metrics.log_value(
|
| 805 |
+
(NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID),
|
| 806 |
+
num_steps,
|
| 807 |
+
reduce="sum",
|
| 808 |
+
)
|
| 809 |
+
self.metrics.log_value(
|
| 810 |
+
(NUM_MODULE_STEPS_SAMPLED_LIFETIME, DEFAULT_MODULE_ID),
|
| 811 |
+
num_steps,
|
| 812 |
+
reduce="sum",
|
| 813 |
+
)
|
| 814 |
+
self.metrics.log_value(
|
| 815 |
+
NUM_EPISODES_LIFETIME,
|
| 816 |
+
num_episodes_completed,
|
| 817 |
+
reduce="sum",
|
| 818 |
+
)
|
| 819 |
+
return num_steps
|
| 820 |
+
|
| 821 |
+
def _log_episode_metrics(self, length, ret, sec):
|
| 822 |
+
# Log general episode metrics.
|
| 823 |
+
# To mimic the old API stack behavior, we'll use `window` here for
|
| 824 |
+
# these particular stats (instead of the default EMA).
|
| 825 |
+
win = self.config.metrics_num_episodes_for_smoothing
|
| 826 |
+
self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win)
|
| 827 |
+
self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win)
|
| 828 |
+
self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win)
|
| 829 |
+
# Per-agent returns.
|
| 830 |
+
self.metrics.log_value(
|
| 831 |
+
("agent_episode_returns_mean", DEFAULT_AGENT_ID), ret, window=win
|
| 832 |
+
)
|
| 833 |
+
# Per-RLModule returns.
|
| 834 |
+
self.metrics.log_value(
|
| 835 |
+
("module_episode_returns_mean", DEFAULT_MODULE_ID), ret, window=win
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# For some metrics, log min/max as well.
|
| 839 |
+
self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min", window=win)
|
| 840 |
+
self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min", window=win)
|
| 841 |
+
self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max", window=win)
|
| 842 |
+
self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max", window=win)
|
| 843 |
+
|
| 844 |
+
@Deprecated(
|
| 845 |
+
new="SingleAgentEnvRunner.get_state(components='rl_module')",
|
| 846 |
+
error=True,
|
| 847 |
+
)
|
| 848 |
+
def get_weights(self, *args, **kwargs):
|
| 849 |
+
pass
|
| 850 |
+
|
| 851 |
+
@Deprecated(new="SingleAgentEnvRunner.set_state()", error=True)
|
| 852 |
+
def set_weights(self, *args, **kwargs):
|
| 853 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/rllib/env/single_agent_episode.py
ADDED
|
@@ -0,0 +1,1862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
import gymnasium as gym
|
| 8 |
+
from gymnasium.core import ActType, ObsType
|
| 9 |
+
from typing import Any, Dict, List, Optional, SupportsFloat, Union
|
| 10 |
+
|
| 11 |
+
from ray.rllib.core.columns import Columns
|
| 12 |
+
from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict
|
| 15 |
+
from ray.rllib.utils.deprecation import Deprecated
|
| 16 |
+
from ray.rllib.utils.typing import AgentID, ModuleID
|
| 17 |
+
from ray.util.annotations import PublicAPI
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@PublicAPI(stability="alpha")
|
| 21 |
+
class SingleAgentEpisode:
|
| 22 |
+
"""A class representing RL environment episodes for individual agents.
|
| 23 |
+
|
| 24 |
+
SingleAgentEpisode stores observations, info dicts, actions, rewards, and all
|
| 25 |
+
module outputs (e.g. state outs, action logp, etc..) for an individual agent within
|
| 26 |
+
some single-agent or multi-agent environment.
|
| 27 |
+
The two main APIs to add data to an ongoing episode are the `add_env_reset()`
|
| 28 |
+
and `add_env_step()` methods, which should be called passing the outputs of the
|
| 29 |
+
respective gym.Env API calls: `env.reset()` and `env.step()`.
|
| 30 |
+
|
| 31 |
+
A SingleAgentEpisode might also only represent a chunk of an episode, which is
|
| 32 |
+
useful for cases, in which partial (non-complete episode) sampling is performed
|
| 33 |
+
and collected episode data has to be returned before the actual gym.Env episode has
|
| 34 |
+
finished (see `SingleAgentEpisode.cut()`). In order to still maintain visibility
|
| 35 |
+
onto past experiences within such a "cut" episode, SingleAgentEpisode instances
|
| 36 |
+
can have a "lookback buffer" of n timesteps at their beginning (left side), which
|
| 37 |
+
solely exists for the purpose of compiling extra data (e.g. "prev. reward"), but
|
| 38 |
+
is not considered part of the finished/packaged episode (b/c the data in the
|
| 39 |
+
lookback buffer is already part of a previous episode chunk).
|
| 40 |
+
|
| 41 |
+
Powerful getter methods, such as `get_observations()` help collect different types
|
| 42 |
+
of data from the episode at individual time indices or time ranges, including the
|
| 43 |
+
"lookback buffer" range described above. For example, to extract the last 4 rewards
|
| 44 |
+
of an ongoing episode, one can call `self.get_rewards(slice(-4, None))` or
|
| 45 |
+
`self.rewards[-4:]`. This would work, even if the ongoing SingleAgentEpisode is
|
| 46 |
+
a continuation chunk from a much earlier started episode, as long as it has a
|
| 47 |
+
lookback buffer size of sufficient size.
|
| 48 |
+
|
| 49 |
+
Examples:
|
| 50 |
+
|
| 51 |
+
.. testcode::
|
| 52 |
+
|
| 53 |
+
import gymnasium as gym
|
| 54 |
+
import numpy as np
|
| 55 |
+
|
| 56 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 57 |
+
|
| 58 |
+
# Construct a new episode (without any data in it yet).
|
| 59 |
+
episode = SingleAgentEpisode()
|
| 60 |
+
assert len(episode) == 0
|
| 61 |
+
|
| 62 |
+
# Fill the episode with some data (10 timesteps).
|
| 63 |
+
env = gym.make("CartPole-v1")
|
| 64 |
+
obs, infos = env.reset()
|
| 65 |
+
episode.add_env_reset(obs, infos)
|
| 66 |
+
|
| 67 |
+
# Even with the initial obs/infos, the episode is still considered len=0.
|
| 68 |
+
assert len(episode) == 0
|
| 69 |
+
for _ in range(5):
|
| 70 |
+
action = env.action_space.sample()
|
| 71 |
+
obs, reward, term, trunc, infos = env.step(action)
|
| 72 |
+
episode.add_env_step(
|
| 73 |
+
observation=obs,
|
| 74 |
+
action=action,
|
| 75 |
+
reward=reward,
|
| 76 |
+
terminated=term,
|
| 77 |
+
truncated=trunc,
|
| 78 |
+
infos=infos,
|
| 79 |
+
)
|
| 80 |
+
assert len(episode) == 5
|
| 81 |
+
|
| 82 |
+
# We can now access information from the episode via the getter APIs.
|
| 83 |
+
|
| 84 |
+
# Get the last 3 rewards (in a batch of size 3).
|
| 85 |
+
episode.get_rewards(slice(-3, None)) # same as `episode.rewards[-3:]`
|
| 86 |
+
|
| 87 |
+
# Get the most recent action (single item, not batched).
|
| 88 |
+
# This works regardless of the action space or whether the episode has
|
| 89 |
+
# been numpy'ized or not (see below).
|
| 90 |
+
episode.get_actions(-1) # same as episode.actions[-1]
|
| 91 |
+
|
| 92 |
+
# Looking back from ts=1, get the previous 4 rewards AND fill with 0.0
|
| 93 |
+
# in case we go over the beginning (ts=0). So we would expect
|
| 94 |
+
# [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received
|
| 95 |
+
# reward in the episode:
|
| 96 |
+
episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0)
|
| 97 |
+
|
| 98 |
+
# Note the use of fill=0.0 here (fill everything that's out of range with this
|
| 99 |
+
# value) AND the argument `neg_index_as_lookback=True`, which interprets
|
| 100 |
+
# negative indices as being left of ts=0 (e.g. -1 being the timestep before
|
| 101 |
+
# ts=0).
|
| 102 |
+
|
| 103 |
+
# Assuming we had a complex action space (nested gym.spaces.Dict) with one or
|
| 104 |
+
# more elements being Discrete or MultiDiscrete spaces:
|
| 105 |
+
# 1) The `fill=...` argument would still work, filling all spaces (Boxes,
|
| 106 |
+
# Discrete) with that provided value.
|
| 107 |
+
# 2) Setting the flag `one_hot_discrete=True` would convert those discrete
|
| 108 |
+
# sub-components automatically into one-hot (or multi-one-hot) tensors.
|
| 109 |
+
# This simplifies the task of having to provide the previous 4 (nested and
|
| 110 |
+
# partially discrete/multi-discrete) actions for each timestep within a training
|
| 111 |
+
# batch, thereby filling timesteps before the episode started with 0.0s and
|
| 112 |
+
# one-hot'ing the discrete/multi-discrete components in these actions:
|
| 113 |
+
episode = SingleAgentEpisode(action_space=gym.spaces.Dict({
|
| 114 |
+
"a": gym.spaces.Discrete(3),
|
| 115 |
+
"b": gym.spaces.MultiDiscrete([2, 3]),
|
| 116 |
+
"c": gym.spaces.Box(-1.0, 1.0, (2,)),
|
| 117 |
+
}))
|
| 118 |
+
|
| 119 |
+
# ... fill episode with data ...
|
| 120 |
+
episode.add_env_reset(observation=0)
|
| 121 |
+
# ... from a few steps.
|
| 122 |
+
episode.add_env_step(
|
| 123 |
+
observation=1,
|
| 124 |
+
action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)},
|
| 125 |
+
reward=1.0,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# In your connector
|
| 129 |
+
prev_4_a = []
|
| 130 |
+
# Note here that len(episode) does NOT include the lookback buffer.
|
| 131 |
+
for ts in range(len(episode)):
|
| 132 |
+
prev_4_a.append(
|
| 133 |
+
episode.get_actions(
|
| 134 |
+
indices=slice(ts - 4, ts),
|
| 135 |
+
# Make sure negative indices are interpreted as
|
| 136 |
+
# "into lookback buffer"
|
| 137 |
+
neg_index_as_lookback=True,
|
| 138 |
+
# Zero-out everything even further before the lookback buffer.
|
| 139 |
+
fill=0.0,
|
| 140 |
+
# Take care of discrete components (get ready as NN input).
|
| 141 |
+
one_hot_discrete=True,
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Finally, convert from list of batch items to a struct (same as action space)
|
| 146 |
+
# of batched (numpy) arrays, in which all leafs have B==len(prev_4_a).
|
| 147 |
+
from ray.rllib.utils.spaces.space_utils import batch
|
| 148 |
+
|
| 149 |
+
prev_4_actions_col = batch(prev_4_a)
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
__slots__ = (
|
| 153 |
+
"actions",
|
| 154 |
+
"agent_id",
|
| 155 |
+
"extra_model_outputs",
|
| 156 |
+
"id_",
|
| 157 |
+
"infos",
|
| 158 |
+
"is_terminated",
|
| 159 |
+
"is_truncated",
|
| 160 |
+
"module_id",
|
| 161 |
+
"multi_agent_episode_id",
|
| 162 |
+
"observations",
|
| 163 |
+
"rewards",
|
| 164 |
+
"t",
|
| 165 |
+
"t_started",
|
| 166 |
+
"_action_space",
|
| 167 |
+
"_last_added_observation",
|
| 168 |
+
"_last_added_infos",
|
| 169 |
+
"_last_step_time",
|
| 170 |
+
"_observation_space",
|
| 171 |
+
"_start_time",
|
| 172 |
+
"_temporary_timestep_data",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
id_: Optional[str] = None,
|
| 178 |
+
*,
|
| 179 |
+
observations: Optional[Union[List[ObsType], InfiniteLookbackBuffer]] = None,
|
| 180 |
+
observation_space: Optional[gym.Space] = None,
|
| 181 |
+
infos: Optional[Union[List[Dict], InfiniteLookbackBuffer]] = None,
|
| 182 |
+
actions: Optional[Union[List[ActType], InfiniteLookbackBuffer]] = None,
|
| 183 |
+
action_space: Optional[gym.Space] = None,
|
| 184 |
+
rewards: Optional[Union[List[SupportsFloat], InfiniteLookbackBuffer]] = None,
|
| 185 |
+
terminated: bool = False,
|
| 186 |
+
truncated: bool = False,
|
| 187 |
+
extra_model_outputs: Optional[Dict[str, Any]] = None,
|
| 188 |
+
t_started: Optional[int] = None,
|
| 189 |
+
len_lookback_buffer: Union[int, str] = "auto",
|
| 190 |
+
agent_id: Optional[AgentID] = None,
|
| 191 |
+
module_id: Optional[ModuleID] = None,
|
| 192 |
+
multi_agent_episode_id: Optional[int] = None,
|
| 193 |
+
):
|
| 194 |
+
"""Initializes a SingleAgentEpisode instance.
|
| 195 |
+
|
| 196 |
+
This constructor can be called with or without already sampled data, part of
|
| 197 |
+
which might then go into the lookback buffer.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
id_: Unique identifier for this episode. If no ID is provided the
|
| 201 |
+
constructor generates a unique hexadecimal code for the id.
|
| 202 |
+
observations: Either a list of individual observations from a sampling or
|
| 203 |
+
an already instantiated `InfiniteLookbackBuffer` object (possibly
|
| 204 |
+
with observation data in it). If a list, will construct the buffer
|
| 205 |
+
automatically (given the data and the `len_lookback_buffer` argument).
|
| 206 |
+
observation_space: An optional gym.Space, which all individual observations
|
| 207 |
+
should abide to. If not None and this SingleAgentEpisode is numpy'ized
|
| 208 |
+
(via the `self.to_numpy()` method), and data is appended or set, the new
|
| 209 |
+
data will be checked for correctness.
|
| 210 |
+
infos: Either a list of individual info dicts from a sampling or
|
| 211 |
+
an already instantiated `InfiniteLookbackBuffer` object (possibly
|
| 212 |
+
with info dicts in it). If a list, will construct the buffer
|
| 213 |
+
automatically (given the data and the `len_lookback_buffer` argument).
|
| 214 |
+
actions: Either a list of individual info dicts from a sampling or
|
| 215 |
+
an already instantiated `InfiniteLookbackBuffer` object (possibly
|
| 216 |
+
with info dict] data in it). If a list, will construct the buffer
|
| 217 |
+
automatically (given the data and the `len_lookback_buffer` argument).
|
| 218 |
+
action_space: An optional gym.Space, which all individual actions
|
| 219 |
+
should abide to. If not None and this SingleAgentEpisode is numpy'ized
|
| 220 |
+
(via the `self.to_numpy()` method), and data is appended or set, the new
|
| 221 |
+
data will be checked for correctness.
|
| 222 |
+
rewards: Either a list of individual rewards from a sampling or
|
| 223 |
+
an already instantiated `InfiniteLookbackBuffer` object (possibly
|
| 224 |
+
with reward data in it). If a list, will construct the buffer
|
| 225 |
+
automatically (given the data and the `len_lookback_buffer` argument).
|
| 226 |
+
extra_model_outputs: A dict mapping string keys to either lists of
|
| 227 |
+
individual extra model output tensors (e.g. `action_logp` or
|
| 228 |
+
`state_outs`) from a sampling or to already instantiated
|
| 229 |
+
`InfiniteLookbackBuffer` object (possibly with extra model output data
|
| 230 |
+
in it). If mapping is to lists, will construct the buffers automatically
|
| 231 |
+
(given the data and the `len_lookback_buffer` argument).
|
| 232 |
+
terminated: A boolean indicating, if the episode is already terminated.
|
| 233 |
+
truncated: A boolean indicating, if the episode has been truncated.
|
| 234 |
+
t_started: Optional. The starting timestep of the episode. The default
|
| 235 |
+
is zero. If data is provided, the starting point is from the last
|
| 236 |
+
observation onwards (i.e. `t_started = len(observations) - 1`). If
|
| 237 |
+
this parameter is provided the episode starts at the provided value.
|
| 238 |
+
len_lookback_buffer: The size of the (optional) lookback buffers to keep in
|
| 239 |
+
front of this Episode for each type of data (observations, actions,
|
| 240 |
+
etc..). If larger 0, will interpret the first `len_lookback_buffer`
|
| 241 |
+
items in each type of data as NOT part of this actual
|
| 242 |
+
episode chunk, but instead serve as "historical" record that may be
|
| 243 |
+
viewed and used to derive new data from. For example, it might be
|
| 244 |
+
necessary to have a lookback buffer of four if you would like to do
|
| 245 |
+
observation frame stacking and your episode has been cut and you are now
|
| 246 |
+
operating on a new chunk (continuing from the cut one). Then, for the
|
| 247 |
+
first 3 items, you would have to be able to look back into the old
|
| 248 |
+
chunk's data.
|
| 249 |
+
If `len_lookback_buffer` is "auto" (default), will interpret all
|
| 250 |
+
provided data in the constructor as part of the lookback buffers.
|
| 251 |
+
agent_id: An optional AgentID indicating which agent this episode belongs
|
| 252 |
+
to. This information is stored under `self.agent_id` and only serves
|
| 253 |
+
reference purposes.
|
| 254 |
+
module_id: An optional ModuleID indicating which RLModule this episode
|
| 255 |
+
belongs to. Normally, this information is obtained by querying an
|
| 256 |
+
`agent_to_module_mapping_fn` with a given agent ID. This information
|
| 257 |
+
is stored under `self.module_id` and only serves reference purposes.
|
| 258 |
+
multi_agent_episode_id: An optional EpisodeID of the encapsulating
|
| 259 |
+
`MultiAgentEpisode` that this `SingleAgentEpisode` belongs to.
|
| 260 |
+
"""
|
| 261 |
+
self.id_ = id_ or uuid.uuid4().hex
|
| 262 |
+
|
| 263 |
+
self.agent_id = agent_id
|
| 264 |
+
self.module_id = module_id
|
| 265 |
+
self.multi_agent_episode_id = multi_agent_episode_id
|
| 266 |
+
|
| 267 |
+
# Lookback buffer length is not provided. Interpret already given data as
|
| 268 |
+
# lookback buffer lengths for all data types.
|
| 269 |
+
len_rewards = len(rewards) if rewards is not None else 0
|
| 270 |
+
if len_lookback_buffer == "auto" or len_lookback_buffer > len_rewards:
|
| 271 |
+
len_lookback_buffer = len_rewards
|
| 272 |
+
|
| 273 |
+
infos = infos or [{} for _ in range(len(observations or []))]
|
| 274 |
+
|
| 275 |
+
# Observations: t0 (initial obs) to T.
|
| 276 |
+
self._observation_space = None
|
| 277 |
+
if isinstance(observations, InfiniteLookbackBuffer):
|
| 278 |
+
self.observations = observations
|
| 279 |
+
else:
|
| 280 |
+
self.observations = InfiniteLookbackBuffer(
|
| 281 |
+
data=observations,
|
| 282 |
+
lookback=len_lookback_buffer,
|
| 283 |
+
)
|
| 284 |
+
self.observation_space = observation_space
|
| 285 |
+
# Infos: t0 (initial info) to T.
|
| 286 |
+
if isinstance(infos, InfiniteLookbackBuffer):
|
| 287 |
+
self.infos = infos
|
| 288 |
+
else:
|
| 289 |
+
self.infos = InfiniteLookbackBuffer(
|
| 290 |
+
data=infos,
|
| 291 |
+
lookback=len_lookback_buffer,
|
| 292 |
+
)
|
| 293 |
+
# Actions: t1 to T.
|
| 294 |
+
self._action_space = None
|
| 295 |
+
if isinstance(actions, InfiniteLookbackBuffer):
|
| 296 |
+
self.actions = actions
|
| 297 |
+
else:
|
| 298 |
+
self.actions = InfiniteLookbackBuffer(
|
| 299 |
+
data=actions,
|
| 300 |
+
lookback=len_lookback_buffer,
|
| 301 |
+
)
|
| 302 |
+
self.action_space = action_space
|
| 303 |
+
# Rewards: t1 to T.
|
| 304 |
+
if isinstance(rewards, InfiniteLookbackBuffer):
|
| 305 |
+
self.rewards = rewards
|
| 306 |
+
else:
|
| 307 |
+
self.rewards = InfiniteLookbackBuffer(
|
| 308 |
+
data=rewards,
|
| 309 |
+
lookback=len_lookback_buffer,
|
| 310 |
+
space=gym.spaces.Box(float("-inf"), float("inf"), (), np.float32),
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# obs[-1] is the final observation in the episode.
|
| 314 |
+
self.is_terminated = terminated
|
| 315 |
+
# obs[-1] is the last obs in a truncated-by-the-env episode (there will no more
|
| 316 |
+
# observations in following chunks for this episode).
|
| 317 |
+
self.is_truncated = truncated
|
| 318 |
+
|
| 319 |
+
# Extra model outputs, e.g. `action_dist_input` needed in the batch.
|
| 320 |
+
self.extra_model_outputs = {}
|
| 321 |
+
for k, v in (extra_model_outputs or {}).items():
|
| 322 |
+
if isinstance(v, InfiniteLookbackBuffer):
|
| 323 |
+
self.extra_model_outputs[k] = v
|
| 324 |
+
else:
|
| 325 |
+
# We cannot use the defaultdict's own constructor here as this would
|
| 326 |
+
# auto-set the lookback buffer to 0 (there is no data passed to that
|
| 327 |
+
# constructor). Then, when we manually have to set the data property,
|
| 328 |
+
# the lookback buffer would still be (incorrectly) 0.
|
| 329 |
+
self.extra_model_outputs[k] = InfiniteLookbackBuffer(
|
| 330 |
+
data=v, lookback=len_lookback_buffer
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# The (global) timestep when this episode (possibly an episode chunk) started,
|
| 334 |
+
# excluding a possible lookback buffer.
|
| 335 |
+
self.t_started = t_started or 0
|
| 336 |
+
# The current (global) timestep in the episode (possibly an episode chunk).
|
| 337 |
+
self.t = len(self.rewards) + self.t_started
|
| 338 |
+
|
| 339 |
+
# Caches for temporary per-timestep data. May be used to store custom metrics
|
| 340 |
+
# from within a callback for the ongoing episode (e.g. render images).
|
| 341 |
+
self._temporary_timestep_data = defaultdict(list)
|
| 342 |
+
|
| 343 |
+
# Keep timer stats on deltas between steps.
|
| 344 |
+
self._start_time = None
|
| 345 |
+
self._last_step_time = None
|
| 346 |
+
|
| 347 |
+
self._last_added_observation = None
|
| 348 |
+
self._last_added_infos = None
|
| 349 |
+
|
| 350 |
+
# Validate the episode data thus far.
|
| 351 |
+
self.validate()
|
| 352 |
+
|
| 353 |
+
def add_env_reset(
|
| 354 |
+
self,
|
| 355 |
+
observation: ObsType,
|
| 356 |
+
infos: Optional[Dict] = None,
|
| 357 |
+
) -> None:
|
| 358 |
+
"""Adds the initial data (after an `env.reset()`) to the episode.
|
| 359 |
+
|
| 360 |
+
This data consists of initial observations and initial infos.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
observation: The initial observation returned by `env.reset()`.
|
| 364 |
+
infos: An (optional) info dict returned by `env.reset()`.
|
| 365 |
+
"""
|
| 366 |
+
assert not self.is_reset
|
| 367 |
+
assert not self.is_done
|
| 368 |
+
assert len(self.observations) == 0
|
| 369 |
+
# Assume that this episode is completely empty and has not stepped yet.
|
| 370 |
+
# Leave self.t (and self.t_started) at 0.
|
| 371 |
+
assert self.t == self.t_started == 0
|
| 372 |
+
|
| 373 |
+
infos = infos or {}
|
| 374 |
+
|
| 375 |
+
if self.observation_space is not None:
|
| 376 |
+
assert self.observation_space.contains(observation), (
|
| 377 |
+
f"`observation` {observation} does NOT fit SingleAgentEpisode's "
|
| 378 |
+
f"observation_space: {self.observation_space}!"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
self.observations.append(observation)
|
| 382 |
+
self.infos.append(infos)
|
| 383 |
+
|
| 384 |
+
self._last_added_observation = observation
|
| 385 |
+
self._last_added_infos = infos
|
| 386 |
+
|
| 387 |
+
# Validate our data.
|
| 388 |
+
self.validate()
|
| 389 |
+
|
| 390 |
+
# Start the timer for this episode.
|
| 391 |
+
self._start_time = time.perf_counter()
|
| 392 |
+
|
| 393 |
+
def add_env_step(
|
| 394 |
+
self,
|
| 395 |
+
observation: ObsType,
|
| 396 |
+
action: ActType,
|
| 397 |
+
reward: SupportsFloat,
|
| 398 |
+
infos: Optional[Dict[str, Any]] = None,
|
| 399 |
+
*,
|
| 400 |
+
terminated: bool = False,
|
| 401 |
+
truncated: bool = False,
|
| 402 |
+
extra_model_outputs: Optional[Dict[str, Any]] = None,
|
| 403 |
+
) -> None:
|
| 404 |
+
"""Adds results of an `env.step()` call (including the action) to this episode.
|
| 405 |
+
|
| 406 |
+
This data consists of an observation and info dict, an action, a reward,
|
| 407 |
+
terminated/truncated flags, and extra model outputs (e.g. action probabilities
|
| 408 |
+
or RNN internal state outputs).
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
observation: The next observation received from the environment after(!)
|
| 412 |
+
taking `action`.
|
| 413 |
+
action: The last action used by the agent during the call to `env.step()`.
|
| 414 |
+
reward: The last reward received by the agent after taking `action`.
|
| 415 |
+
infos: The last info received from the environment after taking `action`.
|
| 416 |
+
terminated: A boolean indicating, if the environment has been
|
| 417 |
+
terminated (after taking `action`).
|
| 418 |
+
truncated: A boolean indicating, if the environment has been
|
| 419 |
+
truncated (after taking `action`).
|
| 420 |
+
extra_model_outputs: The last timestep's specific model outputs.
|
| 421 |
+
These are normally outputs of an RLModule that were computed along with
|
| 422 |
+
`action`, e.g. `action_logp` or `action_dist_inputs`.
|
| 423 |
+
"""
|
| 424 |
+
# Cannot add data to an already done episode.
|
| 425 |
+
assert (
|
| 426 |
+
not self.is_done
|
| 427 |
+
), "The agent is already done: no data can be added to its episode."
|
| 428 |
+
|
| 429 |
+
self.observations.append(observation)
|
| 430 |
+
self.actions.append(action)
|
| 431 |
+
self.rewards.append(reward)
|
| 432 |
+
infos = infos or {}
|
| 433 |
+
self.infos.append(infos)
|
| 434 |
+
self.t += 1
|
| 435 |
+
if extra_model_outputs is not None:
|
| 436 |
+
for k, v in extra_model_outputs.items():
|
| 437 |
+
if k not in self.extra_model_outputs:
|
| 438 |
+
self.extra_model_outputs[k] = InfiniteLookbackBuffer([v])
|
| 439 |
+
else:
|
| 440 |
+
self.extra_model_outputs[k].append(v)
|
| 441 |
+
self.is_terminated = terminated
|
| 442 |
+
self.is_truncated = truncated
|
| 443 |
+
|
| 444 |
+
self._last_added_observation = observation
|
| 445 |
+
self._last_added_infos = infos
|
| 446 |
+
|
| 447 |
+
# Only check spaces if numpy'ized AND every n timesteps.
|
| 448 |
+
if self.is_numpy and self.t % 100:
|
| 449 |
+
if self.observation_space is not None:
|
| 450 |
+
assert self.observation_space.contains(observation), (
|
| 451 |
+
f"`observation` {observation} does NOT fit SingleAgentEpisode's "
|
| 452 |
+
f"observation_space: {self.observation_space}!"
|
| 453 |
+
)
|
| 454 |
+
if self.action_space is not None:
|
| 455 |
+
assert self.action_space.contains(action), (
|
| 456 |
+
f"`action` {action} does NOT fit SingleAgentEpisode's "
|
| 457 |
+
f"action_space: {self.action_space}!"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Validate our data.
|
| 461 |
+
self.validate()
|
| 462 |
+
|
| 463 |
+
# Step time stats.
|
| 464 |
+
self._last_step_time = time.perf_counter()
|
| 465 |
+
if self._start_time is None:
|
| 466 |
+
self._start_time = self._last_step_time
|
| 467 |
+
|
| 468 |
+
def validate(self) -> None:
|
| 469 |
+
"""Validates the episode's data.
|
| 470 |
+
|
| 471 |
+
This function ensures that the data stored to a `SingleAgentEpisode` is
|
| 472 |
+
in order (e.g. that the correct number of observations, actions, rewards
|
| 473 |
+
are there).
|
| 474 |
+
"""
|
| 475 |
+
assert len(self.observations) == len(self.infos)
|
| 476 |
+
if len(self.observations) == 0:
|
| 477 |
+
assert len(self.infos) == len(self.rewards) == len(self.actions) == 0
|
| 478 |
+
for k, v in self.extra_model_outputs.items():
|
| 479 |
+
assert len(v) == 0, (k, v, v.data, len(v))
|
| 480 |
+
# Make sure we always have one more obs stored than rewards (and actions)
|
| 481 |
+
# due to the reset/last-obs logic of an MDP.
|
| 482 |
+
else:
|
| 483 |
+
assert (
|
| 484 |
+
len(self.observations)
|
| 485 |
+
== len(self.infos)
|
| 486 |
+
== len(self.rewards) + 1
|
| 487 |
+
== len(self.actions) + 1
|
| 488 |
+
), (
|
| 489 |
+
len(self.observations),
|
| 490 |
+
len(self.infos),
|
| 491 |
+
len(self.rewards),
|
| 492 |
+
len(self.actions),
|
| 493 |
+
)
|
| 494 |
+
for k, v in self.extra_model_outputs.items():
|
| 495 |
+
assert len(v) == len(self.observations) - 1
|
| 496 |
+
|
| 497 |
+
@property
|
| 498 |
+
def is_reset(self) -> bool:
|
| 499 |
+
"""Returns True if `self.add_env_reset()` has already been called."""
|
| 500 |
+
return len(self.observations) > 0
|
| 501 |
+
|
| 502 |
+
@property
|
| 503 |
+
def is_numpy(self) -> bool:
|
| 504 |
+
"""True, if the data in this episode is already stored as numpy arrays."""
|
| 505 |
+
# If rewards are still a list, return False.
|
| 506 |
+
# Otherwise, rewards should already be a (1D) numpy array.
|
| 507 |
+
return self.rewards.finalized
|
| 508 |
+
|
| 509 |
+
@property
|
| 510 |
+
def is_done(self) -> bool:
|
| 511 |
+
"""Whether the episode is actually done (terminated or truncated).
|
| 512 |
+
|
| 513 |
+
A done episode cannot be continued via `self.add_timestep()` or being
|
| 514 |
+
concatenated on its right-side with another episode chunk or being
|
| 515 |
+
succeeded via `self.create_successor()`.
|
| 516 |
+
"""
|
| 517 |
+
return self.is_terminated or self.is_truncated
|
| 518 |
+
|
| 519 |
+
def to_numpy(self) -> "SingleAgentEpisode":
|
| 520 |
+
"""Converts this Episode's list attributes to numpy arrays.
|
| 521 |
+
|
| 522 |
+
This means in particular that this episodes' lists of (possibly complex)
|
| 523 |
+
data (e.g. if we have a dict obs space) will be converted to (possibly complex)
|
| 524 |
+
structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will
|
| 525 |
+
have the same length (batch dimension) as the length of the original lists.
|
| 526 |
+
|
| 527 |
+
Note that the data under the Columns.INFOS are NEVER numpy'ized and will remain
|
| 528 |
+
a list (normally, a list of the original, env-returned dicts). This is due to
|
| 529 |
+
the herterogenous nature of INFOS returned by envs, which would make it unwieldy
|
| 530 |
+
to convert this information to numpy arrays.
|
| 531 |
+
|
| 532 |
+
After calling this method, no further data may be added to this episode via
|
| 533 |
+
the `self.add_env_step()` method.
|
| 534 |
+
|
| 535 |
+
Examples:
|
| 536 |
+
|
| 537 |
+
.. testcode::
|
| 538 |
+
|
| 539 |
+
import numpy as np
|
| 540 |
+
|
| 541 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 542 |
+
|
| 543 |
+
episode = SingleAgentEpisode(
|
| 544 |
+
observations=[0, 1, 2, 3],
|
| 545 |
+
actions=[1, 2, 3],
|
| 546 |
+
rewards=[1, 2, 3],
|
| 547 |
+
# Note: terminated/truncated have nothing to do with an episode
|
| 548 |
+
# being numpy'ized or not (via the `self.to_numpy()` method)!
|
| 549 |
+
terminated=False,
|
| 550 |
+
len_lookback_buffer=0, # no lookback; all data is actually "in" episode
|
| 551 |
+
)
|
| 552 |
+
# Episode has not been numpy'ized yet.
|
| 553 |
+
assert not episode.is_numpy
|
| 554 |
+
# We are still operating on lists.
|
| 555 |
+
assert episode.get_observations([1]) == [1]
|
| 556 |
+
assert episode.get_observations(slice(None, 2)) == [0, 1]
|
| 557 |
+
# We can still add data (and even add the terminated=True flag).
|
| 558 |
+
episode.add_env_step(
|
| 559 |
+
observation=4,
|
| 560 |
+
action=4,
|
| 561 |
+
reward=4,
|
| 562 |
+
terminated=True,
|
| 563 |
+
)
|
| 564 |
+
# Still NOT numpy'ized.
|
| 565 |
+
assert not episode.is_numpy
|
| 566 |
+
|
| 567 |
+
# Numpy'ized the episode.
|
| 568 |
+
episode.to_numpy()
|
| 569 |
+
assert episode.is_numpy
|
| 570 |
+
|
| 571 |
+
# We cannot add data anymore. The following would crash.
|
| 572 |
+
# episode.add_env_step(observation=5, action=5, reward=5)
|
| 573 |
+
|
| 574 |
+
# Everything is now numpy arrays (with 0-axis of size
|
| 575 |
+
# B=[len of requested slice]).
|
| 576 |
+
assert isinstance(episode.get_observations([1]), np.ndarray) # B=1
|
| 577 |
+
assert isinstance(episode.actions[0:2], np.ndarray) # B=2
|
| 578 |
+
assert isinstance(episode.rewards[1:4], np.ndarray) # B=3
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
This `SingleAgentEpisode` object with the converted numpy data.
|
| 582 |
+
"""
|
| 583 |
+
|
| 584 |
+
self.observations.finalize()
|
| 585 |
+
if len(self) > 0:
|
| 586 |
+
self.actions.finalize()
|
| 587 |
+
self.rewards.finalize()
|
| 588 |
+
for k, v in self.extra_model_outputs.items():
|
| 589 |
+
self.extra_model_outputs[k].finalize()
|
| 590 |
+
|
| 591 |
+
return self
|
| 592 |
+
|
| 593 |
+
def concat_episode(self, other: "SingleAgentEpisode") -> None:
|
| 594 |
+
"""Adds the given `other` SingleAgentEpisode to the right side of self.
|
| 595 |
+
|
| 596 |
+
In order for this to work, both chunks (`self` and `other`) must fit
|
| 597 |
+
together. This is checked by the IDs (must be identical), the time step counters
|
| 598 |
+
(`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the
|
| 599 |
+
observations/infos at the concatenation boundaries. Also, `self.is_done` must
|
| 600 |
+
not be True, meaning `self.is_terminated` and `self.is_truncated` are both
|
| 601 |
+
False.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
other: The other `SingleAgentEpisode` to be concatenated to this one.
|
| 605 |
+
|
| 606 |
+
Returns: A `SingleAgentEpisode` instance containing the concatenated data
|
| 607 |
+
from both episodes (`self` and `other`).
|
| 608 |
+
"""
|
| 609 |
+
assert other.id_ == self.id_
|
| 610 |
+
# NOTE (sven): This is what we agreed on. As the replay buffers must be
|
| 611 |
+
# able to concatenate.
|
| 612 |
+
assert not self.is_done
|
| 613 |
+
# Make sure the timesteps match.
|
| 614 |
+
assert self.t == other.t_started
|
| 615 |
+
# Validate `other`.
|
| 616 |
+
other.validate()
|
| 617 |
+
|
| 618 |
+
# Make sure, end matches other episode chunk's beginning.
|
| 619 |
+
assert np.all(other.observations[0] == self.observations[-1])
|
| 620 |
+
# Pop out our last observations and infos (as these are identical
|
| 621 |
+
# to the first obs and infos in the next episode).
|
| 622 |
+
self.observations.pop()
|
| 623 |
+
self.infos.pop()
|
| 624 |
+
|
| 625 |
+
# Extend ourselves. In case, episode_chunk is already terminated and numpy'ized
|
| 626 |
+
# we need to convert to lists (as we are ourselves still filling up lists).
|
| 627 |
+
self.observations.extend(other.get_observations())
|
| 628 |
+
self.actions.extend(other.get_actions())
|
| 629 |
+
self.rewards.extend(other.get_rewards())
|
| 630 |
+
self.infos.extend(other.get_infos())
|
| 631 |
+
self.t = other.t
|
| 632 |
+
|
| 633 |
+
if other.is_terminated:
|
| 634 |
+
self.is_terminated = True
|
| 635 |
+
elif other.is_truncated:
|
| 636 |
+
self.is_truncated = True
|
| 637 |
+
|
| 638 |
+
for key in other.extra_model_outputs.keys():
|
| 639 |
+
assert key in self.extra_model_outputs
|
| 640 |
+
self.extra_model_outputs[key].extend(other.get_extra_model_outputs(key))
|
| 641 |
+
|
| 642 |
+
# Validate.
|
| 643 |
+
self.validate()
|
| 644 |
+
|
| 645 |
+
def cut(self, len_lookback_buffer: int = 0) -> "SingleAgentEpisode":
|
| 646 |
+
"""Returns a successor episode chunk (of len=0) continuing from this Episode.
|
| 647 |
+
|
| 648 |
+
The successor will have the same ID as `self`.
|
| 649 |
+
If no lookback buffer is requested (len_lookback_buffer=0), the successor's
|
| 650 |
+
observations will be the last observation(s) of `self` and its length will
|
| 651 |
+
therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0,
|
| 652 |
+
the returned successor will have `len_lookback_buffer` observations (and
|
| 653 |
+
actions, rewards, etc..) taken from the right side (end) of `self`. For example
|
| 654 |
+
if `len_lookback_buffer=2`, the returned successor's lookback buffer actions
|
| 655 |
+
will be identical to `self.actions[-2:]`.
|
| 656 |
+
|
| 657 |
+
This method is useful if you would like to discontinue building an episode
|
| 658 |
+
chunk (b/c you have to return it from somewhere), but would like to have a new
|
| 659 |
+
episode instance to continue building the actual gym.Env episode at a later
|
| 660 |
+
time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor)
|
| 661 |
+
will still be able to "look back" into this predecessor episode's data (at
|
| 662 |
+
least to some extend, depending on the value of `len_lookback_buffer`).
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
len_lookback_buffer: The number of timesteps to take along into the new
|
| 666 |
+
chunk as "lookback buffer". A lookback buffer is additional data on
|
| 667 |
+
the left side of the actual episode data for visibility purposes
|
| 668 |
+
(but without actually being part of the new chunk). For example, if
|
| 669 |
+
`self` ends in actions 5, 6, 7, and 8, and we call
|
| 670 |
+
`self.cut(len_lookback_buffer=2)`, the returned chunk will have
|
| 671 |
+
actions 7 and 8 already in it, but still `t_started`==t==8 (not 7!) and
|
| 672 |
+
a length of 0. If there is not enough data in `self` yet to fulfil
|
| 673 |
+
the `len_lookback_buffer` request, the value of `len_lookback_buffer`
|
| 674 |
+
is automatically adjusted (lowered).
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
The successor Episode chunk of this one with the same ID and state and the
|
| 678 |
+
only observation being the last observation in self.
|
| 679 |
+
"""
|
| 680 |
+
assert not self.is_done and len_lookback_buffer >= 0
|
| 681 |
+
|
| 682 |
+
# Initialize this chunk with the most recent obs and infos (even if lookback is
|
| 683 |
+
# 0). Similar to an initial `env.reset()`.
|
| 684 |
+
indices_obs_and_infos = slice(-len_lookback_buffer - 1, None)
|
| 685 |
+
indices_rest = (
|
| 686 |
+
slice(-len_lookback_buffer, None)
|
| 687 |
+
if len_lookback_buffer > 0
|
| 688 |
+
else slice(None, 0)
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# Erase all temporary timestep data caches in `self`.
|
| 692 |
+
self._temporary_timestep_data.clear()
|
| 693 |
+
|
| 694 |
+
return SingleAgentEpisode(
|
| 695 |
+
# Same ID.
|
| 696 |
+
id_=self.id_,
|
| 697 |
+
observations=self.get_observations(indices=indices_obs_and_infos),
|
| 698 |
+
observation_space=self.observation_space,
|
| 699 |
+
infos=self.get_infos(indices=indices_obs_and_infos),
|
| 700 |
+
actions=self.get_actions(indices=indices_rest),
|
| 701 |
+
action_space=self.action_space,
|
| 702 |
+
rewards=self.get_rewards(indices=indices_rest),
|
| 703 |
+
extra_model_outputs={
|
| 704 |
+
k: self.get_extra_model_outputs(k, indices_rest)
|
| 705 |
+
for k in self.extra_model_outputs.keys()
|
| 706 |
+
},
|
| 707 |
+
# Continue with self's current timestep.
|
| 708 |
+
t_started=self.t,
|
| 709 |
+
# Use the length of the provided data as lookback buffer.
|
| 710 |
+
len_lookback_buffer="auto",
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# TODO (sven): Distinguish between:
|
| 714 |
+
# - global index: This is the absolute, global timestep whose values always
|
| 715 |
+
# start from 0 (at the env reset). So doing get_observations(0, global_ts=True)
|
| 716 |
+
# should always return the exact 1st observation (reset obs), no matter what. In
|
| 717 |
+
# case we are in an episode chunk and `fill` or a sufficient lookback buffer is
|
| 718 |
+
# provided, this should yield a result. Otherwise, error.
|
| 719 |
+
# - global index=False -> indices are relative to the chunk start. If a chunk has
|
| 720 |
+
# t_started=6 and we ask for index=0, then return observation at timestep 6
|
| 721 |
+
# (t_started).
|
| 722 |
+
def get_observations(
|
| 723 |
+
self,
|
| 724 |
+
indices: Optional[Union[int, List[int], slice]] = None,
|
| 725 |
+
*,
|
| 726 |
+
neg_index_as_lookback: bool = False,
|
| 727 |
+
fill: Optional[Any] = None,
|
| 728 |
+
one_hot_discrete: bool = False,
|
| 729 |
+
) -> Any:
|
| 730 |
+
"""Returns individual observations or batched ranges thereof from this episode.
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
indices: A single int is interpreted as an index, from which to return the
|
| 734 |
+
individual observation stored at this index.
|
| 735 |
+
A list of ints is interpreted as a list of indices from which to gather
|
| 736 |
+
individual observations in a batch of size len(indices).
|
| 737 |
+
A slice object is interpreted as a range of observations to be returned.
|
| 738 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 739 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 740 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 741 |
+
into the lookback buffer.
|
| 742 |
+
If None, will return all observations (from ts=0 to the end).
|
| 743 |
+
neg_index_as_lookback: If True, negative values in `indices` are
|
| 744 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 745 |
+
buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9],
|
| 746 |
+
where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will
|
| 747 |
+
respond to `get_observations(-1, neg_index_as_lookback=True)`
|
| 748 |
+
with `6` and to
|
| 749 |
+
`get_observations(slice(-2, 1), neg_index_as_lookback=True)` with
|
| 750 |
+
`[5, 6, 7]`.
|
| 751 |
+
fill: An optional value to use for filling up the returned results at
|
| 752 |
+
the boundaries. This filling only happens if the requested index range's
|
| 753 |
+
start/stop boundaries exceed the episode's boundaries (including the
|
| 754 |
+
lookback buffer on the left side). This comes in very handy, if users
|
| 755 |
+
don't want to worry about reaching such boundaries and want to zero-pad.
|
| 756 |
+
For example, an episode with observations [10, 11, 12, 13, 14] and
|
| 757 |
+
lookback buffer size of 2 (meaning observations `10` and `11` are part
|
| 758 |
+
of the lookback buffer) will respond to
|
| 759 |
+
`get_observations(slice(-7, -2), fill=0.0)` with
|
| 760 |
+
`[0.0, 0.0, 10, 11, 12]`.
|
| 761 |
+
one_hot_discrete: If True, will return one-hot vectors (instead of
|
| 762 |
+
int-values) for those sub-components of a (possibly complex) observation
|
| 763 |
+
space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
|
| 764 |
+
requested `indices` are out of the range of our data, the returned
|
| 765 |
+
one-hot vectors will actually be zero-hot (all slots zero).
|
| 766 |
+
|
| 767 |
+
Examples:
|
| 768 |
+
|
| 769 |
+
.. testcode::
|
| 770 |
+
|
| 771 |
+
import gymnasium as gym
|
| 772 |
+
|
| 773 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 774 |
+
from ray.rllib.utils.test_utils import check
|
| 775 |
+
|
| 776 |
+
episode = SingleAgentEpisode(
|
| 777 |
+
# Discrete(4) observations (ints between 0 and 4 (excl.))
|
| 778 |
+
observation_space=gym.spaces.Discrete(4),
|
| 779 |
+
observations=[0, 1, 2, 3],
|
| 780 |
+
actions=[1, 2, 3], rewards=[1, 2, 3], # <- not relevant for this demo
|
| 781 |
+
len_lookback_buffer=0, # no lookback; all data is actually "in" episode
|
| 782 |
+
)
|
| 783 |
+
# Plain usage (`indices` arg only).
|
| 784 |
+
check(episode.get_observations(-1), 3)
|
| 785 |
+
check(episode.get_observations(0), 0)
|
| 786 |
+
check(episode.get_observations([0, 2]), [0, 2])
|
| 787 |
+
check(episode.get_observations([-1, 0]), [3, 0])
|
| 788 |
+
check(episode.get_observations(slice(None, 2)), [0, 1])
|
| 789 |
+
check(episode.get_observations(slice(-2, None)), [2, 3])
|
| 790 |
+
# Using `fill=...` (requesting slices beyond the boundaries).
|
| 791 |
+
check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1])
|
| 792 |
+
check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7])
|
| 793 |
+
# Using `one_hot_discrete=True`.
|
| 794 |
+
check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0])
|
| 795 |
+
check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1])
|
| 796 |
+
check(episode.get_observations(
|
| 797 |
+
slice(0, 3),
|
| 798 |
+
one_hot_discrete=True,
|
| 799 |
+
), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]])
|
| 800 |
+
# Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
|
| 801 |
+
check(episode.get_observations(
|
| 802 |
+
-1,
|
| 803 |
+
neg_index_as_lookback=True, # -1 means one left of ts=0
|
| 804 |
+
fill=0.0,
|
| 805 |
+
one_hot_discrete=True,
|
| 806 |
+
), [0, 0, 0, 0]) # <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
|
| 807 |
+
|
| 808 |
+
Returns:
|
| 809 |
+
The collected observations.
|
| 810 |
+
As a 0-axis batch, if there are several `indices` or a list of exactly one
|
| 811 |
+
index provided OR `indices` is a slice object.
|
| 812 |
+
As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
|
| 813 |
+
"""
|
| 814 |
+
return self.observations.get(
|
| 815 |
+
indices=indices,
|
| 816 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 817 |
+
fill=fill,
|
| 818 |
+
one_hot_discrete=one_hot_discrete,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
def get_infos(
|
| 822 |
+
self,
|
| 823 |
+
indices: Optional[Union[int, List[int], slice]] = None,
|
| 824 |
+
*,
|
| 825 |
+
neg_index_as_lookback: bool = False,
|
| 826 |
+
fill: Optional[Any] = None,
|
| 827 |
+
) -> Any:
|
| 828 |
+
"""Returns individual info dicts or list (ranges) thereof from this episode.
|
| 829 |
+
|
| 830 |
+
Args:
|
| 831 |
+
indices: A single int is interpreted as an index, from which to return the
|
| 832 |
+
individual info dict stored at this index.
|
| 833 |
+
A list of ints is interpreted as a list of indices from which to gather
|
| 834 |
+
individual info dicts in a list of size len(indices).
|
| 835 |
+
A slice object is interpreted as a range of info dicts to be returned.
|
| 836 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 837 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 838 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 839 |
+
into the lookback buffer.
|
| 840 |
+
If None, will return all infos (from ts=0 to the end).
|
| 841 |
+
neg_index_as_lookback: If True, negative values in `indices` are
|
| 842 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 843 |
+
buffer. For example, an episode with infos
|
| 844 |
+
[{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the
|
| 845 |
+
first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will
|
| 846 |
+
respond to `get_infos(-1, neg_index_as_lookback=True)` with
|
| 847 |
+
`{"l":6}` and to
|
| 848 |
+
`get_infos(slice(-2, 1), neg_index_as_lookback=True)` with
|
| 849 |
+
`[{"l":5}, {"l":6}, {"a":7}]`.
|
| 850 |
+
fill: An optional value to use for filling up the returned results at
|
| 851 |
+
the boundaries. This filling only happens if the requested index range's
|
| 852 |
+
start/stop boundaries exceed the episode's boundaries (including the
|
| 853 |
+
lookback buffer on the left side). This comes in very handy, if users
|
| 854 |
+
don't want to worry about reaching such boundaries and want to
|
| 855 |
+
auto-fill. For example, an episode with infos
|
| 856 |
+
[{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer
|
| 857 |
+
size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback
|
| 858 |
+
buffer) will respond to `get_infos(slice(-7, -2), fill={"o": 0.0})`
|
| 859 |
+
with `[{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]`.
|
| 860 |
+
|
| 861 |
+
Examples:
|
| 862 |
+
|
| 863 |
+
.. testcode::
|
| 864 |
+
|
| 865 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 866 |
+
|
| 867 |
+
episode = SingleAgentEpisode(
|
| 868 |
+
infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}],
|
| 869 |
+
# The following is needed, but not relevant for this demo.
|
| 870 |
+
observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
|
| 871 |
+
len_lookback_buffer=0, # no lookback; all data is actually "in" episode
|
| 872 |
+
)
|
| 873 |
+
# Plain usage (`indices` arg only).
|
| 874 |
+
episode.get_infos(-1) # {"d":3}
|
| 875 |
+
episode.get_infos(0) # {"a":0}
|
| 876 |
+
episode.get_infos([0, 2]) # [{"a":0},{"c":2}]
|
| 877 |
+
episode.get_infos([-1, 0]) # [{"d":3},{"a":0}]
|
| 878 |
+
episode.get_infos(slice(None, 2)) # [{"a":0},{"b":1}]
|
| 879 |
+
episode.get_infos(slice(-2, None)) # [{"c":2},{"d":3}]
|
| 880 |
+
# Using `fill=...` (requesting slices beyond the boundaries).
|
| 881 |
+
# TODO (sven): This would require a space being provided. Maybe we can
|
| 882 |
+
# skip this check for infos, which don't have a space anyways.
|
| 883 |
+
# episode.get_infos(slice(-5, -3), fill={"o":-1}) # [{"o":-1},{"a":0}]
|
| 884 |
+
# episode.get_infos(slice(3, 5), fill={"o":-2}) # [{"d":3},{"o":-2}]
|
| 885 |
+
|
| 886 |
+
Returns:
|
| 887 |
+
The collected info dicts.
|
| 888 |
+
As a 0-axis batch, if there are several `indices` or a list of exactly one
|
| 889 |
+
index provided OR `indices` is a slice object.
|
| 890 |
+
As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
|
| 891 |
+
"""
|
| 892 |
+
return self.infos.get(
|
| 893 |
+
indices=indices,
|
| 894 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 895 |
+
fill=fill,
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
def get_actions(
|
| 899 |
+
self,
|
| 900 |
+
indices: Optional[Union[int, List[int], slice]] = None,
|
| 901 |
+
*,
|
| 902 |
+
neg_index_as_lookback: bool = False,
|
| 903 |
+
fill: Optional[Any] = None,
|
| 904 |
+
one_hot_discrete: bool = False,
|
| 905 |
+
) -> Any:
|
| 906 |
+
"""Returns individual actions or batched ranges thereof from this episode.
|
| 907 |
+
|
| 908 |
+
Args:
|
| 909 |
+
indices: A single int is interpreted as an index, from which to return the
|
| 910 |
+
individual action stored at this index.
|
| 911 |
+
A list of ints is interpreted as a list of indices from which to gather
|
| 912 |
+
individual actions in a batch of size len(indices).
|
| 913 |
+
A slice object is interpreted as a range of actions to be returned.
|
| 914 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 915 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 916 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 917 |
+
into the lookback buffer.
|
| 918 |
+
If None, will return all actions (from ts=0 to the end).
|
| 919 |
+
neg_index_as_lookback: If True, negative values in `indices` are
|
| 920 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 921 |
+
buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where
|
| 922 |
+
[4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond
|
| 923 |
+
to `get_actions(-1, neg_index_as_lookback=True)` with `6` and
|
| 924 |
+
to `get_actions(slice(-2, 1), neg_index_as_lookback=True)` with
|
| 925 |
+
`[5, 6, 7]`.
|
| 926 |
+
fill: An optional value to use for filling up the returned results at
|
| 927 |
+
the boundaries. This filling only happens if the requested index range's
|
| 928 |
+
start/stop boundaries exceed the episode's boundaries (including the
|
| 929 |
+
lookback buffer on the left side). This comes in very handy, if users
|
| 930 |
+
don't want to worry about reaching such boundaries and want to zero-pad.
|
| 931 |
+
For example, an episode with actions [10, 11, 12, 13, 14] and
|
| 932 |
+
lookback buffer size of 2 (meaning actions `10` and `11` are part
|
| 933 |
+
of the lookback buffer) will respond to
|
| 934 |
+
`get_actions(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`.
|
| 935 |
+
one_hot_discrete: If True, will return one-hot vectors (instead of
|
| 936 |
+
int-values) for those sub-components of a (possibly complex) action
|
| 937 |
+
space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
|
| 938 |
+
requested `indices` are out of the range of our data, the returned
|
| 939 |
+
one-hot vectors will actually be zero-hot (all slots zero).
|
| 940 |
+
|
| 941 |
+
Examples:
|
| 942 |
+
|
| 943 |
+
.. testcode::
|
| 944 |
+
|
| 945 |
+
import gymnasium as gym
|
| 946 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 947 |
+
|
| 948 |
+
episode = SingleAgentEpisode(
|
| 949 |
+
# Discrete(4) actions (ints between 0 and 4 (excl.))
|
| 950 |
+
action_space=gym.spaces.Discrete(4),
|
| 951 |
+
actions=[1, 2, 3],
|
| 952 |
+
observations=[0, 1, 2, 3], rewards=[1, 2, 3], # <- not relevant here
|
| 953 |
+
len_lookback_buffer=0, # no lookback; all data is actually "in" episode
|
| 954 |
+
)
|
| 955 |
+
# Plain usage (`indices` arg only).
|
| 956 |
+
episode.get_actions(-1) # 3
|
| 957 |
+
episode.get_actions(0) # 1
|
| 958 |
+
episode.get_actions([0, 2]) # [1, 3]
|
| 959 |
+
episode.get_actions([-1, 0]) # [3, 1]
|
| 960 |
+
episode.get_actions(slice(None, 2)) # [1, 2]
|
| 961 |
+
episode.get_actions(slice(-2, None)) # [2, 3]
|
| 962 |
+
# Using `fill=...` (requesting slices beyond the boundaries).
|
| 963 |
+
episode.get_actions(slice(-5, -2), fill=-9) # [-9, -9, 1, 2]
|
| 964 |
+
episode.get_actions(slice(1, 5), fill=-7) # [2, 3, -7, -7]
|
| 965 |
+
# Using `one_hot_discrete=True`.
|
| 966 |
+
episode.get_actions(1, one_hot_discrete=True) # [0 0 1 0] (action=2)
|
| 967 |
+
episode.get_actions(2, one_hot_discrete=True) # [0 0 0 1] (action=3)
|
| 968 |
+
episode.get_actions(
|
| 969 |
+
slice(0, 2),
|
| 970 |
+
one_hot_discrete=True,
|
| 971 |
+
) # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3)
|
| 972 |
+
# Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
|
| 973 |
+
episode.get_actions(
|
| 974 |
+
-1,
|
| 975 |
+
neg_index_as_lookback=True, # -1 means one left of ts=0
|
| 976 |
+
fill=0.0,
|
| 977 |
+
one_hot_discrete=True,
|
| 978 |
+
) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
|
| 979 |
+
|
| 980 |
+
Returns:
|
| 981 |
+
The collected actions.
|
| 982 |
+
As a 0-axis batch, if there are several `indices` or a list of exactly one
|
| 983 |
+
index provided OR `indices` is a slice object.
|
| 984 |
+
As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
|
| 985 |
+
"""
|
| 986 |
+
return self.actions.get(
|
| 987 |
+
indices=indices,
|
| 988 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 989 |
+
fill=fill,
|
| 990 |
+
one_hot_discrete=one_hot_discrete,
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
def get_rewards(
|
| 994 |
+
self,
|
| 995 |
+
indices: Optional[Union[int, List[int], slice]] = None,
|
| 996 |
+
*,
|
| 997 |
+
neg_index_as_lookback: bool = False,
|
| 998 |
+
fill: Optional[float] = None,
|
| 999 |
+
) -> Any:
|
| 1000 |
+
"""Returns individual rewards or batched ranges thereof from this episode.
|
| 1001 |
+
|
| 1002 |
+
Args:
|
| 1003 |
+
indices: A single int is interpreted as an index, from which to return the
|
| 1004 |
+
individual reward stored at this index.
|
| 1005 |
+
A list of ints is interpreted as a list of indices from which to gather
|
| 1006 |
+
individual rewards in a batch of size len(indices).
|
| 1007 |
+
A slice object is interpreted as a range of rewards to be returned.
|
| 1008 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 1009 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 1010 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 1011 |
+
into the lookback buffer.
|
| 1012 |
+
If None, will return all rewards (from ts=0 to the end).
|
| 1013 |
+
neg_index_as_lookback: Negative values in `indices` are interpreted as
|
| 1014 |
+
as "before ts=0", meaning going back into the lookback buffer.
|
| 1015 |
+
For example, an episode with rewards [4, 5, 6, 7, 8, 9], where
|
| 1016 |
+
[4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond
|
| 1017 |
+
to `get_rewards(-1, neg_index_as_lookback=True)` with `6` and
|
| 1018 |
+
to `get_rewards(slice(-2, 1), neg_index_as_lookback=True)` with
|
| 1019 |
+
`[5, 6, 7]`.
|
| 1020 |
+
fill: An optional float value to use for filling up the returned results at
|
| 1021 |
+
the boundaries. This filling only happens if the requested index range's
|
| 1022 |
+
start/stop boundaries exceed the episode's boundaries (including the
|
| 1023 |
+
lookback buffer on the left side). This comes in very handy, if users
|
| 1024 |
+
don't want to worry about reaching such boundaries and want to zero-pad.
|
| 1025 |
+
For example, an episode with rewards [10, 11, 12, 13, 14] and
|
| 1026 |
+
lookback buffer size of 2 (meaning rewards `10` and `11` are part
|
| 1027 |
+
of the lookback buffer) will respond to
|
| 1028 |
+
`get_rewards(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`.
|
| 1029 |
+
|
| 1030 |
+
Examples:
|
| 1031 |
+
|
| 1032 |
+
.. testcode::
|
| 1033 |
+
|
| 1034 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 1035 |
+
|
| 1036 |
+
episode = SingleAgentEpisode(
|
| 1037 |
+
rewards=[1.0, 2.0, 3.0],
|
| 1038 |
+
observations=[0, 1, 2, 3], actions=[1, 2, 3], # <- not relevant here
|
| 1039 |
+
len_lookback_buffer=0, # no lookback; all data is actually "in" episode
|
| 1040 |
+
)
|
| 1041 |
+
# Plain usage (`indices` arg only).
|
| 1042 |
+
episode.get_rewards(-1) # 3.0
|
| 1043 |
+
episode.get_rewards(0) # 1.0
|
| 1044 |
+
episode.get_rewards([0, 2]) # [1.0, 3.0]
|
| 1045 |
+
episode.get_rewards([-1, 0]) # [3.0, 1.0]
|
| 1046 |
+
episode.get_rewards(slice(None, 2)) # [1.0, 2.0]
|
| 1047 |
+
episode.get_rewards(slice(-2, None)) # [2.0, 3.0]
|
| 1048 |
+
# Using `fill=...` (requesting slices beyond the boundaries).
|
| 1049 |
+
episode.get_rewards(slice(-5, -2), fill=0.0) # [0.0, 0.0, 1.0, 2.0]
|
| 1050 |
+
episode.get_rewards(slice(1, 5), fill=0.0) # [2.0, 3.0, 0.0, 0.0]
|
| 1051 |
+
|
| 1052 |
+
Returns:
|
| 1053 |
+
The collected rewards.
|
| 1054 |
+
As a 0-axis batch, if there are several `indices` or a list of exactly one
|
| 1055 |
+
index provided OR `indices` is a slice object.
|
| 1056 |
+
As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
|
| 1057 |
+
"""
|
| 1058 |
+
return self.rewards.get(
|
| 1059 |
+
indices=indices,
|
| 1060 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 1061 |
+
fill=fill,
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
def get_extra_model_outputs(
|
| 1065 |
+
self,
|
| 1066 |
+
key: str,
|
| 1067 |
+
indices: Optional[Union[int, List[int], slice]] = None,
|
| 1068 |
+
*,
|
| 1069 |
+
neg_index_as_lookback: bool = False,
|
| 1070 |
+
fill: Optional[Any] = None,
|
| 1071 |
+
) -> Any:
|
| 1072 |
+
"""Returns extra model outputs (under given key) from this episode.
|
| 1073 |
+
|
| 1074 |
+
Args:
|
| 1075 |
+
key: The `key` within `self.extra_model_outputs` to extract data for.
|
| 1076 |
+
indices: A single int is interpreted as an index, from which to return an
|
| 1077 |
+
individual extra model output stored under `key` at index.
|
| 1078 |
+
A list of ints is interpreted as a list of indices from which to gather
|
| 1079 |
+
individual actions in a batch of size len(indices).
|
| 1080 |
+
A slice object is interpreted as a range of extra model outputs to be
|
| 1081 |
+
returned. Thereby, negative indices by default are interpreted as
|
| 1082 |
+
"before the end" unless the `neg_index_as_lookback=True` option is
|
| 1083 |
+
used, in which case negative indices are interpreted as "before ts=0",
|
| 1084 |
+
meaning going back into the lookback buffer.
|
| 1085 |
+
If None, will return all extra model outputs (from ts=0 to the end).
|
| 1086 |
+
neg_index_as_lookback: If True, negative values in `indices` are
|
| 1087 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 1088 |
+
buffer. For example, an episode with
|
| 1089 |
+
extra_model_outputs['a'] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
|
| 1090 |
+
lookback buffer range (ts=0 item is 7), will respond to
|
| 1091 |
+
`get_extra_model_outputs("a", -1, neg_index_as_lookback=True)` with
|
| 1092 |
+
`6` and to `get_extra_model_outputs("a", slice(-2, 1),
|
| 1093 |
+
neg_index_as_lookback=True)` with `[5, 6, 7]`.
|
| 1094 |
+
fill: An optional value to use for filling up the returned results at
|
| 1095 |
+
the boundaries. This filling only happens if the requested index range's
|
| 1096 |
+
start/stop boundaries exceed the episode's boundaries (including the
|
| 1097 |
+
lookback buffer on the left side). This comes in very handy, if users
|
| 1098 |
+
don't want to worry about reaching such boundaries and want to zero-pad.
|
| 1099 |
+
For example, an episode with
|
| 1100 |
+
extra_model_outputs["b"] = [10, 11, 12, 13, 14] and lookback buffer
|
| 1101 |
+
size of 2 (meaning `10` and `11` are part of the lookback buffer) will
|
| 1102 |
+
respond to
|
| 1103 |
+
`get_extra_model_outputs("b", slice(-7, -2), fill=0.0)` with
|
| 1104 |
+
`[0.0, 0.0, 10, 11, 12]`.
|
| 1105 |
+
TODO (sven): This would require a space being provided. Maybe we can
|
| 1106 |
+
automatically infer the space from existing data?
|
| 1107 |
+
|
| 1108 |
+
Examples:
|
| 1109 |
+
|
| 1110 |
+
.. testcode::
|
| 1111 |
+
|
| 1112 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 1113 |
+
|
| 1114 |
+
episode = SingleAgentEpisode(
|
| 1115 |
+
extra_model_outputs={"mo": [1, 2, 3]},
|
| 1116 |
+
len_lookback_buffer=0, # no lookback; all data is actually "in" episode
|
| 1117 |
+
# The following is needed, but not relevant for this demo.
|
| 1118 |
+
observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
# Plain usage (`indices` arg only).
|
| 1122 |
+
episode.get_extra_model_outputs("mo", -1) # 3
|
| 1123 |
+
episode.get_extra_model_outputs("mo", 1) # 0
|
| 1124 |
+
episode.get_extra_model_outputs("mo", [0, 2]) # [1, 3]
|
| 1125 |
+
episode.get_extra_model_outputs("mo", [-1, 0]) # [3, 1]
|
| 1126 |
+
episode.get_extra_model_outputs("mo", slice(None, 2)) # [1, 2]
|
| 1127 |
+
episode.get_extra_model_outputs("mo", slice(-2, None)) # [2, 3]
|
| 1128 |
+
# Using `fill=...` (requesting slices beyond the boundaries).
|
| 1129 |
+
# TODO (sven): This would require a space being provided. Maybe we can
|
| 1130 |
+
# automatically infer the space from existing data?
|
| 1131 |
+
# episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0) # [0, 0, 1]
|
| 1132 |
+
# episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1) # [3, -1, -1]
|
| 1133 |
+
|
| 1134 |
+
Returns:
|
| 1135 |
+
The collected extra_model_outputs[`key`].
|
| 1136 |
+
As a 0-axis batch, if there are several `indices` or a list of exactly one
|
| 1137 |
+
index provided OR `indices` is a slice object.
|
| 1138 |
+
As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
|
| 1139 |
+
"""
|
| 1140 |
+
value = self.extra_model_outputs[key]
|
| 1141 |
+
# The expected case is: `value` is a `InfiniteLookbackBuffer`.
|
| 1142 |
+
if isinstance(value, InfiniteLookbackBuffer):
|
| 1143 |
+
return value.get(
|
| 1144 |
+
indices=indices,
|
| 1145 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 1146 |
+
fill=fill,
|
| 1147 |
+
)
|
| 1148 |
+
# TODO (sven): This does not seem to be solid yet. Users should NOT be able
|
| 1149 |
+
# to just write directly into our buffers. Instead, use:
|
| 1150 |
+
# `self.set_extra_model_outputs(key, new_data, at_indices=...)` and if key
|
| 1151 |
+
# is not known, add a new buffer to the `extra_model_outputs` dict.
|
| 1152 |
+
assert False
|
| 1153 |
+
# It might be that the user has added new key/value pairs in their custom
|
| 1154 |
+
# postprocessing/connector logic. The values are then most likely numpy
|
| 1155 |
+
# arrays. We convert them automatically to buffers and get the requested
|
| 1156 |
+
# indices (with the given options) from there.
|
| 1157 |
+
return InfiniteLookbackBuffer(value).get(
|
| 1158 |
+
indices, fill=fill, neg_index_as_lookback=neg_index_as_lookback
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
def set_observations(
|
| 1162 |
+
self,
|
| 1163 |
+
*,
|
| 1164 |
+
new_data,
|
| 1165 |
+
at_indices: Optional[Union[int, List[int], slice]] = None,
|
| 1166 |
+
neg_index_as_lookback: bool = False,
|
| 1167 |
+
) -> None:
|
| 1168 |
+
"""Overwrites all or some of this Episode's observations with the provided data.
|
| 1169 |
+
|
| 1170 |
+
Note that an episode's observation data cannot be written to directly as it is
|
| 1171 |
+
managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
|
| 1172 |
+
observations are added to the episode either by calling `self.add_env_step` or
|
| 1173 |
+
more directly (and manually) via `self.observations.append|extend()`.
|
| 1174 |
+
However, for certain postprocessing steps, the entirety (or a slice) of an
|
| 1175 |
+
episode's observations might have to be rewritten, which is when
|
| 1176 |
+
`self.set_observations()` should be used.
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
new_data: The new observation data to overwrite existing data with.
|
| 1180 |
+
This may be a list of individual observation(s) in case this episode
|
| 1181 |
+
is still not numpy'ized yet. In case this episode has already been
|
| 1182 |
+
numpy'ized, this should be (possibly complex) struct matching the
|
| 1183 |
+
observation space and with a batch size of its leafs exactly the size
|
| 1184 |
+
of the to-be-overwritten slice or segment (provided by `at_indices`).
|
| 1185 |
+
at_indices: A single int is interpreted as one index, which to overwrite
|
| 1186 |
+
with `new_data` (which is expected to be a single observation).
|
| 1187 |
+
A list of ints is interpreted as a list of indices, all of which to
|
| 1188 |
+
overwrite with `new_data` (which is expected to be of the same size
|
| 1189 |
+
as `len(at_indices)`).
|
| 1190 |
+
A slice object is interpreted as a range of indices to be overwritten
|
| 1191 |
+
with `new_data` (which is expected to be of the same size as the
|
| 1192 |
+
provided slice).
|
| 1193 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 1194 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 1195 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 1196 |
+
into the lookback buffer.
|
| 1197 |
+
neg_index_as_lookback: If True, negative values in `at_indices` are
|
| 1198 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 1199 |
+
buffer. For example, an episode with
|
| 1200 |
+
observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
|
| 1201 |
+
lookback buffer range (ts=0 item is 7), will handle a call to
|
| 1202 |
+
`set_observations(individual_observation, -1,
|
| 1203 |
+
neg_index_as_lookback=True)` by overwriting the value of 6 in our
|
| 1204 |
+
observations buffer with the provided "individual_observation".
|
| 1205 |
+
|
| 1206 |
+
Raises:
|
| 1207 |
+
IndexError: If the provided `at_indices` do not match the size of
|
| 1208 |
+
`new_data`.
|
| 1209 |
+
"""
|
| 1210 |
+
self.observations.set(
|
| 1211 |
+
new_data=new_data,
|
| 1212 |
+
at_indices=at_indices,
|
| 1213 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
def set_actions(
|
| 1217 |
+
self,
|
| 1218 |
+
*,
|
| 1219 |
+
new_data,
|
| 1220 |
+
at_indices: Optional[Union[int, List[int], slice]] = None,
|
| 1221 |
+
neg_index_as_lookback: bool = False,
|
| 1222 |
+
) -> None:
|
| 1223 |
+
"""Overwrites all or some of this Episode's actions with the provided data.
|
| 1224 |
+
|
| 1225 |
+
Note that an episode's action data cannot be written to directly as it is
|
| 1226 |
+
managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
|
| 1227 |
+
actions are added to the episode either by calling `self.add_env_step` or
|
| 1228 |
+
more directly (and manually) via `self.actions.append|extend()`.
|
| 1229 |
+
However, for certain postprocessing steps, the entirety (or a slice) of an
|
| 1230 |
+
episode's actions might have to be rewritten, which is when
|
| 1231 |
+
`self.set_actions()` should be used.
|
| 1232 |
+
|
| 1233 |
+
Args:
|
| 1234 |
+
new_data: The new action data to overwrite existing data with.
|
| 1235 |
+
This may be a list of individual action(s) in case this episode
|
| 1236 |
+
is still not numpy'ized yet. In case this episode has already been
|
| 1237 |
+
numpy'ized, this should be (possibly complex) struct matching the
|
| 1238 |
+
action space and with a batch size of its leafs exactly the size
|
| 1239 |
+
of the to-be-overwritten slice or segment (provided by `at_indices`).
|
| 1240 |
+
at_indices: A single int is interpreted as one index, which to overwrite
|
| 1241 |
+
with `new_data` (which is expected to be a single action).
|
| 1242 |
+
A list of ints is interpreted as a list of indices, all of which to
|
| 1243 |
+
overwrite with `new_data` (which is expected to be of the same size
|
| 1244 |
+
as `len(at_indices)`).
|
| 1245 |
+
A slice object is interpreted as a range of indices to be overwritten
|
| 1246 |
+
with `new_data` (which is expected to be of the same size as the
|
| 1247 |
+
provided slice).
|
| 1248 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 1249 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 1250 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 1251 |
+
into the lookback buffer.
|
| 1252 |
+
neg_index_as_lookback: If True, negative values in `at_indices` are
|
| 1253 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 1254 |
+
buffer. For example, an episode with
|
| 1255 |
+
actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
|
| 1256 |
+
lookback buffer range (ts=0 item is 7), will handle a call to
|
| 1257 |
+
`set_actions(individual_action, -1,
|
| 1258 |
+
neg_index_as_lookback=True)` by overwriting the value of 6 in our
|
| 1259 |
+
actions buffer with the provided "individual_action".
|
| 1260 |
+
|
| 1261 |
+
Raises:
|
| 1262 |
+
IndexError: If the provided `at_indices` do not match the size of
|
| 1263 |
+
`new_data`.
|
| 1264 |
+
"""
|
| 1265 |
+
self.actions.set(
|
| 1266 |
+
new_data=new_data,
|
| 1267 |
+
at_indices=at_indices,
|
| 1268 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
def set_rewards(
|
| 1272 |
+
self,
|
| 1273 |
+
*,
|
| 1274 |
+
new_data,
|
| 1275 |
+
at_indices: Optional[Union[int, List[int], slice]] = None,
|
| 1276 |
+
neg_index_as_lookback: bool = False,
|
| 1277 |
+
) -> None:
|
| 1278 |
+
"""Overwrites all or some of this Episode's rewards with the provided data.
|
| 1279 |
+
|
| 1280 |
+
Note that an episode's reward data cannot be written to directly as it is
|
| 1281 |
+
managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
|
| 1282 |
+
rewards are added to the episode either by calling `self.add_env_step` or
|
| 1283 |
+
more directly (and manually) via `self.rewards.append|extend()`.
|
| 1284 |
+
However, for certain postprocessing steps, the entirety (or a slice) of an
|
| 1285 |
+
episode's rewards might have to be rewritten, which is when
|
| 1286 |
+
`self.set_rewards()` should be used.
|
| 1287 |
+
|
| 1288 |
+
Args:
|
| 1289 |
+
new_data: The new reward data to overwrite existing data with.
|
| 1290 |
+
This may be a list of individual reward(s) in case this episode
|
| 1291 |
+
is still not numpy'ized yet. In case this episode has already been
|
| 1292 |
+
numpy'ized, this should be a np.ndarray with a length exactly
|
| 1293 |
+
the size of the to-be-overwritten slice or segment (provided by
|
| 1294 |
+
`at_indices`).
|
| 1295 |
+
at_indices: A single int is interpreted as one index, which to overwrite
|
| 1296 |
+
with `new_data` (which is expected to be a single reward).
|
| 1297 |
+
A list of ints is interpreted as a list of indices, all of which to
|
| 1298 |
+
overwrite with `new_data` (which is expected to be of the same size
|
| 1299 |
+
as `len(at_indices)`).
|
| 1300 |
+
A slice object is interpreted as a range of indices to be overwritten
|
| 1301 |
+
with `new_data` (which is expected to be of the same size as the
|
| 1302 |
+
provided slice).
|
| 1303 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 1304 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 1305 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 1306 |
+
into the lookback buffer.
|
| 1307 |
+
neg_index_as_lookback: If True, negative values in `at_indices` are
|
| 1308 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 1309 |
+
buffer. For example, an episode with
|
| 1310 |
+
rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
|
| 1311 |
+
lookback buffer range (ts=0 item is 7), will handle a call to
|
| 1312 |
+
`set_rewards(individual_reward, -1,
|
| 1313 |
+
neg_index_as_lookback=True)` by overwriting the value of 6 in our
|
| 1314 |
+
rewards buffer with the provided "individual_reward".
|
| 1315 |
+
|
| 1316 |
+
Raises:
|
| 1317 |
+
IndexError: If the provided `at_indices` do not match the size of
|
| 1318 |
+
`new_data`.
|
| 1319 |
+
"""
|
| 1320 |
+
self.rewards.set(
|
| 1321 |
+
new_data=new_data,
|
| 1322 |
+
at_indices=at_indices,
|
| 1323 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 1324 |
+
)
|
| 1325 |
+
|
| 1326 |
+
def set_extra_model_outputs(
|
| 1327 |
+
self,
|
| 1328 |
+
*,
|
| 1329 |
+
key,
|
| 1330 |
+
new_data,
|
| 1331 |
+
at_indices: Optional[Union[int, List[int], slice]] = None,
|
| 1332 |
+
neg_index_as_lookback: bool = False,
|
| 1333 |
+
) -> None:
|
| 1334 |
+
"""Overwrites all or some of this Episode's extra model outputs with `new_data`.
|
| 1335 |
+
|
| 1336 |
+
Note that an episode's `extra_model_outputs` data cannot be written to directly
|
| 1337 |
+
as it is managed by a `InfiniteLookbackBuffer` object. Normally, individual,
|
| 1338 |
+
current `extra_model_output` values are added to the episode either by calling
|
| 1339 |
+
`self.add_env_step` or more directly (and manually) via
|
| 1340 |
+
`self.extra_model_outputs[key].append|extend()`. However, for certain
|
| 1341 |
+
postprocessing steps, the entirety (or a slice) of an episode's
|
| 1342 |
+
`extra_model_outputs` might have to be rewritten or a new key (a new type of
|
| 1343 |
+
`extra_model_outputs`) must be inserted, which is when
|
| 1344 |
+
`self.set_extra_model_outputs()` should be used.
|
| 1345 |
+
|
| 1346 |
+
Args:
|
| 1347 |
+
key: The `key` within `self.extra_model_outputs` to override data on or
|
| 1348 |
+
to insert as a new key into `self.extra_model_outputs`.
|
| 1349 |
+
new_data: The new data to overwrite existing data with.
|
| 1350 |
+
This may be a list of individual reward(s) in case this episode
|
| 1351 |
+
is still not numpy'ized yet. In case this episode has already been
|
| 1352 |
+
numpy'ized, this should be a np.ndarray with a length exactly
|
| 1353 |
+
the size of the to-be-overwritten slice or segment (provided by
|
| 1354 |
+
`at_indices`).
|
| 1355 |
+
at_indices: A single int is interpreted as one index, which to overwrite
|
| 1356 |
+
with `new_data` (which is expected to be a single reward).
|
| 1357 |
+
A list of ints is interpreted as a list of indices, all of which to
|
| 1358 |
+
overwrite with `new_data` (which is expected to be of the same size
|
| 1359 |
+
as `len(at_indices)`).
|
| 1360 |
+
A slice object is interpreted as a range of indices to be overwritten
|
| 1361 |
+
with `new_data` (which is expected to be of the same size as the
|
| 1362 |
+
provided slice).
|
| 1363 |
+
Thereby, negative indices by default are interpreted as "before the end"
|
| 1364 |
+
unless the `neg_index_as_lookback=True` option is used, in which case
|
| 1365 |
+
negative indices are interpreted as "before ts=0", meaning going back
|
| 1366 |
+
into the lookback buffer.
|
| 1367 |
+
neg_index_as_lookback: If True, negative values in `at_indices` are
|
| 1368 |
+
interpreted as "before ts=0", meaning going back into the lookback
|
| 1369 |
+
buffer. For example, an episode with
|
| 1370 |
+
rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
|
| 1371 |
+
lookback buffer range (ts=0 item is 7), will handle a call to
|
| 1372 |
+
`set_rewards(individual_reward, -1,
|
| 1373 |
+
neg_index_as_lookback=True)` by overwriting the value of 6 in our
|
| 1374 |
+
rewards buffer with the provided "individual_reward".
|
| 1375 |
+
|
| 1376 |
+
Raises:
|
| 1377 |
+
IndexError: If the provided `at_indices` do not match the size of
|
| 1378 |
+
`new_data`.
|
| 1379 |
+
"""
|
| 1380 |
+
# Record already exists -> Set existing record's data to new values.
|
| 1381 |
+
assert key in self.extra_model_outputs
|
| 1382 |
+
self.extra_model_outputs[key].set(
|
| 1383 |
+
new_data=new_data,
|
| 1384 |
+
at_indices=at_indices,
|
| 1385 |
+
neg_index_as_lookback=neg_index_as_lookback,
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
def add_temporary_timestep_data(self, key: str, data: Any) -> None:
|
| 1389 |
+
"""Temporarily adds (until `to_numpy()` called) per-timestep data to self.
|
| 1390 |
+
|
| 1391 |
+
The given `data` is appended to a list (`self._temporary_timestep_data`), which
|
| 1392 |
+
is cleared upon calling `self.to_numpy()`. To get the thus-far accumulated
|
| 1393 |
+
temporary timestep data for a certain key, use the `get_temporary_timestep_data`
|
| 1394 |
+
API.
|
| 1395 |
+
Note that the size of the per timestep list is NOT checked or validated against
|
| 1396 |
+
the other, non-temporary data in this episode (like observations).
|
| 1397 |
+
|
| 1398 |
+
Args:
|
| 1399 |
+
key: The key under which to find the list to append `data` to. If `data` is
|
| 1400 |
+
the first data to be added for this key, start a new list.
|
| 1401 |
+
data: The data item (representing a single timestep) to be stored.
|
| 1402 |
+
"""
|
| 1403 |
+
if self.is_numpy:
|
| 1404 |
+
raise ValueError(
|
| 1405 |
+
"Cannot use the `add_temporary_timestep_data` API on an already "
|
| 1406 |
+
f"numpy'ized {type(self).__name__}!"
|
| 1407 |
+
)
|
| 1408 |
+
self._temporary_timestep_data[key].append(data)
|
| 1409 |
+
|
| 1410 |
+
def get_temporary_timestep_data(self, key: str) -> List[Any]:
|
| 1411 |
+
"""Returns all temporarily stored data items (list) under the given key.
|
| 1412 |
+
|
| 1413 |
+
Note that all temporary timestep data is erased/cleared when calling
|
| 1414 |
+
`self.to_numpy()`.
|
| 1415 |
+
|
| 1416 |
+
Returns:
|
| 1417 |
+
The current list storing temporary timestep data under `key`.
|
| 1418 |
+
"""
|
| 1419 |
+
if self.is_numpy:
|
| 1420 |
+
raise ValueError(
|
| 1421 |
+
"Cannot use the `get_temporary_timestep_data` API on an already "
|
| 1422 |
+
f"numpy'ized {type(self).__name__}! All temporary data has been erased "
|
| 1423 |
+
f"upon `{type(self).__name__}.to_numpy()`."
|
| 1424 |
+
)
|
| 1425 |
+
try:
|
| 1426 |
+
return self._temporary_timestep_data[key]
|
| 1427 |
+
except KeyError:
|
| 1428 |
+
raise KeyError(f"Key {key} not found in temporary timestep data!")
|
| 1429 |
+
|
| 1430 |
+
def slice(
|
| 1431 |
+
self,
|
| 1432 |
+
slice_: slice,
|
| 1433 |
+
*,
|
| 1434 |
+
len_lookback_buffer: Optional[int] = None,
|
| 1435 |
+
) -> "SingleAgentEpisode":
|
| 1436 |
+
"""Returns a slice of this episode with the given slice object.
|
| 1437 |
+
|
| 1438 |
+
For example, if `self` contains o0 (the reset observation), o1, o2, o3, and o4
|
| 1439 |
+
and the actions a1, a2, a3, and a4 (len of `self` is 4), then a call to
|
| 1440 |
+
`self.slice(slice(1, 3))` would return a new SingleAgentEpisode with
|
| 1441 |
+
observations o1, o2, and o3, and actions a2 and a3. Note here that there is
|
| 1442 |
+
always one observation more in an episode than there are actions (and rewards
|
| 1443 |
+
and extra model outputs) due to the initial observation received after an env
|
| 1444 |
+
reset.
|
| 1445 |
+
|
| 1446 |
+
.. testcode::
|
| 1447 |
+
|
| 1448 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 1449 |
+
from ray.rllib.utils.test_utils import check
|
| 1450 |
+
|
| 1451 |
+
# Generate a simple multi-agent episode.
|
| 1452 |
+
observations = [0, 1, 2, 3, 4, 5]
|
| 1453 |
+
actions = [1, 2, 3, 4, 5]
|
| 1454 |
+
rewards = [0.1, 0.2, 0.3, 0.4, 0.5]
|
| 1455 |
+
episode = SingleAgentEpisode(
|
| 1456 |
+
observations=observations,
|
| 1457 |
+
actions=actions,
|
| 1458 |
+
rewards=rewards,
|
| 1459 |
+
len_lookback_buffer=0, # all given data is part of the episode
|
| 1460 |
+
)
|
| 1461 |
+
slice_1 = episode[:1]
|
| 1462 |
+
check(slice_1.observations, [0, 1])
|
| 1463 |
+
check(slice_1.actions, [1])
|
| 1464 |
+
check(slice_1.rewards, [0.1])
|
| 1465 |
+
|
| 1466 |
+
slice_2 = episode[-2:]
|
| 1467 |
+
check(slice_2.observations, [3, 4, 5])
|
| 1468 |
+
check(slice_2.actions, [4, 5])
|
| 1469 |
+
check(slice_2.rewards, [0.4, 0.5])
|
| 1470 |
+
|
| 1471 |
+
Args:
|
| 1472 |
+
slice_: The slice object to use for slicing. This should exclude the
|
| 1473 |
+
lookback buffer, which will be prepended automatically to the returned
|
| 1474 |
+
slice.
|
| 1475 |
+
len_lookback_buffer: If not None, forces the returned slice to try to have
|
| 1476 |
+
this number of timesteps in its lookback buffer (if available). If None
|
| 1477 |
+
(default), tries to make the returned slice's lookback as large as the
|
| 1478 |
+
current lookback buffer of this episode (`self`).
|
| 1479 |
+
|
| 1480 |
+
Returns:
|
| 1481 |
+
The new SingleAgentEpisode representing the requested slice.
|
| 1482 |
+
"""
|
| 1483 |
+
# Translate `slice_` into one that only contains 0-or-positive ints and will
|
| 1484 |
+
# NOT contain any None.
|
| 1485 |
+
start = slice_.start
|
| 1486 |
+
stop = slice_.stop
|
| 1487 |
+
|
| 1488 |
+
# Start is None -> 0.
|
| 1489 |
+
if start is None:
|
| 1490 |
+
start = 0
|
| 1491 |
+
# Start is negative -> Interpret index as counting "from end".
|
| 1492 |
+
elif start < 0:
|
| 1493 |
+
start = len(self) + start
|
| 1494 |
+
|
| 1495 |
+
# Stop is None -> Set stop to our len (one ts past last valid index).
|
| 1496 |
+
if stop is None:
|
| 1497 |
+
stop = len(self)
|
| 1498 |
+
# Stop is negative -> Interpret index as counting "from end".
|
| 1499 |
+
elif stop < 0:
|
| 1500 |
+
stop = len(self) + stop
|
| 1501 |
+
|
| 1502 |
+
step = slice_.step if slice_.step is not None else 1
|
| 1503 |
+
|
| 1504 |
+
# Figure out, whether slicing stops at the very end of this episode to know
|
| 1505 |
+
# whether `self.is_terminated/is_truncated` should be kept as-is.
|
| 1506 |
+
keep_done = stop == len(self)
|
| 1507 |
+
# Provide correct timestep- and pre-buffer information.
|
| 1508 |
+
t_started = self.t_started + start
|
| 1509 |
+
|
| 1510 |
+
_lb = (
|
| 1511 |
+
len_lookback_buffer
|
| 1512 |
+
if len_lookback_buffer is not None
|
| 1513 |
+
else self.observations.lookback
|
| 1514 |
+
)
|
| 1515 |
+
if (
|
| 1516 |
+
start >= 0
|
| 1517 |
+
and start - _lb < 0
|
| 1518 |
+
and self.observations.lookback < (_lb - start)
|
| 1519 |
+
):
|
| 1520 |
+
_lb = self.observations.lookback + start
|
| 1521 |
+
observations = InfiniteLookbackBuffer(
|
| 1522 |
+
data=self.get_observations(
|
| 1523 |
+
slice(start - _lb, stop + 1, step),
|
| 1524 |
+
neg_index_as_lookback=True,
|
| 1525 |
+
),
|
| 1526 |
+
lookback=_lb,
|
| 1527 |
+
space=self.observation_space,
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
_lb = (
|
| 1531 |
+
len_lookback_buffer
|
| 1532 |
+
if len_lookback_buffer is not None
|
| 1533 |
+
else self.infos.lookback
|
| 1534 |
+
)
|
| 1535 |
+
if start >= 0 and start - _lb < 0 and self.infos.lookback < (_lb - start):
|
| 1536 |
+
_lb = self.infos.lookback + start
|
| 1537 |
+
infos = InfiniteLookbackBuffer(
|
| 1538 |
+
data=self.get_infos(
|
| 1539 |
+
slice(start - _lb, stop + 1, step),
|
| 1540 |
+
neg_index_as_lookback=True,
|
| 1541 |
+
),
|
| 1542 |
+
lookback=_lb,
|
| 1543 |
+
)
|
| 1544 |
+
|
| 1545 |
+
_lb = (
|
| 1546 |
+
len_lookback_buffer
|
| 1547 |
+
if len_lookback_buffer is not None
|
| 1548 |
+
else self.actions.lookback
|
| 1549 |
+
)
|
| 1550 |
+
if start >= 0 and start - _lb < 0 and self.actions.lookback < (_lb - start):
|
| 1551 |
+
_lb = self.actions.lookback + start
|
| 1552 |
+
actions = InfiniteLookbackBuffer(
|
| 1553 |
+
data=self.get_actions(
|
| 1554 |
+
slice(start - _lb, stop, step),
|
| 1555 |
+
neg_index_as_lookback=True,
|
| 1556 |
+
),
|
| 1557 |
+
lookback=_lb,
|
| 1558 |
+
space=self.action_space,
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
_lb = (
|
| 1562 |
+
len_lookback_buffer
|
| 1563 |
+
if len_lookback_buffer is not None
|
| 1564 |
+
else self.rewards.lookback
|
| 1565 |
+
)
|
| 1566 |
+
if start >= 0 and start - _lb < 0 and self.rewards.lookback < (_lb - start):
|
| 1567 |
+
_lb = self.rewards.lookback + start
|
| 1568 |
+
rewards = InfiniteLookbackBuffer(
|
| 1569 |
+
data=self.get_rewards(
|
| 1570 |
+
slice(start - _lb, stop, step),
|
| 1571 |
+
neg_index_as_lookback=True,
|
| 1572 |
+
),
|
| 1573 |
+
lookback=_lb,
|
| 1574 |
+
)
|
| 1575 |
+
|
| 1576 |
+
extra_model_outputs = {}
|
| 1577 |
+
for k, v in self.extra_model_outputs.items():
|
| 1578 |
+
_lb = len_lookback_buffer if len_lookback_buffer is not None else v.lookback
|
| 1579 |
+
if start >= 0 and start - _lb < 0 and v.lookback < (_lb - start):
|
| 1580 |
+
_lb = v.lookback + start
|
| 1581 |
+
extra_model_outputs[k] = InfiniteLookbackBuffer(
|
| 1582 |
+
data=self.get_extra_model_outputs(
|
| 1583 |
+
key=k,
|
| 1584 |
+
indices=slice(start - _lb, stop, step),
|
| 1585 |
+
neg_index_as_lookback=True,
|
| 1586 |
+
),
|
| 1587 |
+
lookback=_lb,
|
| 1588 |
+
)
|
| 1589 |
+
|
| 1590 |
+
return SingleAgentEpisode(
|
| 1591 |
+
id_=self.id_,
|
| 1592 |
+
# In the following, offset `start`s automatically by lookbacks.
|
| 1593 |
+
observations=observations,
|
| 1594 |
+
observation_space=self.observation_space,
|
| 1595 |
+
infos=infos,
|
| 1596 |
+
actions=actions,
|
| 1597 |
+
action_space=self.action_space,
|
| 1598 |
+
rewards=rewards,
|
| 1599 |
+
extra_model_outputs=extra_model_outputs,
|
| 1600 |
+
terminated=(self.is_terminated if keep_done else False),
|
| 1601 |
+
truncated=(self.is_truncated if keep_done else False),
|
| 1602 |
+
t_started=t_started,
|
| 1603 |
+
)
|
| 1604 |
+
|
| 1605 |
+
def get_data_dict(self):
|
| 1606 |
+
"""Converts a SingleAgentEpisode into a data dict mapping str keys to data.
|
| 1607 |
+
|
| 1608 |
+
The keys used are:
|
| 1609 |
+
Columns.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS,
|
| 1610 |
+
and those in `self.extra_model_outputs`.
|
| 1611 |
+
|
| 1612 |
+
Returns:
|
| 1613 |
+
A data dict mapping str keys to data records.
|
| 1614 |
+
"""
|
| 1615 |
+
t = list(range(self.t_started, self.t))
|
| 1616 |
+
terminateds = [False] * (len(self) - 1) + [self.is_terminated]
|
| 1617 |
+
truncateds = [False] * (len(self) - 1) + [self.is_truncated]
|
| 1618 |
+
eps_id = [self.id_] * len(self)
|
| 1619 |
+
|
| 1620 |
+
if self.is_numpy:
|
| 1621 |
+
t = np.array(t)
|
| 1622 |
+
terminateds = np.array(terminateds)
|
| 1623 |
+
truncateds = np.array(truncateds)
|
| 1624 |
+
eps_id = np.array(eps_id)
|
| 1625 |
+
|
| 1626 |
+
return dict(
|
| 1627 |
+
{
|
| 1628 |
+
# Trivial 1D data (compiled above).
|
| 1629 |
+
Columns.TERMINATEDS: terminateds,
|
| 1630 |
+
Columns.TRUNCATEDS: truncateds,
|
| 1631 |
+
Columns.T: t,
|
| 1632 |
+
Columns.EPS_ID: eps_id,
|
| 1633 |
+
# Retrieve obs, infos, actions, rewards using our get_... APIs,
|
| 1634 |
+
# which return all relevant timesteps (excluding the lookback
|
| 1635 |
+
# buffer!). Slice off last obs and infos to have the same number
|
| 1636 |
+
# of them as we have actions and rewards.
|
| 1637 |
+
Columns.OBS: self.get_observations(slice(None, -1)),
|
| 1638 |
+
Columns.INFOS: self.get_infos(slice(None, -1)),
|
| 1639 |
+
Columns.ACTIONS: self.get_actions(),
|
| 1640 |
+
Columns.REWARDS: self.get_rewards(),
|
| 1641 |
+
},
|
| 1642 |
+
# All `extra_model_outs`: Same as obs: Use get_... API.
|
| 1643 |
+
**{
|
| 1644 |
+
k: self.get_extra_model_outputs(k)
|
| 1645 |
+
for k in self.extra_model_outputs.keys()
|
| 1646 |
+
},
|
| 1647 |
+
)
|
| 1648 |
+
|
| 1649 |
+
def get_sample_batch(self) -> SampleBatch:
|
| 1650 |
+
"""Converts this `SingleAgentEpisode` into a `SampleBatch`.
|
| 1651 |
+
|
| 1652 |
+
Returns:
|
| 1653 |
+
A SampleBatch containing all of this episode's data.
|
| 1654 |
+
"""
|
| 1655 |
+
return SampleBatch(self.get_data_dict())
|
| 1656 |
+
|
| 1657 |
+
def get_return(self) -> float:
|
| 1658 |
+
"""Calculates an episode's return, excluding the lookback buffer's rewards.
|
| 1659 |
+
|
| 1660 |
+
The return is computed by a simple sum, neglecting the discount factor.
|
| 1661 |
+
Note that if `self` is a continuation chunk (resulting from a call to
|
| 1662 |
+
`self.cut()`), the previous chunk's rewards are NOT counted and thus NOT
|
| 1663 |
+
part of the returned reward sum.
|
| 1664 |
+
|
| 1665 |
+
Returns:
|
| 1666 |
+
The sum of rewards collected during this episode, excluding possible data
|
| 1667 |
+
inside the lookback buffer and excluding possible data in a predecessor
|
| 1668 |
+
chunk.
|
| 1669 |
+
"""
|
| 1670 |
+
return sum(self.get_rewards())
|
| 1671 |
+
|
| 1672 |
+
def get_duration_s(self) -> float:
|
| 1673 |
+
"""Returns the duration of this Episode (chunk) in seconds."""
|
| 1674 |
+
if self._last_step_time is None:
|
| 1675 |
+
return 0.0
|
| 1676 |
+
return self._last_step_time - self._start_time
|
| 1677 |
+
|
| 1678 |
+
def env_steps(self) -> int:
|
| 1679 |
+
"""Returns the number of environment steps.
|
| 1680 |
+
|
| 1681 |
+
Note, this episode instance could be a chunk of an actual episode.
|
| 1682 |
+
|
| 1683 |
+
Returns:
|
| 1684 |
+
An integer that counts the number of environment steps this episode instance
|
| 1685 |
+
has seen.
|
| 1686 |
+
"""
|
| 1687 |
+
return len(self)
|
| 1688 |
+
|
| 1689 |
+
def agent_steps(self) -> int:
|
| 1690 |
+
"""Returns the number of agent steps.
|
| 1691 |
+
|
| 1692 |
+
Note, these are identical to the environment steps for a single-agent episode.
|
| 1693 |
+
|
| 1694 |
+
Returns:
|
| 1695 |
+
An integer counting the number of agent steps executed during the time this
|
| 1696 |
+
episode instance records.
|
| 1697 |
+
"""
|
| 1698 |
+
return self.env_steps()
|
| 1699 |
+
|
| 1700 |
+
def get_state(self) -> Dict[str, Any]:
|
| 1701 |
+
"""Returns the pickable state of an episode.
|
| 1702 |
+
|
| 1703 |
+
The data in the episode is stored into a dictionary. Note that episodes
|
| 1704 |
+
can also be generated from states (see `SingleAgentEpisode.from_state()`).
|
| 1705 |
+
|
| 1706 |
+
Returns:
|
| 1707 |
+
A dict containing all the data from the episode.
|
| 1708 |
+
"""
|
| 1709 |
+
infos = self.infos.get_state()
|
| 1710 |
+
infos["data"] = np.array([info if info else None for info in infos["data"]])
|
| 1711 |
+
return {
|
| 1712 |
+
"id_": self.id_,
|
| 1713 |
+
"agent_id": self.agent_id,
|
| 1714 |
+
"module_id": self.module_id,
|
| 1715 |
+
"multi_agent_episode_id": self.multi_agent_episode_id,
|
| 1716 |
+
# Note, all data is stored in `InfiniteLookbackBuffer`s.
|
| 1717 |
+
"observations": self.observations.get_state(),
|
| 1718 |
+
"actions": self.actions.get_state(),
|
| 1719 |
+
"rewards": self.rewards.get_state(),
|
| 1720 |
+
"infos": self.infos.get_state(),
|
| 1721 |
+
"extra_model_outputs": {
|
| 1722 |
+
k: v.get_state() if v else v
|
| 1723 |
+
for k, v in self.extra_model_outputs.items()
|
| 1724 |
+
}
|
| 1725 |
+
if len(self.extra_model_outputs) > 0
|
| 1726 |
+
else None,
|
| 1727 |
+
"is_terminated": self.is_terminated,
|
| 1728 |
+
"is_truncated": self.is_truncated,
|
| 1729 |
+
"t_started": self.t_started,
|
| 1730 |
+
"t": self.t,
|
| 1731 |
+
"_observation_space": gym_space_to_dict(self._observation_space)
|
| 1732 |
+
if self._observation_space
|
| 1733 |
+
else None,
|
| 1734 |
+
"_action_space": gym_space_to_dict(self._action_space)
|
| 1735 |
+
if self._action_space
|
| 1736 |
+
else None,
|
| 1737 |
+
"_start_time": self._start_time,
|
| 1738 |
+
"_last_step_time": self._last_step_time,
|
| 1739 |
+
"_temporary_timestep_data": dict(self._temporary_timestep_data)
|
| 1740 |
+
if len(self._temporary_timestep_data) > 0
|
| 1741 |
+
else None,
|
| 1742 |
+
}
|
| 1743 |
+
|
| 1744 |
+
@staticmethod
|
| 1745 |
+
def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode":
|
| 1746 |
+
"""Creates a new `SingleAgentEpisode` instance from a state dict.
|
| 1747 |
+
|
| 1748 |
+
Args:
|
| 1749 |
+
state: The state dict, as returned by `self.get_state()`.
|
| 1750 |
+
|
| 1751 |
+
Returns:
|
| 1752 |
+
A new `SingleAgentEpisode` instance with the data from the state dict.
|
| 1753 |
+
"""
|
| 1754 |
+
# Create an empy episode instance.
|
| 1755 |
+
episode = SingleAgentEpisode(id_=state["id_"])
|
| 1756 |
+
# Load all the data from the state dict into the episode.
|
| 1757 |
+
episode.agent_id = state["agent_id"]
|
| 1758 |
+
episode.module_id = state["module_id"]
|
| 1759 |
+
episode.multi_agent_episode_id = state["multi_agent_episode_id"]
|
| 1760 |
+
# Convert data back to `InfiniteLookbackBuffer`s.
|
| 1761 |
+
episode.observations = InfiniteLookbackBuffer.from_state(state["observations"])
|
| 1762 |
+
episode.actions = InfiniteLookbackBuffer.from_state(state["actions"])
|
| 1763 |
+
episode.rewards = InfiniteLookbackBuffer.from_state(state["rewards"])
|
| 1764 |
+
episode.infos = InfiniteLookbackBuffer.from_state(state["infos"])
|
| 1765 |
+
episode.extra_model_outputs = (
|
| 1766 |
+
defaultdict(
|
| 1767 |
+
functools.partial(
|
| 1768 |
+
InfiniteLookbackBuffer, lookback=episode.observations.lookback
|
| 1769 |
+
),
|
| 1770 |
+
{
|
| 1771 |
+
k: InfiniteLookbackBuffer.from_state(v)
|
| 1772 |
+
for k, v in state["extra_model_outputs"].items()
|
| 1773 |
+
},
|
| 1774 |
+
)
|
| 1775 |
+
if state["extra_model_outputs"]
|
| 1776 |
+
else defaultdict(
|
| 1777 |
+
functools.partial(
|
| 1778 |
+
InfiniteLookbackBuffer, lookback=episode.observations.lookback
|
| 1779 |
+
),
|
| 1780 |
+
)
|
| 1781 |
+
)
|
| 1782 |
+
episode.is_terminated = state["is_terminated"]
|
| 1783 |
+
episode.is_truncated = state["is_truncated"]
|
| 1784 |
+
episode.t_started = state["t_started"]
|
| 1785 |
+
episode.t = state["t"]
|
| 1786 |
+
# We need to convert the spaces to dictionaries for serialization.
|
| 1787 |
+
episode._observation_space = (
|
| 1788 |
+
gym_space_from_dict(state["_observation_space"])
|
| 1789 |
+
if state["_observation_space"]
|
| 1790 |
+
else None
|
| 1791 |
+
)
|
| 1792 |
+
episode._action_space = (
|
| 1793 |
+
gym_space_from_dict(state["_action_space"])
|
| 1794 |
+
if state["_action_space"]
|
| 1795 |
+
else None
|
| 1796 |
+
)
|
| 1797 |
+
episode._start_time = state["_start_time"]
|
| 1798 |
+
episode._last_step_time = state["_last_step_time"]
|
| 1799 |
+
episode._temporary_timestep_data = defaultdict(
|
| 1800 |
+
list, state["_temporary_timestep_data"] or {}
|
| 1801 |
+
)
|
| 1802 |
+
# Validate the episode.
|
| 1803 |
+
episode.validate()
|
| 1804 |
+
|
| 1805 |
+
return episode
|
| 1806 |
+
|
| 1807 |
+
@property
|
| 1808 |
+
def observation_space(self):
|
| 1809 |
+
return self._observation_space
|
| 1810 |
+
|
| 1811 |
+
@observation_space.setter
|
| 1812 |
+
def observation_space(self, value):
|
| 1813 |
+
self._observation_space = self.observations.space = value
|
| 1814 |
+
|
| 1815 |
+
@property
|
| 1816 |
+
def action_space(self):
|
| 1817 |
+
return self._action_space
|
| 1818 |
+
|
| 1819 |
+
@action_space.setter
|
| 1820 |
+
def action_space(self, value):
|
| 1821 |
+
self._action_space = self.actions.space = value
|
| 1822 |
+
|
| 1823 |
+
def __len__(self) -> int:
|
| 1824 |
+
"""Returning the length of an episode.
|
| 1825 |
+
|
| 1826 |
+
The length of an episode is defined by the length of its data, excluding
|
| 1827 |
+
the lookback buffer data. The length is the number of timesteps an agent has
|
| 1828 |
+
stepped through an environment thus far.
|
| 1829 |
+
|
| 1830 |
+
The length is 0 in case of an episode whose env has NOT been reset yet, but
|
| 1831 |
+
also 0 right after the `env.reset()` data has been added via
|
| 1832 |
+
`self.add_env_reset()`. Only after the first call to `env.step()` (and
|
| 1833 |
+
`self.add_env_step()`, the length will be 1.
|
| 1834 |
+
|
| 1835 |
+
Returns:
|
| 1836 |
+
An integer, defining the length of an episode.
|
| 1837 |
+
"""
|
| 1838 |
+
return self.t - self.t_started
|
| 1839 |
+
|
| 1840 |
+
def __repr__(self):
|
| 1841 |
+
return (
|
| 1842 |
+
f"SAEps(len={len(self)} done={self.is_done} "
|
| 1843 |
+
f"R={self.get_return()} id_={self.id_})"
|
| 1844 |
+
)
|
| 1845 |
+
|
| 1846 |
+
def __getitem__(self, item: slice) -> "SingleAgentEpisode":
|
| 1847 |
+
"""Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:]."""
|
| 1848 |
+
if isinstance(item, slice):
|
| 1849 |
+
return self.slice(slice_=item)
|
| 1850 |
+
else:
|
| 1851 |
+
raise NotImplementedError(
|
| 1852 |
+
f"SingleAgentEpisode does not support getting item '{item}'! "
|
| 1853 |
+
"Only slice objects allowed with the syntax: `episode[a:b]`."
|
| 1854 |
+
)
|
| 1855 |
+
|
| 1856 |
+
@Deprecated(new="SingleAgentEpisode.is_numpy()", error=True)
|
| 1857 |
+
def is_finalized(self):
|
| 1858 |
+
pass
|
| 1859 |
+
|
| 1860 |
+
@Deprecated(new="SingleAgentEpisode.to_numpy()", error=True)
|
| 1861 |
+
def finalize(self):
|
| 1862 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/rllib/env/tcp_client_inference_env_runner.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import gzip
|
| 4 |
+
import json
|
| 5 |
+
import pathlib
|
| 6 |
+
import socket
|
| 7 |
+
import tempfile
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
from typing import Collection, DefaultDict, List, Optional, Union
|
| 11 |
+
|
| 12 |
+
import gymnasium as gym
|
| 13 |
+
import numpy as np
|
| 14 |
+
import onnxruntime
|
| 15 |
+
|
| 16 |
+
from ray.rllib.core import (
|
| 17 |
+
Columns,
|
| 18 |
+
COMPONENT_RL_MODULE,
|
| 19 |
+
DEFAULT_AGENT_ID,
|
| 20 |
+
DEFAULT_MODULE_ID,
|
| 21 |
+
)
|
| 22 |
+
from ray.rllib.env import INPUT_ENV_SPACES
|
| 23 |
+
from ray.rllib.env.env_runner import EnvRunner
|
| 24 |
+
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
|
| 25 |
+
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
|
| 26 |
+
from ray.rllib.env.utils.external_env_protocol import RLlink as rllink
|
| 27 |
+
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
| 28 |
+
from ray.rllib.utils.checkpoints import Checkpointable
|
| 29 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 30 |
+
from ray.rllib.utils.metrics import (
|
| 31 |
+
EPISODE_DURATION_SEC_MEAN,
|
| 32 |
+
EPISODE_LEN_MAX,
|
| 33 |
+
EPISODE_LEN_MEAN,
|
| 34 |
+
EPISODE_LEN_MIN,
|
| 35 |
+
EPISODE_RETURN_MAX,
|
| 36 |
+
EPISODE_RETURN_MEAN,
|
| 37 |
+
EPISODE_RETURN_MIN,
|
| 38 |
+
WEIGHTS_SEQ_NO,
|
| 39 |
+
)
|
| 40 |
+
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
|
| 41 |
+
from ray.rllib.utils.numpy import softmax
|
| 42 |
+
from ray.rllib.utils.typing import EpisodeID, StateDict
|
| 43 |
+
|
| 44 |
+
torch, _ = try_import_torch()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@ExperimentalAPI
|
| 48 |
+
class TcpClientInferenceEnvRunner(EnvRunner, Checkpointable):
|
| 49 |
+
"""An EnvRunner communicating with an external env through a TCP socket.
|
| 50 |
+
|
| 51 |
+
This implementation assumes:
|
| 52 |
+
- Only one external client ever connects to this env runner.
|
| 53 |
+
- The external client performs inference locally through an ONNX model. Thus,
|
| 54 |
+
samples are sent in bulk once a certain number of timesteps has been executed on the
|
| 55 |
+
client's side (no individual action requests).
|
| 56 |
+
- A copy of the RLModule is kept at all times on the env runner, but never used
|
| 57 |
+
for inference, only as a data (weights) container.
|
| 58 |
+
TODO (sven): The above might be inefficient as we have to store basically two
|
| 59 |
+
models, one in this EnvRunner, one in the env (as ONNX).
|
| 60 |
+
- There is no environment and no connectors on this env runner. The external env
|
| 61 |
+
is responsible for generating all the data to create episodes.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
@override(EnvRunner)
|
| 65 |
+
def __init__(self, *, config, **kwargs):
|
| 66 |
+
"""
|
| 67 |
+
Initializes a TcpClientInferenceEnvRunner instance.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
config: The AlgorithmConfig to use for setup.
|
| 71 |
+
|
| 72 |
+
Keyword Args:
|
| 73 |
+
port: The base port number. The server socket is then actually bound to
|
| 74 |
+
`port` + self.worker_index.
|
| 75 |
+
"""
|
| 76 |
+
super().__init__(config=config)
|
| 77 |
+
|
| 78 |
+
self.worker_index: int = kwargs.get("worker_index", 0)
|
| 79 |
+
|
| 80 |
+
self._weights_seq_no = 0
|
| 81 |
+
|
| 82 |
+
# Build the module from its spec.
|
| 83 |
+
module_spec = self.config.get_rl_module_spec(
|
| 84 |
+
spaces=self.get_spaces(), inference_only=True
|
| 85 |
+
)
|
| 86 |
+
self.module = module_spec.build()
|
| 87 |
+
|
| 88 |
+
self.host = "localhost"
|
| 89 |
+
self.port = int(self.config.env_config.get("port", 5555)) + self.worker_index
|
| 90 |
+
self.server_socket = None
|
| 91 |
+
self.client_socket = None
|
| 92 |
+
self.address = None
|
| 93 |
+
|
| 94 |
+
self.metrics = MetricsLogger()
|
| 95 |
+
|
| 96 |
+
self._episode_chunks_to_return: Optional[List[SingleAgentEpisode]] = None
|
| 97 |
+
self._done_episodes_for_metrics: List[SingleAgentEpisode] = []
|
| 98 |
+
self._ongoing_episodes_for_metrics: DefaultDict[
|
| 99 |
+
EpisodeID, List[SingleAgentEpisode]
|
| 100 |
+
] = defaultdict(list)
|
| 101 |
+
|
| 102 |
+
self._sample_lock = threading.Lock()
|
| 103 |
+
self._on_policy_lock = threading.Lock()
|
| 104 |
+
self._blocked_on_state = False
|
| 105 |
+
|
| 106 |
+
# Start a background thread for client communication.
|
| 107 |
+
self.thread = threading.Thread(
|
| 108 |
+
target=self._client_message_listener, daemon=True
|
| 109 |
+
)
|
| 110 |
+
self.thread.start()
|
| 111 |
+
|
| 112 |
+
@override(EnvRunner)
|
| 113 |
+
def assert_healthy(self):
|
| 114 |
+
"""Checks that the server socket is open and listening."""
|
| 115 |
+
assert (
|
| 116 |
+
self.server_socket is not None
|
| 117 |
+
), "Server socket is None (not connected, not listening)."
|
| 118 |
+
|
| 119 |
+
@override(EnvRunner)
|
| 120 |
+
def sample(self, **kwargs):
|
| 121 |
+
"""Waits for the client to send episodes."""
|
| 122 |
+
while True:
|
| 123 |
+
with self._sample_lock:
|
| 124 |
+
if self._episode_chunks_to_return is not None:
|
| 125 |
+
num_env_steps = 0
|
| 126 |
+
num_episodes_completed = 0
|
| 127 |
+
for eps in self._episode_chunks_to_return:
|
| 128 |
+
if eps.is_done:
|
| 129 |
+
self._done_episodes_for_metrics.append(eps)
|
| 130 |
+
num_episodes_completed += 1
|
| 131 |
+
else:
|
| 132 |
+
self._ongoing_episodes_for_metrics[eps.id_].append(eps)
|
| 133 |
+
num_env_steps += len(eps)
|
| 134 |
+
|
| 135 |
+
ret = self._episode_chunks_to_return
|
| 136 |
+
self._episode_chunks_to_return = None
|
| 137 |
+
|
| 138 |
+
SingleAgentEnvRunner._increase_sampled_metrics(
|
| 139 |
+
self, num_env_steps, num_episodes_completed
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return ret
|
| 143 |
+
time.sleep(0.01)
|
| 144 |
+
|
| 145 |
+
@override(EnvRunner)
|
| 146 |
+
def get_metrics(self):
|
| 147 |
+
# TODO (sven): We should probably make this a utility function to be called
|
| 148 |
+
# from within Single/MultiAgentEnvRunner and other EnvRunner subclasses, as
|
| 149 |
+
# needed.
|
| 150 |
+
# Compute per-episode metrics (only on already completed episodes).
|
| 151 |
+
for eps in self._done_episodes_for_metrics:
|
| 152 |
+
assert eps.is_done
|
| 153 |
+
episode_length = len(eps)
|
| 154 |
+
episode_return = eps.get_return()
|
| 155 |
+
episode_duration_s = eps.get_duration_s()
|
| 156 |
+
# Don't forget about the already returned chunks of this episode.
|
| 157 |
+
if eps.id_ in self._ongoing_episodes_for_metrics:
|
| 158 |
+
for eps2 in self._ongoing_episodes_for_metrics[eps.id_]:
|
| 159 |
+
episode_length += len(eps2)
|
| 160 |
+
episode_return += eps2.get_return()
|
| 161 |
+
episode_duration_s += eps2.get_duration_s()
|
| 162 |
+
del self._ongoing_episodes_for_metrics[eps.id_]
|
| 163 |
+
|
| 164 |
+
self._log_episode_metrics(
|
| 165 |
+
episode_length, episode_return, episode_duration_s
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Now that we have logged everything, clear cache of done episodes.
|
| 169 |
+
self._done_episodes_for_metrics.clear()
|
| 170 |
+
|
| 171 |
+
# Return reduced metrics.
|
| 172 |
+
return self.metrics.reduce()
|
| 173 |
+
|
| 174 |
+
def get_spaces(self):
|
| 175 |
+
return {
|
| 176 |
+
INPUT_ENV_SPACES: (self.config.observation_space, self.config.action_space),
|
| 177 |
+
DEFAULT_MODULE_ID: (
|
| 178 |
+
self.config.observation_space,
|
| 179 |
+
self.config.action_space,
|
| 180 |
+
),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
@override(EnvRunner)
|
| 184 |
+
def stop(self):
|
| 185 |
+
"""Closes the client and server sockets."""
|
| 186 |
+
self._close_sockets_if_necessary()
|
| 187 |
+
|
| 188 |
+
@override(Checkpointable)
|
| 189 |
+
def get_ctor_args_and_kwargs(self):
|
| 190 |
+
return (
|
| 191 |
+
(), # *args
|
| 192 |
+
{"config": self.config}, # **kwargs
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
@override(Checkpointable)
|
| 196 |
+
def get_checkpointable_components(self):
|
| 197 |
+
return [
|
| 198 |
+
(COMPONENT_RL_MODULE, self.module),
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
@override(Checkpointable)
|
| 202 |
+
def get_state(
|
| 203 |
+
self,
|
| 204 |
+
components: Optional[Union[str, Collection[str]]] = None,
|
| 205 |
+
*,
|
| 206 |
+
not_components: Optional[Union[str, Collection[str]]] = None,
|
| 207 |
+
**kwargs,
|
| 208 |
+
) -> StateDict:
|
| 209 |
+
return {}
|
| 210 |
+
|
| 211 |
+
@override(Checkpointable)
|
| 212 |
+
def set_state(self, state: StateDict) -> None:
|
| 213 |
+
# Update the RLModule state.
|
| 214 |
+
if COMPONENT_RL_MODULE in state:
|
| 215 |
+
# A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the
|
| 216 |
+
# update.
|
| 217 |
+
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
|
| 218 |
+
|
| 219 |
+
# Only update the weigths, if this is the first synchronization or
|
| 220 |
+
# if the weights of this `EnvRunner` lacks behind the actual ones.
|
| 221 |
+
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
|
| 222 |
+
rl_module_state = state[COMPONENT_RL_MODULE]
|
| 223 |
+
if (
|
| 224 |
+
isinstance(rl_module_state, dict)
|
| 225 |
+
and DEFAULT_MODULE_ID in rl_module_state
|
| 226 |
+
):
|
| 227 |
+
rl_module_state = rl_module_state[DEFAULT_MODULE_ID]
|
| 228 |
+
self.module.set_state(rl_module_state)
|
| 229 |
+
|
| 230 |
+
# Update our weights_seq_no, if the new one is > 0.
|
| 231 |
+
if weights_seq_no > 0:
|
| 232 |
+
self._weights_seq_no = weights_seq_no
|
| 233 |
+
|
| 234 |
+
if self._blocked_on_state is True:
|
| 235 |
+
self._send_set_state_message()
|
| 236 |
+
self._blocked_on_state = False
|
| 237 |
+
|
| 238 |
+
def _client_message_listener(self):
|
| 239 |
+
"""Entry point for the listener thread."""
|
| 240 |
+
|
| 241 |
+
# Set up the server socket and bind to the specified host and port.
|
| 242 |
+
self._recycle_sockets()
|
| 243 |
+
|
| 244 |
+
# Enter an endless message receival- and processing loop.
|
| 245 |
+
while True:
|
| 246 |
+
# As long as we are blocked on a new state, sleep a bit and continue.
|
| 247 |
+
# Do NOT process any incoming messages (until we send out the new state
|
| 248 |
+
# back to the client).
|
| 249 |
+
if self._blocked_on_state is True:
|
| 250 |
+
time.sleep(0.01)
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
# Blocking call to get next message.
|
| 255 |
+
msg_type, msg_body = _get_message(self.client_socket)
|
| 256 |
+
|
| 257 |
+
# Process the message received based on its type.
|
| 258 |
+
# Initial handshake.
|
| 259 |
+
if msg_type == rllink.PING:
|
| 260 |
+
self._send_pong_message()
|
| 261 |
+
|
| 262 |
+
# Episode data from the client.
|
| 263 |
+
elif msg_type in [
|
| 264 |
+
rllink.EPISODES,
|
| 265 |
+
rllink.EPISODES_AND_GET_STATE,
|
| 266 |
+
]:
|
| 267 |
+
self._process_episodes_message(msg_type, msg_body)
|
| 268 |
+
|
| 269 |
+
# Client requests the state (model weights).
|
| 270 |
+
elif msg_type == rllink.GET_STATE:
|
| 271 |
+
self._send_set_state_message()
|
| 272 |
+
|
| 273 |
+
# Clients requests some (relevant) config information.
|
| 274 |
+
elif msg_type == rllink.GET_CONFIG:
|
| 275 |
+
self._send_set_config_message()
|
| 276 |
+
|
| 277 |
+
except ConnectionError as e:
|
| 278 |
+
print(f"Messaging/connection error {e}! Recycling sockets ...")
|
| 279 |
+
self._recycle_sockets(5.0)
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
def _recycle_sockets(self, sleep: float = 0.0):
|
| 283 |
+
# Close all old sockets, if they exist.
|
| 284 |
+
self._close_sockets_if_necessary()
|
| 285 |
+
|
| 286 |
+
time.sleep(sleep)
|
| 287 |
+
|
| 288 |
+
# Start listening on the configured port.
|
| 289 |
+
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 290 |
+
# Allow reuse of the address.
|
| 291 |
+
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 292 |
+
self.server_socket.bind((self.host, self.port))
|
| 293 |
+
# Listen for a single connection.
|
| 294 |
+
self.server_socket.listen(1)
|
| 295 |
+
print(f"Waiting for client to connect to port {self.port}...")
|
| 296 |
+
|
| 297 |
+
self.client_socket, self.address = self.server_socket.accept()
|
| 298 |
+
print(f"Connected to client at {self.address}")
|
| 299 |
+
|
| 300 |
+
def _close_sockets_if_necessary(self):
|
| 301 |
+
if self.client_socket:
|
| 302 |
+
self.client_socket.close()
|
| 303 |
+
if self.server_socket:
|
| 304 |
+
self.server_socket.close()
|
| 305 |
+
|
| 306 |
+
def _send_pong_message(self):
|
| 307 |
+
_send_message(self.client_socket, {"type": rllink.PONG.name})
|
| 308 |
+
|
| 309 |
+
def _process_episodes_message(self, msg_type, msg_body):
|
| 310 |
+
# On-policy training -> we have to block until we get a new `set_state` call
|
| 311 |
+
# (b/c the learning step is done and we can sent new weights back to all
|
| 312 |
+
# clients).
|
| 313 |
+
if msg_type == rllink.EPISODES_AND_GET_STATE:
|
| 314 |
+
self._blocked_on_state = True
|
| 315 |
+
|
| 316 |
+
episodes = []
|
| 317 |
+
for episode_data in msg_body["episodes"]:
|
| 318 |
+
episode = SingleAgentEpisode(
|
| 319 |
+
observation_space=self.config.observation_space,
|
| 320 |
+
observations=[np.array(o) for o in episode_data[Columns.OBS]],
|
| 321 |
+
action_space=self.config.action_space,
|
| 322 |
+
actions=episode_data[Columns.ACTIONS],
|
| 323 |
+
rewards=episode_data[Columns.REWARDS],
|
| 324 |
+
extra_model_outputs={
|
| 325 |
+
Columns.ACTION_DIST_INPUTS: [
|
| 326 |
+
np.array(a) for a in episode_data[Columns.ACTION_DIST_INPUTS]
|
| 327 |
+
],
|
| 328 |
+
Columns.ACTION_LOGP: episode_data[Columns.ACTION_LOGP],
|
| 329 |
+
},
|
| 330 |
+
terminated=episode_data["is_terminated"],
|
| 331 |
+
truncated=episode_data["is_truncated"],
|
| 332 |
+
len_lookback_buffer=0,
|
| 333 |
+
)
|
| 334 |
+
episodes.append(episode.to_numpy())
|
| 335 |
+
|
| 336 |
+
# Push episodes into the to-be-returned list (for `sample()` requests).
|
| 337 |
+
with self._sample_lock:
|
| 338 |
+
if isinstance(self._episode_chunks_to_return, list):
|
| 339 |
+
self._episode_chunks_to_return.extend(episodes)
|
| 340 |
+
else:
|
| 341 |
+
self._episode_chunks_to_return = episodes
|
| 342 |
+
|
| 343 |
+
def _send_set_state_message(self):
|
| 344 |
+
with tempfile.TemporaryDirectory() as dir:
|
| 345 |
+
onnx_file = pathlib.Path(dir) / "_temp_model.onnx"
|
| 346 |
+
torch.onnx.export(
|
| 347 |
+
self.module,
|
| 348 |
+
{
|
| 349 |
+
"batch": {
|
| 350 |
+
"obs": torch.randn(1, *self.config.observation_space.shape)
|
| 351 |
+
}
|
| 352 |
+
},
|
| 353 |
+
onnx_file,
|
| 354 |
+
export_params=True,
|
| 355 |
+
)
|
| 356 |
+
with open(onnx_file, "rb") as f:
|
| 357 |
+
compressed = gzip.compress(f.read())
|
| 358 |
+
onnx_binary = base64.b64encode(compressed).decode("utf-8")
|
| 359 |
+
_send_message(
|
| 360 |
+
self.client_socket,
|
| 361 |
+
{
|
| 362 |
+
"type": rllink.SET_STATE.name,
|
| 363 |
+
"onnx_file": onnx_binary,
|
| 364 |
+
WEIGHTS_SEQ_NO: self._weights_seq_no,
|
| 365 |
+
},
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def _send_set_config_message(self):
|
| 369 |
+
_send_message(
|
| 370 |
+
self.client_socket,
|
| 371 |
+
{
|
| 372 |
+
"type": rllink.SET_CONFIG.name,
|
| 373 |
+
"env_steps_per_sample": self.config.get_rollout_fragment_length(
|
| 374 |
+
worker_index=self.worker_index
|
| 375 |
+
),
|
| 376 |
+
"force_on_policy": True,
|
| 377 |
+
},
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def _log_episode_metrics(self, length, ret, sec):
|
| 381 |
+
# Log general episode metrics.
|
| 382 |
+
# To mimic the old API stack behavior, we'll use `window` here for
|
| 383 |
+
# these particular stats (instead of the default EMA).
|
| 384 |
+
win = self.config.metrics_num_episodes_for_smoothing
|
| 385 |
+
self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win)
|
| 386 |
+
self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win)
|
| 387 |
+
self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win)
|
| 388 |
+
# Per-agent returns.
|
| 389 |
+
self.metrics.log_value(
|
| 390 |
+
("agent_episode_returns_mean", DEFAULT_AGENT_ID), ret, window=win
|
| 391 |
+
)
|
| 392 |
+
# Per-RLModule returns.
|
| 393 |
+
self.metrics.log_value(
|
| 394 |
+
("module_episode_returns_mean", DEFAULT_MODULE_ID), ret, window=win
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# For some metrics, log min/max as well.
|
| 398 |
+
self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min", window=win)
|
| 399 |
+
self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min", window=win)
|
| 400 |
+
self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max", window=win)
|
| 401 |
+
self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max", window=win)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _send_message(sock_, message: dict):
|
| 405 |
+
"""Sends a message to the client with a length header."""
|
| 406 |
+
body = json.dumps(message).encode("utf-8")
|
| 407 |
+
header = str(len(body)).zfill(8).encode("utf-8")
|
| 408 |
+
try:
|
| 409 |
+
sock_.sendall(header + body)
|
| 410 |
+
except Exception as e:
|
| 411 |
+
raise ConnectionError(
|
| 412 |
+
f"Error sending message {message} to server on socket {sock_}! "
|
| 413 |
+
f"Original error was: {e}"
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _get_message(sock_):
|
| 418 |
+
"""Receives a message from the client following the length-header protocol."""
|
| 419 |
+
try:
|
| 420 |
+
# Read the length header (8 bytes)
|
| 421 |
+
header = _get_num_bytes(sock_, 8)
|
| 422 |
+
msg_length = int(header.decode("utf-8"))
|
| 423 |
+
# Read the message body
|
| 424 |
+
body = _get_num_bytes(sock_, msg_length)
|
| 425 |
+
# Decode JSON.
|
| 426 |
+
message = json.loads(body.decode("utf-8"))
|
| 427 |
+
# Check for proper protocol.
|
| 428 |
+
if "type" not in message:
|
| 429 |
+
raise ConnectionError(
|
| 430 |
+
"Protocol Error! Message from peer does not contain `type` " "field."
|
| 431 |
+
)
|
| 432 |
+
return rllink(message.pop("type")), message
|
| 433 |
+
except Exception as e:
|
| 434 |
+
raise ConnectionError(
|
| 435 |
+
f"Error receiving message from peer on socket {sock_}! "
|
| 436 |
+
f"Original error was: {e}"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _get_num_bytes(sock_, num_bytes):
|
| 441 |
+
"""Helper function to receive a specific number of bytes."""
|
| 442 |
+
data = b""
|
| 443 |
+
while len(data) < num_bytes:
|
| 444 |
+
packet = sock_.recv(num_bytes - len(data))
|
| 445 |
+
if not packet:
|
| 446 |
+
raise ConnectionError(f"No data received from socket {sock_}!")
|
| 447 |
+
data += packet
|
| 448 |
+
return data
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _dummy_client(port: int = 5556):
|
| 452 |
+
"""A dummy client that runs CartPole and acts as a testing external env."""
|
| 453 |
+
|
| 454 |
+
def _set_state(msg_body):
|
| 455 |
+
with tempfile.TemporaryDirectory():
|
| 456 |
+
with open("_temp_onnx", "wb") as f:
|
| 457 |
+
f.write(
|
| 458 |
+
gzip.decompress(
|
| 459 |
+
base64.b64decode(msg_body["onnx_file"].encode("utf-8"))
|
| 460 |
+
)
|
| 461 |
+
)
|
| 462 |
+
onnx_session = onnxruntime.InferenceSession("_temp_onnx")
|
| 463 |
+
output_names = [o.name for o in onnx_session.get_outputs()]
|
| 464 |
+
return onnx_session, output_names
|
| 465 |
+
|
| 466 |
+
# Connect to server.
|
| 467 |
+
while True:
|
| 468 |
+
try:
|
| 469 |
+
print(f"Trying to connect to localhost:{port} ...")
|
| 470 |
+
sock_ = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 471 |
+
sock_.connect(("localhost", port))
|
| 472 |
+
break
|
| 473 |
+
except ConnectionRefusedError:
|
| 474 |
+
time.sleep(5)
|
| 475 |
+
|
| 476 |
+
# Send ping-pong.
|
| 477 |
+
_send_message(sock_, {"type": rllink.PING.name})
|
| 478 |
+
msg_type, msg_body = _get_message(sock_)
|
| 479 |
+
assert msg_type == rllink.PONG
|
| 480 |
+
|
| 481 |
+
# Request config.
|
| 482 |
+
_send_message(sock_, {"type": rllink.GET_CONFIG.name})
|
| 483 |
+
msg_type, msg_body = _get_message(sock_)
|
| 484 |
+
assert msg_type == rllink.SET_CONFIG
|
| 485 |
+
env_steps_per_sample = msg_body["env_steps_per_sample"]
|
| 486 |
+
force_on_policy = msg_body["force_on_policy"]
|
| 487 |
+
|
| 488 |
+
# Request ONNX weights.
|
| 489 |
+
_send_message(sock_, {"type": rllink.GET_STATE.name})
|
| 490 |
+
msg_type, msg_body = _get_message(sock_)
|
| 491 |
+
assert msg_type == rllink.SET_STATE
|
| 492 |
+
onnx_session, output_names = _set_state(msg_body)
|
| 493 |
+
|
| 494 |
+
# Episode collection buckets.
|
| 495 |
+
episodes = []
|
| 496 |
+
observations = []
|
| 497 |
+
actions = []
|
| 498 |
+
action_dist_inputs = []
|
| 499 |
+
action_logps = []
|
| 500 |
+
rewards = []
|
| 501 |
+
|
| 502 |
+
timesteps = 0
|
| 503 |
+
episode_return = 0.0
|
| 504 |
+
|
| 505 |
+
# Start actual env loop.
|
| 506 |
+
env = gym.make("CartPole-v1")
|
| 507 |
+
obs, info = env.reset()
|
| 508 |
+
observations.append(obs.tolist())
|
| 509 |
+
|
| 510 |
+
while True:
|
| 511 |
+
timesteps += 1
|
| 512 |
+
# Perform action inference using the ONNX model.
|
| 513 |
+
logits = onnx_session.run(
|
| 514 |
+
output_names,
|
| 515 |
+
{"onnx::Gemm_0": np.array([obs], np.float32)},
|
| 516 |
+
)[0][
|
| 517 |
+
0
|
| 518 |
+
] # [0]=first return item, [0]=batch size 1
|
| 519 |
+
|
| 520 |
+
# Stochastic sample.
|
| 521 |
+
action_probs = softmax(logits)
|
| 522 |
+
action = int(np.random.choice(list(range(env.action_space.n)), p=action_probs))
|
| 523 |
+
logp = float(np.log(action_probs[action]))
|
| 524 |
+
|
| 525 |
+
# Perform the env step.
|
| 526 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 527 |
+
|
| 528 |
+
# Collect step data.
|
| 529 |
+
observations.append(obs.tolist())
|
| 530 |
+
actions.append(action)
|
| 531 |
+
action_dist_inputs.append(logits.tolist())
|
| 532 |
+
action_logps.append(logp)
|
| 533 |
+
rewards.append(reward)
|
| 534 |
+
episode_return += reward
|
| 535 |
+
|
| 536 |
+
# We have to create a new episode record.
|
| 537 |
+
if timesteps == env_steps_per_sample or terminated or truncated:
|
| 538 |
+
episodes.append(
|
| 539 |
+
{
|
| 540 |
+
Columns.OBS: observations,
|
| 541 |
+
Columns.ACTIONS: actions,
|
| 542 |
+
Columns.ACTION_DIST_INPUTS: action_dist_inputs,
|
| 543 |
+
Columns.ACTION_LOGP: action_logps,
|
| 544 |
+
Columns.REWARDS: rewards,
|
| 545 |
+
"is_terminated": terminated,
|
| 546 |
+
"is_truncated": truncated,
|
| 547 |
+
}
|
| 548 |
+
)
|
| 549 |
+
# We collected enough samples -> Send them to server.
|
| 550 |
+
if timesteps == env_steps_per_sample:
|
| 551 |
+
# Make sure the amount of data we collected is correct.
|
| 552 |
+
assert sum(len(e["actions"]) for e in episodes) == env_steps_per_sample
|
| 553 |
+
|
| 554 |
+
# Send the data to the server.
|
| 555 |
+
if force_on_policy:
|
| 556 |
+
_send_message(
|
| 557 |
+
sock_,
|
| 558 |
+
{
|
| 559 |
+
"type": rllink.EPISODES_AND_GET_STATE.name,
|
| 560 |
+
"episodes": episodes,
|
| 561 |
+
"timesteps": timesteps,
|
| 562 |
+
},
|
| 563 |
+
)
|
| 564 |
+
# We are forced to sample on-policy. Have to wait for a response
|
| 565 |
+
# with the state (weights) in it.
|
| 566 |
+
msg_type, msg_body = _get_message(sock_)
|
| 567 |
+
assert msg_type == rllink.SET_STATE
|
| 568 |
+
onnx_session, output_names = _set_state(msg_body)
|
| 569 |
+
|
| 570 |
+
# Sampling doesn't have to be on-policy -> continue collecting
|
| 571 |
+
# samples.
|
| 572 |
+
else:
|
| 573 |
+
raise NotImplementedError
|
| 574 |
+
|
| 575 |
+
episodes = []
|
| 576 |
+
timesteps = 0
|
| 577 |
+
|
| 578 |
+
# Set new buckets to empty lists (for next episode).
|
| 579 |
+
observations = [observations[-1]]
|
| 580 |
+
actions = []
|
| 581 |
+
action_dist_inputs = []
|
| 582 |
+
action_logps = []
|
| 583 |
+
rewards = []
|
| 584 |
+
|
| 585 |
+
# The episode is done -> Reset.
|
| 586 |
+
if terminated or truncated:
|
| 587 |
+
obs, _ = env.reset()
|
| 588 |
+
observations = [obs.tolist()]
|
| 589 |
+
episode_return = 0.0
|
.venv/lib/python3.11/site-packages/ray/rllib/env/utils/__pycache__/external_env_protocol.cpython-311.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/vector_env.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Callable, List, Optional, Tuple, Union, Set
|
| 5 |
+
|
| 6 |
+
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID
|
| 7 |
+
from ray.rllib.utils.annotations import Deprecated, OldAPIStack, override
|
| 8 |
+
from ray.rllib.utils.typing import (
|
| 9 |
+
EnvActionType,
|
| 10 |
+
EnvID,
|
| 11 |
+
EnvInfoDict,
|
| 12 |
+
EnvObsType,
|
| 13 |
+
EnvType,
|
| 14 |
+
MultiEnvDict,
|
| 15 |
+
AgentID,
|
| 16 |
+
)
|
| 17 |
+
from ray.util import log_once
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@OldAPIStack
|
| 23 |
+
class VectorEnv:
|
| 24 |
+
"""An environment that supports batch evaluation using clones of sub-envs."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self, observation_space: gym.Space, action_space: gym.Space, num_envs: int
|
| 28 |
+
):
|
| 29 |
+
"""Initializes a VectorEnv instance.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
observation_space: The observation Space of a single
|
| 33 |
+
sub-env.
|
| 34 |
+
action_space: The action Space of a single sub-env.
|
| 35 |
+
num_envs: The number of clones to make of the given sub-env.
|
| 36 |
+
"""
|
| 37 |
+
self.observation_space = observation_space
|
| 38 |
+
self.action_space = action_space
|
| 39 |
+
self.num_envs = num_envs
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def vectorize_gym_envs(
|
| 43 |
+
make_env: Optional[Callable[[int], EnvType]] = None,
|
| 44 |
+
existing_envs: Optional[List[gym.Env]] = None,
|
| 45 |
+
num_envs: int = 1,
|
| 46 |
+
action_space: Optional[gym.Space] = None,
|
| 47 |
+
observation_space: Optional[gym.Space] = None,
|
| 48 |
+
restart_failed_sub_environments: bool = False,
|
| 49 |
+
# Deprecated. These seem to have never been used.
|
| 50 |
+
env_config=None,
|
| 51 |
+
policy_config=None,
|
| 52 |
+
) -> "_VectorizedGymEnv":
|
| 53 |
+
"""Translates any given gym.Env(s) into a VectorizedEnv object.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
make_env: Factory that produces a new gym.Env taking the sub-env's
|
| 57 |
+
vector index as only arg. Must be defined if the
|
| 58 |
+
number of `existing_envs` is less than `num_envs`.
|
| 59 |
+
existing_envs: Optional list of already instantiated sub
|
| 60 |
+
environments.
|
| 61 |
+
num_envs: Total number of sub environments in this VectorEnv.
|
| 62 |
+
action_space: The action space. If None, use existing_envs[0]'s
|
| 63 |
+
action space.
|
| 64 |
+
observation_space: The observation space. If None, use
|
| 65 |
+
existing_envs[0]'s observation space.
|
| 66 |
+
restart_failed_sub_environments: If True and any sub-environment (within
|
| 67 |
+
a vectorized env) throws any error during env stepping, the
|
| 68 |
+
Sampler will try to restart the faulty sub-environment. This is done
|
| 69 |
+
without disturbing the other (still intact) sub-environment and without
|
| 70 |
+
the RolloutWorker crashing.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
The resulting _VectorizedGymEnv object (subclass of VectorEnv).
|
| 74 |
+
"""
|
| 75 |
+
return _VectorizedGymEnv(
|
| 76 |
+
make_env=make_env,
|
| 77 |
+
existing_envs=existing_envs or [],
|
| 78 |
+
num_envs=num_envs,
|
| 79 |
+
observation_space=observation_space,
|
| 80 |
+
action_space=action_space,
|
| 81 |
+
restart_failed_sub_environments=restart_failed_sub_environments,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def vector_reset(
|
| 85 |
+
self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
|
| 86 |
+
) -> Tuple[List[EnvObsType], List[EnvInfoDict]]:
|
| 87 |
+
"""Resets all sub-environments.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
seed: The list of seeds to be passed to the sub-environments' when resetting
|
| 91 |
+
them. If None, will not reset any existing PRNGs. If you pass
|
| 92 |
+
integers, the PRNGs will be reset even if they already exists.
|
| 93 |
+
options: The list of options dicts to be passed to the sub-environments'
|
| 94 |
+
when resetting them.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Tuple consitsing of a list of observations from each environment and
|
| 98 |
+
a list of info dicts from each environment.
|
| 99 |
+
"""
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
def reset_at(
|
| 103 |
+
self,
|
| 104 |
+
index: Optional[int] = None,
|
| 105 |
+
*,
|
| 106 |
+
seed: Optional[int] = None,
|
| 107 |
+
options: Optional[dict] = None,
|
| 108 |
+
) -> Union[Tuple[EnvObsType, EnvInfoDict], Exception]:
|
| 109 |
+
"""Resets a single sub-environment.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
index: An optional sub-env index to reset.
|
| 113 |
+
seed: The seed to be passed to the sub-environment at index `index` when
|
| 114 |
+
resetting it. If None, will not reset any existing PRNG. If you pass an
|
| 115 |
+
integer, the PRNG will be reset even if it already exists.
|
| 116 |
+
options: An options dict to be passed to the sub-environment at index
|
| 117 |
+
`index` when resetting it.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Tuple consisting of observations from the reset sub environment and
|
| 121 |
+
an info dict of the reset sub environment. Alternatively an Exception
|
| 122 |
+
can be returned, indicating that the reset operation on the sub environment
|
| 123 |
+
has failed (and why it failed).
|
| 124 |
+
"""
|
| 125 |
+
raise NotImplementedError
|
| 126 |
+
|
| 127 |
+
def restart_at(self, index: Optional[int] = None) -> None:
|
| 128 |
+
"""Restarts a single sub-environment.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
index: An optional sub-env index to restart.
|
| 132 |
+
"""
|
| 133 |
+
raise NotImplementedError
|
| 134 |
+
|
| 135 |
+
def vector_step(
|
| 136 |
+
self, actions: List[EnvActionType]
|
| 137 |
+
) -> Tuple[
|
| 138 |
+
List[EnvObsType], List[float], List[bool], List[bool], List[EnvInfoDict]
|
| 139 |
+
]:
|
| 140 |
+
"""Performs a vectorized step on all sub environments using `actions`.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
actions: List of actions (one for each sub-env).
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
A tuple consisting of
|
| 147 |
+
1) New observations for each sub-env.
|
| 148 |
+
2) Reward values for each sub-env.
|
| 149 |
+
3) Terminated values for each sub-env.
|
| 150 |
+
4) Truncated values for each sub-env.
|
| 151 |
+
5) Info values for each sub-env.
|
| 152 |
+
"""
|
| 153 |
+
raise NotImplementedError
|
| 154 |
+
|
| 155 |
+
def get_sub_environments(self) -> List[EnvType]:
|
| 156 |
+
"""Returns the underlying sub environments.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
List of all underlying sub environments.
|
| 160 |
+
"""
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
# TODO: (sven) Experimental method. Make @PublicAPI at some point.
|
| 164 |
+
def try_render_at(self, index: Optional[int] = None) -> Optional[np.ndarray]:
|
| 165 |
+
"""Renders a single environment.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
index: An optional sub-env index to render.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Either a numpy RGB image (shape=(w x h x 3) dtype=uint8) or
|
| 172 |
+
None in case rendering is handled directly by this method.
|
| 173 |
+
"""
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
def to_base_env(
|
| 177 |
+
self,
|
| 178 |
+
make_env: Optional[Callable[[int], EnvType]] = None,
|
| 179 |
+
num_envs: int = 1,
|
| 180 |
+
remote_envs: bool = False,
|
| 181 |
+
remote_env_batch_wait_ms: int = 0,
|
| 182 |
+
restart_failed_sub_environments: bool = False,
|
| 183 |
+
) -> "BaseEnv":
|
| 184 |
+
"""Converts an RLlib MultiAgentEnv into a BaseEnv object.
|
| 185 |
+
|
| 186 |
+
The resulting BaseEnv is always vectorized (contains n
|
| 187 |
+
sub-environments) to support batched forward passes, where n may
|
| 188 |
+
also be 1. BaseEnv also supports async execution via the `poll` and
|
| 189 |
+
`send_actions` methods and thus supports external simulators.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
make_env: A callable taking an int as input (which indicates
|
| 193 |
+
the number of individual sub-environments within the final
|
| 194 |
+
vectorized BaseEnv) and returning one individual
|
| 195 |
+
sub-environment.
|
| 196 |
+
num_envs: The number of sub-environments to create in the
|
| 197 |
+
resulting (vectorized) BaseEnv. The already existing `env`
|
| 198 |
+
will be one of the `num_envs`.
|
| 199 |
+
remote_envs: Whether each sub-env should be a @ray.remote
|
| 200 |
+
actor. You can set this behavior in your config via the
|
| 201 |
+
`remote_worker_envs=True` option.
|
| 202 |
+
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
| 203 |
+
sub-environments for, if applicable. Only used if
|
| 204 |
+
`remote_envs` is True.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
The resulting BaseEnv object.
|
| 208 |
+
"""
|
| 209 |
+
env = VectorEnvWrapper(self)
|
| 210 |
+
return env
|
| 211 |
+
|
| 212 |
+
@Deprecated(new="vectorize_gym_envs", error=True)
|
| 213 |
+
def wrap(self, *args, **kwargs) -> "_VectorizedGymEnv":
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
@Deprecated(new="get_sub_environments", error=True)
|
| 217 |
+
def get_unwrapped(self) -> List[EnvType]:
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@OldAPIStack
|
| 222 |
+
class _VectorizedGymEnv(VectorEnv):
|
| 223 |
+
"""Internal wrapper to translate any gym.Envs into a VectorEnv object."""
|
| 224 |
+
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
make_env: Optional[Callable[[int], EnvType]] = None,
|
| 228 |
+
existing_envs: Optional[List[gym.Env]] = None,
|
| 229 |
+
num_envs: int = 1,
|
| 230 |
+
*,
|
| 231 |
+
observation_space: Optional[gym.Space] = None,
|
| 232 |
+
action_space: Optional[gym.Space] = None,
|
| 233 |
+
restart_failed_sub_environments: bool = False,
|
| 234 |
+
# Deprecated. These seem to have never been used.
|
| 235 |
+
env_config=None,
|
| 236 |
+
policy_config=None,
|
| 237 |
+
):
|
| 238 |
+
"""Initializes a _VectorizedGymEnv object.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
make_env: Factory that produces a new gym.Env taking the sub-env's
|
| 242 |
+
vector index as only arg. Must be defined if the
|
| 243 |
+
number of `existing_envs` is less than `num_envs`.
|
| 244 |
+
existing_envs: Optional list of already instantiated sub
|
| 245 |
+
environments.
|
| 246 |
+
num_envs: Total number of sub environments in this VectorEnv.
|
| 247 |
+
action_space: The action space. If None, use existing_envs[0]'s
|
| 248 |
+
action space.
|
| 249 |
+
observation_space: The observation space. If None, use
|
| 250 |
+
existing_envs[0]'s observation space.
|
| 251 |
+
restart_failed_sub_environments: If True and any sub-environment (within
|
| 252 |
+
a vectorized env) throws any error during env stepping, we will try to
|
| 253 |
+
restart the faulty sub-environment. This is done
|
| 254 |
+
without disturbing the other (still intact) sub-environments.
|
| 255 |
+
"""
|
| 256 |
+
self.envs = existing_envs
|
| 257 |
+
self.make_env = make_env
|
| 258 |
+
self.restart_failed_sub_environments = restart_failed_sub_environments
|
| 259 |
+
|
| 260 |
+
# Fill up missing envs (so we have exactly num_envs sub-envs in this
|
| 261 |
+
# VectorEnv.
|
| 262 |
+
while len(self.envs) < num_envs:
|
| 263 |
+
self.envs.append(make_env(len(self.envs)))
|
| 264 |
+
|
| 265 |
+
super().__init__(
|
| 266 |
+
observation_space=observation_space or self.envs[0].observation_space,
|
| 267 |
+
action_space=action_space or self.envs[0].action_space,
|
| 268 |
+
num_envs=num_envs,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
@override(VectorEnv)
|
| 272 |
+
def vector_reset(
|
| 273 |
+
self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None
|
| 274 |
+
) -> Tuple[List[EnvObsType], List[EnvInfoDict]]:
|
| 275 |
+
seeds = seeds or [None] * self.num_envs
|
| 276 |
+
options = options or [None] * self.num_envs
|
| 277 |
+
# Use reset_at(index) to restart and retry until
|
| 278 |
+
# we successfully create a new env.
|
| 279 |
+
resetted_obs = []
|
| 280 |
+
resetted_infos = []
|
| 281 |
+
for i in range(len(self.envs)):
|
| 282 |
+
while True:
|
| 283 |
+
obs, infos = self.reset_at(i, seed=seeds[i], options=options[i])
|
| 284 |
+
if not isinstance(obs, Exception):
|
| 285 |
+
break
|
| 286 |
+
resetted_obs.append(obs)
|
| 287 |
+
resetted_infos.append(infos)
|
| 288 |
+
return resetted_obs, resetted_infos
|
| 289 |
+
|
| 290 |
+
@override(VectorEnv)
|
| 291 |
+
def reset_at(
|
| 292 |
+
self,
|
| 293 |
+
index: Optional[int] = None,
|
| 294 |
+
*,
|
| 295 |
+
seed: Optional[int] = None,
|
| 296 |
+
options: Optional[dict] = None,
|
| 297 |
+
) -> Tuple[Union[EnvObsType, Exception], Union[EnvInfoDict, Exception]]:
|
| 298 |
+
if index is None:
|
| 299 |
+
index = 0
|
| 300 |
+
try:
|
| 301 |
+
obs_and_infos = self.envs[index].reset(seed=seed, options=options)
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
if self.restart_failed_sub_environments:
|
| 305 |
+
logger.exception(e.args[0])
|
| 306 |
+
self.restart_at(index)
|
| 307 |
+
obs_and_infos = e, {}
|
| 308 |
+
else:
|
| 309 |
+
raise e
|
| 310 |
+
|
| 311 |
+
return obs_and_infos
|
| 312 |
+
|
| 313 |
+
@override(VectorEnv)
|
| 314 |
+
def restart_at(self, index: Optional[int] = None) -> None:
|
| 315 |
+
if index is None:
|
| 316 |
+
index = 0
|
| 317 |
+
|
| 318 |
+
# Try closing down the old (possibly faulty) sub-env, but ignore errors.
|
| 319 |
+
try:
|
| 320 |
+
self.envs[index].close()
|
| 321 |
+
except Exception as e:
|
| 322 |
+
if log_once("close_sub_env"):
|
| 323 |
+
logger.warning(
|
| 324 |
+
"Trying to close old and replaced sub-environment (at vector "
|
| 325 |
+
f"index={index}), but closing resulted in error:\n{e}"
|
| 326 |
+
)
|
| 327 |
+
env_to_del = self.envs[index]
|
| 328 |
+
self.envs[index] = None
|
| 329 |
+
del env_to_del
|
| 330 |
+
|
| 331 |
+
# Re-create the sub-env at the new index.
|
| 332 |
+
logger.warning(f"Trying to restart sub-environment at index {index}.")
|
| 333 |
+
self.envs[index] = self.make_env(index)
|
| 334 |
+
logger.warning(f"Sub-environment at index {index} restarted successfully.")
|
| 335 |
+
|
| 336 |
+
@override(VectorEnv)
|
| 337 |
+
def vector_step(
|
| 338 |
+
self, actions: List[EnvActionType]
|
| 339 |
+
) -> Tuple[
|
| 340 |
+
List[EnvObsType], List[float], List[bool], List[bool], List[EnvInfoDict]
|
| 341 |
+
]:
|
| 342 |
+
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = (
|
| 343 |
+
[],
|
| 344 |
+
[],
|
| 345 |
+
[],
|
| 346 |
+
[],
|
| 347 |
+
[],
|
| 348 |
+
)
|
| 349 |
+
for i in range(self.num_envs):
|
| 350 |
+
try:
|
| 351 |
+
results = self.envs[i].step(actions[i])
|
| 352 |
+
except Exception as e:
|
| 353 |
+
if self.restart_failed_sub_environments:
|
| 354 |
+
logger.exception(e.args[0])
|
| 355 |
+
self.restart_at(i)
|
| 356 |
+
results = e, 0.0, True, True, {}
|
| 357 |
+
else:
|
| 358 |
+
raise e
|
| 359 |
+
|
| 360 |
+
obs, reward, terminated, truncated, info = results
|
| 361 |
+
|
| 362 |
+
if not isinstance(info, dict):
|
| 363 |
+
raise ValueError(
|
| 364 |
+
"Info should be a dict, got {} ({})".format(info, type(info))
|
| 365 |
+
)
|
| 366 |
+
obs_batch.append(obs)
|
| 367 |
+
reward_batch.append(reward)
|
| 368 |
+
terminated_batch.append(terminated)
|
| 369 |
+
truncated_batch.append(truncated)
|
| 370 |
+
info_batch.append(info)
|
| 371 |
+
return obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch
|
| 372 |
+
|
| 373 |
+
@override(VectorEnv)
|
| 374 |
+
def get_sub_environments(self) -> List[EnvType]:
|
| 375 |
+
return self.envs
|
| 376 |
+
|
| 377 |
+
@override(VectorEnv)
|
| 378 |
+
def try_render_at(self, index: Optional[int] = None):
|
| 379 |
+
if index is None:
|
| 380 |
+
index = 0
|
| 381 |
+
return self.envs[index].render()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@OldAPIStack
|
| 385 |
+
class VectorEnvWrapper(BaseEnv):
|
| 386 |
+
"""Internal adapter of VectorEnv to BaseEnv.
|
| 387 |
+
|
| 388 |
+
We assume the caller will always send the full vector of actions in each
|
| 389 |
+
call to send_actions(), and that they call reset_at() on all completed
|
| 390 |
+
environments before calling send_actions().
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
def __init__(self, vector_env: VectorEnv):
|
| 394 |
+
self.vector_env = vector_env
|
| 395 |
+
self.num_envs = vector_env.num_envs
|
| 396 |
+
self._observation_space = vector_env.observation_space
|
| 397 |
+
self._action_space = vector_env.action_space
|
| 398 |
+
|
| 399 |
+
# Sub-environments' states.
|
| 400 |
+
self.new_obs = None
|
| 401 |
+
self.cur_rewards = None
|
| 402 |
+
self.cur_terminateds = None
|
| 403 |
+
self.cur_truncateds = None
|
| 404 |
+
self.cur_infos = None
|
| 405 |
+
# At first `poll()`, reset everything (all sub-environments).
|
| 406 |
+
self.first_reset_done = False
|
| 407 |
+
# Initialize sub-environments' state.
|
| 408 |
+
self._init_env_state(idx=None)
|
| 409 |
+
|
| 410 |
+
@override(BaseEnv)
|
| 411 |
+
def poll(
|
| 412 |
+
self,
|
| 413 |
+
) -> Tuple[
|
| 414 |
+
MultiEnvDict,
|
| 415 |
+
MultiEnvDict,
|
| 416 |
+
MultiEnvDict,
|
| 417 |
+
MultiEnvDict,
|
| 418 |
+
MultiEnvDict,
|
| 419 |
+
MultiEnvDict,
|
| 420 |
+
]:
|
| 421 |
+
from ray.rllib.env.base_env import with_dummy_agent_id
|
| 422 |
+
|
| 423 |
+
if not self.first_reset_done:
|
| 424 |
+
self.first_reset_done = True
|
| 425 |
+
# TODO(sven): We probably would like to seed this call here as well.
|
| 426 |
+
self.new_obs, self.cur_infos = self.vector_env.vector_reset()
|
| 427 |
+
new_obs = dict(enumerate(self.new_obs))
|
| 428 |
+
rewards = dict(enumerate(self.cur_rewards))
|
| 429 |
+
terminateds = dict(enumerate(self.cur_terminateds))
|
| 430 |
+
truncateds = dict(enumerate(self.cur_truncateds))
|
| 431 |
+
infos = dict(enumerate(self.cur_infos))
|
| 432 |
+
|
| 433 |
+
# Empty all states (in case `poll()` gets called again).
|
| 434 |
+
self.new_obs = []
|
| 435 |
+
self.cur_rewards = []
|
| 436 |
+
self.cur_terminateds = []
|
| 437 |
+
self.cur_truncateds = []
|
| 438 |
+
self.cur_infos = []
|
| 439 |
+
|
| 440 |
+
return (
|
| 441 |
+
with_dummy_agent_id(new_obs),
|
| 442 |
+
with_dummy_agent_id(rewards),
|
| 443 |
+
with_dummy_agent_id(terminateds, "__all__"),
|
| 444 |
+
with_dummy_agent_id(truncateds, "__all__"),
|
| 445 |
+
with_dummy_agent_id(infos),
|
| 446 |
+
{},
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
@override(BaseEnv)
|
| 450 |
+
def send_actions(self, action_dict: MultiEnvDict) -> None:
|
| 451 |
+
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
| 452 |
+
|
| 453 |
+
action_vector = [None] * self.num_envs
|
| 454 |
+
for i in range(self.num_envs):
|
| 455 |
+
action_vector[i] = action_dict[i][_DUMMY_AGENT_ID]
|
| 456 |
+
(
|
| 457 |
+
self.new_obs,
|
| 458 |
+
self.cur_rewards,
|
| 459 |
+
self.cur_terminateds,
|
| 460 |
+
self.cur_truncateds,
|
| 461 |
+
self.cur_infos,
|
| 462 |
+
) = self.vector_env.vector_step(action_vector)
|
| 463 |
+
|
| 464 |
+
@override(BaseEnv)
|
| 465 |
+
def try_reset(
|
| 466 |
+
self,
|
| 467 |
+
env_id: Optional[EnvID] = None,
|
| 468 |
+
*,
|
| 469 |
+
seed: Optional[int] = None,
|
| 470 |
+
options: Optional[dict] = None,
|
| 471 |
+
) -> Tuple[MultiEnvDict, MultiEnvDict]:
|
| 472 |
+
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
| 473 |
+
|
| 474 |
+
if env_id is None:
|
| 475 |
+
env_id = 0
|
| 476 |
+
assert isinstance(env_id, int)
|
| 477 |
+
obs, infos = self.vector_env.reset_at(env_id, seed=seed, options=options)
|
| 478 |
+
|
| 479 |
+
# If exceptions were returned, return MultiEnvDict mapping env indices to
|
| 480 |
+
# these exceptions (for obs and infos).
|
| 481 |
+
if isinstance(obs, Exception):
|
| 482 |
+
return {env_id: obs}, {env_id: infos}
|
| 483 |
+
# Otherwise, return a MultiEnvDict (with single agent ID) and the actual
|
| 484 |
+
# obs and info dicts.
|
| 485 |
+
else:
|
| 486 |
+
return {env_id: {_DUMMY_AGENT_ID: obs}}, {env_id: {_DUMMY_AGENT_ID: infos}}
|
| 487 |
+
|
| 488 |
+
@override(BaseEnv)
|
| 489 |
+
def try_restart(self, env_id: Optional[EnvID] = None) -> None:
|
| 490 |
+
assert env_id is None or isinstance(env_id, int)
|
| 491 |
+
# Restart the sub-env at the index.
|
| 492 |
+
self.vector_env.restart_at(env_id)
|
| 493 |
+
# Auto-reset (get ready for next `poll()`).
|
| 494 |
+
self._init_env_state(env_id)
|
| 495 |
+
|
| 496 |
+
@override(BaseEnv)
|
| 497 |
+
def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
|
| 498 |
+
if not as_dict:
|
| 499 |
+
return self.vector_env.get_sub_environments()
|
| 500 |
+
else:
|
| 501 |
+
return {
|
| 502 |
+
_id: env
|
| 503 |
+
for _id, env in enumerate(self.vector_env.get_sub_environments())
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
@override(BaseEnv)
|
| 507 |
+
def try_render(self, env_id: Optional[EnvID] = None) -> None:
|
| 508 |
+
assert env_id is None or isinstance(env_id, int)
|
| 509 |
+
return self.vector_env.try_render_at(env_id)
|
| 510 |
+
|
| 511 |
+
@property
|
| 512 |
+
@override(BaseEnv)
|
| 513 |
+
def observation_space(self) -> gym.Space:
|
| 514 |
+
return self._observation_space
|
| 515 |
+
|
| 516 |
+
@property
|
| 517 |
+
@override(BaseEnv)
|
| 518 |
+
def action_space(self) -> gym.Space:
|
| 519 |
+
return self._action_space
|
| 520 |
+
|
| 521 |
+
@override(BaseEnv)
|
| 522 |
+
def get_agent_ids(self) -> Set[AgentID]:
|
| 523 |
+
return {_DUMMY_AGENT_ID}
|
| 524 |
+
|
| 525 |
+
def _init_env_state(self, idx: Optional[int] = None) -> None:
|
| 526 |
+
"""Resets all or one particular sub-environment's state (by index).
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
idx: The index to reset at. If None, reset all the sub-environments' states.
|
| 530 |
+
"""
|
| 531 |
+
# If index is None, reset all sub-envs' states:
|
| 532 |
+
if idx is None:
|
| 533 |
+
self.new_obs = [None for _ in range(self.num_envs)]
|
| 534 |
+
self.cur_rewards = [0.0 for _ in range(self.num_envs)]
|
| 535 |
+
self.cur_terminateds = [False for _ in range(self.num_envs)]
|
| 536 |
+
self.cur_truncateds = [False for _ in range(self.num_envs)]
|
| 537 |
+
self.cur_infos = [{} for _ in range(self.num_envs)]
|
| 538 |
+
# Index provided, reset only the sub-env's state at the given index.
|
| 539 |
+
else:
|
| 540 |
+
self.new_obs[idx], self.cur_infos[idx] = self.vector_env.reset_at(idx)
|
| 541 |
+
# Reset all other states to null values.
|
| 542 |
+
self.cur_rewards[idx] = 0.0
|
| 543 |
+
self.cur_terminateds[idx] = False
|
| 544 |
+
self.cur_truncateds[idx] = False
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/atari_wrappers.cpython-311.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/group_agents_wrapper.cpython-311.pyc
ADDED
|
Binary file (7.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/multi_agent_env_compatibility.cpython-311.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/open_spiel.cpython-311.pyc
ADDED
|
Binary file (8.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/__pycache__/unity3d_env.cpython-311.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/atari_wrappers.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium import spaces
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
|
| 7 |
+
from ray.rllib.utils.annotations import PublicAPI
|
| 8 |
+
from ray.rllib.utils.images import rgb2gray, resize
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@PublicAPI
|
| 12 |
+
def is_atari(env: Union[gym.Env, str]) -> bool:
|
| 13 |
+
"""Returns, whether a given env object or env descriptor (str) is an Atari env.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
env: The gym.Env object or a string descriptor of the env (for example,
|
| 17 |
+
"ale_py:ALE/Pong-v5").
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Whether `env` is an Atari environment.
|
| 21 |
+
"""
|
| 22 |
+
# If a gym.Env, check proper spaces as well as occurrence of the "Atari<ALE" string
|
| 23 |
+
# in the class name.
|
| 24 |
+
if not isinstance(env, str):
|
| 25 |
+
if (
|
| 26 |
+
hasattr(env.observation_space, "shape")
|
| 27 |
+
and env.observation_space.shape is not None
|
| 28 |
+
and len(env.observation_space.shape) <= 2
|
| 29 |
+
):
|
| 30 |
+
return False
|
| 31 |
+
return "AtariEnv<ALE" in str(env)
|
| 32 |
+
# If string, check for "ale_py:ALE/" prefix.
|
| 33 |
+
else:
|
| 34 |
+
return env.startswith("ALE/") or env.startswith("ale_py:")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@PublicAPI
|
| 38 |
+
def get_wrapper_by_cls(env, cls):
|
| 39 |
+
"""Returns the gym env wrapper of the given class, or None."""
|
| 40 |
+
currentenv = env
|
| 41 |
+
while True:
|
| 42 |
+
if isinstance(currentenv, cls):
|
| 43 |
+
return currentenv
|
| 44 |
+
elif isinstance(currentenv, gym.Wrapper):
|
| 45 |
+
currentenv = currentenv.env
|
| 46 |
+
else:
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@PublicAPI
|
| 51 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 52 |
+
def __init__(self, env):
|
| 53 |
+
gym.RewardWrapper.__init__(self, env)
|
| 54 |
+
|
| 55 |
+
def reward(self, reward):
|
| 56 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 57 |
+
return np.sign(reward)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@PublicAPI
|
| 61 |
+
class EpisodicLifeEnv(gym.Wrapper):
|
| 62 |
+
def __init__(self, env):
|
| 63 |
+
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
| 64 |
+
Done by DeepMind for the DQN and co. since it helps value estimation.
|
| 65 |
+
"""
|
| 66 |
+
gym.Wrapper.__init__(self, env)
|
| 67 |
+
self.lives = 0
|
| 68 |
+
self.was_real_terminated = True
|
| 69 |
+
|
| 70 |
+
def step(self, action):
|
| 71 |
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
| 72 |
+
self.was_real_terminated = terminated
|
| 73 |
+
# check current lives, make loss of life terminal,
|
| 74 |
+
# then update lives to handle bonus lives
|
| 75 |
+
lives = self.env.unwrapped.ale.lives()
|
| 76 |
+
if lives < self.lives and lives > 0:
|
| 77 |
+
# for Qbert sometimes we stay in lives == 0 condtion for a few fr
|
| 78 |
+
# so its important to keep lives > 0, so that we only reset once
|
| 79 |
+
# the environment advertises `terminated`.
|
| 80 |
+
terminated = True
|
| 81 |
+
self.lives = lives
|
| 82 |
+
return obs, reward, terminated, truncated, info
|
| 83 |
+
|
| 84 |
+
def reset(self, **kwargs):
|
| 85 |
+
"""Reset only when lives are exhausted.
|
| 86 |
+
This way all states are still reachable even though lives are episodic,
|
| 87 |
+
and the learner need not know about any of this behind-the-scenes.
|
| 88 |
+
"""
|
| 89 |
+
if self.was_real_terminated:
|
| 90 |
+
obs, info = self.env.reset(**kwargs)
|
| 91 |
+
else:
|
| 92 |
+
# no-op step to advance from terminal/lost life state
|
| 93 |
+
obs, _, _, _, info = self.env.step(0)
|
| 94 |
+
self.lives = self.env.unwrapped.ale.lives()
|
| 95 |
+
return obs, info
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@PublicAPI
|
| 99 |
+
class FireResetEnv(gym.Wrapper):
|
| 100 |
+
def __init__(self, env):
|
| 101 |
+
"""Take action on reset.
|
| 102 |
+
|
| 103 |
+
For environments that are fixed until firing."""
|
| 104 |
+
gym.Wrapper.__init__(self, env)
|
| 105 |
+
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
|
| 106 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
| 107 |
+
|
| 108 |
+
def reset(self, **kwargs):
|
| 109 |
+
self.env.reset(**kwargs)
|
| 110 |
+
obs, _, terminated, truncated, _ = self.env.step(1)
|
| 111 |
+
if terminated or truncated:
|
| 112 |
+
self.env.reset(**kwargs)
|
| 113 |
+
obs, _, terminated, truncated, info = self.env.step(2)
|
| 114 |
+
if terminated or truncated:
|
| 115 |
+
self.env.reset(**kwargs)
|
| 116 |
+
return obs, info
|
| 117 |
+
|
| 118 |
+
def step(self, ac):
|
| 119 |
+
return self.env.step(ac)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@PublicAPI
|
| 123 |
+
class FrameStack(gym.Wrapper):
|
| 124 |
+
def __init__(self, env, k):
|
| 125 |
+
"""Stack k last frames."""
|
| 126 |
+
gym.Wrapper.__init__(self, env)
|
| 127 |
+
self.k = k
|
| 128 |
+
self.frames = deque([], maxlen=k)
|
| 129 |
+
shp = env.observation_space.shape
|
| 130 |
+
self.observation_space = spaces.Box(
|
| 131 |
+
low=np.repeat(env.observation_space.low, repeats=k, axis=-1),
|
| 132 |
+
high=np.repeat(env.observation_space.high, repeats=k, axis=-1),
|
| 133 |
+
shape=(shp[0], shp[1], shp[2] * k),
|
| 134 |
+
dtype=env.observation_space.dtype,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def reset(self, *, seed=None, options=None):
|
| 138 |
+
ob, infos = self.env.reset(seed=seed, options=options)
|
| 139 |
+
for _ in range(self.k):
|
| 140 |
+
self.frames.append(ob)
|
| 141 |
+
return self._get_ob(), infos
|
| 142 |
+
|
| 143 |
+
def step(self, action):
|
| 144 |
+
ob, reward, terminated, truncated, info = self.env.step(action)
|
| 145 |
+
self.frames.append(ob)
|
| 146 |
+
return self._get_ob(), reward, terminated, truncated, info
|
| 147 |
+
|
| 148 |
+
def _get_ob(self):
|
| 149 |
+
assert len(self.frames) == self.k
|
| 150 |
+
return np.concatenate(self.frames, axis=2)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@PublicAPI
|
| 154 |
+
class FrameStackTrajectoryView(gym.ObservationWrapper):
|
| 155 |
+
def __init__(self, env):
|
| 156 |
+
"""No stacking. Trajectory View API takes care of this."""
|
| 157 |
+
gym.Wrapper.__init__(self, env)
|
| 158 |
+
shp = env.observation_space.shape
|
| 159 |
+
assert shp[2] == 1
|
| 160 |
+
self.observation_space = spaces.Box(
|
| 161 |
+
low=0, high=255, shape=(shp[0], shp[1]), dtype=env.observation_space.dtype
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def observation(self, observation):
|
| 165 |
+
return np.squeeze(observation, axis=-1)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@PublicAPI
|
| 169 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 170 |
+
def __init__(self, env, skip=4):
|
| 171 |
+
"""Return only every `skip`-th frame"""
|
| 172 |
+
gym.Wrapper.__init__(self, env)
|
| 173 |
+
# most recent raw observations (for max pooling across time steps)
|
| 174 |
+
self._obs_buffer = np.zeros(
|
| 175 |
+
(2,) + env.observation_space.shape, dtype=env.observation_space.dtype
|
| 176 |
+
)
|
| 177 |
+
self._skip = skip
|
| 178 |
+
|
| 179 |
+
def step(self, action):
|
| 180 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 181 |
+
total_reward = 0.0
|
| 182 |
+
terminated = truncated = info = None
|
| 183 |
+
for i in range(self._skip):
|
| 184 |
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
| 185 |
+
if i == self._skip - 2:
|
| 186 |
+
self._obs_buffer[0] = obs
|
| 187 |
+
if i == self._skip - 1:
|
| 188 |
+
self._obs_buffer[1] = obs
|
| 189 |
+
total_reward += reward
|
| 190 |
+
if terminated or truncated:
|
| 191 |
+
break
|
| 192 |
+
# Note that the observation on the terminated|truncated=True frame
|
| 193 |
+
# doesn't matter
|
| 194 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 195 |
+
|
| 196 |
+
return max_frame, total_reward, terminated, truncated, info
|
| 197 |
+
|
| 198 |
+
def reset(self, **kwargs):
|
| 199 |
+
return self.env.reset(**kwargs)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@PublicAPI
|
| 203 |
+
class MonitorEnv(gym.Wrapper):
|
| 204 |
+
def __init__(self, env=None):
|
| 205 |
+
"""Record episodes stats prior to EpisodicLifeEnv, etc."""
|
| 206 |
+
gym.Wrapper.__init__(self, env)
|
| 207 |
+
self._current_reward = None
|
| 208 |
+
self._num_steps = None
|
| 209 |
+
self._total_steps = None
|
| 210 |
+
self._episode_rewards = []
|
| 211 |
+
self._episode_lengths = []
|
| 212 |
+
self._num_episodes = 0
|
| 213 |
+
self._num_returned = 0
|
| 214 |
+
|
| 215 |
+
def reset(self, **kwargs):
|
| 216 |
+
obs, info = self.env.reset(**kwargs)
|
| 217 |
+
|
| 218 |
+
if self._total_steps is None:
|
| 219 |
+
self._total_steps = sum(self._episode_lengths)
|
| 220 |
+
|
| 221 |
+
if self._current_reward is not None:
|
| 222 |
+
self._episode_rewards.append(self._current_reward)
|
| 223 |
+
self._episode_lengths.append(self._num_steps)
|
| 224 |
+
self._num_episodes += 1
|
| 225 |
+
|
| 226 |
+
self._current_reward = 0
|
| 227 |
+
self._num_steps = 0
|
| 228 |
+
|
| 229 |
+
return obs, info
|
| 230 |
+
|
| 231 |
+
def step(self, action):
|
| 232 |
+
obs, rew, terminated, truncated, info = self.env.step(action)
|
| 233 |
+
self._current_reward += rew
|
| 234 |
+
self._num_steps += 1
|
| 235 |
+
self._total_steps += 1
|
| 236 |
+
return obs, rew, terminated, truncated, info
|
| 237 |
+
|
| 238 |
+
def get_episode_rewards(self):
|
| 239 |
+
return self._episode_rewards
|
| 240 |
+
|
| 241 |
+
def get_episode_lengths(self):
|
| 242 |
+
return self._episode_lengths
|
| 243 |
+
|
| 244 |
+
def get_total_steps(self):
|
| 245 |
+
return self._total_steps
|
| 246 |
+
|
| 247 |
+
def next_episode_results(self):
|
| 248 |
+
for i in range(self._num_returned, len(self._episode_rewards)):
|
| 249 |
+
yield (self._episode_rewards[i], self._episode_lengths[i])
|
| 250 |
+
self._num_returned = len(self._episode_rewards)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@PublicAPI
|
| 254 |
+
class NoopResetEnv(gym.Wrapper):
|
| 255 |
+
def __init__(self, env, noop_max=30):
|
| 256 |
+
"""Sample initial states by taking random number of no-ops on reset.
|
| 257 |
+
No-op is assumed to be action 0.
|
| 258 |
+
"""
|
| 259 |
+
gym.Wrapper.__init__(self, env)
|
| 260 |
+
self.noop_max = noop_max
|
| 261 |
+
self.override_num_noops = None
|
| 262 |
+
self.noop_action = 0
|
| 263 |
+
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
|
| 264 |
+
|
| 265 |
+
def reset(self, **kwargs):
|
| 266 |
+
"""Do no-op action for a number of steps in [1, noop_max]."""
|
| 267 |
+
self.env.reset(**kwargs)
|
| 268 |
+
if self.override_num_noops is not None:
|
| 269 |
+
noops = self.override_num_noops
|
| 270 |
+
else:
|
| 271 |
+
# This environment now uses the pcg64 random number generator which
|
| 272 |
+
# does not have `randint` as an attribute only has `integers`.
|
| 273 |
+
try:
|
| 274 |
+
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
| 275 |
+
# Also still support older versions.
|
| 276 |
+
except AttributeError:
|
| 277 |
+
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
| 278 |
+
assert noops > 0
|
| 279 |
+
obs = None
|
| 280 |
+
for _ in range(noops):
|
| 281 |
+
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
|
| 282 |
+
if terminated or truncated:
|
| 283 |
+
obs, info = self.env.reset(**kwargs)
|
| 284 |
+
return obs, info
|
| 285 |
+
|
| 286 |
+
def step(self, ac):
|
| 287 |
+
return self.env.step(ac)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@PublicAPI
|
| 291 |
+
class NormalizedImageEnv(gym.ObservationWrapper):
|
| 292 |
+
def __init__(self, *args, **kwargs):
|
| 293 |
+
super().__init__(*args, **kwargs)
|
| 294 |
+
self.observation_space = gym.spaces.Box(
|
| 295 |
+
-1.0,
|
| 296 |
+
1.0,
|
| 297 |
+
shape=self.observation_space.shape,
|
| 298 |
+
dtype=np.float32,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Divide by scale and center around 0.0, such that observations are in the range
|
| 302 |
+
# of -1.0 and 1.0.
|
| 303 |
+
def observation(self, observation):
|
| 304 |
+
return (observation.astype(np.float32) / 128.0) - 1.0
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@PublicAPI
|
| 308 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 309 |
+
def __init__(self, env, dim):
|
| 310 |
+
"""Warp frames to the specified size (dim x dim)."""
|
| 311 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 312 |
+
self.width = dim
|
| 313 |
+
self.height = dim
|
| 314 |
+
self.observation_space = spaces.Box(
|
| 315 |
+
low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def observation(self, frame):
|
| 319 |
+
frame = rgb2gray(frame)
|
| 320 |
+
frame = resize(frame, height=self.height, width=self.width)
|
| 321 |
+
return frame[:, :, None]
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@PublicAPI
|
| 325 |
+
def wrap_atari_for_new_api_stack(
|
| 326 |
+
env: gym.Env,
|
| 327 |
+
dim: int = 64,
|
| 328 |
+
frameskip: int = 4,
|
| 329 |
+
framestack: Optional[int] = None,
|
| 330 |
+
# TODO (sven): Add option to NOT grayscale, in which case framestack must be None
|
| 331 |
+
# (b/c we are using the 3 color channels already as stacking frames).
|
| 332 |
+
) -> gym.Env:
|
| 333 |
+
"""Wraps `env` for new-API-stack-friendly RLlib Atari experiments.
|
| 334 |
+
|
| 335 |
+
Note that we assume reward clipping is done outside the wrapper.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
env: The env object to wrap.
|
| 339 |
+
dim: Dimension to resize observations to (dim x dim).
|
| 340 |
+
frameskip: Whether to skip n frames and max over them (keep brightest pixels).
|
| 341 |
+
framestack: Whether to stack the last n (grayscaled) frames. Note that this
|
| 342 |
+
step happens after(!) a possible frameskip step, meaning that if
|
| 343 |
+
frameskip=4 and framestack=2, we would perform the following over this
|
| 344 |
+
trajectory:
|
| 345 |
+
actual env timesteps: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> ...
|
| 346 |
+
frameskip: ( max ) ( max ) ( max ) ( max )
|
| 347 |
+
framestack: ( stack ) (stack )
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
The wrapped gym.Env.
|
| 351 |
+
"""
|
| 352 |
+
# Time limit.
|
| 353 |
+
env = gym.wrappers.TimeLimit(env, max_episode_steps=108000)
|
| 354 |
+
# Grayscale + resize.
|
| 355 |
+
env = WarpFrame(env, dim=dim)
|
| 356 |
+
# Normalize the image.
|
| 357 |
+
env = NormalizedImageEnv(env)
|
| 358 |
+
# Frameskip: Take max over these n frames.
|
| 359 |
+
if frameskip > 1:
|
| 360 |
+
assert env.spec is not None
|
| 361 |
+
env = MaxAndSkipEnv(env, skip=frameskip)
|
| 362 |
+
# Send n noop actions into env after reset to increase variance in the
|
| 363 |
+
# "start states" of the trajectories. These dummy steps are NOT included in the
|
| 364 |
+
# sampled data used for learning.
|
| 365 |
+
env = NoopResetEnv(env, noop_max=30)
|
| 366 |
+
# Each life is one episode.
|
| 367 |
+
env = EpisodicLifeEnv(env)
|
| 368 |
+
# Some envs only start playing after pressing fire. Unblock those.
|
| 369 |
+
if "FIRE" in env.unwrapped.get_action_meanings():
|
| 370 |
+
env = FireResetEnv(env)
|
| 371 |
+
# Framestack.
|
| 372 |
+
if framestack:
|
| 373 |
+
env = FrameStack(env, k=framestack)
|
| 374 |
+
return env
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@PublicAPI
|
| 378 |
+
def wrap_deepmind(env, dim=84, framestack=True, noframeskip=False):
|
| 379 |
+
"""Configure environment for DeepMind-style Atari.
|
| 380 |
+
|
| 381 |
+
Note that we assume reward clipping is done outside the wrapper.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
env: The env object to wrap.
|
| 385 |
+
dim: Dimension to resize observations to (dim x dim).
|
| 386 |
+
framestack: Whether to framestack observations.
|
| 387 |
+
"""
|
| 388 |
+
env = MonitorEnv(env)
|
| 389 |
+
env = NoopResetEnv(env, noop_max=30)
|
| 390 |
+
if env.spec is not None and noframeskip is True:
|
| 391 |
+
env = MaxAndSkipEnv(env, skip=4)
|
| 392 |
+
env = EpisodicLifeEnv(env)
|
| 393 |
+
if "FIRE" in env.unwrapped.get_action_meanings():
|
| 394 |
+
env = FireResetEnv(env)
|
| 395 |
+
env = WarpFrame(env, dim)
|
| 396 |
+
# env = ClipRewardEnv(env) # reward clipping is handled by policy eval
|
| 397 |
+
# 4x image framestacking.
|
| 398 |
+
if framestack is True:
|
| 399 |
+
env = FrameStack(env, 4)
|
| 400 |
+
return env
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_control_wrapper.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepMind Control Suite Wrapper directly sourced from:
|
| 3 |
+
https://github.com/denisyarats/dmc2gym
|
| 4 |
+
|
| 5 |
+
MIT License
|
| 6 |
+
|
| 7 |
+
Copyright (c) 2020 Denis Yarats
|
| 8 |
+
|
| 9 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 10 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 11 |
+
in the Software without restriction, including without limitation the rights
|
| 12 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 13 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 14 |
+
furnished to do so, subject to the following conditions:
|
| 15 |
+
|
| 16 |
+
The above copyright notice and this permission notice shall be included in all
|
| 17 |
+
copies or substantial portions of the Software.
|
| 18 |
+
|
| 19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 20 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 21 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 22 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 23 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 24 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 25 |
+
SOFTWARE.
|
| 26 |
+
"""
|
| 27 |
+
from gymnasium import core, spaces
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from dm_env import specs
|
| 31 |
+
except ImportError:
|
| 32 |
+
specs = None
|
| 33 |
+
try:
|
| 34 |
+
# Suppress MuJoCo warning (dm_control uses absl logging).
|
| 35 |
+
import absl.logging
|
| 36 |
+
|
| 37 |
+
absl.logging.set_verbosity("error")
|
| 38 |
+
from dm_control import suite
|
| 39 |
+
except (ImportError, OSError):
|
| 40 |
+
suite = None
|
| 41 |
+
import numpy as np
|
| 42 |
+
|
| 43 |
+
from ray.rllib.utils.annotations import PublicAPI
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _spec_to_box(spec):
|
| 47 |
+
def extract_min_max(s):
|
| 48 |
+
assert s.dtype == np.float64 or s.dtype == np.float32
|
| 49 |
+
dim = np.int_(np.prod(s.shape))
|
| 50 |
+
if type(s) is specs.Array:
|
| 51 |
+
bound = np.inf * np.ones(dim, dtype=np.float32)
|
| 52 |
+
return -bound, bound
|
| 53 |
+
elif type(s) is specs.BoundedArray:
|
| 54 |
+
zeros = np.zeros(dim, dtype=np.float32)
|
| 55 |
+
return s.minimum + zeros, s.maximum + zeros
|
| 56 |
+
|
| 57 |
+
mins, maxs = [], []
|
| 58 |
+
for s in spec:
|
| 59 |
+
mn, mx = extract_min_max(s)
|
| 60 |
+
mins.append(mn)
|
| 61 |
+
maxs.append(mx)
|
| 62 |
+
low = np.concatenate(mins, axis=0)
|
| 63 |
+
high = np.concatenate(maxs, axis=0)
|
| 64 |
+
assert low.shape == high.shape
|
| 65 |
+
return spaces.Box(low, high, dtype=np.float32)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _flatten_obs(obs):
|
| 69 |
+
obs_pieces = []
|
| 70 |
+
for v in obs.values():
|
| 71 |
+
flat = np.array([v]) if np.isscalar(v) else v.ravel()
|
| 72 |
+
obs_pieces.append(flat)
|
| 73 |
+
return np.concatenate(obs_pieces, axis=0)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@PublicAPI
|
| 77 |
+
class DMCEnv(core.Env):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
domain_name,
|
| 81 |
+
task_name,
|
| 82 |
+
task_kwargs=None,
|
| 83 |
+
visualize_reward=False,
|
| 84 |
+
from_pixels=False,
|
| 85 |
+
height=64,
|
| 86 |
+
width=64,
|
| 87 |
+
camera_id=0,
|
| 88 |
+
frame_skip=2,
|
| 89 |
+
environment_kwargs=None,
|
| 90 |
+
channels_first=True,
|
| 91 |
+
preprocess=True,
|
| 92 |
+
):
|
| 93 |
+
self._from_pixels = from_pixels
|
| 94 |
+
self._height = height
|
| 95 |
+
self._width = width
|
| 96 |
+
self._camera_id = camera_id
|
| 97 |
+
self._frame_skip = frame_skip
|
| 98 |
+
self._channels_first = channels_first
|
| 99 |
+
self.preprocess = preprocess
|
| 100 |
+
|
| 101 |
+
if specs is None:
|
| 102 |
+
raise RuntimeError(
|
| 103 |
+
(
|
| 104 |
+
"The `specs` module from `dm_env` was not imported. Make sure "
|
| 105 |
+
"`dm_env` is installed and visible in the current python "
|
| 106 |
+
"environment."
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
if suite is None:
|
| 110 |
+
raise RuntimeError(
|
| 111 |
+
(
|
| 112 |
+
"The `suite` module from `dm_control` was not imported. Make "
|
| 113 |
+
"sure `dm_control` is installed and visible in the current "
|
| 114 |
+
"python enviornment."
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# create task
|
| 119 |
+
self._env = suite.load(
|
| 120 |
+
domain_name=domain_name,
|
| 121 |
+
task_name=task_name,
|
| 122 |
+
task_kwargs=task_kwargs,
|
| 123 |
+
visualize_reward=visualize_reward,
|
| 124 |
+
environment_kwargs=environment_kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# true and normalized action spaces
|
| 128 |
+
self._true_action_space = _spec_to_box([self._env.action_spec()])
|
| 129 |
+
self._norm_action_space = spaces.Box(
|
| 130 |
+
low=-1.0, high=1.0, shape=self._true_action_space.shape, dtype=np.float32
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# create observation space
|
| 134 |
+
if from_pixels:
|
| 135 |
+
shape = [3, height, width] if channels_first else [height, width, 3]
|
| 136 |
+
self._observation_space = spaces.Box(
|
| 137 |
+
low=0, high=255, shape=shape, dtype=np.uint8
|
| 138 |
+
)
|
| 139 |
+
if preprocess:
|
| 140 |
+
self._observation_space = spaces.Box(
|
| 141 |
+
low=-0.5, high=0.5, shape=shape, dtype=np.float32
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
self._observation_space = _spec_to_box(
|
| 145 |
+
self._env.observation_spec().values()
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self._state_space = _spec_to_box(self._env.observation_spec().values())
|
| 149 |
+
|
| 150 |
+
self.current_state = None
|
| 151 |
+
|
| 152 |
+
def __getattr__(self, name):
|
| 153 |
+
return getattr(self._env, name)
|
| 154 |
+
|
| 155 |
+
def _get_obs(self, time_step):
|
| 156 |
+
if self._from_pixels:
|
| 157 |
+
obs = self.render(
|
| 158 |
+
height=self._height, width=self._width, camera_id=self._camera_id
|
| 159 |
+
)
|
| 160 |
+
if self._channels_first:
|
| 161 |
+
obs = obs.transpose(2, 0, 1).copy()
|
| 162 |
+
if self.preprocess:
|
| 163 |
+
obs = obs / 255.0 - 0.5
|
| 164 |
+
else:
|
| 165 |
+
obs = _flatten_obs(time_step.observation)
|
| 166 |
+
return obs.astype(np.float32)
|
| 167 |
+
|
| 168 |
+
def _convert_action(self, action):
|
| 169 |
+
action = action.astype(np.float64)
|
| 170 |
+
true_delta = self._true_action_space.high - self._true_action_space.low
|
| 171 |
+
norm_delta = self._norm_action_space.high - self._norm_action_space.low
|
| 172 |
+
action = (action - self._norm_action_space.low) / norm_delta
|
| 173 |
+
action = action * true_delta + self._true_action_space.low
|
| 174 |
+
action = action.astype(np.float32)
|
| 175 |
+
return action
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def observation_space(self):
|
| 179 |
+
return self._observation_space
|
| 180 |
+
|
| 181 |
+
@property
|
| 182 |
+
def state_space(self):
|
| 183 |
+
return self._state_space
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def action_space(self):
|
| 187 |
+
return self._norm_action_space
|
| 188 |
+
|
| 189 |
+
def step(self, action):
|
| 190 |
+
assert self._norm_action_space.contains(action)
|
| 191 |
+
action = self._convert_action(action)
|
| 192 |
+
assert self._true_action_space.contains(action)
|
| 193 |
+
reward = 0.0
|
| 194 |
+
extra = {"internal_state": self._env.physics.get_state().copy()}
|
| 195 |
+
|
| 196 |
+
terminated = truncated = False
|
| 197 |
+
for _ in range(self._frame_skip):
|
| 198 |
+
time_step = self._env.step(action)
|
| 199 |
+
reward += time_step.reward or 0.0
|
| 200 |
+
terminated = False
|
| 201 |
+
truncated = time_step.last()
|
| 202 |
+
if terminated or truncated:
|
| 203 |
+
break
|
| 204 |
+
obs = self._get_obs(time_step)
|
| 205 |
+
self.current_state = _flatten_obs(time_step.observation)
|
| 206 |
+
extra["discount"] = time_step.discount
|
| 207 |
+
return obs, reward, terminated, truncated, extra
|
| 208 |
+
|
| 209 |
+
def reset(self, *, seed=None, options=None):
|
| 210 |
+
time_step = self._env.reset()
|
| 211 |
+
self.current_state = _flatten_obs(time_step.observation)
|
| 212 |
+
obs = self._get_obs(time_step)
|
| 213 |
+
return obs, {}
|
| 214 |
+
|
| 215 |
+
def render(self, mode="rgb_array", height=None, width=None, camera_id=0):
|
| 216 |
+
assert mode == "rgb_array", "only support for rgb_array mode"
|
| 217 |
+
height = height or self._height
|
| 218 |
+
width = width or self._width
|
| 219 |
+
camera_id = camera_id or self._camera_id
|
| 220 |
+
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/dm_env_wrapper.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium import spaces
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from dm_env import specs
|
| 8 |
+
except ImportError:
|
| 9 |
+
specs = None
|
| 10 |
+
|
| 11 |
+
from ray.rllib.utils.annotations import PublicAPI
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _convert_spec_to_space(spec):
|
| 15 |
+
if isinstance(spec, dict):
|
| 16 |
+
return spaces.Dict({k: _convert_spec_to_space(v) for k, v in spec.items()})
|
| 17 |
+
if isinstance(spec, specs.DiscreteArray):
|
| 18 |
+
return spaces.Discrete(spec.num_values)
|
| 19 |
+
elif isinstance(spec, specs.BoundedArray):
|
| 20 |
+
return spaces.Box(
|
| 21 |
+
low=np.asscalar(spec.minimum),
|
| 22 |
+
high=np.asscalar(spec.maximum),
|
| 23 |
+
shape=spec.shape,
|
| 24 |
+
dtype=spec.dtype,
|
| 25 |
+
)
|
| 26 |
+
elif isinstance(spec, specs.Array):
|
| 27 |
+
return spaces.Box(
|
| 28 |
+
low=-float("inf"), high=float("inf"), shape=spec.shape, dtype=spec.dtype
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
raise NotImplementedError(
|
| 32 |
+
(
|
| 33 |
+
"Could not convert `Array` spec of type {} to Gym space. "
|
| 34 |
+
"Attempted to convert: {}"
|
| 35 |
+
).format(type(spec), spec)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@PublicAPI
|
| 40 |
+
class DMEnv(gym.Env):
|
| 41 |
+
"""A `gym.Env` wrapper for the `dm_env` API."""
|
| 42 |
+
|
| 43 |
+
metadata = {"render.modes": ["rgb_array"]}
|
| 44 |
+
|
| 45 |
+
def __init__(self, dm_env):
|
| 46 |
+
super(DMEnv, self).__init__()
|
| 47 |
+
self._env = dm_env
|
| 48 |
+
self._prev_obs = None
|
| 49 |
+
|
| 50 |
+
if specs is None:
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
(
|
| 53 |
+
"The `specs` module from `dm_env` was not imported. Make sure "
|
| 54 |
+
"`dm_env` is installed and visible in the current python "
|
| 55 |
+
"environment."
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def step(self, action):
|
| 60 |
+
ts = self._env.step(action)
|
| 61 |
+
|
| 62 |
+
reward = ts.reward
|
| 63 |
+
if reward is None:
|
| 64 |
+
reward = 0.0
|
| 65 |
+
|
| 66 |
+
return ts.observation, reward, ts.last(), False, {"discount": ts.discount}
|
| 67 |
+
|
| 68 |
+
def reset(self, *, seed=None, options=None):
|
| 69 |
+
ts = self._env.reset()
|
| 70 |
+
return ts.observation, {}
|
| 71 |
+
|
| 72 |
+
def render(self, mode="rgb_array"):
|
| 73 |
+
if self._prev_obs is None:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
"Environment not started. Make sure to reset before rendering."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if mode == "rgb_array":
|
| 79 |
+
return self._prev_obs
|
| 80 |
+
else:
|
| 81 |
+
raise NotImplementedError("Render mode '{}' is not supported.".format(mode))
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def action_space(self):
|
| 85 |
+
spec = self._env.action_spec()
|
| 86 |
+
return _convert_spec_to_space(spec)
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def observation_space(self):
|
| 90 |
+
spec = self._env.observation_spec()
|
| 91 |
+
return _convert_spec_to_space(spec)
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def reward_range(self):
|
| 95 |
+
spec = self._env.reward_spec()
|
| 96 |
+
if isinstance(spec, specs.BoundedArray):
|
| 97 |
+
return spec.minimum, spec.maximum
|
| 98 |
+
return -float("inf"), float("inf")
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/group_agents_wrapper.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 6 |
+
from ray.rllib.utils.annotations import DeveloperAPI
|
| 7 |
+
from ray.rllib.utils.typing import AgentID
|
| 8 |
+
|
| 9 |
+
# info key for the individual rewards of an agent, for example:
|
| 10 |
+
# info: {
|
| 11 |
+
# group_1: {
|
| 12 |
+
# _group_rewards: [5, -1, 1], # 3 agents in this group
|
| 13 |
+
# }
|
| 14 |
+
# }
|
| 15 |
+
GROUP_REWARDS = "_group_rewards"
|
| 16 |
+
|
| 17 |
+
# info key for the individual infos of an agent, for example:
|
| 18 |
+
# info: {
|
| 19 |
+
# group_1: {
|
| 20 |
+
# _group_infos: [{"foo": ...}, {}], # 2 agents in this group
|
| 21 |
+
# }
|
| 22 |
+
# }
|
| 23 |
+
GROUP_INFO = "_group_info"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@DeveloperAPI
|
| 27 |
+
class GroupAgentsWrapper(MultiAgentEnv):
|
| 28 |
+
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
|
| 29 |
+
|
| 30 |
+
See multi_agent_env.py for the specification of groups.
|
| 31 |
+
|
| 32 |
+
This API is experimental.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
env: MultiAgentEnv,
|
| 38 |
+
groups: Dict[str, List[AgentID]],
|
| 39 |
+
obs_space: Optional[gym.Space] = None,
|
| 40 |
+
act_space: Optional[gym.Space] = None,
|
| 41 |
+
):
|
| 42 |
+
"""Wrap an existing MultiAgentEnv to group agent ID together.
|
| 43 |
+
|
| 44 |
+
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
env: The env to wrap and whose agent IDs to group into new agents.
|
| 48 |
+
groups: Mapping from group id to a list of the agent ids
|
| 49 |
+
of group members. If an agent id is not present in any group
|
| 50 |
+
value, it will be left ungrouped. The group id becomes a new agent ID
|
| 51 |
+
in the final environment.
|
| 52 |
+
obs_space: Optional observation space for the grouped
|
| 53 |
+
env. Must be a tuple space. If not provided, will infer this to be a
|
| 54 |
+
Tuple of n individual agents spaces (n=num agents in a group).
|
| 55 |
+
act_space: Optional action space for the grouped env.
|
| 56 |
+
Must be a tuple space. If not provided, will infer this to be a Tuple
|
| 57 |
+
of n individual agents spaces (n=num agents in a group).
|
| 58 |
+
"""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.env = env
|
| 61 |
+
self.groups = groups
|
| 62 |
+
self.agent_id_to_group = {}
|
| 63 |
+
for group_id, agent_ids in groups.items():
|
| 64 |
+
for agent_id in agent_ids:
|
| 65 |
+
if agent_id in self.agent_id_to_group:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"Agent id {} is in multiple groups".format(agent_id)
|
| 68 |
+
)
|
| 69 |
+
self.agent_id_to_group[agent_id] = group_id
|
| 70 |
+
if obs_space is not None:
|
| 71 |
+
self.observation_space = obs_space
|
| 72 |
+
if act_space is not None:
|
| 73 |
+
self.action_space = act_space
|
| 74 |
+
for group_id in groups.keys():
|
| 75 |
+
self._agent_ids.add(group_id)
|
| 76 |
+
|
| 77 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 78 |
+
obs, info = self.env.reset(seed=seed, options=options)
|
| 79 |
+
|
| 80 |
+
return (
|
| 81 |
+
self._group_items(obs),
|
| 82 |
+
self._group_items(
|
| 83 |
+
info,
|
| 84 |
+
agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())},
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def step(self, action_dict):
|
| 89 |
+
# Ungroup and send actions.
|
| 90 |
+
action_dict = self._ungroup_items(action_dict)
|
| 91 |
+
obs, rewards, terminateds, truncateds, infos = self.env.step(action_dict)
|
| 92 |
+
|
| 93 |
+
# Apply grouping transforms to the env outputs
|
| 94 |
+
obs = self._group_items(obs)
|
| 95 |
+
rewards = self._group_items(rewards, agg_fn=lambda gvals: list(gvals.values()))
|
| 96 |
+
# Only if all of the agents are terminated, the group is terminated as well.
|
| 97 |
+
terminateds = self._group_items(
|
| 98 |
+
terminateds, agg_fn=lambda gvals: all(gvals.values())
|
| 99 |
+
)
|
| 100 |
+
# If all of the agents are truncated, the group is truncated as well.
|
| 101 |
+
truncateds = self._group_items(
|
| 102 |
+
truncateds,
|
| 103 |
+
agg_fn=lambda gvals: all(gvals.values()),
|
| 104 |
+
)
|
| 105 |
+
infos = self._group_items(
|
| 106 |
+
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Aggregate rewards, but preserve the original values in infos.
|
| 110 |
+
for agent_id, rew in rewards.items():
|
| 111 |
+
if isinstance(rew, list):
|
| 112 |
+
rewards[agent_id] = sum(rew)
|
| 113 |
+
if agent_id not in infos:
|
| 114 |
+
infos[agent_id] = {}
|
| 115 |
+
infos[agent_id][GROUP_REWARDS] = rew
|
| 116 |
+
|
| 117 |
+
return obs, rewards, terminateds, truncateds, infos
|
| 118 |
+
|
| 119 |
+
def _ungroup_items(self, items):
|
| 120 |
+
out = {}
|
| 121 |
+
for agent_id, value in items.items():
|
| 122 |
+
if agent_id in self.groups:
|
| 123 |
+
assert len(value) == len(self.groups[agent_id]), (
|
| 124 |
+
agent_id,
|
| 125 |
+
value,
|
| 126 |
+
self.groups,
|
| 127 |
+
)
|
| 128 |
+
for a, v in zip(self.groups[agent_id], value):
|
| 129 |
+
out[a] = v
|
| 130 |
+
else:
|
| 131 |
+
out[agent_id] = value
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
def _group_items(self, items, agg_fn=None):
|
| 135 |
+
if agg_fn is None:
|
| 136 |
+
agg_fn = lambda gvals: list(gvals.values()) # noqa: E731
|
| 137 |
+
|
| 138 |
+
grouped_items = {}
|
| 139 |
+
for agent_id, item in items.items():
|
| 140 |
+
if agent_id in self.agent_id_to_group:
|
| 141 |
+
group_id = self.agent_id_to_group[agent_id]
|
| 142 |
+
if group_id in grouped_items:
|
| 143 |
+
continue # already added
|
| 144 |
+
group_out = OrderedDict()
|
| 145 |
+
for a in self.groups[group_id]:
|
| 146 |
+
if a in items:
|
| 147 |
+
group_out[a] = items[a]
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"Missing member of group {}: {}: {}".format(
|
| 151 |
+
group_id, a, items
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
grouped_items[group_id] = agg_fn(group_out)
|
| 155 |
+
else:
|
| 156 |
+
grouped_items[agent_id] = item
|
| 157 |
+
return grouped_items
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/multi_agent_env_compatibility.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 4 |
+
from ray.rllib.utils.typing import MultiAgentDict
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MultiAgentEnvCompatibility(MultiAgentEnv):
|
| 8 |
+
"""A wrapper converting MultiAgentEnv from old gym API to the new one.
|
| 9 |
+
|
| 10 |
+
"Old API" refers to step() method returning (observation, reward, done, info),
|
| 11 |
+
and reset() only retuning the observation.
|
| 12 |
+
"New API" refers to step() method returning (observation, reward, terminated,
|
| 13 |
+
truncated, info) and reset() returning (observation, info).
|
| 14 |
+
|
| 15 |
+
Known limitations:
|
| 16 |
+
- Environments that use `self.np_random` might not work as expected.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, old_env, render_mode: Optional[str] = None):
|
| 20 |
+
"""A wrapper which converts old-style envs to valid modern envs.
|
| 21 |
+
|
| 22 |
+
Some information may be lost in the conversion, so we recommend updating your
|
| 23 |
+
environment.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
old_env: The old MultiAgentEnv to wrap. Implemented with the old API.
|
| 27 |
+
render_mode: The render mode to use when rendering the environment,
|
| 28 |
+
passed automatically to `env.render()`.
|
| 29 |
+
"""
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
|
| 33 |
+
self.render_mode = render_mode
|
| 34 |
+
self.reward_range = getattr(old_env, "reward_range", None)
|
| 35 |
+
self.spec = getattr(old_env, "spec", None)
|
| 36 |
+
self.env = old_env
|
| 37 |
+
|
| 38 |
+
self.observation_space = old_env.observation_space
|
| 39 |
+
self.action_space = old_env.action_space
|
| 40 |
+
|
| 41 |
+
def reset(
|
| 42 |
+
self, *, seed: Optional[int] = None, options: Optional[dict] = None
|
| 43 |
+
) -> Tuple[MultiAgentDict, MultiAgentDict]:
|
| 44 |
+
# Use old `seed()` method.
|
| 45 |
+
if seed is not None:
|
| 46 |
+
self.env.seed(seed)
|
| 47 |
+
# Options are ignored
|
| 48 |
+
|
| 49 |
+
if self.render_mode == "human":
|
| 50 |
+
self.render()
|
| 51 |
+
|
| 52 |
+
obs = self.env.reset()
|
| 53 |
+
infos = {k: {} for k in obs.keys()}
|
| 54 |
+
return obs, infos
|
| 55 |
+
|
| 56 |
+
def step(
|
| 57 |
+
self, action
|
| 58 |
+
) -> Tuple[
|
| 59 |
+
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
|
| 60 |
+
]:
|
| 61 |
+
obs, rewards, terminateds, infos = self.env.step(action)
|
| 62 |
+
|
| 63 |
+
# Truncated should always be False by default.
|
| 64 |
+
truncateds = {k: False for k in terminateds.keys()}
|
| 65 |
+
|
| 66 |
+
return obs, rewards, terminateds, truncateds, infos
|
| 67 |
+
|
| 68 |
+
def render(self):
|
| 69 |
+
# Use the old `render()` API, where we have to pass in the mode to each call.
|
| 70 |
+
return self.env.render(mode=self.render_mode)
|
| 71 |
+
|
| 72 |
+
def close(self):
|
| 73 |
+
self.env.close()
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/open_spiel.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import gymnasium as gym
|
| 5 |
+
|
| 6 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 7 |
+
from ray.rllib.env.utils import try_import_pyspiel
|
| 8 |
+
|
| 9 |
+
pyspiel = try_import_pyspiel(error=True)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OpenSpielEnv(MultiAgentEnv):
|
| 13 |
+
def __init__(self, env):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.env = env
|
| 16 |
+
self.agents = self.possible_agents = list(range(self.env.num_players()))
|
| 17 |
+
# Store the open-spiel game type.
|
| 18 |
+
self.type = self.env.get_type()
|
| 19 |
+
# Stores the current open-spiel game state.
|
| 20 |
+
self.state = None
|
| 21 |
+
|
| 22 |
+
self.observation_space = gym.spaces.Dict(
|
| 23 |
+
{
|
| 24 |
+
aid: gym.spaces.Box(
|
| 25 |
+
float("-inf"),
|
| 26 |
+
float("inf"),
|
| 27 |
+
(self.env.observation_tensor_size(),),
|
| 28 |
+
dtype=np.float32,
|
| 29 |
+
)
|
| 30 |
+
for aid in self.possible_agents
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
self.action_space = gym.spaces.Dict(
|
| 34 |
+
{
|
| 35 |
+
aid: gym.spaces.Discrete(self.env.num_distinct_actions())
|
| 36 |
+
for aid in self.possible_agents
|
| 37 |
+
}
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 41 |
+
self.state = self.env.new_initial_state()
|
| 42 |
+
return self._get_obs(), {}
|
| 43 |
+
|
| 44 |
+
def step(self, action):
|
| 45 |
+
# Before applying action(s), there could be chance nodes.
|
| 46 |
+
# E.g. if env has to figure out, which agent's action should get
|
| 47 |
+
# resolved first in a simultaneous node.
|
| 48 |
+
self._solve_chance_nodes()
|
| 49 |
+
penalties = {}
|
| 50 |
+
|
| 51 |
+
# Sequential game:
|
| 52 |
+
if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
|
| 53 |
+
curr_player = self.state.current_player()
|
| 54 |
+
assert curr_player in action
|
| 55 |
+
try:
|
| 56 |
+
self.state.apply_action(action[curr_player])
|
| 57 |
+
# TODO: (sven) resolve this hack by publishing legal actions
|
| 58 |
+
# with each step.
|
| 59 |
+
except pyspiel.SpielError:
|
| 60 |
+
self.state.apply_action(np.random.choice(self.state.legal_actions()))
|
| 61 |
+
penalties[curr_player] = -0.1
|
| 62 |
+
|
| 63 |
+
# Compile rewards dict.
|
| 64 |
+
rewards = {ag: r for ag, r in enumerate(self.state.returns())}
|
| 65 |
+
# Simultaneous game.
|
| 66 |
+
else:
|
| 67 |
+
assert self.state.current_player() == -2
|
| 68 |
+
# Apparently, this works, even if one or more actions are invalid.
|
| 69 |
+
self.state.apply_actions([action[ag] for ag in range(self.num_agents)])
|
| 70 |
+
|
| 71 |
+
# Now that we have applied all actions, get the next obs.
|
| 72 |
+
obs = self._get_obs()
|
| 73 |
+
|
| 74 |
+
# Compile rewards dict and add the accumulated penalties
|
| 75 |
+
# (for taking invalid actions).
|
| 76 |
+
rewards = {ag: r for ag, r in enumerate(self.state.returns())}
|
| 77 |
+
for ag, penalty in penalties.items():
|
| 78 |
+
rewards[ag] += penalty
|
| 79 |
+
|
| 80 |
+
# Are we done?
|
| 81 |
+
is_terminated = self.state.is_terminal()
|
| 82 |
+
terminateds = dict(
|
| 83 |
+
{ag: is_terminated for ag in range(self.num_agents)},
|
| 84 |
+
**{"__all__": is_terminated}
|
| 85 |
+
)
|
| 86 |
+
truncateds = dict(
|
| 87 |
+
{ag: False for ag in range(self.num_agents)}, **{"__all__": False}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return obs, rewards, terminateds, truncateds, {}
|
| 91 |
+
|
| 92 |
+
def render(self, mode=None) -> None:
|
| 93 |
+
if mode == "human":
|
| 94 |
+
print(self.state)
|
| 95 |
+
|
| 96 |
+
def _get_obs(self):
|
| 97 |
+
# Before calculating an observation, there could be chance nodes
|
| 98 |
+
# (that may have an effect on the actual observations).
|
| 99 |
+
# E.g. After reset, figure out initial (random) positions of the
|
| 100 |
+
# agents.
|
| 101 |
+
self._solve_chance_nodes()
|
| 102 |
+
|
| 103 |
+
if self.state.is_terminal():
|
| 104 |
+
return {}
|
| 105 |
+
|
| 106 |
+
# Sequential game:
|
| 107 |
+
if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
|
| 108 |
+
curr_player = self.state.current_player()
|
| 109 |
+
return {
|
| 110 |
+
curr_player: np.reshape(self.state.observation_tensor(), [-1]).astype(
|
| 111 |
+
np.float32
|
| 112 |
+
)
|
| 113 |
+
}
|
| 114 |
+
# Simultaneous game.
|
| 115 |
+
else:
|
| 116 |
+
assert self.state.current_player() == -2
|
| 117 |
+
return {
|
| 118 |
+
ag: np.reshape(self.state.observation_tensor(ag), [-1]).astype(
|
| 119 |
+
np.float32
|
| 120 |
+
)
|
| 121 |
+
for ag in range(self.num_agents)
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def _solve_chance_nodes(self):
|
| 125 |
+
# Chance node(s): Sample a (non-player) action and apply.
|
| 126 |
+
while self.state.is_chance_node():
|
| 127 |
+
assert self.state.current_player() == -1
|
| 128 |
+
actions, probs = zip(*self.state.chance_outcomes())
|
| 129 |
+
action = np.random.choice(actions, p=probs)
|
| 130 |
+
self.state.apply_action(action)
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/pettingzoo_env.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
|
| 5 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 6 |
+
from ray.rllib.utils.annotations import PublicAPI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@PublicAPI
|
| 10 |
+
class PettingZooEnv(MultiAgentEnv):
|
| 11 |
+
"""An interface to the PettingZoo MARL environment library.
|
| 12 |
+
|
| 13 |
+
See: https://github.com/Farama-Foundation/PettingZoo
|
| 14 |
+
|
| 15 |
+
Inherits from MultiAgentEnv and exposes a given AEC
|
| 16 |
+
(actor-environment-cycle) game from the PettingZoo project via the
|
| 17 |
+
MultiAgentEnv public API.
|
| 18 |
+
|
| 19 |
+
Note that the wrapper has the following important limitation:
|
| 20 |
+
|
| 21 |
+
Environments are positive sum games (-> Agents are expected to cooperate
|
| 22 |
+
to maximize reward). This isn't a hard restriction, it just that
|
| 23 |
+
standard algorithms aren't expected to work well in highly competitive
|
| 24 |
+
games.
|
| 25 |
+
|
| 26 |
+
Also note that the earlier existing restriction of all agents having the same
|
| 27 |
+
observation- and action spaces has been lifted. Different agents can now have
|
| 28 |
+
different spaces and the entire environment's e.g. `self.action_space` is a Dict
|
| 29 |
+
mapping agent IDs to individual agents' spaces. Same for `self.observation_space`.
|
| 30 |
+
|
| 31 |
+
.. testcode::
|
| 32 |
+
:skipif: True
|
| 33 |
+
|
| 34 |
+
from pettingzoo.butterfly import prison_v3
|
| 35 |
+
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
| 36 |
+
env = PettingZooEnv(prison_v3.env())
|
| 37 |
+
obs, infos = env.reset()
|
| 38 |
+
# only returns the observation for the agent which should be stepping
|
| 39 |
+
print(obs)
|
| 40 |
+
|
| 41 |
+
.. testoutput::
|
| 42 |
+
|
| 43 |
+
{
|
| 44 |
+
'prisoner_0': array([[[0, 0, 0],
|
| 45 |
+
[0, 0, 0],
|
| 46 |
+
[0, 0, 0],
|
| 47 |
+
...,
|
| 48 |
+
[0, 0, 0],
|
| 49 |
+
[0, 0, 0],
|
| 50 |
+
[0, 0, 0]]], dtype=uint8)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.. testcode::
|
| 54 |
+
:skipif: True
|
| 55 |
+
|
| 56 |
+
obs, rewards, terminateds, truncateds, infos = env.step({
|
| 57 |
+
"prisoner_0": 1
|
| 58 |
+
})
|
| 59 |
+
# only returns the observation, reward, info, etc, for
|
| 60 |
+
# the agent who's turn is next.
|
| 61 |
+
print(obs)
|
| 62 |
+
|
| 63 |
+
.. testoutput::
|
| 64 |
+
|
| 65 |
+
{
|
| 66 |
+
'prisoner_1': array([[[0, 0, 0],
|
| 67 |
+
[0, 0, 0],
|
| 68 |
+
[0, 0, 0],
|
| 69 |
+
...,
|
| 70 |
+
[0, 0, 0],
|
| 71 |
+
[0, 0, 0],
|
| 72 |
+
[0, 0, 0]]], dtype=uint8)
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.. testcode::
|
| 76 |
+
:skipif: True
|
| 77 |
+
|
| 78 |
+
print(rewards)
|
| 79 |
+
|
| 80 |
+
.. testoutput::
|
| 81 |
+
|
| 82 |
+
{
|
| 83 |
+
'prisoner_1': 0
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
.. testcode::
|
| 87 |
+
:skipif: True
|
| 88 |
+
|
| 89 |
+
print(terminateds)
|
| 90 |
+
|
| 91 |
+
.. testoutput::
|
| 92 |
+
|
| 93 |
+
{
|
| 94 |
+
'prisoner_1': False, '__all__': False
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
.. testcode::
|
| 98 |
+
:skipif: True
|
| 99 |
+
|
| 100 |
+
print(truncateds)
|
| 101 |
+
|
| 102 |
+
.. testoutput::
|
| 103 |
+
|
| 104 |
+
{
|
| 105 |
+
'prisoner_1': False, '__all__': False
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.. testcode::
|
| 109 |
+
:skipif: True
|
| 110 |
+
|
| 111 |
+
print(infos)
|
| 112 |
+
|
| 113 |
+
.. testoutput::
|
| 114 |
+
|
| 115 |
+
{
|
| 116 |
+
'prisoner_1': {'map_tuple': (1, 0)}
|
| 117 |
+
}
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, env):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.env = env
|
| 123 |
+
env.reset()
|
| 124 |
+
|
| 125 |
+
self._agent_ids = set(self.env.agents)
|
| 126 |
+
|
| 127 |
+
self.observation_space = gym.spaces.Dict(
|
| 128 |
+
{aid: self.env.observation_space(aid) for aid in self._agent_ids}
|
| 129 |
+
)
|
| 130 |
+
self.action_space = gym.spaces.Dict(
|
| 131 |
+
{aid: self.env.action_space(aid) for aid in self._agent_ids}
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 135 |
+
info = self.env.reset(seed=seed, options=options)
|
| 136 |
+
return (
|
| 137 |
+
{self.env.agent_selection: self.env.observe(self.env.agent_selection)},
|
| 138 |
+
info or {},
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def step(self, action):
|
| 142 |
+
self.env.step(action[self.env.agent_selection])
|
| 143 |
+
obs_d = {}
|
| 144 |
+
rew_d = {}
|
| 145 |
+
terminated_d = {}
|
| 146 |
+
truncated_d = {}
|
| 147 |
+
info_d = {}
|
| 148 |
+
while self.env.agents:
|
| 149 |
+
obs, rew, terminated, truncated, info = self.env.last()
|
| 150 |
+
agent_id = self.env.agent_selection
|
| 151 |
+
obs_d[agent_id] = obs
|
| 152 |
+
rew_d[agent_id] = rew
|
| 153 |
+
terminated_d[agent_id] = terminated
|
| 154 |
+
truncated_d[agent_id] = truncated
|
| 155 |
+
info_d[agent_id] = info
|
| 156 |
+
if (
|
| 157 |
+
self.env.terminations[self.env.agent_selection]
|
| 158 |
+
or self.env.truncations[self.env.agent_selection]
|
| 159 |
+
):
|
| 160 |
+
self.env.step(None)
|
| 161 |
+
else:
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
all_gone = not self.env.agents
|
| 165 |
+
terminated_d["__all__"] = all_gone and all(terminated_d.values())
|
| 166 |
+
truncated_d["__all__"] = all_gone and all(truncated_d.values())
|
| 167 |
+
|
| 168 |
+
return obs_d, rew_d, terminated_d, truncated_d, info_d
|
| 169 |
+
|
| 170 |
+
def close(self):
|
| 171 |
+
self.env.close()
|
| 172 |
+
|
| 173 |
+
def render(self):
|
| 174 |
+
return self.env.render(self.render_mode)
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def get_sub_environments(self):
|
| 178 |
+
return self.env.unwrapped
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@PublicAPI
|
| 182 |
+
class ParallelPettingZooEnv(MultiAgentEnv):
|
| 183 |
+
def __init__(self, env):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.par_env = env
|
| 186 |
+
self.par_env.reset()
|
| 187 |
+
self._agent_ids = set(self.par_env.agents)
|
| 188 |
+
|
| 189 |
+
self.observation_space = gym.spaces.Dict(
|
| 190 |
+
{aid: self.par_env.observation_space(aid) for aid in self._agent_ids}
|
| 191 |
+
)
|
| 192 |
+
self.action_space = gym.spaces.Dict(
|
| 193 |
+
{aid: self.par_env.action_space(aid) for aid in self._agent_ids}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 197 |
+
obs, info = self.par_env.reset(seed=seed, options=options)
|
| 198 |
+
return obs, info or {}
|
| 199 |
+
|
| 200 |
+
def step(self, action_dict):
|
| 201 |
+
obss, rews, terminateds, truncateds, infos = self.par_env.step(action_dict)
|
| 202 |
+
terminateds["__all__"] = all(terminateds.values())
|
| 203 |
+
truncateds["__all__"] = all(truncateds.values())
|
| 204 |
+
return obss, rews, terminateds, truncateds, infos
|
| 205 |
+
|
| 206 |
+
def close(self):
|
| 207 |
+
self.par_env.close()
|
| 208 |
+
|
| 209 |
+
def render(self):
|
| 210 |
+
return self.par_env.render(self.render_mode)
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def get_sub_environments(self):
|
| 214 |
+
return self.par_env.unwrapped
|
.venv/lib/python3.11/site-packages/ray/rllib/env/wrappers/unity3d_env.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace
|
| 2 |
+
import logging
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
from typing import Callable, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 9 |
+
from ray.rllib.policy.policy import PolicySpec
|
| 10 |
+
from ray.rllib.utils.annotations import PublicAPI
|
| 11 |
+
from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@PublicAPI
|
| 17 |
+
class Unity3DEnv(MultiAgentEnv):
|
| 18 |
+
"""A MultiAgentEnv representing a single Unity3D game instance.
|
| 19 |
+
|
| 20 |
+
For an example on how to use this Env with a running Unity3D editor
|
| 21 |
+
or with a compiled game, see:
|
| 22 |
+
`rllib/examples/unity3d_env_local.py`
|
| 23 |
+
For an example on how to use it inside a Unity game client, which
|
| 24 |
+
connects to an RLlib Policy server, see:
|
| 25 |
+
`rllib/examples/envs/external_envs/unity3d_[client|server].py`
|
| 26 |
+
|
| 27 |
+
Supports all Unity3D (MLAgents) examples, multi- or single-agent and
|
| 28 |
+
gets converted automatically into an ExternalMultiAgentEnv, when used
|
| 29 |
+
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
# Default base port when connecting directly to the Editor
|
| 33 |
+
_BASE_PORT_EDITOR = 5004
|
| 34 |
+
# Default base port when connecting to a compiled environment
|
| 35 |
+
_BASE_PORT_ENVIRONMENT = 5005
|
| 36 |
+
# The worker_id for each environment instance
|
| 37 |
+
_WORKER_ID = 0
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
file_name: str = None,
|
| 42 |
+
port: Optional[int] = None,
|
| 43 |
+
seed: int = 0,
|
| 44 |
+
no_graphics: bool = False,
|
| 45 |
+
timeout_wait: int = 300,
|
| 46 |
+
episode_horizon: int = 1000,
|
| 47 |
+
):
|
| 48 |
+
"""Initializes a Unity3DEnv object.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
file_name (Optional[str]): Name of the Unity game binary.
|
| 52 |
+
If None, will assume a locally running Unity3D editor
|
| 53 |
+
to be used, instead.
|
| 54 |
+
port (Optional[int]): Port number to connect to Unity environment.
|
| 55 |
+
seed: A random seed value to use for the Unity3D game.
|
| 56 |
+
no_graphics: Whether to run the Unity3D simulator in
|
| 57 |
+
no-graphics mode. Default: False.
|
| 58 |
+
timeout_wait: Time (in seconds) to wait for connection from
|
| 59 |
+
the Unity3D instance.
|
| 60 |
+
episode_horizon: A hard horizon to abide to. After at most
|
| 61 |
+
this many steps (per-agent episode `step()` calls), the
|
| 62 |
+
Unity3D game is reset and will start again (finishing the
|
| 63 |
+
multi-agent episode that the game represents).
|
| 64 |
+
Note: The game itself may contain its own episode length
|
| 65 |
+
limits, which are always obeyed (on top of this value here).
|
| 66 |
+
"""
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
if file_name is None:
|
| 70 |
+
print(
|
| 71 |
+
"No game binary provided, will use a running Unity editor "
|
| 72 |
+
"instead.\nMake sure you are pressing the Play (|>) button in "
|
| 73 |
+
"your editor to start."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
import mlagents_envs
|
| 77 |
+
from mlagents_envs.environment import UnityEnvironment
|
| 78 |
+
|
| 79 |
+
# Try connecting to the Unity3D game instance. If a port is blocked
|
| 80 |
+
port_ = None
|
| 81 |
+
while True:
|
| 82 |
+
# Sleep for random time to allow for concurrent startup of many
|
| 83 |
+
# environments (num_env_runners >> 1). Otherwise, would lead to port
|
| 84 |
+
# conflicts sometimes.
|
| 85 |
+
if port_ is not None:
|
| 86 |
+
time.sleep(random.randint(1, 10))
|
| 87 |
+
port_ = port or (
|
| 88 |
+
self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR
|
| 89 |
+
)
|
| 90 |
+
# cache the worker_id and
|
| 91 |
+
# increase it for the next environment
|
| 92 |
+
worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
|
| 93 |
+
Unity3DEnv._WORKER_ID += 1
|
| 94 |
+
try:
|
| 95 |
+
self.unity_env = UnityEnvironment(
|
| 96 |
+
file_name=file_name,
|
| 97 |
+
worker_id=worker_id_,
|
| 98 |
+
base_port=port_,
|
| 99 |
+
seed=seed,
|
| 100 |
+
no_graphics=no_graphics,
|
| 101 |
+
timeout_wait=timeout_wait,
|
| 102 |
+
)
|
| 103 |
+
print("Created UnityEnvironment for port {}".format(port_ + worker_id_))
|
| 104 |
+
except mlagents_envs.exception.UnityWorkerInUseException:
|
| 105 |
+
pass
|
| 106 |
+
else:
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
# ML-Agents API version.
|
| 110 |
+
self.api_version = self.unity_env.API_VERSION.split(".")
|
| 111 |
+
self.api_version = [int(s) for s in self.api_version]
|
| 112 |
+
|
| 113 |
+
# Reset entire env every this number of step calls.
|
| 114 |
+
self.episode_horizon = episode_horizon
|
| 115 |
+
# Keep track of how many times we have called `step` so far.
|
| 116 |
+
self.episode_timesteps = 0
|
| 117 |
+
|
| 118 |
+
def step(
|
| 119 |
+
self, action_dict: MultiAgentDict
|
| 120 |
+
) -> Tuple[
|
| 121 |
+
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
|
| 122 |
+
]:
|
| 123 |
+
"""Performs one multi-agent step through the game.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
action_dict: Multi-agent action dict with:
|
| 127 |
+
keys=agent identifier consisting of
|
| 128 |
+
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
|
| 129 |
+
[Agent index, a unique MLAgent-assigned index per single agent]
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
tuple:
|
| 133 |
+
- obs: Multi-agent observation dict.
|
| 134 |
+
Only those observations for which to get new actions are
|
| 135 |
+
returned.
|
| 136 |
+
- rewards: Rewards dict matching `obs`.
|
| 137 |
+
- dones: Done dict with only an __all__ multi-agent entry in
|
| 138 |
+
it. __all__=True, if episode is done for all agents.
|
| 139 |
+
- infos: An (empty) info dict.
|
| 140 |
+
"""
|
| 141 |
+
from mlagents_envs.base_env import ActionTuple
|
| 142 |
+
|
| 143 |
+
# Set only the required actions (from the DecisionSteps) in Unity3D.
|
| 144 |
+
all_agents = []
|
| 145 |
+
for behavior_name in self.unity_env.behavior_specs:
|
| 146 |
+
# New ML-Agents API: Set all agents actions at the same time
|
| 147 |
+
# via an ActionTuple. Since API v1.4.0.
|
| 148 |
+
if self.api_version[0] > 1 or (
|
| 149 |
+
self.api_version[0] == 1 and self.api_version[1] >= 4
|
| 150 |
+
):
|
| 151 |
+
actions = []
|
| 152 |
+
for agent_id in self.unity_env.get_steps(behavior_name)[0].agent_id:
|
| 153 |
+
key = behavior_name + "_{}".format(agent_id)
|
| 154 |
+
all_agents.append(key)
|
| 155 |
+
actions.append(action_dict[key])
|
| 156 |
+
if actions:
|
| 157 |
+
if actions[0].dtype == np.float32:
|
| 158 |
+
action_tuple = ActionTuple(continuous=np.array(actions))
|
| 159 |
+
else:
|
| 160 |
+
action_tuple = ActionTuple(discrete=np.array(actions))
|
| 161 |
+
self.unity_env.set_actions(behavior_name, action_tuple)
|
| 162 |
+
# Old behavior: Do not use an ActionTuple and set each agent's
|
| 163 |
+
# action individually.
|
| 164 |
+
else:
|
| 165 |
+
for agent_id in self.unity_env.get_steps(behavior_name)[
|
| 166 |
+
0
|
| 167 |
+
].agent_id_to_index.keys():
|
| 168 |
+
key = behavior_name + "_{}".format(agent_id)
|
| 169 |
+
all_agents.append(key)
|
| 170 |
+
self.unity_env.set_action_for_agent(
|
| 171 |
+
behavior_name, agent_id, action_dict[key]
|
| 172 |
+
)
|
| 173 |
+
# Do the step.
|
| 174 |
+
self.unity_env.step()
|
| 175 |
+
|
| 176 |
+
obs, rewards, terminateds, truncateds, infos = self._get_step_results()
|
| 177 |
+
|
| 178 |
+
# Global horizon reached? -> Return __all__ truncated=True, so user
|
| 179 |
+
# can reset. Set all agents' individual `truncated` to True as well.
|
| 180 |
+
self.episode_timesteps += 1
|
| 181 |
+
if self.episode_timesteps > self.episode_horizon:
|
| 182 |
+
return (
|
| 183 |
+
obs,
|
| 184 |
+
rewards,
|
| 185 |
+
terminateds,
|
| 186 |
+
dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}),
|
| 187 |
+
infos,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return obs, rewards, terminateds, truncateds, infos
|
| 191 |
+
|
| 192 |
+
def reset(
|
| 193 |
+
self, *, seed=None, options=None
|
| 194 |
+
) -> Tuple[MultiAgentDict, MultiAgentDict]:
|
| 195 |
+
"""Resets the entire Unity3D scene (a single multi-agent episode)."""
|
| 196 |
+
self.episode_timesteps = 0
|
| 197 |
+
self.unity_env.reset()
|
| 198 |
+
obs, _, _, _, infos = self._get_step_results()
|
| 199 |
+
return obs, infos
|
| 200 |
+
|
| 201 |
+
def _get_step_results(self):
|
| 202 |
+
"""Collects those agents' obs/rewards that have to act in next `step`.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Tuple:
|
| 206 |
+
obs: Multi-agent observation dict.
|
| 207 |
+
Only those observations for which to get new actions are
|
| 208 |
+
returned.
|
| 209 |
+
rewards: Rewards dict matching `obs`.
|
| 210 |
+
dones: Done dict with only an __all__ multi-agent entry in it.
|
| 211 |
+
__all__=True, if episode is done for all agents.
|
| 212 |
+
infos: An (empty) info dict.
|
| 213 |
+
"""
|
| 214 |
+
obs = {}
|
| 215 |
+
rewards = {}
|
| 216 |
+
infos = {}
|
| 217 |
+
for behavior_name in self.unity_env.behavior_specs:
|
| 218 |
+
decision_steps, terminal_steps = self.unity_env.get_steps(behavior_name)
|
| 219 |
+
# Important: Only update those sub-envs that are currently
|
| 220 |
+
# available within _env_state.
|
| 221 |
+
# Loop through all envs ("agents") and fill in, whatever
|
| 222 |
+
# information we have.
|
| 223 |
+
for agent_id, idx in decision_steps.agent_id_to_index.items():
|
| 224 |
+
key = behavior_name + "_{}".format(agent_id)
|
| 225 |
+
os = tuple(o[idx] for o in decision_steps.obs)
|
| 226 |
+
os = os[0] if len(os) == 1 else os
|
| 227 |
+
obs[key] = os
|
| 228 |
+
rewards[key] = (
|
| 229 |
+
decision_steps.reward[idx] + decision_steps.group_reward[idx]
|
| 230 |
+
)
|
| 231 |
+
for agent_id, idx in terminal_steps.agent_id_to_index.items():
|
| 232 |
+
key = behavior_name + "_{}".format(agent_id)
|
| 233 |
+
# Only overwrite rewards (last reward in episode), b/c obs
|
| 234 |
+
# here is the last obs (which doesn't matter anyways).
|
| 235 |
+
# Unless key does not exist in obs.
|
| 236 |
+
if key not in obs:
|
| 237 |
+
os = tuple(o[idx] for o in terminal_steps.obs)
|
| 238 |
+
obs[key] = os = os[0] if len(os) == 1 else os
|
| 239 |
+
rewards[key] = (
|
| 240 |
+
terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Only use dones if all agents are done, then we should do a reset.
|
| 244 |
+
return obs, rewards, {"__all__": False}, {"__all__": False}, infos
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def get_policy_configs_for_game(
|
| 248 |
+
game_name: str,
|
| 249 |
+
) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
|
| 250 |
+
|
| 251 |
+
# The RLlib server must know about the Spaces that the Client will be
|
| 252 |
+
# using inside Unity3D, up-front.
|
| 253 |
+
obs_spaces = {
|
| 254 |
+
# 3DBall.
|
| 255 |
+
"3DBall": Box(float("-inf"), float("inf"), (8,)),
|
| 256 |
+
# 3DBallHard.
|
| 257 |
+
"3DBallHard": Box(float("-inf"), float("inf"), (45,)),
|
| 258 |
+
# GridFoodCollector
|
| 259 |
+
"GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
|
| 260 |
+
# Pyramids.
|
| 261 |
+
"Pyramids": TupleSpace(
|
| 262 |
+
[
|
| 263 |
+
Box(float("-inf"), float("inf"), (56,)),
|
| 264 |
+
Box(float("-inf"), float("inf"), (56,)),
|
| 265 |
+
Box(float("-inf"), float("inf"), (56,)),
|
| 266 |
+
Box(float("-inf"), float("inf"), (4,)),
|
| 267 |
+
]
|
| 268 |
+
),
|
| 269 |
+
# SoccerTwos.
|
| 270 |
+
"SoccerPlayer": TupleSpace(
|
| 271 |
+
[
|
| 272 |
+
Box(-1.0, 1.0, (264,)),
|
| 273 |
+
Box(-1.0, 1.0, (72,)),
|
| 274 |
+
]
|
| 275 |
+
),
|
| 276 |
+
# SoccerStrikersVsGoalie.
|
| 277 |
+
"Goalie": Box(float("-inf"), float("inf"), (738,)),
|
| 278 |
+
"Striker": TupleSpace(
|
| 279 |
+
[
|
| 280 |
+
Box(float("-inf"), float("inf"), (231,)),
|
| 281 |
+
Box(float("-inf"), float("inf"), (63,)),
|
| 282 |
+
]
|
| 283 |
+
),
|
| 284 |
+
# Sorter.
|
| 285 |
+
"Sorter": TupleSpace(
|
| 286 |
+
[
|
| 287 |
+
Box(
|
| 288 |
+
float("-inf"),
|
| 289 |
+
float("inf"),
|
| 290 |
+
(
|
| 291 |
+
20,
|
| 292 |
+
23,
|
| 293 |
+
),
|
| 294 |
+
),
|
| 295 |
+
Box(float("-inf"), float("inf"), (10,)),
|
| 296 |
+
Box(float("-inf"), float("inf"), (8,)),
|
| 297 |
+
]
|
| 298 |
+
),
|
| 299 |
+
# Tennis.
|
| 300 |
+
"Tennis": Box(float("-inf"), float("inf"), (27,)),
|
| 301 |
+
# VisualHallway.
|
| 302 |
+
"VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)),
|
| 303 |
+
# Walker.
|
| 304 |
+
"Walker": Box(float("-inf"), float("inf"), (212,)),
|
| 305 |
+
# FoodCollector.
|
| 306 |
+
"FoodCollector": TupleSpace(
|
| 307 |
+
[
|
| 308 |
+
Box(float("-inf"), float("inf"), (49,)),
|
| 309 |
+
Box(float("-inf"), float("inf"), (4,)),
|
| 310 |
+
]
|
| 311 |
+
),
|
| 312 |
+
}
|
| 313 |
+
action_spaces = {
|
| 314 |
+
# 3DBall.
|
| 315 |
+
"3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
|
| 316 |
+
# 3DBallHard.
|
| 317 |
+
"3DBallHard": Box(-1.0, 1.0, (2,), dtype=np.float32),
|
| 318 |
+
# GridFoodCollector.
|
| 319 |
+
"GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
|
| 320 |
+
# Pyramids.
|
| 321 |
+
"Pyramids": MultiDiscrete([5]),
|
| 322 |
+
# SoccerStrikersVsGoalie.
|
| 323 |
+
"Goalie": MultiDiscrete([3, 3, 3]),
|
| 324 |
+
"Striker": MultiDiscrete([3, 3, 3]),
|
| 325 |
+
# SoccerTwos.
|
| 326 |
+
"SoccerPlayer": MultiDiscrete([3, 3, 3]),
|
| 327 |
+
# Sorter.
|
| 328 |
+
"Sorter": MultiDiscrete([3, 3, 3]),
|
| 329 |
+
# Tennis.
|
| 330 |
+
"Tennis": Box(-1.0, 1.0, (3,)),
|
| 331 |
+
# VisualHallway.
|
| 332 |
+
"VisualHallway": MultiDiscrete([5]),
|
| 333 |
+
# Walker.
|
| 334 |
+
"Walker": Box(-1.0, 1.0, (39,)),
|
| 335 |
+
# FoodCollector.
|
| 336 |
+
"FoodCollector": MultiDiscrete([3, 3, 3, 2]),
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Policies (Unity: "behaviors") and agent-to-policy mapping fns.
|
| 340 |
+
if game_name == "SoccerStrikersVsGoalie":
|
| 341 |
+
policies = {
|
| 342 |
+
"Goalie": PolicySpec(
|
| 343 |
+
observation_space=obs_spaces["Goalie"],
|
| 344 |
+
action_space=action_spaces["Goalie"],
|
| 345 |
+
),
|
| 346 |
+
"Striker": PolicySpec(
|
| 347 |
+
observation_space=obs_spaces["Striker"],
|
| 348 |
+
action_space=action_spaces["Striker"],
|
| 349 |
+
),
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
| 353 |
+
return "Striker" if "Striker" in agent_id else "Goalie"
|
| 354 |
+
|
| 355 |
+
elif game_name == "SoccerTwos":
|
| 356 |
+
policies = {
|
| 357 |
+
"PurplePlayer": PolicySpec(
|
| 358 |
+
observation_space=obs_spaces["SoccerPlayer"],
|
| 359 |
+
action_space=action_spaces["SoccerPlayer"],
|
| 360 |
+
),
|
| 361 |
+
"BluePlayer": PolicySpec(
|
| 362 |
+
observation_space=obs_spaces["SoccerPlayer"],
|
| 363 |
+
action_space=action_spaces["SoccerPlayer"],
|
| 364 |
+
),
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
| 368 |
+
return "BluePlayer" if "1_" in agent_id else "PurplePlayer"
|
| 369 |
+
|
| 370 |
+
else:
|
| 371 |
+
policies = {
|
| 372 |
+
game_name: PolicySpec(
|
| 373 |
+
observation_space=obs_spaces[game_name],
|
| 374 |
+
action_space=action_spaces[game_name],
|
| 375 |
+
),
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
| 379 |
+
return game_name
|
| 380 |
+
|
| 381 |
+
return policies, policy_mapping_fn
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/actions/__pycache__/nested_action_spaces.cpython-311.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|