koichi12 commited on
Commit
697a7f6
·
verified ·
1 Parent(s): b1f8d86

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py +20 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py +698 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py +1232 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py +378 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py +266 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py +87 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py +328 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py +2004 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py +264 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py +253 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py +10 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py +30 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py +51 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py +289 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py +82 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -178,3 +178,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
178
  .venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
179
  .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
180
  .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
178
  .venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
179
  .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
180
  .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
181
+ .venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70ee04d5ba78d502ad5d58d83cd6ec52ed3635c4af63ccc12837f71debf75e54
3
+ size 115849
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
2
+ from ray.rllib.evaluation.sample_batch_builder import (
3
+ SampleBatchBuilder,
4
+ MultiAgentSampleBatchBuilder,
5
+ )
6
+ from ray.rllib.evaluation.sampler import SyncSampler
7
+ from ray.rllib.evaluation.postprocessing import compute_advantages
8
+ from ray.rllib.evaluation.metrics import collect_metrics
9
+ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
10
+
11
+ __all__ = [
12
+ "RolloutWorker",
13
+ "SampleBatch",
14
+ "MultiAgentBatch",
15
+ "SampleBatchBuilder",
16
+ "MultiAgentSampleBatchBuilder",
17
+ "SyncSampler",
18
+ "compute_advantages",
19
+ "collect_metrics",
20
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (888 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc ADDED
Binary file (44.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc ADDED
Binary file (85.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from gymnasium.spaces import Space
3
+ import logging
4
+ import numpy as np
5
+ import tree # pip install dm_tree
6
+ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
7
+
8
+ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
9
+ from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
10
+ from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
11
+ from ray.rllib.policy.policy import Policy
12
+ from ray.rllib.policy.policy_map import PolicyMap
13
+ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, concat_samples
14
+ from ray.rllib.utils.annotations import OldAPIStack, override
15
+ from ray.rllib.utils.debug import summarize
16
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
17
+ from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
18
+ from ray.rllib.utils.typing import (
19
+ AgentID,
20
+ EpisodeID,
21
+ EnvID,
22
+ PolicyID,
23
+ TensorType,
24
+ ViewRequirementsDict,
25
+ )
26
+ from ray.util.debug import log_once
27
+
28
+ _, tf, _ = try_import_tf()
29
+ torch, _ = try_import_torch()
30
+
31
+ if TYPE_CHECKING:
32
+ from ray.rllib.callbacks.callbacks import RLlibCallback
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ @OldAPIStack
38
+ class _PolicyCollector:
39
+ """Collects already postprocessed (single agent) samples for one policy.
40
+
41
+ Samples come in through already postprocessed SampleBatches, which
42
+ contain single episode/trajectory data for a single agent and are then
43
+ appended to this policy's buffers.
44
+ """
45
+
46
+ def __init__(self, policy: Policy):
47
+ """Initializes a _PolicyCollector instance.
48
+
49
+ Args:
50
+ policy: The policy object.
51
+ """
52
+
53
+ self.batches = []
54
+ self.policy = policy
55
+ # The total timestep count for all agents that use this policy.
56
+ # NOTE: This is not an env-step count (across n agents). AgentA and
57
+ # agentB, both using this policy, acting in the same episode and both
58
+ # doing n steps would increase the count by 2*n.
59
+ self.agent_steps = 0
60
+
61
+ def add_postprocessed_batch_for_training(
62
+ self, batch: SampleBatch, view_requirements: ViewRequirementsDict
63
+ ) -> None:
64
+ """Adds a postprocessed SampleBatch (single agent) to our buffers.
65
+
66
+ Args:
67
+ batch: An individual agent's (one trajectory)
68
+ SampleBatch to be added to the Policy's buffers.
69
+ view_requirements: The view
70
+ requirements for the policy. This is so we know, whether a
71
+ view-column needs to be copied at all (not needed for
72
+ training).
73
+ """
74
+ # Add the agent's trajectory length to our count.
75
+ self.agent_steps += batch.count
76
+ # And remove columns not needed for training.
77
+ for view_col, view_req in view_requirements.items():
78
+ if view_col in batch and not view_req.used_for_training:
79
+ del batch[view_col]
80
+ self.batches.append(batch)
81
+
82
+ def build(self):
83
+ """Builds a SampleBatch for this policy from the collected data.
84
+
85
+ Also resets all buffers for further sample collection for this policy.
86
+
87
+ Returns:
88
+ SampleBatch: The SampleBatch with all thus-far collected data for
89
+ this policy.
90
+ """
91
+ # Create batch from our buffers.
92
+ batch = concat_samples(self.batches)
93
+ # Clear batches for future samples.
94
+ self.batches = []
95
+ # Reset agent steps to 0.
96
+ self.agent_steps = 0
97
+ # Add num_grad_updates counter to the policy's batch.
98
+ batch.num_grad_updates = self.policy.num_grad_updates
99
+
100
+ return batch
101
+
102
+
103
+ class _PolicyCollectorGroup:
104
+ def __init__(self, policy_map):
105
+ self.policy_collectors = {}
106
+ # Total env-steps (1 env-step=up to N agents stepped).
107
+ self.env_steps = 0
108
+ # Total agent steps (1 agent-step=1 individual agent (out of N)
109
+ # stepped).
110
+ self.agent_steps = 0
111
+
112
+
113
+ @OldAPIStack
114
+ class SimpleListCollector(SampleCollector):
115
+ """Util to build SampleBatches for each policy in a multi-agent env.
116
+
117
+ Input data is per-agent, while output data is per-policy. There is an M:N
118
+ mapping between agents and policies. We retain one local batch builder
119
+ per agent. When an agent is done, then its local batch is appended into the
120
+ corresponding policy batch for the agent's policy.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ policy_map: PolicyMap,
126
+ clip_rewards: Union[bool, float],
127
+ callbacks: "RLlibCallback",
128
+ multiple_episodes_in_batch: bool = True,
129
+ rollout_fragment_length: int = 200,
130
+ count_steps_by: str = "env_steps",
131
+ ):
132
+ """Initializes a SimpleListCollector instance."""
133
+
134
+ super().__init__(
135
+ policy_map,
136
+ clip_rewards,
137
+ callbacks,
138
+ multiple_episodes_in_batch,
139
+ rollout_fragment_length,
140
+ count_steps_by,
141
+ )
142
+
143
+ self.large_batch_threshold: int = (
144
+ max(1000, self.rollout_fragment_length * 10)
145
+ if self.rollout_fragment_length != float("inf")
146
+ else 5000
147
+ )
148
+
149
+ # Whenever we observe a new episode+agent, add a new
150
+ # _SingleTrajectoryCollector.
151
+ self.agent_collectors: Dict[Tuple[EpisodeID, AgentID], AgentCollector] = {}
152
+ # Internal agent-key-to-policy-id map.
153
+ self.agent_key_to_policy_id = {}
154
+ # Pool of used/unused PolicyCollectorGroups (attached to episodes for
155
+ # across-episode multi-agent sample collection).
156
+ self.policy_collector_groups = []
157
+
158
+ # Agents to collect data from for the next forward pass (per policy).
159
+ self.forward_pass_agent_keys = {pid: [] for pid in self.policy_map.keys()}
160
+ self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()}
161
+
162
+ # Maps episode ID to the (non-built) env steps taken in this episode.
163
+ self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
164
+ # Maps episode ID to the (non-built) individual agent steps in this
165
+ # episode.
166
+ self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int)
167
+ # Maps episode ID to Episode.
168
+ self.episodes = {}
169
+
170
+ @override(SampleCollector)
171
+ def episode_step(self, episode) -> None:
172
+ episode_id = episode.episode_id
173
+ # In the rase case that an "empty" step is taken at the beginning of
174
+ # the episode (none of the agents has an observation in the obs-dict
175
+ # and thus does not take an action), we have seen the episode before
176
+ # and have to add it here to our registry.
177
+ if episode_id not in self.episodes:
178
+ self.episodes[episode_id] = episode
179
+ else:
180
+ assert episode is self.episodes[episode_id]
181
+ self.episode_steps[episode_id] += 1
182
+ episode.length += 1
183
+
184
+ # In case of "empty" env steps (no agent is stepping), the builder
185
+ # object may still be None.
186
+ if episode.batch_builder:
187
+ env_steps = episode.batch_builder.env_steps
188
+ num_individual_observations = sum(
189
+ c.agent_steps for c in episode.batch_builder.policy_collectors.values()
190
+ )
191
+
192
+ if num_individual_observations > self.large_batch_threshold and log_once(
193
+ "large_batch_warning"
194
+ ):
195
+ logger.warning(
196
+ "More than {} observations in {} env steps for "
197
+ "episode {} ".format(
198
+ num_individual_observations, env_steps, episode_id
199
+ )
200
+ + "are buffered in the sampler. If this is more than you "
201
+ "expected, check that that you set a horizon on your "
202
+ "environment correctly and that it terminates at some "
203
+ "point. Note: In multi-agent environments, "
204
+ "`rollout_fragment_length` sets the batch size based on "
205
+ "(across-agents) environment steps, not the steps of "
206
+ "individual agents, which can result in unexpectedly "
207
+ "large batches."
208
+ + (
209
+ "Also, you may be waiting for your Env to "
210
+ "terminate (batch_mode=`complete_episodes`). Make sure "
211
+ "it does at some point."
212
+ if not self.multiple_episodes_in_batch
213
+ else ""
214
+ )
215
+ )
216
+
217
+ @override(SampleCollector)
218
+ def add_init_obs(
219
+ self,
220
+ *,
221
+ episode,
222
+ agent_id: AgentID,
223
+ env_id: EnvID,
224
+ policy_id: PolicyID,
225
+ init_obs: TensorType,
226
+ init_infos: Optional[Dict[str, TensorType]] = None,
227
+ t: int = -1,
228
+ ) -> None:
229
+ # Make sure our mappings are up to date.
230
+ agent_key = (episode.episode_id, agent_id)
231
+ self.agent_key_to_policy_id[agent_key] = policy_id
232
+ policy = self.policy_map[policy_id]
233
+
234
+ # Add initial obs to Trajectory.
235
+ assert agent_key not in self.agent_collectors
236
+ # TODO: determine exact shift-before based on the view-req shifts.
237
+
238
+ # get max_seq_len value (Default is 1)
239
+ try:
240
+ max_seq_len = policy.config["model"]["max_seq_len"]
241
+ except KeyError:
242
+ max_seq_len = 1
243
+
244
+ self.agent_collectors[agent_key] = AgentCollector(
245
+ policy.view_requirements,
246
+ max_seq_len=max_seq_len,
247
+ disable_action_flattening=policy.config.get(
248
+ "_disable_action_flattening", False
249
+ ),
250
+ intial_states=policy.get_initial_state(),
251
+ is_policy_recurrent=policy.is_recurrent(),
252
+ )
253
+ self.agent_collectors[agent_key].add_init_obs(
254
+ episode_id=episode.episode_id,
255
+ agent_index=episode._agent_index(agent_id),
256
+ env_id=env_id,
257
+ init_obs=init_obs,
258
+ init_infos=init_infos or {},
259
+ t=t,
260
+ )
261
+
262
+ self.episodes[episode.episode_id] = episode
263
+ if episode.batch_builder is None:
264
+ episode.batch_builder = (
265
+ self.policy_collector_groups.pop()
266
+ if self.policy_collector_groups
267
+ else _PolicyCollectorGroup(self.policy_map)
268
+ )
269
+
270
+ self._add_to_next_inference_call(agent_key)
271
+
272
+ @override(SampleCollector)
273
+ def add_action_reward_next_obs(
274
+ self,
275
+ episode_id: EpisodeID,
276
+ agent_id: AgentID,
277
+ env_id: EnvID,
278
+ policy_id: PolicyID,
279
+ agent_done: bool,
280
+ values: Dict[str, TensorType],
281
+ ) -> None:
282
+ # Make sure, episode/agent already has some (at least init) data.
283
+ agent_key = (episode_id, agent_id)
284
+ assert self.agent_key_to_policy_id[agent_key] == policy_id
285
+ assert agent_key in self.agent_collectors
286
+
287
+ self.agent_steps[episode_id] += 1
288
+
289
+ # Include the current agent id for multi-agent algorithms.
290
+ if agent_id != _DUMMY_AGENT_ID:
291
+ values["agent_id"] = agent_id
292
+
293
+ # Add action/reward/next-obs (and other data) to Trajectory.
294
+ self.agent_collectors[agent_key].add_action_reward_next_obs(values)
295
+
296
+ if not agent_done:
297
+ self._add_to_next_inference_call(agent_key)
298
+
299
+ @override(SampleCollector)
300
+ def total_env_steps(self) -> int:
301
+ # Add the non-built ongoing-episode env steps + the already built
302
+ # env-steps.
303
+ return sum(self.episode_steps.values()) + sum(
304
+ pg.env_steps for pg in self.policy_collector_groups.values()
305
+ )
306
+
307
+ @override(SampleCollector)
308
+ def total_agent_steps(self) -> int:
309
+ # Add the non-built ongoing-episode agent steps (still in the agent
310
+ # collectors) + the already built agent steps.
311
+ return sum(a.agent_steps for a in self.agent_collectors.values()) + sum(
312
+ pg.agent_steps for pg in self.policy_collector_groups.values()
313
+ )
314
+
315
+ @override(SampleCollector)
316
+ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType]:
317
+ policy = self.policy_map[policy_id]
318
+ keys = self.forward_pass_agent_keys[policy_id]
319
+ batch_size = len(keys)
320
+
321
+ # Return empty batch, if no forward pass to do.
322
+ if batch_size == 0:
323
+ return SampleBatch()
324
+
325
+ buffers = {}
326
+ for k in keys:
327
+ collector = self.agent_collectors[k]
328
+ buffers[k] = collector.buffers
329
+ # Use one agent's buffer_structs (they should all be the same).
330
+ buffer_structs = self.agent_collectors[keys[0]].buffer_structs
331
+
332
+ input_dict = {}
333
+ for view_col, view_req in policy.view_requirements.items():
334
+ # Not used for action computations.
335
+ if not view_req.used_for_compute_actions:
336
+ continue
337
+
338
+ # Create the batch of data from the different buffers.
339
+ data_col = view_req.data_col or view_col
340
+ delta = (
341
+ -1
342
+ if data_col
343
+ in [
344
+ SampleBatch.OBS,
345
+ SampleBatch.INFOS,
346
+ SampleBatch.ENV_ID,
347
+ SampleBatch.EPS_ID,
348
+ SampleBatch.AGENT_INDEX,
349
+ SampleBatch.T,
350
+ ]
351
+ else 0
352
+ )
353
+ # Range of shifts, e.g. "-100:0". Note: This includes index 0!
354
+ if view_req.shift_from is not None:
355
+ time_indices = (view_req.shift_from + delta, view_req.shift_to + delta)
356
+ # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
357
+ else:
358
+ time_indices = view_req.shift + delta
359
+
360
+ # Loop through agents and add up their data (batch).
361
+ data = None
362
+ for k in keys:
363
+ # Buffer for the data does not exist yet: Create dummy
364
+ # (zero) data.
365
+ if data_col not in buffers[k]:
366
+ if view_req.data_col is not None:
367
+ space = policy.view_requirements[view_req.data_col].space
368
+ else:
369
+ space = view_req.space
370
+
371
+ if isinstance(space, Space):
372
+ fill_value = get_dummy_batch_for_space(
373
+ space,
374
+ batch_size=0,
375
+ )
376
+ else:
377
+ fill_value = space
378
+
379
+ self.agent_collectors[k]._build_buffers({data_col: fill_value})
380
+
381
+ if data is None:
382
+ data = [[] for _ in range(len(buffers[keys[0]][data_col]))]
383
+
384
+ # `shift_from` and `shift_to` are defined: User wants a
385
+ # view with some time-range.
386
+ if isinstance(time_indices, tuple):
387
+ # `shift_to` == -1: Until the end (including(!) the
388
+ # last item).
389
+ if time_indices[1] == -1:
390
+ for d, b in zip(data, buffers[k][data_col]):
391
+ d.append(b[time_indices[0] :])
392
+ # `shift_to` != -1: "Normal" range.
393
+ else:
394
+ for d, b in zip(data, buffers[k][data_col]):
395
+ d.append(b[time_indices[0] : time_indices[1] + 1])
396
+ # Single index.
397
+ else:
398
+ for d, b in zip(data, buffers[k][data_col]):
399
+ d.append(b[time_indices])
400
+
401
+ np_data = [np.array(d) for d in data]
402
+ if data_col in buffer_structs:
403
+ input_dict[view_col] = tree.unflatten_as(
404
+ buffer_structs[data_col], np_data
405
+ )
406
+ else:
407
+ input_dict[view_col] = np_data[0]
408
+
409
+ self._reset_inference_calls(policy_id)
410
+
411
+ return SampleBatch(
412
+ input_dict,
413
+ seq_lens=np.ones(batch_size, dtype=np.int32)
414
+ if "state_in_0" in input_dict
415
+ else None,
416
+ )
417
+
418
+ @override(SampleCollector)
419
+ def postprocess_episode(
420
+ self,
421
+ episode,
422
+ is_done: bool = False,
423
+ check_dones: bool = False,
424
+ build: bool = False,
425
+ ) -> Union[None, SampleBatch, MultiAgentBatch]:
426
+ episode_id = episode.episode_id
427
+ policy_collector_group = episode.batch_builder
428
+
429
+ # Build SampleBatches for the given episode.
430
+ pre_batches = {}
431
+ for (eps_id, agent_id), collector in self.agent_collectors.items():
432
+ # Build only if there is data and agent is part of given episode.
433
+ if collector.agent_steps == 0 or eps_id != episode_id:
434
+ continue
435
+ pid = self.agent_key_to_policy_id[(eps_id, agent_id)]
436
+ policy = self.policy_map[pid]
437
+ pre_batch = collector.build_for_training(policy.view_requirements)
438
+ pre_batches[agent_id] = (policy, pre_batch)
439
+
440
+ # Apply reward clipping before calling postprocessing functions.
441
+ if self.clip_rewards is True:
442
+ for _, (_, pre_batch) in pre_batches.items():
443
+ pre_batch["rewards"] = np.sign(pre_batch["rewards"])
444
+ elif self.clip_rewards:
445
+ for _, (_, pre_batch) in pre_batches.items():
446
+ pre_batch["rewards"] = np.clip(
447
+ pre_batch["rewards"],
448
+ a_min=-self.clip_rewards,
449
+ a_max=self.clip_rewards,
450
+ )
451
+
452
+ post_batches = {}
453
+ for agent_id, (_, pre_batch) in pre_batches.items():
454
+ # Entire episode is said to be done.
455
+ # Error if no DONE at end of this agent's trajectory.
456
+ if is_done and check_dones and not pre_batch.is_terminated_or_truncated():
457
+ raise ValueError(
458
+ "Episode {} terminated for all agents, but we still "
459
+ "don't have a last observation for agent {} (policy "
460
+ "{}). ".format(
461
+ episode_id,
462
+ agent_id,
463
+ self.agent_key_to_policy_id[(episode_id, agent_id)],
464
+ )
465
+ + "Please ensure that you include the last observations "
466
+ "of all live agents when setting truncated[__all__] or "
467
+ "terminated[__all__] to True."
468
+ )
469
+
470
+ # Skip a trajectory's postprocessing (and thus using it for training),
471
+ # if its agent's info exists and contains the training_enabled=False
472
+ # setting (used by our PolicyClients).
473
+ last_info = episode.last_info_for(agent_id)
474
+ if last_info and not last_info.get("training_enabled", True):
475
+ if is_done:
476
+ agent_key = (episode_id, agent_id)
477
+ del self.agent_key_to_policy_id[agent_key]
478
+ del self.agent_collectors[agent_key]
479
+ continue
480
+
481
+ if len(pre_batches) > 1:
482
+ other_batches = pre_batches.copy()
483
+ del other_batches[agent_id]
484
+ else:
485
+ other_batches = {}
486
+ pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
487
+ policy = self.policy_map[pid]
488
+ if not pre_batch.is_single_trajectory():
489
+ raise ValueError(
490
+ "Batches sent to postprocessing must be from a single trajectory! "
491
+ "TERMINATED & TRUNCATED need to be False everywhere, except the "
492
+ "last timestep, which can be either True or False for those keys)!",
493
+ pre_batch,
494
+ )
495
+ elif len(set(pre_batch[SampleBatch.EPS_ID])) > 1:
496
+ episode_ids = set(pre_batch[SampleBatch.EPS_ID])
497
+ raise ValueError(
498
+ "Batches sent to postprocessing must only contain steps "
499
+ "from a single episode! Your trajectory contains data from "
500
+ f"{len(episode_ids)} episodes ({list(episode_ids)}).",
501
+ pre_batch,
502
+ )
503
+ # Call the Policy's Exploration's postprocess method.
504
+ post_batches[agent_id] = pre_batch
505
+ if getattr(policy, "exploration", None) is not None:
506
+ policy.exploration.postprocess_trajectory(
507
+ policy, post_batches[agent_id], policy.get_session()
508
+ )
509
+ post_batches[agent_id].set_get_interceptor(None)
510
+ post_batches[agent_id] = policy.postprocess_trajectory(
511
+ post_batches[agent_id], other_batches, episode
512
+ )
513
+
514
+ if log_once("after_post"):
515
+ logger.info(
516
+ "Trajectory fragment after postprocess_trajectory():\n\n{}\n".format(
517
+ summarize(post_batches)
518
+ )
519
+ )
520
+
521
+ # Append into policy batches and reset.
522
+ from ray.rllib.evaluation.rollout_worker import get_global_worker
523
+
524
+ for agent_id, post_batch in sorted(post_batches.items()):
525
+ agent_key = (episode_id, agent_id)
526
+ pid = self.agent_key_to_policy_id[agent_key]
527
+ policy = self.policy_map[pid]
528
+ self.callbacks.on_postprocess_trajectory(
529
+ worker=get_global_worker(),
530
+ episode=episode,
531
+ agent_id=agent_id,
532
+ policy_id=pid,
533
+ policies=self.policy_map,
534
+ postprocessed_batch=post_batch,
535
+ original_batches=pre_batches,
536
+ )
537
+
538
+ # Add the postprocessed SampleBatch to the policy collectors for
539
+ # training.
540
+ # PID may be a newly added policy. Just confirm we have it in our
541
+ # policy map before proceeding with adding a new _PolicyCollector()
542
+ # to the group.
543
+ if pid not in policy_collector_group.policy_collectors:
544
+ assert pid in self.policy_map
545
+ policy_collector_group.policy_collectors[pid] = _PolicyCollector(policy)
546
+ policy_collector_group.policy_collectors[
547
+ pid
548
+ ].add_postprocessed_batch_for_training(post_batch, policy.view_requirements)
549
+
550
+ if is_done:
551
+ del self.agent_key_to_policy_id[agent_key]
552
+ del self.agent_collectors[agent_key]
553
+
554
+ if policy_collector_group:
555
+ env_steps = self.episode_steps[episode_id]
556
+ policy_collector_group.env_steps += env_steps
557
+ agent_steps = self.agent_steps[episode_id]
558
+ policy_collector_group.agent_steps += agent_steps
559
+
560
+ if is_done:
561
+ del self.episode_steps[episode_id]
562
+ del self.episodes[episode_id]
563
+
564
+ if episode_id in self.agent_steps:
565
+ del self.agent_steps[episode_id]
566
+ else:
567
+ assert (
568
+ len(pre_batches) == 0
569
+ ), "Expected the batch to be empty since the episode_id is missing."
570
+ # if the key does not exist it means that throughout the episode all
571
+ # observations were empty (i.e. there was no agent in the env)
572
+ msg = (
573
+ f"Data from episode {episode_id} does not show any agent "
574
+ f"interactions. Hint: Make sure for at least one timestep in the "
575
+ f"episode, env.step() returns non-empty values."
576
+ )
577
+ raise ValueError(msg)
578
+
579
+ # Make PolicyCollectorGroup available for more agent batches in
580
+ # other episodes. Do not reset count to 0.
581
+ if policy_collector_group:
582
+ self.policy_collector_groups.append(policy_collector_group)
583
+ else:
584
+ self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0
585
+
586
+ # Build a MultiAgentBatch from the episode and return.
587
+ if build:
588
+ return self._build_multi_agent_batch(episode)
589
+
590
+ def _build_multi_agent_batch(self, episode) -> Union[MultiAgentBatch, SampleBatch]:
591
+
592
+ ma_batch = {}
593
+ for pid, collector in episode.batch_builder.policy_collectors.items():
594
+ if collector.agent_steps > 0:
595
+ ma_batch[pid] = collector.build()
596
+
597
+ # TODO(sven): We should always return the same type here (MultiAgentBatch),
598
+ # no matter what. Just have to unify our `training_step` methods, then. This
599
+ # will reduce a lot of confusion about what comes out of the sampling process.
600
+ # Create the batch.
601
+ ma_batch = MultiAgentBatch.wrap_as_needed(
602
+ ma_batch, env_steps=episode.batch_builder.env_steps
603
+ )
604
+
605
+ # PolicyCollectorGroup is empty.
606
+ episode.batch_builder.env_steps = 0
607
+ episode.batch_builder.agent_steps = 0
608
+
609
+ return ma_batch
610
+
611
+ @override(SampleCollector)
612
+ def try_build_truncated_episode_multi_agent_batch(
613
+ self,
614
+ ) -> List[Union[MultiAgentBatch, SampleBatch]]:
615
+ batches = []
616
+ # Loop through ongoing episodes and see whether their length plus
617
+ # what's already in the policy collectors reaches the fragment-len
618
+ # (abiding to the unit used: env-steps or agent-steps).
619
+ for episode_id, episode in self.episodes.items():
620
+ # Measure batch size in env-steps.
621
+ if self.count_steps_by == "env_steps":
622
+ built_steps = (
623
+ episode.batch_builder.env_steps if episode.batch_builder else 0
624
+ )
625
+ ongoing_steps = self.episode_steps[episode_id]
626
+ # Measure batch-size in agent-steps.
627
+ else:
628
+ built_steps = (
629
+ episode.batch_builder.agent_steps if episode.batch_builder else 0
630
+ )
631
+ ongoing_steps = self.agent_steps[episode_id]
632
+
633
+ # Reached the fragment-len -> We should build an MA-Batch.
634
+ if built_steps + ongoing_steps >= self.rollout_fragment_length:
635
+ if self.count_steps_by == "env_steps":
636
+ assert built_steps + ongoing_steps == self.rollout_fragment_length
637
+ # If we reached the fragment-len only because of `episode_id`
638
+ # (still ongoing) -> postprocess `episode_id` first.
639
+ if built_steps < self.rollout_fragment_length:
640
+ self.postprocess_episode(episode, is_done=False)
641
+ # If there is a builder for this episode,
642
+ # build the MA-batch and add to return values.
643
+ if episode.batch_builder:
644
+ batch = self._build_multi_agent_batch(episode=episode)
645
+ batches.append(batch)
646
+ # No batch-builder:
647
+ # We have reached the rollout-fragment length w/o any agent
648
+ # steps! Warn that the environment may never request any
649
+ # actions from any agents.
650
+ elif log_once("no_agent_steps"):
651
+ logger.warning(
652
+ "Your environment seems to be stepping w/o ever "
653
+ "emitting agent observations (agents are never "
654
+ "requested to act)!"
655
+ )
656
+
657
+ return batches
658
+
659
+ def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID]) -> None:
660
+ """Adds an Agent key (episode+agent IDs) to the next inference call.
661
+
662
+ This makes sure that the agent's current data (in the trajectory) is
663
+ used for generating the next input_dict for a
664
+ `Policy.compute_actions()` call.
665
+
666
+ Args:
667
+ agent_key (Tuple[EpisodeID, AgentID]: A unique agent key (across
668
+ vectorized environments).
669
+ """
670
+ pid = self.agent_key_to_policy_id[agent_key]
671
+
672
+ # PID may be a newly added policy (added on the fly during training).
673
+ # Just confirm we have it in our policy map before proceeding with
674
+ # forward_pass_size=0.
675
+ if pid not in self.forward_pass_size:
676
+ assert pid in self.policy_map
677
+ self.forward_pass_size[pid] = 0
678
+ self.forward_pass_agent_keys[pid] = []
679
+
680
+ idx = self.forward_pass_size[pid]
681
+ assert idx >= 0
682
+ if idx == 0:
683
+ self.forward_pass_agent_keys[pid].clear()
684
+
685
+ self.forward_pass_agent_keys[pid].append(agent_key)
686
+ self.forward_pass_size[pid] += 1
687
+
688
+ def _reset_inference_calls(self, policy_id: PolicyID) -> None:
689
+ """Resets internal inference input-dict registries.
690
+
691
+ Calling `self.get_inference_input_dict()` after this method is called
692
+ would return an empty input-dict.
693
+
694
+ Args:
695
+ policy_id: The policy ID for which to reset the
696
+ inference pointers.
697
+ """
698
+ self.forward_pass_size[policy_id] = 0
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py ADDED
@@ -0,0 +1,1232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import logging
3
+ import time
4
+ import tree # pip install dm_tree
5
+ from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union
6
+ import numpy as np
7
+
8
+ from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
9
+ from ray.rllib.env.external_env import ExternalEnvWrapper
10
+ from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
11
+ from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
12
+ from ray.rllib.evaluation.episode_v2 import EpisodeV2
13
+ from ray.rllib.evaluation.metrics import RolloutMetrics
14
+ from ray.rllib.models.preprocessors import Preprocessor
15
+ from ray.rllib.policy.policy import Policy
16
+ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
17
+ from ray.rllib.utils.annotations import OldAPIStack
18
+ from ray.rllib.utils.filter import Filter
19
+ from ray.rllib.utils.numpy import convert_to_numpy
20
+ from ray.rllib.utils.spaces.space_utils import unbatch, get_original_space
21
+ from ray.rllib.utils.typing import (
22
+ ActionConnectorDataType,
23
+ AgentConnectorDataType,
24
+ AgentID,
25
+ EnvActionType,
26
+ EnvID,
27
+ EnvInfoDict,
28
+ EnvObsType,
29
+ MultiAgentDict,
30
+ MultiEnvDict,
31
+ PolicyID,
32
+ PolicyOutputType,
33
+ SampleBatchType,
34
+ StateBatches,
35
+ TensorStructType,
36
+ )
37
+ from ray.util.debug import log_once
38
+
39
+ if TYPE_CHECKING:
40
+ from gymnasium.envs.classic_control.rendering import SimpleImageViewer
41
+
42
+ from ray.rllib.callbacks.callbacks import RLlibCallback
43
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
44
+
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ MIN_LARGE_BATCH_THRESHOLD = 1000
50
+ DEFAULT_LARGE_BATCH_THRESHOLD = 5000
51
+ MS_TO_SEC = 1000.0
52
+
53
+
54
+ @OldAPIStack
55
+ class _PerfStats:
56
+ """Sampler perf stats that will be included in rollout metrics."""
57
+
58
+ def __init__(self, ema_coef: Optional[float] = None):
59
+ # If not None, enable Exponential Moving Average mode.
60
+ # The way we update stats is by:
61
+ # updated = (1 - ema_coef) * old + ema_coef * new
62
+ # In general provides more responsive stats about sampler performance.
63
+ # TODO(jungong) : make ema the default (only) mode if it works well.
64
+ self.ema_coef = ema_coef
65
+
66
+ self.iters = 0
67
+ self.raw_obs_processing_time = 0.0
68
+ self.inference_time = 0.0
69
+ self.action_processing_time = 0.0
70
+ self.env_wait_time = 0.0
71
+ self.env_render_time = 0.0
72
+
73
+ def incr(self, field: str, value: Union[int, float]):
74
+ if field == "iters":
75
+ self.iters += value
76
+ return
77
+
78
+ # All the other fields support either global average or ema mode.
79
+ if self.ema_coef is None:
80
+ # Global average.
81
+ self.__dict__[field] += value
82
+ else:
83
+ self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[
84
+ field
85
+ ] + self.ema_coef * value
86
+
87
+ def _get_avg(self):
88
+ # Mean multiplicator (1000 = sec -> ms).
89
+ factor = MS_TO_SEC / self.iters
90
+ return {
91
+ # Raw observation preprocessing.
92
+ "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
93
+ # Computing actions through policy.
94
+ "mean_inference_ms": self.inference_time * factor,
95
+ # Processing actions (to be sent to env, e.g. clipping).
96
+ "mean_action_processing_ms": self.action_processing_time * factor,
97
+ # Waiting for environment (during poll).
98
+ "mean_env_wait_ms": self.env_wait_time * factor,
99
+ # Environment rendering (False by default).
100
+ "mean_env_render_ms": self.env_render_time * factor,
101
+ }
102
+
103
+ def _get_ema(self):
104
+ # In EMA mode, stats are already (exponentially) averaged,
105
+ # hence we only need to do the sec -> ms conversion here.
106
+ return {
107
+ # Raw observation preprocessing.
108
+ "mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC,
109
+ # Computing actions through policy.
110
+ "mean_inference_ms": self.inference_time * MS_TO_SEC,
111
+ # Processing actions (to be sent to env, e.g. clipping).
112
+ "mean_action_processing_ms": self.action_processing_time * MS_TO_SEC,
113
+ # Waiting for environment (during poll).
114
+ "mean_env_wait_ms": self.env_wait_time * MS_TO_SEC,
115
+ # Environment rendering (False by default).
116
+ "mean_env_render_ms": self.env_render_time * MS_TO_SEC,
117
+ }
118
+
119
+ def get(self):
120
+ if self.ema_coef is None:
121
+ return self._get_avg()
122
+ else:
123
+ return self._get_ema()
124
+
125
+
126
+ @OldAPIStack
127
+ class _NewDefaultDict(defaultdict):
128
+ def __missing__(self, env_id):
129
+ ret = self[env_id] = self.default_factory(env_id)
130
+ return ret
131
+
132
+
133
+ @OldAPIStack
134
+ def _build_multi_agent_batch(
135
+ episode_id: int,
136
+ batch_builder: _PolicyCollectorGroup,
137
+ large_batch_threshold: int,
138
+ multiple_episodes_in_batch: bool,
139
+ ) -> MultiAgentBatch:
140
+ """Build MultiAgentBatch from a dict of _PolicyCollectors.
141
+
142
+ Args:
143
+ env_steps: total env steps.
144
+ policy_collectors: collected training SampleBatchs by policy.
145
+
146
+ Returns:
147
+ Always returns a sample batch in MultiAgentBatch format.
148
+ """
149
+ ma_batch = {}
150
+ for pid, collector in batch_builder.policy_collectors.items():
151
+ if collector.agent_steps <= 0:
152
+ continue
153
+
154
+ if batch_builder.agent_steps > large_batch_threshold and log_once(
155
+ "large_batch_warning"
156
+ ):
157
+ logger.warning(
158
+ "More than {} observations in {} env steps for "
159
+ "episode {} ".format(
160
+ batch_builder.agent_steps, batch_builder.env_steps, episode_id
161
+ )
162
+ + "are buffered in the sampler. If this is more than you "
163
+ "expected, check that that you set a horizon on your "
164
+ "environment correctly and that it terminates at some "
165
+ "point. Note: In multi-agent environments, "
166
+ "`rollout_fragment_length` sets the batch size based on "
167
+ "(across-agents) environment steps, not the steps of "
168
+ "individual agents, which can result in unexpectedly "
169
+ "large batches."
170
+ + (
171
+ "Also, you may be waiting for your Env to "
172
+ "terminate (batch_mode=`complete_episodes`). Make sure "
173
+ "it does at some point."
174
+ if not multiple_episodes_in_batch
175
+ else ""
176
+ )
177
+ )
178
+
179
+ batch = collector.build()
180
+
181
+ ma_batch[pid] = batch
182
+
183
+ # Create the multi agent batch.
184
+ return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
185
+
186
+
187
+ @OldAPIStack
188
+ def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
189
+ """Batch a list of input SampleBatches into a single SampleBatch.
190
+
191
+ Args:
192
+ eval_data: list of SampleBatches.
193
+
194
+ Returns:
195
+ single batched SampleBatch.
196
+ """
197
+ inference_batch = concat_samples(eval_data)
198
+ if "state_in_0" in inference_batch:
199
+ batch_size = len(eval_data)
200
+ inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
201
+ return inference_batch
202
+
203
+
204
+ @OldAPIStack
205
+ class EnvRunnerV2:
206
+ """Collect experiences from user environment using Connectors."""
207
+
208
+ def __init__(
209
+ self,
210
+ worker: "RolloutWorker",
211
+ base_env: BaseEnv,
212
+ multiple_episodes_in_batch: bool,
213
+ callbacks: "RLlibCallback",
214
+ perf_stats: _PerfStats,
215
+ rollout_fragment_length: int = 200,
216
+ count_steps_by: str = "env_steps",
217
+ render: bool = None,
218
+ ):
219
+ """
220
+ Args:
221
+ worker: Reference to the current rollout worker.
222
+ base_env: Env implementing BaseEnv.
223
+ multiple_episodes_in_batch: Whether to pack multiple
224
+ episodes into each batch. This guarantees batches will be exactly
225
+ `rollout_fragment_length` in size.
226
+ callbacks: User callbacks to run on episode events.
227
+ perf_stats: Record perf stats into this object.
228
+ rollout_fragment_length: The length of a fragment to collect
229
+ before building a SampleBatch from the data and resetting
230
+ the SampleBatchBuilder object.
231
+ count_steps_by: One of "env_steps" (default) or "agent_steps".
232
+ Use "agent_steps", if you want rollout lengths to be counted
233
+ by individual agent steps. In a multi-agent env,
234
+ a single env_step contains one or more agent_steps, depending
235
+ on how many agents are present at any given time in the
236
+ ongoing episode.
237
+ render: Whether to try to render the environment after each
238
+ step.
239
+ """
240
+ self._worker = worker
241
+ if isinstance(base_env, ExternalEnvWrapper):
242
+ raise ValueError(
243
+ "Policies using the new Connector API do not support ExternalEnv."
244
+ )
245
+ self._base_env = base_env
246
+ self._multiple_episodes_in_batch = multiple_episodes_in_batch
247
+ self._callbacks = callbacks
248
+ self._perf_stats = perf_stats
249
+ self._rollout_fragment_length = rollout_fragment_length
250
+ self._count_steps_by = count_steps_by
251
+ self._render = render
252
+
253
+ # May be populated for image rendering.
254
+ self._simple_image_viewer: Optional[
255
+ "SimpleImageViewer"
256
+ ] = self._get_simple_image_viewer()
257
+
258
+ # Keeps track of active episodes.
259
+ self._active_episodes: Dict[EnvID, EpisodeV2] = {}
260
+ self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
261
+ self._new_batch_builder
262
+ )
263
+
264
+ self._large_batch_threshold: int = (
265
+ max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
266
+ if self._rollout_fragment_length != float("inf")
267
+ else DEFAULT_LARGE_BATCH_THRESHOLD
268
+ )
269
+
270
+ def _get_simple_image_viewer(self):
271
+ """Maybe construct a SimpleImageViewer instance for episode rendering."""
272
+ # Try to render the env, if required.
273
+ if not self._render:
274
+ return None
275
+
276
+ try:
277
+ from gymnasium.envs.classic_control.rendering import SimpleImageViewer
278
+
279
+ return SimpleImageViewer()
280
+ except (ImportError, ModuleNotFoundError):
281
+ self._render = False # disable rendering
282
+ logger.warning(
283
+ "Could not import gymnasium.envs.classic_control."
284
+ "rendering! Try `pip install gymnasium[all]`."
285
+ )
286
+
287
+ return None
288
+
289
+ def _call_on_episode_start(self, episode, env_id):
290
+ # Call each policy's Exploration.on_episode_start method.
291
+ # Note: This may break the exploration (e.g. ParameterNoise) of
292
+ # policies in the `policy_map` that have not been recently used
293
+ # (and are therefore stashed to disk). However, we certainly do not
294
+ # want to loop through all (even stashed) policies here as that
295
+ # would counter the purpose of the LRU policy caching.
296
+ for p in self._worker.policy_map.cache.values():
297
+ if getattr(p, "exploration", None) is not None:
298
+ p.exploration.on_episode_start(
299
+ policy=p,
300
+ environment=self._base_env,
301
+ episode=episode,
302
+ tf_sess=p.get_session(),
303
+ )
304
+ # Call `on_episode_start()` callback.
305
+ self._callbacks.on_episode_start(
306
+ worker=self._worker,
307
+ base_env=self._base_env,
308
+ policies=self._worker.policy_map,
309
+ env_index=env_id,
310
+ episode=episode,
311
+ )
312
+
313
+ def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
314
+ """Create a new batch builder.
315
+
316
+ We create a _PolicyCollectorGroup based on the full policy_map
317
+ as the batch builder.
318
+ """
319
+ return _PolicyCollectorGroup(self._worker.policy_map)
320
+
321
+ def run(self) -> Iterator[SampleBatchType]:
322
+ """Samples and yields training episodes continuously.
323
+
324
+ Yields:
325
+ Object containing state, action, reward, terminal condition,
326
+ and other fields as dictated by `policy`.
327
+ """
328
+ while True:
329
+ outputs = self.step()
330
+ for o in outputs:
331
+ yield o
332
+
333
+ def step(self) -> List[SampleBatchType]:
334
+ """Samples training episodes by stepping through environments."""
335
+
336
+ self._perf_stats.incr("iters", 1)
337
+
338
+ t0 = time.time()
339
+ # Get observations from all ready agents.
340
+ # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
341
+ (
342
+ unfiltered_obs,
343
+ rewards,
344
+ terminateds,
345
+ truncateds,
346
+ infos,
347
+ off_policy_actions,
348
+ ) = self._base_env.poll()
349
+ env_poll_time = time.time() - t0
350
+
351
+ # Process observations and prepare for policy evaluation.
352
+ t1 = time.time()
353
+ # types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]],
354
+ # List[Union[RolloutMetrics, SampleBatchType]]
355
+ active_envs, to_eval, outputs = self._process_observations(
356
+ unfiltered_obs=unfiltered_obs,
357
+ rewards=rewards,
358
+ terminateds=terminateds,
359
+ truncateds=truncateds,
360
+ infos=infos,
361
+ )
362
+ self._perf_stats.incr("raw_obs_processing_time", time.time() - t1)
363
+
364
+ # Do batched policy eval (accross vectorized envs).
365
+ t2 = time.time()
366
+ # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
367
+ eval_results = self._do_policy_eval(to_eval=to_eval)
368
+ self._perf_stats.incr("inference_time", time.time() - t2)
369
+
370
+ # Process results and update episode state.
371
+ t3 = time.time()
372
+ actions_to_send: Dict[
373
+ EnvID, Dict[AgentID, EnvActionType]
374
+ ] = self._process_policy_eval_results(
375
+ active_envs=active_envs,
376
+ to_eval=to_eval,
377
+ eval_results=eval_results,
378
+ off_policy_actions=off_policy_actions,
379
+ )
380
+ self._perf_stats.incr("action_processing_time", time.time() - t3)
381
+
382
+ # Return computed actions to ready envs. We also send to envs that have
383
+ # taken off-policy actions; those envs are free to ignore the action.
384
+ t4 = time.time()
385
+ self._base_env.send_actions(actions_to_send)
386
+ self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
387
+
388
+ self._maybe_render()
389
+
390
+ return outputs
391
+
392
+ def _get_rollout_metrics(
393
+ self, episode: EpisodeV2, policy_map: Dict[str, Policy]
394
+ ) -> List[RolloutMetrics]:
395
+ """Get rollout metrics from completed episode."""
396
+ # TODO(jungong) : why do we need to handle atari metrics differently?
397
+ # Can we unify atari and normal env metrics?
398
+ atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
399
+ if atari_metrics is not None:
400
+ for m in atari_metrics:
401
+ m._replace(custom_metrics=episode.custom_metrics)
402
+ return atari_metrics
403
+ # Create connector metrics
404
+ connector_metrics = {}
405
+ active_agents = episode.get_agents()
406
+ for agent in active_agents:
407
+ policy_id = episode.policy_for(agent)
408
+ policy = episode.policy_map[policy_id]
409
+ connector_metrics[policy_id] = policy.get_connector_metrics()
410
+ # Otherwise, return RolloutMetrics for the episode.
411
+ return [
412
+ RolloutMetrics(
413
+ episode_length=episode.length,
414
+ episode_reward=episode.total_reward,
415
+ agent_rewards=dict(episode.agent_rewards),
416
+ custom_metrics=episode.custom_metrics,
417
+ perf_stats={},
418
+ hist_data=episode.hist_data,
419
+ media=episode.media,
420
+ connector_metrics=connector_metrics,
421
+ )
422
+ ]
423
+
424
+ def _process_observations(
425
+ self,
426
+ unfiltered_obs: MultiEnvDict,
427
+ rewards: MultiEnvDict,
428
+ terminateds: MultiEnvDict,
429
+ truncateds: MultiEnvDict,
430
+ infos: MultiEnvDict,
431
+ ) -> Tuple[
432
+ Set[EnvID],
433
+ Dict[PolicyID, List[AgentConnectorDataType]],
434
+ List[Union[RolloutMetrics, SampleBatchType]],
435
+ ]:
436
+ """Process raw obs from env.
437
+
438
+ Group data for active agents by policy. Reset environments that are done.
439
+
440
+ Args:
441
+ unfiltered_obs: The unfiltered, raw observations from the BaseEnv
442
+ (vectorized, possibly multi-agent). Dict of dict: By env index,
443
+ then agent ID, then mapped to actual obs.
444
+ rewards: The rewards MultiEnvDict of the BaseEnv.
445
+ terminateds: The `terminated` flags MultiEnvDict of the BaseEnv.
446
+ truncateds: The `truncated` flags MultiEnvDict of the BaseEnv.
447
+ infos: The MultiEnvDict of infos dicts of the BaseEnv.
448
+
449
+ Returns:
450
+ A tuple of:
451
+ A list of envs that were active during this step.
452
+ AgentConnectorDataType for active agents for policy evaluation.
453
+ SampleBatches and RolloutMetrics for completed agents for output.
454
+ """
455
+ # Output objects.
456
+ # Note that we need to track envs that are active during this round explicitly,
457
+ # just to be confident which envs require us to send at least an empty action
458
+ # dict to.
459
+ # We can not get this from the _active_episode or to_eval lists because
460
+ # 1. All envs are not required to step during every single step. And
461
+ # 2. to_eval only contains data for the agents that are still active. An env may
462
+ # be active but all agents are done during the step.
463
+ active_envs: Set[EnvID] = set()
464
+ to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list)
465
+ outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
466
+
467
+ # For each (vectorized) sub-environment.
468
+ # types: EnvID, Dict[AgentID, EnvObsType]
469
+ for env_id, env_obs in unfiltered_obs.items():
470
+ # Check for env_id having returned an error instead of a multi-agent
471
+ # obs dict. This is how our BaseEnv can tell the caller to `poll()` that
472
+ # one of its sub-environments is faulty and should be restarted (and the
473
+ # ongoing episode should not be used for training).
474
+ if isinstance(env_obs, Exception):
475
+ assert terminateds[env_id]["__all__"] is True, (
476
+ f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
477
+ "as observation, the terminateds[__all__] flag must also be set to "
478
+ "True!"
479
+ )
480
+ # all_agents_obs is an Exception here.
481
+ # Drop this episode and skip to next.
482
+ self._handle_done_episode(
483
+ env_id=env_id,
484
+ env_obs_or_exception=env_obs,
485
+ is_done=True,
486
+ active_envs=active_envs,
487
+ to_eval=to_eval,
488
+ outputs=outputs,
489
+ )
490
+ continue
491
+
492
+ if env_id not in self._active_episodes:
493
+ episode: EpisodeV2 = self.create_episode(env_id)
494
+ self._active_episodes[env_id] = episode
495
+ else:
496
+ episode: EpisodeV2 = self._active_episodes[env_id]
497
+ # If this episode is brand-new, call the episode start callback(s).
498
+ # Note: EpisodeV2s are initialized with length=-1 (before the reset).
499
+ if not episode.has_init_obs():
500
+ self._call_on_episode_start(episode, env_id)
501
+
502
+ # Check episode termination conditions.
503
+ if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
504
+ all_agents_done = True
505
+ else:
506
+ all_agents_done = False
507
+ active_envs.add(env_id)
508
+
509
+ # Special handling of common info dict.
510
+ episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
511
+
512
+ # Agent sample batches grouped by policy. Each set of sample batches will
513
+ # go through agent connectors together.
514
+ sample_batches_by_policy = defaultdict(list)
515
+ # Whether an agent is terminated or truncated.
516
+ agent_terminateds = {}
517
+ agent_truncateds = {}
518
+ for agent_id, obs in env_obs.items():
519
+ assert agent_id != "__all__"
520
+
521
+ policy_id: PolicyID = episode.policy_for(agent_id)
522
+
523
+ agent_terminated = bool(
524
+ terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
525
+ )
526
+ agent_terminateds[agent_id] = agent_terminated
527
+ agent_truncated = bool(
528
+ truncateds[env_id]["__all__"]
529
+ or truncateds[env_id].get(agent_id, False)
530
+ )
531
+ agent_truncateds[agent_id] = agent_truncated
532
+
533
+ # A completely new agent is already done -> Skip entirely.
534
+ if not episode.has_init_obs(agent_id) and (
535
+ agent_terminated or agent_truncated
536
+ ):
537
+ continue
538
+
539
+ values_dict = {
540
+ SampleBatch.T: episode.length, # Episodes start at -1 before we
541
+ # add the initial obs. After that, we infer from initial obs at
542
+ # t=0 since that will be our new episode.length.
543
+ SampleBatch.ENV_ID: env_id,
544
+ SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
545
+ # Last action (SampleBatch.ACTIONS) column will be populated by
546
+ # StateBufferConnector.
547
+ # Reward received after taking action at timestep t.
548
+ SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
549
+ # After taking action=a, did we reach terminal?
550
+ SampleBatch.TERMINATEDS: agent_terminated,
551
+ # Was the episode truncated artificially
552
+ # (e.g. b/c of some time limit)?
553
+ SampleBatch.TRUNCATEDS: agent_truncated,
554
+ SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
555
+ SampleBatch.NEXT_OBS: obs,
556
+ }
557
+
558
+ # Queue this obs sample for connector preprocessing.
559
+ sample_batches_by_policy[policy_id].append((agent_id, values_dict))
560
+
561
+ # The entire episode is done.
562
+ if all_agents_done:
563
+ # Let's check to see if there are any agents that haven't got the
564
+ # last obs yet. If there are, we have to create fake-last
565
+ # observations for them. (the environment is not required to do so if
566
+ # terminateds[__all__]==True or truncateds[__all__]==True).
567
+ for agent_id in episode.get_agents():
568
+ # If the latest obs we got for this agent is done, or if its
569
+ # episode state is already done, nothing to do.
570
+ if (
571
+ agent_terminateds.get(agent_id, False)
572
+ or agent_truncateds.get(agent_id, False)
573
+ or episode.is_done(agent_id)
574
+ ):
575
+ continue
576
+
577
+ policy_id: PolicyID = episode.policy_for(agent_id)
578
+ policy = self._worker.policy_map[policy_id]
579
+
580
+ # Create a fake observation by sampling the original env
581
+ # observation space.
582
+ obs_space = get_original_space(policy.observation_space)
583
+ # Although there is no obs for this agent, there may be
584
+ # good rewards and info dicts for it.
585
+ # This is the case for e.g. OpenSpiel games, where a reward
586
+ # is only earned with the last step, but the obs for that
587
+ # step is {}.
588
+ reward = rewards[env_id].get(agent_id, 0.0)
589
+ info = infos[env_id].get(agent_id, {})
590
+ values_dict = {
591
+ SampleBatch.T: episode.length,
592
+ SampleBatch.ENV_ID: env_id,
593
+ SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
594
+ # TODO(sven): These should be the summed-up(!) rewards since the
595
+ # last observation received for this agent.
596
+ SampleBatch.REWARDS: reward,
597
+ SampleBatch.TERMINATEDS: True,
598
+ SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False),
599
+ SampleBatch.INFOS: info,
600
+ SampleBatch.NEXT_OBS: obs_space.sample(),
601
+ }
602
+
603
+ # Queue these fake obs for connector preprocessing too.
604
+ sample_batches_by_policy[policy_id].append((agent_id, values_dict))
605
+
606
+ # Run agent connectors.
607
+ for policy_id, batches in sample_batches_by_policy.items():
608
+ policy: Policy = self._worker.policy_map[policy_id]
609
+ # Collected full MultiAgentDicts for this environment.
610
+ # Run agent connectors.
611
+ assert (
612
+ policy.agent_connectors
613
+ ), "EnvRunnerV2 requires agent connectors to work."
614
+
615
+ acd_list: List[AgentConnectorDataType] = [
616
+ AgentConnectorDataType(env_id, agent_id, data)
617
+ for agent_id, data in batches
618
+ ]
619
+
620
+ # For all agents mapped to policy_id, run their data
621
+ # through agent_connectors.
622
+ processed = policy.agent_connectors(acd_list)
623
+
624
+ for d in processed:
625
+ # Record transition info if applicable.
626
+ if not episode.has_init_obs(d.agent_id):
627
+ episode.add_init_obs(
628
+ agent_id=d.agent_id,
629
+ init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
630
+ init_infos=d.data.raw_dict[SampleBatch.INFOS],
631
+ t=d.data.raw_dict[SampleBatch.T],
632
+ )
633
+ else:
634
+ episode.add_action_reward_done_next_obs(
635
+ d.agent_id, d.data.raw_dict
636
+ )
637
+
638
+ # Need to evaluate next actions.
639
+ if not (
640
+ all_agents_done
641
+ or agent_terminateds.get(d.agent_id, False)
642
+ or agent_truncateds.get(d.agent_id, False)
643
+ or episode.is_done(d.agent_id)
644
+ ):
645
+ # Add to eval set if env is not done and this particular agent
646
+ # is also not done.
647
+ item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
648
+ to_eval[policy_id].append(item)
649
+
650
+ # Finished advancing episode by 1 step, mark it so.
651
+ episode.step()
652
+
653
+ # Exception: The very first env.poll() call causes the env to get reset
654
+ # (no step taken yet, just a single starting observation logged).
655
+ # We need to skip this callback in this case.
656
+ if episode.length > 0:
657
+ # Invoke the `on_episode_step` callback after the step is logged
658
+ # to the episode.
659
+ self._callbacks.on_episode_step(
660
+ worker=self._worker,
661
+ base_env=self._base_env,
662
+ policies=self._worker.policy_map,
663
+ episode=episode,
664
+ env_index=env_id,
665
+ )
666
+
667
+ # Episode is terminated/truncated for all agents
668
+ # (terminateds[__all__] == True or truncateds[__all__] == True).
669
+ if all_agents_done:
670
+ # _handle_done_episode will build a MultiAgentBatch for all
671
+ # the agents that are done during this step of rollout in
672
+ # the case of _multiple_episodes_in_batch=False.
673
+ self._handle_done_episode(
674
+ env_id,
675
+ env_obs,
676
+ terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"],
677
+ active_envs,
678
+ to_eval,
679
+ outputs,
680
+ )
681
+
682
+ # Try to build something.
683
+ if self._multiple_episodes_in_batch:
684
+ sample_batch = self._try_build_truncated_episode_multi_agent_batch(
685
+ self._batch_builders[env_id], episode
686
+ )
687
+ if sample_batch:
688
+ outputs.append(sample_batch)
689
+
690
+ # SampleBatch built from data collected by batch_builder.
691
+ # Clean up and delete the batch_builder.
692
+ del self._batch_builders[env_id]
693
+
694
+ return active_envs, to_eval, outputs
695
+
696
+ def _build_done_episode(
697
+ self,
698
+ env_id: EnvID,
699
+ is_done: bool,
700
+ outputs: List[SampleBatchType],
701
+ ):
702
+ """Builds a MultiAgentSampleBatch from the episode and adds it to outputs.
703
+
704
+ Args:
705
+ env_id: The env id.
706
+ is_done: Whether the env is done.
707
+ outputs: The list of outputs to add the
708
+ """
709
+ episode: EpisodeV2 = self._active_episodes[env_id]
710
+ batch_builder = self._batch_builders[env_id]
711
+
712
+ episode.postprocess_episode(
713
+ batch_builder=batch_builder,
714
+ is_done=is_done,
715
+ check_dones=is_done,
716
+ )
717
+
718
+ # If, we are not allowed to pack the next episode into the same
719
+ # SampleBatch (batch_mode=complete_episodes) -> Build the
720
+ # MultiAgentBatch from a single episode and add it to "outputs".
721
+ # Otherwise, just postprocess and continue collecting across
722
+ # episodes.
723
+ if not self._multiple_episodes_in_batch:
724
+ ma_sample_batch = _build_multi_agent_batch(
725
+ episode.episode_id,
726
+ batch_builder,
727
+ self._large_batch_threshold,
728
+ self._multiple_episodes_in_batch,
729
+ )
730
+ if ma_sample_batch:
731
+ outputs.append(ma_sample_batch)
732
+
733
+ # SampleBatch built from data collected by batch_builder.
734
+ # Clean up and delete the batch_builder.
735
+ del self._batch_builders[env_id]
736
+
737
+ def __process_resetted_obs_for_eval(
738
+ self,
739
+ env_id: EnvID,
740
+ obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
741
+ infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
742
+ episode: EpisodeV2,
743
+ to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
744
+ ):
745
+ """Process resetted obs through agent connectors for policy eval.
746
+
747
+ Args:
748
+ env_id: The env id.
749
+ obs: The Resetted obs.
750
+ episode: New episode.
751
+ to_eval: List of agent connector data for policy eval.
752
+ """
753
+ per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
754
+ # types: AgentID, EnvObsType
755
+ for agent_id, raw_obs in obs[env_id].items():
756
+ policy_id: PolicyID = episode.policy_for(agent_id)
757
+ per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
758
+
759
+ for policy_id, agents_obs in per_policy_resetted_obs.items():
760
+ policy = self._worker.policy_map[policy_id]
761
+ acd_list: List[AgentConnectorDataType] = [
762
+ AgentConnectorDataType(
763
+ env_id,
764
+ agent_id,
765
+ {
766
+ SampleBatch.NEXT_OBS: obs,
767
+ SampleBatch.INFOS: infos,
768
+ SampleBatch.T: episode.length,
769
+ SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
770
+ },
771
+ )
772
+ for agent_id, obs in agents_obs
773
+ ]
774
+ # Call agent connectors on these initial obs.
775
+ processed = policy.agent_connectors(acd_list)
776
+
777
+ for d in processed:
778
+ episode.add_init_obs(
779
+ agent_id=d.agent_id,
780
+ init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
781
+ init_infos=d.data.raw_dict[SampleBatch.INFOS],
782
+ t=d.data.raw_dict[SampleBatch.T],
783
+ )
784
+ to_eval[policy_id].append(d)
785
+
786
+ def _handle_done_episode(
787
+ self,
788
+ env_id: EnvID,
789
+ env_obs_or_exception: MultiAgentDict,
790
+ is_done: bool,
791
+ active_envs: Set[EnvID],
792
+ to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
793
+ outputs: List[SampleBatchType],
794
+ ) -> None:
795
+ """Handle an all-finished episode.
796
+
797
+ Add collected SampleBatch to batch builder. Reset corresponding env, etc.
798
+
799
+ Args:
800
+ env_id: Environment ID.
801
+ env_obs_or_exception: Last per-environment observation or Exception.
802
+ env_infos: Last per-environment infos.
803
+ is_done: If all agents are done.
804
+ active_envs: Set of active env ids.
805
+ to_eval: Output container for policy eval data.
806
+ outputs: Output container for collected sample batches.
807
+ """
808
+ if isinstance(env_obs_or_exception, Exception):
809
+ episode_or_exception: Exception = env_obs_or_exception
810
+ # Tell the sampler we have got a faulty episode.
811
+ outputs.append(RolloutMetrics(episode_faulty=True))
812
+ else:
813
+ episode_or_exception: EpisodeV2 = self._active_episodes[env_id]
814
+ # Add rollout metrics.
815
+ outputs.extend(
816
+ self._get_rollout_metrics(
817
+ episode_or_exception, policy_map=self._worker.policy_map
818
+ )
819
+ )
820
+ # Output the collected episode after adding rollout metrics so that we
821
+ # always fetch metrics with RolloutWorker before we fetch samples.
822
+ # This is because we need to behave like env_runner() for now.
823
+ self._build_done_episode(env_id, is_done, outputs)
824
+
825
+ # Clean up and deleted the post-processed episode now that we have collected
826
+ # its data.
827
+ self.end_episode(env_id, episode_or_exception)
828
+ # Create a new episode instance (before we reset the sub-environment).
829
+ new_episode: EpisodeV2 = self.create_episode(env_id)
830
+
831
+ # The sub environment at index `env_id` might throw an exception
832
+ # during the following `try_reset()` attempt. If configured with
833
+ # `restart_failed_sub_environments=True`, the BaseEnv will restart
834
+ # the affected sub environment (create a new one using its c'tor) and
835
+ # must reset the recreated sub env right after that.
836
+ # Should the sub environment fail indefinitely during these
837
+ # repeated reset attempts, the entire worker will be blocked.
838
+ # This would be ok, b/c the alternative would be the worker crashing
839
+ # entirely.
840
+ while True:
841
+ resetted_obs, resetted_infos = self._base_env.try_reset(env_id)
842
+
843
+ if (
844
+ resetted_obs is None
845
+ or resetted_obs == ASYNC_RESET_RETURN
846
+ or not isinstance(resetted_obs[env_id], Exception)
847
+ ):
848
+ break
849
+ else:
850
+ # Report a faulty episode.
851
+ outputs.append(RolloutMetrics(episode_faulty=True))
852
+
853
+ # Reset connector state if this is a hard reset.
854
+ for p in self._worker.policy_map.cache.values():
855
+ p.agent_connectors.reset(env_id)
856
+
857
+ # Creates a new episode if this is not async return.
858
+ # If reset is async, we will get its result in some future poll.
859
+ if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
860
+ self._active_episodes[env_id] = new_episode
861
+ self._call_on_episode_start(new_episode, env_id)
862
+
863
+ self.__process_resetted_obs_for_eval(
864
+ env_id,
865
+ resetted_obs,
866
+ resetted_infos,
867
+ new_episode,
868
+ to_eval,
869
+ )
870
+
871
+ # Step after adding initial obs. This will give us 0 env and agent step.
872
+ new_episode.step()
873
+ active_envs.add(env_id)
874
+
875
+ def create_episode(self, env_id: EnvID) -> EpisodeV2:
876
+ """Creates a new EpisodeV2 instance and returns it.
877
+
878
+ Calls `on_episode_created` callbacks, but does NOT reset the respective
879
+ sub-environment yet.
880
+
881
+ Args:
882
+ env_id: Env ID.
883
+
884
+ Returns:
885
+ The newly created EpisodeV2 instance.
886
+ """
887
+ # Make sure we currently don't have an active episode under this env ID.
888
+ assert env_id not in self._active_episodes
889
+
890
+ # Create a new episode under the same `env_id` and call the
891
+ # `on_episode_created` callbacks.
892
+ new_episode = EpisodeV2(
893
+ env_id,
894
+ self._worker.policy_map,
895
+ self._worker.policy_mapping_fn,
896
+ worker=self._worker,
897
+ callbacks=self._callbacks,
898
+ )
899
+
900
+ # Call `on_episode_created()` callback.
901
+ self._callbacks.on_episode_created(
902
+ worker=self._worker,
903
+ base_env=self._base_env,
904
+ policies=self._worker.policy_map,
905
+ env_index=env_id,
906
+ episode=new_episode,
907
+ )
908
+ return new_episode
909
+
910
+ def end_episode(
911
+ self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
912
+ ):
913
+ """Cleans up an episode that has finished.
914
+
915
+ Args:
916
+ env_id: Env ID.
917
+ episode_or_exception: Instance of an episode if it finished successfully.
918
+ Otherwise, the exception that was thrown,
919
+ """
920
+ # Signal the end of an episode, either successfully with an Episode or
921
+ # unsuccessfully with an Exception.
922
+ self._callbacks.on_episode_end(
923
+ worker=self._worker,
924
+ base_env=self._base_env,
925
+ policies=self._worker.policy_map,
926
+ episode=episode_or_exception,
927
+ env_index=env_id,
928
+ )
929
+
930
+ # Call each (in-memory) policy's Exploration.on_episode_end
931
+ # method.
932
+ # Note: This may break the exploration (e.g. ParameterNoise) of
933
+ # policies in the `policy_map` that have not been recently used
934
+ # (and are therefore stashed to disk). However, we certainly do not
935
+ # want to loop through all (even stashed) policies here as that
936
+ # would counter the purpose of the LRU policy caching.
937
+ for p in self._worker.policy_map.cache.values():
938
+ if getattr(p, "exploration", None) is not None:
939
+ p.exploration.on_episode_end(
940
+ policy=p,
941
+ environment=self._base_env,
942
+ episode=episode_or_exception,
943
+ tf_sess=p.get_session(),
944
+ )
945
+
946
+ if isinstance(episode_or_exception, EpisodeV2):
947
+ episode = episode_or_exception
948
+ if episode.total_agent_steps == 0:
949
+ # if the key does not exist it means that throughout the episode all
950
+ # observations were empty (i.e. there was no agent in the env)
951
+ msg = (
952
+ f"Data from episode {episode.episode_id} does not show any agent "
953
+ f"interactions. Hint: Make sure for at least one timestep in the "
954
+ f"episode, env.step() returns non-empty values."
955
+ )
956
+ raise ValueError(msg)
957
+
958
+ # Clean up the episode and batch_builder for this env id.
959
+ if env_id in self._active_episodes:
960
+ del self._active_episodes[env_id]
961
+
962
+ def _try_build_truncated_episode_multi_agent_batch(
963
+ self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
964
+ ) -> Union[None, SampleBatch, MultiAgentBatch]:
965
+ # Measure batch size in env-steps.
966
+ if self._count_steps_by == "env_steps":
967
+ built_steps = batch_builder.env_steps
968
+ ongoing_steps = episode.active_env_steps
969
+ # Measure batch-size in agent-steps.
970
+ else:
971
+ built_steps = batch_builder.agent_steps
972
+ ongoing_steps = episode.active_agent_steps
973
+
974
+ # Reached the fragment-len -> We should build an MA-Batch.
975
+ if built_steps + ongoing_steps >= self._rollout_fragment_length:
976
+ if self._count_steps_by != "agent_steps":
977
+ assert built_steps + ongoing_steps == self._rollout_fragment_length, (
978
+ f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
979
+ f"rollout_fragment_length ({self._rollout_fragment_length})."
980
+ )
981
+
982
+ # If we reached the fragment-len only because of `episode_id`
983
+ # (still ongoing) -> postprocess `episode_id` first.
984
+ if built_steps < self._rollout_fragment_length:
985
+ episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
986
+
987
+ # If builder has collected some data,
988
+ # build the MA-batch and add to return values.
989
+ if batch_builder.agent_steps > 0:
990
+ return _build_multi_agent_batch(
991
+ episode.episode_id,
992
+ batch_builder,
993
+ self._large_batch_threshold,
994
+ self._multiple_episodes_in_batch,
995
+ )
996
+ # No batch-builder:
997
+ # We have reached the rollout-fragment length w/o any agent
998
+ # steps! Warn that the environment may never request any
999
+ # actions from any agents.
1000
+ elif log_once("no_agent_steps"):
1001
+ logger.warning(
1002
+ "Your environment seems to be stepping w/o ever "
1003
+ "emitting agent observations (agents are never "
1004
+ "requested to act)!"
1005
+ )
1006
+
1007
+ return None
1008
+
1009
+ def _do_policy_eval(
1010
+ self,
1011
+ to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
1012
+ ) -> Dict[PolicyID, PolicyOutputType]:
1013
+ """Call compute_actions on collected episode data to get next action.
1014
+
1015
+ Args:
1016
+ to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects
1017
+ (items in these lists will be the batch's items for the model
1018
+ forward pass).
1019
+
1020
+ Returns:
1021
+ Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
1022
+ """
1023
+ policies = self._worker.policy_map
1024
+
1025
+ # In case policy map has changed, try to find the new policy that
1026
+ # should handle all these per-agent eval data.
1027
+ # Throws exception if these agents are mapped to multiple different
1028
+ # policies now.
1029
+ def _try_find_policy_again(eval_data: AgentConnectorDataType):
1030
+ policy_id = None
1031
+ for d in eval_data:
1032
+ episode = self._active_episodes[d.env_id]
1033
+ # Force refresh policy mapping on the episode.
1034
+ pid = episode.policy_for(d.agent_id, refresh=True)
1035
+ if policy_id is not None and pid != policy_id:
1036
+ raise ValueError(
1037
+ "Policy map changed. The list of eval data that was handled "
1038
+ f"by a same policy is now handled by policy {pid} "
1039
+ "and {policy_id}. "
1040
+ "Please don't do this in the middle of an episode."
1041
+ )
1042
+ policy_id = pid
1043
+ return _get_or_raise(self._worker.policy_map, policy_id)
1044
+
1045
+ eval_results: Dict[PolicyID, TensorStructType] = {}
1046
+ for policy_id, eval_data in to_eval.items():
1047
+ # In case the policyID has been removed from this worker, we need to
1048
+ # re-assign policy_id and re-lookup the Policy object to use.
1049
+ try:
1050
+ policy: Policy = _get_or_raise(policies, policy_id)
1051
+ except ValueError:
1052
+ # policy_mapping_fn from the worker may have already been
1053
+ # changed (mapping fn not staying constant within one episode).
1054
+ policy: Policy = _try_find_policy_again(eval_data)
1055
+
1056
+ input_dict = _batch_inference_sample_batches(
1057
+ [d.data.sample_batch for d in eval_data]
1058
+ )
1059
+
1060
+ eval_results[policy_id] = policy.compute_actions_from_input_dict(
1061
+ input_dict,
1062
+ timestep=policy.global_timestep,
1063
+ episodes=[self._active_episodes[t.env_id] for t in eval_data],
1064
+ )
1065
+
1066
+ return eval_results
1067
+
1068
+ def _process_policy_eval_results(
1069
+ self,
1070
+ active_envs: Set[EnvID],
1071
+ to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
1072
+ eval_results: Dict[PolicyID, PolicyOutputType],
1073
+ off_policy_actions: MultiEnvDict,
1074
+ ):
1075
+ """Process the output of policy neural network evaluation.
1076
+
1077
+ Records policy evaluation results into agent connectors and
1078
+ returns replies to send back to agents in the env.
1079
+
1080
+ Args:
1081
+ active_envs: Set of env IDs that are still active.
1082
+ to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects.
1083
+ eval_results: Mapping of policy IDs to list of
1084
+ actions, rnn-out states, extra-action-fetches dicts.
1085
+ off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
1086
+ off-policy-action, returned by a `BaseEnv.poll()` call.
1087
+
1088
+ Returns:
1089
+ Nested dict of env id -> agent id -> actions to be sent to
1090
+ Env (np.ndarrays).
1091
+ """
1092
+ actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
1093
+
1094
+ for env_id in active_envs:
1095
+ actions_to_send[env_id] = {} # at minimum send empty dict
1096
+
1097
+ # types: PolicyID, List[AgentConnectorDataType]
1098
+ for policy_id, eval_data in to_eval.items():
1099
+ actions: TensorStructType = eval_results[policy_id][0]
1100
+ actions = convert_to_numpy(actions)
1101
+
1102
+ rnn_out: StateBatches = eval_results[policy_id][1]
1103
+ extra_action_out: dict = eval_results[policy_id][2]
1104
+
1105
+ # In case actions is a list (representing the 0th dim of a batch of
1106
+ # primitive actions), try converting it first.
1107
+ if isinstance(actions, list):
1108
+ actions = np.array(actions)
1109
+ # Split action-component batches into single action rows.
1110
+ actions: List[EnvActionType] = unbatch(actions)
1111
+
1112
+ policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
1113
+ assert (
1114
+ policy.agent_connectors and policy.action_connectors
1115
+ ), "EnvRunnerV2 requires action connectors to work."
1116
+
1117
+ # types: int, EnvActionType
1118
+ for i, action in enumerate(actions):
1119
+ env_id: int = eval_data[i].env_id
1120
+ agent_id: AgentID = eval_data[i].agent_id
1121
+ input_dict: TensorStructType = eval_data[i].data.raw_dict
1122
+
1123
+ rnn_states: List[StateBatches] = tree.map_structure(
1124
+ lambda x, i=i: x[i], rnn_out
1125
+ )
1126
+
1127
+ # extra_action_out could be a nested dict
1128
+ fetches: Dict = tree.map_structure(
1129
+ lambda x, i=i: x[i], extra_action_out
1130
+ )
1131
+
1132
+ # Post-process policy output by running them through action connectors.
1133
+ ac_data = ActionConnectorDataType(
1134
+ env_id, agent_id, input_dict, (action, rnn_states, fetches)
1135
+ )
1136
+
1137
+ action_to_send, rnn_states, fetches = policy.action_connectors(
1138
+ ac_data
1139
+ ).output
1140
+
1141
+ # The action we want to buffer is the direct output of
1142
+ # compute_actions_from_input_dict() here. This is because we want to
1143
+ # send the unsqushed actions to the environment while learning and
1144
+ # possibly basing subsequent actions on the squashed actions.
1145
+ action_to_buffer = (
1146
+ action
1147
+ if env_id not in off_policy_actions
1148
+ or agent_id not in off_policy_actions[env_id]
1149
+ else off_policy_actions[env_id][agent_id]
1150
+ )
1151
+
1152
+ # Notify agent connectors with this new policy output.
1153
+ # Necessary for state buffering agent connectors, for example.
1154
+ ac_data: ActionConnectorDataType = ActionConnectorDataType(
1155
+ env_id,
1156
+ agent_id,
1157
+ input_dict,
1158
+ (action_to_buffer, rnn_states, fetches),
1159
+ )
1160
+ policy.agent_connectors.on_policy_output(ac_data)
1161
+
1162
+ assert agent_id not in actions_to_send[env_id]
1163
+ actions_to_send[env_id][agent_id] = action_to_send
1164
+
1165
+ return actions_to_send
1166
+
1167
+ def _maybe_render(self):
1168
+ """Visualize environment."""
1169
+ # Check if we should render.
1170
+ if not self._render or not self._simple_image_viewer:
1171
+ return
1172
+
1173
+ t5 = time.time()
1174
+
1175
+ # Render can either return an RGB image (uint8 [w x h x 3] numpy
1176
+ # array) or take care of rendering itself (returning True).
1177
+ rendered = self._base_env.try_render()
1178
+ # Rendering returned an image -> Display it in a SimpleImageViewer.
1179
+ if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
1180
+ self._simple_image_viewer.imshow(rendered)
1181
+ elif rendered not in [True, False, None]:
1182
+ raise ValueError(
1183
+ f"The env's ({self._base_env}) `try_render()` method returned an"
1184
+ " unsupported value! Make sure you either return a "
1185
+ "uint8/w x h x 3 (RGB) image or handle rendering in a "
1186
+ "window and then return `True`."
1187
+ )
1188
+
1189
+ self._perf_stats.incr("env_render_time", time.time() - t5)
1190
+
1191
+
1192
+ def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
1193
+ """Atari games have multiple logical episodes, one per life.
1194
+
1195
+ However, for metrics reporting we count full episodes, all lives included.
1196
+ """
1197
+ sub_environments = base_env.get_sub_environments()
1198
+ if not sub_environments:
1199
+ return None
1200
+ atari_out = []
1201
+ for sub_env in sub_environments:
1202
+ monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
1203
+ if not monitor:
1204
+ return None
1205
+ for eps_rew, eps_len in monitor.next_episode_results():
1206
+ atari_out.append(RolloutMetrics(eps_len, eps_rew))
1207
+ return atari_out
1208
+
1209
+
1210
+ def _get_or_raise(
1211
+ mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID
1212
+ ) -> Union[Policy, Preprocessor, Filter]:
1213
+ """Returns an object under key `policy_id` in `mapping`.
1214
+
1215
+ Args:
1216
+ mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
1217
+ mapping dict from policy id (str) to actual object (Policy,
1218
+ Preprocessor, etc.).
1219
+ policy_id: The policy ID to lookup.
1220
+
1221
+ Returns:
1222
+ Union[Policy, Preprocessor, Filter]: The found object.
1223
+
1224
+ Raises:
1225
+ ValueError: If `policy_id` cannot be found in `mapping`.
1226
+ """
1227
+ if policy_id not in mapping:
1228
+ raise ValueError(
1229
+ "Could not find policy for agent: PolicyID `{}` not found "
1230
+ "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
1231
+ )
1232
+ return mapping[policy_id]
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+ import numpy as np
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
5
+
6
+ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
7
+ from ray.rllib.evaluation.collectors.simple_list_collector import (
8
+ _PolicyCollector,
9
+ _PolicyCollectorGroup,
10
+ )
11
+ from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
12
+ from ray.rllib.policy.policy_map import PolicyMap
13
+ from ray.rllib.policy.sample_batch import SampleBatch
14
+ from ray.rllib.utils.annotations import OldAPIStack
15
+ from ray.rllib.utils.typing import AgentID, EnvID, EnvInfoDict, PolicyID, TensorType
16
+
17
+ if TYPE_CHECKING:
18
+ from ray.rllib.callbacks.callbacks import RLlibCallback
19
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
20
+
21
+
22
+ @OldAPIStack
23
+ class EpisodeV2:
24
+ """Tracks the current state of a (possibly multi-agent) episode."""
25
+
26
+ def __init__(
27
+ self,
28
+ env_id: EnvID,
29
+ policies: PolicyMap,
30
+ policy_mapping_fn: Callable[[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID],
31
+ *,
32
+ worker: Optional["RolloutWorker"] = None,
33
+ callbacks: Optional["RLlibCallback"] = None,
34
+ ):
35
+ """Initializes an Episode instance.
36
+
37
+ Args:
38
+ env_id: The environment's ID in which this episode runs.
39
+ policies: The PolicyMap object (mapping PolicyIDs to Policy
40
+ objects) to use for determining, which policy is used for
41
+ which agent.
42
+ policy_mapping_fn: The mapping function mapping AgentIDs to
43
+ PolicyIDs.
44
+ worker: The RolloutWorker instance, in which this episode runs.
45
+ """
46
+ # Unique id identifying this trajectory.
47
+ self.episode_id: int = random.randrange(int(1e18))
48
+ # ID of the environment this episode is tracking.
49
+ self.env_id = env_id
50
+ # Summed reward across all agents in this episode.
51
+ self.total_reward: float = 0.0
52
+ # Active (uncollected) # of env steps taken by this episode.
53
+ # Start from -1. After add_init_obs(), we will be at 0 step.
54
+ self.active_env_steps: int = -1
55
+ # Total # of env steps taken by this episode.
56
+ # Start from -1, After add_init_obs(), we will be at 0 step.
57
+ self.total_env_steps: int = -1
58
+ # Active (uncollected) agent steps.
59
+ self.active_agent_steps: int = 0
60
+ # Total # of steps take by all agents in this env.
61
+ self.total_agent_steps: int = 0
62
+ # Dict for user to add custom metrics.
63
+ # TODO (sven): We should probably unify custom_metrics, user_data,
64
+ # and hist_data into a single data container for user to track per-step.
65
+ # metrics and states.
66
+ self.custom_metrics: Dict[str, float] = {}
67
+ # Temporary storage. E.g. storing data in between two custom
68
+ # callbacks referring to the same episode.
69
+ self.user_data: Dict[str, Any] = {}
70
+ # Dict mapping str keys to List[float] for storage of
71
+ # per-timestep float data throughout the episode.
72
+ self.hist_data: Dict[str, List[float]] = {}
73
+ self.media: Dict[str, Any] = {}
74
+
75
+ self.worker = worker
76
+ self.callbacks = callbacks
77
+
78
+ self.policy_map: PolicyMap = policies
79
+ self.policy_mapping_fn: Callable[
80
+ [AgentID, "EpisodeV2", "RolloutWorker"], PolicyID
81
+ ] = policy_mapping_fn
82
+ # Per-agent data collectors.
83
+ self._agent_to_policy: Dict[AgentID, PolicyID] = {}
84
+ self._agent_collectors: Dict[AgentID, AgentCollector] = {}
85
+
86
+ self._next_agent_index: int = 0
87
+ self._agent_to_index: Dict[AgentID, int] = {}
88
+
89
+ # Summed rewards broken down by agent.
90
+ self.agent_rewards: Dict[Tuple[AgentID, PolicyID], float] = defaultdict(float)
91
+ self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list)
92
+
93
+ self._has_init_obs: Dict[AgentID, bool] = {}
94
+ self._last_terminateds: Dict[AgentID, bool] = {}
95
+ self._last_truncateds: Dict[AgentID, bool] = {}
96
+ # Keep last info dict around, in case an environment tries to signal
97
+ # us something.
98
+ self._last_infos: Dict[AgentID, Dict] = {}
99
+
100
+ def policy_for(
101
+ self, agent_id: AgentID = _DUMMY_AGENT_ID, refresh: bool = False
102
+ ) -> PolicyID:
103
+ """Returns and stores the policy ID for the specified agent.
104
+
105
+ If the agent is new, the policy mapping fn will be called to bind the
106
+ agent to a policy for the duration of the entire episode (even if the
107
+ policy_mapping_fn is changed in the meantime!).
108
+
109
+ Args:
110
+ agent_id: The agent ID to lookup the policy ID for.
111
+
112
+ Returns:
113
+ The policy ID for the specified agent.
114
+ """
115
+
116
+ # Perform a new policy_mapping_fn lookup and bind AgentID for the
117
+ # duration of this episode to the returned PolicyID.
118
+ if agent_id not in self._agent_to_policy or refresh:
119
+ policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
120
+ agent_id, # agent_id
121
+ self, # episode
122
+ worker=self.worker,
123
+ )
124
+ # Use already determined PolicyID.
125
+ else:
126
+ policy_id = self._agent_to_policy[agent_id]
127
+
128
+ # PolicyID not found in policy map -> Error.
129
+ if policy_id not in self.policy_map:
130
+ raise KeyError(
131
+ "policy_mapping_fn returned invalid policy id " f"'{policy_id}'!"
132
+ )
133
+ return policy_id
134
+
135
+ def get_agents(self) -> List[AgentID]:
136
+ """Returns list of agent IDs that have appeared in this episode.
137
+
138
+ Returns:
139
+ The list of all agent IDs that have appeared so far in this
140
+ episode.
141
+ """
142
+ return list(self._agent_to_index.keys())
143
+
144
+ def agent_index(self, agent_id: AgentID) -> int:
145
+ """Get the index of an agent among its environment.
146
+
147
+ A new index will be created if an agent is seen for the first time.
148
+
149
+ Args:
150
+ agent_id: ID of an agent.
151
+
152
+ Returns:
153
+ The index of this agent.
154
+ """
155
+ if agent_id not in self._agent_to_index:
156
+ self._agent_to_index[agent_id] = self._next_agent_index
157
+ self._next_agent_index += 1
158
+ return self._agent_to_index[agent_id]
159
+
160
+ def step(self) -> None:
161
+ """Advance the episode forward by one step."""
162
+ self.active_env_steps += 1
163
+ self.total_env_steps += 1
164
+
165
+ def add_init_obs(
166
+ self,
167
+ *,
168
+ agent_id: AgentID,
169
+ init_obs: TensorType,
170
+ init_infos: Dict[str, TensorType],
171
+ t: int = -1,
172
+ ) -> None:
173
+ """Add initial env obs at the start of a new episode
174
+
175
+ Args:
176
+ agent_id: Agent ID.
177
+ init_obs: Initial observations.
178
+ init_infos: Initial infos dicts.
179
+ t: timestamp.
180
+ """
181
+ policy = self.policy_map[self.policy_for(agent_id)]
182
+
183
+ # Add initial obs to Trajectory.
184
+ assert agent_id not in self._agent_collectors
185
+
186
+ self._agent_collectors[agent_id] = AgentCollector(
187
+ policy.view_requirements,
188
+ max_seq_len=policy.config["model"]["max_seq_len"],
189
+ disable_action_flattening=policy.config.get(
190
+ "_disable_action_flattening", False
191
+ ),
192
+ is_policy_recurrent=policy.is_recurrent(),
193
+ intial_states=policy.get_initial_state(),
194
+ _enable_new_api_stack=False,
195
+ )
196
+ self._agent_collectors[agent_id].add_init_obs(
197
+ episode_id=self.episode_id,
198
+ agent_index=self.agent_index(agent_id),
199
+ env_id=self.env_id,
200
+ init_obs=init_obs,
201
+ init_infos=init_infos,
202
+ t=t,
203
+ )
204
+
205
+ self._has_init_obs[agent_id] = True
206
+
207
+ def add_action_reward_done_next_obs(
208
+ self,
209
+ agent_id: AgentID,
210
+ values: Dict[str, TensorType],
211
+ ) -> None:
212
+ """Add action, reward, info, and next_obs as a new step.
213
+
214
+ Args:
215
+ agent_id: Agent ID.
216
+ values: Dict of action, reward, info, and next_obs.
217
+ """
218
+ # Make sure, agent already has some (at least init) data.
219
+ assert agent_id in self._agent_collectors
220
+
221
+ self.active_agent_steps += 1
222
+ self.total_agent_steps += 1
223
+
224
+ # Include the current agent id for multi-agent algorithms.
225
+ if agent_id != _DUMMY_AGENT_ID:
226
+ values["agent_id"] = agent_id
227
+
228
+ # Add action/reward/next-obs (and other data) to Trajectory.
229
+ self._agent_collectors[agent_id].add_action_reward_next_obs(values)
230
+
231
+ # Keep track of agent reward history.
232
+ reward = values[SampleBatch.REWARDS]
233
+ self.total_reward += reward
234
+ self.agent_rewards[(agent_id, self.policy_for(agent_id))] += reward
235
+ self._agent_reward_history[agent_id].append(reward)
236
+
237
+ # Keep track of last terminated info for agent.
238
+ if SampleBatch.TERMINATEDS in values:
239
+ self._last_terminateds[agent_id] = values[SampleBatch.TERMINATEDS]
240
+ # Keep track of last truncated info for agent.
241
+ if SampleBatch.TRUNCATEDS in values:
242
+ self._last_truncateds[agent_id] = values[SampleBatch.TRUNCATEDS]
243
+
244
+ # Keep track of last info dict if available.
245
+ if SampleBatch.INFOS in values:
246
+ self.set_last_info(agent_id, values[SampleBatch.INFOS])
247
+
248
+ def postprocess_episode(
249
+ self,
250
+ batch_builder: _PolicyCollectorGroup,
251
+ is_done: bool = False,
252
+ check_dones: bool = False,
253
+ ) -> None:
254
+ """Build and return currently collected training samples by policies.
255
+
256
+ Clear agent collector states if this episode is done.
257
+
258
+ Args:
259
+ batch_builder: _PolicyCollectorGroup for saving the collected per-agent
260
+ sample batches.
261
+ is_done: If this episode is done (terminated or truncated).
262
+ check_dones: Whether to make sure per-agent trajectories are actually done.
263
+ """
264
+ # TODO: (sven) Once we implement multi-agent communication channels,
265
+ # we have to resolve the restriction of only sending other agent
266
+ # batches from the same policy to the postprocess methods.
267
+ # Build SampleBatches for the given episode.
268
+ pre_batches = {}
269
+ for agent_id, collector in self._agent_collectors.items():
270
+ # Build only if there is data and agent is part of given episode.
271
+ if collector.agent_steps == 0:
272
+ continue
273
+ pid = self.policy_for(agent_id)
274
+ policy = self.policy_map[pid]
275
+ pre_batch = collector.build_for_training(policy.view_requirements)
276
+ pre_batches[agent_id] = (pid, policy, pre_batch)
277
+
278
+ for agent_id, (pid, policy, pre_batch) in pre_batches.items():
279
+ # Entire episode is said to be done.
280
+ # Error if no DONE at end of this agent's trajectory.
281
+ if is_done and check_dones and not pre_batch.is_terminated_or_truncated():
282
+ raise ValueError(
283
+ "Episode {} terminated for all agents, but we still "
284
+ "don't have a last observation for agent {} (policy "
285
+ "{}). ".format(self.episode_id, agent_id, self.policy_for(agent_id))
286
+ + "Please ensure that you include the last observations "
287
+ "of all live agents when setting done[__all__] to "
288
+ "True."
289
+ )
290
+
291
+ # Skip a trajectory's postprocessing (and thus using it for training),
292
+ # if its agent's info exists and contains the training_enabled=False
293
+ # setting (used by our PolicyClients).
294
+ if not self._last_infos.get(agent_id, {}).get("training_enabled", True):
295
+ continue
296
+
297
+ if (
298
+ not pre_batch.is_single_trajectory()
299
+ or len(np.unique(pre_batch[SampleBatch.EPS_ID])) > 1
300
+ ):
301
+ raise ValueError(
302
+ "Batches sent to postprocessing must only contain steps "
303
+ "from a single trajectory.",
304
+ pre_batch,
305
+ )
306
+
307
+ if len(pre_batches) > 1:
308
+ other_batches = pre_batches.copy()
309
+ del other_batches[agent_id]
310
+ else:
311
+ other_batches = {}
312
+
313
+ # Call the Policy's Exploration's postprocess method.
314
+ post_batch = pre_batch
315
+ if getattr(policy, "exploration", None) is not None:
316
+ policy.exploration.postprocess_trajectory(
317
+ policy, post_batch, policy.get_session()
318
+ )
319
+ post_batch.set_get_interceptor(None)
320
+ post_batch = policy.postprocess_trajectory(post_batch, other_batches, self)
321
+
322
+ from ray.rllib.evaluation.rollout_worker import get_global_worker
323
+
324
+ self.callbacks.on_postprocess_trajectory(
325
+ worker=get_global_worker(),
326
+ episode=self,
327
+ agent_id=agent_id,
328
+ policy_id=pid,
329
+ policies=self.policy_map,
330
+ postprocessed_batch=post_batch,
331
+ original_batches=pre_batches,
332
+ )
333
+
334
+ # Append post_batch for return.
335
+ if pid not in batch_builder.policy_collectors:
336
+ batch_builder.policy_collectors[pid] = _PolicyCollector(policy)
337
+ batch_builder.policy_collectors[pid].add_postprocessed_batch_for_training(
338
+ post_batch, policy.view_requirements
339
+ )
340
+
341
+ batch_builder.agent_steps += self.active_agent_steps
342
+ batch_builder.env_steps += self.active_env_steps
343
+
344
+ # AgentCollector cleared.
345
+ self.active_agent_steps = 0
346
+ self.active_env_steps = 0
347
+
348
+ def has_init_obs(self, agent_id: AgentID = None) -> bool:
349
+ """Returns whether this episode has initial obs for an agent.
350
+
351
+ If agent_id is None, return whether we have received any initial obs,
352
+ in other words, whether this episode is completely fresh.
353
+ """
354
+ if agent_id is not None:
355
+ return agent_id in self._has_init_obs and self._has_init_obs[agent_id]
356
+ else:
357
+ return any(list(self._has_init_obs.values()))
358
+
359
+ def is_done(self, agent_id: AgentID) -> bool:
360
+ return self.is_terminated(agent_id) or self.is_truncated(agent_id)
361
+
362
+ def is_terminated(self, agent_id: AgentID) -> bool:
363
+ return self._last_terminateds.get(agent_id, False)
364
+
365
+ def is_truncated(self, agent_id: AgentID) -> bool:
366
+ return self._last_truncateds.get(agent_id, False)
367
+
368
+ def set_last_info(self, agent_id: AgentID, info: Dict):
369
+ self._last_infos[agent_id] = info
370
+
371
+ def last_info_for(
372
+ self, agent_id: AgentID = _DUMMY_AGENT_ID
373
+ ) -> Optional[EnvInfoDict]:
374
+ return self._last_infos.get(agent_id)
375
+
376
+ @property
377
+ def length(self):
378
+ return self.total_env_steps
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import numpy as np
4
+ from typing import List, Optional, TYPE_CHECKING
5
+
6
+ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
7
+ from ray.rllib.utils.annotations import OldAPIStack
8
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
9
+ from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
10
+
11
+ if TYPE_CHECKING:
12
+ from ray.rllib.env.env_runner_group import EnvRunnerGroup
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ RolloutMetrics = OldAPIStack(
17
+ collections.namedtuple(
18
+ "RolloutMetrics",
19
+ [
20
+ "episode_length",
21
+ "episode_reward",
22
+ "agent_rewards",
23
+ "custom_metrics",
24
+ "perf_stats",
25
+ "hist_data",
26
+ "media",
27
+ "episode_faulty",
28
+ "connector_metrics",
29
+ ],
30
+ )
31
+ )
32
+ RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {}, False, {})
33
+
34
+
35
+ @OldAPIStack
36
+ def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict:
37
+ """Return optimization stats reported from the policy.
38
+
39
+ .. testcode::
40
+ :skipif: True
41
+
42
+ grad_info = worker.learn_on_batch(samples)
43
+
44
+ # {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}
45
+
46
+ print(get_stats(grad_info))
47
+
48
+ .. testoutput::
49
+
50
+ {"vf_loss": ..., "policy_loss": ...}
51
+ """
52
+ if LEARNER_STATS_KEY in grad_info:
53
+ return grad_info[LEARNER_STATS_KEY]
54
+
55
+ multiagent_stats = {}
56
+ for k, v in grad_info.items():
57
+ if type(v) is dict:
58
+ if LEARNER_STATS_KEY in v:
59
+ multiagent_stats[k] = v[LEARNER_STATS_KEY]
60
+
61
+ return multiagent_stats
62
+
63
+
64
+ @OldAPIStack
65
+ def collect_metrics(
66
+ workers: "EnvRunnerGroup",
67
+ remote_worker_ids: Optional[List[int]] = None,
68
+ timeout_seconds: int = 180,
69
+ keep_custom_metrics: bool = False,
70
+ ) -> ResultDict:
71
+ """Gathers episode metrics from rollout worker set.
72
+
73
+ Args:
74
+ workers: EnvRunnerGroup.
75
+ remote_worker_ids: Optional list of IDs of remote workers to collect
76
+ metrics from.
77
+ timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
78
+ keep_custom_metrics: Whether to keep custom metrics in the result dict as
79
+ they are (True) or to aggregate them (False).
80
+
81
+ Returns:
82
+ A result dict of metrics.
83
+ """
84
+ episodes = collect_episodes(
85
+ workers, remote_worker_ids, timeout_seconds=timeout_seconds
86
+ )
87
+ metrics = summarize_episodes(
88
+ episodes, episodes, keep_custom_metrics=keep_custom_metrics
89
+ )
90
+ return metrics
91
+
92
+
93
+ @OldAPIStack
94
+ def collect_episodes(
95
+ workers: "EnvRunnerGroup",
96
+ remote_worker_ids: Optional[List[int]] = None,
97
+ timeout_seconds: int = 180,
98
+ ) -> List[RolloutMetrics]:
99
+ """Gathers new episodes metrics tuples from the given RolloutWorkers.
100
+
101
+ Args:
102
+ workers: EnvRunnerGroup.
103
+ remote_worker_ids: Optional list of IDs of remote workers to collect
104
+ metrics from.
105
+ timeout_seconds: Timeout in seconds for collecting metrics from remote workers.
106
+
107
+ Returns:
108
+ List of RolloutMetrics.
109
+ """
110
+ # This will drop get_metrics() calls that are too slow.
111
+ # We can potentially make this an asynchronous call if this turns
112
+ # out to be a problem.
113
+ metric_lists = workers.foreach_env_runner(
114
+ lambda w: w.get_metrics(),
115
+ local_env_runner=True,
116
+ remote_worker_ids=remote_worker_ids,
117
+ timeout_seconds=timeout_seconds,
118
+ )
119
+ if len(metric_lists) == 0:
120
+ logger.warning("WARNING: collected no metrics.")
121
+
122
+ episodes = []
123
+ for metrics in metric_lists:
124
+ episodes.extend(metrics)
125
+
126
+ return episodes
127
+
128
+
129
+ @OldAPIStack
130
+ def summarize_episodes(
131
+ episodes: List[RolloutMetrics],
132
+ new_episodes: List[RolloutMetrics] = None,
133
+ keep_custom_metrics: bool = False,
134
+ ) -> ResultDict:
135
+ """Summarizes a set of episode metrics tuples.
136
+
137
+ Args:
138
+ episodes: List of most recent n episodes. This may include historical ones
139
+ (not newly collected in this iteration) in order to achieve the size of
140
+ the smoothing window.
141
+ new_episodes: All the episodes that were completed in this iteration.
142
+ keep_custom_metrics: Whether to keep custom metrics in the result dict as
143
+ they are (True) or to aggregate them (False).
144
+
145
+ Returns:
146
+ A result dict of metrics.
147
+ """
148
+
149
+ if new_episodes is None:
150
+ new_episodes = episodes
151
+
152
+ episode_rewards = []
153
+ episode_lengths = []
154
+ policy_rewards = collections.defaultdict(list)
155
+ custom_metrics = collections.defaultdict(list)
156
+ perf_stats = collections.defaultdict(list)
157
+ hist_stats = collections.defaultdict(list)
158
+ episode_media = collections.defaultdict(list)
159
+ connector_metrics = collections.defaultdict(list)
160
+ num_faulty_episodes = 0
161
+
162
+ for episode in episodes:
163
+ # Faulty episodes may still carry perf_stats data.
164
+ for k, v in episode.perf_stats.items():
165
+ perf_stats[k].append(v)
166
+ # Continue if this is a faulty episode.
167
+ # There should be other meaningful stats to be collected.
168
+ if episode.episode_faulty:
169
+ num_faulty_episodes += 1
170
+ continue
171
+
172
+ episode_lengths.append(episode.episode_length)
173
+ episode_rewards.append(episode.episode_reward)
174
+ for k, v in episode.custom_metrics.items():
175
+ custom_metrics[k].append(v)
176
+ is_multi_agent = (
177
+ len(episode.agent_rewards) > 1
178
+ or DEFAULT_POLICY_ID not in episode.agent_rewards
179
+ )
180
+ if is_multi_agent:
181
+ for (_, policy_id), reward in episode.agent_rewards.items():
182
+ policy_rewards[policy_id].append(reward)
183
+ for k, v in episode.hist_data.items():
184
+ hist_stats[k] += v
185
+ for k, v in episode.media.items():
186
+ episode_media[k].append(v)
187
+ if hasattr(episode, "connector_metrics"):
188
+ # Group connector metrics by connector_metric name for all policies
189
+ for per_pipeline_metrics in episode.connector_metrics.values():
190
+ for per_connector_metrics in per_pipeline_metrics.values():
191
+ for connector_metric_name, val in per_connector_metrics.items():
192
+ connector_metrics[connector_metric_name].append(val)
193
+
194
+ if episode_rewards:
195
+ min_reward = min(episode_rewards)
196
+ max_reward = max(episode_rewards)
197
+ avg_reward = np.mean(episode_rewards)
198
+ else:
199
+ min_reward = float("nan")
200
+ max_reward = float("nan")
201
+ avg_reward = float("nan")
202
+ if episode_lengths:
203
+ avg_length = np.mean(episode_lengths)
204
+ else:
205
+ avg_length = float("nan")
206
+
207
+ # Show as histogram distributions.
208
+ hist_stats["episode_reward"] = episode_rewards
209
+ hist_stats["episode_lengths"] = episode_lengths
210
+
211
+ policy_reward_min = {}
212
+ policy_reward_mean = {}
213
+ policy_reward_max = {}
214
+ for policy_id, rewards in policy_rewards.copy().items():
215
+ policy_reward_min[policy_id] = np.min(rewards)
216
+ policy_reward_mean[policy_id] = np.mean(rewards)
217
+ policy_reward_max[policy_id] = np.max(rewards)
218
+
219
+ # Show as histogram distributions.
220
+ hist_stats["policy_{}_reward".format(policy_id)] = rewards
221
+
222
+ for k, v_list in custom_metrics.copy().items():
223
+ filt = [v for v in v_list if not np.any(np.isnan(v))]
224
+ if keep_custom_metrics:
225
+ custom_metrics[k] = filt
226
+ else:
227
+ custom_metrics[k + "_mean"] = np.mean(filt)
228
+ if filt:
229
+ custom_metrics[k + "_min"] = np.min(filt)
230
+ custom_metrics[k + "_max"] = np.max(filt)
231
+ else:
232
+ custom_metrics[k + "_min"] = float("nan")
233
+ custom_metrics[k + "_max"] = float("nan")
234
+ del custom_metrics[k]
235
+
236
+ for k, v_list in perf_stats.copy().items():
237
+ perf_stats[k] = np.mean(v_list)
238
+
239
+ mean_connector_metrics = dict()
240
+ for k, v_list in connector_metrics.items():
241
+ mean_connector_metrics[k] = np.mean(v_list)
242
+
243
+ return dict(
244
+ episode_reward_max=max_reward,
245
+ episode_reward_min=min_reward,
246
+ episode_reward_mean=avg_reward,
247
+ episode_len_mean=avg_length,
248
+ episode_media=dict(episode_media),
249
+ episodes_timesteps_total=sum(episode_lengths),
250
+ policy_reward_min=policy_reward_min,
251
+ policy_reward_max=policy_reward_max,
252
+ policy_reward_mean=policy_reward_mean,
253
+ custom_metrics=dict(custom_metrics),
254
+ hist_stats=dict(hist_stats),
255
+ sampler_perf=dict(perf_stats),
256
+ num_faulty_episodes=num_faulty_episodes,
257
+ connector_metrics=mean_connector_metrics,
258
+ # Added these (duplicate) values here for forward compatibility with the new API
259
+ # stack's metrics structure. This allows us to unify our test cases and keeping
260
+ # the new API stack clean of backward-compatible keys.
261
+ num_episodes=len(new_episodes),
262
+ episode_return_max=max_reward,
263
+ episode_return_min=min_reward,
264
+ episode_return_mean=avg_reward,
265
+ episodes_this_iter=len(new_episodes), # deprecate in favor of `num_epsodes_...`
266
+ )
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ray.rllib.env import BaseEnv
4
+ from ray.rllib.policy import Policy
5
+ from ray.rllib.evaluation import RolloutWorker
6
+ from ray.rllib.utils.annotations import OldAPIStack
7
+ from ray.rllib.utils.framework import TensorType
8
+ from ray.rllib.utils.typing import AgentID, PolicyID
9
+
10
+
11
+ @OldAPIStack
12
+ class ObservationFunction:
13
+ """Interceptor function for rewriting observations from the environment.
14
+
15
+ These callbacks can be used for preprocessing of observations, especially
16
+ in multi-agent scenarios.
17
+
18
+ Observation functions can be specified in the multi-agent config by
19
+ specifying ``{"observation_fn": your_obs_func}``. Note that
20
+ ``your_obs_func`` can be a plain Python function.
21
+
22
+ This API is **experimental**.
23
+ """
24
+
25
+ def __call__(
26
+ self,
27
+ agent_obs: Dict[AgentID, TensorType],
28
+ worker: RolloutWorker,
29
+ base_env: BaseEnv,
30
+ policies: Dict[PolicyID, Policy],
31
+ episode,
32
+ **kw
33
+ ) -> Dict[AgentID, TensorType]:
34
+ """Callback run on each environment step to observe the environment.
35
+
36
+ This method takes in the original agent observation dict returned by
37
+ a MultiAgentEnv, and returns a possibly modified one. It can be
38
+ thought of as a "wrapper" around the environment.
39
+
40
+ TODO(ekl): allow end-to-end differentiation through the observation
41
+ function and policy losses.
42
+
43
+ TODO(ekl): enable batch processing.
44
+
45
+ Args:
46
+ agent_obs: Dictionary of default observations from the
47
+ environment. The default implementation of observe() simply
48
+ returns this dict.
49
+ worker: Reference to the current rollout worker.
50
+ base_env: BaseEnv running the episode. The underlying
51
+ sub environment objects (BaseEnvs are vectorized) can be
52
+ retrieved by calling `base_env.get_sub_environments()`.
53
+ policies: Mapping of policy id to policy objects. In single
54
+ agent mode there will only be a single "default" policy.
55
+ episode: Episode state object.
56
+ kwargs: Forward compatibility placeholder.
57
+
58
+ Returns:
59
+ new_agent_obs: copy of agent obs with updates. You can
60
+ rewrite or drop data from the dict if needed (e.g., the env
61
+ can have a dummy "global" observation, and the observer can
62
+ merge the global state into individual observations.
63
+
64
+ .. testcode::
65
+ :skipif: True
66
+
67
+ # Observer that merges global state into individual obs. It is
68
+ # rewriting the discrete obs into a tuple with global state.
69
+ example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
70
+
71
+ .. testoutput::
72
+
73
+ {"a": [1, 101], "b": [2, 101]}
74
+
75
+ .. testcode::
76
+ :skipif: True
77
+
78
+ # Observer for e.g., custom centralized critic model. It is
79
+ # rewriting the discrete obs into a dict with more data.
80
+ example_obs_fn2({"a": 1, "b": 2}, ...)
81
+
82
+ .. testoutput::
83
+
84
+ {"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
85
+ """
86
+
87
+ return agent_obs
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.signal
3
+ from typing import Dict, Optional
4
+
5
+ from ray.rllib.policy.policy import Policy
6
+ from ray.rllib.policy.sample_batch import SampleBatch
7
+ from ray.rllib.utils.annotations import DeveloperAPI, OldAPIStack
8
+ from ray.rllib.utils.numpy import convert_to_numpy
9
+ from ray.rllib.utils.typing import AgentID
10
+ from ray.rllib.utils.typing import TensorType
11
+
12
+
13
+ @DeveloperAPI
14
+ class Postprocessing:
15
+ """Constant definitions for postprocessing."""
16
+
17
+ ADVANTAGES = "advantages"
18
+ VALUE_TARGETS = "value_targets"
19
+
20
+
21
+ @OldAPIStack
22
+ def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
23
+ """Rewrites `batch` to encode n-step rewards, terminateds, truncateds, and next-obs.
24
+
25
+ Observations and actions remain unaffected. At the end of the trajectory,
26
+ n is truncated to fit in the traj length.
27
+
28
+ Args:
29
+ n_step: The number of steps to look ahead and adjust.
30
+ gamma: The discount factor.
31
+ batch: The SampleBatch to adjust (in place).
32
+
33
+ Examples:
34
+ n-step=3
35
+ Trajectory=o0 r0 d0, o1 r1 d1, o2 r2 d2, o3 r3 d3, o4 r4 d4=True o5
36
+ gamma=0.9
37
+ Returned trajectory:
38
+ 0: o0 [r0 + 0.9*r1 + 0.9^2*r2 + 0.9^3*r3] d3 o0'=o3
39
+ 1: o1 [r1 + 0.9*r2 + 0.9^2*r3 + 0.9^3*r4] d4 o1'=o4
40
+ 2: o2 [r2 + 0.9*r3 + 0.9^2*r4] d4 o1'=o5
41
+ 3: o3 [r3 + 0.9*r4] d4 o3'=o5
42
+ 4: o4 r4 d4 o4'=o5
43
+ """
44
+
45
+ assert (
46
+ batch.is_single_trajectory()
47
+ ), "Unexpected terminated|truncated in middle of trajectory!"
48
+
49
+ len_ = len(batch)
50
+
51
+ # Shift NEXT_OBS, TERMINATEDS, and TRUNCATEDS.
52
+ batch[SampleBatch.NEXT_OBS] = np.concatenate(
53
+ [
54
+ batch[SampleBatch.OBS][n_step:],
55
+ np.stack([batch[SampleBatch.NEXT_OBS][-1]] * min(n_step, len_)),
56
+ ],
57
+ axis=0,
58
+ )
59
+ batch[SampleBatch.TERMINATEDS] = np.concatenate(
60
+ [
61
+ batch[SampleBatch.TERMINATEDS][n_step - 1 :],
62
+ np.tile(batch[SampleBatch.TERMINATEDS][-1], min(n_step - 1, len_)),
63
+ ],
64
+ axis=0,
65
+ )
66
+ # Only fix `truncateds`, if present in the batch.
67
+ if SampleBatch.TRUNCATEDS in batch:
68
+ batch[SampleBatch.TRUNCATEDS] = np.concatenate(
69
+ [
70
+ batch[SampleBatch.TRUNCATEDS][n_step - 1 :],
71
+ np.tile(batch[SampleBatch.TRUNCATEDS][-1], min(n_step - 1, len_)),
72
+ ],
73
+ axis=0,
74
+ )
75
+
76
+ # Change rewards in place.
77
+ for i in range(len_):
78
+ for j in range(1, n_step):
79
+ if i + j < len_:
80
+ batch[SampleBatch.REWARDS][i] += (
81
+ gamma**j * batch[SampleBatch.REWARDS][i + j]
82
+ )
83
+
84
+
85
+ @OldAPIStack
86
+ def compute_advantages(
87
+ rollout: SampleBatch,
88
+ last_r: float,
89
+ gamma: float = 0.9,
90
+ lambda_: float = 1.0,
91
+ use_gae: bool = True,
92
+ use_critic: bool = True,
93
+ rewards: TensorType = None,
94
+ vf_preds: TensorType = None,
95
+ ):
96
+ """Given a rollout, compute its value targets and the advantages.
97
+
98
+ Args:
99
+ rollout: SampleBatch of a single trajectory.
100
+ last_r: Value estimation for last observation.
101
+ gamma: Discount factor.
102
+ lambda_: Parameter for GAE.
103
+ use_gae: Using Generalized Advantage Estimation.
104
+ use_critic: Whether to use critic (value estimates). Setting
105
+ this to False will use 0 as baseline.
106
+ rewards: Override the reward values in rollout.
107
+ vf_preds: Override the value function predictions in rollout.
108
+
109
+ Returns:
110
+ SampleBatch with experience from rollout and processed rewards.
111
+ """
112
+ assert (
113
+ SampleBatch.VF_PREDS in rollout or not use_critic
114
+ ), "use_critic=True but values not found"
115
+ assert use_critic or not use_gae, "Can't use gae without using a value function"
116
+ last_r = convert_to_numpy(last_r)
117
+
118
+ if rewards is None:
119
+ rewards = rollout[SampleBatch.REWARDS]
120
+ if vf_preds is None and use_critic:
121
+ vf_preds = rollout[SampleBatch.VF_PREDS]
122
+
123
+ if use_gae:
124
+ vpred_t = np.concatenate([vf_preds, np.array([last_r])])
125
+ delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
126
+ # This formula for the advantage comes from:
127
+ # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
128
+ rollout[Postprocessing.ADVANTAGES] = discount_cumsum(delta_t, gamma * lambda_)
129
+ rollout[Postprocessing.VALUE_TARGETS] = (
130
+ rollout[Postprocessing.ADVANTAGES] + vf_preds
131
+ ).astype(np.float32)
132
+ else:
133
+ rewards_plus_v = np.concatenate([rewards, np.array([last_r])])
134
+ discounted_returns = discount_cumsum(rewards_plus_v, gamma)[:-1].astype(
135
+ np.float32
136
+ )
137
+
138
+ if use_critic:
139
+ rollout[Postprocessing.ADVANTAGES] = discounted_returns - vf_preds
140
+ rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
141
+ else:
142
+ rollout[Postprocessing.ADVANTAGES] = discounted_returns
143
+ rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
144
+ rollout[Postprocessing.ADVANTAGES]
145
+ )
146
+
147
+ rollout[Postprocessing.ADVANTAGES] = rollout[Postprocessing.ADVANTAGES].astype(
148
+ np.float32
149
+ )
150
+
151
+ return rollout
152
+
153
+
154
+ @OldAPIStack
155
+ def compute_gae_for_sample_batch(
156
+ policy: Policy,
157
+ sample_batch: SampleBatch,
158
+ other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
159
+ episode=None,
160
+ ) -> SampleBatch:
161
+ """Adds GAE (generalized advantage estimations) to a trajectory.
162
+
163
+ The trajectory contains only data from one episode and from one agent.
164
+ - If `config.batch_mode=truncate_episodes` (default), sample_batch may
165
+ contain a truncated (at-the-end) episode, in case the
166
+ `config.rollout_fragment_length` was reached by the sampler.
167
+ - If `config.batch_mode=complete_episodes`, sample_batch will contain
168
+ exactly one episode (no matter how long).
169
+ New columns can be added to sample_batch and existing ones may be altered.
170
+
171
+ Args:
172
+ policy: The Policy used to generate the trajectory (`sample_batch`)
173
+ sample_batch: The SampleBatch to postprocess.
174
+ other_agent_batches: Optional dict of AgentIDs mapping to other
175
+ agents' trajectory data (from the same episode).
176
+ NOTE: The other agents use the same policy.
177
+ episode: Optional multi-agent episode object in which the agents
178
+ operated.
179
+
180
+ Returns:
181
+ The postprocessed, modified SampleBatch (or a new one).
182
+ """
183
+ # Compute the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need for the
184
+ # following `last_r` arg in `compute_advantages()`.
185
+ sample_batch = compute_bootstrap_value(sample_batch, policy)
186
+
187
+ vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
188
+ rewards = np.array(sample_batch[SampleBatch.REWARDS])
189
+ # We need to squeeze out the time dimension if there is one
190
+ # Sanity check that both have the same shape
191
+ if len(vf_preds.shape) == 2:
192
+ assert vf_preds.shape == rewards.shape
193
+ vf_preds = np.squeeze(vf_preds, axis=1)
194
+ rewards = np.squeeze(rewards, axis=1)
195
+ squeezed = True
196
+ else:
197
+ squeezed = False
198
+
199
+ # Adds the policy logits, VF preds, and advantages to the batch,
200
+ # using GAE ("generalized advantage estimation") or not.
201
+ batch = compute_advantages(
202
+ rollout=sample_batch,
203
+ last_r=sample_batch[SampleBatch.VALUES_BOOTSTRAPPED][-1],
204
+ gamma=policy.config["gamma"],
205
+ lambda_=policy.config["lambda"],
206
+ use_gae=policy.config["use_gae"],
207
+ use_critic=policy.config.get("use_critic", True),
208
+ vf_preds=vf_preds,
209
+ rewards=rewards,
210
+ )
211
+
212
+ if squeezed:
213
+ # If we needed to squeeze rewards and vf_preds, we need to unsqueeze
214
+ # advantages again for it to have the same shape
215
+ batch[Postprocessing.ADVANTAGES] = np.expand_dims(
216
+ batch[Postprocessing.ADVANTAGES], axis=1
217
+ )
218
+
219
+ return batch
220
+
221
+
222
+ @OldAPIStack
223
+ def compute_bootstrap_value(sample_batch: SampleBatch, policy: Policy) -> SampleBatch:
224
+ """Performs a value function computation at the end of a trajectory.
225
+
226
+ If the trajectory is terminated (not truncated), will not use the value function,
227
+ but assume that the value of the last timestep is 0.0.
228
+ In all other cases, will use the given policy's value function to compute the
229
+ "bootstrapped" value estimate at the end of the given trajectory. To do so, the
230
+ very last observation (sample_batch[NEXT_OBS][-1]) and - if applicable -
231
+ the very last state output (sample_batch[STATE_OUT][-1]) wil be used as inputs to
232
+ the value function.
233
+
234
+ The thus computed value estimate will be stored in a new column of the
235
+ `sample_batch`: SampleBatch.VALUES_BOOTSTRAPPED. Thereby, values at all timesteps
236
+ in this column are set to 0.0, except or the last timestep, which receives the
237
+ computed bootstrapped value.
238
+ This is done, such that in any loss function (which processes raw, intact
239
+ trajectories, such as those of IMPALA and APPO) can use this new column as follows:
240
+
241
+ Example: numbers=ts in episode, '|'=episode boundary (terminal),
242
+ X=bootstrapped value (!= 0.0 b/c ts=12 is not a terminal).
243
+ ts=5 is NOT a terminal.
244
+ T: 8 9 10 11 12 <- no terminal
245
+ VF_PREDS: . . . . .
246
+ VALUES_BOOTSTRAPPED: 0 0 0 0 X
247
+
248
+ Args:
249
+ sample_batch: The SampleBatch (single trajectory) for which to compute the
250
+ bootstrap value at the end. This SampleBatch will be altered in place
251
+ (by adding a new column: SampleBatch.VALUES_BOOTSTRAPPED).
252
+ policy: The Policy object, whose value function to use.
253
+
254
+ Returns:
255
+ The altered SampleBatch (with the extra SampleBatch.VALUES_BOOTSTRAPPED
256
+ column).
257
+ """
258
+ # Trajectory is actually complete -> last r=0.0.
259
+ if sample_batch[SampleBatch.TERMINATEDS][-1]:
260
+ last_r = 0.0
261
+ # Trajectory has been truncated -> last r=VF estimate of last obs.
262
+ else:
263
+ # Input dict is provided to us automatically via the Model's
264
+ # requirements. It's a single-timestep (last one in trajectory)
265
+ # input_dict.
266
+ # Create an input dict according to the Policy's requirements.
267
+ input_dict = sample_batch.get_single_step_input_dict(
268
+ policy.view_requirements, index="last"
269
+ )
270
+ last_r = policy._value(**input_dict)
271
+
272
+ vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
273
+ # We need to squeeze out the time dimension if there is one
274
+ if len(vf_preds.shape) == 2:
275
+ vf_preds = np.squeeze(vf_preds, axis=1)
276
+ squeezed = True
277
+ else:
278
+ squeezed = False
279
+
280
+ # Set the SampleBatch.VALUES_BOOTSTRAPPED field to VF_PREDS[1:] + the
281
+ # very last timestep (where this bootstrapping value is actually needed), which
282
+ # we set to the computed `last_r`.
283
+ sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate(
284
+ [
285
+ convert_to_numpy(vf_preds[1:]),
286
+ np.array([convert_to_numpy(last_r)], dtype=np.float32),
287
+ ],
288
+ axis=0,
289
+ )
290
+
291
+ if squeezed:
292
+ sample_batch[SampleBatch.VF_PREDS] = np.expand_dims(vf_preds, axis=1)
293
+ sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.expand_dims(
294
+ sample_batch[SampleBatch.VALUES_BOOTSTRAPPED], axis=1
295
+ )
296
+
297
+ return sample_batch
298
+
299
+
300
+ @OldAPIStack
301
+ def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
302
+ """Calculates the discounted cumulative sum over a reward sequence `x`.
303
+
304
+ y[t] - discount*y[t+1] = x[t]
305
+ reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
306
+
307
+ Args:
308
+ gamma: The discount factor gamma.
309
+
310
+ Returns:
311
+ The sequence containing the discounted cumulative sums
312
+ for each individual reward in `x` till the end of the trajectory.
313
+
314
+ .. testcode::
315
+ :skipif: True
316
+
317
+ x = np.array([0.0, 1.0, 2.0, 3.0])
318
+ gamma = 0.9
319
+ discount_cumsum(x, gamma)
320
+
321
+ .. testoutput::
322
+
323
+ array([0.0 + 0.9*1.0 + 0.9^2*2.0 + 0.9^3*3.0,
324
+ 1.0 + 0.9*2.0 + 0.9^2*3.0,
325
+ 2.0 + 0.9*3.0,
326
+ 3.0])
327
+ """
328
+ return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py ADDED
@@ -0,0 +1,2004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import importlib.util
3
+ import logging
4
+ import os
5
+ import platform
6
+ import threading
7
+ from collections import defaultdict
8
+ from types import FunctionType
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ Collection,
14
+ Dict,
15
+ List,
16
+ Optional,
17
+ Set,
18
+ Tuple,
19
+ Type,
20
+ Union,
21
+ )
22
+
23
+ from gymnasium.spaces import Space
24
+
25
+ import ray
26
+ from ray import ObjectRef
27
+ from ray import cloudpickle as pickle
28
+ from ray.rllib.connectors.util import (
29
+ create_connectors_for_policy,
30
+ maybe_get_filters_for_syncing,
31
+ )
32
+ from ray.rllib.core.rl_module import validate_module_id
33
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
34
+ from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
35
+ from ray.rllib.env.env_context import EnvContext
36
+ from ray.rllib.env.env_runner import EnvRunner
37
+ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
38
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
39
+ from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
40
+ from ray.rllib.evaluation.metrics import RolloutMetrics
41
+ from ray.rllib.evaluation.sampler import SyncSampler
42
+ from ray.rllib.models import ModelCatalog
43
+ from ray.rllib.models.preprocessors import Preprocessor
44
+ from ray.rllib.offline import (
45
+ D4RLReader,
46
+ DatasetReader,
47
+ DatasetWriter,
48
+ InputReader,
49
+ IOContext,
50
+ JsonReader,
51
+ JsonWriter,
52
+ MixedInput,
53
+ NoopOutput,
54
+ OutputWriter,
55
+ ShuffledInput,
56
+ )
57
+ from ray.rllib.policy.policy import Policy, PolicySpec
58
+ from ray.rllib.policy.policy_map import PolicyMap
59
+ from ray.rllib.policy.sample_batch import (
60
+ DEFAULT_POLICY_ID,
61
+ MultiAgentBatch,
62
+ concat_samples,
63
+ convert_ma_batch_to_sample_batch,
64
+ )
65
+ from ray.rllib.policy.torch_policy import TorchPolicy
66
+ from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
67
+ from ray.rllib.utils import force_list
68
+ from ray.rllib.utils.annotations import OldAPIStack, override
69
+ from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
70
+ from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
71
+ from ray.rllib.utils.filter import Filter, NoFilter
72
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
73
+ from ray.rllib.utils.from_config import from_config
74
+ from ray.rllib.utils.policy import create_policy_for_framework
75
+ from ray.rllib.utils.sgd import do_minibatch_sgd
76
+ from ray.rllib.utils.tf_run_builder import _TFRunBuilder
77
+ from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
78
+ from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
79
+ from ray.rllib.utils.typing import (
80
+ AgentID,
81
+ EnvCreator,
82
+ EnvType,
83
+ ModelGradients,
84
+ ModelWeights,
85
+ MultiAgentPolicyConfigDict,
86
+ PartialAlgorithmConfigDict,
87
+ PolicyID,
88
+ PolicyState,
89
+ SampleBatchType,
90
+ T,
91
+ )
92
+ from ray.tune.registry import registry_contains_input, registry_get_input
93
+ from ray.util.annotations import PublicAPI
94
+ from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once
95
+ from ray.util.iter import ParallelIteratorWorker
96
+
97
+ if TYPE_CHECKING:
98
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
99
+ from ray.rllib.callbacks.callbacks import RLlibCallback
100
+
101
+ tf1, tf, tfv = try_import_tf()
102
+ torch, _ = try_import_torch()
103
+
104
+ logger = logging.getLogger(__name__)
105
+
106
+ # Handle to the current rollout worker, which will be set to the most recently
107
+ # created RolloutWorker in this process. This can be helpful to access in
108
+ # custom env or policy classes for debugging or advanced use cases.
109
+ _global_worker: Optional["RolloutWorker"] = None
110
+
111
+
112
+ @OldAPIStack
113
+ def get_global_worker() -> "RolloutWorker":
114
+ """Returns a handle to the active rollout worker in this process."""
115
+
116
+ global _global_worker
117
+ return _global_worker
118
+
119
+
120
+ def _update_env_seed_if_necessary(
121
+ env: EnvType, seed: int, worker_idx: int, vector_idx: int
122
+ ):
123
+ """Set a deterministic random seed on environment.
124
+
125
+ NOTE: this may not work with remote environments (issue #18154).
126
+ """
127
+ if seed is None:
128
+ return
129
+
130
+ # A single RL job is unlikely to have more than 10K
131
+ # rollout workers.
132
+ max_num_envs_per_env_runner: int = 1000
133
+ assert (
134
+ worker_idx < max_num_envs_per_env_runner
135
+ ), "Too many envs per worker. Random seeds may collide."
136
+ computed_seed: int = worker_idx * max_num_envs_per_env_runner + vector_idx + seed
137
+
138
+ # Gymnasium.env.
139
+ # This will silently fail for most Farama-foundation gymnasium environments.
140
+ # (they do nothing and return None per default)
141
+ if not hasattr(env, "reset"):
142
+ if log_once("env_has_no_reset_method"):
143
+ logger.info(f"Env {env} doesn't have a `reset()` method. Cannot seed.")
144
+ else:
145
+ try:
146
+ env.reset(seed=computed_seed)
147
+ except Exception:
148
+ logger.info(
149
+ f"Env {env} doesn't support setting a seed via its `reset()` "
150
+ "method! Implement this method as `reset(self, *, seed=None, "
151
+ "options=None)` for it to abide to the correct API. Cannot seed."
152
+ )
153
+
154
+
155
+ @OldAPIStack
156
+ class RolloutWorker(ParallelIteratorWorker, EnvRunner):
157
+ """Common experience collection class.
158
+
159
+ This class wraps a policy instance and an environment class to
160
+ collect experiences from the environment. You can create many replicas of
161
+ this class as Ray actors to scale RL training.
162
+
163
+ This class supports vectorized and multi-agent policy evaluation (e.g.,
164
+ VectorEnv, MultiAgentEnv, etc.)
165
+
166
+ .. testcode::
167
+ :skipif: True
168
+
169
+ # Create a rollout worker and using it to collect experiences.
170
+ import gymnasium as gym
171
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
172
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
173
+ worker = RolloutWorker(
174
+ env_creator=lambda _: gym.make("CartPole-v1"),
175
+ default_policy_class=PPOTF1Policy)
176
+ print(worker.sample())
177
+
178
+ # Creating a multi-agent rollout worker
179
+ from gymnasium.spaces import Discrete, Box
180
+ import random
181
+ MultiAgentTrafficGrid = ...
182
+ worker = RolloutWorker(
183
+ env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
184
+ config=AlgorithmConfig().multi_agent(
185
+ policies={
186
+ # Use an ensemble of two policies for car agents
187
+ "car_policy1":
188
+ (PGTFPolicy, Box(...), Discrete(...),
189
+ AlgorithmConfig.overrides(gamma=0.99)),
190
+ "car_policy2":
191
+ (PGTFPolicy, Box(...), Discrete(...),
192
+ AlgorithmConfig.overrides(gamma=0.95)),
193
+ # Use a single shared policy for all traffic lights
194
+ "traffic_light_policy":
195
+ (PGTFPolicy, Box(...), Discrete(...), {}),
196
+ },
197
+ policy_mapping_fn=(
198
+ lambda agent_id, episode, **kwargs:
199
+ random.choice(["car_policy1", "car_policy2"])
200
+ if agent_id.startswith("car_") else "traffic_light_policy"),
201
+ ),
202
+ )
203
+ print(worker.sample())
204
+
205
+ .. testoutput::
206
+
207
+ SampleBatch({
208
+ "obs": [[...]], "actions": [[...]], "rewards": [[...]],
209
+ "terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]}
210
+ )
211
+
212
+ MultiAgentBatch({
213
+ "car_policy1": SampleBatch(...),
214
+ "car_policy2": SampleBatch(...),
215
+ "traffic_light_policy": SampleBatch(...)}
216
+ )
217
+
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ *,
223
+ env_creator: EnvCreator,
224
+ validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
225
+ config: Optional["AlgorithmConfig"] = None,
226
+ worker_index: int = 0,
227
+ num_workers: Optional[int] = None,
228
+ recreated_worker: bool = False,
229
+ log_dir: Optional[str] = None,
230
+ spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
231
+ default_policy_class: Optional[Type[Policy]] = None,
232
+ dataset_shards: Optional[List[ray.data.Dataset]] = None,
233
+ **kwargs,
234
+ ):
235
+ """Initializes a RolloutWorker instance.
236
+
237
+ Args:
238
+ env_creator: Function that returns a gym.Env given an EnvContext
239
+ wrapped configuration.
240
+ validate_env: Optional callable to validate the generated
241
+ environment (only on worker=0).
242
+ worker_index: For remote workers, this should be set to a
243
+ non-zero and unique value. This index is passed to created envs
244
+ through EnvContext so that envs can be configured per worker.
245
+ recreated_worker: Whether this worker is a recreated one. Workers are
246
+ recreated by an Algorithm (via EnvRunnerGroup) in case
247
+ `restart_failed_env_runners=True` and one of the original workers (or
248
+ an already recreated one) has failed. They don't differ from original
249
+ workers other than the value of this flag (`self.recreated_worker`).
250
+ log_dir: Directory where logs can be placed.
251
+ spaces: An optional space dict mapping policy IDs
252
+ to (obs_space, action_space)-tuples. This is used in case no
253
+ Env is created on this RolloutWorker.
254
+ """
255
+ self._original_kwargs: dict = locals().copy()
256
+ del self._original_kwargs["self"]
257
+
258
+ global _global_worker
259
+ _global_worker = self
260
+
261
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
262
+
263
+ # Default config needed?
264
+ if config is None or isinstance(config, dict):
265
+ config = AlgorithmConfig().update_from_dict(config or {})
266
+ # Freeze config, so no one else can alter it from here on.
267
+ config.freeze()
268
+
269
+ # Set extra python env variables before calling super constructor.
270
+ if config.extra_python_environs_for_driver and worker_index == 0:
271
+ for key, value in config.extra_python_environs_for_driver.items():
272
+ os.environ[key] = str(value)
273
+ elif config.extra_python_environs_for_worker and worker_index > 0:
274
+ for key, value in config.extra_python_environs_for_worker.items():
275
+ os.environ[key] = str(value)
276
+
277
+ def gen_rollouts():
278
+ while True:
279
+ yield self.sample()
280
+
281
+ ParallelIteratorWorker.__init__(self, gen_rollouts, False)
282
+ EnvRunner.__init__(self, config=config)
283
+
284
+ self.num_workers = (
285
+ num_workers if num_workers is not None else self.config.num_env_runners
286
+ )
287
+ # In case we are reading from distributed datasets, store the shards here
288
+ # and pick our shard by our worker-index.
289
+ self._ds_shards = dataset_shards
290
+ self.worker_index: int = worker_index
291
+
292
+ # Lock to be able to lock this entire worker
293
+ # (via `self.lock()` and `self.unlock()`).
294
+ # This might be crucial to prevent a race condition in case
295
+ # `config.policy_states_are_swappable=True` and you are using an Algorithm
296
+ # with a learner thread. In this case, the thread might update a policy
297
+ # that is being swapped (during the update) by the Algorithm's
298
+ # training_step's `RolloutWorker.get_weights()` call (to sync back the
299
+ # new weights to all remote workers).
300
+ self._lock = threading.Lock()
301
+
302
+ if (
303
+ tf1
304
+ and (config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
305
+ # This eager check is necessary for certain all-framework tests
306
+ # that use tf's eager_mode() context generator.
307
+ and not tf1.executing_eagerly()
308
+ ):
309
+ tf1.enable_eager_execution()
310
+
311
+ if self.config.log_level:
312
+ logging.getLogger("ray.rllib").setLevel(self.config.log_level)
313
+
314
+ if self.worker_index > 1:
315
+ disable_log_once_globally() # only need 1 worker to log
316
+ elif self.config.log_level == "DEBUG":
317
+ enable_periodic_logging()
318
+
319
+ env_context = EnvContext(
320
+ self.config.env_config,
321
+ worker_index=self.worker_index,
322
+ vector_index=0,
323
+ num_workers=self.num_workers,
324
+ remote=self.config.remote_worker_envs,
325
+ recreated_worker=recreated_worker,
326
+ )
327
+ self.env_context = env_context
328
+ self.config: AlgorithmConfig = config
329
+ self.callbacks: RLlibCallback = self.config.callbacks_class()
330
+ self.recreated_worker: bool = recreated_worker
331
+
332
+ # Setup current policy_mapping_fn. Start with the one from the config, which
333
+ # might be None in older checkpoints (nowadays AlgorithmConfig has a proper
334
+ # default for this); Need to cover this situation via the backup lambda here.
335
+ self.policy_mapping_fn = (
336
+ lambda agent_id, episode, worker, **kw: DEFAULT_POLICY_ID
337
+ )
338
+ self.set_policy_mapping_fn(self.config.policy_mapping_fn)
339
+
340
+ self.env_creator: EnvCreator = env_creator
341
+ # Resolve possible auto-fragment length.
342
+ configured_rollout_fragment_length = self.config.get_rollout_fragment_length(
343
+ worker_index=self.worker_index
344
+ )
345
+ self.total_rollout_fragment_length: int = (
346
+ configured_rollout_fragment_length * self.config.num_envs_per_env_runner
347
+ )
348
+ self.preprocessing_enabled: bool = not config._disable_preprocessor_api
349
+ self.last_batch: Optional[SampleBatchType] = None
350
+ self.global_vars: dict = {
351
+ # TODO(sven): Make this per-policy!
352
+ "timestep": 0,
353
+ # Counter for performed gradient updates per policy in `self.policy_map`.
354
+ # Allows for compiling metrics on the off-policy'ness of an update given
355
+ # that the number of gradient updates of the sampling policies are known
356
+ # to the learner (and can be compared to the learner version of the same
357
+ # policy).
358
+ "num_grad_updates_per_policy": defaultdict(int),
359
+ }
360
+
361
+ # If seed is provided, add worker index to it and 10k iff evaluation worker.
362
+ self.seed = (
363
+ None
364
+ if self.config.seed is None
365
+ else self.config.seed
366
+ + self.worker_index
367
+ + self.config.in_evaluation * 10000
368
+ )
369
+
370
+ # Update the global seed for numpy/random/tf-eager/torch if we are not
371
+ # the local worker, otherwise, this was already done in the Algorithm
372
+ # object itself.
373
+ if self.worker_index > 0:
374
+ update_global_seed_if_necessary(self.config.framework_str, self.seed)
375
+
376
+ # A single environment provided by the user (via config.env). This may
377
+ # also remain None.
378
+ # 1) Create the env using the user provided env_creator. This may
379
+ # return a gym.Env (incl. MultiAgentEnv), an already vectorized
380
+ # VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env).
381
+ # 2) Wrap - if applicable - with Atari/rendering wrappers.
382
+ # 3) Seed the env, if necessary.
383
+ # 4) Vectorize the existing single env by creating more clones of
384
+ # this env and wrapping it with the RLlib BaseEnv class.
385
+ self.env = self.make_sub_env_fn = None
386
+
387
+ # Create a (single) env for this worker.
388
+ if not (
389
+ self.worker_index == 0
390
+ and self.num_workers > 0
391
+ and not self.config.create_env_on_local_worker
392
+ ):
393
+ # Run the `env_creator` function passing the EnvContext.
394
+ self.env = env_creator(copy.deepcopy(self.env_context))
395
+
396
+ clip_rewards = self.config.clip_rewards
397
+
398
+ if self.env is not None:
399
+ # Custom validation function given, typically a function attribute of the
400
+ # Algorithm.
401
+ if validate_env is not None:
402
+ validate_env(self.env, self.env_context)
403
+
404
+ # We can't auto-wrap a BaseEnv.
405
+ if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)):
406
+
407
+ def wrap(env):
408
+ return env
409
+
410
+ # Atari type env and "deepmind" preprocessor pref.
411
+ elif is_atari(self.env) and self.config.preprocessor_pref == "deepmind":
412
+ # Deepmind wrappers already handle all preprocessing.
413
+ self.preprocessing_enabled = False
414
+
415
+ # If clip_rewards not explicitly set to False, switch it
416
+ # on here (clip between -1.0 and 1.0).
417
+ if self.config.clip_rewards is None:
418
+ clip_rewards = True
419
+
420
+ # Framestacking is used.
421
+ use_framestack = self.config.model.get("framestack") is True
422
+
423
+ def wrap(env):
424
+ env = wrap_deepmind(
425
+ env,
426
+ dim=self.config.model.get("dim"),
427
+ framestack=use_framestack,
428
+ noframeskip=self.config.env_config.get("frameskip", 0) == 1,
429
+ )
430
+ return env
431
+
432
+ elif self.config.preprocessor_pref is None:
433
+ # Only turn off preprocessing
434
+ self.preprocessing_enabled = False
435
+
436
+ def wrap(env):
437
+ return env
438
+
439
+ else:
440
+
441
+ def wrap(env):
442
+ return env
443
+
444
+ # Wrap env through the correct wrapper.
445
+ self.env: EnvType = wrap(self.env)
446
+ # Ideally, we would use the same make_sub_env() function below
447
+ # to create self.env, but wrap(env) and self.env has a cyclic
448
+ # dependency on each other right now, so we would settle on
449
+ # duplicating the random seed setting logic for now.
450
+ _update_env_seed_if_necessary(self.env, self.seed, self.worker_index, 0)
451
+ # Call custom callback function `on_sub_environment_created`.
452
+ self.callbacks.on_sub_environment_created(
453
+ worker=self,
454
+ sub_environment=self.env,
455
+ env_context=self.env_context,
456
+ )
457
+
458
+ self.make_sub_env_fn = self._get_make_sub_env_fn(
459
+ env_creator, env_context, validate_env, wrap, self.seed
460
+ )
461
+
462
+ self.spaces = spaces
463
+ self.default_policy_class = default_policy_class
464
+ self.policy_dict, self.is_policy_to_train = self.config.get_multi_agent_setup(
465
+ env=self.env,
466
+ spaces=self.spaces,
467
+ default_policy_class=self.default_policy_class,
468
+ )
469
+
470
+ self.policy_map: Optional[PolicyMap] = None
471
+ # TODO(jungong) : clean up after non-connector env_runner is fully deprecated.
472
+ self.preprocessors: Dict[PolicyID, Preprocessor] = None
473
+
474
+ # Check available number of GPUs.
475
+ num_gpus = (
476
+ self.config.num_gpus
477
+ if self.worker_index == 0
478
+ else self.config.num_gpus_per_env_runner
479
+ )
480
+
481
+ # Error if we don't find enough GPUs.
482
+ if (
483
+ ray.is_initialized()
484
+ and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE
485
+ and not config._fake_gpus
486
+ ):
487
+ devices = []
488
+ if self.config.framework_str in ["tf2", "tf"]:
489
+ devices = get_tf_gpu_devices()
490
+ elif self.config.framework_str == "torch":
491
+ devices = list(range(torch.cuda.device_count()))
492
+
493
+ if len(devices) < num_gpus:
494
+ raise RuntimeError(
495
+ ERR_MSG_NO_GPUS.format(len(devices), devices) + HOWTO_CHANGE_CONFIG
496
+ )
497
+ # Warn, if running in local-mode and actual GPUs (not faked) are
498
+ # requested.
499
+ elif (
500
+ ray.is_initialized()
501
+ and ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
502
+ and num_gpus > 0
503
+ and not self.config._fake_gpus
504
+ ):
505
+ logger.warning(
506
+ "You are running ray with `local_mode=True`, but have "
507
+ f"configured {num_gpus} GPUs to be used! In local mode, "
508
+ f"Policies are placed on the CPU and the `num_gpus` setting "
509
+ f"is ignored."
510
+ )
511
+
512
+ self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter)
513
+
514
+ # If RLModule API is enabled, multi_rl_module_spec holds the specs of the
515
+ # RLModules.
516
+ self.multi_rl_module_spec = None
517
+ self._update_policy_map(policy_dict=self.policy_dict)
518
+
519
+ # Update Policy's view requirements from Model, only if Policy directly
520
+ # inherited from base `Policy` class. At this point here, the Policy
521
+ # must have it's Model (if any) defined and ready to output an initial
522
+ # state.
523
+ for pol in self.policy_map.values():
524
+ if not pol._model_init_state_automatically_added:
525
+ pol._update_model_view_requirements_from_init_state()
526
+
527
+ if (
528
+ self.config.is_multi_agent
529
+ and self.env is not None
530
+ and not isinstance(
531
+ self.env,
532
+ (BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle),
533
+ )
534
+ ):
535
+ raise ValueError(
536
+ f"You are running a multi-agent setup, but the env {self.env} is not a "
537
+ f"subclass of BaseEnv, MultiAgentEnv, ActorHandle, or "
538
+ f"ExternalMultiAgentEnv!"
539
+ )
540
+
541
+ if self.worker_index == 0:
542
+ logger.info("Built filter map: {}".format(self.filters))
543
+
544
+ # This RolloutWorker has no env.
545
+ if self.env is None:
546
+ self.async_env = None
547
+ # Use a custom env-vectorizer and call it providing self.env.
548
+ elif "custom_vector_env" in self.config:
549
+ self.async_env = self.config.custom_vector_env(self.env)
550
+ # Default: Vectorize self.env via the make_sub_env function. This adds
551
+ # further clones of self.env and creates a RLlib BaseEnv (which is
552
+ # vectorized under the hood).
553
+ else:
554
+ # Always use vector env for consistency even if num_envs_per_env_runner=1.
555
+ self.async_env: BaseEnv = convert_to_base_env(
556
+ self.env,
557
+ make_env=self.make_sub_env_fn,
558
+ num_envs=self.config.num_envs_per_env_runner,
559
+ remote_envs=self.config.remote_worker_envs,
560
+ remote_env_batch_wait_ms=self.config.remote_env_batch_wait_ms,
561
+ worker=self,
562
+ restart_failed_sub_environments=(
563
+ self.config.restart_failed_sub_environments
564
+ ),
565
+ )
566
+
567
+ # `truncate_episodes`: Allow a batch to contain more than one episode
568
+ # (fragments) and always make the batch `rollout_fragment_length`
569
+ # long.
570
+ rollout_fragment_length_for_sampler = configured_rollout_fragment_length
571
+ if self.config.batch_mode == "truncate_episodes":
572
+ pack = True
573
+ # `complete_episodes`: Never cut episodes and sampler will return
574
+ # exactly one (complete) episode per poll.
575
+ else:
576
+ assert self.config.batch_mode == "complete_episodes"
577
+ rollout_fragment_length_for_sampler = float("inf")
578
+ pack = False
579
+
580
+ # Create the IOContext for this worker.
581
+ self.io_context: IOContext = IOContext(
582
+ log_dir, self.config, self.worker_index, self
583
+ )
584
+
585
+ render = False
586
+ if self.config.render_env is True and (
587
+ self.num_workers == 0 or self.worker_index == 1
588
+ ):
589
+ render = True
590
+
591
+ if self.env is None:
592
+ self.sampler = None
593
+ else:
594
+ self.sampler = SyncSampler(
595
+ worker=self,
596
+ env=self.async_env,
597
+ clip_rewards=clip_rewards,
598
+ rollout_fragment_length=rollout_fragment_length_for_sampler,
599
+ count_steps_by=self.config.count_steps_by,
600
+ callbacks=self.callbacks,
601
+ multiple_episodes_in_batch=pack,
602
+ normalize_actions=self.config.normalize_actions,
603
+ clip_actions=self.config.clip_actions,
604
+ observation_fn=self.config.observation_fn,
605
+ sample_collector_class=self.config.sample_collector,
606
+ render=render,
607
+ )
608
+
609
+ self.input_reader: InputReader = self._get_input_creator_from_config()(
610
+ self.io_context
611
+ )
612
+ self.output_writer: OutputWriter = self._get_output_creator_from_config()(
613
+ self.io_context
614
+ )
615
+
616
+ # The current weights sequence number (version). May remain None for when
617
+ # not tracking weights versions.
618
+ self.weights_seq_no: Optional[int] = None
619
+
620
+ @override(EnvRunner)
621
+ def make_env(self):
622
+ # Override this method, b/c it's abstract and must be overridden.
623
+ # However, we see no point in implementing it for the old API stack any longer
624
+ # (the RolloutWorker class will be deprecated soon).
625
+ raise NotImplementedError
626
+
627
+ @override(EnvRunner)
628
+ def assert_healthy(self):
629
+ is_healthy = self.policy_map and self.input_reader and self.output_writer
630
+ assert is_healthy, (
631
+ f"RolloutWorker {self} (idx={self.worker_index}; "
632
+ f"num_workers={self.num_workers}) not healthy!"
633
+ )
634
+
635
+ @override(EnvRunner)
636
+ def sample(self, **kwargs) -> SampleBatchType:
637
+ """Returns a batch of experience sampled from this worker.
638
+
639
+ This method must be implemented by subclasses.
640
+
641
+ Returns:
642
+ A columnar batch of experiences (e.g., tensors) or a MultiAgentBatch.
643
+
644
+ .. testcode::
645
+ :skipif: True
646
+
647
+ import gymnasium as gym
648
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
649
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
650
+ worker = RolloutWorker(
651
+ env_creator=lambda _: gym.make("CartPole-v1"),
652
+ default_policy_class=PPOTF1Policy,
653
+ config=AlgorithmConfig(),
654
+ )
655
+ print(worker.sample())
656
+
657
+ .. testoutput::
658
+
659
+ SampleBatch({"obs": [...], "action": [...], ...})
660
+ """
661
+ if self.config.fake_sampler and self.last_batch is not None:
662
+ return self.last_batch
663
+ elif self.input_reader is None:
664
+ raise ValueError(
665
+ "RolloutWorker has no `input_reader` object! "
666
+ "Cannot call `sample()`. You can try setting "
667
+ "`create_env_on_driver` to True."
668
+ )
669
+
670
+ if log_once("sample_start"):
671
+ logger.info(
672
+ "Generating sample batch of size {}".format(
673
+ self.total_rollout_fragment_length
674
+ )
675
+ )
676
+
677
+ batches = [self.input_reader.next()]
678
+ steps_so_far = (
679
+ batches[0].count
680
+ if self.config.count_steps_by == "env_steps"
681
+ else batches[0].agent_steps()
682
+ )
683
+
684
+ # In truncate_episodes mode, never pull more than 1 batch per env.
685
+ # This avoids over-running the target batch size.
686
+ if (
687
+ self.config.batch_mode == "truncate_episodes"
688
+ and not self.config.offline_sampling
689
+ ):
690
+ max_batches = self.config.num_envs_per_env_runner
691
+ else:
692
+ max_batches = float("inf")
693
+ while steps_so_far < self.total_rollout_fragment_length and (
694
+ len(batches) < max_batches
695
+ ):
696
+ batch = self.input_reader.next()
697
+ steps_so_far += (
698
+ batch.count
699
+ if self.config.count_steps_by == "env_steps"
700
+ else batch.agent_steps()
701
+ )
702
+ batches.append(batch)
703
+
704
+ batch = concat_samples(batches)
705
+
706
+ self.callbacks.on_sample_end(worker=self, samples=batch)
707
+
708
+ # Always do writes prior to compression for consistency and to allow
709
+ # for better compression inside the writer.
710
+ self.output_writer.write(batch)
711
+
712
+ if log_once("sample_end"):
713
+ logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch)))
714
+
715
+ if self.config.compress_observations:
716
+ batch.compress(bulk=self.config.compress_observations == "bulk")
717
+
718
+ if self.config.fake_sampler:
719
+ self.last_batch = batch
720
+
721
+ return batch
722
+
723
+ @override(EnvRunner)
724
+ def get_spaces(self) -> Dict[str, Tuple[Space, Space]]:
725
+ spaces = self.foreach_policy(
726
+ lambda p, pid: (pid, p.observation_space, p.action_space)
727
+ )
728
+ spaces = {e[0]: (getattr(e[1], "original_space", e[1]), e[2]) for e in spaces}
729
+ # Try to add the actual env's obs/action spaces.
730
+ env_spaces = self.foreach_env(
731
+ lambda env: (env.observation_space, env.action_space)
732
+ )
733
+ if env_spaces:
734
+ from ray.rllib.env import INPUT_ENV_SPACES
735
+
736
+ spaces[INPUT_ENV_SPACES] = env_spaces[0]
737
+ return spaces
738
+
739
+ @ray.method(num_returns=2)
740
+ def sample_with_count(self) -> Tuple[SampleBatchType, int]:
741
+ """Same as sample() but returns the count as a separate value.
742
+
743
+ Returns:
744
+ A columnar batch of experiences (e.g., tensors) and the
745
+ size of the collected batch.
746
+
747
+ .. testcode::
748
+ :skipif: True
749
+
750
+ import gymnasium as gym
751
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
752
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
753
+ worker = RolloutWorker(
754
+ env_creator=lambda _: gym.make("CartPole-v1"),
755
+ default_policy_class=PPOTFPolicy)
756
+ print(worker.sample_with_count())
757
+
758
+ .. testoutput::
759
+
760
+ (SampleBatch({"obs": [...], "action": [...], ...}), 3)
761
+ """
762
+ batch = self.sample()
763
+ return batch, batch.count
764
+
765
+ def learn_on_batch(self, samples: SampleBatchType) -> Dict:
766
+ """Update policies based on the given batch.
767
+
768
+ This is the equivalent to apply_gradients(compute_gradients(samples)),
769
+ but can be optimized to avoid pulling gradients into CPU memory.
770
+
771
+ Args:
772
+ samples: The SampleBatch or MultiAgentBatch to learn on.
773
+
774
+ Returns:
775
+ Dictionary of extra metadata from compute_gradients().
776
+
777
+ .. testcode::
778
+ :skipif: True
779
+
780
+ import gymnasium as gym
781
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
782
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
783
+ worker = RolloutWorker(
784
+ env_creator=lambda _: gym.make("CartPole-v1"),
785
+ default_policy_class=PPOTF1Policy)
786
+ batch = worker.sample()
787
+ info = worker.learn_on_batch(samples)
788
+ """
789
+ if log_once("learn_on_batch"):
790
+ logger.info(
791
+ "Training on concatenated sample batches:\n\n{}\n".format(
792
+ summarize(samples)
793
+ )
794
+ )
795
+
796
+ info_out = {}
797
+ if isinstance(samples, MultiAgentBatch):
798
+ builders = {}
799
+ to_fetch = {}
800
+ for pid, batch in samples.policy_batches.items():
801
+ if self.is_policy_to_train is not None and not self.is_policy_to_train(
802
+ pid, samples
803
+ ):
804
+ continue
805
+ # Decompress SampleBatch, in case some columns are compressed.
806
+ batch.decompress_if_needed()
807
+
808
+ policy = self.policy_map[pid]
809
+ tf_session = policy.get_session()
810
+ if tf_session and hasattr(policy, "_build_learn_on_batch"):
811
+ builders[pid] = _TFRunBuilder(tf_session, "learn_on_batch")
812
+ to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch)
813
+ else:
814
+ info_out[pid] = policy.learn_on_batch(batch)
815
+
816
+ info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()})
817
+ else:
818
+ if self.is_policy_to_train is None or self.is_policy_to_train(
819
+ DEFAULT_POLICY_ID, samples
820
+ ):
821
+ info_out.update(
822
+ {
823
+ DEFAULT_POLICY_ID: self.policy_map[
824
+ DEFAULT_POLICY_ID
825
+ ].learn_on_batch(samples)
826
+ }
827
+ )
828
+ if log_once("learn_out"):
829
+ logger.debug("Training out:\n\n{}\n".format(summarize(info_out)))
830
+ return info_out
831
+
832
+ def sample_and_learn(
833
+ self,
834
+ expected_batch_size: int,
835
+ num_sgd_iter: int,
836
+ sgd_minibatch_size: str,
837
+ standardize_fields: List[str],
838
+ ) -> Tuple[dict, int]:
839
+ """Sample and batch and learn on it.
840
+
841
+ This is typically used in combination with distributed allreduce.
842
+
843
+ Args:
844
+ expected_batch_size: Expected number of samples to learn on.
845
+ num_sgd_iter: Number of SGD iterations.
846
+ sgd_minibatch_size: SGD minibatch size.
847
+ standardize_fields: List of sample fields to normalize.
848
+
849
+ Returns:
850
+ A tuple consisting of a dictionary of extra metadata returned from
851
+ the policies' `learn_on_batch()` and the number of samples
852
+ learned on.
853
+ """
854
+ batch = self.sample()
855
+ assert batch.count == expected_batch_size, (
856
+ "Batch size possibly out of sync between workers, expected:",
857
+ expected_batch_size,
858
+ "got:",
859
+ batch.count,
860
+ )
861
+ logger.info(
862
+ "Executing distributed minibatch SGD "
863
+ "with epoch size {}, minibatch size {}".format(
864
+ batch.count, sgd_minibatch_size
865
+ )
866
+ )
867
+ info = do_minibatch_sgd(
868
+ batch,
869
+ self.policy_map,
870
+ self,
871
+ num_sgd_iter,
872
+ sgd_minibatch_size,
873
+ standardize_fields,
874
+ )
875
+ return info, batch.count
876
+
877
+ def compute_gradients(
878
+ self,
879
+ samples: SampleBatchType,
880
+ single_agent: bool = None,
881
+ ) -> Tuple[ModelGradients, dict]:
882
+ """Returns a gradient computed w.r.t the specified samples.
883
+
884
+ Uses the Policy's/ies' compute_gradients method(s) to perform the
885
+ calculations. Skips policies that are not trainable as per
886
+ `self.is_policy_to_train()`.
887
+
888
+ Args:
889
+ samples: The SampleBatch or MultiAgentBatch to compute gradients
890
+ for using this worker's trainable policies.
891
+
892
+ Returns:
893
+ In the single-agent case, a tuple consisting of ModelGradients and
894
+ info dict of the worker's policy.
895
+ In the multi-agent case, a tuple consisting of a dict mapping
896
+ PolicyID to ModelGradients and a dict mapping PolicyID to extra
897
+ metadata info.
898
+ Note that the first return value (grads) can be applied as is to a
899
+ compatible worker using the worker's `apply_gradients()` method.
900
+
901
+ .. testcode::
902
+ :skipif: True
903
+
904
+ import gymnasium as gym
905
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
906
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
907
+ worker = RolloutWorker(
908
+ env_creator=lambda _: gym.make("CartPole-v1"),
909
+ default_policy_class=PPOTF1Policy)
910
+ batch = worker.sample()
911
+ grads, info = worker.compute_gradients(samples)
912
+ """
913
+ if log_once("compute_gradients"):
914
+ logger.info("Compute gradients on:\n\n{}\n".format(summarize(samples)))
915
+
916
+ if single_agent is True:
917
+ samples = convert_ma_batch_to_sample_batch(samples)
918
+ grad_out, info_out = self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
919
+ samples
920
+ )
921
+ info_out["batch_count"] = samples.count
922
+ return grad_out, info_out
923
+
924
+ # Treat everything as is multi-agent.
925
+ samples = samples.as_multi_agent()
926
+
927
+ # Calculate gradients for all policies.
928
+ grad_out, info_out = {}, {}
929
+ if self.config.framework_str == "tf":
930
+ for pid, batch in samples.policy_batches.items():
931
+ if self.is_policy_to_train is not None and not self.is_policy_to_train(
932
+ pid, samples
933
+ ):
934
+ continue
935
+ policy = self.policy_map[pid]
936
+ builder = _TFRunBuilder(policy.get_session(), "compute_gradients")
937
+ grad_out[pid], info_out[pid] = policy._build_compute_gradients(
938
+ builder, batch
939
+ )
940
+ grad_out = {k: builder.get(v) for k, v in grad_out.items()}
941
+ info_out = {k: builder.get(v) for k, v in info_out.items()}
942
+ else:
943
+ for pid, batch in samples.policy_batches.items():
944
+ if self.is_policy_to_train is not None and not self.is_policy_to_train(
945
+ pid, samples
946
+ ):
947
+ continue
948
+ grad_out[pid], info_out[pid] = self.policy_map[pid].compute_gradients(
949
+ batch
950
+ )
951
+
952
+ info_out["batch_count"] = samples.count
953
+ if log_once("grad_out"):
954
+ logger.info("Compute grad info:\n\n{}\n".format(summarize(info_out)))
955
+
956
+ return grad_out, info_out
957
+
958
+ def apply_gradients(
959
+ self,
960
+ grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]],
961
+ ) -> None:
962
+ """Applies the given gradients to this worker's models.
963
+
964
+ Uses the Policy's/ies' apply_gradients method(s) to perform the
965
+ operations.
966
+
967
+ Args:
968
+ grads: Single ModelGradients (single-agent case) or a dict
969
+ mapping PolicyIDs to the respective model gradients
970
+ structs.
971
+
972
+ .. testcode::
973
+ :skipif: True
974
+
975
+ import gymnasium as gym
976
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
977
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
978
+ worker = RolloutWorker(
979
+ env_creator=lambda _: gym.make("CartPole-v1"),
980
+ default_policy_class=PPOTF1Policy)
981
+ samples = worker.sample()
982
+ grads, info = worker.compute_gradients(samples)
983
+ worker.apply_gradients(grads)
984
+ """
985
+ if log_once("apply_gradients"):
986
+ logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
987
+ # Grads is a dict (mapping PolicyIDs to ModelGradients).
988
+ # Multi-agent case.
989
+ if isinstance(grads, dict):
990
+ for pid, g in grads.items():
991
+ if self.is_policy_to_train is None or self.is_policy_to_train(
992
+ pid, None
993
+ ):
994
+ self.policy_map[pid].apply_gradients(g)
995
+ # Grads is a ModelGradients type. Single-agent case.
996
+ elif self.is_policy_to_train is None or self.is_policy_to_train(
997
+ DEFAULT_POLICY_ID, None
998
+ ):
999
+ self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
1000
+
1001
+ @override(EnvRunner)
1002
+ def get_metrics(self) -> List[RolloutMetrics]:
1003
+ """Returns the thus-far collected metrics from this worker's rollouts.
1004
+
1005
+ Returns:
1006
+ List of RolloutMetrics collected thus-far.
1007
+ """
1008
+ # Get metrics from sampler (if any).
1009
+ if self.sampler is not None:
1010
+ out = self.sampler.get_metrics()
1011
+ else:
1012
+ out = []
1013
+
1014
+ return out
1015
+
1016
+ def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]:
1017
+ """Calls the given function with each sub-environment as arg.
1018
+
1019
+ Args:
1020
+ func: The function to call for each underlying
1021
+ sub-environment (as only arg).
1022
+
1023
+ Returns:
1024
+ The list of return values of all calls to `func([env])`.
1025
+ """
1026
+
1027
+ if self.async_env is None:
1028
+ return []
1029
+
1030
+ envs = self.async_env.get_sub_environments()
1031
+ # Empty list (not implemented): Call function directly on the
1032
+ # BaseEnv.
1033
+ if not envs:
1034
+ return [func(self.async_env)]
1035
+ # Call function on all underlying (vectorized) sub environments.
1036
+ else:
1037
+ return [func(e) for e in envs]
1038
+
1039
+ def foreach_env_with_context(
1040
+ self, func: Callable[[EnvType, EnvContext], T]
1041
+ ) -> List[T]:
1042
+ """Calls given function with each sub-env plus env_ctx as args.
1043
+
1044
+ Args:
1045
+ func: The function to call for each underlying
1046
+ sub-environment and its EnvContext (as the args).
1047
+
1048
+ Returns:
1049
+ The list of return values of all calls to `func([env, ctx])`.
1050
+ """
1051
+
1052
+ if self.async_env is None:
1053
+ return []
1054
+
1055
+ envs = self.async_env.get_sub_environments()
1056
+ # Empty list (not implemented): Call function directly on the
1057
+ # BaseEnv.
1058
+ if not envs:
1059
+ return [func(self.async_env, self.env_context)]
1060
+ # Call function on all underlying (vectorized) sub environments.
1061
+ else:
1062
+ ret = []
1063
+ for i, e in enumerate(envs):
1064
+ ctx = self.env_context.copy_with_overrides(vector_index=i)
1065
+ ret.append(func(e, ctx))
1066
+ return ret
1067
+
1068
+ def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Optional[Policy]:
1069
+ """Return policy for the specified id, or None.
1070
+
1071
+ Args:
1072
+ policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID
1073
+ (in the single agent case).
1074
+
1075
+ Returns:
1076
+ The policy under the given ID (or None if not found).
1077
+ """
1078
+ return self.policy_map.get(policy_id)
1079
+
1080
+ def add_policy(
1081
+ self,
1082
+ policy_id: PolicyID,
1083
+ policy_cls: Optional[Type[Policy]] = None,
1084
+ policy: Optional[Policy] = None,
1085
+ *,
1086
+ observation_space: Optional[Space] = None,
1087
+ action_space: Optional[Space] = None,
1088
+ config: Optional[PartialAlgorithmConfigDict] = None,
1089
+ policy_state: Optional[PolicyState] = None,
1090
+ policy_mapping_fn=None,
1091
+ policies_to_train: Optional[
1092
+ Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
1093
+ ] = None,
1094
+ module_spec: Optional[RLModuleSpec] = None,
1095
+ ) -> Policy:
1096
+ """Adds a new policy to this RolloutWorker.
1097
+
1098
+ Args:
1099
+ policy_id: ID of the policy to add.
1100
+ policy_cls: The Policy class to use for constructing the new Policy.
1101
+ Note: Only one of `policy_cls` or `policy` must be provided.
1102
+ policy: The Policy instance to add to this algorithm.
1103
+ Note: Only one of `policy_cls` or `policy` must be provided.
1104
+ observation_space: The observation space of the policy to add.
1105
+ action_space: The action space of the policy to add.
1106
+ config: The config overrides for the policy to add.
1107
+ policy_state: Optional state dict to apply to the new
1108
+ policy instance, right after its construction.
1109
+ policy_mapping_fn: An optional (updated) policy mapping function
1110
+ to use from here on. Note that already ongoing episodes will
1111
+ not change their mapping but will use the old mapping till
1112
+ the end of the episode.
1113
+ policies_to_train: An optional collection of policy IDs to be
1114
+ trained or a callable taking PolicyID and - optionally -
1115
+ SampleBatchType and returning a bool (trainable or not?).
1116
+ If None, will keep the existing setup in place.
1117
+ Policies, whose IDs are not in the list (or for which the
1118
+ callable returns False) will not be updated.
1119
+ module_spec: In the new RLModule API we need to pass in the module_spec for
1120
+ the new module that is supposed to be added. Knowing the policy spec is
1121
+ not sufficient.
1122
+
1123
+ Returns:
1124
+ The newly added policy.
1125
+
1126
+ Raises:
1127
+ ValueError: If both `policy_cls` AND `policy` are provided.
1128
+ KeyError: If the given `policy_id` already exists in this worker's
1129
+ PolicyMap.
1130
+ """
1131
+ validate_module_id(policy_id, error=False)
1132
+
1133
+ if module_spec is not None:
1134
+ raise ValueError(
1135
+ "If you pass in module_spec to the policy, the RLModule API needs "
1136
+ "to be enabled."
1137
+ )
1138
+
1139
+ if policy_id in self.policy_map:
1140
+ raise KeyError(
1141
+ f"Policy ID '{policy_id}' already exists in policy map! "
1142
+ "Make sure you use a Policy ID that has not been taken yet."
1143
+ " Policy IDs that are already in your policy map: "
1144
+ f"{list(self.policy_map.keys())}"
1145
+ )
1146
+ if (policy_cls is None) == (policy is None):
1147
+ raise ValueError(
1148
+ "Only one of `policy_cls` or `policy` must be provided to "
1149
+ "RolloutWorker.add_policy()!"
1150
+ )
1151
+
1152
+ if policy is None:
1153
+ policy_dict_to_add, _ = self.config.get_multi_agent_setup(
1154
+ policies={
1155
+ policy_id: PolicySpec(
1156
+ policy_cls, observation_space, action_space, config
1157
+ )
1158
+ },
1159
+ env=self.env,
1160
+ spaces=self.spaces,
1161
+ default_policy_class=self.default_policy_class,
1162
+ )
1163
+ else:
1164
+ policy_dict_to_add = {
1165
+ policy_id: PolicySpec(
1166
+ type(policy),
1167
+ policy.observation_space,
1168
+ policy.action_space,
1169
+ policy.config,
1170
+ )
1171
+ }
1172
+
1173
+ self.policy_dict.update(policy_dict_to_add)
1174
+ self._update_policy_map(
1175
+ policy_dict=policy_dict_to_add,
1176
+ policy=policy,
1177
+ policy_states={policy_id: policy_state},
1178
+ single_agent_rl_module_spec=module_spec,
1179
+ )
1180
+
1181
+ self.set_policy_mapping_fn(policy_mapping_fn)
1182
+ if policies_to_train is not None:
1183
+ self.set_is_policy_to_train(policies_to_train)
1184
+
1185
+ return self.policy_map[policy_id]
1186
+
1187
+ def remove_policy(
1188
+ self,
1189
+ *,
1190
+ policy_id: PolicyID = DEFAULT_POLICY_ID,
1191
+ policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
1192
+ policies_to_train: Optional[
1193
+ Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
1194
+ ] = None,
1195
+ ) -> None:
1196
+ """Removes a policy from this RolloutWorker.
1197
+
1198
+ Args:
1199
+ policy_id: ID of the policy to be removed. None for
1200
+ DEFAULT_POLICY_ID.
1201
+ policy_mapping_fn: An optional (updated) policy mapping function
1202
+ to use from here on. Note that already ongoing episodes will
1203
+ not change their mapping but will use the old mapping till
1204
+ the end of the episode.
1205
+ policies_to_train: An optional collection of policy IDs to be
1206
+ trained or a callable taking PolicyID and - optionally -
1207
+ SampleBatchType and returning a bool (trainable or not?).
1208
+ If None, will keep the existing setup in place.
1209
+ Policies, whose IDs are not in the list (or for which the
1210
+ callable returns False) will not be updated.
1211
+ """
1212
+ if policy_id not in self.policy_map:
1213
+ raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
1214
+ del self.policy_map[policy_id]
1215
+ del self.preprocessors[policy_id]
1216
+ self.set_policy_mapping_fn(policy_mapping_fn)
1217
+ if policies_to_train is not None:
1218
+ self.set_is_policy_to_train(policies_to_train)
1219
+
1220
+ def set_policy_mapping_fn(
1221
+ self,
1222
+ policy_mapping_fn: Optional[Callable[[AgentID, Any], PolicyID]] = None,
1223
+ ) -> None:
1224
+ """Sets `self.policy_mapping_fn` to a new callable (if provided).
1225
+
1226
+ Args:
1227
+ policy_mapping_fn: The new mapping function to use. If None,
1228
+ will keep the existing mapping function in place.
1229
+ """
1230
+ if policy_mapping_fn is not None:
1231
+ self.policy_mapping_fn = policy_mapping_fn
1232
+ if not callable(self.policy_mapping_fn):
1233
+ raise ValueError("`policy_mapping_fn` must be a callable!")
1234
+
1235
+ def set_is_policy_to_train(
1236
+ self,
1237
+ is_policy_to_train: Union[
1238
+ Collection[PolicyID], Callable[[PolicyID, Optional[SampleBatchType]], bool]
1239
+ ],
1240
+ ) -> None:
1241
+ """Sets `self.is_policy_to_train()` to a new callable.
1242
+
1243
+ Args:
1244
+ is_policy_to_train: A collection of policy IDs to be
1245
+ trained or a callable taking PolicyID and - optionally -
1246
+ SampleBatchType and returning a bool (trainable or not?).
1247
+ If None, will keep the existing setup in place.
1248
+ Policies, whose IDs are not in the list (or for which the
1249
+ callable returns False) will not be updated.
1250
+ """
1251
+ # If collection given, construct a simple default callable returning True
1252
+ # if the PolicyID is found in the list/set of IDs.
1253
+ if not callable(is_policy_to_train):
1254
+ assert isinstance(is_policy_to_train, (list, set, tuple)), (
1255
+ "ERROR: `is_policy_to_train`must be a [list|set|tuple] or a "
1256
+ "callable taking PolicyID and SampleBatch and returning "
1257
+ "True|False (trainable or not?)."
1258
+ )
1259
+ pols = set(is_policy_to_train)
1260
+
1261
+ def is_policy_to_train(pid, batch=None):
1262
+ return pid in pols
1263
+
1264
+ self.is_policy_to_train = is_policy_to_train
1265
+
1266
+ @PublicAPI(stability="alpha")
1267
+ def get_policies_to_train(
1268
+ self, batch: Optional[SampleBatchType] = None
1269
+ ) -> Set[PolicyID]:
1270
+ """Returns all policies-to-train, given an optional batch.
1271
+
1272
+ Loops through all policies currently in `self.policy_map` and checks
1273
+ the return value of `self.is_policy_to_train(pid, batch)`.
1274
+
1275
+ Args:
1276
+ batch: An optional SampleBatchType for the
1277
+ `self.is_policy_to_train(pid, [batch]?)` check.
1278
+
1279
+ Returns:
1280
+ The set of currently trainable policy IDs, given the optional
1281
+ `batch`.
1282
+ """
1283
+ return {
1284
+ pid
1285
+ for pid in self.policy_map.keys()
1286
+ if self.is_policy_to_train is None or self.is_policy_to_train(pid, batch)
1287
+ }
1288
+
1289
+ def for_policy(
1290
+ self,
1291
+ func: Callable[[Policy, Optional[Any]], T],
1292
+ policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID,
1293
+ **kwargs,
1294
+ ) -> T:
1295
+ """Calls the given function with the specified policy as first arg.
1296
+
1297
+ Args:
1298
+ func: The function to call with the policy as first arg.
1299
+ policy_id: The PolicyID of the policy to call the function with.
1300
+
1301
+ Keyword Args:
1302
+ kwargs: Additional kwargs to be passed to the call.
1303
+
1304
+ Returns:
1305
+ The return value of the function call.
1306
+ """
1307
+
1308
+ return func(self.policy_map[policy_id], **kwargs)
1309
+
1310
+ def foreach_policy(
1311
+ self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
1312
+ ) -> List[T]:
1313
+ """Calls the given function with each (policy, policy_id) tuple.
1314
+
1315
+ Args:
1316
+ func: The function to call with each (policy, policy ID) tuple.
1317
+
1318
+ Keyword Args:
1319
+ kwargs: Additional kwargs to be passed to the call.
1320
+
1321
+ Returns:
1322
+ The list of return values of all calls to
1323
+ `func([policy, pid, **kwargs])`.
1324
+ """
1325
+ return [func(policy, pid, **kwargs) for pid, policy in self.policy_map.items()]
1326
+
1327
+ def foreach_policy_to_train(
1328
+ self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs
1329
+ ) -> List[T]:
1330
+ """
1331
+ Calls the given function with each (policy, policy_id) tuple.
1332
+
1333
+ Only those policies/IDs will be called on, for which
1334
+ `self.is_policy_to_train()` returns True.
1335
+
1336
+ Args:
1337
+ func: The function to call with each (policy, policy ID) tuple,
1338
+ for only those policies that `self.is_policy_to_train`
1339
+ returns True.
1340
+
1341
+ Keyword Args:
1342
+ kwargs: Additional kwargs to be passed to the call.
1343
+
1344
+ Returns:
1345
+ The list of return values of all calls to
1346
+ `func([policy, pid, **kwargs])`.
1347
+ """
1348
+ return [
1349
+ # Make sure to only iterate over keys() and not items(). Iterating over
1350
+ # items will access policy_map elements even for pids that we do not need,
1351
+ # i.e. those that are not in policy_to_train. Access to policy_map elements
1352
+ # can cause disk access for policies that were offloaded to disk. Since
1353
+ # these policies will be skipped in the for-loop accessing them is
1354
+ # unnecessary, making subsequent disk access unnecessary.
1355
+ func(self.policy_map[pid], pid, **kwargs)
1356
+ for pid in self.policy_map.keys()
1357
+ if self.is_policy_to_train is None or self.is_policy_to_train(pid, None)
1358
+ ]
1359
+
1360
+ def sync_filters(self, new_filters: dict) -> None:
1361
+ """Changes self's filter to given and rebases any accumulated delta.
1362
+
1363
+ Args:
1364
+ new_filters: Filters with new state to update local copy.
1365
+ """
1366
+ assert all(k in new_filters for k in self.filters)
1367
+ for k in self.filters:
1368
+ self.filters[k].sync(new_filters[k])
1369
+
1370
+ def get_filters(self, flush_after: bool = False) -> Dict:
1371
+ """Returns a snapshot of filters.
1372
+
1373
+ Args:
1374
+ flush_after: Clears the filter buffer state.
1375
+
1376
+ Returns:
1377
+ Dict for serializable filters
1378
+ """
1379
+ return_filters = {}
1380
+ for k, f in self.filters.items():
1381
+ return_filters[k] = f.as_serializable()
1382
+ if flush_after:
1383
+ f.reset_buffer()
1384
+ return return_filters
1385
+
1386
+ def get_state(self) -> dict:
1387
+ filters = self.get_filters(flush_after=True)
1388
+ policy_states = {}
1389
+ for pid in self.policy_map.keys():
1390
+ # If required by the user, only capture policies that are actually
1391
+ # trainable. Otherwise, capture all policies (for saving to disk).
1392
+ if (
1393
+ not self.config.checkpoint_trainable_policies_only
1394
+ or self.is_policy_to_train is None
1395
+ or self.is_policy_to_train(pid)
1396
+ ):
1397
+ policy_states[pid] = self.policy_map[pid].get_state()
1398
+
1399
+ return {
1400
+ # List all known policy IDs here for convenience. When an Algorithm gets
1401
+ # restored from a checkpoint, it will not have access to the list of
1402
+ # possible IDs as each policy is stored in its own sub-dir
1403
+ # (see "policy_states").
1404
+ "policy_ids": list(self.policy_map.keys()),
1405
+ # Note that this field will not be stored in the algorithm checkpoint's
1406
+ # state file, but each policy will get its own state file generated in
1407
+ # a sub-dir within the algo's checkpoint dir.
1408
+ "policy_states": policy_states,
1409
+ # Also store current mapping fn and which policies to train.
1410
+ "policy_mapping_fn": self.policy_mapping_fn,
1411
+ "is_policy_to_train": self.is_policy_to_train,
1412
+ # TODO: Filters will be replaced by connectors.
1413
+ "filters": filters,
1414
+ }
1415
+
1416
+ def set_state(self, state: dict) -> None:
1417
+ # Backward compatibility (old checkpoints' states would have the local
1418
+ # worker state as a bytes object, not a dict).
1419
+ if isinstance(state, bytes):
1420
+ state = pickle.loads(state)
1421
+
1422
+ # TODO: Once filters are handled by connectors, get rid of the "filters"
1423
+ # key in `state` entirely (will be part of the policies then).
1424
+ self.sync_filters(state["filters"])
1425
+
1426
+ # Support older checkpoint versions (< 1.0), in which the policy_map
1427
+ # was stored under the "state" key, not "policy_states".
1428
+ policy_states = (
1429
+ state["policy_states"] if "policy_states" in state else state["state"]
1430
+ )
1431
+ for pid, policy_state in policy_states.items():
1432
+ # If - for some reason - we have an invalid PolicyID in the state,
1433
+ # this might be from an older checkpoint (pre v1.0). Just warn here.
1434
+ validate_module_id(pid, error=False)
1435
+
1436
+ if pid not in self.policy_map:
1437
+ spec = policy_state.get("policy_spec", None)
1438
+ if spec is None:
1439
+ logger.warning(
1440
+ f"PolicyID '{pid}' was probably added on-the-fly (not"
1441
+ " part of the static `multagent.policies` config) and"
1442
+ " no PolicySpec objects found in the pickled policy "
1443
+ f"state. Will not add `{pid}`, but ignore it for now."
1444
+ )
1445
+ else:
1446
+ policy_spec = (
1447
+ PolicySpec.deserialize(spec) if isinstance(spec, dict) else spec
1448
+ )
1449
+ self.add_policy(
1450
+ policy_id=pid,
1451
+ policy_cls=policy_spec.policy_class,
1452
+ observation_space=policy_spec.observation_space,
1453
+ action_space=policy_spec.action_space,
1454
+ config=policy_spec.config,
1455
+ )
1456
+ if pid in self.policy_map:
1457
+ self.policy_map[pid].set_state(policy_state)
1458
+
1459
+ # Also restore mapping fn and which policies to train.
1460
+ if "policy_mapping_fn" in state:
1461
+ self.set_policy_mapping_fn(state["policy_mapping_fn"])
1462
+ if state.get("is_policy_to_train") is not None:
1463
+ self.set_is_policy_to_train(state["is_policy_to_train"])
1464
+
1465
+ def get_weights(
1466
+ self,
1467
+ policies: Optional[Collection[PolicyID]] = None,
1468
+ inference_only: bool = False,
1469
+ ) -> Dict[PolicyID, ModelWeights]:
1470
+ """Returns each policies' model weights of this worker.
1471
+
1472
+ Args:
1473
+ policies: List of PolicyIDs to get the weights from.
1474
+ Use None for all policies.
1475
+ inference_only: This argument is only added for interface
1476
+ consistency with the new api stack.
1477
+
1478
+ Returns:
1479
+ Dict mapping PolicyIDs to ModelWeights.
1480
+
1481
+ .. testcode::
1482
+ :skipif: True
1483
+
1484
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
1485
+ # Create a RolloutWorker.
1486
+ worker = ...
1487
+ weights = worker.get_weights()
1488
+ print(weights)
1489
+
1490
+ .. testoutput::
1491
+
1492
+ {"default_policy": {"layer1": array(...), "layer2": ...}}
1493
+ """
1494
+ if policies is None:
1495
+ policies = list(self.policy_map.keys())
1496
+ policies = force_list(policies)
1497
+
1498
+ return {
1499
+ # Make sure to only iterate over keys() and not items(). Iterating over
1500
+ # items will access policy_map elements even for pids that we do not need,
1501
+ # i.e. those that are not in policies. Access to policy_map elements can
1502
+ # cause disk access for policies that were offloaded to disk. Since these
1503
+ # policies will be skipped in the for-loop accessing them is unnecessary,
1504
+ # making subsequent disk access unnecessary.
1505
+ pid: self.policy_map[pid].get_weights()
1506
+ for pid in self.policy_map.keys()
1507
+ if pid in policies
1508
+ }
1509
+
1510
+ def set_weights(
1511
+ self,
1512
+ weights: Dict[PolicyID, ModelWeights],
1513
+ global_vars: Optional[Dict] = None,
1514
+ weights_seq_no: Optional[int] = None,
1515
+ ) -> None:
1516
+ """Sets each policies' model weights of this worker.
1517
+
1518
+ Args:
1519
+ weights: Dict mapping PolicyIDs to the new weights to be used.
1520
+ global_vars: An optional global vars dict to set this
1521
+ worker to. If None, do not update the global_vars.
1522
+ weights_seq_no: If needed, a sequence number for the weights version
1523
+ can be passed into this method. If not None, will store this seq no
1524
+ (in self.weights_seq_no) and in future calls - if the seq no did not
1525
+ change wrt. the last call - will ignore the call to save on performance.
1526
+
1527
+ .. testcode::
1528
+ :skipif: True
1529
+
1530
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
1531
+ # Create a RolloutWorker.
1532
+ worker = ...
1533
+ weights = worker.get_weights()
1534
+ # Set `global_vars` (timestep) as well.
1535
+ worker.set_weights(weights, {"timestep": 42})
1536
+ """
1537
+ # Only update our weights, if no seq no given OR given seq no is different
1538
+ # from ours.
1539
+ if weights_seq_no is None or weights_seq_no != self.weights_seq_no:
1540
+ # If per-policy weights are object refs, `ray.get()` them first.
1541
+ if weights and isinstance(next(iter(weights.values())), ObjectRef):
1542
+ actual_weights = ray.get(list(weights.values()))
1543
+ weights = {
1544
+ pid: actual_weights[i] for i, pid in enumerate(weights.keys())
1545
+ }
1546
+
1547
+ for pid, w in weights.items():
1548
+ if pid in self.policy_map:
1549
+ self.policy_map[pid].set_weights(w)
1550
+ elif log_once("set_weights_on_non_existent_policy"):
1551
+ logger.warning(
1552
+ "`RolloutWorker.set_weights()` used with weights from "
1553
+ f"policyID={pid}, but this policy cannot be found on this "
1554
+ f"worker! Skipping ..."
1555
+ )
1556
+
1557
+ self.weights_seq_no = weights_seq_no
1558
+
1559
+ if global_vars:
1560
+ self.set_global_vars(global_vars)
1561
+
1562
+ def get_global_vars(self) -> dict:
1563
+ """Returns the current `self.global_vars` dict of this RolloutWorker.
1564
+
1565
+ Returns:
1566
+ The current `self.global_vars` dict of this RolloutWorker.
1567
+
1568
+ .. testcode::
1569
+ :skipif: True
1570
+
1571
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
1572
+ # Create a RolloutWorker.
1573
+ worker = ...
1574
+ global_vars = worker.get_global_vars()
1575
+ print(global_vars)
1576
+
1577
+ .. testoutput::
1578
+
1579
+ {"timestep": 424242}
1580
+ """
1581
+ return self.global_vars
1582
+
1583
+ def set_global_vars(
1584
+ self,
1585
+ global_vars: dict,
1586
+ policy_ids: Optional[List[PolicyID]] = None,
1587
+ ) -> None:
1588
+ """Updates this worker's and all its policies' global vars.
1589
+
1590
+ Updates are done using the dict's update method.
1591
+
1592
+ Args:
1593
+ global_vars: The global_vars dict to update the `self.global_vars` dict
1594
+ from.
1595
+ policy_ids: Optional list of Policy IDs to update. If None, will update all
1596
+ policies on the to-be-updated workers.
1597
+
1598
+ .. testcode::
1599
+ :skipif: True
1600
+
1601
+ worker = ...
1602
+ global_vars = worker.set_global_vars(
1603
+ ... {"timestep": 4242})
1604
+ """
1605
+ # Handle per-policy values.
1606
+ global_vars_copy = global_vars.copy()
1607
+ gradient_updates_per_policy = global_vars_copy.pop(
1608
+ "num_grad_updates_per_policy", {}
1609
+ )
1610
+ self.global_vars["num_grad_updates_per_policy"].update(
1611
+ gradient_updates_per_policy
1612
+ )
1613
+ # Only update explicitly provided policies or those that that are being
1614
+ # trained, in order to avoid superfluous access of policies, which might have
1615
+ # been offloaded to the object store.
1616
+ # Important b/c global vars are constantly being updated.
1617
+ for pid in policy_ids if policy_ids is not None else self.policy_map.keys():
1618
+ if self.is_policy_to_train is None or self.is_policy_to_train(pid, None):
1619
+ self.policy_map[pid].on_global_var_update(
1620
+ dict(
1621
+ global_vars_copy,
1622
+ # If count is None, Policy won't update the counter.
1623
+ **{"num_grad_updates": gradient_updates_per_policy.get(pid)},
1624
+ )
1625
+ )
1626
+
1627
+ # Update all other global vars.
1628
+ self.global_vars.update(global_vars_copy)
1629
+
1630
+ @override(EnvRunner)
1631
+ def stop(self) -> None:
1632
+ """Releases all resources used by this RolloutWorker."""
1633
+
1634
+ # If we have an env -> Release its resources.
1635
+ if self.env is not None:
1636
+ self.async_env.stop()
1637
+
1638
+ # Close all policies' sessions (if tf static graph).
1639
+ for policy in self.policy_map.cache.values():
1640
+ sess = policy.get_session()
1641
+ # Closes the tf session, if any.
1642
+ if sess is not None:
1643
+ sess.close()
1644
+
1645
+ def lock(self) -> None:
1646
+ """Locks this RolloutWorker via its own threading.Lock."""
1647
+ self._lock.acquire()
1648
+
1649
+ def unlock(self) -> None:
1650
+ """Unlocks this RolloutWorker via its own threading.Lock."""
1651
+ self._lock.release()
1652
+
1653
+ def setup_torch_data_parallel(
1654
+ self, url: str, world_rank: int, world_size: int, backend: str
1655
+ ) -> None:
1656
+ """Join a torch process group for distributed SGD."""
1657
+
1658
+ logger.info(
1659
+ "Joining process group, url={}, world_rank={}, "
1660
+ "world_size={}, backend={}".format(url, world_rank, world_size, backend)
1661
+ )
1662
+ torch.distributed.init_process_group(
1663
+ backend=backend, init_method=url, rank=world_rank, world_size=world_size
1664
+ )
1665
+
1666
+ for pid, policy in self.policy_map.items():
1667
+ if not isinstance(policy, (TorchPolicy, TorchPolicyV2)):
1668
+ raise ValueError(
1669
+ "This policy does not support torch distributed", policy
1670
+ )
1671
+ policy.distributed_world_size = world_size
1672
+
1673
+ def creation_args(self) -> dict:
1674
+ """Returns the kwargs dict used to create this worker."""
1675
+ return self._original_kwargs
1676
+
1677
+ def get_host(self) -> str:
1678
+ """Returns the hostname of the process running this evaluator."""
1679
+ return platform.node()
1680
+
1681
+ def get_node_ip(self) -> str:
1682
+ """Returns the IP address of the node that this worker runs on."""
1683
+ return ray.util.get_node_ip_address()
1684
+
1685
+ def find_free_port(self) -> int:
1686
+ """Finds a free port on the node that this worker runs on."""
1687
+ from ray.air._internal.util import find_free_port
1688
+
1689
+ return find_free_port()
1690
+
1691
+ def _update_policy_map(
1692
+ self,
1693
+ *,
1694
+ policy_dict: MultiAgentPolicyConfigDict,
1695
+ policy: Optional[Policy] = None,
1696
+ policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
1697
+ single_agent_rl_module_spec: Optional[RLModuleSpec] = None,
1698
+ ) -> None:
1699
+ """Updates the policy map (and other stuff) on this worker.
1700
+
1701
+ It performs the following:
1702
+ 1. It updates the observation preprocessors and updates the policy_specs
1703
+ with the postprocessed observation_spaces.
1704
+ 2. It updates the policy_specs with the complete algorithm_config (merged
1705
+ with the policy_spec's config).
1706
+ 3. If needed it will update the self.multi_rl_module_spec on this worker
1707
+ 3. It updates the policy map with the new policies
1708
+ 4. It updates the filter dict
1709
+ 5. It calls the on_create_policy() hook of the callbacks on the newly added
1710
+ policies.
1711
+
1712
+ Args:
1713
+ policy_dict: The policy dict to update the policy map with.
1714
+ policy: The policy to update the policy map with.
1715
+ policy_states: The policy states to update the policy map with.
1716
+ single_agent_rl_module_spec: The RLModuleSpec to add to the
1717
+ MultiRLModuleSpec. If None, the config's
1718
+ `get_default_rl_module_spec` method's output will be used to create
1719
+ the policy with.
1720
+ """
1721
+
1722
+ # Update the input policy dict with the postprocessed observation spaces and
1723
+ # merge configs. Also updates the preprocessor dict.
1724
+ updated_policy_dict = self._get_complete_policy_specs_dict(policy_dict)
1725
+
1726
+ # Builds the self.policy_map dict
1727
+ self._build_policy_map(
1728
+ policy_dict=updated_policy_dict,
1729
+ policy=policy,
1730
+ policy_states=policy_states,
1731
+ )
1732
+
1733
+ # Initialize the filter dict
1734
+ self._update_filter_dict(updated_policy_dict)
1735
+
1736
+ # Call callback policy init hooks (only if the added policy did not exist
1737
+ # before).
1738
+ if policy is None:
1739
+ self._call_callbacks_on_create_policy()
1740
+
1741
+ if self.worker_index == 0:
1742
+ logger.info(f"Built policy map: {self.policy_map}")
1743
+ logger.info(f"Built preprocessor map: {self.preprocessors}")
1744
+
1745
+ def _get_complete_policy_specs_dict(
1746
+ self, policy_dict: MultiAgentPolicyConfigDict
1747
+ ) -> MultiAgentPolicyConfigDict:
1748
+ """Processes the policy dict and creates a new copy with the processed attrs.
1749
+
1750
+ This processes the observation_space and prepares them for passing to rl module
1751
+ construction. It also merges the policy configs with the algorithm config.
1752
+ During this processing, we will also construct the preprocessors dict.
1753
+ """
1754
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
1755
+
1756
+ updated_policy_dict = copy.deepcopy(policy_dict)
1757
+ # If our preprocessors dict does not exist yet, create it here.
1758
+ self.preprocessors = self.preprocessors or {}
1759
+ # Loop through given policy-dict and add each entry to our map.
1760
+ for name, policy_spec in sorted(updated_policy_dict.items()):
1761
+ logger.debug("Creating policy for {}".format(name))
1762
+
1763
+ # Policy brings its own complete AlgorithmConfig -> Use it for this policy.
1764
+ if isinstance(policy_spec.config, AlgorithmConfig):
1765
+ merged_conf = policy_spec.config
1766
+ else:
1767
+ # Update the general config with the specific config
1768
+ # for this particular policy.
1769
+ merged_conf: "AlgorithmConfig" = self.config.copy(copy_frozen=False)
1770
+ merged_conf.update_from_dict(policy_spec.config or {})
1771
+
1772
+ # Update num_workers and worker_index.
1773
+ merged_conf.worker_index = self.worker_index
1774
+
1775
+ # Preprocessors.
1776
+ obs_space = policy_spec.observation_space
1777
+ # Initialize preprocessor for this policy to None.
1778
+ self.preprocessors[name] = None
1779
+ if self.preprocessing_enabled:
1780
+ # Policies should deal with preprocessed (automatically flattened)
1781
+ # observations if preprocessing is enabled.
1782
+ preprocessor = ModelCatalog.get_preprocessor_for_space(
1783
+ obs_space,
1784
+ merged_conf.model,
1785
+ include_multi_binary=False,
1786
+ )
1787
+ # Original observation space should be accessible at
1788
+ # obs_space.original_space after this step.
1789
+ if preprocessor is not None:
1790
+ obs_space = preprocessor.observation_space
1791
+
1792
+ policy_spec.config = merged_conf
1793
+ policy_spec.observation_space = obs_space
1794
+
1795
+ return updated_policy_dict
1796
+
1797
+ def _update_policy_dict_with_multi_rl_module(
1798
+ self, policy_dict: MultiAgentPolicyConfigDict
1799
+ ) -> MultiAgentPolicyConfigDict:
1800
+ for name, policy_spec in policy_dict.items():
1801
+ policy_spec.config["__multi_rl_module_spec"] = self.multi_rl_module_spec
1802
+ return policy_dict
1803
+
1804
+ def _build_policy_map(
1805
+ self,
1806
+ *,
1807
+ policy_dict: MultiAgentPolicyConfigDict,
1808
+ policy: Optional[Policy] = None,
1809
+ policy_states: Optional[Dict[PolicyID, PolicyState]] = None,
1810
+ ) -> None:
1811
+ """Adds the given policy_dict to `self.policy_map`.
1812
+
1813
+ Args:
1814
+ policy_dict: The MultiAgentPolicyConfigDict to be added to this
1815
+ worker's PolicyMap.
1816
+ policy: If the policy to add already exists, user can provide it here.
1817
+ policy_states: Optional dict from PolicyIDs to PolicyStates to
1818
+ restore the states of the policies being built.
1819
+ """
1820
+
1821
+ # If our policy_map does not exist yet, create it here.
1822
+ self.policy_map = self.policy_map or PolicyMap(
1823
+ capacity=self.config.policy_map_capacity,
1824
+ policy_states_are_swappable=self.config.policy_states_are_swappable,
1825
+ )
1826
+
1827
+ # Loop through given policy-dict and add each entry to our map.
1828
+ for name, policy_spec in sorted(policy_dict.items()):
1829
+ # Create the actual policy object.
1830
+ if policy is None:
1831
+ new_policy = create_policy_for_framework(
1832
+ policy_id=name,
1833
+ policy_class=get_tf_eager_cls_if_necessary(
1834
+ policy_spec.policy_class, policy_spec.config
1835
+ ),
1836
+ merged_config=policy_spec.config,
1837
+ observation_space=policy_spec.observation_space,
1838
+ action_space=policy_spec.action_space,
1839
+ worker_index=self.worker_index,
1840
+ seed=self.seed,
1841
+ )
1842
+ else:
1843
+ new_policy = policy
1844
+
1845
+ self.policy_map[name] = new_policy
1846
+
1847
+ restore_states = (policy_states or {}).get(name, None)
1848
+ # Set the state of the newly created policy before syncing filters, etc.
1849
+ if restore_states:
1850
+ new_policy.set_state(restore_states)
1851
+
1852
+ def _update_filter_dict(self, policy_dict: MultiAgentPolicyConfigDict) -> None:
1853
+ """Updates the filter dict for the given policy_dict."""
1854
+
1855
+ for name, policy_spec in sorted(policy_dict.items()):
1856
+ new_policy = self.policy_map[name]
1857
+ # Note(jungong) : We should only create new connectors for the
1858
+ # policy iff we are creating a new policy from scratch. i.e,
1859
+ # we should NOT create new connectors when we already have the
1860
+ # policy object created before this function call or have the
1861
+ # restoring states from the caller.
1862
+ # Also note that we cannot just check the existence of connectors
1863
+ # to decide whether we should create connectors because we may be
1864
+ # restoring a policy that has 0 connectors configured.
1865
+ if (
1866
+ new_policy.agent_connectors is None
1867
+ or new_policy.action_connectors is None
1868
+ ):
1869
+ # TODO(jungong) : revisit this. It will be nicer to create
1870
+ # connectors as the last step of Policy.__init__().
1871
+ create_connectors_for_policy(new_policy, policy_spec.config)
1872
+ maybe_get_filters_for_syncing(self, name)
1873
+
1874
+ def _call_callbacks_on_create_policy(self):
1875
+ """Calls the on_create_policy callback for each policy in the policy map."""
1876
+ for name, policy in self.policy_map.items():
1877
+ self.callbacks.on_create_policy(policy_id=name, policy=policy)
1878
+
1879
+ def _get_input_creator_from_config(self):
1880
+ def valid_module(class_path):
1881
+ if (
1882
+ isinstance(class_path, str)
1883
+ and not os.path.isfile(class_path)
1884
+ and "." in class_path
1885
+ ):
1886
+ module_path, class_name = class_path.rsplit(".", 1)
1887
+ try:
1888
+ spec = importlib.util.find_spec(module_path)
1889
+ if spec is not None:
1890
+ return True
1891
+ except (ModuleNotFoundError, ValueError):
1892
+ print(
1893
+ f"module {module_path} not found while trying to get "
1894
+ f"input {class_path}"
1895
+ )
1896
+ return False
1897
+
1898
+ # A callable returning an InputReader object to use.
1899
+ if isinstance(self.config.input_, FunctionType):
1900
+ return self.config.input_
1901
+ # Use RLlib's Sampler classes (SyncSampler).
1902
+ elif self.config.input_ == "sampler":
1903
+ return lambda ioctx: ioctx.default_sampler_input()
1904
+ # Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
1905
+ elif self.config.input_ == "dataset":
1906
+ assert self._ds_shards is not None
1907
+ # Input dataset shards should have already been prepared.
1908
+ # We just need to take the proper shard here.
1909
+ return lambda ioctx: DatasetReader(
1910
+ self._ds_shards[self.worker_index], ioctx
1911
+ )
1912
+ # Dict: Mix of different input methods with different ratios.
1913
+ elif isinstance(self.config.input_, dict):
1914
+ return lambda ioctx: ShuffledInput(
1915
+ MixedInput(self.config.input_, ioctx), self.config.shuffle_buffer_size
1916
+ )
1917
+ # A pre-registered input descriptor (str).
1918
+ elif isinstance(self.config.input_, str) and registry_contains_input(
1919
+ self.config.input_
1920
+ ):
1921
+ return registry_get_input(self.config.input_)
1922
+ # D4RL input.
1923
+ elif "d4rl" in self.config.input_:
1924
+ env_name = self.config.input_.split(".")[-1]
1925
+ return lambda ioctx: D4RLReader(env_name, ioctx)
1926
+ # Valid python module (class path) -> Create using `from_config`.
1927
+ elif valid_module(self.config.input_):
1928
+ return lambda ioctx: ShuffledInput(
1929
+ from_config(self.config.input_, ioctx=ioctx)
1930
+ )
1931
+ # JSON file or list of JSON files -> Use JsonReader (shuffled).
1932
+ else:
1933
+ return lambda ioctx: ShuffledInput(
1934
+ JsonReader(self.config.input_, ioctx), self.config.shuffle_buffer_size
1935
+ )
1936
+
1937
+ def _get_output_creator_from_config(self):
1938
+ if isinstance(self.config.output, FunctionType):
1939
+ return self.config.output
1940
+ elif self.config.output is None:
1941
+ return lambda ioctx: NoopOutput()
1942
+ elif self.config.output == "dataset":
1943
+ return lambda ioctx: DatasetWriter(
1944
+ ioctx, compress_columns=self.config.output_compress_columns
1945
+ )
1946
+ elif self.config.output == "logdir":
1947
+ return lambda ioctx: JsonWriter(
1948
+ ioctx.log_dir,
1949
+ ioctx,
1950
+ max_file_size=self.config.output_max_file_size,
1951
+ compress_columns=self.config.output_compress_columns,
1952
+ )
1953
+ else:
1954
+ return lambda ioctx: JsonWriter(
1955
+ self.config.output,
1956
+ ioctx,
1957
+ max_file_size=self.config.output_max_file_size,
1958
+ compress_columns=self.config.output_compress_columns,
1959
+ )
1960
+
1961
+ def _get_make_sub_env_fn(
1962
+ self, env_creator, env_context, validate_env, env_wrapper, seed
1963
+ ):
1964
+ def _make_sub_env_local(vector_index):
1965
+ # Used to created additional environments during environment
1966
+ # vectorization.
1967
+
1968
+ # Create the env context (config dict + meta-data) for
1969
+ # this particular sub-env within the vectorized one.
1970
+ env_ctx = env_context.copy_with_overrides(vector_index=vector_index)
1971
+ # Create the sub-env.
1972
+ env = env_creator(env_ctx)
1973
+ # Custom validation function given by user.
1974
+ if validate_env is not None:
1975
+ validate_env(env, env_ctx)
1976
+ # Use our wrapper, defined above.
1977
+ env = env_wrapper(env)
1978
+
1979
+ # Make sure a deterministic random seed is set on
1980
+ # all the sub-environments if specified.
1981
+ _update_env_seed_if_necessary(
1982
+ env, seed, env_context.worker_index, vector_index
1983
+ )
1984
+ return env
1985
+
1986
+ if not env_context.remote:
1987
+
1988
+ def _make_sub_env_remote(vector_index):
1989
+ sub_env = _make_sub_env_local(vector_index)
1990
+ self.callbacks.on_sub_environment_created(
1991
+ worker=self,
1992
+ sub_environment=sub_env,
1993
+ env_context=env_context.copy_with_overrides(
1994
+ worker_index=env_context.worker_index,
1995
+ vector_index=vector_index,
1996
+ remote=False,
1997
+ ),
1998
+ )
1999
+ return sub_env
2000
+
2001
+ return _make_sub_env_remote
2002
+
2003
+ else:
2004
+ return _make_sub_env_local
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import numpy as np
4
+ from typing import List, Any, Dict, TYPE_CHECKING
5
+
6
+ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
7
+ from ray.rllib.policy.policy import Policy
8
+ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
9
+ from ray.rllib.utils.annotations import OldAPIStack
10
+ from ray.rllib.utils.debug import summarize
11
+ from ray.rllib.utils.deprecation import deprecation_warning
12
+ from ray.rllib.utils.typing import PolicyID, AgentID
13
+ from ray.util.debug import log_once
14
+
15
+ if TYPE_CHECKING:
16
+ from ray.rllib.callbacks.callbacks import RLlibCallback
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _to_float_array(v: List[Any]) -> np.ndarray:
22
+ arr = np.array(v)
23
+ if arr.dtype == np.float64:
24
+ return arr.astype(np.float32) # save some memory
25
+ return arr
26
+
27
+
28
+ @OldAPIStack
29
+ class SampleBatchBuilder:
30
+ """Util to build a SampleBatch incrementally.
31
+
32
+ For efficiency, SampleBatches hold values in column form (as arrays).
33
+ However, it is useful to add data one row (dict) at a time.
34
+ """
35
+
36
+ _next_unroll_id = 0 # disambiguates unrolls within a single episode
37
+
38
+ def __init__(self):
39
+ self.buffers: Dict[str, List] = collections.defaultdict(list)
40
+ self.count = 0
41
+
42
+ def add_values(self, **values: Any) -> None:
43
+ """Add the given dictionary (row) of values to this batch."""
44
+
45
+ for k, v in values.items():
46
+ self.buffers[k].append(v)
47
+ self.count += 1
48
+
49
+ def add_batch(self, batch: SampleBatch) -> None:
50
+ """Add the given batch of values to this batch."""
51
+
52
+ for k, column in batch.items():
53
+ self.buffers[k].extend(column)
54
+ self.count += batch.count
55
+
56
+ def build_and_reset(self) -> SampleBatch:
57
+ """Returns a sample batch including all previously added values."""
58
+
59
+ batch = SampleBatch({k: _to_float_array(v) for k, v in self.buffers.items()})
60
+ if SampleBatch.UNROLL_ID not in batch:
61
+ batch[SampleBatch.UNROLL_ID] = np.repeat(
62
+ SampleBatchBuilder._next_unroll_id, batch.count
63
+ )
64
+ SampleBatchBuilder._next_unroll_id += 1
65
+ self.buffers.clear()
66
+ self.count = 0
67
+ return batch
68
+
69
+
70
+ @OldAPIStack
71
+ class MultiAgentSampleBatchBuilder:
72
+ """Util to build SampleBatches for each policy in a multi-agent env.
73
+
74
+ Input data is per-agent, while output data is per-policy. There is an M:N
75
+ mapping between agents and policies. We retain one local batch builder
76
+ per agent. When an agent is done, then its local batch is appended into the
77
+ corresponding policy batch for the agent's policy.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ policy_map: Dict[PolicyID, Policy],
83
+ clip_rewards: bool,
84
+ callbacks: "RLlibCallback",
85
+ ):
86
+ """Initialize a MultiAgentSampleBatchBuilder.
87
+
88
+ Args:
89
+ policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
90
+ clip_rewards (Union[bool,float]): Whether to clip rewards before
91
+ postprocessing (at +/-1.0) or the actual value to +/- clip.
92
+ callbacks: RLlib callbacks.
93
+ """
94
+ if log_once("MultiAgentSampleBatchBuilder"):
95
+ deprecation_warning(old="MultiAgentSampleBatchBuilder", error=False)
96
+ self.policy_map = policy_map
97
+ self.clip_rewards = clip_rewards
98
+ # Build the Policies' SampleBatchBuilders.
99
+ self.policy_builders = {k: SampleBatchBuilder() for k in policy_map.keys()}
100
+ # Whenever we observe a new agent, add a new SampleBatchBuilder for
101
+ # this agent.
102
+ self.agent_builders = {}
103
+ # Internal agent-to-policy map.
104
+ self.agent_to_policy = {}
105
+ self.callbacks = callbacks
106
+ # Number of "inference" steps taken in the environment.
107
+ # Regardless of the number of agents involved in each of these steps.
108
+ self.count = 0
109
+
110
+ def total(self) -> int:
111
+ """Returns the total number of steps taken in the env (all agents).
112
+
113
+ Returns:
114
+ int: The number of steps taken in total in the environment over all
115
+ agents.
116
+ """
117
+
118
+ return sum(a.count for a in self.agent_builders.values())
119
+
120
+ def has_pending_agent_data(self) -> bool:
121
+ """Returns whether there is pending unprocessed data.
122
+
123
+ Returns:
124
+ bool: True if there is at least one per-agent builder (with data
125
+ in it).
126
+ """
127
+
128
+ return len(self.agent_builders) > 0
129
+
130
+ def add_values(self, agent_id: AgentID, policy_id: AgentID, **values: Any) -> None:
131
+ """Add the given dictionary (row) of values to this batch.
132
+
133
+ Args:
134
+ agent_id: Unique id for the agent we are adding values for.
135
+ policy_id: Unique id for policy controlling the agent.
136
+ values: Row of values to add for this agent.
137
+ """
138
+
139
+ if agent_id not in self.agent_builders:
140
+ self.agent_builders[agent_id] = SampleBatchBuilder()
141
+ self.agent_to_policy[agent_id] = policy_id
142
+
143
+ # Include the current agent id for multi-agent algorithms.
144
+ if agent_id != _DUMMY_AGENT_ID:
145
+ values["agent_id"] = agent_id
146
+
147
+ self.agent_builders[agent_id].add_values(**values)
148
+
149
+ def postprocess_batch_so_far(self, episode=None) -> None:
150
+ """Apply policy postprocessors to any unprocessed rows.
151
+
152
+ This pushes the postprocessed per-agent batches onto the per-policy
153
+ builders, clearing per-agent state.
154
+
155
+ Args:
156
+ episode (Optional[Episode]): The Episode object that
157
+ holds this MultiAgentBatchBuilder object.
158
+ """
159
+
160
+ # Materialize the batches so far.
161
+ pre_batches = {}
162
+ for agent_id, builder in self.agent_builders.items():
163
+ pre_batches[agent_id] = (
164
+ self.policy_map[self.agent_to_policy[agent_id]],
165
+ builder.build_and_reset(),
166
+ )
167
+
168
+ # Apply postprocessor.
169
+ post_batches = {}
170
+ if self.clip_rewards is True:
171
+ for _, (_, pre_batch) in pre_batches.items():
172
+ pre_batch["rewards"] = np.sign(pre_batch["rewards"])
173
+ elif self.clip_rewards:
174
+ for _, (_, pre_batch) in pre_batches.items():
175
+ pre_batch["rewards"] = np.clip(
176
+ pre_batch["rewards"],
177
+ a_min=-self.clip_rewards,
178
+ a_max=self.clip_rewards,
179
+ )
180
+ for agent_id, (_, pre_batch) in pre_batches.items():
181
+ other_batches = pre_batches.copy()
182
+ del other_batches[agent_id]
183
+ policy = self.policy_map[self.agent_to_policy[agent_id]]
184
+ if (
185
+ not pre_batch.is_single_trajectory()
186
+ or len(set(pre_batch[SampleBatch.EPS_ID])) > 1
187
+ ):
188
+ raise ValueError(
189
+ "Batches sent to postprocessing must only contain steps "
190
+ "from a single trajectory.",
191
+ pre_batch,
192
+ )
193
+ # Call the Policy's Exploration's postprocess method.
194
+ post_batches[agent_id] = pre_batch
195
+ if getattr(policy, "exploration", None) is not None:
196
+ policy.exploration.postprocess_trajectory(
197
+ policy, post_batches[agent_id], policy.get_session()
198
+ )
199
+ post_batches[agent_id] = policy.postprocess_trajectory(
200
+ post_batches[agent_id], other_batches, episode
201
+ )
202
+
203
+ if log_once("after_post"):
204
+ logger.info(
205
+ "Trajectory fragment after postprocess_trajectory():\n\n{}\n".format(
206
+ summarize(post_batches)
207
+ )
208
+ )
209
+
210
+ # Append into policy batches and reset
211
+ from ray.rllib.evaluation.rollout_worker import get_global_worker
212
+
213
+ for agent_id, post_batch in sorted(post_batches.items()):
214
+ self.callbacks.on_postprocess_trajectory(
215
+ worker=get_global_worker(),
216
+ episode=episode,
217
+ agent_id=agent_id,
218
+ policy_id=self.agent_to_policy[agent_id],
219
+ policies=self.policy_map,
220
+ postprocessed_batch=post_batch,
221
+ original_batches=pre_batches,
222
+ )
223
+ self.policy_builders[self.agent_to_policy[agent_id]].add_batch(post_batch)
224
+
225
+ self.agent_builders.clear()
226
+ self.agent_to_policy.clear()
227
+
228
+ def check_missing_dones(self) -> None:
229
+ for agent_id, builder in self.agent_builders.items():
230
+ if not builder.buffers.is_terminated_or_truncated():
231
+ raise ValueError(
232
+ "The environment terminated for all agents, but we still "
233
+ "don't have a last observation for "
234
+ "agent {} (policy {}). ".format(
235
+ agent_id, self.agent_to_policy[agent_id]
236
+ )
237
+ + "Please ensure that you include the last observations "
238
+ "of all live agents when setting '__all__' terminated|truncated "
239
+ "to True. "
240
+ )
241
+
242
+ def build_and_reset(self, episode=None) -> MultiAgentBatch:
243
+ """Returns the accumulated sample batches for each policy.
244
+
245
+ Any unprocessed rows will be first postprocessed with a policy
246
+ postprocessor. The internal state of this builder will be reset.
247
+
248
+ Args:
249
+ episode (Optional[Episode]): The Episode object that
250
+ holds this MultiAgentBatchBuilder object or None.
251
+
252
+ Returns:
253
+ MultiAgentBatch: Returns the accumulated sample batches for each
254
+ policy.
255
+ """
256
+
257
+ self.postprocess_batch_so_far(episode)
258
+ policy_batches = {}
259
+ for policy_id, builder in self.policy_builders.items():
260
+ if builder.count > 0:
261
+ policy_batches[policy_id] = builder.build_and_reset()
262
+ old_count = self.count
263
+ self.count = 0
264
+ return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import queue
3
+ from abc import ABCMeta, abstractmethod
4
+ from collections import defaultdict, namedtuple
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ List,
9
+ Optional,
10
+ Type,
11
+ Union,
12
+ )
13
+
14
+ from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
15
+ from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
16
+ from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
17
+ from ray.rllib.evaluation.env_runner_v2 import EnvRunnerV2, _PerfStats
18
+ from ray.rllib.evaluation.metrics import RolloutMetrics
19
+ from ray.rllib.offline import InputReader
20
+ from ray.rllib.policy.sample_batch import concat_samples
21
+ from ray.rllib.utils.annotations import OldAPIStack, override
22
+ from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
23
+ from ray.rllib.utils.framework import try_import_tf
24
+ from ray.rllib.utils.typing import SampleBatchType
25
+ from ray.util.debug import log_once
26
+
27
+ if TYPE_CHECKING:
28
+ from ray.rllib.callbacks.callbacks import RLlibCallback
29
+ from ray.rllib.evaluation.observation_function import ObservationFunction
30
+ from ray.rllib.evaluation.rollout_worker import RolloutWorker
31
+
32
+ tf1, tf, _ = try_import_tf()
33
+ logger = logging.getLogger(__name__)
34
+
35
+ _PolicyEvalData = namedtuple(
36
+ "_PolicyEvalData",
37
+ ["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"],
38
+ )
39
+
40
+ # A batch of RNN states with dimensions [state_index, batch, state_object].
41
+ StateBatch = List[List[Any]]
42
+
43
+
44
+ class _NewEpisodeDefaultDict(defaultdict):
45
+ def __missing__(self, env_id):
46
+ if self.default_factory is None:
47
+ raise KeyError(env_id)
48
+ else:
49
+ ret = self[env_id] = self.default_factory(env_id)
50
+ return ret
51
+
52
+
53
+ @OldAPIStack
54
+ class SamplerInput(InputReader, metaclass=ABCMeta):
55
+ """Reads input experiences from an existing sampler."""
56
+
57
+ @override(InputReader)
58
+ def next(self) -> SampleBatchType:
59
+ batches = [self.get_data()]
60
+ batches.extend(self.get_extra_batches())
61
+ if len(batches) == 0:
62
+ raise RuntimeError("No data available from sampler.")
63
+ return concat_samples(batches)
64
+
65
+ @abstractmethod
66
+ def get_data(self) -> SampleBatchType:
67
+ """Called by `self.next()` to return the next batch of data.
68
+
69
+ Override this in child classes.
70
+
71
+ Returns:
72
+ The next batch of data.
73
+ """
74
+ raise NotImplementedError
75
+
76
+ @abstractmethod
77
+ def get_metrics(self) -> List[RolloutMetrics]:
78
+ """Returns list of episode metrics since the last call to this method.
79
+
80
+ The list will contain one RolloutMetrics object per completed episode.
81
+
82
+ Returns:
83
+ List of RolloutMetrics objects, one per completed episode since
84
+ the last call to this method.
85
+ """
86
+ raise NotImplementedError
87
+
88
+ @abstractmethod
89
+ def get_extra_batches(self) -> List[SampleBatchType]:
90
+ """Returns list of extra batches since the last call to this method.
91
+
92
+ The list will contain all SampleBatches or
93
+ MultiAgentBatches that the user has provided thus-far. Users can
94
+ add these "extra batches" to an episode by calling the episode's
95
+ `add_extra_batch([SampleBatchType])` method. This can be done from
96
+ inside an overridden `Policy.compute_actions_from_input_dict(...,
97
+ episodes)` or from a custom callback's `on_episode_[start|step|end]()`
98
+ methods.
99
+
100
+ Returns:
101
+ List of SamplesBatches or MultiAgentBatches provided thus-far by
102
+ the user since the last call to this method.
103
+ """
104
+ raise NotImplementedError
105
+
106
+
107
+ @OldAPIStack
108
+ class SyncSampler(SamplerInput):
109
+ """Sync SamplerInput that collects experiences when `get_data()` is called."""
110
+
111
+ def __init__(
112
+ self,
113
+ *,
114
+ worker: "RolloutWorker",
115
+ env: BaseEnv,
116
+ clip_rewards: Union[bool, float],
117
+ rollout_fragment_length: int,
118
+ count_steps_by: str = "env_steps",
119
+ callbacks: "RLlibCallback",
120
+ multiple_episodes_in_batch: bool = False,
121
+ normalize_actions: bool = True,
122
+ clip_actions: bool = False,
123
+ observation_fn: Optional["ObservationFunction"] = None,
124
+ sample_collector_class: Optional[Type[SampleCollector]] = None,
125
+ render: bool = False,
126
+ # Obsolete.
127
+ policies=None,
128
+ policy_mapping_fn=None,
129
+ preprocessors=None,
130
+ obs_filters=None,
131
+ tf_sess=None,
132
+ horizon=DEPRECATED_VALUE,
133
+ soft_horizon=DEPRECATED_VALUE,
134
+ no_done_at_end=DEPRECATED_VALUE,
135
+ ):
136
+ """Initializes a SyncSampler instance.
137
+
138
+ Args:
139
+ worker: The RolloutWorker that will use this Sampler for sampling.
140
+ env: Any Env object. Will be converted into an RLlib BaseEnv.
141
+ clip_rewards: True for +/-1.0 clipping,
142
+ actual float value for +/- value clipping. False for no
143
+ clipping.
144
+ rollout_fragment_length: The length of a fragment to collect
145
+ before building a SampleBatch from the data and resetting
146
+ the SampleBatchBuilder object.
147
+ count_steps_by: One of "env_steps" (default) or "agent_steps".
148
+ Use "agent_steps", if you want rollout lengths to be counted
149
+ by individual agent steps. In a multi-agent env,
150
+ a single env_step contains one or more agent_steps, depending
151
+ on how many agents are present at any given time in the
152
+ ongoing episode.
153
+ callbacks: The RLlibCallback object to use when episode
154
+ events happen during rollout.
155
+ multiple_episodes_in_batch: Whether to pack multiple
156
+ episodes into each batch. This guarantees batches will be
157
+ exactly `rollout_fragment_length` in size.
158
+ normalize_actions: Whether to normalize actions to the
159
+ action space's bounds.
160
+ clip_actions: Whether to clip actions according to the
161
+ given action_space's bounds.
162
+ observation_fn: Optional multi-agent observation func to use for
163
+ preprocessing observations.
164
+ sample_collector_class: An optional SampleCollector sub-class to
165
+ use to collect, store, and retrieve environment-, model-,
166
+ and sampler data.
167
+ render: Whether to try to render the environment after each step.
168
+ """
169
+ # All of the following arguments are deprecated. They will instead be
170
+ # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
171
+ if log_once("deprecated_sync_sampler_args"):
172
+ if policies is not None:
173
+ deprecation_warning(old="policies")
174
+ if policy_mapping_fn is not None:
175
+ deprecation_warning(old="policy_mapping_fn")
176
+ if preprocessors is not None:
177
+ deprecation_warning(old="preprocessors")
178
+ if obs_filters is not None:
179
+ deprecation_warning(old="obs_filters")
180
+ if tf_sess is not None:
181
+ deprecation_warning(old="tf_sess")
182
+ if horizon != DEPRECATED_VALUE:
183
+ deprecation_warning(old="horizon", error=True)
184
+ if soft_horizon != DEPRECATED_VALUE:
185
+ deprecation_warning(old="soft_horizon", error=True)
186
+ if no_done_at_end != DEPRECATED_VALUE:
187
+ deprecation_warning(old="no_done_at_end", error=True)
188
+
189
+ self.base_env = convert_to_base_env(env)
190
+ self.rollout_fragment_length = rollout_fragment_length
191
+ self.extra_batches = queue.Queue()
192
+ self.perf_stats = _PerfStats(
193
+ ema_coef=worker.config.sampler_perf_stats_ema_coef,
194
+ )
195
+ if not sample_collector_class:
196
+ sample_collector_class = SimpleListCollector
197
+ self.sample_collector = sample_collector_class(
198
+ worker.policy_map,
199
+ clip_rewards,
200
+ callbacks,
201
+ multiple_episodes_in_batch,
202
+ rollout_fragment_length,
203
+ count_steps_by=count_steps_by,
204
+ )
205
+ self.render = render
206
+
207
+ # Keep a reference to the underlying EnvRunnerV2 instance for
208
+ # unit testing purpose.
209
+ self._env_runner_obj = EnvRunnerV2(
210
+ worker=worker,
211
+ base_env=self.base_env,
212
+ multiple_episodes_in_batch=multiple_episodes_in_batch,
213
+ callbacks=callbacks,
214
+ perf_stats=self.perf_stats,
215
+ rollout_fragment_length=rollout_fragment_length,
216
+ count_steps_by=count_steps_by,
217
+ render=self.render,
218
+ )
219
+ self._env_runner = self._env_runner_obj.run()
220
+ self.metrics_queue = queue.Queue()
221
+
222
+ @override(SamplerInput)
223
+ def get_data(self) -> SampleBatchType:
224
+ while True:
225
+ item = next(self._env_runner)
226
+ if isinstance(item, RolloutMetrics):
227
+ self.metrics_queue.put(item)
228
+ else:
229
+ return item
230
+
231
+ @override(SamplerInput)
232
+ def get_metrics(self) -> List[RolloutMetrics]:
233
+ completed = []
234
+ while True:
235
+ try:
236
+ completed.append(
237
+ self.metrics_queue.get_nowait()._replace(
238
+ perf_stats=self.perf_stats.get()
239
+ )
240
+ )
241
+ except queue.Empty:
242
+ break
243
+ return completed
244
+
245
+ @override(SamplerInput)
246
+ def get_extra_batches(self) -> List[SampleBatchType]:
247
+ extra = []
248
+ while True:
249
+ try:
250
+ extra.append(self.extra_batches.get_nowait())
251
+ except queue.Empty:
252
+ break
253
+ return extra
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.deprecation import Deprecated
2
+
3
+
4
+ @Deprecated(
5
+ new="ray.rllib.env.env_runner_group.EnvRunnerGroup",
6
+ help="The class has only be renamed w/o any changes in functionality.",
7
+ error=True,
8
+ )
9
+ class WorkerSet:
10
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (195 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc ADDED
Binary file (20.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc ADDED
Binary file (6.78 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc ADDED
Binary file (45.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc ADDED
Binary file (41.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.offline.d4rl_reader import D4RLReader
2
+ from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
3
+ from ray.rllib.offline.dataset_writer import DatasetWriter
4
+ from ray.rllib.offline.io_context import IOContext
5
+ from ray.rllib.offline.input_reader import InputReader
6
+ from ray.rllib.offline.mixed_input import MixedInput
7
+ from ray.rllib.offline.json_reader import JsonReader
8
+ from ray.rllib.offline.json_writer import JsonWriter
9
+ from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
10
+ from ray.rllib.offline.resource import get_offline_io_resource_bundles
11
+ from ray.rllib.offline.shuffled_input import ShuffledInput
12
+ from ray.rllib.offline.feature_importance import FeatureImportance
13
+
14
+
15
+ __all__ = [
16
+ "IOContext",
17
+ "JsonReader",
18
+ "JsonWriter",
19
+ "NoopOutput",
20
+ "OutputWriter",
21
+ "InputReader",
22
+ "MixedInput",
23
+ "ShuffledInput",
24
+ "D4RLReader",
25
+ "DatasetReader",
26
+ "DatasetWriter",
27
+ "get_dataset_and_shards",
28
+ "get_offline_io_resource_bundles",
29
+ "FeatureImportance",
30
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc ADDED
Binary file (4.25 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc ADDED
Binary file (789 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc ADDED
Binary file (8.15 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc ADDED
Binary file (3.77 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc ADDED
Binary file (587 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc ADDED
Binary file (8.28 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc ADDED
Binary file (6.62 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc ADDED
Binary file (3.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc ADDED
Binary file (1.59 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc ADDED
Binary file (848 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gymnasium as gym
3
+
4
+ from ray.rllib.offline.input_reader import InputReader
5
+ from ray.rllib.offline.io_context import IOContext
6
+ from ray.rllib.policy.sample_batch import SampleBatch
7
+ from ray.rllib.utils.annotations import override, PublicAPI
8
+ from ray.rllib.utils.typing import SampleBatchType
9
+ from typing import Dict
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @PublicAPI
15
+ class D4RLReader(InputReader):
16
+ """Reader object that loads the dataset from the D4RL dataset."""
17
+
18
+ @PublicAPI
19
+ def __init__(self, inputs: str, ioctx: IOContext = None):
20
+ """Initializes a D4RLReader instance.
21
+
22
+ Args:
23
+ inputs: String corresponding to the D4RL environment name.
24
+ ioctx: Current IO context object.
25
+ """
26
+ import d4rl
27
+
28
+ self.env = gym.make(inputs)
29
+ self.dataset = _convert_to_batch(d4rl.qlearning_dataset(self.env))
30
+ assert self.dataset.count >= 1
31
+ self.counter = 0
32
+
33
+ @override(InputReader)
34
+ def next(self) -> SampleBatchType:
35
+ if self.counter >= self.dataset.count:
36
+ self.counter = 0
37
+
38
+ self.counter += 1
39
+ return self.dataset.slice(start=self.counter, end=self.counter + 1)
40
+
41
+
42
+ def _convert_to_batch(dataset: Dict) -> SampleBatchType:
43
+ # Converts D4RL dataset to SampleBatch
44
+ d = {}
45
+ d[SampleBatch.OBS] = dataset["observations"]
46
+ d[SampleBatch.ACTIONS] = dataset["actions"]
47
+ d[SampleBatch.NEXT_OBS] = dataset["next_observations"]
48
+ d[SampleBatch.REWARDS] = dataset["rewards"]
49
+ d[SampleBatch.TERMINATEDS] = dataset["terminals"]
50
+
51
+ return SampleBatch(d)
.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from pathlib import Path
4
+ import re
5
+ import numpy as np
6
+ from typing import List, Tuple, TYPE_CHECKING, Optional
7
+ import zipfile
8
+
9
+ import ray.data
10
+ from ray.rllib.offline.input_reader import InputReader
11
+ from ray.rllib.offline.io_context import IOContext
12
+ from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
13
+ from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
14
+ from ray.rllib.utils.annotations import override, PublicAPI
15
+ from ray.rllib.utils.typing import SampleBatchType
16
+
17
+ if TYPE_CHECKING:
18
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
19
+
20
+ DEFAULT_NUM_CPUS_PER_TASK = 0.5
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _unzip_this_path(fpath: Path, extract_path: str):
26
+ with zipfile.ZipFile(str(fpath), "r") as zip_ref:
27
+ zip_ref.extractall(extract_path)
28
+
29
+
30
+ def _unzip_if_needed(paths: List[str], format: str):
31
+ """If a path in paths is a zip file, unzip it and use path of the unzipped file"""
32
+ ret_paths = []
33
+ for path in paths:
34
+ if re.search("\\.zip$", str(path)):
35
+ # TODO: We need to add unzip support for s3
36
+ if str(path).startswith("s3://"):
37
+ raise ValueError(
38
+ "unzip_if_needed currently does not support remote paths from s3"
39
+ )
40
+ extract_path = "./"
41
+ try:
42
+ _unzip_this_path(str(path), extract_path)
43
+ except FileNotFoundError:
44
+ # intrepreted as a relative path to rllib folder
45
+ try:
46
+ # TODO: remove this later when we replace all tests with s3 paths
47
+ _unzip_this_path(Path(__file__).parent.parent / path, extract_path)
48
+ except FileNotFoundError:
49
+ raise FileNotFoundError(f"File not found: {path}")
50
+
51
+ unzipped_path = str(
52
+ Path(extract_path).absolute() / f"{Path(path).stem}.{format}"
53
+ )
54
+ ret_paths.append(unzipped_path)
55
+ else:
56
+ # TODO: We can get rid of this logic when we replace all tests with s3 paths
57
+ if str(path).startswith("s3://"):
58
+ ret_paths.append(path)
59
+ else:
60
+ if not Path(path).exists():
61
+ relative_path = str(Path(__file__).parent.parent / path)
62
+ if not Path(relative_path).exists():
63
+ raise FileNotFoundError(f"File not found: {path}")
64
+ path = relative_path
65
+ ret_paths.append(path)
66
+ return ret_paths
67
+
68
+
69
+ @PublicAPI
70
+ def get_dataset_and_shards(
71
+ config: "AlgorithmConfig", num_workers: int = 0
72
+ ) -> Tuple[ray.data.Dataset, List[ray.data.Dataset]]:
73
+ """Returns a dataset and a list of shards.
74
+
75
+ This function uses algorithm configs to create a dataset and a list of shards.
76
+ The following config keys are used to create the dataset:
77
+ input: The input type should be "dataset".
78
+ input_config: A dict containing the following key and values:
79
+ `format`: str, speciifies the format of the input data. This will be the
80
+ format that ray dataset supports. See ray.data.Dataset for
81
+ supported formats. Only "parquet" or "json" are supported for now.
82
+ `paths`: str, a single string or a list of strings. Each string is a path
83
+ to a file or a directory holding the dataset. It can be either a local path
84
+ or a remote path (e.g. to an s3 bucket).
85
+ `loader_fn`: Callable[None, ray.data.Dataset], Instead of
86
+ specifying paths and format, you can specify a function to load the dataset.
87
+ `parallelism`: int, The number of tasks to use for loading the dataset.
88
+ If not specified, it will be set to the number of workers.
89
+ `num_cpus_per_read_task`: float, The number of CPUs to use for each read
90
+ task. If not specified, it will be set to 0.5.
91
+
92
+ Args:
93
+ config: The config dict for the algorithm.
94
+ num_workers: The number of shards to create for remote workers.
95
+
96
+ Returns:
97
+ dataset: The dataset object.
98
+ shards: A list of dataset shards. For num_workers > 0 the first returned
99
+ shared would be a dummy None shard for local_worker.
100
+ """
101
+ # check input and input config keys
102
+ assert config.input_ == "dataset", (
103
+ f"Must specify config.input_ as 'dataset' if"
104
+ f" calling `get_dataset_and_shards`. Got {config.input_}"
105
+ )
106
+
107
+ # check input config format
108
+ input_config = config.input_config
109
+ format = input_config.get("format")
110
+
111
+ supported_fmts = ["json", "parquet"]
112
+ if format is not None and format not in supported_fmts:
113
+ raise ValueError(
114
+ f"Unsupported format {format}. Supported formats are {supported_fmts}"
115
+ )
116
+
117
+ # check paths and loader_fn since only one of them is required.
118
+ paths = input_config.get("paths")
119
+ loader_fn = input_config.get("loader_fn")
120
+ if loader_fn and (format or paths):
121
+ raise ValueError(
122
+ "When using a `loader_fn`, you cannot specify a `format` or `path`."
123
+ )
124
+
125
+ # check if at least loader_fn or format + path is specified.
126
+ if not (format and paths) and not loader_fn:
127
+ raise ValueError(
128
+ "Must specify either a `loader_fn` or a `format` and `path` in "
129
+ "`input_config`."
130
+ )
131
+
132
+ # check paths to be a str or list[str] if not None
133
+ if paths is not None:
134
+ if isinstance(paths, str):
135
+ paths = [paths]
136
+ elif isinstance(paths, list):
137
+ assert isinstance(paths[0], str), "Paths must be a list of path strings."
138
+ else:
139
+ raise ValueError("Paths must be a path string or a list of path strings.")
140
+ paths = _unzip_if_needed(paths, format)
141
+
142
+ # TODO (Kourosh): num_workers is not necessary since we can use parallelism for
143
+ # everything. Having two parameters is confusing here. Remove num_workers later.
144
+ parallelism = input_config.get("parallelism", num_workers or 1)
145
+ cpus_per_task = input_config.get(
146
+ "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
147
+ )
148
+
149
+ if loader_fn:
150
+ dataset = loader_fn()
151
+ elif format == "json":
152
+ dataset = ray.data.read_json(
153
+ paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
154
+ )
155
+ elif format == "parquet":
156
+ dataset = ray.data.read_parquet(
157
+ paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
158
+ )
159
+ else:
160
+ raise ValueError("Un-supported Ray dataset format: ", format)
161
+
162
+ # Local worker will be responsible for sampling.
163
+ if num_workers == 0:
164
+ # Dataset is the only shard we need.
165
+ return dataset, [dataset]
166
+ # Remote workers are responsible for sampling:
167
+ else:
168
+ # Each remote worker gets 1 shard.
169
+ remote_shards = dataset.repartition(
170
+ num_blocks=num_workers, shuffle=False
171
+ ).split(num_workers)
172
+
173
+ # The first None shard is for the local worker, which
174
+ # shouldn't be doing rollout work anyways.
175
+ return dataset, [None] + remote_shards
176
+
177
+
178
+ @PublicAPI
179
+ class DatasetReader(InputReader):
180
+ """Reader object that loads data from Ray Dataset.
181
+
182
+ Examples:
183
+ config = {
184
+ "input": "dataset",
185
+ "input_config": {
186
+ "format": "json",
187
+ # A single data file, a directory, or anything
188
+ # that ray.data.dataset recognizes.
189
+ "paths": "/tmp/sample_batches/",
190
+ # By default, parallelism=num_workers.
191
+ "parallelism": 3,
192
+ # Dataset allocates 0.5 CPU for each reader by default.
193
+ # Adjust this value based on the size of your offline dataset.
194
+ "num_cpus_per_read_task": 0.5,
195
+ }
196
+ }
197
+ """
198
+
199
+ @PublicAPI
200
+ def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
201
+ """Initializes a DatasetReader instance.
202
+
203
+ Args:
204
+ ds: Ray dataset to sample from.
205
+ """
206
+ self._ioctx = ioctx or IOContext()
207
+ self._default_policy = self.policy_map = None
208
+ self.preprocessor = None
209
+ self._dataset = ds
210
+ self.count = None if not self._dataset else self._dataset.count()
211
+ # do this to disable the ray data stdout logging
212
+ ray.data.DataContext.get_current().enable_progress_bars = False
213
+
214
+ # the number of steps to return per call to next()
215
+ self.batch_size = self._ioctx.config.get("train_batch_size", 1)
216
+ num_workers = self._ioctx.config.get("num_env_runners", 0)
217
+ seed = self._ioctx.config.get("seed", None)
218
+ if num_workers:
219
+ self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
220
+ # We allow the creation of a non-functioning None DatasetReader.
221
+ # It's useful for example for a non-rollout local worker.
222
+ if ds:
223
+ if self._ioctx.worker is not None:
224
+ self._policy_map = self._ioctx.worker.policy_map
225
+ self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
226
+ self.preprocessor = (
227
+ self._ioctx.worker.preprocessors.get(DEFAULT_POLICY_ID)
228
+ if not self._ioctx.config.get("_disable_preprocessors", False)
229
+ else None
230
+ )
231
+ print(
232
+ f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
233
+ )
234
+
235
+ def iterator():
236
+ while True:
237
+ ds = self._dataset.random_shuffle(seed=seed)
238
+ yield from ds.iter_rows()
239
+
240
+ self._iter = iterator()
241
+ else:
242
+ self._iter = None
243
+
244
+ @override(InputReader)
245
+ def next(self) -> SampleBatchType:
246
+ # next() should not get called on None DatasetReader.
247
+ assert self._iter is not None
248
+ ret = []
249
+ count = 0
250
+ while count < self.batch_size:
251
+ d = next(self._iter)
252
+ # Columns like obs are compressed when written by DatasetWriter.
253
+ d = from_json_data(d, self._ioctx.worker)
254
+ count += d.count
255
+ d = self._preprocess_if_needed(d)
256
+ d = postprocess_actions(d, self._ioctx)
257
+ d = self._postprocess_if_needed(d)
258
+ ret.append(d)
259
+ ret = concat_samples(ret)
260
+ return ret
261
+
262
+ def _preprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
263
+ # TODO: @kourosh, preprocessor is only supported for single agent case.
264
+ if self.preprocessor:
265
+ for key in (SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS):
266
+ if key in batch:
267
+ batch[key] = np.stack(
268
+ [self.preprocessor.transform(s) for s in batch[key]]
269
+ )
270
+ return batch
271
+
272
+ def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
273
+ if not self._ioctx.config.get("postprocess_inputs"):
274
+ return batch
275
+
276
+ if isinstance(batch, SampleBatch):
277
+ out = []
278
+ for sub_batch in batch.split_by_episode():
279
+ if self._default_policy is not None:
280
+ out.append(self._default_policy.postprocess_trajectory(sub_batch))
281
+ else:
282
+ out.append(sub_batch)
283
+ return concat_samples(out)
284
+ else:
285
+ # TODO(ekl) this is trickier since the alignments between agent
286
+ # trajectories in the episode are not available any more.
287
+ raise NotImplementedError(
288
+ "Postprocessing of multi-agent data not implemented yet."
289
+ )
.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ from ray import data
6
+ from ray.rllib.offline.io_context import IOContext
7
+ from ray.rllib.offline.json_writer import _to_json_dict
8
+ from ray.rllib.offline.output_writer import OutputWriter
9
+ from ray.rllib.utils.annotations import override, PublicAPI
10
+ from ray.rllib.utils.typing import SampleBatchType
11
+ from typing import Dict, List
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @PublicAPI
17
+ class DatasetWriter(OutputWriter):
18
+ """Writer object that saves experiences using Datasets."""
19
+
20
+ @PublicAPI
21
+ def __init__(
22
+ self,
23
+ ioctx: IOContext = None,
24
+ compress_columns: List[str] = frozenset(["obs", "new_obs"]),
25
+ ):
26
+ """Initializes a DatasetWriter instance.
27
+
28
+ Examples:
29
+ config = {
30
+ "output": "dataset",
31
+ "output_config": {
32
+ "format": "json",
33
+ "path": "/tmp/test_samples/",
34
+ "max_num_samples_per_file": 100000,
35
+ }
36
+ }
37
+
38
+ Args:
39
+ ioctx: current IO context object.
40
+ compress_columns: list of sample batch columns to compress.
41
+ """
42
+ self.ioctx = ioctx or IOContext()
43
+
44
+ output_config: Dict = ioctx.output_config
45
+ assert (
46
+ "format" in output_config
47
+ ), "output_config.format must be specified when using Dataset output."
48
+ assert (
49
+ "path" in output_config
50
+ ), "output_config.path must be specified when using Dataset output."
51
+
52
+ self.format = output_config["format"]
53
+ self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
54
+ self.max_num_samples_per_file = (
55
+ output_config["max_num_samples_per_file"]
56
+ if "max_num_samples_per_file" in output_config
57
+ else 100000
58
+ )
59
+ self.compress_columns = compress_columns
60
+
61
+ self.samples = []
62
+
63
+ @override(OutputWriter)
64
+ def write(self, sample_batch: SampleBatchType):
65
+ start = time.time()
66
+
67
+ # Make sure columns like obs are compressed and writable.
68
+ d = _to_json_dict(sample_batch, self.compress_columns)
69
+ self.samples.append(d)
70
+
71
+ # Todo: We should flush at the end of sampling even if this
72
+ # condition was not reached.
73
+ if len(self.samples) >= self.max_num_samples_per_file:
74
+ ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
75
+ if self.format == "json":
76
+ ds.write_json(self.path, try_create_dir=True)
77
+ elif self.format == "parquet":
78
+ ds.write_parquet(self.path, try_create_dir=True)
79
+ else:
80
+ raise ValueError("Unknown output type: ", self.format)
81
+ self.samples = []
82
+ logger.debug("Wrote dataset in {}s".format(time.time() - start))
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (802 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc ADDED
Binary file (8.83 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc ADDED
Binary file (15.9 kB). View file