Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/vtrace_torch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/impala_torch_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/vtrace_torch_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/impala_torch_learner.py +164 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/vtrace_torch_v2.py +168 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/centralized_critic.py +319 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/compute_adapted_gae_on_postprocess_trajectory.py +157 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_in_sequence.py +87 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_simultaneously.py +108 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/async_gym_env_vectorization.py +142 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/action_mask_env.py +42 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_crashing.py +182 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_sparse_rewards.py +51 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_dict_observation_space.py +74 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_large_observation_space.py +69 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_protobuf_observation_space.py +79 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cliff_walking_wall_env.py +71 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/correlated_actions_env.py +79 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/d4rl_env.py +46 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/debug_counter_env.py +92 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/deterministic_envs.py +13 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/dm_control_suite.py +131 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_using_remote_actor.py +63 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_with_subprocess.py +42 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/fast_image_env.py +20 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/gpu_requiring_env.py +37 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/look_and_push.py +65 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/memory_leaking_env.py +35 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/mock_env.py +220 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/bandit_envs_discrete.py +206 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/guess_the_number_game.py +89 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_chess.py +227 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_connect4.py +213 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py +125 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py +144 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/two_step_game.py +123 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/nested_space_repeat_after_me_env.py +50 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/parametric_actions_cartpole.py +145 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/random_env.py +125 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/recommender_system_envs_with_recsim.py +108 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_after_me_env.py +47 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_initial_obs_env.py +32 -0
- .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/simple_corridor.py +42 -0
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (709 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (5.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/vtrace_torch.cpython-311.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/impala_torch_learner.cpython-311.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/vtrace_torch_v2.cpython-311.pyc
ADDED
|
Binary file (8.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/impala_torch_learner.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ray.rllib.algorithms.impala.impala import IMPALAConfig
|
| 4 |
+
from ray.rllib.algorithms.impala.impala_learner import IMPALALearner
|
| 5 |
+
from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import (
|
| 6 |
+
vtrace_torch,
|
| 7 |
+
make_time_major,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.core.columns import Columns
|
| 10 |
+
from ray.rllib.core.learner.learner import ENTROPY_KEY
|
| 11 |
+
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
|
| 12 |
+
from ray.rllib.utils.annotations import override
|
| 13 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 14 |
+
from ray.rllib.utils.typing import ModuleID, TensorType
|
| 15 |
+
|
| 16 |
+
torch, nn = try_import_torch()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class IMPALATorchLearner(IMPALALearner, TorchLearner):
|
| 20 |
+
"""Implements the IMPALA loss function in torch."""
|
| 21 |
+
|
| 22 |
+
@override(TorchLearner)
|
| 23 |
+
def compute_loss_for_module(
|
| 24 |
+
self,
|
| 25 |
+
*,
|
| 26 |
+
module_id: ModuleID,
|
| 27 |
+
config: IMPALAConfig,
|
| 28 |
+
batch: Dict,
|
| 29 |
+
fwd_out: Dict[str, TensorType],
|
| 30 |
+
) -> TensorType:
|
| 31 |
+
module = self.module[module_id].unwrapped()
|
| 32 |
+
|
| 33 |
+
# TODO (sven): Now that we do the +1ts trick to be less vulnerable about
|
| 34 |
+
# bootstrap values at the end of rollouts in the new stack, we might make
|
| 35 |
+
# this a more flexible, configurable parameter for users, e.g.
|
| 36 |
+
# `v_trace_seq_len` (independent of `rollout_fragment_length`). Separation
|
| 37 |
+
# of concerns (sampling vs learning).
|
| 38 |
+
rollout_frag_or_episode_len = config.get_rollout_fragment_length()
|
| 39 |
+
recurrent_seq_len = batch.get("seq_lens")
|
| 40 |
+
|
| 41 |
+
loss_mask = batch[Columns.LOSS_MASK].float()
|
| 42 |
+
loss_mask_time_major = make_time_major(
|
| 43 |
+
loss_mask,
|
| 44 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 45 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 46 |
+
)
|
| 47 |
+
size_loss_mask = torch.sum(loss_mask)
|
| 48 |
+
|
| 49 |
+
# Behavior actions logp and target actions logp.
|
| 50 |
+
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
|
| 51 |
+
target_policy_dist = module.get_train_action_dist_cls().from_logits(
|
| 52 |
+
fwd_out[Columns.ACTION_DIST_INPUTS]
|
| 53 |
+
)
|
| 54 |
+
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])
|
| 55 |
+
|
| 56 |
+
# Values and bootstrap values.
|
| 57 |
+
values = module.compute_values(
|
| 58 |
+
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
|
| 59 |
+
)
|
| 60 |
+
values_time_major = make_time_major(
|
| 61 |
+
values,
|
| 62 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 63 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 64 |
+
)
|
| 65 |
+
assert Columns.VALUES_BOOTSTRAPPED not in batch
|
| 66 |
+
# Use as bootstrap values the vf-preds in the next "batch row", except
|
| 67 |
+
# for the very last row (which doesn't have a next row), for which the
|
| 68 |
+
# bootstrap value does not matter b/c it has a +1ts value at its end
|
| 69 |
+
# anyways. So we chose an arbitrary item (for simplicity of not having to
|
| 70 |
+
# move new data to the device).
|
| 71 |
+
bootstrap_values = torch.cat(
|
| 72 |
+
[
|
| 73 |
+
values_time_major[0][1:], # 0th ts values from "next row"
|
| 74 |
+
values_time_major[0][0:1], # <- can use any arbitrary value here
|
| 75 |
+
],
|
| 76 |
+
dim=0,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# TODO(Artur): In the old impala code, actions were unsqueezed if they were
|
| 80 |
+
# multi_discrete. Find out why and if we need to do the same here.
|
| 81 |
+
# actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
|
| 82 |
+
target_actions_logp_time_major = make_time_major(
|
| 83 |
+
target_actions_logp,
|
| 84 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 85 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 86 |
+
)
|
| 87 |
+
behaviour_actions_logp_time_major = make_time_major(
|
| 88 |
+
behaviour_actions_logp,
|
| 89 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 90 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 91 |
+
)
|
| 92 |
+
rewards_time_major = make_time_major(
|
| 93 |
+
batch[Columns.REWARDS],
|
| 94 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 95 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# the discount factor that is used should be gamma except for timesteps where
|
| 99 |
+
# the episode is terminated. In that case, the discount factor should be 0.
|
| 100 |
+
discounts_time_major = (
|
| 101 |
+
1.0
|
| 102 |
+
- make_time_major(
|
| 103 |
+
batch[Columns.TERMINATEDS],
|
| 104 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 105 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 106 |
+
).type(dtype=torch.float32)
|
| 107 |
+
) * config.gamma
|
| 108 |
+
|
| 109 |
+
# Note that vtrace will compute the main loop on the CPU for better performance.
|
| 110 |
+
vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
|
| 111 |
+
target_action_log_probs=target_actions_logp_time_major,
|
| 112 |
+
behaviour_action_log_probs=behaviour_actions_logp_time_major,
|
| 113 |
+
discounts=discounts_time_major,
|
| 114 |
+
rewards=rewards_time_major,
|
| 115 |
+
values=values_time_major,
|
| 116 |
+
bootstrap_values=bootstrap_values,
|
| 117 |
+
clip_rho_threshold=config.vtrace_clip_rho_threshold,
|
| 118 |
+
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# The policy gradients loss.
|
| 122 |
+
pi_loss = -torch.sum(
|
| 123 |
+
target_actions_logp_time_major * pg_advantages * loss_mask_time_major
|
| 124 |
+
)
|
| 125 |
+
mean_pi_loss = pi_loss / size_loss_mask
|
| 126 |
+
|
| 127 |
+
# The baseline loss.
|
| 128 |
+
delta = values_time_major - vtrace_adjusted_target_values
|
| 129 |
+
vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
|
| 130 |
+
mean_vf_loss = vf_loss / size_loss_mask
|
| 131 |
+
|
| 132 |
+
# The entropy loss.
|
| 133 |
+
entropy_loss = -torch.sum(target_policy_dist.entropy() * loss_mask)
|
| 134 |
+
mean_entropy_loss = entropy_loss / size_loss_mask
|
| 135 |
+
|
| 136 |
+
# The summed weighted loss.
|
| 137 |
+
total_loss = (
|
| 138 |
+
mean_pi_loss
|
| 139 |
+
+ mean_vf_loss * config.vf_loss_coeff
|
| 140 |
+
+ (
|
| 141 |
+
mean_entropy_loss
|
| 142 |
+
* self.entropy_coeff_schedulers_per_module[
|
| 143 |
+
module_id
|
| 144 |
+
].get_current_value()
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Log important loss stats.
|
| 149 |
+
self.metrics.log_dict(
|
| 150 |
+
{
|
| 151 |
+
"pi_loss": pi_loss,
|
| 152 |
+
"mean_pi_loss": mean_pi_loss,
|
| 153 |
+
"vf_loss": vf_loss,
|
| 154 |
+
"mean_vf_loss": mean_vf_loss,
|
| 155 |
+
ENTROPY_KEY: -mean_entropy_loss,
|
| 156 |
+
},
|
| 157 |
+
key=module_id,
|
| 158 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 159 |
+
)
|
| 160 |
+
# Return the total loss.
|
| 161 |
+
return total_loss
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
ImpalaTorchLearner = IMPALATorchLearner
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/vtrace_torch_v2.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 3 |
+
|
| 4 |
+
torch, nn = try_import_torch()
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def make_time_major(
|
| 8 |
+
tensor: Union["torch.Tensor", List["torch.Tensor"]],
|
| 9 |
+
*,
|
| 10 |
+
trajectory_len: int = None,
|
| 11 |
+
recurrent_seq_len: int = None,
|
| 12 |
+
):
|
| 13 |
+
"""Swaps batch and trajectory axis.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
tensor: A tensor or list of tensors to swap the axis of.
|
| 17 |
+
NOTE: Each tensor must have the shape [B * T] where B is the batch size and
|
| 18 |
+
T is the trajectory length.
|
| 19 |
+
trajectory_len: The length of each trajectory being transformed.
|
| 20 |
+
If None then `recurrent_seq_len` must be set.
|
| 21 |
+
recurrent_seq_len: Sequence lengths if recurrent.
|
| 22 |
+
If None then `trajectory_len` must be set.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
res: A tensor with swapped axes or a list of tensors with
|
| 26 |
+
swapped axes.
|
| 27 |
+
"""
|
| 28 |
+
if isinstance(tensor, (list, tuple)):
|
| 29 |
+
return [
|
| 30 |
+
make_time_major(_tensor, trajectory_len, recurrent_seq_len)
|
| 31 |
+
for _tensor in tensor
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
assert (
|
| 35 |
+
trajectory_len is not None or recurrent_seq_len is not None
|
| 36 |
+
), "Either trajectory_len or recurrent_seq_len must be set."
|
| 37 |
+
|
| 38 |
+
# Figure out the sizes of the final B and T axes.
|
| 39 |
+
if recurrent_seq_len is not None:
|
| 40 |
+
assert len(tensor.shape) == 2
|
| 41 |
+
# Swap B and T axes.
|
| 42 |
+
tensor = torch.transpose(tensor, 1, 0)
|
| 43 |
+
return tensor
|
| 44 |
+
else:
|
| 45 |
+
T = trajectory_len
|
| 46 |
+
# Zero-pad, if necessary.
|
| 47 |
+
tensor_0 = tensor.shape[0]
|
| 48 |
+
B = tensor_0 // T
|
| 49 |
+
if B != (tensor_0 / T):
|
| 50 |
+
assert len(tensor.shape) == 1
|
| 51 |
+
tensor = torch.cat(
|
| 52 |
+
[
|
| 53 |
+
tensor,
|
| 54 |
+
torch.zeros(
|
| 55 |
+
trajectory_len - tensor_0 % T,
|
| 56 |
+
dtype=tensor.dtype,
|
| 57 |
+
device=tensor.device,
|
| 58 |
+
),
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
B += 1
|
| 62 |
+
|
| 63 |
+
# Reshape tensor (break up B axis into 2 axes: B and T).
|
| 64 |
+
tensor = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
|
| 65 |
+
|
| 66 |
+
# Swap B and T axes.
|
| 67 |
+
tensor = torch.transpose(tensor, 1, 0)
|
| 68 |
+
|
| 69 |
+
return tensor
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def vtrace_torch(
|
| 73 |
+
*,
|
| 74 |
+
target_action_log_probs: "torch.Tensor",
|
| 75 |
+
behaviour_action_log_probs: "torch.Tensor",
|
| 76 |
+
discounts: "torch.Tensor",
|
| 77 |
+
rewards: "torch.Tensor",
|
| 78 |
+
values: "torch.Tensor",
|
| 79 |
+
bootstrap_values: "torch.Tensor",
|
| 80 |
+
clip_rho_threshold: Union[float, "torch.Tensor"] = 1.0,
|
| 81 |
+
clip_pg_rho_threshold: Union[float, "torch.Tensor"] = 1.0,
|
| 82 |
+
):
|
| 83 |
+
"""V-trace for softmax policies implemented with torch.
|
| 84 |
+
|
| 85 |
+
Calculates V-trace actor critic targets for softmax polices as described in
|
| 86 |
+
"IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner
|
| 87 |
+
Architectures" by Espeholt, Soyer, Munos et al. (https://arxiv.org/abs/1802.01561)
|
| 88 |
+
|
| 89 |
+
The V-trace implementation used here closely resembles the one found in the
|
| 90 |
+
scalable-agent repository by Google DeepMind, available at
|
| 91 |
+
https://github.com/deepmind/scalable_agent. This version has been optimized to
|
| 92 |
+
minimize the number of floating-point operations required per V-Trace
|
| 93 |
+
calculation, achieved through the use of dynamic programming techniques. It's
|
| 94 |
+
important to note that the mathematical expressions used in this implementation
|
| 95 |
+
may appear quite different from those presented in the IMPALA paper.
|
| 96 |
+
|
| 97 |
+
The following terminology applies:
|
| 98 |
+
- `target policy` refers to the policy we are interested in improving.
|
| 99 |
+
- `behaviour policy` refers to the policy that generated the given
|
| 100 |
+
rewards and actions.
|
| 101 |
+
- `T` refers to the time dimension. This is usually either the length of the
|
| 102 |
+
trajectory or the length of the sequence if recurrent.
|
| 103 |
+
- `B` refers to the batch size.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
target_action_log_probs: Action log probs from the target policy. A float32
|
| 107 |
+
tensor of shape [T, B].
|
| 108 |
+
behaviour_action_log_probs: Action log probs from the behaviour policy. A
|
| 109 |
+
float32 tensor of shape [T, B].
|
| 110 |
+
discounts: A float32 tensor of shape [T, B] with the discount encountered when
|
| 111 |
+
following the behaviour policy. This will be 0 for terminal timesteps
|
| 112 |
+
(done=True) and gamma (the discount factor) otherwise.
|
| 113 |
+
rewards: A float32 tensor of shape [T, B] with the rewards generated by
|
| 114 |
+
following the behaviour policy.
|
| 115 |
+
values: A float32 tensor of shape [T, B] with the value function estimates
|
| 116 |
+
wrt. the target policy.
|
| 117 |
+
bootstrap_values: A float32 of shape [B] with the value function estimate at
|
| 118 |
+
time T.
|
| 119 |
+
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
|
| 120 |
+
importance weights (rho) when calculating the baseline targets (vs).
|
| 121 |
+
rho^bar in the paper.
|
| 122 |
+
clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
|
| 123 |
+
on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
|
| 124 |
+
"""
|
| 125 |
+
log_rhos = target_action_log_probs - behaviour_action_log_probs
|
| 126 |
+
|
| 127 |
+
rhos = torch.exp(log_rhos)
|
| 128 |
+
if clip_rho_threshold is not None:
|
| 129 |
+
clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold)
|
| 130 |
+
else:
|
| 131 |
+
clipped_rhos = rhos
|
| 132 |
+
|
| 133 |
+
cs = torch.clamp(rhos, max=1.0)
|
| 134 |
+
# Append bootstrapped value to get [v1, ..., v_t+1]
|
| 135 |
+
values_t_plus_1 = torch.cat(
|
| 136 |
+
[values[1:], torch.unsqueeze(bootstrap_values, 0)], axis=0
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
|
| 140 |
+
|
| 141 |
+
# Only move the for-loop to CPU.
|
| 142 |
+
discounts_cpu = discounts.to("cpu")
|
| 143 |
+
cs_cpu = cs.to("cpu")
|
| 144 |
+
deltas_cpu = deltas.to("cpu")
|
| 145 |
+
vs_minus_v_xs_cpu = [torch.zeros_like(bootstrap_values, device="cpu")]
|
| 146 |
+
for i in reversed(range(len(discounts_cpu))):
|
| 147 |
+
discount_t, c_t, delta_t = discounts_cpu[i], cs_cpu[i], deltas_cpu[i]
|
| 148 |
+
vs_minus_v_xs_cpu.append(delta_t + discount_t * c_t * vs_minus_v_xs_cpu[-1])
|
| 149 |
+
vs_minus_v_xs_cpu = torch.stack(vs_minus_v_xs_cpu[1:])
|
| 150 |
+
# Move results back to GPU - if applicable.
|
| 151 |
+
vs_minus_v_xs = vs_minus_v_xs_cpu.to(deltas.device)
|
| 152 |
+
|
| 153 |
+
# Reverse the results back to original order.
|
| 154 |
+
vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[0])
|
| 155 |
+
|
| 156 |
+
# Add V(x_s) to get v_s.
|
| 157 |
+
vs = torch.add(vs_minus_v_xs, values)
|
| 158 |
+
|
| 159 |
+
# Advantage for policy gradient.
|
| 160 |
+
vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_values, 0)], axis=0)
|
| 161 |
+
if clip_pg_rho_threshold is not None:
|
| 162 |
+
clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold)
|
| 163 |
+
else:
|
| 164 |
+
clipped_pg_rhos = rhos
|
| 165 |
+
pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)
|
| 166 |
+
|
| 167 |
+
# Make sure no gradients backpropagated through the returned values.
|
| 168 |
+
return torch.detach(vs), torch.detach(pg_advantages)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/centralized_critic.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @OldAPIStack
|
| 2 |
+
|
| 3 |
+
# ***********************************************************************************
|
| 4 |
+
# IMPORTANT NOTE: This script uses the old API stack and will soon be replaced by
|
| 5 |
+
# `ray.rllib.examples.multi_agent.pettingzoo_shared_value_function.py`!
|
| 6 |
+
# ***********************************************************************************
|
| 7 |
+
|
| 8 |
+
"""An example of customizing PPO to leverage a centralized critic.
|
| 9 |
+
|
| 10 |
+
Here the model and policy are hard-coded to implement a centralized critic
|
| 11 |
+
for TwoStepGame, but you can adapt this for your own use cases.
|
| 12 |
+
|
| 13 |
+
Compared to simply running `rllib/examples/two_step_game.py --run=PPO`,
|
| 14 |
+
this centralized critic version reaches vf_explained_variance=1.0 more stably
|
| 15 |
+
since it takes into account the opponent actions as well as the policy's.
|
| 16 |
+
Note that this is also using two independent policies instead of weight-sharing
|
| 17 |
+
with one.
|
| 18 |
+
|
| 19 |
+
See also: centralized_critic_2.py for a simpler approach that instead
|
| 20 |
+
modifies the environment.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
from gymnasium.spaces import Discrete
|
| 25 |
+
import numpy as np
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
import ray
|
| 29 |
+
from ray import air, tune
|
| 30 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 31 |
+
from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
|
| 32 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import (
|
| 33 |
+
PPOTF1Policy,
|
| 34 |
+
PPOTF2Policy,
|
| 35 |
+
)
|
| 36 |
+
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
| 37 |
+
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
|
| 38 |
+
from ray.rllib.examples.envs.classes.multi_agent.two_step_game import TwoStepGame
|
| 39 |
+
from ray.rllib.examples._old_api_stack.models.centralized_critic_models import (
|
| 40 |
+
CentralizedCriticModel,
|
| 41 |
+
TorchCentralizedCriticModel,
|
| 42 |
+
)
|
| 43 |
+
from ray.rllib.models import ModelCatalog
|
| 44 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 45 |
+
from ray.rllib.utils.annotations import override
|
| 46 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
| 47 |
+
from ray.rllib.utils.metrics import (
|
| 48 |
+
ENV_RUNNER_RESULTS,
|
| 49 |
+
EPISODE_RETURN_MEAN,
|
| 50 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 51 |
+
)
|
| 52 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 53 |
+
from ray.rllib.utils.test_utils import check_learning_achieved
|
| 54 |
+
from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
|
| 55 |
+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
| 56 |
+
|
| 57 |
+
tf1, tf, tfv = try_import_tf()
|
| 58 |
+
torch, nn = try_import_torch()
|
| 59 |
+
|
| 60 |
+
OPPONENT_OBS = "opponent_obs"
|
| 61 |
+
OPPONENT_ACTION = "opponent_action"
|
| 62 |
+
|
| 63 |
+
parser = argparse.ArgumentParser()
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--framework",
|
| 66 |
+
choices=["tf", "tf2", "torch"],
|
| 67 |
+
default="torch",
|
| 68 |
+
help="The DL framework specifier.",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--as-test",
|
| 72 |
+
action="store_true",
|
| 73 |
+
help="Whether this script should be run as a test: --stop-reward must "
|
| 74 |
+
"be achieved within --stop-timesteps AND --stop-iters.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--stop-iters", type=int, default=100, help="Number of iterations to train."
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--stop-reward", type=float, default=7.99, help="Reward at which we stop training."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class CentralizedValueMixin:
|
| 88 |
+
"""Add method to evaluate the central value function from the model."""
|
| 89 |
+
|
| 90 |
+
def __init__(self):
|
| 91 |
+
if self.config["framework"] != "torch":
|
| 92 |
+
self.compute_central_vf = make_tf_callable(self.get_session())(
|
| 93 |
+
self.model.central_value_function
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
self.compute_central_vf = self.model.central_value_function
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Grabs the opponent obs/act and includes it in the experience train_batch,
|
| 100 |
+
# and computes GAE using the central vf predictions.
|
| 101 |
+
def centralized_critic_postprocessing(
|
| 102 |
+
policy, sample_batch, other_agent_batches=None, episode=None
|
| 103 |
+
):
|
| 104 |
+
pytorch = policy.config["framework"] == "torch"
|
| 105 |
+
if (pytorch and hasattr(policy, "compute_central_vf")) or (
|
| 106 |
+
not pytorch and policy.loss_initialized()
|
| 107 |
+
):
|
| 108 |
+
assert other_agent_batches is not None
|
| 109 |
+
[(_, _, opponent_batch)] = list(other_agent_batches.values())
|
| 110 |
+
|
| 111 |
+
# also record the opponent obs and actions in the trajectory
|
| 112 |
+
sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
|
| 113 |
+
sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]
|
| 114 |
+
|
| 115 |
+
# overwrite default VF prediction with the central VF
|
| 116 |
+
if args.framework == "torch":
|
| 117 |
+
sample_batch[SampleBatch.VF_PREDS] = (
|
| 118 |
+
policy.compute_central_vf(
|
| 119 |
+
convert_to_torch_tensor(
|
| 120 |
+
sample_batch[SampleBatch.CUR_OBS], policy.device
|
| 121 |
+
),
|
| 122 |
+
convert_to_torch_tensor(sample_batch[OPPONENT_OBS], policy.device),
|
| 123 |
+
convert_to_torch_tensor(
|
| 124 |
+
sample_batch[OPPONENT_ACTION], policy.device
|
| 125 |
+
),
|
| 126 |
+
)
|
| 127 |
+
.cpu()
|
| 128 |
+
.detach()
|
| 129 |
+
.numpy()
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
sample_batch[SampleBatch.VF_PREDS] = convert_to_numpy(
|
| 133 |
+
policy.compute_central_vf(
|
| 134 |
+
sample_batch[SampleBatch.CUR_OBS],
|
| 135 |
+
sample_batch[OPPONENT_OBS],
|
| 136 |
+
sample_batch[OPPONENT_ACTION],
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
# Policy hasn't been initialized yet, use zeros.
|
| 141 |
+
sample_batch[OPPONENT_OBS] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
|
| 142 |
+
sample_batch[OPPONENT_ACTION] = np.zeros_like(sample_batch[SampleBatch.ACTIONS])
|
| 143 |
+
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
|
| 144 |
+
sample_batch[SampleBatch.REWARDS], dtype=np.float32
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
completed = sample_batch[SampleBatch.TERMINATEDS][-1]
|
| 148 |
+
if completed:
|
| 149 |
+
last_r = 0.0
|
| 150 |
+
else:
|
| 151 |
+
last_r = sample_batch[SampleBatch.VF_PREDS][-1]
|
| 152 |
+
|
| 153 |
+
train_batch = compute_advantages(
|
| 154 |
+
sample_batch,
|
| 155 |
+
last_r,
|
| 156 |
+
policy.config["gamma"],
|
| 157 |
+
policy.config["lambda"],
|
| 158 |
+
use_gae=policy.config["use_gae"],
|
| 159 |
+
)
|
| 160 |
+
return train_batch
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Copied from PPO but optimizing the central value function.
|
| 164 |
+
def loss_with_central_critic(policy, base_policy, model, dist_class, train_batch):
|
| 165 |
+
# Save original value function.
|
| 166 |
+
vf_saved = model.value_function
|
| 167 |
+
|
| 168 |
+
# Calculate loss with a custom value function.
|
| 169 |
+
model.value_function = lambda: policy.model.central_value_function(
|
| 170 |
+
train_batch[SampleBatch.CUR_OBS],
|
| 171 |
+
train_batch[OPPONENT_OBS],
|
| 172 |
+
train_batch[OPPONENT_ACTION],
|
| 173 |
+
)
|
| 174 |
+
policy._central_value_out = model.value_function()
|
| 175 |
+
loss = base_policy.loss(model, dist_class, train_batch)
|
| 176 |
+
|
| 177 |
+
# Restore original value function.
|
| 178 |
+
model.value_function = vf_saved
|
| 179 |
+
|
| 180 |
+
return loss
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def central_vf_stats(policy, train_batch):
|
| 184 |
+
# Report the explained variance of the central value function.
|
| 185 |
+
return {
|
| 186 |
+
"vf_explained_var": explained_variance(
|
| 187 |
+
train_batch[Postprocessing.VALUE_TARGETS], policy._central_value_out
|
| 188 |
+
)
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_ccppo_policy(base):
|
| 193 |
+
class CCPPOTFPolicy(CentralizedValueMixin, base):
|
| 194 |
+
def __init__(self, observation_space, action_space, config):
|
| 195 |
+
base.__init__(self, observation_space, action_space, config)
|
| 196 |
+
CentralizedValueMixin.__init__(self)
|
| 197 |
+
|
| 198 |
+
@override(base)
|
| 199 |
+
def loss(self, model, dist_class, train_batch):
|
| 200 |
+
# Use super() to get to the base PPO policy.
|
| 201 |
+
# This special loss function utilizes a shared
|
| 202 |
+
# value function defined on self, and the loss function
|
| 203 |
+
# defined on PPO policies.
|
| 204 |
+
return loss_with_central_critic(
|
| 205 |
+
self, super(), model, dist_class, train_batch
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
@override(base)
|
| 209 |
+
def postprocess_trajectory(
|
| 210 |
+
self, sample_batch, other_agent_batches=None, episode=None
|
| 211 |
+
):
|
| 212 |
+
return centralized_critic_postprocessing(
|
| 213 |
+
self, sample_batch, other_agent_batches, episode
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
@override(base)
|
| 217 |
+
def stats_fn(self, train_batch: SampleBatch):
|
| 218 |
+
stats = super().stats_fn(train_batch)
|
| 219 |
+
stats.update(central_vf_stats(self, train_batch))
|
| 220 |
+
return stats
|
| 221 |
+
|
| 222 |
+
return CCPPOTFPolicy
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOTF1Policy)
|
| 226 |
+
CCPPOEagerTFPolicy = get_ccppo_policy(PPOTF2Policy)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy):
|
| 230 |
+
def __init__(self, observation_space, action_space, config):
|
| 231 |
+
PPOTorchPolicy.__init__(self, observation_space, action_space, config)
|
| 232 |
+
CentralizedValueMixin.__init__(self)
|
| 233 |
+
|
| 234 |
+
@override(PPOTorchPolicy)
|
| 235 |
+
def loss(self, model, dist_class, train_batch):
|
| 236 |
+
return loss_with_central_critic(self, super(), model, dist_class, train_batch)
|
| 237 |
+
|
| 238 |
+
@override(PPOTorchPolicy)
|
| 239 |
+
def postprocess_trajectory(
|
| 240 |
+
self, sample_batch, other_agent_batches=None, episode=None
|
| 241 |
+
):
|
| 242 |
+
return centralized_critic_postprocessing(
|
| 243 |
+
self, sample_batch, other_agent_batches, episode
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class CentralizedCritic(PPO):
|
| 248 |
+
@classmethod
|
| 249 |
+
@override(PPO)
|
| 250 |
+
def get_default_policy_class(cls, config):
|
| 251 |
+
if config["framework"] == "torch":
|
| 252 |
+
return CCPPOTorchPolicy
|
| 253 |
+
elif config["framework"] == "tf":
|
| 254 |
+
return CCPPOStaticGraphTFPolicy
|
| 255 |
+
else:
|
| 256 |
+
return CCPPOEagerTFPolicy
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
ray.init(local_mode=True)
|
| 261 |
+
args = parser.parse_args()
|
| 262 |
+
|
| 263 |
+
ModelCatalog.register_custom_model(
|
| 264 |
+
"cc_model",
|
| 265 |
+
TorchCentralizedCriticModel
|
| 266 |
+
if args.framework == "torch"
|
| 267 |
+
else CentralizedCriticModel,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
config = (
|
| 271 |
+
PPOConfig()
|
| 272 |
+
.api_stack(
|
| 273 |
+
enable_env_runner_and_connector_v2=False,
|
| 274 |
+
enable_rl_module_and_learner=False,
|
| 275 |
+
)
|
| 276 |
+
.environment(TwoStepGame)
|
| 277 |
+
.framework(args.framework)
|
| 278 |
+
.env_runners(batch_mode="complete_episodes", num_env_runners=0)
|
| 279 |
+
.training(model={"custom_model": "cc_model"})
|
| 280 |
+
.multi_agent(
|
| 281 |
+
policies={
|
| 282 |
+
"pol1": (
|
| 283 |
+
None,
|
| 284 |
+
Discrete(6),
|
| 285 |
+
TwoStepGame.action_space,
|
| 286 |
+
# `framework` would also be ok here.
|
| 287 |
+
PPOConfig.overrides(framework_str=args.framework),
|
| 288 |
+
),
|
| 289 |
+
"pol2": (
|
| 290 |
+
None,
|
| 291 |
+
Discrete(6),
|
| 292 |
+
TwoStepGame.action_space,
|
| 293 |
+
# `framework` would also be ok here.
|
| 294 |
+
PPOConfig.overrides(framework_str=args.framework),
|
| 295 |
+
),
|
| 296 |
+
},
|
| 297 |
+
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
|
| 298 |
+
if agent_id == 0
|
| 299 |
+
else "pol2",
|
| 300 |
+
)
|
| 301 |
+
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
| 302 |
+
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
stop = {
|
| 306 |
+
TRAINING_ITERATION: args.stop_iters,
|
| 307 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
|
| 308 |
+
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
tuner = tune.Tuner(
|
| 312 |
+
CentralizedCritic,
|
| 313 |
+
param_space=config.to_dict(),
|
| 314 |
+
run_config=air.RunConfig(stop=stop, verbose=1),
|
| 315 |
+
)
|
| 316 |
+
results = tuner.fit()
|
| 317 |
+
|
| 318 |
+
if args.as_test:
|
| 319 |
+
check_learning_achieved(results, args.stop_reward)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/compute_adapted_gae_on_postprocess_trajectory.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @OldAPIStack
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Adapted (time-dependent) GAE for PPO algorithm that you can activate by setting
|
| 5 |
+
use_adapted_gae=True in the policy config. Additionally, it's required that
|
| 6 |
+
"callbacks" include the custom callback class in the Algorithm's config.
|
| 7 |
+
Furthermore, the env must return in its info dictionary a key-value pair of
|
| 8 |
+
the form "d_ts": ... where the value is the length (time) of recent agent step.
|
| 9 |
+
|
| 10 |
+
This adapted, time-dependent computation of advantages may be useful in cases
|
| 11 |
+
where agent's actions take various times and thus time steps are not
|
| 12 |
+
equidistant (https://docdro.id/400TvlR)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 16 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 17 |
+
from ray.rllib.evaluation.postprocessing import Postprocessing
|
| 18 |
+
from ray.rllib.utils.annotations import override
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MyCallbacks(RLlibCallback):
|
| 23 |
+
@override(RLlibCallback)
|
| 24 |
+
def on_postprocess_trajectory(
|
| 25 |
+
self,
|
| 26 |
+
*,
|
| 27 |
+
worker,
|
| 28 |
+
episode,
|
| 29 |
+
agent_id,
|
| 30 |
+
policy_id,
|
| 31 |
+
policies,
|
| 32 |
+
postprocessed_batch,
|
| 33 |
+
original_batches,
|
| 34 |
+
**kwargs
|
| 35 |
+
):
|
| 36 |
+
super().on_postprocess_trajectory(
|
| 37 |
+
worker=worker,
|
| 38 |
+
episode=episode,
|
| 39 |
+
agent_id=agent_id,
|
| 40 |
+
policy_id=policy_id,
|
| 41 |
+
policies=policies,
|
| 42 |
+
postprocessed_batch=postprocessed_batch,
|
| 43 |
+
original_batches=original_batches,
|
| 44 |
+
**kwargs
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if policies[policy_id].config.get("use_adapted_gae", False):
|
| 48 |
+
policy = policies[policy_id]
|
| 49 |
+
assert policy.config[
|
| 50 |
+
"use_gae"
|
| 51 |
+
], "Can't use adapted gae without use_gae=True!"
|
| 52 |
+
|
| 53 |
+
info_dicts = postprocessed_batch[SampleBatch.INFOS]
|
| 54 |
+
assert np.all(
|
| 55 |
+
["d_ts" in info_dict for info_dict in info_dicts]
|
| 56 |
+
), "Info dicts in sample batch must contain data 'd_ts' \
|
| 57 |
+
(=ts[i+1]-ts[i] length of time steps)!"
|
| 58 |
+
|
| 59 |
+
d_ts = np.array(
|
| 60 |
+
[np.float(info_dict.get("d_ts")) for info_dict in info_dicts]
|
| 61 |
+
)
|
| 62 |
+
assert np.all(
|
| 63 |
+
[e.is_integer() for e in d_ts]
|
| 64 |
+
), "Elements of 'd_ts' (length of time steps) must be integer!"
|
| 65 |
+
|
| 66 |
+
# Trajectory is actually complete -> last r=0.0.
|
| 67 |
+
if postprocessed_batch[SampleBatch.TERMINATEDS][-1]:
|
| 68 |
+
last_r = 0.0
|
| 69 |
+
# Trajectory has been truncated -> last r=VF estimate of last obs.
|
| 70 |
+
else:
|
| 71 |
+
# Input dict is provided to us automatically via the Model's
|
| 72 |
+
# requirements. It's a single-timestep (last one in trajectory)
|
| 73 |
+
# input_dict.
|
| 74 |
+
# Create an input dict according to the Model's requirements.
|
| 75 |
+
input_dict = postprocessed_batch.get_single_step_input_dict(
|
| 76 |
+
policy.model.view_requirements, index="last"
|
| 77 |
+
)
|
| 78 |
+
last_r = policy._value(**input_dict)
|
| 79 |
+
|
| 80 |
+
gamma = policy.config["gamma"]
|
| 81 |
+
lambda_ = policy.config["lambda"]
|
| 82 |
+
|
| 83 |
+
vpred_t = np.concatenate(
|
| 84 |
+
[postprocessed_batch[SampleBatch.VF_PREDS], np.array([last_r])]
|
| 85 |
+
)
|
| 86 |
+
delta_t = (
|
| 87 |
+
postprocessed_batch[SampleBatch.REWARDS]
|
| 88 |
+
+ gamma**d_ts * vpred_t[1:]
|
| 89 |
+
- vpred_t[:-1]
|
| 90 |
+
)
|
| 91 |
+
# This formula for the advantage is an adaption of
|
| 92 |
+
# "Generalized Advantage Estimation"
|
| 93 |
+
# (https://arxiv.org/abs/1506.02438) which accounts for time steps
|
| 94 |
+
# of irregular length (see proposal here ).
|
| 95 |
+
# NOTE: last time step delta is not required
|
| 96 |
+
postprocessed_batch[
|
| 97 |
+
Postprocessing.ADVANTAGES
|
| 98 |
+
] = generalized_discount_cumsum(delta_t, d_ts[:-1], gamma * lambda_)
|
| 99 |
+
postprocessed_batch[Postprocessing.VALUE_TARGETS] = (
|
| 100 |
+
postprocessed_batch[Postprocessing.ADVANTAGES]
|
| 101 |
+
+ postprocessed_batch[SampleBatch.VF_PREDS]
|
| 102 |
+
).astype(np.float32)
|
| 103 |
+
|
| 104 |
+
postprocessed_batch[Postprocessing.ADVANTAGES] = postprocessed_batch[
|
| 105 |
+
Postprocessing.ADVANTAGES
|
| 106 |
+
].astype(np.float32)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def generalized_discount_cumsum(
|
| 110 |
+
x: np.ndarray, deltas: np.ndarray, gamma: float
|
| 111 |
+
) -> np.ndarray:
|
| 112 |
+
"""Calculates the 'time-dependent' discounted cumulative sum over a
|
| 113 |
+
(reward) sequence `x`.
|
| 114 |
+
|
| 115 |
+
Recursive equations:
|
| 116 |
+
|
| 117 |
+
y[t] - gamma**deltas[t+1]*y[t+1] = x[t]
|
| 118 |
+
|
| 119 |
+
reversed(y)[t] - gamma**reversed(deltas)[t-1]*reversed(y)[t-1] =
|
| 120 |
+
reversed(x)[t]
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
x (np.ndarray): A sequence of rewards or one-step TD residuals.
|
| 124 |
+
deltas (np.ndarray): A sequence of time step deltas (length of time
|
| 125 |
+
steps).
|
| 126 |
+
gamma: The discount factor gamma.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
np.ndarray: The sequence containing the 'time-dependent' discounted
|
| 130 |
+
cumulative sums for each individual element in `x` till the end of
|
| 131 |
+
the trajectory.
|
| 132 |
+
|
| 133 |
+
.. testcode::
|
| 134 |
+
:skipif: True
|
| 135 |
+
|
| 136 |
+
x = np.array([0.0, 1.0, 2.0, 3.0])
|
| 137 |
+
deltas = np.array([1.0, 4.0, 15.0])
|
| 138 |
+
gamma = 0.9
|
| 139 |
+
generalized_discount_cumsum(x, deltas, gamma)
|
| 140 |
+
|
| 141 |
+
.. testoutput::
|
| 142 |
+
|
| 143 |
+
array([0.0 + 0.9^1.0*1.0 + 0.9^4.0*2.0 + 0.9^15.0*3.0,
|
| 144 |
+
1.0 + 0.9^4.0*2.0 + 0.9^15.0*3.0,
|
| 145 |
+
2.0 + 0.9^15.0*3.0,
|
| 146 |
+
3.0])
|
| 147 |
+
"""
|
| 148 |
+
reversed_x = x[::-1]
|
| 149 |
+
reversed_deltas = deltas[::-1]
|
| 150 |
+
reversed_y = np.empty_like(x)
|
| 151 |
+
reversed_y[0] = reversed_x[0]
|
| 152 |
+
for i in range(1, x.size):
|
| 153 |
+
reversed_y[i] = (
|
| 154 |
+
reversed_x[i] + gamma ** reversed_deltas[i - 1] * reversed_y[i - 1]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return reversed_y[::-1]
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_in_sequence.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of running a multi-agent experiment w/ agents taking turns (sequence).
|
| 2 |
+
|
| 3 |
+
This example:
|
| 4 |
+
- demonstrates how to write your own (multi-agent) environment using RLlib's
|
| 5 |
+
MultiAgentEnv API.
|
| 6 |
+
- shows how to implement the `reset()` and `step()` methods of the env such that
|
| 7 |
+
the agents act in a fixed sequence (taking turns).
|
| 8 |
+
- shows how to configure and setup this environment class within an RLlib
|
| 9 |
+
Algorithm config.
|
| 10 |
+
- runs the experiment with the configured algo, trying to solve the environment.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
How to run this script
|
| 14 |
+
----------------------
|
| 15 |
+
`python [script file name].py --enable-new-api-stack`
|
| 16 |
+
|
| 17 |
+
For debugging, use the following additional command line options
|
| 18 |
+
`--no-tune --num-env-runners=0`
|
| 19 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 20 |
+
have the execution stop there for inspection and debugging.
|
| 21 |
+
|
| 22 |
+
For logging to your WandB account, use:
|
| 23 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 24 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Results to expect
|
| 28 |
+
-----------------
|
| 29 |
+
You should see results similar to the following in your console output:
|
| 30 |
+
+---------------------------+----------+--------+------------------+--------+
|
| 31 |
+
| Trial name | status | iter | total time (s) | ts |
|
| 32 |
+
|---------------------------+----------+--------+------------------+--------+
|
| 33 |
+
| PPO_TicTacToe_957aa_00000 | RUNNING | 25 | 96.7452 | 100000 |
|
| 34 |
+
+---------------------------+----------+--------+------------------+--------+
|
| 35 |
+
+-------------------+------------------+------------------+
|
| 36 |
+
| combined return | return player2 | return player1 |
|
| 37 |
+
|-------------------+------------------+------------------|
|
| 38 |
+
| -2 | 1.15 | -0.85 |
|
| 39 |
+
+-------------------+------------------+------------------+
|
| 40 |
+
|
| 41 |
+
Note that even though we are playing a zero-sum game, the overall return should start
|
| 42 |
+
at some negative values due to the misplacement penalty of our (simplified) TicTacToe
|
| 43 |
+
game.
|
| 44 |
+
"""
|
| 45 |
+
from ray.rllib.examples.envs.classes.multi_agent.tic_tac_toe import TicTacToe
|
| 46 |
+
from ray.rllib.utils.test_utils import (
|
| 47 |
+
add_rllib_example_script_args,
|
| 48 |
+
run_rllib_example_script_experiment,
|
| 49 |
+
)
|
| 50 |
+
from ray.tune.registry import get_trainable_cls, register_env # noqa
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
parser = add_rllib_example_script_args(
|
| 54 |
+
default_reward=-4.0, default_iters=50, default_timesteps=100000
|
| 55 |
+
)
|
| 56 |
+
parser.set_defaults(
|
| 57 |
+
enable_new_api_stack=True,
|
| 58 |
+
num_agents=2,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
assert args.num_agents == 2, "Must set --num-agents=2 when running this script!"
|
| 66 |
+
|
| 67 |
+
# You can also register the env creator function explicitly with:
|
| 68 |
+
# register_env("tic_tac_toe", lambda cfg: TicTacToe())
|
| 69 |
+
|
| 70 |
+
# Or allow the RLlib user to set more c'tor options via their algo config:
|
| 71 |
+
# config.environment(env_config={[c'tor arg name]: [value]})
|
| 72 |
+
# register_env("tic_tac_toe", lambda cfg: TicTacToe(cfg))
|
| 73 |
+
|
| 74 |
+
base_config = (
|
| 75 |
+
get_trainable_cls(args.algo)
|
| 76 |
+
.get_default_config()
|
| 77 |
+
.environment(TicTacToe)
|
| 78 |
+
.multi_agent(
|
| 79 |
+
# Define two policies.
|
| 80 |
+
policies={"player1", "player2"},
|
| 81 |
+
# Map agent "player1" to policy "player1" and agent "player2" to policy
|
| 82 |
+
# "player2".
|
| 83 |
+
policy_mapping_fn=lambda agent_id, episode, **kw: agent_id,
|
| 84 |
+
)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
run_rllib_example_script_experiment(base_config, args)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_simultaneously.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example of running a multi-agent experiment w/ agents always acting simultaneously.
|
| 2 |
+
|
| 3 |
+
This example:
|
| 4 |
+
- demonstrates how to write your own (multi-agent) environment using RLlib's
|
| 5 |
+
MultiAgentEnv API.
|
| 6 |
+
- shows how to implement the `reset()` and `step()` methods of the env such that
|
| 7 |
+
the agents act simultaneously.
|
| 8 |
+
- shows how to configure and setup this environment class within an RLlib
|
| 9 |
+
Algorithm config.
|
| 10 |
+
- runs the experiment with the configured algo, trying to solve the environment.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
How to run this script
|
| 14 |
+
----------------------
|
| 15 |
+
`python [script file name].py --enable-new-api-stack --sheldon-cooper-mode`
|
| 16 |
+
|
| 17 |
+
For debugging, use the following additional command line options
|
| 18 |
+
`--no-tune --num-env-runners=0`
|
| 19 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 20 |
+
have the execution stop there for inspection and debugging.
|
| 21 |
+
|
| 22 |
+
For logging to your WandB account, use:
|
| 23 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 24 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Results to expect
|
| 28 |
+
-----------------
|
| 29 |
+
You should see results similar to the following in your console output:
|
| 30 |
+
|
| 31 |
+
+-----------------------------------+----------+--------+------------------+-------+
|
| 32 |
+
| Trial name | status | iter | total time (s) | ts |
|
| 33 |
+
|-----------------------------------+----------+--------+------------------+-------+
|
| 34 |
+
| PPO_RockPaperScissors_8cef7_00000 | RUNNING | 3 | 16.5348 | 12000 |
|
| 35 |
+
+-----------------------------------+----------+--------+------------------+-------+
|
| 36 |
+
+-------------------+------------------+------------------+
|
| 37 |
+
| combined return | return player2 | return player1 |
|
| 38 |
+
|-------------------+------------------+------------------|
|
| 39 |
+
| 0 | -0.15 | 0.15 |
|
| 40 |
+
+-------------------+------------------+------------------+
|
| 41 |
+
|
| 42 |
+
Note that b/c we are playing a zero-sum game, the overall return remains 0.0 at
|
| 43 |
+
all times.
|
| 44 |
+
"""
|
| 45 |
+
from ray.rllib.examples.envs.classes.multi_agent.rock_paper_scissors import (
|
| 46 |
+
RockPaperScissors,
|
| 47 |
+
)
|
| 48 |
+
from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations
|
| 49 |
+
from ray.rllib.utils.test_utils import (
|
| 50 |
+
add_rllib_example_script_args,
|
| 51 |
+
run_rllib_example_script_experiment,
|
| 52 |
+
)
|
| 53 |
+
from ray.tune.registry import get_trainable_cls, register_env # noqa
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
parser = add_rllib_example_script_args(
|
| 57 |
+
default_reward=0.9, default_iters=50, default_timesteps=100000
|
| 58 |
+
)
|
| 59 |
+
parser.set_defaults(
|
| 60 |
+
enable_new_api_stack=True,
|
| 61 |
+
num_agents=2,
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--sheldon-cooper-mode",
|
| 65 |
+
action="store_true",
|
| 66 |
+
help="Whether to add two more actions to the game: Lizard and Spock. "
|
| 67 |
+
"Watch here for more details :) https://www.youtube.com/watch?v=x5Q6-wMx-K8",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
assert args.num_agents == 2, "Must set --num-agents=2 when running this script!"
|
| 75 |
+
|
| 76 |
+
# You can also register the env creator function explicitly with:
|
| 77 |
+
# register_env("env", lambda cfg: RockPaperScissors({"sheldon_cooper_mode": False}))
|
| 78 |
+
|
| 79 |
+
# Or you can hard code certain settings into the Env's constructor (`config`).
|
| 80 |
+
# register_env(
|
| 81 |
+
# "rock-paper-scissors-w-sheldon-mode-activated",
|
| 82 |
+
# lambda config: RockPaperScissors({**config, **{"sheldon_cooper_mode": True}}),
|
| 83 |
+
# )
|
| 84 |
+
|
| 85 |
+
# Or allow the RLlib user to set more c'tor options via their algo config:
|
| 86 |
+
# config.environment(env_config={[c'tor arg name]: [value]})
|
| 87 |
+
# register_env("rock-paper-scissors", lambda cfg: RockPaperScissors(cfg))
|
| 88 |
+
|
| 89 |
+
base_config = (
|
| 90 |
+
get_trainable_cls(args.algo)
|
| 91 |
+
.get_default_config()
|
| 92 |
+
.environment(
|
| 93 |
+
RockPaperScissors,
|
| 94 |
+
env_config={"sheldon_cooper_mode": args.sheldon_cooper_mode},
|
| 95 |
+
)
|
| 96 |
+
.env_runners(
|
| 97 |
+
env_to_module_connector=lambda env: FlattenObservations(multi_agent=True),
|
| 98 |
+
)
|
| 99 |
+
.multi_agent(
|
| 100 |
+
# Define two policies.
|
| 101 |
+
policies={"player1", "player2"},
|
| 102 |
+
# Map agent "player1" to policy "player1" and agent "player2" to policy
|
| 103 |
+
# "player2".
|
| 104 |
+
policy_mapping_fn=lambda agent_id, episode, **kw: agent_id,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
run_rllib_example_script_experiment(base_config, args)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/async_gym_env_vectorization.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Example demo'ing async gym vector envs, in which sub-envs have their own process.
|
| 2 |
+
|
| 3 |
+
Setting up env vectorization works through setting the `config.num_envs_per_env_runner`
|
| 4 |
+
value to > 1. However, by default the n sub-environments are stepped through
|
| 5 |
+
sequentially, rather than in parallel.
|
| 6 |
+
|
| 7 |
+
This script shows the effect of setting the `config.gym_env_vectorize_mode` from its
|
| 8 |
+
default value of "SYNC" (all sub envs are located in the same EnvRunner process)
|
| 9 |
+
to "ASYNC" (all sub envs in each EnvRunner get their own process).
|
| 10 |
+
|
| 11 |
+
This example:
|
| 12 |
+
- shows, which config settings to change in order to switch from sub-envs being
|
| 13 |
+
stepped in sequence to each sub-envs owning its own process (and compute resource)
|
| 14 |
+
and thus the vector being stepped in parallel.
|
| 15 |
+
- shows, how this setup can increase EnvRunner performance significantly, especially
|
| 16 |
+
for heavier, slower environments.
|
| 17 |
+
- uses an artificially slow CartPole-v1 environment for demonstration purposes.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
How to run this script
|
| 21 |
+
----------------------
|
| 22 |
+
`python [script file name].py --enable-new-api-stack `
|
| 23 |
+
|
| 24 |
+
Use the `--vectorize-mode=BOTH` option to run both modes (SYNC and ASYNC)
|
| 25 |
+
through Tune at the same time and get a better comparison of the throughputs
|
| 26 |
+
achieved.
|
| 27 |
+
|
| 28 |
+
For debugging, use the following additional command line options
|
| 29 |
+
`--no-tune --num-env-runners=0`
|
| 30 |
+
which should allow you to set breakpoints anywhere in the RLlib code and
|
| 31 |
+
have the execution stop there for inspection and debugging.
|
| 32 |
+
|
| 33 |
+
For logging to your WandB account, use:
|
| 34 |
+
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
|
| 35 |
+
--wandb-run-name=[optional: WandB run name (within the defined project)]`
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Results to expect
|
| 39 |
+
-----------------
|
| 40 |
+
You should see results similar to the following in your console output
|
| 41 |
+
when using the
|
| 42 |
+
|
| 43 |
+
+--------------------------+------------+------------------------+------+
|
| 44 |
+
| Trial name | status | gym_env_vectorize_mode | iter |
|
| 45 |
+
| | | | |
|
| 46 |
+
|--------------------------+------------+------------------------+------+
|
| 47 |
+
| PPO_slow-env_6ddf4_00000 | TERMINATED | SYNC | 4 |
|
| 48 |
+
| PPO_slow-env_6ddf4_00001 | TERMINATED | ASYNC | 4 |
|
| 49 |
+
+--------------------------+------------+------------------------+------+
|
| 50 |
+
+------------------+----------------------+------------------------+
|
| 51 |
+
| total time (s) | episode_return_mean | num_env_steps_sample |
|
| 52 |
+
| | | d_lifetime |
|
| 53 |
+
|------------------+----------------------+------------------------+
|
| 54 |
+
| 60.8794 | 73.53 | 16040 |
|
| 55 |
+
| 19.1203 | 73.86 | 16037 |
|
| 56 |
+
+------------------+----------------------+------------------------+
|
| 57 |
+
|
| 58 |
+
You can see that the ASYNC mode, given that the env is sufficiently slow,
|
| 59 |
+
achieves much better results when using vectorization.
|
| 60 |
+
|
| 61 |
+
You should see no difference, however, when only using
|
| 62 |
+
`--num-envs-per-env-runner=1`.
|
| 63 |
+
"""
|
| 64 |
+
import time
|
| 65 |
+
|
| 66 |
+
import gymnasium as gym
|
| 67 |
+
|
| 68 |
+
from ray.rllib.algorithms.ppo import PPOConfig
|
| 69 |
+
from ray.rllib.utils.test_utils import (
|
| 70 |
+
add_rllib_example_script_args,
|
| 71 |
+
run_rllib_example_script_experiment,
|
| 72 |
+
)
|
| 73 |
+
from ray import tune
|
| 74 |
+
|
| 75 |
+
parser = add_rllib_example_script_args(default_reward=60.0)
|
| 76 |
+
parser.set_defaults(
|
| 77 |
+
enable_new_api_stack=True,
|
| 78 |
+
env="CartPole-v1",
|
| 79 |
+
num_envs_per_env_runner=6,
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--vectorize-mode",
|
| 83 |
+
type=str,
|
| 84 |
+
default="ASYNC",
|
| 85 |
+
help="The value `gym.envs.registration.VectorizeMode` to use for env "
|
| 86 |
+
"vectorization. SYNC steps through all sub-envs in sequence. ASYNC (default) "
|
| 87 |
+
"parallelizes sub-envs through multiprocessing and can speed up EnvRunners "
|
| 88 |
+
"significantly. Use the special value `BOTH` to run both ASYNC and SYNC through a "
|
| 89 |
+
"Tune grid-search.",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SlowEnv(gym.ObservationWrapper):
|
| 94 |
+
def observation(self, observation):
|
| 95 |
+
time.sleep(0.005)
|
| 96 |
+
return observation
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
if args.no_tune and args.vectorize_mode == "BOTH":
|
| 103 |
+
raise ValueError(
|
| 104 |
+
"Can't run this script with both --no-tune and --vectorize-mode=BOTH!"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Wrap the env with the slowness wrapper.
|
| 108 |
+
def _env_creator(cfg):
|
| 109 |
+
return SlowEnv(gym.make(args.env, **cfg))
|
| 110 |
+
|
| 111 |
+
tune.register_env("slow-env", _env_creator)
|
| 112 |
+
|
| 113 |
+
if args.vectorize_mode == "BOTH" and args.no_tune:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"`--vectorize-mode=BOTH` and `--no-tune` not allowed in combination!"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
base_config = (
|
| 119 |
+
PPOConfig()
|
| 120 |
+
.environment("slow-env")
|
| 121 |
+
.env_runners(
|
| 122 |
+
gym_env_vectorize_mode=(
|
| 123 |
+
tune.grid_search(["SYNC", "ASYNC"])
|
| 124 |
+
if args.vectorize_mode == "BOTH"
|
| 125 |
+
else args.vectorize_mode
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
results = run_rllib_example_script_experiment(base_config, args)
|
| 131 |
+
|
| 132 |
+
# Compare the throughputs and assert that ASYNC is much faster than SYNC.
|
| 133 |
+
if args.vectorize_mode == "BOTH":
|
| 134 |
+
throughput_sync = (
|
| 135 |
+
results[0].metrics["num_env_steps_sampled_lifetime"]
|
| 136 |
+
/ results[0].metrics["time_total_s"]
|
| 137 |
+
)
|
| 138 |
+
throughput_async = (
|
| 139 |
+
results[1].metrics["num_env_steps_sampled_lifetime"]
|
| 140 |
+
/ results[1].metrics["time_total_s"]
|
| 141 |
+
)
|
| 142 |
+
assert throughput_async > throughput_sync
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/action_mask_env.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gymnasium.spaces import Box, Dict, Discrete
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from ray.rllib.examples.envs.classes.random_env import RandomEnv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ActionMaskEnv(RandomEnv):
|
| 8 |
+
"""A randomly acting environment that publishes an action-mask each step."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, config):
|
| 11 |
+
super().__init__(config)
|
| 12 |
+
# Masking only works for Discrete actions.
|
| 13 |
+
assert isinstance(self.action_space, Discrete)
|
| 14 |
+
# Add action_mask to observations.
|
| 15 |
+
self.observation_space = Dict(
|
| 16 |
+
{
|
| 17 |
+
"action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
|
| 18 |
+
"observations": self.observation_space,
|
| 19 |
+
}
|
| 20 |
+
)
|
| 21 |
+
self.valid_actions = None
|
| 22 |
+
|
| 23 |
+
def reset(self, *, seed=None, options=None):
|
| 24 |
+
obs, info = super().reset()
|
| 25 |
+
self._fix_action_mask(obs)
|
| 26 |
+
return obs, info
|
| 27 |
+
|
| 28 |
+
def step(self, action):
|
| 29 |
+
# Check whether action is valid.
|
| 30 |
+
if not self.valid_actions[action]:
|
| 31 |
+
raise ValueError(
|
| 32 |
+
f"Invalid action ({action}) sent to env! "
|
| 33 |
+
f"valid_actions={self.valid_actions}"
|
| 34 |
+
)
|
| 35 |
+
obs, rew, done, truncated, info = super().step(action)
|
| 36 |
+
self._fix_action_mask(obs)
|
| 37 |
+
return obs, rew, done, truncated, info
|
| 38 |
+
|
| 39 |
+
def _fix_action_mask(self, obs):
|
| 40 |
+
# Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
|
| 41 |
+
self.valid_actions = np.round(obs["action_mask"])
|
| 42 |
+
obs["action_mask"] = self.valid_actions
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_crashing.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from gymnasium.envs.classic_control import CartPoleEnv
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from ray.rllib.examples.envs.classes.multi_agent import make_multi_agent
|
| 7 |
+
from ray.rllib.utils.annotations import override
|
| 8 |
+
from ray.rllib.utils.error import EnvError
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CartPoleCrashing(CartPoleEnv):
|
| 14 |
+
"""A CartPole env that crashes (or stalls) from time to time.
|
| 15 |
+
|
| 16 |
+
Useful for testing faulty sub-env (within a vectorized env) handling by
|
| 17 |
+
EnvRunners.
|
| 18 |
+
|
| 19 |
+
After crashing, the env expects a `reset()` call next (calling `step()` will
|
| 20 |
+
result in yet another error), which may or may not take a very long time to
|
| 21 |
+
complete. This simulates the env having to reinitialize some sub-processes, e.g.
|
| 22 |
+
an external connection.
|
| 23 |
+
|
| 24 |
+
The env can also be configured to stall (and do nothing during a call to `step()`)
|
| 25 |
+
from time to time for a configurable amount of time.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
self.config = config if config is not None else {}
|
| 32 |
+
|
| 33 |
+
# Crash probability (in each `step()`).
|
| 34 |
+
self.p_crash = config.get("p_crash", 0.005)
|
| 35 |
+
# Crash probability when `reset()` is called.
|
| 36 |
+
self.p_crash_reset = config.get("p_crash_reset", 0.0)
|
| 37 |
+
# Crash exactly after every n steps. If a 2-tuple, will uniformly sample
|
| 38 |
+
# crash timesteps from in between the two given values.
|
| 39 |
+
self.crash_after_n_steps = config.get("crash_after_n_steps")
|
| 40 |
+
self._crash_after_n_steps = None
|
| 41 |
+
assert (
|
| 42 |
+
self.crash_after_n_steps is None
|
| 43 |
+
or isinstance(self.crash_after_n_steps, int)
|
| 44 |
+
or (
|
| 45 |
+
isinstance(self.crash_after_n_steps, tuple)
|
| 46 |
+
and len(self.crash_after_n_steps) == 2
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
# Only ever crash, if on certain worker indices.
|
| 50 |
+
faulty_indices = config.get("crash_on_worker_indices", None)
|
| 51 |
+
if faulty_indices and config.worker_index not in faulty_indices:
|
| 52 |
+
self.p_crash = 0.0
|
| 53 |
+
self.p_crash_reset = 0.0
|
| 54 |
+
self.crash_after_n_steps = None
|
| 55 |
+
|
| 56 |
+
# Stall probability (in each `step()`).
|
| 57 |
+
self.p_stall = config.get("p_stall", 0.0)
|
| 58 |
+
# Stall probability when `reset()` is called.
|
| 59 |
+
self.p_stall_reset = config.get("p_stall_reset", 0.0)
|
| 60 |
+
# Stall exactly after every n steps.
|
| 61 |
+
self.stall_after_n_steps = config.get("stall_after_n_steps")
|
| 62 |
+
self._stall_after_n_steps = None
|
| 63 |
+
# Amount of time to stall. If a 2-tuple, will uniformly sample from in between
|
| 64 |
+
# the two given values.
|
| 65 |
+
self.stall_time_sec = config.get("stall_time_sec")
|
| 66 |
+
assert (
|
| 67 |
+
self.stall_time_sec is None
|
| 68 |
+
or isinstance(self.stall_time_sec, (int, float))
|
| 69 |
+
or (
|
| 70 |
+
isinstance(self.stall_time_sec, tuple) and len(self.stall_time_sec) == 2
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Only ever stall, if on certain worker indices.
|
| 75 |
+
faulty_indices = config.get("stall_on_worker_indices", None)
|
| 76 |
+
if faulty_indices and config.worker_index not in faulty_indices:
|
| 77 |
+
self.p_stall = 0.0
|
| 78 |
+
self.p_stall_reset = 0.0
|
| 79 |
+
self.stall_after_n_steps = None
|
| 80 |
+
|
| 81 |
+
# Timestep counter for the ongoing episode.
|
| 82 |
+
self.timesteps = 0
|
| 83 |
+
|
| 84 |
+
# Time in seconds to initialize (in this c'tor).
|
| 85 |
+
sample = 0.0
|
| 86 |
+
if "init_time_s" in config:
|
| 87 |
+
sample = (
|
| 88 |
+
config["init_time_s"]
|
| 89 |
+
if not isinstance(config["init_time_s"], tuple)
|
| 90 |
+
else np.random.uniform(
|
| 91 |
+
config["init_time_s"][0], config["init_time_s"][1]
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
print(f"Initializing crashing env (with init-delay of {sample}sec) ...")
|
| 96 |
+
time.sleep(sample)
|
| 97 |
+
|
| 98 |
+
# Make sure envs don't crash at the same time.
|
| 99 |
+
self._rng = np.random.RandomState()
|
| 100 |
+
|
| 101 |
+
@override(CartPoleEnv)
|
| 102 |
+
def reset(self, *, seed=None, options=None):
|
| 103 |
+
# Reset timestep counter for the new episode.
|
| 104 |
+
self.timesteps = 0
|
| 105 |
+
self._crash_after_n_steps = None
|
| 106 |
+
|
| 107 |
+
# Should we crash?
|
| 108 |
+
if self._should_crash(p=self.p_crash_reset):
|
| 109 |
+
raise EnvError(
|
| 110 |
+
f"Simulated env crash on worker={self.config.worker_index} "
|
| 111 |
+
f"env-idx={self.config.vector_index} during `reset()`! "
|
| 112 |
+
"Feel free to use any other exception type here instead."
|
| 113 |
+
)
|
| 114 |
+
# Should we stall for a while?
|
| 115 |
+
self._stall_if_necessary(p=self.p_stall_reset)
|
| 116 |
+
|
| 117 |
+
return super().reset()
|
| 118 |
+
|
| 119 |
+
@override(CartPoleEnv)
|
| 120 |
+
def step(self, action):
|
| 121 |
+
# Increase timestep counter for the ongoing episode.
|
| 122 |
+
self.timesteps += 1
|
| 123 |
+
|
| 124 |
+
# Should we crash?
|
| 125 |
+
if self._should_crash(p=self.p_crash):
|
| 126 |
+
raise EnvError(
|
| 127 |
+
f"Simulated env crash on worker={self.config.worker_index} "
|
| 128 |
+
f"env-idx={self.config.vector_index} during `step()`! "
|
| 129 |
+
"Feel free to use any other exception type here instead."
|
| 130 |
+
)
|
| 131 |
+
# Should we stall for a while?
|
| 132 |
+
self._stall_if_necessary(p=self.p_stall)
|
| 133 |
+
|
| 134 |
+
return super().step(action)
|
| 135 |
+
|
| 136 |
+
def _should_crash(self, p):
|
| 137 |
+
rnd = self._rng.rand()
|
| 138 |
+
if rnd < p:
|
| 139 |
+
print("Crashing due to p(crash)!")
|
| 140 |
+
return True
|
| 141 |
+
elif self.crash_after_n_steps is not None:
|
| 142 |
+
if self._crash_after_n_steps is None:
|
| 143 |
+
self._crash_after_n_steps = (
|
| 144 |
+
self.crash_after_n_steps
|
| 145 |
+
if not isinstance(self.crash_after_n_steps, tuple)
|
| 146 |
+
else np.random.randint(
|
| 147 |
+
self.crash_after_n_steps[0], self.crash_after_n_steps[1]
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
if self._crash_after_n_steps == self.timesteps:
|
| 151 |
+
print("Crashing due to n timesteps reached!")
|
| 152 |
+
return True
|
| 153 |
+
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
def _stall_if_necessary(self, p):
|
| 157 |
+
stall = False
|
| 158 |
+
if self._rng.rand() < p:
|
| 159 |
+
stall = True
|
| 160 |
+
elif self.stall_after_n_steps is not None:
|
| 161 |
+
if self._stall_after_n_steps is None:
|
| 162 |
+
self._stall_after_n_steps = (
|
| 163 |
+
self.stall_after_n_steps
|
| 164 |
+
if not isinstance(self.stall_after_n_steps, tuple)
|
| 165 |
+
else np.random.randint(
|
| 166 |
+
self.stall_after_n_steps[0], self.stall_after_n_steps[1]
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
if self._stall_after_n_steps == self.timesteps:
|
| 170 |
+
stall = True
|
| 171 |
+
|
| 172 |
+
if stall:
|
| 173 |
+
sec = (
|
| 174 |
+
self.stall_time_sec
|
| 175 |
+
if not isinstance(self.stall_time_sec, tuple)
|
| 176 |
+
else np.random.uniform(self.stall_time_sec[0], self.stall_time_sec[1])
|
| 177 |
+
)
|
| 178 |
+
print(f" -> will stall for {sec}sec ...")
|
| 179 |
+
time.sleep(sec)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
MultiAgentCartPoleCrashing = make_multi_agent(lambda config: CartPoleCrashing(config))
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_sparse_rewards.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
import numpy as np
|
| 5 |
+
from gymnasium.spaces import Box, Dict, Discrete
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CartPoleSparseRewards(gym.Env):
|
| 9 |
+
"""Wrapper for gym CartPole environment where reward is accumulated to the end."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, config=None):
|
| 12 |
+
self.env = gym.make("CartPole-v1")
|
| 13 |
+
self.action_space = Discrete(2)
|
| 14 |
+
self.observation_space = Dict(
|
| 15 |
+
{
|
| 16 |
+
"obs": self.env.observation_space,
|
| 17 |
+
"action_mask": Box(
|
| 18 |
+
low=0, high=1, shape=(self.action_space.n,), dtype=np.int8
|
| 19 |
+
),
|
| 20 |
+
}
|
| 21 |
+
)
|
| 22 |
+
self.running_reward = 0
|
| 23 |
+
|
| 24 |
+
def reset(self, *, seed=None, options=None):
|
| 25 |
+
self.running_reward = 0
|
| 26 |
+
obs, infos = self.env.reset()
|
| 27 |
+
return {
|
| 28 |
+
"obs": obs,
|
| 29 |
+
"action_mask": np.array([1, 1], dtype=np.int8),
|
| 30 |
+
}, infos
|
| 31 |
+
|
| 32 |
+
def step(self, action):
|
| 33 |
+
obs, rew, terminated, truncated, info = self.env.step(action)
|
| 34 |
+
self.running_reward += rew
|
| 35 |
+
score = self.running_reward if terminated else 0
|
| 36 |
+
return (
|
| 37 |
+
{"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)},
|
| 38 |
+
score,
|
| 39 |
+
terminated,
|
| 40 |
+
truncated,
|
| 41 |
+
info,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def set_state(self, state):
|
| 45 |
+
self.running_reward = state[1]
|
| 46 |
+
self.env = deepcopy(state[0])
|
| 47 |
+
obs = np.array(list(self.env.unwrapped.state))
|
| 48 |
+
return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)}
|
| 49 |
+
|
| 50 |
+
def get_state(self):
|
| 51 |
+
return deepcopy(self.env), self.running_reward
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_dict_observation_space.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.envs.classic_control import CartPoleEnv
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CartPoleWithDictObservationSpace(CartPoleEnv):
|
| 7 |
+
"""CartPole gym environment that has a dict observation space.
|
| 8 |
+
|
| 9 |
+
However, otherwise, the information content in each observation remains the same.
|
| 10 |
+
|
| 11 |
+
https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py # noqa
|
| 12 |
+
|
| 13 |
+
The new observation space looks as follows (a little quirky, but this is
|
| 14 |
+
for testing purposes only):
|
| 15 |
+
|
| 16 |
+
gym.spaces.Dict({
|
| 17 |
+
"x-pos": [x-pos],
|
| 18 |
+
"angular-pos": gym.spaces.Dict({"test": [angular-pos]}),
|
| 19 |
+
"velocs": gym.spaces.Tuple([x-veloc, angular-veloc]),
|
| 20 |
+
})
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config=None):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# Fix our observation-space as described above.
|
| 27 |
+
low = self.observation_space.low
|
| 28 |
+
high = self.observation_space.high
|
| 29 |
+
|
| 30 |
+
# Test as many quirks and oddities as possible: Dict, Dict inside a Dict,
|
| 31 |
+
# Tuple inside a Dict, and both (1,)-shapes as well as ()-shapes for Boxes.
|
| 32 |
+
# Also add a random discrete variable here.
|
| 33 |
+
self.observation_space = gym.spaces.Dict(
|
| 34 |
+
{
|
| 35 |
+
"x-pos": gym.spaces.Box(low[0], high[0], (1,), dtype=np.float32),
|
| 36 |
+
"angular-pos": gym.spaces.Dict(
|
| 37 |
+
{
|
| 38 |
+
"value": gym.spaces.Box(low[2], high[2], (), dtype=np.float32),
|
| 39 |
+
# Add some random non-essential information.
|
| 40 |
+
"some_random_stuff": gym.spaces.Discrete(3),
|
| 41 |
+
}
|
| 42 |
+
),
|
| 43 |
+
"velocs": gym.spaces.Tuple(
|
| 44 |
+
[
|
| 45 |
+
# x-veloc
|
| 46 |
+
gym.spaces.Box(low[1], high[1], (1,), dtype=np.float32),
|
| 47 |
+
# angular-veloc
|
| 48 |
+
gym.spaces.Box(low[3], high[3], (), dtype=np.float32),
|
| 49 |
+
]
|
| 50 |
+
),
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def step(self, action):
|
| 55 |
+
next_obs, reward, done, truncated, info = super().step(action)
|
| 56 |
+
return self._compile_current_obs(next_obs), reward, done, truncated, info
|
| 57 |
+
|
| 58 |
+
def reset(self, *, seed=None, options=None):
|
| 59 |
+
init_obs, init_info = super().reset(seed=seed, options=options)
|
| 60 |
+
return self._compile_current_obs(init_obs), init_info
|
| 61 |
+
|
| 62 |
+
def _compile_current_obs(self, original_cartpole_obs):
|
| 63 |
+
# original_cartpole_obs is [x-pos, x-veloc, angle, angle-veloc]
|
| 64 |
+
return {
|
| 65 |
+
"x-pos": np.array([original_cartpole_obs[0]], np.float32),
|
| 66 |
+
"angular-pos": {
|
| 67 |
+
"value": original_cartpole_obs[2],
|
| 68 |
+
"some_random_stuff": np.random.randint(3),
|
| 69 |
+
},
|
| 70 |
+
"velocs": (
|
| 71 |
+
np.array([original_cartpole_obs[1]], np.float32),
|
| 72 |
+
np.array(original_cartpole_obs[3], np.float32),
|
| 73 |
+
),
|
| 74 |
+
}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_large_observation_space.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.envs.classic_control import CartPoleEnv
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CartPoleWithLargeObservationSpace(CartPoleEnv):
|
| 7 |
+
"""CartPole gym environment that has a large dict observation space.
|
| 8 |
+
|
| 9 |
+
However, otherwise, the information content in each observation remains the same.
|
| 10 |
+
|
| 11 |
+
https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py # noqa
|
| 12 |
+
|
| 13 |
+
The new observation space looks as follows (a little quirky, but this is
|
| 14 |
+
for testing purposes only):
|
| 15 |
+
|
| 16 |
+
gym.spaces.Dict({
|
| 17 |
+
"1": gym.spaces.Tuple((
|
| 18 |
+
gym.spaces.Discrete(100),
|
| 19 |
+
gym.spaces.Box(0, 256, shape=(30,), dtype=float32),
|
| 20 |
+
)),
|
| 21 |
+
"2": gym.spaces.Tuple((
|
| 22 |
+
gym.spaces.Discrete(100),
|
| 23 |
+
gym.spaces.Box(0, 256, shape=(30,), dtype=float32),
|
| 24 |
+
)),
|
| 25 |
+
"3": ...
|
| 26 |
+
"actual-obs": gym.spaces.Box(-inf, inf, (4,), float32),
|
| 27 |
+
})
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, config=None):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
# Fix our observation-space as described above.
|
| 34 |
+
low = self.observation_space.low
|
| 35 |
+
high = self.observation_space.high
|
| 36 |
+
|
| 37 |
+
# Test as many quirks and oddities as possible: Dict, Dict inside a Dict,
|
| 38 |
+
# Tuple inside a Dict, and both (1,)-shapes as well as ()-shapes for Boxes.
|
| 39 |
+
# Also add a random discrete variable here.
|
| 40 |
+
spaces = {
|
| 41 |
+
str(i): gym.spaces.Tuple(
|
| 42 |
+
(
|
| 43 |
+
gym.spaces.Discrete(100),
|
| 44 |
+
gym.spaces.Box(0, 256, shape=(30,), dtype=np.float32),
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
for i in range(100)
|
| 48 |
+
}
|
| 49 |
+
spaces.update(
|
| 50 |
+
{
|
| 51 |
+
"actually-useful-stuff": (
|
| 52 |
+
gym.spaces.Box(low[0], high[0], (4,), np.float32)
|
| 53 |
+
)
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
self.observation_space = gym.spaces.Dict(spaces)
|
| 57 |
+
|
| 58 |
+
def step(self, action):
|
| 59 |
+
next_obs, reward, done, truncated, info = super().step(action)
|
| 60 |
+
return self._compile_current_obs(next_obs), reward, done, truncated, info
|
| 61 |
+
|
| 62 |
+
def reset(self, *, seed=None, options=None):
|
| 63 |
+
init_obs, init_info = super().reset(seed=seed, options=options)
|
| 64 |
+
return self._compile_current_obs(init_obs), init_info
|
| 65 |
+
|
| 66 |
+
def _compile_current_obs(self, original_cartpole_obs):
|
| 67 |
+
return {
|
| 68 |
+
str(i): self.observation_space.spaces[str(i)].sample() for i in range(100)
|
| 69 |
+
} | {"actually-useful-stuff": original_cartpole_obs}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_protobuf_observation_space.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.envs.classic_control import CartPoleEnv
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from ray.rllib.examples.envs.classes.utils.cartpole_observations_proto import (
|
| 6 |
+
CartPoleObservation,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CartPoleWithProtobufObservationSpace(CartPoleEnv):
|
| 11 |
+
"""CartPole gym environment that has a protobuf observation space.
|
| 12 |
+
|
| 13 |
+
Sometimes, it is more performant for an environment to publish its observations
|
| 14 |
+
as a protobuf message (instead of a heavily nested Dict).
|
| 15 |
+
|
| 16 |
+
The protobuf message used here is originally defined in the
|
| 17 |
+
`./utils/cartpole_observations.proto` file. We converted this file into a python
|
| 18 |
+
importable module by compiling it with:
|
| 19 |
+
|
| 20 |
+
`protoc --python_out=. cartpole_observations.proto`
|
| 21 |
+
|
| 22 |
+
.. which yielded the `cartpole_observations_proto.py` file in the same directory
|
| 23 |
+
(we import this file's `CartPoleObservation` message here).
|
| 24 |
+
|
| 25 |
+
The new observation space is a (binary) Box(0, 255, ([len of protobuf],), uint8).
|
| 26 |
+
|
| 27 |
+
A ConnectorV2 pipeline or simpler gym.Wrapper will have to be used to convert this
|
| 28 |
+
observation format into an NN-readable (e.g. float32) 1D tensor.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config=None):
|
| 32 |
+
super().__init__()
|
| 33 |
+
dummy_obs = self._convert_observation_to_protobuf(
|
| 34 |
+
np.array([1.0, 1.0, 1.0, 1.0])
|
| 35 |
+
)
|
| 36 |
+
bin_length = len(dummy_obs)
|
| 37 |
+
self.observation_space = gym.spaces.Box(0, 255, (bin_length,), np.uint8)
|
| 38 |
+
|
| 39 |
+
def step(self, action):
|
| 40 |
+
observation, reward, terminated, truncated, info = super().step(action)
|
| 41 |
+
proto_observation = self._convert_observation_to_protobuf(observation)
|
| 42 |
+
return proto_observation, reward, terminated, truncated, info
|
| 43 |
+
|
| 44 |
+
def reset(self, **kwargs):
|
| 45 |
+
observation, info = super().reset(**kwargs)
|
| 46 |
+
proto_observation = self._convert_observation_to_protobuf(observation)
|
| 47 |
+
return proto_observation, info
|
| 48 |
+
|
| 49 |
+
def _convert_observation_to_protobuf(self, observation):
|
| 50 |
+
x_pos, x_veloc, angle_pos, angle_veloc = observation
|
| 51 |
+
|
| 52 |
+
# Create the Protobuf message
|
| 53 |
+
cartpole_observation = CartPoleObservation()
|
| 54 |
+
cartpole_observation.x_pos = x_pos
|
| 55 |
+
cartpole_observation.x_veloc = x_veloc
|
| 56 |
+
cartpole_observation.angle_pos = angle_pos
|
| 57 |
+
cartpole_observation.angle_veloc = angle_veloc
|
| 58 |
+
|
| 59 |
+
# Serialize to binary string.
|
| 60 |
+
return np.frombuffer(cartpole_observation.SerializeToString(), np.uint8)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
env = CartPoleWithProtobufObservationSpace()
|
| 65 |
+
obs, info = env.reset()
|
| 66 |
+
|
| 67 |
+
# Test loading a protobuf object with data from the obs binary string
|
| 68 |
+
# (uint8 ndarray).
|
| 69 |
+
byte_str = obs.tobytes()
|
| 70 |
+
obs_protobuf = CartPoleObservation()
|
| 71 |
+
obs_protobuf.ParseFromString(byte_str)
|
| 72 |
+
print(obs_protobuf)
|
| 73 |
+
|
| 74 |
+
terminated = truncated = False
|
| 75 |
+
while not terminated and not truncated:
|
| 76 |
+
action = env.action_space.sample()
|
| 77 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 78 |
+
|
| 79 |
+
print(obs)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cliff_walking_wall_env.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium import spaces
|
| 3 |
+
|
| 4 |
+
ACTION_UP = 0
|
| 5 |
+
ACTION_RIGHT = 1
|
| 6 |
+
ACTION_DOWN = 2
|
| 7 |
+
ACTION_LEFT = 3
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CliffWalkingWallEnv(gym.Env):
|
| 11 |
+
"""Modified version of the CliffWalking environment from Farama-Foundation's
|
| 12 |
+
Gymnasium with walls instead of a cliff.
|
| 13 |
+
|
| 14 |
+
### Description
|
| 15 |
+
The board is a 4x12 matrix, with (using NumPy matrix indexing):
|
| 16 |
+
- [3, 0] or obs==36 as the start at bottom-left
|
| 17 |
+
- [3, 11] or obs==47 as the goal at bottom-right
|
| 18 |
+
- [3, 1..10] or obs==37...46 as the cliff at bottom-center
|
| 19 |
+
|
| 20 |
+
An episode terminates when the agent reaches the goal.
|
| 21 |
+
|
| 22 |
+
### Actions
|
| 23 |
+
There are 4 discrete deterministic actions:
|
| 24 |
+
- 0: move up
|
| 25 |
+
- 1: move right
|
| 26 |
+
- 2: move down
|
| 27 |
+
- 3: move left
|
| 28 |
+
You can also use the constants ACTION_UP, ACTION_RIGHT, ... defined above.
|
| 29 |
+
|
| 30 |
+
### Observations
|
| 31 |
+
There are 3x12 + 2 possible states, not including the walls. If an action
|
| 32 |
+
would move an agent into one of the walls, it simply stays in the same position.
|
| 33 |
+
|
| 34 |
+
### Reward
|
| 35 |
+
Each time step incurs -1 reward, except reaching the goal which gives +10 reward.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, seed=42) -> None:
|
| 39 |
+
self.observation_space = spaces.Discrete(48)
|
| 40 |
+
self.action_space = spaces.Discrete(4)
|
| 41 |
+
self.observation_space.seed(seed)
|
| 42 |
+
self.action_space.seed(seed)
|
| 43 |
+
|
| 44 |
+
def reset(self, *, seed=None, options=None):
|
| 45 |
+
self.position = 36
|
| 46 |
+
return self.position, {}
|
| 47 |
+
|
| 48 |
+
def step(self, action):
|
| 49 |
+
x = self.position // 12
|
| 50 |
+
y = self.position % 12
|
| 51 |
+
# UP
|
| 52 |
+
if action == ACTION_UP:
|
| 53 |
+
x = max(x - 1, 0)
|
| 54 |
+
# RIGHT
|
| 55 |
+
elif action == ACTION_RIGHT:
|
| 56 |
+
if self.position != 36:
|
| 57 |
+
y = min(y + 1, 11)
|
| 58 |
+
# DOWN
|
| 59 |
+
elif action == ACTION_DOWN:
|
| 60 |
+
if self.position < 25 or self.position > 34:
|
| 61 |
+
x = min(x + 1, 3)
|
| 62 |
+
# LEFT
|
| 63 |
+
elif action == ACTION_LEFT:
|
| 64 |
+
if self.position != 47:
|
| 65 |
+
y = max(y - 1, 0)
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"action {action} not in {self.action_space}")
|
| 68 |
+
self.position = x * 12 + y
|
| 69 |
+
done = self.position == 47
|
| 70 |
+
reward = -1 if not done else 10
|
| 71 |
+
return self.position, reward, done, False, {}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/correlated_actions_env.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CorrelatedActionsEnv(gym.Env):
|
| 8 |
+
"""Environment that can only be solved through an autoregressive action model.
|
| 9 |
+
|
| 10 |
+
In each step, the agent observes a random number (between -1 and 1) and has
|
| 11 |
+
to choose two actions, a1 (discrete, 0, 1, or 2) and a2 (cont. between -1 and 1).
|
| 12 |
+
|
| 13 |
+
The reward is constructed such that actions need to be correlated to succeed. It's
|
| 14 |
+
impossible for the network to learn each action head separately.
|
| 15 |
+
|
| 16 |
+
There are two reward components:
|
| 17 |
+
The first is the negative absolute value of the delta between 1.0 and the sum of
|
| 18 |
+
obs + a1. For example, if obs is -0.3 and a1 was sampled to be 1, then the value of
|
| 19 |
+
the first reward component is:
|
| 20 |
+
r1 = -abs(1.0 - [obs+a1]) = -abs(1.0 - (-0.3 + 1)) = -abs(0.3) = -0.3
|
| 21 |
+
The second reward component is computed as the negative absolute value
|
| 22 |
+
of `obs + a1 + a2`. For example, if obs is 0.5, a1 was sampled to be 0,
|
| 23 |
+
and a2 was sampled to be -0.7, then the value of the second reward component is:
|
| 24 |
+
r2 = -abs(obs + a1 + a2) = -abs(0.5 + 0 - 0.7)) = -abs(-0.2) = -0.2
|
| 25 |
+
|
| 26 |
+
Because of this specific reward function, the agent must learn to optimally sample
|
| 27 |
+
a1 based on the observation and to optimally sample a2, based on the observation
|
| 28 |
+
AND the sampled value of a1.
|
| 29 |
+
|
| 30 |
+
One way to effectively learn this is through correlated action
|
| 31 |
+
distributions, e.g., in examples/actions/auto_regressive_actions.py
|
| 32 |
+
|
| 33 |
+
The game ends after the first step.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config=None):
|
| 37 |
+
super().__init__()
|
| 38 |
+
# Observation space (single continuous value between -1. and 1.).
|
| 39 |
+
self.observation_space = gym.spaces.Box(-1.0, 1.0, shape=(1,), dtype=np.float32)
|
| 40 |
+
|
| 41 |
+
# Action space (discrete action a1 and continuous action a2).
|
| 42 |
+
self.action_space = gym.spaces.Tuple(
|
| 43 |
+
[gym.spaces.Discrete(3), gym.spaces.Box(-2.0, 2.0, (1,), np.float32)]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Internal state for the environment (e.g., could represent a factor
|
| 47 |
+
# influencing the relationship)
|
| 48 |
+
self.obs = None
|
| 49 |
+
|
| 50 |
+
def reset(
|
| 51 |
+
self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
|
| 52 |
+
):
|
| 53 |
+
"""Reset the environment to an initial state."""
|
| 54 |
+
super().reset(seed=seed, options=options)
|
| 55 |
+
|
| 56 |
+
# Randomly initialize the observation between -1 and 1.
|
| 57 |
+
self.obs = np.random.uniform(-1, 1, size=(1,))
|
| 58 |
+
|
| 59 |
+
return self.obs, {}
|
| 60 |
+
|
| 61 |
+
def step(self, action):
|
| 62 |
+
"""Apply the autoregressive action and return step information."""
|
| 63 |
+
|
| 64 |
+
# Extract individual action components, a1 and a2.
|
| 65 |
+
a1, a2 = action
|
| 66 |
+
a2 = a2[0] # dissolve shape=(1,)
|
| 67 |
+
|
| 68 |
+
# r1 depends on how well a1 is aligned to obs:
|
| 69 |
+
r1 = -abs(1.0 - (self.obs[0] + a1))
|
| 70 |
+
# r2 depends on how well a2 is aligned to both, obs and a1.
|
| 71 |
+
r2 = -abs(self.obs[0] + a1 + a2)
|
| 72 |
+
|
| 73 |
+
reward = r1 + r2
|
| 74 |
+
|
| 75 |
+
# Optionally: add some noise or complexity to the reward function
|
| 76 |
+
# reward += np.random.normal(0, 0.01) # Small noise can be added
|
| 77 |
+
|
| 78 |
+
# Terminate after each step (no episode length in this simple example)
|
| 79 |
+
return self.obs, reward, True, False, {}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/d4rl_env.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
8 Environments from D4RL Environment.
|
| 3 |
+
Use fully qualified class-path in your configs:
|
| 4 |
+
e.g. "env": "ray.rllib.examples.envs.classes.d4rl_env.halfcheetah_random".
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gymnasium as gym
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import d4rl
|
| 11 |
+
|
| 12 |
+
d4rl.__name__ # Fool LINTer.
|
| 13 |
+
except ImportError:
|
| 14 |
+
d4rl = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def halfcheetah_random():
|
| 18 |
+
return gym.make("halfcheetah-random-v0")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def halfcheetah_medium():
|
| 22 |
+
return gym.make("halfcheetah-medium-v0")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def halfcheetah_expert():
|
| 26 |
+
return gym.make("halfcheetah-expert-v0")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def halfcheetah_medium_replay():
|
| 30 |
+
return gym.make("halfcheetah-medium-replay-v0")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def hopper_random():
|
| 34 |
+
return gym.make("hopper-random-v0")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def hopper_medium():
|
| 38 |
+
return gym.make("hopper-medium-v0")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def hopper_expert():
|
| 42 |
+
return gym.make("hopper-expert-v0")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def hopper_medium_replay():
|
| 46 |
+
return gym.make("hopper-medium-replay-v0")
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/debug_counter_env.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DebugCounterEnv(gym.Env):
|
| 8 |
+
"""Simple Env that yields a ts counter as observation (0-based).
|
| 9 |
+
|
| 10 |
+
Actions have no effect.
|
| 11 |
+
The episode length is always 15.
|
| 12 |
+
Reward is always: current ts % 3.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config=None):
|
| 16 |
+
config = config or {}
|
| 17 |
+
self.action_space = gym.spaces.Discrete(2)
|
| 18 |
+
self.observation_space = gym.spaces.Box(0, 100, (1,), dtype=np.float32)
|
| 19 |
+
self.start_at_t = int(config.get("start_at_t", 0))
|
| 20 |
+
self.i = self.start_at_t
|
| 21 |
+
|
| 22 |
+
def reset(self, *, seed=None, options=None):
|
| 23 |
+
self.i = self.start_at_t
|
| 24 |
+
return self._get_obs(), {}
|
| 25 |
+
|
| 26 |
+
def step(self, action):
|
| 27 |
+
self.i += 1
|
| 28 |
+
terminated = False
|
| 29 |
+
truncated = self.i >= 15 + self.start_at_t
|
| 30 |
+
return self._get_obs(), float(self.i % 3), terminated, truncated, {}
|
| 31 |
+
|
| 32 |
+
def _get_obs(self):
|
| 33 |
+
return np.array([self.i], dtype=np.float32)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MultiAgentDebugCounterEnv(MultiAgentEnv):
|
| 37 |
+
def __init__(self, config):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.num_agents = config["num_agents"]
|
| 40 |
+
self.base_episode_len = config.get("base_episode_len", 103)
|
| 41 |
+
|
| 42 |
+
# Observation dims:
|
| 43 |
+
# 0=agent ID.
|
| 44 |
+
# 1=episode ID (0.0 for obs after reset).
|
| 45 |
+
# 2=env ID (0.0 for obs after reset).
|
| 46 |
+
# 3=ts (of the agent).
|
| 47 |
+
self.observation_space = gym.spaces.Dict(
|
| 48 |
+
{
|
| 49 |
+
aid: gym.spaces.Box(float("-inf"), float("inf"), (4,))
|
| 50 |
+
for aid in range(self.num_agents)
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Actions are always:
|
| 55 |
+
# (episodeID, envID) as floats.
|
| 56 |
+
self.action_space = gym.spaces.Dict(
|
| 57 |
+
{
|
| 58 |
+
aid: gym.spaces.Box(-float("inf"), float("inf"), shape=(2,))
|
| 59 |
+
for aid in range(self.num_agents)
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.timesteps = [0] * self.num_agents
|
| 64 |
+
self.terminateds = set()
|
| 65 |
+
self.truncateds = set()
|
| 66 |
+
|
| 67 |
+
def reset(self, *, seed=None, options=None):
|
| 68 |
+
self.timesteps = [0] * self.num_agents
|
| 69 |
+
self.terminateds = set()
|
| 70 |
+
self.truncateds = set()
|
| 71 |
+
return {
|
| 72 |
+
i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
|
| 73 |
+
for i in range(self.num_agents)
|
| 74 |
+
}, {}
|
| 75 |
+
|
| 76 |
+
def step(self, action_dict):
|
| 77 |
+
obs, rew, terminated, truncated = {}, {}, {}, {}
|
| 78 |
+
for i, action in action_dict.items():
|
| 79 |
+
self.timesteps[i] += 1
|
| 80 |
+
obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
|
| 81 |
+
rew[i] = self.timesteps[i] % 3
|
| 82 |
+
terminated[i] = False
|
| 83 |
+
truncated[i] = (
|
| 84 |
+
True if self.timesteps[i] > self.base_episode_len + i else False
|
| 85 |
+
)
|
| 86 |
+
if terminated[i]:
|
| 87 |
+
self.terminateds.add(i)
|
| 88 |
+
if truncated[i]:
|
| 89 |
+
self.truncateds.add(i)
|
| 90 |
+
terminated["__all__"] = len(self.terminateds) == self.num_agents
|
| 91 |
+
truncated["__all__"] = len(self.truncateds) == self.num_agents
|
| 92 |
+
return obs, rew, terminated, truncated, {}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/deterministic_envs.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def create_cartpole_deterministic(config):
|
| 5 |
+
env = gym.make("CartPole-v1")
|
| 6 |
+
env.reset(seed=config.get("seed", 0))
|
| 7 |
+
return env
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_pendulum_deterministic(config):
|
| 11 |
+
env = gym.make("Pendulum-v1")
|
| 12 |
+
env.reset(seed=config.get("seed", 0))
|
| 13 |
+
return env
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/dm_control_suite.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
8 Environments from Deepmind Control Suite
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def acrobot_swingup(
|
| 9 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 10 |
+
):
|
| 11 |
+
return DMCEnv(
|
| 12 |
+
"acrobot",
|
| 13 |
+
"swingup",
|
| 14 |
+
from_pixels=from_pixels,
|
| 15 |
+
height=height,
|
| 16 |
+
width=width,
|
| 17 |
+
frame_skip=frame_skip,
|
| 18 |
+
channels_first=channels_first,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def walker_walk(
|
| 23 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 24 |
+
):
|
| 25 |
+
return DMCEnv(
|
| 26 |
+
"walker",
|
| 27 |
+
"walk",
|
| 28 |
+
from_pixels=from_pixels,
|
| 29 |
+
height=height,
|
| 30 |
+
width=width,
|
| 31 |
+
frame_skip=frame_skip,
|
| 32 |
+
channels_first=channels_first,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def hopper_hop(
|
| 37 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 38 |
+
):
|
| 39 |
+
return DMCEnv(
|
| 40 |
+
"hopper",
|
| 41 |
+
"hop",
|
| 42 |
+
from_pixels=from_pixels,
|
| 43 |
+
height=height,
|
| 44 |
+
width=width,
|
| 45 |
+
frame_skip=frame_skip,
|
| 46 |
+
channels_first=channels_first,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def hopper_stand(
|
| 51 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 52 |
+
):
|
| 53 |
+
return DMCEnv(
|
| 54 |
+
"hopper",
|
| 55 |
+
"stand",
|
| 56 |
+
from_pixels=from_pixels,
|
| 57 |
+
height=height,
|
| 58 |
+
width=width,
|
| 59 |
+
frame_skip=frame_skip,
|
| 60 |
+
channels_first=channels_first,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def cheetah_run(
|
| 65 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 66 |
+
):
|
| 67 |
+
return DMCEnv(
|
| 68 |
+
"cheetah",
|
| 69 |
+
"run",
|
| 70 |
+
from_pixels=from_pixels,
|
| 71 |
+
height=height,
|
| 72 |
+
width=width,
|
| 73 |
+
frame_skip=frame_skip,
|
| 74 |
+
channels_first=channels_first,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def walker_run(
|
| 79 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 80 |
+
):
|
| 81 |
+
return DMCEnv(
|
| 82 |
+
"walker",
|
| 83 |
+
"run",
|
| 84 |
+
from_pixels=from_pixels,
|
| 85 |
+
height=height,
|
| 86 |
+
width=width,
|
| 87 |
+
frame_skip=frame_skip,
|
| 88 |
+
channels_first=channels_first,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def pendulum_swingup(
|
| 93 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 94 |
+
):
|
| 95 |
+
return DMCEnv(
|
| 96 |
+
"pendulum",
|
| 97 |
+
"swingup",
|
| 98 |
+
from_pixels=from_pixels,
|
| 99 |
+
height=height,
|
| 100 |
+
width=width,
|
| 101 |
+
frame_skip=frame_skip,
|
| 102 |
+
channels_first=channels_first,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def cartpole_swingup(
|
| 107 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 108 |
+
):
|
| 109 |
+
return DMCEnv(
|
| 110 |
+
"cartpole",
|
| 111 |
+
"swingup",
|
| 112 |
+
from_pixels=from_pixels,
|
| 113 |
+
height=height,
|
| 114 |
+
width=width,
|
| 115 |
+
frame_skip=frame_skip,
|
| 116 |
+
channels_first=channels_first,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def humanoid_walk(
|
| 121 |
+
from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
|
| 122 |
+
):
|
| 123 |
+
return DMCEnv(
|
| 124 |
+
"humanoid",
|
| 125 |
+
"walk",
|
| 126 |
+
from_pixels=from_pixels,
|
| 127 |
+
height=height,
|
| 128 |
+
width=width,
|
| 129 |
+
frame_skip=frame_skip,
|
| 130 |
+
channels_first=channels_first,
|
| 131 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_using_remote_actor.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example of an environment that uses a named remote actor as parameter
|
| 3 |
+
server.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
|
| 7 |
+
from gymnasium.utils import seeding
|
| 8 |
+
|
| 9 |
+
import ray
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@ray.remote
|
| 13 |
+
class ParameterStorage:
|
| 14 |
+
def get_params(self, rng):
|
| 15 |
+
return {
|
| 16 |
+
"MASSCART": rng.uniform(low=0.5, high=2.0),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CartPoleWithRemoteParamServer(CartPoleEnv):
|
| 21 |
+
"""CartPoleMassEnv varies the weights of the cart and the pole."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, env_config):
|
| 24 |
+
self.env_config = env_config
|
| 25 |
+
super().__init__()
|
| 26 |
+
# Get our param server (remote actor) by name.
|
| 27 |
+
self._handler = ray.get_actor(env_config.get("param_server", "param-server"))
|
| 28 |
+
self.rng_seed = None
|
| 29 |
+
self.np_random, _ = seeding.np_random(self.rng_seed)
|
| 30 |
+
|
| 31 |
+
def reset(self, *, seed=None, options=None):
|
| 32 |
+
if seed is not None:
|
| 33 |
+
self.rng_seed = int(seed)
|
| 34 |
+
self.np_random, _ = seeding.np_random(seed)
|
| 35 |
+
print(
|
| 36 |
+
f"Seeding env (worker={self.env_config.worker_index}) " f"with {seed}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Pass in our RNG to guarantee no race conditions.
|
| 40 |
+
# If `self._handler` had its own RNG, this may clash with other
|
| 41 |
+
# envs trying to use the same param-server.
|
| 42 |
+
params = ray.get(self._handler.get_params.remote(self.np_random))
|
| 43 |
+
|
| 44 |
+
# IMPORTANT: Advance the state of our RNG (self._rng was passed
|
| 45 |
+
# above via ray (serialized) and thus not altered locally here!).
|
| 46 |
+
# Or create a new RNG from another random number:
|
| 47 |
+
# Seed the RNG with a deterministic seed if set, otherwise, create
|
| 48 |
+
# a random one.
|
| 49 |
+
new_seed = int(
|
| 50 |
+
self.np_random.integers(0, 1000000) if not self.rng_seed else self.rng_seed
|
| 51 |
+
)
|
| 52 |
+
self.np_random, _ = seeding.np_random(new_seed)
|
| 53 |
+
|
| 54 |
+
print(
|
| 55 |
+
f"Env worker-idx={self.env_config.worker_index} "
|
| 56 |
+
f"mass={params['MASSCART']}"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.masscart = params["MASSCART"]
|
| 60 |
+
self.total_mass = self.masspole + self.masscart
|
| 61 |
+
self.polemass_length = self.masspole * self.length
|
| 62 |
+
|
| 63 |
+
return super().reset()
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_with_subprocess.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium.spaces import Discrete
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EnvWithSubprocess(gym.Env):
|
| 9 |
+
"""An env that spawns a subprocess."""
|
| 10 |
+
|
| 11 |
+
# Dummy command to run as a subprocess with a unique name
|
| 12 |
+
UNIQUE_CMD = "sleep 20"
|
| 13 |
+
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
self.UNIQUE_FILE_0 = config["tmp_file1"]
|
| 16 |
+
self.UNIQUE_FILE_1 = config["tmp_file2"]
|
| 17 |
+
self.UNIQUE_FILE_2 = config["tmp_file3"]
|
| 18 |
+
self.UNIQUE_FILE_3 = config["tmp_file4"]
|
| 19 |
+
|
| 20 |
+
self.action_space = Discrete(2)
|
| 21 |
+
self.observation_space = Discrete(2)
|
| 22 |
+
# Subprocess that should be cleaned up.
|
| 23 |
+
self.subproc = subprocess.Popen(self.UNIQUE_CMD.split(" "), shell=False)
|
| 24 |
+
self.config = config
|
| 25 |
+
# Exit handler should be called.
|
| 26 |
+
atexit.register(lambda: self.subproc.kill())
|
| 27 |
+
if config.worker_index == 0:
|
| 28 |
+
atexit.register(lambda: os.unlink(self.UNIQUE_FILE_0))
|
| 29 |
+
else:
|
| 30 |
+
atexit.register(lambda: os.unlink(self.UNIQUE_FILE_1))
|
| 31 |
+
|
| 32 |
+
def close(self):
|
| 33 |
+
if self.config.worker_index == 0:
|
| 34 |
+
os.unlink(self.UNIQUE_FILE_2)
|
| 35 |
+
else:
|
| 36 |
+
os.unlink(self.UNIQUE_FILE_3)
|
| 37 |
+
|
| 38 |
+
def reset(self, *, seed=None, options=None):
|
| 39 |
+
return 0, {}
|
| 40 |
+
|
| 41 |
+
def step(self, action):
|
| 42 |
+
return 0, 0, True, False, {}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/fast_image_env.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.spaces import Box, Discrete
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FastImageEnv(gym.Env):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
self.zeros = np.zeros((84, 84, 4))
|
| 9 |
+
self.action_space = Discrete(2)
|
| 10 |
+
self.observation_space = Box(0.0, 1.0, shape=(84, 84, 4), dtype=np.float32)
|
| 11 |
+
self.i = 0
|
| 12 |
+
|
| 13 |
+
def reset(self, *, seed=None, options=None):
|
| 14 |
+
self.i = 0
|
| 15 |
+
return self.zeros, {}
|
| 16 |
+
|
| 17 |
+
def step(self, action):
|
| 18 |
+
self.i += 1
|
| 19 |
+
done = truncated = self.i > 1000
|
| 20 |
+
return self.zeros, 1, done, truncated, {}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/gpu_requiring_env.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray.rllib.examples.envs.classes.simple_corridor import SimpleCorridor
|
| 5 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 6 |
+
|
| 7 |
+
torch, _ = try_import_torch()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GPURequiringEnv(SimpleCorridor):
|
| 11 |
+
"""A dummy env that requires a GPU in order to work.
|
| 12 |
+
|
| 13 |
+
The env here is a simple corridor env that additionally simulates a GPU
|
| 14 |
+
check in its constructor via `ray.get_gpu_ids()`. If this returns an
|
| 15 |
+
empty list, we raise an error.
|
| 16 |
+
|
| 17 |
+
To make this env work, use `num_gpus_per_env_runner > 0` (RolloutWorkers
|
| 18 |
+
requesting this many GPUs each) and - maybe - `num_gpus > 0` in case
|
| 19 |
+
your local worker/driver must have an env as well. However, this is
|
| 20 |
+
only the case if `create_env_on_driver`=True (default is False).
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config=None):
|
| 24 |
+
super().__init__(config)
|
| 25 |
+
|
| 26 |
+
# Fake-require some GPUs (at least one).
|
| 27 |
+
# If your local worker's env (`create_env_on_driver`=True) does not
|
| 28 |
+
# necessarily require a GPU, you can perform the below assertion only
|
| 29 |
+
# if `config.worker_index != 0`.
|
| 30 |
+
gpus_available = ray.get_gpu_ids()
|
| 31 |
+
print(f"{type(self).__name__} can see GPUs={gpus_available}")
|
| 32 |
+
|
| 33 |
+
# Create a dummy tensor on the GPU.
|
| 34 |
+
if len(gpus_available) > 0 and torch:
|
| 35 |
+
self._tensor = torch.from_numpy(np.random.random_sample(size=(42, 42))).to(
|
| 36 |
+
f"cuda:{gpus_available[0]}"
|
| 37 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/look_and_push.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LookAndPush(gym.Env):
|
| 6 |
+
"""Memory-requiring Env: Best sequence of actions depends on prev. states.
|
| 7 |
+
|
| 8 |
+
Optimal behavior:
|
| 9 |
+
0) a=0 -> observe next state (s'), which is the "hidden" state.
|
| 10 |
+
If a=1 here, the hidden state is not observed.
|
| 11 |
+
1) a=1 to always jump to s=2 (not matter what the prev. state was).
|
| 12 |
+
2) a=1 to move to s=3.
|
| 13 |
+
3) a=1 to move to s=4.
|
| 14 |
+
4) a=0 OR 1 depending on s' observed after 0): +10 reward and done.
|
| 15 |
+
otherwise: -10 reward and done.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.action_space = gym.spaces.Discrete(2)
|
| 20 |
+
self.observation_space = gym.spaces.Discrete(5)
|
| 21 |
+
self._state = None
|
| 22 |
+
self._case = None
|
| 23 |
+
|
| 24 |
+
def reset(self, *, seed=None, options=None):
|
| 25 |
+
self._state = 2
|
| 26 |
+
self._case = np.random.choice(2)
|
| 27 |
+
return self._state, {}
|
| 28 |
+
|
| 29 |
+
def step(self, action):
|
| 30 |
+
assert self.action_space.contains(action)
|
| 31 |
+
|
| 32 |
+
if self._state == 4:
|
| 33 |
+
if action and self._case:
|
| 34 |
+
return self._state, 10.0, True, {}
|
| 35 |
+
else:
|
| 36 |
+
return self._state, -10, True, {}
|
| 37 |
+
else:
|
| 38 |
+
if action:
|
| 39 |
+
if self._state == 0:
|
| 40 |
+
self._state = 2
|
| 41 |
+
else:
|
| 42 |
+
self._state += 1
|
| 43 |
+
elif self._state == 2:
|
| 44 |
+
self._state = self._case
|
| 45 |
+
|
| 46 |
+
return self._state, -1, False, False, {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class OneHot(gym.Wrapper):
|
| 50 |
+
def __init__(self, env):
|
| 51 |
+
super(OneHot, self).__init__(env)
|
| 52 |
+
self.observation_space = gym.spaces.Box(0.0, 1.0, (env.observation_space.n,))
|
| 53 |
+
|
| 54 |
+
def reset(self, *, seed=None, options=None):
|
| 55 |
+
obs, info = self.env.reset(seed=seed, options=options)
|
| 56 |
+
return self._encode_obs(obs), info
|
| 57 |
+
|
| 58 |
+
def step(self, action):
|
| 59 |
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
| 60 |
+
return self._encode_obs(obs), reward, terminated, truncated, info
|
| 61 |
+
|
| 62 |
+
def _encode_obs(self, obs):
|
| 63 |
+
new_obs = np.ones(self.env.observation_space.n)
|
| 64 |
+
new_obs[obs] = 1.0
|
| 65 |
+
return new_obs
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/memory_leaking_env.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import uuid
|
| 3 |
+
|
| 4 |
+
from ray.rllib.examples.envs.classes.random_env import RandomEnv
|
| 5 |
+
from ray.rllib.utils.annotations import override
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MemoryLeakingEnv(RandomEnv):
|
| 11 |
+
"""An env that leaks very little memory.
|
| 12 |
+
|
| 13 |
+
Useful for proving that our memory-leak tests can catch the
|
| 14 |
+
slightest leaks.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config=None):
|
| 18 |
+
super().__init__(config)
|
| 19 |
+
self._leak = {}
|
| 20 |
+
self._steps_after_reset = 0
|
| 21 |
+
|
| 22 |
+
@override(RandomEnv)
|
| 23 |
+
def reset(self, *, seed=None, options=None):
|
| 24 |
+
self._steps_after_reset = 0
|
| 25 |
+
return super().reset(seed=seed, options=options)
|
| 26 |
+
|
| 27 |
+
@override(RandomEnv)
|
| 28 |
+
def step(self, action):
|
| 29 |
+
self._steps_after_reset += 1
|
| 30 |
+
|
| 31 |
+
# Only leak once an episode.
|
| 32 |
+
if self._steps_after_reset == 2:
|
| 33 |
+
self._leak[uuid.uuid4().hex.upper()] = 1
|
| 34 |
+
|
| 35 |
+
return super().step(action)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/mock_env.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from ray.rllib.env.vector_env import VectorEnv
|
| 6 |
+
from ray.rllib.utils.annotations import override
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MockEnv(gym.Env):
|
| 10 |
+
"""Mock environment for testing purposes.
|
| 11 |
+
|
| 12 |
+
Observation=0, reward=1.0, episode-len is configurable.
|
| 13 |
+
Actions are ignored.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, episode_length, config=None):
|
| 17 |
+
self.episode_length = episode_length
|
| 18 |
+
self.config = config
|
| 19 |
+
self.i = 0
|
| 20 |
+
self.observation_space = gym.spaces.Discrete(1)
|
| 21 |
+
self.action_space = gym.spaces.Discrete(2)
|
| 22 |
+
|
| 23 |
+
def reset(self, *, seed=None, options=None):
|
| 24 |
+
self.i = 0
|
| 25 |
+
return 0, {}
|
| 26 |
+
|
| 27 |
+
def step(self, action):
|
| 28 |
+
self.i += 1
|
| 29 |
+
terminated = truncated = self.i >= self.episode_length
|
| 30 |
+
return 0, 1.0, terminated, truncated, {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MockEnv2(gym.Env):
|
| 34 |
+
"""Mock environment for testing purposes.
|
| 35 |
+
|
| 36 |
+
Observation=ts (discrete space!), reward=100.0, episode-len is
|
| 37 |
+
configurable. Actions are ignored.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
metadata = {
|
| 41 |
+
"render.modes": ["rgb_array"],
|
| 42 |
+
}
|
| 43 |
+
render_mode: Optional[str] = "rgb_array"
|
| 44 |
+
|
| 45 |
+
def __init__(self, episode_length):
|
| 46 |
+
self.episode_length = episode_length
|
| 47 |
+
self.i = 0
|
| 48 |
+
self.observation_space = gym.spaces.Discrete(self.episode_length + 1)
|
| 49 |
+
self.action_space = gym.spaces.Discrete(2)
|
| 50 |
+
self.rng_seed = None
|
| 51 |
+
|
| 52 |
+
def reset(self, *, seed=None, options=None):
|
| 53 |
+
self.i = 0
|
| 54 |
+
if seed is not None:
|
| 55 |
+
self.rng_seed = seed
|
| 56 |
+
return self.i, {}
|
| 57 |
+
|
| 58 |
+
def step(self, action):
|
| 59 |
+
self.i += 1
|
| 60 |
+
terminated = truncated = self.i >= self.episode_length
|
| 61 |
+
return self.i, 100.0, terminated, truncated, {}
|
| 62 |
+
|
| 63 |
+
def render(self):
|
| 64 |
+
# Just generate a random image here for demonstration purposes.
|
| 65 |
+
# Also see `gym/envs/classic_control/cartpole.py` for
|
| 66 |
+
# an example on how to use a Viewer object.
|
| 67 |
+
return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class MockEnv3(gym.Env):
|
| 71 |
+
"""Mock environment for testing purposes.
|
| 72 |
+
|
| 73 |
+
Observation=ts (discrete space!), reward=100.0, episode-len is
|
| 74 |
+
configurable. Actions are ignored.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, episode_length):
|
| 78 |
+
self.episode_length = episode_length
|
| 79 |
+
self.i = 0
|
| 80 |
+
self.observation_space = gym.spaces.Discrete(100)
|
| 81 |
+
self.action_space = gym.spaces.Discrete(2)
|
| 82 |
+
|
| 83 |
+
def reset(self, *, seed=None, options=None):
|
| 84 |
+
self.i = 0
|
| 85 |
+
return self.i, {"timestep": 0}
|
| 86 |
+
|
| 87 |
+
def step(self, action):
|
| 88 |
+
self.i += 1
|
| 89 |
+
terminated = truncated = self.i >= self.episode_length
|
| 90 |
+
return self.i, self.i, terminated, truncated, {"timestep": self.i}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class VectorizedMockEnv(VectorEnv):
|
| 94 |
+
"""Vectorized version of the MockEnv.
|
| 95 |
+
|
| 96 |
+
Contains `num_envs` MockEnv instances, each one having its own
|
| 97 |
+
`episode_length` horizon.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, episode_length, num_envs):
|
| 101 |
+
super().__init__(
|
| 102 |
+
observation_space=gym.spaces.Discrete(1),
|
| 103 |
+
action_space=gym.spaces.Discrete(2),
|
| 104 |
+
num_envs=num_envs,
|
| 105 |
+
)
|
| 106 |
+
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
|
| 107 |
+
|
| 108 |
+
@override(VectorEnv)
|
| 109 |
+
def vector_reset(self, *, seeds=None, options=None):
|
| 110 |
+
seeds = seeds or [None] * self.num_envs
|
| 111 |
+
options = options or [None] * self.num_envs
|
| 112 |
+
obs_and_infos = [
|
| 113 |
+
e.reset(seed=seeds[i], options=options[i]) for i, e in enumerate(self.envs)
|
| 114 |
+
]
|
| 115 |
+
return [oi[0] for oi in obs_and_infos], [oi[1] for oi in obs_and_infos]
|
| 116 |
+
|
| 117 |
+
@override(VectorEnv)
|
| 118 |
+
def reset_at(self, index, *, seed=None, options=None):
|
| 119 |
+
return self.envs[index].reset(seed=seed, options=options)
|
| 120 |
+
|
| 121 |
+
@override(VectorEnv)
|
| 122 |
+
def vector_step(self, actions):
|
| 123 |
+
obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch = (
|
| 124 |
+
[],
|
| 125 |
+
[],
|
| 126 |
+
[],
|
| 127 |
+
[],
|
| 128 |
+
[],
|
| 129 |
+
)
|
| 130 |
+
for i in range(len(self.envs)):
|
| 131 |
+
obs, rew, terminated, truncated, info = self.envs[i].step(actions[i])
|
| 132 |
+
obs_batch.append(obs)
|
| 133 |
+
rew_batch.append(rew)
|
| 134 |
+
terminated_batch.append(terminated)
|
| 135 |
+
truncated_batch.append(truncated)
|
| 136 |
+
info_batch.append(info)
|
| 137 |
+
return obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch
|
| 138 |
+
|
| 139 |
+
@override(VectorEnv)
|
| 140 |
+
def get_sub_environments(self):
|
| 141 |
+
return self.envs
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class MockVectorEnv(VectorEnv):
|
| 145 |
+
"""A custom vector env that uses a single(!) CartPole sub-env.
|
| 146 |
+
|
| 147 |
+
However, this env pretends to be a vectorized one to illustrate how one
|
| 148 |
+
could create custom VectorEnvs w/o the need for actual vectorizations of
|
| 149 |
+
sub-envs under the hood.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, episode_length, mocked_num_envs):
|
| 153 |
+
self.env = gym.make("CartPole-v1")
|
| 154 |
+
super().__init__(
|
| 155 |
+
observation_space=self.env.observation_space,
|
| 156 |
+
action_space=self.env.action_space,
|
| 157 |
+
num_envs=mocked_num_envs,
|
| 158 |
+
)
|
| 159 |
+
self.episode_len = episode_length
|
| 160 |
+
self.ts = 0
|
| 161 |
+
|
| 162 |
+
@override(VectorEnv)
|
| 163 |
+
def vector_reset(self, *, seeds=None, options=None):
|
| 164 |
+
# Since we only have one underlying sub-environment, just use the first seed
|
| 165 |
+
# and the first options dict (the user of this env thinks, there are
|
| 166 |
+
# `self.num_envs` sub-environments and sends that many seeds/options).
|
| 167 |
+
seeds = seeds or [None]
|
| 168 |
+
options = options or [None]
|
| 169 |
+
obs, infos = self.env.reset(seed=seeds[0], options=options[0])
|
| 170 |
+
# Simply repeat the single obs/infos to pretend we really have
|
| 171 |
+
# `self.num_envs` sub-environments.
|
| 172 |
+
return (
|
| 173 |
+
[obs for _ in range(self.num_envs)],
|
| 174 |
+
[infos for _ in range(self.num_envs)],
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
@override(VectorEnv)
|
| 178 |
+
def reset_at(self, index, *, seed=None, options=None):
|
| 179 |
+
self.ts = 0
|
| 180 |
+
return self.env.reset(seed=seed, options=options)
|
| 181 |
+
|
| 182 |
+
@override(VectorEnv)
|
| 183 |
+
def vector_step(self, actions):
|
| 184 |
+
self.ts += 1
|
| 185 |
+
# Apply all actions sequentially to the same env.
|
| 186 |
+
# Whether this would make a lot of sense is debatable.
|
| 187 |
+
obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch = (
|
| 188 |
+
[],
|
| 189 |
+
[],
|
| 190 |
+
[],
|
| 191 |
+
[],
|
| 192 |
+
[],
|
| 193 |
+
)
|
| 194 |
+
for i in range(self.num_envs):
|
| 195 |
+
obs, rew, terminated, truncated, info = self.env.step(actions[i])
|
| 196 |
+
# Artificially truncate once time step limit has been reached.
|
| 197 |
+
# Note: Also terminate/truncate, when underlying CartPole is
|
| 198 |
+
# terminated/truncated.
|
| 199 |
+
if self.ts >= self.episode_len:
|
| 200 |
+
truncated = True
|
| 201 |
+
obs_batch.append(obs)
|
| 202 |
+
rew_batch.append(rew)
|
| 203 |
+
terminated_batch.append(terminated)
|
| 204 |
+
truncated_batch.append(truncated)
|
| 205 |
+
info_batch.append(info)
|
| 206 |
+
if terminated or truncated:
|
| 207 |
+
remaining = self.num_envs - (i + 1)
|
| 208 |
+
obs_batch.extend([obs for _ in range(remaining)])
|
| 209 |
+
rew_batch.extend([rew for _ in range(remaining)])
|
| 210 |
+
terminated_batch.extend([terminated for _ in range(remaining)])
|
| 211 |
+
truncated_batch.extend([truncated for _ in range(remaining)])
|
| 212 |
+
info_batch.extend([info for _ in range(remaining)])
|
| 213 |
+
break
|
| 214 |
+
return obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch
|
| 215 |
+
|
| 216 |
+
@override(VectorEnv)
|
| 217 |
+
def get_sub_environments(self):
|
| 218 |
+
# You may also leave this method as-is, in which case, it would
|
| 219 |
+
# return an empty list.
|
| 220 |
+
return [self.env for _ in range(self.num_envs)]
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/bandit_envs_discrete.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium.spaces import Box, Discrete
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SimpleContextualBandit(gym.Env):
|
| 9 |
+
"""Simple env w/ 2 states and 3 actions (arms): 0, 1, and 2.
|
| 10 |
+
|
| 11 |
+
Episodes last only for one timestep, possible observations are:
|
| 12 |
+
[-1.0, 1.0] and [1.0, -1.0], where the first element is the "current context".
|
| 13 |
+
The highest reward (+10.0) is received for selecting arm 0 for context=1.0
|
| 14 |
+
and arm 2 for context=-1.0. Action 1 always yields 0.0 reward.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config=None):
|
| 18 |
+
self.action_space = Discrete(3)
|
| 19 |
+
self.observation_space = Box(low=-1.0, high=1.0, shape=(2,))
|
| 20 |
+
self.cur_context = None
|
| 21 |
+
|
| 22 |
+
def reset(self, *, seed=None, options=None):
|
| 23 |
+
self.cur_context = random.choice([-1.0, 1.0])
|
| 24 |
+
return np.array([self.cur_context, -self.cur_context]), {}
|
| 25 |
+
|
| 26 |
+
def step(self, action):
|
| 27 |
+
rewards_for_context = {
|
| 28 |
+
-1.0: [-10, 0, 10],
|
| 29 |
+
1.0: [10, 0, -10],
|
| 30 |
+
}
|
| 31 |
+
reward = rewards_for_context[self.cur_context][action]
|
| 32 |
+
return (
|
| 33 |
+
np.array([-self.cur_context, self.cur_context]),
|
| 34 |
+
reward,
|
| 35 |
+
True,
|
| 36 |
+
False,
|
| 37 |
+
{"regret": 10 - reward},
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LinearDiscreteEnv(gym.Env):
|
| 42 |
+
"""Samples data from linearly parameterized arms.
|
| 43 |
+
|
| 44 |
+
The reward for context X and arm i is given by X^T * theta_i, for some
|
| 45 |
+
latent set of parameters {theta_i : i = 1, ..., k}.
|
| 46 |
+
The thetas are sampled uniformly at random, the contexts are Gaussian,
|
| 47 |
+
and Gaussian noise is added to the rewards.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
DEFAULT_CONFIG_LINEAR = {
|
| 51 |
+
"feature_dim": 8,
|
| 52 |
+
"num_actions": 4,
|
| 53 |
+
"reward_noise_std": 0.01,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def __init__(self, config=None):
|
| 57 |
+
self.config = copy.copy(self.DEFAULT_CONFIG_LINEAR)
|
| 58 |
+
if config is not None and type(config) is dict:
|
| 59 |
+
self.config.update(config)
|
| 60 |
+
|
| 61 |
+
self.feature_dim = self.config["feature_dim"]
|
| 62 |
+
self.num_actions = self.config["num_actions"]
|
| 63 |
+
self.sigma = self.config["reward_noise_std"]
|
| 64 |
+
|
| 65 |
+
self.action_space = Discrete(self.num_actions)
|
| 66 |
+
self.observation_space = Box(low=-10, high=10, shape=(self.feature_dim,))
|
| 67 |
+
|
| 68 |
+
self.thetas = np.random.uniform(-1, 1, (self.num_actions, self.feature_dim))
|
| 69 |
+
self.thetas /= np.linalg.norm(self.thetas, axis=1, keepdims=True)
|
| 70 |
+
|
| 71 |
+
self._elapsed_steps = 0
|
| 72 |
+
self._current_context = None
|
| 73 |
+
|
| 74 |
+
def _sample_context(self):
|
| 75 |
+
return np.random.normal(scale=1 / 3, size=(self.feature_dim,))
|
| 76 |
+
|
| 77 |
+
def reset(self, *, seed=None, options=None):
|
| 78 |
+
self._current_context = self._sample_context()
|
| 79 |
+
return self._current_context, {}
|
| 80 |
+
|
| 81 |
+
def step(self, action):
|
| 82 |
+
assert (
|
| 83 |
+
self._elapsed_steps is not None
|
| 84 |
+
), "Cannot call env.step() beforecalling reset()"
|
| 85 |
+
assert action < self.num_actions, "Invalid action."
|
| 86 |
+
|
| 87 |
+
action = int(action)
|
| 88 |
+
context = self._current_context
|
| 89 |
+
rewards = self.thetas.dot(context)
|
| 90 |
+
|
| 91 |
+
opt_action = rewards.argmax()
|
| 92 |
+
|
| 93 |
+
regret = rewards.max() - rewards[action]
|
| 94 |
+
|
| 95 |
+
# Add Gaussian noise
|
| 96 |
+
rewards += np.random.normal(scale=self.sigma, size=rewards.shape)
|
| 97 |
+
|
| 98 |
+
reward = rewards[action]
|
| 99 |
+
self._current_context = self._sample_context()
|
| 100 |
+
return (
|
| 101 |
+
self._current_context,
|
| 102 |
+
reward,
|
| 103 |
+
True,
|
| 104 |
+
False,
|
| 105 |
+
{"regret": regret, "opt_action": opt_action},
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def render(self, mode="human"):
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class WheelBanditEnv(gym.Env):
|
| 113 |
+
"""Wheel bandit environment for 2D contexts
|
| 114 |
+
(see https://arxiv.org/abs/1802.09127).
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
DEFAULT_CONFIG_WHEEL = {
|
| 118 |
+
"delta": 0.5,
|
| 119 |
+
"mu_1": 1.2,
|
| 120 |
+
"mu_2": 1,
|
| 121 |
+
"mu_3": 50,
|
| 122 |
+
"std": 0.01,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
feature_dim = 2
|
| 126 |
+
num_actions = 5
|
| 127 |
+
|
| 128 |
+
def __init__(self, config=None):
|
| 129 |
+
self.config = copy.copy(self.DEFAULT_CONFIG_WHEEL)
|
| 130 |
+
if config is not None and type(config) is dict:
|
| 131 |
+
self.config.update(config)
|
| 132 |
+
|
| 133 |
+
self.delta = self.config["delta"]
|
| 134 |
+
self.mu_1 = self.config["mu_1"]
|
| 135 |
+
self.mu_2 = self.config["mu_2"]
|
| 136 |
+
self.mu_3 = self.config["mu_3"]
|
| 137 |
+
self.std = self.config["std"]
|
| 138 |
+
|
| 139 |
+
self.action_space = Discrete(self.num_actions)
|
| 140 |
+
self.observation_space = Box(low=-1, high=1, shape=(self.feature_dim,))
|
| 141 |
+
|
| 142 |
+
self.means = [self.mu_1] + 4 * [self.mu_2]
|
| 143 |
+
self._elapsed_steps = 0
|
| 144 |
+
self._current_context = None
|
| 145 |
+
|
| 146 |
+
def _sample_context(self):
|
| 147 |
+
while True:
|
| 148 |
+
state = np.random.uniform(-1, 1, self.feature_dim)
|
| 149 |
+
if np.linalg.norm(state) <= 1:
|
| 150 |
+
return state
|
| 151 |
+
|
| 152 |
+
def reset(self, *, seed=None, options=None):
|
| 153 |
+
self._current_context = self._sample_context()
|
| 154 |
+
return self._current_context, {}
|
| 155 |
+
|
| 156 |
+
def step(self, action):
|
| 157 |
+
assert (
|
| 158 |
+
self._elapsed_steps is not None
|
| 159 |
+
), "Cannot call env.step() before calling reset()"
|
| 160 |
+
|
| 161 |
+
action = int(action)
|
| 162 |
+
self._elapsed_steps += 1
|
| 163 |
+
rewards = [
|
| 164 |
+
np.random.normal(self.means[j], self.std) for j in range(self.num_actions)
|
| 165 |
+
]
|
| 166 |
+
context = self._current_context
|
| 167 |
+
r_big = np.random.normal(self.mu_3, self.std)
|
| 168 |
+
|
| 169 |
+
if np.linalg.norm(context) >= self.delta:
|
| 170 |
+
if context[0] > 0:
|
| 171 |
+
if context[1] > 0:
|
| 172 |
+
# First quadrant
|
| 173 |
+
rewards[1] = r_big
|
| 174 |
+
opt_action = 1
|
| 175 |
+
else:
|
| 176 |
+
# Fourth quadrant
|
| 177 |
+
rewards[4] = r_big
|
| 178 |
+
opt_action = 4
|
| 179 |
+
else:
|
| 180 |
+
if context[1] > 0:
|
| 181 |
+
# Second quadrant
|
| 182 |
+
rewards[2] = r_big
|
| 183 |
+
opt_action = 2
|
| 184 |
+
else:
|
| 185 |
+
# Third quadrant
|
| 186 |
+
rewards[3] = r_big
|
| 187 |
+
opt_action = 3
|
| 188 |
+
else:
|
| 189 |
+
# Smaller region where action 0 is optimal
|
| 190 |
+
opt_action = 0
|
| 191 |
+
|
| 192 |
+
reward = rewards[action]
|
| 193 |
+
|
| 194 |
+
regret = rewards[opt_action] - reward
|
| 195 |
+
|
| 196 |
+
self._current_context = self._sample_context()
|
| 197 |
+
return (
|
| 198 |
+
self._current_context,
|
| 199 |
+
reward,
|
| 200 |
+
True,
|
| 201 |
+
False,
|
| 202 |
+
{"regret": regret, "opt_action": opt_action},
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def render(self, mode="human"):
|
| 206 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/guess_the_number_game.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
|
| 3 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GuessTheNumberGame(MultiAgentEnv):
|
| 7 |
+
"""
|
| 8 |
+
We have two players, 0 and 1. Agent 0 has to pick a number between 0, MAX-1
|
| 9 |
+
at reset. Agent 1 has to guess the number by asking N questions of whether
|
| 10 |
+
of the form of "a <number> is higher|lower|equal to the picked number. The
|
| 11 |
+
action space is MultiDiscrete [3, MAX]. For the first index 0 means lower,
|
| 12 |
+
1 means higher and 2 means equal. The environment answers with yes (1) or
|
| 13 |
+
no (0) on the reward function. Every time step that agent 1 wastes agent 0
|
| 14 |
+
gets a reward of 1. After N steps the game is terminated. If agent 1
|
| 15 |
+
guesses the number correctly, it gets a reward of 100 points, otherwise it
|
| 16 |
+
gets a reward of 0. On the other hand if agent 0 wins they win 100 points.
|
| 17 |
+
The optimal policy controlling agent 1 should converge to a binary search
|
| 18 |
+
strategy.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
MAX_NUMBER = 3
|
| 22 |
+
MAX_STEPS = 20
|
| 23 |
+
|
| 24 |
+
def __init__(self, config=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self._agent_ids = {0, 1}
|
| 27 |
+
|
| 28 |
+
self.max_number = config.get("max_number", self.MAX_NUMBER)
|
| 29 |
+
self.max_steps = config.get("max_steps", self.MAX_STEPS)
|
| 30 |
+
|
| 31 |
+
self._number = None
|
| 32 |
+
self.observation_space = gym.spaces.Discrete(2)
|
| 33 |
+
self.action_space = gym.spaces.MultiDiscrete([3, self.max_number])
|
| 34 |
+
|
| 35 |
+
def reset(self, *, seed=None, options=None):
|
| 36 |
+
self._step = 0
|
| 37 |
+
self._number = None
|
| 38 |
+
# agent 0 has to pick a number. So the returned obs does not matter.
|
| 39 |
+
return {0: 0}, {}
|
| 40 |
+
|
| 41 |
+
def step(self, action_dict):
|
| 42 |
+
# get agent 0's action
|
| 43 |
+
agent_0_action = action_dict.get(0)
|
| 44 |
+
|
| 45 |
+
if agent_0_action is not None:
|
| 46 |
+
# ignore the first part of the action and look at the number
|
| 47 |
+
self._number = agent_0_action[1]
|
| 48 |
+
# next obs should tell agent 1 to start guessing.
|
| 49 |
+
# the returned reward and dones should be on agent 0 who picked a
|
| 50 |
+
# number.
|
| 51 |
+
return (
|
| 52 |
+
{1: 0},
|
| 53 |
+
{0: 0},
|
| 54 |
+
{0: False, "__all__": False},
|
| 55 |
+
{0: False, "__all__": False},
|
| 56 |
+
{},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if self._number is None:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"No number is selected by agent 0. Have you restarted "
|
| 62 |
+
"the environment?"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# get agent 1's action
|
| 66 |
+
direction, number = action_dict.get(1)
|
| 67 |
+
info = {}
|
| 68 |
+
# always the same, we don't need agent 0 to act ever again, agent 1 should keep
|
| 69 |
+
# guessing.
|
| 70 |
+
obs = {1: 0}
|
| 71 |
+
guessed_correctly = False
|
| 72 |
+
terminated = {1: False, "__all__": False}
|
| 73 |
+
truncated = {1: False, "__all__": False}
|
| 74 |
+
# everytime agent 1 does not guess correctly agent 0 gets a reward of 1.
|
| 75 |
+
if direction == 0: # lower
|
| 76 |
+
reward = {1: int(number > self._number), 0: 1}
|
| 77 |
+
elif direction == 1: # higher
|
| 78 |
+
reward = {1: int(number < self._number), 0: 1}
|
| 79 |
+
else: # equal
|
| 80 |
+
guessed_correctly = number == self._number
|
| 81 |
+
reward = {1: guessed_correctly * 100, 0: guessed_correctly * -100}
|
| 82 |
+
terminated = {1: guessed_correctly, "__all__": guessed_correctly}
|
| 83 |
+
|
| 84 |
+
self._step += 1
|
| 85 |
+
if self._step >= self.max_steps: # max number of steps episode is over
|
| 86 |
+
truncated["__all__"] = True
|
| 87 |
+
if not guessed_correctly:
|
| 88 |
+
reward[0] = 100 # agent 0 wins
|
| 89 |
+
return obs, reward, terminated, truncated, info
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_chess.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pettingzoo import AECEnv
|
| 2 |
+
from pettingzoo.classic.chess.chess import raw_env as chess_v5
|
| 3 |
+
import copy
|
| 4 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
import chess as ch
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MultiAgentChess(MultiAgentEnv):
|
| 11 |
+
"""An interface to the PettingZoo MARL environment library.
|
| 12 |
+
See: https://github.com/Farama-Foundation/PettingZoo
|
| 13 |
+
Inherits from MultiAgentEnv and exposes a given AEC
|
| 14 |
+
(actor-environment-cycle) game from the PettingZoo project via the
|
| 15 |
+
MultiAgentEnv public API.
|
| 16 |
+
Note that the wrapper has some important limitations:
|
| 17 |
+
1. All agents have the same action_spaces and observation_spaces.
|
| 18 |
+
Note: If, within your aec game, agents do not have homogeneous action /
|
| 19 |
+
observation spaces, apply SuperSuit wrappers
|
| 20 |
+
to apply padding functionality: https://github.com/Farama-Foundation/
|
| 21 |
+
SuperSuit#built-in-multi-agent-only-functions
|
| 22 |
+
2. Environments are positive sum games (-> Agents are expected to cooperate
|
| 23 |
+
to maximize reward). This isn't a hard restriction, it just that
|
| 24 |
+
standard algorithms aren't expected to work well in highly competitive
|
| 25 |
+
games.
|
| 26 |
+
|
| 27 |
+
.. testcode::
|
| 28 |
+
:skipif: True
|
| 29 |
+
|
| 30 |
+
from pettingzoo.butterfly import prison_v3
|
| 31 |
+
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
| 32 |
+
env = PettingZooEnv(prison_v3.env())
|
| 33 |
+
obs = env.reset()
|
| 34 |
+
print(obs)
|
| 35 |
+
# only returns the observation for the agent which should be stepping
|
| 36 |
+
|
| 37 |
+
.. testoutput::
|
| 38 |
+
|
| 39 |
+
{
|
| 40 |
+
'prisoner_0': array([[[0, 0, 0],
|
| 41 |
+
[0, 0, 0],
|
| 42 |
+
[0, 0, 0],
|
| 43 |
+
...,
|
| 44 |
+
[0, 0, 0],
|
| 45 |
+
[0, 0, 0],
|
| 46 |
+
[0, 0, 0]]], dtype=uint8)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.. testcode::
|
| 50 |
+
:skipif: True
|
| 51 |
+
|
| 52 |
+
obs, rewards, dones, infos = env.step({
|
| 53 |
+
"prisoner_0": 1
|
| 54 |
+
})
|
| 55 |
+
# only returns the observation, reward, info, etc, for
|
| 56 |
+
# the agent who's turn is next.
|
| 57 |
+
print(obs)
|
| 58 |
+
|
| 59 |
+
.. testoutput::
|
| 60 |
+
|
| 61 |
+
{
|
| 62 |
+
'prisoner_1': array([[[0, 0, 0],
|
| 63 |
+
[0, 0, 0],
|
| 64 |
+
[0, 0, 0],
|
| 65 |
+
...,
|
| 66 |
+
[0, 0, 0],
|
| 67 |
+
[0, 0, 0],
|
| 68 |
+
[0, 0, 0]]], dtype=uint8)
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.. testcode::
|
| 72 |
+
:skipif: True
|
| 73 |
+
|
| 74 |
+
print(rewards)
|
| 75 |
+
|
| 76 |
+
.. testoutput::
|
| 77 |
+
|
| 78 |
+
{
|
| 79 |
+
'prisoner_1': 0
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.. testcode::
|
| 83 |
+
:skipif: True
|
| 84 |
+
|
| 85 |
+
print(dones)
|
| 86 |
+
|
| 87 |
+
.. testoutput::
|
| 88 |
+
|
| 89 |
+
{
|
| 90 |
+
'prisoner_1': False, '__all__': False
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.. testcode::
|
| 94 |
+
:skipif: True
|
| 95 |
+
|
| 96 |
+
print(infos)
|
| 97 |
+
|
| 98 |
+
.. testoutput::
|
| 99 |
+
|
| 100 |
+
{
|
| 101 |
+
'prisoner_1': {'map_tuple': (1, 0)}
|
| 102 |
+
}
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
config: Dict[Any, Any] = None,
|
| 108 |
+
env: AECEnv = None,
|
| 109 |
+
):
|
| 110 |
+
super().__init__()
|
| 111 |
+
if env is None:
|
| 112 |
+
self.env = chess_v5()
|
| 113 |
+
else:
|
| 114 |
+
self.env = env
|
| 115 |
+
self.env.reset()
|
| 116 |
+
|
| 117 |
+
self.config = config
|
| 118 |
+
if self.config is None:
|
| 119 |
+
self.config = {}
|
| 120 |
+
try:
|
| 121 |
+
self.config["random_start"] = self.config["random_start"]
|
| 122 |
+
except KeyError:
|
| 123 |
+
self.config["random_start"] = 4
|
| 124 |
+
# Get first observation space, assuming all agents have equal space
|
| 125 |
+
self.observation_space = self.env.observation_space(self.env.agents[0])
|
| 126 |
+
|
| 127 |
+
# Get first action space, assuming all agents have equal space
|
| 128 |
+
self.action_space = self.env.action_space(self.env.agents[0])
|
| 129 |
+
|
| 130 |
+
assert all(
|
| 131 |
+
self.env.observation_space(agent) == self.observation_space
|
| 132 |
+
for agent in self.env.agents
|
| 133 |
+
), (
|
| 134 |
+
"Observation spaces for all agents must be identical. Perhaps "
|
| 135 |
+
"SuperSuit's pad_observations wrapper can help (useage: "
|
| 136 |
+
"`supersuit.aec_wrappers.pad_observations(env)`"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
assert all(
|
| 140 |
+
self.env.action_space(agent) == self.action_space
|
| 141 |
+
for agent in self.env.agents
|
| 142 |
+
), (
|
| 143 |
+
"Action spaces for all agents must be identical. Perhaps "
|
| 144 |
+
"SuperSuit's pad_action_space wrapper can help (usage: "
|
| 145 |
+
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
| 146 |
+
)
|
| 147 |
+
self._agent_ids = set(self.env.agents)
|
| 148 |
+
|
| 149 |
+
def random_start(self, random_moves):
|
| 150 |
+
self.env.board = ch.Board()
|
| 151 |
+
for i in range(random_moves):
|
| 152 |
+
self.env.board.push(np.random.choice(list(self.env.board.legal_moves)))
|
| 153 |
+
return self.env.board
|
| 154 |
+
|
| 155 |
+
def observe(self):
|
| 156 |
+
return {
|
| 157 |
+
self.env.agent_selection: self.env.observe(self.env.agent_selection),
|
| 158 |
+
"state": self.get_state(),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def reset(self, *args, **kwargs):
|
| 162 |
+
self.env.reset()
|
| 163 |
+
if self.config["random_start"] > 0:
|
| 164 |
+
self.random_start(self.config["random_start"])
|
| 165 |
+
return (
|
| 166 |
+
{self.env.agent_selection: self.env.observe(self.env.agent_selection)},
|
| 167 |
+
{self.env.agent_selection: {}},
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def step(self, action):
|
| 171 |
+
try:
|
| 172 |
+
self.env.step(action[self.env.agent_selection])
|
| 173 |
+
except (KeyError, IndexError):
|
| 174 |
+
self.env.step(action)
|
| 175 |
+
except AssertionError:
|
| 176 |
+
# Illegal action
|
| 177 |
+
print(action)
|
| 178 |
+
raise AssertionError("Illegal action")
|
| 179 |
+
|
| 180 |
+
obs_d = {}
|
| 181 |
+
rew_d = {}
|
| 182 |
+
done_d = {}
|
| 183 |
+
truncated_d = {}
|
| 184 |
+
info_d = {}
|
| 185 |
+
while self.env.agents:
|
| 186 |
+
obs, rew, done, trunc, info = self.env.last()
|
| 187 |
+
a = self.env.agent_selection
|
| 188 |
+
obs_d[a] = obs
|
| 189 |
+
rew_d[a] = rew
|
| 190 |
+
done_d[a] = done
|
| 191 |
+
truncated_d[a] = trunc
|
| 192 |
+
info_d[a] = info
|
| 193 |
+
if self.env.terminations[self.env.agent_selection]:
|
| 194 |
+
self.env.step(None)
|
| 195 |
+
done_d["__all__"] = True
|
| 196 |
+
truncated_d["__all__"] = True
|
| 197 |
+
else:
|
| 198 |
+
done_d["__all__"] = False
|
| 199 |
+
truncated_d["__all__"] = False
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
return obs_d, rew_d, done_d, truncated_d, info_d
|
| 203 |
+
|
| 204 |
+
def close(self):
|
| 205 |
+
self.env.close()
|
| 206 |
+
|
| 207 |
+
def seed(self, seed=None):
|
| 208 |
+
self.env.seed(seed)
|
| 209 |
+
|
| 210 |
+
def render(self, mode="human"):
|
| 211 |
+
return self.env.render(mode)
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def agent_selection(self):
|
| 215 |
+
return self.env.agent_selection
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def get_sub_environments(self):
|
| 219 |
+
return self.env.unwrapped
|
| 220 |
+
|
| 221 |
+
def get_state(self):
|
| 222 |
+
state = copy.deepcopy(self.env)
|
| 223 |
+
return state
|
| 224 |
+
|
| 225 |
+
def set_state(self, state):
|
| 226 |
+
self.env = copy.deepcopy(state)
|
| 227 |
+
return self.env.observe(self.env.agent_selection)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_connect4.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
from pettingzoo import AECEnv
|
| 5 |
+
from pettingzoo.classic.connect_four_v3 import raw_env as connect_four_v3
|
| 6 |
+
|
| 7 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MultiAgentConnect4(MultiAgentEnv):
|
| 11 |
+
"""An interface to the PettingZoo MARL environment library.
|
| 12 |
+
See: https://github.com/Farama-Foundation/PettingZoo
|
| 13 |
+
Inherits from MultiAgentEnv and exposes a given AEC
|
| 14 |
+
(actor-environment-cycle) game from the PettingZoo project via the
|
| 15 |
+
MultiAgentEnv public API.
|
| 16 |
+
Note that the wrapper has some important limitations:
|
| 17 |
+
1. All agents have the same action_spaces and observation_spaces.
|
| 18 |
+
Note: If, within your aec game, agents do not have homogeneous action /
|
| 19 |
+
observation spaces, apply SuperSuit wrappers
|
| 20 |
+
to apply padding functionality: https://github.com/Farama-Foundation/
|
| 21 |
+
SuperSuit#built-in-multi-agent-only-functions
|
| 22 |
+
2. Environments are positive sum games (-> Agents are expected to cooperate
|
| 23 |
+
to maximize reward). This isn't a hard restriction, it just that
|
| 24 |
+
standard algorithms aren't expected to work well in highly competitive
|
| 25 |
+
games.
|
| 26 |
+
|
| 27 |
+
.. testcode::
|
| 28 |
+
:skipif: True
|
| 29 |
+
|
| 30 |
+
from pettingzoo.butterfly import prison_v3
|
| 31 |
+
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
| 32 |
+
env = PettingZooEnv(prison_v3.env())
|
| 33 |
+
obs = env.reset()
|
| 34 |
+
print(obs)
|
| 35 |
+
|
| 36 |
+
.. testoutput::
|
| 37 |
+
|
| 38 |
+
# only returns the observation for the agent which should be stepping
|
| 39 |
+
{
|
| 40 |
+
'prisoner_0': array([[[0, 0, 0],
|
| 41 |
+
[0, 0, 0],
|
| 42 |
+
[0, 0, 0],
|
| 43 |
+
...,
|
| 44 |
+
[0, 0, 0],
|
| 45 |
+
[0, 0, 0],
|
| 46 |
+
[0, 0, 0]]], dtype=uint8)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.. testcode::
|
| 50 |
+
:skipif: True
|
| 51 |
+
|
| 52 |
+
obs, rewards, dones, infos = env.step({
|
| 53 |
+
"prisoner_0": 1
|
| 54 |
+
})
|
| 55 |
+
# only returns the observation, reward, info, etc, for
|
| 56 |
+
# the agent who's turn is next.
|
| 57 |
+
print(obs)
|
| 58 |
+
|
| 59 |
+
.. testoutput::
|
| 60 |
+
|
| 61 |
+
{
|
| 62 |
+
'prisoner_1': array([[[0, 0, 0],
|
| 63 |
+
[0, 0, 0],
|
| 64 |
+
[0, 0, 0],
|
| 65 |
+
...,
|
| 66 |
+
[0, 0, 0],
|
| 67 |
+
[0, 0, 0],
|
| 68 |
+
[0, 0, 0]]], dtype=uint8)
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.. testcode::
|
| 72 |
+
:skipif: True
|
| 73 |
+
|
| 74 |
+
print(rewards)
|
| 75 |
+
|
| 76 |
+
.. testoutput::
|
| 77 |
+
|
| 78 |
+
{
|
| 79 |
+
'prisoner_1': 0
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.. testcode::
|
| 83 |
+
:skipif: True
|
| 84 |
+
|
| 85 |
+
print(dones)
|
| 86 |
+
|
| 87 |
+
.. testoutput::
|
| 88 |
+
|
| 89 |
+
{
|
| 90 |
+
'prisoner_1': False, '__all__': False
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.. testcode::
|
| 94 |
+
:skipif: True
|
| 95 |
+
|
| 96 |
+
print(infos)
|
| 97 |
+
|
| 98 |
+
.. testoutput::
|
| 99 |
+
|
| 100 |
+
{
|
| 101 |
+
'prisoner_1': {'map_tuple': (1, 0)}
|
| 102 |
+
}
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
config: Dict[Any, Any] = None,
|
| 108 |
+
env: AECEnv = None,
|
| 109 |
+
):
|
| 110 |
+
super().__init__()
|
| 111 |
+
if env is None:
|
| 112 |
+
self.env = connect_four_v3()
|
| 113 |
+
else:
|
| 114 |
+
self.env = env
|
| 115 |
+
self.env.reset()
|
| 116 |
+
|
| 117 |
+
self.config = config
|
| 118 |
+
# Get first observation space, assuming all agents have equal space
|
| 119 |
+
self.observation_space = self.env.observation_space(self.env.agents[0])
|
| 120 |
+
|
| 121 |
+
# Get first action space, assuming all agents have equal space
|
| 122 |
+
self.action_space = self.env.action_space(self.env.agents[0])
|
| 123 |
+
|
| 124 |
+
assert all(
|
| 125 |
+
self.env.observation_space(agent) == self.observation_space
|
| 126 |
+
for agent in self.env.agents
|
| 127 |
+
), (
|
| 128 |
+
"Observation spaces for all agents must be identical. Perhaps "
|
| 129 |
+
"SuperSuit's pad_observations wrapper can help (useage: "
|
| 130 |
+
"`supersuit.aec_wrappers.pad_observations(env)`"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
assert all(
|
| 134 |
+
self.env.action_space(agent) == self.action_space
|
| 135 |
+
for agent in self.env.agents
|
| 136 |
+
), (
|
| 137 |
+
"Action spaces for all agents must be identical. Perhaps "
|
| 138 |
+
"SuperSuit's pad_action_space wrapper can help (usage: "
|
| 139 |
+
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
| 140 |
+
)
|
| 141 |
+
self._agent_ids = set(self.env.agents)
|
| 142 |
+
|
| 143 |
+
def observe(self):
|
| 144 |
+
return {
|
| 145 |
+
self.env.agent_selection: self.env.observe(self.env.agent_selection),
|
| 146 |
+
"state": self.get_state(),
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
def reset(self, *args, **kwargs):
|
| 150 |
+
self.env.reset()
|
| 151 |
+
return (
|
| 152 |
+
{self.env.agent_selection: self.env.observe(self.env.agent_selection)},
|
| 153 |
+
{self.env.agent_selection: {}},
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def step(self, action):
|
| 157 |
+
try:
|
| 158 |
+
self.env.step(action[self.env.agent_selection])
|
| 159 |
+
except (KeyError, IndexError):
|
| 160 |
+
self.env.step(action)
|
| 161 |
+
except AssertionError:
|
| 162 |
+
# Illegal action
|
| 163 |
+
print(action)
|
| 164 |
+
raise AssertionError("Illegal action")
|
| 165 |
+
|
| 166 |
+
obs_d = {}
|
| 167 |
+
rew_d = {}
|
| 168 |
+
done_d = {}
|
| 169 |
+
trunc_d = {}
|
| 170 |
+
info_d = {}
|
| 171 |
+
while self.env.agents:
|
| 172 |
+
obs, rew, done, trunc, info = self.env.last()
|
| 173 |
+
a = self.env.agent_selection
|
| 174 |
+
obs_d[a] = obs
|
| 175 |
+
rew_d[a] = rew
|
| 176 |
+
done_d[a] = done
|
| 177 |
+
trunc_d[a] = trunc
|
| 178 |
+
info_d[a] = info
|
| 179 |
+
if self.env.terminations[self.env.agent_selection]:
|
| 180 |
+
self.env.step(None)
|
| 181 |
+
done_d["__all__"] = True
|
| 182 |
+
trunc_d["__all__"] = True
|
| 183 |
+
else:
|
| 184 |
+
done_d["__all__"] = False
|
| 185 |
+
trunc_d["__all__"] = False
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
return obs_d, rew_d, done_d, trunc_d, info_d
|
| 189 |
+
|
| 190 |
+
def close(self):
|
| 191 |
+
self.env.close()
|
| 192 |
+
|
| 193 |
+
def seed(self, seed=None):
|
| 194 |
+
self.env.seed(seed)
|
| 195 |
+
|
| 196 |
+
def render(self, mode="human"):
|
| 197 |
+
return self.env.render(mode)
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def agent_selection(self):
|
| 201 |
+
return self.env.agent_selection
|
| 202 |
+
|
| 203 |
+
@property
|
| 204 |
+
def get_sub_environments(self):
|
| 205 |
+
return self.env.unwrapped
|
| 206 |
+
|
| 207 |
+
def get_state(self):
|
| 208 |
+
state = copy.deepcopy(self.env)
|
| 209 |
+
return state
|
| 210 |
+
|
| 211 |
+
def set_state(self, state):
|
| 212 |
+
self.env = copy.deepcopy(state)
|
| 213 |
+
return self.env.observe(self.env.agent_selection)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# __sphinx_doc_1_begin__
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
|
| 4 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RockPaperScissors(MultiAgentEnv):
|
| 8 |
+
"""Two-player environment for the famous rock paper scissors game.
|
| 9 |
+
|
| 10 |
+
# __sphinx_doc_1_end__
|
| 11 |
+
Optionally, the "Sheldon Cooper extension" can be activated by passing
|
| 12 |
+
`sheldon_cooper_mode=True` into the constructor, in which case two more moves
|
| 13 |
+
are allowed: Spock and Lizard. Spock is poisoned by Lizard, disproven by Paper, but
|
| 14 |
+
crushes Rock and smashes Scissors. Lizard poisons Spock and eats Paper, but is
|
| 15 |
+
decapitated by Scissors and crushed by Rock.
|
| 16 |
+
|
| 17 |
+
# __sphinx_doc_2_begin__
|
| 18 |
+
Both players always move simultaneously over a course of 10 timesteps in total.
|
| 19 |
+
The winner of each timestep receives reward of +1, the losing player -1.0.
|
| 20 |
+
|
| 21 |
+
The observation of each player is the last opponent action.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
ROCK = 0
|
| 25 |
+
PAPER = 1
|
| 26 |
+
SCISSORS = 2
|
| 27 |
+
LIZARD = 3
|
| 28 |
+
SPOCK = 4
|
| 29 |
+
|
| 30 |
+
WIN_MATRIX = {
|
| 31 |
+
(ROCK, ROCK): (0, 0),
|
| 32 |
+
(ROCK, PAPER): (-1, 1),
|
| 33 |
+
(ROCK, SCISSORS): (1, -1),
|
| 34 |
+
(PAPER, ROCK): (1, -1),
|
| 35 |
+
(PAPER, PAPER): (0, 0),
|
| 36 |
+
(PAPER, SCISSORS): (-1, 1),
|
| 37 |
+
(SCISSORS, ROCK): (-1, 1),
|
| 38 |
+
(SCISSORS, PAPER): (1, -1),
|
| 39 |
+
(SCISSORS, SCISSORS): (0, 0),
|
| 40 |
+
}
|
| 41 |
+
# __sphinx_doc_2_end__
|
| 42 |
+
|
| 43 |
+
WIN_MATRIX.update(
|
| 44 |
+
{
|
| 45 |
+
# Sheldon Cooper mode:
|
| 46 |
+
(LIZARD, LIZARD): (0, 0),
|
| 47 |
+
(LIZARD, SPOCK): (1, -1), # Lizard poisons Spock
|
| 48 |
+
(LIZARD, ROCK): (-1, 1), # Rock crushes lizard
|
| 49 |
+
(LIZARD, PAPER): (1, -1), # Lizard eats paper
|
| 50 |
+
(LIZARD, SCISSORS): (-1, 1), # Scissors decapitate lizard
|
| 51 |
+
(ROCK, LIZARD): (1, -1), # Rock crushes lizard
|
| 52 |
+
(PAPER, LIZARD): (-1, 1), # Lizard eats paper
|
| 53 |
+
(SCISSORS, LIZARD): (1, -1), # Scissors decapitate lizard
|
| 54 |
+
(SPOCK, SPOCK): (0, 0),
|
| 55 |
+
(SPOCK, LIZARD): (-1, 1), # Lizard poisons Spock
|
| 56 |
+
(SPOCK, ROCK): (1, -1), # Spock vaporizes rock
|
| 57 |
+
(SPOCK, PAPER): (-1, 1), # Paper disproves Spock
|
| 58 |
+
(SPOCK, SCISSORS): (1, -1), # Spock smashes scissors
|
| 59 |
+
(ROCK, SPOCK): (-1, 1), # Spock vaporizes rock
|
| 60 |
+
(PAPER, SPOCK): (1, -1), # Paper disproves Spock
|
| 61 |
+
(SCISSORS, SPOCK): (-1, 1), # Spock smashes scissors
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# __sphinx_doc_3_begin__
|
| 66 |
+
def __init__(self, config=None):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.agents = self.possible_agents = ["player1", "player2"]
|
| 70 |
+
|
| 71 |
+
# The observations are always the last taken actions. Hence observation- and
|
| 72 |
+
# action spaces are identical.
|
| 73 |
+
self.observation_spaces = self.action_spaces = {
|
| 74 |
+
"player1": gym.spaces.Discrete(3),
|
| 75 |
+
"player2": gym.spaces.Discrete(3),
|
| 76 |
+
}
|
| 77 |
+
self.last_move = None
|
| 78 |
+
self.num_moves = 0
|
| 79 |
+
# __sphinx_doc_3_end__
|
| 80 |
+
|
| 81 |
+
self.sheldon_cooper_mode = False
|
| 82 |
+
if config.get("sheldon_cooper_mode"):
|
| 83 |
+
self.sheldon_cooper_mode = True
|
| 84 |
+
self.action_spaces = self.observation_spaces = {
|
| 85 |
+
"player1": gym.spaces.Discrete(5),
|
| 86 |
+
"player2": gym.spaces.Discrete(5),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# __sphinx_doc_4_begin__
|
| 90 |
+
def reset(self, *, seed=None, options=None):
|
| 91 |
+
self.num_moves = 0
|
| 92 |
+
|
| 93 |
+
# The first observation should not matter (none of the agents has moved yet).
|
| 94 |
+
# Set them to 0.
|
| 95 |
+
return {
|
| 96 |
+
"player1": 0,
|
| 97 |
+
"player2": 0,
|
| 98 |
+
}, {} # <- empty infos dict
|
| 99 |
+
|
| 100 |
+
# __sphinx_doc_4_end__
|
| 101 |
+
|
| 102 |
+
# __sphinx_doc_5_begin__
|
| 103 |
+
def step(self, action_dict):
|
| 104 |
+
self.num_moves += 1
|
| 105 |
+
|
| 106 |
+
move1 = action_dict["player1"]
|
| 107 |
+
move2 = action_dict["player2"]
|
| 108 |
+
|
| 109 |
+
# Set the next observations (simply use the other player's action).
|
| 110 |
+
# Note that because we are publishing both players in the observations dict,
|
| 111 |
+
# we expect both players to act in the next `step()` (simultaneous stepping).
|
| 112 |
+
observations = {"player1": move2, "player2": move1}
|
| 113 |
+
|
| 114 |
+
# Compute rewards for each player based on the win-matrix.
|
| 115 |
+
r1, r2 = self.WIN_MATRIX[move1, move2]
|
| 116 |
+
rewards = {"player1": r1, "player2": r2}
|
| 117 |
+
|
| 118 |
+
# Terminate the entire episode (for all agents) once 10 moves have been made.
|
| 119 |
+
terminateds = {"__all__": self.num_moves >= 10}
|
| 120 |
+
|
| 121 |
+
# Leave truncateds and infos empty.
|
| 122 |
+
return observations, rewards, terminateds, {}, {}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# __sphinx_doc_5_end__
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# __sphinx_doc_1_begin__
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TicTacToe(MultiAgentEnv):
|
| 9 |
+
"""A two-player game in which any player tries to complete one row in a 3x3 field.
|
| 10 |
+
|
| 11 |
+
The observation space is Box(0.0, 1.0, (9,)), where each index represents a distinct
|
| 12 |
+
field on a 3x3 board and values of 0.0 mean the field is empty, -1.0 means
|
| 13 |
+
the opponend owns the field, and 1.0 means we occupy the field:
|
| 14 |
+
----------
|
| 15 |
+
| 0| 1| 2|
|
| 16 |
+
----------
|
| 17 |
+
| 3| 4| 5|
|
| 18 |
+
----------
|
| 19 |
+
| 6| 7| 8|
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
The action space is Discrete(9) and actions landing on an already occupied field
|
| 23 |
+
are simply ignored (and thus useless to the player taking these actions).
|
| 24 |
+
|
| 25 |
+
Once a player completes a row, they receive +1.0 reward, the losing player receives
|
| 26 |
+
-1.0 reward. In all other cases, both players receive 0.0 reward.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
# __sphinx_doc_1_end__
|
| 30 |
+
|
| 31 |
+
# __sphinx_doc_2_begin__
|
| 32 |
+
def __init__(self, config=None):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
# Define the agents in the game.
|
| 36 |
+
self.agents = self.possible_agents = ["player1", "player2"]
|
| 37 |
+
|
| 38 |
+
# Each agent observes a 9D tensor, representing the 3x3 fields of the board.
|
| 39 |
+
# A 0 means an empty field, a 1 represents a piece of player 1, a -1 a piece of
|
| 40 |
+
# player 2.
|
| 41 |
+
self.observation_spaces = {
|
| 42 |
+
"player1": gym.spaces.Box(-1.0, 1.0, (9,), np.float32),
|
| 43 |
+
"player2": gym.spaces.Box(-1.0, 1.0, (9,), np.float32),
|
| 44 |
+
}
|
| 45 |
+
# Each player has 9 actions, encoding the 9 fields each player can place a piece
|
| 46 |
+
# on during their turn.
|
| 47 |
+
self.action_spaces = {
|
| 48 |
+
"player1": gym.spaces.Discrete(9),
|
| 49 |
+
"player2": gym.spaces.Discrete(9),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
self.board = None
|
| 53 |
+
self.current_player = None
|
| 54 |
+
|
| 55 |
+
# __sphinx_doc_2_end__
|
| 56 |
+
|
| 57 |
+
# __sphinx_doc_3_begin__
|
| 58 |
+
def reset(self, *, seed=None, options=None):
|
| 59 |
+
self.board = [
|
| 60 |
+
0,
|
| 61 |
+
0,
|
| 62 |
+
0,
|
| 63 |
+
0,
|
| 64 |
+
0,
|
| 65 |
+
0,
|
| 66 |
+
0,
|
| 67 |
+
0,
|
| 68 |
+
0,
|
| 69 |
+
]
|
| 70 |
+
# Pick a random player to start the game.
|
| 71 |
+
self.current_player = np.random.choice(["player1", "player2"])
|
| 72 |
+
# Return observations dict (only with the starting player, which is the one
|
| 73 |
+
# we expect to act next).
|
| 74 |
+
return {
|
| 75 |
+
self.current_player: np.array(self.board, np.float32),
|
| 76 |
+
}, {}
|
| 77 |
+
|
| 78 |
+
# __sphinx_doc_3_end__
|
| 79 |
+
|
| 80 |
+
# __sphinx_doc_4_begin__
|
| 81 |
+
def step(self, action_dict):
|
| 82 |
+
action = action_dict[self.current_player]
|
| 83 |
+
|
| 84 |
+
# Create a rewards-dict (containing the rewards of the agent that just acted).
|
| 85 |
+
rewards = {self.current_player: 0.0}
|
| 86 |
+
# Create a terminateds-dict with the special `__all__` agent ID, indicating that
|
| 87 |
+
# if True, the episode ends for all agents.
|
| 88 |
+
terminateds = {"__all__": False}
|
| 89 |
+
|
| 90 |
+
opponent = "player1" if self.current_player == "player2" else "player2"
|
| 91 |
+
|
| 92 |
+
# Penalize trying to place a piece on an already occupied field.
|
| 93 |
+
if self.board[action] != 0:
|
| 94 |
+
rewards[self.current_player] -= 5.0
|
| 95 |
+
# Change the board according to the (valid) action taken.
|
| 96 |
+
else:
|
| 97 |
+
self.board[action] = 1 if self.current_player == "player1" else -1
|
| 98 |
+
|
| 99 |
+
# After having placed a new piece, figure out whether the current player
|
| 100 |
+
# won or not.
|
| 101 |
+
if self.current_player == "player1":
|
| 102 |
+
win_val = [1, 1, 1]
|
| 103 |
+
else:
|
| 104 |
+
win_val = [-1, -1, -1]
|
| 105 |
+
if (
|
| 106 |
+
# Horizontal win.
|
| 107 |
+
self.board[:3] == win_val
|
| 108 |
+
or self.board[3:6] == win_val
|
| 109 |
+
or self.board[6:] == win_val
|
| 110 |
+
# Vertical win.
|
| 111 |
+
or self.board[0:7:3] == win_val
|
| 112 |
+
or self.board[1:8:3] == win_val
|
| 113 |
+
or self.board[2:9:3] == win_val
|
| 114 |
+
# Diagonal win.
|
| 115 |
+
or self.board[::3] == win_val
|
| 116 |
+
or self.board[2:7:2] == win_val
|
| 117 |
+
):
|
| 118 |
+
# Final reward is +5 for victory and -5 for a loss.
|
| 119 |
+
rewards[self.current_player] += 5.0
|
| 120 |
+
rewards[opponent] = -5.0
|
| 121 |
+
|
| 122 |
+
# Episode is done and needs to be reset for a new game.
|
| 123 |
+
terminateds["__all__"] = True
|
| 124 |
+
|
| 125 |
+
# The board might also be full w/o any player having won/lost.
|
| 126 |
+
# In this case, we simply end the episode and none of the players receives
|
| 127 |
+
# +1 or -1 reward.
|
| 128 |
+
elif 0 not in self.board:
|
| 129 |
+
terminateds["__all__"] = True
|
| 130 |
+
|
| 131 |
+
# Flip players and return an observations dict with only the next player to
|
| 132 |
+
# make a move in it.
|
| 133 |
+
self.current_player = opponent
|
| 134 |
+
|
| 135 |
+
return (
|
| 136 |
+
{self.current_player: np.array(self.board, np.float32)},
|
| 137 |
+
rewards,
|
| 138 |
+
terminateds,
|
| 139 |
+
{},
|
| 140 |
+
{},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# __sphinx_doc_4_end__
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/two_step_game.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gymnasium.spaces import Dict, Discrete, MultiDiscrete, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv, ENV_STATE
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TwoStepGame(MultiAgentEnv):
|
| 8 |
+
action_space = Discrete(2)
|
| 9 |
+
|
| 10 |
+
def __init__(self, env_config):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.action_space = Discrete(2)
|
| 13 |
+
self.state = None
|
| 14 |
+
self.agent_1 = 0
|
| 15 |
+
self.agent_2 = 1
|
| 16 |
+
# MADDPG emits action logits instead of actual discrete actions
|
| 17 |
+
self.actions_are_logits = env_config.get("actions_are_logits", False)
|
| 18 |
+
self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", False)
|
| 19 |
+
self.with_state = env_config.get("separate_state_space", False)
|
| 20 |
+
self._agent_ids = {0, 1}
|
| 21 |
+
if not self.one_hot_state_encoding:
|
| 22 |
+
self.observation_space = Discrete(6)
|
| 23 |
+
self.with_state = False
|
| 24 |
+
else:
|
| 25 |
+
# Each agent gets the full state (one-hot encoding of which of the
|
| 26 |
+
# three states are active) as input with the receiving agent's
|
| 27 |
+
# ID (1 or 2) concatenated onto the end.
|
| 28 |
+
if self.with_state:
|
| 29 |
+
self.observation_space = Dict(
|
| 30 |
+
{
|
| 31 |
+
"obs": MultiDiscrete([2, 2, 2, 3]),
|
| 32 |
+
ENV_STATE: MultiDiscrete([2, 2, 2]),
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
else:
|
| 36 |
+
self.observation_space = MultiDiscrete([2, 2, 2, 3])
|
| 37 |
+
|
| 38 |
+
def reset(self, *, seed=None, options=None):
|
| 39 |
+
if seed is not None:
|
| 40 |
+
np.random.seed(seed)
|
| 41 |
+
self.state = np.array([1, 0, 0])
|
| 42 |
+
return self._obs(), {}
|
| 43 |
+
|
| 44 |
+
def step(self, action_dict):
|
| 45 |
+
if self.actions_are_logits:
|
| 46 |
+
action_dict = {
|
| 47 |
+
k: np.random.choice([0, 1], p=v) for k, v in action_dict.items()
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
state_index = np.flatnonzero(self.state)
|
| 51 |
+
if state_index == 0:
|
| 52 |
+
action = action_dict[self.agent_1]
|
| 53 |
+
assert action in [0, 1], action
|
| 54 |
+
if action == 0:
|
| 55 |
+
self.state = np.array([0, 1, 0])
|
| 56 |
+
else:
|
| 57 |
+
self.state = np.array([0, 0, 1])
|
| 58 |
+
global_rew = 0
|
| 59 |
+
terminated = False
|
| 60 |
+
elif state_index == 1:
|
| 61 |
+
global_rew = 7
|
| 62 |
+
terminated = True
|
| 63 |
+
else:
|
| 64 |
+
if action_dict[self.agent_1] == 0 and action_dict[self.agent_2] == 0:
|
| 65 |
+
global_rew = 0
|
| 66 |
+
elif action_dict[self.agent_1] == 1 and action_dict[self.agent_2] == 1:
|
| 67 |
+
global_rew = 8
|
| 68 |
+
else:
|
| 69 |
+
global_rew = 1
|
| 70 |
+
terminated = True
|
| 71 |
+
|
| 72 |
+
rewards = {self.agent_1: global_rew / 2.0, self.agent_2: global_rew / 2.0}
|
| 73 |
+
obs = self._obs()
|
| 74 |
+
terminateds = {"__all__": terminated}
|
| 75 |
+
truncateds = {"__all__": False}
|
| 76 |
+
infos = {
|
| 77 |
+
self.agent_1: {"done": terminateds["__all__"]},
|
| 78 |
+
self.agent_2: {"done": terminateds["__all__"]},
|
| 79 |
+
}
|
| 80 |
+
return obs, rewards, terminateds, truncateds, infos
|
| 81 |
+
|
| 82 |
+
def _obs(self):
|
| 83 |
+
if self.with_state:
|
| 84 |
+
return {
|
| 85 |
+
self.agent_1: {"obs": self.agent_1_obs(), ENV_STATE: self.state},
|
| 86 |
+
self.agent_2: {"obs": self.agent_2_obs(), ENV_STATE: self.state},
|
| 87 |
+
}
|
| 88 |
+
else:
|
| 89 |
+
return {self.agent_1: self.agent_1_obs(), self.agent_2: self.agent_2_obs()}
|
| 90 |
+
|
| 91 |
+
def agent_1_obs(self):
|
| 92 |
+
if self.one_hot_state_encoding:
|
| 93 |
+
return np.concatenate([self.state, [1]])
|
| 94 |
+
else:
|
| 95 |
+
return np.flatnonzero(self.state)[0]
|
| 96 |
+
|
| 97 |
+
def agent_2_obs(self):
|
| 98 |
+
if self.one_hot_state_encoding:
|
| 99 |
+
return np.concatenate([self.state, [2]])
|
| 100 |
+
else:
|
| 101 |
+
return np.flatnonzero(self.state)[0] + 3
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class TwoStepGameWithGroupedAgents(MultiAgentEnv):
|
| 105 |
+
def __init__(self, env_config):
|
| 106 |
+
super().__init__()
|
| 107 |
+
env = TwoStepGame(env_config)
|
| 108 |
+
tuple_obs_space = Tuple([env.observation_space, env.observation_space])
|
| 109 |
+
tuple_act_space = Tuple([env.action_space, env.action_space])
|
| 110 |
+
self._agent_ids = {"agents"}
|
| 111 |
+
self.env = env.with_agent_groups(
|
| 112 |
+
groups={"agents": [0, 1]},
|
| 113 |
+
obs_space=tuple_obs_space,
|
| 114 |
+
act_space=tuple_act_space,
|
| 115 |
+
)
|
| 116 |
+
self.observation_space = Dict({"agents": self.env.observation_space})
|
| 117 |
+
self.action_space = Dict({"agents": self.env.action_space})
|
| 118 |
+
|
| 119 |
+
def reset(self, *, seed=None, options=None):
|
| 120 |
+
return self.env.reset(seed=seed, options=options)
|
| 121 |
+
|
| 122 |
+
def step(self, actions):
|
| 123 |
+
return self.env.step(actions)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/nested_space_repeat_after_me_env.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.spaces import Box, Dict, Discrete, Tuple
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tree # pip install dm_tree
|
| 5 |
+
|
| 6 |
+
from ray.rllib.utils.spaces.space_utils import flatten_space
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class NestedSpaceRepeatAfterMeEnv(gym.Env):
|
| 10 |
+
"""Env for which policy has to repeat the (possibly complex) observation.
|
| 11 |
+
|
| 12 |
+
The action space and observation spaces are always the same and may be
|
| 13 |
+
arbitrarily nested Dict/Tuple Spaces.
|
| 14 |
+
Rewards are given for exactly matching Discrete sub-actions and for being
|
| 15 |
+
as close as possible for Box sub-actions.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config=None):
|
| 19 |
+
config = config or {}
|
| 20 |
+
self.observation_space = config.get(
|
| 21 |
+
"space", Tuple([Discrete(2), Dict({"a": Box(-1.0, 1.0, (2,))})])
|
| 22 |
+
)
|
| 23 |
+
self.action_space = self.observation_space
|
| 24 |
+
self.flattened_action_space = flatten_space(self.action_space)
|
| 25 |
+
self.episode_len = config.get("episode_len", 100)
|
| 26 |
+
|
| 27 |
+
def reset(self, *, seed=None, options=None):
|
| 28 |
+
self.steps = 0
|
| 29 |
+
return self._next_obs(), {}
|
| 30 |
+
|
| 31 |
+
def step(self, action):
|
| 32 |
+
self.steps += 1
|
| 33 |
+
action = tree.flatten(action)
|
| 34 |
+
reward = 0.0
|
| 35 |
+
for a, o, space in zip(
|
| 36 |
+
action, self.current_obs_flattened, self.flattened_action_space
|
| 37 |
+
):
|
| 38 |
+
# Box: -abs(diff).
|
| 39 |
+
if isinstance(space, gym.spaces.Box):
|
| 40 |
+
reward -= np.sum(np.abs(a - o))
|
| 41 |
+
# Discrete: +1.0 if exact match.
|
| 42 |
+
if isinstance(space, gym.spaces.Discrete):
|
| 43 |
+
reward += 1.0 if a == o else 0.0
|
| 44 |
+
truncated = self.steps >= self.episode_len
|
| 45 |
+
return self._next_obs(), reward, False, truncated, {}
|
| 46 |
+
|
| 47 |
+
def _next_obs(self):
|
| 48 |
+
self.current_obs = self.observation_space.sample()
|
| 49 |
+
self.current_obs_flattened = tree.flatten(self.current_obs)
|
| 50 |
+
return self.current_obs
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/parametric_actions_cartpole.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
import numpy as np
|
| 5 |
+
from gymnasium.spaces import Box, Dict, Discrete
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ParametricActionsCartPole(gym.Env):
|
| 9 |
+
"""Parametric action version of CartPole.
|
| 10 |
+
|
| 11 |
+
In this env there are only ever two valid actions, but we pretend there are
|
| 12 |
+
actually up to `max_avail_actions` actions that can be taken, and the two
|
| 13 |
+
valid actions are randomly hidden among this set.
|
| 14 |
+
|
| 15 |
+
At each step, we emit a dict of:
|
| 16 |
+
- the actual cart observation
|
| 17 |
+
- a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail)
|
| 18 |
+
- the list of action embeddings (w/ zeroes for invalid actions) (e.g.,
|
| 19 |
+
[[0, 0],
|
| 20 |
+
[0, 0],
|
| 21 |
+
[-0.2322, -0.2569],
|
| 22 |
+
[0, 0],
|
| 23 |
+
[0, 0],
|
| 24 |
+
[0.7878, 1.2297]] for max_avail_actions=6)
|
| 25 |
+
|
| 26 |
+
In a real environment, the actions embeddings would be larger than two
|
| 27 |
+
units of course, and also there would be a variable number of valid actions
|
| 28 |
+
per step instead of always [LEFT, RIGHT].
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, max_avail_actions):
|
| 32 |
+
# Use simple random 2-unit action embeddings for [LEFT, RIGHT]
|
| 33 |
+
self.left_action_embed = np.random.randn(2)
|
| 34 |
+
self.right_action_embed = np.random.randn(2)
|
| 35 |
+
self.action_space = Discrete(max_avail_actions)
|
| 36 |
+
self.wrapped = gym.make("CartPole-v1")
|
| 37 |
+
self.observation_space = Dict(
|
| 38 |
+
{
|
| 39 |
+
"action_mask": Box(0, 1, shape=(max_avail_actions,), dtype=np.int8),
|
| 40 |
+
"avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
|
| 41 |
+
"cart": self.wrapped.observation_space,
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def update_avail_actions(self):
|
| 46 |
+
self.action_assignments = np.array(
|
| 47 |
+
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32
|
| 48 |
+
)
|
| 49 |
+
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8)
|
| 50 |
+
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2)
|
| 51 |
+
self.action_assignments[self.left_idx] = self.left_action_embed
|
| 52 |
+
self.action_assignments[self.right_idx] = self.right_action_embed
|
| 53 |
+
self.action_mask[self.left_idx] = 1
|
| 54 |
+
self.action_mask[self.right_idx] = 1
|
| 55 |
+
|
| 56 |
+
def reset(self, *, seed=None, options=None):
|
| 57 |
+
self.update_avail_actions()
|
| 58 |
+
obs, infos = self.wrapped.reset()
|
| 59 |
+
return {
|
| 60 |
+
"action_mask": self.action_mask,
|
| 61 |
+
"avail_actions": self.action_assignments,
|
| 62 |
+
"cart": obs,
|
| 63 |
+
}, infos
|
| 64 |
+
|
| 65 |
+
def step(self, action):
|
| 66 |
+
if action == self.left_idx:
|
| 67 |
+
actual_action = 0
|
| 68 |
+
elif action == self.right_idx:
|
| 69 |
+
actual_action = 1
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"Chosen action was not one of the non-zero action embeddings",
|
| 73 |
+
action,
|
| 74 |
+
self.action_assignments,
|
| 75 |
+
self.action_mask,
|
| 76 |
+
self.left_idx,
|
| 77 |
+
self.right_idx,
|
| 78 |
+
)
|
| 79 |
+
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
|
| 80 |
+
self.update_avail_actions()
|
| 81 |
+
self.action_mask = self.action_mask.astype(np.int8)
|
| 82 |
+
obs = {
|
| 83 |
+
"action_mask": self.action_mask,
|
| 84 |
+
"avail_actions": self.action_assignments,
|
| 85 |
+
"cart": orig_obs,
|
| 86 |
+
}
|
| 87 |
+
return obs, rew, done, truncated, info
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ParametricActionsCartPoleNoEmbeddings(gym.Env):
|
| 91 |
+
"""Same as the above ParametricActionsCartPole.
|
| 92 |
+
|
| 93 |
+
However, action embeddings are not published inside observations,
|
| 94 |
+
but will be learnt by the model.
|
| 95 |
+
|
| 96 |
+
At each step, we emit a dict of:
|
| 97 |
+
- the actual cart observation
|
| 98 |
+
- a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail)
|
| 99 |
+
- action embeddings (w/ "dummy embedding" for invalid actions) are
|
| 100 |
+
outsourced in the model and will be learned.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, max_avail_actions):
|
| 104 |
+
# Randomly set which two actions are valid and available.
|
| 105 |
+
self.left_idx, self.right_idx = random.sample(range(max_avail_actions), 2)
|
| 106 |
+
self.valid_avail_actions_mask = np.array(
|
| 107 |
+
[0.0] * max_avail_actions, dtype=np.int8
|
| 108 |
+
)
|
| 109 |
+
self.valid_avail_actions_mask[self.left_idx] = 1
|
| 110 |
+
self.valid_avail_actions_mask[self.right_idx] = 1
|
| 111 |
+
self.action_space = Discrete(max_avail_actions)
|
| 112 |
+
self.wrapped = gym.make("CartPole-v1")
|
| 113 |
+
self.observation_space = Dict(
|
| 114 |
+
{
|
| 115 |
+
"valid_avail_actions_mask": Box(0, 1, shape=(max_avail_actions,)),
|
| 116 |
+
"cart": self.wrapped.observation_space,
|
| 117 |
+
}
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def reset(self, *, seed=None, options=None):
|
| 121 |
+
obs, infos = self.wrapped.reset()
|
| 122 |
+
return {
|
| 123 |
+
"valid_avail_actions_mask": self.valid_avail_actions_mask,
|
| 124 |
+
"cart": obs,
|
| 125 |
+
}, infos
|
| 126 |
+
|
| 127 |
+
def step(self, action):
|
| 128 |
+
if action == self.left_idx:
|
| 129 |
+
actual_action = 0
|
| 130 |
+
elif action == self.right_idx:
|
| 131 |
+
actual_action = 1
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
"Chosen action was not one of the non-zero action embeddings",
|
| 135 |
+
action,
|
| 136 |
+
self.valid_avail_actions_mask,
|
| 137 |
+
self.left_idx,
|
| 138 |
+
self.right_idx,
|
| 139 |
+
)
|
| 140 |
+
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
|
| 141 |
+
obs = {
|
| 142 |
+
"valid_avail_actions_mask": self.valid_avail_actions_mask,
|
| 143 |
+
"cart": orig_obs,
|
| 144 |
+
}
|
| 145 |
+
return obs, rew, done, truncated, info
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/random_env.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium.spaces import Discrete, Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.rllib.examples.envs.classes.multi_agent import make_multi_agent
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RandomEnv(gym.Env):
|
| 10 |
+
"""A randomly acting environment.
|
| 11 |
+
|
| 12 |
+
Can be instantiated with arbitrary action-, observation-, and reward
|
| 13 |
+
spaces. Observations and rewards are generated by simply sampling from the
|
| 14 |
+
observation/reward spaces. The probability of a `terminated=True` after each
|
| 15 |
+
action can be configured, as well as the max episode length.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config=None):
|
| 19 |
+
config = config or {}
|
| 20 |
+
|
| 21 |
+
# Action space.
|
| 22 |
+
self.action_space = config.get("action_space", Discrete(2))
|
| 23 |
+
# Observation space from which to sample.
|
| 24 |
+
self.observation_space = config.get("observation_space", Discrete(2))
|
| 25 |
+
# Reward space from which to sample.
|
| 26 |
+
self.reward_space = config.get(
|
| 27 |
+
"reward_space",
|
| 28 |
+
gym.spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32),
|
| 29 |
+
)
|
| 30 |
+
self.static_samples = config.get("static_samples", False)
|
| 31 |
+
if self.static_samples:
|
| 32 |
+
self.observation_sample = self.observation_space.sample()
|
| 33 |
+
self.reward_sample = self.reward_space.sample()
|
| 34 |
+
|
| 35 |
+
# Chance that an episode ends at any step.
|
| 36 |
+
# Note that a max episode length can be specified via
|
| 37 |
+
# `max_episode_len`.
|
| 38 |
+
self.p_terminated = config.get("p_terminated")
|
| 39 |
+
if self.p_terminated is None:
|
| 40 |
+
self.p_terminated = config.get("p_done", 0.1)
|
| 41 |
+
# A max episode length. Even if the `p_terminated` sampling does not lead
|
| 42 |
+
# to a terminus, the episode will end after at most this many
|
| 43 |
+
# timesteps.
|
| 44 |
+
# Set to 0 or None for using no limit on the episode length.
|
| 45 |
+
self.max_episode_len = config.get("max_episode_len", None)
|
| 46 |
+
# Whether to check action bounds.
|
| 47 |
+
self.check_action_bounds = config.get("check_action_bounds", False)
|
| 48 |
+
# Steps taken so far (after last reset).
|
| 49 |
+
self.steps = 0
|
| 50 |
+
|
| 51 |
+
def reset(self, *, seed=None, options=None):
|
| 52 |
+
self.steps = 0
|
| 53 |
+
if not self.static_samples:
|
| 54 |
+
return self.observation_space.sample(), {}
|
| 55 |
+
else:
|
| 56 |
+
return copy.deepcopy(self.observation_sample), {}
|
| 57 |
+
|
| 58 |
+
def step(self, action):
|
| 59 |
+
if self.check_action_bounds and not self.action_space.contains(action):
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"Illegal action for {}: {}".format(self.action_space, action)
|
| 62 |
+
)
|
| 63 |
+
if isinstance(self.action_space, Tuple) and len(action) != len(
|
| 64 |
+
self.action_space.spaces
|
| 65 |
+
):
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"Illegal action for {}: {}".format(self.action_space, action)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.steps += 1
|
| 71 |
+
terminated = False
|
| 72 |
+
truncated = False
|
| 73 |
+
# We are `truncated` as per our max-episode-len.
|
| 74 |
+
if self.max_episode_len and self.steps >= self.max_episode_len:
|
| 75 |
+
truncated = True
|
| 76 |
+
# Max episode length not reached yet -> Sample `terminated` via `p_terminated`.
|
| 77 |
+
elif self.p_terminated > 0.0:
|
| 78 |
+
terminated = bool(
|
| 79 |
+
np.random.choice(
|
| 80 |
+
[True, False], p=[self.p_terminated, 1.0 - self.p_terminated]
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if not self.static_samples:
|
| 85 |
+
return (
|
| 86 |
+
self.observation_space.sample(),
|
| 87 |
+
self.reward_space.sample(),
|
| 88 |
+
terminated,
|
| 89 |
+
truncated,
|
| 90 |
+
{},
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
return (
|
| 94 |
+
copy.deepcopy(self.observation_sample),
|
| 95 |
+
copy.deepcopy(self.reward_sample),
|
| 96 |
+
terminated,
|
| 97 |
+
truncated,
|
| 98 |
+
{},
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Multi-agent version of the RandomEnv.
|
| 103 |
+
RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Large observation space "pre-compiled" random env (for testing).
|
| 107 |
+
class RandomLargeObsSpaceEnv(RandomEnv):
|
| 108 |
+
def __init__(self, config=None):
|
| 109 |
+
config = config or {}
|
| 110 |
+
config.update({"observation_space": gym.spaces.Box(-1.0, 1.0, (5000,))})
|
| 111 |
+
super().__init__(config=config)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Large observation space + cont. actions "pre-compiled" random env
|
| 115 |
+
# (for testing).
|
| 116 |
+
class RandomLargeObsSpaceEnvContActions(RandomEnv):
|
| 117 |
+
def __init__(self, config=None):
|
| 118 |
+
config = config or {}
|
| 119 |
+
config.update(
|
| 120 |
+
{
|
| 121 |
+
"observation_space": gym.spaces.Box(-1.0, 1.0, (5000,)),
|
| 122 |
+
"action_space": gym.spaces.Box(-1.0, 1.0, (5,)),
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
super().__init__(config=config)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/recommender_system_envs_with_recsim.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Examples for RecSim envs ready to be used by RLlib Algorithms.
|
| 2 |
+
|
| 3 |
+
RecSim is a configurable recommender systems simulation platform.
|
| 4 |
+
Source: https://github.com/google-research/recsim
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from recsim import choice_model
|
| 8 |
+
from recsim.environments import (
|
| 9 |
+
long_term_satisfaction as lts,
|
| 10 |
+
interest_evolution as iev,
|
| 11 |
+
interest_exploration as iex,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from ray.rllib.env.wrappers.recsim import make_recsim_env
|
| 15 |
+
from ray.tune import register_env
|
| 16 |
+
|
| 17 |
+
# Some built-in RecSim envs to test with.
|
| 18 |
+
# ---------------------------------------
|
| 19 |
+
|
| 20 |
+
# Long-term satisfaction env: User has to pick from items that are either
|
| 21 |
+
# a) unhealthy, but taste good, or b) healthy, but have bad taste.
|
| 22 |
+
# Best strategy is to pick a mix of both to ensure long-term
|
| 23 |
+
# engagement.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def lts_user_model_creator(env_ctx):
|
| 27 |
+
return lts.LTSUserModel(
|
| 28 |
+
env_ctx["slate_size"],
|
| 29 |
+
user_state_ctor=lts.LTSUserState,
|
| 30 |
+
response_model_ctor=lts.LTSResponse,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def lts_document_sampler_creator(env_ctx):
|
| 35 |
+
return lts.LTSDocumentSampler()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
LongTermSatisfactionRecSimEnv = make_recsim_env(
|
| 39 |
+
recsim_user_model_creator=lts_user_model_creator,
|
| 40 |
+
recsim_document_sampler_creator=lts_document_sampler_creator,
|
| 41 |
+
reward_aggregator=lts.clicked_engagement_reward,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Interest exploration env: Models the problem of active exploration
|
| 46 |
+
# of user interests. It is meant to illustrate popularity bias in
|
| 47 |
+
# recommender systems, where myopic maximization of engagement leads
|
| 48 |
+
# to bias towards documents that have wider appeal,
|
| 49 |
+
# whereas niche user interests remain unexplored.
|
| 50 |
+
def iex_user_model_creator(env_ctx):
|
| 51 |
+
return iex.IEUserModel(
|
| 52 |
+
env_ctx["slate_size"],
|
| 53 |
+
user_state_ctor=iex.IEUserState,
|
| 54 |
+
response_model_ctor=iex.IEResponse,
|
| 55 |
+
seed=env_ctx["seed"],
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def iex_document_sampler_creator(env_ctx):
|
| 60 |
+
return iex.IETopicDocumentSampler(seed=env_ctx["seed"])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
InterestExplorationRecSimEnv = make_recsim_env(
|
| 64 |
+
recsim_user_model_creator=iex_user_model_creator,
|
| 65 |
+
recsim_document_sampler_creator=iex_document_sampler_creator,
|
| 66 |
+
reward_aggregator=iex.total_clicks_reward,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Interest evolution env: See https://github.com/google-research/recsim
|
| 71 |
+
# for more information.
|
| 72 |
+
def iev_user_model_creator(env_ctx):
|
| 73 |
+
return iev.IEvUserModel(
|
| 74 |
+
env_ctx["slate_size"],
|
| 75 |
+
choice_model_ctor=choice_model.MultinomialProportionalChoiceModel,
|
| 76 |
+
response_model_ctor=iev.IEvResponse,
|
| 77 |
+
user_state_ctor=iev.IEvUserState,
|
| 78 |
+
seed=env_ctx["seed"],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Extend IEvVideo to fix a bug caused by None cluster_ids.
|
| 83 |
+
class SingleClusterIEvVideo(iev.IEvVideo):
|
| 84 |
+
def __init__(self, doc_id, features, video_length=None, quality=None):
|
| 85 |
+
super(SingleClusterIEvVideo, self).__init__(
|
| 86 |
+
doc_id=doc_id,
|
| 87 |
+
features=features,
|
| 88 |
+
cluster_id=0, # single cluster.
|
| 89 |
+
video_length=video_length,
|
| 90 |
+
quality=quality,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def iev_document_sampler_creator(env_ctx):
|
| 95 |
+
return iev.UtilityModelVideoSampler(doc_ctor=iev.IEvVideo, seed=env_ctx["seed"])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
InterestEvolutionRecSimEnv = make_recsim_env(
|
| 99 |
+
recsim_user_model_creator=iev_user_model_creator,
|
| 100 |
+
recsim_document_sampler_creator=iev_document_sampler_creator,
|
| 101 |
+
reward_aggregator=iev.clicked_watchtime_reward,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Backward compatibility.
|
| 106 |
+
register_env(
|
| 107 |
+
name="RecSim-v1", env_creator=lambda env_ctx: InterestEvolutionRecSimEnv(env_ctx)
|
| 108 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_after_me_env.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.spaces import Box, Discrete
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RepeatAfterMeEnv(gym.Env):
|
| 7 |
+
"""Env in which the observation at timestep minus n must be repeated."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, config=None):
|
| 10 |
+
config = config or {}
|
| 11 |
+
if config.get("continuous"):
|
| 12 |
+
self.observation_space = Box(-1.0, 1.0, (2,))
|
| 13 |
+
else:
|
| 14 |
+
self.observation_space = Discrete(2)
|
| 15 |
+
|
| 16 |
+
self.action_space = self.observation_space
|
| 17 |
+
# Note: Set `repeat_delay` to 0 for simply repeating the seen
|
| 18 |
+
# observation (no delay).
|
| 19 |
+
self.delay = config.get("repeat_delay", 1)
|
| 20 |
+
self.episode_len = config.get("episode_len", 100)
|
| 21 |
+
self.history = []
|
| 22 |
+
|
| 23 |
+
def reset(self, *, seed=None, options=None):
|
| 24 |
+
self.history = [0] * self.delay
|
| 25 |
+
return self._next_obs(), {}
|
| 26 |
+
|
| 27 |
+
def step(self, action):
|
| 28 |
+
obs = self.history[-(1 + self.delay)]
|
| 29 |
+
|
| 30 |
+
reward = 0.0
|
| 31 |
+
# Box: -abs(diff).
|
| 32 |
+
if isinstance(self.action_space, Box):
|
| 33 |
+
reward = -np.sum(np.abs(action - obs))
|
| 34 |
+
# Discrete: +1.0 if exact match, -1.0 otherwise.
|
| 35 |
+
if isinstance(self.action_space, Discrete):
|
| 36 |
+
reward = 1.0 if action == obs else -1.0
|
| 37 |
+
|
| 38 |
+
done = truncated = len(self.history) > self.episode_len
|
| 39 |
+
return self._next_obs(), reward, done, truncated, {}
|
| 40 |
+
|
| 41 |
+
def _next_obs(self):
|
| 42 |
+
if isinstance(self.observation_space, Box):
|
| 43 |
+
token = np.random.random(size=(2,))
|
| 44 |
+
else:
|
| 45 |
+
token = np.random.choice([0, 1])
|
| 46 |
+
self.history.append(token)
|
| 47 |
+
return token
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_initial_obs_env.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.spaces import Discrete
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RepeatInitialObsEnv(gym.Env):
|
| 7 |
+
"""Env in which the initial observation has to be repeated all the time.
|
| 8 |
+
|
| 9 |
+
Runs for n steps.
|
| 10 |
+
r=1 if action correct, -1 otherwise (max. R=100).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, episode_len=100):
|
| 14 |
+
self.observation_space = Discrete(2)
|
| 15 |
+
self.action_space = Discrete(2)
|
| 16 |
+
self.token = None
|
| 17 |
+
self.episode_len = episode_len
|
| 18 |
+
self.num_steps = 0
|
| 19 |
+
|
| 20 |
+
def reset(self, *, seed=None, options=None):
|
| 21 |
+
self.token = random.choice([0, 1])
|
| 22 |
+
self.num_steps = 0
|
| 23 |
+
return self.token, {}
|
| 24 |
+
|
| 25 |
+
def step(self, action):
|
| 26 |
+
if action == self.token:
|
| 27 |
+
reward = 1
|
| 28 |
+
else:
|
| 29 |
+
reward = -1
|
| 30 |
+
self.num_steps += 1
|
| 31 |
+
done = truncated = self.num_steps >= self.episode_len
|
| 32 |
+
return 0, reward, done, truncated, {}
|
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/simple_corridor.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium.spaces import Box, Discrete
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SimpleCorridor(gym.Env):
|
| 7 |
+
"""Example of a custom env in which you have to walk down a corridor.
|
| 8 |
+
|
| 9 |
+
You can configure the length of the corridor via the env config."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, config=None):
|
| 12 |
+
config = config or {}
|
| 13 |
+
|
| 14 |
+
self.action_space = Discrete(2)
|
| 15 |
+
self.observation_space = Box(0.0, 999.0, shape=(1,), dtype=np.float32)
|
| 16 |
+
|
| 17 |
+
self.set_corridor_length(config.get("corridor_length", 10))
|
| 18 |
+
|
| 19 |
+
self._cur_pos = 0
|
| 20 |
+
|
| 21 |
+
def set_corridor_length(self, length):
|
| 22 |
+
self.end_pos = length
|
| 23 |
+
print(f"Set corridor length to {self.end_pos}")
|
| 24 |
+
assert self.end_pos <= 999, "The maximum `corridor_length` allowed is 999!"
|
| 25 |
+
|
| 26 |
+
def reset(self, *, seed=None, options=None):
|
| 27 |
+
self._cur_pos = 0.0
|
| 28 |
+
return self._get_obs(), {}
|
| 29 |
+
|
| 30 |
+
def step(self, action):
|
| 31 |
+
assert action in [0, 1], action
|
| 32 |
+
if action == 0 and self._cur_pos > 0:
|
| 33 |
+
self._cur_pos -= 1.0
|
| 34 |
+
elif action == 1:
|
| 35 |
+
self._cur_pos += 1.0
|
| 36 |
+
terminated = self._cur_pos >= self.end_pos
|
| 37 |
+
truncated = False
|
| 38 |
+
reward = 1.0 if terminated else -0.01
|
| 39 |
+
return self._get_obs(), reward, terminated, truncated, {}
|
| 40 |
+
|
| 41 |
+
def _get_obs(self):
|
| 42 |
+
return np.array([self._cur_pos], np.float32)
|