koichi12 commited on
Commit
24f659d
·
verified ·
1 Parent(s): 6b42d14

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py +13 -0
  2. .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py +1358 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py +1047 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py +1051 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py +966 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py +1696 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py +294 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py +448 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py +683 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py +1820 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py +389 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py +1200 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py +365 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py +221 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py +1201 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py +1260 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py +152 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py +10 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py +56 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py +211 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py +79 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py +39 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py +444 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py +246 -0
.venv/lib/python3.11/site-packages/ray/rllib/policy/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.policy.policy import Policy
2
+ from ray.rllib.policy.torch_policy import TorchPolicy
3
+ from ray.rllib.policy.tf_policy import TFPolicy
4
+ from ray.rllib.policy.policy_template import build_policy_class
5
+ from ray.rllib.policy.tf_policy_template import build_tf_policy
6
+
7
+ __all__ = [
8
+ "Policy",
9
+ "TFPolicy",
10
+ "TorchPolicy",
11
+ "build_policy_class",
12
+ "build_tf_policy",
13
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (654 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/tf_policy.cpython-311.pyc ADDED
Binary file (61.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_mixins.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/policy/__pycache__/torch_policy_v2.cpython-311.pyc ADDED
Binary file (60.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy.py ADDED
@@ -0,0 +1,1358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple, OrderedDict
2
+ import gymnasium as gym
3
+ import logging
4
+ import re
5
+ import tree # pip install dm_tree
6
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
7
+
8
+ from ray.util.debug import log_once
9
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
10
+ from ray.rllib.models.modelv2 import ModelV2
11
+ from ray.rllib.policy.policy import Policy
12
+ from ray.rllib.policy.sample_batch import SampleBatch
13
+ from ray.rllib.policy.tf_policy import TFPolicy
14
+ from ray.rllib.policy.view_requirement import ViewRequirement
15
+ from ray.rllib.models.catalog import ModelCatalog
16
+ from ray.rllib.utils import force_list
17
+ from ray.rllib.utils.annotations import OldAPIStack, override
18
+ from ray.rllib.utils.debug import summarize
19
+ from ray.rllib.utils.deprecation import (
20
+ deprecation_warning,
21
+ DEPRECATED_VALUE,
22
+ )
23
+ from ray.rllib.utils.framework import try_import_tf
24
+ from ray.rllib.utils.metrics import (
25
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
26
+ NUM_GRAD_UPDATES_LIFETIME,
27
+ )
28
+ from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
29
+ from ray.rllib.utils.tf_utils import get_placeholder
30
+ from ray.rllib.utils.typing import (
31
+ LocalOptimizer,
32
+ ModelGradients,
33
+ TensorType,
34
+ AlgorithmConfigDict,
35
+ )
36
+
37
+ tf1, tf, tfv = try_import_tf()
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ # Variable scope in which created variables will be placed under.
42
+ TOWER_SCOPE_NAME = "tower"
43
+
44
+
45
+ @OldAPIStack
46
+ class DynamicTFPolicy(TFPolicy):
47
+ """A TFPolicy that auto-defines placeholders dynamically at runtime.
48
+
49
+ Do not sub-class this class directly (neither should you sub-class
50
+ TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy
51
+ to generate your custom tf (graph-mode or eager) Policy classes.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ obs_space: gym.spaces.Space,
57
+ action_space: gym.spaces.Space,
58
+ config: AlgorithmConfigDict,
59
+ loss_fn: Callable[
60
+ [Policy, ModelV2, Type[TFActionDistribution], SampleBatch], TensorType
61
+ ],
62
+ *,
63
+ stats_fn: Optional[
64
+ Callable[[Policy, SampleBatch], Dict[str, TensorType]]
65
+ ] = None,
66
+ grad_stats_fn: Optional[
67
+ Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
68
+ ] = None,
69
+ before_loss_init: Optional[
70
+ Callable[
71
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
72
+ ]
73
+ ] = None,
74
+ make_model: Optional[
75
+ Callable[
76
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
77
+ ModelV2,
78
+ ]
79
+ ] = None,
80
+ action_sampler_fn: Optional[
81
+ Callable[
82
+ [TensorType, List[TensorType]],
83
+ Union[
84
+ Tuple[TensorType, TensorType],
85
+ Tuple[TensorType, TensorType, TensorType, List[TensorType]],
86
+ ],
87
+ ]
88
+ ] = None,
89
+ action_distribution_fn: Optional[
90
+ Callable[
91
+ [Policy, ModelV2, TensorType, TensorType, TensorType],
92
+ Tuple[TensorType, type, List[TensorType]],
93
+ ]
94
+ ] = None,
95
+ existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
96
+ existing_model: Optional[ModelV2] = None,
97
+ get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
98
+ obs_include_prev_action_reward=DEPRECATED_VALUE,
99
+ ):
100
+ """Initializes a DynamicTFPolicy instance.
101
+
102
+ Initialization of this class occurs in two phases and defines the
103
+ static graph.
104
+
105
+ Phase 1: The model is created and model variables are initialized.
106
+
107
+ Phase 2: A fake batch of data is created, sent to the trajectory
108
+ postprocessor, and then used to create placeholders for the loss
109
+ function. The loss and stats functions are initialized with these
110
+ placeholders.
111
+
112
+ Args:
113
+ observation_space: Observation space of the policy.
114
+ action_space: Action space of the policy.
115
+ config: Policy-specific configuration data.
116
+ loss_fn: Function that returns a loss tensor for the policy graph.
117
+ stats_fn: Optional callable that - given the policy and batch
118
+ input tensors - returns a dict mapping str to TF ops.
119
+ These ops are fetched from the graph after loss calculations
120
+ and the resulting values can be found in the results dict
121
+ returned by e.g. `Algorithm.train()` or in tensorboard (if TB
122
+ logging is enabled).
123
+ grad_stats_fn: Optional callable that - given the policy, batch
124
+ input tensors, and calculated loss gradient tensors - returns
125
+ a dict mapping str to TF ops. These ops are fetched from the
126
+ graph after loss and gradient calculations and the resulting
127
+ values can be found in the results dict returned by e.g.
128
+ `Algorithm.train()` or in tensorboard (if TB logging is
129
+ enabled).
130
+ before_loss_init: Optional function to run prior to
131
+ loss init that takes the same arguments as __init__.
132
+ make_model: Optional function that returns a ModelV2 object
133
+ given policy, obs_space, action_space, and policy config.
134
+ All policy variables should be created in this function. If not
135
+ specified, a default model will be created.
136
+ action_sampler_fn: A callable returning either a sampled action and
137
+ its log-likelihood or a sampled action, its log-likelihood,
138
+ action distribution inputs and updated state given Policy,
139
+ ModelV2, observation inputs, explore, and is_training.
140
+ Provide `action_sampler_fn` if you would like to have full
141
+ control over the action computation step, including the
142
+ model forward pass, possible sampling from a distribution,
143
+ and exploration logic.
144
+ Note: If `action_sampler_fn` is given, `action_distribution_fn`
145
+ must be None. If both `action_sampler_fn` and
146
+ `action_distribution_fn` are None, RLlib will simply pass
147
+ inputs through `self.model` to get distribution inputs, create
148
+ the distribution object, sample from it, and apply some
149
+ exploration logic to the results.
150
+ The callable takes as inputs: Policy, ModelV2, obs_batch,
151
+ state_batches (optional), seq_lens (optional),
152
+ prev_actions_batch (optional), prev_rewards_batch (optional),
153
+ explore, and is_training.
154
+ action_distribution_fn: A callable returning distribution inputs
155
+ (parameters), a dist-class to generate an action distribution
156
+ object from, and internal-state outputs (or an empty list if
157
+ not applicable).
158
+ Provide `action_distribution_fn` if you would like to only
159
+ customize the model forward pass call. The resulting
160
+ distribution parameters are then used by RLlib to create a
161
+ distribution object, sample from it, and execute any
162
+ exploration logic.
163
+ Note: If `action_distribution_fn` is given, `action_sampler_fn`
164
+ must be None. If both `action_sampler_fn` and
165
+ `action_distribution_fn` are None, RLlib will simply pass
166
+ inputs through `self.model` to get distribution inputs, create
167
+ the distribution object, sample from it, and apply some
168
+ exploration logic to the results.
169
+ The callable takes as inputs: Policy, ModelV2, input_dict,
170
+ explore, timestep, is_training.
171
+ existing_inputs: When copying a policy, this specifies an existing
172
+ dict of placeholders to use instead of defining new ones.
173
+ existing_model: When copying a policy, this specifies an existing
174
+ model to clone and share weights with.
175
+ get_batch_divisibility_req: Optional callable that returns the
176
+ divisibility requirement for sample batches. If None, will
177
+ assume a value of 1.
178
+ """
179
+ if obs_include_prev_action_reward != DEPRECATED_VALUE:
180
+ deprecation_warning(old="obs_include_prev_action_reward", error=True)
181
+ self.observation_space = obs_space
182
+ self.action_space = action_space
183
+ self.config = config
184
+ self.framework = "tf"
185
+ self._loss_fn = loss_fn
186
+ self._stats_fn = stats_fn
187
+ self._grad_stats_fn = grad_stats_fn
188
+ self._seq_lens = None
189
+ self._is_tower = existing_inputs is not None
190
+
191
+ dist_class = None
192
+ if action_sampler_fn or action_distribution_fn:
193
+ if not make_model:
194
+ raise ValueError(
195
+ "`make_model` is required if `action_sampler_fn` OR "
196
+ "`action_distribution_fn` is given"
197
+ )
198
+ else:
199
+ dist_class, logit_dim = ModelCatalog.get_action_dist(
200
+ action_space, self.config["model"]
201
+ )
202
+
203
+ # Setup self.model.
204
+ if existing_model:
205
+ if isinstance(existing_model, list):
206
+ self.model = existing_model[0]
207
+ # TODO: (sven) hack, but works for `target_[q_]?model`.
208
+ for i in range(1, len(existing_model)):
209
+ setattr(self, existing_model[i][0], existing_model[i][1])
210
+ elif make_model:
211
+ self.model = make_model(self, obs_space, action_space, config)
212
+ else:
213
+ self.model = ModelCatalog.get_model_v2(
214
+ obs_space=obs_space,
215
+ action_space=action_space,
216
+ num_outputs=logit_dim,
217
+ model_config=self.config["model"],
218
+ framework="tf",
219
+ )
220
+ # Auto-update model's inference view requirements, if recurrent.
221
+ self._update_model_view_requirements_from_init_state()
222
+
223
+ # Input placeholders already given -> Use these.
224
+ if existing_inputs:
225
+ self._state_inputs = [
226
+ v for k, v in existing_inputs.items() if k.startswith("state_in_")
227
+ ]
228
+ # Placeholder for RNN time-chunk valid lengths.
229
+ if self._state_inputs:
230
+ self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
231
+ # Create new input placeholders.
232
+ else:
233
+ self._state_inputs = [
234
+ get_placeholder(
235
+ space=vr.space,
236
+ time_axis=not isinstance(vr.shift, int),
237
+ name=k,
238
+ )
239
+ for k, vr in self.model.view_requirements.items()
240
+ if k.startswith("state_in_")
241
+ ]
242
+ # Placeholder for RNN time-chunk valid lengths.
243
+ if self._state_inputs:
244
+ self._seq_lens = tf1.placeholder(
245
+ dtype=tf.int32, shape=[None], name="seq_lens"
246
+ )
247
+
248
+ # Use default settings.
249
+ # Add NEXT_OBS, STATE_IN_0.., and others.
250
+ self.view_requirements = self._get_default_view_requirements()
251
+ # Combine view_requirements for Model and Policy.
252
+ self.view_requirements.update(self.model.view_requirements)
253
+ # Disable env-info placeholder.
254
+ if SampleBatch.INFOS in self.view_requirements:
255
+ self.view_requirements[SampleBatch.INFOS].used_for_training = False
256
+
257
+ # Setup standard placeholders.
258
+ if self._is_tower:
259
+ timestep = existing_inputs["timestep"]
260
+ explore = False
261
+ self._input_dict, self._dummy_batch = self._get_input_dict_and_dummy_batch(
262
+ self.view_requirements, existing_inputs
263
+ )
264
+ else:
265
+ if not self.config.get("_disable_action_flattening"):
266
+ action_ph = ModelCatalog.get_action_placeholder(action_space)
267
+ prev_action_ph = {}
268
+ if SampleBatch.PREV_ACTIONS not in self.view_requirements:
269
+ prev_action_ph = {
270
+ SampleBatch.PREV_ACTIONS: ModelCatalog.get_action_placeholder(
271
+ action_space, "prev_action"
272
+ )
273
+ }
274
+ (
275
+ self._input_dict,
276
+ self._dummy_batch,
277
+ ) = self._get_input_dict_and_dummy_batch(
278
+ self.view_requirements,
279
+ dict({SampleBatch.ACTIONS: action_ph}, **prev_action_ph),
280
+ )
281
+ else:
282
+ (
283
+ self._input_dict,
284
+ self._dummy_batch,
285
+ ) = self._get_input_dict_and_dummy_batch(self.view_requirements, {})
286
+ # Placeholder for (sampling steps) timestep (int).
287
+ timestep = tf1.placeholder_with_default(
288
+ tf.zeros((), dtype=tf.int64), (), name="timestep"
289
+ )
290
+ # Placeholder for `is_exploring` flag.
291
+ explore = tf1.placeholder_with_default(True, (), name="is_exploring")
292
+
293
+ # Placeholder for `is_training` flag.
294
+ self._input_dict.set_training(self._get_is_training_placeholder())
295
+
296
+ # Multi-GPU towers do not need any action computing/exploration
297
+ # graphs.
298
+ sampled_action = None
299
+ sampled_action_logp = None
300
+ dist_inputs = None
301
+ extra_action_fetches = {}
302
+ self._state_out = None
303
+ if not self._is_tower:
304
+ # Create the Exploration object to use for this Policy.
305
+ self.exploration = self._create_exploration()
306
+
307
+ # Fully customized action generation (e.g., custom policy).
308
+ if action_sampler_fn:
309
+ action_sampler_outputs = action_sampler_fn(
310
+ self,
311
+ self.model,
312
+ obs_batch=self._input_dict[SampleBatch.CUR_OBS],
313
+ state_batches=self._state_inputs,
314
+ seq_lens=self._seq_lens,
315
+ prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS),
316
+ prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS),
317
+ explore=explore,
318
+ is_training=self._input_dict.is_training,
319
+ )
320
+ if len(action_sampler_outputs) == 4:
321
+ (
322
+ sampled_action,
323
+ sampled_action_logp,
324
+ dist_inputs,
325
+ self._state_out,
326
+ ) = action_sampler_outputs
327
+ else:
328
+ dist_inputs = None
329
+ self._state_out = []
330
+ sampled_action, sampled_action_logp = action_sampler_outputs
331
+ # Distribution generation is customized, e.g., DQN, DDPG.
332
+ else:
333
+ if action_distribution_fn:
334
+ # Try new action_distribution_fn signature, supporting
335
+ # state_batches and seq_lens.
336
+ in_dict = self._input_dict
337
+ try:
338
+ (
339
+ dist_inputs,
340
+ dist_class,
341
+ self._state_out,
342
+ ) = action_distribution_fn(
343
+ self,
344
+ self.model,
345
+ input_dict=in_dict,
346
+ state_batches=self._state_inputs,
347
+ seq_lens=self._seq_lens,
348
+ explore=explore,
349
+ timestep=timestep,
350
+ is_training=in_dict.is_training,
351
+ )
352
+ # Trying the old way (to stay backward compatible).
353
+ # TODO: Remove in future.
354
+ except TypeError as e:
355
+ if (
356
+ "positional argument" in e.args[0]
357
+ or "unexpected keyword argument" in e.args[0]
358
+ ):
359
+ (
360
+ dist_inputs,
361
+ dist_class,
362
+ self._state_out,
363
+ ) = action_distribution_fn(
364
+ self,
365
+ self.model,
366
+ obs_batch=in_dict[SampleBatch.CUR_OBS],
367
+ state_batches=self._state_inputs,
368
+ seq_lens=self._seq_lens,
369
+ prev_action_batch=in_dict.get(SampleBatch.PREV_ACTIONS),
370
+ prev_reward_batch=in_dict.get(SampleBatch.PREV_REWARDS),
371
+ explore=explore,
372
+ is_training=in_dict.is_training,
373
+ )
374
+ else:
375
+ raise e
376
+
377
+ # Default distribution generation behavior:
378
+ # Pass through model. E.g., PG, PPO.
379
+ else:
380
+ if isinstance(self.model, tf.keras.Model):
381
+ dist_inputs, self._state_out, extra_action_fetches = self.model(
382
+ self._input_dict
383
+ )
384
+ else:
385
+ dist_inputs, self._state_out = self.model(self._input_dict)
386
+
387
+ action_dist = dist_class(dist_inputs, self.model)
388
+
389
+ # Using exploration to get final action (e.g. via sampling).
390
+ (
391
+ sampled_action,
392
+ sampled_action_logp,
393
+ ) = self.exploration.get_exploration_action(
394
+ action_distribution=action_dist, timestep=timestep, explore=explore
395
+ )
396
+
397
+ if dist_inputs is not None:
398
+ extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
399
+
400
+ if sampled_action_logp is not None:
401
+ extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
402
+ extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
403
+ tf.cast(sampled_action_logp, tf.float32)
404
+ )
405
+
406
+ # Phase 1 init.
407
+ sess = tf1.get_default_session() or tf1.Session(
408
+ config=tf1.ConfigProto(**self.config["tf_session_args"])
409
+ )
410
+
411
+ batch_divisibility_req = (
412
+ get_batch_divisibility_req(self)
413
+ if callable(get_batch_divisibility_req)
414
+ else (get_batch_divisibility_req or 1)
415
+ )
416
+
417
+ prev_action_input = (
418
+ self._input_dict[SampleBatch.PREV_ACTIONS]
419
+ if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys
420
+ else None
421
+ )
422
+ prev_reward_input = (
423
+ self._input_dict[SampleBatch.PREV_REWARDS]
424
+ if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys
425
+ else None
426
+ )
427
+
428
+ super().__init__(
429
+ observation_space=obs_space,
430
+ action_space=action_space,
431
+ config=config,
432
+ sess=sess,
433
+ obs_input=self._input_dict[SampleBatch.OBS],
434
+ action_input=self._input_dict[SampleBatch.ACTIONS],
435
+ sampled_action=sampled_action,
436
+ sampled_action_logp=sampled_action_logp,
437
+ dist_inputs=dist_inputs,
438
+ dist_class=dist_class,
439
+ loss=None, # dynamically initialized on run
440
+ loss_inputs=[],
441
+ model=self.model,
442
+ state_inputs=self._state_inputs,
443
+ state_outputs=self._state_out,
444
+ prev_action_input=prev_action_input,
445
+ prev_reward_input=prev_reward_input,
446
+ seq_lens=self._seq_lens,
447
+ max_seq_len=config["model"]["max_seq_len"],
448
+ batch_divisibility_req=batch_divisibility_req,
449
+ explore=explore,
450
+ timestep=timestep,
451
+ )
452
+
453
+ # Phase 2 init.
454
+ if before_loss_init is not None:
455
+ before_loss_init(self, obs_space, action_space, config)
456
+ if hasattr(self, "_extra_action_fetches"):
457
+ self._extra_action_fetches.update(extra_action_fetches)
458
+ else:
459
+ self._extra_action_fetches = extra_action_fetches
460
+
461
+ # Loss initialization and model/postprocessing test calls.
462
+ if not self._is_tower:
463
+ self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True)
464
+
465
+ # Create MultiGPUTowerStacks, if we have at least one actual
466
+ # GPU or >1 CPUs (fake GPUs).
467
+ if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
468
+ # Per-GPU graph copies created here must share vars with the
469
+ # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
470
+ # Adam nodes are created after all of the device copies are
471
+ # created.
472
+ with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
473
+ self.multi_gpu_tower_stacks = [
474
+ TFMultiGPUTowerStack(policy=self)
475
+ for i in range(self.config.get("num_multi_gpu_tower_stacks", 1))
476
+ ]
477
+
478
+ # Initialize again after loss and tower init.
479
+ self.get_session().run(tf1.global_variables_initializer())
480
+
481
+ @override(TFPolicy)
482
+ def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
483
+ """Creates a copy of self using existing input placeholders."""
484
+
485
+ flat_loss_inputs = tree.flatten(self._loss_input_dict)
486
+ flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
487
+
488
+ # Note that there might be RNN state inputs at the end of the list
489
+ if len(flat_loss_inputs) != len(existing_inputs):
490
+ raise ValueError(
491
+ "Tensor list mismatch",
492
+ self._loss_input_dict,
493
+ self._state_inputs,
494
+ existing_inputs,
495
+ )
496
+ for i, v in enumerate(flat_loss_inputs_no_rnn):
497
+ if v.shape.as_list() != existing_inputs[i].shape.as_list():
498
+ raise ValueError(
499
+ "Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
500
+ )
501
+ # By convention, the loss inputs are followed by state inputs and then
502
+ # the seq len tensor.
503
+ rnn_inputs = []
504
+ for i in range(len(self._state_inputs)):
505
+ rnn_inputs.append(
506
+ (
507
+ "state_in_{}".format(i),
508
+ existing_inputs[len(flat_loss_inputs_no_rnn) + i],
509
+ )
510
+ )
511
+ if rnn_inputs:
512
+ rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
513
+ existing_inputs_unflattened = tree.unflatten_as(
514
+ self._loss_input_dict_no_rnn,
515
+ existing_inputs[: len(flat_loss_inputs_no_rnn)],
516
+ )
517
+ input_dict = OrderedDict(
518
+ [("is_exploring", self._is_exploring), ("timestep", self._timestep)]
519
+ + [
520
+ (k, existing_inputs_unflattened[k])
521
+ for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
522
+ ]
523
+ + rnn_inputs
524
+ )
525
+
526
+ instance = self.__class__(
527
+ self.observation_space,
528
+ self.action_space,
529
+ self.config,
530
+ existing_inputs=input_dict,
531
+ existing_model=[
532
+ self.model,
533
+ # Deprecated: Target models should all reside under
534
+ # `policy.target_model` now.
535
+ ("target_q_model", getattr(self, "target_q_model", None)),
536
+ ("target_model", getattr(self, "target_model", None)),
537
+ ],
538
+ )
539
+
540
+ instance._loss_input_dict = input_dict
541
+ losses = instance._do_loss_init(SampleBatch(input_dict))
542
+ loss_inputs = [
543
+ (k, existing_inputs_unflattened[k])
544
+ for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
545
+ ]
546
+
547
+ TFPolicy._initialize_loss(instance, losses, loss_inputs)
548
+ if instance._grad_stats_fn:
549
+ instance._stats_fetches.update(
550
+ instance._grad_stats_fn(instance, input_dict, instance._grads)
551
+ )
552
+ return instance
553
+
554
+ @override(Policy)
555
+ def get_initial_state(self) -> List[TensorType]:
556
+ if self.model:
557
+ return self.model.get_initial_state()
558
+ else:
559
+ return []
560
+
561
+ @override(Policy)
562
+ def load_batch_into_buffer(
563
+ self,
564
+ batch: SampleBatch,
565
+ buffer_index: int = 0,
566
+ ) -> int:
567
+ # Set the is_training flag of the batch.
568
+ batch.set_training(True)
569
+
570
+ # Shortcut for 1 CPU only: Store batch in
571
+ # `self._loaded_single_cpu_batch`.
572
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
573
+ assert buffer_index == 0
574
+ self._loaded_single_cpu_batch = batch
575
+ return len(batch)
576
+
577
+ input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
578
+ data_keys = tree.flatten(self._loss_input_dict_no_rnn)
579
+ if self._state_inputs:
580
+ state_keys = self._state_inputs + [self._seq_lens]
581
+ else:
582
+ state_keys = []
583
+ inputs = [input_dict[k] for k in data_keys]
584
+ state_inputs = [input_dict[k] for k in state_keys]
585
+
586
+ return self.multi_gpu_tower_stacks[buffer_index].load_data(
587
+ sess=self.get_session(),
588
+ inputs=inputs,
589
+ state_inputs=state_inputs,
590
+ num_grad_updates=batch.num_grad_updates,
591
+ )
592
+
593
+ @override(Policy)
594
+ def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
595
+ # Shortcut for 1 CPU only: Batch should already be stored in
596
+ # `self._loaded_single_cpu_batch`.
597
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
598
+ assert buffer_index == 0
599
+ return (
600
+ len(self._loaded_single_cpu_batch)
601
+ if self._loaded_single_cpu_batch is not None
602
+ else 0
603
+ )
604
+
605
+ return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
606
+
607
+ @override(Policy)
608
+ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
609
+ # Shortcut for 1 CPU only: Batch should already be stored in
610
+ # `self._loaded_single_cpu_batch`.
611
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
612
+ assert buffer_index == 0
613
+ if self._loaded_single_cpu_batch is None:
614
+ raise ValueError(
615
+ "Must call Policy.load_batch_into_buffer() before "
616
+ "Policy.learn_on_loaded_batch()!"
617
+ )
618
+ # Get the correct slice of the already loaded batch to use,
619
+ # based on offset and batch size.
620
+ batch_size = self.config.get("minibatch_size")
621
+ if batch_size is None:
622
+ batch_size = self.config.get(
623
+ "sgd_minibatch_size", self.config["train_batch_size"]
624
+ )
625
+ if batch_size >= len(self._loaded_single_cpu_batch):
626
+ sliced_batch = self._loaded_single_cpu_batch
627
+ else:
628
+ sliced_batch = self._loaded_single_cpu_batch.slice(
629
+ start=offset, end=offset + batch_size
630
+ )
631
+ return self.learn_on_batch(sliced_batch)
632
+
633
+ tower_stack = self.multi_gpu_tower_stacks[buffer_index]
634
+ results = tower_stack.optimize(self.get_session(), offset)
635
+ self.num_grad_updates += 1
636
+
637
+ results.update(
638
+ {
639
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
640
+ # -1, b/c we have to measure this diff before we do the update above.
641
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
642
+ self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0)
643
+ ),
644
+ }
645
+ )
646
+
647
+ return results
648
+
649
+ def _get_input_dict_and_dummy_batch(self, view_requirements, existing_inputs):
650
+ """Creates input_dict and dummy_batch for loss initialization.
651
+
652
+ Used for managing the Policy's input placeholders and for loss
653
+ initialization.
654
+ Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
655
+
656
+ Args:
657
+ view_requirements: The view requirements dict.
658
+ existing_inputs (Dict[str, tf.placeholder]): A dict of already
659
+ existing placeholders.
660
+
661
+ Returns:
662
+ Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
663
+ input_dict/dummy_batch tuple.
664
+ """
665
+ input_dict = {}
666
+ for view_col, view_req in view_requirements.items():
667
+ # Point state_in to the already existing self._state_inputs.
668
+ mo = re.match(r"state_in_(\d+)", view_col)
669
+ if mo is not None:
670
+ input_dict[view_col] = self._state_inputs[int(mo.group(1))]
671
+ # State-outs (no placeholders needed).
672
+ elif view_col.startswith("state_out_"):
673
+ continue
674
+ # Skip action dist inputs placeholder (do later).
675
+ elif view_col == SampleBatch.ACTION_DIST_INPUTS:
676
+ continue
677
+ # This is a tower: Input placeholders already exist.
678
+ elif view_col in existing_inputs:
679
+ input_dict[view_col] = existing_inputs[view_col]
680
+ # All others.
681
+ else:
682
+ time_axis = not isinstance(view_req.shift, int)
683
+ if view_req.used_for_training:
684
+ # Create a +time-axis placeholder if the shift is not an
685
+ # int (range or list of ints).
686
+ # Do not flatten actions if action flattening disabled.
687
+ if self.config.get("_disable_action_flattening") and view_col in [
688
+ SampleBatch.ACTIONS,
689
+ SampleBatch.PREV_ACTIONS,
690
+ ]:
691
+ flatten = False
692
+ # Do not flatten observations if no preprocessor API used.
693
+ elif (
694
+ view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
695
+ and self.config["_disable_preprocessor_api"]
696
+ ):
697
+ flatten = False
698
+ # Flatten everything else.
699
+ else:
700
+ flatten = True
701
+ input_dict[view_col] = get_placeholder(
702
+ space=view_req.space,
703
+ name=view_col,
704
+ time_axis=time_axis,
705
+ flatten=flatten,
706
+ )
707
+ dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)
708
+
709
+ return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
710
+
711
+ @override(Policy)
712
+ def _initialize_loss_from_dummy_batch(
713
+ self, auto_remove_unneeded_view_reqs: bool = True, stats_fn=None
714
+ ) -> None:
715
+ # Create the optimizer/exploration optimizer here. Some initialization
716
+ # steps (e.g. exploration postprocessing) may need this.
717
+ if not self._optimizers:
718
+ self._optimizers = force_list(self.optimizer())
719
+ # Backward compatibility.
720
+ self._optimizer = self._optimizers[0]
721
+
722
+ # Test calls depend on variable init, so initialize model first.
723
+ self.get_session().run(tf1.global_variables_initializer())
724
+
725
+ # Fields that have not been accessed are not needed for action
726
+ # computations -> Tag them as `used_for_compute_actions=False`.
727
+ for key, view_req in self.view_requirements.items():
728
+ if (
729
+ not key.startswith("state_in_")
730
+ and key not in self._input_dict.accessed_keys
731
+ ):
732
+ view_req.used_for_compute_actions = False
733
+ for key, value in self._extra_action_fetches.items():
734
+ self._dummy_batch[key] = get_dummy_batch_for_space(
735
+ gym.spaces.Box(
736
+ -1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name
737
+ ),
738
+ batch_size=len(self._dummy_batch),
739
+ )
740
+ self._input_dict[key] = get_placeholder(value=value, name=key)
741
+ if key not in self.view_requirements:
742
+ logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key))
743
+ self.view_requirements[key] = ViewRequirement(
744
+ space=gym.spaces.Box(
745
+ -1.0,
746
+ 1.0,
747
+ shape=value.shape.as_list()[1:],
748
+ dtype=value.dtype.name,
749
+ ),
750
+ used_for_compute_actions=False,
751
+ )
752
+ dummy_batch = self._dummy_batch
753
+
754
+ logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
755
+ self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session())
756
+ _ = self.postprocess_trajectory(dummy_batch)
757
+ # Add new columns automatically to (loss) input_dict.
758
+ for key in dummy_batch.added_keys:
759
+ if key not in self._input_dict:
760
+ self._input_dict[key] = get_placeholder(
761
+ value=dummy_batch[key], name=key
762
+ )
763
+ if key not in self.view_requirements:
764
+ self.view_requirements[key] = ViewRequirement(
765
+ space=gym.spaces.Box(
766
+ -1.0,
767
+ 1.0,
768
+ shape=dummy_batch[key].shape[1:],
769
+ dtype=dummy_batch[key].dtype,
770
+ ),
771
+ used_for_compute_actions=False,
772
+ )
773
+
774
+ train_batch = SampleBatch(
775
+ dict(self._input_dict, **self._loss_input_dict),
776
+ _is_training=True,
777
+ )
778
+
779
+ if self._state_inputs:
780
+ train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
781
+ self._loss_input_dict.update(
782
+ {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
783
+ )
784
+
785
+ self._loss_input_dict.update({k: v for k, v in train_batch.items()})
786
+
787
+ if log_once("loss_init"):
788
+ logger.debug(
789
+ "Initializing loss function with dummy input:\n\n{}\n".format(
790
+ summarize(train_batch)
791
+ )
792
+ )
793
+
794
+ losses = self._do_loss_init(train_batch)
795
+
796
+ all_accessed_keys = (
797
+ train_batch.accessed_keys
798
+ | dummy_batch.accessed_keys
799
+ | dummy_batch.added_keys
800
+ | set(self.model.view_requirements.keys())
801
+ )
802
+
803
+ TFPolicy._initialize_loss(
804
+ self,
805
+ losses,
806
+ [(k, v) for k, v in train_batch.items() if k in all_accessed_keys]
807
+ + (
808
+ [(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
809
+ if SampleBatch.SEQ_LENS in train_batch
810
+ else []
811
+ ),
812
+ )
813
+
814
+ if "is_training" in self._loss_input_dict:
815
+ del self._loss_input_dict["is_training"]
816
+
817
+ # Call the grads stats fn.
818
+ # TODO: (sven) rename to simply stats_fn to match eager and torch.
819
+ if self._grad_stats_fn:
820
+ self._stats_fetches.update(
821
+ self._grad_stats_fn(self, train_batch, self._grads)
822
+ )
823
+
824
+ # Add new columns automatically to view-reqs.
825
+ if auto_remove_unneeded_view_reqs:
826
+ # Add those needed for postprocessing and training.
827
+ all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
828
+ # Tag those only needed for post-processing (with some exceptions).
829
+ for key in dummy_batch.accessed_keys:
830
+ if (
831
+ key not in train_batch.accessed_keys
832
+ and key not in self.model.view_requirements
833
+ and key
834
+ not in [
835
+ SampleBatch.EPS_ID,
836
+ SampleBatch.AGENT_INDEX,
837
+ SampleBatch.UNROLL_ID,
838
+ SampleBatch.TERMINATEDS,
839
+ SampleBatch.TRUNCATEDS,
840
+ SampleBatch.REWARDS,
841
+ SampleBatch.INFOS,
842
+ SampleBatch.T,
843
+ SampleBatch.OBS_EMBEDS,
844
+ ]
845
+ ):
846
+ if key in self.view_requirements:
847
+ self.view_requirements[key].used_for_training = False
848
+ if key in self._loss_input_dict:
849
+ del self._loss_input_dict[key]
850
+ # Remove those not needed at all (leave those that are needed
851
+ # by Sampler to properly execute sample collection).
852
+ # Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS,
853
+ # no matter what.
854
+ for key in list(self.view_requirements.keys()):
855
+ if (
856
+ key not in all_accessed_keys
857
+ and key
858
+ not in [
859
+ SampleBatch.EPS_ID,
860
+ SampleBatch.AGENT_INDEX,
861
+ SampleBatch.UNROLL_ID,
862
+ SampleBatch.TERMINATEDS,
863
+ SampleBatch.TRUNCATEDS,
864
+ SampleBatch.REWARDS,
865
+ SampleBatch.INFOS,
866
+ SampleBatch.T,
867
+ ]
868
+ and key not in self.model.view_requirements
869
+ ):
870
+ # If user deleted this key manually in postprocessing
871
+ # fn, warn about it and do not remove from
872
+ # view-requirements.
873
+ if key in dummy_batch.deleted_keys:
874
+ logger.warning(
875
+ "SampleBatch key '{}' was deleted manually in "
876
+ "postprocessing function! RLlib will "
877
+ "automatically remove non-used items from the "
878
+ "data stream. Remove the `del` from your "
879
+ "postprocessing function.".format(key)
880
+ )
881
+ # If we are not writing output to disk, safe to erase
882
+ # this key to save space in the sample batch.
883
+ elif self.config["output"] is None:
884
+ del self.view_requirements[key]
885
+
886
+ if key in self._loss_input_dict:
887
+ del self._loss_input_dict[key]
888
+ # Add those data_cols (again) that are missing and have
889
+ # dependencies by view_cols.
890
+ for key in list(self.view_requirements.keys()):
891
+ vr = self.view_requirements[key]
892
+ if (
893
+ vr.data_col is not None
894
+ and vr.data_col not in self.view_requirements
895
+ ):
896
+ used_for_training = vr.data_col in train_batch.accessed_keys
897
+ self.view_requirements[vr.data_col] = ViewRequirement(
898
+ space=vr.space, used_for_training=used_for_training
899
+ )
900
+
901
+ self._loss_input_dict_no_rnn = {
902
+ k: v
903
+ for k, v in self._loss_input_dict.items()
904
+ if (v not in self._state_inputs and v != self._seq_lens)
905
+ }
906
+
907
+ def _do_loss_init(self, train_batch: SampleBatch):
908
+ losses = self._loss_fn(self, self.model, self.dist_class, train_batch)
909
+ losses = force_list(losses)
910
+ if self._stats_fn:
911
+ self._stats_fetches.update(self._stats_fn(self, train_batch))
912
+ # Override the update ops to be those of the model.
913
+ self._update_ops = []
914
+ if not isinstance(self.model, tf.keras.Model):
915
+ self._update_ops = self.model.update_ops()
916
+ return losses
917
+
918
+
919
+ @OldAPIStack
920
+ class TFMultiGPUTowerStack:
921
+ """Optimizer that runs in parallel across multiple local devices.
922
+
923
+ TFMultiGPUTowerStack automatically splits up and loads training data
924
+ onto specified local devices (e.g. GPUs) with `load_data()`. During a call
925
+ to `optimize()`, the devices compute gradients over slices of the data in
926
+ parallel. The gradients are then averaged and applied to the shared
927
+ weights.
928
+
929
+ The data loaded is pinned in device memory until the next call to
930
+ `load_data`, so you can make multiple passes (possibly in randomized order)
931
+ over the same data once loaded.
932
+
933
+ This is similar to tf1.train.SyncReplicasOptimizer, but works within a
934
+ single TensorFlow graph, i.e. implements in-graph replicated training:
935
+
936
+ https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
937
+ """
938
+
939
+ def __init__(
940
+ self,
941
+ # Deprecated.
942
+ optimizer=None,
943
+ devices=None,
944
+ input_placeholders=None,
945
+ rnn_inputs=None,
946
+ max_per_device_batch_size=None,
947
+ build_graph=None,
948
+ grad_norm_clipping=None,
949
+ # Use only `policy` argument from here on.
950
+ policy: TFPolicy = None,
951
+ ):
952
+ """Initializes a TFMultiGPUTowerStack instance.
953
+
954
+ Args:
955
+ policy: The TFPolicy object that this tower stack
956
+ belongs to.
957
+ """
958
+ # Obsoleted usage, use only `policy` arg from here on.
959
+ if policy is None:
960
+ deprecation_warning(
961
+ old="TFMultiGPUTowerStack(...)",
962
+ new="TFMultiGPUTowerStack(policy=[Policy])",
963
+ error=True,
964
+ )
965
+ self.policy = None
966
+ self.optimizers = optimizer
967
+ self.devices = devices
968
+ self.max_per_device_batch_size = max_per_device_batch_size
969
+ self.policy_copy = build_graph
970
+ else:
971
+ self.policy: TFPolicy = policy
972
+ self.optimizers: List[LocalOptimizer] = self.policy._optimizers
973
+ self.devices = self.policy.devices
974
+ self.max_per_device_batch_size = (
975
+ max_per_device_batch_size
976
+ or policy.config.get(
977
+ "minibatch_size", policy.config.get("train_batch_size", 999999)
978
+ )
979
+ ) // len(self.devices)
980
+ input_placeholders = tree.flatten(self.policy._loss_input_dict_no_rnn)
981
+ rnn_inputs = []
982
+ if self.policy._state_inputs:
983
+ rnn_inputs = self.policy._state_inputs + [self.policy._seq_lens]
984
+ grad_norm_clipping = self.policy.config.get("grad_clip")
985
+ self.policy_copy = self.policy.copy
986
+
987
+ assert len(self.devices) > 1 or "gpu" in self.devices[0]
988
+ self.loss_inputs = input_placeholders + rnn_inputs
989
+
990
+ shared_ops = tf1.get_collection(
991
+ tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
992
+ )
993
+
994
+ # Then setup the per-device loss graphs that use the shared weights
995
+ self._batch_index = tf1.placeholder(tf.int32, name="batch_index")
996
+
997
+ # Dynamic batch size, which may be shrunk if there isn't enough data
998
+ self._per_device_batch_size = tf1.placeholder(
999
+ tf.int32, name="per_device_batch_size"
1000
+ )
1001
+ self._loaded_per_device_batch_size = max_per_device_batch_size
1002
+
1003
+ # When loading RNN input, we dynamically determine the max seq len
1004
+ self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len")
1005
+ self._loaded_max_seq_len = 1
1006
+
1007
+ device_placeholders = [[] for _ in range(len(self.devices))]
1008
+
1009
+ for t in tree.flatten(self.loss_inputs):
1010
+ # Split on the CPU in case the data doesn't fit in GPU memory.
1011
+ with tf.device("/cpu:0"):
1012
+ splits = tf.split(t, len(self.devices))
1013
+ for i, d in enumerate(self.devices):
1014
+ device_placeholders[i].append(splits[i])
1015
+
1016
+ self._towers = []
1017
+ for tower_i, (device, placeholders) in enumerate(
1018
+ zip(self.devices, device_placeholders)
1019
+ ):
1020
+ self._towers.append(
1021
+ self._setup_device(
1022
+ tower_i, device, placeholders, len(tree.flatten(input_placeholders))
1023
+ )
1024
+ )
1025
+
1026
+ if self.policy.config["_tf_policy_handles_more_than_one_loss"]:
1027
+ avgs = []
1028
+ for i, optim in enumerate(self.optimizers):
1029
+ avg = _average_gradients([t.grads[i] for t in self._towers])
1030
+ if grad_norm_clipping:
1031
+ clipped = []
1032
+ for grad, _ in avg:
1033
+ clipped.append(grad)
1034
+ clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
1035
+ for i, (grad, var) in enumerate(avg):
1036
+ avg[i] = (clipped[i], var)
1037
+ avgs.append(avg)
1038
+
1039
+ # Gather update ops for any batch norm layers.
1040
+ # TODO(ekl) here we
1041
+ # will use all the ops found which won't work for DQN / DDPG, but
1042
+ # those aren't supported with multi-gpu right now anyways.
1043
+ self._update_ops = tf1.get_collection(
1044
+ tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
1045
+ )
1046
+ for op in shared_ops:
1047
+ self._update_ops.remove(op) # only care about tower update ops
1048
+ if self._update_ops:
1049
+ logger.debug(
1050
+ "Update ops to run on apply gradient: {}".format(self._update_ops)
1051
+ )
1052
+
1053
+ with tf1.control_dependencies(self._update_ops):
1054
+ self._train_op = tf.group(
1055
+ [o.apply_gradients(a) for o, a in zip(self.optimizers, avgs)]
1056
+ )
1057
+ else:
1058
+ avg = _average_gradients([t.grads for t in self._towers])
1059
+ if grad_norm_clipping:
1060
+ clipped = []
1061
+ for grad, _ in avg:
1062
+ clipped.append(grad)
1063
+ clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
1064
+ for i, (grad, var) in enumerate(avg):
1065
+ avg[i] = (clipped[i], var)
1066
+
1067
+ # Gather update ops for any batch norm layers.
1068
+ # TODO(ekl) here we
1069
+ # will use all the ops found which won't work for DQN / DDPG, but
1070
+ # those aren't supported with multi-gpu right now anyways.
1071
+ self._update_ops = tf1.get_collection(
1072
+ tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
1073
+ )
1074
+ for op in shared_ops:
1075
+ self._update_ops.remove(op) # only care about tower update ops
1076
+ if self._update_ops:
1077
+ logger.debug(
1078
+ "Update ops to run on apply gradient: {}".format(self._update_ops)
1079
+ )
1080
+
1081
+ with tf1.control_dependencies(self._update_ops):
1082
+ self._train_op = self.optimizers[0].apply_gradients(avg)
1083
+
1084
+ # The lifetime number of gradient updates that the policy having sent
1085
+ # some data (SampleBatchType) into this tower stack's GPU buffer(s) has already
1086
+ # undergone.
1087
+ self.num_grad_updates = 0
1088
+
1089
+ def load_data(self, sess, inputs, state_inputs, num_grad_updates=None):
1090
+ """Bulk loads the specified inputs into device memory.
1091
+
1092
+ The shape of the inputs must conform to the shapes of the input
1093
+ placeholders this optimizer was constructed with.
1094
+
1095
+ The data is split equally across all the devices. If the data is not
1096
+ evenly divisible by the batch size, excess data will be discarded.
1097
+
1098
+ Args:
1099
+ sess: TensorFlow session.
1100
+ inputs: List of arrays matching the input placeholders, of shape
1101
+ [BATCH_SIZE, ...].
1102
+ state_inputs: List of RNN input arrays. These arrays have size
1103
+ [BATCH_SIZE / MAX_SEQ_LEN, ...].
1104
+ num_grad_updates: The lifetime number of gradient updates that the
1105
+ policy having collected the data has already undergone.
1106
+
1107
+ Returns:
1108
+ The number of tuples loaded per device.
1109
+ """
1110
+ self.num_grad_updates = num_grad_updates
1111
+
1112
+ if log_once("load_data"):
1113
+ logger.info(
1114
+ "Training on concatenated sample batches:\n\n{}\n".format(
1115
+ summarize(
1116
+ {
1117
+ "placeholders": self.loss_inputs,
1118
+ "inputs": inputs,
1119
+ "state_inputs": state_inputs,
1120
+ }
1121
+ )
1122
+ )
1123
+ )
1124
+
1125
+ feed_dict = {}
1126
+ assert len(self.loss_inputs) == len(inputs + state_inputs), (
1127
+ self.loss_inputs,
1128
+ inputs,
1129
+ state_inputs,
1130
+ )
1131
+
1132
+ # Let's suppose we have the following input data, and 2 devices:
1133
+ # 1 2 3 4 5 6 7 <- state inputs shape
1134
+ # A A A B B B C C C D D D E E E F F F G G G <- inputs shape
1135
+ # The data is truncated and split across devices as follows:
1136
+ # |---| seq len = 3
1137
+ # |---------------------------------| seq batch size = 6 seqs
1138
+ # |----------------| per device batch size = 9 tuples
1139
+
1140
+ if len(state_inputs) > 0:
1141
+ smallest_array = state_inputs[0]
1142
+ seq_len = len(inputs[0]) // len(state_inputs[0])
1143
+ self._loaded_max_seq_len = seq_len
1144
+ else:
1145
+ smallest_array = inputs[0]
1146
+ self._loaded_max_seq_len = 1
1147
+
1148
+ sequences_per_minibatch = (
1149
+ self.max_per_device_batch_size
1150
+ // self._loaded_max_seq_len
1151
+ * len(self.devices)
1152
+ )
1153
+ if sequences_per_minibatch < 1:
1154
+ logger.warning(
1155
+ (
1156
+ "Target minibatch size is {}, however the rollout sequence "
1157
+ "length is {}, hence the minibatch size will be raised to "
1158
+ "{}."
1159
+ ).format(
1160
+ self.max_per_device_batch_size,
1161
+ self._loaded_max_seq_len,
1162
+ self._loaded_max_seq_len * len(self.devices),
1163
+ )
1164
+ )
1165
+ sequences_per_minibatch = 1
1166
+
1167
+ if len(smallest_array) < sequences_per_minibatch:
1168
+ # Dynamically shrink the batch size if insufficient data
1169
+ sequences_per_minibatch = _make_divisible_by(
1170
+ len(smallest_array), len(self.devices)
1171
+ )
1172
+
1173
+ if log_once("data_slicing"):
1174
+ logger.info(
1175
+ (
1176
+ "Divided {} rollout sequences, each of length {}, among "
1177
+ "{} devices."
1178
+ ).format(
1179
+ len(smallest_array), self._loaded_max_seq_len, len(self.devices)
1180
+ )
1181
+ )
1182
+
1183
+ if sequences_per_minibatch < len(self.devices):
1184
+ raise ValueError(
1185
+ "Must load at least 1 tuple sequence per device. Try "
1186
+ "increasing `minibatch_size` or reducing `max_seq_len` "
1187
+ "to ensure that at least one sequence fits per device."
1188
+ )
1189
+ self._loaded_per_device_batch_size = (
1190
+ sequences_per_minibatch // len(self.devices) * self._loaded_max_seq_len
1191
+ )
1192
+
1193
+ if len(state_inputs) > 0:
1194
+ # First truncate the RNN state arrays to the sequences_per_minib.
1195
+ state_inputs = [
1196
+ _make_divisible_by(arr, sequences_per_minibatch) for arr in state_inputs
1197
+ ]
1198
+ # Then truncate the data inputs to match
1199
+ inputs = [arr[: len(state_inputs[0]) * seq_len] for arr in inputs]
1200
+ assert len(state_inputs[0]) * seq_len == len(inputs[0]), (
1201
+ len(state_inputs[0]),
1202
+ sequences_per_minibatch,
1203
+ seq_len,
1204
+ len(inputs[0]),
1205
+ )
1206
+ for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
1207
+ feed_dict[ph] = arr
1208
+ truncated_len = len(inputs[0])
1209
+ else:
1210
+ truncated_len = 0
1211
+ for ph, arr in zip(self.loss_inputs, inputs):
1212
+ truncated_arr = _make_divisible_by(arr, sequences_per_minibatch)
1213
+ feed_dict[ph] = truncated_arr
1214
+ if truncated_len == 0:
1215
+ truncated_len = len(truncated_arr)
1216
+
1217
+ sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
1218
+
1219
+ self.num_tuples_loaded = truncated_len
1220
+ samples_per_device = truncated_len // len(self.devices)
1221
+ assert samples_per_device > 0, "No data loaded?"
1222
+ assert samples_per_device % self._loaded_per_device_batch_size == 0
1223
+ # Return loaded samples per-device.
1224
+ return samples_per_device
1225
+
1226
+ def optimize(self, sess, batch_index):
1227
+ """Run a single step of SGD.
1228
+
1229
+ Runs a SGD step over a slice of the preloaded batch with size given by
1230
+ self._loaded_per_device_batch_size and offset given by the batch_index
1231
+ argument.
1232
+
1233
+ Updates shared model weights based on the averaged per-device
1234
+ gradients.
1235
+
1236
+ Args:
1237
+ sess: TensorFlow session.
1238
+ batch_index: Offset into the preloaded data. This value must be
1239
+ between `0` and `tuples_per_device`. The amount of data to
1240
+ process is at most `max_per_device_batch_size`.
1241
+
1242
+ Returns:
1243
+ The outputs of extra_ops evaluated over the batch.
1244
+ """
1245
+ feed_dict = {
1246
+ self._batch_index: batch_index,
1247
+ self._per_device_batch_size: self._loaded_per_device_batch_size,
1248
+ self._max_seq_len: self._loaded_max_seq_len,
1249
+ }
1250
+ for tower in self._towers:
1251
+ feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
1252
+
1253
+ fetches = {"train": self._train_op}
1254
+ for tower_num, tower in enumerate(self._towers):
1255
+ tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
1256
+ fetches["tower_{}".format(tower_num)] = tower_fetch
1257
+
1258
+ return sess.run(fetches, feed_dict=feed_dict)
1259
+
1260
+ def get_device_losses(self):
1261
+ return [t.loss_graph for t in self._towers]
1262
+
1263
+ def _setup_device(self, tower_i, device, device_input_placeholders, num_data_in):
1264
+ assert num_data_in <= len(device_input_placeholders)
1265
+ with tf.device(device):
1266
+ with tf1.name_scope(TOWER_SCOPE_NAME + f"_{tower_i}"):
1267
+ device_input_batches = []
1268
+ device_input_slices = []
1269
+ for i, ph in enumerate(device_input_placeholders):
1270
+ current_batch = tf1.Variable(
1271
+ ph, trainable=False, validate_shape=False, collections=[]
1272
+ )
1273
+ device_input_batches.append(current_batch)
1274
+ if i < num_data_in:
1275
+ scale = self._max_seq_len
1276
+ granularity = self._max_seq_len
1277
+ else:
1278
+ scale = self._max_seq_len
1279
+ granularity = 1
1280
+ current_slice = tf.slice(
1281
+ current_batch,
1282
+ (
1283
+ [self._batch_index // scale * granularity]
1284
+ + [0] * len(ph.shape[1:])
1285
+ ),
1286
+ (
1287
+ [self._per_device_batch_size // scale * granularity]
1288
+ + [-1] * len(ph.shape[1:])
1289
+ ),
1290
+ )
1291
+ current_slice.set_shape(ph.shape)
1292
+ device_input_slices.append(current_slice)
1293
+ graph_obj = self.policy_copy(device_input_slices)
1294
+ device_grads = graph_obj.gradients(self.optimizers, graph_obj._losses)
1295
+ return _Tower(
1296
+ tf.group(*[batch.initializer for batch in device_input_batches]),
1297
+ device_grads,
1298
+ graph_obj,
1299
+ )
1300
+
1301
+
1302
+ # Each tower is a copy of the loss graph pinned to a specific device.
1303
+ _Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"])
1304
+
1305
+
1306
+ def _make_divisible_by(a, n):
1307
+ if type(a) is int:
1308
+ return a - a % n
1309
+ return a[0 : a.shape[0] - a.shape[0] % n]
1310
+
1311
+
1312
+ def _average_gradients(tower_grads):
1313
+ """Averages gradients across towers.
1314
+
1315
+ Calculate the average gradient for each shared variable across all towers.
1316
+ Note that this function provides a synchronization point across all towers.
1317
+
1318
+ Args:
1319
+ tower_grads: List of lists of (gradient, variable) tuples. The outer
1320
+ list is over individual gradients. The inner list is over the
1321
+ gradient calculation for each tower.
1322
+
1323
+ Returns:
1324
+ List of pairs of (gradient, variable) where the gradient has been
1325
+ averaged across all towers.
1326
+
1327
+ TODO(ekl): We could use NCCL if this becomes a bottleneck.
1328
+ """
1329
+
1330
+ average_grads = []
1331
+ for grad_and_vars in zip(*tower_grads):
1332
+ # Note that each grad_and_vars looks like the following:
1333
+ # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
1334
+ grads = []
1335
+ for g, _ in grad_and_vars:
1336
+ if g is not None:
1337
+ # Add 0 dimension to the gradients to represent the tower.
1338
+ expanded_g = tf.expand_dims(g, 0)
1339
+
1340
+ # Append on a 'tower' dimension which we will average over
1341
+ # below.
1342
+ grads.append(expanded_g)
1343
+
1344
+ if not grads:
1345
+ continue
1346
+
1347
+ # Average over the 'tower' dimension.
1348
+ grad = tf.concat(axis=0, values=grads)
1349
+ grad = tf.reduce_mean(grad, 0)
1350
+
1351
+ # Keep in mind that the Variables are redundant because they are shared
1352
+ # across towers. So .. we will just return the first tower's pointer to
1353
+ # the Variable.
1354
+ v = grad_and_vars[0][1]
1355
+ grad_and_var = (grad, v)
1356
+ average_grads.append(grad_and_var)
1357
+
1358
+ return average_grads
.venv/lib/python3.11/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py ADDED
@@ -0,0 +1,1047 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import gymnasium as gym
3
+ import logging
4
+ import re
5
+ import tree # pip install dm_tree
6
+ from typing import Dict, List, Optional, Tuple, Type, Union
7
+
8
+ from ray.rllib.models.catalog import ModelCatalog
9
+ from ray.rllib.models.modelv2 import ModelV2
10
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
11
+ from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack
12
+ from ray.rllib.policy.policy import Policy
13
+ from ray.rllib.policy.sample_batch import SampleBatch
14
+ from ray.rllib.policy.tf_policy import TFPolicy
15
+ from ray.rllib.policy.view_requirement import ViewRequirement
16
+ from ray.rllib.utils import force_list
17
+ from ray.rllib.utils.annotations import (
18
+ OldAPIStack,
19
+ OverrideToImplementCustomLogic,
20
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
21
+ is_overridden,
22
+ override,
23
+ )
24
+ from ray.rllib.utils.debug import summarize
25
+ from ray.rllib.utils.framework import try_import_tf
26
+ from ray.rllib.utils.metrics import (
27
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
28
+ NUM_GRAD_UPDATES_LIFETIME,
29
+ )
30
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
31
+ from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
32
+ from ray.rllib.utils.tf_utils import get_placeholder
33
+ from ray.rllib.utils.typing import (
34
+ AlgorithmConfigDict,
35
+ LocalOptimizer,
36
+ ModelGradients,
37
+ TensorType,
38
+ )
39
+ from ray.util.debug import log_once
40
+
41
+ tf1, tf, tfv = try_import_tf()
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ @OldAPIStack
47
+ class DynamicTFPolicyV2(TFPolicy):
48
+ """A TFPolicy that auto-defines placeholders dynamically at runtime.
49
+
50
+ This class is intended to be used and extended by sub-classing.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ obs_space: gym.spaces.Space,
56
+ action_space: gym.spaces.Space,
57
+ config: AlgorithmConfigDict,
58
+ *,
59
+ existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
60
+ existing_model: Optional[ModelV2] = None,
61
+ ):
62
+ self.observation_space = obs_space
63
+ self.action_space = action_space
64
+ self.config = config
65
+ self.framework = "tf"
66
+ self._seq_lens = None
67
+ self._is_tower = existing_inputs is not None
68
+
69
+ self.validate_spaces(obs_space, action_space, config)
70
+
71
+ self.dist_class = self._init_dist_class()
72
+ # Setup self.model.
73
+ if existing_model and isinstance(existing_model, list):
74
+ self.model = existing_model[0]
75
+ # TODO: (sven) hack, but works for `target_[q_]?model`.
76
+ for i in range(1, len(existing_model)):
77
+ setattr(self, existing_model[i][0], existing_model[i][1])
78
+ else:
79
+ self.model = self.make_model()
80
+ # Auto-update model's inference view requirements, if recurrent.
81
+ self._update_model_view_requirements_from_init_state()
82
+
83
+ self._init_state_inputs(existing_inputs)
84
+ self._init_view_requirements()
85
+ timestep, explore = self._init_input_dict_and_dummy_batch(existing_inputs)
86
+ (
87
+ sampled_action,
88
+ sampled_action_logp,
89
+ dist_inputs,
90
+ self._policy_extra_action_fetches,
91
+ ) = self._init_action_fetches(timestep, explore)
92
+
93
+ # Phase 1 init.
94
+ sess = tf1.get_default_session() or tf1.Session(
95
+ config=tf1.ConfigProto(**self.config["tf_session_args"])
96
+ )
97
+
98
+ batch_divisibility_req = self.get_batch_divisibility_req()
99
+
100
+ prev_action_input = (
101
+ self._input_dict[SampleBatch.PREV_ACTIONS]
102
+ if SampleBatch.PREV_ACTIONS in self._input_dict.accessed_keys
103
+ else None
104
+ )
105
+ prev_reward_input = (
106
+ self._input_dict[SampleBatch.PREV_REWARDS]
107
+ if SampleBatch.PREV_REWARDS in self._input_dict.accessed_keys
108
+ else None
109
+ )
110
+
111
+ super().__init__(
112
+ observation_space=obs_space,
113
+ action_space=action_space,
114
+ config=config,
115
+ sess=sess,
116
+ obs_input=self._input_dict[SampleBatch.OBS],
117
+ action_input=self._input_dict[SampleBatch.ACTIONS],
118
+ sampled_action=sampled_action,
119
+ sampled_action_logp=sampled_action_logp,
120
+ dist_inputs=dist_inputs,
121
+ dist_class=self.dist_class,
122
+ loss=None, # dynamically initialized on run
123
+ loss_inputs=[],
124
+ model=self.model,
125
+ state_inputs=self._state_inputs,
126
+ state_outputs=self._state_out,
127
+ prev_action_input=prev_action_input,
128
+ prev_reward_input=prev_reward_input,
129
+ seq_lens=self._seq_lens,
130
+ max_seq_len=config["model"].get("max_seq_len", 20),
131
+ batch_divisibility_req=batch_divisibility_req,
132
+ explore=explore,
133
+ timestep=timestep,
134
+ )
135
+
136
+ @staticmethod
137
+ def enable_eager_execution_if_necessary():
138
+ # This is static graph TF policy.
139
+ # Simply do nothing.
140
+ pass
141
+
142
+ @OverrideToImplementCustomLogic
143
+ def validate_spaces(
144
+ self,
145
+ obs_space: gym.spaces.Space,
146
+ action_space: gym.spaces.Space,
147
+ config: AlgorithmConfigDict,
148
+ ):
149
+ return {}
150
+
151
+ @OverrideToImplementCustomLogic
152
+ @override(Policy)
153
+ def loss(
154
+ self,
155
+ model: Union[ModelV2, "tf.keras.Model"],
156
+ dist_class: Type[TFActionDistribution],
157
+ train_batch: SampleBatch,
158
+ ) -> Union[TensorType, List[TensorType]]:
159
+ """Constructs loss computation graph for this TF1 policy.
160
+
161
+ Args:
162
+ model: The Model to calculate the loss for.
163
+ dist_class: The action distr. class.
164
+ train_batch: The training data.
165
+
166
+ Returns:
167
+ A single loss tensor or a list of loss tensors.
168
+ """
169
+ raise NotImplementedError
170
+
171
+ @OverrideToImplementCustomLogic
172
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
173
+ """Stats function. Returns a dict of statistics.
174
+
175
+ Args:
176
+ train_batch: The SampleBatch (already) used for training.
177
+
178
+ Returns:
179
+ The stats dict.
180
+ """
181
+ return {}
182
+
183
+ @OverrideToImplementCustomLogic
184
+ def grad_stats_fn(
185
+ self, train_batch: SampleBatch, grads: ModelGradients
186
+ ) -> Dict[str, TensorType]:
187
+ """Gradient stats function. Returns a dict of statistics.
188
+
189
+ Args:
190
+ train_batch: The SampleBatch (already) used for training.
191
+
192
+ Returns:
193
+ The stats dict.
194
+ """
195
+ return {}
196
+
197
+ @OverrideToImplementCustomLogic
198
+ def make_model(self) -> ModelV2:
199
+ """Build underlying model for this Policy.
200
+
201
+ Returns:
202
+ The Model for the Policy to use.
203
+ """
204
+ # Default ModelV2 model.
205
+ _, logit_dim = ModelCatalog.get_action_dist(
206
+ self.action_space, self.config["model"]
207
+ )
208
+ return ModelCatalog.get_model_v2(
209
+ obs_space=self.observation_space,
210
+ action_space=self.action_space,
211
+ num_outputs=logit_dim,
212
+ model_config=self.config["model"],
213
+ framework="tf",
214
+ )
215
+
216
+ @OverrideToImplementCustomLogic
217
+ def compute_gradients_fn(
218
+ self, optimizer: LocalOptimizer, loss: TensorType
219
+ ) -> ModelGradients:
220
+ """Gradients computing function (from loss tensor, using local optimizer).
221
+
222
+ Args:
223
+ policy: The Policy object that generated the loss tensor and
224
+ that holds the given local optimizer.
225
+ optimizer: The tf (local) optimizer object to
226
+ calculate the gradients with.
227
+ loss: The loss tensor for which gradients should be
228
+ calculated.
229
+
230
+ Returns:
231
+ ModelGradients: List of the possibly clipped gradients- and variable
232
+ tuples.
233
+ """
234
+ return None
235
+
236
+ @OverrideToImplementCustomLogic
237
+ def apply_gradients_fn(
238
+ self,
239
+ optimizer: "tf.keras.optimizers.Optimizer",
240
+ grads: ModelGradients,
241
+ ) -> "tf.Operation":
242
+ """Gradients computing function (from loss tensor, using local optimizer).
243
+
244
+ Args:
245
+ optimizer: The tf (local) optimizer object to
246
+ calculate the gradients with.
247
+ grads: The gradient tensor to be applied.
248
+
249
+ Returns:
250
+ "tf.Operation": TF operation that applies supplied gradients.
251
+ """
252
+ return None
253
+
254
+ @OverrideToImplementCustomLogic
255
+ def action_sampler_fn(
256
+ self,
257
+ model: ModelV2,
258
+ *,
259
+ obs_batch: TensorType,
260
+ state_batches: TensorType,
261
+ **kwargs,
262
+ ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
263
+ """Custom function for sampling new actions given policy.
264
+
265
+ Args:
266
+ model: Underlying model.
267
+ obs_batch: Observation tensor batch.
268
+ state_batches: Action sampling state batch.
269
+
270
+ Returns:
271
+ Sampled action
272
+ Log-likelihood
273
+ Action distribution inputs
274
+ Updated state
275
+ """
276
+ return None, None, None, None
277
+
278
+ @OverrideToImplementCustomLogic
279
+ def action_distribution_fn(
280
+ self,
281
+ model: ModelV2,
282
+ *,
283
+ obs_batch: TensorType,
284
+ state_batches: TensorType,
285
+ **kwargs,
286
+ ) -> Tuple[TensorType, type, List[TensorType]]:
287
+ """Action distribution function for this Policy.
288
+
289
+ Args:
290
+ model: Underlying model.
291
+ obs_batch: Observation tensor batch.
292
+ state_batches: Action sampling state batch.
293
+
294
+ Returns:
295
+ Distribution input.
296
+ ActionDistribution class.
297
+ State outs.
298
+ """
299
+ return None, None, None
300
+
301
+ @OverrideToImplementCustomLogic
302
+ def get_batch_divisibility_req(self) -> int:
303
+ """Get batch divisibility request.
304
+
305
+ Returns:
306
+ Size N. A sample batch must be of size K*N.
307
+ """
308
+ # By default, any sized batch is ok, so simply return 1.
309
+ return 1
310
+
311
+ @override(TFPolicy)
312
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
313
+ def extra_action_out_fn(self) -> Dict[str, TensorType]:
314
+ """Extra values to fetch and return from compute_actions().
315
+
316
+ Returns:
317
+ Dict[str, TensorType]: An extra fetch-dict to be passed to and
318
+ returned from the compute_actions() call.
319
+ """
320
+ extra_action_fetches = super().extra_action_out_fn()
321
+ extra_action_fetches.update(self._policy_extra_action_fetches)
322
+ return extra_action_fetches
323
+
324
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
325
+ def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
326
+ """Extra stats to be reported after gradient computation.
327
+
328
+ Returns:
329
+ Dict[str, TensorType]: An extra fetch-dict.
330
+ """
331
+ return {}
332
+
333
+ @override(TFPolicy)
334
+ def extra_compute_grad_fetches(self):
335
+ return dict({LEARNER_STATS_KEY: {}}, **self.extra_learn_fetches_fn())
336
+
337
+ @override(Policy)
338
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
339
+ def postprocess_trajectory(
340
+ self,
341
+ sample_batch: SampleBatch,
342
+ other_agent_batches: Optional[SampleBatch] = None,
343
+ episode=None,
344
+ ):
345
+ """Post process trajectory in the format of a SampleBatch.
346
+
347
+ Args:
348
+ sample_batch: sample_batch: batch of experiences for the policy,
349
+ which will contain at most one episode trajectory.
350
+ other_agent_batches: In a multi-agent env, this contains a
351
+ mapping of agent ids to (policy, agent_batch) tuples
352
+ containing the policy and experiences of the other agents.
353
+ episode: An optional multi-agent episode object to provide
354
+ access to all of the internal episode state, which may
355
+ be useful for model-based or multi-agent algorithms.
356
+
357
+ Returns:
358
+ The postprocessed sample batch.
359
+ """
360
+ return Policy.postprocess_trajectory(self, sample_batch)
361
+
362
+ @override(TFPolicy)
363
+ @OverrideToImplementCustomLogic
364
+ def optimizer(
365
+ self,
366
+ ) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
367
+ """TF optimizer to use for policy optimization.
368
+
369
+ Returns:
370
+ A local optimizer or a list of local optimizers to use for this
371
+ Policy's Model.
372
+ """
373
+ return super().optimizer()
374
+
375
+ def _init_dist_class(self):
376
+ if is_overridden(self.action_sampler_fn) or is_overridden(
377
+ self.action_distribution_fn
378
+ ):
379
+ if not is_overridden(self.make_model):
380
+ raise ValueError(
381
+ "`make_model` is required if `action_sampler_fn` OR "
382
+ "`action_distribution_fn` is given"
383
+ )
384
+ return None
385
+ else:
386
+ dist_class, _ = ModelCatalog.get_action_dist(
387
+ self.action_space, self.config["model"]
388
+ )
389
+ return dist_class
390
+
391
+ def _init_view_requirements(self):
392
+ # If ViewRequirements are explicitly specified.
393
+ if getattr(self, "view_requirements", None):
394
+ return
395
+
396
+ # Use default settings.
397
+ # Add NEXT_OBS, STATE_IN_0.., and others.
398
+ self.view_requirements = self._get_default_view_requirements()
399
+ # Combine view_requirements for Model and Policy.
400
+ # TODO(jungong) : models will not carry view_requirements once they
401
+ # are migrated to be organic Keras models.
402
+ self.view_requirements.update(self.model.view_requirements)
403
+ # Disable env-info placeholder.
404
+ if SampleBatch.INFOS in self.view_requirements:
405
+ self.view_requirements[SampleBatch.INFOS].used_for_training = False
406
+
407
+ def _init_state_inputs(self, existing_inputs: Dict[str, "tf1.placeholder"]):
408
+ """Initialize input placeholders.
409
+
410
+ Args:
411
+ existing_inputs: existing placeholders.
412
+ """
413
+ if existing_inputs:
414
+ self._state_inputs = [
415
+ v for k, v in existing_inputs.items() if k.startswith("state_in_")
416
+ ]
417
+ # Placeholder for RNN time-chunk valid lengths.
418
+ if self._state_inputs:
419
+ self._seq_lens = existing_inputs[SampleBatch.SEQ_LENS]
420
+ # Create new input placeholders.
421
+ else:
422
+ self._state_inputs = [
423
+ get_placeholder(
424
+ space=vr.space,
425
+ time_axis=not isinstance(vr.shift, int),
426
+ name=k,
427
+ )
428
+ for k, vr in self.model.view_requirements.items()
429
+ if k.startswith("state_in_")
430
+ ]
431
+ # Placeholder for RNN time-chunk valid lengths.
432
+ if self._state_inputs:
433
+ self._seq_lens = tf1.placeholder(
434
+ dtype=tf.int32, shape=[None], name="seq_lens"
435
+ )
436
+
437
+ def _init_input_dict_and_dummy_batch(
438
+ self, existing_inputs: Dict[str, "tf1.placeholder"]
439
+ ) -> Tuple[Union[int, TensorType], Union[bool, TensorType]]:
440
+ """Initialized input_dict and dummy_batch data.
441
+
442
+ Args:
443
+ existing_inputs: When copying a policy, this specifies an existing
444
+ dict of placeholders to use instead of defining new ones.
445
+
446
+ Returns:
447
+ timestep: training timestep.
448
+ explore: whether this policy should explore.
449
+ """
450
+ # Setup standard placeholders.
451
+ if self._is_tower:
452
+ assert existing_inputs is not None
453
+ timestep = existing_inputs["timestep"]
454
+ explore = False
455
+ (
456
+ self._input_dict,
457
+ self._dummy_batch,
458
+ ) = self._create_input_dict_and_dummy_batch(
459
+ self.view_requirements, existing_inputs
460
+ )
461
+ else:
462
+ # Placeholder for (sampling steps) timestep (int).
463
+ timestep = tf1.placeholder_with_default(
464
+ tf.zeros((), dtype=tf.int64), (), name="timestep"
465
+ )
466
+ # Placeholder for `is_exploring` flag.
467
+ explore = tf1.placeholder_with_default(True, (), name="is_exploring")
468
+ (
469
+ self._input_dict,
470
+ self._dummy_batch,
471
+ ) = self._create_input_dict_and_dummy_batch(self.view_requirements, {})
472
+
473
+ # Placeholder for `is_training` flag.
474
+ self._input_dict.set_training(self._get_is_training_placeholder())
475
+
476
+ return timestep, explore
477
+
478
+ def _create_input_dict_and_dummy_batch(self, view_requirements, existing_inputs):
479
+ """Creates input_dict and dummy_batch for loss initialization.
480
+
481
+ Used for managing the Policy's input placeholders and for loss
482
+ initialization.
483
+ Input_dict: Str -> tf.placeholders, dummy_batch: str -> np.arrays.
484
+
485
+ Args:
486
+ view_requirements: The view requirements dict.
487
+ existing_inputs (Dict[str, tf.placeholder]): A dict of already
488
+ existing placeholders.
489
+
490
+ Returns:
491
+ Tuple[Dict[str, tf.placeholder], Dict[str, np.ndarray]]: The
492
+ input_dict/dummy_batch tuple.
493
+ """
494
+ input_dict = {}
495
+ for view_col, view_req in view_requirements.items():
496
+ # Point state_in to the already existing self._state_inputs.
497
+ mo = re.match(r"state_in_(\d+)", view_col)
498
+ if mo is not None:
499
+ input_dict[view_col] = self._state_inputs[int(mo.group(1))]
500
+ # State-outs (no placeholders needed).
501
+ elif view_col.startswith("state_out_"):
502
+ continue
503
+ # Skip action dist inputs placeholder (do later).
504
+ elif view_col == SampleBatch.ACTION_DIST_INPUTS:
505
+ continue
506
+ # This is a tower: Input placeholders already exist.
507
+ elif view_col in existing_inputs:
508
+ input_dict[view_col] = existing_inputs[view_col]
509
+ # All others.
510
+ else:
511
+ time_axis = not isinstance(view_req.shift, int)
512
+ if view_req.used_for_training:
513
+ # Create a +time-axis placeholder if the shift is not an
514
+ # int (range or list of ints).
515
+ # Do not flatten actions if action flattening disabled.
516
+ if self.config.get("_disable_action_flattening") and view_col in [
517
+ SampleBatch.ACTIONS,
518
+ SampleBatch.PREV_ACTIONS,
519
+ ]:
520
+ flatten = False
521
+ # Do not flatten observations if no preprocessor API used.
522
+ elif (
523
+ view_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS]
524
+ and self.config["_disable_preprocessor_api"]
525
+ ):
526
+ flatten = False
527
+ # Flatten everything else.
528
+ else:
529
+ flatten = True
530
+ input_dict[view_col] = get_placeholder(
531
+ space=view_req.space,
532
+ name=view_col,
533
+ time_axis=time_axis,
534
+ flatten=flatten,
535
+ )
536
+ dummy_batch = self._get_dummy_batch_from_view_requirements(batch_size=32)
537
+
538
+ return SampleBatch(input_dict, seq_lens=self._seq_lens), dummy_batch
539
+
540
+ def _init_action_fetches(
541
+ self, timestep: Union[int, TensorType], explore: Union[bool, TensorType]
542
+ ) -> Tuple[TensorType, TensorType, TensorType, type, Dict[str, TensorType]]:
543
+ """Create action related fields for base Policy and loss initialization."""
544
+ # Multi-GPU towers do not need any action computing/exploration
545
+ # graphs.
546
+ sampled_action = None
547
+ sampled_action_logp = None
548
+ dist_inputs = None
549
+ extra_action_fetches = {}
550
+ self._state_out = None
551
+ if not self._is_tower:
552
+ # Create the Exploration object to use for this Policy.
553
+ self.exploration = self._create_exploration()
554
+
555
+ # Fully customized action generation (e.g., custom policy).
556
+ if is_overridden(self.action_sampler_fn):
557
+ (
558
+ sampled_action,
559
+ sampled_action_logp,
560
+ dist_inputs,
561
+ self._state_out,
562
+ ) = self.action_sampler_fn(
563
+ self.model,
564
+ obs_batch=self._input_dict[SampleBatch.OBS],
565
+ state_batches=self._state_inputs,
566
+ seq_lens=self._seq_lens,
567
+ prev_action_batch=self._input_dict.get(SampleBatch.PREV_ACTIONS),
568
+ prev_reward_batch=self._input_dict.get(SampleBatch.PREV_REWARDS),
569
+ explore=explore,
570
+ is_training=self._input_dict.is_training,
571
+ )
572
+ # Distribution generation is customized, e.g., DQN, DDPG.
573
+ else:
574
+ if is_overridden(self.action_distribution_fn):
575
+ # Try new action_distribution_fn signature, supporting
576
+ # state_batches and seq_lens.
577
+ in_dict = self._input_dict
578
+ (
579
+ dist_inputs,
580
+ self.dist_class,
581
+ self._state_out,
582
+ ) = self.action_distribution_fn(
583
+ self.model,
584
+ obs_batch=in_dict[SampleBatch.OBS],
585
+ state_batches=self._state_inputs,
586
+ seq_lens=self._seq_lens,
587
+ explore=explore,
588
+ timestep=timestep,
589
+ is_training=in_dict.is_training,
590
+ )
591
+ # Default distribution generation behavior:
592
+ # Pass through model. E.g., PG, PPO.
593
+ else:
594
+ if isinstance(self.model, tf.keras.Model):
595
+ dist_inputs, self._state_out, extra_action_fetches = self.model(
596
+ self._input_dict
597
+ )
598
+ else:
599
+ dist_inputs, self._state_out = self.model(self._input_dict)
600
+
601
+ action_dist = self.dist_class(dist_inputs, self.model)
602
+
603
+ # Using exploration to get final action (e.g. via sampling).
604
+ (
605
+ sampled_action,
606
+ sampled_action_logp,
607
+ ) = self.exploration.get_exploration_action(
608
+ action_distribution=action_dist, timestep=timestep, explore=explore
609
+ )
610
+
611
+ if dist_inputs is not None:
612
+ extra_action_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
613
+
614
+ if sampled_action_logp is not None:
615
+ extra_action_fetches[SampleBatch.ACTION_LOGP] = sampled_action_logp
616
+ extra_action_fetches[SampleBatch.ACTION_PROB] = tf.exp(
617
+ tf.cast(sampled_action_logp, tf.float32)
618
+ )
619
+
620
+ return (
621
+ sampled_action,
622
+ sampled_action_logp,
623
+ dist_inputs,
624
+ extra_action_fetches,
625
+ )
626
+
627
+ def _init_optimizers(self):
628
+ # Create the optimizer/exploration optimizer here. Some initialization
629
+ # steps (e.g. exploration postprocessing) may need this.
630
+ optimizers = force_list(self.optimizer())
631
+ if self.exploration:
632
+ optimizers = self.exploration.get_exploration_optimizer(optimizers)
633
+
634
+ # No optimizers produced -> Return.
635
+ if not optimizers:
636
+ return
637
+
638
+ # The list of local (tf) optimizers (one per loss term).
639
+ self._optimizers = optimizers
640
+ # Backward compatibility.
641
+ self._optimizer = optimizers[0]
642
+
643
+ def maybe_initialize_optimizer_and_loss(self):
644
+ # We don't need to initialize loss calculation for MultiGPUTowerStack.
645
+ if self._is_tower:
646
+ self.get_session().run(tf1.global_variables_initializer())
647
+ return
648
+
649
+ # Loss initialization and model/postprocessing test calls.
650
+ self._init_optimizers()
651
+ self._initialize_loss_from_dummy_batch(auto_remove_unneeded_view_reqs=True)
652
+
653
+ # Create MultiGPUTowerStacks, if we have at least one actual
654
+ # GPU or >1 CPUs (fake GPUs).
655
+ if len(self.devices) > 1 or any("gpu" in d for d in self.devices):
656
+ # Per-GPU graph copies created here must share vars with the
657
+ # policy. Therefore, `reuse` is set to tf1.AUTO_REUSE because
658
+ # Adam nodes are created after all of the device copies are
659
+ # created.
660
+ with tf1.variable_scope("", reuse=tf1.AUTO_REUSE):
661
+ self.multi_gpu_tower_stacks = [
662
+ TFMultiGPUTowerStack(policy=self)
663
+ for _ in range(self.config.get("num_multi_gpu_tower_stacks", 1))
664
+ ]
665
+
666
+ # Initialize again after loss and tower init.
667
+ self.get_session().run(tf1.global_variables_initializer())
668
+
669
+ @override(Policy)
670
+ def _initialize_loss_from_dummy_batch(
671
+ self, auto_remove_unneeded_view_reqs: bool = True
672
+ ) -> None:
673
+ # Test calls depend on variable init, so initialize model first.
674
+ self.get_session().run(tf1.global_variables_initializer())
675
+
676
+ # Fields that have not been accessed are not needed for action
677
+ # computations -> Tag them as `used_for_compute_actions=False`.
678
+ for key, view_req in self.view_requirements.items():
679
+ if (
680
+ not key.startswith("state_in_")
681
+ and key not in self._input_dict.accessed_keys
682
+ ):
683
+ view_req.used_for_compute_actions = False
684
+ for key, value in self.extra_action_out_fn().items():
685
+ self._dummy_batch[key] = get_dummy_batch_for_space(
686
+ gym.spaces.Box(
687
+ -1.0, 1.0, shape=value.shape.as_list()[1:], dtype=value.dtype.name
688
+ ),
689
+ batch_size=len(self._dummy_batch),
690
+ )
691
+ self._input_dict[key] = get_placeholder(value=value, name=key)
692
+ if key not in self.view_requirements:
693
+ logger.info("Adding extra-action-fetch `{}` to view-reqs.".format(key))
694
+ self.view_requirements[key] = ViewRequirement(
695
+ space=gym.spaces.Box(
696
+ -1.0,
697
+ 1.0,
698
+ shape=value.shape.as_list()[1:],
699
+ dtype=value.dtype.name,
700
+ ),
701
+ used_for_compute_actions=False,
702
+ )
703
+ dummy_batch = self._dummy_batch
704
+
705
+ logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
706
+ self.exploration.postprocess_trajectory(self, dummy_batch, self.get_session())
707
+ _ = self.postprocess_trajectory(dummy_batch)
708
+ # Add new columns automatically to (loss) input_dict.
709
+ for key in dummy_batch.added_keys:
710
+ if key not in self._input_dict:
711
+ self._input_dict[key] = get_placeholder(
712
+ value=dummy_batch[key], name=key
713
+ )
714
+ if key not in self.view_requirements:
715
+ self.view_requirements[key] = ViewRequirement(
716
+ space=gym.spaces.Box(
717
+ -1.0,
718
+ 1.0,
719
+ shape=dummy_batch[key].shape[1:],
720
+ dtype=dummy_batch[key].dtype,
721
+ ),
722
+ used_for_compute_actions=False,
723
+ )
724
+
725
+ train_batch = SampleBatch(
726
+ dict(self._input_dict, **self._loss_input_dict),
727
+ _is_training=True,
728
+ )
729
+
730
+ if self._state_inputs:
731
+ train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
732
+ self._loss_input_dict.update(
733
+ {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
734
+ )
735
+
736
+ self._loss_input_dict.update({k: v for k, v in train_batch.items()})
737
+
738
+ if log_once("loss_init"):
739
+ logger.debug(
740
+ "Initializing loss function with dummy input:\n\n{}\n".format(
741
+ summarize(train_batch)
742
+ )
743
+ )
744
+
745
+ losses = self._do_loss_init(train_batch)
746
+
747
+ all_accessed_keys = (
748
+ train_batch.accessed_keys
749
+ | dummy_batch.accessed_keys
750
+ | dummy_batch.added_keys
751
+ | set(self.model.view_requirements.keys())
752
+ )
753
+
754
+ TFPolicy._initialize_loss(
755
+ self,
756
+ losses,
757
+ [(k, v) for k, v in train_batch.items() if k in all_accessed_keys]
758
+ + (
759
+ [(SampleBatch.SEQ_LENS, train_batch[SampleBatch.SEQ_LENS])]
760
+ if SampleBatch.SEQ_LENS in train_batch
761
+ else []
762
+ ),
763
+ )
764
+
765
+ if "is_training" in self._loss_input_dict:
766
+ del self._loss_input_dict["is_training"]
767
+
768
+ # Call the grads stats fn.
769
+ # TODO: (sven) rename to simply stats_fn to match eager and torch.
770
+ self._stats_fetches.update(self.grad_stats_fn(train_batch, self._grads))
771
+
772
+ # Add new columns automatically to view-reqs.
773
+ if auto_remove_unneeded_view_reqs:
774
+ # Add those needed for postprocessing and training.
775
+ all_accessed_keys = train_batch.accessed_keys | dummy_batch.accessed_keys
776
+ # Tag those only needed for post-processing (with some exceptions).
777
+ for key in dummy_batch.accessed_keys:
778
+ if (
779
+ key not in train_batch.accessed_keys
780
+ and key not in self.model.view_requirements
781
+ and key
782
+ not in [
783
+ SampleBatch.EPS_ID,
784
+ SampleBatch.AGENT_INDEX,
785
+ SampleBatch.UNROLL_ID,
786
+ SampleBatch.TERMINATEDS,
787
+ SampleBatch.TRUNCATEDS,
788
+ SampleBatch.REWARDS,
789
+ SampleBatch.INFOS,
790
+ SampleBatch.T,
791
+ SampleBatch.OBS_EMBEDS,
792
+ ]
793
+ ):
794
+ if key in self.view_requirements:
795
+ self.view_requirements[key].used_for_training = False
796
+ if key in self._loss_input_dict:
797
+ del self._loss_input_dict[key]
798
+ # Remove those not needed at all (leave those that are needed
799
+ # by Sampler to properly execute sample collection).
800
+ # Also always leave TERMINATEDS, TRUNCATEDS, REWARDS, and INFOS,
801
+ # no matter what.
802
+ for key in list(self.view_requirements.keys()):
803
+ if (
804
+ key not in all_accessed_keys
805
+ and key
806
+ not in [
807
+ SampleBatch.EPS_ID,
808
+ SampleBatch.AGENT_INDEX,
809
+ SampleBatch.UNROLL_ID,
810
+ SampleBatch.TERMINATEDS,
811
+ SampleBatch.TRUNCATEDS,
812
+ SampleBatch.REWARDS,
813
+ SampleBatch.INFOS,
814
+ SampleBatch.T,
815
+ ]
816
+ and key not in self.model.view_requirements
817
+ ):
818
+ # If user deleted this key manually in postprocessing
819
+ # fn, warn about it and do not remove from
820
+ # view-requirements.
821
+ if key in dummy_batch.deleted_keys:
822
+ logger.warning(
823
+ "SampleBatch key '{}' was deleted manually in "
824
+ "postprocessing function! RLlib will "
825
+ "automatically remove non-used items from the "
826
+ "data stream. Remove the `del` from your "
827
+ "postprocessing function.".format(key)
828
+ )
829
+ # If we are not writing output to disk, safe to erase
830
+ # this key to save space in the sample batch.
831
+ elif self.config["output"] is None:
832
+ del self.view_requirements[key]
833
+
834
+ if key in self._loss_input_dict:
835
+ del self._loss_input_dict[key]
836
+ # Add those data_cols (again) that are missing and have
837
+ # dependencies by view_cols.
838
+ for key in list(self.view_requirements.keys()):
839
+ vr = self.view_requirements[key]
840
+ if (
841
+ vr.data_col is not None
842
+ and vr.data_col not in self.view_requirements
843
+ ):
844
+ used_for_training = vr.data_col in train_batch.accessed_keys
845
+ self.view_requirements[vr.data_col] = ViewRequirement(
846
+ space=vr.space, used_for_training=used_for_training
847
+ )
848
+
849
+ self._loss_input_dict_no_rnn = {
850
+ k: v
851
+ for k, v in self._loss_input_dict.items()
852
+ if (v not in self._state_inputs and v != self._seq_lens)
853
+ }
854
+
855
+ def _do_loss_init(self, train_batch: SampleBatch):
856
+ losses = self.loss(self.model, self.dist_class, train_batch)
857
+ losses = force_list(losses)
858
+ self._stats_fetches.update(self.stats_fn(train_batch))
859
+ # Override the update ops to be those of the model.
860
+ self._update_ops = []
861
+ if not isinstance(self.model, tf.keras.Model):
862
+ self._update_ops = self.model.update_ops()
863
+ return losses
864
+
865
+ @override(TFPolicy)
866
+ def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
867
+ """Creates a copy of self using existing input placeholders."""
868
+
869
+ flat_loss_inputs = tree.flatten(self._loss_input_dict)
870
+ flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
871
+
872
+ # Note that there might be RNN state inputs at the end of the list
873
+ if len(flat_loss_inputs) != len(existing_inputs):
874
+ raise ValueError(
875
+ "Tensor list mismatch",
876
+ self._loss_input_dict,
877
+ self._state_inputs,
878
+ existing_inputs,
879
+ )
880
+ for i, v in enumerate(flat_loss_inputs_no_rnn):
881
+ if v.shape.as_list() != existing_inputs[i].shape.as_list():
882
+ raise ValueError(
883
+ "Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
884
+ )
885
+ # By convention, the loss inputs are followed by state inputs and then
886
+ # the seq len tensor.
887
+ rnn_inputs = []
888
+ for i in range(len(self._state_inputs)):
889
+ rnn_inputs.append(
890
+ (
891
+ "state_in_{}".format(i),
892
+ existing_inputs[len(flat_loss_inputs_no_rnn) + i],
893
+ )
894
+ )
895
+ if rnn_inputs:
896
+ rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
897
+ existing_inputs_unflattened = tree.unflatten_as(
898
+ self._loss_input_dict_no_rnn,
899
+ existing_inputs[: len(flat_loss_inputs_no_rnn)],
900
+ )
901
+ input_dict = OrderedDict(
902
+ [("is_exploring", self._is_exploring), ("timestep", self._timestep)]
903
+ + [
904
+ (k, existing_inputs_unflattened[k])
905
+ for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
906
+ ]
907
+ + rnn_inputs
908
+ )
909
+
910
+ instance = self.__class__(
911
+ self.observation_space,
912
+ self.action_space,
913
+ self.config,
914
+ existing_inputs=input_dict,
915
+ existing_model=[
916
+ self.model,
917
+ # Deprecated: Target models should all reside under
918
+ # `policy.target_model` now.
919
+ ("target_q_model", getattr(self, "target_q_model", None)),
920
+ ("target_model", getattr(self, "target_model", None)),
921
+ ],
922
+ )
923
+
924
+ instance._loss_input_dict = input_dict
925
+ losses = instance._do_loss_init(SampleBatch(input_dict))
926
+ loss_inputs = [
927
+ (k, existing_inputs_unflattened[k])
928
+ for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
929
+ ]
930
+
931
+ TFPolicy._initialize_loss(instance, losses, loss_inputs)
932
+ instance._stats_fetches.update(
933
+ instance.grad_stats_fn(input_dict, instance._grads)
934
+ )
935
+ return instance
936
+
937
+ @override(Policy)
938
+ def get_initial_state(self) -> List[TensorType]:
939
+ if self.model:
940
+ return self.model.get_initial_state()
941
+ else:
942
+ return []
943
+
944
+ @override(Policy)
945
+ def load_batch_into_buffer(
946
+ self,
947
+ batch: SampleBatch,
948
+ buffer_index: int = 0,
949
+ ) -> int:
950
+ # Set the is_training flag of the batch.
951
+ batch.set_training(True)
952
+
953
+ # Shortcut for 1 CPU only: Store batch in
954
+ # `self._loaded_single_cpu_batch`.
955
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
956
+ assert buffer_index == 0
957
+ self._loaded_single_cpu_batch = batch
958
+ return len(batch)
959
+
960
+ input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
961
+ data_keys = tree.flatten(self._loss_input_dict_no_rnn)
962
+ if self._state_inputs:
963
+ state_keys = self._state_inputs + [self._seq_lens]
964
+ else:
965
+ state_keys = []
966
+ inputs = [input_dict[k] for k in data_keys]
967
+ state_inputs = [input_dict[k] for k in state_keys]
968
+
969
+ return self.multi_gpu_tower_stacks[buffer_index].load_data(
970
+ sess=self.get_session(),
971
+ inputs=inputs,
972
+ state_inputs=state_inputs,
973
+ num_grad_updates=batch.num_grad_updates,
974
+ )
975
+
976
+ @override(Policy)
977
+ def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
978
+ # Shortcut for 1 CPU only: Batch should already be stored in
979
+ # `self._loaded_single_cpu_batch`.
980
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
981
+ assert buffer_index == 0
982
+ return (
983
+ len(self._loaded_single_cpu_batch)
984
+ if self._loaded_single_cpu_batch is not None
985
+ else 0
986
+ )
987
+
988
+ return self.multi_gpu_tower_stacks[buffer_index].num_tuples_loaded
989
+
990
+ @override(Policy)
991
+ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
992
+ # Shortcut for 1 CPU only: Batch should already be stored in
993
+ # `self._loaded_single_cpu_batch`.
994
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
995
+ assert buffer_index == 0
996
+ if self._loaded_single_cpu_batch is None:
997
+ raise ValueError(
998
+ "Must call Policy.load_batch_into_buffer() before "
999
+ "Policy.learn_on_loaded_batch()!"
1000
+ )
1001
+ # Get the correct slice of the already loaded batch to use,
1002
+ # based on offset and batch size.
1003
+ batch_size = self.config.get("minibatch_size")
1004
+ if batch_size is None:
1005
+ batch_size = self.config.get(
1006
+ "sgd_minibatch_size", self.config["train_batch_size"]
1007
+ )
1008
+
1009
+ if batch_size >= len(self._loaded_single_cpu_batch):
1010
+ sliced_batch = self._loaded_single_cpu_batch
1011
+ else:
1012
+ sliced_batch = self._loaded_single_cpu_batch.slice(
1013
+ start=offset, end=offset + batch_size
1014
+ )
1015
+ return self.learn_on_batch(sliced_batch)
1016
+
1017
+ tower_stack = self.multi_gpu_tower_stacks[buffer_index]
1018
+ results = tower_stack.optimize(self.get_session(), offset)
1019
+ self.num_grad_updates += 1
1020
+
1021
+ results.update(
1022
+ {
1023
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
1024
+ # -1, b/c we have to measure this diff before we do the update above.
1025
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
1026
+ self.num_grad_updates - 1 - (tower_stack.num_grad_updates or 0)
1027
+ ),
1028
+ }
1029
+ )
1030
+
1031
+ return results
1032
+
1033
+ @override(TFPolicy)
1034
+ def gradients(self, optimizer, loss):
1035
+ optimizers = force_list(optimizer)
1036
+ losses = force_list(loss)
1037
+
1038
+ if is_overridden(self.compute_gradients_fn):
1039
+ # New API: Allow more than one optimizer -> Return a list of
1040
+ # lists of gradients.
1041
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
1042
+ return self.compute_gradients_fn(optimizers, losses)
1043
+ # Old API: Return a single List of gradients.
1044
+ else:
1045
+ return self.compute_gradients_fn(optimizers[0], losses[0])
1046
+ else:
1047
+ return super().gradients(optimizers, losses)
.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Eager mode TF policy built using build_tf_policy().
2
+
3
+ It supports both traced and non-traced eager execution modes."""
4
+
5
+ import functools
6
+ import logging
7
+ import os
8
+ import threading
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import tree # pip install dm_tree
12
+
13
+ from ray.rllib.models.catalog import ModelCatalog
14
+ from ray.rllib.models.repeated_values import RepeatedValues
15
+ from ray.rllib.policy.policy import Policy, PolicyState
16
+ from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
17
+ from ray.rllib.policy.sample_batch import SampleBatch
18
+ from ray.rllib.utils import add_mixins, force_list
19
+ from ray.rllib.utils.annotations import OldAPIStack, override
20
+ from ray.rllib.utils.deprecation import (
21
+ DEPRECATED_VALUE,
22
+ deprecation_warning,
23
+ )
24
+ from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
25
+ from ray.rllib.utils.framework import try_import_tf
26
+ from ray.rllib.utils.metrics import (
27
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
28
+ NUM_AGENT_STEPS_TRAINED,
29
+ NUM_GRAD_UPDATES_LIFETIME,
30
+ )
31
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
32
+ from ray.rllib.utils.numpy import convert_to_numpy
33
+ from ray.rllib.utils.spaces.space_utils import normalize_action
34
+ from ray.rllib.utils.tf_utils import get_gpu_devices
35
+ from ray.rllib.utils.threading import with_lock
36
+ from ray.rllib.utils.typing import (
37
+ LocalOptimizer,
38
+ ModelGradients,
39
+ TensorType,
40
+ TensorStructType,
41
+ )
42
+ from ray.util.debug import log_once
43
+
44
+ tf1, tf, tfv = try_import_tf()
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ def _convert_to_tf(x, dtype=None):
49
+ if isinstance(x, SampleBatch):
50
+ dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
51
+ return tree.map_structure(_convert_to_tf, dict_)
52
+ elif isinstance(x, Policy):
53
+ return x
54
+ # Special handling of "Repeated" values.
55
+ elif isinstance(x, RepeatedValues):
56
+ return RepeatedValues(
57
+ tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len
58
+ )
59
+
60
+ if x is not None:
61
+ d = dtype
62
+ return tree.map_structure(
63
+ lambda f: _convert_to_tf(f, d)
64
+ if isinstance(f, RepeatedValues)
65
+ else tf.convert_to_tensor(f, d)
66
+ if f is not None and not tf.is_tensor(f)
67
+ else f,
68
+ x,
69
+ )
70
+
71
+ return x
72
+
73
+
74
+ def _convert_to_numpy(x):
75
+ def _map(x):
76
+ if isinstance(x, tf.Tensor):
77
+ return x.numpy()
78
+ return x
79
+
80
+ try:
81
+ return tf.nest.map_structure(_map, x)
82
+ except AttributeError:
83
+ raise TypeError(
84
+ ("Object of type {} has no method to convert to numpy.").format(type(x))
85
+ )
86
+
87
+
88
+ def _convert_eager_inputs(func):
89
+ @functools.wraps(func)
90
+ def _func(*args, **kwargs):
91
+ if tf.executing_eagerly():
92
+ eager_args = [_convert_to_tf(x) for x in args]
93
+ # TODO: (sven) find a way to remove key-specific hacks.
94
+ eager_kwargs = {
95
+ k: _convert_to_tf(v, dtype=tf.int64 if k == "timestep" else None)
96
+ for k, v in kwargs.items()
97
+ if k not in {"info_batch", "episodes"}
98
+ }
99
+ return func(*eager_args, **eager_kwargs)
100
+ else:
101
+ return func(*args, **kwargs)
102
+
103
+ return _func
104
+
105
+
106
+ def _convert_eager_outputs(func):
107
+ @functools.wraps(func)
108
+ def _func(*args, **kwargs):
109
+ out = func(*args, **kwargs)
110
+ if tf.executing_eagerly():
111
+ out = tf.nest.map_structure(_convert_to_numpy, out)
112
+ return out
113
+
114
+ return _func
115
+
116
+
117
+ def _disallow_var_creation(next_creator, **kw):
118
+ v = next_creator(**kw)
119
+ raise ValueError(
120
+ "Detected a variable being created during an eager "
121
+ "forward pass. Variables should only be created during "
122
+ "model initialization: {}".format(v.name)
123
+ )
124
+
125
+
126
+ def _check_too_many_retraces(obj):
127
+ """Asserts that a given number of re-traces is not breached."""
128
+
129
+ def _func(self_, *args, **kwargs):
130
+ if (
131
+ self_.config.get("eager_max_retraces") is not None
132
+ and self_._re_trace_counter > self_.config["eager_max_retraces"]
133
+ ):
134
+ raise RuntimeError(
135
+ "Too many tf-eager re-traces detected! This could lead to"
136
+ " significant slow-downs (even slower than running in "
137
+ "tf-eager mode w/ `eager_tracing=False`). To switch off "
138
+ "these re-trace counting checks, set `eager_max_retraces`"
139
+ " in your config to None."
140
+ )
141
+ return obj(self_, *args, **kwargs)
142
+
143
+ return _func
144
+
145
+
146
+ @OldAPIStack
147
+ class EagerTFPolicy(Policy):
148
+ """Dummy class to recognize any eagerized TFPolicy by its inheritance."""
149
+
150
+ pass
151
+
152
+
153
+ def _traced_eager_policy(eager_policy_cls):
154
+ """Wrapper class that enables tracing for all eager policy methods.
155
+
156
+ This is enabled by the `--trace`/`eager_tracing=True` config when
157
+ framework=tf2.
158
+ """
159
+
160
+ class TracedEagerPolicy(eager_policy_cls):
161
+ def __init__(self, *args, **kwargs):
162
+ self._traced_learn_on_batch_helper = False
163
+ self._traced_compute_actions_helper = False
164
+ self._traced_compute_gradients_helper = False
165
+ self._traced_apply_gradients_helper = False
166
+ super(TracedEagerPolicy, self).__init__(*args, **kwargs)
167
+
168
+ @_check_too_many_retraces
169
+ @override(Policy)
170
+ def compute_actions_from_input_dict(
171
+ self,
172
+ input_dict: Dict[str, TensorType],
173
+ explore: bool = None,
174
+ timestep: Optional[int] = None,
175
+ episodes=None,
176
+ **kwargs,
177
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
178
+ """Traced version of Policy.compute_actions_from_input_dict."""
179
+
180
+ # Create a traced version of `self._compute_actions_helper`.
181
+ if self._traced_compute_actions_helper is False and not self._no_tracing:
182
+ self._compute_actions_helper = _convert_eager_inputs(
183
+ tf.function(
184
+ super(TracedEagerPolicy, self)._compute_actions_helper,
185
+ autograph=False,
186
+ reduce_retracing=True,
187
+ )
188
+ )
189
+ self._traced_compute_actions_helper = True
190
+
191
+ # Now that the helper method is traced, call super's
192
+ # `compute_actions_from_input_dict()` (which will call the traced helper).
193
+ return super(TracedEagerPolicy, self).compute_actions_from_input_dict(
194
+ input_dict=input_dict,
195
+ explore=explore,
196
+ timestep=timestep,
197
+ episodes=episodes,
198
+ **kwargs,
199
+ )
200
+
201
+ @_check_too_many_retraces
202
+ @override(eager_policy_cls)
203
+ def learn_on_batch(self, samples):
204
+ """Traced version of Policy.learn_on_batch."""
205
+
206
+ # Create a traced version of `self._learn_on_batch_helper`.
207
+ if self._traced_learn_on_batch_helper is False and not self._no_tracing:
208
+ self._learn_on_batch_helper = _convert_eager_inputs(
209
+ tf.function(
210
+ super(TracedEagerPolicy, self)._learn_on_batch_helper,
211
+ autograph=False,
212
+ reduce_retracing=True,
213
+ )
214
+ )
215
+ self._traced_learn_on_batch_helper = True
216
+
217
+ # Now that the helper method is traced, call super's
218
+ # apply_gradients (which will call the traced helper).
219
+ return super(TracedEagerPolicy, self).learn_on_batch(samples)
220
+
221
+ @_check_too_many_retraces
222
+ @override(eager_policy_cls)
223
+ def compute_gradients(self, samples: SampleBatch) -> ModelGradients:
224
+ """Traced version of Policy.compute_gradients."""
225
+
226
+ # Create a traced version of `self._compute_gradients_helper`.
227
+ if self._traced_compute_gradients_helper is False and not self._no_tracing:
228
+ self._compute_gradients_helper = _convert_eager_inputs(
229
+ tf.function(
230
+ super(TracedEagerPolicy, self)._compute_gradients_helper,
231
+ autograph=False,
232
+ reduce_retracing=True,
233
+ )
234
+ )
235
+ self._traced_compute_gradients_helper = True
236
+
237
+ # Now that the helper method is traced, call super's
238
+ # `compute_gradients()` (which will call the traced helper).
239
+ return super(TracedEagerPolicy, self).compute_gradients(samples)
240
+
241
+ @_check_too_many_retraces
242
+ @override(Policy)
243
+ def apply_gradients(self, grads: ModelGradients) -> None:
244
+ """Traced version of Policy.apply_gradients."""
245
+
246
+ # Create a traced version of `self._apply_gradients_helper`.
247
+ if self._traced_apply_gradients_helper is False and not self._no_tracing:
248
+ self._apply_gradients_helper = _convert_eager_inputs(
249
+ tf.function(
250
+ super(TracedEagerPolicy, self)._apply_gradients_helper,
251
+ autograph=False,
252
+ reduce_retracing=True,
253
+ )
254
+ )
255
+ self._traced_apply_gradients_helper = True
256
+
257
+ # Now that the helper method is traced, call super's
258
+ # `apply_gradients()` (which will call the traced helper).
259
+ return super(TracedEagerPolicy, self).apply_gradients(grads)
260
+
261
+ @classmethod
262
+ def with_tracing(cls):
263
+ # Already traced -> Return same class.
264
+ return cls
265
+
266
+ TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
267
+ TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
268
+ return TracedEagerPolicy
269
+
270
+
271
+ class _OptimizerWrapper:
272
+ def __init__(self, tape):
273
+ self.tape = tape
274
+
275
+ def compute_gradients(self, loss, var_list):
276
+ return list(zip(self.tape.gradient(loss, var_list), var_list))
277
+
278
+
279
+ @OldAPIStack
280
+ def _build_eager_tf_policy(
281
+ name,
282
+ loss_fn,
283
+ get_default_config=None,
284
+ postprocess_fn=None,
285
+ stats_fn=None,
286
+ optimizer_fn=None,
287
+ compute_gradients_fn=None,
288
+ apply_gradients_fn=None,
289
+ grad_stats_fn=None,
290
+ extra_learn_fetches_fn=None,
291
+ extra_action_out_fn=None,
292
+ validate_spaces=None,
293
+ before_init=None,
294
+ before_loss_init=None,
295
+ after_init=None,
296
+ make_model=None,
297
+ action_sampler_fn=None,
298
+ action_distribution_fn=None,
299
+ mixins=None,
300
+ get_batch_divisibility_req=None,
301
+ # Deprecated args.
302
+ obs_include_prev_action_reward=DEPRECATED_VALUE,
303
+ extra_action_fetches_fn=None,
304
+ gradients_fn=None,
305
+ ):
306
+ """Build an eager TF policy.
307
+
308
+ An eager policy runs all operations in eager mode, which makes debugging
309
+ much simpler, but has lower performance.
310
+
311
+ You shouldn't need to call this directly. Rather, prefer to build a TF
312
+ graph policy and use set `.framework("tf2", eager_tracing=False) in your
313
+ AlgorithmConfig to have it automatically be converted to an eager policy.
314
+
315
+ This has the same signature as build_tf_policy()."""
316
+
317
+ base = add_mixins(EagerTFPolicy, mixins)
318
+
319
+ if obs_include_prev_action_reward != DEPRECATED_VALUE:
320
+ deprecation_warning(old="obs_include_prev_action_reward", error=True)
321
+
322
+ if extra_action_fetches_fn is not None:
323
+ deprecation_warning(
324
+ old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
325
+ )
326
+
327
+ if gradients_fn is not None:
328
+ deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
329
+
330
+ class eager_policy_cls(base):
331
+ def __init__(self, observation_space, action_space, config):
332
+ # If this class runs as a @ray.remote actor, eager mode may not
333
+ # have been activated yet.
334
+ if not tf1.executing_eagerly():
335
+ tf1.enable_eager_execution()
336
+ self.framework = config.get("framework", "tf2")
337
+ EagerTFPolicy.__init__(self, observation_space, action_space, config)
338
+
339
+ # Global timestep should be a tensor.
340
+ self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
341
+ self.explore = tf.Variable(
342
+ self.config["explore"], trainable=False, dtype=tf.bool
343
+ )
344
+
345
+ # Log device and worker index.
346
+ num_gpus = self._get_num_gpus_for_policy()
347
+ if num_gpus > 0:
348
+ gpu_ids = get_gpu_devices()
349
+ logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
350
+
351
+ self._is_training = False
352
+
353
+ # Only for `config.eager_tracing=True`: A counter to keep track of
354
+ # how many times an eager-traced method (e.g.
355
+ # `self._compute_actions_helper`) has been re-traced by tensorflow.
356
+ # We will raise an error if more than n re-tracings have been
357
+ # detected, since this would considerably slow down execution.
358
+ # The variable below should only get incremented during the
359
+ # tf.function trace operations, never when calling the already
360
+ # traced function after that.
361
+ self._re_trace_counter = 0
362
+
363
+ self._loss_initialized = False
364
+ # To ensure backward compatibility:
365
+ # Old way: If `loss` provided here, use as-is (as a function).
366
+ if loss_fn is not None:
367
+ self._loss = loss_fn
368
+ # New way: Convert the overridden `self.loss` into a plain
369
+ # function, so it can be called the same way as `loss` would
370
+ # be, ensuring backward compatibility.
371
+ elif self.loss.__func__.__qualname__ != "Policy.loss":
372
+ self._loss = self.loss.__func__
373
+ # `loss` not provided nor overridden from Policy -> Set to None.
374
+ else:
375
+ self._loss = None
376
+
377
+ self.batch_divisibility_req = (
378
+ get_batch_divisibility_req(self)
379
+ if callable(get_batch_divisibility_req)
380
+ else (get_batch_divisibility_req or 1)
381
+ )
382
+ self._max_seq_len = config["model"]["max_seq_len"]
383
+
384
+ if validate_spaces:
385
+ validate_spaces(self, observation_space, action_space, config)
386
+
387
+ if before_init:
388
+ before_init(self, observation_space, action_space, config)
389
+
390
+ self.config = config
391
+ self.dist_class = None
392
+ if action_sampler_fn or action_distribution_fn:
393
+ if not make_model:
394
+ raise ValueError(
395
+ "`make_model` is required if `action_sampler_fn` OR "
396
+ "`action_distribution_fn` is given"
397
+ )
398
+ else:
399
+ self.dist_class, logit_dim = ModelCatalog.get_action_dist(
400
+ action_space, self.config["model"]
401
+ )
402
+
403
+ if make_model:
404
+ self.model = make_model(self, observation_space, action_space, config)
405
+ else:
406
+ self.model = ModelCatalog.get_model_v2(
407
+ observation_space,
408
+ action_space,
409
+ logit_dim,
410
+ config["model"],
411
+ framework=self.framework,
412
+ )
413
+ # Lock used for locking some methods on the object-level.
414
+ # This prevents possible race conditions when calling the model
415
+ # first, then its value function (e.g. in a loss function), in
416
+ # between of which another model call is made (e.g. to compute an
417
+ # action).
418
+ self._lock = threading.RLock()
419
+
420
+ # Auto-update model's inference view requirements, if recurrent.
421
+ self._update_model_view_requirements_from_init_state()
422
+ # Combine view_requirements for Model and Policy.
423
+ self.view_requirements.update(self.model.view_requirements)
424
+
425
+ self.exploration = self._create_exploration()
426
+ self._state_inputs = self.model.get_initial_state()
427
+ self._is_recurrent = len(self._state_inputs) > 0
428
+
429
+ if before_loss_init:
430
+ before_loss_init(self, observation_space, action_space, config)
431
+
432
+ if optimizer_fn:
433
+ optimizers = optimizer_fn(self, config)
434
+ else:
435
+ optimizers = tf.keras.optimizers.Adam(config["lr"])
436
+ optimizers = force_list(optimizers)
437
+ if self.exploration:
438
+ optimizers = self.exploration.get_exploration_optimizer(optimizers)
439
+
440
+ # The list of local (tf) optimizers (one per loss term).
441
+ self._optimizers: List[LocalOptimizer] = optimizers
442
+ # Backward compatibility: A user's policy may only support a single
443
+ # loss term and optimizer (no lists).
444
+ self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
445
+
446
+ self._initialize_loss_from_dummy_batch(
447
+ auto_remove_unneeded_view_reqs=True,
448
+ stats_fn=stats_fn,
449
+ )
450
+ self._loss_initialized = True
451
+
452
+ if after_init:
453
+ after_init(self, observation_space, action_space, config)
454
+
455
+ # Got to reset global_timestep again after fake run-throughs.
456
+ self.global_timestep.assign(0)
457
+
458
+ @override(Policy)
459
+ def compute_actions_from_input_dict(
460
+ self,
461
+ input_dict: Dict[str, TensorType],
462
+ explore: bool = None,
463
+ timestep: Optional[int] = None,
464
+ episodes=None,
465
+ **kwargs,
466
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
467
+ if not self.config.get("eager_tracing") and not tf1.executing_eagerly():
468
+ tf1.enable_eager_execution()
469
+
470
+ self._is_training = False
471
+
472
+ explore = explore if explore is not None else self.explore
473
+ timestep = timestep if timestep is not None else self.global_timestep
474
+ if isinstance(timestep, tf.Tensor):
475
+ timestep = int(timestep.numpy())
476
+
477
+ # Pass lazy (eager) tensor dict to Model as `input_dict`.
478
+ input_dict = self._lazy_tensor_dict(input_dict)
479
+ input_dict.set_training(False)
480
+
481
+ # Pack internal state inputs into (separate) list.
482
+ state_batches = [
483
+ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
484
+ ]
485
+ self._state_in = state_batches
486
+ self._is_recurrent = state_batches != []
487
+
488
+ # Call the exploration before_compute_actions hook.
489
+ self.exploration.before_compute_actions(
490
+ timestep=timestep, explore=explore, tf_sess=self.get_session()
491
+ )
492
+
493
+ ret = self._compute_actions_helper(
494
+ input_dict,
495
+ state_batches,
496
+ # TODO: Passing episodes into a traced method does not work.
497
+ None if self.config["eager_tracing"] else episodes,
498
+ explore,
499
+ timestep,
500
+ )
501
+ # Update our global timestep by the batch size.
502
+ self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
503
+ return convert_to_numpy(ret)
504
+
505
+ @override(Policy)
506
+ def compute_actions(
507
+ self,
508
+ obs_batch: Union[List[TensorStructType], TensorStructType],
509
+ state_batches: Optional[List[TensorType]] = None,
510
+ prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
511
+ prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
512
+ info_batch: Optional[Dict[str, list]] = None,
513
+ episodes: Optional[List] = None,
514
+ explore: Optional[bool] = None,
515
+ timestep: Optional[int] = None,
516
+ **kwargs,
517
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
518
+ # Create input dict to simply pass the entire call to
519
+ # self.compute_actions_from_input_dict().
520
+ input_dict = SampleBatch(
521
+ {
522
+ SampleBatch.CUR_OBS: obs_batch,
523
+ },
524
+ _is_training=tf.constant(False),
525
+ )
526
+ if state_batches is not None:
527
+ for i, s in enumerate(state_batches):
528
+ input_dict[f"state_in_{i}"] = s
529
+ if prev_action_batch is not None:
530
+ input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
531
+ if prev_reward_batch is not None:
532
+ input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
533
+ if info_batch is not None:
534
+ input_dict[SampleBatch.INFOS] = info_batch
535
+
536
+ return self.compute_actions_from_input_dict(
537
+ input_dict=input_dict,
538
+ explore=explore,
539
+ timestep=timestep,
540
+ episodes=episodes,
541
+ **kwargs,
542
+ )
543
+
544
+ @with_lock
545
+ @override(Policy)
546
+ def compute_log_likelihoods(
547
+ self,
548
+ actions,
549
+ obs_batch,
550
+ state_batches=None,
551
+ prev_action_batch=None,
552
+ prev_reward_batch=None,
553
+ actions_normalized=True,
554
+ **kwargs,
555
+ ):
556
+ if action_sampler_fn and action_distribution_fn is None:
557
+ raise ValueError(
558
+ "Cannot compute log-prob/likelihood w/o an "
559
+ "`action_distribution_fn` and a provided "
560
+ "`action_sampler_fn`!"
561
+ )
562
+
563
+ seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
564
+ input_batch = SampleBatch(
565
+ {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
566
+ _is_training=False,
567
+ )
568
+ if prev_action_batch is not None:
569
+ input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
570
+ prev_action_batch
571
+ )
572
+ if prev_reward_batch is not None:
573
+ input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
574
+ prev_reward_batch
575
+ )
576
+
577
+ if self.exploration:
578
+ # Exploration hook before each forward pass.
579
+ self.exploration.before_compute_actions(explore=False)
580
+
581
+ # Action dist class and inputs are generated via custom function.
582
+ if action_distribution_fn:
583
+ dist_inputs, dist_class, _ = action_distribution_fn(
584
+ self, self.model, input_batch, explore=False, is_training=False
585
+ )
586
+ # Default log-likelihood calculation.
587
+ else:
588
+ dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
589
+ dist_class = self.dist_class
590
+
591
+ action_dist = dist_class(dist_inputs, self.model)
592
+
593
+ # Normalize actions if necessary.
594
+ if not actions_normalized and self.config["normalize_actions"]:
595
+ actions = normalize_action(actions, self.action_space_struct)
596
+
597
+ log_likelihoods = action_dist.logp(actions)
598
+
599
+ return log_likelihoods
600
+
601
+ @override(Policy)
602
+ def postprocess_trajectory(
603
+ self, sample_batch, other_agent_batches=None, episode=None
604
+ ):
605
+ assert tf.executing_eagerly()
606
+ # Call super's postprocess_trajectory first.
607
+ sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
608
+ if postprocess_fn:
609
+ return postprocess_fn(self, sample_batch, other_agent_batches, episode)
610
+ return sample_batch
611
+
612
+ @with_lock
613
+ @override(Policy)
614
+ def learn_on_batch(self, postprocessed_batch):
615
+ # Callback handling.
616
+ learn_stats = {}
617
+ self.callbacks.on_learn_on_batch(
618
+ policy=self, train_batch=postprocessed_batch, result=learn_stats
619
+ )
620
+
621
+ pad_batch_to_sequences_of_same_size(
622
+ postprocessed_batch,
623
+ max_seq_len=self._max_seq_len,
624
+ shuffle=False,
625
+ batch_divisibility_req=self.batch_divisibility_req,
626
+ view_requirements=self.view_requirements,
627
+ )
628
+
629
+ self._is_training = True
630
+ postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
631
+ postprocessed_batch.set_training(True)
632
+ stats = self._learn_on_batch_helper(postprocessed_batch)
633
+ self.num_grad_updates += 1
634
+
635
+ stats.update(
636
+ {
637
+ "custom_metrics": learn_stats,
638
+ NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
639
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
640
+ # -1, b/c we have to measure this diff before we do the update
641
+ # above.
642
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
643
+ self.num_grad_updates
644
+ - 1
645
+ - (postprocessed_batch.num_grad_updates or 0)
646
+ ),
647
+ }
648
+ )
649
+ return convert_to_numpy(stats)
650
+
651
+ @override(Policy)
652
+ def compute_gradients(
653
+ self, postprocessed_batch: SampleBatch
654
+ ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
655
+ pad_batch_to_sequences_of_same_size(
656
+ postprocessed_batch,
657
+ shuffle=False,
658
+ max_seq_len=self._max_seq_len,
659
+ batch_divisibility_req=self.batch_divisibility_req,
660
+ view_requirements=self.view_requirements,
661
+ )
662
+
663
+ self._is_training = True
664
+ self._lazy_tensor_dict(postprocessed_batch)
665
+ postprocessed_batch.set_training(True)
666
+ grads_and_vars, grads, stats = self._compute_gradients_helper(
667
+ postprocessed_batch
668
+ )
669
+ return convert_to_numpy((grads, stats))
670
+
671
+ @override(Policy)
672
+ def apply_gradients(self, gradients: ModelGradients) -> None:
673
+ self._apply_gradients_helper(
674
+ list(
675
+ zip(
676
+ [
677
+ (tf.convert_to_tensor(g) if g is not None else None)
678
+ for g in gradients
679
+ ],
680
+ self.model.trainable_variables(),
681
+ )
682
+ )
683
+ )
684
+
685
+ @override(Policy)
686
+ def get_weights(self, as_dict=False):
687
+ variables = self.variables()
688
+ if as_dict:
689
+ return {v.name: v.numpy() for v in variables}
690
+ return [v.numpy() for v in variables]
691
+
692
+ @override(Policy)
693
+ def set_weights(self, weights):
694
+ variables = self.variables()
695
+ assert len(weights) == len(variables), (len(weights), len(variables))
696
+ for v, w in zip(variables, weights):
697
+ v.assign(w)
698
+
699
+ @override(Policy)
700
+ def get_exploration_state(self):
701
+ return convert_to_numpy(self.exploration.get_state())
702
+
703
+ @override(Policy)
704
+ def is_recurrent(self):
705
+ return self._is_recurrent
706
+
707
+ @override(Policy)
708
+ def num_state_tensors(self):
709
+ return len(self._state_inputs)
710
+
711
+ @override(Policy)
712
+ def get_initial_state(self):
713
+ if hasattr(self, "model"):
714
+ return self.model.get_initial_state()
715
+ return []
716
+
717
+ @override(Policy)
718
+ def get_state(self) -> PolicyState:
719
+ # Legacy Policy state (w/o keras model and w/o PolicySpec).
720
+ state = super().get_state()
721
+
722
+ state["global_timestep"] = state["global_timestep"].numpy()
723
+ if self._optimizer and len(self._optimizer.variables()) > 0:
724
+ state["_optimizer_variables"] = self._optimizer.variables()
725
+ # Add exploration state.
726
+ if self.exploration:
727
+ # This is not compatible with RLModules, which have a method
728
+ # `forward_exploration` to specify custom exploration behavior.
729
+ state["_exploration_state"] = self.exploration.get_state()
730
+ return state
731
+
732
+ @override(Policy)
733
+ def set_state(self, state: PolicyState) -> None:
734
+ # Set optimizer vars first.
735
+ optimizer_vars = state.get("_optimizer_variables", None)
736
+ if optimizer_vars and self._optimizer.variables():
737
+ if not type(self).__name__.endswith("_traced") and log_once(
738
+ "set_state_optimizer_vars_tf_eager_policy_v2"
739
+ ):
740
+ logger.warning(
741
+ "Cannot restore an optimizer's state for tf eager! Keras "
742
+ "is not able to save the v1.x optimizers (from "
743
+ "tf.compat.v1.train) since they aren't compatible with "
744
+ "checkpoints."
745
+ )
746
+ for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
747
+ opt_var.assign(value)
748
+ # Set exploration's state.
749
+ if hasattr(self, "exploration") and "_exploration_state" in state:
750
+ self.exploration.set_state(state=state["_exploration_state"])
751
+
752
+ # Restore glbal timestep (tf vars).
753
+ self.global_timestep.assign(state["global_timestep"])
754
+
755
+ # Then the Policy's (NN) weights and connectors.
756
+ super().set_state(state)
757
+
758
+ @override(Policy)
759
+ def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
760
+ """Exports the Policy's Model to local directory for serving.
761
+
762
+ Note: Since the TfModelV2 class that EagerTfPolicy uses is-NOT-a
763
+ tf.keras.Model, we need to assume that there is a `base_model` property
764
+ within this TfModelV2 class that is-a tf.keras.Model. This base model
765
+ will be used here for the export.
766
+ TODO (kourosh): This restriction will be resolved once we move Policy and
767
+ ModelV2 to the new Learner/RLModule APIs.
768
+
769
+ Args:
770
+ export_dir: Local writable directory.
771
+ onnx: If given, will export model in ONNX format. The
772
+ value of this parameter set the ONNX OpSet version to use.
773
+ """
774
+ if (
775
+ hasattr(self, "model")
776
+ and hasattr(self.model, "base_model")
777
+ and isinstance(self.model.base_model, tf.keras.Model)
778
+ ):
779
+ # Store model in ONNX format.
780
+ if onnx:
781
+ try:
782
+ import tf2onnx
783
+ except ImportError as e:
784
+ raise RuntimeError(
785
+ "Converting a TensorFlow model to ONNX requires "
786
+ "`tf2onnx` to be installed. Install with "
787
+ "`pip install tf2onnx`."
788
+ ) from e
789
+
790
+ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
791
+ self.model.base_model,
792
+ output_path=os.path.join(export_dir, "model.onnx"),
793
+ )
794
+ # Save the tf.keras.Model (architecture and weights, so it can be
795
+ # retrieved w/o access to the original (custom) Model or Policy code).
796
+ else:
797
+ try:
798
+ self.model.base_model.save(export_dir, save_format="tf")
799
+ except Exception:
800
+ logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
801
+ else:
802
+ logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
803
+
804
+ def variables(self):
805
+ """Return the list of all savable variables for this policy."""
806
+ if isinstance(self.model, tf.keras.Model):
807
+ return self.model.variables
808
+ else:
809
+ return self.model.variables()
810
+
811
+ def loss_initialized(self):
812
+ return self._loss_initialized
813
+
814
+ @with_lock
815
+ def _compute_actions_helper(
816
+ self, input_dict, state_batches, episodes, explore, timestep
817
+ ):
818
+ # Increase the tracing counter to make sure we don't re-trace too
819
+ # often. If eager_tracing=True, this counter should only get
820
+ # incremented during the @tf.function trace operations, never when
821
+ # calling the already traced function after that.
822
+ self._re_trace_counter += 1
823
+
824
+ # Calculate RNN sequence lengths.
825
+ batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
826
+ seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
827
+
828
+ # Add default and custom fetches.
829
+ extra_fetches = {}
830
+
831
+ # Use Exploration object.
832
+ with tf.variable_creator_scope(_disallow_var_creation):
833
+ if action_sampler_fn:
834
+ action_sampler_outputs = action_sampler_fn(
835
+ self,
836
+ self.model,
837
+ input_dict[SampleBatch.CUR_OBS],
838
+ explore=explore,
839
+ timestep=timestep,
840
+ episodes=episodes,
841
+ )
842
+ if len(action_sampler_outputs) == 4:
843
+ actions, logp, dist_inputs, state_out = action_sampler_outputs
844
+ else:
845
+ dist_inputs = None
846
+ state_out = []
847
+ actions, logp = action_sampler_outputs
848
+ else:
849
+ if action_distribution_fn:
850
+ # Try new action_distribution_fn signature, supporting
851
+ # state_batches and seq_lens.
852
+ try:
853
+ (
854
+ dist_inputs,
855
+ self.dist_class,
856
+ state_out,
857
+ ) = action_distribution_fn(
858
+ self,
859
+ self.model,
860
+ input_dict=input_dict,
861
+ state_batches=state_batches,
862
+ seq_lens=seq_lens,
863
+ explore=explore,
864
+ timestep=timestep,
865
+ is_training=False,
866
+ )
867
+ # Trying the old way (to stay backward compatible).
868
+ # TODO: Remove in future.
869
+ except TypeError as e:
870
+ if (
871
+ "positional argument" in e.args[0]
872
+ or "unexpected keyword argument" in e.args[0]
873
+ ):
874
+ (
875
+ dist_inputs,
876
+ self.dist_class,
877
+ state_out,
878
+ ) = action_distribution_fn(
879
+ self,
880
+ self.model,
881
+ input_dict[SampleBatch.OBS],
882
+ explore=explore,
883
+ timestep=timestep,
884
+ is_training=False,
885
+ )
886
+ else:
887
+ raise e
888
+ elif isinstance(self.model, tf.keras.Model):
889
+ input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
890
+ if state_batches and "state_in_0" not in input_dict:
891
+ for i, s in enumerate(state_batches):
892
+ input_dict[f"state_in_{i}"] = s
893
+ self._lazy_tensor_dict(input_dict)
894
+ dist_inputs, state_out, extra_fetches = self.model(input_dict)
895
+ else:
896
+ dist_inputs, state_out = self.model(
897
+ input_dict, state_batches, seq_lens
898
+ )
899
+
900
+ action_dist = self.dist_class(dist_inputs, self.model)
901
+
902
+ # Get the exploration action from the forward results.
903
+ actions, logp = self.exploration.get_exploration_action(
904
+ action_distribution=action_dist,
905
+ timestep=timestep,
906
+ explore=explore,
907
+ )
908
+
909
+ # Action-logp and action-prob.
910
+ if logp is not None:
911
+ extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
912
+ extra_fetches[SampleBatch.ACTION_LOGP] = logp
913
+ # Action-dist inputs.
914
+ if dist_inputs is not None:
915
+ extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
916
+ # Custom extra fetches.
917
+ if extra_action_out_fn:
918
+ extra_fetches.update(extra_action_out_fn(self))
919
+
920
+ return actions, state_out, extra_fetches
921
+
922
+ # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
923
+ # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
924
+ # It seems there may be a clash between the traced-by-tf function and the
925
+ # traced-by-ray functions (for making the policy class a ray actor).
926
+ def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
927
+ # Increase the tracing counter to make sure we don't re-trace too
928
+ # often. If eager_tracing=True, this counter should only get
929
+ # incremented during the @tf.function trace operations, never when
930
+ # calling the already traced function after that.
931
+ self._re_trace_counter += 1
932
+
933
+ with tf.variable_creator_scope(_disallow_var_creation):
934
+ grads_and_vars, _, stats = self._compute_gradients_helper(samples)
935
+ self._apply_gradients_helper(grads_and_vars)
936
+ return stats
937
+
938
+ def _get_is_training_placeholder(self):
939
+ return tf.convert_to_tensor(self._is_training)
940
+
941
+ @with_lock
942
+ def _compute_gradients_helper(self, samples):
943
+ """Computes and returns grads as eager tensors."""
944
+
945
+ # Increase the tracing counter to make sure we don't re-trace too
946
+ # often. If eager_tracing=True, this counter should only get
947
+ # incremented during the @tf.function trace operations, never when
948
+ # calling the already traced function after that.
949
+ self._re_trace_counter += 1
950
+
951
+ # Gather all variables for which to calculate losses.
952
+ if isinstance(self.model, tf.keras.Model):
953
+ variables = self.model.trainable_variables
954
+ else:
955
+ variables = self.model.trainable_variables()
956
+
957
+ # Calculate the loss(es) inside a tf GradientTape.
958
+ with tf.GradientTape(persistent=compute_gradients_fn is not None) as tape:
959
+ losses = self._loss(self, self.model, self.dist_class, samples)
960
+ losses = force_list(losses)
961
+
962
+ # User provided a compute_gradients_fn.
963
+ if compute_gradients_fn:
964
+ # Wrap our tape inside a wrapper, such that the resulting
965
+ # object looks like a "classic" tf.optimizer. This way, custom
966
+ # compute_gradients_fn will work on both tf static graph
967
+ # and tf-eager.
968
+ optimizer = _OptimizerWrapper(tape)
969
+ # More than one loss terms/optimizers.
970
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
971
+ grads_and_vars = compute_gradients_fn(
972
+ self, [optimizer] * len(losses), losses
973
+ )
974
+ # Only one loss and one optimizer.
975
+ else:
976
+ grads_and_vars = [compute_gradients_fn(self, optimizer, losses[0])]
977
+ # Default: Compute gradients using the above tape.
978
+ else:
979
+ grads_and_vars = [
980
+ list(zip(tape.gradient(loss, variables), variables))
981
+ for loss in losses
982
+ ]
983
+
984
+ if log_once("grad_vars"):
985
+ for g_and_v in grads_and_vars:
986
+ for g, v in g_and_v:
987
+ if g is not None:
988
+ logger.info(f"Optimizing variable {v.name}")
989
+
990
+ # `grads_and_vars` is returned a list (len=num optimizers/losses)
991
+ # of lists of (grad, var) tuples.
992
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
993
+ grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
994
+ # `grads_and_vars` is returned as a list of (grad, var) tuples.
995
+ else:
996
+ grads_and_vars = grads_and_vars[0]
997
+ grads = [g for g, _ in grads_and_vars]
998
+
999
+ stats = self._stats(self, samples, grads)
1000
+ return grads_and_vars, grads, stats
1001
+
1002
+ def _apply_gradients_helper(self, grads_and_vars):
1003
+ # Increase the tracing counter to make sure we don't re-trace too
1004
+ # often. If eager_tracing=True, this counter should only get
1005
+ # incremented during the @tf.function trace operations, never when
1006
+ # calling the already traced function after that.
1007
+ self._re_trace_counter += 1
1008
+
1009
+ if apply_gradients_fn:
1010
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
1011
+ apply_gradients_fn(self, self._optimizers, grads_and_vars)
1012
+ else:
1013
+ apply_gradients_fn(self, self._optimizer, grads_and_vars)
1014
+ else:
1015
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
1016
+ for i, o in enumerate(self._optimizers):
1017
+ o.apply_gradients(
1018
+ [(g, v) for g, v in grads_and_vars[i] if g is not None]
1019
+ )
1020
+ else:
1021
+ self._optimizer.apply_gradients(
1022
+ [(g, v) for g, v in grads_and_vars if g is not None]
1023
+ )
1024
+
1025
+ def _stats(self, outputs, samples, grads):
1026
+ fetches = {}
1027
+ if stats_fn:
1028
+ fetches[LEARNER_STATS_KEY] = dict(stats_fn(outputs, samples))
1029
+ else:
1030
+ fetches[LEARNER_STATS_KEY] = {}
1031
+
1032
+ if extra_learn_fetches_fn:
1033
+ fetches.update(dict(extra_learn_fetches_fn(self)))
1034
+ if grad_stats_fn:
1035
+ fetches.update(dict(grad_stats_fn(self, samples, grads)))
1036
+ return fetches
1037
+
1038
+ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
1039
+ # TODO: (sven): Keep for a while to ensure backward compatibility.
1040
+ if not isinstance(postprocessed_batch, SampleBatch):
1041
+ postprocessed_batch = SampleBatch(postprocessed_batch)
1042
+ postprocessed_batch.set_get_interceptor(_convert_to_tf)
1043
+ return postprocessed_batch
1044
+
1045
+ @classmethod
1046
+ def with_tracing(cls):
1047
+ return _traced_eager_policy(cls)
1048
+
1049
+ eager_policy_cls.__name__ = name + "_eager"
1050
+ eager_policy_cls.__qualname__ = name + "_eager"
1051
+ return eager_policy_cls
.venv/lib/python3.11/site-packages/ray/rllib/policy/eager_tf_policy_v2.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Eager mode TF policy built using build_tf_policy().
2
+
3
+ It supports both traced and non-traced eager execution modes.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import threading
9
+ from typing import Dict, List, Optional, Tuple, Type, Union
10
+
11
+ import gymnasium as gym
12
+ import tree # pip install dm_tree
13
+
14
+ from ray.rllib.utils.numpy import convert_to_numpy
15
+ from ray.rllib.models.catalog import ModelCatalog
16
+ from ray.rllib.models.modelv2 import ModelV2
17
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
18
+ from ray.rllib.policy.eager_tf_policy import (
19
+ _convert_to_tf,
20
+ _disallow_var_creation,
21
+ _OptimizerWrapper,
22
+ _traced_eager_policy,
23
+ )
24
+ from ray.rllib.policy.policy import Policy, PolicyState
25
+ from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
26
+ from ray.rllib.policy.sample_batch import SampleBatch
27
+ from ray.rllib.utils import force_list
28
+ from ray.rllib.utils.annotations import (
29
+ is_overridden,
30
+ OldAPIStack,
31
+ OverrideToImplementCustomLogic,
32
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
33
+ override,
34
+ )
35
+ from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
36
+ from ray.rllib.utils.framework import try_import_tf
37
+ from ray.rllib.utils.metrics import (
38
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
39
+ NUM_AGENT_STEPS_TRAINED,
40
+ NUM_GRAD_UPDATES_LIFETIME,
41
+ )
42
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
43
+ from ray.rllib.utils.spaces.space_utils import normalize_action
44
+ from ray.rllib.utils.tf_utils import get_gpu_devices
45
+ from ray.rllib.utils.threading import with_lock
46
+ from ray.rllib.utils.typing import (
47
+ AlgorithmConfigDict,
48
+ LocalOptimizer,
49
+ ModelGradients,
50
+ TensorType,
51
+ )
52
+ from ray.util.debug import log_once
53
+
54
+ tf1, tf, tfv = try_import_tf()
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ @OldAPIStack
59
+ class EagerTFPolicyV2(Policy):
60
+ """A TF-eager / TF2 based tensorflow policy.
61
+
62
+ This class is intended to be used and extended by sub-classing.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ observation_space: gym.spaces.Space,
68
+ action_space: gym.spaces.Space,
69
+ config: AlgorithmConfigDict,
70
+ **kwargs,
71
+ ):
72
+ self.framework = config.get("framework", "tf2")
73
+
74
+ # Log device.
75
+ logger.info(
76
+ "Creating TF-eager policy running on {}.".format(
77
+ "GPU" if get_gpu_devices() else "CPU"
78
+ )
79
+ )
80
+
81
+ Policy.__init__(self, observation_space, action_space, config)
82
+
83
+ self._is_training = False
84
+ # Global timestep should be a tensor.
85
+ self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
86
+ self.explore = tf.Variable(
87
+ self.config["explore"], trainable=False, dtype=tf.bool
88
+ )
89
+
90
+ # Log device and worker index.
91
+ num_gpus = self._get_num_gpus_for_policy()
92
+ if num_gpus > 0:
93
+ gpu_ids = get_gpu_devices()
94
+ logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
95
+
96
+ self._is_training = False
97
+
98
+ self._loss_initialized = False
99
+ # Backward compatibility workaround so Policy will call self.loss() directly.
100
+ # TODO(jungong): clean up after all policies are migrated to new sub-class
101
+ # implementation.
102
+ self._loss = None
103
+
104
+ self.batch_divisibility_req = self.get_batch_divisibility_req()
105
+ self._max_seq_len = self.config["model"]["max_seq_len"]
106
+
107
+ self.validate_spaces(observation_space, action_space, self.config)
108
+
109
+ # If using default make_model(), dist_class will get updated when
110
+ # the model is created next.
111
+ self.dist_class = self._init_dist_class()
112
+ self.model = self.make_model()
113
+
114
+ self._init_view_requirements()
115
+
116
+ self.exploration = self._create_exploration()
117
+ self._state_inputs = self.model.get_initial_state()
118
+ self._is_recurrent = len(self._state_inputs) > 0
119
+
120
+ # Got to reset global_timestep again after fake run-throughs.
121
+ self.global_timestep.assign(0)
122
+
123
+ # Lock used for locking some methods on the object-level.
124
+ # This prevents possible race conditions when calling the model
125
+ # first, then its value function (e.g. in a loss function), in
126
+ # between of which another model call is made (e.g. to compute an
127
+ # action).
128
+ self._lock = threading.RLock()
129
+
130
+ # Only for `config.eager_tracing=True`: A counter to keep track of
131
+ # how many times an eager-traced method (e.g.
132
+ # `self._compute_actions_helper`) has been re-traced by tensorflow.
133
+ # We will raise an error if more than n re-tracings have been
134
+ # detected, since this would considerably slow down execution.
135
+ # The variable below should only get incremented during the
136
+ # tf.function trace operations, never when calling the already
137
+ # traced function after that.
138
+ self._re_trace_counter = 0
139
+
140
+ @staticmethod
141
+ def enable_eager_execution_if_necessary():
142
+ # If this class runs as a @ray.remote actor, eager mode may not
143
+ # have been activated yet.
144
+ if tf1 and not tf1.executing_eagerly():
145
+ tf1.enable_eager_execution()
146
+
147
+ @OverrideToImplementCustomLogic
148
+ def validate_spaces(
149
+ self,
150
+ obs_space: gym.spaces.Space,
151
+ action_space: gym.spaces.Space,
152
+ config: AlgorithmConfigDict,
153
+ ):
154
+ return {}
155
+
156
+ @OverrideToImplementCustomLogic
157
+ @override(Policy)
158
+ def loss(
159
+ self,
160
+ model: Union[ModelV2, "tf.keras.Model"],
161
+ dist_class: Type[TFActionDistribution],
162
+ train_batch: SampleBatch,
163
+ ) -> Union[TensorType, List[TensorType]]:
164
+ """Compute loss for this policy using model, dist_class and a train_batch.
165
+
166
+ Args:
167
+ model: The Model to calculate the loss for.
168
+ dist_class: The action distr. class.
169
+ train_batch: The training data.
170
+
171
+ Returns:
172
+ A single loss tensor or a list of loss tensors.
173
+ """
174
+ raise NotImplementedError
175
+
176
+ @OverrideToImplementCustomLogic
177
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
178
+ """Stats function. Returns a dict of statistics.
179
+
180
+ Args:
181
+ train_batch: The SampleBatch (already) used for training.
182
+
183
+ Returns:
184
+ The stats dict.
185
+ """
186
+ return {}
187
+
188
+ @OverrideToImplementCustomLogic
189
+ def grad_stats_fn(
190
+ self, train_batch: SampleBatch, grads: ModelGradients
191
+ ) -> Dict[str, TensorType]:
192
+ """Gradient stats function. Returns a dict of statistics.
193
+
194
+ Args:
195
+ train_batch: The SampleBatch (already) used for training.
196
+
197
+ Returns:
198
+ The stats dict.
199
+ """
200
+ return {}
201
+
202
+ @OverrideToImplementCustomLogic
203
+ def make_model(self) -> ModelV2:
204
+ """Build underlying model for this Policy.
205
+
206
+ Returns:
207
+ The Model for the Policy to use.
208
+ """
209
+ # Default ModelV2 model.
210
+ _, logit_dim = ModelCatalog.get_action_dist(
211
+ self.action_space, self.config["model"]
212
+ )
213
+ return ModelCatalog.get_model_v2(
214
+ self.observation_space,
215
+ self.action_space,
216
+ logit_dim,
217
+ self.config["model"],
218
+ framework=self.framework,
219
+ )
220
+
221
+ @OverrideToImplementCustomLogic
222
+ def compute_gradients_fn(
223
+ self, policy: Policy, optimizer: LocalOptimizer, loss: TensorType
224
+ ) -> ModelGradients:
225
+ """Gradients computing function (from loss tensor, using local optimizer).
226
+
227
+ Args:
228
+ policy: The Policy object that generated the loss tensor and
229
+ that holds the given local optimizer.
230
+ optimizer: The tf (local) optimizer object to
231
+ calculate the gradients with.
232
+ loss: The loss tensor for which gradients should be
233
+ calculated.
234
+
235
+ Returns:
236
+ ModelGradients: List of the possibly clipped gradients- and variable
237
+ tuples.
238
+ """
239
+ return None
240
+
241
+ @OverrideToImplementCustomLogic
242
+ def apply_gradients_fn(
243
+ self,
244
+ optimizer: "tf.keras.optimizers.Optimizer",
245
+ grads: ModelGradients,
246
+ ) -> "tf.Operation":
247
+ """Gradients computing function (from loss tensor, using local optimizer).
248
+
249
+ Args:
250
+ optimizer: The tf (local) optimizer object to
251
+ calculate the gradients with.
252
+ grads: The gradient tensor to be applied.
253
+
254
+ Returns:
255
+ "tf.Operation": TF operation that applies supplied gradients.
256
+ """
257
+ return None
258
+
259
+ @OverrideToImplementCustomLogic
260
+ def action_sampler_fn(
261
+ self,
262
+ model: ModelV2,
263
+ *,
264
+ obs_batch: TensorType,
265
+ state_batches: TensorType,
266
+ **kwargs,
267
+ ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
268
+ """Custom function for sampling new actions given policy.
269
+
270
+ Args:
271
+ model: Underlying model.
272
+ obs_batch: Observation tensor batch.
273
+ state_batches: Action sampling state batch.
274
+
275
+ Returns:
276
+ Sampled action
277
+ Log-likelihood
278
+ Action distribution inputs
279
+ Updated state
280
+ """
281
+ return None, None, None, None
282
+
283
+ @OverrideToImplementCustomLogic
284
+ def action_distribution_fn(
285
+ self,
286
+ model: ModelV2,
287
+ *,
288
+ obs_batch: TensorType,
289
+ state_batches: TensorType,
290
+ **kwargs,
291
+ ) -> Tuple[TensorType, type, List[TensorType]]:
292
+ """Action distribution function for this Policy.
293
+
294
+ Args:
295
+ model: Underlying model.
296
+ obs_batch: Observation tensor batch.
297
+ state_batches: Action sampling state batch.
298
+
299
+ Returns:
300
+ Distribution input.
301
+ ActionDistribution class.
302
+ State outs.
303
+ """
304
+ return None, None, None
305
+
306
+ @OverrideToImplementCustomLogic
307
+ def get_batch_divisibility_req(self) -> int:
308
+ """Get batch divisibility request.
309
+
310
+ Returns:
311
+ Size N. A sample batch must be of size K*N.
312
+ """
313
+ # By default, any sized batch is ok, so simply return 1.
314
+ return 1
315
+
316
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
317
+ def extra_action_out_fn(self) -> Dict[str, TensorType]:
318
+ """Extra values to fetch and return from compute_actions().
319
+
320
+ Returns:
321
+ Dict[str, TensorType]: An extra fetch-dict to be passed to and
322
+ returned from the compute_actions() call.
323
+ """
324
+ return {}
325
+
326
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
327
+ def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
328
+ """Extra stats to be reported after gradient computation.
329
+
330
+ Returns:
331
+ Dict[str, TensorType]: An extra fetch-dict.
332
+ """
333
+ return {}
334
+
335
+ @override(Policy)
336
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
337
+ def postprocess_trajectory(
338
+ self,
339
+ sample_batch: SampleBatch,
340
+ other_agent_batches: Optional[SampleBatch] = None,
341
+ episode=None,
342
+ ):
343
+ """Post process trajectory in the format of a SampleBatch.
344
+
345
+ Args:
346
+ sample_batch: sample_batch: batch of experiences for the policy,
347
+ which will contain at most one episode trajectory.
348
+ other_agent_batches: In a multi-agent env, this contains a
349
+ mapping of agent ids to (policy, agent_batch) tuples
350
+ containing the policy and experiences of the other agents.
351
+ episode: An optional multi-agent episode object to provide
352
+ access to all of the internal episode state, which may
353
+ be useful for model-based or multi-agent algorithms.
354
+
355
+ Returns:
356
+ The postprocessed sample batch.
357
+ """
358
+ assert tf.executing_eagerly()
359
+ return Policy.postprocess_trajectory(self, sample_batch)
360
+
361
+ @OverrideToImplementCustomLogic
362
+ def optimizer(
363
+ self,
364
+ ) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
365
+ """TF optimizer to use for policy optimization.
366
+
367
+ Returns:
368
+ A local optimizer or a list of local optimizers to use for this
369
+ Policy's Model.
370
+ """
371
+ return tf.keras.optimizers.Adam(self.config["lr"])
372
+
373
+ def _init_dist_class(self):
374
+ if is_overridden(self.action_sampler_fn) or is_overridden(
375
+ self.action_distribution_fn
376
+ ):
377
+ if not is_overridden(self.make_model):
378
+ raise ValueError(
379
+ "`make_model` is required if `action_sampler_fn` OR "
380
+ "`action_distribution_fn` is given"
381
+ )
382
+ return None
383
+ else:
384
+ dist_class, _ = ModelCatalog.get_action_dist(
385
+ self.action_space, self.config["model"]
386
+ )
387
+ return dist_class
388
+
389
+ def _init_view_requirements(self):
390
+ # Auto-update model's inference view requirements, if recurrent.
391
+ self._update_model_view_requirements_from_init_state()
392
+ # Combine view_requirements for Model and Policy.
393
+ self.view_requirements.update(self.model.view_requirements)
394
+
395
+ # Disable env-info placeholder.
396
+ if SampleBatch.INFOS in self.view_requirements:
397
+ self.view_requirements[SampleBatch.INFOS].used_for_training = False
398
+
399
+ def maybe_initialize_optimizer_and_loss(self):
400
+ optimizers = force_list(self.optimizer())
401
+ if self.exploration:
402
+ # Policies with RLModules don't have an exploration object.
403
+ optimizers = self.exploration.get_exploration_optimizer(optimizers)
404
+
405
+ # The list of local (tf) optimizers (one per loss term).
406
+ self._optimizers: List[LocalOptimizer] = optimizers
407
+ # Backward compatibility: A user's policy may only support a single
408
+ # loss term and optimizer (no lists).
409
+ self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
410
+
411
+ self._initialize_loss_from_dummy_batch(
412
+ auto_remove_unneeded_view_reqs=True,
413
+ )
414
+ self._loss_initialized = True
415
+
416
+ @override(Policy)
417
+ def compute_actions_from_input_dict(
418
+ self,
419
+ input_dict: Dict[str, TensorType],
420
+ explore: bool = None,
421
+ timestep: Optional[int] = None,
422
+ episodes=None,
423
+ **kwargs,
424
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
425
+ self._is_training = False
426
+
427
+ explore = explore if explore is not None else self.explore
428
+ timestep = timestep if timestep is not None else self.global_timestep
429
+ if isinstance(timestep, tf.Tensor):
430
+ timestep = int(timestep.numpy())
431
+
432
+ # Pass lazy (eager) tensor dict to Model as `input_dict`.
433
+ input_dict = self._lazy_tensor_dict(input_dict)
434
+ input_dict.set_training(False)
435
+
436
+ # Pack internal state inputs into (separate) list.
437
+ state_batches = [
438
+ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
439
+ ]
440
+ self._state_in = state_batches
441
+ self._is_recurrent = len(tree.flatten(self._state_in)) > 0
442
+
443
+ # Call the exploration before_compute_actions hook.
444
+ if self.exploration:
445
+ # Policies with RLModules don't have an exploration object.
446
+ self.exploration.before_compute_actions(
447
+ timestep=timestep, explore=explore, tf_sess=self.get_session()
448
+ )
449
+
450
+ ret = self._compute_actions_helper(
451
+ input_dict,
452
+ state_batches,
453
+ # TODO: Passing episodes into a traced method does not work.
454
+ None if self.config["eager_tracing"] else episodes,
455
+ explore,
456
+ timestep,
457
+ )
458
+ # Update our global timestep by the batch size.
459
+ self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
460
+ return convert_to_numpy(ret)
461
+
462
+ # TODO(jungong) : deprecate this API and make compute_actions_from_input_dict the
463
+ # only canonical entry point for inference.
464
+ @override(Policy)
465
+ def compute_actions(
466
+ self,
467
+ obs_batch,
468
+ state_batches=None,
469
+ prev_action_batch=None,
470
+ prev_reward_batch=None,
471
+ info_batch=None,
472
+ episodes=None,
473
+ explore=None,
474
+ timestep=None,
475
+ **kwargs,
476
+ ):
477
+ # Create input dict to simply pass the entire call to
478
+ # self.compute_actions_from_input_dict().
479
+ input_dict = SampleBatch(
480
+ {
481
+ SampleBatch.CUR_OBS: obs_batch,
482
+ },
483
+ _is_training=tf.constant(False),
484
+ )
485
+ if state_batches is not None:
486
+ for s in enumerate(state_batches):
487
+ input_dict["state_in_{i}"] = s
488
+ if prev_action_batch is not None:
489
+ input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
490
+ if prev_reward_batch is not None:
491
+ input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
492
+ if info_batch is not None:
493
+ input_dict[SampleBatch.INFOS] = info_batch
494
+
495
+ return self.compute_actions_from_input_dict(
496
+ input_dict=input_dict,
497
+ explore=explore,
498
+ timestep=timestep,
499
+ episodes=episodes,
500
+ **kwargs,
501
+ )
502
+
503
+ @with_lock
504
+ @override(Policy)
505
+ def compute_log_likelihoods(
506
+ self,
507
+ actions: Union[List[TensorType], TensorType],
508
+ obs_batch: Union[List[TensorType], TensorType],
509
+ state_batches: Optional[List[TensorType]] = None,
510
+ prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
511
+ prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
512
+ actions_normalized: bool = True,
513
+ in_training: bool = True,
514
+ ) -> TensorType:
515
+ if is_overridden(self.action_sampler_fn) and not is_overridden(
516
+ self.action_distribution_fn
517
+ ):
518
+ raise ValueError(
519
+ "Cannot compute log-prob/likelihood w/o an "
520
+ "`action_distribution_fn` and a provided "
521
+ "`action_sampler_fn`!"
522
+ )
523
+
524
+ seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
525
+ input_batch = SampleBatch(
526
+ {
527
+ SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
528
+ SampleBatch.ACTIONS: actions,
529
+ },
530
+ _is_training=False,
531
+ )
532
+ if prev_action_batch is not None:
533
+ input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
534
+ prev_action_batch
535
+ )
536
+ if prev_reward_batch is not None:
537
+ input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
538
+ prev_reward_batch
539
+ )
540
+
541
+ # Exploration hook before each forward pass.
542
+ if self.exploration:
543
+ # Policies with RLModules don't have an exploration object.
544
+ self.exploration.before_compute_actions(explore=False)
545
+
546
+ # Action dist class and inputs are generated via custom function.
547
+ if is_overridden(self.action_distribution_fn):
548
+ dist_inputs, self.dist_class, _ = self.action_distribution_fn(
549
+ self, self.model, input_batch, explore=False, is_training=False
550
+ )
551
+ action_dist = self.dist_class(dist_inputs, self.model)
552
+ # Default log-likelihood calculation.
553
+ else:
554
+ dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
555
+ action_dist = self.dist_class(dist_inputs, self.model)
556
+
557
+ # Normalize actions if necessary.
558
+ if not actions_normalized and self.config["normalize_actions"]:
559
+ actions = normalize_action(actions, self.action_space_struct)
560
+
561
+ log_likelihoods = action_dist.logp(actions)
562
+
563
+ return log_likelihoods
564
+
565
+ @with_lock
566
+ @override(Policy)
567
+ def learn_on_batch(self, postprocessed_batch):
568
+ # Callback handling.
569
+ learn_stats = {}
570
+ self.callbacks.on_learn_on_batch(
571
+ policy=self, train_batch=postprocessed_batch, result=learn_stats
572
+ )
573
+
574
+ pad_batch_to_sequences_of_same_size(
575
+ postprocessed_batch,
576
+ max_seq_len=self._max_seq_len,
577
+ shuffle=False,
578
+ batch_divisibility_req=self.batch_divisibility_req,
579
+ view_requirements=self.view_requirements,
580
+ )
581
+
582
+ self._is_training = True
583
+ postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
584
+ postprocessed_batch.set_training(True)
585
+ stats = self._learn_on_batch_helper(postprocessed_batch)
586
+ self.num_grad_updates += 1
587
+
588
+ stats.update(
589
+ {
590
+ "custom_metrics": learn_stats,
591
+ NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
592
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
593
+ # -1, b/c we have to measure this diff before we do the update above.
594
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
595
+ self.num_grad_updates
596
+ - 1
597
+ - (postprocessed_batch.num_grad_updates or 0)
598
+ ),
599
+ }
600
+ )
601
+
602
+ return convert_to_numpy(stats)
603
+
604
+ @override(Policy)
605
+ def compute_gradients(
606
+ self, postprocessed_batch: SampleBatch
607
+ ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
608
+
609
+ pad_batch_to_sequences_of_same_size(
610
+ postprocessed_batch,
611
+ shuffle=False,
612
+ max_seq_len=self._max_seq_len,
613
+ batch_divisibility_req=self.batch_divisibility_req,
614
+ view_requirements=self.view_requirements,
615
+ )
616
+
617
+ self._is_training = True
618
+ self._lazy_tensor_dict(postprocessed_batch)
619
+ postprocessed_batch.set_training(True)
620
+ grads_and_vars, grads, stats = self._compute_gradients_helper(
621
+ postprocessed_batch
622
+ )
623
+ return convert_to_numpy((grads, stats))
624
+
625
+ @override(Policy)
626
+ def apply_gradients(self, gradients: ModelGradients) -> None:
627
+ self._apply_gradients_helper(
628
+ list(
629
+ zip(
630
+ [
631
+ (tf.convert_to_tensor(g) if g is not None else None)
632
+ for g in gradients
633
+ ],
634
+ self.model.trainable_variables(),
635
+ )
636
+ )
637
+ )
638
+
639
+ @override(Policy)
640
+ def get_weights(self, as_dict=False):
641
+ variables = self.variables()
642
+ if as_dict:
643
+ return {v.name: v.numpy() for v in variables}
644
+ return [v.numpy() for v in variables]
645
+
646
+ @override(Policy)
647
+ def set_weights(self, weights):
648
+ variables = self.variables()
649
+ assert len(weights) == len(variables), (len(weights), len(variables))
650
+ for v, w in zip(variables, weights):
651
+ v.assign(w)
652
+
653
+ @override(Policy)
654
+ def get_exploration_state(self):
655
+ return convert_to_numpy(self.exploration.get_state())
656
+
657
+ @override(Policy)
658
+ def is_recurrent(self):
659
+ return self._is_recurrent
660
+
661
+ @override(Policy)
662
+ def num_state_tensors(self):
663
+ return len(self._state_inputs)
664
+
665
+ @override(Policy)
666
+ def get_initial_state(self):
667
+ if hasattr(self, "model"):
668
+ return self.model.get_initial_state()
669
+ return []
670
+
671
+ @override(Policy)
672
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
673
+ def get_state(self) -> PolicyState:
674
+ # Legacy Policy state (w/o keras model and w/o PolicySpec).
675
+ state = super().get_state()
676
+
677
+ state["global_timestep"] = state["global_timestep"].numpy()
678
+ # In the new Learner API stack, the optimizers live in the learner.
679
+ state["_optimizer_variables"] = []
680
+ if self._optimizer and len(self._optimizer.variables()) > 0:
681
+ state["_optimizer_variables"] = self._optimizer.variables()
682
+
683
+ # Add exploration state.
684
+ if self.exploration:
685
+ # This is not compatible with RLModules, which have a method
686
+ # `forward_exploration` to specify custom exploration behavior.
687
+ state["_exploration_state"] = self.exploration.get_state()
688
+
689
+ return state
690
+
691
+ @override(Policy)
692
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
693
+ def set_state(self, state: PolicyState) -> None:
694
+ # Set optimizer vars.
695
+ optimizer_vars = state.get("_optimizer_variables", None)
696
+ if optimizer_vars and self._optimizer.variables():
697
+ if not type(self).__name__.endswith("_traced") and log_once(
698
+ "set_state_optimizer_vars_tf_eager_policy_v2"
699
+ ):
700
+ logger.warning(
701
+ "Cannot restore an optimizer's state for tf eager! Keras "
702
+ "is not able to save the v1.x optimizers (from "
703
+ "tf.compat.v1.train) since they aren't compatible with "
704
+ "checkpoints."
705
+ )
706
+ for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
707
+ opt_var.assign(value)
708
+ # Set exploration's state.
709
+ if hasattr(self, "exploration") and "_exploration_state" in state:
710
+ self.exploration.set_state(state=state["_exploration_state"])
711
+
712
+ # Restore glbal timestep (tf vars).
713
+ self.global_timestep.assign(state["global_timestep"])
714
+
715
+ # Then the Policy's (NN) weights and connectors.
716
+ super().set_state(state)
717
+
718
+ @override(Policy)
719
+ def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
720
+ if onnx:
721
+ try:
722
+ import tf2onnx
723
+ except ImportError as e:
724
+ raise RuntimeError(
725
+ "Converting a TensorFlow model to ONNX requires "
726
+ "`tf2onnx` to be installed. Install with "
727
+ "`pip install tf2onnx`."
728
+ ) from e
729
+
730
+ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
731
+ self.model.base_model,
732
+ output_path=os.path.join(export_dir, "model.onnx"),
733
+ )
734
+ # Save the tf.keras.Model (architecture and weights, so it can be retrieved
735
+ # w/o access to the original (custom) Model or Policy code).
736
+ elif (
737
+ hasattr(self, "model")
738
+ and hasattr(self.model, "base_model")
739
+ and isinstance(self.model.base_model, tf.keras.Model)
740
+ ):
741
+ try:
742
+ self.model.base_model.save(export_dir, save_format="tf")
743
+ except Exception:
744
+ logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
745
+ else:
746
+ logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
747
+
748
+ def variables(self):
749
+ """Return the list of all savable variables for this policy."""
750
+ if isinstance(self.model, tf.keras.Model):
751
+ return self.model.variables
752
+ else:
753
+ return self.model.variables()
754
+
755
+ def loss_initialized(self):
756
+ return self._loss_initialized
757
+
758
+ @with_lock
759
+ def _compute_actions_helper(
760
+ self,
761
+ input_dict,
762
+ state_batches,
763
+ episodes,
764
+ explore,
765
+ timestep,
766
+ _ray_trace_ctx=None,
767
+ ):
768
+ # Increase the tracing counter to make sure we don't re-trace too
769
+ # often. If eager_tracing=True, this counter should only get
770
+ # incremented during the @tf.function trace operations, never when
771
+ # calling the already traced function after that.
772
+ self._re_trace_counter += 1
773
+
774
+ # Calculate RNN sequence lengths.
775
+ if SampleBatch.SEQ_LENS in input_dict:
776
+ seq_lens = input_dict[SampleBatch.SEQ_LENS]
777
+ else:
778
+ batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
779
+ seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
780
+
781
+ # Add default and custom fetches.
782
+ extra_fetches = {}
783
+
784
+ with tf.variable_creator_scope(_disallow_var_creation):
785
+
786
+ if is_overridden(self.action_sampler_fn):
787
+ actions, logp, dist_inputs, state_out = self.action_sampler_fn(
788
+ self.model,
789
+ input_dict[SampleBatch.OBS],
790
+ explore=explore,
791
+ timestep=timestep,
792
+ episodes=episodes,
793
+ )
794
+ else:
795
+ # Try `action_distribution_fn`.
796
+ if is_overridden(self.action_distribution_fn):
797
+ (
798
+ dist_inputs,
799
+ self.dist_class,
800
+ state_out,
801
+ ) = self.action_distribution_fn(
802
+ self.model,
803
+ obs_batch=input_dict[SampleBatch.OBS],
804
+ state_batches=state_batches,
805
+ seq_lens=seq_lens,
806
+ explore=explore,
807
+ timestep=timestep,
808
+ is_training=False,
809
+ )
810
+ elif isinstance(self.model, tf.keras.Model):
811
+ if state_batches and "state_in_0" not in input_dict:
812
+ for i, s in enumerate(state_batches):
813
+ input_dict[f"state_in_{i}"] = s
814
+ self._lazy_tensor_dict(input_dict)
815
+ dist_inputs, state_out, extra_fetches = self.model(input_dict)
816
+ else:
817
+ dist_inputs, state_out = self.model(
818
+ input_dict, state_batches, seq_lens
819
+ )
820
+
821
+ action_dist = self.dist_class(dist_inputs, self.model)
822
+
823
+ # Get the exploration action from the forward results.
824
+ actions, logp = self.exploration.get_exploration_action(
825
+ action_distribution=action_dist,
826
+ timestep=timestep,
827
+ explore=explore,
828
+ )
829
+
830
+ # Action-logp and action-prob.
831
+ if logp is not None:
832
+ extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
833
+ extra_fetches[SampleBatch.ACTION_LOGP] = logp
834
+ # Action-dist inputs.
835
+ if dist_inputs is not None:
836
+ extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
837
+ # Custom extra fetches.
838
+ extra_fetches.update(self.extra_action_out_fn())
839
+
840
+ return actions, state_out, extra_fetches
841
+
842
+ # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
843
+ # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
844
+ # It seems there may be a clash between the traced-by-tf function and the
845
+ # traced-by-ray functions (for making the policy class a ray actor).
846
+ def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
847
+ # Increase the tracing counter to make sure we don't re-trace too
848
+ # often. If eager_tracing=True, this counter should only get
849
+ # incremented during the @tf.function trace operations, never when
850
+ # calling the already traced function after that.
851
+ self._re_trace_counter += 1
852
+
853
+ with tf.variable_creator_scope(_disallow_var_creation):
854
+ grads_and_vars, _, stats = self._compute_gradients_helper(samples)
855
+ self._apply_gradients_helper(grads_and_vars)
856
+ return stats
857
+
858
+ def _get_is_training_placeholder(self):
859
+ return tf.convert_to_tensor(self._is_training)
860
+
861
+ @with_lock
862
+ def _compute_gradients_helper(self, samples):
863
+ """Computes and returns grads as eager tensors."""
864
+
865
+ # Increase the tracing counter to make sure we don't re-trace too
866
+ # often. If eager_tracing=True, this counter should only get
867
+ # incremented during the @tf.function trace operations, never when
868
+ # calling the already traced function after that.
869
+ self._re_trace_counter += 1
870
+
871
+ # Gather all variables for which to calculate losses.
872
+ if isinstance(self.model, tf.keras.Model):
873
+ variables = self.model.trainable_variables
874
+ else:
875
+ variables = self.model.trainable_variables()
876
+
877
+ # Calculate the loss(es) inside a tf GradientTape.
878
+ with tf.GradientTape(
879
+ persistent=is_overridden(self.compute_gradients_fn)
880
+ ) as tape:
881
+ losses = self.loss(self.model, self.dist_class, samples)
882
+ losses = force_list(losses)
883
+
884
+ # User provided a custom compute_gradients_fn.
885
+ if is_overridden(self.compute_gradients_fn):
886
+ # Wrap our tape inside a wrapper, such that the resulting
887
+ # object looks like a "classic" tf.optimizer. This way, custom
888
+ # compute_gradients_fn will work on both tf static graph
889
+ # and tf-eager.
890
+ optimizer = _OptimizerWrapper(tape)
891
+ # More than one loss terms/optimizers.
892
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
893
+ grads_and_vars = self.compute_gradients_fn(
894
+ [optimizer] * len(losses), losses
895
+ )
896
+ # Only one loss and one optimizer.
897
+ else:
898
+ grads_and_vars = [self.compute_gradients_fn(optimizer, losses[0])]
899
+ # Default: Compute gradients using the above tape.
900
+ else:
901
+ grads_and_vars = [
902
+ list(zip(tape.gradient(loss, variables), variables)) for loss in losses
903
+ ]
904
+
905
+ if log_once("grad_vars"):
906
+ for g_and_v in grads_and_vars:
907
+ for g, v in g_and_v:
908
+ if g is not None:
909
+ logger.info(f"Optimizing variable {v.name}")
910
+
911
+ # `grads_and_vars` is returned a list (len=num optimizers/losses)
912
+ # of lists of (grad, var) tuples.
913
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
914
+ grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
915
+ # `grads_and_vars` is returned as a list of (grad, var) tuples.
916
+ else:
917
+ grads_and_vars = grads_and_vars[0]
918
+ grads = [g for g, _ in grads_and_vars]
919
+
920
+ stats = self._stats(samples, grads)
921
+ return grads_and_vars, grads, stats
922
+
923
+ def _apply_gradients_helper(self, grads_and_vars):
924
+ # Increase the tracing counter to make sure we don't re-trace too
925
+ # often. If eager_tracing=True, this counter should only get
926
+ # incremented during the @tf.function trace operations, never when
927
+ # calling the already traced function after that.
928
+ self._re_trace_counter += 1
929
+
930
+ if is_overridden(self.apply_gradients_fn):
931
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
932
+ self.apply_gradients_fn(self._optimizers, grads_and_vars)
933
+ else:
934
+ self.apply_gradients_fn(self._optimizer, grads_and_vars)
935
+ else:
936
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
937
+ for i, o in enumerate(self._optimizers):
938
+ o.apply_gradients(
939
+ [(g, v) for g, v in grads_and_vars[i] if g is not None]
940
+ )
941
+ else:
942
+ self._optimizer.apply_gradients(
943
+ [(g, v) for g, v in grads_and_vars if g is not None]
944
+ )
945
+
946
+ def _stats(self, samples, grads):
947
+ fetches = {}
948
+ if is_overridden(self.stats_fn):
949
+ fetches[LEARNER_STATS_KEY] = dict(self.stats_fn(samples))
950
+ else:
951
+ fetches[LEARNER_STATS_KEY] = {}
952
+
953
+ fetches.update(dict(self.extra_learn_fetches_fn()))
954
+ fetches.update(dict(self.grad_stats_fn(samples, grads)))
955
+ return fetches
956
+
957
+ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
958
+ # TODO: (sven): Keep for a while to ensure backward compatibility.
959
+ if not isinstance(postprocessed_batch, SampleBatch):
960
+ postprocessed_batch = SampleBatch(postprocessed_batch)
961
+ postprocessed_batch.set_get_interceptor(_convert_to_tf)
962
+ return postprocessed_batch
963
+
964
+ @classmethod
965
+ def with_tracing(cls):
966
+ return _traced_eager_policy(cls)
.venv/lib/python3.11/site-packages/ray/rllib/policy/policy.py ADDED
@@ -0,0 +1,1696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import platform
5
+ from abc import ABCMeta, abstractmethod
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Collection,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ )
17
+
18
+ import gymnasium as gym
19
+ import numpy as np
20
+ import tree # pip install dm_tree
21
+ from gymnasium.spaces import Box
22
+ from packaging import version
23
+
24
+ import ray
25
+ import ray.cloudpickle as pickle
26
+ from ray.actor import ActorHandle
27
+ from ray.train import Checkpoint
28
+ from ray.rllib.models.action_dist import ActionDistribution
29
+ from ray.rllib.models.catalog import ModelCatalog
30
+ from ray.rllib.models.modelv2 import ModelV2
31
+ from ray.rllib.policy.sample_batch import SampleBatch
32
+ from ray.rllib.policy.view_requirement import ViewRequirement
33
+ from ray.rllib.utils.annotations import (
34
+ OldAPIStack,
35
+ OverrideToImplementCustomLogic,
36
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
37
+ is_overridden,
38
+ )
39
+ from ray.rllib.utils.checkpoints import (
40
+ CHECKPOINT_VERSION,
41
+ get_checkpoint_info,
42
+ try_import_msgpack,
43
+ )
44
+ from ray.rllib.utils.deprecation import (
45
+ DEPRECATED_VALUE,
46
+ deprecation_warning,
47
+ )
48
+ from ray.rllib.utils.exploration.exploration import Exploration
49
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
50
+ from ray.rllib.utils.from_config import from_config
51
+ from ray.rllib.utils.numpy import convert_to_numpy
52
+ from ray.rllib.utils.serialization import (
53
+ deserialize_type,
54
+ space_from_dict,
55
+ space_to_dict,
56
+ )
57
+ from ray.rllib.utils.spaces.space_utils import (
58
+ get_base_struct_from_space,
59
+ get_dummy_batch_for_space,
60
+ unbatch,
61
+ )
62
+ from ray.rllib.utils.tensor_dtype import get_np_dtype
63
+ from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
64
+ from ray.rllib.utils.typing import (
65
+ AgentID,
66
+ AlgorithmConfigDict,
67
+ ModelGradients,
68
+ ModelWeights,
69
+ PolicyID,
70
+ PolicyState,
71
+ T,
72
+ TensorStructType,
73
+ TensorType,
74
+ )
75
+
76
+ tf1, tf, tfv = try_import_tf()
77
+ torch, _ = try_import_torch()
78
+
79
+
80
+ logger = logging.getLogger(__name__)
81
+
82
+
83
+ @OldAPIStack
84
+ class PolicySpec:
85
+ """A policy spec used in the "config.multiagent.policies" specification dict.
86
+
87
+ As values (keys are the policy IDs (str)). E.g.:
88
+ config:
89
+ multiagent:
90
+ policies: {
91
+ "pol1": PolicySpec(None, Box, Discrete(2), {"lr": 0.0001}),
92
+ "pol2": PolicySpec(config={"lr": 0.001}),
93
+ }
94
+ """
95
+
96
+ def __init__(
97
+ self, policy_class=None, observation_space=None, action_space=None, config=None
98
+ ):
99
+ # If None, use the Algorithm's default policy class stored under
100
+ # `Algorithm._policy_class`.
101
+ self.policy_class = policy_class
102
+ # If None, use the env's observation space. If None and there is no Env
103
+ # (e.g. offline RL), an error is thrown.
104
+ self.observation_space = observation_space
105
+ # If None, use the env's action space. If None and there is no Env
106
+ # (e.g. offline RL), an error is thrown.
107
+ self.action_space = action_space
108
+ # Overrides defined keys in the main Algorithm config.
109
+ # If None, use {}.
110
+ self.config = config
111
+
112
+ def __eq__(self, other: "PolicySpec"):
113
+ return (
114
+ self.policy_class == other.policy_class
115
+ and self.observation_space == other.observation_space
116
+ and self.action_space == other.action_space
117
+ and self.config == other.config
118
+ )
119
+
120
+ def get_state(self) -> Dict[str, Any]:
121
+ """Returns the state of a `PolicyDict` as a dict."""
122
+ return (
123
+ self.policy_class,
124
+ self.observation_space,
125
+ self.action_space,
126
+ self.config,
127
+ )
128
+
129
+ @classmethod
130
+ def from_state(cls, state: Dict[str, Any]) -> "PolicySpec":
131
+ """Builds a `PolicySpec` from a state."""
132
+ policy_spec = PolicySpec()
133
+ policy_spec.__dict__.update(state)
134
+
135
+ return policy_spec
136
+
137
+ def serialize(self) -> Dict:
138
+ from ray.rllib.algorithms.registry import get_policy_class_name
139
+
140
+ # Try to figure out a durable name for this policy.
141
+ cls = get_policy_class_name(self.policy_class)
142
+ if cls is None:
143
+ logger.warning(
144
+ f"Can not figure out a durable policy name for {self.policy_class}. "
145
+ f"You are probably trying to checkpoint a custom policy. "
146
+ f"Raw policy class may cause problems when the checkpoint needs to "
147
+ "be loaded in the future. To fix this, make sure you add your "
148
+ "custom policy in rllib.algorithms.registry.POLICIES."
149
+ )
150
+ cls = self.policy_class
151
+
152
+ return {
153
+ "policy_class": cls,
154
+ "observation_space": space_to_dict(self.observation_space),
155
+ "action_space": space_to_dict(self.action_space),
156
+ # TODO(jungong) : try making the config dict durable by maybe
157
+ # getting rid of all the fields that are not JSON serializable.
158
+ "config": self.config,
159
+ }
160
+
161
+ @classmethod
162
+ def deserialize(cls, spec: Dict) -> "PolicySpec":
163
+ if isinstance(spec["policy_class"], str):
164
+ # Try to recover the actual policy class from durable name.
165
+ from ray.rllib.algorithms.registry import get_policy_class
166
+
167
+ policy_class = get_policy_class(spec["policy_class"])
168
+ elif isinstance(spec["policy_class"], type):
169
+ # Policy spec is already a class type. Simply use it.
170
+ policy_class = spec["policy_class"]
171
+ else:
172
+ raise AttributeError(f"Unknown policy class spec {spec['policy_class']}")
173
+
174
+ return cls(
175
+ policy_class=policy_class,
176
+ observation_space=space_from_dict(spec["observation_space"]),
177
+ action_space=space_from_dict(spec["action_space"]),
178
+ config=spec["config"],
179
+ )
180
+
181
+
182
+ @OldAPIStack
183
+ class Policy(metaclass=ABCMeta):
184
+ """RLlib's base class for all Policy implementations.
185
+
186
+ Policy is the abstract superclass for all DL-framework specific sub-classes
187
+ (e.g. TFPolicy or TorchPolicy). It exposes APIs to
188
+
189
+ 1. Compute actions from observation (and possibly other) inputs.
190
+
191
+ 2. Manage the Policy's NN model(s), like exporting and loading their weights.
192
+
193
+ 3. Postprocess a given trajectory from the environment or other input via the
194
+ `postprocess_trajectory` method.
195
+
196
+ 4. Compute losses from a train batch.
197
+
198
+ 5. Perform updates from a train batch on the NN-models (this normally includes loss
199
+ calculations) either:
200
+
201
+ a. in one monolithic step (`learn_on_batch`)
202
+
203
+ b. via batch pre-loading, then n steps of actual loss computations and updates
204
+ (`load_batch_into_buffer` + `learn_on_loaded_batch`).
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ observation_space: gym.Space,
210
+ action_space: gym.Space,
211
+ config: AlgorithmConfigDict,
212
+ ):
213
+ """Initializes a Policy instance.
214
+
215
+ Args:
216
+ observation_space: Observation space of the policy.
217
+ action_space: Action space of the policy.
218
+ config: A complete Algorithm/Policy config dict. For the default
219
+ config keys and values, see rllib/algorithm/algorithm.py.
220
+ """
221
+ self.observation_space: gym.Space = observation_space
222
+ self.action_space: gym.Space = action_space
223
+ # the policy id in the global context.
224
+ self.__policy_id = config.get("__policy_id")
225
+ # The base struct of the observation/action spaces.
226
+ # E.g. action-space = gym.spaces.Dict({"a": Discrete(2)}) ->
227
+ # action_space_struct = {"a": Discrete(2)}
228
+ self.observation_space_struct = get_base_struct_from_space(observation_space)
229
+ self.action_space_struct = get_base_struct_from_space(action_space)
230
+
231
+ self.config: AlgorithmConfigDict = config
232
+ self.framework = self.config.get("framework")
233
+
234
+ # Create the callbacks object to use for handling custom callbacks.
235
+ from ray.rllib.callbacks.callbacks import RLlibCallback
236
+
237
+ callbacks = self.config.get("callbacks")
238
+ if isinstance(callbacks, RLlibCallback):
239
+ self.callbacks = callbacks()
240
+ elif isinstance(callbacks, (str, type)):
241
+ try:
242
+ self.callbacks: "RLlibCallback" = deserialize_type(
243
+ self.config.get("callbacks")
244
+ )()
245
+ except Exception:
246
+ pass # TEST
247
+ else:
248
+ self.callbacks: "RLlibCallback" = RLlibCallback()
249
+
250
+ # The global timestep, broadcast down from time to time from the
251
+ # local worker to all remote workers.
252
+ self.global_timestep: int = 0
253
+ # The number of gradient updates this policy has undergone.
254
+ self.num_grad_updates: int = 0
255
+
256
+ # The action distribution class to use for action sampling, if any.
257
+ # Child classes may set this.
258
+ self.dist_class: Optional[Type] = None
259
+
260
+ # Initialize view requirements.
261
+ self.init_view_requirements()
262
+
263
+ # Whether the Model's initial state (method) has been added
264
+ # automatically based on the given view requirements of the model.
265
+ self._model_init_state_automatically_added = False
266
+
267
+ # Connectors.
268
+ self.agent_connectors = None
269
+ self.action_connectors = None
270
+
271
+ @staticmethod
272
+ def from_checkpoint(
273
+ checkpoint: Union[str, Checkpoint],
274
+ policy_ids: Optional[Collection[PolicyID]] = None,
275
+ ) -> Union["Policy", Dict[PolicyID, "Policy"]]:
276
+ """Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.
277
+
278
+ Note: This method must remain backward compatible from 2.1.0 on, wrt.
279
+ checkpoints created with Ray 2.0.0 or later.
280
+
281
+ Args:
282
+ checkpoint: The path (str) to a Policy or Algorithm checkpoint directory
283
+ or an AIR Checkpoint (Policy or Algorithm) instance to restore
284
+ from.
285
+ If checkpoint is a Policy checkpoint, `policy_ids` must be None
286
+ and only the Policy in that checkpoint is restored and returned.
287
+ If checkpoint is an Algorithm checkpoint and `policy_ids` is None,
288
+ will return a list of all Policy objects found in
289
+ the checkpoint, otherwise a list of those policies in `policy_ids`.
290
+ policy_ids: List of policy IDs to extract from a given Algorithm checkpoint.
291
+ If None and an Algorithm checkpoint is provided, will restore all
292
+ policies found in that checkpoint. If a Policy checkpoint is given,
293
+ this arg must be None.
294
+
295
+ Returns:
296
+ An instantiated Policy, if `checkpoint` is a Policy checkpoint. A dict
297
+ mapping PolicyID to Policies, if `checkpoint` is an Algorithm checkpoint.
298
+ In the latter case, returns all policies within the Algorithm if
299
+ `policy_ids` is None, else a dict of only those Policies that are in
300
+ `policy_ids`.
301
+ """
302
+ checkpoint_info = get_checkpoint_info(checkpoint)
303
+
304
+ # Algorithm checkpoint: Extract one or more policies from it and return them
305
+ # in a dict (mapping PolicyID to Policy instances).
306
+ if checkpoint_info["type"] == "Algorithm":
307
+ from ray.rllib.algorithms.algorithm import Algorithm
308
+
309
+ policies = {}
310
+
311
+ # Old Algorithm checkpoints: State must be completely retrieved from:
312
+ # algo state file -> worker -> "state".
313
+ if checkpoint_info["checkpoint_version"] < version.Version("1.0"):
314
+ with open(checkpoint_info["state_file"], "rb") as f:
315
+ state = pickle.load(f)
316
+ # In older checkpoint versions, the policy states are stored under
317
+ # "state" within the worker state (which is pickled in itself).
318
+ worker_state = pickle.loads(state["worker"])
319
+ policy_states = worker_state["state"]
320
+ for pid, policy_state in policy_states.items():
321
+ # Get spec and config, merge config with
322
+ serialized_policy_spec = worker_state["policy_specs"][pid]
323
+ policy_config = Algorithm.merge_algorithm_configs(
324
+ worker_state["policy_config"], serialized_policy_spec["config"]
325
+ )
326
+ serialized_policy_spec.update({"config": policy_config})
327
+ policy_state.update({"policy_spec": serialized_policy_spec})
328
+ policies[pid] = Policy.from_state(policy_state)
329
+ # Newer versions: Get policy states from "policies/" sub-dirs.
330
+ elif checkpoint_info["policy_ids"] is not None:
331
+ for policy_id in checkpoint_info["policy_ids"]:
332
+ if policy_ids is None or policy_id in policy_ids:
333
+ policy_checkpoint_info = get_checkpoint_info(
334
+ os.path.join(
335
+ checkpoint_info["checkpoint_dir"],
336
+ "policies",
337
+ policy_id,
338
+ )
339
+ )
340
+ assert policy_checkpoint_info["type"] == "Policy"
341
+ with open(policy_checkpoint_info["state_file"], "rb") as f:
342
+ policy_state = pickle.load(f)
343
+ policies[policy_id] = Policy.from_state(policy_state)
344
+ return policies
345
+
346
+ # Policy checkpoint: Return a single Policy instance.
347
+ else:
348
+ msgpack = None
349
+ if checkpoint_info.get("format") == "msgpack":
350
+ msgpack = try_import_msgpack(error=True)
351
+
352
+ with open(checkpoint_info["state_file"], "rb") as f:
353
+ if msgpack is not None:
354
+ state = msgpack.load(f)
355
+ else:
356
+ state = pickle.load(f)
357
+ return Policy.from_state(state)
358
+
359
+ @staticmethod
360
+ def from_state(state: PolicyState) -> "Policy":
361
+ """Recovers a Policy from a state object.
362
+
363
+ The `state` of an instantiated Policy can be retrieved by calling its
364
+ `get_state` method. This only works for the V2 Policy classes (EagerTFPolicyV2,
365
+ SynamicTFPolicyV2, and TorchPolicyV2). It contains all information necessary
366
+ to create the Policy. No access to the original code (e.g. configs, knowledge of
367
+ the policy's class, etc..) is needed.
368
+
369
+ Args:
370
+ state: The state to recover a new Policy instance from.
371
+
372
+ Returns:
373
+ A new Policy instance.
374
+ """
375
+ serialized_pol_spec: Optional[dict] = state.get("policy_spec")
376
+ if serialized_pol_spec is None:
377
+ raise ValueError(
378
+ "No `policy_spec` key was found in given `state`! "
379
+ "Cannot create new Policy."
380
+ )
381
+ pol_spec = PolicySpec.deserialize(serialized_pol_spec)
382
+ actual_class = get_tf_eager_cls_if_necessary(
383
+ pol_spec.policy_class,
384
+ pol_spec.config,
385
+ )
386
+
387
+ if pol_spec.config["framework"] == "tf":
388
+ from ray.rllib.policy.tf_policy import TFPolicy
389
+
390
+ return TFPolicy._tf1_from_state_helper(state)
391
+
392
+ # Create the new policy.
393
+ new_policy = actual_class(
394
+ # Note(jungong) : we are intentionally not using keyward arguments here
395
+ # because some policies name the observation space parameter obs_space,
396
+ # and some others name it observation_space.
397
+ pol_spec.observation_space,
398
+ pol_spec.action_space,
399
+ pol_spec.config,
400
+ )
401
+
402
+ # Set the new policy's state (weights, optimizer vars, exploration state,
403
+ # etc..).
404
+ new_policy.set_state(state)
405
+ # Return the new policy.
406
+ return new_policy
407
+
408
+ def init_view_requirements(self):
409
+ """Maximal view requirements dict for `learn_on_batch()` and
410
+ `compute_actions` calls.
411
+ Specific policies can override this function to provide custom
412
+ list of view requirements.
413
+ """
414
+ # Maximal view requirements dict for `learn_on_batch()` and
415
+ # `compute_actions` calls.
416
+ # View requirements will be automatically filtered out later based
417
+ # on the postprocessing and loss functions to ensure optimal data
418
+ # collection and transfer performance.
419
+ view_reqs = self._get_default_view_requirements()
420
+ if not hasattr(self, "view_requirements"):
421
+ self.view_requirements = view_reqs
422
+ else:
423
+ for k, v in view_reqs.items():
424
+ if k not in self.view_requirements:
425
+ self.view_requirements[k] = v
426
+
427
+ def get_connector_metrics(self) -> Dict:
428
+ """Get metrics on timing from connectors."""
429
+ return {
430
+ "agent_connectors": {
431
+ name + "_ms": 1000 * timer.mean
432
+ for name, timer in self.agent_connectors.timers.items()
433
+ },
434
+ "action_connectors": {
435
+ name + "_ms": 1000 * timer.mean
436
+ for name, timer in self.agent_connectors.timers.items()
437
+ },
438
+ }
439
+
440
+ def reset_connectors(self, env_id) -> None:
441
+ """Reset action- and agent-connectors for this policy."""
442
+ self.agent_connectors.reset(env_id=env_id)
443
+ self.action_connectors.reset(env_id=env_id)
444
+
445
+ def compute_single_action(
446
+ self,
447
+ obs: Optional[TensorStructType] = None,
448
+ state: Optional[List[TensorType]] = None,
449
+ *,
450
+ prev_action: Optional[TensorStructType] = None,
451
+ prev_reward: Optional[TensorStructType] = None,
452
+ info: dict = None,
453
+ input_dict: Optional[SampleBatch] = None,
454
+ episode=None,
455
+ explore: Optional[bool] = None,
456
+ timestep: Optional[int] = None,
457
+ # Kwars placeholder for future compatibility.
458
+ **kwargs,
459
+ ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
460
+ """Computes and returns a single (B=1) action value.
461
+
462
+ Takes an input dict (usually a SampleBatch) as its main data input.
463
+ This allows for using this method in case a more complex input pattern
464
+ (view requirements) is needed, for example when the Model requires the
465
+ last n observations, the last m actions/rewards, or a combination
466
+ of any of these.
467
+ Alternatively, in case no complex inputs are required, takes a single
468
+ `obs` values (and possibly single state values, prev-action/reward
469
+ values, etc..).
470
+
471
+ Args:
472
+ obs: Single observation.
473
+ state: List of RNN state inputs, if any.
474
+ prev_action: Previous action value, if any.
475
+ prev_reward: Previous reward, if any.
476
+ info: Info object, if any.
477
+ input_dict: A SampleBatch or input dict containing the
478
+ single (unbatched) Tensors to compute actions. If given, it'll
479
+ be used instead of `obs`, `state`, `prev_action|reward`, and
480
+ `info`.
481
+ episode: This provides access to all of the internal episode state,
482
+ which may be useful for model-based or multi-agent algorithms.
483
+ explore: Whether to pick an exploitation or
484
+ exploration action
485
+ (default: None -> use self.config["explore"]).
486
+ timestep: The current (sampling) time step.
487
+
488
+ Keyword Args:
489
+ kwargs: Forward compatibility placeholder.
490
+
491
+ Returns:
492
+ Tuple consisting of the action, the list of RNN state outputs (if
493
+ any), and a dictionary of extra features (if any).
494
+ """
495
+ # Build the input-dict used for the call to
496
+ # `self.compute_actions_from_input_dict()`.
497
+ if input_dict is None:
498
+ input_dict = {SampleBatch.OBS: obs}
499
+ if state is not None:
500
+ for i, s in enumerate(state):
501
+ input_dict[f"state_in_{i}"] = s
502
+ if prev_action is not None:
503
+ input_dict[SampleBatch.PREV_ACTIONS] = prev_action
504
+ if prev_reward is not None:
505
+ input_dict[SampleBatch.PREV_REWARDS] = prev_reward
506
+ if info is not None:
507
+ input_dict[SampleBatch.INFOS] = info
508
+
509
+ # Batch all data in input dict.
510
+ input_dict = tree.map_structure_with_path(
511
+ lambda p, s: (
512
+ s
513
+ if p == "seq_lens"
514
+ else s.unsqueeze(0)
515
+ if torch and isinstance(s, torch.Tensor)
516
+ else np.expand_dims(s, 0)
517
+ ),
518
+ input_dict,
519
+ )
520
+
521
+ episodes = None
522
+ if episode is not None:
523
+ episodes = [episode]
524
+
525
+ out = self.compute_actions_from_input_dict(
526
+ input_dict=SampleBatch(input_dict),
527
+ episodes=episodes,
528
+ explore=explore,
529
+ timestep=timestep,
530
+ )
531
+
532
+ # Some policies don't return a tuple, but always just a single action.
533
+ # E.g. ES and ARS.
534
+ if not isinstance(out, tuple):
535
+ single_action = out
536
+ state_out = []
537
+ info = {}
538
+ # Normal case: Policy should return (action, state, info) tuple.
539
+ else:
540
+ batched_action, state_out, info = out
541
+ single_action = unbatch(batched_action)
542
+ assert len(single_action) == 1
543
+ single_action = single_action[0]
544
+
545
+ # Return action, internal state(s), infos.
546
+ return (
547
+ single_action,
548
+ tree.map_structure(lambda x: x[0], state_out),
549
+ tree.map_structure(lambda x: x[0], info),
550
+ )
551
+
552
+ def compute_actions_from_input_dict(
553
+ self,
554
+ input_dict: Union[SampleBatch, Dict[str, TensorStructType]],
555
+ explore: Optional[bool] = None,
556
+ timestep: Optional[int] = None,
557
+ episodes=None,
558
+ **kwargs,
559
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
560
+ """Computes actions from collected samples (across multiple-agents).
561
+
562
+ Takes an input dict (usually a SampleBatch) as its main data input.
563
+ This allows for using this method in case a more complex input pattern
564
+ (view requirements) is needed, for example when the Model requires the
565
+ last n observations, the last m actions/rewards, or a combination
566
+ of any of these.
567
+
568
+ Args:
569
+ input_dict: A SampleBatch or input dict containing the Tensors
570
+ to compute actions. `input_dict` already abides to the
571
+ Policy's as well as the Model's view requirements and can
572
+ thus be passed to the Model as-is.
573
+ explore: Whether to pick an exploitation or exploration
574
+ action (default: None -> use self.config["explore"]).
575
+ timestep: The current (sampling) time step.
576
+ episodes: This provides access to all of the internal episodes'
577
+ state, which may be useful for model-based or multi-agent
578
+ algorithms.
579
+
580
+ Keyword Args:
581
+ kwargs: Forward compatibility placeholder.
582
+
583
+ Returns:
584
+ actions: Batch of output actions, with shape like
585
+ [BATCH_SIZE, ACTION_SHAPE].
586
+ state_outs: List of RNN state output
587
+ batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
588
+ info: Dictionary of extra feature batches, if any, with shape like
589
+ {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
590
+ """
591
+ # Default implementation just passes obs, prev-a/r, and states on to
592
+ # `self.compute_actions()`.
593
+ state_batches = [s for k, s in input_dict.items() if k.startswith("state_in")]
594
+ return self.compute_actions(
595
+ input_dict[SampleBatch.OBS],
596
+ state_batches,
597
+ prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
598
+ prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
599
+ info_batch=input_dict.get(SampleBatch.INFOS),
600
+ explore=explore,
601
+ timestep=timestep,
602
+ episodes=episodes,
603
+ **kwargs,
604
+ )
605
+
606
+ @abstractmethod
607
+ def compute_actions(
608
+ self,
609
+ obs_batch: Union[List[TensorStructType], TensorStructType],
610
+ state_batches: Optional[List[TensorType]] = None,
611
+ prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
612
+ prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
613
+ info_batch: Optional[Dict[str, list]] = None,
614
+ episodes: Optional[List] = None,
615
+ explore: Optional[bool] = None,
616
+ timestep: Optional[int] = None,
617
+ **kwargs,
618
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
619
+ """Computes actions for the current policy.
620
+
621
+ Args:
622
+ obs_batch: Batch of observations.
623
+ state_batches: List of RNN state input batches, if any.
624
+ prev_action_batch: Batch of previous action values.
625
+ prev_reward_batch: Batch of previous rewards.
626
+ info_batch: Batch of info objects.
627
+ episodes: List of Episode objects, one for each obs in
628
+ obs_batch. This provides access to all of the internal
629
+ episode state, which may be useful for model-based or
630
+ multi-agent algorithms.
631
+ explore: Whether to pick an exploitation or exploration action.
632
+ Set to None (default) for using the value of
633
+ `self.config["explore"]`.
634
+ timestep: The current (sampling) time step.
635
+
636
+ Keyword Args:
637
+ kwargs: Forward compatibility placeholder
638
+
639
+ Returns:
640
+ actions: Batch of output actions, with shape like
641
+ [BATCH_SIZE, ACTION_SHAPE].
642
+ state_outs (List[TensorType]): List of RNN state output
643
+ batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
644
+ info (List[dict]): Dictionary of extra feature batches, if any,
645
+ with shape like
646
+ {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
647
+ """
648
+ raise NotImplementedError
649
+
650
+ def compute_log_likelihoods(
651
+ self,
652
+ actions: Union[List[TensorType], TensorType],
653
+ obs_batch: Union[List[TensorType], TensorType],
654
+ state_batches: Optional[List[TensorType]] = None,
655
+ prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
656
+ prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
657
+ actions_normalized: bool = True,
658
+ in_training: bool = True,
659
+ ) -> TensorType:
660
+ """Computes the log-prob/likelihood for a given action and observation.
661
+
662
+ The log-likelihood is calculated using this Policy's action
663
+ distribution class (self.dist_class).
664
+
665
+ Args:
666
+ actions: Batch of actions, for which to retrieve the
667
+ log-probs/likelihoods (given all other inputs: obs,
668
+ states, ..).
669
+ obs_batch: Batch of observations.
670
+ state_batches: List of RNN state input batches, if any.
671
+ prev_action_batch: Batch of previous action values.
672
+ prev_reward_batch: Batch of previous rewards.
673
+ actions_normalized: Is the given `actions` already normalized
674
+ (between -1.0 and 1.0) or not? If not and
675
+ `normalize_actions=True`, we need to normalize the given
676
+ actions first, before calculating log likelihoods.
677
+ in_training: Whether to use the forward_train() or forward_exploration() of
678
+ the underlying RLModule.
679
+ Returns:
680
+ Batch of log probs/likelihoods, with shape: [BATCH_SIZE].
681
+ """
682
+ raise NotImplementedError
683
+
684
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
685
+ def postprocess_trajectory(
686
+ self,
687
+ sample_batch: SampleBatch,
688
+ other_agent_batches: Optional[
689
+ Dict[AgentID, Tuple["Policy", SampleBatch]]
690
+ ] = None,
691
+ episode=None,
692
+ ) -> SampleBatch:
693
+ """Implements algorithm-specific trajectory postprocessing.
694
+
695
+ This will be called on each trajectory fragment computed during policy
696
+ evaluation. Each fragment is guaranteed to be only from one episode.
697
+ The given fragment may or may not contain the end of this episode,
698
+ depending on the `batch_mode=truncate_episodes|complete_episodes`,
699
+ `rollout_fragment_length`, and other settings.
700
+
701
+ Args:
702
+ sample_batch: batch of experiences for the policy,
703
+ which will contain at most one episode trajectory.
704
+ other_agent_batches: In a multi-agent env, this contains a
705
+ mapping of agent ids to (policy, agent_batch) tuples
706
+ containing the policy and experiences of the other agents.
707
+ episode: An optional multi-agent episode object to provide
708
+ access to all of the internal episode state, which may
709
+ be useful for model-based or multi-agent algorithms.
710
+
711
+ Returns:
712
+ The postprocessed sample batch.
713
+ """
714
+ # The default implementation just returns the same, unaltered batch.
715
+ return sample_batch
716
+
717
+ @OverrideToImplementCustomLogic
718
+ def loss(
719
+ self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch
720
+ ) -> Union[TensorType, List[TensorType]]:
721
+ """Loss function for this Policy.
722
+
723
+ Override this method in order to implement custom loss computations.
724
+
725
+ Args:
726
+ model: The model to calculate the loss(es).
727
+ dist_class: The action distribution class to sample actions
728
+ from the model's outputs.
729
+ train_batch: The input batch on which to calculate the loss.
730
+
731
+ Returns:
732
+ Either a single loss tensor or a list of loss tensors.
733
+ """
734
+ raise NotImplementedError
735
+
736
+ def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]:
737
+ """Perform one learning update, given `samples`.
738
+
739
+ Either this method or the combination of `compute_gradients` and
740
+ `apply_gradients` must be implemented by subclasses.
741
+
742
+ Args:
743
+ samples: The SampleBatch object to learn from.
744
+
745
+ Returns:
746
+ Dictionary of extra metadata from `compute_gradients()`.
747
+
748
+ .. testcode::
749
+ :skipif: True
750
+
751
+ policy, sample_batch = ...
752
+ policy.learn_on_batch(sample_batch)
753
+ """
754
+ # The default implementation is simply a fused `compute_gradients` plus
755
+ # `apply_gradients` call.
756
+ grads, grad_info = self.compute_gradients(samples)
757
+ self.apply_gradients(grads)
758
+ return grad_info
759
+
760
+ def learn_on_batch_from_replay_buffer(
761
+ self, replay_actor: ActorHandle, policy_id: PolicyID
762
+ ) -> Dict[str, TensorType]:
763
+ """Samples a batch from given replay actor and performs an update.
764
+
765
+ Args:
766
+ replay_actor: The replay buffer actor to sample from.
767
+ policy_id: The ID of this policy.
768
+
769
+ Returns:
770
+ Dictionary of extra metadata from `compute_gradients()`.
771
+ """
772
+ # Sample a batch from the given replay actor.
773
+ # Note that for better performance (less data sent through the
774
+ # network), this policy should be co-located on the same node
775
+ # as `replay_actor`. Such a co-location step is usually done during
776
+ # the Algorithm's `setup()` phase.
777
+ batch = ray.get(replay_actor.replay.remote(policy_id=policy_id))
778
+ if batch is None:
779
+ return {}
780
+
781
+ # Send to own learn_on_batch method for updating.
782
+ # TODO: hack w/ `hasattr`
783
+ if hasattr(self, "devices") and len(self.devices) > 1:
784
+ self.load_batch_into_buffer(batch, buffer_index=0)
785
+ return self.learn_on_loaded_batch(offset=0, buffer_index=0)
786
+ else:
787
+ return self.learn_on_batch(batch)
788
+
789
+ def load_batch_into_buffer(self, batch: SampleBatch, buffer_index: int = 0) -> int:
790
+ """Bulk-loads the given SampleBatch into the devices' memories.
791
+
792
+ The data is split equally across all the Policy's devices.
793
+ If the data is not evenly divisible by the batch size, excess data
794
+ should be discarded.
795
+
796
+ Args:
797
+ batch: The SampleBatch to load.
798
+ buffer_index: The index of the buffer (a MultiGPUTowerStack) to use
799
+ on the devices. The number of buffers on each device depends
800
+ on the value of the `num_multi_gpu_tower_stacks` config key.
801
+
802
+ Returns:
803
+ The number of tuples loaded per device.
804
+ """
805
+ raise NotImplementedError
806
+
807
+ def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
808
+ """Returns the number of currently loaded samples in the given buffer.
809
+
810
+ Args:
811
+ buffer_index: The index of the buffer (a MultiGPUTowerStack)
812
+ to use on the devices. The number of buffers on each device
813
+ depends on the value of the `num_multi_gpu_tower_stacks` config
814
+ key.
815
+
816
+ Returns:
817
+ The number of tuples loaded per device.
818
+ """
819
+ raise NotImplementedError
820
+
821
+ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
822
+ """Runs a single step of SGD on an already loaded data in a buffer.
823
+
824
+ Runs an SGD step over a slice of the pre-loaded batch, offset by
825
+ the `offset` argument (useful for performing n minibatch SGD
826
+ updates repeatedly on the same, already pre-loaded data).
827
+
828
+ Updates the model weights based on the averaged per-device gradients.
829
+
830
+ Args:
831
+ offset: Offset into the preloaded data. Used for pre-loading
832
+ a train-batch once to a device, then iterating over
833
+ (subsampling through) this batch n times doing minibatch SGD.
834
+ buffer_index: The index of the buffer (a MultiGPUTowerStack)
835
+ to take the already pre-loaded data from. The number of buffers
836
+ on each device depends on the value of the
837
+ `num_multi_gpu_tower_stacks` config key.
838
+
839
+ Returns:
840
+ The outputs of extra_ops evaluated over the batch.
841
+ """
842
+ raise NotImplementedError
843
+
844
+ def compute_gradients(
845
+ self, postprocessed_batch: SampleBatch
846
+ ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
847
+ """Computes gradients given a batch of experiences.
848
+
849
+ Either this in combination with `apply_gradients()` or
850
+ `learn_on_batch()` must be implemented by subclasses.
851
+
852
+ Args:
853
+ postprocessed_batch: The SampleBatch object to use
854
+ for calculating gradients.
855
+
856
+ Returns:
857
+ grads: List of gradient output values.
858
+ grad_info: Extra policy-specific info values.
859
+ """
860
+ raise NotImplementedError
861
+
862
+ def apply_gradients(self, gradients: ModelGradients) -> None:
863
+ """Applies the (previously) computed gradients.
864
+
865
+ Either this in combination with `compute_gradients()` or
866
+ `learn_on_batch()` must be implemented by subclasses.
867
+
868
+ Args:
869
+ gradients: The already calculated gradients to apply to this
870
+ Policy.
871
+ """
872
+ raise NotImplementedError
873
+
874
+ def get_weights(self) -> ModelWeights:
875
+ """Returns model weights.
876
+
877
+ Note: The return value of this method will reside under the "weights"
878
+ key in the return value of Policy.get_state(). Model weights are only
879
+ one part of a Policy's state. Other state information contains:
880
+ optimizer variables, exploration state, and global state vars such as
881
+ the sampling timestep.
882
+
883
+ Returns:
884
+ Serializable copy or view of model weights.
885
+ """
886
+ raise NotImplementedError
887
+
888
+ def set_weights(self, weights: ModelWeights) -> None:
889
+ """Sets this Policy's model's weights.
890
+
891
+ Note: Model weights are only one part of a Policy's state. Other
892
+ state information contains: optimizer variables, exploration state,
893
+ and global state vars such as the sampling timestep.
894
+
895
+ Args:
896
+ weights: Serializable copy or view of model weights.
897
+ """
898
+ raise NotImplementedError
899
+
900
+ def get_exploration_state(self) -> Dict[str, TensorType]:
901
+ """Returns the state of this Policy's exploration component.
902
+
903
+ Returns:
904
+ Serializable information on the `self.exploration` object.
905
+ """
906
+ return self.exploration.get_state()
907
+
908
+ def is_recurrent(self) -> bool:
909
+ """Whether this Policy holds a recurrent Model.
910
+
911
+ Returns:
912
+ True if this Policy has-a RNN-based Model.
913
+ """
914
+ return False
915
+
916
+ def num_state_tensors(self) -> int:
917
+ """The number of internal states needed by the RNN-Model of the Policy.
918
+
919
+ Returns:
920
+ int: The number of RNN internal states kept by this Policy's Model.
921
+ """
922
+ return 0
923
+
924
+ def get_initial_state(self) -> List[TensorType]:
925
+ """Returns initial RNN state for the current policy.
926
+
927
+ Returns:
928
+ List[TensorType]: Initial RNN state for the current policy.
929
+ """
930
+ return []
931
+
932
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
933
+ def get_state(self) -> PolicyState:
934
+ """Returns the entire current state of this Policy.
935
+
936
+ Note: Not to be confused with an RNN model's internal state.
937
+ State includes the Model(s)' weights, optimizer weights,
938
+ the exploration component's state, as well as global variables, such
939
+ as sampling timesteps.
940
+
941
+ Note that the state may contain references to the original variables.
942
+ This means that you may need to deepcopy() the state before mutating it.
943
+
944
+ Returns:
945
+ Serialized local state.
946
+ """
947
+ state = {
948
+ # All the policy's weights.
949
+ "weights": self.get_weights(),
950
+ # The current global timestep.
951
+ "global_timestep": self.global_timestep,
952
+ # The current num_grad_updates counter.
953
+ "num_grad_updates": self.num_grad_updates,
954
+ }
955
+
956
+ # Add this Policy's spec so it can be retreived w/o access to the original
957
+ # code.
958
+ policy_spec = PolicySpec(
959
+ policy_class=type(self),
960
+ observation_space=self.observation_space,
961
+ action_space=self.action_space,
962
+ config=self.config,
963
+ )
964
+ state["policy_spec"] = policy_spec.serialize()
965
+
966
+ # Checkpoint connectors state as well if enabled.
967
+ connector_configs = {}
968
+ if self.agent_connectors:
969
+ connector_configs["agent"] = self.agent_connectors.to_state()
970
+ if self.action_connectors:
971
+ connector_configs["action"] = self.action_connectors.to_state()
972
+ state["connector_configs"] = connector_configs
973
+
974
+ return state
975
+
976
+ def restore_connectors(self, state: PolicyState):
977
+ """Restore agent and action connectors if configs available.
978
+
979
+ Args:
980
+ state: The new state to set this policy to. Can be
981
+ obtained by calling `self.get_state()`.
982
+ """
983
+ # To avoid a circular dependency problem cause by SampleBatch.
984
+ from ray.rllib.connectors.util import restore_connectors_for_policy
985
+
986
+ connector_configs = state.get("connector_configs", {})
987
+ if "agent" in connector_configs:
988
+ self.agent_connectors = restore_connectors_for_policy(
989
+ self, connector_configs["agent"]
990
+ )
991
+ logger.debug("restoring agent connectors:")
992
+ logger.debug(self.agent_connectors.__str__(indentation=4))
993
+ if "action" in connector_configs:
994
+ self.action_connectors = restore_connectors_for_policy(
995
+ self, connector_configs["action"]
996
+ )
997
+ logger.debug("restoring action connectors:")
998
+ logger.debug(self.action_connectors.__str__(indentation=4))
999
+
1000
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
1001
+ def set_state(self, state: PolicyState) -> None:
1002
+ """Restores the entire current state of this Policy from `state`.
1003
+
1004
+ Args:
1005
+ state: The new state to set this policy to. Can be
1006
+ obtained by calling `self.get_state()`.
1007
+ """
1008
+ if "policy_spec" in state:
1009
+ policy_spec = PolicySpec.deserialize(state["policy_spec"])
1010
+ # Assert spaces remained the same.
1011
+ if (
1012
+ policy_spec.observation_space is not None
1013
+ and policy_spec.observation_space != self.observation_space
1014
+ ):
1015
+ logger.warning(
1016
+ "`observation_space` in given policy state ("
1017
+ f"{policy_spec.observation_space}) does not match this Policy's "
1018
+ f"observation space ({self.observation_space})."
1019
+ )
1020
+ if (
1021
+ policy_spec.action_space is not None
1022
+ and policy_spec.action_space != self.action_space
1023
+ ):
1024
+ logger.warning(
1025
+ "`action_space` in given policy state ("
1026
+ f"{policy_spec.action_space}) does not match this Policy's "
1027
+ f"action space ({self.action_space})."
1028
+ )
1029
+ # Override config, if part of the spec.
1030
+ if policy_spec.config:
1031
+ self.config = policy_spec.config
1032
+
1033
+ # Override NN weights.
1034
+ self.set_weights(state["weights"])
1035
+ self.restore_connectors(state)
1036
+
1037
+ def apply(
1038
+ self,
1039
+ func: Callable[["Policy", Optional[Any], Optional[Any]], T],
1040
+ *args,
1041
+ **kwargs,
1042
+ ) -> T:
1043
+ """Calls the given function with this Policy instance.
1044
+
1045
+ Useful for when the Policy class has been converted into a ActorHandle
1046
+ and the user needs to execute some functionality (e.g. add a property)
1047
+ on the underlying policy object.
1048
+
1049
+ Args:
1050
+ func: The function to call, with this Policy as first
1051
+ argument, followed by args, and kwargs.
1052
+ args: Optional additional args to pass to the function call.
1053
+ kwargs: Optional additional kwargs to pass to the function call.
1054
+
1055
+ Returns:
1056
+ The return value of the function call.
1057
+ """
1058
+ return func(self, *args, **kwargs)
1059
+
1060
+ def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
1061
+ """Called on an update to global vars.
1062
+
1063
+ Args:
1064
+ global_vars: Global variables by str key, broadcast from the
1065
+ driver.
1066
+ """
1067
+ # Store the current global time step (sum over all policies' sample
1068
+ # steps).
1069
+ # Make sure, we keep global_timestep as a Tensor for tf-eager
1070
+ # (leads to memory leaks if not doing so).
1071
+ if self.framework == "tf2":
1072
+ self.global_timestep.assign(global_vars["timestep"])
1073
+ else:
1074
+ self.global_timestep = global_vars["timestep"]
1075
+ # Update our lifetime gradient update counter.
1076
+ num_grad_updates = global_vars.get("num_grad_updates")
1077
+ if num_grad_updates is not None:
1078
+ self.num_grad_updates = num_grad_updates
1079
+
1080
+ def export_checkpoint(
1081
+ self,
1082
+ export_dir: str,
1083
+ filename_prefix=DEPRECATED_VALUE,
1084
+ *,
1085
+ policy_state: Optional[PolicyState] = None,
1086
+ checkpoint_format: str = "cloudpickle",
1087
+ ) -> None:
1088
+ """Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
1089
+
1090
+ Args:
1091
+ export_dir: Local writable directory to store the AIR Checkpoint
1092
+ information into.
1093
+ policy_state: An optional PolicyState to write to disk. Used by
1094
+ `Algorithm.save_checkpoint()` to save on the additional
1095
+ `self.get_state()` calls of its different Policies.
1096
+ checkpoint_format: Either one of 'cloudpickle' or 'msgpack'.
1097
+
1098
+ .. testcode::
1099
+ :skipif: True
1100
+
1101
+ from ray.rllib.algorithms.ppo import PPOTorchPolicy
1102
+ policy = PPOTorchPolicy(...)
1103
+ policy.export_checkpoint("/tmp/export_dir")
1104
+ """
1105
+ # `filename_prefix` should not longer be used as new Policy checkpoints
1106
+ # contain more than one file with a fixed filename structure.
1107
+ if filename_prefix != DEPRECATED_VALUE:
1108
+ deprecation_warning(
1109
+ old="Policy.export_checkpoint(filename_prefix=...)",
1110
+ error=True,
1111
+ )
1112
+ if checkpoint_format not in ["cloudpickle", "msgpack"]:
1113
+ raise ValueError(
1114
+ f"Value of `checkpoint_format` ({checkpoint_format}) must either be "
1115
+ "'cloudpickle' or 'msgpack'!"
1116
+ )
1117
+
1118
+ if policy_state is None:
1119
+ policy_state = self.get_state()
1120
+
1121
+ # Write main policy state file.
1122
+ os.makedirs(export_dir, exist_ok=True)
1123
+ if checkpoint_format == "cloudpickle":
1124
+ policy_state["checkpoint_version"] = CHECKPOINT_VERSION
1125
+ state_file = "policy_state.pkl"
1126
+ with open(os.path.join(export_dir, state_file), "w+b") as f:
1127
+ pickle.dump(policy_state, f)
1128
+ else:
1129
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
1130
+
1131
+ msgpack = try_import_msgpack(error=True)
1132
+ policy_state["checkpoint_version"] = str(CHECKPOINT_VERSION)
1133
+ # Serialize the config for msgpack dump'ing.
1134
+ policy_state["policy_spec"]["config"] = AlgorithmConfig._serialize_dict(
1135
+ policy_state["policy_spec"]["config"]
1136
+ )
1137
+ state_file = "policy_state.msgpck"
1138
+ with open(os.path.join(export_dir, state_file), "w+b") as f:
1139
+ msgpack.dump(policy_state, f)
1140
+
1141
+ # Write RLlib checkpoint json.
1142
+ with open(os.path.join(export_dir, "rllib_checkpoint.json"), "w") as f:
1143
+ json.dump(
1144
+ {
1145
+ "type": "Policy",
1146
+ "checkpoint_version": str(policy_state["checkpoint_version"]),
1147
+ "format": checkpoint_format,
1148
+ "state_file": state_file,
1149
+ "ray_version": ray.__version__,
1150
+ "ray_commit": ray.__commit__,
1151
+ },
1152
+ f,
1153
+ )
1154
+
1155
+ # Add external model files, if required.
1156
+ if self.config["export_native_model_files"]:
1157
+ self.export_model(os.path.join(export_dir, "model"))
1158
+
1159
+ def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
1160
+ """Exports the Policy's Model to local directory for serving.
1161
+
1162
+ Note: The file format will depend on the deep learning framework used.
1163
+ See the child classed of Policy and their `export_model`
1164
+ implementations for more details.
1165
+
1166
+ Args:
1167
+ export_dir: Local writable directory.
1168
+ onnx: If given, will export model in ONNX format. The
1169
+ value of this parameter set the ONNX OpSet version to use.
1170
+
1171
+ Raises:
1172
+ ValueError: If a native DL-framework based model (e.g. a keras Model)
1173
+ cannot be saved to disk for various reasons.
1174
+ """
1175
+ raise NotImplementedError
1176
+
1177
+ def import_model_from_h5(self, import_file: str) -> None:
1178
+ """Imports Policy from local file.
1179
+
1180
+ Args:
1181
+ import_file: Local readable file.
1182
+ """
1183
+ raise NotImplementedError
1184
+
1185
+ def get_session(self) -> Optional["tf1.Session"]:
1186
+ """Returns tf.Session object to use for computing actions or None.
1187
+
1188
+ Note: This method only applies to TFPolicy sub-classes. All other
1189
+ sub-classes should expect a None to be returned from this method.
1190
+
1191
+ Returns:
1192
+ The tf Session to use for computing actions and losses with
1193
+ this policy or None.
1194
+ """
1195
+ return None
1196
+
1197
+ def get_host(self) -> str:
1198
+ """Returns the computer's network name.
1199
+
1200
+ Returns:
1201
+ The computer's networks name or an empty string, if the network
1202
+ name could not be determined.
1203
+ """
1204
+ return platform.node()
1205
+
1206
+ def _get_num_gpus_for_policy(self) -> int:
1207
+ """Decide on the number of CPU/GPU nodes this policy should run on.
1208
+
1209
+ Return:
1210
+ 0 if policy should run on CPU. >0 if policy should run on 1 or
1211
+ more GPUs.
1212
+ """
1213
+ worker_idx = self.config.get("worker_index", 0)
1214
+ fake_gpus = self.config.get("_fake_gpus", False)
1215
+
1216
+ if (
1217
+ ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
1218
+ and not fake_gpus
1219
+ ):
1220
+ # If in local debugging mode, and _fake_gpus is not on.
1221
+ num_gpus = 0
1222
+ elif worker_idx == 0:
1223
+ # If head node, take num_gpus.
1224
+ num_gpus = self.config["num_gpus"]
1225
+ else:
1226
+ # If worker node, take `num_gpus_per_env_runner`.
1227
+ num_gpus = self.config["num_gpus_per_env_runner"]
1228
+
1229
+ if num_gpus == 0:
1230
+ dev = "CPU"
1231
+ else:
1232
+ dev = "{} {}".format(num_gpus, "fake-GPUs" if fake_gpus else "GPUs")
1233
+
1234
+ logger.info(
1235
+ "Policy (worker={}) running on {}.".format(
1236
+ worker_idx if worker_idx > 0 else "local", dev
1237
+ )
1238
+ )
1239
+
1240
+ return num_gpus
1241
+
1242
+ def _create_exploration(self) -> Exploration:
1243
+ """Creates the Policy's Exploration object.
1244
+
1245
+ This method only exists b/c some Algorithms do not use TfPolicy nor
1246
+ TorchPolicy, but inherit directly from Policy. Others inherit from
1247
+ TfPolicy w/o using DynamicTFPolicy.
1248
+
1249
+ Returns:
1250
+ Exploration: The Exploration object to be used by this Policy.
1251
+ """
1252
+ if getattr(self, "exploration", None) is not None:
1253
+ return self.exploration
1254
+
1255
+ exploration = from_config(
1256
+ Exploration,
1257
+ self.config.get("exploration_config", {"type": "StochasticSampling"}),
1258
+ action_space=self.action_space,
1259
+ policy_config=self.config,
1260
+ model=getattr(self, "model", None),
1261
+ num_workers=self.config.get("num_env_runners", 0),
1262
+ worker_index=self.config.get("worker_index", 0),
1263
+ framework=getattr(self, "framework", self.config.get("framework", "tf")),
1264
+ )
1265
+ return exploration
1266
+
1267
+ def _get_default_view_requirements(self):
1268
+ """Returns a default ViewRequirements dict.
1269
+
1270
+ Note: This is the base/maximum requirement dict, from which later
1271
+ some requirements will be subtracted again automatically to streamline
1272
+ data collection, batch creation, and data transfer.
1273
+
1274
+ Returns:
1275
+ ViewReqDict: The default view requirements dict.
1276
+ """
1277
+
1278
+ # Default view requirements (equal to those that we would use before
1279
+ # the trajectory view API was introduced).
1280
+ return {
1281
+ SampleBatch.OBS: ViewRequirement(space=self.observation_space),
1282
+ SampleBatch.NEXT_OBS: ViewRequirement(
1283
+ data_col=SampleBatch.OBS,
1284
+ shift=1,
1285
+ space=self.observation_space,
1286
+ used_for_compute_actions=False,
1287
+ ),
1288
+ SampleBatch.ACTIONS: ViewRequirement(
1289
+ space=self.action_space, used_for_compute_actions=False
1290
+ ),
1291
+ # For backward compatibility with custom Models that don't specify
1292
+ # these explicitly (will be removed by Policy if not used).
1293
+ SampleBatch.PREV_ACTIONS: ViewRequirement(
1294
+ data_col=SampleBatch.ACTIONS, shift=-1, space=self.action_space
1295
+ ),
1296
+ SampleBatch.REWARDS: ViewRequirement(),
1297
+ # For backward compatibility with custom Models that don't specify
1298
+ # these explicitly (will be removed by Policy if not used).
1299
+ SampleBatch.PREV_REWARDS: ViewRequirement(
1300
+ data_col=SampleBatch.REWARDS, shift=-1
1301
+ ),
1302
+ SampleBatch.TERMINATEDS: ViewRequirement(),
1303
+ SampleBatch.TRUNCATEDS: ViewRequirement(),
1304
+ SampleBatch.INFOS: ViewRequirement(used_for_compute_actions=False),
1305
+ SampleBatch.EPS_ID: ViewRequirement(),
1306
+ SampleBatch.UNROLL_ID: ViewRequirement(),
1307
+ SampleBatch.AGENT_INDEX: ViewRequirement(),
1308
+ SampleBatch.T: ViewRequirement(),
1309
+ }
1310
+
1311
+ def _initialize_loss_from_dummy_batch(
1312
+ self,
1313
+ auto_remove_unneeded_view_reqs: bool = True,
1314
+ stats_fn=None,
1315
+ ) -> None:
1316
+ """Performs test calls through policy's model and loss.
1317
+
1318
+ NOTE: This base method should work for define-by-run Policies such as
1319
+ torch and tf-eager policies.
1320
+
1321
+ If required, will thereby detect automatically, which data views are
1322
+ required by a) the forward pass, b) the postprocessing, and c) the loss
1323
+ functions, and remove those from self.view_requirements that are not
1324
+ necessary for these computations (to save data storage and transfer).
1325
+
1326
+ Args:
1327
+ auto_remove_unneeded_view_reqs: Whether to automatically
1328
+ remove those ViewRequirements records from
1329
+ self.view_requirements that are not needed.
1330
+ stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str,
1331
+ TensorType]]]): An optional stats function to be called after
1332
+ the loss.
1333
+ """
1334
+
1335
+ if self.config.get("_disable_initialize_loss_from_dummy_batch", False):
1336
+ return
1337
+ # Signal Policy that currently we do not like to eager/jit trace
1338
+ # any function calls. This is to be able to track, which columns
1339
+ # in the dummy batch are accessed by the different function (e.g.
1340
+ # loss) such that we can then adjust our view requirements.
1341
+ self._no_tracing = True
1342
+ # Save for later so that loss init does not change global timestep
1343
+ global_ts_before_init = int(convert_to_numpy(self.global_timestep))
1344
+
1345
+ sample_batch_size = min(
1346
+ max(self.batch_divisibility_req * 4, 32),
1347
+ self.config["train_batch_size"], # Don't go over the asked batch size.
1348
+ )
1349
+ self._dummy_batch = self._get_dummy_batch_from_view_requirements(
1350
+ sample_batch_size
1351
+ )
1352
+ self._lazy_tensor_dict(self._dummy_batch)
1353
+ explore = False
1354
+ actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
1355
+ self._dummy_batch, explore=explore
1356
+ )
1357
+ for key, view_req in self.view_requirements.items():
1358
+ if key not in self._dummy_batch.accessed_keys:
1359
+ view_req.used_for_compute_actions = False
1360
+ # Add all extra action outputs to view reqirements (these may be
1361
+ # filtered out later again, if not needed for postprocessing or loss).
1362
+ for key, value in extra_outs.items():
1363
+ self._dummy_batch[key] = value
1364
+ if key not in self.view_requirements:
1365
+ if isinstance(value, (dict, np.ndarray)):
1366
+ # the assumption is that value is a nested_dict of np.arrays leaves
1367
+ space = get_gym_space_from_struct_of_tensors(value)
1368
+ self.view_requirements[key] = ViewRequirement(
1369
+ space=space, used_for_compute_actions=False
1370
+ )
1371
+ else:
1372
+ raise ValueError(
1373
+ "policy.compute_actions_from_input_dict() returns an "
1374
+ "extra action output that is neither a numpy array nor a dict."
1375
+ )
1376
+
1377
+ for key in self._dummy_batch.accessed_keys:
1378
+ if key not in self.view_requirements:
1379
+ self.view_requirements[key] = ViewRequirement()
1380
+ self.view_requirements[key].used_for_compute_actions = False
1381
+ # TODO (kourosh) Why did we use to make used_for_compute_actions True here?
1382
+ new_batch = self._get_dummy_batch_from_view_requirements(sample_batch_size)
1383
+ # Make sure the dummy_batch will return numpy arrays when accessed
1384
+ self._dummy_batch.set_get_interceptor(None)
1385
+
1386
+ # try to re-use the output of the previous run to avoid overriding things that
1387
+ # would break (e.g. scale = 0 of Normal distribution cannot be zero)
1388
+ for k in new_batch:
1389
+ if k not in self._dummy_batch:
1390
+ self._dummy_batch[k] = new_batch[k]
1391
+
1392
+ # Make sure the book-keeping of dummy_batch keys are reset to correcly track
1393
+ # what is accessed, what is added and what's deleted from now on.
1394
+ self._dummy_batch.accessed_keys.clear()
1395
+ self._dummy_batch.deleted_keys.clear()
1396
+ self._dummy_batch.added_keys.clear()
1397
+
1398
+ if self.exploration:
1399
+ # Policies with RLModules don't have an exploration object.
1400
+ self.exploration.postprocess_trajectory(self, self._dummy_batch)
1401
+
1402
+ postprocessed_batch = self.postprocess_trajectory(self._dummy_batch)
1403
+ seq_lens = None
1404
+ if state_outs:
1405
+ B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
1406
+ i = 0
1407
+ while "state_in_{}".format(i) in postprocessed_batch:
1408
+ postprocessed_batch["state_in_{}".format(i)] = postprocessed_batch[
1409
+ "state_in_{}".format(i)
1410
+ ][:B]
1411
+ if "state_out_{}".format(i) in postprocessed_batch:
1412
+ postprocessed_batch["state_out_{}".format(i)] = postprocessed_batch[
1413
+ "state_out_{}".format(i)
1414
+ ][:B]
1415
+ i += 1
1416
+
1417
+ seq_len = sample_batch_size // B
1418
+ seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
1419
+ postprocessed_batch[SampleBatch.SEQ_LENS] = seq_lens
1420
+
1421
+ # Switch on lazy to-tensor conversion on `postprocessed_batch`.
1422
+ train_batch = self._lazy_tensor_dict(postprocessed_batch)
1423
+ # Calling loss, so set `is_training` to True.
1424
+ train_batch.set_training(True)
1425
+ if seq_lens is not None:
1426
+ train_batch[SampleBatch.SEQ_LENS] = seq_lens
1427
+ train_batch.count = self._dummy_batch.count
1428
+
1429
+ # Call the loss function, if it exists.
1430
+ # TODO(jungong) : clean up after all agents get migrated.
1431
+ # We should simply do self.loss(...) here.
1432
+ if self._loss is not None:
1433
+ self._loss(self, self.model, self.dist_class, train_batch)
1434
+ elif is_overridden(self.loss) and not self.config["in_evaluation"]:
1435
+ self.loss(self.model, self.dist_class, train_batch)
1436
+ # Call the stats fn, if given.
1437
+ # TODO(jungong) : clean up after all agents get migrated.
1438
+ # We should simply do self.stats_fn(train_batch) here.
1439
+ if stats_fn is not None:
1440
+ stats_fn(self, train_batch)
1441
+ if hasattr(self, "stats_fn") and not self.config["in_evaluation"]:
1442
+ self.stats_fn(train_batch)
1443
+
1444
+ # Re-enable tracing.
1445
+ self._no_tracing = False
1446
+
1447
+ # Add new columns automatically to view-reqs.
1448
+ if auto_remove_unneeded_view_reqs:
1449
+ # Add those needed for postprocessing and training.
1450
+ all_accessed_keys = (
1451
+ train_batch.accessed_keys
1452
+ | self._dummy_batch.accessed_keys
1453
+ | self._dummy_batch.added_keys
1454
+ )
1455
+ for key in all_accessed_keys:
1456
+ if key not in self.view_requirements and key != SampleBatch.SEQ_LENS:
1457
+ self.view_requirements[key] = ViewRequirement(
1458
+ used_for_compute_actions=False
1459
+ )
1460
+ if self._loss or is_overridden(self.loss):
1461
+ # Tag those only needed for post-processing (with some
1462
+ # exceptions).
1463
+ for key in self._dummy_batch.accessed_keys:
1464
+ if (
1465
+ key not in train_batch.accessed_keys
1466
+ and key in self.view_requirements
1467
+ and key not in self.model.view_requirements
1468
+ and key
1469
+ not in [
1470
+ SampleBatch.EPS_ID,
1471
+ SampleBatch.AGENT_INDEX,
1472
+ SampleBatch.UNROLL_ID,
1473
+ SampleBatch.TERMINATEDS,
1474
+ SampleBatch.TRUNCATEDS,
1475
+ SampleBatch.REWARDS,
1476
+ SampleBatch.INFOS,
1477
+ SampleBatch.T,
1478
+ ]
1479
+ ):
1480
+ self.view_requirements[key].used_for_training = False
1481
+ # Remove those not needed at all (leave those that are needed
1482
+ # by Sampler to properly execute sample collection). Also always leave
1483
+ # TERMINATEDS, TRUNCATEDS, REWARDS, INFOS, no matter what.
1484
+ for key in list(self.view_requirements.keys()):
1485
+ if (
1486
+ key not in all_accessed_keys
1487
+ and key
1488
+ not in [
1489
+ SampleBatch.EPS_ID,
1490
+ SampleBatch.AGENT_INDEX,
1491
+ SampleBatch.UNROLL_ID,
1492
+ SampleBatch.TERMINATEDS,
1493
+ SampleBatch.TRUNCATEDS,
1494
+ SampleBatch.REWARDS,
1495
+ SampleBatch.INFOS,
1496
+ SampleBatch.T,
1497
+ ]
1498
+ and key not in self.model.view_requirements
1499
+ ):
1500
+ # If user deleted this key manually in postprocessing
1501
+ # fn, warn about it and do not remove from
1502
+ # view-requirements.
1503
+ if key in self._dummy_batch.deleted_keys:
1504
+ logger.warning(
1505
+ "SampleBatch key '{}' was deleted manually in "
1506
+ "postprocessing function! RLlib will "
1507
+ "automatically remove non-used items from the "
1508
+ "data stream. Remove the `del` from your "
1509
+ "postprocessing function.".format(key)
1510
+ )
1511
+ # If we are not writing output to disk, save to erase
1512
+ # this key to save space in the sample batch.
1513
+ elif self.config["output"] is None:
1514
+ del self.view_requirements[key]
1515
+
1516
+ if type(self.global_timestep) is int:
1517
+ self.global_timestep = global_ts_before_init
1518
+ elif isinstance(self.global_timestep, tf.Variable):
1519
+ self.global_timestep.assign(global_ts_before_init)
1520
+ else:
1521
+ raise ValueError(
1522
+ "Variable self.global_timestep of policy {} needs to be "
1523
+ "either of type `int` or `tf.Variable`, "
1524
+ "but is of type {}.".format(self, type(self.global_timestep))
1525
+ )
1526
+
1527
+ def maybe_remove_time_dimension(self, input_dict: Dict[str, TensorType]):
1528
+ """Removes a time dimension for recurrent RLModules.
1529
+
1530
+ Args:
1531
+ input_dict: The input dict.
1532
+
1533
+ Returns:
1534
+ The input dict with a possibly removed time dimension.
1535
+ """
1536
+ raise NotImplementedError
1537
+
1538
+ def _get_dummy_batch_from_view_requirements(
1539
+ self, batch_size: int = 1
1540
+ ) -> SampleBatch:
1541
+ """Creates a numpy dummy batch based on the Policy's view requirements.
1542
+
1543
+ Args:
1544
+ batch_size: The size of the batch to create.
1545
+
1546
+ Returns:
1547
+ Dict[str, TensorType]: The dummy batch containing all zero values.
1548
+ """
1549
+ ret = {}
1550
+ for view_col, view_req in self.view_requirements.items():
1551
+ data_col = view_req.data_col or view_col
1552
+ # Flattened dummy batch.
1553
+ if (isinstance(view_req.space, (gym.spaces.Tuple, gym.spaces.Dict))) and (
1554
+ (
1555
+ data_col == SampleBatch.OBS
1556
+ and not self.config["_disable_preprocessor_api"]
1557
+ )
1558
+ or (
1559
+ data_col == SampleBatch.ACTIONS
1560
+ and not self.config.get("_disable_action_flattening")
1561
+ )
1562
+ ):
1563
+ _, shape = ModelCatalog.get_action_shape(
1564
+ view_req.space, framework=self.config["framework"]
1565
+ )
1566
+ ret[view_col] = np.zeros((batch_size,) + shape[1:], np.float32)
1567
+ # Non-flattened dummy batch.
1568
+ else:
1569
+ # Range of indices on time-axis, e.g. "-50:-1".
1570
+ if isinstance(view_req.space, gym.spaces.Space):
1571
+ time_size = (
1572
+ len(view_req.shift_arr) if len(view_req.shift_arr) > 1 else None
1573
+ )
1574
+ ret[view_col] = get_dummy_batch_for_space(
1575
+ view_req.space, batch_size=batch_size, time_size=time_size
1576
+ )
1577
+ else:
1578
+ ret[view_col] = [view_req.space for _ in range(batch_size)]
1579
+
1580
+ # Due to different view requirements for the different columns,
1581
+ # columns in the resulting batch may not all have the same batch size.
1582
+ return SampleBatch(ret)
1583
+
1584
+ def _update_model_view_requirements_from_init_state(self):
1585
+ """Uses Model's (or this Policy's) init state to add needed ViewReqs.
1586
+
1587
+ Can be called from within a Policy to make sure RNNs automatically
1588
+ update their internal state-related view requirements.
1589
+ Changes the `self.view_requirements` dict.
1590
+ """
1591
+ self._model_init_state_automatically_added = True
1592
+ model = getattr(self, "model", None)
1593
+
1594
+ obj = model or self
1595
+ if model and not hasattr(model, "view_requirements"):
1596
+ model.view_requirements = {
1597
+ SampleBatch.OBS: ViewRequirement(space=self.observation_space)
1598
+ }
1599
+ view_reqs = obj.view_requirements
1600
+ # Add state-ins to this model's view.
1601
+ init_state = []
1602
+ if hasattr(obj, "get_initial_state") and callable(obj.get_initial_state):
1603
+ init_state = obj.get_initial_state()
1604
+ else:
1605
+ # Add this functionality automatically for new native model API.
1606
+ if (
1607
+ tf
1608
+ and isinstance(model, tf.keras.Model)
1609
+ and "state_in_0" not in view_reqs
1610
+ ):
1611
+ obj.get_initial_state = lambda: [
1612
+ np.zeros_like(view_req.space.sample())
1613
+ for k, view_req in model.view_requirements.items()
1614
+ if k.startswith("state_in_")
1615
+ ]
1616
+ else:
1617
+ obj.get_initial_state = lambda: []
1618
+ if "state_in_0" in view_reqs:
1619
+ self.is_recurrent = lambda: True
1620
+
1621
+ # Make sure auto-generated init-state view requirements get added
1622
+ # to both Policy and Model, no matter what.
1623
+ view_reqs = [view_reqs] + (
1624
+ [self.view_requirements] if hasattr(self, "view_requirements") else []
1625
+ )
1626
+
1627
+ for i, state in enumerate(init_state):
1628
+ # Allow `state` to be either a Space (use zeros as initial values)
1629
+ # or any value (e.g. a dict or a non-zero tensor).
1630
+ fw = (
1631
+ np
1632
+ if isinstance(state, np.ndarray)
1633
+ else torch
1634
+ if torch and torch.is_tensor(state)
1635
+ else None
1636
+ )
1637
+ if fw:
1638
+ space = (
1639
+ Box(-1.0, 1.0, shape=state.shape) if fw.all(state == 0.0) else state
1640
+ )
1641
+ else:
1642
+ space = state
1643
+ for vr in view_reqs:
1644
+ # Only override if user has not already provided
1645
+ # custom view-requirements for state_in_n.
1646
+ if "state_in_{}".format(i) not in vr:
1647
+ vr["state_in_{}".format(i)] = ViewRequirement(
1648
+ "state_out_{}".format(i),
1649
+ shift=-1,
1650
+ used_for_compute_actions=True,
1651
+ batch_repeat_value=self.config.get("model", {}).get(
1652
+ "max_seq_len", 1
1653
+ ),
1654
+ space=space,
1655
+ )
1656
+ # Only override if user has not already provided
1657
+ # custom view-requirements for state_out_n.
1658
+ if "state_out_{}".format(i) not in vr:
1659
+ vr["state_out_{}".format(i)] = ViewRequirement(
1660
+ space=space, used_for_training=True
1661
+ )
1662
+
1663
+ def __repr__(self):
1664
+ return type(self).__name__
1665
+
1666
+
1667
+ @OldAPIStack
1668
+ def get_gym_space_from_struct_of_tensors(
1669
+ value: Union[Dict, Tuple, List, TensorType],
1670
+ batched_input=True,
1671
+ ) -> gym.Space:
1672
+ start_idx = 1 if batched_input else 0
1673
+ struct = tree.map_structure(
1674
+ lambda x: gym.spaces.Box(
1675
+ -1.0, 1.0, shape=x.shape[start_idx:], dtype=get_np_dtype(x)
1676
+ ),
1677
+ value,
1678
+ )
1679
+ space = get_gym_space_from_struct_of_spaces(struct)
1680
+ return space
1681
+
1682
+
1683
+ @OldAPIStack
1684
+ def get_gym_space_from_struct_of_spaces(value: Union[Dict, Tuple]) -> gym.spaces.Dict:
1685
+ if isinstance(value, dict):
1686
+ return gym.spaces.Dict(
1687
+ {k: get_gym_space_from_struct_of_spaces(v) for k, v in value.items()}
1688
+ )
1689
+ elif isinstance(value, (tuple, list)):
1690
+ return gym.spaces.Tuple([get_gym_space_from_struct_of_spaces(v) for v in value])
1691
+ else:
1692
+ assert isinstance(value, gym.spaces.Space), (
1693
+ f"The struct of spaces should only contain dicts, tiples and primitive "
1694
+ f"gym spaces. Space is of type {type(value)}"
1695
+ )
1696
+ return value
.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_map.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import threading
3
+ from typing import Dict, Set
4
+ import logging
5
+
6
+ import ray
7
+ from ray.rllib.policy.policy import Policy
8
+ from ray.rllib.utils.annotations import OldAPIStack, override
9
+ from ray.rllib.utils.deprecation import deprecation_warning
10
+ from ray.rllib.utils.framework import try_import_tf
11
+ from ray.rllib.utils.threading import with_lock
12
+ from ray.rllib.utils.typing import PolicyID
13
+
14
+ tf1, tf, tfv = try_import_tf()
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @OldAPIStack
19
+ class PolicyMap(dict):
20
+ """Maps policy IDs to Policy objects.
21
+
22
+ Thereby, keeps n policies in memory and - when capacity is reached -
23
+ writes the least recently used to disk. This allows adding 100s of
24
+ policies to a Algorithm for league-based setups w/o running out of memory.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ capacity: int = 100,
31
+ policy_states_are_swappable: bool = False,
32
+ # Deprecated args.
33
+ worker_index=None,
34
+ num_workers=None,
35
+ policy_config=None,
36
+ session_creator=None,
37
+ seed=None,
38
+ ):
39
+ """Initializes a PolicyMap instance.
40
+
41
+ Args:
42
+ capacity: The size of the Policy object cache. This is the maximum number
43
+ of policies that are held in RAM memory. When reaching this capacity,
44
+ the least recently used Policy's state will be stored in the Ray object
45
+ store and recovered from there when being accessed again.
46
+ policy_states_are_swappable: Whether all Policy objects in this map can be
47
+ "swapped out" via a simple `state = A.get_state(); B.set_state(state)`,
48
+ where `A` and `B` are policy instances in this map. You should set
49
+ this to True for significantly speeding up the PolicyMap's cache lookup
50
+ times, iff your policies all share the same neural network
51
+ architecture and optimizer types. If True, the PolicyMap will not
52
+ have to garbage collect old, least recently used policies, but instead
53
+ keep them in memory and simply override their state with the state of
54
+ the most recently accessed one.
55
+ For example, in a league-based training setup, you might have 100s of
56
+ the same policies in your map (playing against each other in various
57
+ combinations), but all of them share the same state structure
58
+ (are "swappable").
59
+ """
60
+ if policy_config is not None:
61
+ deprecation_warning(
62
+ old="PolicyMap(policy_config=..)",
63
+ error=True,
64
+ )
65
+
66
+ super().__init__()
67
+
68
+ self.capacity = capacity
69
+
70
+ if any(
71
+ i is not None
72
+ for i in [policy_config, worker_index, num_workers, session_creator, seed]
73
+ ):
74
+ deprecation_warning(
75
+ old="PolicyMap([deprecated args]...)",
76
+ new="PolicyMap(capacity=..., policy_states_are_swappable=...)",
77
+ error=False,
78
+ )
79
+
80
+ self.policy_states_are_swappable = policy_states_are_swappable
81
+
82
+ # The actual cache with the in-memory policy objects.
83
+ self.cache: Dict[str, Policy] = {}
84
+
85
+ # Set of keys that may be looked up (cached or not).
86
+ self._valid_keys: Set[str] = set()
87
+ # The doubly-linked list holding the currently in-memory objects.
88
+ self._deque = deque()
89
+
90
+ # Ray object store references to the stashed Policy states.
91
+ self._policy_state_refs = {}
92
+
93
+ # Lock used for locking some methods on the object-level.
94
+ # This prevents possible race conditions when accessing the map
95
+ # and the underlying structures, like self._deque and others.
96
+ self._lock = threading.RLock()
97
+
98
+ @with_lock
99
+ @override(dict)
100
+ def __getitem__(self, item: PolicyID):
101
+ # Never seen this key -> Error.
102
+ if item not in self._valid_keys:
103
+ raise KeyError(
104
+ f"PolicyID '{item}' not found in this PolicyMap! "
105
+ f"IDs stored in this map: {self._valid_keys}."
106
+ )
107
+
108
+ # Item already in cache -> Rearrange deque (promote `item` to
109
+ # "most recently used") and return it.
110
+ if item in self.cache:
111
+ self._deque.remove(item)
112
+ self._deque.append(item)
113
+ return self.cache[item]
114
+
115
+ # Item not currently in cache -> Get from stash and - if at capacity -
116
+ # remove leftmost one.
117
+ if item not in self._policy_state_refs:
118
+ raise AssertionError(
119
+ f"PolicyID {item} not found in internal Ray object store cache!"
120
+ )
121
+ policy_state = ray.get(self._policy_state_refs[item])
122
+
123
+ policy = None
124
+ # We are at capacity: Remove the oldest policy from deque as well as the
125
+ # cache and return it.
126
+ if len(self._deque) == self.capacity:
127
+ policy = self._stash_least_used_policy()
128
+
129
+ # All our policies have same NN-architecture (are "swappable").
130
+ # -> Load new policy's state into the one that just got removed from the cache.
131
+ # This way, we save the costly re-creation step.
132
+ if policy is not None and self.policy_states_are_swappable:
133
+ logger.debug(f"restoring policy: {item}")
134
+ policy.set_state(policy_state)
135
+ else:
136
+ logger.debug(f"creating new policy: {item}")
137
+ policy = Policy.from_state(policy_state)
138
+
139
+ self.cache[item] = policy
140
+ # Promote the item to most recently one.
141
+ self._deque.append(item)
142
+
143
+ return policy
144
+
145
+ @with_lock
146
+ @override(dict)
147
+ def __setitem__(self, key: PolicyID, value: Policy):
148
+ # Item already in cache -> Rearrange deque.
149
+ if key in self.cache:
150
+ self._deque.remove(key)
151
+
152
+ # Item not currently in cache -> store new value and - if at capacity -
153
+ # remove leftmost one.
154
+ else:
155
+ # Cache at capacity -> Drop leftmost item.
156
+ if len(self._deque) == self.capacity:
157
+ self._stash_least_used_policy()
158
+
159
+ # Promote `key` to "most recently used".
160
+ self._deque.append(key)
161
+
162
+ # Update our cache.
163
+ self.cache[key] = value
164
+ self._valid_keys.add(key)
165
+
166
+ @with_lock
167
+ @override(dict)
168
+ def __delitem__(self, key: PolicyID):
169
+ # Make key invalid.
170
+ self._valid_keys.remove(key)
171
+ # Remove policy from deque if contained
172
+ if key in self._deque:
173
+ self._deque.remove(key)
174
+ # Remove policy from memory if currently cached.
175
+ if key in self.cache:
176
+ policy = self.cache[key]
177
+ self._close_session(policy)
178
+ del self.cache[key]
179
+ # Remove Ray object store reference (if this ID has already been stored
180
+ # there), so the item gets garbage collected.
181
+ if key in self._policy_state_refs:
182
+ del self._policy_state_refs[key]
183
+
184
+ @override(dict)
185
+ def __iter__(self):
186
+ return iter(self.keys())
187
+
188
+ @override(dict)
189
+ def items(self):
190
+ """Iterates over all policies, even the stashed ones."""
191
+
192
+ def gen():
193
+ for key in self._valid_keys:
194
+ yield (key, self[key])
195
+
196
+ return gen()
197
+
198
+ @override(dict)
199
+ def keys(self):
200
+ """Returns all valid keys, even the stashed ones."""
201
+ self._lock.acquire()
202
+ ks = list(self._valid_keys)
203
+ self._lock.release()
204
+
205
+ def gen():
206
+ for key in ks:
207
+ yield key
208
+
209
+ return gen()
210
+
211
+ @override(dict)
212
+ def values(self):
213
+ """Returns all valid values, even the stashed ones."""
214
+ self._lock.acquire()
215
+ vs = [self[k] for k in self._valid_keys]
216
+ self._lock.release()
217
+
218
+ def gen():
219
+ for value in vs:
220
+ yield value
221
+
222
+ return gen()
223
+
224
+ @with_lock
225
+ @override(dict)
226
+ def update(self, __m, **kwargs):
227
+ """Updates the map with the given dict and/or kwargs."""
228
+ for k, v in __m.items():
229
+ self[k] = v
230
+ for k, v in kwargs.items():
231
+ self[k] = v
232
+
233
+ @with_lock
234
+ @override(dict)
235
+ def get(self, key: PolicyID):
236
+ """Returns the value for the given key or None if not found."""
237
+ if key not in self._valid_keys:
238
+ return None
239
+ return self[key]
240
+
241
+ @with_lock
242
+ @override(dict)
243
+ def __len__(self) -> int:
244
+ """Returns number of all policies, including the stashed-to-disk ones."""
245
+ return len(self._valid_keys)
246
+
247
+ @with_lock
248
+ @override(dict)
249
+ def __contains__(self, item: PolicyID):
250
+ return item in self._valid_keys
251
+
252
+ @override(dict)
253
+ def __str__(self) -> str:
254
+ # Only print out our keys (policy IDs), not values as this could trigger
255
+ # the LRU caching.
256
+ return (
257
+ f"<PolicyMap lru-caching-capacity={self.capacity} policy-IDs="
258
+ f"{list(self.keys())}>"
259
+ )
260
+
261
+ def _stash_least_used_policy(self) -> Policy:
262
+ """Writes the least-recently used policy's state to the Ray object store.
263
+
264
+ Also closes the session - if applicable - of the stashed policy.
265
+
266
+ Returns:
267
+ The least-recently used policy, that just got removed from the cache.
268
+ """
269
+ # Get policy's state for writing to object store.
270
+ dropped_policy_id = self._deque.popleft()
271
+ assert dropped_policy_id in self.cache
272
+ policy = self.cache[dropped_policy_id]
273
+ policy_state = policy.get_state()
274
+
275
+ # If we don't simply swap out vs an existing policy:
276
+ # Close the tf session, if any.
277
+ if not self.policy_states_are_swappable:
278
+ self._close_session(policy)
279
+
280
+ # Remove from memory. This will clear the tf Graph as well.
281
+ del self.cache[dropped_policy_id]
282
+
283
+ # Store state in Ray object store.
284
+ self._policy_state_refs[dropped_policy_id] = ray.put(policy_state)
285
+
286
+ # Return the just removed policy, in case it's needed by the caller.
287
+ return policy
288
+
289
+ @staticmethod
290
+ def _close_session(policy: Policy):
291
+ sess = policy.get_session()
292
+ # Closes the tf session, if any.
293
+ if sess is not None:
294
+ sess.close()
.venv/lib/python3.11/site-packages/ray/rllib/policy/policy_template.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (
2
+ Any,
3
+ Callable,
4
+ Dict,
5
+ List,
6
+ Optional,
7
+ Tuple,
8
+ Type,
9
+ Union,
10
+ )
11
+
12
+ import gymnasium as gym
13
+
14
+ from ray.rllib.models.catalog import ModelCatalog
15
+ from ray.rllib.models.modelv2 import ModelV2
16
+ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
17
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
18
+ from ray.rllib.policy.policy import Policy
19
+ from ray.rllib.policy.sample_batch import SampleBatch
20
+ from ray.rllib.policy.torch_policy import TorchPolicy
21
+ from ray.rllib.utils import add_mixins, NullContextManager
22
+ from ray.rllib.utils.annotations import OldAPIStack, override
23
+ from ray.rllib.utils.framework import try_import_torch, try_import_jax
24
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
25
+ from ray.rllib.utils.numpy import convert_to_numpy
26
+ from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
27
+
28
+ jax, _ = try_import_jax()
29
+ torch, _ = try_import_torch()
30
+
31
+
32
+ @OldAPIStack
33
+ def build_policy_class(
34
+ name: str,
35
+ framework: str,
36
+ *,
37
+ loss_fn: Optional[
38
+ Callable[
39
+ [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
40
+ Union[TensorType, List[TensorType]],
41
+ ]
42
+ ],
43
+ get_default_config: Optional[Callable[[], AlgorithmConfigDict]] = None,
44
+ stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
45
+ postprocess_fn: Optional[
46
+ Callable[
47
+ [
48
+ Policy,
49
+ SampleBatch,
50
+ Optional[Dict[Any, SampleBatch]],
51
+ Optional[Any],
52
+ ],
53
+ SampleBatch,
54
+ ]
55
+ ] = None,
56
+ extra_action_out_fn: Optional[
57
+ Callable[
58
+ [
59
+ Policy,
60
+ Dict[str, TensorType],
61
+ List[TensorType],
62
+ ModelV2,
63
+ TorchDistributionWrapper,
64
+ ],
65
+ Dict[str, TensorType],
66
+ ]
67
+ ] = None,
68
+ extra_grad_process_fn: Optional[
69
+ Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]
70
+ ] = None,
71
+ # TODO: (sven) Replace "fetches" with "process".
72
+ extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
73
+ optimizer_fn: Optional[
74
+ Callable[[Policy, AlgorithmConfigDict], "torch.optim.Optimizer"]
75
+ ] = None,
76
+ validate_spaces: Optional[
77
+ Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
78
+ ] = None,
79
+ before_init: Optional[
80
+ Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
81
+ ] = None,
82
+ before_loss_init: Optional[
83
+ Callable[
84
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
85
+ ]
86
+ ] = None,
87
+ after_init: Optional[
88
+ Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
89
+ ] = None,
90
+ _after_loss_init: Optional[
91
+ Callable[
92
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
93
+ ]
94
+ ] = None,
95
+ action_sampler_fn: Optional[
96
+ Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
97
+ ] = None,
98
+ action_distribution_fn: Optional[
99
+ Callable[
100
+ [Policy, ModelV2, TensorType, TensorType, TensorType],
101
+ Tuple[TensorType, type, List[TensorType]],
102
+ ]
103
+ ] = None,
104
+ make_model: Optional[
105
+ Callable[
106
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
107
+ ]
108
+ ] = None,
109
+ make_model_and_action_dist: Optional[
110
+ Callable[
111
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
112
+ Tuple[ModelV2, Type[TorchDistributionWrapper]],
113
+ ]
114
+ ] = None,
115
+ compute_gradients_fn: Optional[
116
+ Callable[[Policy, SampleBatch], Tuple[ModelGradients, dict]]
117
+ ] = None,
118
+ apply_gradients_fn: Optional[
119
+ Callable[[Policy, "torch.optim.Optimizer"], None]
120
+ ] = None,
121
+ mixins: Optional[List[type]] = None,
122
+ get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
123
+ ) -> Type[TorchPolicy]:
124
+ """Helper function for creating a new Policy class at runtime.
125
+
126
+ Supports frameworks JAX and PyTorch.
127
+
128
+ Args:
129
+ name: name of the policy (e.g., "PPOTorchPolicy")
130
+ framework: Either "jax" or "torch".
131
+ loss_fn (Optional[Callable[[Policy, ModelV2,
132
+ Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
133
+ List[TensorType]]]]): Callable that returns a loss tensor.
134
+ get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]):
135
+ Optional callable that returns the default config to merge with any
136
+ overrides. If None, uses only(!) the user-provided
137
+ PartialAlgorithmConfigDict as dict for this Policy.
138
+ postprocess_fn (Optional[Callable[[Policy, SampleBatch,
139
+ Optional[Dict[Any, SampleBatch]], Optional[Any]],
140
+ SampleBatch]]): Optional callable for post-processing experience
141
+ batches (called after the super's `postprocess_trajectory` method).
142
+ stats_fn (Optional[Callable[[Policy, SampleBatch],
143
+ Dict[str, TensorType]]]): Optional callable that returns a dict of
144
+ values given the policy and training batch. If None,
145
+ will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
146
+ used for logging (e.g. in TensorBoard).
147
+ extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
148
+ List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
149
+ TensorType]]]): Optional callable that returns a dict of extra
150
+ values to include in experiences. If None, no extra computations
151
+ will be performed.
152
+ extra_grad_process_fn (Optional[Callable[[Policy,
153
+ "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
154
+ Optional callable that is called after gradients are computed and
155
+ returns a processing info dict. If None, will call the
156
+ `TorchPolicy.extra_grad_process()` method instead.
157
+ # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
158
+ extra_learn_fetches_fn (Optional[Callable[[Policy],
159
+ Dict[str, TensorType]]]): Optional callable that returns a dict of
160
+ extra tensors from the policy after loss evaluation. If None,
161
+ will call the `TorchPolicy.extra_compute_grad_fetches()` method
162
+ instead.
163
+ optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict],
164
+ "torch.optim.Optimizer"]]): Optional callable that returns a
165
+ torch optimizer given the policy and config. If None, will call
166
+ the `TorchPolicy.optimizer()` method instead (which returns a
167
+ torch Adam optimizer).
168
+ validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
169
+ AlgorithmConfigDict], None]]): Optional callable that takes the
170
+ Policy, observation_space, action_space, and config to check for
171
+ correctness. If None, no spaces checking will be done.
172
+ before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
173
+ AlgorithmConfigDict], None]]): Optional callable to run at the
174
+ beginning of `Policy.__init__` that takes the same arguments as
175
+ the Policy constructor. If None, this step will be skipped.
176
+ before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
177
+ gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
178
+ run prior to loss init. If None, this step will be skipped.
179
+ after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
180
+ AlgorithmConfigDict], None]]): DEPRECATED: Use `before_loss_init`
181
+ instead.
182
+ _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
183
+ gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
184
+ run after the loss init. If None, this step will be skipped.
185
+ This will be deprecated at some point and renamed into `after_init`
186
+ to match `build_tf_policy()` behavior.
187
+ action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
188
+ Tuple[TensorType, TensorType]]]): Optional callable returning a
189
+ sampled action and its log-likelihood given some (obs and state)
190
+ inputs. If None, will either use `action_distribution_fn` or
191
+ compute actions by calling self.model, then sampling from the
192
+ so parameterized action distribution.
193
+ action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
194
+ TensorType, TensorType], Tuple[TensorType,
195
+ Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
196
+ that takes the Policy, Model, the observation batch, an
197
+ explore-flag, a timestep, and an is_training flag and returns a
198
+ tuple of a) distribution inputs (parameters), b) a dist-class to
199
+ generate an action distribution object from, and c) internal-state
200
+ outputs (empty list if not applicable). If None, will either use
201
+ `action_sampler_fn` or compute actions by calling self.model,
202
+ then sampling from the parameterized action distribution.
203
+ make_model (Optional[Callable[[Policy, gym.spaces.Space,
204
+ gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable
205
+ that takes the same arguments as Policy.__init__ and returns a
206
+ model instance. The distribution class will be determined
207
+ automatically. Note: Only one of `make_model` or
208
+ `make_model_and_action_dist` should be provided. If both are None,
209
+ a default Model will be created.
210
+ make_model_and_action_dist (Optional[Callable[[Policy,
211
+ gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
212
+ Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
213
+ callable that takes the same arguments as Policy.__init__ and
214
+ returns a tuple of model instance and torch action distribution
215
+ class.
216
+ Note: Only one of `make_model` or `make_model_and_action_dist`
217
+ should be provided. If both are None, a default Model will be
218
+ created.
219
+ compute_gradients_fn (Optional[Callable[
220
+ [Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional
221
+ callable that the sampled batch an computes the gradients w.r.
222
+ to the loss function.
223
+ If None, will call the `TorchPolicy.compute_gradients()` method
224
+ instead.
225
+ apply_gradients_fn (Optional[Callable[[Policy,
226
+ "torch.optim.Optimizer"], None]]): Optional callable that
227
+ takes a grads list and applies these to the Model's parameters.
228
+ If None, will call the `TorchPolicy.apply_gradients()` method
229
+ instead.
230
+ mixins (Optional[List[type]]): Optional list of any class mixins for
231
+ the returned policy class. These mixins will be applied in order
232
+ and will have higher precedence than the TorchPolicy class.
233
+ get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
234
+ Optional callable that returns the divisibility requirement for
235
+ sample batches. If None, will assume a value of 1.
236
+
237
+ Returns:
238
+ Type[TorchPolicy]: TorchPolicy child class constructed from the
239
+ specified args.
240
+ """
241
+
242
+ original_kwargs = locals().copy()
243
+ parent_cls = TorchPolicy
244
+ base = add_mixins(parent_cls, mixins)
245
+
246
+ class policy_cls(base):
247
+ def __init__(self, obs_space, action_space, config):
248
+ self.config = config
249
+
250
+ # Set the DL framework for this Policy.
251
+ self.framework = self.config["framework"] = framework
252
+
253
+ # Validate observation- and action-spaces.
254
+ if validate_spaces:
255
+ validate_spaces(self, obs_space, action_space, self.config)
256
+
257
+ # Do some pre-initialization steps.
258
+ if before_init:
259
+ before_init(self, obs_space, action_space, self.config)
260
+
261
+ # Model is customized (use default action dist class).
262
+ if make_model:
263
+ assert make_model_and_action_dist is None, (
264
+ "Either `make_model` or `make_model_and_action_dist`"
265
+ " must be None!"
266
+ )
267
+ self.model = make_model(self, obs_space, action_space, config)
268
+ dist_class, _ = ModelCatalog.get_action_dist(
269
+ action_space, self.config["model"], framework=framework
270
+ )
271
+ # Model and action dist class are customized.
272
+ elif make_model_and_action_dist:
273
+ self.model, dist_class = make_model_and_action_dist(
274
+ self, obs_space, action_space, config
275
+ )
276
+ # Use default model and default action dist.
277
+ else:
278
+ dist_class, logit_dim = ModelCatalog.get_action_dist(
279
+ action_space, self.config["model"], framework=framework
280
+ )
281
+ self.model = ModelCatalog.get_model_v2(
282
+ obs_space=obs_space,
283
+ action_space=action_space,
284
+ num_outputs=logit_dim,
285
+ model_config=self.config["model"],
286
+ framework=framework,
287
+ )
288
+
289
+ # Make sure, we passed in a correct Model factory.
290
+ model_cls = TorchModelV2
291
+ assert isinstance(
292
+ self.model, model_cls
293
+ ), "ERROR: Generated Model must be a TorchModelV2 object!"
294
+
295
+ # Call the framework-specific Policy constructor.
296
+ self.parent_cls = parent_cls
297
+ self.parent_cls.__init__(
298
+ self,
299
+ observation_space=obs_space,
300
+ action_space=action_space,
301
+ config=config,
302
+ model=self.model,
303
+ loss=None if self.config["in_evaluation"] else loss_fn,
304
+ action_distribution_class=dist_class,
305
+ action_sampler_fn=action_sampler_fn,
306
+ action_distribution_fn=action_distribution_fn,
307
+ max_seq_len=config["model"]["max_seq_len"],
308
+ get_batch_divisibility_req=get_batch_divisibility_req,
309
+ )
310
+
311
+ # Merge Model's view requirements into Policy's.
312
+ self.view_requirements.update(self.model.view_requirements)
313
+
314
+ _before_loss_init = before_loss_init or after_init
315
+ if _before_loss_init:
316
+ _before_loss_init(
317
+ self, self.observation_space, self.action_space, config
318
+ )
319
+
320
+ # Perform test runs through postprocessing- and loss functions.
321
+ self._initialize_loss_from_dummy_batch(
322
+ auto_remove_unneeded_view_reqs=True,
323
+ stats_fn=None if self.config["in_evaluation"] else stats_fn,
324
+ )
325
+
326
+ if _after_loss_init:
327
+ _after_loss_init(self, obs_space, action_space, config)
328
+
329
+ # Got to reset global_timestep again after this fake run-through.
330
+ self.global_timestep = 0
331
+
332
+ @override(Policy)
333
+ def postprocess_trajectory(
334
+ self, sample_batch, other_agent_batches=None, episode=None
335
+ ):
336
+ # Do all post-processing always with no_grad().
337
+ # Not using this here will introduce a memory leak
338
+ # in torch (issue #6962).
339
+ with self._no_grad_context():
340
+ # Call super's postprocess_trajectory first.
341
+ sample_batch = super().postprocess_trajectory(
342
+ sample_batch, other_agent_batches, episode
343
+ )
344
+ if postprocess_fn:
345
+ return postprocess_fn(
346
+ self, sample_batch, other_agent_batches, episode
347
+ )
348
+
349
+ return sample_batch
350
+
351
+ @override(parent_cls)
352
+ def extra_grad_process(self, optimizer, loss):
353
+ """Called after optimizer.zero_grad() and loss.backward() calls.
354
+
355
+ Allows for gradient processing before optimizer.step() is called.
356
+ E.g. for gradient clipping.
357
+ """
358
+ if extra_grad_process_fn:
359
+ return extra_grad_process_fn(self, optimizer, loss)
360
+ else:
361
+ return parent_cls.extra_grad_process(self, optimizer, loss)
362
+
363
+ @override(parent_cls)
364
+ def extra_compute_grad_fetches(self):
365
+ if extra_learn_fetches_fn:
366
+ fetches = convert_to_numpy(extra_learn_fetches_fn(self))
367
+ # Auto-add empty learner stats dict if needed.
368
+ return dict({LEARNER_STATS_KEY: {}}, **fetches)
369
+ else:
370
+ return parent_cls.extra_compute_grad_fetches(self)
371
+
372
+ @override(parent_cls)
373
+ def compute_gradients(self, batch):
374
+ if compute_gradients_fn:
375
+ return compute_gradients_fn(self, batch)
376
+ else:
377
+ return parent_cls.compute_gradients(self, batch)
378
+
379
+ @override(parent_cls)
380
+ def apply_gradients(self, gradients):
381
+ if apply_gradients_fn:
382
+ apply_gradients_fn(self, gradients)
383
+ else:
384
+ parent_cls.apply_gradients(self, gradients)
385
+
386
+ @override(parent_cls)
387
+ def extra_action_out(self, input_dict, state_batches, model, action_dist):
388
+ with self._no_grad_context():
389
+ if extra_action_out_fn:
390
+ stats_dict = extra_action_out_fn(
391
+ self, input_dict, state_batches, model, action_dist
392
+ )
393
+ else:
394
+ stats_dict = parent_cls.extra_action_out(
395
+ self, input_dict, state_batches, model, action_dist
396
+ )
397
+ return self._convert_to_numpy(stats_dict)
398
+
399
+ @override(parent_cls)
400
+ def optimizer(self):
401
+ if optimizer_fn:
402
+ optimizers = optimizer_fn(self, self.config)
403
+ else:
404
+ optimizers = parent_cls.optimizer(self)
405
+ return optimizers
406
+
407
+ @override(parent_cls)
408
+ def extra_grad_info(self, train_batch):
409
+ with self._no_grad_context():
410
+ if stats_fn:
411
+ stats_dict = stats_fn(self, train_batch)
412
+ else:
413
+ stats_dict = self.parent_cls.extra_grad_info(self, train_batch)
414
+ return self._convert_to_numpy(stats_dict)
415
+
416
+ def _no_grad_context(self):
417
+ if self.framework == "torch":
418
+ return torch.no_grad()
419
+ return NullContextManager()
420
+
421
+ def _convert_to_numpy(self, data):
422
+ if self.framework == "torch":
423
+ return convert_to_numpy(data)
424
+ return data
425
+
426
+ def with_updates(**overrides):
427
+ """Creates a Torch|JAXPolicy cls based on settings of another one.
428
+
429
+ Keyword Args:
430
+ **overrides: The settings (passed into `build_torch_policy`) that
431
+ should be different from the class that this method is called
432
+ on.
433
+
434
+ Returns:
435
+ type: A new Torch|JAXPolicy sub-class.
436
+
437
+ Examples:
438
+ >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
439
+ .. name="MySpecialDQNPolicyClass",
440
+ .. loss_function=[some_new_loss_function],
441
+ .. )
442
+ """
443
+ return build_policy_class(**dict(original_kwargs, **overrides))
444
+
445
+ policy_cls.with_updates = staticmethod(with_updates)
446
+ policy_cls.__name__ = name
447
+ policy_cls.__qualname__ = name
448
+ return policy_cls
.venv/lib/python3.11/site-packages/ray/rllib/policy/rnn_sequencing.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RNN utils for RLlib.
2
+
3
+ The main trick here is that we add the time dimension at the last moment.
4
+ The non-LSTM layers of the model see their inputs as one flat batch. Before
5
+ the LSTM cell, we reshape the input to add the expected time dimension. During
6
+ postprocessing, we dynamically pad the experience batches so that this
7
+ reshaping is possible.
8
+
9
+ Note that this padding strategy only works out if we assume zero inputs don't
10
+ meaningfully affect the loss function. This happens to be true for all the
11
+ current algorithms: https://github.com/ray-project/ray/issues/2992
12
+ """
13
+
14
+ import logging
15
+ import numpy as np
16
+ import tree # pip install dm_tree
17
+ from typing import List, Optional
18
+ import functools
19
+
20
+ from ray.rllib.policy.sample_batch import SampleBatch
21
+ from ray.rllib.utils.annotations import OldAPIStack
22
+ from ray.rllib.utils.debug import summarize
23
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
24
+ from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
25
+ from ray.util import log_once
26
+ from ray.rllib.utils.typing import SampleBatchType
27
+
28
+ tf1, tf, tfv = try_import_tf()
29
+ torch, _ = try_import_torch()
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @OldAPIStack
35
+ def pad_batch_to_sequences_of_same_size(
36
+ batch: SampleBatch,
37
+ max_seq_len: int,
38
+ shuffle: bool = False,
39
+ batch_divisibility_req: int = 1,
40
+ feature_keys: Optional[List[str]] = None,
41
+ view_requirements: Optional[ViewRequirementsDict] = None,
42
+ _enable_new_api_stack: bool = False,
43
+ padding: str = "zero",
44
+ ):
45
+ """Applies padding to `batch` so it's choppable into same-size sequences.
46
+
47
+ Shuffles `batch` (if desired), makes sure divisibility requirement is met,
48
+ then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
49
+ adding a time dimension (yet).
50
+ Padding depends on episodes found in batch and `max_seq_len`.
51
+
52
+ Args:
53
+ batch: The SampleBatch object. All values in here have
54
+ the shape [B, ...].
55
+ max_seq_len: The max. sequence length to use for chopping.
56
+ shuffle: Whether to shuffle batch sequences. Shuffle may
57
+ be done in-place. This only makes sense if you're further
58
+ applying minibatch SGD after getting the outputs.
59
+ batch_divisibility_req: The int by which the batch dimension
60
+ must be dividable.
61
+ feature_keys: An optional list of keys to apply sequence-chopping
62
+ to. If None, use all keys in batch that are not
63
+ "state_in/out_"-type keys.
64
+ view_requirements: An optional Policy ViewRequirements dict to
65
+ be able to infer whether e.g. dynamic max'ing should be
66
+ applied over the seq_lens.
67
+ _enable_new_api_stack: This is a temporary flag to enable the new RLModule API.
68
+ After a complete rollout of the new API, this flag will be removed.
69
+ padding: Padding type to use. Either "zero" or "last". Zero padding
70
+ will pad with zeros, last padding will pad with the last value.
71
+ """
72
+ # If already zero-padded, skip.
73
+ if batch.zero_padded:
74
+ return
75
+
76
+ batch.zero_padded = True
77
+
78
+ if batch_divisibility_req > 1:
79
+ meets_divisibility_reqs = (
80
+ len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
81
+ # not multiagent
82
+ and max(batch[SampleBatch.AGENT_INDEX]) == 0
83
+ )
84
+ else:
85
+ meets_divisibility_reqs = True
86
+
87
+ states_already_reduced_to_init = False
88
+
89
+ # RNN/attention net case. Figure out whether we should apply dynamic
90
+ # max'ing over the list of sequence lengths.
91
+ if _enable_new_api_stack and ("state_in" in batch or "state_out" in batch):
92
+ # TODO (Kourosh): This is a temporary fix to enable the new RLModule API.
93
+ # We should think of a more elegant solution once we have confirmed that other
94
+ # parts of the API are stable and user-friendly.
95
+ seq_lens = batch.get(SampleBatch.SEQ_LENS)
96
+
97
+ # state_in is a nested dict of tensors of states. We need to retreive the
98
+ # length of the inner most tensor (which should be already the same as the
99
+ # length of other tensors) and compare it to len(seq_lens).
100
+ state_ins = tree.flatten(batch["state_in"])
101
+ if state_ins:
102
+ assert all(
103
+ len(state_in) == len(state_ins[0]) for state_in in state_ins
104
+ ), "All state_in tensors should have the same batch_dim size."
105
+
106
+ # if the batch dim of states is the same as the number of sequences
107
+ if len(state_ins[0]) == len(seq_lens):
108
+ states_already_reduced_to_init = True
109
+
110
+ # TODO (Kourosh): What is the use-case of DynamicMax functionality?
111
+ dynamic_max = True
112
+ else:
113
+ dynamic_max = False
114
+
115
+ elif not _enable_new_api_stack and (
116
+ "state_in_0" in batch or "state_out_0" in batch
117
+ ):
118
+ # Check, whether the state inputs have already been reduced to their
119
+ # init values at the beginning of each max_seq_len chunk.
120
+ if batch.get(SampleBatch.SEQ_LENS) is not None and len(
121
+ batch["state_in_0"]
122
+ ) == len(batch[SampleBatch.SEQ_LENS]):
123
+ states_already_reduced_to_init = True
124
+
125
+ # RNN (or single timestep state-in): Set the max dynamically.
126
+ if view_requirements and view_requirements["state_in_0"].shift_from is None:
127
+ dynamic_max = True
128
+ # Attention Nets (state inputs are over some range): No dynamic maxing
129
+ # possible.
130
+ else:
131
+ dynamic_max = False
132
+ # Multi-agent case.
133
+ elif not meets_divisibility_reqs:
134
+ max_seq_len = batch_divisibility_req
135
+ dynamic_max = False
136
+ batch.max_seq_len = max_seq_len
137
+ # Simple case: No RNN/attention net, nor do we need to pad.
138
+ else:
139
+ if shuffle:
140
+ batch.shuffle()
141
+ return
142
+
143
+ # RNN, attention net, or multi-agent case.
144
+ state_keys = []
145
+ feature_keys_ = feature_keys or []
146
+ for k, v in batch.items():
147
+ if k.startswith("state_in"):
148
+ state_keys.append(k)
149
+ elif (
150
+ not feature_keys
151
+ and (not k.startswith("state_out") if not _enable_new_api_stack else True)
152
+ and k not in [SampleBatch.SEQ_LENS]
153
+ ):
154
+ feature_keys_.append(k)
155
+ feature_sequences, initial_states, seq_lens = chop_into_sequences(
156
+ feature_columns=[batch[k] for k in feature_keys_],
157
+ state_columns=[batch[k] for k in state_keys],
158
+ episode_ids=batch.get(SampleBatch.EPS_ID),
159
+ unroll_ids=batch.get(SampleBatch.UNROLL_ID),
160
+ agent_indices=batch.get(SampleBatch.AGENT_INDEX),
161
+ seq_lens=batch.get(SampleBatch.SEQ_LENS),
162
+ max_seq_len=max_seq_len,
163
+ dynamic_max=dynamic_max,
164
+ states_already_reduced_to_init=states_already_reduced_to_init,
165
+ shuffle=shuffle,
166
+ handle_nested_data=True,
167
+ padding=padding,
168
+ pad_infos_with_empty_dicts=_enable_new_api_stack,
169
+ )
170
+ for i, k in enumerate(feature_keys_):
171
+ batch[k] = tree.unflatten_as(batch[k], feature_sequences[i])
172
+ for i, k in enumerate(state_keys):
173
+ batch[k] = initial_states[i]
174
+ batch[SampleBatch.SEQ_LENS] = np.array(seq_lens)
175
+ if dynamic_max:
176
+ batch.max_seq_len = max(seq_lens)
177
+
178
+ if log_once("rnn_ma_feed_dict"):
179
+ logger.info(
180
+ "Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
181
+ summarize(
182
+ {
183
+ "features": feature_sequences,
184
+ "initial_states": initial_states,
185
+ "seq_lens": seq_lens,
186
+ "max_seq_len": max_seq_len,
187
+ }
188
+ )
189
+ )
190
+ )
191
+
192
+
193
+ @OldAPIStack
194
+ def add_time_dimension(
195
+ padded_inputs: TensorType,
196
+ *,
197
+ seq_lens: TensorType,
198
+ framework: str = "tf",
199
+ time_major: bool = False,
200
+ ):
201
+ """Adds a time dimension to padded inputs.
202
+
203
+ Args:
204
+ padded_inputs: a padded batch of sequences. That is,
205
+ for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
206
+ A, B, C are sequence elements and * denotes padding.
207
+ seq_lens: A 1D tensor of sequence lengths, denoting the non-padded length
208
+ in timesteps of each rollout in the batch.
209
+ framework: The framework string ("tf2", "tf", "torch").
210
+ time_major: Whether data should be returned in time-major (TxB)
211
+ format or not (BxT).
212
+
213
+ Returns:
214
+ TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
215
+ """
216
+
217
+ # Sequence lengths have to be specified for LSTM batch inputs. The
218
+ # input batch must be padded to the max seq length given here. That is,
219
+ # batch_size == len(seq_lens) * max(seq_lens)
220
+ if framework in ["tf2", "tf"]:
221
+ assert time_major is False, "time-major not supported yet for tf!"
222
+ padded_inputs = tf.convert_to_tensor(padded_inputs)
223
+ padded_batch_size = tf.shape(padded_inputs)[0]
224
+ # Dynamically reshape the padded batch to introduce a time dimension.
225
+ new_batch_size = tf.shape(seq_lens)[0]
226
+ time_size = padded_batch_size // new_batch_size
227
+ new_shape = tf.concat(
228
+ [
229
+ tf.expand_dims(new_batch_size, axis=0),
230
+ tf.expand_dims(time_size, axis=0),
231
+ tf.shape(padded_inputs)[1:],
232
+ ],
233
+ axis=0,
234
+ )
235
+ return tf.reshape(padded_inputs, new_shape)
236
+ elif framework == "torch":
237
+ padded_inputs = torch.as_tensor(padded_inputs)
238
+ padded_batch_size = padded_inputs.shape[0]
239
+
240
+ # Dynamically reshape the padded batch to introduce a time dimension.
241
+ new_batch_size = seq_lens.shape[0]
242
+ time_size = padded_batch_size // new_batch_size
243
+ batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
244
+ padded_outputs = padded_inputs.view(batch_major_shape)
245
+
246
+ if time_major:
247
+ # Swap the batch and time dimensions
248
+ padded_outputs = padded_outputs.transpose(0, 1)
249
+ return padded_outputs
250
+ else:
251
+ assert framework == "np", "Unknown framework: {}".format(framework)
252
+ padded_inputs = np.asarray(padded_inputs)
253
+ padded_batch_size = padded_inputs.shape[0]
254
+
255
+ # Dynamically reshape the padded batch to introduce a time dimension.
256
+ new_batch_size = seq_lens.shape[0]
257
+ time_size = padded_batch_size // new_batch_size
258
+ batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
259
+ padded_outputs = padded_inputs.reshape(batch_major_shape)
260
+
261
+ if time_major:
262
+ # Swap the batch and time dimensions
263
+ padded_outputs = padded_outputs.transpose(0, 1)
264
+ return padded_outputs
265
+
266
+
267
+ @OldAPIStack
268
+ def chop_into_sequences(
269
+ *,
270
+ feature_columns,
271
+ state_columns,
272
+ max_seq_len,
273
+ episode_ids=None,
274
+ unroll_ids=None,
275
+ agent_indices=None,
276
+ dynamic_max=True,
277
+ shuffle=False,
278
+ seq_lens=None,
279
+ states_already_reduced_to_init=False,
280
+ handle_nested_data=False,
281
+ _extra_padding=0,
282
+ padding: str = "zero",
283
+ pad_infos_with_empty_dicts: bool = False,
284
+ ):
285
+ """Truncate and pad experiences into fixed-length sequences.
286
+
287
+ Args:
288
+ feature_columns: List of arrays containing features.
289
+ state_columns: List of arrays containing LSTM state values.
290
+ max_seq_len: Max length of sequences. Sequences longer than max_seq_len
291
+ will be split into subsequences that span the batch dimension
292
+ and sum to max_seq_len.
293
+ episode_ids (List[EpisodeID]): List of episode ids for each step.
294
+ unroll_ids (List[UnrollID]): List of identifiers for the sample batch.
295
+ This is used to make sure sequences are cut between sample batches.
296
+ agent_indices (List[AgentID]): List of agent ids for each step. Note
297
+ that this has to be combined with episode_ids for uniqueness.
298
+ dynamic_max: Whether to dynamically shrink the max seq len.
299
+ For example, if max len is 20 and the actual max seq len in the
300
+ data is 7, it will be shrunk to 7.
301
+ shuffle: Whether to shuffle the sequence outputs.
302
+ handle_nested_data: If True, assume that the data in
303
+ `feature_columns` could be nested structures (of data).
304
+ If False, assumes that all items in `feature_columns` are
305
+ only np.ndarrays (no nested structured of np.ndarrays).
306
+ _extra_padding: Add extra padding to the end of sequences.
307
+ padding: Padding type to use. Either "zero" or "last". Zero padding
308
+ will pad with zeros, last padding will pad with the last value.
309
+ pad_infos_with_empty_dicts: If True, will zero-pad INFOs with empty
310
+ dicts (instead of None). Used by the new API stack in the meantime,
311
+ however, as soon as the new ConnectorV2 API will be activated (as
312
+ part of the new API stack), we will no longer use this utility function
313
+ anyway.
314
+
315
+ Returns:
316
+ f_pad: Padded feature columns. These will be of shape
317
+ [NUM_SEQUENCES * MAX_SEQ_LEN, ...].
318
+ s_init: Initial states for each sequence, of shape
319
+ [NUM_SEQUENCES, ...].
320
+ seq_lens: List of sequence lengths, of shape [NUM_SEQUENCES].
321
+
322
+ .. testcode::
323
+ :skipif: True
324
+
325
+ from ray.rllib.policy.rnn_sequencing import chop_into_sequences
326
+ f_pad, s_init, seq_lens = chop_into_sequences(
327
+ episode_ids=[1, 1, 5, 5, 5, 5],
328
+ unroll_ids=[4, 4, 4, 4, 4, 4],
329
+ agent_indices=[0, 0, 0, 0, 0, 0],
330
+ feature_columns=[[4, 4, 8, 8, 8, 8],
331
+ [1, 1, 0, 1, 1, 0]],
332
+ state_columns=[[4, 5, 4, 5, 5, 5]],
333
+ max_seq_len=3)
334
+ print(f_pad)
335
+ print(s_init)
336
+ print(seq_lens)
337
+
338
+
339
+ .. testoutput::
340
+
341
+ [[4, 4, 0, 8, 8, 8, 8, 0, 0],
342
+ [1, 1, 0, 0, 1, 1, 0, 0, 0]]
343
+ [[4, 4, 5]]
344
+ [2, 3, 1]
345
+ """
346
+
347
+ if seq_lens is None or len(seq_lens) == 0:
348
+ prev_id = None
349
+ seq_lens = []
350
+ seq_len = 0
351
+ unique_ids = np.add(
352
+ np.add(episode_ids, agent_indices),
353
+ np.array(unroll_ids, dtype=np.int64) << 32,
354
+ )
355
+ for uid in unique_ids:
356
+ if (prev_id is not None and uid != prev_id) or seq_len >= max_seq_len:
357
+ seq_lens.append(seq_len)
358
+ seq_len = 0
359
+ seq_len += 1
360
+ prev_id = uid
361
+ if seq_len:
362
+ seq_lens.append(seq_len)
363
+ seq_lens = np.array(seq_lens, dtype=np.int32)
364
+
365
+ # Dynamically shrink max len as needed to optimize memory usage
366
+ if dynamic_max:
367
+ max_seq_len = max(seq_lens) + _extra_padding
368
+
369
+ length = len(seq_lens) * max_seq_len
370
+
371
+ feature_sequences = []
372
+ for col in feature_columns:
373
+ if isinstance(col, list):
374
+ col = np.array(col)
375
+ feature_sequences.append([])
376
+
377
+ for f in tree.flatten(col):
378
+ # Save unnecessary copy.
379
+ if not isinstance(f, np.ndarray):
380
+ f = np.array(f)
381
+
382
+ # New stack behavior (temporarily until we move to ConnectorV2 API, where
383
+ # this (admitedly convoluted) function will no longer be used at all).
384
+ if (
385
+ f.dtype == object
386
+ and pad_infos_with_empty_dicts
387
+ and isinstance(f[0], dict)
388
+ ):
389
+ f_pad = [{} for _ in range(length)]
390
+ # Old stack behavior: Pad INFOs with None.
391
+ elif f.dtype == object or f.dtype.type is np.str_:
392
+ f_pad = [None] * length
393
+ # Pad everything else with zeros.
394
+ else:
395
+ # Make sure type doesn't change.
396
+ f_pad = np.zeros((length,) + np.shape(f)[1:], dtype=f.dtype)
397
+ seq_base = 0
398
+ i = 0
399
+ for len_ in seq_lens:
400
+ for seq_offset in range(len_):
401
+ f_pad[seq_base + seq_offset] = f[i]
402
+ i += 1
403
+
404
+ if padding == "last":
405
+ for seq_offset in range(len_, max_seq_len):
406
+ f_pad[seq_base + seq_offset] = f[i - 1]
407
+
408
+ seq_base += max_seq_len
409
+
410
+ assert i == len(f), f
411
+ feature_sequences[-1].append(f_pad)
412
+
413
+ if states_already_reduced_to_init:
414
+ initial_states = state_columns
415
+ else:
416
+ initial_states = []
417
+ for state_column in state_columns:
418
+ if isinstance(state_column, list):
419
+ state_column = np.array(state_column)
420
+ initial_state_flat = []
421
+ # state_column may have a nested structure (e.g. LSTM state).
422
+ for s in tree.flatten(state_column):
423
+ # Skip unnecessary copy.
424
+ if not isinstance(s, np.ndarray):
425
+ s = np.array(s)
426
+ s_init = []
427
+ i = 0
428
+ for len_ in seq_lens:
429
+ s_init.append(s[i])
430
+ i += len_
431
+ initial_state_flat.append(np.array(s_init))
432
+ initial_states.append(tree.unflatten_as(state_column, initial_state_flat))
433
+
434
+ if shuffle:
435
+ permutation = np.random.permutation(len(seq_lens))
436
+ for i, f in enumerate(tree.flatten(feature_sequences)):
437
+ orig_shape = f.shape
438
+ f = np.reshape(f, (len(seq_lens), -1) + f.shape[1:])
439
+ f = f[permutation]
440
+ f = np.reshape(f, orig_shape)
441
+ feature_sequences[i] = f
442
+ for i, s in enumerate(initial_states):
443
+ s = s[permutation]
444
+ initial_states[i] = s
445
+ seq_lens = seq_lens[permutation]
446
+
447
+ # Classic behavior: Don't assume data in feature_columns are nested
448
+ # structs. Don't return them as flattened lists, but as is (index 0).
449
+ if not handle_nested_data:
450
+ feature_sequences = [f[0] for f in feature_sequences]
451
+
452
+ return feature_sequences, initial_states, seq_lens
453
+
454
+
455
+ @OldAPIStack
456
+ def timeslice_along_seq_lens_with_overlap(
457
+ sample_batch: SampleBatchType,
458
+ seq_lens: Optional[List[int]] = None,
459
+ zero_pad_max_seq_len: int = 0,
460
+ pre_overlap: int = 0,
461
+ zero_init_states: bool = True,
462
+ ) -> List["SampleBatch"]:
463
+ """Slices batch along `seq_lens` (each seq-len item produces one batch).
464
+
465
+ Args:
466
+ sample_batch: The SampleBatch to timeslice.
467
+ seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
468
+ at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`.
469
+ zero_pad_max_seq_len: If >0, already zero-pad the resulting
470
+ slices up to this length. NOTE: This max-len will include the
471
+ additional timesteps gained via setting pre_overlap (see Example).
472
+ pre_overlap: If >0, will overlap each two consecutive slices by
473
+ this many timesteps (toward the left side). This will cause
474
+ zero-padding at the very beginning of the batch.
475
+ zero_init_states: Whether initial states should always be
476
+ zero'd. If False, will use the state_outs of the batch to
477
+ populate state_in values.
478
+
479
+ Returns:
480
+ List[SampleBatch]: The list of (new) SampleBatches.
481
+
482
+ Examples:
483
+ assert seq_lens == [5, 5, 2]
484
+ assert sample_batch.count == 12
485
+ # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
486
+ slices = timeslice_along_seq_lens_with_overlap(
487
+ sample_batch=sample_batch.
488
+ zero_pad_max_seq_len=10,
489
+ pre_overlap=3)
490
+ # Z = zero padding (at beginning or end).
491
+ # |pre (3)| seq | max-seq-len (up to 10)
492
+ # slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z
493
+ # slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z
494
+ # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
495
+ # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
496
+ # count (makes sure each slice has exactly length 10).
497
+ """
498
+ if seq_lens is None:
499
+ seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
500
+ else:
501
+ if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once(
502
+ "overriding_sequencing_information"
503
+ ):
504
+ logger.warning(
505
+ "Found sequencing information in a batch that will be "
506
+ "ignored when slicing. Ignore this warning if you know "
507
+ "what you are doing."
508
+ )
509
+
510
+ if seq_lens is None:
511
+ max_seq_len = zero_pad_max_seq_len - pre_overlap
512
+ if log_once("no_sequence_lengths_available_for_time_slicing"):
513
+ logger.warning(
514
+ "Trying to slice a batch along sequences without "
515
+ "sequence lengths being provided in the batch. Batch will "
516
+ "be sliced into slices of size "
517
+ "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
518
+ max_seq_len, zero_pad_max_seq_len, pre_overlap
519
+ )
520
+ )
521
+ num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
522
+ seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
523
+ [last_seq_len] if last_seq_len else []
524
+ )
525
+
526
+ assert (
527
+ seq_lens is not None and len(seq_lens) > 0
528
+ ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
529
+ # Generate n slices based on seq_lens.
530
+ start = 0
531
+ slices = []
532
+ for seq_len in seq_lens:
533
+ pre_begin = start - pre_overlap
534
+ slice_begin = start
535
+ end = start + seq_len
536
+ slices.append((pre_begin, slice_begin, end))
537
+ start += seq_len
538
+
539
+ timeslices = []
540
+ for begin, slice_begin, end in slices:
541
+ zero_length = None
542
+ data_begin = 0
543
+ zero_init_states_ = zero_init_states
544
+ if begin < 0:
545
+ zero_length = pre_overlap
546
+ data_begin = slice_begin
547
+ zero_init_states_ = True
548
+ else:
549
+ eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end]
550
+ is_last_episode_ids = eps_ids == eps_ids[-1]
551
+ if not is_last_episode_ids[0]:
552
+ zero_length = int(sum(1.0 - is_last_episode_ids))
553
+ data_begin = begin + zero_length
554
+ zero_init_states_ = True
555
+
556
+ if zero_length is not None:
557
+ data = {
558
+ k: np.concatenate(
559
+ [
560
+ np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype),
561
+ v[data_begin:end],
562
+ ]
563
+ )
564
+ for k, v in sample_batch.items()
565
+ if k != SampleBatch.SEQ_LENS
566
+ }
567
+ else:
568
+ data = {
569
+ k: v[begin:end]
570
+ for k, v in sample_batch.items()
571
+ if k != SampleBatch.SEQ_LENS
572
+ }
573
+
574
+ if zero_init_states_:
575
+ i = 0
576
+ key = "state_in_{}".format(i)
577
+ while key in data:
578
+ data[key] = np.zeros_like(sample_batch[key][0:1])
579
+ # Del state_out_n from data if exists.
580
+ data.pop("state_out_{}".format(i), None)
581
+ i += 1
582
+ key = "state_in_{}".format(i)
583
+ # TODO: This will not work with attention nets as their state_outs are
584
+ # not compatible with state_ins.
585
+ else:
586
+ i = 0
587
+ key = "state_in_{}".format(i)
588
+ while key in data:
589
+ data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin]
590
+ del data["state_out_{}".format(i)]
591
+ i += 1
592
+ key = "state_in_{}".format(i)
593
+
594
+ timeslices.append(SampleBatch(data, seq_lens=[end - begin]))
595
+
596
+ # Zero-pad each slice if necessary.
597
+ if zero_pad_max_seq_len > 0:
598
+ for ts in timeslices:
599
+ ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)
600
+
601
+ return timeslices
602
+
603
+
604
+ @OldAPIStack
605
+ def get_fold_unfold_fns(b_dim: int, t_dim: int, framework: str):
606
+ """Produces two functions to fold/unfold any Tensors in a struct.
607
+
608
+ Args:
609
+ b_dim: The batch dimension to use for folding.
610
+ t_dim: The time dimension to use for folding.
611
+ framework: The framework to use for folding. One of "tf2" or "torch".
612
+
613
+ Returns:
614
+ fold: A function that takes a struct of torch.Tensors and reshapes
615
+ them to have a first dimension of `b_dim * t_dim`.
616
+ unfold: A function that takes a struct of torch.Tensors and reshapes
617
+ them to have a first dimension of `b_dim` and a second dimension
618
+ of `t_dim`.
619
+ """
620
+ if framework in "tf2":
621
+ # TensorFlow traced eager complains if we don't convert these to tensors here
622
+ b_dim = tf.convert_to_tensor(b_dim)
623
+ t_dim = tf.convert_to_tensor(t_dim)
624
+
625
+ def fold_mapping(item):
626
+ if item is None:
627
+ # Torch has no representation for `None`, so we return None
628
+ return item
629
+ item = tf.convert_to_tensor(item)
630
+ shape = tf.shape(item)
631
+ other_dims = shape[2:]
632
+ return tf.reshape(item, tf.concat([[b_dim * t_dim], other_dims], axis=0))
633
+
634
+ def unfold_mapping(item):
635
+ if item is None:
636
+ return item
637
+ item = tf.convert_to_tensor(item)
638
+ shape = item.shape
639
+ other_dims = shape[1:]
640
+
641
+ return tf.reshape(item, tf.concat([[b_dim], [t_dim], other_dims], axis=0))
642
+
643
+ elif framework == "torch":
644
+
645
+ def fold_mapping(item):
646
+ if item is None:
647
+ # Torch has no representation for `None`, so we return None
648
+ return item
649
+ item = torch.as_tensor(item)
650
+ size = list(item.size())
651
+ current_b_dim, current_t_dim = list(size[:2])
652
+
653
+ assert (b_dim, t_dim) == (current_b_dim, current_t_dim), (
654
+ "All tensors in the struct must have the same batch and time "
655
+ "dimensions. Got {} and {}.".format(
656
+ (b_dim, t_dim), (current_b_dim, current_t_dim)
657
+ )
658
+ )
659
+
660
+ other_dims = size[2:]
661
+ return item.reshape([b_dim * t_dim] + other_dims)
662
+
663
+ def unfold_mapping(item):
664
+ if item is None:
665
+ return item
666
+ item = torch.as_tensor(item)
667
+ size = list(item.size())
668
+ current_b_dim = size[0]
669
+ other_dims = size[1:]
670
+ assert current_b_dim == b_dim * t_dim, (
671
+ "The first dimension of the tensor must be equal to the product of "
672
+ "the desired batch and time dimensions. Got {} and {}.".format(
673
+ current_b_dim, b_dim * t_dim
674
+ )
675
+ )
676
+ return item.reshape([b_dim, t_dim] + other_dims)
677
+
678
+ else:
679
+ raise ValueError(f"framework {framework} not implemented!")
680
+
681
+ return functools.partial(tree.map_structure, fold_mapping), functools.partial(
682
+ tree.map_structure, unfold_mapping
683
+ )
.venv/lib/python3.11/site-packages/ray/rllib/policy/sample_batch.py ADDED
@@ -0,0 +1,1820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from functools import partial
3
+ import itertools
4
+ import sys
5
+ from numbers import Number
6
+ from typing import Dict, Iterator, Set, Union
7
+ from typing import List, Optional
8
+
9
+ import numpy as np
10
+ import tree # pip install dm_tree
11
+
12
+ from ray.rllib.core.columns import Columns
13
+ from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI
14
+ from ray.rllib.utils.compression import pack, unpack, is_compressed
15
+ from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
16
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
17
+ from ray.rllib.utils.torch_utils import convert_to_torch_tensor
18
+ from ray.rllib.utils.typing import (
19
+ ModuleID,
20
+ PolicyID,
21
+ TensorType,
22
+ SampleBatchType,
23
+ ViewRequirementsDict,
24
+ )
25
+ from ray.util import log_once
26
+
27
+ tf1, tf, tfv = try_import_tf()
28
+ torch, _ = try_import_torch()
29
+
30
+ # Default policy id for single agent environments
31
+ DEFAULT_POLICY_ID = "default_policy"
32
+
33
+
34
+ @DeveloperAPI
35
+ def attempt_count_timesteps(tensor_dict: dict):
36
+ """Attempt to count timesteps based on dimensions of individual elements.
37
+
38
+ Returns the first successfully counted number of timesteps.
39
+ We do not attempt to count on INFOS or any state_in_* and state_out_* keys. The
40
+ number of timesteps we count in cases where we are unable to count is zero.
41
+
42
+ Args:
43
+ tensor_dict: A SampleBatch or another dict.
44
+
45
+ Returns:
46
+ count: The inferred number of timesteps >= 0.
47
+ """
48
+ # Try to infer the "length" of the SampleBatch by finding the first
49
+ # value that is actually a ndarray/tensor.
50
+ # Skip manual counting routine if we can directly infer count from sequence lengths
51
+ seq_lens = tensor_dict.get(SampleBatch.SEQ_LENS)
52
+ if (
53
+ seq_lens is not None
54
+ and not (tf and tf.is_tensor(seq_lens) and not hasattr(seq_lens, "numpy"))
55
+ and len(seq_lens) > 0
56
+ ):
57
+ if torch and torch.is_tensor(seq_lens):
58
+ return seq_lens.sum().item()
59
+ else:
60
+ return int(sum(seq_lens))
61
+
62
+ for k, v in tensor_dict.items():
63
+ if k == SampleBatch.SEQ_LENS:
64
+ continue
65
+
66
+ assert isinstance(k, str), tensor_dict
67
+
68
+ if (
69
+ k == SampleBatch.INFOS
70
+ or k.startswith("state_in_")
71
+ or k.startswith("state_out_")
72
+ ):
73
+ # Don't attempt to count on infos since we make no assumptions
74
+ # about its content
75
+ # Don't attempt to count on state since nesting can potentially mess
76
+ # things up
77
+ continue
78
+
79
+ # If this is a nested dict (for example a nested observation),
80
+ # try to flatten it, assert that all elements have the same length (batch
81
+ # dimension)
82
+ v_list = tree.flatten(v) if isinstance(v, (dict, tuple)) else [v]
83
+ # TODO: Drop support for lists and Numbers as values.
84
+ # If v_list contains lists or Numbers, convert them to arrays, too.
85
+ v_list = [
86
+ np.array(_v) if isinstance(_v, (Number, list)) else _v for _v in v_list
87
+ ]
88
+ try:
89
+ # Add one of the elements' length, since they are all the same
90
+ _len = len(v_list[0])
91
+ if _len:
92
+ return _len
93
+ except Exception:
94
+ pass
95
+
96
+ # Return zero if we are unable to count
97
+ return 0
98
+
99
+
100
+ @PublicAPI
101
+ class SampleBatch(dict):
102
+ """Wrapper around a dictionary with string keys and array-like values.
103
+
104
+ For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
105
+ samples, each with an "obs" and "reward" attribute.
106
+ """
107
+
108
+ # On rows in SampleBatch:
109
+ # Each comment signifies how values relate to each other within a given row.
110
+ # A row generally signifies one timestep. Most importantly, at t=0, SampleBatch.OBS
111
+ # will usually be the reset-observation, while SampleBatch.ACTIONS will be the
112
+ # action based on the reset-observation and so on. This scheme is derived from
113
+ # RLlib's sampling logic.
114
+
115
+ # The following fields have all been moved to `Columns` and are only left here
116
+ # for backward compatibility.
117
+ OBS = Columns.OBS
118
+ ACTIONS = Columns.ACTIONS
119
+ REWARDS = Columns.REWARDS
120
+ TERMINATEDS = Columns.TERMINATEDS
121
+ TRUNCATEDS = Columns.TRUNCATEDS
122
+ INFOS = Columns.INFOS
123
+ SEQ_LENS = Columns.SEQ_LENS
124
+ T = Columns.T
125
+ ACTION_DIST_INPUTS = Columns.ACTION_DIST_INPUTS
126
+ ACTION_PROB = Columns.ACTION_PROB
127
+ ACTION_LOGP = Columns.ACTION_LOGP
128
+ VF_PREDS = Columns.VF_PREDS
129
+ VALUES_BOOTSTRAPPED = Columns.VALUES_BOOTSTRAPPED
130
+ EPS_ID = Columns.EPS_ID
131
+ NEXT_OBS = Columns.NEXT_OBS
132
+
133
+ # Action distribution object.
134
+ ACTION_DIST = "action_dist"
135
+ # Action chosen before SampleBatch.ACTIONS.
136
+ PREV_ACTIONS = "prev_actions"
137
+ # Reward received before SampleBatch.REWARDS.
138
+ PREV_REWARDS = "prev_rewards"
139
+ ENV_ID = "env_id" # An env ID (e.g. the index for a vectorized sub-env).
140
+ AGENT_INDEX = "agent_index" # Uniquely identifies an agent within an episode.
141
+ # Uniquely identifies a sample batch. This is important to distinguish RNN
142
+ # sequences from the same episode when multiple sample batches are
143
+ # concatenated (fusing sequences across batches can be unsafe).
144
+ UNROLL_ID = "unroll_id"
145
+
146
+ # RE 3
147
+ # This is only computed and used when RE3 exploration strategy is enabled.
148
+ OBS_EMBEDS = "obs_embeds"
149
+ # Decision Transformer
150
+ RETURNS_TO_GO = "returns_to_go"
151
+ ATTENTION_MASKS = "attention_masks"
152
+ # Do not set this key directly. Instead, the values under this key are
153
+ # auto-computed via the values of the TERMINATEDS and TRUNCATEDS keys.
154
+ DONES = "dones"
155
+ # Use SampleBatch.OBS instead.
156
+ CUR_OBS = "obs"
157
+
158
+ @PublicAPI
159
+ def __init__(self, *args, **kwargs):
160
+ """Constructs a sample batch (same params as dict constructor).
161
+
162
+ Note: All args and those kwargs not listed below will be passed
163
+ as-is to the parent dict constructor.
164
+
165
+ Args:
166
+ _time_major: Whether data in this sample batch
167
+ is time-major. This is False by default and only relevant
168
+ if the data contains sequences.
169
+ _max_seq_len: The max sequence chunk length
170
+ if the data contains sequences.
171
+ _zero_padded: Whether the data in this batch
172
+ contains sequences AND these sequences are right-zero-padded
173
+ according to the `_max_seq_len` setting.
174
+ _is_training: Whether this batch is used for
175
+ training. If False, batch may be used for e.g. action
176
+ computations (inference).
177
+ """
178
+
179
+ if SampleBatch.DONES in kwargs:
180
+ raise KeyError(
181
+ "SampleBatch cannot be constructed anymore with a `DONES` key! "
182
+ "Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
183
+ " DONES will then be automatically computed using terminated|truncated."
184
+ )
185
+
186
+ # Possible seq_lens (TxB or BxT) setup.
187
+ self.time_major = kwargs.pop("_time_major", None)
188
+ # Maximum seq len value.
189
+ self.max_seq_len = kwargs.pop("_max_seq_len", None)
190
+ # Is alredy right-zero-padded?
191
+ self.zero_padded = kwargs.pop("_zero_padded", False)
192
+ # Whether this batch is used for training (vs inference).
193
+ self._is_training = kwargs.pop("_is_training", None)
194
+ # Weighted average number of grad updates that have been performed on the
195
+ # policy/ies that were used to collect this batch.
196
+ # E.g.: Two rollout workers collect samples of 50ts each
197
+ # (rollout_fragment_length=50). One of them has a policy that has undergone
198
+ # 2 updates thus far, the other worker uses a policy that has undergone 3
199
+ # updates thus far. The train batch size is 100, so we concatenate these 2
200
+ # batches to a new one that's 100ts long. This new 100ts batch will have its
201
+ # `num_gradient_updates` property set to 2.5 as it's the weighted average
202
+ # (both original batches contribute 50%).
203
+ self.num_grad_updates: Optional[float] = kwargs.pop("_num_grad_updates", None)
204
+
205
+ # Call super constructor. This will make the actual data accessible
206
+ # by column name (str) via e.g. self["some-col"].
207
+ dict.__init__(self, *args, **kwargs)
208
+
209
+ # Indicates whether, for this batch, sequence lengths should be slices by
210
+ # their index in the batch or by their index as a sequence.
211
+ # This is useful if a batch contains tensors of shape (B, T, ...), where each
212
+ # index of B indicates one sequence. In this case, when slicing the batch,
213
+ # we want one sequence to be slices out per index in B (
214
+ # `_slice_seq_lens_by_batch_index=True`. However, if the padded batch
215
+ # contains tensors of shape (B*T, ...), where each index of B*T indicates
216
+ # one timestep, we want one sequence to be sliced per T steps in B*T (
217
+ # `self._slice_seq_lens_in_B=False`).
218
+ # ._slice_seq_lens_in_B = True is only meant to be used for batches that we
219
+ # feed into Learner._update(), all other places in RLlib are not expected to
220
+ # need this.
221
+ self._slice_seq_lens_in_B = False
222
+
223
+ self.accessed_keys = set()
224
+ self.added_keys = set()
225
+ self.deleted_keys = set()
226
+ self.intercepted_values = {}
227
+ self.get_interceptor = None
228
+
229
+ # Clear out None seq-lens.
230
+ seq_lens_ = self.get(SampleBatch.SEQ_LENS)
231
+ if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
232
+ self.pop(SampleBatch.SEQ_LENS, None)
233
+ # Numpyfy seq_lens if list.
234
+ elif isinstance(seq_lens_, list):
235
+ self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32)
236
+ elif (torch and torch.is_tensor(seq_lens_)) or (tf and tf.is_tensor(seq_lens_)):
237
+ self[SampleBatch.SEQ_LENS] = seq_lens_
238
+
239
+ if (
240
+ self.max_seq_len is None
241
+ and seq_lens_ is not None
242
+ and not (tf and tf.is_tensor(seq_lens_))
243
+ and len(seq_lens_) > 0
244
+ ):
245
+ if torch and torch.is_tensor(seq_lens_):
246
+ self.max_seq_len = seq_lens_.max().item()
247
+ else:
248
+ self.max_seq_len = max(seq_lens_)
249
+
250
+ if self._is_training is None:
251
+ self._is_training = self.pop("is_training", False)
252
+
253
+ for k, v in self.items():
254
+ # TODO: Drop support for lists and Numbers as values.
255
+ # Convert lists of int|float into numpy arrays make sure all data
256
+ # has same length.
257
+ if isinstance(v, (Number, list)) and not k == SampleBatch.INFOS:
258
+ self[k] = np.array(v)
259
+
260
+ self.count = attempt_count_timesteps(self)
261
+
262
+ # A convenience map for slicing this batch into sub-batches along
263
+ # the time axis. This helps reduce repeated iterations through the
264
+ # batch's seq_lens array to find good slicing points. Built lazily
265
+ # when needed.
266
+ self._slice_map = []
267
+
268
+ @PublicAPI
269
+ def __len__(self) -> int:
270
+ """Returns the amount of samples in the sample batch."""
271
+ return self.count
272
+
273
+ @PublicAPI
274
+ def agent_steps(self) -> int:
275
+ """Returns the same as len(self) (number of steps in this batch).
276
+
277
+ To make this compatible with `MultiAgentBatch.agent_steps()`.
278
+ """
279
+ return len(self)
280
+
281
+ @PublicAPI
282
+ def env_steps(self) -> int:
283
+ """Returns the same as len(self) (number of steps in this batch).
284
+
285
+ To make this compatible with `MultiAgentBatch.env_steps()`.
286
+ """
287
+ return len(self)
288
+
289
+ @DeveloperAPI
290
+ def enable_slicing_by_batch_id(self):
291
+ self._slice_seq_lens_in_B = True
292
+
293
+ @DeveloperAPI
294
+ def disable_slicing_by_batch_id(self):
295
+ self._slice_seq_lens_in_B = False
296
+
297
+ @ExperimentalAPI
298
+ def is_terminated_or_truncated(self) -> bool:
299
+ """Returns True if `self` is either terminated or truncated at idx -1."""
300
+ return self[SampleBatch.TERMINATEDS][-1] or (
301
+ SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][-1]
302
+ )
303
+
304
+ @ExperimentalAPI
305
+ def is_single_trajectory(self) -> bool:
306
+ """Returns True if this SampleBatch only contains one trajectory.
307
+
308
+ This is determined by checking all timesteps (except for the last) for being
309
+ not terminated AND (if applicable) not truncated.
310
+ """
311
+ return not any(self[SampleBatch.TERMINATEDS][:-1]) and (
312
+ SampleBatch.TRUNCATEDS not in self
313
+ or not any(self[SampleBatch.TRUNCATEDS][:-1])
314
+ )
315
+
316
+ @staticmethod
317
+ @PublicAPI
318
+ @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
319
+ def concat_samples(samples):
320
+ pass
321
+
322
+ @PublicAPI
323
+ def concat(self, other: "SampleBatch") -> "SampleBatch":
324
+ """Concatenates `other` to this one and returns a new SampleBatch.
325
+
326
+ Args:
327
+ other: The other SampleBatch object to concat to this one.
328
+
329
+ Returns:
330
+ The new SampleBatch, resulting from concating `other` to `self`.
331
+
332
+ .. testcode::
333
+ :skipif: True
334
+
335
+ import numpy as np
336
+ from ray.rllib.policy.sample_batch import SampleBatch
337
+ b1 = SampleBatch({"a": np.array([1, 2])})
338
+ b2 = SampleBatch({"a": np.array([3, 4, 5])})
339
+ print(b1.concat(b2))
340
+
341
+ .. testoutput::
342
+
343
+ {"a": np.array([1, 2, 3, 4, 5])}
344
+ """
345
+ return concat_samples([self, other])
346
+
347
+ @PublicAPI
348
+ def copy(self, shallow: bool = False) -> "SampleBatch":
349
+ """Creates a deep or shallow copy of this SampleBatch and returns it.
350
+
351
+ Args:
352
+ shallow: Whether the copying should be done shallowly.
353
+
354
+ Returns:
355
+ A deep or shallow copy of this SampleBatch object.
356
+ """
357
+ copy_ = dict(self)
358
+ data = tree.map_structure(
359
+ lambda v: (
360
+ np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
361
+ ),
362
+ copy_,
363
+ )
364
+ copy_ = SampleBatch(
365
+ data,
366
+ _time_major=self.time_major,
367
+ _zero_padded=self.zero_padded,
368
+ _max_seq_len=self.max_seq_len,
369
+ _num_grad_updates=self.num_grad_updates,
370
+ )
371
+ copy_.set_get_interceptor(self.get_interceptor)
372
+ copy_.added_keys = self.added_keys
373
+ copy_.deleted_keys = self.deleted_keys
374
+ copy_.accessed_keys = self.accessed_keys
375
+ return copy_
376
+
377
+ @PublicAPI
378
+ def rows(self) -> Iterator[Dict[str, TensorType]]:
379
+ """Returns an iterator over data rows, i.e. dicts with column values.
380
+
381
+ Note that if `seq_lens` is set in self, we set it to 1 in the rows.
382
+
383
+ Yields:
384
+ The column values of the row in this iteration.
385
+
386
+ .. testcode::
387
+ :skipif: True
388
+
389
+ from ray.rllib.policy.sample_batch import SampleBatch
390
+ batch = SampleBatch({
391
+ "a": [1, 2, 3],
392
+ "b": [4, 5, 6],
393
+ "seq_lens": [1, 2]
394
+ })
395
+ for row in batch.rows():
396
+ print(row)
397
+
398
+ .. testoutput::
399
+
400
+ {"a": 1, "b": 4, "seq_lens": 1}
401
+ {"a": 2, "b": 5, "seq_lens": 1}
402
+ {"a": 3, "b": 6, "seq_lens": 1}
403
+ """
404
+
405
+ seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
406
+
407
+ self_as_dict = dict(self)
408
+
409
+ for i in range(self.count):
410
+ yield tree.map_structure_with_path(
411
+ lambda p, v, i=i: v[i] if p[0] != self.SEQ_LENS else seq_lens,
412
+ self_as_dict,
413
+ )
414
+
415
+ @PublicAPI
416
+ def columns(self, keys: List[str]) -> List[any]:
417
+ """Returns a list of the batch-data in the specified columns.
418
+
419
+ Args:
420
+ keys: List of column names fo which to return the data.
421
+
422
+ Returns:
423
+ The list of data items ordered by the order of column
424
+ names in `keys`.
425
+
426
+ .. testcode::
427
+ :skipif: True
428
+
429
+ from ray.rllib.policy.sample_batch import SampleBatch
430
+ batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
431
+ print(batch.columns(["a", "b"]))
432
+
433
+ .. testoutput::
434
+
435
+ [[1], [2]]
436
+ """
437
+
438
+ # TODO: (sven) Make this work for nested data as well.
439
+ out = []
440
+ for k in keys:
441
+ out.append(self[k])
442
+ return out
443
+
444
+ @PublicAPI
445
+ def shuffle(self) -> "SampleBatch":
446
+ """Shuffles the rows of this batch in-place.
447
+
448
+ Returns:
449
+ This very (now shuffled) SampleBatch.
450
+
451
+ Raises:
452
+ ValueError: If self[SampleBatch.SEQ_LENS] is defined.
453
+
454
+ .. testcode::
455
+ :skipif: True
456
+
457
+ from ray.rllib.policy.sample_batch import SampleBatch
458
+ batch = SampleBatch({"a": [1, 2, 3, 4]})
459
+ print(batch.shuffle())
460
+
461
+ .. testoutput::
462
+
463
+ {"a": [4, 1, 3, 2]}
464
+ """
465
+ has_time_rank = self.get(SampleBatch.SEQ_LENS) is not None
466
+
467
+ # Shuffling the data when we have `seq_lens` defined is probably
468
+ # a bad idea!
469
+ if has_time_rank and not self.zero_padded:
470
+ raise ValueError(
471
+ "SampleBatch.shuffle not possible when your data has "
472
+ "`seq_lens` defined AND is not zero-padded yet!"
473
+ )
474
+
475
+ # Get a permutation over the single items once and use the same
476
+ # permutation for all the data (otherwise, data would become
477
+ # meaningless).
478
+ # - Shuffle by individual item.
479
+ if not has_time_rank:
480
+ permutation = np.random.permutation(self.count)
481
+ # - Shuffle along batch axis (leave axis=1/time-axis as-is).
482
+ else:
483
+ permutation = np.random.permutation(len(self[SampleBatch.SEQ_LENS]))
484
+
485
+ self_as_dict = dict(self)
486
+ infos = self_as_dict.pop(Columns.INFOS, None)
487
+ shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
488
+ if infos is not None:
489
+ self_as_dict[Columns.INFOS] = [infos[i] for i in permutation]
490
+
491
+ self.update(shuffled)
492
+
493
+ # Flush cache such that intercepted values are recalculated after the
494
+ # shuffling.
495
+ self.intercepted_values = {}
496
+ return self
497
+
498
+ @PublicAPI
499
+ def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]:
500
+ """Splits by `eps_id` column and returns list of new batches.
501
+ If `eps_id` is not present, splits by `dones` instead.
502
+
503
+ Args:
504
+ key: If specified, overwrite default and use key to split.
505
+
506
+ Returns:
507
+ List of batches, one per distinct episode.
508
+
509
+ Raises:
510
+ KeyError: If the `eps_id` AND `dones` columns are not present.
511
+
512
+ .. testcode::
513
+ :skipif: True
514
+
515
+ from ray.rllib.policy.sample_batch import SampleBatch
516
+ # "eps_id" is present
517
+ batch = SampleBatch(
518
+ {"a": [1, 2, 3], "eps_id": [0, 0, 1]})
519
+ print(batch.split_by_episode())
520
+
521
+ # "eps_id" not present, split by "dones" instead
522
+ batch = SampleBatch(
523
+ {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
524
+ print(batch.split_by_episode())
525
+
526
+ # The last episode is appended even if it does not end with done
527
+ batch = SampleBatch(
528
+ {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
529
+ print(batch.split_by_episode())
530
+
531
+ batch = SampleBatch(
532
+ {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
533
+ print(batch.split_by_episode())
534
+
535
+
536
+ .. testoutput::
537
+
538
+ [{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
539
+ [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
540
+ [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
541
+ [{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
542
+
543
+
544
+ """
545
+
546
+ assert key is None or key in [SampleBatch.EPS_ID, SampleBatch.DONES], (
547
+ f"`SampleBatch.split_by_episode(key={key})` invalid! "
548
+ f"Must be [None|'dones'|'eps_id']."
549
+ )
550
+
551
+ def slice_by_eps_id():
552
+ slices = []
553
+ # Produce a new slice whenever we find a new episode ID.
554
+ cur_eps_id = self[SampleBatch.EPS_ID][0]
555
+ offset = 0
556
+ for i in range(self.count):
557
+ next_eps_id = self[SampleBatch.EPS_ID][i]
558
+ if next_eps_id != cur_eps_id:
559
+ slices.append(self[offset:i])
560
+ offset = i
561
+ cur_eps_id = next_eps_id
562
+ # Add final slice.
563
+ slices.append(self[offset : self.count])
564
+ return slices
565
+
566
+ def slice_by_terminateds_or_truncateds():
567
+ slices = []
568
+ offset = 0
569
+ for i in range(self.count):
570
+ if self[SampleBatch.TERMINATEDS][i] or (
571
+ SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][i]
572
+ ):
573
+ # Since self[i] is the last timestep of the episode,
574
+ # append it to the batch, then set offset to the start
575
+ # of the next batch
576
+ slices.append(self[offset : i + 1])
577
+ offset = i + 1
578
+ # Add final slice.
579
+ if offset != self.count:
580
+ slices.append(self[offset:])
581
+ return slices
582
+
583
+ key_to_method = {
584
+ SampleBatch.EPS_ID: slice_by_eps_id,
585
+ SampleBatch.DONES: slice_by_terminateds_or_truncateds,
586
+ }
587
+
588
+ # If key not specified, default to this order.
589
+ key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES]
590
+
591
+ slices = None
592
+ if key is not None:
593
+ # If key specified, directly use it.
594
+ if key == SampleBatch.EPS_ID and key not in self:
595
+ raise KeyError(f"{self} does not have key `{key}`!")
596
+ slices = key_to_method[key]()
597
+ else:
598
+ # If key not specified, go in order.
599
+ for key in key_resolve_order:
600
+ if key == SampleBatch.DONES or key in self:
601
+ slices = key_to_method[key]()
602
+ break
603
+ if slices is None:
604
+ raise KeyError(f"{self} does not have keys {key_resolve_order}!")
605
+
606
+ assert (
607
+ sum(s.count for s in slices) == self.count
608
+ ), f"Calling split_by_episode on {self} returns {slices}"
609
+ f"which should in total have {self.count} timesteps!"
610
+ return slices
611
+
612
+ def slice(
613
+ self, start: int, end: int, state_start=None, state_end=None
614
+ ) -> "SampleBatch":
615
+ """Returns a slice of the row data of this batch (w/o copying).
616
+
617
+ Args:
618
+ start: Starting index. If < 0, will left-zero-pad.
619
+ end: Ending index.
620
+
621
+ Returns:
622
+ A new SampleBatch, which has a slice of this batch's data.
623
+ """
624
+ if (
625
+ self.get(SampleBatch.SEQ_LENS) is not None
626
+ and len(self[SampleBatch.SEQ_LENS]) > 0
627
+ ):
628
+ if start < 0:
629
+ data = {
630
+ k: np.concatenate(
631
+ [
632
+ np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype),
633
+ v[0:end],
634
+ ]
635
+ )
636
+ for k, v in self.items()
637
+ if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
638
+ }
639
+ else:
640
+ data = {
641
+ k: tree.map_structure(lambda s: s[start:end], v)
642
+ for k, v in self.items()
643
+ if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
644
+ }
645
+ if state_start is not None:
646
+ assert state_end is not None
647
+ state_idx = 0
648
+ state_key = "state_in_{}".format(state_idx)
649
+ while state_key in self:
650
+ data[state_key] = self[state_key][state_start:state_end]
651
+ state_idx += 1
652
+ state_key = "state_in_{}".format(state_idx)
653
+ seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end])
654
+ # Adjust seq_lens if necessary.
655
+ data_len = len(data[next(iter(data))])
656
+ if sum(seq_lens) != data_len:
657
+ assert sum(seq_lens) > data_len
658
+ seq_lens[-1] = data_len - sum(seq_lens[:-1])
659
+ else:
660
+ # Fix state_in_x data.
661
+ count = 0
662
+ state_start = None
663
+ seq_lens = None
664
+ for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]):
665
+ count += seq_len
666
+ if count >= end:
667
+ state_idx = 0
668
+ state_key = "state_in_{}".format(state_idx)
669
+ if state_start is None:
670
+ state_start = i
671
+ while state_key in self:
672
+ data[state_key] = self[state_key][state_start : i + 1]
673
+ state_idx += 1
674
+ state_key = "state_in_{}".format(state_idx)
675
+ seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [
676
+ seq_len - (count - end)
677
+ ]
678
+ if start < 0:
679
+ seq_lens[0] += -start
680
+ diff = sum(seq_lens) - (end - start)
681
+ if diff > 0:
682
+ seq_lens[0] -= diff
683
+ assert sum(seq_lens) == (end - start)
684
+ break
685
+ elif state_start is None and count > start:
686
+ state_start = i
687
+
688
+ return SampleBatch(
689
+ data,
690
+ seq_lens=seq_lens,
691
+ _is_training=self.is_training,
692
+ _time_major=self.time_major,
693
+ _num_grad_updates=self.num_grad_updates,
694
+ )
695
+ else:
696
+ return SampleBatch(
697
+ tree.map_structure(lambda value: value[start:end], self),
698
+ _is_training=self.is_training,
699
+ _time_major=self.time_major,
700
+ _num_grad_updates=self.num_grad_updates,
701
+ )
702
+
703
+ def _batch_slice(self, slice_: slice) -> "SampleBatch":
704
+ """Helper method to handle SampleBatch slicing using a slice object.
705
+
706
+ The returned SampleBatch uses the same underlying data object as
707
+ `self`, so changing the slice will also change `self`.
708
+
709
+ Note that only zero or positive bounds are allowed for both start
710
+ and stop values. The slice step must be 1 (or None, which is the
711
+ same).
712
+
713
+ Args:
714
+ slice_: The python slice object to slice by.
715
+
716
+ Returns:
717
+ A new SampleBatch, however "linking" into the same data
718
+ (sliced) as self.
719
+ """
720
+ start = slice_.start or 0
721
+ stop = slice_.stop or len(self[SampleBatch.SEQ_LENS])
722
+ # If stop goes beyond the length of this batch -> Make it go till the
723
+ # end only (including last item).
724
+ # Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
725
+ if stop > len(self):
726
+ stop = len(self)
727
+ assert start >= 0 and stop >= 0 and slice_.step in [1, None]
728
+
729
+ # Exclude INFOs from regular array slicing as the data under this column might
730
+ # be a list (not good for `tree.map_structure` call).
731
+ # Furthermore, slicing does not work when the data in the column is
732
+ # singular (not a list or array).
733
+ infos = self.pop(SampleBatch.INFOS, None)
734
+ data = tree.map_structure(lambda value: value[start:stop], self)
735
+ if infos is not None:
736
+ # Slice infos according to SEQ_LENS.
737
+ info_slice_start = int(sum(self[SampleBatch.SEQ_LENS][:start]))
738
+ info_slice_stop = int(sum(self[SampleBatch.SEQ_LENS][start:stop]))
739
+ data[SampleBatch.INFOS] = infos[info_slice_start:info_slice_stop]
740
+ # Put infos back into `self`.
741
+ self[Columns.INFOS] = infos
742
+
743
+ return SampleBatch(
744
+ data,
745
+ _is_training=self.is_training,
746
+ _time_major=self.time_major,
747
+ _num_grad_updates=self.num_grad_updates,
748
+ )
749
+
750
+ @PublicAPI
751
+ def timeslices(
752
+ self,
753
+ size: Optional[int] = None,
754
+ num_slices: Optional[int] = None,
755
+ k: Optional[int] = None,
756
+ ) -> List["SampleBatch"]:
757
+ """Returns SampleBatches, each one representing a k-slice of this one.
758
+
759
+ Will start from timestep 0 and produce slices of size=k.
760
+
761
+ Args:
762
+ size: The size (in timesteps) of each returned SampleBatch.
763
+ num_slices: The number of slices to produce.
764
+ k: Deprecated: Use size or num_slices instead. The size
765
+ (in timesteps) of each returned SampleBatch.
766
+
767
+ Returns:
768
+ The list of `num_slices` (new) SampleBatches or n (new)
769
+ SampleBatches each one of size `size`.
770
+ """
771
+ if size is None and num_slices is None:
772
+ deprecation_warning("k", "size or num_slices")
773
+ assert k is not None
774
+ size = k
775
+
776
+ if size is None:
777
+ assert isinstance(num_slices, int)
778
+
779
+ slices = []
780
+ left = len(self)
781
+ start = 0
782
+ while left:
783
+ len_ = left // (num_slices - len(slices))
784
+ stop = start + len_
785
+ slices.append(self[start:stop])
786
+ left -= len_
787
+ start = stop
788
+
789
+ return slices
790
+
791
+ else:
792
+ assert isinstance(size, int)
793
+
794
+ slices = []
795
+ left = len(self)
796
+ start = 0
797
+ while left:
798
+ stop = start + size
799
+ slices.append(self[start:stop])
800
+ left -= size
801
+ start = stop
802
+
803
+ return slices
804
+
805
+ @Deprecated(new="SampleBatch.right_zero_pad", error=True)
806
+ def zero_pad(self, max_seq_len, exclude_states=True):
807
+ pass
808
+
809
+ def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
810
+ """Right (adding zeros at end) zero-pads this SampleBatch in-place.
811
+
812
+ This will set the `self.zero_padded` flag to True and
813
+ `self.max_seq_len` to the given `max_seq_len` value.
814
+
815
+ Args:
816
+ max_seq_len: The max (total) length to zero pad to.
817
+ exclude_states: If False, also right-zero-pad all
818
+ `state_in_x` data. If True, leave `state_in_x` keys
819
+ as-is.
820
+
821
+ Returns:
822
+ This very (now right-zero-padded) SampleBatch.
823
+
824
+ Raises:
825
+ ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
826
+
827
+ .. testcode::
828
+ :skipif: True
829
+
830
+ from ray.rllib.policy.sample_batch import SampleBatch
831
+ batch = SampleBatch(
832
+ {"a": [1, 2, 3], "seq_lens": [1, 2]})
833
+ print(batch.right_zero_pad(max_seq_len=4))
834
+
835
+ batch = SampleBatch({"a": [1, 2, 3],
836
+ "state_in_0": [1.0, 3.0],
837
+ "seq_lens": [1, 2]})
838
+ print(batch.right_zero_pad(max_seq_len=5))
839
+
840
+ .. testoutput::
841
+
842
+ {"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
843
+ {"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
844
+ "state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
845
+ "seq_lens": [1, 2]}
846
+
847
+ """
848
+ seq_lens = self.get(SampleBatch.SEQ_LENS)
849
+ if seq_lens is None:
850
+ raise ValueError(
851
+ "Cannot right-zero-pad SampleBatch if no `seq_lens` field "
852
+ f"present! SampleBatch={self}"
853
+ )
854
+
855
+ length = len(seq_lens) * max_seq_len
856
+
857
+ def _zero_pad_in_place(path, value):
858
+ # Skip "state_in_..." columns and "seq_lens".
859
+ if (exclude_states is True and path[0].startswith("state_in_")) or path[
860
+ 0
861
+ ] == SampleBatch.SEQ_LENS:
862
+ return
863
+ # Generate zero-filled primer of len=max_seq_len.
864
+ if value.dtype == object or value.dtype.type is np.str_:
865
+ f_pad = [None] * length
866
+ else:
867
+ # Make sure type doesn't change.
868
+ f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype)
869
+ # Fill primer with data.
870
+ f_pad_base = f_base = 0
871
+ for len_ in self[SampleBatch.SEQ_LENS]:
872
+ f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_]
873
+ f_pad_base += max_seq_len
874
+ f_base += len_
875
+ assert f_base == len(value), value
876
+
877
+ # Update our data in-place.
878
+ curr = self
879
+ for i, p in enumerate(path):
880
+ if i == len(path) - 1:
881
+ curr[p] = f_pad
882
+ curr = curr[p]
883
+
884
+ self_as_dict = dict(self)
885
+ tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
886
+
887
+ # Set flags to indicate, we are now zero-padded (and to what extend).
888
+ self.zero_padded = True
889
+ self.max_seq_len = max_seq_len
890
+
891
+ return self
892
+
893
+ @ExperimentalAPI
894
+ def to_device(self, device, framework="torch"):
895
+ """TODO: transfer batch to given device as framework tensor."""
896
+ if framework == "torch":
897
+ assert torch is not None
898
+ for k, v in self.items():
899
+ self[k] = convert_to_torch_tensor(v, device)
900
+ else:
901
+ raise NotImplementedError
902
+ return self
903
+
904
+ @PublicAPI
905
+ def size_bytes(self) -> int:
906
+ """Returns sum over number of bytes of all data buffers.
907
+
908
+ For numpy arrays, we use ``.nbytes``. For all other value types, we use
909
+ sys.getsizeof(...).
910
+
911
+ Returns:
912
+ The overall size in bytes of the data buffer (all columns).
913
+ """
914
+ return sum(
915
+ v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
916
+ for v in tree.flatten(self)
917
+ )
918
+
919
+ def get(self, key, default=None):
920
+ """Returns one column (by key) from the data or a default value."""
921
+ try:
922
+ return self.__getitem__(key)
923
+ except KeyError:
924
+ return default
925
+
926
+ @PublicAPI
927
+ def as_multi_agent(self, module_id: Optional[ModuleID] = None) -> "MultiAgentBatch":
928
+ """Returns the respective MultiAgentBatch
929
+
930
+ Note, if `module_id` is not provided uses `DEFAULT_POLICY`_ID`.
931
+
932
+ Args;
933
+ module_id: An optional module ID. If `None` the `DEFAULT_POLICY_ID`
934
+ is used.
935
+
936
+ Returns:
937
+ The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding
938
+ to this SampleBatch.
939
+ """
940
+ return MultiAgentBatch({module_id or DEFAULT_POLICY_ID: self}, self.count)
941
+
942
+ @PublicAPI
943
+ def __getitem__(self, key: Union[str, slice]) -> TensorType:
944
+ """Returns one column (by key) from the data or a sliced new batch.
945
+
946
+ Args:
947
+ key: The key (column name) to return or
948
+ a slice object for slicing this SampleBatch.
949
+
950
+ Returns:
951
+ The data under the given key or a sliced version of this batch.
952
+ """
953
+ if isinstance(key, slice):
954
+ return self._slice(key)
955
+
956
+ # Special key DONES -> Translate to `TERMINATEDS | TRUNCATEDS` to reflect
957
+ # the old meaning of DONES.
958
+ if key == SampleBatch.DONES:
959
+ return self[SampleBatch.TERMINATEDS]
960
+ # Backward compatibility for when "input-dicts" were used.
961
+ elif key == "is_training":
962
+ if log_once("SampleBatch['is_training']"):
963
+ deprecation_warning(
964
+ old="SampleBatch['is_training']",
965
+ new="SampleBatch.is_training",
966
+ error=False,
967
+ )
968
+ return self.is_training
969
+
970
+ if not hasattr(self, key) and key in self:
971
+ self.accessed_keys.add(key)
972
+
973
+ value = dict.__getitem__(self, key)
974
+ if self.get_interceptor is not None:
975
+ if key not in self.intercepted_values:
976
+ self.intercepted_values[key] = self.get_interceptor(value)
977
+ value = self.intercepted_values[key]
978
+ return value
979
+
980
+ @PublicAPI
981
+ def __setitem__(self, key, item) -> None:
982
+ """Inserts (overrides) an entire column (by key) in the data buffer.
983
+
984
+ Args:
985
+ key: The column name to set a value for.
986
+ item: The data to insert.
987
+ """
988
+ # Disallow setting DONES key directly.
989
+ if key == SampleBatch.DONES:
990
+ raise KeyError(
991
+ "Cannot set `DONES` anymore in a SampleBatch! "
992
+ "Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
993
+ " DONES will then be automatically computed using terminated|truncated."
994
+ )
995
+ # Defend against creating SampleBatch via pickle (no property
996
+ # `added_keys` and first item is already set).
997
+ elif not hasattr(self, "added_keys"):
998
+ dict.__setitem__(self, key, item)
999
+ return
1000
+
1001
+ # Backward compatibility for when "input-dicts" were used.
1002
+ if key == "is_training":
1003
+ if log_once("SampleBatch['is_training']"):
1004
+ deprecation_warning(
1005
+ old="SampleBatch['is_training']",
1006
+ new="SampleBatch.is_training",
1007
+ error=False,
1008
+ )
1009
+ self._is_training = item
1010
+ return
1011
+
1012
+ if key not in self:
1013
+ self.added_keys.add(key)
1014
+
1015
+ dict.__setitem__(self, key, item)
1016
+ if key in self.intercepted_values:
1017
+ self.intercepted_values[key] = item
1018
+
1019
+ @property
1020
+ def is_training(self):
1021
+ if self.get_interceptor is not None and isinstance(self._is_training, bool):
1022
+ if "_is_training" not in self.intercepted_values:
1023
+ self.intercepted_values["_is_training"] = self.get_interceptor(
1024
+ self._is_training
1025
+ )
1026
+ return self.intercepted_values["_is_training"]
1027
+ return self._is_training
1028
+
1029
+ def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
1030
+ """Sets the `is_training` flag for this SampleBatch."""
1031
+ self._is_training = training
1032
+ self.intercepted_values.pop("_is_training", None)
1033
+
1034
+ @PublicAPI
1035
+ def __delitem__(self, key):
1036
+ self.deleted_keys.add(key)
1037
+ dict.__delitem__(self, key)
1038
+
1039
+ @DeveloperAPI
1040
+ def compress(
1041
+ self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
1042
+ ) -> "SampleBatch":
1043
+ """Compresses the data buffers (by column) in place.
1044
+
1045
+ Args:
1046
+ bulk: Whether to compress across the batch dimension (0)
1047
+ as well. If False will compress n separate list items, where n
1048
+ is the batch size.
1049
+ columns: The columns to compress. Default: Only
1050
+ compress the obs and new_obs columns.
1051
+
1052
+ Returns:
1053
+ This very (now compressed) SampleBatch.
1054
+ """
1055
+
1056
+ def _compress_in_place(path, value):
1057
+ if path[0] not in columns:
1058
+ return
1059
+ curr = self
1060
+ for i, p in enumerate(path):
1061
+ if i == len(path) - 1:
1062
+ if bulk:
1063
+ curr[p] = pack(value)
1064
+ else:
1065
+ curr[p] = np.array([pack(o) for o in value])
1066
+ curr = curr[p]
1067
+
1068
+ tree.map_structure_with_path(_compress_in_place, self)
1069
+
1070
+ return self
1071
+
1072
+ @DeveloperAPI
1073
+ def decompress_if_needed(
1074
+ self, columns: Set[str] = frozenset(["obs", "new_obs"])
1075
+ ) -> "SampleBatch":
1076
+ """Decompresses data buffers (per column if not compressed) in place.
1077
+
1078
+ Args:
1079
+ columns: The columns to decompress. Default: Only
1080
+ decompress the obs and new_obs columns.
1081
+
1082
+ Returns:
1083
+ This very (now uncompressed) SampleBatch.
1084
+ """
1085
+
1086
+ def _decompress_in_place(path, value):
1087
+ if path[0] not in columns:
1088
+ return
1089
+ curr = self
1090
+ for p in path[:-1]:
1091
+ curr = curr[p]
1092
+ # Bulk compressed.
1093
+ if is_compressed(value):
1094
+ curr[path[-1]] = unpack(value)
1095
+ # Non bulk compressed.
1096
+ elif len(value) > 0 and is_compressed(value[0]):
1097
+ curr[path[-1]] = np.array([unpack(o) for o in value])
1098
+
1099
+ tree.map_structure_with_path(_decompress_in_place, self)
1100
+
1101
+ return self
1102
+
1103
+ @DeveloperAPI
1104
+ def set_get_interceptor(self, fn):
1105
+ """Sets a function to be called on every getitem."""
1106
+ # If get-interceptor changes, must erase old intercepted values.
1107
+ if fn is not self.get_interceptor:
1108
+ self.intercepted_values = {}
1109
+ self.get_interceptor = fn
1110
+
1111
+ def __repr__(self):
1112
+ keys = list(self.keys())
1113
+ if self.get(SampleBatch.SEQ_LENS) is None:
1114
+ return f"SampleBatch({self.count}: {keys})"
1115
+ else:
1116
+ keys.remove(SampleBatch.SEQ_LENS)
1117
+ return (
1118
+ f"SampleBatch({self.count} " f"(seqs={len(self['seq_lens'])}): {keys})"
1119
+ )
1120
+
1121
+ def _slice(self, slice_: slice) -> "SampleBatch":
1122
+ """Helper method to handle SampleBatch slicing using a slice object.
1123
+
1124
+ The returned SampleBatch uses the same underlying data object as
1125
+ `self`, so changing the slice will also change `self`.
1126
+
1127
+ Note that only zero or positive bounds are allowed for both start
1128
+ and stop values. The slice step must be 1 (or None, which is the
1129
+ same).
1130
+
1131
+ Args:
1132
+ slice_: The python slice object to slice by.
1133
+
1134
+ Returns:
1135
+ A new SampleBatch, however "linking" into the same data
1136
+ (sliced) as self.
1137
+ """
1138
+ if self._slice_seq_lens_in_B:
1139
+ return self._batch_slice(slice_)
1140
+
1141
+ start = slice_.start or 0
1142
+ stop = slice_.stop or len(self)
1143
+ # If stop goes beyond the length of this batch -> Make it go till the
1144
+ # end only (including last item).
1145
+ # Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
1146
+ if stop > len(self):
1147
+ stop = len(self)
1148
+
1149
+ if (
1150
+ self.get(SampleBatch.SEQ_LENS) is not None
1151
+ and len(self[SampleBatch.SEQ_LENS]) > 0
1152
+ ):
1153
+ # Build our slice-map, if not done already.
1154
+ if not self._slice_map:
1155
+ sum_ = 0
1156
+ for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
1157
+ self._slice_map.extend([(i, sum_)] * l)
1158
+ sum_ = sum_ + l
1159
+ # In case `stop` points to the very end (lengths of this
1160
+ # batch), return the last sequence (the -1 here makes sure we
1161
+ # never go beyond it; would result in an index error below).
1162
+ self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))
1163
+
1164
+ start_seq_len, start_unpadded = self._slice_map[start]
1165
+ stop_seq_len, stop_unpadded = self._slice_map[stop]
1166
+ start_padded = start_unpadded
1167
+ stop_padded = stop_unpadded
1168
+ if self.zero_padded:
1169
+ start_padded = start_seq_len * self.max_seq_len
1170
+ stop_padded = stop_seq_len * self.max_seq_len
1171
+
1172
+ def map_(path, value):
1173
+ if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
1174
+ "state_in_"
1175
+ ):
1176
+ return value[start_padded:stop_padded]
1177
+ else:
1178
+ return value[start_seq_len:stop_seq_len]
1179
+
1180
+ infos = self.pop(SampleBatch.INFOS, None)
1181
+ data = tree.map_structure_with_path(map_, self)
1182
+ if infos is not None and isinstance(infos, (list, np.ndarray)):
1183
+ self[SampleBatch.INFOS] = infos
1184
+ data[SampleBatch.INFOS] = infos[start_unpadded:stop_unpadded]
1185
+
1186
+ return SampleBatch(
1187
+ data,
1188
+ _is_training=self.is_training,
1189
+ _time_major=self.time_major,
1190
+ _zero_padded=self.zero_padded,
1191
+ _max_seq_len=self.max_seq_len if self.zero_padded else None,
1192
+ _num_grad_updates=self.num_grad_updates,
1193
+ )
1194
+ else:
1195
+ infos = self.pop(SampleBatch.INFOS, None)
1196
+ data = tree.map_structure(lambda s: s[start:stop], self)
1197
+ if infos is not None and isinstance(infos, (list, np.ndarray)):
1198
+ self[SampleBatch.INFOS] = infos
1199
+ data[SampleBatch.INFOS] = infos[start:stop]
1200
+
1201
+ return SampleBatch(
1202
+ data,
1203
+ _is_training=self.is_training,
1204
+ _time_major=self.time_major,
1205
+ _num_grad_updates=self.num_grad_updates,
1206
+ )
1207
+
1208
+ @Deprecated(error=False)
1209
+ def _get_slice_indices(self, slice_size):
1210
+ data_slices = []
1211
+ data_slices_states = []
1212
+ if (
1213
+ self.get(SampleBatch.SEQ_LENS) is not None
1214
+ and len(self[SampleBatch.SEQ_LENS]) > 0
1215
+ ):
1216
+ assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), (
1217
+ "ERROR: `slice_size` must be larger than the max. seq-len "
1218
+ "in the batch!"
1219
+ )
1220
+ start_pos = 0
1221
+ current_slize_size = 0
1222
+ actual_slice_idx = 0
1223
+ start_idx = 0
1224
+ idx = 0
1225
+ while idx < len(self[SampleBatch.SEQ_LENS]):
1226
+ seq_len = self[SampleBatch.SEQ_LENS][idx]
1227
+ current_slize_size += seq_len
1228
+ actual_slice_idx += (
1229
+ seq_len if not self.zero_padded else self.max_seq_len
1230
+ )
1231
+ # Complete minibatch -> Append to data_slices.
1232
+ if current_slize_size >= slice_size:
1233
+ end_idx = idx + 1
1234
+ # We are not zero-padded yet; all sequences are
1235
+ # back-to-back.
1236
+ if not self.zero_padded:
1237
+ data_slices.append((start_pos, start_pos + slice_size))
1238
+ start_pos += slice_size
1239
+ if current_slize_size > slice_size:
1240
+ overhead = current_slize_size - slice_size
1241
+ start_pos -= seq_len - overhead
1242
+ idx -= 1
1243
+ # We are already zero-padded: Cut in chunks of max_seq_len.
1244
+ else:
1245
+ data_slices.append((start_pos, actual_slice_idx))
1246
+ start_pos = actual_slice_idx
1247
+
1248
+ data_slices_states.append((start_idx, end_idx))
1249
+ current_slize_size = 0
1250
+ start_idx = idx + 1
1251
+ idx += 1
1252
+ else:
1253
+ i = 0
1254
+ while i < self.count:
1255
+ data_slices.append((i, i + slice_size))
1256
+ i += slice_size
1257
+ return data_slices, data_slices_states
1258
+
1259
+ @ExperimentalAPI
1260
+ def get_single_step_input_dict(
1261
+ self,
1262
+ view_requirements: ViewRequirementsDict,
1263
+ index: Union[str, int] = "last",
1264
+ ) -> "SampleBatch":
1265
+ """Creates single ts SampleBatch at given index from `self`.
1266
+
1267
+ For usage as input-dict for model (action or value function) calls.
1268
+
1269
+ Args:
1270
+ view_requirements: A view requirements dict from the model for
1271
+ which to produce the input_dict.
1272
+ index: An integer index value indicating the
1273
+ position in the trajectory for which to generate the
1274
+ compute_actions input dict. Set to "last" to generate the dict
1275
+ at the very end of the trajectory (e.g. for value estimation).
1276
+ Note that "last" is different from -1, as "last" will use the
1277
+ final NEXT_OBS as observation input.
1278
+
1279
+ Returns:
1280
+ The (single-timestep) input dict for ModelV2 calls.
1281
+ """
1282
+ last_mappings = {
1283
+ SampleBatch.OBS: SampleBatch.NEXT_OBS,
1284
+ SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
1285
+ SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
1286
+ }
1287
+
1288
+ input_dict = {}
1289
+ for view_col, view_req in view_requirements.items():
1290
+ if view_req.used_for_compute_actions is False:
1291
+ continue
1292
+
1293
+ # Create batches of size 1 (single-agent input-dict).
1294
+ data_col = view_req.data_col or view_col
1295
+ if index == "last":
1296
+ data_col = last_mappings.get(data_col, data_col)
1297
+ # Range needed.
1298
+ if view_req.shift_from is not None:
1299
+ # Batch repeat value > 1: We have single frames in the
1300
+ # batch at each timestep (for the `data_col`).
1301
+ data = self[view_col][-1]
1302
+ traj_len = len(self[data_col])
1303
+ missing_at_end = traj_len % view_req.batch_repeat_value
1304
+ # Index into the observations column must be shifted by
1305
+ # -1 b/c index=0 for observations means the current (last
1306
+ # seen) observation (after having taken an action).
1307
+ obs_shift = (
1308
+ -1 if data_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] else 0
1309
+ )
1310
+ from_ = view_req.shift_from + obs_shift
1311
+ to_ = view_req.shift_to + obs_shift + 1
1312
+ if to_ == 0:
1313
+ to_ = None
1314
+ input_dict[view_col] = np.array(
1315
+ [
1316
+ np.concatenate([data, self[data_col][-missing_at_end:]])[
1317
+ from_:to_
1318
+ ]
1319
+ ]
1320
+ )
1321
+ # Single index.
1322
+ else:
1323
+ input_dict[view_col] = tree.map_structure(
1324
+ lambda v: v[-1:], # keep as array (w/ 1 element)
1325
+ self[data_col],
1326
+ )
1327
+ # Single index somewhere inside the trajectory (non-last).
1328
+ else:
1329
+ input_dict[view_col] = self[data_col][
1330
+ index : index + 1 if index != -1 else None
1331
+ ]
1332
+
1333
+ return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
1334
+
1335
+
1336
+ @PublicAPI
1337
+ class MultiAgentBatch:
1338
+ """A batch of experiences from multiple agents in the environment.
1339
+
1340
+ Attributes:
1341
+ policy_batches (Dict[PolicyID, SampleBatch]): Dict mapping policy IDs to
1342
+ SampleBatches of experiences.
1343
+ count: The number of env steps in this batch.
1344
+ """
1345
+
1346
+ @PublicAPI
1347
+ def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int):
1348
+ """Initialize a MultiAgentBatch instance.
1349
+
1350
+ Args:
1351
+ policy_batches: Dict mapping policy IDs to SampleBatches of experiences.
1352
+ env_steps: The number of environment steps in the environment
1353
+ this batch contains. This will be less than the number of
1354
+ transitions this batch contains across all policies in total.
1355
+ """
1356
+
1357
+ for v in policy_batches.values():
1358
+ assert isinstance(v, SampleBatch)
1359
+ self.policy_batches = policy_batches
1360
+ # Called "count" for uniformity with SampleBatch.
1361
+ # Prefer to access this via the `env_steps()` method when possible
1362
+ # for clarity.
1363
+ self.count = env_steps
1364
+
1365
+ @PublicAPI
1366
+ def env_steps(self) -> int:
1367
+ """The number of env steps (there are >= 1 agent steps per env step).
1368
+
1369
+ Returns:
1370
+ The number of environment steps contained in this batch.
1371
+ """
1372
+ return self.count
1373
+
1374
+ @PublicAPI
1375
+ def __len__(self) -> int:
1376
+ """Same as `self.env_steps()`."""
1377
+ return self.count
1378
+
1379
+ @PublicAPI
1380
+ def agent_steps(self) -> int:
1381
+ """The number of agent steps (there are >= 1 agent steps per env step).
1382
+
1383
+ Returns:
1384
+ The number of agent steps total in this batch.
1385
+ """
1386
+ ct = 0
1387
+ for batch in self.policy_batches.values():
1388
+ ct += batch.count
1389
+ return ct
1390
+
1391
+ @PublicAPI
1392
+ def timeslices(self, k: int) -> List["MultiAgentBatch"]:
1393
+ """Returns k-step batches holding data for each agent at those steps.
1394
+
1395
+ For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
1396
+ for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
1397
+
1398
+ Calling timeslices(1) would return three MultiAgentBatches containing
1399
+ [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
1400
+
1401
+ Calling timeslices(2) would return two MultiAgentBatches containing
1402
+ [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
1403
+
1404
+ This method is used to implement "lockstep" replay mode. Note that this
1405
+ method does not guarantee each batch contains only data from a single
1406
+ unroll. Batches might contain data from multiple different envs.
1407
+ """
1408
+ from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
1409
+
1410
+ # Build a sorted set of (eps_id, t, policy_id, data...)
1411
+ steps = []
1412
+ for policy_id, batch in self.policy_batches.items():
1413
+ for row in batch.rows():
1414
+ steps.append(
1415
+ (
1416
+ row[SampleBatch.EPS_ID],
1417
+ row[SampleBatch.T],
1418
+ row[SampleBatch.AGENT_INDEX],
1419
+ policy_id,
1420
+ row,
1421
+ )
1422
+ )
1423
+ steps.sort()
1424
+
1425
+ finished_slices = []
1426
+ cur_slice = collections.defaultdict(SampleBatchBuilder)
1427
+ cur_slice_size = 0
1428
+
1429
+ def finish_slice():
1430
+ nonlocal cur_slice_size
1431
+ assert cur_slice_size > 0
1432
+ batch = MultiAgentBatch(
1433
+ {k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
1434
+ )
1435
+ cur_slice_size = 0
1436
+ cur_slice.clear()
1437
+ finished_slices.append(batch)
1438
+
1439
+ # For each unique env timestep.
1440
+ for _, group in itertools.groupby(steps, lambda x: x[:2]):
1441
+ # Accumulate into the current slice.
1442
+ for _, _, _, policy_id, row in group:
1443
+ cur_slice[policy_id].add_values(**row)
1444
+ cur_slice_size += 1
1445
+ # Slice has reached target number of env steps.
1446
+ if cur_slice_size >= k:
1447
+ finish_slice()
1448
+ assert cur_slice_size == 0
1449
+
1450
+ if cur_slice_size > 0:
1451
+ finish_slice()
1452
+
1453
+ assert len(finished_slices) > 0, finished_slices
1454
+ return finished_slices
1455
+
1456
+ @staticmethod
1457
+ @PublicAPI
1458
+ def wrap_as_needed(
1459
+ policy_batches: Dict[PolicyID, SampleBatch], env_steps: int
1460
+ ) -> Union[SampleBatch, "MultiAgentBatch"]:
1461
+ """Returns SampleBatch or MultiAgentBatch, depending on given policies.
1462
+ If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch.
1463
+
1464
+ Args:
1465
+ policy_batches: Mapping from policy ids to SampleBatch.
1466
+ env_steps: Number of env steps in the batch.
1467
+
1468
+ Returns:
1469
+ The single default policy's SampleBatch or a MultiAgentBatch
1470
+ (more than one policy).
1471
+ """
1472
+ if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
1473
+ return policy_batches[DEFAULT_POLICY_ID]
1474
+ return MultiAgentBatch(policy_batches=policy_batches, env_steps=env_steps)
1475
+
1476
+ @staticmethod
1477
+ @PublicAPI
1478
+ @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
1479
+ def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
1480
+ return concat_samples_into_ma_batch(samples)
1481
+
1482
+ @PublicAPI
1483
+ def copy(self) -> "MultiAgentBatch":
1484
+ """Deep-copies self into a new MultiAgentBatch.
1485
+
1486
+ Returns:
1487
+ The copy of self with deep-copied data.
1488
+ """
1489
+ return MultiAgentBatch(
1490
+ {k: v.copy() for (k, v) in self.policy_batches.items()}, self.count
1491
+ )
1492
+
1493
+ @ExperimentalAPI
1494
+ def to_device(self, device, framework="torch"):
1495
+ """TODO: transfer batch to given device as framework tensor."""
1496
+ if framework == "torch":
1497
+ assert torch is not None
1498
+ for pid, policy_batch in self.policy_batches.items():
1499
+ self.policy_batches[pid] = policy_batch.to_device(
1500
+ device, framework=framework
1501
+ )
1502
+ else:
1503
+ raise NotImplementedError
1504
+ return self
1505
+
1506
+ @PublicAPI
1507
+ def size_bytes(self) -> int:
1508
+ """
1509
+ Returns:
1510
+ The overall size in bytes of all policy batches (all columns).
1511
+ """
1512
+ return sum(b.size_bytes() for b in self.policy_batches.values())
1513
+
1514
+ @DeveloperAPI
1515
+ def compress(
1516
+ self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
1517
+ ) -> None:
1518
+ """Compresses each policy batch (per column) in place.
1519
+
1520
+ Args:
1521
+ bulk: Whether to compress across the batch dimension (0)
1522
+ as well. If False will compress n separate list items, where n
1523
+ is the batch size.
1524
+ columns: Set of column names to compress.
1525
+ """
1526
+ for batch in self.policy_batches.values():
1527
+ batch.compress(bulk=bulk, columns=columns)
1528
+
1529
+ @DeveloperAPI
1530
+ def decompress_if_needed(
1531
+ self, columns: Set[str] = frozenset(["obs", "new_obs"])
1532
+ ) -> "MultiAgentBatch":
1533
+ """Decompresses each policy batch (per column), if already compressed.
1534
+
1535
+ Args:
1536
+ columns: Set of column names to decompress.
1537
+
1538
+ Returns:
1539
+ Self.
1540
+ """
1541
+ for batch in self.policy_batches.values():
1542
+ batch.decompress_if_needed(columns)
1543
+ return self
1544
+
1545
+ @DeveloperAPI
1546
+ def as_multi_agent(self) -> "MultiAgentBatch":
1547
+ """Simply returns `self` (already a MultiAgentBatch).
1548
+
1549
+ Returns:
1550
+ This very instance of MultiAgentBatch.
1551
+ """
1552
+ return self
1553
+
1554
+ def __getitem__(self, key: str) -> SampleBatch:
1555
+ """Returns the SampleBatch for the given policy id."""
1556
+ return self.policy_batches[key]
1557
+
1558
+ def __str__(self):
1559
+ return "MultiAgentBatch({}, env_steps={})".format(
1560
+ str(self.policy_batches), self.count
1561
+ )
1562
+
1563
+ def __repr__(self):
1564
+ return "MultiAgentBatch({}, env_steps={})".format(
1565
+ str(self.policy_batches), self.count
1566
+ )
1567
+
1568
+
1569
+ @PublicAPI
1570
+ def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
1571
+ """Concatenates a list of SampleBatches or MultiAgentBatches.
1572
+
1573
+ If all items in the list are or SampleBatch typ4, the output will be
1574
+ a SampleBatch type. Otherwise, the output will be a MultiAgentBatch type.
1575
+ If input is a mixture of SampleBatch and MultiAgentBatch types, it will treat
1576
+ SampleBatch objects as MultiAgentBatch types with 'default_policy' key and
1577
+ concatenate it with th rest of MultiAgentBatch objects.
1578
+ Empty samples are simply ignored.
1579
+
1580
+ Args:
1581
+ samples: List of SampleBatches or MultiAgentBatches to be
1582
+ concatenated.
1583
+
1584
+ Returns:
1585
+ A new (concatenated) SampleBatch or MultiAgentBatch.
1586
+
1587
+ .. testcode::
1588
+ :skipif: True
1589
+
1590
+ import numpy as np
1591
+ from ray.rllib.policy.sample_batch import SampleBatch
1592
+ b1 = SampleBatch({"a": np.array([1, 2]),
1593
+ "b": np.array([10, 11])})
1594
+ b2 = SampleBatch({"a": np.array([3]),
1595
+ "b": np.array([12])})
1596
+ print(concat_samples([b1, b2]))
1597
+
1598
+
1599
+ c1 = MultiAgentBatch({'default_policy': {
1600
+ "a": np.array([1, 2]),
1601
+ "b": np.array([10, 11])
1602
+ }}, env_steps=2)
1603
+ c2 = SampleBatch({"a": np.array([3]),
1604
+ "b": np.array([12])})
1605
+ print(concat_samples([b1, b2]))
1606
+
1607
+ .. testoutput::
1608
+
1609
+ {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
1610
+ MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
1611
+ "b": np.array([10, 11, 12])}}
1612
+
1613
+ """
1614
+
1615
+ if any(isinstance(s, MultiAgentBatch) for s in samples):
1616
+ return concat_samples_into_ma_batch(samples)
1617
+
1618
+ # the output is a SampleBatch type
1619
+ concatd_seq_lens = []
1620
+ concatd_num_grad_updates = [0, 0.0] # [0]=count; [1]=weighted sum values
1621
+ concated_samples = []
1622
+ # Make sure these settings are consistent amongst all batches.
1623
+ zero_padded = max_seq_len = time_major = None
1624
+ for s in samples:
1625
+ if s.count <= 0:
1626
+ continue
1627
+
1628
+ if max_seq_len is None:
1629
+ zero_padded = s.zero_padded
1630
+ max_seq_len = s.max_seq_len
1631
+ time_major = s.time_major
1632
+
1633
+ # Make sure these settings are consistent amongst all batches.
1634
+ if s.zero_padded != zero_padded or s.time_major != time_major:
1635
+ raise ValueError(
1636
+ "All SampleBatches' `zero_padded` and `time_major` settings "
1637
+ "must be consistent!"
1638
+ )
1639
+ if (
1640
+ s.max_seq_len is None or max_seq_len is None
1641
+ ) and s.max_seq_len != max_seq_len:
1642
+ raise ValueError(
1643
+ "Samples must consistently either provide or omit " "`max_seq_len`!"
1644
+ )
1645
+ elif zero_padded and s.max_seq_len != max_seq_len:
1646
+ raise ValueError(
1647
+ "For `zero_padded` SampleBatches, the values of `max_seq_len` "
1648
+ "must be consistent!"
1649
+ )
1650
+
1651
+ if max_seq_len is not None:
1652
+ max_seq_len = max(max_seq_len, s.max_seq_len)
1653
+ if s.get(SampleBatch.SEQ_LENS) is not None:
1654
+ concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])
1655
+ if s.num_grad_updates is not None:
1656
+ concatd_num_grad_updates[0] += s.count
1657
+ concatd_num_grad_updates[1] += s.num_grad_updates * s.count
1658
+
1659
+ concated_samples.append(s)
1660
+
1661
+ # If we don't have any samples (0 or only empty SampleBatches),
1662
+ # return an empty SampleBatch here.
1663
+ if len(concated_samples) == 0:
1664
+ return SampleBatch()
1665
+
1666
+ # Collect the concat'd data.
1667
+ concatd_data = {}
1668
+
1669
+ for k in concated_samples[0].keys():
1670
+ if k == SampleBatch.INFOS:
1671
+ concatd_data[k] = _concat_values(
1672
+ *[s[k] for s in concated_samples],
1673
+ time_major=time_major,
1674
+ )
1675
+ else:
1676
+ values_to_concat = [c[k] for c in concated_samples]
1677
+ _concat_values_w_time = partial(_concat_values, time_major=time_major)
1678
+ concatd_data[k] = tree.map_structure(
1679
+ _concat_values_w_time, *values_to_concat
1680
+ )
1681
+
1682
+ if concatd_seq_lens != [] and torch and torch.is_tensor(concatd_seq_lens[0]):
1683
+ concatd_seq_lens = torch.Tensor(concatd_seq_lens)
1684
+ elif concatd_seq_lens != [] and tf and tf.is_tensor(concatd_seq_lens[0]):
1685
+ concatd_seq_lens = tf.convert_to_tensor(concatd_seq_lens)
1686
+
1687
+ # Return a new (concat'd) SampleBatch.
1688
+ return SampleBatch(
1689
+ concatd_data,
1690
+ seq_lens=concatd_seq_lens,
1691
+ _time_major=time_major,
1692
+ _zero_padded=zero_padded,
1693
+ _max_seq_len=max_seq_len,
1694
+ # Compute weighted average of the num_grad_updates for the batches
1695
+ # (assuming they all come from the same policy).
1696
+ _num_grad_updates=(
1697
+ concatd_num_grad_updates[1] / (concatd_num_grad_updates[0] or 1.0)
1698
+ ),
1699
+ )
1700
+
1701
+
1702
+ @PublicAPI
1703
+ def concat_samples_into_ma_batch(samples: List[SampleBatchType]) -> "MultiAgentBatch":
1704
+ """Concatenates a list of SampleBatchTypes to a single MultiAgentBatch type.
1705
+
1706
+ This function, as opposed to concat_samples() forces the output to always be
1707
+ MultiAgentBatch which is more generic than SampleBatch.
1708
+
1709
+ Args:
1710
+ samples: List of SampleBatches or MultiAgentBatches to be
1711
+ concatenated.
1712
+
1713
+ Returns:
1714
+ A new (concatenated) MultiAgentBatch.
1715
+
1716
+ .. testcode::
1717
+ :skipif: True
1718
+
1719
+ import numpy as np
1720
+ from ray.rllib.policy.sample_batch import SampleBatch
1721
+ b1 = MultiAgentBatch({'default_policy': {
1722
+ "a": np.array([1, 2]),
1723
+ "b": np.array([10, 11])
1724
+ }}, env_steps=2)
1725
+ b2 = SampleBatch({"a": np.array([3]),
1726
+ "b": np.array([12])})
1727
+ print(concat_samples([b1, b2]))
1728
+
1729
+ .. testoutput::
1730
+
1731
+ {'default_policy': {"a": np.array([1, 2, 3]),
1732
+ "b": np.array([10, 11, 12])}}
1733
+
1734
+ """
1735
+
1736
+ policy_batches = collections.defaultdict(list)
1737
+ env_steps = 0
1738
+ for s in samples:
1739
+ # Some batches in `samples` may be SampleBatch.
1740
+ if isinstance(s, SampleBatch):
1741
+ # If empty SampleBatch: ok (just ignore).
1742
+ if len(s) <= 0:
1743
+ continue
1744
+ else:
1745
+ # if non-empty: just convert to MA-batch and move forward
1746
+ s = s.as_multi_agent()
1747
+ elif not isinstance(s, MultiAgentBatch):
1748
+ # Otherwise: Error.
1749
+ raise ValueError(
1750
+ "`concat_samples_into_ma_batch` can only concat "
1751
+ "SampleBatch|MultiAgentBatch objects, not {}!".format(type(s).__name__)
1752
+ )
1753
+
1754
+ for key, batch in s.policy_batches.items():
1755
+ policy_batches[key].append(batch)
1756
+ env_steps += s.env_steps()
1757
+
1758
+ out = {}
1759
+ for key, batches in policy_batches.items():
1760
+ out[key] = concat_samples(batches)
1761
+
1762
+ return MultiAgentBatch(out, env_steps)
1763
+
1764
+
1765
+ def _concat_values(*values, time_major=None) -> TensorType:
1766
+ """Concatenates a list of values.
1767
+
1768
+ Args:
1769
+ values: The values to concatenate.
1770
+ time_major: Whether to concatenate along the first axis
1771
+ (time_major=False) or the second axis (time_major=True).
1772
+ """
1773
+ if torch and torch.is_tensor(values[0]):
1774
+ return torch.cat(values, dim=1 if time_major else 0)
1775
+ elif isinstance(values[0], np.ndarray):
1776
+ return np.concatenate(values, axis=1 if time_major else 0)
1777
+ elif tf and tf.is_tensor(values[0]):
1778
+ return tf.concat(values, axis=1 if time_major else 0)
1779
+ elif isinstance(values[0], list):
1780
+ concatenated_list = []
1781
+ for sublist in values:
1782
+ concatenated_list.extend(sublist)
1783
+ return concatenated_list
1784
+ else:
1785
+ raise ValueError(
1786
+ f"Unsupported type for concatenation: {type(values[0])} "
1787
+ f"first element: {values[0]}"
1788
+ )
1789
+
1790
+
1791
+ @DeveloperAPI
1792
+ def convert_ma_batch_to_sample_batch(batch: SampleBatchType) -> SampleBatch:
1793
+ """Converts a MultiAgentBatch to a SampleBatch if neccessary.
1794
+
1795
+ Args:
1796
+ batch: The SampleBatchType to convert.
1797
+
1798
+ Returns:
1799
+ batch: the converted SampleBatch
1800
+
1801
+ Raises:
1802
+ ValueError if the MultiAgentBatch has more than one policy_id
1803
+ or if the policy_id is not `DEFAULT_POLICY_ID`
1804
+ """
1805
+ if isinstance(batch, MultiAgentBatch):
1806
+ policy_keys = batch.policy_batches.keys()
1807
+ if len(policy_keys) == 1 and DEFAULT_POLICY_ID in policy_keys:
1808
+ batch = batch.policy_batches[DEFAULT_POLICY_ID]
1809
+ else:
1810
+ raise ValueError(
1811
+ "RLlib tried to convert a multi agent-batch with data from more "
1812
+ "than one policy to a single-agent batch. This is not supported and "
1813
+ "may be due to a number of issues. Here are two possible ones:"
1814
+ "1) Off-Policy Estimation is not implemented for "
1815
+ "multi-agent batches. You can set `off_policy_estimation_methods: {}` "
1816
+ "to resolve this."
1817
+ "2) Loading multi-agent data for offline training is not implemented."
1818
+ "Load single-agent data instead to resolve this."
1819
+ )
1820
+ return batch
.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_mixins.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List
3
+
4
+ import numpy as np
5
+
6
+
7
+ from ray.rllib.models.modelv2 import ModelV2
8
+ from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
9
+ from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
10
+ from ray.rllib.policy.policy import PolicyState
11
+ from ray.rllib.policy.sample_batch import SampleBatch
12
+ from ray.rllib.policy.tf_policy import TFPolicy
13
+ from ray.rllib.utils.annotations import OldAPIStack
14
+ from ray.rllib.utils.framework import get_variable, try_import_tf
15
+ from ray.rllib.utils.schedules import PiecewiseSchedule
16
+ from ray.rllib.utils.tf_utils import make_tf_callable
17
+ from ray.rllib.utils.typing import (
18
+ AlgorithmConfigDict,
19
+ LocalOptimizer,
20
+ ModelGradients,
21
+ TensorType,
22
+ )
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+ tf1, tf, tfv = try_import_tf()
27
+
28
+
29
+ @OldAPIStack
30
+ class LearningRateSchedule:
31
+ """Mixin for TFPolicy that adds a learning rate schedule."""
32
+
33
+ def __init__(self, lr, lr_schedule):
34
+ self._lr_schedule = None
35
+ if lr_schedule is None:
36
+ self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
37
+ else:
38
+ self._lr_schedule = PiecewiseSchedule(
39
+ lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
40
+ )
41
+ self.cur_lr = tf1.get_variable(
42
+ "lr", initializer=self._lr_schedule.value(0), trainable=False
43
+ )
44
+ if self.framework == "tf":
45
+ self._lr_placeholder = tf1.placeholder(dtype=tf.float32, name="lr")
46
+ self._lr_update = self.cur_lr.assign(
47
+ self._lr_placeholder, read_value=False
48
+ )
49
+
50
+ def on_global_var_update(self, global_vars):
51
+ super().on_global_var_update(global_vars)
52
+ if self._lr_schedule is not None:
53
+ new_val = self._lr_schedule.value(global_vars["timestep"])
54
+ if self.framework == "tf":
55
+ self.get_session().run(
56
+ self._lr_update, feed_dict={self._lr_placeholder: new_val}
57
+ )
58
+ else:
59
+ self.cur_lr.assign(new_val, read_value=False)
60
+ # This property (self._optimizer) is (still) accessible for
61
+ # both TFPolicy and any TFPolicy_eager.
62
+ self._optimizer.learning_rate.assign(self.cur_lr)
63
+
64
+ def optimizer(self):
65
+ if self.framework == "tf":
66
+ return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
67
+ else:
68
+ return tf.keras.optimizers.Adam(self.cur_lr)
69
+
70
+
71
+ @OldAPIStack
72
+ class EntropyCoeffSchedule:
73
+ """Mixin for TFPolicy that adds entropy coeff decay."""
74
+
75
+ def __init__(self, entropy_coeff, entropy_coeff_schedule):
76
+ self._entropy_coeff_schedule = None
77
+ if entropy_coeff_schedule is None:
78
+ self.entropy_coeff = get_variable(
79
+ entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False
80
+ )
81
+ else:
82
+ # Allows for custom schedule similar to lr_schedule format
83
+ if isinstance(entropy_coeff_schedule, list):
84
+ self._entropy_coeff_schedule = PiecewiseSchedule(
85
+ entropy_coeff_schedule,
86
+ outside_value=entropy_coeff_schedule[-1][-1],
87
+ framework=None,
88
+ )
89
+ else:
90
+ # Implements previous version but enforces outside_value
91
+ self._entropy_coeff_schedule = PiecewiseSchedule(
92
+ [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
93
+ outside_value=0.0,
94
+ framework=None,
95
+ )
96
+
97
+ self.entropy_coeff = get_variable(
98
+ self._entropy_coeff_schedule.value(0),
99
+ framework="tf",
100
+ tf_name="entropy_coeff",
101
+ trainable=False,
102
+ )
103
+ if self.framework == "tf":
104
+ self._entropy_coeff_placeholder = tf1.placeholder(
105
+ dtype=tf.float32, name="entropy_coeff"
106
+ )
107
+ self._entropy_coeff_update = self.entropy_coeff.assign(
108
+ self._entropy_coeff_placeholder, read_value=False
109
+ )
110
+
111
+ def on_global_var_update(self, global_vars):
112
+ super().on_global_var_update(global_vars)
113
+ if self._entropy_coeff_schedule is not None:
114
+ new_val = self._entropy_coeff_schedule.value(global_vars["timestep"])
115
+ if self.framework == "tf":
116
+ self.get_session().run(
117
+ self._entropy_coeff_update,
118
+ feed_dict={self._entropy_coeff_placeholder: new_val},
119
+ )
120
+ else:
121
+ self.entropy_coeff.assign(new_val, read_value=False)
122
+
123
+
124
+ @OldAPIStack
125
+ class KLCoeffMixin:
126
+ """Assigns the `update_kl()` and other KL-related methods to a TFPolicy.
127
+
128
+ This is used in Algorithms to update the KL coefficient after each
129
+ learning step based on `config.kl_target` and the measured KL value
130
+ (from the train_batch).
131
+ """
132
+
133
+ def __init__(self, config: AlgorithmConfigDict):
134
+ # The current KL value (as python float).
135
+ self.kl_coeff_val = config["kl_coeff"]
136
+ # The current KL value (as tf Variable for in-graph operations).
137
+ self.kl_coeff = get_variable(
138
+ float(self.kl_coeff_val),
139
+ tf_name="kl_coeff",
140
+ trainable=False,
141
+ framework=config["framework"],
142
+ )
143
+ # Constant target value.
144
+ self.kl_target = config["kl_target"]
145
+ if self.framework == "tf":
146
+ self._kl_coeff_placeholder = tf1.placeholder(
147
+ dtype=tf.float32, name="kl_coeff"
148
+ )
149
+ self._kl_coeff_update = self.kl_coeff.assign(
150
+ self._kl_coeff_placeholder, read_value=False
151
+ )
152
+
153
+ def update_kl(self, sampled_kl):
154
+ # Update the current KL value based on the recently measured value.
155
+ # Increase.
156
+ if sampled_kl > 2.0 * self.kl_target:
157
+ self.kl_coeff_val *= 1.5
158
+ # Decrease.
159
+ elif sampled_kl < 0.5 * self.kl_target:
160
+ self.kl_coeff_val *= 0.5
161
+ # No change.
162
+ else:
163
+ return self.kl_coeff_val
164
+
165
+ # Make sure, new value is also stored in graph/tf variable.
166
+ self._set_kl_coeff(self.kl_coeff_val)
167
+
168
+ # Return the current KL value.
169
+ return self.kl_coeff_val
170
+
171
+ def _set_kl_coeff(self, new_kl_coeff):
172
+ # Set the (off graph) value.
173
+ self.kl_coeff_val = new_kl_coeff
174
+
175
+ # Update the tf/tf2 Variable (via session call for tf or `assign`).
176
+ if self.framework == "tf":
177
+ self.get_session().run(
178
+ self._kl_coeff_update,
179
+ feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val},
180
+ )
181
+ else:
182
+ self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
183
+
184
+ def get_state(self) -> PolicyState:
185
+ state = super().get_state()
186
+ # Add current kl-coeff value.
187
+ state["current_kl_coeff"] = self.kl_coeff_val
188
+ return state
189
+
190
+ def set_state(self, state: PolicyState) -> None:
191
+ # Set current kl-coeff value first.
192
+ self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"]))
193
+ # Call super's set_state with rest of the state dict.
194
+ super().set_state(state)
195
+
196
+
197
+ @OldAPIStack
198
+ class TargetNetworkMixin:
199
+ """Assign the `update_target` method to the policy.
200
+
201
+ The function is called every `target_network_update_freq` steps by the
202
+ master learner.
203
+ """
204
+
205
+ def __init__(self):
206
+ model_vars = self.model.trainable_variables()
207
+ target_model_vars = self.target_model.trainable_variables()
208
+
209
+ @make_tf_callable(self.get_session())
210
+ def update_target_fn(tau):
211
+ tau = tf.convert_to_tensor(tau, dtype=tf.float32)
212
+ update_target_expr = []
213
+ assert len(model_vars) == len(target_model_vars), (
214
+ model_vars,
215
+ target_model_vars,
216
+ )
217
+ for var, var_target in zip(model_vars, target_model_vars):
218
+ update_target_expr.append(
219
+ var_target.assign(tau * var + (1.0 - tau) * var_target)
220
+ )
221
+ logger.debug("Update target op {}".format(var_target))
222
+ return tf.group(*update_target_expr)
223
+
224
+ # Hard initial update.
225
+ self._do_update = update_target_fn
226
+ # TODO: The previous SAC implementation does an update(1.0) here.
227
+ # If this is changed to tau != 1.0 the sac_loss_function test fails. Why?
228
+ # Also the test is not very maintainable, we need to change that unittest
229
+ # anyway.
230
+ self.update_target(tau=1.0) # self.config.get("tau", 1.0))
231
+
232
+ @property
233
+ def q_func_vars(self):
234
+ if not hasattr(self, "_q_func_vars"):
235
+ self._q_func_vars = self.model.variables()
236
+ return self._q_func_vars
237
+
238
+ @property
239
+ def target_q_func_vars(self):
240
+ if not hasattr(self, "_target_q_func_vars"):
241
+ self._target_q_func_vars = self.target_model.variables()
242
+ return self._target_q_func_vars
243
+
244
+ # Support both hard and soft sync.
245
+ def update_target(self, tau: int = None) -> None:
246
+ self._do_update(np.float32(tau or self.config.get("tau", 1.0)))
247
+
248
+ def variables(self) -> List[TensorType]:
249
+ return self.model.variables()
250
+
251
+ def set_weights(self, weights):
252
+ if isinstance(self, TFPolicy):
253
+ TFPolicy.set_weights(self, weights)
254
+ elif isinstance(self, EagerTFPolicyV2): # Handle TF2V2 policies.
255
+ EagerTFPolicyV2.set_weights(self, weights)
256
+ elif isinstance(self, EagerTFPolicy): # Handle TF2 policies.
257
+ EagerTFPolicy.set_weights(self, weights)
258
+ self.update_target(self.config.get("tau", 1.0))
259
+
260
+
261
+ @OldAPIStack
262
+ class ValueNetworkMixin:
263
+ """Assigns the `_value()` method to a TFPolicy.
264
+
265
+ This way, Policy can call `_value()` to get the current VF estimate on a
266
+ single(!) observation (as done in `postprocess_trajectory_fn`).
267
+ Note: When doing this, an actual forward pass is being performed.
268
+ This is different from only calling `model.value_function()`, where
269
+ the result of the most recent forward pass is being used to return an
270
+ already calculated tensor.
271
+ """
272
+
273
+ def __init__(self, config):
274
+ # When doing GAE or vtrace, we need the value function estimate on the
275
+ # observation.
276
+ if config.get("use_gae") or config.get("vtrace"):
277
+ # Input dict is provided to us automatically via the Model's
278
+ # requirements. It's a single-timestep (last one in trajectory)
279
+ # input_dict.
280
+ @make_tf_callable(self.get_session())
281
+ def value(**input_dict):
282
+ input_dict = SampleBatch(input_dict)
283
+ if isinstance(self.model, tf.keras.Model):
284
+ _, _, extra_outs = self.model(input_dict)
285
+ return extra_outs[SampleBatch.VF_PREDS][0]
286
+ else:
287
+ model_out, _ = self.model(input_dict)
288
+ # [0] = remove the batch dim.
289
+ return self.model.value_function()[0]
290
+
291
+ # When not doing GAE, we do not require the value function's output.
292
+ else:
293
+
294
+ @make_tf_callable(self.get_session())
295
+ def value(*args, **kwargs):
296
+ return tf.constant(0.0)
297
+
298
+ self._value = value
299
+ self._should_cache_extra_action = config["framework"] == "tf"
300
+ self._cached_extra_action_fetches = None
301
+
302
+ def _extra_action_out_impl(self) -> Dict[str, TensorType]:
303
+ extra_action_out = super().extra_action_out_fn()
304
+ # Keras models return values for each call in third return argument
305
+ # (dict).
306
+ if isinstance(self.model, tf.keras.Model):
307
+ return extra_action_out
308
+ # Return value function outputs. VF estimates will hence be added to the
309
+ # SampleBatches produced by the sampler(s) to generate the train batches
310
+ # going into the loss function.
311
+ extra_action_out.update(
312
+ {
313
+ SampleBatch.VF_PREDS: self.model.value_function(),
314
+ }
315
+ )
316
+ return extra_action_out
317
+
318
+ def extra_action_out_fn(self) -> Dict[str, TensorType]:
319
+ if not self._should_cache_extra_action:
320
+ return self._extra_action_out_impl()
321
+
322
+ # Note: there are 2 reasons we are caching the extra_action_fetches for
323
+ # TF1 static graph here.
324
+ # 1. for better performance, so we don't query base class and model for
325
+ # extra fetches every single time.
326
+ # 2. for correctness. TF1 is special because the static graph may contain
327
+ # two logical graphs. One created by DynamicTFPolicy for action
328
+ # computation, and one created by MultiGPUTower for GPU training.
329
+ # Depending on which logical graph ran last time,
330
+ # self.model.value_function() will point to the output tensor
331
+ # of the specific logical graph, causing problem if we try to
332
+ # fetch action (run inference) using the training output tensor.
333
+ # For that reason, we cache the action output tensor from the
334
+ # vanilla DynamicTFPolicy once and call it a day.
335
+ if self._cached_extra_action_fetches is not None:
336
+ return self._cached_extra_action_fetches
337
+
338
+ self._cached_extra_action_fetches = self._extra_action_out_impl()
339
+ return self._cached_extra_action_fetches
340
+
341
+
342
+ @OldAPIStack
343
+ class GradStatsMixin:
344
+ def __init__(self):
345
+ pass
346
+
347
+ def grad_stats_fn(
348
+ self, train_batch: SampleBatch, grads: ModelGradients
349
+ ) -> Dict[str, TensorType]:
350
+ # We have support for more than one loss (list of lists of grads).
351
+ if self.config.get("_tf_policy_handles_more_than_one_loss"):
352
+ grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
353
+ # Old case: We have a single list of grads (only one loss term and
354
+ # optimizer).
355
+ else:
356
+ grad_gnorm = tf.linalg.global_norm(grads)
357
+
358
+ return {
359
+ "grad_gnorm": grad_gnorm,
360
+ }
361
+
362
+
363
+ def compute_gradients(
364
+ policy, optimizer: LocalOptimizer, loss: TensorType
365
+ ) -> ModelGradients:
366
+ # Compute the gradients.
367
+ variables = policy.model.trainable_variables
368
+ if isinstance(policy.model, ModelV2):
369
+ variables = variables()
370
+ grads_and_vars = optimizer.compute_gradients(loss, variables)
371
+
372
+ # Clip by global norm, if necessary.
373
+ if policy.config.get("grad_clip") is not None:
374
+ # Defuse inf gradients (due to super large losses).
375
+ grads = [g for (g, v) in grads_and_vars]
376
+ grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
377
+ # If the global_norm is inf -> All grads will be NaN. Stabilize this
378
+ # here by setting them to 0.0. This will simply ignore destructive loss
379
+ # calculations.
380
+ policy.grads = []
381
+ for g in grads:
382
+ if g is not None:
383
+ policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g))
384
+ else:
385
+ policy.grads.append(None)
386
+ clipped_grads_and_vars = list(zip(policy.grads, variables))
387
+ return clipped_grads_and_vars
388
+ else:
389
+ return grads_and_vars
.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy.py ADDED
@@ -0,0 +1,1200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import gymnasium as gym
6
+ import numpy as np
7
+ import tree # pip install dm_tree
8
+
9
+ import ray
10
+ import ray.experimental.tf_utils
11
+ from ray.rllib.models.modelv2 import ModelV2
12
+ from ray.rllib.policy.policy import Policy, PolicyState, PolicySpec
13
+ from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
14
+ from ray.rllib.policy.sample_batch import SampleBatch
15
+ from ray.rllib.utils import force_list
16
+ from ray.rllib.utils.annotations import OldAPIStack, override
17
+ from ray.rllib.utils.debug import summarize
18
+ from ray.rllib.utils.deprecation import Deprecated
19
+ from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
20
+ from ray.rllib.utils.framework import try_import_tf
21
+ from ray.rllib.utils.metrics import (
22
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
23
+ NUM_AGENT_STEPS_TRAINED,
24
+ NUM_GRAD_UPDATES_LIFETIME,
25
+ )
26
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
27
+ from ray.rllib.utils.spaces.space_utils import normalize_action
28
+ from ray.rllib.utils.tf_run_builder import _TFRunBuilder
29
+ from ray.rllib.utils.tf_utils import get_gpu_devices
30
+ from ray.rllib.utils.typing import (
31
+ AlgorithmConfigDict,
32
+ LocalOptimizer,
33
+ ModelGradients,
34
+ TensorType,
35
+ )
36
+ from ray.util.debug import log_once
37
+
38
+ tf1, tf, tfv = try_import_tf()
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ @OldAPIStack
43
+ class TFPolicy(Policy):
44
+ """An agent policy and loss implemented in TensorFlow.
45
+
46
+ Do not sub-class this class directly (neither should you sub-class
47
+ DynamicTFPolicy), but rather use
48
+ rllib.policy.tf_policy_template.build_tf_policy
49
+ to generate your custom tf (graph-mode or eager) Policy classes.
50
+
51
+ Extending this class enables RLlib to perform TensorFlow specific
52
+ optimizations on the policy, e.g., parallelization across gpus or
53
+ fusing multiple graphs together in the multi-agent setting.
54
+
55
+ Input tensors are typically shaped like [BATCH_SIZE, ...].
56
+
57
+ .. testcode::
58
+ :skipif: True
59
+
60
+ from ray.rllib.policy import TFPolicy
61
+ class TFPolicySubclass(TFPolicy):
62
+ ...
63
+
64
+ sess, obs_input, sampled_action, loss, loss_inputs = ...
65
+ policy = TFPolicySubclass(
66
+ sess, obs_input, sampled_action, loss, loss_inputs)
67
+ print(policy.compute_actions([1, 0, 2]))
68
+ print(policy.postprocess_trajectory(SampleBatch({...})))
69
+
70
+ .. testoutput::
71
+
72
+ (array([0, 1, 1]), [], {})
73
+ SampleBatch({"action": ..., "advantages": ..., ...})
74
+
75
+ """
76
+
77
+ # In order to create tf_policies from checkpoints, this class needs to separate
78
+ # variables into their own scopes. Normally, we would do this in the model
79
+ # catalog, but since Policy.from_state() can be called anywhere, we need to
80
+ # keep track of it here to not break the from_state API.
81
+ tf_var_creation_scope_counter = 0
82
+
83
+ @staticmethod
84
+ def next_tf_var_scope_name():
85
+ # Tracks multiple instances that are spawned from this policy via .from_state()
86
+ TFPolicy.tf_var_creation_scope_counter += 1
87
+ return f"var_scope_{TFPolicy.tf_var_creation_scope_counter}"
88
+
89
+ def __init__(
90
+ self,
91
+ observation_space: gym.spaces.Space,
92
+ action_space: gym.spaces.Space,
93
+ config: AlgorithmConfigDict,
94
+ sess: "tf1.Session",
95
+ obs_input: TensorType,
96
+ sampled_action: TensorType,
97
+ loss: Union[TensorType, List[TensorType]],
98
+ loss_inputs: List[Tuple[str, TensorType]],
99
+ model: Optional[ModelV2] = None,
100
+ sampled_action_logp: Optional[TensorType] = None,
101
+ action_input: Optional[TensorType] = None,
102
+ log_likelihood: Optional[TensorType] = None,
103
+ dist_inputs: Optional[TensorType] = None,
104
+ dist_class: Optional[type] = None,
105
+ state_inputs: Optional[List[TensorType]] = None,
106
+ state_outputs: Optional[List[TensorType]] = None,
107
+ prev_action_input: Optional[TensorType] = None,
108
+ prev_reward_input: Optional[TensorType] = None,
109
+ seq_lens: Optional[TensorType] = None,
110
+ max_seq_len: int = 20,
111
+ batch_divisibility_req: int = 1,
112
+ update_ops: List[TensorType] = None,
113
+ explore: Optional[TensorType] = None,
114
+ timestep: Optional[TensorType] = None,
115
+ ):
116
+ """Initializes a Policy object.
117
+
118
+ Args:
119
+ observation_space: Observation space of the policy.
120
+ action_space: Action space of the policy.
121
+ config: Policy-specific configuration data.
122
+ sess: The TensorFlow session to use.
123
+ obs_input: Input placeholder for observations, of shape
124
+ [BATCH_SIZE, obs...].
125
+ sampled_action: Tensor for sampling an action, of shape
126
+ [BATCH_SIZE, action...]
127
+ loss: Scalar policy loss output tensor or a list thereof
128
+ (in case there is more than one loss).
129
+ loss_inputs: A (name, placeholder) tuple for each loss input
130
+ argument. Each placeholder name must
131
+ correspond to a SampleBatch column key returned by
132
+ postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
133
+ These keys will be read from postprocessed sample batches and
134
+ fed into the specified placeholders during loss computation.
135
+ model: The optional ModelV2 to use for calculating actions and
136
+ losses. If not None, TFPolicy will provide functionality for
137
+ getting variables, calling the model's custom loss (if
138
+ provided), and importing weights into the model.
139
+ sampled_action_logp: log probability of the sampled action.
140
+ action_input: Input placeholder for actions for
141
+ logp/log-likelihood calculations.
142
+ log_likelihood: Tensor to calculate the log_likelihood (given
143
+ action_input and obs_input).
144
+ dist_class: An optional ActionDistribution class to use for
145
+ generating a dist object from distribution inputs.
146
+ dist_inputs: Tensor to calculate the distribution
147
+ inputs/parameters.
148
+ state_inputs: List of RNN state input Tensors.
149
+ state_outputs: List of RNN state output Tensors.
150
+ prev_action_input: placeholder for previous actions.
151
+ prev_reward_input: placeholder for previous rewards.
152
+ seq_lens: Placeholder for RNN sequence lengths, of shape
153
+ [NUM_SEQUENCES].
154
+ Note that NUM_SEQUENCES << BATCH_SIZE. See
155
+ policy/rnn_sequencing.py for more information.
156
+ max_seq_len: Max sequence length for LSTM training.
157
+ batch_divisibility_req: pad all agent experiences batches to
158
+ multiples of this value. This only has an effect if not using
159
+ a LSTM model.
160
+ update_ops: override the batchnorm update ops
161
+ to run when applying gradients. Otherwise we run all update
162
+ ops found in the current variable scope.
163
+ explore: Placeholder for `explore` parameter into call to
164
+ Exploration.get_exploration_action. Explicitly set this to
165
+ False for not creating any Exploration component.
166
+ timestep: Placeholder for the global sampling timestep.
167
+ """
168
+ self.framework = "tf"
169
+ super().__init__(observation_space, action_space, config)
170
+
171
+ # Get devices to build the graph on.
172
+ num_gpus = self._get_num_gpus_for_policy()
173
+ gpu_ids = get_gpu_devices()
174
+ logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
175
+
176
+ # Place on one or more CPU(s) when either:
177
+ # - Fake GPU mode.
178
+ # - num_gpus=0 (either set by user or we are in local_mode=True).
179
+ # - no GPUs available.
180
+ if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
181
+ self.devices = ["/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)]
182
+ # Place on one or more actual GPU(s), when:
183
+ # - num_gpus > 0 (set by user) AND
184
+ # - local_mode=False AND
185
+ # - actual GPUs available AND
186
+ # - non-fake GPU mode.
187
+ else:
188
+ # We are a remote worker (WORKER_MODE=1):
189
+ # GPUs should be assigned to us by ray.
190
+ if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
191
+ gpu_ids = ray.get_gpu_ids()
192
+
193
+ if len(gpu_ids) < num_gpus:
194
+ raise ValueError(
195
+ "TFPolicy was not able to find enough GPU IDs! Found "
196
+ f"{gpu_ids}, but num_gpus={num_gpus}."
197
+ )
198
+
199
+ self.devices = [f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus]
200
+
201
+ # Disable env-info placeholder.
202
+ if SampleBatch.INFOS in self.view_requirements:
203
+ self.view_requirements[SampleBatch.INFOS].used_for_compute_actions = False
204
+ self.view_requirements[SampleBatch.INFOS].used_for_training = False
205
+ # Optionally add `infos` to the output dataset
206
+ if self.config["output_config"].get("store_infos", False):
207
+ self.view_requirements[SampleBatch.INFOS].used_for_training = True
208
+
209
+ assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), (
210
+ "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` "
211
+ "not allowed! You passed in {}.".format(model)
212
+ )
213
+ self.model = model
214
+ # Auto-update model's inference view requirements, if recurrent.
215
+ if self.model is not None:
216
+ self._update_model_view_requirements_from_init_state()
217
+
218
+ # If `explore` is explicitly set to False, don't create an exploration
219
+ # component.
220
+ self.exploration = self._create_exploration() if explore is not False else None
221
+
222
+ self._sess = sess
223
+ self._obs_input = obs_input
224
+ self._prev_action_input = prev_action_input
225
+ self._prev_reward_input = prev_reward_input
226
+ self._sampled_action = sampled_action
227
+ self._is_training = self._get_is_training_placeholder()
228
+ self._is_exploring = (
229
+ explore
230
+ if explore is not None
231
+ else tf1.placeholder_with_default(True, (), name="is_exploring")
232
+ )
233
+ self._sampled_action_logp = sampled_action_logp
234
+ self._sampled_action_prob = (
235
+ tf.math.exp(self._sampled_action_logp)
236
+ if self._sampled_action_logp is not None
237
+ else None
238
+ )
239
+ self._action_input = action_input # For logp calculations.
240
+ self._dist_inputs = dist_inputs
241
+ self.dist_class = dist_class
242
+ self._cached_extra_action_out = None
243
+ self._state_inputs = state_inputs or []
244
+ self._state_outputs = state_outputs or []
245
+ self._seq_lens = seq_lens
246
+ self._max_seq_len = max_seq_len
247
+
248
+ if self._state_inputs and self._seq_lens is None:
249
+ raise ValueError(
250
+ "seq_lens tensor must be given if state inputs are defined"
251
+ )
252
+
253
+ self._batch_divisibility_req = batch_divisibility_req
254
+ self._update_ops = update_ops
255
+ self._apply_op = None
256
+ self._stats_fetches = {}
257
+ self._timestep = (
258
+ timestep
259
+ if timestep is not None
260
+ else tf1.placeholder_with_default(
261
+ tf.zeros((), dtype=tf.int64), (), name="timestep"
262
+ )
263
+ )
264
+
265
+ self._optimizers: List[LocalOptimizer] = []
266
+ # Backward compatibility and for some code shared with tf-eager Policy.
267
+ self._optimizer = None
268
+
269
+ self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
270
+ self._grads: Union[ModelGradients, List[ModelGradients]] = []
271
+ # Policy tf-variables (weights), whose values to get/set via
272
+ # get_weights/set_weights.
273
+ self._variables = None
274
+ # Local optimizer(s)' tf-variables (e.g. state vars for Adam).
275
+ # Will be stored alongside `self._variables` when checkpointing.
276
+ self._optimizer_variables: Optional[
277
+ ray.experimental.tf_utils.TensorFlowVariables
278
+ ] = None
279
+
280
+ # The loss tf-op(s). Number of losses must match number of optimizers.
281
+ self._losses = []
282
+ # Backward compatibility (in case custom child TFPolicies access this
283
+ # property).
284
+ self._loss = None
285
+ # A batch dict passed into loss function as input.
286
+ self._loss_input_dict = {}
287
+ losses = force_list(loss)
288
+ if len(losses) > 0:
289
+ self._initialize_loss(losses, loss_inputs)
290
+
291
+ # The log-likelihood calculator op.
292
+ self._log_likelihood = log_likelihood
293
+ if (
294
+ self._log_likelihood is None
295
+ and self._dist_inputs is not None
296
+ and self.dist_class is not None
297
+ ):
298
+ self._log_likelihood = self.dist_class(self._dist_inputs, self.model).logp(
299
+ self._action_input
300
+ )
301
+
302
+ @override(Policy)
303
+ def compute_actions_from_input_dict(
304
+ self,
305
+ input_dict: Union[SampleBatch, Dict[str, TensorType]],
306
+ explore: bool = None,
307
+ timestep: Optional[int] = None,
308
+ episode=None,
309
+ **kwargs,
310
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
311
+ explore = explore if explore is not None else self.config["explore"]
312
+ timestep = timestep if timestep is not None else self.global_timestep
313
+
314
+ # Switch off is_training flag in our batch.
315
+ if isinstance(input_dict, SampleBatch):
316
+ input_dict.set_training(False)
317
+ else:
318
+ # Deprecated dict input.
319
+ input_dict["is_training"] = False
320
+
321
+ builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict")
322
+ obs_batch = input_dict[SampleBatch.OBS]
323
+ to_fetch = self._build_compute_actions(
324
+ builder, input_dict=input_dict, explore=explore, timestep=timestep
325
+ )
326
+
327
+ # Execute session run to get action (and other fetches).
328
+ fetched = builder.get(to_fetch)
329
+
330
+ # Update our global timestep by the batch size.
331
+ self.global_timestep += (
332
+ len(obs_batch)
333
+ if isinstance(obs_batch, list)
334
+ else len(input_dict)
335
+ if isinstance(input_dict, SampleBatch)
336
+ else obs_batch.shape[0]
337
+ )
338
+
339
+ return fetched
340
+
341
+ @override(Policy)
342
+ def compute_actions(
343
+ self,
344
+ obs_batch: Union[List[TensorType], TensorType],
345
+ state_batches: Optional[List[TensorType]] = None,
346
+ prev_action_batch: Union[List[TensorType], TensorType] = None,
347
+ prev_reward_batch: Union[List[TensorType], TensorType] = None,
348
+ info_batch: Optional[Dict[str, list]] = None,
349
+ episodes=None,
350
+ explore: Optional[bool] = None,
351
+ timestep: Optional[int] = None,
352
+ **kwargs,
353
+ ):
354
+ explore = explore if explore is not None else self.config["explore"]
355
+ timestep = timestep if timestep is not None else self.global_timestep
356
+
357
+ builder = _TFRunBuilder(self.get_session(), "compute_actions")
358
+
359
+ input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
360
+ if state_batches:
361
+ for i, s in enumerate(state_batches):
362
+ input_dict[f"state_in_{i}"] = s
363
+ if prev_action_batch is not None:
364
+ input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
365
+ if prev_reward_batch is not None:
366
+ input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
367
+
368
+ to_fetch = self._build_compute_actions(
369
+ builder, input_dict=input_dict, explore=explore, timestep=timestep
370
+ )
371
+
372
+ # Execute session run to get action (and other fetches).
373
+ fetched = builder.get(to_fetch)
374
+
375
+ # Update our global timestep by the batch size.
376
+ self.global_timestep += (
377
+ len(obs_batch)
378
+ if isinstance(obs_batch, list)
379
+ else tree.flatten(obs_batch)[0].shape[0]
380
+ )
381
+
382
+ return fetched
383
+
384
+ @override(Policy)
385
+ def compute_log_likelihoods(
386
+ self,
387
+ actions: Union[List[TensorType], TensorType],
388
+ obs_batch: Union[List[TensorType], TensorType],
389
+ state_batches: Optional[List[TensorType]] = None,
390
+ prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
391
+ prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
392
+ actions_normalized: bool = True,
393
+ **kwargs,
394
+ ) -> TensorType:
395
+ if self._log_likelihood is None:
396
+ raise ValueError(
397
+ "Cannot compute log-prob/likelihood w/o a self._log_likelihood op!"
398
+ )
399
+
400
+ # Exploration hook before each forward pass.
401
+ self.exploration.before_compute_actions(
402
+ explore=False, tf_sess=self.get_session()
403
+ )
404
+
405
+ builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods")
406
+
407
+ # Normalize actions if necessary.
408
+ if actions_normalized is False and self.config["normalize_actions"]:
409
+ actions = normalize_action(actions, self.action_space_struct)
410
+
411
+ # Feed actions (for which we want logp values) into graph.
412
+ builder.add_feed_dict({self._action_input: actions})
413
+ # Feed observations.
414
+ builder.add_feed_dict({self._obs_input: obs_batch})
415
+ # Internal states.
416
+ state_batches = state_batches or []
417
+ if len(self._state_inputs) != len(state_batches):
418
+ raise ValueError(
419
+ "Must pass in RNN state batches for placeholders {}, got {}".format(
420
+ self._state_inputs, state_batches
421
+ )
422
+ )
423
+ builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)})
424
+ if state_batches:
425
+ builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
426
+ # Prev-a and r.
427
+ if self._prev_action_input is not None and prev_action_batch is not None:
428
+ builder.add_feed_dict({self._prev_action_input: prev_action_batch})
429
+ if self._prev_reward_input is not None and prev_reward_batch is not None:
430
+ builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
431
+ # Fetch the log_likelihoods output and return.
432
+ fetches = builder.add_fetches([self._log_likelihood])
433
+ return builder.get(fetches)[0]
434
+
435
+ @override(Policy)
436
+ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
437
+ assert self.loss_initialized()
438
+
439
+ # Switch on is_training flag in our batch.
440
+ postprocessed_batch.set_training(True)
441
+
442
+ builder = _TFRunBuilder(self.get_session(), "learn_on_batch")
443
+
444
+ # Callback handling.
445
+ learn_stats = {}
446
+ self.callbacks.on_learn_on_batch(
447
+ policy=self, train_batch=postprocessed_batch, result=learn_stats
448
+ )
449
+
450
+ fetches = self._build_learn_on_batch(builder, postprocessed_batch)
451
+ stats = builder.get(fetches)
452
+ self.num_grad_updates += 1
453
+
454
+ stats.update(
455
+ {
456
+ "custom_metrics": learn_stats,
457
+ NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
458
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
459
+ # -1, b/c we have to measure this diff before we do the update above.
460
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
461
+ self.num_grad_updates
462
+ - 1
463
+ - (postprocessed_batch.num_grad_updates or 0)
464
+ ),
465
+ }
466
+ )
467
+
468
+ return stats
469
+
470
+ @override(Policy)
471
+ def compute_gradients(
472
+ self, postprocessed_batch: SampleBatch
473
+ ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
474
+ assert self.loss_initialized()
475
+ # Switch on is_training flag in our batch.
476
+ postprocessed_batch.set_training(True)
477
+ builder = _TFRunBuilder(self.get_session(), "compute_gradients")
478
+ fetches = self._build_compute_gradients(builder, postprocessed_batch)
479
+ return builder.get(fetches)
480
+
481
+ @staticmethod
482
+ def _tf1_from_state_helper(state: PolicyState) -> "Policy":
483
+ """Recovers a TFPolicy from a state object.
484
+
485
+ The `state` of an instantiated TFPolicy can be retrieved by calling its
486
+ `get_state` method. Is meant to be used by the Policy.from_state() method to
487
+ aid with tracking variable creation.
488
+
489
+ Args:
490
+ state: The state to recover a new TFPolicy instance from.
491
+
492
+ Returns:
493
+ A new TFPolicy instance.
494
+ """
495
+ serialized_pol_spec: Optional[dict] = state.get("policy_spec")
496
+ if serialized_pol_spec is None:
497
+ raise ValueError(
498
+ "No `policy_spec` key was found in given `state`! "
499
+ "Cannot create new Policy."
500
+ )
501
+ pol_spec = PolicySpec.deserialize(serialized_pol_spec)
502
+
503
+ with tf1.variable_scope(TFPolicy.next_tf_var_scope_name()):
504
+ # Create the new policy.
505
+ new_policy = pol_spec.policy_class(
506
+ # Note(jungong) : we are intentionally not using keyward arguments here
507
+ # because some policies name the observation space parameter obs_space,
508
+ # and some others name it observation_space.
509
+ pol_spec.observation_space,
510
+ pol_spec.action_space,
511
+ pol_spec.config,
512
+ )
513
+
514
+ # Set the new policy's state (weights, optimizer vars, exploration state,
515
+ # etc..).
516
+ new_policy.set_state(state)
517
+
518
+ # Return the new policy.
519
+ return new_policy
520
+
521
+ @override(Policy)
522
+ def apply_gradients(self, gradients: ModelGradients) -> None:
523
+ assert self.loss_initialized()
524
+ builder = _TFRunBuilder(self.get_session(), "apply_gradients")
525
+ fetches = self._build_apply_gradients(builder, gradients)
526
+ builder.get(fetches)
527
+
528
+ @override(Policy)
529
+ def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
530
+ return self._variables.get_weights()
531
+
532
+ @override(Policy)
533
+ def set_weights(self, weights) -> None:
534
+ return self._variables.set_weights(weights)
535
+
536
+ @override(Policy)
537
+ def get_exploration_state(self) -> Dict[str, TensorType]:
538
+ return self.exploration.get_state(sess=self.get_session())
539
+
540
+ @Deprecated(new="get_exploration_state", error=True)
541
+ def get_exploration_info(self) -> Dict[str, TensorType]:
542
+ return self.get_exploration_state()
543
+
544
+ @override(Policy)
545
+ def is_recurrent(self) -> bool:
546
+ return len(self._state_inputs) > 0
547
+
548
+ @override(Policy)
549
+ def num_state_tensors(self) -> int:
550
+ return len(self._state_inputs)
551
+
552
+ @override(Policy)
553
+ def get_state(self) -> PolicyState:
554
+ # For tf Policies, return Policy weights and optimizer var values.
555
+ state = super().get_state()
556
+
557
+ if len(self._optimizer_variables.variables) > 0:
558
+ state["_optimizer_variables"] = self.get_session().run(
559
+ self._optimizer_variables.variables
560
+ )
561
+ # Add exploration state.
562
+ state["_exploration_state"] = self.exploration.get_state(self.get_session())
563
+ return state
564
+
565
+ @override(Policy)
566
+ def set_state(self, state: PolicyState) -> None:
567
+ # Set optimizer vars first.
568
+ optimizer_vars = state.get("_optimizer_variables", None)
569
+ if optimizer_vars is not None:
570
+ self._optimizer_variables.set_weights(optimizer_vars)
571
+ # Set exploration's state.
572
+ if hasattr(self, "exploration") and "_exploration_state" in state:
573
+ self.exploration.set_state(
574
+ state=state["_exploration_state"], sess=self.get_session()
575
+ )
576
+
577
+ # Restore global timestep.
578
+ self.global_timestep = state["global_timestep"]
579
+
580
+ # Then the Policy's (NN) weights and connectors.
581
+ super().set_state(state)
582
+
583
+ @override(Policy)
584
+ def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
585
+ """Export tensorflow graph to export_dir for serving."""
586
+ if onnx:
587
+ try:
588
+ import tf2onnx
589
+ except ImportError as e:
590
+ raise RuntimeError(
591
+ "Converting a TensorFlow model to ONNX requires "
592
+ "`tf2onnx` to be installed. Install with "
593
+ "`pip install tf2onnx`."
594
+ ) from e
595
+
596
+ with self.get_session().graph.as_default():
597
+ signature_def_map = self._build_signature_def()
598
+
599
+ sd = signature_def_map[
600
+ tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
601
+ ]
602
+ inputs = [v.name for k, v in sd.inputs.items()]
603
+ outputs = [v.name for k, v in sd.outputs.items()]
604
+
605
+ from tf2onnx import tf_loader
606
+
607
+ frozen_graph_def = tf_loader.freeze_session(
608
+ self.get_session(), input_names=inputs, output_names=outputs
609
+ )
610
+
611
+ with tf1.Session(graph=tf.Graph()) as session:
612
+ tf.import_graph_def(frozen_graph_def, name="")
613
+
614
+ g = tf2onnx.tfonnx.process_tf_graph(
615
+ session.graph,
616
+ input_names=inputs,
617
+ output_names=outputs,
618
+ inputs_as_nchw=inputs,
619
+ )
620
+
621
+ model_proto = g.make_model("onnx_model")
622
+ tf2onnx.utils.save_onnx_model(
623
+ export_dir, "model", feed_dict={}, model_proto=model_proto
624
+ )
625
+ # Save the tf.keras.Model (architecture and weights, so it can be retrieved
626
+ # w/o access to the original (custom) Model or Policy code).
627
+ elif (
628
+ hasattr(self, "model")
629
+ and hasattr(self.model, "base_model")
630
+ and isinstance(self.model.base_model, tf.keras.Model)
631
+ ):
632
+ with self.get_session().graph.as_default():
633
+ try:
634
+ self.model.base_model.save(filepath=export_dir, save_format="tf")
635
+ except Exception:
636
+ logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
637
+ else:
638
+ logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
639
+
640
+ @override(Policy)
641
+ def import_model_from_h5(self, import_file: str) -> None:
642
+ """Imports weights into tf model."""
643
+ if self.model is None:
644
+ raise NotImplementedError("No `self.model` to import into!")
645
+
646
+ # Make sure the session is the right one (see issue #7046).
647
+ with self.get_session().graph.as_default():
648
+ with self.get_session().as_default():
649
+ return self.model.import_from_h5(import_file)
650
+
651
+ @override(Policy)
652
+ def get_session(self) -> Optional["tf1.Session"]:
653
+ """Returns a reference to the TF session for this policy."""
654
+ return self._sess
655
+
656
+ def variables(self):
657
+ """Return the list of all savable variables for this policy."""
658
+ if self.model is None:
659
+ raise NotImplementedError("No `self.model` to get variables for!")
660
+ elif isinstance(self.model, tf.keras.Model):
661
+ return self.model.variables
662
+ else:
663
+ return self.model.variables()
664
+
665
+ def get_placeholder(self, name) -> "tf1.placeholder":
666
+ """Returns the given action or loss input placeholder by name.
667
+
668
+ If the loss has not been initialized and a loss input placeholder is
669
+ requested, an error is raised.
670
+
671
+ Args:
672
+ name: The name of the placeholder to return. One of
673
+ SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
674
+ `self._loss_input_dict`.
675
+
676
+ Returns:
677
+ tf1.placeholder: The placeholder under the given str key.
678
+ """
679
+ if name == SampleBatch.CUR_OBS:
680
+ return self._obs_input
681
+ elif name == SampleBatch.PREV_ACTIONS:
682
+ return self._prev_action_input
683
+ elif name == SampleBatch.PREV_REWARDS:
684
+ return self._prev_reward_input
685
+
686
+ assert self._loss_input_dict, (
687
+ "You need to populate `self._loss_input_dict` before "
688
+ "`get_placeholder()` can be called"
689
+ )
690
+ return self._loss_input_dict[name]
691
+
692
+ def loss_initialized(self) -> bool:
693
+ """Returns whether the loss term(s) have been initialized."""
694
+ return len(self._losses) > 0
695
+
696
+ def _initialize_loss(
697
+ self, losses: List[TensorType], loss_inputs: List[Tuple[str, TensorType]]
698
+ ) -> None:
699
+ """Initializes the loss op from given loss tensor and placeholders.
700
+
701
+ Args:
702
+ loss (List[TensorType]): The list of loss ops returned by some
703
+ loss function.
704
+ loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
705
+ (name, tf1.placeholders) needed for calculating the loss.
706
+ """
707
+ self._loss_input_dict = dict(loss_inputs)
708
+ self._loss_input_dict_no_rnn = {
709
+ k: v
710
+ for k, v in self._loss_input_dict.items()
711
+ if (v not in self._state_inputs and v != self._seq_lens)
712
+ }
713
+ for i, ph in enumerate(self._state_inputs):
714
+ self._loss_input_dict["state_in_{}".format(i)] = ph
715
+
716
+ if self.model and not isinstance(self.model, tf.keras.Model):
717
+ self._losses = force_list(
718
+ self.model.custom_loss(losses, self._loss_input_dict)
719
+ )
720
+ self._stats_fetches.update({"model": self.model.metrics()})
721
+ else:
722
+ self._losses = losses
723
+ # Backward compatibility.
724
+ self._loss = self._losses[0] if self._losses is not None else None
725
+
726
+ if not self._optimizers:
727
+ self._optimizers = force_list(self.optimizer())
728
+ # Backward compatibility.
729
+ self._optimizer = self._optimizers[0] if self._optimizers else None
730
+
731
+ # Supporting more than one loss/optimizer.
732
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
733
+ self._grads_and_vars = []
734
+ self._grads = []
735
+ for group in self.gradients(self._optimizers, self._losses):
736
+ g_and_v = [(g, v) for (g, v) in group if g is not None]
737
+ self._grads_and_vars.append(g_and_v)
738
+ self._grads.append([g for (g, _) in g_and_v])
739
+ # Only one optimizer and and loss term.
740
+ else:
741
+ self._grads_and_vars = [
742
+ (g, v)
743
+ for (g, v) in self.gradients(self._optimizer, self._loss)
744
+ if g is not None
745
+ ]
746
+ self._grads = [g for (g, _) in self._grads_and_vars]
747
+
748
+ if self.model:
749
+ self._variables = ray.experimental.tf_utils.TensorFlowVariables(
750
+ [], self.get_session(), self.variables()
751
+ )
752
+
753
+ # Gather update ops for any batch norm layers.
754
+ if len(self.devices) <= 1:
755
+ if not self._update_ops:
756
+ self._update_ops = tf1.get_collection(
757
+ tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
758
+ )
759
+ if self._update_ops:
760
+ logger.info(
761
+ "Update ops to run on apply gradient: {}".format(self._update_ops)
762
+ )
763
+ with tf1.control_dependencies(self._update_ops):
764
+ self._apply_op = self.build_apply_op(
765
+ optimizer=self._optimizers
766
+ if self.config["_tf_policy_handles_more_than_one_loss"]
767
+ else self._optimizer,
768
+ grads_and_vars=self._grads_and_vars,
769
+ )
770
+
771
+ if log_once("loss_used"):
772
+ logger.debug(
773
+ "These tensors were used in the loss functions:"
774
+ f"\n{summarize(self._loss_input_dict)}\n"
775
+ )
776
+
777
+ self.get_session().run(tf1.global_variables_initializer())
778
+
779
+ # TensorFlowVariables holing a flat list of all our optimizers'
780
+ # variables.
781
+ self._optimizer_variables = ray.experimental.tf_utils.TensorFlowVariables(
782
+ [v for o in self._optimizers for v in o.variables()], self.get_session()
783
+ )
784
+
785
+ def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> "TFPolicy":
786
+ """Creates a copy of self using existing input placeholders.
787
+
788
+ Optional: Only required to work with the multi-GPU optimizer.
789
+
790
+ Args:
791
+ existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
792
+ names (str) to tf1.placeholders to re-use (share) with the
793
+ returned copy of self.
794
+
795
+ Returns:
796
+ TFPolicy: A copy of self.
797
+ """
798
+ raise NotImplementedError
799
+
800
+ def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
801
+ """Extra dict to pass to the compute actions session run.
802
+
803
+ Returns:
804
+ Dict[TensorType, TensorType]: A feed dict to be added to the
805
+ feed_dict passed to the compute_actions session.run() call.
806
+ """
807
+ return {}
808
+
809
+ def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
810
+ # Cache graph fetches for action computation for better
811
+ # performance.
812
+ # This function is called every time the static graph is run
813
+ # to compute actions.
814
+ if not self._cached_extra_action_out:
815
+ self._cached_extra_action_out = self.extra_action_out_fn()
816
+ return self._cached_extra_action_out
817
+
818
+ def extra_action_out_fn(self) -> Dict[str, TensorType]:
819
+ """Extra values to fetch and return from compute_actions().
820
+
821
+ By default we return action probability/log-likelihood info
822
+ and action distribution inputs (if present).
823
+
824
+ Returns:
825
+ Dict[str, TensorType]: An extra fetch-dict to be passed to and
826
+ returned from the compute_actions() call.
827
+ """
828
+ extra_fetches = {}
829
+ # Action-logp and action-prob.
830
+ if self._sampled_action_logp is not None:
831
+ extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob
832
+ extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp
833
+ # Action-dist inputs.
834
+ if self._dist_inputs is not None:
835
+ extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs
836
+ return extra_fetches
837
+
838
+ def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
839
+ """Extra dict to pass to the compute gradients session run.
840
+
841
+ Returns:
842
+ Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
843
+ compute_gradients Session.run() call.
844
+ """
845
+ return {} # e.g, kl_coeff
846
+
847
+ def extra_compute_grad_fetches(self) -> Dict[str, any]:
848
+ """Extra values to fetch and return from compute_gradients().
849
+
850
+ Returns:
851
+ Dict[str, any]: Extra fetch dict to be added to the fetch dict
852
+ of the compute_gradients Session.run() call.
853
+ """
854
+ return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
855
+
856
+ def optimizer(self) -> "tf.keras.optimizers.Optimizer":
857
+ """TF optimizer to use for policy optimization.
858
+
859
+ Returns:
860
+ tf.keras.optimizers.Optimizer: The local optimizer to use for this
861
+ Policy's Model.
862
+ """
863
+ if hasattr(self, "config") and "lr" in self.config:
864
+ return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
865
+ else:
866
+ return tf1.train.AdamOptimizer()
867
+
868
+ def gradients(
869
+ self,
870
+ optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
871
+ loss: Union[TensorType, List[TensorType]],
872
+ ) -> Union[List[ModelGradients], List[List[ModelGradients]]]:
873
+ """Override this for a custom gradient computation behavior.
874
+
875
+ Args:
876
+ optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single
877
+ LocalOptimizer of a list thereof to use for gradient
878
+ calculations. If more than one optimizer given, the number of
879
+ optimizers must match the number of losses provided.
880
+ loss (Union[TensorType, List[TensorType]]): A single loss term
881
+ or a list thereof to use for gradient calculations.
882
+ If more than one loss given, the number of loss terms must
883
+ match the number of optimizers provided.
884
+
885
+ Returns:
886
+ Union[List[ModelGradients], List[List[ModelGradients]]]: List of
887
+ ModelGradients (grads and vars OR just grads) OR List of List
888
+ of ModelGradients in case we have more than one
889
+ optimizer/loss.
890
+ """
891
+ optimizers = force_list(optimizer)
892
+ losses = force_list(loss)
893
+
894
+ # We have more than one optimizers and loss terms.
895
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
896
+ grads = []
897
+ for optim, loss_ in zip(optimizers, losses):
898
+ grads.append(optim.compute_gradients(loss_))
899
+ # We have only one optimizer and one loss term.
900
+ else:
901
+ return optimizers[0].compute_gradients(losses[0])
902
+
903
+ def build_apply_op(
904
+ self,
905
+ optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
906
+ grads_and_vars: Union[ModelGradients, List[ModelGradients]],
907
+ ) -> "tf.Operation":
908
+ """Override this for a custom gradient apply computation behavior.
909
+
910
+ Args:
911
+ optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local
912
+ tf optimizer to use for applying the grads and vars.
913
+ grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List
914
+ of tuples with grad values and the grad-value's corresponding
915
+ tf.variable in it.
916
+
917
+ Returns:
918
+ tf.Operation: The tf op that applies all computed gradients
919
+ (`grads_and_vars`) to the model(s) via the given optimizer(s).
920
+ """
921
+ optimizers = force_list(optimizer)
922
+
923
+ # We have more than one optimizers and loss terms.
924
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
925
+ ops = []
926
+ for i, optim in enumerate(optimizers):
927
+ # Specify global_step (e.g. for TD3 which needs to count the
928
+ # num updates that have happened).
929
+ ops.append(
930
+ optim.apply_gradients(
931
+ grads_and_vars[i],
932
+ global_step=tf1.train.get_or_create_global_step(),
933
+ )
934
+ )
935
+ return tf.group(ops)
936
+ # We have only one optimizer and one loss term.
937
+ else:
938
+ return optimizers[0].apply_gradients(
939
+ grads_and_vars, global_step=tf1.train.get_or_create_global_step()
940
+ )
941
+
942
+ def _get_is_training_placeholder(self):
943
+ """Get the placeholder for _is_training, i.e., for batch norm layers.
944
+
945
+ This can be called safely before __init__ has run.
946
+ """
947
+ if not hasattr(self, "_is_training"):
948
+ self._is_training = tf1.placeholder_with_default(
949
+ False, (), name="is_training"
950
+ )
951
+ return self._is_training
952
+
953
+ def _debug_vars(self):
954
+ if log_once("grad_vars"):
955
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
956
+ for group in self._grads_and_vars:
957
+ for _, v in group:
958
+ logger.info("Optimizing variable {}".format(v))
959
+ else:
960
+ for _, v in self._grads_and_vars:
961
+ logger.info("Optimizing variable {}".format(v))
962
+
963
+ def _extra_input_signature_def(self):
964
+ """Extra input signatures to add when exporting tf model.
965
+ Inferred from extra_compute_action_feed_dict()
966
+ """
967
+ feed_dict = self.extra_compute_action_feed_dict()
968
+ return {
969
+ k.name: tf1.saved_model.utils.build_tensor_info(k) for k in feed_dict.keys()
970
+ }
971
+
972
+ def _extra_output_signature_def(self):
973
+ """Extra output signatures to add when exporting tf model.
974
+ Inferred from extra_compute_action_fetches()
975
+ """
976
+ fetches = self.extra_compute_action_fetches()
977
+ return {
978
+ k: tf1.saved_model.utils.build_tensor_info(fetches[k])
979
+ for k in fetches.keys()
980
+ }
981
+
982
+ def _build_signature_def(self):
983
+ """Build signature def map for tensorflow SavedModelBuilder."""
984
+ # build input signatures
985
+ input_signature = self._extra_input_signature_def()
986
+ input_signature["observations"] = tf1.saved_model.utils.build_tensor_info(
987
+ self._obs_input
988
+ )
989
+
990
+ if self._seq_lens is not None:
991
+ input_signature[
992
+ SampleBatch.SEQ_LENS
993
+ ] = tf1.saved_model.utils.build_tensor_info(self._seq_lens)
994
+ if self._prev_action_input is not None:
995
+ input_signature["prev_action"] = tf1.saved_model.utils.build_tensor_info(
996
+ self._prev_action_input
997
+ )
998
+ if self._prev_reward_input is not None:
999
+ input_signature["prev_reward"] = tf1.saved_model.utils.build_tensor_info(
1000
+ self._prev_reward_input
1001
+ )
1002
+
1003
+ input_signature["is_training"] = tf1.saved_model.utils.build_tensor_info(
1004
+ self._is_training
1005
+ )
1006
+
1007
+ if self._timestep is not None:
1008
+ input_signature["timestep"] = tf1.saved_model.utils.build_tensor_info(
1009
+ self._timestep
1010
+ )
1011
+
1012
+ for state_input in self._state_inputs:
1013
+ input_signature[state_input.name] = tf1.saved_model.utils.build_tensor_info(
1014
+ state_input
1015
+ )
1016
+
1017
+ # build output signatures
1018
+ output_signature = self._extra_output_signature_def()
1019
+ for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
1020
+ output_signature[
1021
+ "actions_{}".format(i)
1022
+ ] = tf1.saved_model.utils.build_tensor_info(a)
1023
+
1024
+ for state_output in self._state_outputs:
1025
+ output_signature[
1026
+ state_output.name
1027
+ ] = tf1.saved_model.utils.build_tensor_info(state_output)
1028
+ signature_def = tf1.saved_model.signature_def_utils.build_signature_def(
1029
+ input_signature,
1030
+ output_signature,
1031
+ tf1.saved_model.signature_constants.PREDICT_METHOD_NAME,
1032
+ )
1033
+ signature_def_key = (
1034
+ tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
1035
+ )
1036
+ signature_def_map = {signature_def_key: signature_def}
1037
+ return signature_def_map
1038
+
1039
+ def _build_compute_actions(
1040
+ self,
1041
+ builder,
1042
+ *,
1043
+ input_dict=None,
1044
+ obs_batch=None,
1045
+ state_batches=None,
1046
+ prev_action_batch=None,
1047
+ prev_reward_batch=None,
1048
+ episodes=None,
1049
+ explore=None,
1050
+ timestep=None,
1051
+ ):
1052
+ explore = explore if explore is not None else self.config["explore"]
1053
+ timestep = timestep if timestep is not None else self.global_timestep
1054
+
1055
+ # Call the exploration before_compute_actions hook.
1056
+ self.exploration.before_compute_actions(
1057
+ timestep=timestep, explore=explore, tf_sess=self.get_session()
1058
+ )
1059
+
1060
+ builder.add_feed_dict(self.extra_compute_action_feed_dict())
1061
+
1062
+ # `input_dict` given: Simply build what's in that dict.
1063
+ if hasattr(self, "_input_dict"):
1064
+ for key, value in input_dict.items():
1065
+ if key in self._input_dict:
1066
+ # Handle complex/nested spaces as well.
1067
+ tree.map_structure(
1068
+ lambda k, v: builder.add_feed_dict({k: v}),
1069
+ self._input_dict[key],
1070
+ value,
1071
+ )
1072
+ # For policies that inherit directly from TFPolicy.
1073
+ else:
1074
+ builder.add_feed_dict({self._obs_input: input_dict[SampleBatch.OBS]})
1075
+ if SampleBatch.PREV_ACTIONS in input_dict:
1076
+ builder.add_feed_dict(
1077
+ {self._prev_action_input: input_dict[SampleBatch.PREV_ACTIONS]}
1078
+ )
1079
+ if SampleBatch.PREV_REWARDS in input_dict:
1080
+ builder.add_feed_dict(
1081
+ {self._prev_reward_input: input_dict[SampleBatch.PREV_REWARDS]}
1082
+ )
1083
+ state_batches = []
1084
+ i = 0
1085
+ while "state_in_{}".format(i) in input_dict:
1086
+ state_batches.append(input_dict["state_in_{}".format(i)])
1087
+ i += 1
1088
+ builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
1089
+
1090
+ if "state_in_0" in input_dict and SampleBatch.SEQ_LENS not in input_dict:
1091
+ builder.add_feed_dict(
1092
+ {self._seq_lens: np.ones(len(input_dict["state_in_0"]))}
1093
+ )
1094
+
1095
+ builder.add_feed_dict({self._is_exploring: explore})
1096
+ if timestep is not None:
1097
+ builder.add_feed_dict({self._timestep: timestep})
1098
+
1099
+ # Determine, what exactly to fetch from the graph.
1100
+ to_fetch = (
1101
+ [self._sampled_action]
1102
+ + self._state_outputs
1103
+ + [self.extra_compute_action_fetches()]
1104
+ )
1105
+
1106
+ # Add the ops to fetch for the upcoming session call.
1107
+ fetches = builder.add_fetches(to_fetch)
1108
+ return fetches[0], fetches[1:-1], fetches[-1]
1109
+
1110
+ def _build_compute_gradients(self, builder, postprocessed_batch):
1111
+ self._debug_vars()
1112
+ builder.add_feed_dict(self.extra_compute_grad_feed_dict())
1113
+ builder.add_feed_dict(
1114
+ self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
1115
+ )
1116
+ fetches = builder.add_fetches([self._grads, self._get_grad_and_stats_fetches()])
1117
+ return fetches[0], fetches[1]
1118
+
1119
+ def _build_apply_gradients(self, builder, gradients):
1120
+ if len(gradients) != len(self._grads):
1121
+ raise ValueError(
1122
+ "Unexpected number of gradients to apply, got {} for {}".format(
1123
+ gradients, self._grads
1124
+ )
1125
+ )
1126
+ builder.add_feed_dict({self._is_training: True})
1127
+ builder.add_feed_dict(dict(zip(self._grads, gradients)))
1128
+ fetches = builder.add_fetches([self._apply_op])
1129
+ return fetches[0]
1130
+
1131
+ def _build_learn_on_batch(self, builder, postprocessed_batch):
1132
+ self._debug_vars()
1133
+
1134
+ builder.add_feed_dict(self.extra_compute_grad_feed_dict())
1135
+ builder.add_feed_dict(
1136
+ self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
1137
+ )
1138
+ fetches = builder.add_fetches(
1139
+ [
1140
+ self._apply_op,
1141
+ self._get_grad_and_stats_fetches(),
1142
+ ]
1143
+ )
1144
+ return fetches[1]
1145
+
1146
+ def _get_grad_and_stats_fetches(self):
1147
+ fetches = self.extra_compute_grad_fetches()
1148
+ if LEARNER_STATS_KEY not in fetches:
1149
+ raise ValueError("Grad fetches should contain 'stats': {...} entry")
1150
+ if self._stats_fetches:
1151
+ fetches[LEARNER_STATS_KEY] = dict(
1152
+ self._stats_fetches, **fetches[LEARNER_STATS_KEY]
1153
+ )
1154
+ return fetches
1155
+
1156
+ def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool):
1157
+ """Return a feed dict from a batch.
1158
+
1159
+ Args:
1160
+ train_batch: batch of data to derive inputs from.
1161
+ shuffle: whether to shuffle batch sequences. Shuffle may
1162
+ be done in-place. This only makes sense if you're further
1163
+ applying minibatch SGD after getting the outputs.
1164
+
1165
+ Returns:
1166
+ Feed dict of data.
1167
+ """
1168
+
1169
+ # Get batch ready for RNNs, if applicable.
1170
+ if not isinstance(train_batch, SampleBatch) or not train_batch.zero_padded:
1171
+ pad_batch_to_sequences_of_same_size(
1172
+ train_batch,
1173
+ max_seq_len=self._max_seq_len,
1174
+ shuffle=shuffle,
1175
+ batch_divisibility_req=self._batch_divisibility_req,
1176
+ feature_keys=list(self._loss_input_dict_no_rnn.keys()),
1177
+ view_requirements=self.view_requirements,
1178
+ )
1179
+
1180
+ # Mark the batch as "is_training" so the Model can use this
1181
+ # information.
1182
+ train_batch.set_training(True)
1183
+
1184
+ # Build the feed dict from the batch.
1185
+ feed_dict = {}
1186
+ for key, placeholders in self._loss_input_dict.items():
1187
+ a = tree.map_structure(
1188
+ lambda ph, v: feed_dict.__setitem__(ph, v),
1189
+ placeholders,
1190
+ train_batch[key],
1191
+ )
1192
+ del a
1193
+
1194
+ state_keys = ["state_in_{}".format(i) for i in range(len(self._state_inputs))]
1195
+ for key in state_keys:
1196
+ feed_dict[self._loss_input_dict[key]] = train_batch[key]
1197
+ if state_keys:
1198
+ feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS]
1199
+
1200
+ return feed_dict
.venv/lib/python3.11/site-packages/ray/rllib/policy/tf_policy_template.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
3
+
4
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
5
+ from ray.rllib.models.modelv2 import ModelV2
6
+ from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
7
+ from ray.rllib.policy import eager_tf_policy
8
+ from ray.rllib.policy.policy import Policy
9
+ from ray.rllib.policy.sample_batch import SampleBatch
10
+ from ray.rllib.policy.tf_policy import TFPolicy
11
+ from ray.rllib.utils import add_mixins, force_list
12
+ from ray.rllib.utils.annotations import OldAPIStack, override
13
+ from ray.rllib.utils.deprecation import (
14
+ deprecation_warning,
15
+ DEPRECATED_VALUE,
16
+ )
17
+ from ray.rllib.utils.framework import try_import_tf
18
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
19
+ from ray.rllib.utils.typing import (
20
+ ModelGradients,
21
+ TensorType,
22
+ AlgorithmConfigDict,
23
+ )
24
+
25
+ tf1, tf, tfv = try_import_tf()
26
+
27
+
28
+ @OldAPIStack
29
+ def build_tf_policy(
30
+ name: str,
31
+ *,
32
+ loss_fn: Callable[
33
+ [Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
34
+ Union[TensorType, List[TensorType]],
35
+ ],
36
+ get_default_config: Optional[Callable[[None], AlgorithmConfigDict]] = None,
37
+ postprocess_fn=None,
38
+ stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
39
+ optimizer_fn: Optional[
40
+ Callable[[Policy, AlgorithmConfigDict], "tf.keras.optimizers.Optimizer"]
41
+ ] = None,
42
+ compute_gradients_fn: Optional[
43
+ Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]
44
+ ] = None,
45
+ apply_gradients_fn: Optional[
46
+ Callable[
47
+ [Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"
48
+ ]
49
+ ] = None,
50
+ grad_stats_fn: Optional[
51
+ Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
52
+ ] = None,
53
+ extra_action_out_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
54
+ extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
55
+ validate_spaces: Optional[
56
+ Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
57
+ ] = None,
58
+ before_init: Optional[
59
+ Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
60
+ ] = None,
61
+ before_loss_init: Optional[
62
+ Callable[
63
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
64
+ ]
65
+ ] = None,
66
+ after_init: Optional[
67
+ Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
68
+ ] = None,
69
+ make_model: Optional[
70
+ Callable[
71
+ [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
72
+ ]
73
+ ] = None,
74
+ action_sampler_fn: Optional[
75
+ Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
76
+ ] = None,
77
+ action_distribution_fn: Optional[
78
+ Callable[
79
+ [Policy, ModelV2, TensorType, TensorType, TensorType],
80
+ Tuple[TensorType, type, List[TensorType]],
81
+ ]
82
+ ] = None,
83
+ mixins: Optional[List[type]] = None,
84
+ get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
85
+ # Deprecated args.
86
+ obs_include_prev_action_reward=DEPRECATED_VALUE,
87
+ extra_action_fetches_fn=None, # Use `extra_action_out_fn`.
88
+ gradients_fn=None, # Use `compute_gradients_fn`.
89
+ ) -> Type[DynamicTFPolicy]:
90
+ """Helper function for creating a dynamic tf policy at runtime.
91
+
92
+ Functions will be run in this order to initialize the policy:
93
+ 1. Placeholder setup: postprocess_fn
94
+ 2. Loss init: loss_fn, stats_fn
95
+ 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
96
+ grad_stats_fn
97
+
98
+ This means that you can e.g., depend on any policy attributes created in
99
+ the running of `loss_fn` in later functions such as `stats_fn`.
100
+
101
+ In eager mode, the following functions will be run repeatedly on each
102
+ eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
103
+ and grad_stats_fn.
104
+
105
+ This means that these functions should not define any variables internally,
106
+ otherwise they will fail in eager mode execution. Variable should only
107
+ be created in make_model (if defined).
108
+
109
+ Args:
110
+ name: Name of the policy (e.g., "PPOTFPolicy").
111
+ loss_fn (Callable[[
112
+ Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
113
+ Union[TensorType, List[TensorType]]]): Callable for calculating a
114
+ loss tensor.
115
+ get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]):
116
+ Optional callable that returns the default config to merge with any
117
+ overrides. If None, uses only(!) the user-provided
118
+ PartialAlgorithmConfigDict as dict for this Policy.
119
+ postprocess_fn (Optional[Callable[[Policy, SampleBatch,
120
+ Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
121
+ Optional callable for post-processing experience batches (called
122
+ after the parent class' `postprocess_trajectory` method).
123
+ stats_fn (Optional[Callable[[Policy, SampleBatch],
124
+ Dict[str, TensorType]]]): Optional callable that returns a dict of
125
+ TF tensors to fetch given the policy and batch input tensors. If
126
+ None, will not compute any stats.
127
+ optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict],
128
+ "tf.keras.optimizers.Optimizer"]]): Optional callable that returns
129
+ a tf.Optimizer given the policy and config. If None, will call
130
+ the base class' `optimizer()` method instead (which returns a
131
+ tf1.train.AdamOptimizer).
132
+ compute_gradients_fn (Optional[Callable[[Policy,
133
+ "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
134
+ Optional callable that returns a list of gradients. If None,
135
+ this defaults to optimizer.compute_gradients([loss]).
136
+ apply_gradients_fn (Optional[Callable[[Policy,
137
+ "tf.keras.optimizers.Optimizer", ModelGradients],
138
+ "tf.Operation"]]): Optional callable that returns an apply
139
+ gradients op given policy, tf-optimizer, and grads_and_vars. If
140
+ None, will call the base class' `build_apply_op()` method instead.
141
+ grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
142
+ Dict[str, TensorType]]]): Optional callable that returns a dict of
143
+ TF fetches given the policy, batch input, and gradient tensors. If
144
+ None, will not collect any gradient stats.
145
+ extra_action_out_fn (Optional[Callable[[Policy],
146
+ Dict[str, TensorType]]]): Optional callable that returns
147
+ a dict of TF fetches given the policy object. If None, will not
148
+ perform any extra fetches.
149
+ extra_learn_fetches_fn (Optional[Callable[[Policy],
150
+ Dict[str, TensorType]]]): Optional callable that returns a dict of
151
+ extra values to fetch and return when learning on a batch. If None,
152
+ will call the base class' `extra_compute_grad_fetches()` method
153
+ instead.
154
+ validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
155
+ AlgorithmConfigDict], None]]): Optional callable that takes the
156
+ Policy, observation_space, action_space, and config to check
157
+ the spaces for correctness. If None, no spaces checking will be
158
+ done.
159
+ before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
160
+ AlgorithmConfigDict], None]]): Optional callable to run at the
161
+ beginning of policy init that takes the same arguments as the
162
+ policy constructor. If None, this step will be skipped.
163
+ before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
164
+ gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
165
+ run prior to loss init. If None, this step will be skipped.
166
+ after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
167
+ AlgorithmConfigDict], None]]): Optional callable to run at the end of
168
+ policy init. If None, this step will be skipped.
169
+ make_model (Optional[Callable[[Policy, gym.spaces.Space,
170
+ gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable
171
+ that returns a ModelV2 object.
172
+ All policy variables should be created in this function. If None,
173
+ a default ModelV2 object will be created.
174
+ action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
175
+ Tuple[TensorType, TensorType]]]): A callable returning a sampled
176
+ action and its log-likelihood given observation and state inputs.
177
+ If None, will either use `action_distribution_fn` or
178
+ compute actions by calling self.model, then sampling from the
179
+ so parameterized action distribution.
180
+ action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
181
+ TensorType, TensorType],
182
+ Tuple[TensorType, type, List[TensorType]]]]): Optional callable
183
+ returning distribution inputs (parameters), a dist-class to
184
+ generate an action distribution object from, and internal-state
185
+ outputs (or an empty list if not applicable). If None, will either
186
+ use `action_sampler_fn` or compute actions by calling self.model,
187
+ then sampling from the so parameterized action distribution.
188
+ mixins (Optional[List[type]]): Optional list of any class mixins for
189
+ the returned policy class. These mixins will be applied in order
190
+ and will have higher precedence than the DynamicTFPolicy class.
191
+ get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
192
+ Optional callable that returns the divisibility requirement for
193
+ sample batches. If None, will assume a value of 1.
194
+
195
+ Returns:
196
+ Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
197
+ specified args.
198
+ """
199
+ original_kwargs = locals().copy()
200
+ base = add_mixins(DynamicTFPolicy, mixins)
201
+
202
+ if obs_include_prev_action_reward != DEPRECATED_VALUE:
203
+ deprecation_warning(old="obs_include_prev_action_reward", error=True)
204
+
205
+ if extra_action_fetches_fn is not None:
206
+ deprecation_warning(
207
+ old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
208
+ )
209
+
210
+ if gradients_fn is not None:
211
+ deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
212
+
213
+ class policy_cls(base):
214
+ def __init__(
215
+ self,
216
+ obs_space,
217
+ action_space,
218
+ config,
219
+ existing_model=None,
220
+ existing_inputs=None,
221
+ ):
222
+ if validate_spaces:
223
+ validate_spaces(self, obs_space, action_space, config)
224
+
225
+ if before_init:
226
+ before_init(self, obs_space, action_space, config)
227
+
228
+ def before_loss_init_wrapper(policy, obs_space, action_space, config):
229
+ if before_loss_init:
230
+ before_loss_init(policy, obs_space, action_space, config)
231
+
232
+ if extra_action_out_fn is None or policy._is_tower:
233
+ extra_action_fetches = {}
234
+ else:
235
+ extra_action_fetches = extra_action_out_fn(policy)
236
+
237
+ if hasattr(policy, "_extra_action_fetches"):
238
+ policy._extra_action_fetches.update(extra_action_fetches)
239
+ else:
240
+ policy._extra_action_fetches = extra_action_fetches
241
+
242
+ DynamicTFPolicy.__init__(
243
+ self,
244
+ obs_space=obs_space,
245
+ action_space=action_space,
246
+ config=config,
247
+ loss_fn=loss_fn,
248
+ stats_fn=stats_fn,
249
+ grad_stats_fn=grad_stats_fn,
250
+ before_loss_init=before_loss_init_wrapper,
251
+ make_model=make_model,
252
+ action_sampler_fn=action_sampler_fn,
253
+ action_distribution_fn=action_distribution_fn,
254
+ existing_inputs=existing_inputs,
255
+ existing_model=existing_model,
256
+ get_batch_divisibility_req=get_batch_divisibility_req,
257
+ )
258
+
259
+ if after_init:
260
+ after_init(self, obs_space, action_space, config)
261
+
262
+ # Got to reset global_timestep again after this fake run-through.
263
+ self.global_timestep = 0
264
+
265
+ @override(Policy)
266
+ def postprocess_trajectory(
267
+ self, sample_batch, other_agent_batches=None, episode=None
268
+ ):
269
+ # Call super's postprocess_trajectory first.
270
+ sample_batch = Policy.postprocess_trajectory(self, sample_batch)
271
+ if postprocess_fn:
272
+ return postprocess_fn(self, sample_batch, other_agent_batches, episode)
273
+ return sample_batch
274
+
275
+ @override(TFPolicy)
276
+ def optimizer(self):
277
+ if optimizer_fn:
278
+ optimizers = optimizer_fn(self, self.config)
279
+ else:
280
+ optimizers = base.optimizer(self)
281
+ optimizers = force_list(optimizers)
282
+ if self.exploration:
283
+ optimizers = self.exploration.get_exploration_optimizer(optimizers)
284
+
285
+ # No optimizers produced -> Return None.
286
+ if not optimizers:
287
+ return None
288
+ # New API: Allow more than one optimizer to be returned.
289
+ # -> Return list.
290
+ elif self.config["_tf_policy_handles_more_than_one_loss"]:
291
+ return optimizers
292
+ # Old API: Return a single LocalOptimizer.
293
+ else:
294
+ return optimizers[0]
295
+
296
+ @override(TFPolicy)
297
+ def gradients(self, optimizer, loss):
298
+ optimizers = force_list(optimizer)
299
+ losses = force_list(loss)
300
+
301
+ if compute_gradients_fn:
302
+ # New API: Allow more than one optimizer -> Return a list of
303
+ # lists of gradients.
304
+ if self.config["_tf_policy_handles_more_than_one_loss"]:
305
+ return compute_gradients_fn(self, optimizers, losses)
306
+ # Old API: Return a single List of gradients.
307
+ else:
308
+ return compute_gradients_fn(self, optimizers[0], losses[0])
309
+ else:
310
+ return base.gradients(self, optimizers, losses)
311
+
312
+ @override(TFPolicy)
313
+ def build_apply_op(self, optimizer, grads_and_vars):
314
+ if apply_gradients_fn:
315
+ return apply_gradients_fn(self, optimizer, grads_and_vars)
316
+ else:
317
+ return base.build_apply_op(self, optimizer, grads_and_vars)
318
+
319
+ @override(TFPolicy)
320
+ def extra_compute_action_fetches(self):
321
+ return dict(
322
+ base.extra_compute_action_fetches(self), **self._extra_action_fetches
323
+ )
324
+
325
+ @override(TFPolicy)
326
+ def extra_compute_grad_fetches(self):
327
+ if extra_learn_fetches_fn:
328
+ # TODO: (sven) in torch, extra_learn_fetches do not exist.
329
+ # Hence, things like td_error are returned by the stats_fn
330
+ # and end up under the LEARNER_STATS_KEY. We should
331
+ # change tf to do this as well. However, this will confilct
332
+ # the handling of LEARNER_STATS_KEY inside the multi-GPU
333
+ # train op.
334
+ # Auto-add empty learner stats dict if needed.
335
+ return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self))
336
+ else:
337
+ return base.extra_compute_grad_fetches(self)
338
+
339
+ def with_updates(**overrides):
340
+ """Allows creating a TFPolicy cls based on settings of another one.
341
+
342
+ Keyword Args:
343
+ **overrides: The settings (passed into `build_tf_policy`) that
344
+ should be different from the class that this method is called
345
+ on.
346
+
347
+ Returns:
348
+ type: A new TFPolicy sub-class.
349
+
350
+ Examples:
351
+ >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
352
+ .. name="MySpecialDQNPolicyClass",
353
+ .. loss_function=[some_new_loss_function],
354
+ .. )
355
+ """
356
+ return build_tf_policy(**dict(original_kwargs, **overrides))
357
+
358
+ def as_eager():
359
+ return eager_tf_policy._build_eager_tf_policy(**original_kwargs)
360
+
361
+ policy_cls.with_updates = staticmethod(with_updates)
362
+ policy_cls.as_eager = staticmethod(as_eager)
363
+ policy_cls.__name__ = name
364
+ policy_cls.__qualname__ = name
365
+ return policy_cls
.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_mixins.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.policy.policy import PolicyState
2
+ from ray.rllib.policy.sample_batch import SampleBatch
3
+ from ray.rllib.policy.torch_policy import TorchPolicy
4
+ from ray.rllib.utils.annotations import OldAPIStack
5
+ from ray.rllib.utils.framework import try_import_torch
6
+ from ray.rllib.utils.schedules import PiecewiseSchedule
7
+
8
+ torch, nn = try_import_torch()
9
+
10
+
11
+ @OldAPIStack
12
+ class LearningRateSchedule:
13
+ """Mixin for TorchPolicy that adds a learning rate schedule."""
14
+
15
+ def __init__(self, lr, lr_schedule, lr2=None, lr2_schedule=None):
16
+ self._lr_schedule = None
17
+ self._lr2_schedule = None
18
+ # Disable any scheduling behavior related to learning if Learner API is active.
19
+ # Schedules are handled by Learner class.
20
+ if lr_schedule is None:
21
+ self.cur_lr = lr
22
+ else:
23
+ self._lr_schedule = PiecewiseSchedule(
24
+ lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
25
+ )
26
+ self.cur_lr = self._lr_schedule.value(0)
27
+ if lr2_schedule is None:
28
+ self.cur_lr2 = lr2
29
+ else:
30
+ self._lr2_schedule = PiecewiseSchedule(
31
+ lr2_schedule, outside_value=lr2_schedule[-1][-1], framework=None
32
+ )
33
+ self.cur_lr2 = self._lr2_schedule.value(0)
34
+
35
+ def on_global_var_update(self, global_vars):
36
+ super().on_global_var_update(global_vars)
37
+ if self._lr_schedule:
38
+ self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
39
+ for opt in self._optimizers:
40
+ for p in opt.param_groups:
41
+ p["lr"] = self.cur_lr
42
+ if self._lr2_schedule:
43
+ assert len(self._optimizers) == 2
44
+ self.cur_lr2 = self._lr2_schedule.value(global_vars["timestep"])
45
+ opt = self._optimizers[1]
46
+ for p in opt.param_groups:
47
+ p["lr"] = self.cur_lr2
48
+
49
+
50
+ @OldAPIStack
51
+ class EntropyCoeffSchedule:
52
+ """Mixin for TorchPolicy that adds entropy coeff decay."""
53
+
54
+ def __init__(self, entropy_coeff, entropy_coeff_schedule):
55
+ self._entropy_coeff_schedule = None
56
+ # Disable any scheduling behavior related to learning if Learner API is active.
57
+ # Schedules are handled by Learner class.
58
+ if entropy_coeff_schedule is None:
59
+ self.entropy_coeff = entropy_coeff
60
+ else:
61
+ # Allows for custom schedule similar to lr_schedule format
62
+ if isinstance(entropy_coeff_schedule, list):
63
+ self._entropy_coeff_schedule = PiecewiseSchedule(
64
+ entropy_coeff_schedule,
65
+ outside_value=entropy_coeff_schedule[-1][-1],
66
+ framework=None,
67
+ )
68
+ else:
69
+ # Implements previous version but enforces outside_value
70
+ self._entropy_coeff_schedule = PiecewiseSchedule(
71
+ [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
72
+ outside_value=0.0,
73
+ framework=None,
74
+ )
75
+ self.entropy_coeff = self._entropy_coeff_schedule.value(0)
76
+
77
+ def on_global_var_update(self, global_vars):
78
+ super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
79
+ if self._entropy_coeff_schedule is not None:
80
+ self.entropy_coeff = self._entropy_coeff_schedule.value(
81
+ global_vars["timestep"]
82
+ )
83
+
84
+
85
+ @OldAPIStack
86
+ class KLCoeffMixin:
87
+ """Assigns the `update_kl()` method to a TorchPolicy.
88
+
89
+ This is used by Algorithms to update the KL coefficient
90
+ after each learning step based on `config.kl_target` and
91
+ the measured KL value (from the train_batch).
92
+ """
93
+
94
+ def __init__(self, config):
95
+ # The current KL value (as python float).
96
+ self.kl_coeff = config["kl_coeff"]
97
+ # Constant target value.
98
+ self.kl_target = config["kl_target"]
99
+
100
+ def update_kl(self, sampled_kl):
101
+ # Update the current KL value based on the recently measured value.
102
+ if sampled_kl > 2.0 * self.kl_target:
103
+ self.kl_coeff *= 1.5
104
+ elif sampled_kl < 0.5 * self.kl_target:
105
+ self.kl_coeff *= 0.5
106
+ # Return the current KL value.
107
+ return self.kl_coeff
108
+
109
+ def get_state(self) -> PolicyState:
110
+ state = super().get_state()
111
+ # Add current kl-coeff value.
112
+ state["current_kl_coeff"] = self.kl_coeff
113
+ return state
114
+
115
+ def set_state(self, state: PolicyState) -> None:
116
+ # Set current kl-coeff value first.
117
+ self.kl_coeff = state.pop("current_kl_coeff", self.config["kl_coeff"])
118
+ # Call super's set_state with rest of the state dict.
119
+ super().set_state(state)
120
+
121
+
122
+ @OldAPIStack
123
+ class ValueNetworkMixin:
124
+ """Assigns the `_value()` method to a TorchPolicy.
125
+
126
+ This way, Policy can call `_value()` to get the current VF estimate on a
127
+ single(!) observation (as done in `postprocess_trajectory_fn`).
128
+ Note: When doing this, an actual forward pass is being performed.
129
+ This is different from only calling `model.value_function()`, where
130
+ the result of the most recent forward pass is being used to return an
131
+ already calculated tensor.
132
+ """
133
+
134
+ def __init__(self, config):
135
+ # When doing GAE, we need the value function estimate on the
136
+ # observation.
137
+ if config.get("use_gae") or config.get("vtrace"):
138
+ # Input dict is provided to us automatically via the Model's
139
+ # requirements. It's a single-timestep (last one in trajectory)
140
+ # input_dict.
141
+
142
+ def value(**input_dict):
143
+ input_dict = SampleBatch(input_dict)
144
+ input_dict = self._lazy_tensor_dict(input_dict)
145
+ model_out, _ = self.model(input_dict)
146
+ # [0] = remove the batch dim.
147
+ return self.model.value_function()[0].item()
148
+
149
+ # When not doing GAE, we do not require the value function's output.
150
+ else:
151
+
152
+ def value(*args, **kwargs):
153
+ return 0.0
154
+
155
+ self._value = value
156
+
157
+ def extra_action_out(self, input_dict, state_batches, model, action_dist):
158
+ """Defines extra fetches per action computation.
159
+
160
+ Args:
161
+ input_dict (Dict[str, TensorType]): The input dict used for the action
162
+ computing forward pass.
163
+ state_batches (List[TensorType]): List of state tensors (empty for
164
+ non-RNNs).
165
+ model (ModelV2): The Model object of the Policy.
166
+ action_dist: The instantiated distribution
167
+ object, resulting from the model's outputs and the given
168
+ distribution class.
169
+
170
+ Returns:
171
+ Dict[str, TensorType]: Dict with extra tf fetches to perform per
172
+ action computation.
173
+ """
174
+ # Return value function outputs. VF estimates will hence be added to
175
+ # the SampleBatches produced by the sampler(s) to generate the train
176
+ # batches going into the loss function.
177
+ return {
178
+ SampleBatch.VF_PREDS: model.value_function(),
179
+ }
180
+
181
+
182
+ @OldAPIStack
183
+ class TargetNetworkMixin:
184
+ """Mixin class adding a method for (soft) target net(s) synchronizations.
185
+
186
+ - Adds the `update_target` method to the policy.
187
+ Calling `update_target` updates all target Q-networks' weights from their
188
+ respective "main" Q-networks, based on tau (smooth, partial updating).
189
+ """
190
+
191
+ def __init__(self):
192
+ # Hard initial update from Q-net(s) to target Q-net(s).
193
+ tau = self.config.get("tau", 1.0)
194
+ self.update_target(tau=tau)
195
+
196
+ def update_target(self, tau=None):
197
+ # Update_target_fn will be called periodically to copy Q network to
198
+ # target Q network, using (soft) tau-synching.
199
+ tau = tau or self.config.get("tau", 1.0)
200
+
201
+ model_state_dict = self.model.state_dict()
202
+
203
+ # Support partial (soft) synching.
204
+ # If tau == 1.0: Full sync from Q-model to target Q-model.
205
+ # Support partial (soft) synching.
206
+ # If tau == 1.0: Full sync from Q-model to target Q-model.
207
+ target_state_dict = next(iter(self.target_models.values())).state_dict()
208
+ model_state_dict = {
209
+ k: tau * model_state_dict[k] + (1 - tau) * v
210
+ for k, v in target_state_dict.items()
211
+ }
212
+
213
+ for target in self.target_models.values():
214
+ target.load_state_dict(model_state_dict)
215
+
216
+ def set_weights(self, weights):
217
+ # Makes sure that whenever we restore weights for this policy's
218
+ # model, we sync the target network (from the main model)
219
+ # at the same time.
220
+ TorchPolicy.set_weights(self, weights)
221
+ self.update_target()
.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy.py ADDED
@@ -0,0 +1,1201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import logging
4
+ import math
5
+ import os
6
+ import threading
7
+ import time
8
+ from typing import (
9
+ Any,
10
+ Callable,
11
+ Dict,
12
+ List,
13
+ Optional,
14
+ Set,
15
+ Tuple,
16
+ Type,
17
+ Union,
18
+ )
19
+
20
+ import gymnasium as gym
21
+ import numpy as np
22
+ import tree # pip install dm_tree
23
+
24
+ import ray
25
+ from ray.rllib.models.catalog import ModelCatalog
26
+ from ray.rllib.models.modelv2 import ModelV2
27
+ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
28
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
29
+ from ray.rllib.policy.policy import Policy, PolicyState
30
+ from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
31
+ from ray.rllib.policy.sample_batch import SampleBatch
32
+ from ray.rllib.utils import NullContextManager, force_list
33
+ from ray.rllib.utils.annotations import OldAPIStack, override
34
+ from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
35
+ from ray.rllib.utils.framework import try_import_torch
36
+ from ray.rllib.utils.metrics import (
37
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
38
+ NUM_AGENT_STEPS_TRAINED,
39
+ NUM_GRAD_UPDATES_LIFETIME,
40
+ )
41
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
42
+ from ray.rllib.utils.numpy import convert_to_numpy
43
+ from ray.rllib.utils.spaces.space_utils import normalize_action
44
+ from ray.rllib.utils.threading import with_lock
45
+ from ray.rllib.utils.torch_utils import convert_to_torch_tensor
46
+ from ray.rllib.utils.typing import (
47
+ AlgorithmConfigDict,
48
+ GradInfoDict,
49
+ ModelGradients,
50
+ ModelWeights,
51
+ TensorStructType,
52
+ TensorType,
53
+ )
54
+
55
+ torch, nn = try_import_torch()
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ @OldAPIStack
61
+ class TorchPolicy(Policy):
62
+ """PyTorch specific Policy class to use with RLlib."""
63
+
64
+ def __init__(
65
+ self,
66
+ observation_space: gym.spaces.Space,
67
+ action_space: gym.spaces.Space,
68
+ config: AlgorithmConfigDict,
69
+ *,
70
+ model: Optional[TorchModelV2] = None,
71
+ loss: Optional[
72
+ Callable[
73
+ [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
74
+ Union[TensorType, List[TensorType]],
75
+ ]
76
+ ] = None,
77
+ action_distribution_class: Optional[Type[TorchDistributionWrapper]] = None,
78
+ action_sampler_fn: Optional[
79
+ Callable[
80
+ [TensorType, List[TensorType]],
81
+ Union[
82
+ Tuple[TensorType, TensorType, List[TensorType]],
83
+ Tuple[TensorType, TensorType, TensorType, List[TensorType]],
84
+ ],
85
+ ]
86
+ ] = None,
87
+ action_distribution_fn: Optional[
88
+ Callable[
89
+ [Policy, ModelV2, TensorType, TensorType, TensorType],
90
+ Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]],
91
+ ]
92
+ ] = None,
93
+ max_seq_len: int = 20,
94
+ get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
95
+ ):
96
+ """Initializes a TorchPolicy instance.
97
+
98
+ Args:
99
+ observation_space: Observation space of the policy.
100
+ action_space: Action space of the policy.
101
+ config: The Policy's config dict.
102
+ model: PyTorch policy module. Given observations as
103
+ input, this module must return a list of outputs where the
104
+ first item is action logits, and the rest can be any value.
105
+ loss: Callable that returns one or more (a list of) scalar loss
106
+ terms.
107
+ action_distribution_class: Class for a torch action distribution.
108
+ action_sampler_fn: A callable returning either a sampled action,
109
+ its log-likelihood and updated state or a sampled action, its
110
+ log-likelihood, updated state and action distribution inputs
111
+ given Policy, ModelV2, input_dict, state batches (optional),
112
+ explore, and timestep. Provide `action_sampler_fn` if you would
113
+ like to have full control over the action computation step,
114
+ including the model forward pass, possible sampling from a
115
+ distribution, and exploration logic.
116
+ Note: If `action_sampler_fn` is given, `action_distribution_fn`
117
+ must be None. If both `action_sampler_fn` and
118
+ `action_distribution_fn` are None, RLlib will simply pass
119
+ inputs through `self.model` to get distribution inputs, create
120
+ the distribution object, sample from it, and apply some
121
+ exploration logic to the results.
122
+ The callable takes as inputs: Policy, ModelV2, input_dict
123
+ (SampleBatch), state_batches (optional), explore, and timestep.
124
+ action_distribution_fn: A callable returning distribution inputs
125
+ (parameters), a dist-class to generate an action distribution
126
+ object from, and internal-state outputs (or an empty list if
127
+ not applicable).
128
+ Provide `action_distribution_fn` if you would like to only
129
+ customize the model forward pass call. The resulting
130
+ distribution parameters are then used by RLlib to create a
131
+ distribution object, sample from it, and execute any
132
+ exploration logic.
133
+ Note: If `action_distribution_fn` is given, `action_sampler_fn`
134
+ must be None. If both `action_sampler_fn` and
135
+ `action_distribution_fn` are None, RLlib will simply pass
136
+ inputs through `self.model` to get distribution inputs, create
137
+ the distribution object, sample from it, and apply some
138
+ exploration logic to the results.
139
+ The callable takes as inputs: Policy, ModelV2, ModelInputDict,
140
+ explore, timestep, is_training.
141
+ max_seq_len: Max sequence length for LSTM training.
142
+ get_batch_divisibility_req: Optional callable that returns the
143
+ divisibility requirement for sample batches given the Policy.
144
+ """
145
+ self.framework = config["framework"] = "torch"
146
+ self._loss_initialized = False
147
+ super().__init__(observation_space, action_space, config)
148
+
149
+ # Create multi-GPU model towers, if necessary.
150
+ # - The central main model will be stored under self.model, residing
151
+ # on self.device (normally, a CPU).
152
+ # - Each GPU will have a copy of that model under
153
+ # self.model_gpu_towers, matching the devices in self.devices.
154
+ # - Parallelization is done by splitting the train batch and passing
155
+ # it through the model copies in parallel, then averaging over the
156
+ # resulting gradients, applying these averages on the main model and
157
+ # updating all towers' weights from the main model.
158
+ # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
159
+ # parallelization will be done.
160
+
161
+ # If no Model is provided, build a default one here.
162
+ if model is None:
163
+ dist_class, logit_dim = ModelCatalog.get_action_dist(
164
+ action_space, self.config["model"], framework=self.framework
165
+ )
166
+ model = ModelCatalog.get_model_v2(
167
+ obs_space=self.observation_space,
168
+ action_space=self.action_space,
169
+ num_outputs=logit_dim,
170
+ model_config=self.config["model"],
171
+ framework=self.framework,
172
+ )
173
+ if action_distribution_class is None:
174
+ action_distribution_class = dist_class
175
+
176
+ # Get devices to build the graph on.
177
+ num_gpus = self._get_num_gpus_for_policy()
178
+ gpu_ids = list(range(torch.cuda.device_count()))
179
+ logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
180
+
181
+ # Place on one or more CPU(s) when either:
182
+ # - Fake GPU mode.
183
+ # - num_gpus=0 (either set by user or we are in local_mode=True).
184
+ # - No GPUs available.
185
+ if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
186
+ self.device = torch.device("cpu")
187
+ self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
188
+ self.model_gpu_towers = [
189
+ model if i == 0 else copy.deepcopy(model)
190
+ for i in range(int(math.ceil(num_gpus)) or 1)
191
+ ]
192
+ if hasattr(self, "target_model"):
193
+ self.target_models = {
194
+ m: self.target_model for m in self.model_gpu_towers
195
+ }
196
+ self.model = model
197
+ # Place on one or more actual GPU(s), when:
198
+ # - num_gpus > 0 (set by user) AND
199
+ # - local_mode=False AND
200
+ # - actual GPUs available AND
201
+ # - non-fake GPU mode.
202
+ else:
203
+ # We are a remote worker (WORKER_MODE=1):
204
+ # GPUs should be assigned to us by ray.
205
+ if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
206
+ gpu_ids = ray.get_gpu_ids()
207
+
208
+ if len(gpu_ids) < num_gpus:
209
+ raise ValueError(
210
+ "TorchPolicy was not able to find enough GPU IDs! Found "
211
+ f"{gpu_ids}, but num_gpus={num_gpus}."
212
+ )
213
+
214
+ self.devices = [
215
+ torch.device("cuda:{}".format(i))
216
+ for i, id_ in enumerate(gpu_ids)
217
+ if i < num_gpus
218
+ ]
219
+ self.device = self.devices[0]
220
+ ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
221
+ self.model_gpu_towers = []
222
+ for i, _ in enumerate(ids):
223
+ model_copy = copy.deepcopy(model)
224
+ self.model_gpu_towers.append(model_copy.to(self.devices[i]))
225
+ if hasattr(self, "target_model"):
226
+ self.target_models = {
227
+ m: copy.deepcopy(self.target_model).to(self.devices[i])
228
+ for i, m in enumerate(self.model_gpu_towers)
229
+ }
230
+ self.model = self.model_gpu_towers[0]
231
+
232
+ # Lock used for locking some methods on the object-level.
233
+ # This prevents possible race conditions when calling the model
234
+ # first, then its value function (e.g. in a loss function), in
235
+ # between of which another model call is made (e.g. to compute an
236
+ # action).
237
+ self._lock = threading.RLock()
238
+
239
+ self._state_inputs = self.model.get_initial_state()
240
+ self._is_recurrent = len(self._state_inputs) > 0
241
+ # Auto-update model's inference view requirements, if recurrent.
242
+ self._update_model_view_requirements_from_init_state()
243
+ # Combine view_requirements for Model and Policy.
244
+ self.view_requirements.update(self.model.view_requirements)
245
+
246
+ self.exploration = self._create_exploration()
247
+ self.unwrapped_model = model # used to support DistributedDataParallel
248
+ # To ensure backward compatibility:
249
+ # Old way: If `loss` provided here, use as-is (as a function).
250
+ if loss is not None:
251
+ self._loss = loss
252
+ # New way: Convert the overridden `self.loss` into a plain function,
253
+ # so it can be called the same way as `loss` would be, ensuring
254
+ # backward compatibility.
255
+ elif self.loss.__func__.__qualname__ != "Policy.loss":
256
+ self._loss = self.loss.__func__
257
+ # `loss` not provided nor overridden from Policy -> Set to None.
258
+ else:
259
+ self._loss = None
260
+ self._optimizers = force_list(self.optimizer())
261
+ # Store, which params (by index within the model's list of
262
+ # parameters) should be updated per optimizer.
263
+ # Maps optimizer idx to set or param indices.
264
+ self.multi_gpu_param_groups: List[Set[int]] = []
265
+ main_params = {p: i for i, p in enumerate(self.model.parameters())}
266
+ for o in self._optimizers:
267
+ param_indices = []
268
+ for pg_idx, pg in enumerate(o.param_groups):
269
+ for p in pg["params"]:
270
+ param_indices.append(main_params[p])
271
+ self.multi_gpu_param_groups.append(set(param_indices))
272
+
273
+ # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
274
+ # one with m towers (num_gpus).
275
+ num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
276
+ self._loaded_batches = [[] for _ in range(num_buffers)]
277
+
278
+ self.dist_class = action_distribution_class
279
+ self.action_sampler_fn = action_sampler_fn
280
+ self.action_distribution_fn = action_distribution_fn
281
+
282
+ # If set, means we are using distributed allreduce during learning.
283
+ self.distributed_world_size = None
284
+
285
+ self.max_seq_len = max_seq_len
286
+ self.batch_divisibility_req = (
287
+ get_batch_divisibility_req(self)
288
+ if callable(get_batch_divisibility_req)
289
+ else (get_batch_divisibility_req or 1)
290
+ )
291
+
292
+ @override(Policy)
293
+ def compute_actions_from_input_dict(
294
+ self,
295
+ input_dict: Dict[str, TensorType],
296
+ explore: bool = None,
297
+ timestep: Optional[int] = None,
298
+ **kwargs,
299
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
300
+ with torch.no_grad():
301
+ # Pass lazy (torch) tensor dict to Model as `input_dict`.
302
+ input_dict = self._lazy_tensor_dict(input_dict)
303
+ input_dict.set_training(True)
304
+ # Pack internal state inputs into (separate) list.
305
+ state_batches = [
306
+ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
307
+ ]
308
+ # Calculate RNN sequence lengths.
309
+ seq_lens = (
310
+ torch.tensor(
311
+ [1] * len(state_batches[0]),
312
+ dtype=torch.long,
313
+ device=state_batches[0].device,
314
+ )
315
+ if state_batches
316
+ else None
317
+ )
318
+
319
+ return self._compute_action_helper(
320
+ input_dict, state_batches, seq_lens, explore, timestep
321
+ )
322
+
323
+ @override(Policy)
324
+ def compute_actions(
325
+ self,
326
+ obs_batch: Union[List[TensorStructType], TensorStructType],
327
+ state_batches: Optional[List[TensorType]] = None,
328
+ prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
329
+ prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
330
+ info_batch: Optional[Dict[str, list]] = None,
331
+ episodes=None,
332
+ explore: Optional[bool] = None,
333
+ timestep: Optional[int] = None,
334
+ **kwargs,
335
+ ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
336
+ with torch.no_grad():
337
+ seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
338
+ input_dict = self._lazy_tensor_dict(
339
+ {
340
+ SampleBatch.CUR_OBS: obs_batch,
341
+ "is_training": False,
342
+ }
343
+ )
344
+ if prev_action_batch is not None:
345
+ input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
346
+ if prev_reward_batch is not None:
347
+ input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
348
+ state_batches = [
349
+ convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
350
+ ]
351
+ return self._compute_action_helper(
352
+ input_dict, state_batches, seq_lens, explore, timestep
353
+ )
354
+
355
+ @with_lock
356
+ @override(Policy)
357
+ def compute_log_likelihoods(
358
+ self,
359
+ actions: Union[List[TensorStructType], TensorStructType],
360
+ obs_batch: Union[List[TensorStructType], TensorStructType],
361
+ state_batches: Optional[List[TensorType]] = None,
362
+ prev_action_batch: Optional[
363
+ Union[List[TensorStructType], TensorStructType]
364
+ ] = None,
365
+ prev_reward_batch: Optional[
366
+ Union[List[TensorStructType], TensorStructType]
367
+ ] = None,
368
+ actions_normalized: bool = True,
369
+ **kwargs,
370
+ ) -> TensorType:
371
+ if self.action_sampler_fn and self.action_distribution_fn is None:
372
+ raise ValueError(
373
+ "Cannot compute log-prob/likelihood w/o an "
374
+ "`action_distribution_fn` and a provided "
375
+ "`action_sampler_fn`!"
376
+ )
377
+
378
+ with torch.no_grad():
379
+ input_dict = self._lazy_tensor_dict(
380
+ {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
381
+ )
382
+ if prev_action_batch is not None:
383
+ input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
384
+ if prev_reward_batch is not None:
385
+ input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
386
+ seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
387
+ state_batches = [
388
+ convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
389
+ ]
390
+
391
+ # Exploration hook before each forward pass.
392
+ self.exploration.before_compute_actions(explore=False)
393
+
394
+ # Action dist class and inputs are generated via custom function.
395
+ if self.action_distribution_fn:
396
+ # Try new action_distribution_fn signature, supporting
397
+ # state_batches and seq_lens.
398
+ try:
399
+ dist_inputs, dist_class, state_out = self.action_distribution_fn(
400
+ self,
401
+ self.model,
402
+ input_dict=input_dict,
403
+ state_batches=state_batches,
404
+ seq_lens=seq_lens,
405
+ explore=False,
406
+ is_training=False,
407
+ )
408
+ # Trying the old way (to stay backward compatible).
409
+ # TODO: Remove in future.
410
+ except TypeError as e:
411
+ if (
412
+ "positional argument" in e.args[0]
413
+ or "unexpected keyword argument" in e.args[0]
414
+ ):
415
+ dist_inputs, dist_class, _ = self.action_distribution_fn(
416
+ policy=self,
417
+ model=self.model,
418
+ obs_batch=input_dict[SampleBatch.CUR_OBS],
419
+ explore=False,
420
+ is_training=False,
421
+ )
422
+ else:
423
+ raise e
424
+
425
+ # Default action-dist inputs calculation.
426
+ else:
427
+ dist_class = self.dist_class
428
+ dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
429
+
430
+ action_dist = dist_class(dist_inputs, self.model)
431
+
432
+ # Normalize actions if necessary.
433
+ actions = input_dict[SampleBatch.ACTIONS]
434
+ if not actions_normalized and self.config["normalize_actions"]:
435
+ actions = normalize_action(actions, self.action_space_struct)
436
+
437
+ log_likelihoods = action_dist.logp(actions)
438
+
439
+ return log_likelihoods
440
+
441
+ @with_lock
442
+ @override(Policy)
443
+ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
444
+ # Set Model to train mode.
445
+ if self.model:
446
+ self.model.train()
447
+ # Callback handling.
448
+ learn_stats = {}
449
+ self.callbacks.on_learn_on_batch(
450
+ policy=self, train_batch=postprocessed_batch, result=learn_stats
451
+ )
452
+
453
+ # Compute gradients (will calculate all losses and `backward()`
454
+ # them to get the grads).
455
+ grads, fetches = self.compute_gradients(postprocessed_batch)
456
+
457
+ # Step the optimizers.
458
+ self.apply_gradients(_directStepOptimizerSingleton)
459
+
460
+ self.num_grad_updates += 1
461
+
462
+ if self.model:
463
+ fetches["model"] = self.model.metrics()
464
+
465
+ fetches.update(
466
+ {
467
+ "custom_metrics": learn_stats,
468
+ NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
469
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
470
+ # -1, b/c we have to measure this diff before we do the update above.
471
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
472
+ self.num_grad_updates
473
+ - 1
474
+ - (postprocessed_batch.num_grad_updates or 0)
475
+ ),
476
+ }
477
+ )
478
+
479
+ return fetches
480
+
481
+ @override(Policy)
482
+ def load_batch_into_buffer(
483
+ self,
484
+ batch: SampleBatch,
485
+ buffer_index: int = 0,
486
+ ) -> int:
487
+ # Set the is_training flag of the batch.
488
+ batch.set_training(True)
489
+
490
+ # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
491
+ if len(self.devices) == 1 and self.devices[0].type == "cpu":
492
+ assert buffer_index == 0
493
+ pad_batch_to_sequences_of_same_size(
494
+ batch=batch,
495
+ max_seq_len=self.max_seq_len,
496
+ shuffle=False,
497
+ batch_divisibility_req=self.batch_divisibility_req,
498
+ view_requirements=self.view_requirements,
499
+ )
500
+ self._lazy_tensor_dict(batch)
501
+ self._loaded_batches[0] = [batch]
502
+ return len(batch)
503
+
504
+ # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
505
+ # 0123 0123456 0123 0123456789ABC
506
+
507
+ # 1) split into n per-GPU sub batches (n=2).
508
+ # [0123 0123456] [012] [3 0123456789 ABC]
509
+ # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
510
+ slices = batch.timeslices(num_slices=len(self.devices))
511
+
512
+ # 2) zero-padding (max-seq-len=10).
513
+ # - [0123000000 0123456000 0120000000]
514
+ # - [3000000000 0123456789 ABC0000000]
515
+ for slice in slices:
516
+ pad_batch_to_sequences_of_same_size(
517
+ batch=slice,
518
+ max_seq_len=self.max_seq_len,
519
+ shuffle=False,
520
+ batch_divisibility_req=self.batch_divisibility_req,
521
+ view_requirements=self.view_requirements,
522
+ )
523
+
524
+ # 3) Load splits into the given buffer (consisting of n GPUs).
525
+ slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
526
+ self._loaded_batches[buffer_index] = slices
527
+
528
+ # Return loaded samples per-device.
529
+ return len(slices[0])
530
+
531
+ @override(Policy)
532
+ def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
533
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
534
+ assert buffer_index == 0
535
+ return sum(len(b) for b in self._loaded_batches[buffer_index])
536
+
537
+ @override(Policy)
538
+ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
539
+ if not self._loaded_batches[buffer_index]:
540
+ raise ValueError(
541
+ "Must call Policy.load_batch_into_buffer() before "
542
+ "Policy.learn_on_loaded_batch()!"
543
+ )
544
+
545
+ # Get the correct slice of the already loaded batch to use,
546
+ # based on offset and batch size.
547
+ device_batch_size = self.config.get("minibatch_size")
548
+ if device_batch_size is None:
549
+ device_batch_size = self.config.get(
550
+ "sgd_minibatch_size",
551
+ self.config["train_batch_size"],
552
+ )
553
+ device_batch_size //= len(self.devices)
554
+
555
+ # Set Model to train mode.
556
+ if self.model_gpu_towers:
557
+ for t in self.model_gpu_towers:
558
+ t.train()
559
+
560
+ # Shortcut for 1 CPU only: Batch should already be stored in
561
+ # `self._loaded_batches`.
562
+ if len(self.devices) == 1 and self.devices[0].type == "cpu":
563
+ assert buffer_index == 0
564
+ if device_batch_size >= len(self._loaded_batches[0][0]):
565
+ batch = self._loaded_batches[0][0]
566
+ else:
567
+ batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
568
+ return self.learn_on_batch(batch)
569
+
570
+ if len(self.devices) > 1:
571
+ # Copy weights of main model (tower-0) to all other towers.
572
+ state_dict = self.model.state_dict()
573
+ # Just making sure tower-0 is really the same as self.model.
574
+ assert self.model_gpu_towers[0] is self.model
575
+ for tower in self.model_gpu_towers[1:]:
576
+ tower.load_state_dict(state_dict)
577
+
578
+ if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
579
+ device_batches = self._loaded_batches[buffer_index]
580
+ else:
581
+ device_batches = [
582
+ b[offset : offset + device_batch_size]
583
+ for b in self._loaded_batches[buffer_index]
584
+ ]
585
+
586
+ # Callback handling.
587
+ batch_fetches = {}
588
+ for i, batch in enumerate(device_batches):
589
+ custom_metrics = {}
590
+ self.callbacks.on_learn_on_batch(
591
+ policy=self, train_batch=batch, result=custom_metrics
592
+ )
593
+ batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
594
+
595
+ # Do the (maybe parallelized) gradient calculation step.
596
+ tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
597
+
598
+ # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
599
+ all_grads = []
600
+ for i in range(len(tower_outputs[0][0])):
601
+ if tower_outputs[0][0][i] is not None:
602
+ all_grads.append(
603
+ torch.mean(
604
+ torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
605
+ dim=0,
606
+ )
607
+ )
608
+ else:
609
+ all_grads.append(None)
610
+ # Set main model's grads to mean-reduced values.
611
+ for i, p in enumerate(self.model.parameters()):
612
+ p.grad = all_grads[i]
613
+
614
+ self.apply_gradients(_directStepOptimizerSingleton)
615
+
616
+ self.num_grad_updates += 1
617
+
618
+ for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
619
+ batch_fetches[f"tower_{i}"].update(
620
+ {
621
+ LEARNER_STATS_KEY: self.extra_grad_info(batch),
622
+ "model": model.metrics(),
623
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
624
+ # -1, b/c we have to measure this diff before we do the update
625
+ # above.
626
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
627
+ self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
628
+ ),
629
+ }
630
+ )
631
+ batch_fetches.update(self.extra_compute_grad_fetches())
632
+
633
+ return batch_fetches
634
+
635
+ @with_lock
636
+ @override(Policy)
637
+ def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
638
+ assert len(self.devices) == 1
639
+
640
+ # If not done yet, see whether we have to zero-pad this batch.
641
+ if not postprocessed_batch.zero_padded:
642
+ pad_batch_to_sequences_of_same_size(
643
+ batch=postprocessed_batch,
644
+ max_seq_len=self.max_seq_len,
645
+ shuffle=False,
646
+ batch_divisibility_req=self.batch_divisibility_req,
647
+ view_requirements=self.view_requirements,
648
+ )
649
+
650
+ postprocessed_batch.set_training(True)
651
+ self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
652
+
653
+ # Do the (maybe parallelized) gradient calculation step.
654
+ tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
655
+
656
+ all_grads, grad_info = tower_outputs[0]
657
+
658
+ grad_info["allreduce_latency"] /= len(self._optimizers)
659
+ grad_info.update(self.extra_grad_info(postprocessed_batch))
660
+
661
+ fetches = self.extra_compute_grad_fetches()
662
+
663
+ return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
664
+
665
+ @override(Policy)
666
+ def apply_gradients(self, gradients: ModelGradients) -> None:
667
+ if gradients == _directStepOptimizerSingleton:
668
+ for i, opt in enumerate(self._optimizers):
669
+ opt.step()
670
+ else:
671
+ # TODO(sven): Not supported for multiple optimizers yet.
672
+ assert len(self._optimizers) == 1
673
+ for g, p in zip(gradients, self.model.parameters()):
674
+ if g is not None:
675
+ if torch.is_tensor(g):
676
+ p.grad = g.to(self.device)
677
+ else:
678
+ p.grad = torch.from_numpy(g).to(self.device)
679
+
680
+ self._optimizers[0].step()
681
+
682
+ def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
683
+ """Returns list of per-tower stats, copied to this Policy's device.
684
+
685
+ Args:
686
+ stats_name: The name of the stats to average over (this str
687
+ must exist as a key inside each tower's `tower_stats` dict).
688
+
689
+ Returns:
690
+ The list of stats tensor (structs) of all towers, copied to this
691
+ Policy's device.
692
+
693
+ Raises:
694
+ AssertionError: If the `stats_name` cannot be found in any one
695
+ of the tower's `tower_stats` dicts.
696
+ """
697
+ data = []
698
+ for tower in self.model_gpu_towers:
699
+ if stats_name in tower.tower_stats:
700
+ data.append(
701
+ tree.map_structure(
702
+ lambda s: s.to(self.device), tower.tower_stats[stats_name]
703
+ )
704
+ )
705
+ assert len(data) > 0, (
706
+ f"Stats `{stats_name}` not found in any of the towers (you have "
707
+ f"{len(self.model_gpu_towers)} towers in total)! Make "
708
+ "sure you call the loss function on at least one of the towers."
709
+ )
710
+ return data
711
+
712
+ @override(Policy)
713
+ def get_weights(self) -> ModelWeights:
714
+ return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
715
+
716
+ @override(Policy)
717
+ def set_weights(self, weights: ModelWeights) -> None:
718
+ weights = convert_to_torch_tensor(weights, device=self.device)
719
+ self.model.load_state_dict(weights)
720
+
721
+ @override(Policy)
722
+ def is_recurrent(self) -> bool:
723
+ return self._is_recurrent
724
+
725
+ @override(Policy)
726
+ def num_state_tensors(self) -> int:
727
+ return len(self.model.get_initial_state())
728
+
729
+ @override(Policy)
730
+ def get_initial_state(self) -> List[TensorType]:
731
+ return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
732
+
733
+ @override(Policy)
734
+ def get_state(self) -> PolicyState:
735
+ state = super().get_state()
736
+
737
+ state["_optimizer_variables"] = []
738
+ for i, o in enumerate(self._optimizers):
739
+ optim_state_dict = convert_to_numpy(o.state_dict())
740
+ state["_optimizer_variables"].append(optim_state_dict)
741
+ # Add exploration state.
742
+ if self.exploration:
743
+ # This is not compatible with RLModules, which have a method
744
+ # `forward_exploration` to specify custom exploration behavior.
745
+ state["_exploration_state"] = self.exploration.get_state()
746
+ return state
747
+
748
+ @override(Policy)
749
+ def set_state(self, state: PolicyState) -> None:
750
+ # Set optimizer vars first.
751
+ optimizer_vars = state.get("_optimizer_variables", None)
752
+ if optimizer_vars:
753
+ assert len(optimizer_vars) == len(self._optimizers)
754
+ for o, s in zip(self._optimizers, optimizer_vars):
755
+ # Torch optimizer param_groups include things like beta, etc. These
756
+ # parameters should be left as scalar and not converted to tensors.
757
+ # otherwise, torch.optim.step() will start to complain.
758
+ optim_state_dict = {"param_groups": s["param_groups"]}
759
+ optim_state_dict["state"] = convert_to_torch_tensor(
760
+ s["state"], device=self.device
761
+ )
762
+ o.load_state_dict(optim_state_dict)
763
+ # Set exploration's state.
764
+ if hasattr(self, "exploration") and "_exploration_state" in state:
765
+ self.exploration.set_state(state=state["_exploration_state"])
766
+
767
+ # Restore global timestep.
768
+ self.global_timestep = state["global_timestep"]
769
+
770
+ # Then the Policy's (NN) weights and connectors.
771
+ super().set_state(state)
772
+
773
+ def extra_grad_process(
774
+ self, optimizer: "torch.optim.Optimizer", loss: TensorType
775
+ ) -> Dict[str, TensorType]:
776
+ """Called after each optimizer.zero_grad() + loss.backward() call.
777
+
778
+ Called for each self._optimizers/loss-value pair.
779
+ Allows for gradient processing before optimizer.step() is called.
780
+ E.g. for gradient clipping.
781
+
782
+ Args:
783
+ optimizer: A torch optimizer object.
784
+ loss: The loss tensor associated with the optimizer.
785
+
786
+ Returns:
787
+ An dict with information on the gradient processing step.
788
+ """
789
+ return {}
790
+
791
+ def extra_compute_grad_fetches(self) -> Dict[str, Any]:
792
+ """Extra values to fetch and return from compute_gradients().
793
+
794
+ Returns:
795
+ Extra fetch dict to be added to the fetch dict of the
796
+ `compute_gradients` call.
797
+ """
798
+ return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
799
+
800
+ def extra_action_out(
801
+ self,
802
+ input_dict: Dict[str, TensorType],
803
+ state_batches: List[TensorType],
804
+ model: TorchModelV2,
805
+ action_dist: TorchDistributionWrapper,
806
+ ) -> Dict[str, TensorType]:
807
+ """Returns dict of extra info to include in experience batch.
808
+
809
+ Args:
810
+ input_dict: Dict of model input tensors.
811
+ state_batches: List of state tensors.
812
+ model: Reference to the model object.
813
+ action_dist: Torch action dist object
814
+ to get log-probs (e.g. for already sampled actions).
815
+
816
+ Returns:
817
+ Extra outputs to return in a `compute_actions_from_input_dict()`
818
+ call (3rd return value).
819
+ """
820
+ return {}
821
+
822
+ def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
823
+ """Return dict of extra grad info.
824
+
825
+ Args:
826
+ train_batch: The training batch for which to produce
827
+ extra grad info for.
828
+
829
+ Returns:
830
+ The info dict carrying grad info per str key.
831
+ """
832
+ return {}
833
+
834
+ def optimizer(
835
+ self,
836
+ ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
837
+ """Custom the local PyTorch optimizer(s) to use.
838
+
839
+ Returns:
840
+ The local PyTorch optimizer(s) to use for this Policy.
841
+ """
842
+ if hasattr(self, "config"):
843
+ optimizers = [
844
+ torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
845
+ ]
846
+ else:
847
+ optimizers = [torch.optim.Adam(self.model.parameters())]
848
+ if self.exploration:
849
+ optimizers = self.exploration.get_exploration_optimizer(optimizers)
850
+ return optimizers
851
+
852
+ @override(Policy)
853
+ def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
854
+ """Exports the Policy's Model to local directory for serving.
855
+
856
+ Creates a TorchScript model and saves it.
857
+
858
+ Args:
859
+ export_dir: Local writable directory or filename.
860
+ onnx: If given, will export model in ONNX format. The
861
+ value of this parameter set the ONNX OpSet version to use.
862
+ """
863
+ os.makedirs(export_dir, exist_ok=True)
864
+
865
+ if onnx:
866
+ self._lazy_tensor_dict(self._dummy_batch)
867
+ # Provide dummy state inputs if not an RNN (torch cannot jit with
868
+ # returned empty internal states list).
869
+ if "state_in_0" not in self._dummy_batch:
870
+ self._dummy_batch["state_in_0"] = self._dummy_batch[
871
+ SampleBatch.SEQ_LENS
872
+ ] = np.array([1.0])
873
+ seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
874
+
875
+ state_ins = []
876
+ i = 0
877
+ while "state_in_{}".format(i) in self._dummy_batch:
878
+ state_ins.append(self._dummy_batch["state_in_{}".format(i)])
879
+ i += 1
880
+ dummy_inputs = {
881
+ k: self._dummy_batch[k]
882
+ for k in self._dummy_batch.keys()
883
+ if k != "is_training"
884
+ }
885
+
886
+ file_name = os.path.join(export_dir, "model.onnx")
887
+ torch.onnx.export(
888
+ self.model,
889
+ (dummy_inputs, state_ins, seq_lens),
890
+ file_name,
891
+ export_params=True,
892
+ opset_version=onnx,
893
+ do_constant_folding=True,
894
+ input_names=list(dummy_inputs.keys())
895
+ + ["state_ins", SampleBatch.SEQ_LENS],
896
+ output_names=["output", "state_outs"],
897
+ dynamic_axes={
898
+ k: {0: "batch_size"}
899
+ for k in list(dummy_inputs.keys())
900
+ + ["state_ins", SampleBatch.SEQ_LENS]
901
+ },
902
+ )
903
+ # Save the torch.Model (architecture and weights, so it can be retrieved
904
+ # w/o access to the original (custom) Model or Policy code).
905
+ else:
906
+ filename = os.path.join(export_dir, "model.pt")
907
+ try:
908
+ torch.save(self.model, f=filename)
909
+ except Exception:
910
+ if os.path.exists(filename):
911
+ os.remove(filename)
912
+ logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
913
+
914
+ @override(Policy)
915
+ def import_model_from_h5(self, import_file: str) -> None:
916
+ """Imports weights into torch model."""
917
+ return self.model.import_from_h5(import_file)
918
+
919
+ @with_lock
920
+ def _compute_action_helper(
921
+ self, input_dict, state_batches, seq_lens, explore, timestep
922
+ ):
923
+ """Shared forward pass logic (w/ and w/o trajectory view API).
924
+
925
+ Returns:
926
+ A tuple consisting of a) actions, b) state_out, c) extra_fetches.
927
+ """
928
+ explore = explore if explore is not None else self.config["explore"]
929
+ timestep = timestep if timestep is not None else self.global_timestep
930
+ self._is_recurrent = state_batches is not None and state_batches != []
931
+
932
+ # Switch to eval mode.
933
+ if self.model:
934
+ self.model.eval()
935
+
936
+ if self.action_sampler_fn:
937
+ action_dist = dist_inputs = None
938
+ action_sampler_outputs = self.action_sampler_fn(
939
+ self,
940
+ self.model,
941
+ input_dict,
942
+ state_batches,
943
+ explore=explore,
944
+ timestep=timestep,
945
+ )
946
+ if len(action_sampler_outputs) == 4:
947
+ actions, logp, dist_inputs, state_out = action_sampler_outputs
948
+ else:
949
+ actions, logp, state_out = action_sampler_outputs
950
+ else:
951
+ # Call the exploration before_compute_actions hook.
952
+ self.exploration.before_compute_actions(explore=explore, timestep=timestep)
953
+ if self.action_distribution_fn:
954
+ # Try new action_distribution_fn signature, supporting
955
+ # state_batches and seq_lens.
956
+ try:
957
+ dist_inputs, dist_class, state_out = self.action_distribution_fn(
958
+ self,
959
+ self.model,
960
+ input_dict=input_dict,
961
+ state_batches=state_batches,
962
+ seq_lens=seq_lens,
963
+ explore=explore,
964
+ timestep=timestep,
965
+ is_training=False,
966
+ )
967
+ # Trying the old way (to stay backward compatible).
968
+ # TODO: Remove in future.
969
+ except TypeError as e:
970
+ if (
971
+ "positional argument" in e.args[0]
972
+ or "unexpected keyword argument" in e.args[0]
973
+ ):
974
+ (
975
+ dist_inputs,
976
+ dist_class,
977
+ state_out,
978
+ ) = self.action_distribution_fn(
979
+ self,
980
+ self.model,
981
+ input_dict[SampleBatch.CUR_OBS],
982
+ explore=explore,
983
+ timestep=timestep,
984
+ is_training=False,
985
+ )
986
+ else:
987
+ raise e
988
+ else:
989
+ dist_class = self.dist_class
990
+ dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
991
+
992
+ if not (
993
+ isinstance(dist_class, functools.partial)
994
+ or issubclass(dist_class, TorchDistributionWrapper)
995
+ ):
996
+ raise ValueError(
997
+ "`dist_class` ({}) not a TorchDistributionWrapper "
998
+ "subclass! Make sure your `action_distribution_fn` or "
999
+ "`make_model_and_action_dist` return a correct "
1000
+ "distribution class.".format(dist_class.__name__)
1001
+ )
1002
+ action_dist = dist_class(dist_inputs, self.model)
1003
+
1004
+ # Get the exploration action from the forward results.
1005
+ actions, logp = self.exploration.get_exploration_action(
1006
+ action_distribution=action_dist, timestep=timestep, explore=explore
1007
+ )
1008
+
1009
+ input_dict[SampleBatch.ACTIONS] = actions
1010
+
1011
+ # Add default and custom fetches.
1012
+ extra_fetches = self.extra_action_out(
1013
+ input_dict, state_batches, self.model, action_dist
1014
+ )
1015
+
1016
+ # Action-dist inputs.
1017
+ if dist_inputs is not None:
1018
+ extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
1019
+
1020
+ # Action-logp and action-prob.
1021
+ if logp is not None:
1022
+ extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
1023
+ extra_fetches[SampleBatch.ACTION_LOGP] = logp
1024
+
1025
+ # Update our global timestep by the batch size.
1026
+ self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
1027
+
1028
+ return convert_to_numpy((actions, state_out, extra_fetches))
1029
+
1030
+ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
1031
+ # TODO: (sven): Keep for a while to ensure backward compatibility.
1032
+ if not isinstance(postprocessed_batch, SampleBatch):
1033
+ postprocessed_batch = SampleBatch(postprocessed_batch)
1034
+ postprocessed_batch.set_get_interceptor(
1035
+ functools.partial(convert_to_torch_tensor, device=device or self.device)
1036
+ )
1037
+ return postprocessed_batch
1038
+
1039
+ def _multi_gpu_parallel_grad_calc(
1040
+ self, sample_batches: List[SampleBatch]
1041
+ ) -> List[Tuple[List[TensorType], GradInfoDict]]:
1042
+ """Performs a parallelized loss and gradient calculation over the batch.
1043
+
1044
+ Splits up the given train batch into n shards (n=number of this
1045
+ Policy's devices) and passes each data shard (in parallel) through
1046
+ the loss function using the individual devices' models
1047
+ (self.model_gpu_towers). Then returns each tower's outputs.
1048
+
1049
+ Args:
1050
+ sample_batches: A list of SampleBatch shards to
1051
+ calculate loss and gradients for.
1052
+
1053
+ Returns:
1054
+ A list (one item per device) of 2-tuples, each with 1) gradient
1055
+ list and 2) grad info dict.
1056
+ """
1057
+ assert len(self.model_gpu_towers) == len(sample_batches)
1058
+ lock = threading.Lock()
1059
+ results = {}
1060
+ grad_enabled = torch.is_grad_enabled()
1061
+
1062
+ def _worker(shard_idx, model, sample_batch, device):
1063
+ torch.set_grad_enabled(grad_enabled)
1064
+ try:
1065
+ with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
1066
+ device
1067
+ ):
1068
+ loss_out = force_list(
1069
+ self._loss(self, model, self.dist_class, sample_batch)
1070
+ )
1071
+
1072
+ # Call Model's custom-loss with Policy loss outputs and
1073
+ # train_batch.
1074
+ loss_out = model.custom_loss(loss_out, sample_batch)
1075
+
1076
+ assert len(loss_out) == len(self._optimizers)
1077
+
1078
+ # Loop through all optimizers.
1079
+ grad_info = {"allreduce_latency": 0.0}
1080
+
1081
+ parameters = list(model.parameters())
1082
+ all_grads = [None for _ in range(len(parameters))]
1083
+ for opt_idx, opt in enumerate(self._optimizers):
1084
+ # Erase gradients in all vars of the tower that this
1085
+ # optimizer would affect.
1086
+ param_indices = self.multi_gpu_param_groups[opt_idx]
1087
+ for param_idx, param in enumerate(parameters):
1088
+ if param_idx in param_indices and param.grad is not None:
1089
+ param.grad.data.zero_()
1090
+ # Recompute gradients of loss over all variables.
1091
+ loss_out[opt_idx].backward(retain_graph=True)
1092
+ grad_info.update(
1093
+ self.extra_grad_process(opt, loss_out[opt_idx])
1094
+ )
1095
+
1096
+ grads = []
1097
+ # Note that return values are just references;
1098
+ # Calling zero_grad would modify the values.
1099
+ for param_idx, param in enumerate(parameters):
1100
+ if param_idx in param_indices:
1101
+ if param.grad is not None:
1102
+ grads.append(param.grad)
1103
+ all_grads[param_idx] = param.grad
1104
+
1105
+ if self.distributed_world_size:
1106
+ start = time.time()
1107
+ if torch.cuda.is_available():
1108
+ # Sadly, allreduce_coalesced does not work with
1109
+ # CUDA yet.
1110
+ for g in grads:
1111
+ torch.distributed.all_reduce(
1112
+ g, op=torch.distributed.ReduceOp.SUM
1113
+ )
1114
+ else:
1115
+ torch.distributed.all_reduce_coalesced(
1116
+ grads, op=torch.distributed.ReduceOp.SUM
1117
+ )
1118
+
1119
+ for param_group in opt.param_groups:
1120
+ for p in param_group["params"]:
1121
+ if p.grad is not None:
1122
+ p.grad /= self.distributed_world_size
1123
+
1124
+ grad_info["allreduce_latency"] += time.time() - start
1125
+
1126
+ with lock:
1127
+ results[shard_idx] = (all_grads, grad_info)
1128
+ except Exception as e:
1129
+ import traceback
1130
+
1131
+ with lock:
1132
+ results[shard_idx] = (
1133
+ ValueError(
1134
+ f"Error In tower {shard_idx} on device "
1135
+ f"{device} during multi GPU parallel gradient "
1136
+ f"calculation:"
1137
+ f": {e}\n"
1138
+ f"Traceback: \n"
1139
+ f"{traceback.format_exc()}\n"
1140
+ ),
1141
+ e,
1142
+ )
1143
+
1144
+ # Single device (GPU) or fake-GPU case (serialize for better
1145
+ # debugging).
1146
+ if len(self.devices) == 1 or self.config["_fake_gpus"]:
1147
+ for shard_idx, (model, sample_batch, device) in enumerate(
1148
+ zip(self.model_gpu_towers, sample_batches, self.devices)
1149
+ ):
1150
+ _worker(shard_idx, model, sample_batch, device)
1151
+ # Raise errors right away for better debugging.
1152
+ last_result = results[len(results) - 1]
1153
+ if isinstance(last_result[0], ValueError):
1154
+ raise last_result[0] from last_result[1]
1155
+ # Multi device (GPU) case: Parallelize via threads.
1156
+ else:
1157
+ threads = [
1158
+ threading.Thread(
1159
+ target=_worker, args=(shard_idx, model, sample_batch, device)
1160
+ )
1161
+ for shard_idx, (model, sample_batch, device) in enumerate(
1162
+ zip(self.model_gpu_towers, sample_batches, self.devices)
1163
+ )
1164
+ ]
1165
+
1166
+ for thread in threads:
1167
+ thread.start()
1168
+ for thread in threads:
1169
+ thread.join()
1170
+
1171
+ # Gather all threads' outputs and return.
1172
+ outputs = []
1173
+ for shard_idx in range(len(sample_batches)):
1174
+ output = results[shard_idx]
1175
+ if isinstance(output[0], Exception):
1176
+ raise output[0] from output[1]
1177
+ outputs.append(results[shard_idx])
1178
+ return outputs
1179
+
1180
+
1181
+ @OldAPIStack
1182
+ class DirectStepOptimizer:
1183
+ """Typesafe method for indicating `apply_gradients` can directly step the
1184
+ optimizers with in-place gradients.
1185
+ """
1186
+
1187
+ _instance = None
1188
+
1189
+ def __new__(cls):
1190
+ if DirectStepOptimizer._instance is None:
1191
+ DirectStepOptimizer._instance = super().__new__(cls)
1192
+ return DirectStepOptimizer._instance
1193
+
1194
+ def __eq__(self, other):
1195
+ return type(self) is type(other)
1196
+
1197
+ def __repr__(self):
1198
+ return "DirectStepOptimizer"
1199
+
1200
+
1201
+ _directStepOptimizerSingleton = DirectStepOptimizer()
.venv/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py ADDED
@@ -0,0 +1,1260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import logging
4
+ import math
5
+ import os
6
+ import threading
7
+ import time
8
+ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
9
+
10
+ import gymnasium as gym
11
+ import numpy as np
12
+ from packaging import version
13
+ import tree # pip install dm_tree
14
+
15
+ import ray
16
+ from ray.rllib.models.catalog import ModelCatalog
17
+ from ray.rllib.models.modelv2 import ModelV2
18
+ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
19
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
20
+ from ray.rllib.policy.policy import Policy
21
+ from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
22
+ from ray.rllib.policy.sample_batch import SampleBatch
23
+ from ray.rllib.policy.torch_policy import _directStepOptimizerSingleton
24
+ from ray.rllib.utils import NullContextManager, force_list
25
+ from ray.rllib.utils.annotations import (
26
+ OldAPIStack,
27
+ OverrideToImplementCustomLogic,
28
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
29
+ is_overridden,
30
+ override,
31
+ )
32
+ from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
33
+ from ray.rllib.utils.framework import try_import_torch
34
+ from ray.rllib.utils.metrics import (
35
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
36
+ NUM_AGENT_STEPS_TRAINED,
37
+ NUM_GRAD_UPDATES_LIFETIME,
38
+ )
39
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
40
+ from ray.rllib.utils.numpy import convert_to_numpy
41
+ from ray.rllib.utils.spaces.space_utils import normalize_action
42
+ from ray.rllib.utils.threading import with_lock
43
+ from ray.rllib.utils.torch_utils import (
44
+ convert_to_torch_tensor,
45
+ TORCH_COMPILE_REQUIRED_VERSION,
46
+ )
47
+ from ray.rllib.utils.typing import (
48
+ AlgorithmConfigDict,
49
+ GradInfoDict,
50
+ ModelGradients,
51
+ ModelWeights,
52
+ PolicyState,
53
+ TensorStructType,
54
+ TensorType,
55
+ )
56
+
57
+ torch, nn = try_import_torch()
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ @OldAPIStack
63
+ class TorchPolicyV2(Policy):
64
+ """PyTorch specific Policy class to use with RLlib."""
65
+
66
+ def __init__(
67
+ self,
68
+ observation_space: gym.spaces.Space,
69
+ action_space: gym.spaces.Space,
70
+ config: AlgorithmConfigDict,
71
+ *,
72
+ max_seq_len: int = 20,
73
+ ):
74
+ """Initializes a TorchPolicy instance.
75
+
76
+ Args:
77
+ observation_space: Observation space of the policy.
78
+ action_space: Action space of the policy.
79
+ config: The Policy's config dict.
80
+ max_seq_len: Max sequence length for LSTM training.
81
+ """
82
+ self.framework = config["framework"] = "torch"
83
+
84
+ self._loss_initialized = False
85
+ super().__init__(observation_space, action_space, config)
86
+
87
+ # Create model.
88
+ model, dist_class = self._init_model_and_dist_class()
89
+
90
+ # Create multi-GPU model towers, if necessary.
91
+ # - The central main model will be stored under self.model, residing
92
+ # on self.device (normally, a CPU).
93
+ # - Each GPU will have a copy of that model under
94
+ # self.model_gpu_towers, matching the devices in self.devices.
95
+ # - Parallelization is done by splitting the train batch and passing
96
+ # it through the model copies in parallel, then averaging over the
97
+ # resulting gradients, applying these averages on the main model and
98
+ # updating all towers' weights from the main model.
99
+ # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
100
+ # parallelization will be done.
101
+
102
+ # Get devices to build the graph on.
103
+ num_gpus = self._get_num_gpus_for_policy()
104
+ gpu_ids = list(range(torch.cuda.device_count()))
105
+ logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
106
+
107
+ # Place on one or more CPU(s) when either:
108
+ # - Fake GPU mode.
109
+ # - num_gpus=0 (either set by user or we are in local_mode=True).
110
+ # - No GPUs available.
111
+ if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
112
+ self.device = torch.device("cpu")
113
+ self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
114
+ self.model_gpu_towers = [
115
+ model if i == 0 else copy.deepcopy(model)
116
+ for i in range(int(math.ceil(num_gpus)) or 1)
117
+ ]
118
+ if hasattr(self, "target_model"):
119
+ self.target_models = {
120
+ m: self.target_model for m in self.model_gpu_towers
121
+ }
122
+ self.model = model
123
+ # Place on one or more actual GPU(s), when:
124
+ # - num_gpus > 0 (set by user) AND
125
+ # - local_mode=False AND
126
+ # - actual GPUs available AND
127
+ # - non-fake GPU mode.
128
+ else:
129
+ # We are a remote worker (WORKER_MODE=1):
130
+ # GPUs should be assigned to us by ray.
131
+ if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
132
+ gpu_ids = ray.get_gpu_ids()
133
+
134
+ if len(gpu_ids) < num_gpus:
135
+ raise ValueError(
136
+ "TorchPolicy was not able to find enough GPU IDs! Found "
137
+ f"{gpu_ids}, but num_gpus={num_gpus}."
138
+ )
139
+
140
+ self.devices = [
141
+ torch.device("cuda:{}".format(i))
142
+ for i, id_ in enumerate(gpu_ids)
143
+ if i < num_gpus
144
+ ]
145
+ self.device = self.devices[0]
146
+ ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
147
+ self.model_gpu_towers = []
148
+ for i, _ in enumerate(ids):
149
+ model_copy = copy.deepcopy(model)
150
+ self.model_gpu_towers.append(model_copy.to(self.devices[i]))
151
+ if hasattr(self, "target_model"):
152
+ self.target_models = {
153
+ m: copy.deepcopy(self.target_model).to(self.devices[i])
154
+ for i, m in enumerate(self.model_gpu_towers)
155
+ }
156
+ self.model = self.model_gpu_towers[0]
157
+
158
+ self.dist_class = dist_class
159
+ self.unwrapped_model = model # used to support DistributedDataParallel
160
+
161
+ # Lock used for locking some methods on the object-level.
162
+ # This prevents possible race conditions when calling the model
163
+ # first, then its value function (e.g. in a loss function), in
164
+ # between of which another model call is made (e.g. to compute an
165
+ # action).
166
+ self._lock = threading.RLock()
167
+
168
+ self._state_inputs = self.model.get_initial_state()
169
+ self._is_recurrent = len(tree.flatten(self._state_inputs)) > 0
170
+ # Auto-update model's inference view requirements, if recurrent.
171
+ self._update_model_view_requirements_from_init_state()
172
+ # Combine view_requirements for Model and Policy.
173
+ self.view_requirements.update(self.model.view_requirements)
174
+
175
+ self.exploration = self._create_exploration()
176
+ self._optimizers = force_list(self.optimizer())
177
+
178
+ # Backward compatibility workaround so Policy will call self.loss()
179
+ # directly.
180
+ # TODO (jungong): clean up after all policies are migrated to new sub-class
181
+ # implementation.
182
+ self._loss = None
183
+
184
+ # Store, which params (by index within the model's list of
185
+ # parameters) should be updated per optimizer.
186
+ # Maps optimizer idx to set or param indices.
187
+ self.multi_gpu_param_groups: List[Set[int]] = []
188
+ main_params = {p: i for i, p in enumerate(self.model.parameters())}
189
+ for o in self._optimizers:
190
+ param_indices = []
191
+ for pg_idx, pg in enumerate(o.param_groups):
192
+ for p in pg["params"]:
193
+ param_indices.append(main_params[p])
194
+ self.multi_gpu_param_groups.append(set(param_indices))
195
+
196
+ # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
197
+ # one with m towers (num_gpus).
198
+ num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
199
+ self._loaded_batches = [[] for _ in range(num_buffers)]
200
+
201
+ # If set, means we are using distributed allreduce during learning.
202
+ self.distributed_world_size = None
203
+
204
+ self.batch_divisibility_req = self.get_batch_divisibility_req()
205
+ self.max_seq_len = max_seq_len
206
+
207
+ # If model is an RLModule it won't have tower_stats instead there will be a
208
+ # self.tower_state[model] -> dict for each tower.
209
+ self.tower_stats = {}
210
+ if not hasattr(self.model, "tower_stats"):
211
+ for model in self.model_gpu_towers:
212
+ self.tower_stats[model] = {}
213
+
214
+ def loss_initialized(self):
215
+ return self._loss_initialized
216
+
217
+ @OverrideToImplementCustomLogic
218
+ @override(Policy)
219
+ def loss(
220
+ self,
221
+ model: ModelV2,
222
+ dist_class: Type[TorchDistributionWrapper],
223
+ train_batch: SampleBatch,
224
+ ) -> Union[TensorType, List[TensorType]]:
225
+ """Constructs the loss function.
226
+
227
+ Args:
228
+ model: The Model to calculate the loss for.
229
+ dist_class: The action distr. class.
230
+ train_batch: The training data.
231
+
232
+ Returns:
233
+ Loss tensor given the input batch.
234
+ """
235
+ raise NotImplementedError
236
+
237
+ @OverrideToImplementCustomLogic
238
+ def action_sampler_fn(
239
+ self,
240
+ model: ModelV2,
241
+ *,
242
+ obs_batch: TensorType,
243
+ state_batches: TensorType,
244
+ **kwargs,
245
+ ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
246
+ """Custom function for sampling new actions given policy.
247
+
248
+ Args:
249
+ model: Underlying model.
250
+ obs_batch: Observation tensor batch.
251
+ state_batches: Action sampling state batch.
252
+
253
+ Returns:
254
+ Sampled action
255
+ Log-likelihood
256
+ Action distribution inputs
257
+ Updated state
258
+ """
259
+ return None, None, None, None
260
+
261
+ @OverrideToImplementCustomLogic
262
+ def action_distribution_fn(
263
+ self,
264
+ model: ModelV2,
265
+ *,
266
+ obs_batch: TensorType,
267
+ state_batches: TensorType,
268
+ **kwargs,
269
+ ) -> Tuple[TensorType, type, List[TensorType]]:
270
+ """Action distribution function for this Policy.
271
+
272
+ Args:
273
+ model: Underlying model.
274
+ obs_batch: Observation tensor batch.
275
+ state_batches: Action sampling state batch.
276
+
277
+ Returns:
278
+ Distribution input.
279
+ ActionDistribution class.
280
+ State outs.
281
+ """
282
+ return None, None, None
283
+
284
+ @OverrideToImplementCustomLogic
285
+ def make_model(self) -> ModelV2:
286
+ """Create model.
287
+
288
+ Note: only one of make_model or make_model_and_action_dist
289
+ can be overridden.
290
+
291
+ Returns:
292
+ ModelV2 model.
293
+ """
294
+ return None
295
+
296
+ @OverrideToImplementCustomLogic
297
+ def make_model_and_action_dist(
298
+ self,
299
+ ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
300
+ """Create model and action distribution function.
301
+
302
+ Returns:
303
+ ModelV2 model.
304
+ ActionDistribution class.
305
+ """
306
+ return None, None
307
+
308
+ @OverrideToImplementCustomLogic
309
+ def get_batch_divisibility_req(self) -> int:
310
+ """Get batch divisibility request.
311
+
312
+ Returns:
313
+ Size N. A sample batch must be of size K*N.
314
+ """
315
+ # By default, any sized batch is ok, so simply return 1.
316
+ return 1
317
+
318
+ @OverrideToImplementCustomLogic
319
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
320
+ """Stats function. Returns a dict of statistics.
321
+
322
+ Args:
323
+ train_batch: The SampleBatch (already) used for training.
324
+
325
+ Returns:
326
+ The stats dict.
327
+ """
328
+ return {}
329
+
330
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
331
+ def extra_grad_process(
332
+ self, optimizer: "torch.optim.Optimizer", loss: TensorType
333
+ ) -> Dict[str, TensorType]:
334
+ """Called after each optimizer.zero_grad() + loss.backward() call.
335
+
336
+ Called for each self._optimizers/loss-value pair.
337
+ Allows for gradient processing before optimizer.step() is called.
338
+ E.g. for gradient clipping.
339
+
340
+ Args:
341
+ optimizer: A torch optimizer object.
342
+ loss: The loss tensor associated with the optimizer.
343
+
344
+ Returns:
345
+ An dict with information on the gradient processing step.
346
+ """
347
+ return {}
348
+
349
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
350
+ def extra_compute_grad_fetches(self) -> Dict[str, Any]:
351
+ """Extra values to fetch and return from compute_gradients().
352
+
353
+ Returns:
354
+ Extra fetch dict to be added to the fetch dict of the
355
+ `compute_gradients` call.
356
+ """
357
+ return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
358
+
359
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
360
+ def extra_action_out(
361
+ self,
362
+ input_dict: Dict[str, TensorType],
363
+ state_batches: List[TensorType],
364
+ model: TorchModelV2,
365
+ action_dist: TorchDistributionWrapper,
366
+ ) -> Dict[str, TensorType]:
367
+ """Returns dict of extra info to include in experience batch.
368
+
369
+ Args:
370
+ input_dict: Dict of model input tensors.
371
+ state_batches: List of state tensors.
372
+ model: Reference to the model object.
373
+ action_dist: Torch action dist object
374
+ to get log-probs (e.g. for already sampled actions).
375
+
376
+ Returns:
377
+ Extra outputs to return in a `compute_actions_from_input_dict()`
378
+ call (3rd return value).
379
+ """
380
+ return {}
381
+
382
+ @override(Policy)
383
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
384
+ def postprocess_trajectory(
385
+ self,
386
+ sample_batch: SampleBatch,
387
+ other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
388
+ episode=None,
389
+ ) -> SampleBatch:
390
+ """Postprocesses a trajectory and returns the processed trajectory.
391
+
392
+ The trajectory contains only data from one episode and from one agent.
393
+ - If `config.batch_mode=truncate_episodes` (default), sample_batch may
394
+ contain a truncated (at-the-end) episode, in case the
395
+ `config.rollout_fragment_length` was reached by the sampler.
396
+ - If `config.batch_mode=complete_episodes`, sample_batch will contain
397
+ exactly one episode (no matter how long).
398
+ New columns can be added to sample_batch and existing ones may be altered.
399
+
400
+ Args:
401
+ sample_batch: The SampleBatch to postprocess.
402
+ other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
403
+ dict of AgentIDs mapping to other agents' trajectory data (from the
404
+ same episode). NOTE: The other agents use the same policy.
405
+ episode (Optional[Episode]): Optional multi-agent episode
406
+ object in which the agents operated.
407
+
408
+ Returns:
409
+ SampleBatch: The postprocessed, modified SampleBatch (or a new one).
410
+ """
411
+ return sample_batch
412
+
413
+ @OverrideToImplementCustomLogic
414
+ def optimizer(
415
+ self,
416
+ ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
417
+ """Custom the local PyTorch optimizer(s) to use.
418
+
419
+ Returns:
420
+ The local PyTorch optimizer(s) to use for this Policy.
421
+ """
422
+ if hasattr(self, "config"):
423
+ optimizers = [
424
+ torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
425
+ ]
426
+ else:
427
+ optimizers = [torch.optim.Adam(self.model.parameters())]
428
+ if self.exploration:
429
+ optimizers = self.exploration.get_exploration_optimizer(optimizers)
430
+ return optimizers
431
+
432
+ def _init_model_and_dist_class(self):
433
+ if is_overridden(self.make_model) and is_overridden(
434
+ self.make_model_and_action_dist
435
+ ):
436
+ raise ValueError(
437
+ "Only one of make_model or make_model_and_action_dist "
438
+ "can be overridden."
439
+ )
440
+
441
+ if is_overridden(self.make_model):
442
+ model = self.make_model()
443
+ dist_class, _ = ModelCatalog.get_action_dist(
444
+ self.action_space, self.config["model"], framework=self.framework
445
+ )
446
+ elif is_overridden(self.make_model_and_action_dist):
447
+ model, dist_class = self.make_model_and_action_dist()
448
+ else:
449
+ dist_class, logit_dim = ModelCatalog.get_action_dist(
450
+ self.action_space, self.config["model"], framework=self.framework
451
+ )
452
+ model = ModelCatalog.get_model_v2(
453
+ obs_space=self.observation_space,
454
+ action_space=self.action_space,
455
+ num_outputs=logit_dim,
456
+ model_config=self.config["model"],
457
+ framework=self.framework,
458
+ )
459
+
460
+ # Compile the model, if requested by the user.
461
+ if self.config.get("torch_compile_learner"):
462
+ if (
463
+ torch is not None
464
+ and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
465
+ ):
466
+ raise ValueError("`torch.compile` is not supported for torch < 2.0.0!")
467
+
468
+ lw = "learner" if self.config.get("worker_index") else "worker"
469
+ model = torch.compile(
470
+ model,
471
+ backend=self.config.get(
472
+ f"torch_compile_{lw}_dynamo_backend", "inductor"
473
+ ),
474
+ dynamic=False,
475
+ mode=self.config.get(f"torch_compile_{lw}_dynamo_mode"),
476
+ )
477
+ return model, dist_class
478
+
479
+ @override(Policy)
480
+ def compute_actions_from_input_dict(
481
+ self,
482
+ input_dict: Dict[str, TensorType],
483
+ explore: bool = None,
484
+ timestep: Optional[int] = None,
485
+ **kwargs,
486
+ ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
487
+
488
+ seq_lens = None
489
+ with torch.no_grad():
490
+ # Pass lazy (torch) tensor dict to Model as `input_dict`.
491
+ input_dict = self._lazy_tensor_dict(input_dict)
492
+ input_dict.set_training(True)
493
+ # Pack internal state inputs into (separate) list.
494
+ state_batches = [
495
+ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
496
+ ]
497
+ # Calculate RNN sequence lengths.
498
+ if state_batches:
499
+ seq_lens = torch.tensor(
500
+ [1] * len(state_batches[0]),
501
+ dtype=torch.long,
502
+ device=state_batches[0].device,
503
+ )
504
+
505
+ return self._compute_action_helper(
506
+ input_dict, state_batches, seq_lens, explore, timestep
507
+ )
508
+
509
+ @override(Policy)
510
+ def compute_actions(
511
+ self,
512
+ obs_batch: Union[List[TensorStructType], TensorStructType],
513
+ state_batches: Optional[List[TensorType]] = None,
514
+ prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
515
+ prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
516
+ info_batch: Optional[Dict[str, list]] = None,
517
+ episodes=None,
518
+ explore: Optional[bool] = None,
519
+ timestep: Optional[int] = None,
520
+ **kwargs,
521
+ ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
522
+
523
+ with torch.no_grad():
524
+ seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
525
+ input_dict = self._lazy_tensor_dict(
526
+ {
527
+ SampleBatch.CUR_OBS: obs_batch,
528
+ "is_training": False,
529
+ }
530
+ )
531
+ if prev_action_batch is not None:
532
+ input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
533
+ if prev_reward_batch is not None:
534
+ input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
535
+ state_batches = [
536
+ convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
537
+ ]
538
+ return self._compute_action_helper(
539
+ input_dict, state_batches, seq_lens, explore, timestep
540
+ )
541
+
542
+ @with_lock
543
+ @override(Policy)
544
+ def compute_log_likelihoods(
545
+ self,
546
+ actions: Union[List[TensorStructType], TensorStructType],
547
+ obs_batch: Union[List[TensorStructType], TensorStructType],
548
+ state_batches: Optional[List[TensorType]] = None,
549
+ prev_action_batch: Optional[
550
+ Union[List[TensorStructType], TensorStructType]
551
+ ] = None,
552
+ prev_reward_batch: Optional[
553
+ Union[List[TensorStructType], TensorStructType]
554
+ ] = None,
555
+ actions_normalized: bool = True,
556
+ in_training: bool = True,
557
+ ) -> TensorType:
558
+
559
+ if is_overridden(self.action_sampler_fn) and not is_overridden(
560
+ self.action_distribution_fn
561
+ ):
562
+ raise ValueError(
563
+ "Cannot compute log-prob/likelihood w/o an "
564
+ "`action_distribution_fn` and a provided "
565
+ "`action_sampler_fn`!"
566
+ )
567
+
568
+ with torch.no_grad():
569
+ input_dict = self._lazy_tensor_dict(
570
+ {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
571
+ )
572
+ if prev_action_batch is not None:
573
+ input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
574
+ if prev_reward_batch is not None:
575
+ input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
576
+ seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
577
+ state_batches = [
578
+ convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
579
+ ]
580
+
581
+ if self.exploration:
582
+ # Exploration hook before each forward pass.
583
+ self.exploration.before_compute_actions(explore=False)
584
+
585
+ # Action dist class and inputs are generated via custom function.
586
+ if is_overridden(self.action_distribution_fn):
587
+ dist_inputs, dist_class, state_out = self.action_distribution_fn(
588
+ self.model,
589
+ obs_batch=input_dict,
590
+ state_batches=state_batches,
591
+ seq_lens=seq_lens,
592
+ explore=False,
593
+ is_training=False,
594
+ )
595
+ action_dist = dist_class(dist_inputs, self.model)
596
+ # Default action-dist inputs calculation.
597
+ else:
598
+ dist_class = self.dist_class
599
+ dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
600
+
601
+ action_dist = dist_class(dist_inputs, self.model)
602
+
603
+ # Normalize actions if necessary.
604
+ actions = input_dict[SampleBatch.ACTIONS]
605
+ if not actions_normalized and self.config["normalize_actions"]:
606
+ actions = normalize_action(actions, self.action_space_struct)
607
+
608
+ log_likelihoods = action_dist.logp(actions)
609
+
610
+ return log_likelihoods
611
+
612
+ @with_lock
613
+ @override(Policy)
614
+ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
615
+
616
+ # Set Model to train mode.
617
+ if self.model:
618
+ self.model.train()
619
+ # Callback handling.
620
+ learn_stats = {}
621
+ self.callbacks.on_learn_on_batch(
622
+ policy=self, train_batch=postprocessed_batch, result=learn_stats
623
+ )
624
+
625
+ # Compute gradients (will calculate all losses and `backward()`
626
+ # them to get the grads).
627
+ grads, fetches = self.compute_gradients(postprocessed_batch)
628
+
629
+ # Step the optimizers.
630
+ self.apply_gradients(_directStepOptimizerSingleton)
631
+
632
+ self.num_grad_updates += 1
633
+ if self.model and hasattr(self.model, "metrics"):
634
+ fetches["model"] = self.model.metrics()
635
+ else:
636
+ fetches["model"] = {}
637
+
638
+ fetches.update(
639
+ {
640
+ "custom_metrics": learn_stats,
641
+ NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
642
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
643
+ # -1, b/c we have to measure this diff before we do the update above.
644
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
645
+ self.num_grad_updates
646
+ - 1
647
+ - (postprocessed_batch.num_grad_updates or 0)
648
+ ),
649
+ }
650
+ )
651
+
652
+ return fetches
653
+
654
+ @override(Policy)
655
+ def load_batch_into_buffer(
656
+ self,
657
+ batch: SampleBatch,
658
+ buffer_index: int = 0,
659
+ ) -> int:
660
+ # Set the is_training flag of the batch.
661
+ batch.set_training(True)
662
+
663
+ # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
664
+ if len(self.devices) == 1 and self.devices[0].type == "cpu":
665
+ assert buffer_index == 0
666
+ pad_batch_to_sequences_of_same_size(
667
+ batch=batch,
668
+ max_seq_len=self.max_seq_len,
669
+ shuffle=False,
670
+ batch_divisibility_req=self.batch_divisibility_req,
671
+ view_requirements=self.view_requirements,
672
+ _enable_new_api_stack=False,
673
+ padding="zero",
674
+ )
675
+ self._lazy_tensor_dict(batch)
676
+ self._loaded_batches[0] = [batch]
677
+ return len(batch)
678
+
679
+ # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
680
+ # 0123 0123456 0123 0123456789ABC
681
+
682
+ # 1) split into n per-GPU sub batches (n=2).
683
+ # [0123 0123456] [012] [3 0123456789 ABC]
684
+ # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
685
+ slices = batch.timeslices(num_slices=len(self.devices))
686
+
687
+ # 2) zero-padding (max-seq-len=10).
688
+ # - [0123000000 0123456000 0120000000]
689
+ # - [3000000000 0123456789 ABC0000000]
690
+ for slice in slices:
691
+ pad_batch_to_sequences_of_same_size(
692
+ batch=slice,
693
+ max_seq_len=self.max_seq_len,
694
+ shuffle=False,
695
+ batch_divisibility_req=self.batch_divisibility_req,
696
+ view_requirements=self.view_requirements,
697
+ _enable_new_api_stack=False,
698
+ padding="zero",
699
+ )
700
+
701
+ # 3) Load splits into the given buffer (consisting of n GPUs).
702
+ slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
703
+ self._loaded_batches[buffer_index] = slices
704
+
705
+ # Return loaded samples per-device.
706
+ return len(slices[0])
707
+
708
+ @override(Policy)
709
+ def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
710
+ if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
711
+ assert buffer_index == 0
712
+ return sum(len(b) for b in self._loaded_batches[buffer_index])
713
+
714
+ @override(Policy)
715
+ def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
716
+ if not self._loaded_batches[buffer_index]:
717
+ raise ValueError(
718
+ "Must call Policy.load_batch_into_buffer() before "
719
+ "Policy.learn_on_loaded_batch()!"
720
+ )
721
+
722
+ # Get the correct slice of the already loaded batch to use,
723
+ # based on offset and batch size.
724
+ device_batch_size = self.config.get("minibatch_size")
725
+ if device_batch_size is None:
726
+ device_batch_size = self.config.get(
727
+ "sgd_minibatch_size",
728
+ self.config["train_batch_size"],
729
+ )
730
+ device_batch_size //= len(self.devices)
731
+
732
+ # Set Model to train mode.
733
+ if self.model_gpu_towers:
734
+ for t in self.model_gpu_towers:
735
+ t.train()
736
+
737
+ # Shortcut for 1 CPU only: Batch should already be stored in
738
+ # `self._loaded_batches`.
739
+ if len(self.devices) == 1 and self.devices[0].type == "cpu":
740
+ assert buffer_index == 0
741
+ if device_batch_size >= len(self._loaded_batches[0][0]):
742
+ batch = self._loaded_batches[0][0]
743
+ else:
744
+ batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
745
+
746
+ return self.learn_on_batch(batch)
747
+
748
+ if len(self.devices) > 1:
749
+ # Copy weights of main model (tower-0) to all other towers.
750
+ state_dict = self.model.state_dict()
751
+ # Just making sure tower-0 is really the same as self.model.
752
+ assert self.model_gpu_towers[0] is self.model
753
+ for tower in self.model_gpu_towers[1:]:
754
+ tower.load_state_dict(state_dict)
755
+
756
+ if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
757
+ device_batches = self._loaded_batches[buffer_index]
758
+ else:
759
+ device_batches = [
760
+ b[offset : offset + device_batch_size]
761
+ for b in self._loaded_batches[buffer_index]
762
+ ]
763
+
764
+ # Callback handling.
765
+ batch_fetches = {}
766
+ for i, batch in enumerate(device_batches):
767
+ custom_metrics = {}
768
+ self.callbacks.on_learn_on_batch(
769
+ policy=self, train_batch=batch, result=custom_metrics
770
+ )
771
+ batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
772
+
773
+ # Do the (maybe parallelized) gradient calculation step.
774
+ tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
775
+
776
+ # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
777
+ all_grads = []
778
+ for i in range(len(tower_outputs[0][0])):
779
+ if tower_outputs[0][0][i] is not None:
780
+ all_grads.append(
781
+ torch.mean(
782
+ torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
783
+ dim=0,
784
+ )
785
+ )
786
+ else:
787
+ all_grads.append(None)
788
+ # Set main model's grads to mean-reduced values.
789
+ for i, p in enumerate(self.model.parameters()):
790
+ p.grad = all_grads[i]
791
+
792
+ self.apply_gradients(_directStepOptimizerSingleton)
793
+
794
+ self.num_grad_updates += 1
795
+
796
+ for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
797
+ batch_fetches[f"tower_{i}"].update(
798
+ {
799
+ LEARNER_STATS_KEY: self.stats_fn(batch),
800
+ "model": model.metrics(),
801
+ NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
802
+ # -1, b/c we have to measure this diff before we do the update
803
+ # above.
804
+ DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
805
+ self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
806
+ ),
807
+ }
808
+ )
809
+ batch_fetches.update(self.extra_compute_grad_fetches())
810
+
811
+ return batch_fetches
812
+
813
+ @with_lock
814
+ @override(Policy)
815
+ def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
816
+
817
+ assert len(self.devices) == 1
818
+
819
+ # If not done yet, see whether we have to zero-pad this batch.
820
+ if not postprocessed_batch.zero_padded:
821
+ pad_batch_to_sequences_of_same_size(
822
+ batch=postprocessed_batch,
823
+ max_seq_len=self.max_seq_len,
824
+ shuffle=False,
825
+ batch_divisibility_req=self.batch_divisibility_req,
826
+ view_requirements=self.view_requirements,
827
+ _enable_new_api_stack=False,
828
+ padding="zero",
829
+ )
830
+
831
+ postprocessed_batch.set_training(True)
832
+ self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
833
+
834
+ # Do the (maybe parallelized) gradient calculation step.
835
+ tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
836
+
837
+ all_grads, grad_info = tower_outputs[0]
838
+
839
+ grad_info["allreduce_latency"] /= len(self._optimizers)
840
+ grad_info.update(self.stats_fn(postprocessed_batch))
841
+
842
+ fetches = self.extra_compute_grad_fetches()
843
+
844
+ return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
845
+
846
+ @override(Policy)
847
+ def apply_gradients(self, gradients: ModelGradients) -> None:
848
+ if gradients == _directStepOptimizerSingleton:
849
+ for i, opt in enumerate(self._optimizers):
850
+ opt.step()
851
+ else:
852
+ # TODO(sven): Not supported for multiple optimizers yet.
853
+ assert len(self._optimizers) == 1
854
+ for g, p in zip(gradients, self.model.parameters()):
855
+ if g is not None:
856
+ if torch.is_tensor(g):
857
+ p.grad = g.to(self.device)
858
+ else:
859
+ p.grad = torch.from_numpy(g).to(self.device)
860
+
861
+ self._optimizers[0].step()
862
+
863
+ def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
864
+ """Returns list of per-tower stats, copied to this Policy's device.
865
+
866
+ Args:
867
+ stats_name: The name of the stats to average over (this str
868
+ must exist as a key inside each tower's `tower_stats` dict).
869
+
870
+ Returns:
871
+ The list of stats tensor (structs) of all towers, copied to this
872
+ Policy's device.
873
+
874
+ Raises:
875
+ AssertionError: If the `stats_name` cannot be found in any one
876
+ of the tower's `tower_stats` dicts.
877
+ """
878
+ data = []
879
+ for model in self.model_gpu_towers:
880
+ if self.tower_stats:
881
+ tower_stats = self.tower_stats[model]
882
+ else:
883
+ tower_stats = model.tower_stats
884
+
885
+ if stats_name in tower_stats:
886
+ data.append(
887
+ tree.map_structure(
888
+ lambda s: s.to(self.device), tower_stats[stats_name]
889
+ )
890
+ )
891
+
892
+ assert len(data) > 0, (
893
+ f"Stats `{stats_name}` not found in any of the towers (you have "
894
+ f"{len(self.model_gpu_towers)} towers in total)! Make "
895
+ "sure you call the loss function on at least one of the towers."
896
+ )
897
+ return data
898
+
899
+ @override(Policy)
900
+ def get_weights(self) -> ModelWeights:
901
+ return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
902
+
903
+ @override(Policy)
904
+ def set_weights(self, weights: ModelWeights) -> None:
905
+ weights = convert_to_torch_tensor(weights, device=self.device)
906
+ self.model.load_state_dict(weights)
907
+
908
+ @override(Policy)
909
+ def is_recurrent(self) -> bool:
910
+ return self._is_recurrent
911
+
912
+ @override(Policy)
913
+ def num_state_tensors(self) -> int:
914
+ return len(self.model.get_initial_state())
915
+
916
+ @override(Policy)
917
+ def get_initial_state(self) -> List[TensorType]:
918
+ return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
919
+
920
+ @override(Policy)
921
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
922
+ def get_state(self) -> PolicyState:
923
+ # Legacy Policy state (w/o torch.nn.Module and w/o PolicySpec).
924
+ state = super().get_state()
925
+
926
+ state["_optimizer_variables"] = []
927
+ for i, o in enumerate(self._optimizers):
928
+ optim_state_dict = convert_to_numpy(o.state_dict())
929
+ state["_optimizer_variables"].append(optim_state_dict)
930
+ # Add exploration state.
931
+ if self.exploration:
932
+ # This is not compatible with RLModules, which have a method
933
+ # `forward_exploration` to specify custom exploration behavior.
934
+ state["_exploration_state"] = self.exploration.get_state()
935
+ return state
936
+
937
+ @override(Policy)
938
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
939
+ def set_state(self, state: PolicyState) -> None:
940
+ # Set optimizer vars first.
941
+ optimizer_vars = state.get("_optimizer_variables", None)
942
+ if optimizer_vars:
943
+ assert len(optimizer_vars) == len(self._optimizers)
944
+ for o, s in zip(self._optimizers, optimizer_vars):
945
+ # Torch optimizer param_groups include things like beta, etc. These
946
+ # parameters should be left as scalar and not converted to tensors.
947
+ # otherwise, torch.optim.step() will start to complain.
948
+ optim_state_dict = {"param_groups": s["param_groups"]}
949
+ optim_state_dict["state"] = convert_to_torch_tensor(
950
+ s["state"], device=self.device
951
+ )
952
+ o.load_state_dict(optim_state_dict)
953
+ # Set exploration's state.
954
+ if hasattr(self, "exploration") and "_exploration_state" in state:
955
+ self.exploration.set_state(state=state["_exploration_state"])
956
+
957
+ # Restore global timestep.
958
+ self.global_timestep = state["global_timestep"]
959
+
960
+ # Then the Policy's (NN) weights and connectors.
961
+ super().set_state(state)
962
+
963
+ @override(Policy)
964
+ def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
965
+ """Exports the Policy's Model to local directory for serving.
966
+
967
+ Creates a TorchScript model and saves it.
968
+
969
+ Args:
970
+ export_dir: Local writable directory or filename.
971
+ onnx: If given, will export model in ONNX format. The
972
+ value of this parameter set the ONNX OpSet version to use.
973
+ """
974
+
975
+ os.makedirs(export_dir, exist_ok=True)
976
+
977
+ if onnx:
978
+ self._lazy_tensor_dict(self._dummy_batch)
979
+ # Provide dummy state inputs if not an RNN (torch cannot jit with
980
+ # returned empty internal states list).
981
+ if "state_in_0" not in self._dummy_batch:
982
+ self._dummy_batch["state_in_0"] = self._dummy_batch[
983
+ SampleBatch.SEQ_LENS
984
+ ] = np.array([1.0])
985
+ seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
986
+
987
+ state_ins = []
988
+ i = 0
989
+ while "state_in_{}".format(i) in self._dummy_batch:
990
+ state_ins.append(self._dummy_batch["state_in_{}".format(i)])
991
+ i += 1
992
+ dummy_inputs = {
993
+ k: self._dummy_batch[k]
994
+ for k in self._dummy_batch.keys()
995
+ if k != "is_training"
996
+ }
997
+
998
+ file_name = os.path.join(export_dir, "model.onnx")
999
+ torch.onnx.export(
1000
+ self.model,
1001
+ (dummy_inputs, state_ins, seq_lens),
1002
+ file_name,
1003
+ export_params=True,
1004
+ opset_version=onnx,
1005
+ do_constant_folding=True,
1006
+ input_names=list(dummy_inputs.keys())
1007
+ + ["state_ins", SampleBatch.SEQ_LENS],
1008
+ output_names=["output", "state_outs"],
1009
+ dynamic_axes={
1010
+ k: {0: "batch_size"}
1011
+ for k in list(dummy_inputs.keys())
1012
+ + ["state_ins", SampleBatch.SEQ_LENS]
1013
+ },
1014
+ )
1015
+ # Save the torch.Model (architecture and weights, so it can be retrieved
1016
+ # w/o access to the original (custom) Model or Policy code).
1017
+ else:
1018
+ filename = os.path.join(export_dir, "model.pt")
1019
+ try:
1020
+ torch.save(self.model, f=filename)
1021
+ except Exception:
1022
+ if os.path.exists(filename):
1023
+ os.remove(filename)
1024
+ logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
1025
+
1026
+ @override(Policy)
1027
+ def import_model_from_h5(self, import_file: str) -> None:
1028
+ """Imports weights into torch model."""
1029
+ return self.model.import_from_h5(import_file)
1030
+
1031
+ @with_lock
1032
+ def _compute_action_helper(
1033
+ self, input_dict, state_batches, seq_lens, explore, timestep
1034
+ ):
1035
+ """Shared forward pass logic (w/ and w/o trajectory view API).
1036
+
1037
+ Returns:
1038
+ A tuple consisting of a) actions, b) state_out, c) extra_fetches.
1039
+ The input_dict is modified in-place to include a numpy copy of the computed
1040
+ actions under `SampleBatch.ACTIONS`.
1041
+ """
1042
+ explore = explore if explore is not None else self.config["explore"]
1043
+ timestep = timestep if timestep is not None else self.global_timestep
1044
+
1045
+ # Switch to eval mode.
1046
+ if self.model:
1047
+ self.model.eval()
1048
+
1049
+ extra_fetches = dist_inputs = logp = None
1050
+
1051
+ if is_overridden(self.action_sampler_fn):
1052
+ action_dist = None
1053
+ actions, logp, dist_inputs, state_out = self.action_sampler_fn(
1054
+ self.model,
1055
+ obs_batch=input_dict,
1056
+ state_batches=state_batches,
1057
+ explore=explore,
1058
+ timestep=timestep,
1059
+ )
1060
+ else:
1061
+ # Call the exploration before_compute_actions hook.
1062
+ self.exploration.before_compute_actions(explore=explore, timestep=timestep)
1063
+ if is_overridden(self.action_distribution_fn):
1064
+ dist_inputs, dist_class, state_out = self.action_distribution_fn(
1065
+ self.model,
1066
+ obs_batch=input_dict,
1067
+ state_batches=state_batches,
1068
+ seq_lens=seq_lens,
1069
+ explore=explore,
1070
+ timestep=timestep,
1071
+ is_training=False,
1072
+ )
1073
+ else:
1074
+ dist_class = self.dist_class
1075
+ dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
1076
+
1077
+ if not (
1078
+ isinstance(dist_class, functools.partial)
1079
+ or issubclass(dist_class, TorchDistributionWrapper)
1080
+ ):
1081
+ raise ValueError(
1082
+ "`dist_class` ({}) not a TorchDistributionWrapper "
1083
+ "subclass! Make sure your `action_distribution_fn` or "
1084
+ "`make_model_and_action_dist` return a correct "
1085
+ "distribution class.".format(dist_class.__name__)
1086
+ )
1087
+ action_dist = dist_class(dist_inputs, self.model)
1088
+
1089
+ # Get the exploration action from the forward results.
1090
+ actions, logp = self.exploration.get_exploration_action(
1091
+ action_distribution=action_dist, timestep=timestep, explore=explore
1092
+ )
1093
+
1094
+ # Add default and custom fetches.
1095
+ if extra_fetches is None:
1096
+ extra_fetches = self.extra_action_out(
1097
+ input_dict, state_batches, self.model, action_dist
1098
+ )
1099
+
1100
+ # Action-dist inputs.
1101
+ if dist_inputs is not None:
1102
+ extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
1103
+
1104
+ # Action-logp and action-prob.
1105
+ if logp is not None:
1106
+ extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
1107
+ extra_fetches[SampleBatch.ACTION_LOGP] = logp
1108
+
1109
+ # Update our global timestep by the batch size.
1110
+ self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
1111
+ return convert_to_numpy((actions, state_out, extra_fetches))
1112
+
1113
+ def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
1114
+ if not isinstance(postprocessed_batch, SampleBatch):
1115
+ postprocessed_batch = SampleBatch(postprocessed_batch)
1116
+ postprocessed_batch.set_get_interceptor(
1117
+ functools.partial(convert_to_torch_tensor, device=device or self.device)
1118
+ )
1119
+ return postprocessed_batch
1120
+
1121
+ def _multi_gpu_parallel_grad_calc(
1122
+ self, sample_batches: List[SampleBatch]
1123
+ ) -> List[Tuple[List[TensorType], GradInfoDict]]:
1124
+ """Performs a parallelized loss and gradient calculation over the batch.
1125
+
1126
+ Splits up the given train batch into n shards (n=number of this
1127
+ Policy's devices) and passes each data shard (in parallel) through
1128
+ the loss function using the individual devices' models
1129
+ (self.model_gpu_towers). Then returns each tower's outputs.
1130
+
1131
+ Args:
1132
+ sample_batches: A list of SampleBatch shards to
1133
+ calculate loss and gradients for.
1134
+
1135
+ Returns:
1136
+ A list (one item per device) of 2-tuples, each with 1) gradient
1137
+ list and 2) grad info dict.
1138
+ """
1139
+ assert len(self.model_gpu_towers) == len(sample_batches)
1140
+ lock = threading.Lock()
1141
+ results = {}
1142
+ grad_enabled = torch.is_grad_enabled()
1143
+
1144
+ def _worker(shard_idx, model, sample_batch, device):
1145
+ torch.set_grad_enabled(grad_enabled)
1146
+ try:
1147
+ with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
1148
+ device
1149
+ ):
1150
+ loss_out = force_list(
1151
+ self.loss(model, self.dist_class, sample_batch)
1152
+ )
1153
+
1154
+ # Call Model's custom-loss with Policy loss outputs and
1155
+ # train_batch.
1156
+ if hasattr(model, "custom_loss"):
1157
+ loss_out = model.custom_loss(loss_out, sample_batch)
1158
+
1159
+ assert len(loss_out) == len(self._optimizers)
1160
+
1161
+ # Loop through all optimizers.
1162
+ grad_info = {"allreduce_latency": 0.0}
1163
+
1164
+ parameters = list(model.parameters())
1165
+ all_grads = [None for _ in range(len(parameters))]
1166
+ for opt_idx, opt in enumerate(self._optimizers):
1167
+ # Erase gradients in all vars of the tower that this
1168
+ # optimizer would affect.
1169
+ param_indices = self.multi_gpu_param_groups[opt_idx]
1170
+ for param_idx, param in enumerate(parameters):
1171
+ if param_idx in param_indices and param.grad is not None:
1172
+ param.grad.data.zero_()
1173
+ # Recompute gradients of loss over all variables.
1174
+ loss_out[opt_idx].backward(retain_graph=True)
1175
+ grad_info.update(
1176
+ self.extra_grad_process(opt, loss_out[opt_idx])
1177
+ )
1178
+
1179
+ grads = []
1180
+ # Note that return values are just references;
1181
+ # Calling zero_grad would modify the values.
1182
+ for param_idx, param in enumerate(parameters):
1183
+ if param_idx in param_indices:
1184
+ if param.grad is not None:
1185
+ grads.append(param.grad)
1186
+ all_grads[param_idx] = param.grad
1187
+
1188
+ if self.distributed_world_size:
1189
+ start = time.time()
1190
+ if torch.cuda.is_available():
1191
+ # Sadly, allreduce_coalesced does not work with
1192
+ # CUDA yet.
1193
+ for g in grads:
1194
+ torch.distributed.all_reduce(
1195
+ g, op=torch.distributed.ReduceOp.SUM
1196
+ )
1197
+ else:
1198
+ torch.distributed.all_reduce_coalesced(
1199
+ grads, op=torch.distributed.ReduceOp.SUM
1200
+ )
1201
+
1202
+ for param_group in opt.param_groups:
1203
+ for p in param_group["params"]:
1204
+ if p.grad is not None:
1205
+ p.grad /= self.distributed_world_size
1206
+
1207
+ grad_info["allreduce_latency"] += time.time() - start
1208
+
1209
+ with lock:
1210
+ results[shard_idx] = (all_grads, grad_info)
1211
+ except Exception as e:
1212
+ import traceback
1213
+
1214
+ with lock:
1215
+ results[shard_idx] = (
1216
+ ValueError(
1217
+ e.args[0]
1218
+ + "\n traceback"
1219
+ + traceback.format_exc()
1220
+ + "\n"
1221
+ + "In tower {} on device {}".format(shard_idx, device)
1222
+ ),
1223
+ e,
1224
+ )
1225
+
1226
+ # Single device (GPU) or fake-GPU case (serialize for better
1227
+ # debugging).
1228
+ if len(self.devices) == 1 or self.config["_fake_gpus"]:
1229
+ for shard_idx, (model, sample_batch, device) in enumerate(
1230
+ zip(self.model_gpu_towers, sample_batches, self.devices)
1231
+ ):
1232
+ _worker(shard_idx, model, sample_batch, device)
1233
+ # Raise errors right away for better debugging.
1234
+ last_result = results[len(results) - 1]
1235
+ if isinstance(last_result[0], ValueError):
1236
+ raise last_result[0] from last_result[1]
1237
+ # Multi device (GPU) case: Parallelize via threads.
1238
+ else:
1239
+ threads = [
1240
+ threading.Thread(
1241
+ target=_worker, args=(shard_idx, model, sample_batch, device)
1242
+ )
1243
+ for shard_idx, (model, sample_batch, device) in enumerate(
1244
+ zip(self.model_gpu_towers, sample_batches, self.devices)
1245
+ )
1246
+ ]
1247
+
1248
+ for thread in threads:
1249
+ thread.start()
1250
+ for thread in threads:
1251
+ thread.join()
1252
+
1253
+ # Gather all threads' outputs and return.
1254
+ outputs = []
1255
+ for shard_idx in range(len(sample_batches)):
1256
+ output = results[shard_idx]
1257
+ if isinstance(output[0], Exception):
1258
+ raise output[0] from output[1]
1259
+ outputs.append(results[shard_idx])
1260
+ return outputs
.venv/lib/python3.11/site-packages/ray/rllib/policy/view_requirement.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import gymnasium as gym
3
+ from typing import Dict, List, Optional, Union
4
+ import numpy as np
5
+
6
+ from ray.rllib.utils.annotations import OldAPIStack
7
+ from ray.rllib.utils.framework import try_import_torch
8
+ from ray.rllib.utils.serialization import (
9
+ gym_space_to_dict,
10
+ gym_space_from_dict,
11
+ )
12
+
13
+ torch, _ = try_import_torch()
14
+
15
+
16
+ @OldAPIStack
17
+ @dataclasses.dataclass
18
+ class ViewRequirement:
19
+ """Single view requirement (for one column in an SampleBatch/input_dict).
20
+
21
+ Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
22
+ their `[train|inference]_view_requirements()` methods, where the str key
23
+ represents the column name (C) under which the view is available in the
24
+ input_dict/SampleBatch and ViewRequirement specifies the actual underlying
25
+ column names (in the original data buffer), timestep shifts, and other
26
+ options to build the view.
27
+
28
+ .. testcode::
29
+ :skipif: True
30
+
31
+ from ray.rllib.models.modelv2 import ModelV2
32
+ # The default ViewRequirement for a Model is:
33
+ req = ModelV2(...).view_requirements
34
+ print(req)
35
+
36
+ .. testoutput::
37
+
38
+ {"obs": ViewRequirement(shift=0)}
39
+
40
+ Args:
41
+ data_col: The data column name from the SampleBatch
42
+ (str key). If None, use the dict key under which this
43
+ ViewRequirement resides.
44
+ space: The gym Space used in case we need to pad data
45
+ in inaccessible areas of the trajectory (t<0 or t>H).
46
+ Default: Simple box space, e.g. rewards.
47
+ shift: Single shift value or
48
+ list of relative positions to use (relative to the underlying
49
+ `data_col`).
50
+ Example: For a view column "prev_actions", you can set
51
+ `data_col="actions"` and `shift=-1`.
52
+ Example: For a view column "obs" in an Atari framestacking
53
+ fashion, you can set `data_col="obs"` and
54
+ `shift=[-3, -2, -1, 0]`.
55
+ Example: For the obs input to an attention net, you can specify
56
+ a range via a str: `shift="-100:0"`, which will pass in
57
+ the past 100 observations plus the current one.
58
+ index: An optional absolute position arg,
59
+ used e.g. for the location of a requested inference dict within
60
+ the trajectory. Negative values refer to counting from the end
61
+ of a trajectory. (#TODO: Is this still used?)
62
+ batch_repeat_value: determines how many time steps we should skip
63
+ before we repeat the view indexing for the next timestep. For RNNs this
64
+ number is usually the sequence length that we will rollout over.
65
+ Example:
66
+ view_col = "state_in_0", data_col = "state_out_0"
67
+ batch_repeat_value = 5, shift = -1
68
+ buffer["state_out_0"] = [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
69
+ output["state_in_0"] = [-1, 4, 9]
70
+ Explanation: For t=0, we output buffer["state_out_0"][-1]. We then skip 5
71
+ time steps and repeat the view. for t=5, we output buffer["state_out_0"][4]
72
+ . Continuing on this pattern, for t=10, we output buffer["state_out_0"][9].
73
+ used_for_compute_actions: Whether the data will be used for
74
+ creating input_dicts for `Policy.compute_actions()` calls (or
75
+ `Policy.compute_actions_from_input_dict()`).
76
+ used_for_training: Whether the data will be used for
77
+ training. If False, the column will not be copied into the
78
+ final train batch.
79
+ """
80
+
81
+ data_col: Optional[str] = None
82
+ space: gym.Space = None
83
+ shift: Union[int, str, List[int]] = 0
84
+ index: Optional[int] = None
85
+ batch_repeat_value: int = 1
86
+ used_for_compute_actions: bool = True
87
+ used_for_training: bool = True
88
+ shift_arr: Optional[np.ndarray] = dataclasses.field(init=False)
89
+
90
+ def __post_init__(self):
91
+ """Initializes a ViewRequirement object.
92
+
93
+ shift_arr is infered from the shift value.
94
+
95
+ For example:
96
+ - if shift is -1, then shift_arr is np.array([-1]).
97
+ - if shift is [-1, -2], then shift_arr is np.array([-2, -1]).
98
+ - if shift is "-2:2", then shift_arr is np.array([-2, -1, 0, 1, 2]).
99
+ """
100
+
101
+ if self.space is None:
102
+ self.space = gym.spaces.Box(float("-inf"), float("inf"), shape=())
103
+
104
+ # TODO: ideally we won't need shift_from and shift_to, and shift_step.
105
+ # all of them should be captured within shift_arr.
106
+ # Special case: Providing a (probably larger) range of indices, e.g.
107
+ # "-100:0" (past 100 timesteps plus current one).
108
+ self.shift_from = self.shift_to = self.shift_step = None
109
+ if isinstance(self.shift, str):
110
+ split = self.shift.split(":")
111
+ assert len(split) in [2, 3], f"Invalid shift str format: {self.shift}"
112
+ if len(split) == 2:
113
+ f, t = split
114
+ self.shift_step = 1
115
+ else:
116
+ f, t, s = split
117
+ self.shift_step = int(s)
118
+
119
+ self.shift_from = int(f)
120
+ self.shift_to = int(t)
121
+
122
+ shift = self.shift
123
+ self.shfit_arr = None
124
+ if self.shift_from:
125
+ self.shift_arr = np.arange(
126
+ self.shift_from, self.shift_to + 1, self.shift_step
127
+ )
128
+ else:
129
+ if isinstance(shift, int):
130
+ self.shift_arr = np.array([shift])
131
+ elif isinstance(shift, list):
132
+ self.shift_arr = np.array(shift)
133
+ else:
134
+ ValueError(f'unrecognized shift type: "{shift}"')
135
+
136
+ def to_dict(self) -> Dict:
137
+ """Return a dict for this ViewRequirement that can be JSON serialized."""
138
+ return {
139
+ "data_col": self.data_col,
140
+ "space": gym_space_to_dict(self.space),
141
+ "shift": self.shift,
142
+ "index": self.index,
143
+ "batch_repeat_value": self.batch_repeat_value,
144
+ "used_for_training": self.used_for_training,
145
+ "used_for_compute_actions": self.used_for_compute_actions,
146
+ }
147
+
148
+ @classmethod
149
+ def from_dict(cls, d: Dict):
150
+ """Construct a ViewRequirement instance from JSON deserialized dict."""
151
+ d["space"] = gym_space_from_dict(d["space"])
152
+ return cls(**d)
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.debug.deterministic import update_global_seed_if_necessary
2
+ from ray.rllib.utils.debug.memory import check_memory_leaks
3
+ from ray.rllib.utils.debug.summary import summarize
4
+
5
+
6
+ __all__ = [
7
+ "check_memory_leaks",
8
+ "summarize",
9
+ "update_global_seed_if_necessary",
10
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (514 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/deterministic.cpython-311.pyc ADDED
Binary file (2.47 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/memory.cpython-311.pyc ADDED
Binary file (8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/__pycache__/summary.cpython-311.pyc ADDED
Binary file (4.97 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/deterministic.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ from typing import Optional
5
+
6
+ from ray.rllib.utils.annotations import DeveloperAPI
7
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
8
+
9
+
10
+ @DeveloperAPI
11
+ def update_global_seed_if_necessary(
12
+ framework: Optional[str] = None, seed: Optional[int] = None
13
+ ) -> None:
14
+ """Seed global modules such as random, numpy, torch, or tf.
15
+
16
+ This is useful for debugging and testing.
17
+
18
+ Args:
19
+ framework: The framework specifier (may be None).
20
+ seed: An optional int seed. If None, will not do
21
+ anything.
22
+ """
23
+ if seed is None:
24
+ return
25
+
26
+ # Python random module.
27
+ random.seed(seed)
28
+ # Numpy.
29
+ np.random.seed(seed)
30
+
31
+ # Torch.
32
+ if framework == "torch":
33
+ torch, _ = try_import_torch()
34
+ torch.manual_seed(seed)
35
+ # See https://github.com/pytorch/pytorch/issues/47672.
36
+ cuda_version = torch.version.cuda
37
+ if cuda_version is not None and float(torch.version.cuda) >= 10.2:
38
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
39
+ else:
40
+ from packaging.version import Version
41
+
42
+ if Version(torch.__version__) >= Version("1.8.0"):
43
+ # Not all Operations support this.
44
+ torch.use_deterministic_algorithms(True)
45
+ else:
46
+ torch.set_deterministic(True)
47
+ # This is only for Convolution no problem.
48
+ torch.backends.cudnn.deterministic = True
49
+ elif framework == "tf2":
50
+ tf1, tf, tfv = try_import_tf()
51
+ # Tf2.x.
52
+ if tfv == 2:
53
+ tf.random.set_seed(seed)
54
+ # Tf1.x.
55
+ else:
56
+ tf1.set_random_seed(seed)
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/memory.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import numpy as np
3
+ import tree # pip install dm_tree
4
+ from typing import DefaultDict, List, Optional, Set
5
+
6
+ from ray.rllib.utils.annotations import DeveloperAPI
7
+ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
8
+ from ray.util.debug import _test_some_code_for_memory_leaks, Suspect
9
+
10
+
11
+ @DeveloperAPI
12
+ def check_memory_leaks(
13
+ algorithm,
14
+ to_check: Optional[Set[str]] = None,
15
+ repeats: Optional[int] = None,
16
+ max_num_trials: int = 3,
17
+ ) -> DefaultDict[str, List[Suspect]]:
18
+ """Diagnoses the given Algorithm for possible memory leaks.
19
+
20
+ Isolates single components inside the Algorithm's local worker, e.g. the env,
21
+ policy, etc.. and calls some of their methods repeatedly, while checking
22
+ the memory footprints and keeping track of which lines in the code add
23
+ un-GC'd items to memory.
24
+
25
+ Args:
26
+ algorithm: The Algorithm instance to test.
27
+ to_check: Set of strings to indentify components to test. Allowed strings
28
+ are: "env", "policy", "model", "rollout_worker". By default, check all
29
+ of these.
30
+ repeats: Number of times the test code block should get executed (per trial).
31
+ If a trial fails, a new trial may get started with a larger number of
32
+ repeats: actual_repeats = `repeats` * (trial + 1) (1st trial == 0).
33
+ max_num_trials: The maximum number of trials to run each check for.
34
+
35
+ Raises:
36
+ A defaultdict(list) with keys being the `to_check` strings and values being
37
+ lists of Suspect instances that were found.
38
+ """
39
+ local_worker = algorithm.env_runner
40
+
41
+ # Which components should we test?
42
+ to_check = to_check or {"env", "model", "policy", "rollout_worker"}
43
+
44
+ results_per_category = defaultdict(list)
45
+
46
+ # Test a single sub-env (first in the VectorEnv)?
47
+ if "env" in to_check:
48
+ assert local_worker.async_env is not None, (
49
+ "ERROR: Cannot test 'env' since given Algorithm does not have one "
50
+ "in its local worker. Try setting `create_env_on_driver=True`."
51
+ )
52
+
53
+ # Isolate the first sub-env in the vectorized setup and test it.
54
+ env = local_worker.async_env.get_sub_environments()[0]
55
+ action_space = env.action_space
56
+ # Always use same action to avoid numpy random caused memory leaks.
57
+ action_sample = action_space.sample()
58
+
59
+ def code():
60
+ ts = 0
61
+ env.reset()
62
+ while True:
63
+ # If masking is used, try something like this:
64
+ # np.random.choice(
65
+ # action_space.n, p=(obs["action_mask"] / sum(obs["action_mask"])))
66
+ _, _, done, _, _ = env.step(action_sample)
67
+ ts += 1
68
+ if done:
69
+ break
70
+
71
+ test = _test_some_code_for_memory_leaks(
72
+ desc="Looking for leaks in env, running through episodes.",
73
+ init=None,
74
+ code=code,
75
+ # How many times to repeat the function call?
76
+ repeats=repeats or 200,
77
+ max_num_trials=max_num_trials,
78
+ )
79
+ if test:
80
+ results_per_category["env"].extend(test)
81
+
82
+ # Test the policy (single-agent case only so far).
83
+ if "policy" in to_check:
84
+ policy = local_worker.policy_map[DEFAULT_POLICY_ID]
85
+
86
+ # Get a fixed obs (B=10).
87
+ obs = tree.map_structure(
88
+ lambda s: np.stack([s] * 10, axis=0), policy.observation_space.sample()
89
+ )
90
+
91
+ print("Looking for leaks in Policy")
92
+
93
+ def code():
94
+ policy.compute_actions_from_input_dict(
95
+ {
96
+ "obs": obs,
97
+ }
98
+ )
99
+
100
+ # Call `compute_actions_from_input_dict()` n times.
101
+ test = _test_some_code_for_memory_leaks(
102
+ desc="Calling `compute_actions_from_input_dict()`.",
103
+ init=None,
104
+ code=code,
105
+ # How many times to repeat the function call?
106
+ repeats=repeats or 400,
107
+ # How many times to re-try if we find a suspicious memory
108
+ # allocation?
109
+ max_num_trials=max_num_trials,
110
+ )
111
+ if test:
112
+ results_per_category["policy"].extend(test)
113
+
114
+ # Testing this only makes sense if the learner API is disabled.
115
+ if not policy.config.get("enable_rl_module_and_learner", False):
116
+ # Call `learn_on_batch()` n times.
117
+ dummy_batch = policy._get_dummy_batch_from_view_requirements(batch_size=16)
118
+
119
+ test = _test_some_code_for_memory_leaks(
120
+ desc="Calling `learn_on_batch()`.",
121
+ init=None,
122
+ code=lambda: policy.learn_on_batch(dummy_batch),
123
+ # How many times to repeat the function call?
124
+ repeats=repeats or 100,
125
+ max_num_trials=max_num_trials,
126
+ )
127
+ if test:
128
+ results_per_category["policy"].extend(test)
129
+
130
+ # Test only the model.
131
+ if "model" in to_check:
132
+ policy = local_worker.policy_map[DEFAULT_POLICY_ID]
133
+
134
+ # Get a fixed obs.
135
+ obs = tree.map_structure(lambda s: s[None], policy.observation_space.sample())
136
+
137
+ print("Looking for leaks in Model")
138
+
139
+ # Call `compute_actions_from_input_dict()` n times.
140
+ test = _test_some_code_for_memory_leaks(
141
+ desc="Calling `[model]()`.",
142
+ init=None,
143
+ code=lambda: policy.model({SampleBatch.OBS: obs}),
144
+ # How many times to repeat the function call?
145
+ repeats=repeats or 400,
146
+ # How many times to re-try if we find a suspicious memory
147
+ # allocation?
148
+ max_num_trials=max_num_trials,
149
+ )
150
+ if test:
151
+ results_per_category["model"].extend(test)
152
+
153
+ # Test the RolloutWorker.
154
+ if "rollout_worker" in to_check:
155
+ print("Looking for leaks in local RolloutWorker")
156
+
157
+ def code():
158
+ local_worker.sample()
159
+ local_worker.get_metrics()
160
+
161
+ # Call `compute_actions_from_input_dict()` n times.
162
+ test = _test_some_code_for_memory_leaks(
163
+ desc="Calling `sample()` and `get_metrics()`.",
164
+ init=None,
165
+ code=code,
166
+ # How many times to repeat the function call?
167
+ repeats=repeats or 50,
168
+ # How many times to re-try if we find a suspicious memory
169
+ # allocation?
170
+ max_num_trials=max_num_trials,
171
+ )
172
+ if test:
173
+ results_per_category["rollout_worker"].extend(test)
174
+
175
+ if "learner" in to_check and algorithm.config.get(
176
+ "enable_rl_module_and_learner", False
177
+ ):
178
+ learner_group = algorithm.learner_group
179
+ assert learner_group._is_local, (
180
+ "This test will miss leaks hidden in remote "
181
+ "workers. Please make sure that there is a "
182
+ "local learner inside the learner group for "
183
+ "this test."
184
+ )
185
+
186
+ dummy_batch = (
187
+ algorithm.get_policy()
188
+ ._get_dummy_batch_from_view_requirements(batch_size=16)
189
+ .as_multi_agent()
190
+ )
191
+
192
+ print("Looking for leaks in Learner")
193
+
194
+ def code():
195
+ learner_group.update(dummy_batch)
196
+
197
+ # Call `compute_actions_from_input_dict()` n times.
198
+ test = _test_some_code_for_memory_leaks(
199
+ desc="Calling `LearnerGroup.update()`.",
200
+ init=None,
201
+ code=code,
202
+ # How many times to repeat the function call?
203
+ repeats=repeats or 400,
204
+ # How many times to re-try if we find a suspicious memory
205
+ # allocation?
206
+ max_num_trials=max_num_trials,
207
+ )
208
+ if test:
209
+ results_per_category["learner"].extend(test)
210
+
211
+ return results_per_category
.venv/lib/python3.11/site-packages/ray/rllib/utils/debug/summary.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pprint
3
+ from typing import Any
4
+
5
+ from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
6
+ from ray.rllib.utils.annotations import DeveloperAPI
7
+
8
+ _printer = pprint.PrettyPrinter(indent=2, width=60)
9
+
10
+
11
+ @DeveloperAPI
12
+ def summarize(obj: Any) -> Any:
13
+ """Return a pretty-formatted string for an object.
14
+
15
+ This has special handling for pretty-formatting of commonly used data types
16
+ in RLlib, such as SampleBatch, numpy arrays, etc.
17
+
18
+ Args:
19
+ obj: The object to format.
20
+
21
+ Returns:
22
+ The summarized object.
23
+ """
24
+
25
+ return _printer.pformat(_summarize(obj))
26
+
27
+
28
+ def _summarize(obj):
29
+ if isinstance(obj, dict):
30
+ return {k: _summarize(v) for k, v in obj.items()}
31
+ elif hasattr(obj, "_asdict"):
32
+ return {
33
+ "type": obj.__class__.__name__,
34
+ "data": _summarize(obj._asdict()),
35
+ }
36
+ elif isinstance(obj, list):
37
+ return [_summarize(x) for x in obj]
38
+ elif isinstance(obj, tuple):
39
+ return tuple(_summarize(x) for x in obj)
40
+ elif isinstance(obj, np.ndarray):
41
+ if obj.size == 0:
42
+ return _StringValue("np.ndarray({}, dtype={})".format(obj.shape, obj.dtype))
43
+ elif obj.dtype == object or obj.dtype.type is np.str_:
44
+ return _StringValue(
45
+ "np.ndarray({}, dtype={}, head={})".format(
46
+ obj.shape, obj.dtype, _summarize(obj[0])
47
+ )
48
+ )
49
+ else:
50
+ return _StringValue(
51
+ "np.ndarray({}, dtype={}, min={}, max={}, mean={})".format(
52
+ obj.shape,
53
+ obj.dtype,
54
+ round(float(np.min(obj)), 3),
55
+ round(float(np.max(obj)), 3),
56
+ round(float(np.mean(obj)), 3),
57
+ )
58
+ )
59
+ elif isinstance(obj, MultiAgentBatch):
60
+ return {
61
+ "type": "MultiAgentBatch",
62
+ "policy_batches": _summarize(obj.policy_batches),
63
+ "count": obj.count,
64
+ }
65
+ elif isinstance(obj, SampleBatch):
66
+ return {
67
+ "type": "SampleBatch",
68
+ "data": {k: _summarize(v) for k, v in obj.items()},
69
+ }
70
+ else:
71
+ return obj
72
+
73
+
74
+ class _StringValue:
75
+ def __init__(self, value):
76
+ self.value = value
77
+
78
+ def __repr__(self):
79
+ return self.value
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.utils.exploration.curiosity import Curiosity
2
+ from ray.rllib.utils.exploration.exploration import Exploration
3
+ from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
4
+ from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise
5
+ from ray.rllib.utils.exploration.ornstein_uhlenbeck_noise import OrnsteinUhlenbeckNoise
6
+ from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
7
+ from ray.rllib.utils.exploration.per_worker_epsilon_greedy import PerWorkerEpsilonGreedy
8
+ from ray.rllib.utils.exploration.per_worker_gaussian_noise import PerWorkerGaussianNoise
9
+ from ray.rllib.utils.exploration.per_worker_ornstein_uhlenbeck_noise import (
10
+ PerWorkerOrnsteinUhlenbeckNoise,
11
+ )
12
+ from ray.rllib.utils.exploration.random import Random
13
+ from ray.rllib.utils.exploration.random_encoder import RE3
14
+ from ray.rllib.utils.exploration.slate_epsilon_greedy import SlateEpsilonGreedy
15
+ from ray.rllib.utils.exploration.slate_soft_q import SlateSoftQ
16
+ from ray.rllib.utils.exploration.soft_q import SoftQ
17
+ from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling
18
+ from ray.rllib.utils.exploration.thompson_sampling import ThompsonSampling
19
+ from ray.rllib.utils.exploration.upper_confidence_bound import UpperConfidenceBound
20
+
21
+ __all__ = [
22
+ "Curiosity",
23
+ "Exploration",
24
+ "EpsilonGreedy",
25
+ "GaussianNoise",
26
+ "OrnsteinUhlenbeckNoise",
27
+ "ParameterNoise",
28
+ "PerWorkerEpsilonGreedy",
29
+ "PerWorkerGaussianNoise",
30
+ "PerWorkerOrnsteinUhlenbeckNoise",
31
+ "Random",
32
+ "RE3",
33
+ "SlateEpsilonGreedy",
34
+ "SlateSoftQ",
35
+ "SoftQ",
36
+ "StochasticSampling",
37
+ "ThompsonSampling",
38
+ "UpperConfidenceBound",
39
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.05 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/curiosity.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/epsilon_greedy.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/exploration.cpython-311.pyc ADDED
Binary file (8.84 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/gaussian_noise.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/ornstein_uhlenbeck_noise.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/parameter_noise.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_epsilon_greedy.cpython-311.pyc ADDED
Binary file (2.58 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_gaussian_noise.cpython-311.pyc ADDED
Binary file (2.39 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/per_worker_ornstein_uhlenbeck_noise.cpython-311.pyc ADDED
Binary file (2.48 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random.cpython-311.pyc ADDED
Binary file (9.35 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/random_encoder.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_epsilon_greedy.cpython-311.pyc ADDED
Binary file (5.8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/slate_soft_q.cpython-311.pyc ADDED
Binary file (2.31 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/soft_q.cpython-311.pyc ADDED
Binary file (3.27 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/stochastic_sampling.cpython-311.pyc ADDED
Binary file (8 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/thompson_sampling.cpython-311.pyc ADDED
Binary file (3.16 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/__pycache__/upper_confidence_bound.cpython-311.pyc ADDED
Binary file (3.02 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/curiosity.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gymnasium.spaces import Discrete, MultiDiscrete, Space
2
+ import numpy as np
3
+ from typing import Optional, Tuple, Union
4
+
5
+ from ray.rllib.models.action_dist import ActionDistribution
6
+ from ray.rllib.models.catalog import ModelCatalog
7
+ from ray.rllib.models.modelv2 import ModelV2
8
+ from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical
9
+ from ray.rllib.models.torch.misc import SlimFC
10
+ from ray.rllib.models.torch.torch_action_dist import (
11
+ TorchCategorical,
12
+ TorchMultiCategorical,
13
+ )
14
+ from ray.rllib.models.utils import get_activation_fn
15
+ from ray.rllib.policy.sample_batch import SampleBatch
16
+ from ray.rllib.utils import NullContextManager
17
+ from ray.rllib.utils.annotations import OldAPIStack, override
18
+ from ray.rllib.utils.exploration.exploration import Exploration
19
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch
20
+ from ray.rllib.utils.from_config import from_config
21
+ from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot
22
+ from ray.rllib.utils.torch_utils import one_hot
23
+ from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType
24
+
25
+ tf1, tf, tfv = try_import_tf()
26
+ torch, nn = try_import_torch()
27
+ F = None
28
+ if nn is not None:
29
+ F = nn.functional
30
+
31
+
32
+ @OldAPIStack
33
+ class Curiosity(Exploration):
34
+ """Implementation of:
35
+ [1] Curiosity-driven Exploration by Self-supervised Prediction
36
+ Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
37
+ https://arxiv.org/pdf/1705.05363.pdf
38
+
39
+ Learns a simplified model of the environment based on three networks:
40
+ 1) Embedding observations into latent space ("feature" network).
41
+ 2) Predicting the action, given two consecutive embedded observations
42
+ ("inverse" network).
43
+ 3) Predicting the next embedded obs, given an obs and action
44
+ ("forward" network).
45
+
46
+ The less the agent is able to predict the actually observed next feature
47
+ vector, given obs and action (through the forwards network), the larger the
48
+ "intrinsic reward", which will be added to the extrinsic reward.
49
+ Therefore, if a state transition was unexpected, the agent becomes
50
+ "curious" and will further explore this transition leading to better
51
+ exploration in sparse rewards environments.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ action_space: Space,
57
+ *,
58
+ framework: str,
59
+ model: ModelV2,
60
+ feature_dim: int = 288,
61
+ feature_net_config: Optional[ModelConfigDict] = None,
62
+ inverse_net_hiddens: Tuple[int] = (256,),
63
+ inverse_net_activation: str = "relu",
64
+ forward_net_hiddens: Tuple[int] = (256,),
65
+ forward_net_activation: str = "relu",
66
+ beta: float = 0.2,
67
+ eta: float = 1.0,
68
+ lr: float = 1e-3,
69
+ sub_exploration: Optional[FromConfigSpec] = None,
70
+ **kwargs
71
+ ):
72
+ """Initializes a Curiosity object.
73
+
74
+ Uses as defaults the hyperparameters described in [1].
75
+
76
+ Args:
77
+ feature_dim: The dimensionality of the feature (phi)
78
+ vectors.
79
+ feature_net_config: Optional model
80
+ configuration for the feature network, producing feature
81
+ vectors (phi) from observations. This can be used to configure
82
+ fcnet- or conv_net setups to properly process any observation
83
+ space.
84
+ inverse_net_hiddens: Tuple of the layer sizes of the
85
+ inverse (action predicting) NN head (on top of the feature
86
+ outputs for phi and phi').
87
+ inverse_net_activation: Activation specifier for the inverse
88
+ net.
89
+ forward_net_hiddens: Tuple of the layer sizes of the
90
+ forward (phi' predicting) NN head.
91
+ forward_net_activation: Activation specifier for the forward
92
+ net.
93
+ beta: Weight for the forward loss (over the inverse loss,
94
+ which gets weight=1.0-beta) in the common loss term.
95
+ eta: Weight for intrinsic rewards before being added to
96
+ extrinsic ones.
97
+ lr: The learning rate for the curiosity-specific
98
+ optimizer, optimizing feature-, inverse-, and forward nets.
99
+ sub_exploration: The config dict for
100
+ the underlying Exploration to use (e.g. epsilon-greedy for
101
+ DQN). If None, uses the FromSpecDict provided in the Policy's
102
+ default config.
103
+ """
104
+ if not isinstance(action_space, (Discrete, MultiDiscrete)):
105
+ raise ValueError(
106
+ "Only (Multi)Discrete action spaces supported for Curiosity so far!"
107
+ )
108
+
109
+ super().__init__(action_space, model=model, framework=framework, **kwargs)
110
+
111
+ if self.policy_config["num_env_runners"] != 0:
112
+ raise ValueError(
113
+ "Curiosity exploration currently does not support parallelism."
114
+ " `num_workers` must be 0!"
115
+ )
116
+
117
+ self.feature_dim = feature_dim
118
+ if feature_net_config is None:
119
+ feature_net_config = self.policy_config["model"].copy()
120
+ self.feature_net_config = feature_net_config
121
+ self.inverse_net_hiddens = inverse_net_hiddens
122
+ self.inverse_net_activation = inverse_net_activation
123
+ self.forward_net_hiddens = forward_net_hiddens
124
+ self.forward_net_activation = forward_net_activation
125
+
126
+ self.action_dim = (
127
+ self.action_space.n
128
+ if isinstance(self.action_space, Discrete)
129
+ else np.sum(self.action_space.nvec)
130
+ )
131
+
132
+ self.beta = beta
133
+ self.eta = eta
134
+ self.lr = lr
135
+ # TODO: (sven) if sub_exploration is None, use Algorithm's default
136
+ # Exploration config.
137
+ if sub_exploration is None:
138
+ raise NotImplementedError
139
+ self.sub_exploration = sub_exploration
140
+
141
+ # Creates modules/layers inside the actual ModelV2.
142
+ self._curiosity_feature_net = ModelCatalog.get_model_v2(
143
+ self.model.obs_space,
144
+ self.action_space,
145
+ self.feature_dim,
146
+ model_config=self.feature_net_config,
147
+ framework=self.framework,
148
+ name="feature_net",
149
+ )
150
+
151
+ self._curiosity_inverse_fcnet = self._create_fc_net(
152
+ [2 * self.feature_dim] + list(self.inverse_net_hiddens) + [self.action_dim],
153
+ self.inverse_net_activation,
154
+ name="inverse_net",
155
+ )
156
+
157
+ self._curiosity_forward_fcnet = self._create_fc_net(
158
+ [self.feature_dim + self.action_dim]
159
+ + list(self.forward_net_hiddens)
160
+ + [self.feature_dim],
161
+ self.forward_net_activation,
162
+ name="forward_net",
163
+ )
164
+
165
+ # This is only used to select the correct action
166
+ self.exploration_submodule = from_config(
167
+ cls=Exploration,
168
+ config=self.sub_exploration,
169
+ action_space=self.action_space,
170
+ framework=self.framework,
171
+ policy_config=self.policy_config,
172
+ model=self.model,
173
+ num_workers=self.num_workers,
174
+ worker_index=self.worker_index,
175
+ )
176
+
177
+ @override(Exploration)
178
+ def get_exploration_action(
179
+ self,
180
+ *,
181
+ action_distribution: ActionDistribution,
182
+ timestep: Union[int, TensorType],
183
+ explore: bool = True
184
+ ):
185
+ # Simply delegate to sub-Exploration module.
186
+ return self.exploration_submodule.get_exploration_action(
187
+ action_distribution=action_distribution, timestep=timestep, explore=explore
188
+ )
189
+
190
+ @override(Exploration)
191
+ def get_exploration_optimizer(self, optimizers):
192
+ # Create, but don't add Adam for curiosity NN updating to the policy.
193
+ # If we added and returned it here, it would be used in the policy's
194
+ # update loop, which we don't want (curiosity updating happens inside
195
+ # `postprocess_trajectory`).
196
+ if self.framework == "torch":
197
+ feature_params = list(self._curiosity_feature_net.parameters())
198
+ inverse_params = list(self._curiosity_inverse_fcnet.parameters())
199
+ forward_params = list(self._curiosity_forward_fcnet.parameters())
200
+
201
+ # Now that the Policy's own optimizer(s) have been created (from
202
+ # the Model parameters (IMPORTANT: w/o(!) the curiosity params),
203
+ # we can add our curiosity sub-modules to the Policy's Model.
204
+ self.model._curiosity_feature_net = self._curiosity_feature_net.to(
205
+ self.device
206
+ )
207
+ self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet.to(
208
+ self.device
209
+ )
210
+ self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet.to(
211
+ self.device
212
+ )
213
+ self._optimizer = torch.optim.Adam(
214
+ forward_params + inverse_params + feature_params, lr=self.lr
215
+ )
216
+ else:
217
+ self.model._curiosity_feature_net = self._curiosity_feature_net
218
+ self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet
219
+ self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet
220
+ # Feature net is a RLlib ModelV2, the other 2 are keras Models.
221
+ self._optimizer_var_list = (
222
+ self._curiosity_feature_net.base_model.variables
223
+ + self._curiosity_inverse_fcnet.variables
224
+ + self._curiosity_forward_fcnet.variables
225
+ )
226
+ self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
227
+ # Create placeholders and initialize the loss.
228
+ if self.framework == "tf":
229
+ self._obs_ph = get_placeholder(
230
+ space=self.model.obs_space, name="_curiosity_obs"
231
+ )
232
+ self._next_obs_ph = get_placeholder(
233
+ space=self.model.obs_space, name="_curiosity_next_obs"
234
+ )
235
+ self._action_ph = get_placeholder(
236
+ space=self.model.action_space, name="_curiosity_action"
237
+ )
238
+ (
239
+ self._forward_l2_norm_sqared,
240
+ self._update_op,
241
+ ) = self._postprocess_helper_tf(
242
+ self._obs_ph, self._next_obs_ph, self._action_ph
243
+ )
244
+
245
+ return optimizers
246
+
247
+ @override(Exploration)
248
+ def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
249
+ """Calculates phi values (obs, obs', and predicted obs') and ri.
250
+
251
+ Also calculates forward and inverse losses and updates the curiosity
252
+ module on the provided batch using our optimizer.
253
+ """
254
+ if self.framework != "torch":
255
+ self._postprocess_tf(policy, sample_batch, tf_sess)
256
+ else:
257
+ self._postprocess_torch(policy, sample_batch)
258
+
259
+ def _postprocess_tf(self, policy, sample_batch, tf_sess):
260
+ # tf1 static-graph: Perform session call on our loss and update ops.
261
+ if self.framework == "tf":
262
+ forward_l2_norm_sqared, _ = tf_sess.run(
263
+ [self._forward_l2_norm_sqared, self._update_op],
264
+ feed_dict={
265
+ self._obs_ph: sample_batch[SampleBatch.OBS],
266
+ self._next_obs_ph: sample_batch[SampleBatch.NEXT_OBS],
267
+ self._action_ph: sample_batch[SampleBatch.ACTIONS],
268
+ },
269
+ )
270
+ # tf-eager: Perform model calls, loss calculations, and optimizer
271
+ # stepping on the fly.
272
+ else:
273
+ forward_l2_norm_sqared, _ = self._postprocess_helper_tf(
274
+ sample_batch[SampleBatch.OBS],
275
+ sample_batch[SampleBatch.NEXT_OBS],
276
+ sample_batch[SampleBatch.ACTIONS],
277
+ )
278
+ # Scale intrinsic reward by eta hyper-parameter.
279
+ sample_batch[SampleBatch.REWARDS] = (
280
+ sample_batch[SampleBatch.REWARDS] + self.eta * forward_l2_norm_sqared
281
+ )
282
+
283
+ return sample_batch
284
+
285
+ def _postprocess_helper_tf(self, obs, next_obs, actions):
286
+ with (
287
+ tf.GradientTape() if self.framework != "tf" else NullContextManager()
288
+ ) as tape:
289
+ # Push both observations through feature net to get both phis.
290
+ phis, _ = self.model._curiosity_feature_net(
291
+ {SampleBatch.OBS: tf.concat([obs, next_obs], axis=0)}
292
+ )
293
+ phi, next_phi = tf.split(phis, 2)
294
+
295
+ # Predict next phi with forward model.
296
+ predicted_next_phi = self.model._curiosity_forward_fcnet(
297
+ tf.concat([phi, tf_one_hot(actions, self.action_space)], axis=-1)
298
+ )
299
+
300
+ # Forward loss term (predicted phi', given phi and action vs
301
+ # actually observed phi').
302
+ forward_l2_norm_sqared = 0.5 * tf.reduce_sum(
303
+ tf.square(predicted_next_phi - next_phi), axis=-1
304
+ )
305
+ forward_loss = tf.reduce_mean(forward_l2_norm_sqared)
306
+
307
+ # Inverse loss term (prediced action that led from phi to phi' vs
308
+ # actual action taken).
309
+ phi_cat_next_phi = tf.concat([phi, next_phi], axis=-1)
310
+ dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
311
+ action_dist = (
312
+ Categorical(dist_inputs, self.model)
313
+ if isinstance(self.action_space, Discrete)
314
+ else MultiCategorical(dist_inputs, self.model, self.action_space.nvec)
315
+ )
316
+ # Neg log(p); p=probability of observed action given the inverse-NN
317
+ # predicted action distribution.
318
+ inverse_loss = -action_dist.logp(tf.convert_to_tensor(actions))
319
+ inverse_loss = tf.reduce_mean(inverse_loss)
320
+
321
+ # Calculate the ICM loss.
322
+ loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
323
+
324
+ # Step the optimizer.
325
+ if self.framework != "tf":
326
+ grads = tape.gradient(loss, self._optimizer_var_list)
327
+ grads_and_vars = [
328
+ (g, v) for g, v in zip(grads, self._optimizer_var_list) if g is not None
329
+ ]
330
+ update_op = self._optimizer.apply_gradients(grads_and_vars)
331
+ else:
332
+ update_op = self._optimizer.minimize(
333
+ loss, var_list=self._optimizer_var_list
334
+ )
335
+
336
+ # Return the squared l2 norm and the optimizer update op.
337
+ return forward_l2_norm_sqared, update_op
338
+
339
+ def _postprocess_torch(self, policy, sample_batch):
340
+ # Push both observations through feature net to get both phis.
341
+ phis, _ = self.model._curiosity_feature_net(
342
+ {
343
+ SampleBatch.OBS: torch.cat(
344
+ [
345
+ torch.from_numpy(sample_batch[SampleBatch.OBS]).to(
346
+ policy.device
347
+ ),
348
+ torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to(
349
+ policy.device
350
+ ),
351
+ ]
352
+ )
353
+ }
354
+ )
355
+ phi, next_phi = torch.chunk(phis, 2)
356
+ actions_tensor = (
357
+ torch.from_numpy(sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)
358
+ )
359
+
360
+ # Predict next phi with forward model.
361
+ predicted_next_phi = self.model._curiosity_forward_fcnet(
362
+ torch.cat([phi, one_hot(actions_tensor, self.action_space).float()], dim=-1)
363
+ )
364
+
365
+ # Forward loss term (predicted phi', given phi and action vs actually
366
+ # observed phi').
367
+ forward_l2_norm_sqared = 0.5 * torch.sum(
368
+ torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1
369
+ )
370
+ forward_loss = torch.mean(forward_l2_norm_sqared)
371
+
372
+ # Scale intrinsic reward by eta hyper-parameter.
373
+ sample_batch[SampleBatch.REWARDS] = (
374
+ sample_batch[SampleBatch.REWARDS]
375
+ + self.eta * forward_l2_norm_sqared.detach().cpu().numpy()
376
+ )
377
+
378
+ # Inverse loss term (prediced action that led from phi to phi' vs
379
+ # actual action taken).
380
+ phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
381
+ dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
382
+ action_dist = (
383
+ TorchCategorical(dist_inputs, self.model)
384
+ if isinstance(self.action_space, Discrete)
385
+ else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec)
386
+ )
387
+ # Neg log(p); p=probability of observed action given the inverse-NN
388
+ # predicted action distribution.
389
+ inverse_loss = -action_dist.logp(actions_tensor)
390
+ inverse_loss = torch.mean(inverse_loss)
391
+
392
+ # Calculate the ICM loss.
393
+ loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
394
+ # Perform an optimizer step.
395
+ self._optimizer.zero_grad()
396
+ loss.backward()
397
+ self._optimizer.step()
398
+
399
+ # Return the postprocessed sample batch (with the corrected rewards).
400
+ return sample_batch
401
+
402
+ def _create_fc_net(self, layer_dims, activation, name=None):
403
+ """Given a list of layer dimensions (incl. input-dim), creates FC-net.
404
+
405
+ Args:
406
+ layer_dims (Tuple[int]): Tuple of layer dims, including the input
407
+ dimension.
408
+ activation: An activation specifier string (e.g. "relu").
409
+
410
+ Examples:
411
+ If layer_dims is [4,8,6] we'll have a two layer net: 4->8 (8 nodes)
412
+ and 8->6 (6 nodes), where the second layer (6 nodes) does not have
413
+ an activation anymore. 4 is the input dimension.
414
+ """
415
+ layers = (
416
+ [tf.keras.layers.Input(shape=(layer_dims[0],), name="{}_in".format(name))]
417
+ if self.framework != "torch"
418
+ else []
419
+ )
420
+
421
+ for i in range(len(layer_dims) - 1):
422
+ act = activation if i < len(layer_dims) - 2 else None
423
+ if self.framework == "torch":
424
+ layers.append(
425
+ SlimFC(
426
+ in_size=layer_dims[i],
427
+ out_size=layer_dims[i + 1],
428
+ initializer=torch.nn.init.xavier_uniform_,
429
+ activation_fn=act,
430
+ )
431
+ )
432
+ else:
433
+ layers.append(
434
+ tf.keras.layers.Dense(
435
+ units=layer_dims[i + 1],
436
+ activation=get_activation_fn(act),
437
+ name="{}_{}".format(name, i),
438
+ )
439
+ )
440
+
441
+ if self.framework == "torch":
442
+ return nn.Sequential(*layers)
443
+ else:
444
+ return tf.keras.Sequential(layers)
.venv/lib/python3.11/site-packages/ray/rllib/utils/exploration/epsilon_greedy.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+ import tree # pip install dm_tree
4
+ import random
5
+ from typing import Union, Optional
6
+
7
+ from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
8
+ from ray.rllib.models.action_dist import ActionDistribution
9
+ from ray.rllib.utils.annotations import override, OldAPIStack
10
+ from ray.rllib.utils.exploration.exploration import Exploration, TensorType
11
+ from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable
12
+ from ray.rllib.utils.from_config import from_config
13
+ from ray.rllib.utils.numpy import convert_to_numpy
14
+ from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
15
+ from ray.rllib.utils.torch_utils import FLOAT_MIN
16
+
17
+ tf1, tf, tfv = try_import_tf()
18
+ torch, _ = try_import_torch()
19
+
20
+
21
+ @OldAPIStack
22
+ class EpsilonGreedy(Exploration):
23
+ """Epsilon-greedy Exploration class that produces exploration actions.
24
+
25
+ When given a Model's output and a current epsilon value (based on some
26
+ Schedule), it produces a random action (if rand(1) < eps) or
27
+ uses the model-computed one (if rand(1) >= eps).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ action_space: gym.spaces.Space,
33
+ *,
34
+ framework: str,
35
+ initial_epsilon: float = 1.0,
36
+ final_epsilon: float = 0.05,
37
+ warmup_timesteps: int = 0,
38
+ epsilon_timesteps: int = int(1e5),
39
+ epsilon_schedule: Optional[Schedule] = None,
40
+ **kwargs,
41
+ ):
42
+ """Create an EpsilonGreedy exploration class.
43
+
44
+ Args:
45
+ action_space: The action space the exploration should occur in.
46
+ framework: The framework specifier.
47
+ initial_epsilon: The initial epsilon value to use.
48
+ final_epsilon: The final epsilon value to use.
49
+ warmup_timesteps: The timesteps over which to not change epsilon in the
50
+ beginning.
51
+ epsilon_timesteps: The timesteps (additional to `warmup_timesteps`)
52
+ after which epsilon should always be `final_epsilon`.
53
+ E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps,
54
+ epsilon will reach its final value.
55
+ epsilon_schedule: An optional Schedule object
56
+ to use (instead of constructing one from the given parameters).
57
+ """
58
+ assert framework is not None
59
+ super().__init__(action_space=action_space, framework=framework, **kwargs)
60
+
61
+ self.epsilon_schedule = from_config(
62
+ Schedule, epsilon_schedule, framework=framework
63
+ ) or PiecewiseSchedule(
64
+ endpoints=[
65
+ (0, initial_epsilon),
66
+ (warmup_timesteps, initial_epsilon),
67
+ (warmup_timesteps + epsilon_timesteps, final_epsilon),
68
+ ],
69
+ outside_value=final_epsilon,
70
+ framework=self.framework,
71
+ )
72
+
73
+ # The current timestep value (tf-var or python int).
74
+ self.last_timestep = get_variable(
75
+ np.array(0, np.int64),
76
+ framework=framework,
77
+ tf_name="timestep",
78
+ dtype=np.int64,
79
+ )
80
+
81
+ # Build the tf-info-op.
82
+ if self.framework == "tf":
83
+ self._tf_state_op = self.get_state()
84
+
85
+ @override(Exploration)
86
+ def get_exploration_action(
87
+ self,
88
+ *,
89
+ action_distribution: ActionDistribution,
90
+ timestep: Union[int, TensorType],
91
+ explore: Optional[Union[bool, TensorType]] = True,
92
+ ):
93
+
94
+ if self.framework in ["tf2", "tf"]:
95
+ return self._get_tf_exploration_action_op(
96
+ action_distribution, explore, timestep
97
+ )
98
+ else:
99
+ return self._get_torch_exploration_action(
100
+ action_distribution, explore, timestep
101
+ )
102
+
103
+ def _get_tf_exploration_action_op(
104
+ self,
105
+ action_distribution: ActionDistribution,
106
+ explore: Union[bool, TensorType],
107
+ timestep: Union[int, TensorType],
108
+ ) -> "tf.Tensor":
109
+ """TF method to produce the tf op for an epsilon exploration action.
110
+
111
+ Args:
112
+ action_distribution: The instantiated ActionDistribution object
113
+ to work with when creating exploration actions.
114
+
115
+ Returns:
116
+ The tf exploration-action op.
117
+ """
118
+ # TODO: Support MultiActionDistr for tf.
119
+ q_values = action_distribution.inputs
120
+ epsilon = self.epsilon_schedule(
121
+ timestep if timestep is not None else self.last_timestep
122
+ )
123
+
124
+ # Get the exploit action as the one with the highest logit value.
125
+ exploit_action = tf.argmax(q_values, axis=1)
126
+
127
+ batch_size = tf.shape(q_values)[0]
128
+ # Mask out actions with q-value=-inf so that we don't even consider
129
+ # them for exploration.
130
+ random_valid_action_logits = tf.where(
131
+ tf.equal(q_values, tf.float32.min),
132
+ tf.ones_like(q_values) * tf.float32.min,
133
+ tf.ones_like(q_values),
134
+ )
135
+ random_actions = tf.squeeze(
136
+ tf.random.categorical(random_valid_action_logits, 1), axis=1
137
+ )
138
+
139
+ chose_random = (
140
+ tf.random.uniform(
141
+ tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
142
+ )
143
+ < epsilon
144
+ )
145
+
146
+ action = tf.cond(
147
+ pred=tf.constant(explore, dtype=tf.bool)
148
+ if isinstance(explore, bool)
149
+ else explore,
150
+ true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)),
151
+ false_fn=lambda: exploit_action,
152
+ )
153
+
154
+ if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
155
+ self.last_timestep = timestep
156
+ return action, tf.zeros_like(action, dtype=tf.float32)
157
+ else:
158
+ assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
159
+ with tf1.control_dependencies([assign_op]):
160
+ return action, tf.zeros_like(action, dtype=tf.float32)
161
+
162
+ def _get_torch_exploration_action(
163
+ self,
164
+ action_distribution: ActionDistribution,
165
+ explore: bool,
166
+ timestep: Union[int, TensorType],
167
+ ) -> "torch.Tensor":
168
+ """Torch method to produce an epsilon exploration action.
169
+
170
+ Args:
171
+ action_distribution: The instantiated
172
+ ActionDistribution object to work with when creating
173
+ exploration actions.
174
+
175
+ Returns:
176
+ The exploration-action.
177
+ """
178
+ q_values = action_distribution.inputs
179
+ self.last_timestep = timestep
180
+ exploit_action = action_distribution.deterministic_sample()
181
+ batch_size = q_values.size()[0]
182
+ action_logp = torch.zeros(batch_size, dtype=torch.float)
183
+
184
+ # Explore.
185
+ if explore:
186
+ # Get the current epsilon.
187
+ epsilon = self.epsilon_schedule(self.last_timestep)
188
+ if isinstance(action_distribution, TorchMultiActionDistribution):
189
+ exploit_action = tree.flatten(exploit_action)
190
+ for i in range(batch_size):
191
+ if random.random() < epsilon:
192
+ # TODO: (bcahlit) Mask out actions
193
+ random_action = tree.flatten(self.action_space.sample())
194
+ for j in range(len(exploit_action)):
195
+ exploit_action[j][i] = torch.tensor(random_action[j])
196
+ exploit_action = tree.unflatten_as(
197
+ action_distribution.action_space_struct, exploit_action
198
+ )
199
+
200
+ return exploit_action, action_logp
201
+
202
+ else:
203
+ # Mask out actions, whose Q-values are -inf, so that we don't
204
+ # even consider them for exploration.
205
+ random_valid_action_logits = torch.where(
206
+ q_values <= FLOAT_MIN,
207
+ torch.ones_like(q_values) * 0.0,
208
+ torch.ones_like(q_values),
209
+ )
210
+ # A random action.
211
+ random_actions = torch.squeeze(
212
+ torch.multinomial(random_valid_action_logits, 1), axis=1
213
+ )
214
+
215
+ # Pick either random or greedy.
216
+ action = torch.where(
217
+ torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
218
+ random_actions,
219
+ exploit_action,
220
+ )
221
+
222
+ return action, action_logp
223
+ # Return the deterministic "sample" (argmax) over the logits.
224
+ else:
225
+ return exploit_action, action_logp
226
+
227
+ @override(Exploration)
228
+ def get_state(self, sess: Optional["tf.Session"] = None):
229
+ if sess:
230
+ return sess.run(self._tf_state_op)
231
+ eps = self.epsilon_schedule(self.last_timestep)
232
+ return {
233
+ "cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps,
234
+ "last_timestep": convert_to_numpy(self.last_timestep)
235
+ if self.framework != "tf"
236
+ else self.last_timestep,
237
+ }
238
+
239
+ @override(Exploration)
240
+ def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
241
+ if self.framework == "tf":
242
+ self.last_timestep.load(state["last_timestep"], session=sess)
243
+ elif isinstance(self.last_timestep, int):
244
+ self.last_timestep = state["last_timestep"]
245
+ else:
246
+ self.last_timestep.assign(state["last_timestep"])