koichi12 commited on
Commit
747c195
·
verified ·
1 Parent(s): 697a7f6

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/observation_function.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sample_batch_builder.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/worker_set.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__init__.py +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/__init__.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/agent_collector.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/sample_collector.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/simple_list_collector.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py +688 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/sample_collector.py +298 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/complex_input_net.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_modelv2.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__init__.py +13 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/gru_gate.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/multi_head_attention.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/noisy_layer.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/relative_multi_head_attention.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/skip_connection.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/skip_connection.py +43 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/d4rl_reader.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/input_reader.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/resource.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/shuffled_input.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__init__.py +15 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/doubly_robust.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/feature_importance.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/importance_sampling.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/weighted_importance_sampling.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/direct_method.py +180 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/doubly_robust.py +253 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/fqe_torch_model.py +297 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/importance_sampling.py +126 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/off_policy_estimator.py +248 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/utils/__init__.py +141 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/checkpoints.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/compression.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/deprecation.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/from_config.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/lambda_defaultdict.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/memory.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/serialization.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/torch_utils.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/utils/actors.py +258 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/utils/annotations.py +213 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/utils/checkpoints.py +1045 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/utils/deprecation.py +134 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/utils/error.py +128 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/utils/filter_manager.py +82 -0
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/observation_function.cpython-311.pyc ADDED
Binary file (3.98 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sample_batch_builder.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/worker_set.cpython-311.pyc ADDED
Binary file (713 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (204 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/agent_collector.cpython-311.pyc ADDED
Binary file (27.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/sample_collector.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/__pycache__/simple_list_collector.cpython-311.pyc ADDED
Binary file (28.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import math
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import numpy as np
7
+ import tree # pip install dm_tree
8
+ from gymnasium.spaces import Space
9
+
10
+ from ray.rllib.policy.sample_batch import SampleBatch
11
+ from ray.rllib.policy.view_requirement import ViewRequirement
12
+ from ray.rllib.utils.annotations import OldAPIStack
13
+ from ray.rllib.utils.framework import try_import_torch
14
+ from ray.rllib.utils.spaces.space_utils import (
15
+ flatten_to_single_ndarray,
16
+ get_dummy_batch_for_space,
17
+ )
18
+ from ray.rllib.utils.typing import (
19
+ EpisodeID,
20
+ EnvID,
21
+ TensorType,
22
+ ViewRequirementsDict,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ torch, _ = try_import_torch()
28
+
29
+
30
+ def _to_float_np_array(v: List[Any]) -> np.ndarray:
31
+ if torch and torch.is_tensor(v[0]):
32
+ raise ValueError
33
+ arr = np.array(v)
34
+ if arr.dtype == np.float64:
35
+ return arr.astype(np.float32) # save some memory
36
+ return arr
37
+
38
+
39
+ def _get_buffered_slice_with_paddings(d, inds):
40
+ element_at_t = []
41
+ for index in inds:
42
+ if index < len(d):
43
+ element_at_t.append(d[index])
44
+ else:
45
+ # zero pad similar to the last element.
46
+ element_at_t.append(tree.map_structure(np.zeros_like, d[-1]))
47
+ return element_at_t
48
+
49
+
50
+ @OldAPIStack
51
+ class AgentCollector:
52
+ """Collects samples for one agent in one trajectory (episode).
53
+
54
+ The agent may be part of a multi-agent environment. Samples are stored in
55
+ lists including some possible automatic "shift" buffer at the beginning to
56
+ be able to save memory when storing things like NEXT_OBS, PREV_REWARDS,
57
+ etc.., which are specified using the trajectory view API.
58
+ """
59
+
60
+ _next_unroll_id = 0 # disambiguates unrolls within a single episode
61
+
62
+ # TODO: @kourosh add different types of padding. e.g. zeros vs. same
63
+ def __init__(
64
+ self,
65
+ view_reqs: ViewRequirementsDict,
66
+ *,
67
+ max_seq_len: int = 1,
68
+ disable_action_flattening: bool = True,
69
+ intial_states: Optional[List[TensorType]] = None,
70
+ is_policy_recurrent: bool = False,
71
+ is_training: bool = True,
72
+ _enable_new_api_stack: bool = False,
73
+ ):
74
+ """Initialize an AgentCollector.
75
+
76
+ Args:
77
+ view_reqs: A dict of view requirements for the agent.
78
+ max_seq_len: The maximum sequence length to store.
79
+ disable_action_flattening: If True, don't flatten the action.
80
+ intial_states: The initial states from the policy.get_initial_states()
81
+ is_policy_recurrent: If True, the policy is recurrent.
82
+ is_training: Sets the is_training flag for the buffers. if True, all the
83
+ timesteps are stored in the buffers until explictly build_for_training
84
+ () is called. if False, only the content required for the last time
85
+ step is stored in the buffers. This will save memory during inference.
86
+ You can change the behavior at runtime by calling is_training(mode).
87
+ """
88
+ self.max_seq_len = max_seq_len
89
+ self.disable_action_flattening = disable_action_flattening
90
+ self.view_requirements = view_reqs
91
+ # The initial_states can be an np array
92
+ self.initial_states = intial_states if intial_states is not None else []
93
+ self.is_policy_recurrent = is_policy_recurrent
94
+ self._is_training = is_training
95
+ self._enable_new_api_stack = _enable_new_api_stack
96
+
97
+ # Determine the size of the buffer we need for data before the actual
98
+ # episode starts. This is used for 0-buffering of e.g. prev-actions,
99
+ # or internal state inputs.
100
+ view_req_shifts = [
101
+ min(vr.shift_arr)
102
+ - int((vr.data_col or k) in [SampleBatch.OBS, SampleBatch.INFOS])
103
+ for k, vr in view_reqs.items()
104
+ ]
105
+ self.shift_before = -min(view_req_shifts)
106
+
107
+ # The actual data buffers. Keys are column names, values are lists
108
+ # that contain the sub-components (e.g. for complex obs spaces) with
109
+ # each sub-component holding a list of per-timestep tensors.
110
+ # E.g.: obs-space = Dict(a=Discrete(2), b=Box((2,)))
111
+ # buffers["obs"] = [
112
+ # [0, 1], # <- 1st sub-component of observation
113
+ # [np.array([.2, .3]), np.array([.0, -.2])] # <- 2nd sub-component
114
+ # ]
115
+ # NOTE: infos and state_out... are not flattened due to them often
116
+ # using custom dict values whose structure may vary from timestep to
117
+ # timestep.
118
+ self.buffers: Dict[str, List[List[TensorType]]] = {}
119
+ # Maps column names to an example data item, which may be deeply
120
+ # nested. These are used such that we'll know how to unflatten
121
+ # the flattened data inside self.buffers when building the
122
+ # SampleBatch.
123
+ self.buffer_structs: Dict[str, Any] = {}
124
+ # The episode ID for the agent for which we collect data.
125
+ self.episode_id = None
126
+ # The unroll ID, unique across all rollouts (within a RolloutWorker).
127
+ self.unroll_id = None
128
+ # The simple timestep count for this agent. Gets increased by one
129
+ # each time a (non-initial!) observation is added.
130
+ self.agent_steps = 0
131
+ # Keep track of view requirements that have a view on columns that we gain from
132
+ # inference and also need for inference. These have dummy values appended in
133
+ # buffers to account for the missing value when building for inference
134
+ # Example: We have one 'state_in' view requirement that has a view on our
135
+ # state_outs at t=[-10, ..., -1]. At any given build_for_inference()-call,
136
+ # the buffer must contain eleven values from t=[-10, ..., 0] for us to index
137
+ # properly. Since state_out at t=0 is missing, we substitute it with a buffer
138
+ # value that should never make it into batches built for training.
139
+ self.data_cols_with_dummy_values = set()
140
+
141
+ @property
142
+ def training(self) -> bool:
143
+ return self._is_training
144
+
145
+ def is_training(self, is_training: bool) -> None:
146
+ self._is_training = is_training
147
+
148
+ def is_empty(self) -> bool:
149
+ """Returns True if this collector has no data."""
150
+ return not self.buffers or all(len(item) == 0 for item in self.buffers.values())
151
+
152
+ def add_init_obs(
153
+ self,
154
+ episode_id: EpisodeID,
155
+ agent_index: int,
156
+ env_id: EnvID,
157
+ init_obs: TensorType,
158
+ init_infos: Optional[Dict[str, TensorType]] = None,
159
+ t: int = -1,
160
+ ) -> None:
161
+ """Adds an initial observation (after reset) to the Agent's trajectory.
162
+
163
+ Args:
164
+ episode_id: Unique ID for the episode we are adding the
165
+ initial observation for.
166
+ agent_index: Unique int index (starting from 0) for the agent
167
+ within its episode. Not to be confused with AGENT_ID (Any).
168
+ env_id: The environment index (in a vectorized setup).
169
+ init_obs: The initial observation tensor (after `env.reset()`).
170
+ init_infos: The initial infos dict (after `env.reset()`).
171
+ t: The time step (episode length - 1). The initial obs has
172
+ ts=-1(!), then an action/reward/next-obs at t=0, etc..
173
+ """
174
+ # Store episode ID + unroll ID, which will be constant throughout this
175
+ # AgentCollector's lifecycle.
176
+ self.episode_id = episode_id
177
+ if self.unroll_id is None:
178
+ self.unroll_id = AgentCollector._next_unroll_id
179
+ AgentCollector._next_unroll_id += 1
180
+
181
+ # convert init_obs to np.array (in case it is a list)
182
+ if isinstance(init_obs, list):
183
+ init_obs = np.array(init_obs)
184
+
185
+ if SampleBatch.OBS not in self.buffers:
186
+ single_row = {
187
+ SampleBatch.OBS: init_obs,
188
+ SampleBatch.INFOS: init_infos or {},
189
+ SampleBatch.AGENT_INDEX: agent_index,
190
+ SampleBatch.ENV_ID: env_id,
191
+ SampleBatch.T: t,
192
+ SampleBatch.EPS_ID: self.episode_id,
193
+ SampleBatch.UNROLL_ID: self.unroll_id,
194
+ }
195
+
196
+ # TODO (Artur): Remove when PREV_ACTIONS and PREV_REWARDS get deprecated.
197
+ # Note (Artur): As long as we have these in our default view requirements,
198
+ # we should build buffers with neutral elements instead of building them
199
+ # on the first AgentCollector.build_for_inference call if present.
200
+ # This prevents us from accidentally building buffers with duplicates of
201
+ # the first incoming value.
202
+ if SampleBatch.PREV_REWARDS in self.view_requirements:
203
+ single_row[SampleBatch.REWARDS] = get_dummy_batch_for_space(
204
+ space=self.view_requirements[SampleBatch.REWARDS].space,
205
+ batch_size=0,
206
+ fill_value=0.0,
207
+ )
208
+ if SampleBatch.PREV_ACTIONS in self.view_requirements:
209
+ potentially_flattened_batch = get_dummy_batch_for_space(
210
+ space=self.view_requirements[SampleBatch.ACTIONS].space,
211
+ batch_size=0,
212
+ fill_value=0.0,
213
+ )
214
+ if not self.disable_action_flattening:
215
+ potentially_flattened_batch = flatten_to_single_ndarray(
216
+ potentially_flattened_batch
217
+ )
218
+ single_row[SampleBatch.ACTIONS] = potentially_flattened_batch
219
+ self._build_buffers(single_row)
220
+
221
+ # Append data to existing buffers.
222
+ flattened = tree.flatten(init_obs)
223
+ for i, sub_obs in enumerate(flattened):
224
+ self.buffers[SampleBatch.OBS][i].append(sub_obs)
225
+ self.buffers[SampleBatch.INFOS][0].append(init_infos or {})
226
+ self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
227
+ self.buffers[SampleBatch.ENV_ID][0].append(env_id)
228
+ self.buffers[SampleBatch.T][0].append(t)
229
+ self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
230
+ self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
231
+
232
+ def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> None:
233
+ """Adds the given dictionary (row) of values to the Agent's trajectory.
234
+
235
+ Args:
236
+ values: Data dict (interpreted as a single row) to be added to buffer.
237
+ Must contain keys:
238
+ SampleBatch.ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS, and NEXT_OBS.
239
+ """
240
+ if self.unroll_id is None:
241
+ self.unroll_id = AgentCollector._next_unroll_id
242
+ AgentCollector._next_unroll_id += 1
243
+
244
+ # Next obs -> obs.
245
+ values = copy.copy(input_values)
246
+ assert SampleBatch.OBS not in values
247
+ values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
248
+ del values[SampleBatch.NEXT_OBS]
249
+
250
+ # convert obs to np.array (in case it is a list)
251
+ if isinstance(values[SampleBatch.OBS], list):
252
+ values[SampleBatch.OBS] = np.array(values[SampleBatch.OBS])
253
+
254
+ # Default to next timestep if not provided in input values
255
+ if SampleBatch.T not in input_values:
256
+ values[SampleBatch.T] = self.buffers[SampleBatch.T][0][-1] + 1
257
+
258
+ # Make sure EPS_ID/UNROLL_ID stay the same for this agent.
259
+ if SampleBatch.EPS_ID in values:
260
+ assert values[SampleBatch.EPS_ID] == self.episode_id
261
+ del values[SampleBatch.EPS_ID]
262
+ self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
263
+ if SampleBatch.UNROLL_ID in values:
264
+ assert values[SampleBatch.UNROLL_ID] == self.unroll_id
265
+ del values[SampleBatch.UNROLL_ID]
266
+ self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
267
+
268
+ for k, v in values.items():
269
+ if k not in self.buffers:
270
+ if self.training and k.startswith("state_out"):
271
+ vr = self.view_requirements[k]
272
+ data_col = vr.data_col or k
273
+ self._fill_buffer_with_initial_values(
274
+ data_col, vr, build_for_inference=False
275
+ )
276
+ else:
277
+ self._build_buffers({k: v})
278
+ # Do not flatten infos, state_out and (if configured) actions.
279
+ # Infos/state-outs may be structs that change from timestep to
280
+ # timestep.
281
+ should_flatten_action_key = (
282
+ k == SampleBatch.ACTIONS and not self.disable_action_flattening
283
+ )
284
+ # Note (Artur) RL Modules's states need no flattening
285
+ should_flatten_state_key = (
286
+ k.startswith("state_out") and not self._enable_new_api_stack
287
+ )
288
+ if (
289
+ k == SampleBatch.INFOS
290
+ or should_flatten_state_key
291
+ or should_flatten_action_key
292
+ ):
293
+ if should_flatten_action_key:
294
+ v = flatten_to_single_ndarray(v)
295
+ # Briefly remove dummy value to add to buffer
296
+ if k in self.data_cols_with_dummy_values:
297
+ dummy = self.buffers[k][0].pop(-1)
298
+ self.buffers[k][0].append(v)
299
+ # Add back dummy value
300
+ if k in self.data_cols_with_dummy_values:
301
+ self.buffers[k][0].append(dummy)
302
+ # Flatten all other columns.
303
+ else:
304
+ flattened = tree.flatten(v)
305
+ for i, sub_list in enumerate(self.buffers[k]):
306
+ # Briefly remove dummy value to add to buffer
307
+ if k in self.data_cols_with_dummy_values:
308
+ dummy = sub_list.pop(-1)
309
+ sub_list.append(flattened[i])
310
+ # Add back dummy value
311
+ if k in self.data_cols_with_dummy_values:
312
+ sub_list.append(dummy)
313
+
314
+ # In inference mode, we don't need to keep all of trajectory in memory
315
+ # we only need to keep the steps required. We can pop from the beginning to
316
+ # create room for new data.
317
+ if not self.training:
318
+ for k in self.buffers:
319
+ for sub_list in self.buffers[k]:
320
+ if sub_list:
321
+ sub_list.pop(0)
322
+
323
+ self.agent_steps += 1
324
+
325
+ def build_for_inference(self) -> SampleBatch:
326
+ """During inference, we will build a SampleBatch with a batch size of 1 that
327
+ can then be used to run the forward pass of a policy. This data will only
328
+ include the enviornment context for running the policy at the last timestep.
329
+
330
+ Returns:
331
+ A SampleBatch with a batch size of 1.
332
+ """
333
+
334
+ batch_data = {}
335
+ np_data = {}
336
+ for view_col, view_req in self.view_requirements.items():
337
+ # Create the batch of data from the different buffers.
338
+ data_col = view_req.data_col or view_col
339
+
340
+ # if this view is not for inference, skip it.
341
+ if not view_req.used_for_compute_actions:
342
+ continue
343
+
344
+ if np.any(view_req.shift_arr > 0):
345
+ raise ValueError(
346
+ f"During inference the agent can only use past observations to "
347
+ f"respect causality. However, view_col = {view_col} seems to "
348
+ f"depend on future indices {view_req.shift_arr}, while the "
349
+ f"used_for_compute_actions flag is set to True. Please fix the "
350
+ f"discrepancy. Hint: If you are using a custom model make sure "
351
+ f"the view_requirements are initialized properly and is point "
352
+ f"only refering to past timesteps during inference."
353
+ )
354
+
355
+ # Some columns don't exist yet
356
+ # (get created during postprocessing or depend on state_out).
357
+ if data_col not in self.buffers:
358
+ self._fill_buffer_with_initial_values(
359
+ data_col, view_req, build_for_inference=True
360
+ )
361
+ self._prepare_for_data_cols_with_dummy_values(data_col)
362
+
363
+ # Keep an np-array cache, so we don't have to regenerate the
364
+ # np-array for different view_cols using to the same data_col.
365
+ self._cache_in_np(np_data, data_col)
366
+
367
+ data = []
368
+ for d in np_data[data_col]:
369
+ # if shift_arr = [0] the data will be just the last time step
370
+ # (len(d) - 1), if shift_arr = [-1] the data will be just the timestep
371
+ # before the last one (len(d) - 2) and so on.
372
+ element_at_t = d[view_req.shift_arr + len(d) - 1]
373
+ if element_at_t.shape[0] == 1:
374
+ # We'd normally squeeze here to remove the time dim, but we'll
375
+ # simply use the time dim as the batch dim.
376
+ data.append(element_at_t)
377
+ continue
378
+ # add the batch dimension with [None]
379
+ data.append(element_at_t[None])
380
+
381
+ # We unflatten even if data is empty here, because the structure might be
382
+ # nested with empty leafs and so we still need to reconstruct it.
383
+ # This is useful because we spec-check states in RLModules and these
384
+ # states can sometimes be nested dicts with empty leafs.
385
+ batch_data[view_col] = self._unflatten_as_buffer_struct(data, data_col)
386
+
387
+ batch = self._get_sample_batch(batch_data)
388
+ return batch
389
+
390
+ # TODO: @kouorsh we don't really need view_requirements anymore since it's already
391
+ # an attribute of the class
392
+ def build_for_training(
393
+ self, view_requirements: ViewRequirementsDict
394
+ ) -> SampleBatch:
395
+ """Builds a SampleBatch from the thus-far collected agent data.
396
+
397
+ If the episode/trajectory has no TERMINATED|TRUNCATED=True at the end, will
398
+ copy the necessary n timesteps at the end of the trajectory back to the
399
+ beginning of the buffers and wait for new samples coming in.
400
+ SampleBatches created by this method will be ready for postprocessing
401
+ by a Policy.
402
+
403
+ Args:
404
+ view_requirements: The viewrequirements dict needed to build the
405
+ SampleBatch from the raw buffers (which may have data shifts as well as
406
+ mappings from view-col to data-col in them).
407
+
408
+ Returns:
409
+ SampleBatch: The built SampleBatch for this agent, ready to go into
410
+ postprocessing.
411
+ """
412
+ batch_data = {}
413
+ np_data = {}
414
+ for view_col, view_req in view_requirements.items():
415
+ # Create the batch of data from the different buffers.
416
+ data_col = view_req.data_col or view_col
417
+
418
+ if data_col not in self.buffers:
419
+ is_state = self._fill_buffer_with_initial_values(
420
+ data_col, view_req, build_for_inference=False
421
+ )
422
+
423
+ # We need to skip this view_col if it does not exist in the buffers and
424
+ # is not an RNN state because it could be the special keys that gets
425
+ # added by policy's postprocessing function for training.
426
+ if not is_state:
427
+ continue
428
+
429
+ # OBS and INFOS are already shifted by -1 (the initial obs/info starts one
430
+ # ts before all other data columns).
431
+ obs_shift = -1 if data_col in [SampleBatch.OBS, SampleBatch.INFOS] else 0
432
+
433
+ # Keep an np-array cache so we don't have to regenerate the
434
+ # np-array for different view_cols using to the same data_col.
435
+ self._cache_in_np(np_data, data_col)
436
+
437
+ # Go through each time-step in the buffer and construct the view
438
+ # accordingly.
439
+ data = []
440
+ for d in np_data[data_col]:
441
+ shifted_data = []
442
+
443
+ # batch_repeat_value determines how many time steps should we skip
444
+ # before we repeat indexing the data.
445
+ # Example: batch_repeat_value=10, shift_arr = [-3, -2, -1],
446
+ # shift_before = 3
447
+ # buffer = [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
448
+ # resulting_data = [[-3, -2, -1], [7, 8, 9]]
449
+ # explanation: For t=0, we output [-3, -2, -1]. We then skip 10 time
450
+ # steps ahead and get to t=10. For t=10, we output [7, 8, 9]. We skip
451
+ # 10 more time steps and get to t=20. but since t=20 is out of bound we
452
+ # stop.
453
+
454
+ # count computes the number of time steps that we need to consider.
455
+ # if batch_repeat_value = 1, this number should be the length of
456
+ # episode so far, which is len(buffer) - shift_before (-1 if this
457
+ # value was gained during inference. This is because we keep a dummy
458
+ # value at the last position of the buffer that makes it one longer).
459
+ count = int(
460
+ math.ceil(
461
+ (
462
+ len(d)
463
+ - int(data_col in self.data_cols_with_dummy_values)
464
+ - self.shift_before
465
+ )
466
+ / view_req.batch_repeat_value
467
+ )
468
+ )
469
+ for i in range(count):
470
+ # the indices for time step t
471
+ inds = (
472
+ self.shift_before
473
+ + obs_shift
474
+ + view_req.shift_arr
475
+ + (i * view_req.batch_repeat_value)
476
+ )
477
+
478
+ # handle the case where the inds are out of bounds from the end.
479
+ # if during the indexing any of the indices are out of bounds, we
480
+ # need to use padding on the end to fill in the missing indices.
481
+ # Create padding first time we encounter data
482
+ if max(inds) < len(d):
483
+ # Simple case where we can simply pick slices from buffer
484
+ element_at_t = d[inds]
485
+ else:
486
+ # Case in which we have to pad because buffer has insufficient
487
+ # length. This branch takes more time than simply picking
488
+ # slices we try to avoid it.
489
+ element_at_t = _get_buffered_slice_with_paddings(d, inds)
490
+ element_at_t = np.stack(element_at_t)
491
+
492
+ if element_at_t.shape[0] == 1:
493
+ # Remove the T dimension if it is 1.
494
+ element_at_t = element_at_t[0]
495
+ shifted_data.append(element_at_t)
496
+
497
+ # in some multi-agent cases shifted_data may be an empty list.
498
+ # In this case we should just create an empty array and return it.
499
+ if shifted_data:
500
+ shifted_data_np = np.stack(shifted_data, 0)
501
+ else:
502
+ shifted_data_np = np.array(shifted_data)
503
+ data.append(shifted_data_np)
504
+
505
+ # We unflatten even if data is empty here, because the structure might be
506
+ # nested with empty leafs and so we still need to reconstruct it.
507
+ # This is useful because we spec-check states in RLModules and these
508
+ # states can sometimes be nested dicts with empty leafs.
509
+ batch_data[view_col] = self._unflatten_as_buffer_struct(data, data_col)
510
+
511
+ batch = self._get_sample_batch(batch_data)
512
+
513
+ # This trajectory is continuing -> Copy data at the end (in the size of
514
+ # self.shift_before) to the beginning of buffers and erase everything
515
+ # else.
516
+ if (
517
+ SampleBatch.TERMINATEDS in self.buffers
518
+ and not self.buffers[SampleBatch.TERMINATEDS][0][-1]
519
+ and SampleBatch.TRUNCATEDS in self.buffers
520
+ and not self.buffers[SampleBatch.TRUNCATEDS][0][-1]
521
+ ):
522
+ # Copy data to beginning of buffer and cut lists.
523
+ if self.shift_before > 0:
524
+ for k, data in self.buffers.items():
525
+ # Loop through
526
+ for i in range(len(data)):
527
+ self.buffers[k][i] = data[i][-self.shift_before :]
528
+ self.agent_steps = 0
529
+
530
+ # Reset our unroll_id.
531
+ self.unroll_id = None
532
+
533
+ return batch
534
+
535
+ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
536
+ """Builds the buffers for sample collection, given an example data row.
537
+
538
+ Args:
539
+ single_row (Dict[str, TensorType]): A single row (keys=column
540
+ names) of data to base the buffers on.
541
+ """
542
+ for col, data in single_row.items():
543
+ if col in self.buffers:
544
+ continue
545
+
546
+ shift = self.shift_before - (
547
+ 1
548
+ if col
549
+ in [
550
+ SampleBatch.OBS,
551
+ SampleBatch.INFOS,
552
+ SampleBatch.EPS_ID,
553
+ SampleBatch.AGENT_INDEX,
554
+ SampleBatch.ENV_ID,
555
+ SampleBatch.T,
556
+ SampleBatch.UNROLL_ID,
557
+ ]
558
+ else 0
559
+ )
560
+
561
+ # Store all data as flattened lists, except INFOS and state-out
562
+ # lists. These are monolithic items (infos is a dict that
563
+ # should not be further split, same for state-out items, which
564
+ # could be custom dicts as well).
565
+ should_flatten_action_key = (
566
+ col == SampleBatch.ACTIONS and not self.disable_action_flattening
567
+ )
568
+ # Note (Artur) RL Modules's states need no flattening
569
+ should_flatten_state_key = (
570
+ col.startswith("state_out") and not self._enable_new_api_stack
571
+ )
572
+ if (
573
+ col == SampleBatch.INFOS
574
+ or should_flatten_state_key
575
+ or should_flatten_action_key
576
+ ):
577
+ if should_flatten_action_key:
578
+ data = flatten_to_single_ndarray(data)
579
+ self.buffers[col] = [[data for _ in range(shift)]]
580
+ else:
581
+ self.buffers[col] = [
582
+ [v for _ in range(shift)] for v in tree.flatten(data)
583
+ ]
584
+ # Store an example data struct so we know, how to unflatten
585
+ # each data col.
586
+ self.buffer_structs[col] = data
587
+
588
+ def _get_sample_batch(self, batch_data: Dict[str, TensorType]) -> SampleBatch:
589
+ """Returns a SampleBatch from the given data dictionary. Also updates the
590
+ sequence information based on the max_seq_len."""
591
+
592
+ # Due to possible batch-repeats > 1, columns in the resulting batch
593
+ # may not all have the same batch size.
594
+ batch = SampleBatch(batch_data, is_training=self.training)
595
+
596
+ # Adjust the seq-lens array depending on the incoming agent sequences.
597
+ if self.is_policy_recurrent:
598
+ seq_lens = []
599
+ max_seq_len = self.max_seq_len
600
+ count = batch.count
601
+ while count > 0:
602
+ seq_lens.append(min(count, max_seq_len))
603
+ count -= max_seq_len
604
+ batch["seq_lens"] = np.array(seq_lens)
605
+ batch.max_seq_len = max_seq_len
606
+
607
+ return batch
608
+
609
+ def _cache_in_np(self, cache_dict: Dict[str, List[np.ndarray]], key: str) -> None:
610
+ """Caches the numpy version of the key in the buffer dict."""
611
+ if key not in cache_dict:
612
+ cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
613
+
614
+ def _unflatten_as_buffer_struct(
615
+ self, data: List[np.ndarray], key: str
616
+ ) -> np.ndarray:
617
+ """Unflattens the given to match the buffer struct format for that key."""
618
+ if key not in self.buffer_structs:
619
+ return data[0]
620
+
621
+ return tree.unflatten_as(self.buffer_structs[key], data)
622
+
623
+ def _fill_buffer_with_initial_values(
624
+ self,
625
+ data_col: str,
626
+ view_requirement: ViewRequirement,
627
+ build_for_inference: bool = False,
628
+ ) -> bool:
629
+ """Fills the buffer with the initial values for the given data column.
630
+ for dat_col starting with `state_out`, use the initial states of the policy,
631
+ but for other data columns, create a dummy value based on the view requirement
632
+ space.
633
+
634
+ Args:
635
+ data_col: The data column to fill the buffer with.
636
+ view_requirement: The view requirement for the view_col. Normally the view
637
+ requirement for the data column is used and if it does not exist for
638
+ some reason the view requirement for view column is used instead.
639
+ build_for_inference: Whether this is getting called for inference or not.
640
+
641
+ returns:
642
+ is_state: True if the data_col is an RNN state, False otherwise.
643
+ """
644
+ try:
645
+ space = self.view_requirements[data_col].space
646
+ except KeyError:
647
+ space = view_requirement.space
648
+
649
+ # special treatment for state_out
650
+ # add them to the buffer in case they don't exist yet
651
+ is_state = True
652
+ if data_col.startswith("state_out"):
653
+ if self._enable_new_api_stack:
654
+ self._build_buffers({data_col: self.initial_states})
655
+ else:
656
+ if not self.is_policy_recurrent:
657
+ raise ValueError(
658
+ f"{data_col} is not available, because the given policy is"
659
+ f"not recurrent according to the input model_inital_states."
660
+ f"Have you forgotten to return non-empty lists in"
661
+ f"policy.get_initial_states()?"
662
+ )
663
+ state_ind = int(data_col.split("_")[-1])
664
+ self._build_buffers({data_col: self.initial_states[state_ind]})
665
+ else:
666
+ is_state = False
667
+ # only create dummy data during inference
668
+ if build_for_inference:
669
+ if isinstance(space, Space):
670
+ # state_out assumes the values do not have a batch dimension
671
+ # (i.e. instead of being (1, d) it is of shape (d,).
672
+ fill_value = get_dummy_batch_for_space(
673
+ space,
674
+ batch_size=0,
675
+ )
676
+ else:
677
+ fill_value = space
678
+
679
+ self._build_buffers({data_col: fill_value})
680
+
681
+ return is_state
682
+
683
+ def _prepare_for_data_cols_with_dummy_values(self, data_col):
684
+ self.data_cols_with_dummy_values.add(data_col)
685
+ # For items gained during inference, we append a dummy value here so
686
+ # that view requirements viewing these is not shifted by 1
687
+ for b in self.buffers[data_col]:
688
+ b.append(b[-1])
.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/sample_collector.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABCMeta, abstractmethod
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
4
+
5
+ from ray.rllib.policy.policy_map import PolicyMap
6
+ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
7
+ from ray.rllib.utils.annotations import OldAPIStack
8
+ from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, TensorType
9
+
10
+ if TYPE_CHECKING:
11
+ from ray.rllib.callbacks.callbacks import RLlibCallback
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ # fmt: off
17
+ # __sphinx_doc_begin__
18
+ @OldAPIStack
19
+ class SampleCollector(metaclass=ABCMeta):
20
+ """Collects samples for all policies and agents from a multi-agent env.
21
+
22
+ This API is controlled by RolloutWorker objects to store all data
23
+ generated by Environments and Policies/Models during rollout and
24
+ postprocessing. It's purposes are to a) make data collection and
25
+ SampleBatch/input_dict generation from this data faster, b) to unify
26
+ the way we collect samples from environments and model (outputs), thereby
27
+ allowing for possible user customizations, c) to allow for more complex
28
+ inputs fed into different policies (e.g. multi-agent case with inter-agent
29
+ communication channel).
30
+ """
31
+
32
+ def __init__(self,
33
+ policy_map: PolicyMap,
34
+ clip_rewards: Union[bool, float],
35
+ callbacks: "RLlibCallback",
36
+ multiple_episodes_in_batch: bool = True,
37
+ rollout_fragment_length: int = 200,
38
+ count_steps_by: str = "env_steps"):
39
+ """Initializes a SampleCollector instance.
40
+
41
+ Args:
42
+ policy_map: Maps policy ids to policy instances.
43
+ clip_rewards (Union[bool, float]): Whether to clip rewards before
44
+ postprocessing (at +/-1.0) or the actual value to +/- clip.
45
+ callbacks: RLlib callbacks.
46
+ multiple_episodes_in_batch: Whether it's allowed to pack
47
+ multiple episodes into the same built batch.
48
+ rollout_fragment_length: The
49
+
50
+ """
51
+
52
+ self.policy_map = policy_map
53
+ self.clip_rewards = clip_rewards
54
+ self.callbacks = callbacks
55
+ self.multiple_episodes_in_batch = multiple_episodes_in_batch
56
+ self.rollout_fragment_length = rollout_fragment_length
57
+ self.count_steps_by = count_steps_by
58
+
59
+ @abstractmethod
60
+ def add_init_obs(
61
+ self,
62
+ *,
63
+ episode,
64
+ agent_id: AgentID,
65
+ policy_id: PolicyID,
66
+ init_obs: TensorType,
67
+ init_infos: Optional[Dict[str, TensorType]] = None,
68
+ t: int = -1,
69
+ ) -> None:
70
+ """Adds an initial obs (after reset) to this collector.
71
+
72
+ Since the very first observation in an environment is collected w/o
73
+ additional data (w/o actions, w/o reward) after env.reset() is called,
74
+ this method initializes a new trajectory for a given agent.
75
+ `add_init_obs()` has to be called first for each agent/episode-ID
76
+ combination. After this, only `add_action_reward_next_obs()` must be
77
+ called for that same agent/episode-pair.
78
+
79
+ Args:
80
+ episode: The Episode, for which we
81
+ are adding an Agent's initial observation.
82
+ agent_id: Unique id for the agent we are adding
83
+ values for.
84
+ env_id: The environment index (in a vectorized setup).
85
+ policy_id: Unique id for policy controlling the agent.
86
+ init_obs: Initial observation (after env.reset()).
87
+ init_obs: Initial observation (after env.reset()).
88
+ init_infos: Initial infos dict (after env.reset()).
89
+ t: The time step (episode length - 1). The initial obs has
90
+ ts=-1(!), then an action/reward/next-obs at t=0, etc..
91
+
92
+ .. testcode::
93
+ :skipif: True
94
+
95
+ obs, infos = env.reset()
96
+ collector.add_init_obs(
97
+ episode=my_episode,
98
+ agent_id=0,
99
+ policy_id="pol0",
100
+ t=-1,
101
+ init_obs=obs,
102
+ init_infos=infos,
103
+ )
104
+ obs, r, terminated, truncated, info = env.step(action)
105
+ collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
106
+ "action": action, "obs": obs, "reward": r, "terminated": terminated,
107
+ "truncated": truncated, "info": info
108
+ })
109
+ """
110
+ raise NotImplementedError
111
+
112
+ @abstractmethod
113
+ def add_action_reward_next_obs(
114
+ self,
115
+ episode_id: EpisodeID,
116
+ agent_id: AgentID,
117
+ env_id: EnvID,
118
+ policy_id: PolicyID,
119
+ agent_done: bool,
120
+ values: Dict[str, TensorType],
121
+ ) -> None:
122
+ """Add the given dictionary (row) of values to this collector.
123
+
124
+ The incoming data (`values`) must include action, reward, terminated, truncated,
125
+ and next_obs information and may include any other information.
126
+ For the initial observation (after Env.reset()) of the given agent/
127
+ episode-ID combination, `add_initial_obs()` must be called instead.
128
+
129
+ Args:
130
+ episode_id: Unique id for the episode we are adding
131
+ values for.
132
+ agent_id: Unique id for the agent we are adding
133
+ values for.
134
+ env_id: The environment index (in a vectorized setup).
135
+ policy_id: Unique id for policy controlling the agent.
136
+ agent_done: Whether the given agent is done (terminated or truncated) with
137
+ its trajectory (the multi-agent episode may still be ongoing).
138
+ values (Dict[str, TensorType]): Row of values to add for this
139
+ agent. This row must contain the keys SampleBatch.ACTION,
140
+ REWARD, NEW_OBS, TERMINATED, and TRUNCATED.
141
+
142
+ .. testcode::
143
+ :skipif: True
144
+
145
+ obs, info = env.reset()
146
+ collector.add_init_obs(12345, 0, "pol0", obs)
147
+ obs, r, terminated, truncated, info = env.step(action)
148
+ collector.add_action_reward_next_obs(
149
+ 12345,
150
+ 0,
151
+ "pol0",
152
+ agent_done=False,
153
+ values={
154
+ "action": action, "obs": obs, "reward": r,
155
+ "terminated": terminated, "truncated": truncated
156
+ },
157
+ )
158
+ """
159
+ raise NotImplementedError
160
+
161
+ @abstractmethod
162
+ def episode_step(self, episode) -> None:
163
+ """Increases the episode step counter (across all agents) by one.
164
+
165
+ Args:
166
+ episode: Episode we are stepping through.
167
+ Useful for handling counting b/c it is called once across
168
+ all agents that are inside this episode.
169
+ """
170
+ raise NotImplementedError
171
+
172
+ @abstractmethod
173
+ def total_env_steps(self) -> int:
174
+ """Returns total number of env-steps taken so far.
175
+
176
+ Thereby, a step in an N-agent multi-agent environment counts as only 1
177
+ for this metric. The returned count contains everything that has not
178
+ been built yet (and returned as MultiAgentBatches by the
179
+ `try_build_truncated_episode_multi_agent_batch` or
180
+ `postprocess_episode(build=True)` methods). After such build, this
181
+ counter is reset to 0.
182
+
183
+ Returns:
184
+ int: The number of env-steps taken in total in the environment(s)
185
+ so far.
186
+ """
187
+ raise NotImplementedError
188
+
189
+ @abstractmethod
190
+ def total_agent_steps(self) -> int:
191
+ """Returns total number of (individual) agent-steps taken so far.
192
+
193
+ Thereby, a step in an N-agent multi-agent environment counts as N.
194
+ If less than N agents have stepped (because some agents were not
195
+ required to send actions), the count will be increased by less than N.
196
+ The returned count contains everything that has not been built yet
197
+ (and returned as MultiAgentBatches by the
198
+ `try_build_truncated_episode_multi_agent_batch` or
199
+ `postprocess_episode(build=True)` methods). After such build, this
200
+ counter is reset to 0.
201
+
202
+ Returns:
203
+ int: The number of (individual) agent-steps taken in total in the
204
+ environment(s) so far.
205
+ """
206
+ raise NotImplementedError
207
+
208
+ # TODO(jungong) : Remove this API call once we completely move to
209
+ # connector based sample collection.
210
+ @abstractmethod
211
+ def get_inference_input_dict(self, policy_id: PolicyID) -> \
212
+ Dict[str, TensorType]:
213
+ """Returns an input_dict for an (inference) forward pass from our data.
214
+
215
+ The input_dict can then be used for action computations inside a
216
+ Policy via `Policy.compute_actions_from_input_dict()`.
217
+
218
+ Args:
219
+ policy_id: The Policy ID to get the input dict for.
220
+
221
+ Returns:
222
+ Dict[str, TensorType]: The input_dict to be passed into the ModelV2
223
+ for inference/training.
224
+
225
+ .. testcode::
226
+ :skipif: True
227
+
228
+ obs, r, terminated, truncated, info = env.step(action)
229
+ collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
230
+ "action": action, "obs": obs, "reward": r,
231
+ "terminated": terminated, "truncated", truncated
232
+ })
233
+ input_dict = collector.get_inference_input_dict(policy.model)
234
+ action = policy.compute_actions_from_input_dict(input_dict)
235
+ # repeat
236
+ """
237
+ raise NotImplementedError
238
+
239
+ @abstractmethod
240
+ def postprocess_episode(
241
+ self,
242
+ episode,
243
+ is_done: bool = False,
244
+ check_dones: bool = False,
245
+ build: bool = False,
246
+ ) -> Optional[MultiAgentBatch]:
247
+ """Postprocesses all agents' trajectories in a given episode.
248
+
249
+ Generates (single-trajectory) SampleBatches for all Policies/Agents and
250
+ calls Policy.postprocess_trajectory on each of these. Postprocessing
251
+ may happens in-place, meaning any changes to the viewed data columns
252
+ are directly reflected inside this collector's buffers.
253
+ Also makes sure that additional (newly created) data columns are
254
+ correctly added to the buffers.
255
+
256
+ Args:
257
+ episode: The Episode object for which
258
+ to post-process data.
259
+ is_done: Whether the given episode is actually terminated
260
+ (all agents are terminated OR truncated). If True, the
261
+ episode will no longer be used/continued and we may need to
262
+ recycle/erase it internally. If a soft-horizon is hit, the
263
+ episode will continue to be used and `is_done` should be set
264
+ to False here.
265
+ check_dones: Whether we need to check that all agents'
266
+ trajectories have dones=True at the end.
267
+ build: Whether to build a MultiAgentBatch from the given
268
+ episode (and only that episode!) and return that
269
+ MultiAgentBatch. Used for batch_mode=`complete_episodes`.
270
+
271
+ Returns:
272
+ Optional[MultiAgentBatch]: If `build` is True, the
273
+ SampleBatch or MultiAgentBatch built from `episode` (either
274
+ just from that episde or from the `_PolicyCollectorGroup`
275
+ in the `episode.batch_builder` property).
276
+ """
277
+ raise NotImplementedError
278
+
279
+ @abstractmethod
280
+ def try_build_truncated_episode_multi_agent_batch(self) -> \
281
+ List[Union[MultiAgentBatch, SampleBatch]]:
282
+ """Tries to build an MA-batch, if `rollout_fragment_length` is reached.
283
+
284
+ Any unprocessed data will be first postprocessed with a policy
285
+ postprocessor.
286
+ This is usually called to collect samples for policy training.
287
+ If not enough data has been collected yet (`rollout_fragment_length`),
288
+ returns an empty list.
289
+
290
+ Returns:
291
+ List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly
292
+ empty) list of MultiAgentBatches (containing the accumulated
293
+ SampleBatches for each policy or a simple SampleBatch if only
294
+ one policy). The list will be empty if
295
+ `self.rollout_fragment_length` has not been reached yet.
296
+ """
297
+ raise NotImplementedError
298
+ # __sphinx_doc_end__
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/complex_input_net.cpython-311.pyc ADDED
Binary file (9.81 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_modelv2.cpython-311.pyc ADDED
Binary file (4.45 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.models.torch.modules.gru_gate import GRUGate
2
+ from ray.rllib.models.torch.modules.multi_head_attention import MultiHeadAttention
3
+ from ray.rllib.models.torch.modules.relative_multi_head_attention import (
4
+ RelativeMultiHeadAttention,
5
+ )
6
+ from ray.rllib.models.torch.modules.skip_connection import SkipConnection
7
+
8
+ __all__ = [
9
+ "GRUGate",
10
+ "RelativeMultiHeadAttention",
11
+ "SkipConnection",
12
+ "MultiHeadAttention",
13
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (695 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/gru_gate.cpython-311.pyc ADDED
Binary file (4.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/multi_head_attention.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/noisy_layer.cpython-311.pyc ADDED
Binary file (6.18 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/relative_multi_head_attention.cpython-311.pyc ADDED
Binary file (9.77 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/__pycache__/skip_connection.cpython-311.pyc ADDED
Binary file (2.23 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/models/torch/modules/skip_connection.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.annotations import OldAPIStack
2
+ from ray.rllib.utils.framework import try_import_torch
3
+ from ray.rllib.utils.typing import TensorType
4
+ from typing import Optional
5
+
6
+ torch, nn = try_import_torch()
7
+
8
+
9
+ @OldAPIStack
10
+ class SkipConnection(nn.Module):
11
+ """Skip connection layer.
12
+
13
+ Adds the original input to the output (regular residual layer) OR uses
14
+ input as hidden state input to a given fan_in_layer.
15
+ """
16
+
17
+ def __init__(
18
+ self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, **kwargs
19
+ ):
20
+ """Initializes a SkipConnection nn Module object.
21
+
22
+ Args:
23
+ layer (nn.Module): Any layer processing inputs.
24
+ fan_in_layer (Optional[nn.Module]): An optional
25
+ layer taking two inputs: The original input and the output
26
+ of `layer`.
27
+ """
28
+ super().__init__(**kwargs)
29
+ self._layer = layer
30
+ self._fan_in_layer = fan_in_layer
31
+
32
+ def forward(self, inputs: TensorType, **kwargs) -> TensorType:
33
+ # del kwargs
34
+ outputs = self._layer(inputs, **kwargs)
35
+ # Residual case, just add inputs to outputs.
36
+ if self._fan_in_layer is None:
37
+ outputs = outputs + inputs
38
+ # Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`.
39
+ else:
40
+ # NOTE: In the GRU case, `inputs` is the state input.
41
+ outputs = self._fan_in_layer((inputs, outputs))
42
+
43
+ return outputs
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.37 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/d4rl_reader.cpython-311.pyc ADDED
Binary file (3.08 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/input_reader.cpython-311.pyc ADDED
Binary file (8.95 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/resource.cpython-311.pyc ADDED
Binary file (1.49 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/shuffled_input.cpython-311.pyc ADDED
Binary file (2.89 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
2
+ from ray.rllib.offline.estimators.weighted_importance_sampling import (
3
+ WeightedImportanceSampling,
4
+ )
5
+ from ray.rllib.offline.estimators.direct_method import DirectMethod
6
+ from ray.rllib.offline.estimators.doubly_robust import DoublyRobust
7
+ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
8
+
9
+ __all__ = [
10
+ "OffPolicyEstimator",
11
+ "ImportanceSampling",
12
+ "WeightedImportanceSampling",
13
+ "DirectMethod",
14
+ "DoublyRobust",
15
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/doubly_robust.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/feature_importance.cpython-311.pyc ADDED
Binary file (618 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/importance_sampling.cpython-311.pyc ADDED
Binary file (6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/weighted_importance_sampling.cpython-311.pyc ADDED
Binary file (9.14 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/direct_method.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Any, Optional, List
3
+ import math
4
+ import numpy as np
5
+
6
+ from ray.data import Dataset
7
+
8
+ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
9
+ from ray.rllib.offline.offline_evaluation_utils import compute_q_and_v_values
10
+ from ray.rllib.offline.offline_evaluator import OfflineEvaluator
11
+ from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
12
+ from ray.rllib.policy import Policy
13
+ from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
14
+ from ray.rllib.policy.sample_batch import SampleBatch
15
+ from ray.rllib.utils.annotations import DeveloperAPI, override
16
+ from ray.rllib.utils.typing import SampleBatchType
17
+ from ray.rllib.utils.numpy import convert_to_numpy
18
+
19
+ logger = logging.getLogger()
20
+
21
+
22
+ @DeveloperAPI
23
+ class DirectMethod(OffPolicyEstimator):
24
+ r"""The Direct Method estimator.
25
+
26
+ Let s_t, a_t, and r_t be the state, action, and reward at timestep t.
27
+
28
+ This method trains a Q-model for the evaluation policy \pi_e on behavior
29
+ data generated by \pi_b. Currently, RLlib implements this using
30
+ Fitted-Q Evaluation (FQE). You can also implement your own model
31
+ and pass it in as `q_model_config = {"type": your_model_class, **your_kwargs}`.
32
+
33
+ This estimator computes the expected return for \pi_e for an episode as:
34
+ V^{\pi_e}(s_0) = \sum_{a \in A} \pi_e(a | s_0) Q(s_0, a)
35
+ and returns the mean and standard deviation over episodes.
36
+
37
+ For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""
38
+
39
+ @override(OffPolicyEstimator)
40
+ def __init__(
41
+ self,
42
+ policy: Policy,
43
+ gamma: float,
44
+ epsilon_greedy: float = 0.0,
45
+ q_model_config: Optional[Dict] = None,
46
+ ):
47
+ """Initializes a Direct Method OPE Estimator.
48
+
49
+ Args:
50
+ policy: Policy to evaluate.
51
+ gamma: Discount factor of the environment.
52
+ epsilon_greedy: The probability by which we act acording to a fully random
53
+ policy during deployment. With 1-epsilon_greedy we act according the
54
+ target policy.
55
+ q_model_config: Arguments to specify the Q-model. Must specify
56
+ a `type` key pointing to the Q-model class.
57
+ This Q-model is trained in the train() method and is used
58
+ to compute the state-value estimates for the DirectMethod estimator.
59
+ It must implement `train` and `estimate_v`.
60
+ TODO (Rohan138): Unify this with RLModule API.
61
+ """
62
+
63
+ super().__init__(policy, gamma, epsilon_greedy)
64
+
65
+ # Some dummy policies and ones that are not based on a tensor framework
66
+ # backend can come without a config or without a framework key.
67
+ if hasattr(policy, "config"):
68
+ assert (
69
+ policy.config.get("framework", "torch") == "torch"
70
+ ), "Framework must be torch to use DirectMethod."
71
+
72
+ q_model_config = q_model_config or {}
73
+ model_cls = q_model_config.pop("type", FQETorchModel)
74
+ self.model = model_cls(
75
+ policy=policy,
76
+ gamma=gamma,
77
+ **q_model_config,
78
+ )
79
+ assert hasattr(
80
+ self.model, "estimate_v"
81
+ ), "self.model must implement `estimate_v`!"
82
+
83
+ @override(OffPolicyEstimator)
84
+ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
85
+ estimates_per_epsiode = {}
86
+ rewards = episode["rewards"]
87
+
88
+ v_behavior = 0.0
89
+ for t in range(episode.count):
90
+ v_behavior += rewards[t] * self.gamma**t
91
+
92
+ v_target = self._compute_v_target(episode[:1])
93
+
94
+ estimates_per_epsiode["v_behavior"] = v_behavior
95
+ estimates_per_epsiode["v_target"] = v_target
96
+
97
+ return estimates_per_epsiode
98
+
99
+ @override(OffPolicyEstimator)
100
+ def estimate_on_single_step_samples(
101
+ self, batch: SampleBatch
102
+ ) -> Dict[str, List[float]]:
103
+ estimates_per_epsiode = {}
104
+ rewards = batch["rewards"]
105
+
106
+ v_behavior = rewards
107
+ v_target = self._compute_v_target(batch)
108
+
109
+ estimates_per_epsiode["v_behavior"] = v_behavior
110
+ estimates_per_epsiode["v_target"] = v_target
111
+
112
+ return estimates_per_epsiode
113
+
114
+ def _compute_v_target(self, init_step):
115
+ v_target = self.model.estimate_v(init_step)
116
+ v_target = convert_to_numpy(v_target)
117
+ return v_target
118
+
119
+ @override(OffPolicyEstimator)
120
+ def train(self, batch: SampleBatchType) -> Dict[str, Any]:
121
+ """Trains self.model on the given batch.
122
+
123
+ Args:
124
+ batch: A SampleBatchType to train on
125
+
126
+ Returns:
127
+ A dict with key "loss" and value as the mean training loss.
128
+ """
129
+ batch = convert_ma_batch_to_sample_batch(batch)
130
+ losses = self.model.train(batch)
131
+ return {"loss": np.mean(losses)}
132
+
133
+ @override(OfflineEvaluator)
134
+ def estimate_on_dataset(
135
+ self, dataset: Dataset, *, n_parallelism: int = ...
136
+ ) -> Dict[str, Any]:
137
+ """Calculates the Direct Method estimate on the given dataset.
138
+
139
+ Note: This estimate works for only discrete action spaces for now.
140
+
141
+ Args:
142
+ dataset: Dataset to compute the estimate on. Each record in dataset should
143
+ include the following columns: `obs`, `actions`, `action_prob` and
144
+ `rewards`. The `obs` on each row shoud be a vector of D dimensions.
145
+ n_parallelism: The number of parallel workers to use.
146
+
147
+ Returns:
148
+ Dictionary with the following keys:
149
+ v_target: The estimated value of the target policy.
150
+ v_behavior: The estimated value of the behavior policy.
151
+ v_gain: The estimated gain of the target policy over the behavior
152
+ policy.
153
+ v_std: The standard deviation of the estimated value of the target.
154
+ """
155
+ # compute v_values
156
+ batch_size = max(dataset.count() // n_parallelism, 1)
157
+ updated_ds = dataset.map_batches(
158
+ compute_q_and_v_values,
159
+ batch_size=batch_size,
160
+ batch_format="pandas",
161
+ fn_kwargs={
162
+ "model_class": self.model.__class__,
163
+ "model_state": self.model.get_state(),
164
+ "compute_q_values": False,
165
+ },
166
+ )
167
+
168
+ v_behavior = updated_ds.mean("rewards")
169
+ v_target = updated_ds.mean("v_values")
170
+ v_gain_mean = v_target / v_behavior
171
+ v_gain_ste = (
172
+ updated_ds.std("v_values") / v_behavior / math.sqrt(dataset.count())
173
+ )
174
+
175
+ return {
176
+ "v_behavior": v_behavior,
177
+ "v_target": v_target,
178
+ "v_gain_mean": v_gain_mean,
179
+ "v_gain_ste": v_gain_ste,
180
+ }
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/doubly_robust.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import math
4
+ import pandas as pd
5
+
6
+ from typing import Dict, Any, Optional, List
7
+
8
+ from ray.data import Dataset
9
+
10
+ from ray.rllib.policy import Policy
11
+ from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch
12
+ from ray.rllib.utils.annotations import DeveloperAPI, override
13
+ from ray.rllib.utils.typing import SampleBatchType
14
+ from ray.rllib.utils.numpy import convert_to_numpy
15
+
16
+ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
17
+ from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
18
+ from ray.rllib.offline.offline_evaluator import OfflineEvaluator
19
+ from ray.rllib.offline.offline_evaluation_utils import (
20
+ compute_is_weights,
21
+ compute_q_and_v_values,
22
+ )
23
+
24
+ logger = logging.getLogger()
25
+
26
+
27
+ @DeveloperAPI
28
+ class DoublyRobust(OffPolicyEstimator):
29
+ """The Doubly Robust estimator.
30
+
31
+ Let s_t, a_t, and r_t be the state, action, and reward at timestep t.
32
+
33
+ This method trains a Q-model for the evaluation policy \pi_e on behavior
34
+ data generated by \pi_b. Currently, RLlib implements this using
35
+ Fitted-Q Evaluation (FQE). You can also implement your own model
36
+ and pass it in as `q_model_config = {"type": your_model_class, **your_kwargs}`.
37
+
38
+ For behavior policy \pi_b and evaluation policy \pi_e, define the
39
+ cumulative importance ratio at timestep t as:
40
+ p_t = \sum_{t'=0}^t (\pi_e(a_{t'} | s_{t'}) / \pi_b(a_{t'} | s_{t'})).
41
+
42
+ Consider an episode with length T. Let V_T = 0.
43
+ For all t in {0, T - 1}, use the following recursive update:
44
+ V_t^DR = (\sum_{a \in A} \pi_e(a | s_t) Q(s_t, a))
45
+ + p_t * (r_t + \gamma * V_{t+1}^DR - Q(s_t, a_t))
46
+
47
+ This estimator computes the expected return for \pi_e for an episode as:
48
+ V^{\pi_e}(s_0) = V_0^DR
49
+ and returns the mean and standard deviation over episodes.
50
+
51
+ For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""
52
+
53
+ @override(OffPolicyEstimator)
54
+ def __init__(
55
+ self,
56
+ policy: Policy,
57
+ gamma: float,
58
+ epsilon_greedy: float = 0.0,
59
+ normalize_weights: bool = True,
60
+ q_model_config: Optional[Dict] = None,
61
+ ):
62
+ """Initializes a Doubly Robust OPE Estimator.
63
+
64
+ Args:
65
+ policy: Policy to evaluate.
66
+ gamma: Discount factor of the environment.
67
+ epsilon_greedy: The probability by which we act acording to a fully random
68
+ policy during deployment. With 1-epsilon_greedy we act
69
+ according the target policy.
70
+ normalize_weights: If True, the inverse propensity scores are normalized to
71
+ their sum across the entire dataset. The effect of this is similar to
72
+ weighted importance sampling compared to standard importance sampling.
73
+ q_model_config: Arguments to specify the Q-model. Must specify
74
+ a `type` key pointing to the Q-model class.
75
+ This Q-model is trained in the train() method and is used
76
+ to compute the state-value and Q-value estimates
77
+ for the DoublyRobust estimator.
78
+ It must implement `train`, `estimate_q`, and `estimate_v`.
79
+ TODO (Rohan138): Unify this with RLModule API.
80
+ """
81
+
82
+ super().__init__(policy, gamma, epsilon_greedy)
83
+ q_model_config = q_model_config or {}
84
+ q_model_config["gamma"] = gamma
85
+
86
+ self._model_cls = q_model_config.pop("type", FQETorchModel)
87
+ self._model_configs = q_model_config
88
+ self._normalize_weights = normalize_weights
89
+
90
+ self.model = self._model_cls(
91
+ policy=policy,
92
+ **q_model_config,
93
+ )
94
+ assert hasattr(
95
+ self.model, "estimate_v"
96
+ ), "self.model must implement `estimate_v`!"
97
+ assert hasattr(
98
+ self.model, "estimate_q"
99
+ ), "self.model must implement `estimate_q`!"
100
+
101
+ @override(OffPolicyEstimator)
102
+ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
103
+ estimates_per_epsiode = {}
104
+
105
+ rewards, old_prob = episode["rewards"], episode["action_prob"]
106
+ new_prob = self.compute_action_probs(episode)
107
+
108
+ weight = new_prob / old_prob
109
+
110
+ v_behavior = 0.0
111
+ v_target = 0.0
112
+ q_values = self.model.estimate_q(episode)
113
+ q_values = convert_to_numpy(q_values)
114
+ v_values = self.model.estimate_v(episode)
115
+ v_values = convert_to_numpy(v_values)
116
+ assert q_values.shape == v_values.shape == (episode.count,)
117
+
118
+ for t in reversed(range(episode.count)):
119
+ v_behavior = rewards[t] + self.gamma * v_behavior
120
+ v_target = v_values[t] + weight[t] * (
121
+ rewards[t] + self.gamma * v_target - q_values[t]
122
+ )
123
+ v_target = v_target.item()
124
+
125
+ estimates_per_epsiode["v_behavior"] = v_behavior
126
+ estimates_per_epsiode["v_target"] = v_target
127
+
128
+ return estimates_per_epsiode
129
+
130
+ @override(OffPolicyEstimator)
131
+ def estimate_on_single_step_samples(
132
+ self, batch: SampleBatch
133
+ ) -> Dict[str, List[float]]:
134
+ estimates_per_epsiode = {}
135
+
136
+ rewards, old_prob = batch["rewards"], batch["action_prob"]
137
+ new_prob = self.compute_action_probs(batch)
138
+
139
+ q_values = self.model.estimate_q(batch)
140
+ q_values = convert_to_numpy(q_values)
141
+ v_values = self.model.estimate_v(batch)
142
+ v_values = convert_to_numpy(v_values)
143
+
144
+ v_behavior = rewards
145
+
146
+ weight = new_prob / old_prob
147
+ v_target = v_values + weight * (rewards - q_values)
148
+
149
+ estimates_per_epsiode["v_behavior"] = v_behavior
150
+ estimates_per_epsiode["v_target"] = v_target
151
+
152
+ return estimates_per_epsiode
153
+
154
+ @override(OffPolicyEstimator)
155
+ def train(self, batch: SampleBatchType) -> Dict[str, Any]:
156
+ """Trains self.model on the given batch.
157
+
158
+ Args:
159
+ batch: A SampleBatch or MultiAgentbatch to train on
160
+
161
+ Returns:
162
+ A dict with key "loss" and value as the mean training loss.
163
+ """
164
+ batch = convert_ma_batch_to_sample_batch(batch)
165
+ losses = self.model.train(batch)
166
+ return {"loss": np.mean(losses)}
167
+
168
+ @override(OfflineEvaluator)
169
+ def estimate_on_dataset(
170
+ self, dataset: Dataset, *, n_parallelism: int = ...
171
+ ) -> Dict[str, Any]:
172
+ """Estimates the policy value using the Doubly Robust estimator.
173
+
174
+ The doubly robust estimator uses normalization of importance sampling weights
175
+ (aka. propensity ratios) to the average of the importance weights across the
176
+ entire dataset. This is done to reduce the variance of the estimate (similar to
177
+ weighted importance sampling). You can disable this by setting
178
+ `normalize_weights=False` in the constructor.
179
+
180
+ Note: This estimate works for only discrete action spaces for now.
181
+
182
+ Args:
183
+ dataset: Dataset to compute the estimate on. Each record in dataset should
184
+ include the following columns: `obs`, `actions`, `action_prob` and
185
+ `rewards`. The `obs` on each row shoud be a vector of D dimensions.
186
+ n_parallelism: Number of parallelism to use for the computation.
187
+
188
+ Returns:
189
+ A dict with the following keys:
190
+ v_target: The estimated value of the target policy.
191
+ v_behavior: The estimated value of the behavior policy.
192
+ v_gain: The estimated gain of the target policy over the behavior
193
+ policy.
194
+ v_std: The standard deviation of the estimated value of the target.
195
+ """
196
+
197
+ # step 1: compute the weights and weighted rewards
198
+ batch_size = max(dataset.count() // n_parallelism, 1)
199
+ updated_ds = dataset.map_batches(
200
+ compute_is_weights,
201
+ batch_size=batch_size,
202
+ batch_format="pandas",
203
+ fn_kwargs={
204
+ "policy_state": self.policy.get_state(),
205
+ "estimator_class": self.__class__,
206
+ },
207
+ )
208
+
209
+ # step 2: compute q_values and v_values
210
+ batch_size = max(updated_ds.count() // n_parallelism, 1)
211
+ updated_ds = updated_ds.map_batches(
212
+ compute_q_and_v_values,
213
+ batch_size=batch_size,
214
+ batch_format="pandas",
215
+ fn_kwargs={
216
+ "model_class": self.model.__class__,
217
+ "model_state": self.model.get_state(),
218
+ },
219
+ )
220
+
221
+ # step 3: compute the v_target
222
+ def compute_v_target(batch: pd.DataFrame, normalizer: float = 1.0):
223
+ weights = batch["weights"] / normalizer
224
+ batch["v_target"] = batch["v_values"] + weights * (
225
+ batch["rewards"] - batch["q_values"]
226
+ )
227
+ batch["v_behavior"] = batch["rewards"]
228
+ return batch
229
+
230
+ normalizer = updated_ds.mean("weights") if self._normalize_weights else 1.0
231
+ updated_ds = updated_ds.map_batches(
232
+ compute_v_target,
233
+ batch_size=batch_size,
234
+ batch_format="pandas",
235
+ fn_kwargs={"normalizer": normalizer},
236
+ )
237
+
238
+ v_behavior = updated_ds.mean("v_behavior")
239
+ v_target = updated_ds.mean("v_target")
240
+ v_gain_mean = v_target / v_behavior
241
+ v_gain_ste = (
242
+ updated_ds.std("v_target")
243
+ / normalizer
244
+ / v_behavior
245
+ / math.sqrt(dataset.count())
246
+ )
247
+
248
+ return {
249
+ "v_behavior": v_behavior,
250
+ "v_target": v_target,
251
+ "v_gain_mean": v_gain_mean,
252
+ "v_gain_ste": v_gain_ste,
253
+ }
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/fqe_torch_model.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from ray.rllib.models.utils import get_initializer
3
+ from ray.rllib.policy import Policy
4
+
5
+ from ray.rllib.models.catalog import ModelCatalog
6
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
7
+ from ray.rllib.policy.sample_batch import SampleBatch
8
+ from ray.rllib.utils.annotations import DeveloperAPI
9
+ from ray.rllib.utils.framework import try_import_torch
10
+ from ray.rllib.utils.annotations import is_overridden
11
+ from ray.rllib.utils.typing import ModelConfigDict, TensorType
12
+ from gymnasium.spaces import Discrete
13
+
14
+ torch, nn = try_import_torch()
15
+
16
+ # TODO: Create a config object for FQE and unify it with the RLModule API
17
+
18
+
19
+ @DeveloperAPI
20
+ class FQETorchModel:
21
+ """Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
22
+ https://arxiv.org/abs/1911.06854
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ policy: Policy,
28
+ gamma: float,
29
+ model_config: ModelConfigDict = None,
30
+ n_iters: int = 1,
31
+ lr: float = 1e-3,
32
+ min_loss_threshold: float = 1e-4,
33
+ clip_grad_norm: float = 100.0,
34
+ minibatch_size: int = None,
35
+ polyak_coef: float = 1.0,
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ policy: Policy to evaluate.
40
+ gamma: Discount factor of the environment.
41
+ model_config: The ModelConfigDict for self.q_model, defaults to:
42
+ {
43
+ "fcnet_hiddens": [8, 8],
44
+ "fcnet_activation": "relu",
45
+ "vf_share_layers": True,
46
+ },
47
+ n_iters: Number of gradient steps to run on batch, defaults to 1
48
+ lr: Learning rate for Adam optimizer
49
+ min_loss_threshold: Early stopping if mean loss < min_loss_threshold
50
+ clip_grad_norm: Clip loss gradients to this maximum value
51
+ minibatch_size: Minibatch size for training Q-function;
52
+ if None, train on the whole batch
53
+ polyak_coef: Polyak averaging factor for target Q-function
54
+ """
55
+ self.policy = policy
56
+ assert isinstance(
57
+ policy.action_space, Discrete
58
+ ), f"{self.__class__.__name__} only supports discrete action spaces!"
59
+ self.gamma = gamma
60
+ self.observation_space = policy.observation_space
61
+ self.action_space = policy.action_space
62
+
63
+ if model_config is None:
64
+ model_config = {
65
+ "fcnet_hiddens": [32, 32, 32],
66
+ "fcnet_activation": "relu",
67
+ "vf_share_layers": True,
68
+ }
69
+ self.model_config = model_config
70
+
71
+ self.device = self.policy.device
72
+ self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
73
+ self.observation_space,
74
+ self.action_space,
75
+ self.action_space.n,
76
+ model_config,
77
+ framework="torch",
78
+ name="TorchQModel",
79
+ ).to(self.device)
80
+
81
+ self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
82
+ self.observation_space,
83
+ self.action_space,
84
+ self.action_space.n,
85
+ model_config,
86
+ framework="torch",
87
+ name="TargetTorchQModel",
88
+ ).to(self.device)
89
+
90
+ self.n_iters = n_iters
91
+ self.lr = lr
92
+ self.min_loss_threshold = min_loss_threshold
93
+ self.clip_grad_norm = clip_grad_norm
94
+ self.minibatch_size = minibatch_size
95
+ self.polyak_coef = polyak_coef
96
+ self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
97
+ initializer = get_initializer("xavier_uniform", framework="torch")
98
+ # Hard update target
99
+ self.update_target(polyak_coef=1.0)
100
+
101
+ def f(m):
102
+ if isinstance(m, nn.Linear):
103
+ initializer(m.weight)
104
+
105
+ self.initializer = f
106
+
107
+ def train(self, batch: SampleBatch) -> TensorType:
108
+ """Trains self.q_model using FQE loss on given batch.
109
+
110
+ Args:
111
+ batch: A SampleBatch of episodes to train on
112
+
113
+ Returns:
114
+ A list of losses for each training iteration
115
+ """
116
+ losses = []
117
+ minibatch_size = self.minibatch_size or batch.count
118
+ # Copy batch for shuffling
119
+ batch = batch.copy(shallow=True)
120
+ for _ in range(self.n_iters):
121
+ minibatch_losses = []
122
+ batch.shuffle()
123
+ for idx in range(0, batch.count, minibatch_size):
124
+ minibatch = batch[idx : idx + minibatch_size]
125
+ obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
126
+ actions = torch.tensor(
127
+ minibatch[SampleBatch.ACTIONS],
128
+ device=self.device,
129
+ dtype=int,
130
+ )
131
+ rewards = torch.tensor(
132
+ minibatch[SampleBatch.REWARDS], device=self.device
133
+ )
134
+ next_obs = torch.tensor(
135
+ minibatch[SampleBatch.NEXT_OBS], device=self.device
136
+ )
137
+ dones = torch.tensor(
138
+ minibatch[SampleBatch.TERMINATEDS], device=self.device, dtype=float
139
+ )
140
+
141
+ # Compute Q-values for current obs
142
+ q_values, _ = self.q_model({"obs": obs}, [], None)
143
+ q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
144
+
145
+ next_action_probs = self._compute_action_probs(next_obs)
146
+
147
+ # Compute Q-values for next obs
148
+ with torch.no_grad():
149
+ next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
150
+
151
+ # Compute estimated state value next_v = E_{a ~ pi(s)} [Q(next_obs,a)]
152
+ next_v = torch.sum(next_q_values * next_action_probs, axis=-1)
153
+ targets = rewards + (1 - dones) * self.gamma * next_v
154
+ loss = (targets - q_acts) ** 2
155
+ loss = torch.mean(loss)
156
+ self.optimizer.zero_grad()
157
+ loss.backward()
158
+ nn.utils.clip_grad.clip_grad_norm_(
159
+ self.q_model.variables(), self.clip_grad_norm
160
+ )
161
+ self.optimizer.step()
162
+ minibatch_losses.append(loss.item())
163
+ iter_loss = sum(minibatch_losses) / len(minibatch_losses)
164
+ losses.append(iter_loss)
165
+ if iter_loss < self.min_loss_threshold:
166
+ break
167
+ self.update_target()
168
+ return losses
169
+
170
+ def estimate_q(self, batch: SampleBatch) -> TensorType:
171
+ obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
172
+ with torch.no_grad():
173
+ q_values, _ = self.q_model({"obs": obs}, [], None)
174
+ actions = torch.tensor(
175
+ batch[SampleBatch.ACTIONS], device=self.device, dtype=int
176
+ )
177
+ q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
178
+ return q_values
179
+
180
+ def estimate_v(self, batch: SampleBatch) -> TensorType:
181
+ obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
182
+ with torch.no_grad():
183
+ q_values, _ = self.q_model({"obs": obs}, [], None)
184
+ # Compute pi(a | s) for each action a in policy.action_space
185
+ action_probs = self._compute_action_probs(obs)
186
+ v_values = torch.sum(q_values * action_probs, axis=-1)
187
+ return v_values
188
+
189
+ def update_target(self, polyak_coef=None):
190
+ # Update_target will be called periodically to copy Q network to
191
+ # target Q network, using (soft) polyak_coef-synching.
192
+ polyak_coef = polyak_coef or self.polyak_coef
193
+ model_state_dict = self.q_model.state_dict()
194
+ # Support partial (soft) synching.
195
+ # If polyak_coef == 1.0: Full sync from Q-model to target Q-model.
196
+ target_state_dict = self.target_q_model.state_dict()
197
+ model_state_dict = {
198
+ k: polyak_coef * model_state_dict[k] + (1 - polyak_coef) * v
199
+ for k, v in target_state_dict.items()
200
+ }
201
+
202
+ self.target_q_model.load_state_dict(model_state_dict)
203
+
204
+ def _compute_action_probs(self, obs: TensorType) -> TensorType:
205
+ """Compute action distribution over the action space.
206
+
207
+ Args:
208
+ obs: A tensor of observations of shape (batch_size * obs_dim)
209
+
210
+ Returns:
211
+ action_probs: A tensor of action probabilities
212
+ of shape (batch_size * action_dim)
213
+ """
214
+ input_dict = {SampleBatch.OBS: obs}
215
+ seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
216
+ state_batches = []
217
+ if is_overridden(self.policy.action_distribution_fn):
218
+ try:
219
+ # TorchPolicyV2 function signature
220
+ dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
221
+ self.policy.model,
222
+ obs_batch=input_dict,
223
+ state_batches=state_batches,
224
+ seq_lens=seq_lens,
225
+ explore=False,
226
+ is_training=False,
227
+ )
228
+ except TypeError:
229
+ # TorchPolicyV1 function signature for compatibility with DQN
230
+ # TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
231
+ dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
232
+ self.policy,
233
+ self.policy.model,
234
+ input_dict=input_dict,
235
+ state_batches=state_batches,
236
+ seq_lens=seq_lens,
237
+ explore=False,
238
+ is_training=False,
239
+ )
240
+ else:
241
+ dist_class = self.policy.dist_class
242
+ dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
243
+ action_dist = dist_class(dist_inputs, self.policy.model)
244
+ assert isinstance(
245
+ action_dist.dist, torch.distributions.categorical.Categorical
246
+ ), "FQE only supports Categorical or MultiCategorical distributions!"
247
+ action_probs = action_dist.dist.probs
248
+ return action_probs
249
+
250
+ def get_state(self) -> Dict[str, Any]:
251
+ """Returns the current state of the FQE Model."""
252
+ return {
253
+ "policy_state": self.policy.get_state(),
254
+ "model_config": self.model_config,
255
+ "n_iters": self.n_iters,
256
+ "lr": self.lr,
257
+ "min_loss_threshold": self.min_loss_threshold,
258
+ "clip_grad_norm": self.clip_grad_norm,
259
+ "minibatch_size": self.minibatch_size,
260
+ "polyak_coef": self.polyak_coef,
261
+ "gamma": self.gamma,
262
+ "q_model_state": self.q_model.state_dict(),
263
+ "target_q_model_state": self.target_q_model.state_dict(),
264
+ }
265
+
266
+ def set_state(self, state: Dict[str, Any]) -> None:
267
+ """Sets the current state of the FQE Model.
268
+ Args:
269
+ state: A state dict returned by `get_state()`.
270
+ """
271
+ self.n_iters = state["n_iters"]
272
+ self.lr = state["lr"]
273
+ self.min_loss_threshold = state["min_loss_threshold"]
274
+ self.clip_grad_norm = state["clip_grad_norm"]
275
+ self.minibatch_size = state["minibatch_size"]
276
+ self.polyak_coef = state["polyak_coef"]
277
+ self.gamma = state["gamma"]
278
+ self.policy.set_state(state["policy_state"])
279
+ self.q_model.load_state_dict(state["q_model_state"])
280
+ self.target_q_model.load_state_dict(state["target_q_model_state"])
281
+
282
+ @classmethod
283
+ def from_state(cls, state: Dict[str, Any]) -> "FQETorchModel":
284
+ """Creates a FQE Model from a state dict.
285
+
286
+ Args:
287
+ state: A state dict returned by `get_state`.
288
+
289
+ Returns:
290
+ An instance of the FQETorchModel.
291
+ """
292
+ policy = Policy.from_state(state["policy_state"])
293
+ model = cls(
294
+ policy=policy, gamma=state["gamma"], model_config=state["model_config"]
295
+ )
296
+ model.set_state(state)
297
+ return model
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/importance_sampling.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import math
3
+
4
+ from ray.data import Dataset
5
+
6
+ from ray.rllib.utils.annotations import override, DeveloperAPI
7
+ from ray.rllib.offline.offline_evaluator import OfflineEvaluator
8
+ from ray.rllib.offline.offline_evaluation_utils import (
9
+ remove_time_dim,
10
+ compute_is_weights,
11
+ )
12
+ from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
13
+ from ray.rllib.policy.sample_batch import SampleBatch
14
+
15
+
16
+ @DeveloperAPI
17
+ class ImportanceSampling(OffPolicyEstimator):
18
+ r"""The step-wise IS estimator.
19
+
20
+ Let s_t, a_t, and r_t be the state, action, and reward at timestep t.
21
+
22
+ For behavior policy \pi_b and evaluation policy \pi_e, define the
23
+ cumulative importance ratio at timestep t as:
24
+ p_t = \sum_{t'=0}^t (\pi_e(a_{t'} | s_{t'}) / \pi_b(a_{t'} | s_{t'})).
25
+
26
+ This estimator computes the expected return for \pi_e for an episode as:
27
+ V^{\pi_e}(s_0) = \sum_t \gamma ^ {t} * p_t * r_t
28
+ and returns the mean and standard deviation over episodes.
29
+
30
+ For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""
31
+
32
+ @override(OffPolicyEstimator)
33
+ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, float]:
34
+ estimates_per_epsiode = {}
35
+
36
+ rewards, old_prob = episode["rewards"], episode["action_prob"]
37
+ new_prob = self.compute_action_probs(episode)
38
+
39
+ # calculate importance ratios
40
+ p = []
41
+ for t in range(episode.count):
42
+ if t == 0:
43
+ pt_prev = 1.0
44
+ else:
45
+ pt_prev = p[t - 1]
46
+ p.append(pt_prev * new_prob[t] / old_prob[t])
47
+
48
+ # calculate stepwise IS estimate
49
+ v_behavior = 0.0
50
+ v_target = 0.0
51
+ for t in range(episode.count):
52
+ v_behavior += rewards[t] * self.gamma**t
53
+ v_target += p[t] * rewards[t] * self.gamma**t
54
+
55
+ estimates_per_epsiode["v_behavior"] = v_behavior
56
+ estimates_per_epsiode["v_target"] = v_target
57
+
58
+ return estimates_per_epsiode
59
+
60
+ @override(OffPolicyEstimator)
61
+ def estimate_on_single_step_samples(
62
+ self, batch: SampleBatch
63
+ ) -> Dict[str, List[float]]:
64
+ estimates_per_epsiode = {}
65
+
66
+ rewards, old_prob = batch["rewards"], batch["action_prob"]
67
+ new_prob = self.compute_action_probs(batch)
68
+
69
+ weights = new_prob / old_prob
70
+ v_behavior = rewards
71
+ v_target = weights * rewards
72
+
73
+ estimates_per_epsiode["v_behavior"] = v_behavior
74
+ estimates_per_epsiode["v_target"] = v_target
75
+
76
+ return estimates_per_epsiode
77
+
78
+ @override(OfflineEvaluator)
79
+ def estimate_on_dataset(
80
+ self, dataset: Dataset, *, n_parallelism: int = ...
81
+ ) -> Dict[str, Any]:
82
+ """Computes the Importance sampling estimate on the given dataset.
83
+
84
+ Note: This estimate works for both continuous and discrete action spaces.
85
+
86
+ Args:
87
+ dataset: Dataset to compute the estimate on. Each record in dataset should
88
+ include the following columns: `obs`, `actions`, `action_prob` and
89
+ `rewards`. The `obs` on each row shoud be a vector of D dimensions.
90
+ n_parallelism: The number of parallel workers to use.
91
+
92
+ Returns:
93
+ A dictionary containing the following keys:
94
+ v_target: The estimated value of the target policy.
95
+ v_behavior: The estimated value of the behavior policy.
96
+ v_gain_mean: The mean of the gain of the target policy over the
97
+ behavior policy.
98
+ v_gain_ste: The standard error of the gain of the target policy over
99
+ the behavior policy.
100
+ """
101
+ batch_size = max(dataset.count() // n_parallelism, 1)
102
+ dataset = dataset.map_batches(
103
+ remove_time_dim, batch_size=batch_size, batch_format="pandas"
104
+ )
105
+ updated_ds = dataset.map_batches(
106
+ compute_is_weights,
107
+ batch_size=batch_size,
108
+ batch_format="pandas",
109
+ fn_kwargs={
110
+ "policy_state": self.policy.get_state(),
111
+ "estimator_class": self.__class__,
112
+ },
113
+ )
114
+ v_target = updated_ds.mean("weighted_rewards")
115
+ v_behavior = updated_ds.mean("rewards")
116
+ v_gain_mean = v_target / v_behavior
117
+ v_gain_ste = (
118
+ updated_ds.std("weighted_rewards") / v_behavior / math.sqrt(dataset.count())
119
+ )
120
+
121
+ return {
122
+ "v_target": v_target,
123
+ "v_behavior": v_behavior,
124
+ "v_gain_mean": v_gain_mean,
125
+ "v_gain_ste": v_gain_ste,
126
+ }
.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/off_policy_estimator.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+ import tree
4
+ from typing import Dict, Any, List
5
+
6
+ import logging
7
+ from ray.rllib.policy.sample_batch import SampleBatch
8
+ from ray.rllib.policy import Policy
9
+ from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
10
+ from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
11
+ from ray.rllib.utils.annotations import (
12
+ DeveloperAPI,
13
+ ExperimentalAPI,
14
+ OverrideToImplementCustomLogic,
15
+ )
16
+ from ray.rllib.utils.deprecation import Deprecated
17
+ from ray.rllib.utils.numpy import convert_to_numpy
18
+ from ray.rllib.utils.typing import TensorType, SampleBatchType
19
+ from ray.rllib.offline.offline_evaluator import OfflineEvaluator
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @DeveloperAPI
25
+ class OffPolicyEstimator(OfflineEvaluator):
26
+ """Interface for an off policy estimator for counterfactual evaluation."""
27
+
28
+ @DeveloperAPI
29
+ def __init__(
30
+ self,
31
+ policy: Policy,
32
+ gamma: float = 0.0,
33
+ epsilon_greedy: float = 0.0,
34
+ ):
35
+ """Initializes an OffPolicyEstimator instance.
36
+
37
+ Args:
38
+ policy: Policy to evaluate.
39
+ gamma: Discount factor of the environment.
40
+ epsilon_greedy: The probability by which we act acording to a fully random
41
+ policy during deployment. With 1-epsilon_greedy we act according the target
42
+ policy.
43
+ # TODO (kourosh): convert the input parameters to a config dict.
44
+ """
45
+ super().__init__(policy)
46
+ self.gamma = gamma
47
+ self.epsilon_greedy = epsilon_greedy
48
+
49
+ @DeveloperAPI
50
+ def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
51
+ """Returns off-policy estimates for the given one episode.
52
+
53
+ Args:
54
+ batch: The episode to calculate the off-policy estimates (OPE) on. The
55
+ episode must be a sample batch type that contains the fields "obs",
56
+ "actions", and "action_prob" and it needs to represent a
57
+ complete trajectory.
58
+
59
+ Returns:
60
+ The off-policy estimates (OPE) calculated on the given episode. The returned
61
+ dict can be any arbitrary mapping of strings to metrics.
62
+ """
63
+ raise NotImplementedError
64
+
65
+ @DeveloperAPI
66
+ def estimate_on_single_step_samples(
67
+ self,
68
+ batch: SampleBatch,
69
+ ) -> Dict[str, List[float]]:
70
+ """Returns off-policy estimates for the batch of single timesteps. This is
71
+ highly optimized for bandits assuming each episode is a single timestep.
72
+
73
+ Args:
74
+ batch: The batch to calculate the off-policy estimates (OPE) on. The
75
+ batch must be a sample batch type that contains the fields "obs",
76
+ "actions", and "action_prob".
77
+
78
+ Returns:
79
+ The off-policy estimates (OPE) calculated on the given batch of single time
80
+ step samples. The returned dict can be any arbitrary mapping of strings to
81
+ a list of floats capturing the values per each record.
82
+ """
83
+ raise NotImplementedError
84
+
85
+ def on_before_split_batch_by_episode(
86
+ self, sample_batch: SampleBatch
87
+ ) -> SampleBatch:
88
+ """Called before the batch is split by episode. You can perform any
89
+ preprocessing on the batch that you want here.
90
+ e.g. adding done flags to the batch, or reseting some stats that you want to
91
+ track per episode later during estimation, .etc.
92
+
93
+ Args:
94
+ sample_batch: The batch to split by episode. This contains multiple
95
+ episodes.
96
+
97
+ Returns:
98
+ The modified batch before calling split_by_episode().
99
+ """
100
+ return sample_batch
101
+
102
+ @OverrideToImplementCustomLogic
103
+ def on_after_split_batch_by_episode(
104
+ self, all_episodes: List[SampleBatch]
105
+ ) -> List[SampleBatch]:
106
+ """Called after the batch is split by episode. You can perform any
107
+ postprocessing on each episode that you want here.
108
+ e.g. computing advantage per episode, .etc.
109
+
110
+ Args:
111
+ all_episodes: The list of episodes in the original batch. Each element is a
112
+ sample batch type that is a single episode.
113
+ """
114
+
115
+ return all_episodes
116
+
117
+ @OverrideToImplementCustomLogic
118
+ def peek_on_single_episode(self, episode: SampleBatch) -> None:
119
+ """This is called on each episode before it is passed to
120
+ estimate_on_single_episode(). Using this method, you can get a peek at the
121
+ entire validation dataset before runnining the estimation. For examlpe if you
122
+ need to perform any normalizations of any sorts on the dataset, you can compute
123
+ the normalization parameters here.
124
+
125
+ Args:
126
+ episode: The episode that is split from the original batch. This is a
127
+ sample batch type that is a single episode.
128
+ """
129
+ pass
130
+
131
+ @DeveloperAPI
132
+ def estimate(
133
+ self, batch: SampleBatchType, split_batch_by_episode: bool = True
134
+ ) -> Dict[str, Any]:
135
+ """Compute off-policy estimates.
136
+
137
+ Args:
138
+ batch: The batch to calculate the off-policy estimates (OPE) on. The
139
+ batch must contain the fields "obs", "actions", and "action_prob".
140
+ split_batch_by_episode: Whether to split the batch by episode.
141
+
142
+ Returns:
143
+ The off-policy estimates (OPE) calculated on the given batch. The returned
144
+ dict can be any arbitrary mapping of strings to metrics.
145
+ The dict consists of the following metrics:
146
+ - v_behavior: The discounted return averaged over episodes in the batch
147
+ - v_behavior_std: The standard deviation corresponding to v_behavior
148
+ - v_target: The estimated discounted return for `self.policy`,
149
+ averaged over episodes in the batch
150
+ - v_target_std: The standard deviation corresponding to v_target
151
+ - v_gain: v_target / max(v_behavior, 1e-8)
152
+ - v_delta: The difference between v_target and v_behavior.
153
+ """
154
+ batch = convert_ma_batch_to_sample_batch(batch)
155
+ self.check_action_prob_in_batch(batch)
156
+ estimates_per_epsiode = []
157
+ if split_batch_by_episode:
158
+ batch = self.on_before_split_batch_by_episode(batch)
159
+ all_episodes = batch.split_by_episode()
160
+ all_episodes = self.on_after_split_batch_by_episode(all_episodes)
161
+ for episode in all_episodes:
162
+ assert len(set(episode[SampleBatch.EPS_ID])) == 1, (
163
+ "The episode must contain only one episode id. For some reason "
164
+ "the split_by_episode() method could not successfully split "
165
+ "the batch by episodes. Each row in the dataset should be "
166
+ "one episode. Check your evaluation dataset for errors."
167
+ )
168
+ self.peek_on_single_episode(episode)
169
+
170
+ for episode in all_episodes:
171
+ estimate_step_results = self.estimate_on_single_episode(episode)
172
+ estimates_per_epsiode.append(estimate_step_results)
173
+
174
+ # turn a list of identical dicts into a dict of lists
175
+ estimates_per_epsiode = tree.map_structure(
176
+ lambda *x: list(x), *estimates_per_epsiode
177
+ )
178
+ else:
179
+ # the returned dict is a mapping of strings to a list of floats
180
+ estimates_per_epsiode = self.estimate_on_single_step_samples(batch)
181
+
182
+ estimates = {
183
+ "v_behavior": np.mean(estimates_per_epsiode["v_behavior"]),
184
+ "v_behavior_std": np.std(estimates_per_epsiode["v_behavior"]),
185
+ "v_target": np.mean(estimates_per_epsiode["v_target"]),
186
+ "v_target_std": np.std(estimates_per_epsiode["v_target"]),
187
+ }
188
+ estimates["v_gain"] = estimates["v_target"] / max(estimates["v_behavior"], 1e-8)
189
+ estimates["v_delta"] = estimates["v_target"] - estimates["v_behavior"]
190
+
191
+ return estimates
192
+
193
+ @DeveloperAPI
194
+ def check_action_prob_in_batch(self, batch: SampleBatchType) -> None:
195
+ """Checks if we support off policy estimation (OPE) on given batch.
196
+
197
+ Args:
198
+ batch: The batch to check.
199
+
200
+ Raises:
201
+ ValueError: In case `action_prob` key is not in batch
202
+ """
203
+
204
+ if "action_prob" not in batch:
205
+ raise ValueError(
206
+ "Off-policy estimation is not possible unless the inputs "
207
+ "include action probabilities (i.e., the policy is stochastic "
208
+ "and emits the 'action_prob' key). For DQN this means using "
209
+ "`exploration_config: {type: 'SoftQ'}`. You can also set "
210
+ "`off_policy_estimation_methods: {}` to disable estimation."
211
+ )
212
+
213
+ @ExperimentalAPI
214
+ def compute_action_probs(self, batch: SampleBatch):
215
+ log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
216
+ new_prob = np.exp(convert_to_numpy(log_likelihoods))
217
+
218
+ if self.epsilon_greedy > 0.0:
219
+ if not isinstance(self.policy.action_space, gym.spaces.Discrete):
220
+ raise ValueError(
221
+ "Evaluation with epsilon-greedy exploration is only supported "
222
+ "with discrete action spaces."
223
+ )
224
+ eps = self.epsilon_greedy
225
+ new_prob = new_prob * (1 - eps) + eps / self.policy.action_space.n
226
+
227
+ return new_prob
228
+
229
+ @DeveloperAPI
230
+ def train(self, batch: SampleBatchType) -> Dict[str, Any]:
231
+ """Train a model for Off-Policy Estimation.
232
+
233
+ Args:
234
+ batch: SampleBatch to train on
235
+
236
+ Returns:
237
+ Any optional metrics to return from the estimator
238
+ """
239
+ return {}
240
+
241
+ @Deprecated(
242
+ old="OffPolicyEstimator.action_log_likelihood",
243
+ new="ray.rllib.utils.policy.compute_log_likelihoods_from_input_dict",
244
+ error=True,
245
+ )
246
+ def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
247
+ log_likelihoods = compute_log_likelihoods_from_input_dict(self.policy, batch)
248
+ return convert_to_numpy(log_likelihoods)
.venv/lib/python3.11/site-packages/ray/rllib/utils/__init__.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from functools import partial
3
+
4
+ from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
5
+ from ray.rllib.utils.deprecation import deprecation_warning
6
+ from ray.rllib.utils.filter import Filter
7
+ from ray.rllib.utils.filter_manager import FilterManager
8
+ from ray.rllib.utils.framework import (
9
+ try_import_jax,
10
+ try_import_tf,
11
+ try_import_tfp,
12
+ try_import_torch,
13
+ )
14
+ from ray.rllib.utils.numpy import (
15
+ sigmoid,
16
+ softmax,
17
+ relu,
18
+ one_hot,
19
+ fc,
20
+ lstm,
21
+ SMALL_NUMBER,
22
+ LARGE_INTEGER,
23
+ MIN_LOG_NN_OUTPUT,
24
+ MAX_LOG_NN_OUTPUT,
25
+ )
26
+ from ray.rllib.utils.schedules import (
27
+ LinearSchedule,
28
+ PiecewiseSchedule,
29
+ PolynomialSchedule,
30
+ ExponentialSchedule,
31
+ ConstantSchedule,
32
+ )
33
+ from ray.rllib.utils.test_utils import (
34
+ check,
35
+ check_compute_single_action,
36
+ check_train_results,
37
+ )
38
+ from ray.tune.utils import merge_dicts, deep_update
39
+
40
+
41
+ @DeveloperAPI
42
+ def add_mixins(base, mixins, reversed=False):
43
+ """Returns a new class with mixins applied in priority order."""
44
+
45
+ mixins = list(mixins or [])
46
+
47
+ while mixins:
48
+ if reversed:
49
+
50
+ class new_base(base, mixins.pop()):
51
+ pass
52
+
53
+ else:
54
+
55
+ class new_base(mixins.pop(), base):
56
+ pass
57
+
58
+ base = new_base
59
+
60
+ return base
61
+
62
+
63
+ @DeveloperAPI
64
+ def force_list(elements=None, to_tuple=False):
65
+ """
66
+ Makes sure `elements` is returned as a list, whether `elements` is a single
67
+ item, already a list, or a tuple.
68
+
69
+ Args:
70
+ elements (Optional[any]): The inputs as single item, list, or tuple to
71
+ be converted into a list/tuple. If None, returns empty list/tuple.
72
+ to_tuple: Whether to use tuple (instead of list).
73
+
74
+ Returns:
75
+ Union[list,tuple]: All given elements in a list/tuple depending on
76
+ `to_tuple`'s value. If elements is None,
77
+ returns an empty list/tuple.
78
+ """
79
+ ctor = list
80
+ if to_tuple is True:
81
+ ctor = tuple
82
+ return (
83
+ ctor()
84
+ if elements is None
85
+ else ctor(elements)
86
+ if type(elements) in [list, set, tuple]
87
+ else ctor([elements])
88
+ )
89
+
90
+
91
+ @DeveloperAPI
92
+ class NullContextManager(contextlib.AbstractContextManager):
93
+ """No-op context manager"""
94
+
95
+ def __init__(self):
96
+ pass
97
+
98
+ def __enter__(self):
99
+ pass
100
+
101
+ def __exit__(self, *args):
102
+ pass
103
+
104
+
105
+ force_tuple = partial(force_list, to_tuple=True)
106
+
107
+ __all__ = [
108
+ "add_mixins",
109
+ "check",
110
+ "check_compute_single_action",
111
+ "check_train_results",
112
+ "deep_update",
113
+ "deprecation_warning",
114
+ "fc",
115
+ "force_list",
116
+ "force_tuple",
117
+ "lstm",
118
+ "merge_dicts",
119
+ "one_hot",
120
+ "override",
121
+ "relu",
122
+ "sigmoid",
123
+ "softmax",
124
+ "try_import_jax",
125
+ "try_import_tf",
126
+ "try_import_tfp",
127
+ "try_import_torch",
128
+ "ConstantSchedule",
129
+ "DeveloperAPI",
130
+ "ExponentialSchedule",
131
+ "Filter",
132
+ "FilterManager",
133
+ "LARGE_INTEGER",
134
+ "LinearSchedule",
135
+ "MAX_LOG_NN_OUTPUT",
136
+ "MIN_LOG_NN_OUTPUT",
137
+ "PiecewiseSchedule",
138
+ "PolynomialSchedule",
139
+ "PublicAPI",
140
+ "SMALL_NUMBER",
141
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/checkpoints.cpython-311.pyc ADDED
Binary file (42.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/compression.cpython-311.pyc ADDED
Binary file (4.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/deprecation.cpython-311.pyc ADDED
Binary file (5.27 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/from_config.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/lambda_defaultdict.cpython-311.pyc ADDED
Binary file (2.79 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/memory.cpython-311.pyc ADDED
Binary file (523 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/serialization.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/__pycache__/torch_utils.cpython-311.pyc ADDED
Binary file (32 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/actors.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ import logging
3
+ import platform
4
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
5
+
6
+ import ray
7
+ from ray.actor import ActorClass, ActorHandle
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class TaskPool:
13
+ """Helper class for tracking the status of many in-flight actor tasks."""
14
+
15
+ def __init__(self):
16
+ self._tasks = {}
17
+ self._objects = {}
18
+ self._fetching = deque()
19
+
20
+ def add(self, worker, all_obj_refs):
21
+ if isinstance(all_obj_refs, list):
22
+ obj_ref = all_obj_refs[0]
23
+ else:
24
+ obj_ref = all_obj_refs
25
+ self._tasks[obj_ref] = worker
26
+ self._objects[obj_ref] = all_obj_refs
27
+
28
+ def completed(self, blocking_wait=False):
29
+ pending = list(self._tasks)
30
+ if pending:
31
+ ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0)
32
+ if not ready and blocking_wait:
33
+ ready, _ = ray.wait(pending, num_returns=1, timeout=10.0)
34
+ for obj_ref in ready:
35
+ yield (self._tasks.pop(obj_ref), self._objects.pop(obj_ref))
36
+
37
+ def completed_prefetch(self, blocking_wait=False, max_yield=999):
38
+ """Similar to completed but only returns once the object is local.
39
+
40
+ Assumes obj_ref only is one id."""
41
+
42
+ for worker, obj_ref in self.completed(blocking_wait=blocking_wait):
43
+ self._fetching.append((worker, obj_ref))
44
+
45
+ for _ in range(max_yield):
46
+ if not self._fetching:
47
+ break
48
+
49
+ yield self._fetching.popleft()
50
+
51
+ def reset_workers(self, workers):
52
+ """Notify that some workers may be removed."""
53
+ for obj_ref, ev in self._tasks.copy().items():
54
+ if ev not in workers:
55
+ del self._tasks[obj_ref]
56
+ del self._objects[obj_ref]
57
+
58
+ # We want to keep the same deque reference so that we don't suffer from
59
+ # stale references in generators that are still in flight
60
+ for _ in range(len(self._fetching)):
61
+ ev, obj_ref = self._fetching.popleft()
62
+ if ev in workers:
63
+ # Re-queue items that are still valid
64
+ self._fetching.append((ev, obj_ref))
65
+
66
+ @property
67
+ def count(self):
68
+ return len(self._tasks)
69
+
70
+
71
+ def create_colocated_actors(
72
+ actor_specs: Sequence[Tuple[Type, Any, Any, int]],
73
+ node: Optional[str] = "localhost",
74
+ max_attempts: int = 10,
75
+ ) -> Dict[Type, List[ActorHandle]]:
76
+ """Create co-located actors of any type(s) on any node.
77
+
78
+ Args:
79
+ actor_specs: Tuple/list with tuples consisting of: 1) The
80
+ (already @ray.remote) class(es) to construct, 2) c'tor args,
81
+ 3) c'tor kwargs, and 4) the number of actors of that class with
82
+ given args/kwargs to construct.
83
+ node: The node to co-locate the actors on. By default ("localhost"),
84
+ place the actors on the node the caller of this function is
85
+ located on. Use None for indicating that any (resource fulfilling)
86
+ node in the cluster may be used.
87
+ max_attempts: The maximum number of co-location attempts to
88
+ perform before throwing an error.
89
+
90
+ Returns:
91
+ A dict mapping the created types to the list of n ActorHandles
92
+ created (and co-located) for that type.
93
+ """
94
+ if node == "localhost":
95
+ node = platform.node()
96
+
97
+ # Maps each entry in `actor_specs` to lists of already co-located actors.
98
+ ok = [[] for _ in range(len(actor_specs))]
99
+
100
+ # Try n times to co-locate all given actor types (`actor_specs`).
101
+ # With each (failed) attempt, increase the number of actors we try to
102
+ # create (on the same node), then kill the ones that have been created in
103
+ # excess.
104
+ for attempt in range(max_attempts):
105
+ # If any attempt to co-locate fails, set this to False and we'll do
106
+ # another attempt.
107
+ all_good = True
108
+ # Process all `actor_specs` in sequence.
109
+ for i, (typ, args, kwargs, count) in enumerate(actor_specs):
110
+ args = args or [] # Allow None.
111
+ kwargs = kwargs or {} # Allow None.
112
+ # We don't have enough actors yet of this spec co-located on
113
+ # the desired node.
114
+ if len(ok[i]) < count:
115
+ co_located = try_create_colocated(
116
+ cls=typ,
117
+ args=args,
118
+ kwargs=kwargs,
119
+ count=count * (attempt + 1),
120
+ node=node,
121
+ )
122
+ # If node did not matter (None), from here on, use the host
123
+ # that the first actor(s) are already co-located on.
124
+ if node is None:
125
+ node = ray.get(co_located[0].get_host.remote())
126
+ # Add the newly co-located actors to the `ok` list.
127
+ ok[i].extend(co_located)
128
+ # If we still don't have enough -> We'll have to do another
129
+ # attempt.
130
+ if len(ok[i]) < count:
131
+ all_good = False
132
+ # We created too many actors for this spec -> Kill/truncate
133
+ # the excess ones.
134
+ if len(ok[i]) > count:
135
+ for a in ok[i][count:]:
136
+ a.__ray_terminate__.remote()
137
+ ok[i] = ok[i][:count]
138
+
139
+ # All `actor_specs` have been fulfilled, return lists of
140
+ # co-located actors.
141
+ if all_good:
142
+ return ok
143
+
144
+ raise Exception("Unable to create enough colocated actors -> aborting.")
145
+
146
+
147
+ def try_create_colocated(
148
+ cls: Type[ActorClass],
149
+ args: List[Any],
150
+ count: int,
151
+ kwargs: Optional[List[Any]] = None,
152
+ node: Optional[str] = "localhost",
153
+ ) -> List[ActorHandle]:
154
+ """Tries to co-locate (same node) a set of Actors of the same type.
155
+
156
+ Returns a list of successfully co-located actors. All actors that could
157
+ not be co-located (with the others on the given node) will not be in this
158
+ list.
159
+
160
+ Creates each actor via it's remote() constructor and then checks, whether
161
+ it has been co-located (on the same node) with the other (already created)
162
+ ones. If not, terminates the just created actor.
163
+
164
+ Args:
165
+ cls: The Actor class to use (already @ray.remote "converted").
166
+ args: List of args to pass to the Actor's constructor. One item
167
+ per to-be-created actor (`count`).
168
+ count: Number of actors of the given `cls` to construct.
169
+ kwargs: Optional list of kwargs to pass to the Actor's constructor.
170
+ One item per to-be-created actor (`count`).
171
+ node: The node to co-locate the actors on. By default ("localhost"),
172
+ place the actors on the node the caller of this function is
173
+ located on. If None, will try to co-locate all actors on
174
+ any available node.
175
+
176
+ Returns:
177
+ List containing all successfully co-located actor handles.
178
+ """
179
+ if node == "localhost":
180
+ node = platform.node()
181
+
182
+ kwargs = kwargs or {}
183
+ actors = [cls.remote(*args, **kwargs) for _ in range(count)]
184
+ co_located, non_co_located = split_colocated(actors, node=node)
185
+ logger.info("Got {} colocated actors of {}".format(len(co_located), count))
186
+ for a in non_co_located:
187
+ a.__ray_terminate__.remote()
188
+ return co_located
189
+
190
+
191
+ def split_colocated(
192
+ actors: List[ActorHandle],
193
+ node: Optional[str] = "localhost",
194
+ ) -> Tuple[List[ActorHandle], List[ActorHandle]]:
195
+ """Splits up given actors into colocated (on same node) and non colocated.
196
+
197
+ The co-location criterion depends on the `node` given:
198
+ If given (or default: platform.node()): Consider all actors that are on
199
+ that node "colocated".
200
+ If None: Consider the largest sub-set of actors that are all located on
201
+ the same node (whatever that node is) as "colocated".
202
+
203
+ Args:
204
+ actors: The list of actor handles to split into "colocated" and
205
+ "non colocated".
206
+ node: The node defining "colocation" criterion. If provided, consider
207
+ thos actors "colocated" that sit on this node. If None, use the
208
+ largest subset within `actors` that are sitting on the same
209
+ (any) node.
210
+
211
+ Returns:
212
+ Tuple of two lists: 1) Co-located ActorHandles, 2) non co-located
213
+ ActorHandles.
214
+ """
215
+ if node == "localhost":
216
+ node = platform.node()
217
+
218
+ # Get nodes of all created actors.
219
+ hosts = ray.get([a.get_host.remote() for a in actors])
220
+
221
+ # If `node` not provided, use the largest group of actors that sit on the
222
+ # same node, regardless of what that node is.
223
+ if node is None:
224
+ node_groups = defaultdict(set)
225
+ for host, actor in zip(hosts, actors):
226
+ node_groups[host].add(actor)
227
+ max_ = -1
228
+ largest_group = None
229
+ for host in node_groups:
230
+ if max_ < len(node_groups[host]):
231
+ max_ = len(node_groups[host])
232
+ largest_group = host
233
+ non_co_located = []
234
+ for host in node_groups:
235
+ if host != largest_group:
236
+ non_co_located.extend(list(node_groups[host]))
237
+ return list(node_groups[largest_group]), non_co_located
238
+ # Node provided (or default: localhost): Consider those actors "colocated"
239
+ # that were placed on `node`.
240
+ else:
241
+ # Split into co-located (on `node) and non-co-located (not on `node`).
242
+ co_located = []
243
+ non_co_located = []
244
+ for host, a in zip(hosts, actors):
245
+ # This actor has been placed on the correct node.
246
+ if host == node:
247
+ co_located.append(a)
248
+ # This actor has been placed on a different node.
249
+ else:
250
+ non_co_located.append(a)
251
+ return co_located, non_co_located
252
+
253
+
254
+ def drop_colocated(actors: List[ActorHandle]) -> List[ActorHandle]:
255
+ colocated, non_colocated = split_colocated(actors)
256
+ for a in colocated:
257
+ a.__ray_terminate__.remote()
258
+ return non_colocated
.venv/lib/python3.11/site-packages/ray/rllib/utils/annotations.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.deprecation import Deprecated
2
+ from ray.util.annotations import _mark_annotated
3
+
4
+
5
+ def override(parent_cls):
6
+ """Decorator for documenting method overrides.
7
+
8
+ Args:
9
+ parent_cls: The superclass that provides the overridden method. If
10
+ `parent_class` does not actually have the method or the class, in which
11
+ method is defined is not a subclass of `parent_class`, an error is raised.
12
+
13
+ .. testcode::
14
+ :skipif: True
15
+
16
+ from ray.rllib.policy import Policy
17
+ class TorchPolicy(Policy):
18
+ ...
19
+ # Indicates that `TorchPolicy.loss()` overrides the parent
20
+ # Policy class' own `loss method. Leads to an error if Policy
21
+ # does not have a `loss` method.
22
+
23
+ @override(Policy)
24
+ def loss(self, model, action_dist, train_batch):
25
+ ...
26
+
27
+ """
28
+
29
+ class OverrideCheck:
30
+ def __init__(self, func, expected_parent_cls):
31
+ self.func = func
32
+ self.expected_parent_cls = expected_parent_cls
33
+
34
+ def __set_name__(self, owner, name):
35
+ # Check if the owner (the class) is a subclass of the expected base class
36
+ if not issubclass(owner, self.expected_parent_cls):
37
+ raise TypeError(
38
+ f"When using the @override decorator, {owner.__name__} must be a "
39
+ f"subclass of {parent_cls.__name__}!"
40
+ )
41
+ # Set the function as a regular method on the class.
42
+ setattr(owner, name, self.func)
43
+
44
+ def decorator(method):
45
+ # Check, whether `method` is actually defined by the parent class.
46
+ if method.__name__ not in dir(parent_cls):
47
+ raise NameError(
48
+ f"When using the @override decorator, {method.__name__} must override "
49
+ f"the respective method (with the same name) of {parent_cls.__name__}!"
50
+ )
51
+
52
+ # Check if the class is a subclass of the expected base class
53
+ OverrideCheck(method, parent_cls)
54
+ return method
55
+
56
+ return decorator
57
+
58
+
59
+ def PublicAPI(obj):
60
+ """Decorator for documenting public APIs.
61
+
62
+ Public APIs are classes and methods exposed to end users of RLlib. You
63
+ can expect these APIs to remain stable across RLlib releases.
64
+
65
+ Subclasses that inherit from a ``@PublicAPI`` base class can be
66
+ assumed part of the RLlib public API as well (e.g., all Algorithm classes
67
+ are in public API because Algorithm is ``@PublicAPI``).
68
+
69
+ In addition, you can assume all algo configurations are part of their
70
+ public API as well.
71
+
72
+ .. testcode::
73
+ :skipif: True
74
+
75
+ # Indicates that the `Algorithm` class is exposed to end users
76
+ # of RLlib and will remain stable across RLlib releases.
77
+ from ray import tune
78
+ @PublicAPI
79
+ class Algorithm(tune.Trainable):
80
+ ...
81
+ """
82
+
83
+ _mark_annotated(obj)
84
+ return obj
85
+
86
+
87
+ def DeveloperAPI(obj):
88
+ """Decorator for documenting developer APIs.
89
+
90
+ Developer APIs are classes and methods explicitly exposed to developers
91
+ for the purposes of building custom algorithms or advanced training
92
+ strategies on top of RLlib internals. You can generally expect these APIs
93
+ to be stable sans minor changes (but less stable than public APIs).
94
+
95
+ Subclasses that inherit from a ``@DeveloperAPI`` base class can be
96
+ assumed part of the RLlib developer API as well.
97
+
98
+ .. testcode::
99
+ :skipif: True
100
+
101
+ # Indicates that the `TorchPolicy` class is exposed to end users
102
+ # of RLlib and will remain (relatively) stable across RLlib
103
+ # releases.
104
+ from ray.rllib.policy import Policy
105
+ @DeveloperAPI
106
+ class TorchPolicy(Policy):
107
+ ...
108
+ """
109
+
110
+ _mark_annotated(obj)
111
+ return obj
112
+
113
+
114
+ def ExperimentalAPI(obj):
115
+ """Decorator for documenting experimental APIs.
116
+
117
+ Experimental APIs are classes and methods that are in development and may
118
+ change at any time in their development process. You should not expect
119
+ these APIs to be stable until their tag is changed to `DeveloperAPI` or
120
+ `PublicAPI`.
121
+
122
+ Subclasses that inherit from a ``@ExperimentalAPI`` base class can be
123
+ assumed experimental as well.
124
+
125
+ .. testcode::
126
+ :skipif: True
127
+
128
+ from ray.rllib.policy import Policy
129
+ class TorchPolicy(Policy):
130
+ ...
131
+ # Indicates that the `TorchPolicy.loss` method is a new and
132
+ # experimental API and may change frequently in future
133
+ # releases.
134
+ @ExperimentalAPI
135
+ def loss(self, model, action_dist, train_batch):
136
+ ...
137
+ """
138
+
139
+ _mark_annotated(obj)
140
+ return obj
141
+
142
+
143
+ def OldAPIStack(obj):
144
+ """Decorator for classes/methods/functions belonging to the old API stack.
145
+
146
+ These should be deprecated at some point after Ray 3.0 (RLlib GA).
147
+ It is recommended for users to start exploring (and coding against) the new API
148
+ stack instead.
149
+ """
150
+ # No effect yet.
151
+
152
+ _mark_annotated(obj)
153
+ return obj
154
+
155
+
156
+ def OverrideToImplementCustomLogic(obj):
157
+ """Users should override this in their sub-classes to implement custom logic.
158
+
159
+ Used in Algorithm and Policy to tag methods that need overriding, e.g.
160
+ `Policy.loss()`.
161
+
162
+ .. testcode::
163
+ :skipif: True
164
+
165
+ from ray.rllib.policy.torch_policy import TorchPolicy
166
+ @overrides(TorchPolicy)
167
+ @OverrideToImplementCustomLogic
168
+ def loss(self, ...):
169
+ # implement custom loss function here ...
170
+ # ... w/o calling the corresponding `super().loss()` method.
171
+ ...
172
+
173
+ """
174
+ obj.__is_overridden__ = False
175
+ return obj
176
+
177
+
178
+ def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
179
+ """Users should override this in their sub-classes to implement custom logic.
180
+
181
+ Thereby, it is recommended (but not required) to call the super-class'
182
+ corresponding method.
183
+
184
+ Used in Algorithm and Policy to tag methods that need overriding, but the
185
+ super class' method should still be called, e.g.
186
+ `Algorithm.setup()`.
187
+
188
+ .. testcode::
189
+ :skipif: True
190
+
191
+ from ray import tune
192
+ @overrides(tune.Trainable)
193
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
194
+ def setup(self, config):
195
+ # implement custom setup logic here ...
196
+ super().setup(config)
197
+ # ... or here (after having called super()'s setup method.
198
+ """
199
+ obj.__is_overridden__ = False
200
+ return obj
201
+
202
+
203
+ def is_overridden(obj):
204
+ """Check whether a function has been overridden.
205
+
206
+ Note, this only works for API calls decorated with OverrideToImplementCustomLogic
207
+ or OverrideToImplementCustomLogic_CallToSuperRecommended.
208
+ """
209
+ return getattr(obj, "__is_overridden__", True)
210
+
211
+
212
+ # Backward compatibility.
213
+ Deprecated = Deprecated
.venv/lib/python3.11/site-packages/ray/rllib/utils/checkpoints.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import inspect
3
+ import json
4
+ import logging
5
+ import os
6
+ from packaging import version
7
+ import pathlib
8
+ import re
9
+ import tempfile
10
+ from types import MappingProxyType
11
+ from typing import Any, Collection, Dict, List, Optional, Tuple, Union
12
+
13
+ import pyarrow.fs
14
+
15
+ import ray
16
+ import ray.cloudpickle as pickle
17
+ from ray.rllib.core import (
18
+ COMPONENT_LEARNER,
19
+ COMPONENT_LEARNER_GROUP,
20
+ COMPONENT_RL_MODULE,
21
+ )
22
+ from ray.rllib.utils import force_list
23
+ from ray.rllib.utils.actor_manager import FaultTolerantActorManager
24
+ from ray.rllib.utils.annotations import (
25
+ OldAPIStack,
26
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
27
+ )
28
+ from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type
29
+ from ray.rllib.utils.typing import StateDict
30
+ from ray.train import Checkpoint
31
+ from ray.tune.utils.file_transfer import sync_dir_between_nodes
32
+ from ray.util import log_once
33
+ from ray.util.annotations import PublicAPI
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # The current checkpoint version used by RLlib for Algorithm and Policy checkpoints.
38
+ # History:
39
+ # 0.1: Ray 2.0.0
40
+ # A single `checkpoint-[iter num]` file for Algorithm checkpoints
41
+ # within the checkpoint directory. Policy checkpoints not supported across all
42
+ # DL frameworks.
43
+
44
+ # 1.0: Ray >=2.1.0
45
+ # An algorithm_state.pkl file for the state of the Algorithm (excluding
46
+ # individual policy states).
47
+ # One sub-dir inside the "policies" sub-dir for each policy with a
48
+ # dedicated policy_state.pkl in it for the policy state.
49
+
50
+ # 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file
51
+ # indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`.
52
+
53
+ # 1.2: Introduces the checkpoint for the new Learner API if the Learner API is enabled.
54
+
55
+ # 2.0: Introduces the Checkpointable API for all components on the new API stack
56
+ # (if the Learner-, RLModule, EnvRunner, and ConnectorV2 APIs are enabled).
57
+
58
+ CHECKPOINT_VERSION = version.Version("1.1")
59
+ CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER = version.Version("2.1")
60
+
61
+
62
+ @PublicAPI(stability="alpha")
63
+ class Checkpointable(abc.ABC):
64
+ """Abstract base class for a component of RLlib that can be checkpointed to disk.
65
+
66
+ Subclasses must implement the following APIs:
67
+ - save_to_path()
68
+ - restore_from_path()
69
+ - from_checkpoint()
70
+ - get_state()
71
+ - set_state()
72
+ - get_ctor_args_and_kwargs()
73
+ - get_metadata()
74
+ - get_checkpointable_components()
75
+ """
76
+
77
+ # The state file for the implementing class.
78
+ # This file contains any state information that does NOT belong to any subcomponent
79
+ # of the implementing class (which are `Checkpointable` themselves and thus should
80
+ # have their own state- and metadata files).
81
+ # After a `save_to_path([path])` this file can be found directly in: `path/`.
82
+ STATE_FILE_NAME = "state"
83
+
84
+ # The filename of the pickle file that contains the class information of the
85
+ # Checkpointable as well as all constructor args to be passed to such a class in
86
+ # order to construct a new instance.
87
+ CLASS_AND_CTOR_ARGS_FILE_NAME = "class_and_ctor_args.pkl"
88
+
89
+ # Subclasses may set this to their own metadata filename.
90
+ # The dict returned by self.get_metadata() is stored in this JSON file.
91
+ METADATA_FILE_NAME = "metadata.json"
92
+
93
+ def save_to_path(
94
+ self,
95
+ path: Optional[Union[str, pathlib.Path]] = None,
96
+ *,
97
+ state: Optional[StateDict] = None,
98
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
99
+ use_msgpack: bool = False,
100
+ ) -> str:
101
+ """Saves the state of the implementing class (or `state`) to `path`.
102
+
103
+ The state of the implementing class is always saved in the following format:
104
+
105
+ .. testcode::
106
+ :skipif: True
107
+
108
+ path/
109
+ [component1]/
110
+ [component1 subcomponentA]/
111
+ ...
112
+ [component1 subcomponentB]/
113
+ ...
114
+ [component2]/
115
+ ...
116
+ [cls.METADATA_FILE_NAME] (json)
117
+ [cls.STATE_FILE_NAME] (pkl|msgpack)
118
+
119
+ The main logic is to loop through all subcomponents of this Checkpointable
120
+ and call their respective `save_to_path` methods. Then save the remaining
121
+ (non subcomponent) state to this Checkpointable's STATE_FILE_NAME.
122
+ In the exception that a component is a FaultTolerantActorManager instance,
123
+ instead of calling `save_to_path` directly on that manager, the first healthy
124
+ actor is interpreted as the component and its `save_to_path` method is called.
125
+ Even if that actor is located on another node, the created file is automatically
126
+ synced to the local node.
127
+
128
+ Args:
129
+ path: The path to the directory to save the state of the implementing class
130
+ to. If `path` doesn't exist or is None, then a new directory will be
131
+ created (and returned).
132
+ state: An optional state dict to be used instead of getting a new state of
133
+ the implementing class through `self.get_state()`.
134
+ filesystem: PyArrow FileSystem to use to access data at the `path`.
135
+ If not specified, this is inferred from the URI scheme of `path`.
136
+ use_msgpack: Whether the state file should be written using msgpack and
137
+ msgpack_numpy (file extension is `.msgpack`), rather than pickle (file
138
+ extension is `.pkl`).
139
+
140
+ Returns:
141
+ The path (str) where the state has been saved.
142
+ """
143
+
144
+ # If no path is given create a local temporary directory.
145
+ if path is None:
146
+ import uuid
147
+
148
+ # Get the location of the temporary directory on the OS.
149
+ tmp_dir = pathlib.Path(tempfile.gettempdir())
150
+ # Create a random directory name.
151
+ random_dir_name = str(uuid.uuid4())
152
+ # Create the path, but do not craet the directory on the
153
+ # filesystem, yet. This is done by `PyArrow`.
154
+ path = path or tmp_dir / random_dir_name
155
+
156
+ # We need a string path for `pyarrow.fs.FileSystem.from_uri`.
157
+ path = path if isinstance(path, str) else path.as_posix()
158
+
159
+ # If we have no filesystem, figure it out.
160
+ if path and not filesystem:
161
+ # Note the path needs to be a path that is relative to the
162
+ # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
163
+ filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
164
+
165
+ # Make sure, path exists.
166
+ filesystem.create_dir(path, recursive=True)
167
+
168
+ # Convert to `pathlib.Path` for easy handling.
169
+ path = pathlib.Path(path)
170
+
171
+ # Write metadata file to disk.
172
+ metadata = self.get_metadata()
173
+ if "checkpoint_version" not in metadata:
174
+ metadata["checkpoint_version"] = str(
175
+ CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER
176
+ )
177
+ with filesystem.open_output_stream(
178
+ (path / self.METADATA_FILE_NAME).as_posix()
179
+ ) as f:
180
+ f.write(json.dumps(metadata).encode("utf-8"))
181
+
182
+ # Write the class and constructor args information to disk. Always use pickle
183
+ # for this, because this information contains classes and maybe other
184
+ # non-serializable data.
185
+ with filesystem.open_output_stream(
186
+ (path / self.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix()
187
+ ) as f:
188
+ pickle.dump(
189
+ {
190
+ "class": type(self),
191
+ "ctor_args_and_kwargs": self.get_ctor_args_and_kwargs(),
192
+ },
193
+ f,
194
+ )
195
+
196
+ # Get the entire state of this Checkpointable, or use provided `state`.
197
+ _state_provided = state is not None
198
+ state = state or self.get_state(
199
+ not_components=[c[0] for c in self.get_checkpointable_components()]
200
+ )
201
+
202
+ # Write components of `self` that themselves are `Checkpointable`.
203
+ for comp_name, comp in self.get_checkpointable_components():
204
+ # If subcomponent's name is not in `state`, ignore it and don't write this
205
+ # subcomponent's state to disk.
206
+ if _state_provided and comp_name not in state:
207
+ continue
208
+ comp_path = path / comp_name
209
+
210
+ # If component is an ActorManager, save the manager's first healthy
211
+ # actor's state to disk (even if it's on another node, in which case, we'll
212
+ # sync the generated file(s) back to this node).
213
+ if isinstance(comp, FaultTolerantActorManager):
214
+ actor_to_use = comp.healthy_actor_ids()[0]
215
+
216
+ def _get_ip(_=None):
217
+ import ray
218
+
219
+ return ray.util.get_node_ip_address()
220
+
221
+ _result = next(
222
+ iter(
223
+ comp.foreach_actor(
224
+ _get_ip,
225
+ remote_actor_ids=[actor_to_use],
226
+ )
227
+ )
228
+ )
229
+ if not _result.ok:
230
+ raise _result.get()
231
+ worker_ip_addr = _result.get()
232
+ self_ip_addr = _get_ip()
233
+
234
+ # Save the state to a temporary location on the `actor_to_use`'s
235
+ # node.
236
+ comp_state_ref = None
237
+ if _state_provided:
238
+ comp_state_ref = ray.put(state.pop(comp_name))
239
+
240
+ if worker_ip_addr == self_ip_addr:
241
+ comp.foreach_actor(
242
+ lambda w, _path=comp_path, _state=comp_state_ref, _use_msgpack=use_msgpack: ( # noqa
243
+ w.save_to_path(
244
+ _path,
245
+ state=(
246
+ ray.get(_state)
247
+ if _state is not None
248
+ else w.get_state()
249
+ ),
250
+ use_msgpack=_use_msgpack,
251
+ )
252
+ ),
253
+ remote_actor_ids=[actor_to_use],
254
+ )
255
+ else:
256
+ # Save the checkpoint to the temporary directory on the worker.
257
+ def _save(w, _state=comp_state_ref, _use_msgpack=use_msgpack):
258
+ import tempfile
259
+
260
+ # Create a temporary directory on the worker.
261
+ tmpdir = tempfile.mkdtemp()
262
+ w.save_to_path(
263
+ tmpdir,
264
+ state=(
265
+ ray.get(_state) if _state is not None else w.get_state()
266
+ ),
267
+ use_msgpack=_use_msgpack,
268
+ )
269
+ return tmpdir
270
+
271
+ _result = next(
272
+ iter(comp.foreach_actor(_save, remote_actor_ids=[actor_to_use]))
273
+ )
274
+ if not _result.ok:
275
+ raise _result.get()
276
+ worker_temp_dir = _result.get()
277
+
278
+ # Sync the temporary directory from the worker to this node.
279
+ sync_dir_between_nodes(
280
+ worker_ip_addr,
281
+ worker_temp_dir,
282
+ self_ip_addr,
283
+ str(comp_path),
284
+ )
285
+
286
+ # Remove the temporary directory on the worker.
287
+ def _rmdir(_, _dir=worker_temp_dir):
288
+ import shutil
289
+
290
+ shutil.rmtree(_dir)
291
+
292
+ comp.foreach_actor(_rmdir, remote_actor_ids=[actor_to_use])
293
+
294
+ # Local component (instance stored in a property of `self`).
295
+ else:
296
+ if _state_provided:
297
+ comp_state = state.pop(comp_name)
298
+ else:
299
+ comp_state = self.get_state(components=comp_name)[comp_name]
300
+ # By providing the `state` arg, we make sure that the component does not
301
+ # have to call its own `get_state()` anymore, but uses what's provided
302
+ # here.
303
+ comp.save_to_path(
304
+ comp_path,
305
+ filesystem=filesystem,
306
+ state=comp_state,
307
+ use_msgpack=use_msgpack,
308
+ )
309
+
310
+ # Write all the remaining state to disk.
311
+ filename = path / (
312
+ self.STATE_FILE_NAME + (".msgpack" if use_msgpack else ".pkl")
313
+ )
314
+ with filesystem.open_output_stream(filename.as_posix()) as f:
315
+ if use_msgpack:
316
+ msgpack = try_import_msgpack(error=True)
317
+ msgpack.dump(state, f)
318
+ else:
319
+ pickle.dump(state, f)
320
+
321
+ return str(path)
322
+
323
+ def restore_from_path(
324
+ self,
325
+ path: Union[str, pathlib.Path],
326
+ *,
327
+ component: Optional[str] = None,
328
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
329
+ **kwargs,
330
+ ) -> None:
331
+ """Restores the state of the implementing class from the given path.
332
+
333
+ If the `component` arg is provided, `path` refers to a checkpoint of a
334
+ subcomponent of `self`, thus allowing the user to load only the subcomponent's
335
+ state into `self` without affecting any of the other state information (for
336
+ example, loading only the NN state into a Checkpointable, which contains such
337
+ an NN, but also has other state information that should NOT be changed by
338
+ calling this method).
339
+
340
+ The given `path` should have the following structure and contain the following
341
+ files:
342
+
343
+ .. testcode::
344
+ :skipif: True
345
+
346
+ path/
347
+ [component1]/
348
+ [component1 subcomponentA]/
349
+ ...
350
+ [component1 subcomponentB]/
351
+ ...
352
+ [component2]/
353
+ ...
354
+ [cls.METADATA_FILE_NAME] (json)
355
+ [cls.STATE_FILE_NAME] (pkl|msgpack)
356
+
357
+ Note that the self.METADATA_FILE_NAME file is not required to restore the state.
358
+
359
+ Args:
360
+ path: The path to load the implementing class' state from or to load the
361
+ state of only one subcomponent's state of the implementing class (if
362
+ `component` is provided).
363
+ component: If provided, `path` is interpreted as the checkpoint path of only
364
+ the subcomponent and thus, only that subcomponent's state is
365
+ restored/loaded. All other state of `self` remains unchanged in this
366
+ case.
367
+ filesystem: PyArrow FileSystem to use to access data at the `path`. If not
368
+ specified, this is inferred from the URI scheme of `path`.
369
+ **kwargs: Forward compatibility kwargs.
370
+ """
371
+ path = path if isinstance(path, str) else path.as_posix()
372
+
373
+ if path and not filesystem:
374
+ # Note the path needs to be a path that is relative to the
375
+ # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
376
+ filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
377
+ # Only here convert to a `Path` instance b/c otherwise
378
+ # cloud path gets broken (i.e. 'gs://' -> 'gs:/').
379
+ path = pathlib.Path(path)
380
+
381
+ if not _exists_at_fs_path(filesystem, path.as_posix()):
382
+ raise FileNotFoundError(f"`path` ({path}) not found!")
383
+
384
+ # Restore components of `self` that themselves are `Checkpointable`.
385
+ orig_comp_names = {c[0] for c in self.get_checkpointable_components()}
386
+ self._restore_all_subcomponents_from_path(
387
+ path, filesystem, component=component, **kwargs
388
+ )
389
+
390
+ # Restore the "base" state (not individual subcomponents).
391
+ if component is None:
392
+ filename = path / self.STATE_FILE_NAME
393
+ if filename.with_suffix(".msgpack").is_file():
394
+ msgpack = try_import_msgpack(error=True)
395
+ with filesystem.open_input_stream(
396
+ filename.with_suffix(".msgpack").as_posix()
397
+ ) as f:
398
+ state = msgpack.load(f, strict_map_key=False)
399
+ else:
400
+ with filesystem.open_input_stream(
401
+ filename.with_suffix(".pkl").as_posix()
402
+ ) as f:
403
+ state = pickle.load(f)
404
+ self.set_state(state)
405
+
406
+ new_comp_names = {c[0] for c in self.get_checkpointable_components()}
407
+ diff_comp_names = new_comp_names - orig_comp_names
408
+ if diff_comp_names:
409
+ self._restore_all_subcomponents_from_path(
410
+ path, filesystem, only_comp_names=diff_comp_names, **kwargs
411
+ )
412
+
413
+ @classmethod
414
+ def from_checkpoint(
415
+ cls,
416
+ path: Union[str, pathlib.Path],
417
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
418
+ **kwargs,
419
+ ) -> "Checkpointable":
420
+ """Creates a new Checkpointable instance from the given location and returns it.
421
+
422
+ Args:
423
+ path: The checkpoint path to load (a) the information on how to construct
424
+ a new instance of the implementing class and (b) the state to restore
425
+ the created instance to.
426
+ filesystem: PyArrow FileSystem to use to access data at the `path`. If not
427
+ specified, this is inferred from the URI scheme of `path`.
428
+ kwargs: Forward compatibility kwargs. Note that these kwargs are sent to
429
+ each subcomponent's `from_checkpoint()` call.
430
+
431
+ Returns:
432
+ A new instance of the implementing class, already set to the state stored
433
+ under `path`.
434
+ """
435
+ # We need a string path for the `PyArrow` filesystem.
436
+ path = path if isinstance(path, str) else path.as_posix()
437
+
438
+ # If no filesystem is passed in create one.
439
+ if path and not filesystem:
440
+ # Note the path needs to be a path that is relative to the
441
+ # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
442
+ filesystem, path = pyarrow.fs.FileSystem.from_uri(path)
443
+ # Only here convert to a `Path` instance b/c otherwise
444
+ # cloud path gets broken (i.e. 'gs://' -> 'gs:/').
445
+ path = pathlib.Path(path)
446
+
447
+ # Get the class constructor to call and its args/kwargs.
448
+ # Try reading the pickle file first.
449
+ try:
450
+ with filesystem.open_input_stream(
451
+ (path / cls.CLASS_AND_CTOR_ARGS_FILE_NAME).as_posix()
452
+ ) as f:
453
+ ctor_info = pickle.load(f)
454
+ ctor = ctor_info["class"]
455
+ ctor_args = force_list(ctor_info["ctor_args_and_kwargs"][0])
456
+ ctor_kwargs = ctor_info["ctor_args_and_kwargs"][1]
457
+
458
+ # Inspect the ctor to see, which arguments in ctor_info should be replaced
459
+ # with the user provided **kwargs.
460
+ for i, (param_name, param) in enumerate(
461
+ inspect.signature(ctor).parameters.items()
462
+ ):
463
+ if param_name in kwargs:
464
+ val = kwargs.pop(param_name)
465
+ if (
466
+ param.kind == inspect._ParameterKind.POSITIONAL_OR_KEYWORD
467
+ and len(ctor_args) > i
468
+ ):
469
+ ctor_args[i] = val
470
+ else:
471
+ ctor_kwargs[param_name] = val
472
+
473
+ # If the pickle file is from another python version, use provided
474
+ # args instead.
475
+ except Exception:
476
+ # Use class that this method was called on.
477
+ ctor = cls
478
+ # Use only user provided **kwargs.
479
+ ctor_args = []
480
+ ctor_kwargs = kwargs
481
+
482
+ # Check, whether the constructor actually goes together with `cls`.
483
+ if not issubclass(ctor, cls):
484
+ raise ValueError(
485
+ f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
486
+ f"a subclass of `cls` ({cls})!"
487
+ )
488
+ elif not issubclass(ctor, Checkpointable):
489
+ raise ValueError(
490
+ f"The class ({ctor}) stored in checkpoint ({path}) does not seem to be "
491
+ "an implementer of the `Checkpointable` API!"
492
+ )
493
+
494
+ # Construct the initial object (without any particular state).
495
+ obj = ctor(*ctor_args, **ctor_kwargs)
496
+ # Restore the state of the constructed object.
497
+ obj.restore_from_path(path, filesystem=filesystem, **kwargs)
498
+ # Return the new object.
499
+ return obj
500
+
501
+ @abc.abstractmethod
502
+ def get_state(
503
+ self,
504
+ components: Optional[Union[str, Collection[str]]] = None,
505
+ *,
506
+ not_components: Optional[Union[str, Collection[str]]] = None,
507
+ **kwargs,
508
+ ) -> StateDict:
509
+ """Returns the implementing class's current state as a dict.
510
+
511
+ The returned dict must only contain msgpack-serializable data if you want to
512
+ use the `AlgorithmConfig._msgpack_checkpoints` option. Consider returning your
513
+ non msgpack-serializable data from the `Checkpointable.get_ctor_args_and_kwargs`
514
+ method, instead.
515
+
516
+ Args:
517
+ components: An optional collection of string keys to be included in the
518
+ returned state. This might be useful, if getting certain components
519
+ of the state is expensive (e.g. reading/compiling the weights of a large
520
+ NN) and at the same time, these components are not required by the
521
+ caller.
522
+ not_components: An optional list of string keys to be excluded in the
523
+ returned state, even if the same string is part of `components`.
524
+ This is useful to get the complete state of the class, except
525
+ one or a few components.
526
+ kwargs: Forward-compatibility kwargs.
527
+
528
+ Returns:
529
+ The current state of the implementing class (or only the `components`
530
+ specified, w/o those in `not_components`).
531
+ """
532
+
533
+ @abc.abstractmethod
534
+ def set_state(self, state: StateDict) -> None:
535
+ """Sets the implementing class' state to the given state dict.
536
+
537
+ If component keys are missing in `state`, these components of the implementing
538
+ class will not be updated/set.
539
+
540
+ Args:
541
+ state: The state dict to restore the state from. Maps component keys
542
+ to the corresponding subcomponent's own state.
543
+ """
544
+
545
+ @abc.abstractmethod
546
+ def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
547
+ """Returns the args/kwargs used to create `self` from its constructor.
548
+
549
+ Returns:
550
+ A tuple of the args (as a tuple) and kwargs (as a Dict[str, Any]) used to
551
+ construct `self` from its class constructor.
552
+ """
553
+
554
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
555
+ def get_metadata(self) -> Dict:
556
+ """Returns JSON writable metadata further describing the implementing class.
557
+
558
+ Note that this metadata is NOT part of any state and is thus NOT needed to
559
+ restore the state of a Checkpointable instance from a directory. Rather, the
560
+ metadata will be written into `self.METADATA_FILE_NAME` when calling
561
+ `self.save_to_path()` for the user's convenience.
562
+
563
+ Returns:
564
+ A JSON-encodable dict of metadata information.
565
+ """
566
+ return {
567
+ "class_and_ctor_args_file": self.CLASS_AND_CTOR_ARGS_FILE_NAME,
568
+ "state_file": self.STATE_FILE_NAME,
569
+ "ray_version": ray.__version__,
570
+ "ray_commit": ray.__commit__,
571
+ }
572
+
573
+ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
574
+ """Returns the implementing class's own Checkpointable subcomponents.
575
+
576
+ Returns:
577
+ A list of 2-tuples (name, subcomponent) describing the implementing class'
578
+ subcomponents, all of which have to be `Checkpointable` themselves and
579
+ whose state is therefore written into subdirectories (rather than the main
580
+ state file (self.STATE_FILE_NAME) when calling `self.save_to_path()`).
581
+ """
582
+ return []
583
+
584
+ def _check_component(self, name, components, not_components) -> bool:
585
+ comp_list = force_list(components)
586
+ not_comp_list = force_list(not_components)
587
+ if (
588
+ components is None
589
+ or any(c.startswith(name + "/") for c in comp_list)
590
+ or name in comp_list
591
+ ) and (not_components is None or name not in not_comp_list):
592
+ return True
593
+ return False
594
+
595
+ def _get_subcomponents(self, name, components):
596
+ if components is None:
597
+ return None
598
+
599
+ components = force_list(components)
600
+ subcomponents = []
601
+ for comp in components:
602
+ if comp.startswith(name + "/"):
603
+ subcomponents.append(comp[len(name) + 1 :])
604
+
605
+ return None if not subcomponents else subcomponents
606
+
607
+ def _restore_all_subcomponents_from_path(
608
+ self, path, filesystem, only_comp_names=None, component=None, **kwargs
609
+ ):
610
+ for comp_name, comp in self.get_checkpointable_components():
611
+ if only_comp_names is not None and comp_name not in only_comp_names:
612
+ continue
613
+
614
+ # The value of the `component` argument for the upcoming
615
+ # `[subcomponent].restore_from_path(.., component=..)` call.
616
+ comp_arg = None
617
+
618
+ if component is None:
619
+ comp_dir = path / comp_name
620
+ # If subcomponent's dir is not in path, ignore it and don't restore this
621
+ # subcomponent's state from disk.
622
+ if not _exists_at_fs_path(filesystem, comp_dir.as_posix()):
623
+ continue
624
+ else:
625
+ comp_dir = path
626
+
627
+ # `component` is a path that starts with `comp` -> Remove the name of
628
+ # `comp` from the `component` arg in the upcoming call to `restore_..`.
629
+ if component.startswith(comp_name + "/"):
630
+ comp_arg = component[len(comp_name) + 1 :]
631
+ # `component` has nothing to do with `comp` -> Skip.
632
+ elif component != comp_name:
633
+ continue
634
+
635
+ # If component is an ActorManager, restore all the manager's healthy
636
+ # actors' states from disk (even if they are on another node, in which case,
637
+ # we'll sync checkpoint file(s) to the respective node).
638
+ if isinstance(comp, FaultTolerantActorManager):
639
+ head_node_ip = ray.util.get_node_ip_address()
640
+ all_healthy_actors = comp.healthy_actor_ids()
641
+
642
+ def _restore(
643
+ w,
644
+ _kwargs=MappingProxyType(kwargs),
645
+ _path=comp_dir,
646
+ _head_ip=head_node_ip,
647
+ _comp_arg=comp_arg,
648
+ ):
649
+ import ray
650
+ import tempfile
651
+
652
+ worker_node_ip = ray.util.get_node_ip_address()
653
+ # If the worker is on the same node as the head, load the checkpoint
654
+ # directly from the path otherwise sync the checkpoint from the head
655
+ # to the worker and load it from there.
656
+ if worker_node_ip == _head_ip:
657
+ w.restore_from_path(_path, component=_comp_arg, **_kwargs)
658
+ else:
659
+ with tempfile.TemporaryDirectory() as temp_dir:
660
+ sync_dir_between_nodes(
661
+ _head_ip, _path, worker_node_ip, temp_dir
662
+ )
663
+ w.restore_from_path(
664
+ temp_dir, component=_comp_arg, **_kwargs
665
+ )
666
+
667
+ comp.foreach_actor(_restore, remote_actor_ids=all_healthy_actors)
668
+
669
+ # Call `restore_from_path()` on local subcomponent, thereby passing in the
670
+ # **kwargs.
671
+ else:
672
+ comp.restore_from_path(
673
+ comp_dir, filesystem=filesystem, component=comp_arg, **kwargs
674
+ )
675
+
676
+
677
+ def _exists_at_fs_path(fs: pyarrow.fs.FileSystem, path: str) -> bool:
678
+ """Returns `True` if the path can be found in the filesystem."""
679
+ valid = fs.get_file_info(path)
680
+ return valid.type != pyarrow.fs.FileType.NotFound
681
+
682
+
683
+ def _is_dir(file_info: pyarrow.fs.FileInfo) -> bool:
684
+ """Returns `True`, if the file info is from a directory."""
685
+ return file_info.type == pyarrow.fs.FileType.Directory
686
+
687
+
688
+ @OldAPIStack
689
+ def get_checkpoint_info(
690
+ checkpoint: Union[str, Checkpoint],
691
+ filesystem: Optional["pyarrow.fs.FileSystem"] = None,
692
+ ) -> Dict[str, Any]:
693
+ """Returns a dict with information about an Algorithm/Policy checkpoint.
694
+
695
+ If the given checkpoint is a >=v1.0 checkpoint directory, try reading all
696
+ information from the contained `rllib_checkpoint.json` file.
697
+
698
+ Args:
699
+ checkpoint: The checkpoint directory (str) or an AIR Checkpoint object.
700
+ filesystem: PyArrow FileSystem to use to access data at the `checkpoint`. If not
701
+ specified, this is inferred from the URI scheme provided by `checkpoint`.
702
+
703
+ Returns:
704
+ A dict containing the keys:
705
+ "type": One of "Policy" or "Algorithm".
706
+ "checkpoint_version": A version tuple, e.g. v1.0, indicating the checkpoint
707
+ version. This will help RLlib to remain backward compatible wrt. future
708
+ Ray and checkpoint versions.
709
+ "checkpoint_dir": The directory with all the checkpoint files in it. This might
710
+ be the same as the incoming `checkpoint` arg.
711
+ "state_file": The main file with the Algorithm/Policy's state information in it.
712
+ This is usually a pickle-encoded file.
713
+ "policy_ids": An optional set of PolicyIDs in case we are dealing with an
714
+ Algorithm checkpoint. None if `checkpoint` is a Policy checkpoint.
715
+ """
716
+ # Default checkpoint info.
717
+ info = {
718
+ "type": "Algorithm",
719
+ "format": "cloudpickle",
720
+ "checkpoint_version": CHECKPOINT_VERSION,
721
+ "checkpoint_dir": None,
722
+ "state_file": None,
723
+ "policy_ids": None,
724
+ "module_ids": None,
725
+ }
726
+
727
+ # `checkpoint` is a Checkpoint instance: Translate to directory and continue.
728
+ if isinstance(checkpoint, Checkpoint):
729
+ checkpoint = checkpoint.to_directory()
730
+
731
+ if checkpoint and not filesystem:
732
+ # Note the path needs to be a path that is relative to the
733
+ # filesystem (e.g. `gs://tmp/...` -> `tmp/...`).
734
+ filesystem, checkpoint = pyarrow.fs.FileSystem.from_uri(checkpoint)
735
+ # Only here convert to a `Path` instance b/c otherwise
736
+ # cloud path gets broken (i.e. 'gs://' -> 'gs:/').
737
+ checkpoint = pathlib.Path(checkpoint)
738
+
739
+ # Checkpoint is dir.
740
+ if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir(
741
+ filesystem.get_file_info(checkpoint.as_posix())
742
+ ):
743
+ info.update({"checkpoint_dir": str(checkpoint)})
744
+
745
+ # Figure out whether this is an older checkpoint format
746
+ # (with a `checkpoint-\d+` file in it).
747
+ file_info_list = filesystem.get_file_info(
748
+ pyarrow.fs.FileSelector(checkpoint.as_posix(), recursive=False)
749
+ )
750
+ for file_info in file_info_list:
751
+ if file_info.is_file:
752
+ if re.match("checkpoint-\\d+", file_info.base_name):
753
+ info.update(
754
+ {
755
+ "checkpoint_version": version.Version("0.1"),
756
+ "state_file": str(file_info.base_name),
757
+ }
758
+ )
759
+ return info
760
+
761
+ # No old checkpoint file found.
762
+
763
+ # If rllib_checkpoint.json file present, read available information from it
764
+ # and then continue with the checkpoint analysis (possibly overriding further
765
+ # information).
766
+ if _exists_at_fs_path(
767
+ filesystem, (checkpoint / "rllib_checkpoint.json").as_posix()
768
+ ):
769
+ # if (checkpoint / "rllib_checkpoint.json").is_file():
770
+ with filesystem.open_input_stream(
771
+ (checkpoint / "rllib_checkpoint.json").as_posix()
772
+ ) as f:
773
+ # with open(checkpoint / "rllib_checkpoint.json") as f:
774
+ rllib_checkpoint_info = json.load(fp=f)
775
+ if "checkpoint_version" in rllib_checkpoint_info:
776
+ rllib_checkpoint_info["checkpoint_version"] = version.Version(
777
+ rllib_checkpoint_info["checkpoint_version"]
778
+ )
779
+ info.update(rllib_checkpoint_info)
780
+ else:
781
+ # No rllib_checkpoint.json file present: Warn and continue trying to figure
782
+ # out checkpoint info ourselves.
783
+ if log_once("no_rllib_checkpoint_json_file"):
784
+ logger.warning(
785
+ "No `rllib_checkpoint.json` file found in checkpoint directory "
786
+ f"{checkpoint}! Trying to extract checkpoint info from other files "
787
+ f"found in that dir."
788
+ )
789
+
790
+ # Policy checkpoint file found.
791
+ for extension in ["pkl", "msgpck"]:
792
+ if _exists_at_fs_path(
793
+ filesystem, (checkpoint / ("policy_state." + extension)).as_posix()
794
+ ):
795
+ # if (checkpoint / ("policy_state." + extension)).is_file():
796
+ info.update(
797
+ {
798
+ "type": "Policy",
799
+ "format": "cloudpickle" if extension == "pkl" else "msgpack",
800
+ "checkpoint_version": CHECKPOINT_VERSION,
801
+ "state_file": str(checkpoint / f"policy_state.{extension}"),
802
+ }
803
+ )
804
+ return info
805
+
806
+ # Valid Algorithm checkpoint >v0 file found?
807
+ format = None
808
+ for extension in ["pkl", "msgpck", "msgpack"]:
809
+ state_file = checkpoint / f"algorithm_state.{extension}"
810
+ if (
811
+ _exists_at_fs_path(filesystem, state_file.as_posix())
812
+ and filesystem.get_file_info(state_file.as_posix()).is_file
813
+ ):
814
+ format = "cloudpickle" if extension == "pkl" else "msgpack"
815
+ break
816
+ if format is None:
817
+ raise ValueError(
818
+ "Given checkpoint does not seem to be valid! No file with the name "
819
+ "`algorithm_state.[pkl|msgpack|msgpck]` (or `checkpoint-[0-9]+`) found."
820
+ )
821
+
822
+ info.update(
823
+ {
824
+ "format": format,
825
+ "state_file": str(state_file),
826
+ }
827
+ )
828
+
829
+ # Collect all policy IDs in the sub-dir "policies/".
830
+ policies_dir = checkpoint / "policies"
831
+ if _exists_at_fs_path(filesystem, policies_dir.as_posix()) and _is_dir(
832
+ filesystem.get_file_info(policies_dir.as_posix())
833
+ ):
834
+ policy_ids = set()
835
+ file_info_list = filesystem.get_file_info(
836
+ pyarrow.fs.FileSelector(policies_dir.as_posix(), recursive=False)
837
+ )
838
+ for file_info in file_info_list:
839
+ policy_ids.add(file_info.base_name)
840
+ info.update({"policy_ids": policy_ids})
841
+
842
+ # Collect all module IDs in the sub-dir "learner/module_state/".
843
+ modules_dir = (
844
+ checkpoint
845
+ / COMPONENT_LEARNER_GROUP
846
+ / COMPONENT_LEARNER
847
+ / COMPONENT_RL_MODULE
848
+ )
849
+ if _exists_at_fs_path(filesystem, checkpoint.as_posix()) and _is_dir(
850
+ filesystem.get_file_info(modules_dir.as_posix())
851
+ ):
852
+ module_ids = set()
853
+ file_info_list = filesystem.get_file_info(
854
+ pyarrow.fs.FileSelector(modules_dir.as_posix(), recursive=False)
855
+ )
856
+ for file_info in file_info_list:
857
+ # Only add subdirs (those are the ones where the RLModule data
858
+ # is stored, not files (could be json metadata files).
859
+ module_dir = modules_dir / file_info.base_name
860
+ if _is_dir(filesystem.get_file_info(module_dir.as_posix())):
861
+ module_ids.add(file_info.base_name)
862
+ info.update({"module_ids": module_ids})
863
+
864
+ # Checkpoint is a file: Use as-is (interpreting it as old Algorithm checkpoint
865
+ # version).
866
+ elif (
867
+ _exists_at_fs_path(filesystem, checkpoint.as_posix())
868
+ and filesystem.get_file_info(checkpoint.as_posix()).is_file
869
+ ):
870
+ info.update(
871
+ {
872
+ "checkpoint_version": version.Version("0.1"),
873
+ "checkpoint_dir": str(checkpoint.parent),
874
+ "state_file": str(checkpoint),
875
+ }
876
+ )
877
+
878
+ else:
879
+ raise ValueError(
880
+ f"Given checkpoint ({str(checkpoint)}) not found! Must be a "
881
+ "checkpoint directory (or a file for older checkpoint versions)."
882
+ )
883
+
884
+ return info
885
+
886
+
887
+ @OldAPIStack
888
+ def convert_to_msgpack_checkpoint(
889
+ checkpoint: Union[str, Checkpoint],
890
+ msgpack_checkpoint_dir: str,
891
+ ) -> str:
892
+ """Converts an Algorithm checkpoint (pickle based) to a msgpack based one.
893
+
894
+ Msgpack has the advantage of being python version independent.
895
+
896
+ Args:
897
+ checkpoint: The directory, in which to find the Algorithm checkpoint (pickle
898
+ based).
899
+ msgpack_checkpoint_dir: The directory, in which to create the new msgpack
900
+ based checkpoint.
901
+
902
+ Returns:
903
+ The directory in which the msgpack checkpoint has been created. Note that
904
+ this is the same as `msgpack_checkpoint_dir`.
905
+ """
906
+ from ray.rllib.algorithms import Algorithm
907
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
908
+ from ray.rllib.core.rl_module import validate_module_id
909
+
910
+ # Try to import msgpack and msgpack_numpy.
911
+ msgpack = try_import_msgpack(error=True)
912
+
913
+ # Restore the Algorithm using the python version dependent checkpoint.
914
+ algo = Algorithm.from_checkpoint(checkpoint)
915
+ state = algo.__getstate__()
916
+
917
+ # Convert all code in state into serializable data.
918
+ # Serialize the algorithm class.
919
+ state["algorithm_class"] = serialize_type(state["algorithm_class"])
920
+ # Serialize the algorithm's config object.
921
+ if not isinstance(state["config"], dict):
922
+ state["config"] = state["config"].serialize()
923
+ else:
924
+ state["config"] = AlgorithmConfig._serialize_dict(state["config"])
925
+
926
+ # Extract policy states from worker state (Policies get their own
927
+ # checkpoint sub-dirs).
928
+ policy_states = {}
929
+ if "worker" in state and "policy_states" in state["worker"]:
930
+ policy_states = state["worker"].pop("policy_states", {})
931
+
932
+ # Policy mapping fn.
933
+ state["worker"]["policy_mapping_fn"] = NOT_SERIALIZABLE
934
+ # Is Policy to train function.
935
+ state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE
936
+
937
+ # Add RLlib checkpoint version (as string).
938
+ state["checkpoint_version"] = str(CHECKPOINT_VERSION)
939
+
940
+ # Write state (w/o policies) to disk.
941
+ state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")
942
+ with open(state_file, "wb") as f:
943
+ msgpack.dump(state, f)
944
+
945
+ # Write rllib_checkpoint.json.
946
+ with open(os.path.join(msgpack_checkpoint_dir, "rllib_checkpoint.json"), "w") as f:
947
+ json.dump(
948
+ {
949
+ "type": "Algorithm",
950
+ "checkpoint_version": state["checkpoint_version"],
951
+ "format": "msgpack",
952
+ "state_file": state_file,
953
+ "policy_ids": list(policy_states.keys()),
954
+ "ray_version": ray.__version__,
955
+ "ray_commit": ray.__commit__,
956
+ },
957
+ f,
958
+ )
959
+
960
+ # Write individual policies to disk, each in their own subdirectory.
961
+ for pid, policy_state in policy_states.items():
962
+ # From here on, disallow policyIDs that would not work as directory names.
963
+ validate_module_id(pid, error=True)
964
+ policy_dir = os.path.join(msgpack_checkpoint_dir, "policies", pid)
965
+ os.makedirs(policy_dir, exist_ok=True)
966
+ policy = algo.get_policy(pid)
967
+ policy.export_checkpoint(
968
+ policy_dir,
969
+ policy_state=policy_state,
970
+ checkpoint_format="msgpack",
971
+ )
972
+
973
+ # Release all resources used by the Algorithm.
974
+ algo.stop()
975
+
976
+ return msgpack_checkpoint_dir
977
+
978
+
979
+ @OldAPIStack
980
+ def convert_to_msgpack_policy_checkpoint(
981
+ policy_checkpoint: Union[str, Checkpoint],
982
+ msgpack_checkpoint_dir: str,
983
+ ) -> str:
984
+ """Converts a Policy checkpoint (pickle based) to a msgpack based one.
985
+
986
+ Msgpack has the advantage of being python version independent.
987
+
988
+ Args:
989
+ policy_checkpoint: The directory, in which to find the Policy checkpoint (pickle
990
+ based).
991
+ msgpack_checkpoint_dir: The directory, in which to create the new msgpack
992
+ based checkpoint.
993
+
994
+ Returns:
995
+ The directory in which the msgpack checkpoint has been created. Note that
996
+ this is the same as `msgpack_checkpoint_dir`.
997
+ """
998
+ from ray.rllib.policy.policy import Policy
999
+
1000
+ policy = Policy.from_checkpoint(policy_checkpoint)
1001
+
1002
+ os.makedirs(msgpack_checkpoint_dir, exist_ok=True)
1003
+ policy.export_checkpoint(
1004
+ msgpack_checkpoint_dir,
1005
+ policy_state=policy.get_state(),
1006
+ checkpoint_format="msgpack",
1007
+ )
1008
+
1009
+ # Release all resources used by the Policy.
1010
+ del policy
1011
+
1012
+ return msgpack_checkpoint_dir
1013
+
1014
+
1015
+ @PublicAPI
1016
+ def try_import_msgpack(error: bool = False):
1017
+ """Tries importing msgpack and msgpack_numpy and returns the patched msgpack module.
1018
+
1019
+ Returns None if error is False and msgpack or msgpack_numpy is not installed.
1020
+ Raises an error, if error is True and the modules could not be imported.
1021
+
1022
+ Args:
1023
+ error: Whether to raise an error if msgpack/msgpack_numpy cannot be imported.
1024
+
1025
+ Returns:
1026
+ The `msgpack` module.
1027
+
1028
+ Raises:
1029
+ ImportError: If error=True and msgpack/msgpack_numpy is not installed.
1030
+ """
1031
+ try:
1032
+ import msgpack
1033
+ import msgpack_numpy
1034
+
1035
+ # Make msgpack_numpy look like msgpack.
1036
+ msgpack_numpy.patch()
1037
+
1038
+ return msgpack
1039
+
1040
+ except Exception:
1041
+ if error:
1042
+ raise ImportError(
1043
+ "Could not import or setup msgpack and msgpack_numpy! "
1044
+ "Try running `pip install msgpack msgpack_numpy` first."
1045
+ )
.venv/lib/python3.11/site-packages/ray/rllib/utils/deprecation.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ from typing import Optional, Union
4
+
5
+ from ray.util import log_once
6
+ from ray.util.annotations import _mark_annotated
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # A constant to use for any configuration that should be deprecated
11
+ # (to check, whether this config has actually been assigned a proper value or
12
+ # not).
13
+ DEPRECATED_VALUE = -1
14
+
15
+
16
+ def deprecation_warning(
17
+ old: str,
18
+ new: Optional[str] = None,
19
+ *,
20
+ help: Optional[str] = None,
21
+ error: Optional[Union[bool, Exception]] = None,
22
+ ) -> None:
23
+ """Warns (via the `logger` object) or throws a deprecation warning/error.
24
+
25
+ Args:
26
+ old: A description of the "thing" that is to be deprecated.
27
+ new: A description of the new "thing" that replaces it.
28
+ help: An optional help text to tell the user, what to
29
+ do instead of using `old`.
30
+ error: Whether or which exception to raise. If True, raise ValueError.
31
+ If False, just warn. If `error` is-a subclass of Exception,
32
+ raise that Exception.
33
+
34
+ Raises:
35
+ ValueError: If `error=True`.
36
+ Exception: Of type `error`, iff `error` is a sub-class of `Exception`.
37
+ """
38
+ msg = "`{}` has been deprecated.{}".format(
39
+ old, (" Use `{}` instead.".format(new) if new else f" {help}" if help else "")
40
+ )
41
+
42
+ if error:
43
+ if not isinstance(error, bool) and issubclass(error, Exception):
44
+ # error is an Exception
45
+ raise error(msg)
46
+ else:
47
+ # error is a boolean, construct ValueError ourselves
48
+ raise ValueError(msg)
49
+ else:
50
+ logger.warning(
51
+ "DeprecationWarning: " + msg + " This will raise an error in the future!"
52
+ )
53
+
54
+
55
+ def Deprecated(old=None, *, new=None, help=None, error):
56
+ """Decorator for documenting a deprecated class, method, or function.
57
+
58
+ Automatically adds a `deprecation.deprecation_warning(old=...,
59
+ error=False)` to not break existing code at this point to the decorated
60
+ class' constructor, method, or function.
61
+
62
+ In a next major release, this warning should then be made an error
63
+ (by setting error=True), which means at this point that the
64
+ class/method/function is no longer supported, but will still inform
65
+ the user about the deprecation event.
66
+
67
+ In a further major release, the class, method, function should be erased
68
+ entirely from the codebase.
69
+
70
+
71
+ .. testcode::
72
+ :skipif: True
73
+
74
+ from ray.rllib.utils.deprecation import Deprecated
75
+ # Deprecated class: Patches the constructor to warn if the class is
76
+ # used.
77
+ @Deprecated(new="NewAndMuchCoolerClass", error=False)
78
+ class OldAndUncoolClass:
79
+ ...
80
+
81
+ # Deprecated class method: Patches the method to warn if called.
82
+ class StillCoolClass:
83
+ ...
84
+ @Deprecated(new="StillCoolClass.new_and_much_cooler_method()",
85
+ error=False)
86
+ def old_and_uncool_method(self, uncool_arg):
87
+ ...
88
+
89
+ # Deprecated function: Patches the function to warn if called.
90
+ @Deprecated(new="new_and_much_cooler_function", error=False)
91
+ def old_and_uncool_function(*uncool_args):
92
+ ...
93
+ """
94
+
95
+ def _inner(obj):
96
+ # A deprecated class.
97
+ if inspect.isclass(obj):
98
+ # Patch the class' init method to raise the warning/error.
99
+ obj_init = obj.__init__
100
+
101
+ def patched_init(*args, **kwargs):
102
+ if log_once(old or obj.__name__):
103
+ deprecation_warning(
104
+ old=old or obj.__name__,
105
+ new=new,
106
+ help=help,
107
+ error=error,
108
+ )
109
+ return obj_init(*args, **kwargs)
110
+
111
+ obj.__init__ = patched_init
112
+ _mark_annotated(obj)
113
+ # Return the patched class (with the warning/error when
114
+ # instantiated).
115
+ return obj
116
+
117
+ # A deprecated class method or function.
118
+ # Patch with the warning/error at the beginning.
119
+ def _ctor(*args, **kwargs):
120
+ if log_once(old or obj.__name__):
121
+ deprecation_warning(
122
+ old=old or obj.__name__,
123
+ new=new,
124
+ help=help,
125
+ error=error,
126
+ )
127
+ # Call the deprecated method/function.
128
+ return obj(*args, **kwargs)
129
+
130
+ # Return the patched class method/function.
131
+ return _ctor
132
+
133
+ # Return the prepared decorator.
134
+ return _inner
.venv/lib/python3.11/site-packages/ray/rllib/utils/error.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.annotations import PublicAPI
2
+
3
+
4
+ @PublicAPI
5
+ class UnsupportedSpaceException(Exception):
6
+ """Error for an unsupported action or observation space."""
7
+
8
+ pass
9
+
10
+
11
+ @PublicAPI
12
+ class EnvError(Exception):
13
+ """Error if we encounter an error during RL environment validation."""
14
+
15
+ pass
16
+
17
+
18
+ @PublicAPI
19
+ class MultiAgentEnvError(Exception):
20
+ """Error if we encounter an error during MultiAgentEnv stepping/validation."""
21
+
22
+ pass
23
+
24
+
25
+ @PublicAPI
26
+ class NotSerializable(Exception):
27
+ """Error if we encounter objects that can't be serialized by ray."""
28
+
29
+ pass
30
+
31
+
32
+ # -------
33
+ # Error messages
34
+ # -------
35
+
36
+ # Message explaining there are no GPUs available for the
37
+ # num_gpus=n or num_gpus_per_env_runner=m settings.
38
+ ERR_MSG_NO_GPUS = """Found {} GPUs on your machine (GPU devices found: {})! If your
39
+ machine does not have any GPUs, you should set the config keys
40
+ `num_gpus_per_learner` and `num_gpus_per_env_runner` to 0. They may be set to
41
+ 1 by default for your particular RL algorithm."""
42
+
43
+ ERR_MSG_INVALID_ENV_DESCRIPTOR = """The env string you provided ('{}') is:
44
+ a) Not a supported or -installed environment.
45
+ b) Not a tune-registered environment creator.
46
+ c) Not a valid env class string.
47
+
48
+ Try one of the following:
49
+ a) For Atari support: `pip install gym[atari] autorom[accept-rom-license]`.
50
+ For PyBullet support: `pip install pybullet`.
51
+ b) To register your custom env, do `from ray import tune;
52
+ tune.register('[name]', lambda cfg: [return env obj from here using cfg])`.
53
+ Then in your config, do `config['env'] = [name]`.
54
+ c) Make sure you provide a fully qualified classpath, e.g.:
55
+ `ray.rllib.examples.envs.classes.repeat_after_me_env.RepeatAfterMeEnv`
56
+ """
57
+
58
+
59
+ ERR_MSG_OLD_GYM_API = """Your environment ({}) does not abide to the new gymnasium-style API!
60
+ From Ray 2.3 on, RLlib only supports the new (gym>=0.26 or gymnasium) Env APIs.
61
+ {}
62
+ Learn more about the most important changes here:
63
+ https://github.com/openai/gym and here: https://github.com/Farama-Foundation/Gymnasium
64
+
65
+ In order to fix this problem, do the following:
66
+
67
+ 1) Run `pip install gymnasium` on your command line.
68
+ 2) Change all your import statements in your code from
69
+ `import gym` -> `import gymnasium as gym` OR
70
+ `from gym.spaces import Discrete` -> `from gymnasium.spaces import Discrete`
71
+
72
+ For your custom (single agent) gym.Env classes:
73
+ 3.1) Either wrap your old Env class via the provided `from gymnasium.wrappers import
74
+ EnvCompatibility` wrapper class.
75
+ 3.2) Alternatively to 3.1:
76
+ - Change your `reset()` method to have the call signature 'def reset(self, *,
77
+ seed=None, options=None)'
78
+ - Return an additional info dict (empty dict should be fine) from your `reset()`
79
+ method.
80
+ - Return an additional `truncated` flag from your `step()` method (between `done` and
81
+ `info`). This flag should indicate, whether the episode was terminated prematurely
82
+ due to some time constraint or other kind of horizon setting.
83
+
84
+ For your custom RLlib `MultiAgentEnv` classes:
85
+ 4.1) Either wrap your old MultiAgentEnv via the provided
86
+ `from ray.rllib.env.wrappers.multi_agent_env_compatibility import
87
+ MultiAgentEnvCompatibility` wrapper class.
88
+ 4.2) Alternatively to 4.1:
89
+ - Change your `reset()` method to have the call signature
90
+ 'def reset(self, *, seed=None, options=None)'
91
+ - Return an additional per-agent info dict (empty dict should be fine) from your
92
+ `reset()` method.
93
+ - Rename `dones` into `terminateds` and only set this to True, if the episode is really
94
+ done (as opposed to has been terminated prematurely due to some horizon/time-limit
95
+ setting).
96
+ - Return an additional `truncateds` per-agent dictionary flag from your `step()`
97
+ method, including the `__all__` key (100% analogous to your `dones/terminateds`
98
+ per-agent dict).
99
+ Return this new `truncateds` dict between `dones/terminateds` and `infos`. This
100
+ flag should indicate, whether the episode (for some agent or all agents) was
101
+ terminated prematurely due to some time constraint or other kind of horizon setting.
102
+ """ # noqa
103
+
104
+
105
+ ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL = """Could not save keras model under self[TfPolicy].model.base_model!
106
+ This is either due to ..
107
+ a) .. this Policy's ModelV2 not having any `base_model` (tf.keras.Model) property
108
+ b) .. the ModelV2's `base_model` not being used by the Algorithm and thus its
109
+ variables not being properly initialized.
110
+ """ # noqa
111
+
112
+ ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL = """Could not save torch model under self[TorchPolicy].model!
113
+ This is most likely due to the fact that you are using an Algorithm that
114
+ uses a Catalog-generated TorchModelV2 subclass, which is torch.save() cannot pickle.
115
+ """ # noqa
116
+
117
+ # -------
118
+ # HOWTO_ strings can be added to any error/warning/into message
119
+ # to eplain to the user, how to actually fix the encountered problem.
120
+ # -------
121
+
122
+ # HOWTO change the RLlib config, depending on how user runs the job.
123
+ HOWTO_CHANGE_CONFIG = """
124
+ To change the config for `tune.Tuner().fit()` in a script: Modify the python dict
125
+ passed to `tune.Tuner(param_space=[...]).fit()`.
126
+ To change the config for an RLlib Algorithm instance: Modify the python dict
127
+ passed to the Algorithm's constructor, e.g. `PPO(config=[...])`.
128
+ """
.venv/lib/python3.11/site-packages/ray/rllib/utils/filter_manager.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import ray
5
+ from ray.rllib.utils.annotations import OldAPIStack
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @OldAPIStack
11
+ class FilterManager:
12
+ """Manages filters and coordination across remote evaluators that expose
13
+ `get_filters` and `sync_filters`.
14
+ """
15
+
16
+ @staticmethod
17
+ def synchronize(
18
+ local_filters,
19
+ worker_set,
20
+ update_remote=True,
21
+ timeout_seconds: Optional[float] = None,
22
+ use_remote_data_for_update: bool = True,
23
+ ):
24
+ """Aggregates filters from remote workers (if use_remote_data_for_update=True).
25
+
26
+ Local copy is updated and then broadcasted to all remote evaluators
27
+ (if `update_remote` is True).
28
+
29
+ Args:
30
+ local_filters: Filters to be synchronized.
31
+ worker_set: EnvRunnerGroup with remote EnvRunners with filters.
32
+ update_remote: Whether to push updates from the local filters to the remote
33
+ workers' filters.
34
+ timeout_seconds: How long to wait for filter to get or set filters
35
+ use_remote_data_for_update: Whether to use the `worker_set`'s remote workers
36
+ to update the local filters. If False, stats from the remote workers
37
+ will not be used and discarded.
38
+ """
39
+ # No sync/update required in either direction -> Early out.
40
+ if not (update_remote or use_remote_data_for_update):
41
+ return
42
+
43
+ logger.debug(f"Synchronizing filters: {local_filters}")
44
+
45
+ # Get the filters from the remote workers.
46
+ remote_filters = worker_set.foreach_env_runner(
47
+ func=lambda worker: worker.get_filters(flush_after=True),
48
+ local_env_runner=False,
49
+ timeout_seconds=timeout_seconds,
50
+ )
51
+ if len(remote_filters) != worker_set.num_healthy_remote_workers():
52
+ logger.error(
53
+ "Failed to get remote filters from a rollout worker in "
54
+ "FilterManager! "
55
+ "Filtered metrics may be computed, but filtered wrong."
56
+ )
57
+
58
+ # Should we utilize the remote workers' filter stats to update the local
59
+ # filters?
60
+ if use_remote_data_for_update:
61
+ for rf in remote_filters:
62
+ for k in local_filters:
63
+ local_filters[k].apply_changes(rf[k], with_buffer=False)
64
+
65
+ # Should we update the remote workers' filters from the (now possibly synched)
66
+ # local filters?
67
+ if update_remote:
68
+ copies = {k: v.as_serializable() for k, v in local_filters.items()}
69
+ remote_copy = ray.put(copies)
70
+
71
+ logger.debug("Updating remote filters ...")
72
+ results = worker_set.foreach_env_runner(
73
+ func=lambda worker: worker.sync_filters(ray.get(remote_copy)),
74
+ local_env_runner=False,
75
+ timeout_seconds=timeout_seconds,
76
+ )
77
+ if len(results) != worker_set.num_healthy_remote_workers():
78
+ logger.error(
79
+ "Failed to set remote filters to a rollout worker in "
80
+ "FilterManager. "
81
+ "Filtered metrics may be computed, but filtered wrong."
82
+ )