koichi12 commited on
Commit
ead37c9
·
verified ·
1 Parent(s): 5f20f96

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/utils.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/vtrace_torch.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__init__.py +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/impala_torch_learner.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/vtrace_torch_v2.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/impala_torch_learner.py +164 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/vtrace_torch_v2.py +168 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/examples/__init__.py +0 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/examples/centralized_critic.py +319 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/examples/compute_adapted_gae_on_postprocess_trajectory.py +157 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/__init__.py +0 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_in_sequence.py +87 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_simultaneously.py +108 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/async_gym_env_vectorization.py +142 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/__init__.py +0 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/action_mask_env.py +42 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_crashing.py +182 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_sparse_rewards.py +51 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_dict_observation_space.py +74 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_large_observation_space.py +69 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_protobuf_observation_space.py +79 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cliff_walking_wall_env.py +71 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/correlated_actions_env.py +79 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/d4rl_env.py +46 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/debug_counter_env.py +92 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/deterministic_envs.py +13 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/dm_control_suite.py +131 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_using_remote_actor.py +63 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_with_subprocess.py +42 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/fast_image_env.py +20 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/gpu_requiring_env.py +37 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/look_and_push.py +65 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/memory_leaking_env.py +35 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/mock_env.py +220 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/bandit_envs_discrete.py +206 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/guess_the_number_game.py +89 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_chess.py +227 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_connect4.py +213 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py +125 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py +144 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/two_step_game.py +123 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/nested_space_repeat_after_me_env.py +50 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/parametric_actions_cartpole.py +145 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/random_env.py +125 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/recommender_system_envs_with_recsim.py +108 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_after_me_env.py +47 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_initial_obs_env.py +32 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/simple_corridor.py +42 -0
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (709 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.03 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/__pycache__/vtrace_torch.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/impala_torch_learner.cpython-311.pyc ADDED
Binary file (5.79 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/__pycache__/vtrace_torch_v2.cpython-311.pyc ADDED
Binary file (8.13 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/impala_torch_learner.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ray.rllib.algorithms.impala.impala import IMPALAConfig
4
+ from ray.rllib.algorithms.impala.impala_learner import IMPALALearner
5
+ from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import (
6
+ vtrace_torch,
7
+ make_time_major,
8
+ )
9
+ from ray.rllib.core.columns import Columns
10
+ from ray.rllib.core.learner.learner import ENTROPY_KEY
11
+ from ray.rllib.core.learner.torch.torch_learner import TorchLearner
12
+ from ray.rllib.utils.annotations import override
13
+ from ray.rllib.utils.framework import try_import_torch
14
+ from ray.rllib.utils.typing import ModuleID, TensorType
15
+
16
+ torch, nn = try_import_torch()
17
+
18
+
19
+ class IMPALATorchLearner(IMPALALearner, TorchLearner):
20
+ """Implements the IMPALA loss function in torch."""
21
+
22
+ @override(TorchLearner)
23
+ def compute_loss_for_module(
24
+ self,
25
+ *,
26
+ module_id: ModuleID,
27
+ config: IMPALAConfig,
28
+ batch: Dict,
29
+ fwd_out: Dict[str, TensorType],
30
+ ) -> TensorType:
31
+ module = self.module[module_id].unwrapped()
32
+
33
+ # TODO (sven): Now that we do the +1ts trick to be less vulnerable about
34
+ # bootstrap values at the end of rollouts in the new stack, we might make
35
+ # this a more flexible, configurable parameter for users, e.g.
36
+ # `v_trace_seq_len` (independent of `rollout_fragment_length`). Separation
37
+ # of concerns (sampling vs learning).
38
+ rollout_frag_or_episode_len = config.get_rollout_fragment_length()
39
+ recurrent_seq_len = batch.get("seq_lens")
40
+
41
+ loss_mask = batch[Columns.LOSS_MASK].float()
42
+ loss_mask_time_major = make_time_major(
43
+ loss_mask,
44
+ trajectory_len=rollout_frag_or_episode_len,
45
+ recurrent_seq_len=recurrent_seq_len,
46
+ )
47
+ size_loss_mask = torch.sum(loss_mask)
48
+
49
+ # Behavior actions logp and target actions logp.
50
+ behaviour_actions_logp = batch[Columns.ACTION_LOGP]
51
+ target_policy_dist = module.get_train_action_dist_cls().from_logits(
52
+ fwd_out[Columns.ACTION_DIST_INPUTS]
53
+ )
54
+ target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])
55
+
56
+ # Values and bootstrap values.
57
+ values = module.compute_values(
58
+ batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
59
+ )
60
+ values_time_major = make_time_major(
61
+ values,
62
+ trajectory_len=rollout_frag_or_episode_len,
63
+ recurrent_seq_len=recurrent_seq_len,
64
+ )
65
+ assert Columns.VALUES_BOOTSTRAPPED not in batch
66
+ # Use as bootstrap values the vf-preds in the next "batch row", except
67
+ # for the very last row (which doesn't have a next row), for which the
68
+ # bootstrap value does not matter b/c it has a +1ts value at its end
69
+ # anyways. So we chose an arbitrary item (for simplicity of not having to
70
+ # move new data to the device).
71
+ bootstrap_values = torch.cat(
72
+ [
73
+ values_time_major[0][1:], # 0th ts values from "next row"
74
+ values_time_major[0][0:1], # <- can use any arbitrary value here
75
+ ],
76
+ dim=0,
77
+ )
78
+
79
+ # TODO(Artur): In the old impala code, actions were unsqueezed if they were
80
+ # multi_discrete. Find out why and if we need to do the same here.
81
+ # actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
82
+ target_actions_logp_time_major = make_time_major(
83
+ target_actions_logp,
84
+ trajectory_len=rollout_frag_or_episode_len,
85
+ recurrent_seq_len=recurrent_seq_len,
86
+ )
87
+ behaviour_actions_logp_time_major = make_time_major(
88
+ behaviour_actions_logp,
89
+ trajectory_len=rollout_frag_or_episode_len,
90
+ recurrent_seq_len=recurrent_seq_len,
91
+ )
92
+ rewards_time_major = make_time_major(
93
+ batch[Columns.REWARDS],
94
+ trajectory_len=rollout_frag_or_episode_len,
95
+ recurrent_seq_len=recurrent_seq_len,
96
+ )
97
+
98
+ # the discount factor that is used should be gamma except for timesteps where
99
+ # the episode is terminated. In that case, the discount factor should be 0.
100
+ discounts_time_major = (
101
+ 1.0
102
+ - make_time_major(
103
+ batch[Columns.TERMINATEDS],
104
+ trajectory_len=rollout_frag_or_episode_len,
105
+ recurrent_seq_len=recurrent_seq_len,
106
+ ).type(dtype=torch.float32)
107
+ ) * config.gamma
108
+
109
+ # Note that vtrace will compute the main loop on the CPU for better performance.
110
+ vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
111
+ target_action_log_probs=target_actions_logp_time_major,
112
+ behaviour_action_log_probs=behaviour_actions_logp_time_major,
113
+ discounts=discounts_time_major,
114
+ rewards=rewards_time_major,
115
+ values=values_time_major,
116
+ bootstrap_values=bootstrap_values,
117
+ clip_rho_threshold=config.vtrace_clip_rho_threshold,
118
+ clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
119
+ )
120
+
121
+ # The policy gradients loss.
122
+ pi_loss = -torch.sum(
123
+ target_actions_logp_time_major * pg_advantages * loss_mask_time_major
124
+ )
125
+ mean_pi_loss = pi_loss / size_loss_mask
126
+
127
+ # The baseline loss.
128
+ delta = values_time_major - vtrace_adjusted_target_values
129
+ vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
130
+ mean_vf_loss = vf_loss / size_loss_mask
131
+
132
+ # The entropy loss.
133
+ entropy_loss = -torch.sum(target_policy_dist.entropy() * loss_mask)
134
+ mean_entropy_loss = entropy_loss / size_loss_mask
135
+
136
+ # The summed weighted loss.
137
+ total_loss = (
138
+ mean_pi_loss
139
+ + mean_vf_loss * config.vf_loss_coeff
140
+ + (
141
+ mean_entropy_loss
142
+ * self.entropy_coeff_schedulers_per_module[
143
+ module_id
144
+ ].get_current_value()
145
+ )
146
+ )
147
+
148
+ # Log important loss stats.
149
+ self.metrics.log_dict(
150
+ {
151
+ "pi_loss": pi_loss,
152
+ "mean_pi_loss": mean_pi_loss,
153
+ "vf_loss": vf_loss,
154
+ "mean_vf_loss": mean_vf_loss,
155
+ ENTROPY_KEY: -mean_entropy_loss,
156
+ },
157
+ key=module_id,
158
+ window=1, # <- single items (should not be mean/ema-reduced over time).
159
+ )
160
+ # Return the total loss.
161
+ return total_loss
162
+
163
+
164
+ ImpalaTorchLearner = IMPALATorchLearner
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/impala/torch/vtrace_torch_v2.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from ray.rllib.utils.framework import try_import_torch
3
+
4
+ torch, nn = try_import_torch()
5
+
6
+
7
+ def make_time_major(
8
+ tensor: Union["torch.Tensor", List["torch.Tensor"]],
9
+ *,
10
+ trajectory_len: int = None,
11
+ recurrent_seq_len: int = None,
12
+ ):
13
+ """Swaps batch and trajectory axis.
14
+
15
+ Args:
16
+ tensor: A tensor or list of tensors to swap the axis of.
17
+ NOTE: Each tensor must have the shape [B * T] where B is the batch size and
18
+ T is the trajectory length.
19
+ trajectory_len: The length of each trajectory being transformed.
20
+ If None then `recurrent_seq_len` must be set.
21
+ recurrent_seq_len: Sequence lengths if recurrent.
22
+ If None then `trajectory_len` must be set.
23
+
24
+ Returns:
25
+ res: A tensor with swapped axes or a list of tensors with
26
+ swapped axes.
27
+ """
28
+ if isinstance(tensor, (list, tuple)):
29
+ return [
30
+ make_time_major(_tensor, trajectory_len, recurrent_seq_len)
31
+ for _tensor in tensor
32
+ ]
33
+
34
+ assert (
35
+ trajectory_len is not None or recurrent_seq_len is not None
36
+ ), "Either trajectory_len or recurrent_seq_len must be set."
37
+
38
+ # Figure out the sizes of the final B and T axes.
39
+ if recurrent_seq_len is not None:
40
+ assert len(tensor.shape) == 2
41
+ # Swap B and T axes.
42
+ tensor = torch.transpose(tensor, 1, 0)
43
+ return tensor
44
+ else:
45
+ T = trajectory_len
46
+ # Zero-pad, if necessary.
47
+ tensor_0 = tensor.shape[0]
48
+ B = tensor_0 // T
49
+ if B != (tensor_0 / T):
50
+ assert len(tensor.shape) == 1
51
+ tensor = torch.cat(
52
+ [
53
+ tensor,
54
+ torch.zeros(
55
+ trajectory_len - tensor_0 % T,
56
+ dtype=tensor.dtype,
57
+ device=tensor.device,
58
+ ),
59
+ ]
60
+ )
61
+ B += 1
62
+
63
+ # Reshape tensor (break up B axis into 2 axes: B and T).
64
+ tensor = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
65
+
66
+ # Swap B and T axes.
67
+ tensor = torch.transpose(tensor, 1, 0)
68
+
69
+ return tensor
70
+
71
+
72
+ def vtrace_torch(
73
+ *,
74
+ target_action_log_probs: "torch.Tensor",
75
+ behaviour_action_log_probs: "torch.Tensor",
76
+ discounts: "torch.Tensor",
77
+ rewards: "torch.Tensor",
78
+ values: "torch.Tensor",
79
+ bootstrap_values: "torch.Tensor",
80
+ clip_rho_threshold: Union[float, "torch.Tensor"] = 1.0,
81
+ clip_pg_rho_threshold: Union[float, "torch.Tensor"] = 1.0,
82
+ ):
83
+ """V-trace for softmax policies implemented with torch.
84
+
85
+ Calculates V-trace actor critic targets for softmax polices as described in
86
+ "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner
87
+ Architectures" by Espeholt, Soyer, Munos et al. (https://arxiv.org/abs/1802.01561)
88
+
89
+ The V-trace implementation used here closely resembles the one found in the
90
+ scalable-agent repository by Google DeepMind, available at
91
+ https://github.com/deepmind/scalable_agent. This version has been optimized to
92
+ minimize the number of floating-point operations required per V-Trace
93
+ calculation, achieved through the use of dynamic programming techniques. It's
94
+ important to note that the mathematical expressions used in this implementation
95
+ may appear quite different from those presented in the IMPALA paper.
96
+
97
+ The following terminology applies:
98
+ - `target policy` refers to the policy we are interested in improving.
99
+ - `behaviour policy` refers to the policy that generated the given
100
+ rewards and actions.
101
+ - `T` refers to the time dimension. This is usually either the length of the
102
+ trajectory or the length of the sequence if recurrent.
103
+ - `B` refers to the batch size.
104
+
105
+ Args:
106
+ target_action_log_probs: Action log probs from the target policy. A float32
107
+ tensor of shape [T, B].
108
+ behaviour_action_log_probs: Action log probs from the behaviour policy. A
109
+ float32 tensor of shape [T, B].
110
+ discounts: A float32 tensor of shape [T, B] with the discount encountered when
111
+ following the behaviour policy. This will be 0 for terminal timesteps
112
+ (done=True) and gamma (the discount factor) otherwise.
113
+ rewards: A float32 tensor of shape [T, B] with the rewards generated by
114
+ following the behaviour policy.
115
+ values: A float32 tensor of shape [T, B] with the value function estimates
116
+ wrt. the target policy.
117
+ bootstrap_values: A float32 of shape [B] with the value function estimate at
118
+ time T.
119
+ clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
120
+ importance weights (rho) when calculating the baseline targets (vs).
121
+ rho^bar in the paper.
122
+ clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
123
+ on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
124
+ """
125
+ log_rhos = target_action_log_probs - behaviour_action_log_probs
126
+
127
+ rhos = torch.exp(log_rhos)
128
+ if clip_rho_threshold is not None:
129
+ clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold)
130
+ else:
131
+ clipped_rhos = rhos
132
+
133
+ cs = torch.clamp(rhos, max=1.0)
134
+ # Append bootstrapped value to get [v1, ..., v_t+1]
135
+ values_t_plus_1 = torch.cat(
136
+ [values[1:], torch.unsqueeze(bootstrap_values, 0)], axis=0
137
+ )
138
+
139
+ deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
140
+
141
+ # Only move the for-loop to CPU.
142
+ discounts_cpu = discounts.to("cpu")
143
+ cs_cpu = cs.to("cpu")
144
+ deltas_cpu = deltas.to("cpu")
145
+ vs_minus_v_xs_cpu = [torch.zeros_like(bootstrap_values, device="cpu")]
146
+ for i in reversed(range(len(discounts_cpu))):
147
+ discount_t, c_t, delta_t = discounts_cpu[i], cs_cpu[i], deltas_cpu[i]
148
+ vs_minus_v_xs_cpu.append(delta_t + discount_t * c_t * vs_minus_v_xs_cpu[-1])
149
+ vs_minus_v_xs_cpu = torch.stack(vs_minus_v_xs_cpu[1:])
150
+ # Move results back to GPU - if applicable.
151
+ vs_minus_v_xs = vs_minus_v_xs_cpu.to(deltas.device)
152
+
153
+ # Reverse the results back to original order.
154
+ vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[0])
155
+
156
+ # Add V(x_s) to get v_s.
157
+ vs = torch.add(vs_minus_v_xs, values)
158
+
159
+ # Advantage for policy gradient.
160
+ vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_values, 0)], axis=0)
161
+ if clip_pg_rho_threshold is not None:
162
+ clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold)
163
+ else:
164
+ clipped_pg_rhos = rhos
165
+ pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)
166
+
167
+ # Make sure no gradients backpropagated through the returned values.
168
+ return torch.detach(vs), torch.detach(pg_advantages)
.venv/lib/python3.11/site-packages/ray/rllib/examples/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/examples/centralized_critic.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @OldAPIStack
2
+
3
+ # ***********************************************************************************
4
+ # IMPORTANT NOTE: This script uses the old API stack and will soon be replaced by
5
+ # `ray.rllib.examples.multi_agent.pettingzoo_shared_value_function.py`!
6
+ # ***********************************************************************************
7
+
8
+ """An example of customizing PPO to leverage a centralized critic.
9
+
10
+ Here the model and policy are hard-coded to implement a centralized critic
11
+ for TwoStepGame, but you can adapt this for your own use cases.
12
+
13
+ Compared to simply running `rllib/examples/two_step_game.py --run=PPO`,
14
+ this centralized critic version reaches vf_explained_variance=1.0 more stably
15
+ since it takes into account the opponent actions as well as the policy's.
16
+ Note that this is also using two independent policies instead of weight-sharing
17
+ with one.
18
+
19
+ See also: centralized_critic_2.py for a simpler approach that instead
20
+ modifies the environment.
21
+ """
22
+
23
+ import argparse
24
+ from gymnasium.spaces import Discrete
25
+ import numpy as np
26
+ import os
27
+
28
+ import ray
29
+ from ray import air, tune
30
+ from ray.air.constants import TRAINING_ITERATION
31
+ from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
32
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import (
33
+ PPOTF1Policy,
34
+ PPOTF2Policy,
35
+ )
36
+ from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
37
+ from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
38
+ from ray.rllib.examples.envs.classes.multi_agent.two_step_game import TwoStepGame
39
+ from ray.rllib.examples._old_api_stack.models.centralized_critic_models import (
40
+ CentralizedCriticModel,
41
+ TorchCentralizedCriticModel,
42
+ )
43
+ from ray.rllib.models import ModelCatalog
44
+ from ray.rllib.policy.sample_batch import SampleBatch
45
+ from ray.rllib.utils.annotations import override
46
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
47
+ from ray.rllib.utils.metrics import (
48
+ ENV_RUNNER_RESULTS,
49
+ EPISODE_RETURN_MEAN,
50
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
51
+ )
52
+ from ray.rllib.utils.numpy import convert_to_numpy
53
+ from ray.rllib.utils.test_utils import check_learning_achieved
54
+ from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
55
+ from ray.rllib.utils.torch_utils import convert_to_torch_tensor
56
+
57
+ tf1, tf, tfv = try_import_tf()
58
+ torch, nn = try_import_torch()
59
+
60
+ OPPONENT_OBS = "opponent_obs"
61
+ OPPONENT_ACTION = "opponent_action"
62
+
63
+ parser = argparse.ArgumentParser()
64
+ parser.add_argument(
65
+ "--framework",
66
+ choices=["tf", "tf2", "torch"],
67
+ default="torch",
68
+ help="The DL framework specifier.",
69
+ )
70
+ parser.add_argument(
71
+ "--as-test",
72
+ action="store_true",
73
+ help="Whether this script should be run as a test: --stop-reward must "
74
+ "be achieved within --stop-timesteps AND --stop-iters.",
75
+ )
76
+ parser.add_argument(
77
+ "--stop-iters", type=int, default=100, help="Number of iterations to train."
78
+ )
79
+ parser.add_argument(
80
+ "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
81
+ )
82
+ parser.add_argument(
83
+ "--stop-reward", type=float, default=7.99, help="Reward at which we stop training."
84
+ )
85
+
86
+
87
+ class CentralizedValueMixin:
88
+ """Add method to evaluate the central value function from the model."""
89
+
90
+ def __init__(self):
91
+ if self.config["framework"] != "torch":
92
+ self.compute_central_vf = make_tf_callable(self.get_session())(
93
+ self.model.central_value_function
94
+ )
95
+ else:
96
+ self.compute_central_vf = self.model.central_value_function
97
+
98
+
99
+ # Grabs the opponent obs/act and includes it in the experience train_batch,
100
+ # and computes GAE using the central vf predictions.
101
+ def centralized_critic_postprocessing(
102
+ policy, sample_batch, other_agent_batches=None, episode=None
103
+ ):
104
+ pytorch = policy.config["framework"] == "torch"
105
+ if (pytorch and hasattr(policy, "compute_central_vf")) or (
106
+ not pytorch and policy.loss_initialized()
107
+ ):
108
+ assert other_agent_batches is not None
109
+ [(_, _, opponent_batch)] = list(other_agent_batches.values())
110
+
111
+ # also record the opponent obs and actions in the trajectory
112
+ sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
113
+ sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]
114
+
115
+ # overwrite default VF prediction with the central VF
116
+ if args.framework == "torch":
117
+ sample_batch[SampleBatch.VF_PREDS] = (
118
+ policy.compute_central_vf(
119
+ convert_to_torch_tensor(
120
+ sample_batch[SampleBatch.CUR_OBS], policy.device
121
+ ),
122
+ convert_to_torch_tensor(sample_batch[OPPONENT_OBS], policy.device),
123
+ convert_to_torch_tensor(
124
+ sample_batch[OPPONENT_ACTION], policy.device
125
+ ),
126
+ )
127
+ .cpu()
128
+ .detach()
129
+ .numpy()
130
+ )
131
+ else:
132
+ sample_batch[SampleBatch.VF_PREDS] = convert_to_numpy(
133
+ policy.compute_central_vf(
134
+ sample_batch[SampleBatch.CUR_OBS],
135
+ sample_batch[OPPONENT_OBS],
136
+ sample_batch[OPPONENT_ACTION],
137
+ )
138
+ )
139
+ else:
140
+ # Policy hasn't been initialized yet, use zeros.
141
+ sample_batch[OPPONENT_OBS] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
142
+ sample_batch[OPPONENT_ACTION] = np.zeros_like(sample_batch[SampleBatch.ACTIONS])
143
+ sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
144
+ sample_batch[SampleBatch.REWARDS], dtype=np.float32
145
+ )
146
+
147
+ completed = sample_batch[SampleBatch.TERMINATEDS][-1]
148
+ if completed:
149
+ last_r = 0.0
150
+ else:
151
+ last_r = sample_batch[SampleBatch.VF_PREDS][-1]
152
+
153
+ train_batch = compute_advantages(
154
+ sample_batch,
155
+ last_r,
156
+ policy.config["gamma"],
157
+ policy.config["lambda"],
158
+ use_gae=policy.config["use_gae"],
159
+ )
160
+ return train_batch
161
+
162
+
163
+ # Copied from PPO but optimizing the central value function.
164
+ def loss_with_central_critic(policy, base_policy, model, dist_class, train_batch):
165
+ # Save original value function.
166
+ vf_saved = model.value_function
167
+
168
+ # Calculate loss with a custom value function.
169
+ model.value_function = lambda: policy.model.central_value_function(
170
+ train_batch[SampleBatch.CUR_OBS],
171
+ train_batch[OPPONENT_OBS],
172
+ train_batch[OPPONENT_ACTION],
173
+ )
174
+ policy._central_value_out = model.value_function()
175
+ loss = base_policy.loss(model, dist_class, train_batch)
176
+
177
+ # Restore original value function.
178
+ model.value_function = vf_saved
179
+
180
+ return loss
181
+
182
+
183
+ def central_vf_stats(policy, train_batch):
184
+ # Report the explained variance of the central value function.
185
+ return {
186
+ "vf_explained_var": explained_variance(
187
+ train_batch[Postprocessing.VALUE_TARGETS], policy._central_value_out
188
+ )
189
+ }
190
+
191
+
192
+ def get_ccppo_policy(base):
193
+ class CCPPOTFPolicy(CentralizedValueMixin, base):
194
+ def __init__(self, observation_space, action_space, config):
195
+ base.__init__(self, observation_space, action_space, config)
196
+ CentralizedValueMixin.__init__(self)
197
+
198
+ @override(base)
199
+ def loss(self, model, dist_class, train_batch):
200
+ # Use super() to get to the base PPO policy.
201
+ # This special loss function utilizes a shared
202
+ # value function defined on self, and the loss function
203
+ # defined on PPO policies.
204
+ return loss_with_central_critic(
205
+ self, super(), model, dist_class, train_batch
206
+ )
207
+
208
+ @override(base)
209
+ def postprocess_trajectory(
210
+ self, sample_batch, other_agent_batches=None, episode=None
211
+ ):
212
+ return centralized_critic_postprocessing(
213
+ self, sample_batch, other_agent_batches, episode
214
+ )
215
+
216
+ @override(base)
217
+ def stats_fn(self, train_batch: SampleBatch):
218
+ stats = super().stats_fn(train_batch)
219
+ stats.update(central_vf_stats(self, train_batch))
220
+ return stats
221
+
222
+ return CCPPOTFPolicy
223
+
224
+
225
+ CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOTF1Policy)
226
+ CCPPOEagerTFPolicy = get_ccppo_policy(PPOTF2Policy)
227
+
228
+
229
+ class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy):
230
+ def __init__(self, observation_space, action_space, config):
231
+ PPOTorchPolicy.__init__(self, observation_space, action_space, config)
232
+ CentralizedValueMixin.__init__(self)
233
+
234
+ @override(PPOTorchPolicy)
235
+ def loss(self, model, dist_class, train_batch):
236
+ return loss_with_central_critic(self, super(), model, dist_class, train_batch)
237
+
238
+ @override(PPOTorchPolicy)
239
+ def postprocess_trajectory(
240
+ self, sample_batch, other_agent_batches=None, episode=None
241
+ ):
242
+ return centralized_critic_postprocessing(
243
+ self, sample_batch, other_agent_batches, episode
244
+ )
245
+
246
+
247
+ class CentralizedCritic(PPO):
248
+ @classmethod
249
+ @override(PPO)
250
+ def get_default_policy_class(cls, config):
251
+ if config["framework"] == "torch":
252
+ return CCPPOTorchPolicy
253
+ elif config["framework"] == "tf":
254
+ return CCPPOStaticGraphTFPolicy
255
+ else:
256
+ return CCPPOEagerTFPolicy
257
+
258
+
259
+ if __name__ == "__main__":
260
+ ray.init(local_mode=True)
261
+ args = parser.parse_args()
262
+
263
+ ModelCatalog.register_custom_model(
264
+ "cc_model",
265
+ TorchCentralizedCriticModel
266
+ if args.framework == "torch"
267
+ else CentralizedCriticModel,
268
+ )
269
+
270
+ config = (
271
+ PPOConfig()
272
+ .api_stack(
273
+ enable_env_runner_and_connector_v2=False,
274
+ enable_rl_module_and_learner=False,
275
+ )
276
+ .environment(TwoStepGame)
277
+ .framework(args.framework)
278
+ .env_runners(batch_mode="complete_episodes", num_env_runners=0)
279
+ .training(model={"custom_model": "cc_model"})
280
+ .multi_agent(
281
+ policies={
282
+ "pol1": (
283
+ None,
284
+ Discrete(6),
285
+ TwoStepGame.action_space,
286
+ # `framework` would also be ok here.
287
+ PPOConfig.overrides(framework_str=args.framework),
288
+ ),
289
+ "pol2": (
290
+ None,
291
+ Discrete(6),
292
+ TwoStepGame.action_space,
293
+ # `framework` would also be ok here.
294
+ PPOConfig.overrides(framework_str=args.framework),
295
+ ),
296
+ },
297
+ policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
298
+ if agent_id == 0
299
+ else "pol2",
300
+ )
301
+ # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
302
+ .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
303
+ )
304
+
305
+ stop = {
306
+ TRAINING_ITERATION: args.stop_iters,
307
+ NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
308
+ f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
309
+ }
310
+
311
+ tuner = tune.Tuner(
312
+ CentralizedCritic,
313
+ param_space=config.to_dict(),
314
+ run_config=air.RunConfig(stop=stop, verbose=1),
315
+ )
316
+ results = tuner.fit()
317
+
318
+ if args.as_test:
319
+ check_learning_achieved(results, args.stop_reward)
.venv/lib/python3.11/site-packages/ray/rllib/examples/compute_adapted_gae_on_postprocess_trajectory.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @OldAPIStack
2
+
3
+ """
4
+ Adapted (time-dependent) GAE for PPO algorithm that you can activate by setting
5
+ use_adapted_gae=True in the policy config. Additionally, it's required that
6
+ "callbacks" include the custom callback class in the Algorithm's config.
7
+ Furthermore, the env must return in its info dictionary a key-value pair of
8
+ the form "d_ts": ... where the value is the length (time) of recent agent step.
9
+
10
+ This adapted, time-dependent computation of advantages may be useful in cases
11
+ where agent's actions take various times and thus time steps are not
12
+ equidistant (https://docdro.id/400TvlR)
13
+ """
14
+
15
+ from ray.rllib.callbacks.callbacks import RLlibCallback
16
+ from ray.rllib.policy.sample_batch import SampleBatch
17
+ from ray.rllib.evaluation.postprocessing import Postprocessing
18
+ from ray.rllib.utils.annotations import override
19
+ import numpy as np
20
+
21
+
22
+ class MyCallbacks(RLlibCallback):
23
+ @override(RLlibCallback)
24
+ def on_postprocess_trajectory(
25
+ self,
26
+ *,
27
+ worker,
28
+ episode,
29
+ agent_id,
30
+ policy_id,
31
+ policies,
32
+ postprocessed_batch,
33
+ original_batches,
34
+ **kwargs
35
+ ):
36
+ super().on_postprocess_trajectory(
37
+ worker=worker,
38
+ episode=episode,
39
+ agent_id=agent_id,
40
+ policy_id=policy_id,
41
+ policies=policies,
42
+ postprocessed_batch=postprocessed_batch,
43
+ original_batches=original_batches,
44
+ **kwargs
45
+ )
46
+
47
+ if policies[policy_id].config.get("use_adapted_gae", False):
48
+ policy = policies[policy_id]
49
+ assert policy.config[
50
+ "use_gae"
51
+ ], "Can't use adapted gae without use_gae=True!"
52
+
53
+ info_dicts = postprocessed_batch[SampleBatch.INFOS]
54
+ assert np.all(
55
+ ["d_ts" in info_dict for info_dict in info_dicts]
56
+ ), "Info dicts in sample batch must contain data 'd_ts' \
57
+ (=ts[i+1]-ts[i] length of time steps)!"
58
+
59
+ d_ts = np.array(
60
+ [np.float(info_dict.get("d_ts")) for info_dict in info_dicts]
61
+ )
62
+ assert np.all(
63
+ [e.is_integer() for e in d_ts]
64
+ ), "Elements of 'd_ts' (length of time steps) must be integer!"
65
+
66
+ # Trajectory is actually complete -> last r=0.0.
67
+ if postprocessed_batch[SampleBatch.TERMINATEDS][-1]:
68
+ last_r = 0.0
69
+ # Trajectory has been truncated -> last r=VF estimate of last obs.
70
+ else:
71
+ # Input dict is provided to us automatically via the Model's
72
+ # requirements. It's a single-timestep (last one in trajectory)
73
+ # input_dict.
74
+ # Create an input dict according to the Model's requirements.
75
+ input_dict = postprocessed_batch.get_single_step_input_dict(
76
+ policy.model.view_requirements, index="last"
77
+ )
78
+ last_r = policy._value(**input_dict)
79
+
80
+ gamma = policy.config["gamma"]
81
+ lambda_ = policy.config["lambda"]
82
+
83
+ vpred_t = np.concatenate(
84
+ [postprocessed_batch[SampleBatch.VF_PREDS], np.array([last_r])]
85
+ )
86
+ delta_t = (
87
+ postprocessed_batch[SampleBatch.REWARDS]
88
+ + gamma**d_ts * vpred_t[1:]
89
+ - vpred_t[:-1]
90
+ )
91
+ # This formula for the advantage is an adaption of
92
+ # "Generalized Advantage Estimation"
93
+ # (https://arxiv.org/abs/1506.02438) which accounts for time steps
94
+ # of irregular length (see proposal here ).
95
+ # NOTE: last time step delta is not required
96
+ postprocessed_batch[
97
+ Postprocessing.ADVANTAGES
98
+ ] = generalized_discount_cumsum(delta_t, d_ts[:-1], gamma * lambda_)
99
+ postprocessed_batch[Postprocessing.VALUE_TARGETS] = (
100
+ postprocessed_batch[Postprocessing.ADVANTAGES]
101
+ + postprocessed_batch[SampleBatch.VF_PREDS]
102
+ ).astype(np.float32)
103
+
104
+ postprocessed_batch[Postprocessing.ADVANTAGES] = postprocessed_batch[
105
+ Postprocessing.ADVANTAGES
106
+ ].astype(np.float32)
107
+
108
+
109
+ def generalized_discount_cumsum(
110
+ x: np.ndarray, deltas: np.ndarray, gamma: float
111
+ ) -> np.ndarray:
112
+ """Calculates the 'time-dependent' discounted cumulative sum over a
113
+ (reward) sequence `x`.
114
+
115
+ Recursive equations:
116
+
117
+ y[t] - gamma**deltas[t+1]*y[t+1] = x[t]
118
+
119
+ reversed(y)[t] - gamma**reversed(deltas)[t-1]*reversed(y)[t-1] =
120
+ reversed(x)[t]
121
+
122
+ Args:
123
+ x (np.ndarray): A sequence of rewards or one-step TD residuals.
124
+ deltas (np.ndarray): A sequence of time step deltas (length of time
125
+ steps).
126
+ gamma: The discount factor gamma.
127
+
128
+ Returns:
129
+ np.ndarray: The sequence containing the 'time-dependent' discounted
130
+ cumulative sums for each individual element in `x` till the end of
131
+ the trajectory.
132
+
133
+ .. testcode::
134
+ :skipif: True
135
+
136
+ x = np.array([0.0, 1.0, 2.0, 3.0])
137
+ deltas = np.array([1.0, 4.0, 15.0])
138
+ gamma = 0.9
139
+ generalized_discount_cumsum(x, deltas, gamma)
140
+
141
+ .. testoutput::
142
+
143
+ array([0.0 + 0.9^1.0*1.0 + 0.9^4.0*2.0 + 0.9^15.0*3.0,
144
+ 1.0 + 0.9^4.0*2.0 + 0.9^15.0*3.0,
145
+ 2.0 + 0.9^15.0*3.0,
146
+ 3.0])
147
+ """
148
+ reversed_x = x[::-1]
149
+ reversed_deltas = deltas[::-1]
150
+ reversed_y = np.empty_like(x)
151
+ reversed_y[0] = reversed_x[0]
152
+ for i in range(1, x.size):
153
+ reversed_y[i] = (
154
+ reversed_x[i] + gamma ** reversed_deltas[i - 1] * reversed_y[i - 1]
155
+ )
156
+
157
+ return reversed_y[::-1]
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_in_sequence.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example of running a multi-agent experiment w/ agents taking turns (sequence).
2
+
3
+ This example:
4
+ - demonstrates how to write your own (multi-agent) environment using RLlib's
5
+ MultiAgentEnv API.
6
+ - shows how to implement the `reset()` and `step()` methods of the env such that
7
+ the agents act in a fixed sequence (taking turns).
8
+ - shows how to configure and setup this environment class within an RLlib
9
+ Algorithm config.
10
+ - runs the experiment with the configured algo, trying to solve the environment.
11
+
12
+
13
+ How to run this script
14
+ ----------------------
15
+ `python [script file name].py --enable-new-api-stack`
16
+
17
+ For debugging, use the following additional command line options
18
+ `--no-tune --num-env-runners=0`
19
+ which should allow you to set breakpoints anywhere in the RLlib code and
20
+ have the execution stop there for inspection and debugging.
21
+
22
+ For logging to your WandB account, use:
23
+ `--wandb-key=[your WandB API key] --wandb-project=[some project name]
24
+ --wandb-run-name=[optional: WandB run name (within the defined project)]`
25
+
26
+
27
+ Results to expect
28
+ -----------------
29
+ You should see results similar to the following in your console output:
30
+ +---------------------------+----------+--------+------------------+--------+
31
+ | Trial name | status | iter | total time (s) | ts |
32
+ |---------------------------+----------+--------+------------------+--------+
33
+ | PPO_TicTacToe_957aa_00000 | RUNNING | 25 | 96.7452 | 100000 |
34
+ +---------------------------+----------+--------+------------------+--------+
35
+ +-------------------+------------------+------------------+
36
+ | combined return | return player2 | return player1 |
37
+ |-------------------+------------------+------------------|
38
+ | -2 | 1.15 | -0.85 |
39
+ +-------------------+------------------+------------------+
40
+
41
+ Note that even though we are playing a zero-sum game, the overall return should start
42
+ at some negative values due to the misplacement penalty of our (simplified) TicTacToe
43
+ game.
44
+ """
45
+ from ray.rllib.examples.envs.classes.multi_agent.tic_tac_toe import TicTacToe
46
+ from ray.rllib.utils.test_utils import (
47
+ add_rllib_example_script_args,
48
+ run_rllib_example_script_experiment,
49
+ )
50
+ from ray.tune.registry import get_trainable_cls, register_env # noqa
51
+
52
+
53
+ parser = add_rllib_example_script_args(
54
+ default_reward=-4.0, default_iters=50, default_timesteps=100000
55
+ )
56
+ parser.set_defaults(
57
+ enable_new_api_stack=True,
58
+ num_agents=2,
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ args = parser.parse_args()
64
+
65
+ assert args.num_agents == 2, "Must set --num-agents=2 when running this script!"
66
+
67
+ # You can also register the env creator function explicitly with:
68
+ # register_env("tic_tac_toe", lambda cfg: TicTacToe())
69
+
70
+ # Or allow the RLlib user to set more c'tor options via their algo config:
71
+ # config.environment(env_config={[c'tor arg name]: [value]})
72
+ # register_env("tic_tac_toe", lambda cfg: TicTacToe(cfg))
73
+
74
+ base_config = (
75
+ get_trainable_cls(args.algo)
76
+ .get_default_config()
77
+ .environment(TicTacToe)
78
+ .multi_agent(
79
+ # Define two policies.
80
+ policies={"player1", "player2"},
81
+ # Map agent "player1" to policy "player1" and agent "player2" to policy
82
+ # "player2".
83
+ policy_mapping_fn=lambda agent_id, episode, **kw: agent_id,
84
+ )
85
+ )
86
+
87
+ run_rllib_example_script_experiment(base_config, args)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/agents_act_simultaneously.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example of running a multi-agent experiment w/ agents always acting simultaneously.
2
+
3
+ This example:
4
+ - demonstrates how to write your own (multi-agent) environment using RLlib's
5
+ MultiAgentEnv API.
6
+ - shows how to implement the `reset()` and `step()` methods of the env such that
7
+ the agents act simultaneously.
8
+ - shows how to configure and setup this environment class within an RLlib
9
+ Algorithm config.
10
+ - runs the experiment with the configured algo, trying to solve the environment.
11
+
12
+
13
+ How to run this script
14
+ ----------------------
15
+ `python [script file name].py --enable-new-api-stack --sheldon-cooper-mode`
16
+
17
+ For debugging, use the following additional command line options
18
+ `--no-tune --num-env-runners=0`
19
+ which should allow you to set breakpoints anywhere in the RLlib code and
20
+ have the execution stop there for inspection and debugging.
21
+
22
+ For logging to your WandB account, use:
23
+ `--wandb-key=[your WandB API key] --wandb-project=[some project name]
24
+ --wandb-run-name=[optional: WandB run name (within the defined project)]`
25
+
26
+
27
+ Results to expect
28
+ -----------------
29
+ You should see results similar to the following in your console output:
30
+
31
+ +-----------------------------------+----------+--------+------------------+-------+
32
+ | Trial name | status | iter | total time (s) | ts |
33
+ |-----------------------------------+----------+--------+------------------+-------+
34
+ | PPO_RockPaperScissors_8cef7_00000 | RUNNING | 3 | 16.5348 | 12000 |
35
+ +-----------------------------------+----------+--------+------------------+-------+
36
+ +-------------------+------------------+------------------+
37
+ | combined return | return player2 | return player1 |
38
+ |-------------------+------------------+------------------|
39
+ | 0 | -0.15 | 0.15 |
40
+ +-------------------+------------------+------------------+
41
+
42
+ Note that b/c we are playing a zero-sum game, the overall return remains 0.0 at
43
+ all times.
44
+ """
45
+ from ray.rllib.examples.envs.classes.multi_agent.rock_paper_scissors import (
46
+ RockPaperScissors,
47
+ )
48
+ from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations
49
+ from ray.rllib.utils.test_utils import (
50
+ add_rllib_example_script_args,
51
+ run_rllib_example_script_experiment,
52
+ )
53
+ from ray.tune.registry import get_trainable_cls, register_env # noqa
54
+
55
+
56
+ parser = add_rllib_example_script_args(
57
+ default_reward=0.9, default_iters=50, default_timesteps=100000
58
+ )
59
+ parser.set_defaults(
60
+ enable_new_api_stack=True,
61
+ num_agents=2,
62
+ )
63
+ parser.add_argument(
64
+ "--sheldon-cooper-mode",
65
+ action="store_true",
66
+ help="Whether to add two more actions to the game: Lizard and Spock. "
67
+ "Watch here for more details :) https://www.youtube.com/watch?v=x5Q6-wMx-K8",
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ args = parser.parse_args()
73
+
74
+ assert args.num_agents == 2, "Must set --num-agents=2 when running this script!"
75
+
76
+ # You can also register the env creator function explicitly with:
77
+ # register_env("env", lambda cfg: RockPaperScissors({"sheldon_cooper_mode": False}))
78
+
79
+ # Or you can hard code certain settings into the Env's constructor (`config`).
80
+ # register_env(
81
+ # "rock-paper-scissors-w-sheldon-mode-activated",
82
+ # lambda config: RockPaperScissors({**config, **{"sheldon_cooper_mode": True}}),
83
+ # )
84
+
85
+ # Or allow the RLlib user to set more c'tor options via their algo config:
86
+ # config.environment(env_config={[c'tor arg name]: [value]})
87
+ # register_env("rock-paper-scissors", lambda cfg: RockPaperScissors(cfg))
88
+
89
+ base_config = (
90
+ get_trainable_cls(args.algo)
91
+ .get_default_config()
92
+ .environment(
93
+ RockPaperScissors,
94
+ env_config={"sheldon_cooper_mode": args.sheldon_cooper_mode},
95
+ )
96
+ .env_runners(
97
+ env_to_module_connector=lambda env: FlattenObservations(multi_agent=True),
98
+ )
99
+ .multi_agent(
100
+ # Define two policies.
101
+ policies={"player1", "player2"},
102
+ # Map agent "player1" to policy "player1" and agent "player2" to policy
103
+ # "player2".
104
+ policy_mapping_fn=lambda agent_id, episode, **kw: agent_id,
105
+ )
106
+ )
107
+
108
+ run_rllib_example_script_experiment(base_config, args)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/async_gym_env_vectorization.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example demo'ing async gym vector envs, in which sub-envs have their own process.
2
+
3
+ Setting up env vectorization works through setting the `config.num_envs_per_env_runner`
4
+ value to > 1. However, by default the n sub-environments are stepped through
5
+ sequentially, rather than in parallel.
6
+
7
+ This script shows the effect of setting the `config.gym_env_vectorize_mode` from its
8
+ default value of "SYNC" (all sub envs are located in the same EnvRunner process)
9
+ to "ASYNC" (all sub envs in each EnvRunner get their own process).
10
+
11
+ This example:
12
+ - shows, which config settings to change in order to switch from sub-envs being
13
+ stepped in sequence to each sub-envs owning its own process (and compute resource)
14
+ and thus the vector being stepped in parallel.
15
+ - shows, how this setup can increase EnvRunner performance significantly, especially
16
+ for heavier, slower environments.
17
+ - uses an artificially slow CartPole-v1 environment for demonstration purposes.
18
+
19
+
20
+ How to run this script
21
+ ----------------------
22
+ `python [script file name].py --enable-new-api-stack `
23
+
24
+ Use the `--vectorize-mode=BOTH` option to run both modes (SYNC and ASYNC)
25
+ through Tune at the same time and get a better comparison of the throughputs
26
+ achieved.
27
+
28
+ For debugging, use the following additional command line options
29
+ `--no-tune --num-env-runners=0`
30
+ which should allow you to set breakpoints anywhere in the RLlib code and
31
+ have the execution stop there for inspection and debugging.
32
+
33
+ For logging to your WandB account, use:
34
+ `--wandb-key=[your WandB API key] --wandb-project=[some project name]
35
+ --wandb-run-name=[optional: WandB run name (within the defined project)]`
36
+
37
+
38
+ Results to expect
39
+ -----------------
40
+ You should see results similar to the following in your console output
41
+ when using the
42
+
43
+ +--------------------------+------------+------------------------+------+
44
+ | Trial name | status | gym_env_vectorize_mode | iter |
45
+ | | | | |
46
+ |--------------------------+------------+------------------------+------+
47
+ | PPO_slow-env_6ddf4_00000 | TERMINATED | SYNC | 4 |
48
+ | PPO_slow-env_6ddf4_00001 | TERMINATED | ASYNC | 4 |
49
+ +--------------------------+------------+------------------------+------+
50
+ +------------------+----------------------+------------------------+
51
+ | total time (s) | episode_return_mean | num_env_steps_sample |
52
+ | | | d_lifetime |
53
+ |------------------+----------------------+------------------------+
54
+ | 60.8794 | 73.53 | 16040 |
55
+ | 19.1203 | 73.86 | 16037 |
56
+ +------------------+----------------------+------------------------+
57
+
58
+ You can see that the ASYNC mode, given that the env is sufficiently slow,
59
+ achieves much better results when using vectorization.
60
+
61
+ You should see no difference, however, when only using
62
+ `--num-envs-per-env-runner=1`.
63
+ """
64
+ import time
65
+
66
+ import gymnasium as gym
67
+
68
+ from ray.rllib.algorithms.ppo import PPOConfig
69
+ from ray.rllib.utils.test_utils import (
70
+ add_rllib_example_script_args,
71
+ run_rllib_example_script_experiment,
72
+ )
73
+ from ray import tune
74
+
75
+ parser = add_rllib_example_script_args(default_reward=60.0)
76
+ parser.set_defaults(
77
+ enable_new_api_stack=True,
78
+ env="CartPole-v1",
79
+ num_envs_per_env_runner=6,
80
+ )
81
+ parser.add_argument(
82
+ "--vectorize-mode",
83
+ type=str,
84
+ default="ASYNC",
85
+ help="The value `gym.envs.registration.VectorizeMode` to use for env "
86
+ "vectorization. SYNC steps through all sub-envs in sequence. ASYNC (default) "
87
+ "parallelizes sub-envs through multiprocessing and can speed up EnvRunners "
88
+ "significantly. Use the special value `BOTH` to run both ASYNC and SYNC through a "
89
+ "Tune grid-search.",
90
+ )
91
+
92
+
93
+ class SlowEnv(gym.ObservationWrapper):
94
+ def observation(self, observation):
95
+ time.sleep(0.005)
96
+ return observation
97
+
98
+
99
+ if __name__ == "__main__":
100
+ args = parser.parse_args()
101
+
102
+ if args.no_tune and args.vectorize_mode == "BOTH":
103
+ raise ValueError(
104
+ "Can't run this script with both --no-tune and --vectorize-mode=BOTH!"
105
+ )
106
+
107
+ # Wrap the env with the slowness wrapper.
108
+ def _env_creator(cfg):
109
+ return SlowEnv(gym.make(args.env, **cfg))
110
+
111
+ tune.register_env("slow-env", _env_creator)
112
+
113
+ if args.vectorize_mode == "BOTH" and args.no_tune:
114
+ raise ValueError(
115
+ "`--vectorize-mode=BOTH` and `--no-tune` not allowed in combination!"
116
+ )
117
+
118
+ base_config = (
119
+ PPOConfig()
120
+ .environment("slow-env")
121
+ .env_runners(
122
+ gym_env_vectorize_mode=(
123
+ tune.grid_search(["SYNC", "ASYNC"])
124
+ if args.vectorize_mode == "BOTH"
125
+ else args.vectorize_mode
126
+ ),
127
+ )
128
+ )
129
+
130
+ results = run_rllib_example_script_experiment(base_config, args)
131
+
132
+ # Compare the throughputs and assert that ASYNC is much faster than SYNC.
133
+ if args.vectorize_mode == "BOTH":
134
+ throughput_sync = (
135
+ results[0].metrics["num_env_steps_sampled_lifetime"]
136
+ / results[0].metrics["time_total_s"]
137
+ )
138
+ throughput_async = (
139
+ results[1].metrics["num_env_steps_sampled_lifetime"]
140
+ / results[1].metrics["time_total_s"]
141
+ )
142
+ assert throughput_async > throughput_sync
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/action_mask_env.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gymnasium.spaces import Box, Dict, Discrete
2
+ import numpy as np
3
+
4
+ from ray.rllib.examples.envs.classes.random_env import RandomEnv
5
+
6
+
7
+ class ActionMaskEnv(RandomEnv):
8
+ """A randomly acting environment that publishes an action-mask each step."""
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ # Masking only works for Discrete actions.
13
+ assert isinstance(self.action_space, Discrete)
14
+ # Add action_mask to observations.
15
+ self.observation_space = Dict(
16
+ {
17
+ "action_mask": Box(0.0, 1.0, shape=(self.action_space.n,)),
18
+ "observations": self.observation_space,
19
+ }
20
+ )
21
+ self.valid_actions = None
22
+
23
+ def reset(self, *, seed=None, options=None):
24
+ obs, info = super().reset()
25
+ self._fix_action_mask(obs)
26
+ return obs, info
27
+
28
+ def step(self, action):
29
+ # Check whether action is valid.
30
+ if not self.valid_actions[action]:
31
+ raise ValueError(
32
+ f"Invalid action ({action}) sent to env! "
33
+ f"valid_actions={self.valid_actions}"
34
+ )
35
+ obs, rew, done, truncated, info = super().step(action)
36
+ self._fix_action_mask(obs)
37
+ return obs, rew, done, truncated, info
38
+
39
+ def _fix_action_mask(self, obs):
40
+ # Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
41
+ self.valid_actions = np.round(obs["action_mask"])
42
+ obs["action_mask"] = self.valid_actions
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_crashing.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from gymnasium.envs.classic_control import CartPoleEnv
3
+ import numpy as np
4
+ import time
5
+
6
+ from ray.rllib.examples.envs.classes.multi_agent import make_multi_agent
7
+ from ray.rllib.utils.annotations import override
8
+ from ray.rllib.utils.error import EnvError
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class CartPoleCrashing(CartPoleEnv):
14
+ """A CartPole env that crashes (or stalls) from time to time.
15
+
16
+ Useful for testing faulty sub-env (within a vectorized env) handling by
17
+ EnvRunners.
18
+
19
+ After crashing, the env expects a `reset()` call next (calling `step()` will
20
+ result in yet another error), which may or may not take a very long time to
21
+ complete. This simulates the env having to reinitialize some sub-processes, e.g.
22
+ an external connection.
23
+
24
+ The env can also be configured to stall (and do nothing during a call to `step()`)
25
+ from time to time for a configurable amount of time.
26
+ """
27
+
28
+ def __init__(self, config=None):
29
+ super().__init__()
30
+
31
+ self.config = config if config is not None else {}
32
+
33
+ # Crash probability (in each `step()`).
34
+ self.p_crash = config.get("p_crash", 0.005)
35
+ # Crash probability when `reset()` is called.
36
+ self.p_crash_reset = config.get("p_crash_reset", 0.0)
37
+ # Crash exactly after every n steps. If a 2-tuple, will uniformly sample
38
+ # crash timesteps from in between the two given values.
39
+ self.crash_after_n_steps = config.get("crash_after_n_steps")
40
+ self._crash_after_n_steps = None
41
+ assert (
42
+ self.crash_after_n_steps is None
43
+ or isinstance(self.crash_after_n_steps, int)
44
+ or (
45
+ isinstance(self.crash_after_n_steps, tuple)
46
+ and len(self.crash_after_n_steps) == 2
47
+ )
48
+ )
49
+ # Only ever crash, if on certain worker indices.
50
+ faulty_indices = config.get("crash_on_worker_indices", None)
51
+ if faulty_indices and config.worker_index not in faulty_indices:
52
+ self.p_crash = 0.0
53
+ self.p_crash_reset = 0.0
54
+ self.crash_after_n_steps = None
55
+
56
+ # Stall probability (in each `step()`).
57
+ self.p_stall = config.get("p_stall", 0.0)
58
+ # Stall probability when `reset()` is called.
59
+ self.p_stall_reset = config.get("p_stall_reset", 0.0)
60
+ # Stall exactly after every n steps.
61
+ self.stall_after_n_steps = config.get("stall_after_n_steps")
62
+ self._stall_after_n_steps = None
63
+ # Amount of time to stall. If a 2-tuple, will uniformly sample from in between
64
+ # the two given values.
65
+ self.stall_time_sec = config.get("stall_time_sec")
66
+ assert (
67
+ self.stall_time_sec is None
68
+ or isinstance(self.stall_time_sec, (int, float))
69
+ or (
70
+ isinstance(self.stall_time_sec, tuple) and len(self.stall_time_sec) == 2
71
+ )
72
+ )
73
+
74
+ # Only ever stall, if on certain worker indices.
75
+ faulty_indices = config.get("stall_on_worker_indices", None)
76
+ if faulty_indices and config.worker_index not in faulty_indices:
77
+ self.p_stall = 0.0
78
+ self.p_stall_reset = 0.0
79
+ self.stall_after_n_steps = None
80
+
81
+ # Timestep counter for the ongoing episode.
82
+ self.timesteps = 0
83
+
84
+ # Time in seconds to initialize (in this c'tor).
85
+ sample = 0.0
86
+ if "init_time_s" in config:
87
+ sample = (
88
+ config["init_time_s"]
89
+ if not isinstance(config["init_time_s"], tuple)
90
+ else np.random.uniform(
91
+ config["init_time_s"][0], config["init_time_s"][1]
92
+ )
93
+ )
94
+
95
+ print(f"Initializing crashing env (with init-delay of {sample}sec) ...")
96
+ time.sleep(sample)
97
+
98
+ # Make sure envs don't crash at the same time.
99
+ self._rng = np.random.RandomState()
100
+
101
+ @override(CartPoleEnv)
102
+ def reset(self, *, seed=None, options=None):
103
+ # Reset timestep counter for the new episode.
104
+ self.timesteps = 0
105
+ self._crash_after_n_steps = None
106
+
107
+ # Should we crash?
108
+ if self._should_crash(p=self.p_crash_reset):
109
+ raise EnvError(
110
+ f"Simulated env crash on worker={self.config.worker_index} "
111
+ f"env-idx={self.config.vector_index} during `reset()`! "
112
+ "Feel free to use any other exception type here instead."
113
+ )
114
+ # Should we stall for a while?
115
+ self._stall_if_necessary(p=self.p_stall_reset)
116
+
117
+ return super().reset()
118
+
119
+ @override(CartPoleEnv)
120
+ def step(self, action):
121
+ # Increase timestep counter for the ongoing episode.
122
+ self.timesteps += 1
123
+
124
+ # Should we crash?
125
+ if self._should_crash(p=self.p_crash):
126
+ raise EnvError(
127
+ f"Simulated env crash on worker={self.config.worker_index} "
128
+ f"env-idx={self.config.vector_index} during `step()`! "
129
+ "Feel free to use any other exception type here instead."
130
+ )
131
+ # Should we stall for a while?
132
+ self._stall_if_necessary(p=self.p_stall)
133
+
134
+ return super().step(action)
135
+
136
+ def _should_crash(self, p):
137
+ rnd = self._rng.rand()
138
+ if rnd < p:
139
+ print("Crashing due to p(crash)!")
140
+ return True
141
+ elif self.crash_after_n_steps is not None:
142
+ if self._crash_after_n_steps is None:
143
+ self._crash_after_n_steps = (
144
+ self.crash_after_n_steps
145
+ if not isinstance(self.crash_after_n_steps, tuple)
146
+ else np.random.randint(
147
+ self.crash_after_n_steps[0], self.crash_after_n_steps[1]
148
+ )
149
+ )
150
+ if self._crash_after_n_steps == self.timesteps:
151
+ print("Crashing due to n timesteps reached!")
152
+ return True
153
+
154
+ return False
155
+
156
+ def _stall_if_necessary(self, p):
157
+ stall = False
158
+ if self._rng.rand() < p:
159
+ stall = True
160
+ elif self.stall_after_n_steps is not None:
161
+ if self._stall_after_n_steps is None:
162
+ self._stall_after_n_steps = (
163
+ self.stall_after_n_steps
164
+ if not isinstance(self.stall_after_n_steps, tuple)
165
+ else np.random.randint(
166
+ self.stall_after_n_steps[0], self.stall_after_n_steps[1]
167
+ )
168
+ )
169
+ if self._stall_after_n_steps == self.timesteps:
170
+ stall = True
171
+
172
+ if stall:
173
+ sec = (
174
+ self.stall_time_sec
175
+ if not isinstance(self.stall_time_sec, tuple)
176
+ else np.random.uniform(self.stall_time_sec[0], self.stall_time_sec[1])
177
+ )
178
+ print(f" -> will stall for {sec}sec ...")
179
+ time.sleep(sec)
180
+
181
+
182
+ MultiAgentCartPoleCrashing = make_multi_agent(lambda config: CartPoleCrashing(config))
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_sparse_rewards.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import gymnasium as gym
4
+ import numpy as np
5
+ from gymnasium.spaces import Box, Dict, Discrete
6
+
7
+
8
+ class CartPoleSparseRewards(gym.Env):
9
+ """Wrapper for gym CartPole environment where reward is accumulated to the end."""
10
+
11
+ def __init__(self, config=None):
12
+ self.env = gym.make("CartPole-v1")
13
+ self.action_space = Discrete(2)
14
+ self.observation_space = Dict(
15
+ {
16
+ "obs": self.env.observation_space,
17
+ "action_mask": Box(
18
+ low=0, high=1, shape=(self.action_space.n,), dtype=np.int8
19
+ ),
20
+ }
21
+ )
22
+ self.running_reward = 0
23
+
24
+ def reset(self, *, seed=None, options=None):
25
+ self.running_reward = 0
26
+ obs, infos = self.env.reset()
27
+ return {
28
+ "obs": obs,
29
+ "action_mask": np.array([1, 1], dtype=np.int8),
30
+ }, infos
31
+
32
+ def step(self, action):
33
+ obs, rew, terminated, truncated, info = self.env.step(action)
34
+ self.running_reward += rew
35
+ score = self.running_reward if terminated else 0
36
+ return (
37
+ {"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)},
38
+ score,
39
+ terminated,
40
+ truncated,
41
+ info,
42
+ )
43
+
44
+ def set_state(self, state):
45
+ self.running_reward = state[1]
46
+ self.env = deepcopy(state[0])
47
+ obs = np.array(list(self.env.unwrapped.state))
48
+ return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)}
49
+
50
+ def get_state(self):
51
+ return deepcopy(self.env), self.running_reward
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_dict_observation_space.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.envs.classic_control import CartPoleEnv
3
+ import numpy as np
4
+
5
+
6
+ class CartPoleWithDictObservationSpace(CartPoleEnv):
7
+ """CartPole gym environment that has a dict observation space.
8
+
9
+ However, otherwise, the information content in each observation remains the same.
10
+
11
+ https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py # noqa
12
+
13
+ The new observation space looks as follows (a little quirky, but this is
14
+ for testing purposes only):
15
+
16
+ gym.spaces.Dict({
17
+ "x-pos": [x-pos],
18
+ "angular-pos": gym.spaces.Dict({"test": [angular-pos]}),
19
+ "velocs": gym.spaces.Tuple([x-veloc, angular-veloc]),
20
+ })
21
+ """
22
+
23
+ def __init__(self, config=None):
24
+ super().__init__()
25
+
26
+ # Fix our observation-space as described above.
27
+ low = self.observation_space.low
28
+ high = self.observation_space.high
29
+
30
+ # Test as many quirks and oddities as possible: Dict, Dict inside a Dict,
31
+ # Tuple inside a Dict, and both (1,)-shapes as well as ()-shapes for Boxes.
32
+ # Also add a random discrete variable here.
33
+ self.observation_space = gym.spaces.Dict(
34
+ {
35
+ "x-pos": gym.spaces.Box(low[0], high[0], (1,), dtype=np.float32),
36
+ "angular-pos": gym.spaces.Dict(
37
+ {
38
+ "value": gym.spaces.Box(low[2], high[2], (), dtype=np.float32),
39
+ # Add some random non-essential information.
40
+ "some_random_stuff": gym.spaces.Discrete(3),
41
+ }
42
+ ),
43
+ "velocs": gym.spaces.Tuple(
44
+ [
45
+ # x-veloc
46
+ gym.spaces.Box(low[1], high[1], (1,), dtype=np.float32),
47
+ # angular-veloc
48
+ gym.spaces.Box(low[3], high[3], (), dtype=np.float32),
49
+ ]
50
+ ),
51
+ }
52
+ )
53
+
54
+ def step(self, action):
55
+ next_obs, reward, done, truncated, info = super().step(action)
56
+ return self._compile_current_obs(next_obs), reward, done, truncated, info
57
+
58
+ def reset(self, *, seed=None, options=None):
59
+ init_obs, init_info = super().reset(seed=seed, options=options)
60
+ return self._compile_current_obs(init_obs), init_info
61
+
62
+ def _compile_current_obs(self, original_cartpole_obs):
63
+ # original_cartpole_obs is [x-pos, x-veloc, angle, angle-veloc]
64
+ return {
65
+ "x-pos": np.array([original_cartpole_obs[0]], np.float32),
66
+ "angular-pos": {
67
+ "value": original_cartpole_obs[2],
68
+ "some_random_stuff": np.random.randint(3),
69
+ },
70
+ "velocs": (
71
+ np.array([original_cartpole_obs[1]], np.float32),
72
+ np.array(original_cartpole_obs[3], np.float32),
73
+ ),
74
+ }
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_large_observation_space.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.envs.classic_control import CartPoleEnv
3
+ import numpy as np
4
+
5
+
6
+ class CartPoleWithLargeObservationSpace(CartPoleEnv):
7
+ """CartPole gym environment that has a large dict observation space.
8
+
9
+ However, otherwise, the information content in each observation remains the same.
10
+
11
+ https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py # noqa
12
+
13
+ The new observation space looks as follows (a little quirky, but this is
14
+ for testing purposes only):
15
+
16
+ gym.spaces.Dict({
17
+ "1": gym.spaces.Tuple((
18
+ gym.spaces.Discrete(100),
19
+ gym.spaces.Box(0, 256, shape=(30,), dtype=float32),
20
+ )),
21
+ "2": gym.spaces.Tuple((
22
+ gym.spaces.Discrete(100),
23
+ gym.spaces.Box(0, 256, shape=(30,), dtype=float32),
24
+ )),
25
+ "3": ...
26
+ "actual-obs": gym.spaces.Box(-inf, inf, (4,), float32),
27
+ })
28
+ """
29
+
30
+ def __init__(self, config=None):
31
+ super().__init__()
32
+
33
+ # Fix our observation-space as described above.
34
+ low = self.observation_space.low
35
+ high = self.observation_space.high
36
+
37
+ # Test as many quirks and oddities as possible: Dict, Dict inside a Dict,
38
+ # Tuple inside a Dict, and both (1,)-shapes as well as ()-shapes for Boxes.
39
+ # Also add a random discrete variable here.
40
+ spaces = {
41
+ str(i): gym.spaces.Tuple(
42
+ (
43
+ gym.spaces.Discrete(100),
44
+ gym.spaces.Box(0, 256, shape=(30,), dtype=np.float32),
45
+ )
46
+ )
47
+ for i in range(100)
48
+ }
49
+ spaces.update(
50
+ {
51
+ "actually-useful-stuff": (
52
+ gym.spaces.Box(low[0], high[0], (4,), np.float32)
53
+ )
54
+ }
55
+ )
56
+ self.observation_space = gym.spaces.Dict(spaces)
57
+
58
+ def step(self, action):
59
+ next_obs, reward, done, truncated, info = super().step(action)
60
+ return self._compile_current_obs(next_obs), reward, done, truncated, info
61
+
62
+ def reset(self, *, seed=None, options=None):
63
+ init_obs, init_info = super().reset(seed=seed, options=options)
64
+ return self._compile_current_obs(init_obs), init_info
65
+
66
+ def _compile_current_obs(self, original_cartpole_obs):
67
+ return {
68
+ str(i): self.observation_space.spaces[str(i)].sample() for i in range(100)
69
+ } | {"actually-useful-stuff": original_cartpole_obs}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cartpole_with_protobuf_observation_space.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.envs.classic_control import CartPoleEnv
3
+ import numpy as np
4
+
5
+ from ray.rllib.examples.envs.classes.utils.cartpole_observations_proto import (
6
+ CartPoleObservation,
7
+ )
8
+
9
+
10
+ class CartPoleWithProtobufObservationSpace(CartPoleEnv):
11
+ """CartPole gym environment that has a protobuf observation space.
12
+
13
+ Sometimes, it is more performant for an environment to publish its observations
14
+ as a protobuf message (instead of a heavily nested Dict).
15
+
16
+ The protobuf message used here is originally defined in the
17
+ `./utils/cartpole_observations.proto` file. We converted this file into a python
18
+ importable module by compiling it with:
19
+
20
+ `protoc --python_out=. cartpole_observations.proto`
21
+
22
+ .. which yielded the `cartpole_observations_proto.py` file in the same directory
23
+ (we import this file's `CartPoleObservation` message here).
24
+
25
+ The new observation space is a (binary) Box(0, 255, ([len of protobuf],), uint8).
26
+
27
+ A ConnectorV2 pipeline or simpler gym.Wrapper will have to be used to convert this
28
+ observation format into an NN-readable (e.g. float32) 1D tensor.
29
+ """
30
+
31
+ def __init__(self, config=None):
32
+ super().__init__()
33
+ dummy_obs = self._convert_observation_to_protobuf(
34
+ np.array([1.0, 1.0, 1.0, 1.0])
35
+ )
36
+ bin_length = len(dummy_obs)
37
+ self.observation_space = gym.spaces.Box(0, 255, (bin_length,), np.uint8)
38
+
39
+ def step(self, action):
40
+ observation, reward, terminated, truncated, info = super().step(action)
41
+ proto_observation = self._convert_observation_to_protobuf(observation)
42
+ return proto_observation, reward, terminated, truncated, info
43
+
44
+ def reset(self, **kwargs):
45
+ observation, info = super().reset(**kwargs)
46
+ proto_observation = self._convert_observation_to_protobuf(observation)
47
+ return proto_observation, info
48
+
49
+ def _convert_observation_to_protobuf(self, observation):
50
+ x_pos, x_veloc, angle_pos, angle_veloc = observation
51
+
52
+ # Create the Protobuf message
53
+ cartpole_observation = CartPoleObservation()
54
+ cartpole_observation.x_pos = x_pos
55
+ cartpole_observation.x_veloc = x_veloc
56
+ cartpole_observation.angle_pos = angle_pos
57
+ cartpole_observation.angle_veloc = angle_veloc
58
+
59
+ # Serialize to binary string.
60
+ return np.frombuffer(cartpole_observation.SerializeToString(), np.uint8)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ env = CartPoleWithProtobufObservationSpace()
65
+ obs, info = env.reset()
66
+
67
+ # Test loading a protobuf object with data from the obs binary string
68
+ # (uint8 ndarray).
69
+ byte_str = obs.tobytes()
70
+ obs_protobuf = CartPoleObservation()
71
+ obs_protobuf.ParseFromString(byte_str)
72
+ print(obs_protobuf)
73
+
74
+ terminated = truncated = False
75
+ while not terminated and not truncated:
76
+ action = env.action_space.sample()
77
+ obs, reward, terminated, truncated, info = env.step(action)
78
+
79
+ print(obs)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/cliff_walking_wall_env.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+
4
+ ACTION_UP = 0
5
+ ACTION_RIGHT = 1
6
+ ACTION_DOWN = 2
7
+ ACTION_LEFT = 3
8
+
9
+
10
+ class CliffWalkingWallEnv(gym.Env):
11
+ """Modified version of the CliffWalking environment from Farama-Foundation's
12
+ Gymnasium with walls instead of a cliff.
13
+
14
+ ### Description
15
+ The board is a 4x12 matrix, with (using NumPy matrix indexing):
16
+ - [3, 0] or obs==36 as the start at bottom-left
17
+ - [3, 11] or obs==47 as the goal at bottom-right
18
+ - [3, 1..10] or obs==37...46 as the cliff at bottom-center
19
+
20
+ An episode terminates when the agent reaches the goal.
21
+
22
+ ### Actions
23
+ There are 4 discrete deterministic actions:
24
+ - 0: move up
25
+ - 1: move right
26
+ - 2: move down
27
+ - 3: move left
28
+ You can also use the constants ACTION_UP, ACTION_RIGHT, ... defined above.
29
+
30
+ ### Observations
31
+ There are 3x12 + 2 possible states, not including the walls. If an action
32
+ would move an agent into one of the walls, it simply stays in the same position.
33
+
34
+ ### Reward
35
+ Each time step incurs -1 reward, except reaching the goal which gives +10 reward.
36
+ """
37
+
38
+ def __init__(self, seed=42) -> None:
39
+ self.observation_space = spaces.Discrete(48)
40
+ self.action_space = spaces.Discrete(4)
41
+ self.observation_space.seed(seed)
42
+ self.action_space.seed(seed)
43
+
44
+ def reset(self, *, seed=None, options=None):
45
+ self.position = 36
46
+ return self.position, {}
47
+
48
+ def step(self, action):
49
+ x = self.position // 12
50
+ y = self.position % 12
51
+ # UP
52
+ if action == ACTION_UP:
53
+ x = max(x - 1, 0)
54
+ # RIGHT
55
+ elif action == ACTION_RIGHT:
56
+ if self.position != 36:
57
+ y = min(y + 1, 11)
58
+ # DOWN
59
+ elif action == ACTION_DOWN:
60
+ if self.position < 25 or self.position > 34:
61
+ x = min(x + 1, 3)
62
+ # LEFT
63
+ elif action == ACTION_LEFT:
64
+ if self.position != 47:
65
+ y = max(y - 1, 0)
66
+ else:
67
+ raise ValueError(f"action {action} not in {self.action_space}")
68
+ self.position = x * 12 + y
69
+ done = self.position == 47
70
+ reward = -1 if not done else 10
71
+ return self.position, reward, done, False, {}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/correlated_actions_env.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import gymnasium as gym
4
+ import numpy as np
5
+
6
+
7
+ class CorrelatedActionsEnv(gym.Env):
8
+ """Environment that can only be solved through an autoregressive action model.
9
+
10
+ In each step, the agent observes a random number (between -1 and 1) and has
11
+ to choose two actions, a1 (discrete, 0, 1, or 2) and a2 (cont. between -1 and 1).
12
+
13
+ The reward is constructed such that actions need to be correlated to succeed. It's
14
+ impossible for the network to learn each action head separately.
15
+
16
+ There are two reward components:
17
+ The first is the negative absolute value of the delta between 1.0 and the sum of
18
+ obs + a1. For example, if obs is -0.3 and a1 was sampled to be 1, then the value of
19
+ the first reward component is:
20
+ r1 = -abs(1.0 - [obs+a1]) = -abs(1.0 - (-0.3 + 1)) = -abs(0.3) = -0.3
21
+ The second reward component is computed as the negative absolute value
22
+ of `obs + a1 + a2`. For example, if obs is 0.5, a1 was sampled to be 0,
23
+ and a2 was sampled to be -0.7, then the value of the second reward component is:
24
+ r2 = -abs(obs + a1 + a2) = -abs(0.5 + 0 - 0.7)) = -abs(-0.2) = -0.2
25
+
26
+ Because of this specific reward function, the agent must learn to optimally sample
27
+ a1 based on the observation and to optimally sample a2, based on the observation
28
+ AND the sampled value of a1.
29
+
30
+ One way to effectively learn this is through correlated action
31
+ distributions, e.g., in examples/actions/auto_regressive_actions.py
32
+
33
+ The game ends after the first step.
34
+ """
35
+
36
+ def __init__(self, config=None):
37
+ super().__init__()
38
+ # Observation space (single continuous value between -1. and 1.).
39
+ self.observation_space = gym.spaces.Box(-1.0, 1.0, shape=(1,), dtype=np.float32)
40
+
41
+ # Action space (discrete action a1 and continuous action a2).
42
+ self.action_space = gym.spaces.Tuple(
43
+ [gym.spaces.Discrete(3), gym.spaces.Box(-2.0, 2.0, (1,), np.float32)]
44
+ )
45
+
46
+ # Internal state for the environment (e.g., could represent a factor
47
+ # influencing the relationship)
48
+ self.obs = None
49
+
50
+ def reset(
51
+ self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
52
+ ):
53
+ """Reset the environment to an initial state."""
54
+ super().reset(seed=seed, options=options)
55
+
56
+ # Randomly initialize the observation between -1 and 1.
57
+ self.obs = np.random.uniform(-1, 1, size=(1,))
58
+
59
+ return self.obs, {}
60
+
61
+ def step(self, action):
62
+ """Apply the autoregressive action and return step information."""
63
+
64
+ # Extract individual action components, a1 and a2.
65
+ a1, a2 = action
66
+ a2 = a2[0] # dissolve shape=(1,)
67
+
68
+ # r1 depends on how well a1 is aligned to obs:
69
+ r1 = -abs(1.0 - (self.obs[0] + a1))
70
+ # r2 depends on how well a2 is aligned to both, obs and a1.
71
+ r2 = -abs(self.obs[0] + a1 + a2)
72
+
73
+ reward = r1 + r2
74
+
75
+ # Optionally: add some noise or complexity to the reward function
76
+ # reward += np.random.normal(0, 0.01) # Small noise can be added
77
+
78
+ # Terminate after each step (no episode length in this simple example)
79
+ return self.obs, reward, True, False, {}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/d4rl_env.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 8 Environments from D4RL Environment.
3
+ Use fully qualified class-path in your configs:
4
+ e.g. "env": "ray.rllib.examples.envs.classes.d4rl_env.halfcheetah_random".
5
+ """
6
+
7
+ import gymnasium as gym
8
+
9
+ try:
10
+ import d4rl
11
+
12
+ d4rl.__name__ # Fool LINTer.
13
+ except ImportError:
14
+ d4rl = None
15
+
16
+
17
+ def halfcheetah_random():
18
+ return gym.make("halfcheetah-random-v0")
19
+
20
+
21
+ def halfcheetah_medium():
22
+ return gym.make("halfcheetah-medium-v0")
23
+
24
+
25
+ def halfcheetah_expert():
26
+ return gym.make("halfcheetah-expert-v0")
27
+
28
+
29
+ def halfcheetah_medium_replay():
30
+ return gym.make("halfcheetah-medium-replay-v0")
31
+
32
+
33
+ def hopper_random():
34
+ return gym.make("hopper-random-v0")
35
+
36
+
37
+ def hopper_medium():
38
+ return gym.make("hopper-medium-v0")
39
+
40
+
41
+ def hopper_expert():
42
+ return gym.make("hopper-expert-v0")
43
+
44
+
45
+ def hopper_medium_replay():
46
+ return gym.make("hopper-medium-replay-v0")
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/debug_counter_env.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+
4
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
5
+
6
+
7
+ class DebugCounterEnv(gym.Env):
8
+ """Simple Env that yields a ts counter as observation (0-based).
9
+
10
+ Actions have no effect.
11
+ The episode length is always 15.
12
+ Reward is always: current ts % 3.
13
+ """
14
+
15
+ def __init__(self, config=None):
16
+ config = config or {}
17
+ self.action_space = gym.spaces.Discrete(2)
18
+ self.observation_space = gym.spaces.Box(0, 100, (1,), dtype=np.float32)
19
+ self.start_at_t = int(config.get("start_at_t", 0))
20
+ self.i = self.start_at_t
21
+
22
+ def reset(self, *, seed=None, options=None):
23
+ self.i = self.start_at_t
24
+ return self._get_obs(), {}
25
+
26
+ def step(self, action):
27
+ self.i += 1
28
+ terminated = False
29
+ truncated = self.i >= 15 + self.start_at_t
30
+ return self._get_obs(), float(self.i % 3), terminated, truncated, {}
31
+
32
+ def _get_obs(self):
33
+ return np.array([self.i], dtype=np.float32)
34
+
35
+
36
+ class MultiAgentDebugCounterEnv(MultiAgentEnv):
37
+ def __init__(self, config):
38
+ super().__init__()
39
+ self.num_agents = config["num_agents"]
40
+ self.base_episode_len = config.get("base_episode_len", 103)
41
+
42
+ # Observation dims:
43
+ # 0=agent ID.
44
+ # 1=episode ID (0.0 for obs after reset).
45
+ # 2=env ID (0.0 for obs after reset).
46
+ # 3=ts (of the agent).
47
+ self.observation_space = gym.spaces.Dict(
48
+ {
49
+ aid: gym.spaces.Box(float("-inf"), float("inf"), (4,))
50
+ for aid in range(self.num_agents)
51
+ }
52
+ )
53
+
54
+ # Actions are always:
55
+ # (episodeID, envID) as floats.
56
+ self.action_space = gym.spaces.Dict(
57
+ {
58
+ aid: gym.spaces.Box(-float("inf"), float("inf"), shape=(2,))
59
+ for aid in range(self.num_agents)
60
+ }
61
+ )
62
+
63
+ self.timesteps = [0] * self.num_agents
64
+ self.terminateds = set()
65
+ self.truncateds = set()
66
+
67
+ def reset(self, *, seed=None, options=None):
68
+ self.timesteps = [0] * self.num_agents
69
+ self.terminateds = set()
70
+ self.truncateds = set()
71
+ return {
72
+ i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32)
73
+ for i in range(self.num_agents)
74
+ }, {}
75
+
76
+ def step(self, action_dict):
77
+ obs, rew, terminated, truncated = {}, {}, {}, {}
78
+ for i, action in action_dict.items():
79
+ self.timesteps[i] += 1
80
+ obs[i] = np.array([i, action[0], action[1], self.timesteps[i]])
81
+ rew[i] = self.timesteps[i] % 3
82
+ terminated[i] = False
83
+ truncated[i] = (
84
+ True if self.timesteps[i] > self.base_episode_len + i else False
85
+ )
86
+ if terminated[i]:
87
+ self.terminateds.add(i)
88
+ if truncated[i]:
89
+ self.truncateds.add(i)
90
+ terminated["__all__"] = len(self.terminateds) == self.num_agents
91
+ truncated["__all__"] = len(self.truncateds) == self.num_agents
92
+ return obs, rew, terminated, truncated, {}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/deterministic_envs.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+
3
+
4
+ def create_cartpole_deterministic(config):
5
+ env = gym.make("CartPole-v1")
6
+ env.reset(seed=config.get("seed", 0))
7
+ return env
8
+
9
+
10
+ def create_pendulum_deterministic(config):
11
+ env = gym.make("Pendulum-v1")
12
+ env.reset(seed=config.get("seed", 0))
13
+ return env
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/dm_control_suite.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv
2
+
3
+ """
4
+ 8 Environments from Deepmind Control Suite
5
+ """
6
+
7
+
8
+ def acrobot_swingup(
9
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
10
+ ):
11
+ return DMCEnv(
12
+ "acrobot",
13
+ "swingup",
14
+ from_pixels=from_pixels,
15
+ height=height,
16
+ width=width,
17
+ frame_skip=frame_skip,
18
+ channels_first=channels_first,
19
+ )
20
+
21
+
22
+ def walker_walk(
23
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
24
+ ):
25
+ return DMCEnv(
26
+ "walker",
27
+ "walk",
28
+ from_pixels=from_pixels,
29
+ height=height,
30
+ width=width,
31
+ frame_skip=frame_skip,
32
+ channels_first=channels_first,
33
+ )
34
+
35
+
36
+ def hopper_hop(
37
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
38
+ ):
39
+ return DMCEnv(
40
+ "hopper",
41
+ "hop",
42
+ from_pixels=from_pixels,
43
+ height=height,
44
+ width=width,
45
+ frame_skip=frame_skip,
46
+ channels_first=channels_first,
47
+ )
48
+
49
+
50
+ def hopper_stand(
51
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
52
+ ):
53
+ return DMCEnv(
54
+ "hopper",
55
+ "stand",
56
+ from_pixels=from_pixels,
57
+ height=height,
58
+ width=width,
59
+ frame_skip=frame_skip,
60
+ channels_first=channels_first,
61
+ )
62
+
63
+
64
+ def cheetah_run(
65
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
66
+ ):
67
+ return DMCEnv(
68
+ "cheetah",
69
+ "run",
70
+ from_pixels=from_pixels,
71
+ height=height,
72
+ width=width,
73
+ frame_skip=frame_skip,
74
+ channels_first=channels_first,
75
+ )
76
+
77
+
78
+ def walker_run(
79
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
80
+ ):
81
+ return DMCEnv(
82
+ "walker",
83
+ "run",
84
+ from_pixels=from_pixels,
85
+ height=height,
86
+ width=width,
87
+ frame_skip=frame_skip,
88
+ channels_first=channels_first,
89
+ )
90
+
91
+
92
+ def pendulum_swingup(
93
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
94
+ ):
95
+ return DMCEnv(
96
+ "pendulum",
97
+ "swingup",
98
+ from_pixels=from_pixels,
99
+ height=height,
100
+ width=width,
101
+ frame_skip=frame_skip,
102
+ channels_first=channels_first,
103
+ )
104
+
105
+
106
+ def cartpole_swingup(
107
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
108
+ ):
109
+ return DMCEnv(
110
+ "cartpole",
111
+ "swingup",
112
+ from_pixels=from_pixels,
113
+ height=height,
114
+ width=width,
115
+ frame_skip=frame_skip,
116
+ channels_first=channels_first,
117
+ )
118
+
119
+
120
+ def humanoid_walk(
121
+ from_pixels=True, height=64, width=64, frame_skip=2, channels_first=True
122
+ ):
123
+ return DMCEnv(
124
+ "humanoid",
125
+ "walk",
126
+ from_pixels=from_pixels,
127
+ height=height,
128
+ width=width,
129
+ frame_skip=frame_skip,
130
+ channels_first=channels_first,
131
+ )
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_using_remote_actor.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example of an environment that uses a named remote actor as parameter
3
+ server.
4
+
5
+ """
6
+ from gymnasium.envs.classic_control.cartpole import CartPoleEnv
7
+ from gymnasium.utils import seeding
8
+
9
+ import ray
10
+
11
+
12
+ @ray.remote
13
+ class ParameterStorage:
14
+ def get_params(self, rng):
15
+ return {
16
+ "MASSCART": rng.uniform(low=0.5, high=2.0),
17
+ }
18
+
19
+
20
+ class CartPoleWithRemoteParamServer(CartPoleEnv):
21
+ """CartPoleMassEnv varies the weights of the cart and the pole."""
22
+
23
+ def __init__(self, env_config):
24
+ self.env_config = env_config
25
+ super().__init__()
26
+ # Get our param server (remote actor) by name.
27
+ self._handler = ray.get_actor(env_config.get("param_server", "param-server"))
28
+ self.rng_seed = None
29
+ self.np_random, _ = seeding.np_random(self.rng_seed)
30
+
31
+ def reset(self, *, seed=None, options=None):
32
+ if seed is not None:
33
+ self.rng_seed = int(seed)
34
+ self.np_random, _ = seeding.np_random(seed)
35
+ print(
36
+ f"Seeding env (worker={self.env_config.worker_index}) " f"with {seed}"
37
+ )
38
+
39
+ # Pass in our RNG to guarantee no race conditions.
40
+ # If `self._handler` had its own RNG, this may clash with other
41
+ # envs trying to use the same param-server.
42
+ params = ray.get(self._handler.get_params.remote(self.np_random))
43
+
44
+ # IMPORTANT: Advance the state of our RNG (self._rng was passed
45
+ # above via ray (serialized) and thus not altered locally here!).
46
+ # Or create a new RNG from another random number:
47
+ # Seed the RNG with a deterministic seed if set, otherwise, create
48
+ # a random one.
49
+ new_seed = int(
50
+ self.np_random.integers(0, 1000000) if not self.rng_seed else self.rng_seed
51
+ )
52
+ self.np_random, _ = seeding.np_random(new_seed)
53
+
54
+ print(
55
+ f"Env worker-idx={self.env_config.worker_index} "
56
+ f"mass={params['MASSCART']}"
57
+ )
58
+
59
+ self.masscart = params["MASSCART"]
60
+ self.total_mass = self.masspole + self.masscart
61
+ self.polemass_length = self.masspole * self.length
62
+
63
+ return super().reset()
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/env_with_subprocess.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import gymnasium as gym
3
+ from gymnasium.spaces import Discrete
4
+ import os
5
+ import subprocess
6
+
7
+
8
+ class EnvWithSubprocess(gym.Env):
9
+ """An env that spawns a subprocess."""
10
+
11
+ # Dummy command to run as a subprocess with a unique name
12
+ UNIQUE_CMD = "sleep 20"
13
+
14
+ def __init__(self, config):
15
+ self.UNIQUE_FILE_0 = config["tmp_file1"]
16
+ self.UNIQUE_FILE_1 = config["tmp_file2"]
17
+ self.UNIQUE_FILE_2 = config["tmp_file3"]
18
+ self.UNIQUE_FILE_3 = config["tmp_file4"]
19
+
20
+ self.action_space = Discrete(2)
21
+ self.observation_space = Discrete(2)
22
+ # Subprocess that should be cleaned up.
23
+ self.subproc = subprocess.Popen(self.UNIQUE_CMD.split(" "), shell=False)
24
+ self.config = config
25
+ # Exit handler should be called.
26
+ atexit.register(lambda: self.subproc.kill())
27
+ if config.worker_index == 0:
28
+ atexit.register(lambda: os.unlink(self.UNIQUE_FILE_0))
29
+ else:
30
+ atexit.register(lambda: os.unlink(self.UNIQUE_FILE_1))
31
+
32
+ def close(self):
33
+ if self.config.worker_index == 0:
34
+ os.unlink(self.UNIQUE_FILE_2)
35
+ else:
36
+ os.unlink(self.UNIQUE_FILE_3)
37
+
38
+ def reset(self, *, seed=None, options=None):
39
+ return 0, {}
40
+
41
+ def step(self, action):
42
+ return 0, 0, True, False, {}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/fast_image_env.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.spaces import Box, Discrete
3
+ import numpy as np
4
+
5
+
6
+ class FastImageEnv(gym.Env):
7
+ def __init__(self, config):
8
+ self.zeros = np.zeros((84, 84, 4))
9
+ self.action_space = Discrete(2)
10
+ self.observation_space = Box(0.0, 1.0, shape=(84, 84, 4), dtype=np.float32)
11
+ self.i = 0
12
+
13
+ def reset(self, *, seed=None, options=None):
14
+ self.i = 0
15
+ return self.zeros, {}
16
+
17
+ def step(self, action):
18
+ self.i += 1
19
+ done = truncated = self.i > 1000
20
+ return self.zeros, 1, done, truncated, {}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/gpu_requiring_env.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import ray
4
+ from ray.rllib.examples.envs.classes.simple_corridor import SimpleCorridor
5
+ from ray.rllib.utils.framework import try_import_torch
6
+
7
+ torch, _ = try_import_torch()
8
+
9
+
10
+ class GPURequiringEnv(SimpleCorridor):
11
+ """A dummy env that requires a GPU in order to work.
12
+
13
+ The env here is a simple corridor env that additionally simulates a GPU
14
+ check in its constructor via `ray.get_gpu_ids()`. If this returns an
15
+ empty list, we raise an error.
16
+
17
+ To make this env work, use `num_gpus_per_env_runner > 0` (RolloutWorkers
18
+ requesting this many GPUs each) and - maybe - `num_gpus > 0` in case
19
+ your local worker/driver must have an env as well. However, this is
20
+ only the case if `create_env_on_driver`=True (default is False).
21
+ """
22
+
23
+ def __init__(self, config=None):
24
+ super().__init__(config)
25
+
26
+ # Fake-require some GPUs (at least one).
27
+ # If your local worker's env (`create_env_on_driver`=True) does not
28
+ # necessarily require a GPU, you can perform the below assertion only
29
+ # if `config.worker_index != 0`.
30
+ gpus_available = ray.get_gpu_ids()
31
+ print(f"{type(self).__name__} can see GPUs={gpus_available}")
32
+
33
+ # Create a dummy tensor on the GPU.
34
+ if len(gpus_available) > 0 and torch:
35
+ self._tensor = torch.from_numpy(np.random.random_sample(size=(42, 42))).to(
36
+ f"cuda:{gpus_available[0]}"
37
+ )
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/look_and_push.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+
4
+
5
+ class LookAndPush(gym.Env):
6
+ """Memory-requiring Env: Best sequence of actions depends on prev. states.
7
+
8
+ Optimal behavior:
9
+ 0) a=0 -> observe next state (s'), which is the "hidden" state.
10
+ If a=1 here, the hidden state is not observed.
11
+ 1) a=1 to always jump to s=2 (not matter what the prev. state was).
12
+ 2) a=1 to move to s=3.
13
+ 3) a=1 to move to s=4.
14
+ 4) a=0 OR 1 depending on s' observed after 0): +10 reward and done.
15
+ otherwise: -10 reward and done.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.action_space = gym.spaces.Discrete(2)
20
+ self.observation_space = gym.spaces.Discrete(5)
21
+ self._state = None
22
+ self._case = None
23
+
24
+ def reset(self, *, seed=None, options=None):
25
+ self._state = 2
26
+ self._case = np.random.choice(2)
27
+ return self._state, {}
28
+
29
+ def step(self, action):
30
+ assert self.action_space.contains(action)
31
+
32
+ if self._state == 4:
33
+ if action and self._case:
34
+ return self._state, 10.0, True, {}
35
+ else:
36
+ return self._state, -10, True, {}
37
+ else:
38
+ if action:
39
+ if self._state == 0:
40
+ self._state = 2
41
+ else:
42
+ self._state += 1
43
+ elif self._state == 2:
44
+ self._state = self._case
45
+
46
+ return self._state, -1, False, False, {}
47
+
48
+
49
+ class OneHot(gym.Wrapper):
50
+ def __init__(self, env):
51
+ super(OneHot, self).__init__(env)
52
+ self.observation_space = gym.spaces.Box(0.0, 1.0, (env.observation_space.n,))
53
+
54
+ def reset(self, *, seed=None, options=None):
55
+ obs, info = self.env.reset(seed=seed, options=options)
56
+ return self._encode_obs(obs), info
57
+
58
+ def step(self, action):
59
+ obs, reward, terminated, truncated, info = self.env.step(action)
60
+ return self._encode_obs(obs), reward, terminated, truncated, info
61
+
62
+ def _encode_obs(self, obs):
63
+ new_obs = np.ones(self.env.observation_space.n)
64
+ new_obs[obs] = 1.0
65
+ return new_obs
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/memory_leaking_env.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import uuid
3
+
4
+ from ray.rllib.examples.envs.classes.random_env import RandomEnv
5
+ from ray.rllib.utils.annotations import override
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class MemoryLeakingEnv(RandomEnv):
11
+ """An env that leaks very little memory.
12
+
13
+ Useful for proving that our memory-leak tests can catch the
14
+ slightest leaks.
15
+ """
16
+
17
+ def __init__(self, config=None):
18
+ super().__init__(config)
19
+ self._leak = {}
20
+ self._steps_after_reset = 0
21
+
22
+ @override(RandomEnv)
23
+ def reset(self, *, seed=None, options=None):
24
+ self._steps_after_reset = 0
25
+ return super().reset(seed=seed, options=options)
26
+
27
+ @override(RandomEnv)
28
+ def step(self, action):
29
+ self._steps_after_reset += 1
30
+
31
+ # Only leak once an episode.
32
+ if self._steps_after_reset == 2:
33
+ self._leak[uuid.uuid4().hex.upper()] = 1
34
+
35
+ return super().step(action)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/mock_env.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+ from typing import Optional
4
+
5
+ from ray.rllib.env.vector_env import VectorEnv
6
+ from ray.rllib.utils.annotations import override
7
+
8
+
9
+ class MockEnv(gym.Env):
10
+ """Mock environment for testing purposes.
11
+
12
+ Observation=0, reward=1.0, episode-len is configurable.
13
+ Actions are ignored.
14
+ """
15
+
16
+ def __init__(self, episode_length, config=None):
17
+ self.episode_length = episode_length
18
+ self.config = config
19
+ self.i = 0
20
+ self.observation_space = gym.spaces.Discrete(1)
21
+ self.action_space = gym.spaces.Discrete(2)
22
+
23
+ def reset(self, *, seed=None, options=None):
24
+ self.i = 0
25
+ return 0, {}
26
+
27
+ def step(self, action):
28
+ self.i += 1
29
+ terminated = truncated = self.i >= self.episode_length
30
+ return 0, 1.0, terminated, truncated, {}
31
+
32
+
33
+ class MockEnv2(gym.Env):
34
+ """Mock environment for testing purposes.
35
+
36
+ Observation=ts (discrete space!), reward=100.0, episode-len is
37
+ configurable. Actions are ignored.
38
+ """
39
+
40
+ metadata = {
41
+ "render.modes": ["rgb_array"],
42
+ }
43
+ render_mode: Optional[str] = "rgb_array"
44
+
45
+ def __init__(self, episode_length):
46
+ self.episode_length = episode_length
47
+ self.i = 0
48
+ self.observation_space = gym.spaces.Discrete(self.episode_length + 1)
49
+ self.action_space = gym.spaces.Discrete(2)
50
+ self.rng_seed = None
51
+
52
+ def reset(self, *, seed=None, options=None):
53
+ self.i = 0
54
+ if seed is not None:
55
+ self.rng_seed = seed
56
+ return self.i, {}
57
+
58
+ def step(self, action):
59
+ self.i += 1
60
+ terminated = truncated = self.i >= self.episode_length
61
+ return self.i, 100.0, terminated, truncated, {}
62
+
63
+ def render(self):
64
+ # Just generate a random image here for demonstration purposes.
65
+ # Also see `gym/envs/classic_control/cartpole.py` for
66
+ # an example on how to use a Viewer object.
67
+ return np.random.randint(0, 256, size=(300, 400, 3), dtype=np.uint8)
68
+
69
+
70
+ class MockEnv3(gym.Env):
71
+ """Mock environment for testing purposes.
72
+
73
+ Observation=ts (discrete space!), reward=100.0, episode-len is
74
+ configurable. Actions are ignored.
75
+ """
76
+
77
+ def __init__(self, episode_length):
78
+ self.episode_length = episode_length
79
+ self.i = 0
80
+ self.observation_space = gym.spaces.Discrete(100)
81
+ self.action_space = gym.spaces.Discrete(2)
82
+
83
+ def reset(self, *, seed=None, options=None):
84
+ self.i = 0
85
+ return self.i, {"timestep": 0}
86
+
87
+ def step(self, action):
88
+ self.i += 1
89
+ terminated = truncated = self.i >= self.episode_length
90
+ return self.i, self.i, terminated, truncated, {"timestep": self.i}
91
+
92
+
93
+ class VectorizedMockEnv(VectorEnv):
94
+ """Vectorized version of the MockEnv.
95
+
96
+ Contains `num_envs` MockEnv instances, each one having its own
97
+ `episode_length` horizon.
98
+ """
99
+
100
+ def __init__(self, episode_length, num_envs):
101
+ super().__init__(
102
+ observation_space=gym.spaces.Discrete(1),
103
+ action_space=gym.spaces.Discrete(2),
104
+ num_envs=num_envs,
105
+ )
106
+ self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
107
+
108
+ @override(VectorEnv)
109
+ def vector_reset(self, *, seeds=None, options=None):
110
+ seeds = seeds or [None] * self.num_envs
111
+ options = options or [None] * self.num_envs
112
+ obs_and_infos = [
113
+ e.reset(seed=seeds[i], options=options[i]) for i, e in enumerate(self.envs)
114
+ ]
115
+ return [oi[0] for oi in obs_and_infos], [oi[1] for oi in obs_and_infos]
116
+
117
+ @override(VectorEnv)
118
+ def reset_at(self, index, *, seed=None, options=None):
119
+ return self.envs[index].reset(seed=seed, options=options)
120
+
121
+ @override(VectorEnv)
122
+ def vector_step(self, actions):
123
+ obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch = (
124
+ [],
125
+ [],
126
+ [],
127
+ [],
128
+ [],
129
+ )
130
+ for i in range(len(self.envs)):
131
+ obs, rew, terminated, truncated, info = self.envs[i].step(actions[i])
132
+ obs_batch.append(obs)
133
+ rew_batch.append(rew)
134
+ terminated_batch.append(terminated)
135
+ truncated_batch.append(truncated)
136
+ info_batch.append(info)
137
+ return obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch
138
+
139
+ @override(VectorEnv)
140
+ def get_sub_environments(self):
141
+ return self.envs
142
+
143
+
144
+ class MockVectorEnv(VectorEnv):
145
+ """A custom vector env that uses a single(!) CartPole sub-env.
146
+
147
+ However, this env pretends to be a vectorized one to illustrate how one
148
+ could create custom VectorEnvs w/o the need for actual vectorizations of
149
+ sub-envs under the hood.
150
+ """
151
+
152
+ def __init__(self, episode_length, mocked_num_envs):
153
+ self.env = gym.make("CartPole-v1")
154
+ super().__init__(
155
+ observation_space=self.env.observation_space,
156
+ action_space=self.env.action_space,
157
+ num_envs=mocked_num_envs,
158
+ )
159
+ self.episode_len = episode_length
160
+ self.ts = 0
161
+
162
+ @override(VectorEnv)
163
+ def vector_reset(self, *, seeds=None, options=None):
164
+ # Since we only have one underlying sub-environment, just use the first seed
165
+ # and the first options dict (the user of this env thinks, there are
166
+ # `self.num_envs` sub-environments and sends that many seeds/options).
167
+ seeds = seeds or [None]
168
+ options = options or [None]
169
+ obs, infos = self.env.reset(seed=seeds[0], options=options[0])
170
+ # Simply repeat the single obs/infos to pretend we really have
171
+ # `self.num_envs` sub-environments.
172
+ return (
173
+ [obs for _ in range(self.num_envs)],
174
+ [infos for _ in range(self.num_envs)],
175
+ )
176
+
177
+ @override(VectorEnv)
178
+ def reset_at(self, index, *, seed=None, options=None):
179
+ self.ts = 0
180
+ return self.env.reset(seed=seed, options=options)
181
+
182
+ @override(VectorEnv)
183
+ def vector_step(self, actions):
184
+ self.ts += 1
185
+ # Apply all actions sequentially to the same env.
186
+ # Whether this would make a lot of sense is debatable.
187
+ obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch = (
188
+ [],
189
+ [],
190
+ [],
191
+ [],
192
+ [],
193
+ )
194
+ for i in range(self.num_envs):
195
+ obs, rew, terminated, truncated, info = self.env.step(actions[i])
196
+ # Artificially truncate once time step limit has been reached.
197
+ # Note: Also terminate/truncate, when underlying CartPole is
198
+ # terminated/truncated.
199
+ if self.ts >= self.episode_len:
200
+ truncated = True
201
+ obs_batch.append(obs)
202
+ rew_batch.append(rew)
203
+ terminated_batch.append(terminated)
204
+ truncated_batch.append(truncated)
205
+ info_batch.append(info)
206
+ if terminated or truncated:
207
+ remaining = self.num_envs - (i + 1)
208
+ obs_batch.extend([obs for _ in range(remaining)])
209
+ rew_batch.extend([rew for _ in range(remaining)])
210
+ terminated_batch.extend([terminated for _ in range(remaining)])
211
+ truncated_batch.extend([truncated for _ in range(remaining)])
212
+ info_batch.extend([info for _ in range(remaining)])
213
+ break
214
+ return obs_batch, rew_batch, terminated_batch, truncated_batch, info_batch
215
+
216
+ @override(VectorEnv)
217
+ def get_sub_environments(self):
218
+ # You may also leave this method as-is, in which case, it would
219
+ # return an empty list.
220
+ return [self.env for _ in range(self.num_envs)]
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/bandit_envs_discrete.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import gymnasium as gym
3
+ from gymnasium.spaces import Box, Discrete
4
+ import numpy as np
5
+ import random
6
+
7
+
8
+ class SimpleContextualBandit(gym.Env):
9
+ """Simple env w/ 2 states and 3 actions (arms): 0, 1, and 2.
10
+
11
+ Episodes last only for one timestep, possible observations are:
12
+ [-1.0, 1.0] and [1.0, -1.0], where the first element is the "current context".
13
+ The highest reward (+10.0) is received for selecting arm 0 for context=1.0
14
+ and arm 2 for context=-1.0. Action 1 always yields 0.0 reward.
15
+ """
16
+
17
+ def __init__(self, config=None):
18
+ self.action_space = Discrete(3)
19
+ self.observation_space = Box(low=-1.0, high=1.0, shape=(2,))
20
+ self.cur_context = None
21
+
22
+ def reset(self, *, seed=None, options=None):
23
+ self.cur_context = random.choice([-1.0, 1.0])
24
+ return np.array([self.cur_context, -self.cur_context]), {}
25
+
26
+ def step(self, action):
27
+ rewards_for_context = {
28
+ -1.0: [-10, 0, 10],
29
+ 1.0: [10, 0, -10],
30
+ }
31
+ reward = rewards_for_context[self.cur_context][action]
32
+ return (
33
+ np.array([-self.cur_context, self.cur_context]),
34
+ reward,
35
+ True,
36
+ False,
37
+ {"regret": 10 - reward},
38
+ )
39
+
40
+
41
+ class LinearDiscreteEnv(gym.Env):
42
+ """Samples data from linearly parameterized arms.
43
+
44
+ The reward for context X and arm i is given by X^T * theta_i, for some
45
+ latent set of parameters {theta_i : i = 1, ..., k}.
46
+ The thetas are sampled uniformly at random, the contexts are Gaussian,
47
+ and Gaussian noise is added to the rewards.
48
+ """
49
+
50
+ DEFAULT_CONFIG_LINEAR = {
51
+ "feature_dim": 8,
52
+ "num_actions": 4,
53
+ "reward_noise_std": 0.01,
54
+ }
55
+
56
+ def __init__(self, config=None):
57
+ self.config = copy.copy(self.DEFAULT_CONFIG_LINEAR)
58
+ if config is not None and type(config) is dict:
59
+ self.config.update(config)
60
+
61
+ self.feature_dim = self.config["feature_dim"]
62
+ self.num_actions = self.config["num_actions"]
63
+ self.sigma = self.config["reward_noise_std"]
64
+
65
+ self.action_space = Discrete(self.num_actions)
66
+ self.observation_space = Box(low=-10, high=10, shape=(self.feature_dim,))
67
+
68
+ self.thetas = np.random.uniform(-1, 1, (self.num_actions, self.feature_dim))
69
+ self.thetas /= np.linalg.norm(self.thetas, axis=1, keepdims=True)
70
+
71
+ self._elapsed_steps = 0
72
+ self._current_context = None
73
+
74
+ def _sample_context(self):
75
+ return np.random.normal(scale=1 / 3, size=(self.feature_dim,))
76
+
77
+ def reset(self, *, seed=None, options=None):
78
+ self._current_context = self._sample_context()
79
+ return self._current_context, {}
80
+
81
+ def step(self, action):
82
+ assert (
83
+ self._elapsed_steps is not None
84
+ ), "Cannot call env.step() beforecalling reset()"
85
+ assert action < self.num_actions, "Invalid action."
86
+
87
+ action = int(action)
88
+ context = self._current_context
89
+ rewards = self.thetas.dot(context)
90
+
91
+ opt_action = rewards.argmax()
92
+
93
+ regret = rewards.max() - rewards[action]
94
+
95
+ # Add Gaussian noise
96
+ rewards += np.random.normal(scale=self.sigma, size=rewards.shape)
97
+
98
+ reward = rewards[action]
99
+ self._current_context = self._sample_context()
100
+ return (
101
+ self._current_context,
102
+ reward,
103
+ True,
104
+ False,
105
+ {"regret": regret, "opt_action": opt_action},
106
+ )
107
+
108
+ def render(self, mode="human"):
109
+ raise NotImplementedError
110
+
111
+
112
+ class WheelBanditEnv(gym.Env):
113
+ """Wheel bandit environment for 2D contexts
114
+ (see https://arxiv.org/abs/1802.09127).
115
+ """
116
+
117
+ DEFAULT_CONFIG_WHEEL = {
118
+ "delta": 0.5,
119
+ "mu_1": 1.2,
120
+ "mu_2": 1,
121
+ "mu_3": 50,
122
+ "std": 0.01,
123
+ }
124
+
125
+ feature_dim = 2
126
+ num_actions = 5
127
+
128
+ def __init__(self, config=None):
129
+ self.config = copy.copy(self.DEFAULT_CONFIG_WHEEL)
130
+ if config is not None and type(config) is dict:
131
+ self.config.update(config)
132
+
133
+ self.delta = self.config["delta"]
134
+ self.mu_1 = self.config["mu_1"]
135
+ self.mu_2 = self.config["mu_2"]
136
+ self.mu_3 = self.config["mu_3"]
137
+ self.std = self.config["std"]
138
+
139
+ self.action_space = Discrete(self.num_actions)
140
+ self.observation_space = Box(low=-1, high=1, shape=(self.feature_dim,))
141
+
142
+ self.means = [self.mu_1] + 4 * [self.mu_2]
143
+ self._elapsed_steps = 0
144
+ self._current_context = None
145
+
146
+ def _sample_context(self):
147
+ while True:
148
+ state = np.random.uniform(-1, 1, self.feature_dim)
149
+ if np.linalg.norm(state) <= 1:
150
+ return state
151
+
152
+ def reset(self, *, seed=None, options=None):
153
+ self._current_context = self._sample_context()
154
+ return self._current_context, {}
155
+
156
+ def step(self, action):
157
+ assert (
158
+ self._elapsed_steps is not None
159
+ ), "Cannot call env.step() before calling reset()"
160
+
161
+ action = int(action)
162
+ self._elapsed_steps += 1
163
+ rewards = [
164
+ np.random.normal(self.means[j], self.std) for j in range(self.num_actions)
165
+ ]
166
+ context = self._current_context
167
+ r_big = np.random.normal(self.mu_3, self.std)
168
+
169
+ if np.linalg.norm(context) >= self.delta:
170
+ if context[0] > 0:
171
+ if context[1] > 0:
172
+ # First quadrant
173
+ rewards[1] = r_big
174
+ opt_action = 1
175
+ else:
176
+ # Fourth quadrant
177
+ rewards[4] = r_big
178
+ opt_action = 4
179
+ else:
180
+ if context[1] > 0:
181
+ # Second quadrant
182
+ rewards[2] = r_big
183
+ opt_action = 2
184
+ else:
185
+ # Third quadrant
186
+ rewards[3] = r_big
187
+ opt_action = 3
188
+ else:
189
+ # Smaller region where action 0 is optimal
190
+ opt_action = 0
191
+
192
+ reward = rewards[action]
193
+
194
+ regret = rewards[opt_action] - reward
195
+
196
+ self._current_context = self._sample_context()
197
+ return (
198
+ self._current_context,
199
+ reward,
200
+ True,
201
+ False,
202
+ {"regret": regret, "opt_action": opt_action},
203
+ )
204
+
205
+ def render(self, mode="human"):
206
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/guess_the_number_game.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+
3
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
4
+
5
+
6
+ class GuessTheNumberGame(MultiAgentEnv):
7
+ """
8
+ We have two players, 0 and 1. Agent 0 has to pick a number between 0, MAX-1
9
+ at reset. Agent 1 has to guess the number by asking N questions of whether
10
+ of the form of "a <number> is higher|lower|equal to the picked number. The
11
+ action space is MultiDiscrete [3, MAX]. For the first index 0 means lower,
12
+ 1 means higher and 2 means equal. The environment answers with yes (1) or
13
+ no (0) on the reward function. Every time step that agent 1 wastes agent 0
14
+ gets a reward of 1. After N steps the game is terminated. If agent 1
15
+ guesses the number correctly, it gets a reward of 100 points, otherwise it
16
+ gets a reward of 0. On the other hand if agent 0 wins they win 100 points.
17
+ The optimal policy controlling agent 1 should converge to a binary search
18
+ strategy.
19
+ """
20
+
21
+ MAX_NUMBER = 3
22
+ MAX_STEPS = 20
23
+
24
+ def __init__(self, config=None):
25
+ super().__init__()
26
+ self._agent_ids = {0, 1}
27
+
28
+ self.max_number = config.get("max_number", self.MAX_NUMBER)
29
+ self.max_steps = config.get("max_steps", self.MAX_STEPS)
30
+
31
+ self._number = None
32
+ self.observation_space = gym.spaces.Discrete(2)
33
+ self.action_space = gym.spaces.MultiDiscrete([3, self.max_number])
34
+
35
+ def reset(self, *, seed=None, options=None):
36
+ self._step = 0
37
+ self._number = None
38
+ # agent 0 has to pick a number. So the returned obs does not matter.
39
+ return {0: 0}, {}
40
+
41
+ def step(self, action_dict):
42
+ # get agent 0's action
43
+ agent_0_action = action_dict.get(0)
44
+
45
+ if agent_0_action is not None:
46
+ # ignore the first part of the action and look at the number
47
+ self._number = agent_0_action[1]
48
+ # next obs should tell agent 1 to start guessing.
49
+ # the returned reward and dones should be on agent 0 who picked a
50
+ # number.
51
+ return (
52
+ {1: 0},
53
+ {0: 0},
54
+ {0: False, "__all__": False},
55
+ {0: False, "__all__": False},
56
+ {},
57
+ )
58
+
59
+ if self._number is None:
60
+ raise ValueError(
61
+ "No number is selected by agent 0. Have you restarted "
62
+ "the environment?"
63
+ )
64
+
65
+ # get agent 1's action
66
+ direction, number = action_dict.get(1)
67
+ info = {}
68
+ # always the same, we don't need agent 0 to act ever again, agent 1 should keep
69
+ # guessing.
70
+ obs = {1: 0}
71
+ guessed_correctly = False
72
+ terminated = {1: False, "__all__": False}
73
+ truncated = {1: False, "__all__": False}
74
+ # everytime agent 1 does not guess correctly agent 0 gets a reward of 1.
75
+ if direction == 0: # lower
76
+ reward = {1: int(number > self._number), 0: 1}
77
+ elif direction == 1: # higher
78
+ reward = {1: int(number < self._number), 0: 1}
79
+ else: # equal
80
+ guessed_correctly = number == self._number
81
+ reward = {1: guessed_correctly * 100, 0: guessed_correctly * -100}
82
+ terminated = {1: guessed_correctly, "__all__": guessed_correctly}
83
+
84
+ self._step += 1
85
+ if self._step >= self.max_steps: # max number of steps episode is over
86
+ truncated["__all__"] = True
87
+ if not guessed_correctly:
88
+ reward[0] = 100 # agent 0 wins
89
+ return obs, reward, terminated, truncated, info
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_chess.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pettingzoo import AECEnv
2
+ from pettingzoo.classic.chess.chess import raw_env as chess_v5
3
+ import copy
4
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
5
+ from typing import Dict, Any
6
+ import chess as ch
7
+ import numpy as np
8
+
9
+
10
+ class MultiAgentChess(MultiAgentEnv):
11
+ """An interface to the PettingZoo MARL environment library.
12
+ See: https://github.com/Farama-Foundation/PettingZoo
13
+ Inherits from MultiAgentEnv and exposes a given AEC
14
+ (actor-environment-cycle) game from the PettingZoo project via the
15
+ MultiAgentEnv public API.
16
+ Note that the wrapper has some important limitations:
17
+ 1. All agents have the same action_spaces and observation_spaces.
18
+ Note: If, within your aec game, agents do not have homogeneous action /
19
+ observation spaces, apply SuperSuit wrappers
20
+ to apply padding functionality: https://github.com/Farama-Foundation/
21
+ SuperSuit#built-in-multi-agent-only-functions
22
+ 2. Environments are positive sum games (-> Agents are expected to cooperate
23
+ to maximize reward). This isn't a hard restriction, it just that
24
+ standard algorithms aren't expected to work well in highly competitive
25
+ games.
26
+
27
+ .. testcode::
28
+ :skipif: True
29
+
30
+ from pettingzoo.butterfly import prison_v3
31
+ from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
32
+ env = PettingZooEnv(prison_v3.env())
33
+ obs = env.reset()
34
+ print(obs)
35
+ # only returns the observation for the agent which should be stepping
36
+
37
+ .. testoutput::
38
+
39
+ {
40
+ 'prisoner_0': array([[[0, 0, 0],
41
+ [0, 0, 0],
42
+ [0, 0, 0],
43
+ ...,
44
+ [0, 0, 0],
45
+ [0, 0, 0],
46
+ [0, 0, 0]]], dtype=uint8)
47
+ }
48
+
49
+ .. testcode::
50
+ :skipif: True
51
+
52
+ obs, rewards, dones, infos = env.step({
53
+ "prisoner_0": 1
54
+ })
55
+ # only returns the observation, reward, info, etc, for
56
+ # the agent who's turn is next.
57
+ print(obs)
58
+
59
+ .. testoutput::
60
+
61
+ {
62
+ 'prisoner_1': array([[[0, 0, 0],
63
+ [0, 0, 0],
64
+ [0, 0, 0],
65
+ ...,
66
+ [0, 0, 0],
67
+ [0, 0, 0],
68
+ [0, 0, 0]]], dtype=uint8)
69
+ }
70
+
71
+ .. testcode::
72
+ :skipif: True
73
+
74
+ print(rewards)
75
+
76
+ .. testoutput::
77
+
78
+ {
79
+ 'prisoner_1': 0
80
+ }
81
+
82
+ .. testcode::
83
+ :skipif: True
84
+
85
+ print(dones)
86
+
87
+ .. testoutput::
88
+
89
+ {
90
+ 'prisoner_1': False, '__all__': False
91
+ }
92
+
93
+ .. testcode::
94
+ :skipif: True
95
+
96
+ print(infos)
97
+
98
+ .. testoutput::
99
+
100
+ {
101
+ 'prisoner_1': {'map_tuple': (1, 0)}
102
+ }
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ config: Dict[Any, Any] = None,
108
+ env: AECEnv = None,
109
+ ):
110
+ super().__init__()
111
+ if env is None:
112
+ self.env = chess_v5()
113
+ else:
114
+ self.env = env
115
+ self.env.reset()
116
+
117
+ self.config = config
118
+ if self.config is None:
119
+ self.config = {}
120
+ try:
121
+ self.config["random_start"] = self.config["random_start"]
122
+ except KeyError:
123
+ self.config["random_start"] = 4
124
+ # Get first observation space, assuming all agents have equal space
125
+ self.observation_space = self.env.observation_space(self.env.agents[0])
126
+
127
+ # Get first action space, assuming all agents have equal space
128
+ self.action_space = self.env.action_space(self.env.agents[0])
129
+
130
+ assert all(
131
+ self.env.observation_space(agent) == self.observation_space
132
+ for agent in self.env.agents
133
+ ), (
134
+ "Observation spaces for all agents must be identical. Perhaps "
135
+ "SuperSuit's pad_observations wrapper can help (useage: "
136
+ "`supersuit.aec_wrappers.pad_observations(env)`"
137
+ )
138
+
139
+ assert all(
140
+ self.env.action_space(agent) == self.action_space
141
+ for agent in self.env.agents
142
+ ), (
143
+ "Action spaces for all agents must be identical. Perhaps "
144
+ "SuperSuit's pad_action_space wrapper can help (usage: "
145
+ "`supersuit.aec_wrappers.pad_action_space(env)`"
146
+ )
147
+ self._agent_ids = set(self.env.agents)
148
+
149
+ def random_start(self, random_moves):
150
+ self.env.board = ch.Board()
151
+ for i in range(random_moves):
152
+ self.env.board.push(np.random.choice(list(self.env.board.legal_moves)))
153
+ return self.env.board
154
+
155
+ def observe(self):
156
+ return {
157
+ self.env.agent_selection: self.env.observe(self.env.agent_selection),
158
+ "state": self.get_state(),
159
+ }
160
+
161
+ def reset(self, *args, **kwargs):
162
+ self.env.reset()
163
+ if self.config["random_start"] > 0:
164
+ self.random_start(self.config["random_start"])
165
+ return (
166
+ {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
167
+ {self.env.agent_selection: {}},
168
+ )
169
+
170
+ def step(self, action):
171
+ try:
172
+ self.env.step(action[self.env.agent_selection])
173
+ except (KeyError, IndexError):
174
+ self.env.step(action)
175
+ except AssertionError:
176
+ # Illegal action
177
+ print(action)
178
+ raise AssertionError("Illegal action")
179
+
180
+ obs_d = {}
181
+ rew_d = {}
182
+ done_d = {}
183
+ truncated_d = {}
184
+ info_d = {}
185
+ while self.env.agents:
186
+ obs, rew, done, trunc, info = self.env.last()
187
+ a = self.env.agent_selection
188
+ obs_d[a] = obs
189
+ rew_d[a] = rew
190
+ done_d[a] = done
191
+ truncated_d[a] = trunc
192
+ info_d[a] = info
193
+ if self.env.terminations[self.env.agent_selection]:
194
+ self.env.step(None)
195
+ done_d["__all__"] = True
196
+ truncated_d["__all__"] = True
197
+ else:
198
+ done_d["__all__"] = False
199
+ truncated_d["__all__"] = False
200
+ break
201
+
202
+ return obs_d, rew_d, done_d, truncated_d, info_d
203
+
204
+ def close(self):
205
+ self.env.close()
206
+
207
+ def seed(self, seed=None):
208
+ self.env.seed(seed)
209
+
210
+ def render(self, mode="human"):
211
+ return self.env.render(mode)
212
+
213
+ @property
214
+ def agent_selection(self):
215
+ return self.env.agent_selection
216
+
217
+ @property
218
+ def get_sub_environments(self):
219
+ return self.env.unwrapped
220
+
221
+ def get_state(self):
222
+ state = copy.deepcopy(self.env)
223
+ return state
224
+
225
+ def set_state(self, state):
226
+ self.env = copy.deepcopy(state)
227
+ return self.env.observe(self.env.agent_selection)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/pettingzoo_connect4.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, Any
3
+
4
+ from pettingzoo import AECEnv
5
+ from pettingzoo.classic.connect_four_v3 import raw_env as connect_four_v3
6
+
7
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
8
+
9
+
10
+ class MultiAgentConnect4(MultiAgentEnv):
11
+ """An interface to the PettingZoo MARL environment library.
12
+ See: https://github.com/Farama-Foundation/PettingZoo
13
+ Inherits from MultiAgentEnv and exposes a given AEC
14
+ (actor-environment-cycle) game from the PettingZoo project via the
15
+ MultiAgentEnv public API.
16
+ Note that the wrapper has some important limitations:
17
+ 1. All agents have the same action_spaces and observation_spaces.
18
+ Note: If, within your aec game, agents do not have homogeneous action /
19
+ observation spaces, apply SuperSuit wrappers
20
+ to apply padding functionality: https://github.com/Farama-Foundation/
21
+ SuperSuit#built-in-multi-agent-only-functions
22
+ 2. Environments are positive sum games (-> Agents are expected to cooperate
23
+ to maximize reward). This isn't a hard restriction, it just that
24
+ standard algorithms aren't expected to work well in highly competitive
25
+ games.
26
+
27
+ .. testcode::
28
+ :skipif: True
29
+
30
+ from pettingzoo.butterfly import prison_v3
31
+ from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
32
+ env = PettingZooEnv(prison_v3.env())
33
+ obs = env.reset()
34
+ print(obs)
35
+
36
+ .. testoutput::
37
+
38
+ # only returns the observation for the agent which should be stepping
39
+ {
40
+ 'prisoner_0': array([[[0, 0, 0],
41
+ [0, 0, 0],
42
+ [0, 0, 0],
43
+ ...,
44
+ [0, 0, 0],
45
+ [0, 0, 0],
46
+ [0, 0, 0]]], dtype=uint8)
47
+ }
48
+
49
+ .. testcode::
50
+ :skipif: True
51
+
52
+ obs, rewards, dones, infos = env.step({
53
+ "prisoner_0": 1
54
+ })
55
+ # only returns the observation, reward, info, etc, for
56
+ # the agent who's turn is next.
57
+ print(obs)
58
+
59
+ .. testoutput::
60
+
61
+ {
62
+ 'prisoner_1': array([[[0, 0, 0],
63
+ [0, 0, 0],
64
+ [0, 0, 0],
65
+ ...,
66
+ [0, 0, 0],
67
+ [0, 0, 0],
68
+ [0, 0, 0]]], dtype=uint8)
69
+ }
70
+
71
+ .. testcode::
72
+ :skipif: True
73
+
74
+ print(rewards)
75
+
76
+ .. testoutput::
77
+
78
+ {
79
+ 'prisoner_1': 0
80
+ }
81
+
82
+ .. testcode::
83
+ :skipif: True
84
+
85
+ print(dones)
86
+
87
+ .. testoutput::
88
+
89
+ {
90
+ 'prisoner_1': False, '__all__': False
91
+ }
92
+
93
+ .. testcode::
94
+ :skipif: True
95
+
96
+ print(infos)
97
+
98
+ .. testoutput::
99
+
100
+ {
101
+ 'prisoner_1': {'map_tuple': (1, 0)}
102
+ }
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ config: Dict[Any, Any] = None,
108
+ env: AECEnv = None,
109
+ ):
110
+ super().__init__()
111
+ if env is None:
112
+ self.env = connect_four_v3()
113
+ else:
114
+ self.env = env
115
+ self.env.reset()
116
+
117
+ self.config = config
118
+ # Get first observation space, assuming all agents have equal space
119
+ self.observation_space = self.env.observation_space(self.env.agents[0])
120
+
121
+ # Get first action space, assuming all agents have equal space
122
+ self.action_space = self.env.action_space(self.env.agents[0])
123
+
124
+ assert all(
125
+ self.env.observation_space(agent) == self.observation_space
126
+ for agent in self.env.agents
127
+ ), (
128
+ "Observation spaces for all agents must be identical. Perhaps "
129
+ "SuperSuit's pad_observations wrapper can help (useage: "
130
+ "`supersuit.aec_wrappers.pad_observations(env)`"
131
+ )
132
+
133
+ assert all(
134
+ self.env.action_space(agent) == self.action_space
135
+ for agent in self.env.agents
136
+ ), (
137
+ "Action spaces for all agents must be identical. Perhaps "
138
+ "SuperSuit's pad_action_space wrapper can help (usage: "
139
+ "`supersuit.aec_wrappers.pad_action_space(env)`"
140
+ )
141
+ self._agent_ids = set(self.env.agents)
142
+
143
+ def observe(self):
144
+ return {
145
+ self.env.agent_selection: self.env.observe(self.env.agent_selection),
146
+ "state": self.get_state(),
147
+ }
148
+
149
+ def reset(self, *args, **kwargs):
150
+ self.env.reset()
151
+ return (
152
+ {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
153
+ {self.env.agent_selection: {}},
154
+ )
155
+
156
+ def step(self, action):
157
+ try:
158
+ self.env.step(action[self.env.agent_selection])
159
+ except (KeyError, IndexError):
160
+ self.env.step(action)
161
+ except AssertionError:
162
+ # Illegal action
163
+ print(action)
164
+ raise AssertionError("Illegal action")
165
+
166
+ obs_d = {}
167
+ rew_d = {}
168
+ done_d = {}
169
+ trunc_d = {}
170
+ info_d = {}
171
+ while self.env.agents:
172
+ obs, rew, done, trunc, info = self.env.last()
173
+ a = self.env.agent_selection
174
+ obs_d[a] = obs
175
+ rew_d[a] = rew
176
+ done_d[a] = done
177
+ trunc_d[a] = trunc
178
+ info_d[a] = info
179
+ if self.env.terminations[self.env.agent_selection]:
180
+ self.env.step(None)
181
+ done_d["__all__"] = True
182
+ trunc_d["__all__"] = True
183
+ else:
184
+ done_d["__all__"] = False
185
+ trunc_d["__all__"] = False
186
+ break
187
+
188
+ return obs_d, rew_d, done_d, trunc_d, info_d
189
+
190
+ def close(self):
191
+ self.env.close()
192
+
193
+ def seed(self, seed=None):
194
+ self.env.seed(seed)
195
+
196
+ def render(self, mode="human"):
197
+ return self.env.render(mode)
198
+
199
+ @property
200
+ def agent_selection(self):
201
+ return self.env.agent_selection
202
+
203
+ @property
204
+ def get_sub_environments(self):
205
+ return self.env.unwrapped
206
+
207
+ def get_state(self):
208
+ state = copy.deepcopy(self.env)
209
+ return state
210
+
211
+ def set_state(self, state):
212
+ self.env = copy.deepcopy(state)
213
+ return self.env.observe(self.env.agent_selection)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/rock_paper_scissors.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # __sphinx_doc_1_begin__
2
+ import gymnasium as gym
3
+
4
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
5
+
6
+
7
+ class RockPaperScissors(MultiAgentEnv):
8
+ """Two-player environment for the famous rock paper scissors game.
9
+
10
+ # __sphinx_doc_1_end__
11
+ Optionally, the "Sheldon Cooper extension" can be activated by passing
12
+ `sheldon_cooper_mode=True` into the constructor, in which case two more moves
13
+ are allowed: Spock and Lizard. Spock is poisoned by Lizard, disproven by Paper, but
14
+ crushes Rock and smashes Scissors. Lizard poisons Spock and eats Paper, but is
15
+ decapitated by Scissors and crushed by Rock.
16
+
17
+ # __sphinx_doc_2_begin__
18
+ Both players always move simultaneously over a course of 10 timesteps in total.
19
+ The winner of each timestep receives reward of +1, the losing player -1.0.
20
+
21
+ The observation of each player is the last opponent action.
22
+ """
23
+
24
+ ROCK = 0
25
+ PAPER = 1
26
+ SCISSORS = 2
27
+ LIZARD = 3
28
+ SPOCK = 4
29
+
30
+ WIN_MATRIX = {
31
+ (ROCK, ROCK): (0, 0),
32
+ (ROCK, PAPER): (-1, 1),
33
+ (ROCK, SCISSORS): (1, -1),
34
+ (PAPER, ROCK): (1, -1),
35
+ (PAPER, PAPER): (0, 0),
36
+ (PAPER, SCISSORS): (-1, 1),
37
+ (SCISSORS, ROCK): (-1, 1),
38
+ (SCISSORS, PAPER): (1, -1),
39
+ (SCISSORS, SCISSORS): (0, 0),
40
+ }
41
+ # __sphinx_doc_2_end__
42
+
43
+ WIN_MATRIX.update(
44
+ {
45
+ # Sheldon Cooper mode:
46
+ (LIZARD, LIZARD): (0, 0),
47
+ (LIZARD, SPOCK): (1, -1), # Lizard poisons Spock
48
+ (LIZARD, ROCK): (-1, 1), # Rock crushes lizard
49
+ (LIZARD, PAPER): (1, -1), # Lizard eats paper
50
+ (LIZARD, SCISSORS): (-1, 1), # Scissors decapitate lizard
51
+ (ROCK, LIZARD): (1, -1), # Rock crushes lizard
52
+ (PAPER, LIZARD): (-1, 1), # Lizard eats paper
53
+ (SCISSORS, LIZARD): (1, -1), # Scissors decapitate lizard
54
+ (SPOCK, SPOCK): (0, 0),
55
+ (SPOCK, LIZARD): (-1, 1), # Lizard poisons Spock
56
+ (SPOCK, ROCK): (1, -1), # Spock vaporizes rock
57
+ (SPOCK, PAPER): (-1, 1), # Paper disproves Spock
58
+ (SPOCK, SCISSORS): (1, -1), # Spock smashes scissors
59
+ (ROCK, SPOCK): (-1, 1), # Spock vaporizes rock
60
+ (PAPER, SPOCK): (1, -1), # Paper disproves Spock
61
+ (SCISSORS, SPOCK): (-1, 1), # Spock smashes scissors
62
+ }
63
+ )
64
+
65
+ # __sphinx_doc_3_begin__
66
+ def __init__(self, config=None):
67
+ super().__init__()
68
+
69
+ self.agents = self.possible_agents = ["player1", "player2"]
70
+
71
+ # The observations are always the last taken actions. Hence observation- and
72
+ # action spaces are identical.
73
+ self.observation_spaces = self.action_spaces = {
74
+ "player1": gym.spaces.Discrete(3),
75
+ "player2": gym.spaces.Discrete(3),
76
+ }
77
+ self.last_move = None
78
+ self.num_moves = 0
79
+ # __sphinx_doc_3_end__
80
+
81
+ self.sheldon_cooper_mode = False
82
+ if config.get("sheldon_cooper_mode"):
83
+ self.sheldon_cooper_mode = True
84
+ self.action_spaces = self.observation_spaces = {
85
+ "player1": gym.spaces.Discrete(5),
86
+ "player2": gym.spaces.Discrete(5),
87
+ }
88
+
89
+ # __sphinx_doc_4_begin__
90
+ def reset(self, *, seed=None, options=None):
91
+ self.num_moves = 0
92
+
93
+ # The first observation should not matter (none of the agents has moved yet).
94
+ # Set them to 0.
95
+ return {
96
+ "player1": 0,
97
+ "player2": 0,
98
+ }, {} # <- empty infos dict
99
+
100
+ # __sphinx_doc_4_end__
101
+
102
+ # __sphinx_doc_5_begin__
103
+ def step(self, action_dict):
104
+ self.num_moves += 1
105
+
106
+ move1 = action_dict["player1"]
107
+ move2 = action_dict["player2"]
108
+
109
+ # Set the next observations (simply use the other player's action).
110
+ # Note that because we are publishing both players in the observations dict,
111
+ # we expect both players to act in the next `step()` (simultaneous stepping).
112
+ observations = {"player1": move2, "player2": move1}
113
+
114
+ # Compute rewards for each player based on the win-matrix.
115
+ r1, r2 = self.WIN_MATRIX[move1, move2]
116
+ rewards = {"player1": r1, "player2": r2}
117
+
118
+ # Terminate the entire episode (for all agents) once 10 moves have been made.
119
+ terminateds = {"__all__": self.num_moves >= 10}
120
+
121
+ # Leave truncateds and infos empty.
122
+ return observations, rewards, terminateds, {}, {}
123
+
124
+
125
+ # __sphinx_doc_5_end__
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # __sphinx_doc_1_begin__
2
+ import gymnasium as gym
3
+ import numpy as np
4
+
5
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
6
+
7
+
8
+ class TicTacToe(MultiAgentEnv):
9
+ """A two-player game in which any player tries to complete one row in a 3x3 field.
10
+
11
+ The observation space is Box(0.0, 1.0, (9,)), where each index represents a distinct
12
+ field on a 3x3 board and values of 0.0 mean the field is empty, -1.0 means
13
+ the opponend owns the field, and 1.0 means we occupy the field:
14
+ ----------
15
+ | 0| 1| 2|
16
+ ----------
17
+ | 3| 4| 5|
18
+ ----------
19
+ | 6| 7| 8|
20
+ ----------
21
+
22
+ The action space is Discrete(9) and actions landing on an already occupied field
23
+ are simply ignored (and thus useless to the player taking these actions).
24
+
25
+ Once a player completes a row, they receive +1.0 reward, the losing player receives
26
+ -1.0 reward. In all other cases, both players receive 0.0 reward.
27
+ """
28
+
29
+ # __sphinx_doc_1_end__
30
+
31
+ # __sphinx_doc_2_begin__
32
+ def __init__(self, config=None):
33
+ super().__init__()
34
+
35
+ # Define the agents in the game.
36
+ self.agents = self.possible_agents = ["player1", "player2"]
37
+
38
+ # Each agent observes a 9D tensor, representing the 3x3 fields of the board.
39
+ # A 0 means an empty field, a 1 represents a piece of player 1, a -1 a piece of
40
+ # player 2.
41
+ self.observation_spaces = {
42
+ "player1": gym.spaces.Box(-1.0, 1.0, (9,), np.float32),
43
+ "player2": gym.spaces.Box(-1.0, 1.0, (9,), np.float32),
44
+ }
45
+ # Each player has 9 actions, encoding the 9 fields each player can place a piece
46
+ # on during their turn.
47
+ self.action_spaces = {
48
+ "player1": gym.spaces.Discrete(9),
49
+ "player2": gym.spaces.Discrete(9),
50
+ }
51
+
52
+ self.board = None
53
+ self.current_player = None
54
+
55
+ # __sphinx_doc_2_end__
56
+
57
+ # __sphinx_doc_3_begin__
58
+ def reset(self, *, seed=None, options=None):
59
+ self.board = [
60
+ 0,
61
+ 0,
62
+ 0,
63
+ 0,
64
+ 0,
65
+ 0,
66
+ 0,
67
+ 0,
68
+ 0,
69
+ ]
70
+ # Pick a random player to start the game.
71
+ self.current_player = np.random.choice(["player1", "player2"])
72
+ # Return observations dict (only with the starting player, which is the one
73
+ # we expect to act next).
74
+ return {
75
+ self.current_player: np.array(self.board, np.float32),
76
+ }, {}
77
+
78
+ # __sphinx_doc_3_end__
79
+
80
+ # __sphinx_doc_4_begin__
81
+ def step(self, action_dict):
82
+ action = action_dict[self.current_player]
83
+
84
+ # Create a rewards-dict (containing the rewards of the agent that just acted).
85
+ rewards = {self.current_player: 0.0}
86
+ # Create a terminateds-dict with the special `__all__` agent ID, indicating that
87
+ # if True, the episode ends for all agents.
88
+ terminateds = {"__all__": False}
89
+
90
+ opponent = "player1" if self.current_player == "player2" else "player2"
91
+
92
+ # Penalize trying to place a piece on an already occupied field.
93
+ if self.board[action] != 0:
94
+ rewards[self.current_player] -= 5.0
95
+ # Change the board according to the (valid) action taken.
96
+ else:
97
+ self.board[action] = 1 if self.current_player == "player1" else -1
98
+
99
+ # After having placed a new piece, figure out whether the current player
100
+ # won or not.
101
+ if self.current_player == "player1":
102
+ win_val = [1, 1, 1]
103
+ else:
104
+ win_val = [-1, -1, -1]
105
+ if (
106
+ # Horizontal win.
107
+ self.board[:3] == win_val
108
+ or self.board[3:6] == win_val
109
+ or self.board[6:] == win_val
110
+ # Vertical win.
111
+ or self.board[0:7:3] == win_val
112
+ or self.board[1:8:3] == win_val
113
+ or self.board[2:9:3] == win_val
114
+ # Diagonal win.
115
+ or self.board[::3] == win_val
116
+ or self.board[2:7:2] == win_val
117
+ ):
118
+ # Final reward is +5 for victory and -5 for a loss.
119
+ rewards[self.current_player] += 5.0
120
+ rewards[opponent] = -5.0
121
+
122
+ # Episode is done and needs to be reset for a new game.
123
+ terminateds["__all__"] = True
124
+
125
+ # The board might also be full w/o any player having won/lost.
126
+ # In this case, we simply end the episode and none of the players receives
127
+ # +1 or -1 reward.
128
+ elif 0 not in self.board:
129
+ terminateds["__all__"] = True
130
+
131
+ # Flip players and return an observations dict with only the next player to
132
+ # make a move in it.
133
+ self.current_player = opponent
134
+
135
+ return (
136
+ {self.current_player: np.array(self.board, np.float32)},
137
+ rewards,
138
+ terminateds,
139
+ {},
140
+ {},
141
+ )
142
+
143
+
144
+ # __sphinx_doc_4_end__
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/multi_agent/two_step_game.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gymnasium.spaces import Dict, Discrete, MultiDiscrete, Tuple
2
+ import numpy as np
3
+
4
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv, ENV_STATE
5
+
6
+
7
+ class TwoStepGame(MultiAgentEnv):
8
+ action_space = Discrete(2)
9
+
10
+ def __init__(self, env_config):
11
+ super().__init__()
12
+ self.action_space = Discrete(2)
13
+ self.state = None
14
+ self.agent_1 = 0
15
+ self.agent_2 = 1
16
+ # MADDPG emits action logits instead of actual discrete actions
17
+ self.actions_are_logits = env_config.get("actions_are_logits", False)
18
+ self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", False)
19
+ self.with_state = env_config.get("separate_state_space", False)
20
+ self._agent_ids = {0, 1}
21
+ if not self.one_hot_state_encoding:
22
+ self.observation_space = Discrete(6)
23
+ self.with_state = False
24
+ else:
25
+ # Each agent gets the full state (one-hot encoding of which of the
26
+ # three states are active) as input with the receiving agent's
27
+ # ID (1 or 2) concatenated onto the end.
28
+ if self.with_state:
29
+ self.observation_space = Dict(
30
+ {
31
+ "obs": MultiDiscrete([2, 2, 2, 3]),
32
+ ENV_STATE: MultiDiscrete([2, 2, 2]),
33
+ }
34
+ )
35
+ else:
36
+ self.observation_space = MultiDiscrete([2, 2, 2, 3])
37
+
38
+ def reset(self, *, seed=None, options=None):
39
+ if seed is not None:
40
+ np.random.seed(seed)
41
+ self.state = np.array([1, 0, 0])
42
+ return self._obs(), {}
43
+
44
+ def step(self, action_dict):
45
+ if self.actions_are_logits:
46
+ action_dict = {
47
+ k: np.random.choice([0, 1], p=v) for k, v in action_dict.items()
48
+ }
49
+
50
+ state_index = np.flatnonzero(self.state)
51
+ if state_index == 0:
52
+ action = action_dict[self.agent_1]
53
+ assert action in [0, 1], action
54
+ if action == 0:
55
+ self.state = np.array([0, 1, 0])
56
+ else:
57
+ self.state = np.array([0, 0, 1])
58
+ global_rew = 0
59
+ terminated = False
60
+ elif state_index == 1:
61
+ global_rew = 7
62
+ terminated = True
63
+ else:
64
+ if action_dict[self.agent_1] == 0 and action_dict[self.agent_2] == 0:
65
+ global_rew = 0
66
+ elif action_dict[self.agent_1] == 1 and action_dict[self.agent_2] == 1:
67
+ global_rew = 8
68
+ else:
69
+ global_rew = 1
70
+ terminated = True
71
+
72
+ rewards = {self.agent_1: global_rew / 2.0, self.agent_2: global_rew / 2.0}
73
+ obs = self._obs()
74
+ terminateds = {"__all__": terminated}
75
+ truncateds = {"__all__": False}
76
+ infos = {
77
+ self.agent_1: {"done": terminateds["__all__"]},
78
+ self.agent_2: {"done": terminateds["__all__"]},
79
+ }
80
+ return obs, rewards, terminateds, truncateds, infos
81
+
82
+ def _obs(self):
83
+ if self.with_state:
84
+ return {
85
+ self.agent_1: {"obs": self.agent_1_obs(), ENV_STATE: self.state},
86
+ self.agent_2: {"obs": self.agent_2_obs(), ENV_STATE: self.state},
87
+ }
88
+ else:
89
+ return {self.agent_1: self.agent_1_obs(), self.agent_2: self.agent_2_obs()}
90
+
91
+ def agent_1_obs(self):
92
+ if self.one_hot_state_encoding:
93
+ return np.concatenate([self.state, [1]])
94
+ else:
95
+ return np.flatnonzero(self.state)[0]
96
+
97
+ def agent_2_obs(self):
98
+ if self.one_hot_state_encoding:
99
+ return np.concatenate([self.state, [2]])
100
+ else:
101
+ return np.flatnonzero(self.state)[0] + 3
102
+
103
+
104
+ class TwoStepGameWithGroupedAgents(MultiAgentEnv):
105
+ def __init__(self, env_config):
106
+ super().__init__()
107
+ env = TwoStepGame(env_config)
108
+ tuple_obs_space = Tuple([env.observation_space, env.observation_space])
109
+ tuple_act_space = Tuple([env.action_space, env.action_space])
110
+ self._agent_ids = {"agents"}
111
+ self.env = env.with_agent_groups(
112
+ groups={"agents": [0, 1]},
113
+ obs_space=tuple_obs_space,
114
+ act_space=tuple_act_space,
115
+ )
116
+ self.observation_space = Dict({"agents": self.env.observation_space})
117
+ self.action_space = Dict({"agents": self.env.action_space})
118
+
119
+ def reset(self, *, seed=None, options=None):
120
+ return self.env.reset(seed=seed, options=options)
121
+
122
+ def step(self, actions):
123
+ return self.env.step(actions)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/nested_space_repeat_after_me_env.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.spaces import Box, Dict, Discrete, Tuple
3
+ import numpy as np
4
+ import tree # pip install dm_tree
5
+
6
+ from ray.rllib.utils.spaces.space_utils import flatten_space
7
+
8
+
9
+ class NestedSpaceRepeatAfterMeEnv(gym.Env):
10
+ """Env for which policy has to repeat the (possibly complex) observation.
11
+
12
+ The action space and observation spaces are always the same and may be
13
+ arbitrarily nested Dict/Tuple Spaces.
14
+ Rewards are given for exactly matching Discrete sub-actions and for being
15
+ as close as possible for Box sub-actions.
16
+ """
17
+
18
+ def __init__(self, config=None):
19
+ config = config or {}
20
+ self.observation_space = config.get(
21
+ "space", Tuple([Discrete(2), Dict({"a": Box(-1.0, 1.0, (2,))})])
22
+ )
23
+ self.action_space = self.observation_space
24
+ self.flattened_action_space = flatten_space(self.action_space)
25
+ self.episode_len = config.get("episode_len", 100)
26
+
27
+ def reset(self, *, seed=None, options=None):
28
+ self.steps = 0
29
+ return self._next_obs(), {}
30
+
31
+ def step(self, action):
32
+ self.steps += 1
33
+ action = tree.flatten(action)
34
+ reward = 0.0
35
+ for a, o, space in zip(
36
+ action, self.current_obs_flattened, self.flattened_action_space
37
+ ):
38
+ # Box: -abs(diff).
39
+ if isinstance(space, gym.spaces.Box):
40
+ reward -= np.sum(np.abs(a - o))
41
+ # Discrete: +1.0 if exact match.
42
+ if isinstance(space, gym.spaces.Discrete):
43
+ reward += 1.0 if a == o else 0.0
44
+ truncated = self.steps >= self.episode_len
45
+ return self._next_obs(), reward, False, truncated, {}
46
+
47
+ def _next_obs(self):
48
+ self.current_obs = self.observation_space.sample()
49
+ self.current_obs_flattened = tree.flatten(self.current_obs)
50
+ return self.current_obs
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/parametric_actions_cartpole.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gymnasium as gym
4
+ import numpy as np
5
+ from gymnasium.spaces import Box, Dict, Discrete
6
+
7
+
8
+ class ParametricActionsCartPole(gym.Env):
9
+ """Parametric action version of CartPole.
10
+
11
+ In this env there are only ever two valid actions, but we pretend there are
12
+ actually up to `max_avail_actions` actions that can be taken, and the two
13
+ valid actions are randomly hidden among this set.
14
+
15
+ At each step, we emit a dict of:
16
+ - the actual cart observation
17
+ - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail)
18
+ - the list of action embeddings (w/ zeroes for invalid actions) (e.g.,
19
+ [[0, 0],
20
+ [0, 0],
21
+ [-0.2322, -0.2569],
22
+ [0, 0],
23
+ [0, 0],
24
+ [0.7878, 1.2297]] for max_avail_actions=6)
25
+
26
+ In a real environment, the actions embeddings would be larger than two
27
+ units of course, and also there would be a variable number of valid actions
28
+ per step instead of always [LEFT, RIGHT].
29
+ """
30
+
31
+ def __init__(self, max_avail_actions):
32
+ # Use simple random 2-unit action embeddings for [LEFT, RIGHT]
33
+ self.left_action_embed = np.random.randn(2)
34
+ self.right_action_embed = np.random.randn(2)
35
+ self.action_space = Discrete(max_avail_actions)
36
+ self.wrapped = gym.make("CartPole-v1")
37
+ self.observation_space = Dict(
38
+ {
39
+ "action_mask": Box(0, 1, shape=(max_avail_actions,), dtype=np.int8),
40
+ "avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
41
+ "cart": self.wrapped.observation_space,
42
+ }
43
+ )
44
+
45
+ def update_avail_actions(self):
46
+ self.action_assignments = np.array(
47
+ [[0.0, 0.0]] * self.action_space.n, dtype=np.float32
48
+ )
49
+ self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8)
50
+ self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2)
51
+ self.action_assignments[self.left_idx] = self.left_action_embed
52
+ self.action_assignments[self.right_idx] = self.right_action_embed
53
+ self.action_mask[self.left_idx] = 1
54
+ self.action_mask[self.right_idx] = 1
55
+
56
+ def reset(self, *, seed=None, options=None):
57
+ self.update_avail_actions()
58
+ obs, infos = self.wrapped.reset()
59
+ return {
60
+ "action_mask": self.action_mask,
61
+ "avail_actions": self.action_assignments,
62
+ "cart": obs,
63
+ }, infos
64
+
65
+ def step(self, action):
66
+ if action == self.left_idx:
67
+ actual_action = 0
68
+ elif action == self.right_idx:
69
+ actual_action = 1
70
+ else:
71
+ raise ValueError(
72
+ "Chosen action was not one of the non-zero action embeddings",
73
+ action,
74
+ self.action_assignments,
75
+ self.action_mask,
76
+ self.left_idx,
77
+ self.right_idx,
78
+ )
79
+ orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
80
+ self.update_avail_actions()
81
+ self.action_mask = self.action_mask.astype(np.int8)
82
+ obs = {
83
+ "action_mask": self.action_mask,
84
+ "avail_actions": self.action_assignments,
85
+ "cart": orig_obs,
86
+ }
87
+ return obs, rew, done, truncated, info
88
+
89
+
90
+ class ParametricActionsCartPoleNoEmbeddings(gym.Env):
91
+ """Same as the above ParametricActionsCartPole.
92
+
93
+ However, action embeddings are not published inside observations,
94
+ but will be learnt by the model.
95
+
96
+ At each step, we emit a dict of:
97
+ - the actual cart observation
98
+ - a mask of valid actions (e.g., [0, 0, 1, 0, 0, 1] for 6 max avail)
99
+ - action embeddings (w/ "dummy embedding" for invalid actions) are
100
+ outsourced in the model and will be learned.
101
+ """
102
+
103
+ def __init__(self, max_avail_actions):
104
+ # Randomly set which two actions are valid and available.
105
+ self.left_idx, self.right_idx = random.sample(range(max_avail_actions), 2)
106
+ self.valid_avail_actions_mask = np.array(
107
+ [0.0] * max_avail_actions, dtype=np.int8
108
+ )
109
+ self.valid_avail_actions_mask[self.left_idx] = 1
110
+ self.valid_avail_actions_mask[self.right_idx] = 1
111
+ self.action_space = Discrete(max_avail_actions)
112
+ self.wrapped = gym.make("CartPole-v1")
113
+ self.observation_space = Dict(
114
+ {
115
+ "valid_avail_actions_mask": Box(0, 1, shape=(max_avail_actions,)),
116
+ "cart": self.wrapped.observation_space,
117
+ }
118
+ )
119
+
120
+ def reset(self, *, seed=None, options=None):
121
+ obs, infos = self.wrapped.reset()
122
+ return {
123
+ "valid_avail_actions_mask": self.valid_avail_actions_mask,
124
+ "cart": obs,
125
+ }, infos
126
+
127
+ def step(self, action):
128
+ if action == self.left_idx:
129
+ actual_action = 0
130
+ elif action == self.right_idx:
131
+ actual_action = 1
132
+ else:
133
+ raise ValueError(
134
+ "Chosen action was not one of the non-zero action embeddings",
135
+ action,
136
+ self.valid_avail_actions_mask,
137
+ self.left_idx,
138
+ self.right_idx,
139
+ )
140
+ orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
141
+ obs = {
142
+ "valid_avail_actions_mask": self.valid_avail_actions_mask,
143
+ "cart": orig_obs,
144
+ }
145
+ return obs, rew, done, truncated, info
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/random_env.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import gymnasium as gym
3
+ from gymnasium.spaces import Discrete, Tuple
4
+ import numpy as np
5
+
6
+ from ray.rllib.examples.envs.classes.multi_agent import make_multi_agent
7
+
8
+
9
+ class RandomEnv(gym.Env):
10
+ """A randomly acting environment.
11
+
12
+ Can be instantiated with arbitrary action-, observation-, and reward
13
+ spaces. Observations and rewards are generated by simply sampling from the
14
+ observation/reward spaces. The probability of a `terminated=True` after each
15
+ action can be configured, as well as the max episode length.
16
+ """
17
+
18
+ def __init__(self, config=None):
19
+ config = config or {}
20
+
21
+ # Action space.
22
+ self.action_space = config.get("action_space", Discrete(2))
23
+ # Observation space from which to sample.
24
+ self.observation_space = config.get("observation_space", Discrete(2))
25
+ # Reward space from which to sample.
26
+ self.reward_space = config.get(
27
+ "reward_space",
28
+ gym.spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32),
29
+ )
30
+ self.static_samples = config.get("static_samples", False)
31
+ if self.static_samples:
32
+ self.observation_sample = self.observation_space.sample()
33
+ self.reward_sample = self.reward_space.sample()
34
+
35
+ # Chance that an episode ends at any step.
36
+ # Note that a max episode length can be specified via
37
+ # `max_episode_len`.
38
+ self.p_terminated = config.get("p_terminated")
39
+ if self.p_terminated is None:
40
+ self.p_terminated = config.get("p_done", 0.1)
41
+ # A max episode length. Even if the `p_terminated` sampling does not lead
42
+ # to a terminus, the episode will end after at most this many
43
+ # timesteps.
44
+ # Set to 0 or None for using no limit on the episode length.
45
+ self.max_episode_len = config.get("max_episode_len", None)
46
+ # Whether to check action bounds.
47
+ self.check_action_bounds = config.get("check_action_bounds", False)
48
+ # Steps taken so far (after last reset).
49
+ self.steps = 0
50
+
51
+ def reset(self, *, seed=None, options=None):
52
+ self.steps = 0
53
+ if not self.static_samples:
54
+ return self.observation_space.sample(), {}
55
+ else:
56
+ return copy.deepcopy(self.observation_sample), {}
57
+
58
+ def step(self, action):
59
+ if self.check_action_bounds and not self.action_space.contains(action):
60
+ raise ValueError(
61
+ "Illegal action for {}: {}".format(self.action_space, action)
62
+ )
63
+ if isinstance(self.action_space, Tuple) and len(action) != len(
64
+ self.action_space.spaces
65
+ ):
66
+ raise ValueError(
67
+ "Illegal action for {}: {}".format(self.action_space, action)
68
+ )
69
+
70
+ self.steps += 1
71
+ terminated = False
72
+ truncated = False
73
+ # We are `truncated` as per our max-episode-len.
74
+ if self.max_episode_len and self.steps >= self.max_episode_len:
75
+ truncated = True
76
+ # Max episode length not reached yet -> Sample `terminated` via `p_terminated`.
77
+ elif self.p_terminated > 0.0:
78
+ terminated = bool(
79
+ np.random.choice(
80
+ [True, False], p=[self.p_terminated, 1.0 - self.p_terminated]
81
+ )
82
+ )
83
+
84
+ if not self.static_samples:
85
+ return (
86
+ self.observation_space.sample(),
87
+ self.reward_space.sample(),
88
+ terminated,
89
+ truncated,
90
+ {},
91
+ )
92
+ else:
93
+ return (
94
+ copy.deepcopy(self.observation_sample),
95
+ copy.deepcopy(self.reward_sample),
96
+ terminated,
97
+ truncated,
98
+ {},
99
+ )
100
+
101
+
102
+ # Multi-agent version of the RandomEnv.
103
+ RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c))
104
+
105
+
106
+ # Large observation space "pre-compiled" random env (for testing).
107
+ class RandomLargeObsSpaceEnv(RandomEnv):
108
+ def __init__(self, config=None):
109
+ config = config or {}
110
+ config.update({"observation_space": gym.spaces.Box(-1.0, 1.0, (5000,))})
111
+ super().__init__(config=config)
112
+
113
+
114
+ # Large observation space + cont. actions "pre-compiled" random env
115
+ # (for testing).
116
+ class RandomLargeObsSpaceEnvContActions(RandomEnv):
117
+ def __init__(self, config=None):
118
+ config = config or {}
119
+ config.update(
120
+ {
121
+ "observation_space": gym.spaces.Box(-1.0, 1.0, (5000,)),
122
+ "action_space": gym.spaces.Box(-1.0, 1.0, (5,)),
123
+ }
124
+ )
125
+ super().__init__(config=config)
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/recommender_system_envs_with_recsim.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Examples for RecSim envs ready to be used by RLlib Algorithms.
2
+
3
+ RecSim is a configurable recommender systems simulation platform.
4
+ Source: https://github.com/google-research/recsim
5
+ """
6
+
7
+ from recsim import choice_model
8
+ from recsim.environments import (
9
+ long_term_satisfaction as lts,
10
+ interest_evolution as iev,
11
+ interest_exploration as iex,
12
+ )
13
+
14
+ from ray.rllib.env.wrappers.recsim import make_recsim_env
15
+ from ray.tune import register_env
16
+
17
+ # Some built-in RecSim envs to test with.
18
+ # ---------------------------------------
19
+
20
+ # Long-term satisfaction env: User has to pick from items that are either
21
+ # a) unhealthy, but taste good, or b) healthy, but have bad taste.
22
+ # Best strategy is to pick a mix of both to ensure long-term
23
+ # engagement.
24
+
25
+
26
+ def lts_user_model_creator(env_ctx):
27
+ return lts.LTSUserModel(
28
+ env_ctx["slate_size"],
29
+ user_state_ctor=lts.LTSUserState,
30
+ response_model_ctor=lts.LTSResponse,
31
+ )
32
+
33
+
34
+ def lts_document_sampler_creator(env_ctx):
35
+ return lts.LTSDocumentSampler()
36
+
37
+
38
+ LongTermSatisfactionRecSimEnv = make_recsim_env(
39
+ recsim_user_model_creator=lts_user_model_creator,
40
+ recsim_document_sampler_creator=lts_document_sampler_creator,
41
+ reward_aggregator=lts.clicked_engagement_reward,
42
+ )
43
+
44
+
45
+ # Interest exploration env: Models the problem of active exploration
46
+ # of user interests. It is meant to illustrate popularity bias in
47
+ # recommender systems, where myopic maximization of engagement leads
48
+ # to bias towards documents that have wider appeal,
49
+ # whereas niche user interests remain unexplored.
50
+ def iex_user_model_creator(env_ctx):
51
+ return iex.IEUserModel(
52
+ env_ctx["slate_size"],
53
+ user_state_ctor=iex.IEUserState,
54
+ response_model_ctor=iex.IEResponse,
55
+ seed=env_ctx["seed"],
56
+ )
57
+
58
+
59
+ def iex_document_sampler_creator(env_ctx):
60
+ return iex.IETopicDocumentSampler(seed=env_ctx["seed"])
61
+
62
+
63
+ InterestExplorationRecSimEnv = make_recsim_env(
64
+ recsim_user_model_creator=iex_user_model_creator,
65
+ recsim_document_sampler_creator=iex_document_sampler_creator,
66
+ reward_aggregator=iex.total_clicks_reward,
67
+ )
68
+
69
+
70
+ # Interest evolution env: See https://github.com/google-research/recsim
71
+ # for more information.
72
+ def iev_user_model_creator(env_ctx):
73
+ return iev.IEvUserModel(
74
+ env_ctx["slate_size"],
75
+ choice_model_ctor=choice_model.MultinomialProportionalChoiceModel,
76
+ response_model_ctor=iev.IEvResponse,
77
+ user_state_ctor=iev.IEvUserState,
78
+ seed=env_ctx["seed"],
79
+ )
80
+
81
+
82
+ # Extend IEvVideo to fix a bug caused by None cluster_ids.
83
+ class SingleClusterIEvVideo(iev.IEvVideo):
84
+ def __init__(self, doc_id, features, video_length=None, quality=None):
85
+ super(SingleClusterIEvVideo, self).__init__(
86
+ doc_id=doc_id,
87
+ features=features,
88
+ cluster_id=0, # single cluster.
89
+ video_length=video_length,
90
+ quality=quality,
91
+ )
92
+
93
+
94
+ def iev_document_sampler_creator(env_ctx):
95
+ return iev.UtilityModelVideoSampler(doc_ctor=iev.IEvVideo, seed=env_ctx["seed"])
96
+
97
+
98
+ InterestEvolutionRecSimEnv = make_recsim_env(
99
+ recsim_user_model_creator=iev_user_model_creator,
100
+ recsim_document_sampler_creator=iev_document_sampler_creator,
101
+ reward_aggregator=iev.clicked_watchtime_reward,
102
+ )
103
+
104
+
105
+ # Backward compatibility.
106
+ register_env(
107
+ name="RecSim-v1", env_creator=lambda env_ctx: InterestEvolutionRecSimEnv(env_ctx)
108
+ )
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_after_me_env.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.spaces import Box, Discrete
3
+ import numpy as np
4
+
5
+
6
+ class RepeatAfterMeEnv(gym.Env):
7
+ """Env in which the observation at timestep minus n must be repeated."""
8
+
9
+ def __init__(self, config=None):
10
+ config = config or {}
11
+ if config.get("continuous"):
12
+ self.observation_space = Box(-1.0, 1.0, (2,))
13
+ else:
14
+ self.observation_space = Discrete(2)
15
+
16
+ self.action_space = self.observation_space
17
+ # Note: Set `repeat_delay` to 0 for simply repeating the seen
18
+ # observation (no delay).
19
+ self.delay = config.get("repeat_delay", 1)
20
+ self.episode_len = config.get("episode_len", 100)
21
+ self.history = []
22
+
23
+ def reset(self, *, seed=None, options=None):
24
+ self.history = [0] * self.delay
25
+ return self._next_obs(), {}
26
+
27
+ def step(self, action):
28
+ obs = self.history[-(1 + self.delay)]
29
+
30
+ reward = 0.0
31
+ # Box: -abs(diff).
32
+ if isinstance(self.action_space, Box):
33
+ reward = -np.sum(np.abs(action - obs))
34
+ # Discrete: +1.0 if exact match, -1.0 otherwise.
35
+ if isinstance(self.action_space, Discrete):
36
+ reward = 1.0 if action == obs else -1.0
37
+
38
+ done = truncated = len(self.history) > self.episode_len
39
+ return self._next_obs(), reward, done, truncated, {}
40
+
41
+ def _next_obs(self):
42
+ if isinstance(self.observation_space, Box):
43
+ token = np.random.random(size=(2,))
44
+ else:
45
+ token = np.random.choice([0, 1])
46
+ self.history.append(token)
47
+ return token
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/repeat_initial_obs_env.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.spaces import Discrete
3
+ import random
4
+
5
+
6
+ class RepeatInitialObsEnv(gym.Env):
7
+ """Env in which the initial observation has to be repeated all the time.
8
+
9
+ Runs for n steps.
10
+ r=1 if action correct, -1 otherwise (max. R=100).
11
+ """
12
+
13
+ def __init__(self, episode_len=100):
14
+ self.observation_space = Discrete(2)
15
+ self.action_space = Discrete(2)
16
+ self.token = None
17
+ self.episode_len = episode_len
18
+ self.num_steps = 0
19
+
20
+ def reset(self, *, seed=None, options=None):
21
+ self.token = random.choice([0, 1])
22
+ self.num_steps = 0
23
+ return self.token, {}
24
+
25
+ def step(self, action):
26
+ if action == self.token:
27
+ reward = 1
28
+ else:
29
+ reward = -1
30
+ self.num_steps += 1
31
+ done = truncated = self.num_steps >= self.episode_len
32
+ return 0, reward, done, truncated, {}
.venv/lib/python3.11/site-packages/ray/rllib/examples/envs/classes/simple_corridor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium.spaces import Box, Discrete
3
+ import numpy as np
4
+
5
+
6
+ class SimpleCorridor(gym.Env):
7
+ """Example of a custom env in which you have to walk down a corridor.
8
+
9
+ You can configure the length of the corridor via the env config."""
10
+
11
+ def __init__(self, config=None):
12
+ config = config or {}
13
+
14
+ self.action_space = Discrete(2)
15
+ self.observation_space = Box(0.0, 999.0, shape=(1,), dtype=np.float32)
16
+
17
+ self.set_corridor_length(config.get("corridor_length", 10))
18
+
19
+ self._cur_pos = 0
20
+
21
+ def set_corridor_length(self, length):
22
+ self.end_pos = length
23
+ print(f"Set corridor length to {self.end_pos}")
24
+ assert self.end_pos <= 999, "The maximum `corridor_length` allowed is 999!"
25
+
26
+ def reset(self, *, seed=None, options=None):
27
+ self._cur_pos = 0.0
28
+ return self._get_obs(), {}
29
+
30
+ def step(self, action):
31
+ assert action in [0, 1], action
32
+ if action == 0 and self._cur_pos > 0:
33
+ self._cur_pos -= 1.0
34
+ elif action == 1:
35
+ self._cur_pos += 1.0
36
+ terminated = self._cur_pos >= self.end_pos
37
+ truncated = False
38
+ reward = 1.0 if terminated else -0.01
39
+ return self._get_obs(), reward, terminated, truncated, {}
40
+
41
+ def _get_obs(self):
42
+ return np.array([self._cur_pos], np.float32)