koichi12 commited on
Commit
5f20f96
·
verified ·
1 Parent(s): 30f24c0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__init__.py +12 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_learner.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_rl_module.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_tf_policy.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_torch_policy.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/default_appo_rl_module.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/utils.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo.py +434 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_learner.py +147 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_rl_module.py +11 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py +393 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_torch_policy.py +412 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/default_appo_rl_module.py +59 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__init__.py +0 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_learner.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_rl_module.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/default_appo_torch_rl_module.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_learner.py +234 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_rl_module.py +13 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/default_appo_torch_rl_module.py +10 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/utils.py +133 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/__init__.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_catalog.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_learner.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_rl_module.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/actor_network.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/critic_network.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/disagree_networks.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/dreamer_model.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/world_model.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__init__.py +0 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/__init__.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/cnn_atari.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/continue_predictor.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/conv_transpose_atari.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/dynamics_predictor.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/mlp.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/representation_layer.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor_layer.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/sequence_model.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/vector_decoder.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py +94 -0
.gitattributes CHANGED
@@ -175,3 +175,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
175
  .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
176
  .venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
177
  .venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
175
  .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
176
  .venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
177
  .venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
178
+ .venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebb34d8a5e73fa6657fb50dde3c5afc10ca55bef89431f9fbe15555295f4da0e
3
+ size 168124
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.appo.appo import APPO, APPOConfig
2
+ from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy, APPOTF2Policy
3
+ from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
4
+
5
+ __all__ = [
6
+ "APPO",
7
+ "APPOConfig",
8
+ # @OldAPIStack
9
+ "APPOTF1Policy",
10
+ "APPOTF2Policy",
11
+ "APPOTorchPolicy",
12
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (580 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_learner.cpython-311.pyc ADDED
Binary file (8.25 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_rl_module.cpython-311.pyc ADDED
Binary file (637 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_tf_policy.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_torch_policy.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/default_appo_rl_module.cpython-311.pyc ADDED
Binary file (3.85 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.31 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Asynchronous Proximal Policy Optimization (APPO)
2
+
3
+ The algorithm is described in [1] (under the name of "IMPACT"):
4
+
5
+ Detailed documentation:
6
+ https://docs.ray.io/en/master/rllib-algorithms.html#appo
7
+
8
+ [1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
9
+ Luo et al. 2020
10
+ https://arxiv.org/pdf/1912.00167
11
+ """
12
+
13
+ from typing import Optional, Type
14
+ import logging
15
+
16
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
17
+ from ray.rllib.algorithms.impala.impala import IMPALA, IMPALAConfig
18
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
19
+ from ray.rllib.policy.policy import Policy
20
+ from ray.rllib.utils.annotations import override
21
+ from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
22
+ from ray.rllib.utils.metrics import (
23
+ LAST_TARGET_UPDATE_TS,
24
+ NUM_AGENT_STEPS_SAMPLED,
25
+ NUM_ENV_STEPS_SAMPLED,
26
+ NUM_TARGET_UPDATES,
27
+ )
28
+ from ray.rllib.utils.metrics import LEARNER_STATS_KEY
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
34
+ LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
35
+ OLD_ACTION_DIST_KEY = "old_action_dist"
36
+
37
+
38
+ class APPOConfig(IMPALAConfig):
39
+ """Defines a configuration class from which an APPO Algorithm can be built.
40
+
41
+ .. testcode::
42
+
43
+ from ray.rllib.algorithms.appo import APPOConfig
44
+ config = (
45
+ APPOConfig()
46
+ .training(lr=0.01, grad_clip=30.0, train_batch_size_per_learner=50)
47
+ )
48
+ config = config.learners(num_learners=1)
49
+ config = config.env_runners(num_env_runners=1)
50
+ config = config.environment("CartPole-v1")
51
+
52
+ # Build an Algorithm object from the config and run 1 training iteration.
53
+ algo = config.build()
54
+ algo.train()
55
+ del algo
56
+
57
+ .. testcode::
58
+
59
+ from ray.rllib.algorithms.appo import APPOConfig
60
+ from ray import air
61
+ from ray import tune
62
+
63
+ config = APPOConfig()
64
+ # Update the config object.
65
+ config = config.training(lr=tune.grid_search([0.001,]))
66
+ # Set the config object's env.
67
+ config = config.environment(env="CartPole-v1")
68
+ # Use to_dict() to get the old-style python config dict when running with tune.
69
+ tune.Tuner(
70
+ "APPO",
71
+ run_config=air.RunConfig(
72
+ stop={"training_iteration": 1},
73
+ verbose=0,
74
+ ),
75
+ param_space=config.to_dict(),
76
+
77
+ ).fit()
78
+
79
+ .. testoutput::
80
+ :hide:
81
+
82
+ ...
83
+ """
84
+
85
+ def __init__(self, algo_class=None):
86
+ """Initializes a APPOConfig instance."""
87
+ self.exploration_config = {
88
+ # The Exploration class to use. In the simplest case, this is the name
89
+ # (str) of any class present in the `rllib.utils.exploration` package.
90
+ # You can also provide the python class directly or the full location
91
+ # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
92
+ # EpsilonGreedy").
93
+ "type": "StochasticSampling",
94
+ # Add constructor kwargs here (if any).
95
+ }
96
+
97
+ super().__init__(algo_class=algo_class or APPO)
98
+
99
+ # fmt: off
100
+ # __sphinx_doc_begin__
101
+ # APPO specific settings:
102
+ self.vtrace = True
103
+ self.use_gae = True
104
+ self.lambda_ = 1.0
105
+ self.clip_param = 0.4
106
+ self.use_kl_loss = False
107
+ self.kl_coeff = 1.0
108
+ self.kl_target = 0.01
109
+ self.target_worker_clipping = 2.0
110
+
111
+ # Circular replay buffer settings.
112
+ # Used in [1] for discrete action tasks:
113
+ # `circular_buffer_num_batches=4` and `circular_buffer_iterations_per_batch=2`
114
+ # For cont. action tasks:
115
+ # `circular_buffer_num_batches=16` and `circular_buffer_iterations_per_batch=20`
116
+ self.circular_buffer_num_batches = 4
117
+ self.circular_buffer_iterations_per_batch = 2
118
+
119
+ # Override some of IMPALAConfig's default values with APPO-specific values.
120
+ self.num_env_runners = 2
121
+ self.target_network_update_freq = 2
122
+ self.broadcast_interval = 1
123
+ self.grad_clip = 40.0
124
+ # Note: Only when using enable_rl_module_and_learner=True can the clipping mode
125
+ # be configured by the user. On the old API stack, RLlib will always clip by
126
+ # global_norm, no matter the value of `grad_clip_by`.
127
+ self.grad_clip_by = "global_norm"
128
+
129
+ self.opt_type = "adam"
130
+ self.lr = 0.0005
131
+ self.decay = 0.99
132
+ self.momentum = 0.0
133
+ self.epsilon = 0.1
134
+ self.vf_loss_coeff = 0.5
135
+ self.entropy_coeff = 0.01
136
+ self.tau = 1.0
137
+ # __sphinx_doc_end__
138
+ # fmt: on
139
+
140
+ self.lr_schedule = None # @OldAPIStack
141
+ self.entropy_coeff_schedule = None # @OldAPIStack
142
+ self.num_gpus = 0 # @OldAPIStack
143
+ self.num_multi_gpu_tower_stacks = 1 # @OldAPIStack
144
+ self.minibatch_buffer_size = 1 # @OldAPIStack
145
+ self.replay_proportion = 0.0 # @OldAPIStack
146
+ self.replay_buffer_num_slots = 100 # @OldAPIStack
147
+ self.learner_queue_size = 16 # @OldAPIStack
148
+ self.learner_queue_timeout = 300 # @OldAPIStack
149
+
150
+ # Deprecated keys.
151
+ self.target_update_frequency = DEPRECATED_VALUE
152
+ self.use_critic = DEPRECATED_VALUE
153
+
154
+ @override(IMPALAConfig)
155
+ def training(
156
+ self,
157
+ *,
158
+ vtrace: Optional[bool] = NotProvided,
159
+ use_gae: Optional[bool] = NotProvided,
160
+ lambda_: Optional[float] = NotProvided,
161
+ clip_param: Optional[float] = NotProvided,
162
+ use_kl_loss: Optional[bool] = NotProvided,
163
+ kl_coeff: Optional[float] = NotProvided,
164
+ kl_target: Optional[float] = NotProvided,
165
+ target_network_update_freq: Optional[int] = NotProvided,
166
+ tau: Optional[float] = NotProvided,
167
+ target_worker_clipping: Optional[float] = NotProvided,
168
+ circular_buffer_num_batches: Optional[int] = NotProvided,
169
+ circular_buffer_iterations_per_batch: Optional[int] = NotProvided,
170
+ # Deprecated keys.
171
+ target_update_frequency=DEPRECATED_VALUE,
172
+ use_critic=DEPRECATED_VALUE,
173
+ **kwargs,
174
+ ) -> "APPOConfig":
175
+ """Sets the training related configuration.
176
+
177
+ Args:
178
+ vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE
179
+ advantages will be used instead.
180
+ use_gae: If true, use the Generalized Advantage Estimator (GAE)
181
+ with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
182
+ Only applies if vtrace=False.
183
+ lambda_: GAE (lambda) parameter.
184
+ clip_param: PPO surrogate slipping parameter.
185
+ use_kl_loss: Whether to use the KL-term in the loss function.
186
+ kl_coeff: Coefficient for weighting the KL-loss term.
187
+ kl_target: Target term for the KL-term to reach (via adjusting the
188
+ `kl_coeff` automatically).
189
+ target_network_update_freq: NOTE: This parameter is only applicable on
190
+ the new API stack. The frequency with which to update the target
191
+ policy network from the main trained policy network. The metric
192
+ used is `NUM_ENV_STEPS_TRAINED_LIFETIME` and the unit is `n` (see [1]
193
+ 4.1.1), where: `n = [circular_buffer_num_batches (N)] *
194
+ [circular_buffer_iterations_per_batch (K)] * [train batch size]`
195
+ For example, if you set `target_network_update_freq=2`, and N=4, K=2,
196
+ and `train_batch_size_per_learner=500`, then the target net is updated
197
+ every 2*4*2*500=8000 trained env steps (every 16 batch updates on each
198
+ learner).
199
+ The authors in [1] suggests that this setting is robust to a range of
200
+ choices (try values between 0.125 and 4).
201
+ target_network_update_freq: The frequency to update the target policy and
202
+ tune the kl loss coefficients that are used during training. After
203
+ setting this parameter, the algorithm waits for at least
204
+ `target_network_update_freq` number of environment samples to be trained
205
+ on before updating the target networks and tune the kl loss
206
+ coefficients. NOTE: This parameter is only applicable when using the
207
+ Learner API (enable_rl_module_and_learner=True).
208
+ tau: The factor by which to update the target policy network towards
209
+ the current policy network. Can range between 0 and 1.
210
+ e.g. updated_param = tau * current_param + (1 - tau) * target_param
211
+ target_worker_clipping: The maximum value for the target-worker-clipping
212
+ used for computing the IS ratio, described in [1]
213
+ IS = min(π(i) / π(target), ρ) * (π / π(i))
214
+ circular_buffer_num_batches: The number of train batches that fit
215
+ into the circular buffer. Each such train batch can be sampled for
216
+ training max. `circular_buffer_iterations_per_batch` times.
217
+ circular_buffer_iterations_per_batch: The number of times any train
218
+ batch in the circular buffer can be sampled for training. A batch gets
219
+ evicted from the buffer either if it's the oldest batch in the buffer
220
+ and a new batch is added OR if the batch reaches this max. number of
221
+ being sampled.
222
+
223
+ Returns:
224
+ This updated AlgorithmConfig object.
225
+ """
226
+ if target_update_frequency != DEPRECATED_VALUE:
227
+ deprecation_warning(
228
+ old="target_update_frequency",
229
+ new="target_network_update_freq",
230
+ error=True,
231
+ )
232
+ if use_critic != DEPRECATED_VALUE:
233
+ deprecation_warning(
234
+ old="use_critic",
235
+ help="`use_critic` no longer supported! APPO always uses a value "
236
+ "function (critic).",
237
+ error=True,
238
+ )
239
+
240
+ # Pass kwargs onto super's `training()` method.
241
+ super().training(**kwargs)
242
+
243
+ if vtrace is not NotProvided:
244
+ self.vtrace = vtrace
245
+ if use_gae is not NotProvided:
246
+ self.use_gae = use_gae
247
+ if lambda_ is not NotProvided:
248
+ self.lambda_ = lambda_
249
+ if clip_param is not NotProvided:
250
+ self.clip_param = clip_param
251
+ if use_kl_loss is not NotProvided:
252
+ self.use_kl_loss = use_kl_loss
253
+ if kl_coeff is not NotProvided:
254
+ self.kl_coeff = kl_coeff
255
+ if kl_target is not NotProvided:
256
+ self.kl_target = kl_target
257
+ if target_network_update_freq is not NotProvided:
258
+ self.target_network_update_freq = target_network_update_freq
259
+ if tau is not NotProvided:
260
+ self.tau = tau
261
+ if target_worker_clipping is not NotProvided:
262
+ self.target_worker_clipping = target_worker_clipping
263
+ if circular_buffer_num_batches is not NotProvided:
264
+ self.circular_buffer_num_batches = circular_buffer_num_batches
265
+ if circular_buffer_iterations_per_batch is not NotProvided:
266
+ self.circular_buffer_iterations_per_batch = (
267
+ circular_buffer_iterations_per_batch
268
+ )
269
+
270
+ return self
271
+
272
+ @override(IMPALAConfig)
273
+ def validate(self) -> None:
274
+ super().validate()
275
+
276
+ # On new API stack, circular buffer should be used, not `minibatch_buffer_size`.
277
+ if self.enable_rl_module_and_learner:
278
+ if self.minibatch_buffer_size != 1 or self.replay_proportion != 0.0:
279
+ self._value_error(
280
+ "`minibatch_buffer_size/replay_proportion` not valid on new API "
281
+ "stack with APPO! "
282
+ "Use `circular_buffer_num_batches` for the number of train batches "
283
+ "in the circular buffer. To change the maximum number of times "
284
+ "any batch may be sampled, set "
285
+ "`circular_buffer_iterations_per_batch`."
286
+ )
287
+ if self.num_multi_gpu_tower_stacks != 1:
288
+ self._value_error(
289
+ "`num_multi_gpu_tower_stacks` not supported on new API stack with "
290
+ "APPO! In order to train on multi-GPU, use "
291
+ "`config.learners(num_learners=[number of GPUs], "
292
+ "num_gpus_per_learner=1)`. To scale the throughput of batch-to-GPU-"
293
+ "pre-loading on each of your `Learners`, set "
294
+ "`num_gpu_loader_threads` to a higher number (recommended values: "
295
+ "1-8)."
296
+ )
297
+ if self.learner_queue_size != 16:
298
+ self._value_error(
299
+ "`learner_queue_size` not supported on new API stack with "
300
+ "APPO! In order set the size of the circular buffer (which acts as "
301
+ "a 'learner queue'), use "
302
+ "`config.training(circular_buffer_num_batches=..)`. To change the "
303
+ "maximum number of times any batch may be sampled, set "
304
+ "`config.training(circular_buffer_iterations_per_batch=..)`."
305
+ )
306
+
307
+ @override(IMPALAConfig)
308
+ def get_default_learner_class(self):
309
+ if self.framework_str == "torch":
310
+ from ray.rllib.algorithms.appo.torch.appo_torch_learner import (
311
+ APPOTorchLearner,
312
+ )
313
+
314
+ return APPOTorchLearner
315
+ elif self.framework_str in ["tf2", "tf"]:
316
+ raise ValueError(
317
+ "TensorFlow is no longer supported on the new API stack! "
318
+ "Use `framework='torch'`."
319
+ )
320
+ else:
321
+ raise ValueError(
322
+ f"The framework {self.framework_str} is not supported. "
323
+ "Use `framework='torch'`."
324
+ )
325
+
326
+ @override(IMPALAConfig)
327
+ def get_default_rl_module_spec(self) -> RLModuleSpec:
328
+ if self.framework_str == "torch":
329
+ from ray.rllib.algorithms.appo.torch.appo_torch_rl_module import (
330
+ APPOTorchRLModule as RLModule,
331
+ )
332
+ else:
333
+ raise ValueError(
334
+ f"The framework {self.framework_str} is not supported. "
335
+ "Use either 'torch' or 'tf2'."
336
+ )
337
+
338
+ return RLModuleSpec(module_class=RLModule)
339
+
340
+ @property
341
+ @override(AlgorithmConfig)
342
+ def _model_config_auto_includes(self):
343
+ return super()._model_config_auto_includes | {"vf_share_layers": False}
344
+
345
+
346
+ class APPO(IMPALA):
347
+ def __init__(self, config, *args, **kwargs):
348
+ """Initializes an APPO instance."""
349
+ super().__init__(config, *args, **kwargs)
350
+
351
+ # After init: Initialize target net.
352
+
353
+ # TODO(avnishn): Does this need to happen in __init__? I think we can move it
354
+ # to setup()
355
+ if not self.config.enable_rl_module_and_learner:
356
+ self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())
357
+
358
+ @override(IMPALA)
359
+ def training_step(self) -> None:
360
+ if self.config.enable_rl_module_and_learner:
361
+ return super().training_step()
362
+
363
+ train_results = super().training_step()
364
+ # Update the target network and the KL coefficient for the APPO-loss.
365
+ # The target network update frequency is calculated automatically by the product
366
+ # of `num_epochs` setting (usually 1 for APPO) and `minibatch_buffer_size`.
367
+ last_update = self._counters[LAST_TARGET_UPDATE_TS]
368
+ cur_ts = self._counters[
369
+ (
370
+ NUM_AGENT_STEPS_SAMPLED
371
+ if self.config.count_steps_by == "agent_steps"
372
+ else NUM_ENV_STEPS_SAMPLED
373
+ )
374
+ ]
375
+ target_update_freq = self.config.num_epochs * self.config.minibatch_buffer_size
376
+ if cur_ts - last_update > target_update_freq:
377
+ self._counters[NUM_TARGET_UPDATES] += 1
378
+ self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
379
+
380
+ # Update our target network.
381
+ self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())
382
+
383
+ # Also update the KL-coefficient for the APPO loss, if necessary.
384
+ if self.config.use_kl_loss:
385
+
386
+ def update(pi, pi_id):
387
+ assert LEARNER_STATS_KEY not in train_results, (
388
+ "{} should be nested under policy id key".format(
389
+ LEARNER_STATS_KEY
390
+ ),
391
+ train_results,
392
+ )
393
+ if pi_id in train_results:
394
+ kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
395
+ assert kl is not None, (train_results, pi_id)
396
+ # Make the actual `Policy.update_kl()` call.
397
+ pi.update_kl(kl)
398
+ else:
399
+ logger.warning("No data for {}, not updating kl".format(pi_id))
400
+
401
+ # Update KL on all trainable policies within the local (trainer)
402
+ # Worker.
403
+ self.env_runner.foreach_policy_to_train(update)
404
+
405
+ return train_results
406
+
407
+ @classmethod
408
+ @override(IMPALA)
409
+ def get_default_config(cls) -> AlgorithmConfig:
410
+ return APPOConfig()
411
+
412
+ @classmethod
413
+ @override(IMPALA)
414
+ def get_default_policy_class(
415
+ cls, config: AlgorithmConfig
416
+ ) -> Optional[Type[Policy]]:
417
+ if config["framework"] == "torch":
418
+ from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
419
+
420
+ return APPOTorchPolicy
421
+ elif config["framework"] == "tf":
422
+ if config.enable_rl_module_and_learner:
423
+ raise ValueError(
424
+ "RLlib's RLModule and Learner API is not supported for"
425
+ " tf1. Use "
426
+ "framework='tf2' instead."
427
+ )
428
+ from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy
429
+
430
+ return APPOTF1Policy
431
+ else:
432
+ from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy
433
+
434
+ return APPOTF2Policy
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_learner.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any, Dict, Optional
3
+
4
+ from ray.rllib.algorithms.appo.appo import APPOConfig
5
+ from ray.rllib.algorithms.appo.utils import CircularBuffer
6
+ from ray.rllib.algorithms.impala.impala_learner import IMPALALearner
7
+ from ray.rllib.core.learner.learner import Learner
8
+ from ray.rllib.core.learner.utils import update_target_network
9
+ from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI
10
+ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
11
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
12
+ from ray.rllib.utils.annotations import override
13
+ from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
14
+ from ray.rllib.utils.metrics import (
15
+ LAST_TARGET_UPDATE_TS,
16
+ NUM_ENV_STEPS_TRAINED_LIFETIME,
17
+ NUM_MODULE_STEPS_TRAINED,
18
+ NUM_TARGET_UPDATES,
19
+ )
20
+ from ray.rllib.utils.schedules.scheduler import Scheduler
21
+ from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn
22
+
23
+
24
+ class APPOLearner(IMPALALearner):
25
+ """Adds KL coeff updates via `after_gradient_based_update()` to IMPALA logic.
26
+
27
+ Framework-specific subclasses must override `_update_module_kl_coeff()`.
28
+ """
29
+
30
+ @override(IMPALALearner)
31
+ def build(self):
32
+ self._learner_thread_in_queue = CircularBuffer(
33
+ num_batches=self.config.circular_buffer_num_batches,
34
+ iterations_per_batch=self.config.circular_buffer_iterations_per_batch,
35
+ )
36
+
37
+ super().build()
38
+
39
+ # Make target networks.
40
+ self.module.foreach_module(
41
+ lambda mid, mod: (
42
+ mod.make_target_networks()
43
+ if isinstance(mod, TargetNetworkAPI)
44
+ else None
45
+ )
46
+ )
47
+
48
+ # The current kl coefficients per module as (framework specific) tensor
49
+ # variables.
50
+ self.curr_kl_coeffs_per_module: LambdaDefaultDict[
51
+ ModuleID, Scheduler
52
+ ] = LambdaDefaultDict(
53
+ lambda module_id: self._get_tensor_variable(
54
+ self.config.get_config_for_module(module_id).kl_coeff
55
+ )
56
+ )
57
+
58
+ @override(Learner)
59
+ def add_module(
60
+ self,
61
+ *,
62
+ module_id: ModuleID,
63
+ module_spec: RLModuleSpec,
64
+ config_overrides: Optional[Dict] = None,
65
+ new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
66
+ ) -> MultiRLModuleSpec:
67
+ marl_spec = super().add_module(
68
+ module_id=module_id,
69
+ module_spec=module_spec,
70
+ config_overrides=config_overrides,
71
+ new_should_module_be_updated=new_should_module_be_updated,
72
+ )
73
+ # Create target networks for added Module, if applicable.
74
+ if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
75
+ self.module[module_id].unwrapped().make_target_networks()
76
+ return marl_spec
77
+
78
+ @override(IMPALALearner)
79
+ def remove_module(self, module_id: str) -> MultiRLModuleSpec:
80
+ marl_spec = super().remove_module(module_id)
81
+ self.curr_kl_coeffs_per_module.pop(module_id)
82
+ return marl_spec
83
+
84
+ @override(Learner)
85
+ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
86
+ """Updates the target Q Networks."""
87
+ super().after_gradient_based_update(timesteps=timesteps)
88
+
89
+ # TODO (sven): Maybe we should have a `after_gradient_based_update`
90
+ # method per module?
91
+ curr_timestep = timesteps.get(NUM_ENV_STEPS_TRAINED_LIFETIME, 0)
92
+ for module_id, module in self.module._rl_modules.items():
93
+ config = self.config.get_config_for_module(module_id)
94
+
95
+ last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
96
+ if isinstance(module.unwrapped(), TargetNetworkAPI) and (
97
+ curr_timestep - self.metrics.peek(last_update_ts_key, default=0)
98
+ >= (
99
+ config.target_network_update_freq
100
+ * config.circular_buffer_num_batches
101
+ * config.circular_buffer_iterations_per_batch
102
+ * config.train_batch_size_per_learner
103
+ )
104
+ ):
105
+ for (
106
+ main_net,
107
+ target_net,
108
+ ) in module.unwrapped().get_target_network_pairs():
109
+ update_target_network(
110
+ main_net=main_net,
111
+ target_net=target_net,
112
+ tau=config.tau,
113
+ )
114
+ # Increase lifetime target network update counter by one.
115
+ self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
116
+ # Update the (single-value -> window=1) last updated timestep metric.
117
+ self.metrics.log_value(last_update_ts_key, curr_timestep, window=1)
118
+
119
+ if (
120
+ config.use_kl_loss
121
+ and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0)
122
+ > 0
123
+ ):
124
+ self._update_module_kl_coeff(module_id=module_id, config=config)
125
+
126
+ @classmethod
127
+ @override(Learner)
128
+ def rl_module_required_apis(cls) -> list[type]:
129
+ # In order for a PPOLearner to update an RLModule, it must implement the
130
+ # following APIs:
131
+ return [TargetNetworkAPI, ValueFunctionAPI]
132
+
133
+ @abc.abstractmethod
134
+ def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
135
+ """Dynamically update the KL loss coefficients of each module.
136
+
137
+ The update is completed using the mean KL divergence between the action
138
+ distributions current policy and old policy of each module. That action
139
+ distribution is computed during the most recent update/call to `compute_loss`.
140
+
141
+ Args:
142
+ module_id: The module whose KL loss coefficient to update.
143
+ config: The AlgorithmConfig specific to the given `module_id`.
144
+ """
145
+
146
+
147
+ AppoLearner = APPOLearner
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_rl_module.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backward compat import.
2
+ from ray.rllib.algorithms.appo.default_appo_rl_module import ( # noqa
3
+ DefaultAPPORLModule as APPORLModule,
4
+ )
5
+ from ray.rllib.utils.deprecation import deprecation_warning
6
+
7
+ deprecation_warning(
8
+ old="ray.rllib.algorithms.appo.appo_rl_module.APPORLModule",
9
+ new="ray.rllib.algorithms.appo.default_appo_rl_module.DefaultAPPORLModule",
10
+ error=False,
11
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TensorFlow policy class used for APPO.
3
+
4
+ Adapted from VTraceTFPolicy to use the PPO surrogate loss.
5
+ Keep in sync with changes to VTraceTFPolicy.
6
+ """
7
+
8
+ import numpy as np
9
+ import logging
10
+ import gymnasium as gym
11
+ from typing import Dict, List, Optional, Type, Union
12
+
13
+ from ray.rllib.algorithms.appo.utils import make_appo_models
14
+ from ray.rllib.algorithms.impala import vtrace_tf as vtrace
15
+ from ray.rllib.algorithms.impala.impala_tf_policy import (
16
+ _make_time_major,
17
+ VTraceClipGradients,
18
+ VTraceOptimizer,
19
+ )
20
+ from ray.rllib.evaluation.postprocessing import (
21
+ compute_bootstrap_value,
22
+ compute_gae_for_sample_batch,
23
+ Postprocessing,
24
+ )
25
+ from ray.rllib.models.tf.tf_action_dist import Categorical
26
+ from ray.rllib.policy.sample_batch import SampleBatch
27
+ from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
28
+ from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
29
+ from ray.rllib.policy.tf_mixins import (
30
+ EntropyCoeffSchedule,
31
+ LearningRateSchedule,
32
+ KLCoeffMixin,
33
+ ValueNetworkMixin,
34
+ GradStatsMixin,
35
+ TargetNetworkMixin,
36
+ )
37
+ from ray.rllib.models.modelv2 import ModelV2
38
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
39
+ from ray.rllib.utils.annotations import (
40
+ override,
41
+ )
42
+ from ray.rllib.utils.framework import try_import_tf
43
+ from ray.rllib.utils.tf_utils import explained_variance
44
+ from ray.rllib.utils.typing import TensorType
45
+
46
+ tf1, tf, tfv = try_import_tf()
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ # TODO (sven): Deprecate once APPO and IMPALA fully on RLModules/Learner APIs.
52
+ def get_appo_tf_policy(name: str, base: type) -> type:
53
+ """Construct an APPOTFPolicy inheriting either dynamic or eager base policies.
54
+
55
+ Args:
56
+ base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
57
+
58
+ Returns:
59
+ A TF Policy to be used with Impala.
60
+ """
61
+
62
+ class APPOTFPolicy(
63
+ VTraceClipGradients,
64
+ VTraceOptimizer,
65
+ LearningRateSchedule,
66
+ KLCoeffMixin,
67
+ EntropyCoeffSchedule,
68
+ ValueNetworkMixin,
69
+ TargetNetworkMixin,
70
+ GradStatsMixin,
71
+ base,
72
+ ):
73
+ def __init__(
74
+ self,
75
+ observation_space,
76
+ action_space,
77
+ config,
78
+ existing_model=None,
79
+ existing_inputs=None,
80
+ ):
81
+ # First thing first, enable eager execution if necessary.
82
+ base.enable_eager_execution_if_necessary()
83
+
84
+ # Although this is a no-op, we call __init__ here to make it clear
85
+ # that base.__init__ will use the make_model() call.
86
+ VTraceClipGradients.__init__(self)
87
+ VTraceOptimizer.__init__(self)
88
+
89
+ # Initialize base class.
90
+ base.__init__(
91
+ self,
92
+ observation_space,
93
+ action_space,
94
+ config,
95
+ existing_inputs=existing_inputs,
96
+ existing_model=existing_model,
97
+ )
98
+
99
+ # TF LearningRateSchedule depends on self.framework, so initialize
100
+ # after base.__init__() is called.
101
+ LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
102
+ EntropyCoeffSchedule.__init__(
103
+ self, config["entropy_coeff"], config["entropy_coeff_schedule"]
104
+ )
105
+ ValueNetworkMixin.__init__(self, config)
106
+ KLCoeffMixin.__init__(self, config)
107
+
108
+ GradStatsMixin.__init__(self)
109
+
110
+ # Note: this is a bit ugly, but loss and optimizer initialization must
111
+ # happen after all the MixIns are initialized.
112
+ self.maybe_initialize_optimizer_and_loss()
113
+
114
+ # Initiate TargetNetwork ops after loss initialization.
115
+ TargetNetworkMixin.__init__(self)
116
+
117
+ @override(base)
118
+ def make_model(self) -> ModelV2:
119
+ return make_appo_models(self)
120
+
121
+ @override(base)
122
+ def loss(
123
+ self,
124
+ model: Union[ModelV2, "tf.keras.Model"],
125
+ dist_class: Type[TFActionDistribution],
126
+ train_batch: SampleBatch,
127
+ ) -> Union[TensorType, List[TensorType]]:
128
+ model_out, _ = model(train_batch)
129
+ action_dist = dist_class(model_out, model)
130
+
131
+ if isinstance(self.action_space, gym.spaces.Discrete):
132
+ is_multidiscrete = False
133
+ output_hidden_shape = [self.action_space.n]
134
+ elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
135
+ is_multidiscrete = True
136
+ output_hidden_shape = self.action_space.nvec.astype(np.int32)
137
+ else:
138
+ is_multidiscrete = False
139
+ output_hidden_shape = 1
140
+
141
+ def make_time_major(*args, **kw):
142
+ return _make_time_major(
143
+ self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
144
+ )
145
+
146
+ actions = train_batch[SampleBatch.ACTIONS]
147
+ dones = train_batch[SampleBatch.TERMINATEDS]
148
+ rewards = train_batch[SampleBatch.REWARDS]
149
+ behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
150
+
151
+ target_model_out, _ = self.target_model(train_batch)
152
+ prev_action_dist = dist_class(behaviour_logits, self.model)
153
+ values = self.model.value_function()
154
+ values_time_major = make_time_major(values)
155
+ bootstrap_values_time_major = make_time_major(
156
+ train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
157
+ )
158
+ bootstrap_value = bootstrap_values_time_major[-1]
159
+
160
+ if self.is_recurrent():
161
+ max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
162
+ mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
163
+ mask = tf.reshape(mask, [-1])
164
+ mask = make_time_major(mask)
165
+
166
+ def reduce_mean_valid(t):
167
+ return tf.reduce_mean(tf.boolean_mask(t, mask))
168
+
169
+ else:
170
+ reduce_mean_valid = tf.reduce_mean
171
+
172
+ if self.config["vtrace"]:
173
+ logger.debug("Using V-Trace surrogate loss (vtrace=True)")
174
+
175
+ # Prepare actions for loss.
176
+ loss_actions = (
177
+ actions if is_multidiscrete else tf.expand_dims(actions, axis=1)
178
+ )
179
+
180
+ old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
181
+ old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
182
+
183
+ # Prepare KL for Loss
184
+ mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist))
185
+
186
+ unpacked_behaviour_logits = tf.split(
187
+ behaviour_logits, output_hidden_shape, axis=1
188
+ )
189
+ unpacked_old_policy_behaviour_logits = tf.split(
190
+ old_policy_behaviour_logits, output_hidden_shape, axis=1
191
+ )
192
+
193
+ # Compute vtrace on the CPU for better perf.
194
+ with tf.device("/cpu:0"):
195
+ vtrace_returns = vtrace.multi_from_logits(
196
+ behaviour_policy_logits=make_time_major(
197
+ unpacked_behaviour_logits
198
+ ),
199
+ target_policy_logits=make_time_major(
200
+ unpacked_old_policy_behaviour_logits
201
+ ),
202
+ actions=tf.unstack(make_time_major(loss_actions), axis=2),
203
+ discounts=tf.cast(
204
+ ~make_time_major(tf.cast(dones, tf.bool)),
205
+ tf.float32,
206
+ )
207
+ * self.config["gamma"],
208
+ rewards=make_time_major(rewards),
209
+ values=values_time_major,
210
+ bootstrap_value=bootstrap_value,
211
+ dist_class=Categorical if is_multidiscrete else dist_class,
212
+ model=model,
213
+ clip_rho_threshold=tf.cast(
214
+ self.config["vtrace_clip_rho_threshold"], tf.float32
215
+ ),
216
+ clip_pg_rho_threshold=tf.cast(
217
+ self.config["vtrace_clip_pg_rho_threshold"], tf.float32
218
+ ),
219
+ )
220
+
221
+ actions_logp = make_time_major(action_dist.logp(actions))
222
+ prev_actions_logp = make_time_major(prev_action_dist.logp(actions))
223
+ old_policy_actions_logp = make_time_major(
224
+ old_policy_action_dist.logp(actions)
225
+ )
226
+
227
+ is_ratio = tf.clip_by_value(
228
+ tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
229
+ )
230
+ logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
231
+ self._is_ratio = is_ratio
232
+
233
+ advantages = vtrace_returns.pg_advantages
234
+ surrogate_loss = tf.minimum(
235
+ advantages * logp_ratio,
236
+ advantages
237
+ * tf.clip_by_value(
238
+ logp_ratio,
239
+ 1 - self.config["clip_param"],
240
+ 1 + self.config["clip_param"],
241
+ ),
242
+ )
243
+
244
+ action_kl = (
245
+ tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
246
+ )
247
+ mean_kl_loss = reduce_mean_valid(action_kl)
248
+ mean_policy_loss = -reduce_mean_valid(surrogate_loss)
249
+
250
+ # The value function loss.
251
+ value_targets = vtrace_returns.vs
252
+ delta = values_time_major - value_targets
253
+ mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
254
+
255
+ # The entropy loss.
256
+ actions_entropy = make_time_major(action_dist.multi_entropy())
257
+ mean_entropy = reduce_mean_valid(actions_entropy)
258
+
259
+ else:
260
+ logger.debug("Using PPO surrogate loss (vtrace=False)")
261
+
262
+ # Prepare KL for Loss
263
+ mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
264
+
265
+ logp_ratio = tf.math.exp(
266
+ make_time_major(action_dist.logp(actions))
267
+ - make_time_major(prev_action_dist.logp(actions))
268
+ )
269
+
270
+ advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
271
+ surrogate_loss = tf.minimum(
272
+ advantages * logp_ratio,
273
+ advantages
274
+ * tf.clip_by_value(
275
+ logp_ratio,
276
+ 1 - self.config["clip_param"],
277
+ 1 + self.config["clip_param"],
278
+ ),
279
+ )
280
+
281
+ action_kl = (
282
+ tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
283
+ )
284
+ mean_kl_loss = reduce_mean_valid(action_kl)
285
+ mean_policy_loss = -reduce_mean_valid(surrogate_loss)
286
+
287
+ # The value function loss.
288
+ value_targets = make_time_major(
289
+ train_batch[Postprocessing.VALUE_TARGETS]
290
+ )
291
+ delta = values_time_major - value_targets
292
+ mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
293
+
294
+ # The entropy loss.
295
+ mean_entropy = reduce_mean_valid(
296
+ make_time_major(action_dist.multi_entropy())
297
+ )
298
+
299
+ # The summed weighted loss.
300
+ total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
301
+ # Optional KL loss.
302
+ if self.config["use_kl_loss"]:
303
+ total_loss += self.kl_coeff * mean_kl_loss
304
+ # Optional vf loss (or in a separate term due to separate
305
+ # optimizers/networks).
306
+ loss_wo_vf = total_loss
307
+ if not self.config["_separate_vf_optimizer"]:
308
+ total_loss += mean_vf_loss * self.config["vf_loss_coeff"]
309
+
310
+ # Store stats in policy for stats_fn.
311
+ self._total_loss = total_loss
312
+ self._loss_wo_vf = loss_wo_vf
313
+ self._mean_policy_loss = mean_policy_loss
314
+ # Backward compatibility: Deprecate policy._mean_kl.
315
+ self._mean_kl_loss = self._mean_kl = mean_kl_loss
316
+ self._mean_vf_loss = mean_vf_loss
317
+ self._mean_entropy = mean_entropy
318
+ self._value_targets = value_targets
319
+
320
+ # Return one total loss or two losses: vf vs rest (policy + kl).
321
+ if self.config["_separate_vf_optimizer"]:
322
+ return loss_wo_vf, mean_vf_loss
323
+ else:
324
+ return total_loss
325
+
326
+ @override(base)
327
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
328
+ values_batched = _make_time_major(
329
+ self,
330
+ train_batch.get(SampleBatch.SEQ_LENS),
331
+ self.model.value_function(),
332
+ )
333
+
334
+ stats_dict = {
335
+ "cur_lr": tf.cast(self.cur_lr, tf.float64),
336
+ "total_loss": self._total_loss,
337
+ "policy_loss": self._mean_policy_loss,
338
+ "entropy": self._mean_entropy,
339
+ "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()),
340
+ "vf_loss": self._mean_vf_loss,
341
+ "vf_explained_var": explained_variance(
342
+ tf.reshape(self._value_targets, [-1]),
343
+ tf.reshape(values_batched, [-1]),
344
+ ),
345
+ "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
346
+ }
347
+
348
+ if self.config["vtrace"]:
349
+ is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1])
350
+ stats_dict["mean_IS"] = is_stat_mean
351
+ stats_dict["var_IS"] = is_stat_var
352
+
353
+ if self.config["use_kl_loss"]:
354
+ stats_dict["kl"] = self._mean_kl_loss
355
+ stats_dict["KL_Coeff"] = self.kl_coeff
356
+
357
+ return stats_dict
358
+
359
+ @override(base)
360
+ def postprocess_trajectory(
361
+ self,
362
+ sample_batch: SampleBatch,
363
+ other_agent_batches: Optional[SampleBatch] = None,
364
+ episode=None,
365
+ ):
366
+ # Call super's postprocess_trajectory first.
367
+ # sample_batch = super().postprocess_trajectory(
368
+ # sample_batch, other_agent_batches, episode
369
+ # )
370
+
371
+ if not self.config["vtrace"]:
372
+ sample_batch = compute_gae_for_sample_batch(
373
+ self, sample_batch, other_agent_batches, episode
374
+ )
375
+ else:
376
+ # Add the Columns.VALUES_BOOTSTRAPPED column, which we'll need
377
+ # inside the loss for vtrace calculations.
378
+ sample_batch = compute_bootstrap_value(sample_batch, self)
379
+
380
+ return sample_batch
381
+
382
+ @override(base)
383
+ def get_batch_divisibility_req(self) -> int:
384
+ return self.config["rollout_fragment_length"]
385
+
386
+ APPOTFPolicy.__name__ = name
387
+ APPOTFPolicy.__qualname__ = name
388
+
389
+ return APPOTFPolicy
390
+
391
+
392
+ APPOTF1Policy = get_appo_tf_policy("APPOTF1Policy", DynamicTFPolicyV2)
393
+ APPOTF2Policy = get_appo_tf_policy("APPOTF2Policy", EagerTFPolicyV2)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_torch_policy.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch policy class used for APPO.
3
+
4
+ Adapted from VTraceTFPolicy to use the PPO surrogate loss.
5
+ Keep in sync with changes to VTraceTFPolicy.
6
+ """
7
+
8
+ import gymnasium as gym
9
+ import numpy as np
10
+ import logging
11
+ from typing import Any, Dict, List, Optional, Type, Union
12
+
13
+ import ray
14
+ from ray.rllib.algorithms.appo.utils import make_appo_models
15
+ import ray.rllib.algorithms.impala.vtrace_torch as vtrace
16
+ from ray.rllib.algorithms.impala.impala_torch_policy import (
17
+ make_time_major,
18
+ VTraceOptimizer,
19
+ )
20
+ from ray.rllib.evaluation.postprocessing import (
21
+ compute_bootstrap_value,
22
+ compute_gae_for_sample_batch,
23
+ Postprocessing,
24
+ )
25
+ from ray.rllib.models.action_dist import ActionDistribution
26
+ from ray.rllib.models.modelv2 import ModelV2
27
+ from ray.rllib.models.torch.torch_action_dist import (
28
+ TorchDistributionWrapper,
29
+ TorchCategorical,
30
+ )
31
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
32
+ from ray.rllib.policy.sample_batch import SampleBatch
33
+ from ray.rllib.policy.torch_mixins import (
34
+ EntropyCoeffSchedule,
35
+ LearningRateSchedule,
36
+ KLCoeffMixin,
37
+ ValueNetworkMixin,
38
+ TargetNetworkMixin,
39
+ )
40
+ from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
41
+ from ray.rllib.utils.annotations import override
42
+ from ray.rllib.utils.framework import try_import_torch
43
+ from ray.rllib.utils.numpy import convert_to_numpy
44
+ from ray.rllib.utils.torch_utils import (
45
+ apply_grad_clipping,
46
+ explained_variance,
47
+ global_norm,
48
+ sequence_mask,
49
+ )
50
+ from ray.rllib.utils.typing import TensorType
51
+
52
+ torch, nn = try_import_torch()
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ # TODO (sven): Deprecate once APPO and IMPALA fully on RLModules/Learner APIs.
58
+ class APPOTorchPolicy(
59
+ VTraceOptimizer,
60
+ LearningRateSchedule,
61
+ EntropyCoeffSchedule,
62
+ KLCoeffMixin,
63
+ ValueNetworkMixin,
64
+ TargetNetworkMixin,
65
+ TorchPolicyV2,
66
+ ):
67
+ """PyTorch policy class used with APPO."""
68
+
69
+ def __init__(self, observation_space, action_space, config):
70
+ config = dict(ray.rllib.algorithms.appo.appo.APPOConfig().to_dict(), **config)
71
+ config["enable_rl_module_and_learner"] = False
72
+ config["enable_env_runner_and_connector_v2"] = False
73
+
74
+ # Although this is a no-op, we call __init__ here to make it clear
75
+ # that base.__init__ will use the make_model() call.
76
+ VTraceOptimizer.__init__(self)
77
+
78
+ lr_schedule_additional_args = []
79
+ if config.get("_separate_vf_optimizer"):
80
+ lr_schedule_additional_args = (
81
+ [config["_lr_vf"][0][1], config["_lr_vf"]]
82
+ if isinstance(config["_lr_vf"], (list, tuple))
83
+ else [config["_lr_vf"], None]
84
+ )
85
+ LearningRateSchedule.__init__(
86
+ self, config["lr"], config["lr_schedule"], *lr_schedule_additional_args
87
+ )
88
+
89
+ TorchPolicyV2.__init__(
90
+ self,
91
+ observation_space,
92
+ action_space,
93
+ config,
94
+ max_seq_len=config["model"]["max_seq_len"],
95
+ )
96
+
97
+ EntropyCoeffSchedule.__init__(
98
+ self, config["entropy_coeff"], config["entropy_coeff_schedule"]
99
+ )
100
+ ValueNetworkMixin.__init__(self, config)
101
+ KLCoeffMixin.__init__(self, config)
102
+
103
+ self._initialize_loss_from_dummy_batch()
104
+
105
+ # Initiate TargetNetwork ops after loss initialization.
106
+ TargetNetworkMixin.__init__(self)
107
+
108
+ @override(TorchPolicyV2)
109
+ def init_view_requirements(self):
110
+ self.view_requirements = self._get_default_view_requirements()
111
+
112
+ @override(TorchPolicyV2)
113
+ def make_model(self) -> ModelV2:
114
+ return make_appo_models(self)
115
+
116
+ @override(TorchPolicyV2)
117
+ def loss(
118
+ self,
119
+ model: ModelV2,
120
+ dist_class: Type[ActionDistribution],
121
+ train_batch: SampleBatch,
122
+ ) -> Union[TensorType, List[TensorType]]:
123
+ """Constructs the loss for APPO.
124
+
125
+ With IS modifications and V-trace for Advantage Estimation.
126
+
127
+ Args:
128
+ model (ModelV2): The Model to calculate the loss for.
129
+ dist_class (Type[ActionDistribution]): The action distr. class.
130
+ train_batch: The training data.
131
+
132
+ Returns:
133
+ Union[TensorType, List[TensorType]]: A single loss tensor or a list
134
+ of loss tensors.
135
+ """
136
+ target_model = self.target_models[model]
137
+
138
+ model_out, _ = model(train_batch)
139
+ action_dist = dist_class(model_out, model)
140
+
141
+ if isinstance(self.action_space, gym.spaces.Discrete):
142
+ is_multidiscrete = False
143
+ output_hidden_shape = [self.action_space.n]
144
+ elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
145
+ is_multidiscrete = True
146
+ output_hidden_shape = self.action_space.nvec.astype(np.int32)
147
+ else:
148
+ is_multidiscrete = False
149
+ output_hidden_shape = 1
150
+
151
+ def _make_time_major(*args, **kwargs):
152
+ return make_time_major(
153
+ self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs
154
+ )
155
+
156
+ actions = train_batch[SampleBatch.ACTIONS]
157
+ dones = train_batch[SampleBatch.TERMINATEDS]
158
+ rewards = train_batch[SampleBatch.REWARDS]
159
+ behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
160
+
161
+ target_model_out, _ = target_model(train_batch)
162
+
163
+ prev_action_dist = dist_class(behaviour_logits, model)
164
+ values = model.value_function()
165
+ values_time_major = _make_time_major(values)
166
+ bootstrap_values_time_major = _make_time_major(
167
+ train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
168
+ )
169
+ bootstrap_value = bootstrap_values_time_major[-1]
170
+
171
+ if self.is_recurrent():
172
+ max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
173
+ mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
174
+ mask = torch.reshape(mask, [-1])
175
+ mask = _make_time_major(mask)
176
+ num_valid = torch.sum(mask)
177
+
178
+ def reduce_mean_valid(t):
179
+ return torch.sum(t[mask]) / num_valid
180
+
181
+ else:
182
+ reduce_mean_valid = torch.mean
183
+
184
+ if self.config["vtrace"]:
185
+ logger.debug("Using V-Trace surrogate loss (vtrace=True)")
186
+
187
+ old_policy_behaviour_logits = target_model_out.detach()
188
+ old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
189
+
190
+ if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
191
+ unpacked_behaviour_logits = torch.split(
192
+ behaviour_logits, list(output_hidden_shape), dim=1
193
+ )
194
+ unpacked_old_policy_behaviour_logits = torch.split(
195
+ old_policy_behaviour_logits, list(output_hidden_shape), dim=1
196
+ )
197
+ else:
198
+ unpacked_behaviour_logits = torch.chunk(
199
+ behaviour_logits, output_hidden_shape, dim=1
200
+ )
201
+ unpacked_old_policy_behaviour_logits = torch.chunk(
202
+ old_policy_behaviour_logits, output_hidden_shape, dim=1
203
+ )
204
+
205
+ # Prepare actions for loss.
206
+ loss_actions = (
207
+ actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
208
+ )
209
+
210
+ # Prepare KL for loss.
211
+ action_kl = _make_time_major(old_policy_action_dist.kl(action_dist))
212
+
213
+ # Compute vtrace on the CPU for better perf.
214
+ vtrace_returns = vtrace.multi_from_logits(
215
+ behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits),
216
+ target_policy_logits=_make_time_major(
217
+ unpacked_old_policy_behaviour_logits
218
+ ),
219
+ actions=torch.unbind(_make_time_major(loss_actions), dim=2),
220
+ discounts=(1.0 - _make_time_major(dones).float())
221
+ * self.config["gamma"],
222
+ rewards=_make_time_major(rewards),
223
+ values=values_time_major,
224
+ bootstrap_value=bootstrap_value,
225
+ dist_class=TorchCategorical if is_multidiscrete else dist_class,
226
+ model=model,
227
+ clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
228
+ clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
229
+ )
230
+
231
+ actions_logp = _make_time_major(action_dist.logp(actions))
232
+ prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
233
+ old_policy_actions_logp = _make_time_major(
234
+ old_policy_action_dist.logp(actions)
235
+ )
236
+ is_ratio = torch.clamp(
237
+ torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
238
+ )
239
+ logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
240
+ self._is_ratio = is_ratio
241
+
242
+ advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
243
+ surrogate_loss = torch.min(
244
+ advantages * logp_ratio,
245
+ advantages
246
+ * torch.clamp(
247
+ logp_ratio,
248
+ 1 - self.config["clip_param"],
249
+ 1 + self.config["clip_param"],
250
+ ),
251
+ )
252
+
253
+ mean_kl_loss = reduce_mean_valid(action_kl)
254
+ mean_policy_loss = -reduce_mean_valid(surrogate_loss)
255
+
256
+ # The value function loss.
257
+ value_targets = vtrace_returns.vs.to(values_time_major.device)
258
+ delta = values_time_major - value_targets
259
+ mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
260
+
261
+ # The entropy loss.
262
+ mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))
263
+
264
+ else:
265
+ logger.debug("Using PPO surrogate loss (vtrace=False)")
266
+
267
+ # Prepare KL for Loss
268
+ action_kl = _make_time_major(prev_action_dist.kl(action_dist))
269
+
270
+ actions_logp = _make_time_major(action_dist.logp(actions))
271
+ prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
272
+ logp_ratio = torch.exp(actions_logp - prev_actions_logp)
273
+
274
+ advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
275
+ surrogate_loss = torch.min(
276
+ advantages * logp_ratio,
277
+ advantages
278
+ * torch.clamp(
279
+ logp_ratio,
280
+ 1 - self.config["clip_param"],
281
+ 1 + self.config["clip_param"],
282
+ ),
283
+ )
284
+
285
+ mean_kl_loss = reduce_mean_valid(action_kl)
286
+ mean_policy_loss = -reduce_mean_valid(surrogate_loss)
287
+
288
+ # The value function loss.
289
+ value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
290
+ delta = values_time_major - value_targets
291
+ mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
292
+
293
+ # The entropy loss.
294
+ mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))
295
+
296
+ # The summed weighted loss.
297
+ total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
298
+ # Optional additional KL Loss
299
+ if self.config["use_kl_loss"]:
300
+ total_loss += self.kl_coeff * mean_kl_loss
301
+
302
+ # Optional vf loss (or in a separate term due to separate
303
+ # optimizers/networks).
304
+ loss_wo_vf = total_loss
305
+ if not self.config["_separate_vf_optimizer"]:
306
+ total_loss += mean_vf_loss * self.config["vf_loss_coeff"]
307
+
308
+ # Store values for stats function in model (tower), such that for
309
+ # multi-GPU, we do not override them during the parallel loss phase.
310
+ model.tower_stats["total_loss"] = total_loss
311
+ model.tower_stats["mean_policy_loss"] = mean_policy_loss
312
+ model.tower_stats["mean_kl_loss"] = mean_kl_loss
313
+ model.tower_stats["mean_vf_loss"] = mean_vf_loss
314
+ model.tower_stats["mean_entropy"] = mean_entropy
315
+ model.tower_stats["value_targets"] = value_targets
316
+ model.tower_stats["vf_explained_var"] = explained_variance(
317
+ torch.reshape(value_targets, [-1]),
318
+ torch.reshape(values_time_major, [-1]),
319
+ )
320
+
321
+ # Return one total loss or two losses: vf vs rest (policy + kl).
322
+ if self.config["_separate_vf_optimizer"]:
323
+ return loss_wo_vf, mean_vf_loss
324
+ else:
325
+ return total_loss
326
+
327
+ @override(TorchPolicyV2)
328
+ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
329
+ """Stats function for APPO. Returns a dict with important loss stats.
330
+
331
+ Args:
332
+ policy: The Policy to generate stats for.
333
+ train_batch: The SampleBatch (already) used for training.
334
+
335
+ Returns:
336
+ Dict[str, TensorType]: The stats dict.
337
+ """
338
+ stats_dict = {
339
+ "cur_lr": self.cur_lr,
340
+ "total_loss": torch.mean(torch.stack(self.get_tower_stats("total_loss"))),
341
+ "policy_loss": torch.mean(
342
+ torch.stack(self.get_tower_stats("mean_policy_loss"))
343
+ ),
344
+ "entropy": torch.mean(torch.stack(self.get_tower_stats("mean_entropy"))),
345
+ "entropy_coeff": self.entropy_coeff,
346
+ "var_gnorm": global_norm(self.model.trainable_variables()),
347
+ "vf_loss": torch.mean(torch.stack(self.get_tower_stats("mean_vf_loss"))),
348
+ "vf_explained_var": torch.mean(
349
+ torch.stack(self.get_tower_stats("vf_explained_var"))
350
+ ),
351
+ }
352
+
353
+ if self.config["vtrace"]:
354
+ is_stat_mean = torch.mean(self._is_ratio, [0, 1])
355
+ is_stat_var = torch.var(self._is_ratio, [0, 1])
356
+ stats_dict["mean_IS"] = is_stat_mean
357
+ stats_dict["var_IS"] = is_stat_var
358
+
359
+ if self.config["use_kl_loss"]:
360
+ stats_dict["kl"] = torch.mean(
361
+ torch.stack(self.get_tower_stats("mean_kl_loss"))
362
+ )
363
+ stats_dict["KL_Coeff"] = self.kl_coeff
364
+
365
+ return convert_to_numpy(stats_dict)
366
+
367
+ @override(TorchPolicyV2)
368
+ def extra_action_out(
369
+ self,
370
+ input_dict: Dict[str, TensorType],
371
+ state_batches: List[TensorType],
372
+ model: TorchModelV2,
373
+ action_dist: TorchDistributionWrapper,
374
+ ) -> Dict[str, TensorType]:
375
+ return {SampleBatch.VF_PREDS: model.value_function()}
376
+
377
+ @override(TorchPolicyV2)
378
+ def postprocess_trajectory(
379
+ self,
380
+ sample_batch: SampleBatch,
381
+ other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
382
+ episode=None,
383
+ ):
384
+ # Call super's postprocess_trajectory first.
385
+ # sample_batch = super().postprocess_trajectory(
386
+ # sample_batch, other_agent_batches, episode
387
+ # )
388
+
389
+ # Do all post-processing always with no_grad().
390
+ # Not using this here will introduce a memory leak
391
+ # in torch (issue #6962).
392
+ with torch.no_grad():
393
+ if not self.config["vtrace"]:
394
+ sample_batch = compute_gae_for_sample_batch(
395
+ self, sample_batch, other_agent_batches, episode
396
+ )
397
+ else:
398
+ # Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need
399
+ # inside the loss for vtrace calculations.
400
+ sample_batch = compute_bootstrap_value(sample_batch, self)
401
+
402
+ return sample_batch
403
+
404
+ @override(TorchPolicyV2)
405
+ def extra_grad_process(
406
+ self, optimizer: "torch.optim.Optimizer", loss: TensorType
407
+ ) -> Dict[str, TensorType]:
408
+ return apply_grad_clipping(self, optimizer, loss)
409
+
410
+ @override(TorchPolicyV2)
411
+ def get_batch_divisibility_req(self) -> int:
412
+ return self.config["rollout_fragment_length"]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/default_appo_rl_module.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Any, Dict, List, Tuple
3
+
4
+ from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule
5
+ from ray.rllib.core.learner.utils import make_target_network
6
+ from ray.rllib.core.models.base import ACTOR
7
+ from ray.rllib.core.models.tf.encoder import ENCODER_OUT
8
+ from ray.rllib.core.rl_module.apis import (
9
+ TARGET_NETWORK_ACTION_DIST_INPUTS,
10
+ TargetNetworkAPI,
11
+ )
12
+ from ray.rllib.utils.typing import NetworkType
13
+
14
+ from ray.rllib.utils.annotations import (
15
+ override,
16
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
17
+ )
18
+ from ray.util.annotations import DeveloperAPI
19
+
20
+
21
+ @DeveloperAPI
22
+ class DefaultAPPORLModule(DefaultPPORLModule, TargetNetworkAPI, abc.ABC):
23
+ """Default RLModule used by APPO, if user does not specify a custom RLModule.
24
+
25
+ Users who want to train their RLModules with APPO may implement any RLModule
26
+ (or TorchRLModule) subclass as long as the custom class also implements the
27
+ `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py)
28
+ and the `TargetNetworkAPI` (see
29
+ ray.rllib.core.rl_module.apis.target_network_api.py).
30
+ """
31
+
32
+ @override(TargetNetworkAPI)
33
+ def make_target_networks(self):
34
+ self._old_encoder = make_target_network(self.encoder)
35
+ self._old_pi = make_target_network(self.pi)
36
+
37
+ @override(TargetNetworkAPI)
38
+ def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
39
+ return [
40
+ (self.encoder, self._old_encoder),
41
+ (self.pi, self._old_pi),
42
+ ]
43
+
44
+ @override(TargetNetworkAPI)
45
+ def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
46
+ old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
47
+ old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
48
+ return {TARGET_NETWORK_ACTION_DIST_INPUTS: old_action_dist_logits}
49
+
50
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
51
+ @override(DefaultPPORLModule)
52
+ def get_non_inference_attributes(self) -> List[str]:
53
+ # Get the NON inference-only attributes from the parent class
54
+ # `PPOTorchRLModule`.
55
+ ret = super().get_non_inference_attributes()
56
+ # Add the two (APPO) target networks to it (NOT needed in
57
+ # inference-only mode).
58
+ ret += ["_old_encoder", "_old_pi"]
59
+ return ret
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (204 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_learner.cpython-311.pyc ADDED
Binary file (9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_rl_module.cpython-311.pyc ADDED
Binary file (709 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/default_appo_torch_rl_module.cpython-311.pyc ADDED
Binary file (818 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_learner.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Asynchronous Proximal Policy Optimization (APPO)
2
+
3
+ The algorithm is described in [1] (under the name of "IMPACT"):
4
+
5
+ Detailed documentation:
6
+ https://docs.ray.io/en/master/rllib-algorithms.html#appo
7
+
8
+ [1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
9
+ Luo et al. 2020
10
+ https://arxiv.org/pdf/1912.00167
11
+ """
12
+ from typing import Dict
13
+
14
+ from ray.rllib.algorithms.appo.appo import (
15
+ APPOConfig,
16
+ LEARNER_RESULTS_CURR_KL_COEFF_KEY,
17
+ LEARNER_RESULTS_KL_KEY,
18
+ )
19
+ from ray.rllib.algorithms.appo.appo_learner import APPOLearner
20
+ from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner
21
+ from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import (
22
+ make_time_major,
23
+ vtrace_torch,
24
+ )
25
+ from ray.rllib.core.columns import Columns
26
+ from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
27
+ from ray.rllib.core.rl_module.apis import (
28
+ TARGET_NETWORK_ACTION_DIST_INPUTS,
29
+ TargetNetworkAPI,
30
+ ValueFunctionAPI,
31
+ )
32
+ from ray.rllib.utils.annotations import override
33
+ from ray.rllib.utils.framework import try_import_torch
34
+ from ray.rllib.utils.numpy import convert_to_numpy
35
+ from ray.rllib.utils.typing import ModuleID, TensorType
36
+
37
+ torch, nn = try_import_torch()
38
+
39
+
40
+ class APPOTorchLearner(APPOLearner, IMPALATorchLearner):
41
+ """Implements APPO loss / update logic on top of IMPALATorchLearner."""
42
+
43
+ @override(IMPALATorchLearner)
44
+ def compute_loss_for_module(
45
+ self,
46
+ *,
47
+ module_id: ModuleID,
48
+ config: APPOConfig,
49
+ batch: Dict,
50
+ fwd_out: Dict[str, TensorType],
51
+ ) -> TensorType:
52
+ module = self.module[module_id].unwrapped()
53
+ assert isinstance(module, TargetNetworkAPI)
54
+ assert isinstance(module, ValueFunctionAPI)
55
+
56
+ # TODO (sven): Now that we do the +1ts trick to be less vulnerable about
57
+ # bootstrap values at the end of rollouts in the new stack, we might make
58
+ # this a more flexible, configurable parameter for users, e.g.
59
+ # `v_trace_seq_len` (independent of `rollout_fragment_length`). Separation
60
+ # of concerns (sampling vs learning).
61
+ rollout_frag_or_episode_len = config.get_rollout_fragment_length()
62
+ recurrent_seq_len = batch.get("seq_lens")
63
+
64
+ loss_mask = batch[Columns.LOSS_MASK].float()
65
+ loss_mask_time_major = make_time_major(
66
+ loss_mask,
67
+ trajectory_len=rollout_frag_or_episode_len,
68
+ recurrent_seq_len=recurrent_seq_len,
69
+ )
70
+ size_loss_mask = torch.sum(loss_mask)
71
+
72
+ values = module.compute_values(
73
+ batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
74
+ )
75
+
76
+ action_dist_cls_train = module.get_train_action_dist_cls()
77
+ target_policy_dist = action_dist_cls_train.from_logits(
78
+ fwd_out[Columns.ACTION_DIST_INPUTS]
79
+ )
80
+
81
+ old_target_policy_dist = action_dist_cls_train.from_logits(
82
+ module.forward_target(batch)[TARGET_NETWORK_ACTION_DIST_INPUTS]
83
+ )
84
+ old_target_policy_actions_logp = old_target_policy_dist.logp(
85
+ batch[Columns.ACTIONS]
86
+ )
87
+ behaviour_actions_logp = batch[Columns.ACTION_LOGP]
88
+ target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])
89
+
90
+ behaviour_actions_logp_time_major = make_time_major(
91
+ behaviour_actions_logp,
92
+ trajectory_len=rollout_frag_or_episode_len,
93
+ recurrent_seq_len=recurrent_seq_len,
94
+ )
95
+ target_actions_logp_time_major = make_time_major(
96
+ target_actions_logp,
97
+ trajectory_len=rollout_frag_or_episode_len,
98
+ recurrent_seq_len=recurrent_seq_len,
99
+ )
100
+ old_actions_logp_time_major = make_time_major(
101
+ old_target_policy_actions_logp,
102
+ trajectory_len=rollout_frag_or_episode_len,
103
+ recurrent_seq_len=recurrent_seq_len,
104
+ )
105
+ rewards_time_major = make_time_major(
106
+ batch[Columns.REWARDS],
107
+ trajectory_len=rollout_frag_or_episode_len,
108
+ recurrent_seq_len=recurrent_seq_len,
109
+ )
110
+ values_time_major = make_time_major(
111
+ values,
112
+ trajectory_len=rollout_frag_or_episode_len,
113
+ recurrent_seq_len=recurrent_seq_len,
114
+ )
115
+ assert Columns.VALUES_BOOTSTRAPPED not in batch
116
+ # Use as bootstrap values the vf-preds in the next "batch row", except
117
+ # for the very last row (which doesn't have a next row), for which the
118
+ # bootstrap value does not matter b/c it has a +1ts value at its end
119
+ # anyways. So we chose an arbitrary item (for simplicity of not having to
120
+ # move new data to the device).
121
+ bootstrap_values = torch.cat(
122
+ [
123
+ values_time_major[0][1:], # 0th ts values from "next row"
124
+ values_time_major[0][0:1], # <- can use any arbitrary value here
125
+ ],
126
+ dim=0,
127
+ )
128
+
129
+ # The discount factor that is used should be gamma except for timesteps where
130
+ # the episode is terminated. In that case, the discount factor should be 0.
131
+ discounts_time_major = (
132
+ 1.0
133
+ - make_time_major(
134
+ batch[Columns.TERMINATEDS],
135
+ trajectory_len=rollout_frag_or_episode_len,
136
+ recurrent_seq_len=recurrent_seq_len,
137
+ ).float()
138
+ ) * config.gamma
139
+
140
+ # Note that vtrace will compute the main loop on the CPU for better performance.
141
+ vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
142
+ target_action_log_probs=old_actions_logp_time_major,
143
+ behaviour_action_log_probs=behaviour_actions_logp_time_major,
144
+ discounts=discounts_time_major,
145
+ rewards=rewards_time_major,
146
+ values=values_time_major,
147
+ bootstrap_values=bootstrap_values,
148
+ clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
149
+ clip_rho_threshold=config.vtrace_clip_rho_threshold,
150
+ )
151
+ pg_advantages = pg_advantages * loss_mask_time_major
152
+
153
+ # The policy gradients loss.
154
+ is_ratio = torch.clip(
155
+ torch.exp(behaviour_actions_logp_time_major - old_actions_logp_time_major),
156
+ 0.0,
157
+ 2.0,
158
+ )
159
+ logp_ratio = is_ratio * torch.exp(
160
+ target_actions_logp_time_major - behaviour_actions_logp_time_major
161
+ )
162
+
163
+ surrogate_loss = torch.minimum(
164
+ pg_advantages * logp_ratio,
165
+ pg_advantages
166
+ * torch.clip(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
167
+ )
168
+
169
+ if config.use_kl_loss:
170
+ action_kl = old_target_policy_dist.kl(target_policy_dist) * loss_mask
171
+ mean_kl_loss = torch.sum(action_kl) / size_loss_mask
172
+ else:
173
+ mean_kl_loss = 0.0
174
+ mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)
175
+
176
+ # The baseline loss.
177
+ delta = values_time_major - vtrace_adjusted_target_values
178
+ vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
179
+ mean_vf_loss = vf_loss / size_loss_mask
180
+
181
+ # The entropy loss.
182
+ mean_entropy_loss = (
183
+ -torch.sum(target_policy_dist.entropy() * loss_mask) / size_loss_mask
184
+ )
185
+
186
+ # The summed weighted loss.
187
+ total_loss = (
188
+ mean_pi_loss
189
+ + (mean_vf_loss * config.vf_loss_coeff)
190
+ + (
191
+ mean_entropy_loss
192
+ * self.entropy_coeff_schedulers_per_module[
193
+ module_id
194
+ ].get_current_value()
195
+ )
196
+ + (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
197
+ )
198
+
199
+ # Log important loss stats.
200
+ self.metrics.log_dict(
201
+ {
202
+ POLICY_LOSS_KEY: mean_pi_loss,
203
+ VF_LOSS_KEY: mean_vf_loss,
204
+ ENTROPY_KEY: -mean_entropy_loss,
205
+ LEARNER_RESULTS_KL_KEY: mean_kl_loss,
206
+ LEARNER_RESULTS_CURR_KL_COEFF_KEY: (
207
+ self.curr_kl_coeffs_per_module[module_id]
208
+ ),
209
+ },
210
+ key=module_id,
211
+ window=1, # <- single items (should not be mean/ema-reduced over time).
212
+ )
213
+ # Return the total loss.
214
+ return total_loss
215
+
216
+ @override(APPOLearner)
217
+ def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
218
+ # Update the current KL value based on the recently measured value.
219
+ # Increase.
220
+ kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)))
221
+ kl_coeff_var = self.curr_kl_coeffs_per_module[module_id]
222
+
223
+ if kl > 2.0 * config.kl_target:
224
+ # TODO (Kourosh) why not *2.0?
225
+ kl_coeff_var.data *= 1.5
226
+ # Decrease.
227
+ elif kl < 0.5 * config.kl_target:
228
+ kl_coeff_var.data *= 0.5
229
+
230
+ self.metrics.log_value(
231
+ (module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY),
232
+ kl_coeff_var.item(),
233
+ window=1,
234
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_rl_module.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backward compat import.
2
+ from ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module import ( # noqa
3
+ DefaultAPPOTorchRLModule as APPOTorchRLModule,
4
+ )
5
+ from ray.rllib.utils.deprecation import deprecation_warning
6
+
7
+
8
+ deprecation_warning(
9
+ old="ray.rllib.algorithms.appo.torch.appo_torch_rl_module.APPOTorchRLModule",
10
+ new="ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module."
11
+ "DefaultAPPOTorchRLModule",
12
+ error=False,
13
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/default_appo_torch_rl_module.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.appo.default_appo_rl_module import DefaultAPPORLModule
2
+ from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
3
+ DefaultPPOTorchRLModule,
4
+ )
5
+ from ray.util.annotations import DeveloperAPI
6
+
7
+
8
+ @DeveloperAPI
9
+ class DefaultAPPOTorchRLModule(DefaultPPOTorchRLModule, DefaultAPPORLModule):
10
+ pass
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
3
+ Luo et al. 2020
4
+ https://arxiv.org/pdf/1912.00167
5
+ """
6
+ from collections import deque
7
+ import random
8
+ import threading
9
+ import time
10
+
11
+ from ray.rllib.models.catalog import ModelCatalog
12
+ from ray.rllib.models.modelv2 import ModelV2
13
+ from ray.rllib.utils.annotations import OldAPIStack
14
+
15
+
16
+ POLICY_SCOPE = "func"
17
+ TARGET_POLICY_SCOPE = "target_func"
18
+
19
+
20
+ class CircularBuffer:
21
+ """A circular batch-wise buffer as described in [1] for APPO.
22
+
23
+ The buffer holds at most N batches, which are sampled at random (uniformly).
24
+ If full and a new batch is added, the oldest batch is discarded. Also, each batch
25
+ currently in the buffer can be sampled at most K times (after which it is also
26
+ discarded).
27
+ """
28
+
29
+ def __init__(self, num_batches: int, iterations_per_batch: int):
30
+ # N from the paper (buffer size).
31
+ self.num_batches = num_batches
32
+ # K ("replay coefficient") from the paper.
33
+ self.iterations_per_batch = iterations_per_batch
34
+
35
+ self._buffer = deque(maxlen=self.num_batches)
36
+ self._lock = threading.Lock()
37
+
38
+ # The number of valid (not expired) entries in this buffer.
39
+ self._num_valid_batches = 0
40
+
41
+ def add(self, batch):
42
+ dropped_entry = None
43
+ dropped_ts = 0
44
+
45
+ # Add buffer and k=0 information to the deque.
46
+ with self._lock:
47
+ len_ = len(self._buffer)
48
+ if len_ == self.num_batches:
49
+ dropped_entry = self._buffer[0]
50
+ self._buffer.append([batch, 0])
51
+ self._num_valid_batches += 1
52
+
53
+ # A valid entry (w/ a batch whose k has not been reach K yet) was dropped.
54
+ if dropped_entry is not None and dropped_entry[0] is not None:
55
+ dropped_ts += dropped_entry[0].env_steps() * (
56
+ self.iterations_per_batch - dropped_entry[1]
57
+ )
58
+ self._num_valid_batches -= 1
59
+
60
+ return dropped_ts
61
+
62
+ def sample(self):
63
+ k = entry = batch = None
64
+
65
+ while True:
66
+ # Only initially, the buffer may be empty -> Just wait for some time.
67
+ if len(self) == 0:
68
+ time.sleep(0.001)
69
+ continue
70
+ # Sample a random buffer index.
71
+ with self._lock:
72
+ entry = self._buffer[random.randint(0, len(self._buffer) - 1)]
73
+ batch, k = entry
74
+ # Ignore batches that have already been invalidated.
75
+ if batch is not None:
76
+ break
77
+
78
+ # Increase k += 1 for this batch.
79
+ assert k is not None
80
+ entry[1] += 1
81
+
82
+ # This batch has been exhausted (k == K) -> Invalidate it in the buffer.
83
+ if k == self.iterations_per_batch - 1:
84
+ entry[0] = None
85
+ entry[1] = None
86
+ self._num_valid_batches += 1
87
+
88
+ # Return the sampled batch.
89
+ return batch
90
+
91
+ def __len__(self) -> int:
92
+ """Returns the number of actually valid (non-expired) batches in the buffer."""
93
+ return self._num_valid_batches
94
+
95
+
96
+ @OldAPIStack
97
+ def make_appo_models(policy) -> ModelV2:
98
+ """Builds model and target model for APPO.
99
+
100
+ Returns:
101
+ ModelV2: The Model for the Policy to use.
102
+ Note: The target model will not be returned, just assigned to
103
+ `policy.target_model`.
104
+ """
105
+ # Get the num_outputs for the following model construction calls.
106
+ _, logit_dim = ModelCatalog.get_action_dist(
107
+ policy.action_space, policy.config["model"]
108
+ )
109
+
110
+ # Construct the (main) model.
111
+ policy.model = ModelCatalog.get_model_v2(
112
+ policy.observation_space,
113
+ policy.action_space,
114
+ logit_dim,
115
+ policy.config["model"],
116
+ name=POLICY_SCOPE,
117
+ framework=policy.framework,
118
+ )
119
+ policy.model_variables = policy.model.variables()
120
+
121
+ # Construct the target model.
122
+ policy.target_model = ModelCatalog.get_model_v2(
123
+ policy.observation_space,
124
+ policy.action_space,
125
+ logit_dim,
126
+ policy.config["model"],
127
+ name=TARGET_POLICY_SCOPE,
128
+ framework=policy.framework,
129
+ )
130
+ policy.target_model_variables = policy.target_model.variables()
131
+
132
+ # Return only the model (not the target model).
133
+ return policy.model
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (672 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3.cpython-311.pyc ADDED
Binary file (32.2 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_catalog.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_learner.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_rl_module.cpython-311.pyc ADDED
Binary file (8.11 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (213 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/actor_network.cpython-311.pyc ADDED
Binary file (8.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/critic_network.cpython-311.pyc ADDED
Binary file (8.85 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/disagree_networks.cpython-311.pyc ADDED
Binary file (4.82 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/dreamer_model.cpython-311.pyc ADDED
Binary file (25 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/world_model.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (224 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/cnn_atari.cpython-311.pyc ADDED
Binary file (4.84 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/continue_predictor.cpython-311.pyc ADDED
Binary file (4.91 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/conv_transpose_atari.cpython-311.pyc ADDED
Binary file (7.44 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/dynamics_predictor.cpython-311.pyc ADDED
Binary file (4.23 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/mlp.cpython-311.pyc ADDED
Binary file (4.66 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/representation_layer.cpython-311.pyc ADDED
Binary file (5.96 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor.cpython-311.pyc ADDED
Binary file (5.52 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor_layer.cpython-311.pyc ADDED
Binary file (4.96 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/sequence_model.cpython-311.pyc ADDED
Binary file (6.51 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/vector_decoder.cpython-311.pyc ADDED
Binary file (4.59 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
7
+ from ray.rllib.algorithms.dreamerv3.utils import (
8
+ get_gru_units,
9
+ get_num_z_classes,
10
+ get_num_z_categoricals,
11
+ )
12
+ from ray.rllib.utils.framework import try_import_tf, try_import_tfp
13
+
14
+ _, tf, _ = try_import_tf()
15
+ tfp = try_import_tfp()
16
+
17
+
18
+ class ContinuePredictor(tf.keras.Model):
19
+ """The world-model network sub-component used to predict the `continue` flags .
20
+
21
+ Predicted continue flags are used to produce "dream data" to learn the policy in.
22
+
23
+ The continue flags are predicted via a linear output used to parameterize a
24
+ Bernoulli distribution, from which simply the mode is used (no stochastic
25
+ sampling!). In other words, if the sigmoid of the output of the linear layer is
26
+ >0.5, we predict a continuation of the episode, otherwise we predict an episode
27
+ terminal.
28
+ """
29
+
30
+ def __init__(self, *, model_size: str = "XS"):
31
+ """Initializes a ContinuePredictor instance.
32
+
33
+ Args:
34
+ model_size: The "Model Size" used according to [1] Appendinx B.
35
+ Determines the exact size of the underlying MLP.
36
+ """
37
+ super().__init__(name="continue_predictor")
38
+ self.model_size = model_size
39
+ self.mlp = MLP(model_size=model_size, output_layer_size=1)
40
+
41
+ # Trace self.call.
42
+ dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
43
+ self.call = tf.function(
44
+ input_signature=[
45
+ tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
46
+ tf.TensorSpec(
47
+ shape=[
48
+ None,
49
+ get_num_z_categoricals(model_size),
50
+ get_num_z_classes(model_size),
51
+ ],
52
+ dtype=dl_type,
53
+ ),
54
+ ]
55
+ )(self.call)
56
+
57
+ def call(self, h, z):
58
+ """Performs a forward pass through the continue predictor.
59
+
60
+ Args:
61
+ h: The deterministic hidden state of the sequence model. [B, dim(h)].
62
+ z: The stochastic discrete representations of the original
63
+ observation input. [B, num_categoricals, num_classes].
64
+ """
65
+ # Flatten last two dims of z.
66
+ assert len(z.shape) == 3
67
+ z_shape = tf.shape(z)
68
+ z = tf.reshape(z, shape=(z_shape[0], -1))
69
+ assert len(z.shape) == 2
70
+ out = tf.concat([h, z], axis=-1)
71
+ out.set_shape(
72
+ [
73
+ None,
74
+ (
75
+ get_num_z_categoricals(self.model_size)
76
+ * get_num_z_classes(self.model_size)
77
+ + get_gru_units(self.model_size)
78
+ ),
79
+ ]
80
+ )
81
+ # Send h-cat-z through MLP.
82
+ out = self.mlp(out)
83
+ # Remove the extra [B, 1] dimension at the end to get a proper Bernoulli
84
+ # distribution. Otherwise, tfp will think that the batch dims are [B, 1]
85
+ # where they should be just [B].
86
+ logits = tf.cast(tf.squeeze(out, axis=-1), tf.float32)
87
+ # Create the Bernoulli distribution object.
88
+ bernoulli = tfp.distributions.Bernoulli(logits=logits, dtype=tf.float32)
89
+
90
+ # Take the mode (greedy, deterministic "sample").
91
+ continue_ = bernoulli.mode()
92
+
93
+ # Return Bernoulli sample (whether to continue) OR (continue?, Bernoulli prob).
94
+ return continue_, bernoulli