koichi12 commited on
Commit
30f24c0
·
verified ·
1 Parent(s): adce983

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc +3 -0
  4. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py +0 -0
  10. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py +8 -0
  12. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py +9 -0
  13. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py +388 -0
  18. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py +426 -0
  19. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py +406 -0
  20. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__init__.py +0 -0
  21. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py +275 -0
  25. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py +206 -0
  26. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py +15 -0
  27. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py +750 -0
  28. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py +80 -0
  29. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py +31 -0
  30. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py +153 -0
  31. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__init__.py +0 -0
  32. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py +915 -0
  36. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py +23 -0
  37. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__init__.py +0 -0
  38. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py +203 -0
  39. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py +112 -0
  40. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py +187 -0
  41. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py +98 -0
  42. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py +177 -0
  43. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py +94 -0
  44. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py +606 -0
  45. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py +407 -0
  46. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py +12 -0
  47. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -173,3 +173,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
173
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
174
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
175
  .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
 
 
 
173
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
174
  .venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
175
  .venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
176
+ .venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
177
+ .venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc96e86e5e36ee78f9cfcd3d87220524f3cb583ba7b0472482fe408fbc1c57fa
3
+ size 114677
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e715fb00f3b4360472455b9c5d37eb8337c42bc50fea95d2d75fa67bebdcb096
3
+ size 158454
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.39 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc ADDED
Binary file (424 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc ADDED
Binary file (8.29 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc ADDED
Binary file (6.39 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.86 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc ADDED
Binary file (3.16 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @OldAPIStack
2
+ from ray.rllib.callbacks.callbacks import RLlibCallback
3
+ from ray.rllib.callbacks.utils import _make_multi_callbacks
4
+
5
+
6
+ # Backward compatibility
7
+ DefaultCallbacks = RLlibCallback
8
+ make_multi_callbacks = _make_multi_callbacks
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.cql.cql import CQL, CQLConfig
2
+ from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
3
+
4
+ __all__ = [
5
+ "CQL",
6
+ "CQLConfig",
7
+ # @OldAPIStack
8
+ "CQLTorchPolicy",
9
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (438 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc ADDED
Binary file (20.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Type, Union
3
+
4
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
5
+ from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
6
+ from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
7
+ from ray.rllib.algorithms.sac.sac import (
8
+ SAC,
9
+ SACConfig,
10
+ )
11
+ from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
12
+ AddObservationsFromEpisodesToBatch,
13
+ )
14
+ from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
15
+ AddNextObservationsFromEpisodesToTrainBatch,
16
+ )
17
+ from ray.rllib.core.learner.learner import Learner
18
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
19
+ from ray.rllib.execution.rollout_ops import (
20
+ synchronous_parallel_sample,
21
+ )
22
+ from ray.rllib.execution.train_ops import (
23
+ multi_gpu_train_one_step,
24
+ train_one_step,
25
+ )
26
+ from ray.rllib.policy.policy import Policy
27
+ from ray.rllib.utils.annotations import OldAPIStack, override
28
+ from ray.rllib.utils.deprecation import (
29
+ DEPRECATED_VALUE,
30
+ deprecation_warning,
31
+ )
32
+ from ray.rllib.utils.framework import try_import_tf, try_import_tfp
33
+ from ray.rllib.utils.metrics import (
34
+ ALL_MODULES,
35
+ LEARNER_RESULTS,
36
+ LEARNER_UPDATE_TIMER,
37
+ LAST_TARGET_UPDATE_TS,
38
+ NUM_AGENT_STEPS_SAMPLED,
39
+ NUM_AGENT_STEPS_TRAINED,
40
+ NUM_ENV_STEPS_SAMPLED,
41
+ NUM_ENV_STEPS_TRAINED,
42
+ NUM_TARGET_UPDATES,
43
+ OFFLINE_SAMPLING_TIMER,
44
+ TARGET_NET_UPDATE_TIMER,
45
+ SYNCH_WORKER_WEIGHTS_TIMER,
46
+ SAMPLE_TIMER,
47
+ TIMERS,
48
+ )
49
+ from ray.rllib.utils.typing import ResultDict, RLModuleSpecType
50
+
51
+ tf1, tf, tfv = try_import_tf()
52
+ tfp = try_import_tfp()
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ class CQLConfig(SACConfig):
57
+ """Defines a configuration class from which a CQL can be built.
58
+
59
+ .. testcode::
60
+ :skipif: True
61
+
62
+ from ray.rllib.algorithms.cql import CQLConfig
63
+ config = CQLConfig().training(gamma=0.9, lr=0.01)
64
+ config = config.resources(num_gpus=0)
65
+ config = config.env_runners(num_env_runners=4)
66
+ print(config.to_dict())
67
+ # Build a Algorithm object from the config and run 1 training iteration.
68
+ algo = config.build(env="CartPole-v1")
69
+ algo.train()
70
+ """
71
+
72
+ def __init__(self, algo_class=None):
73
+ super().__init__(algo_class=algo_class or CQL)
74
+
75
+ # fmt: off
76
+ # __sphinx_doc_begin__
77
+ # CQL-specific config settings:
78
+ self.bc_iters = 20000
79
+ self.temperature = 1.0
80
+ self.num_actions = 10
81
+ self.lagrangian = False
82
+ self.lagrangian_thresh = 5.0
83
+ self.min_q_weight = 5.0
84
+ self.deterministic_backup = True
85
+ self.lr = 3e-4
86
+ # Note, the new stack defines learning rates for each component.
87
+ # The base learning rate `lr` has to be set to `None`, if using
88
+ # the new stack.
89
+ self.actor_lr = 1e-4
90
+ self.critic_lr = 1e-3
91
+ self.alpha_lr = 1e-3
92
+
93
+ self.replay_buffer_config = {
94
+ "_enable_replay_buffer_api": True,
95
+ "type": "MultiAgentPrioritizedReplayBuffer",
96
+ "capacity": int(1e6),
97
+ # If True prioritized replay buffer will be used.
98
+ "prioritized_replay": False,
99
+ "prioritized_replay_alpha": 0.6,
100
+ "prioritized_replay_beta": 0.4,
101
+ "prioritized_replay_eps": 1e-6,
102
+ # Whether to compute priorities already on the remote worker side.
103
+ "worker_side_prioritization": False,
104
+ }
105
+
106
+ # Changes to Algorithm's/SACConfig's default:
107
+
108
+ # .reporting()
109
+ self.min_sample_timesteps_per_iteration = 0
110
+ self.min_train_timesteps_per_iteration = 100
111
+ # fmt: on
112
+ # __sphinx_doc_end__
113
+
114
+ self.timesteps_per_iteration = DEPRECATED_VALUE
115
+
116
+ @override(SACConfig)
117
+ def training(
118
+ self,
119
+ *,
120
+ bc_iters: Optional[int] = NotProvided,
121
+ temperature: Optional[float] = NotProvided,
122
+ num_actions: Optional[int] = NotProvided,
123
+ lagrangian: Optional[bool] = NotProvided,
124
+ lagrangian_thresh: Optional[float] = NotProvided,
125
+ min_q_weight: Optional[float] = NotProvided,
126
+ deterministic_backup: Optional[bool] = NotProvided,
127
+ **kwargs,
128
+ ) -> "CQLConfig":
129
+ """Sets the training-related configuration.
130
+
131
+ Args:
132
+ bc_iters: Number of iterations with Behavior Cloning pretraining.
133
+ temperature: CQL loss temperature.
134
+ num_actions: Number of actions to sample for CQL loss
135
+ lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
136
+ lagrangian_thresh: Lagrangian threshold.
137
+ min_q_weight: in Q weight multiplier.
138
+ deterministic_backup: If the target in the Bellman update should have an
139
+ entropy backup. Defaults to `True`.
140
+
141
+ Returns:
142
+ This updated AlgorithmConfig object.
143
+ """
144
+ # Pass kwargs onto super's `training()` method.
145
+ super().training(**kwargs)
146
+
147
+ if bc_iters is not NotProvided:
148
+ self.bc_iters = bc_iters
149
+ if temperature is not NotProvided:
150
+ self.temperature = temperature
151
+ if num_actions is not NotProvided:
152
+ self.num_actions = num_actions
153
+ if lagrangian is not NotProvided:
154
+ self.lagrangian = lagrangian
155
+ if lagrangian_thresh is not NotProvided:
156
+ self.lagrangian_thresh = lagrangian_thresh
157
+ if min_q_weight is not NotProvided:
158
+ self.min_q_weight = min_q_weight
159
+ if deterministic_backup is not NotProvided:
160
+ self.deterministic_backup = deterministic_backup
161
+
162
+ return self
163
+
164
+ @override(AlgorithmConfig)
165
+ def offline_data(self, **kwargs) -> "CQLConfig":
166
+
167
+ super().offline_data(**kwargs)
168
+
169
+ # Check, if the passed in class incorporates the `OfflinePreLearner`
170
+ # interface.
171
+ if "prelearner_class" in kwargs:
172
+ from ray.rllib.offline.offline_data import OfflinePreLearner
173
+
174
+ if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
175
+ raise ValueError(
176
+ f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
177
+ "subclass of `OfflinePreLearner`. Any class passed to "
178
+ "`prelearner_class` needs to implement the interface given by "
179
+ "`OfflinePreLearner`."
180
+ )
181
+
182
+ return self
183
+
184
+ @override(SACConfig)
185
+ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
186
+ if self.framework_str == "torch":
187
+ from ray.rllib.algorithms.cql.torch.cql_torch_learner import CQLTorchLearner
188
+
189
+ return CQLTorchLearner
190
+ else:
191
+ raise ValueError(
192
+ f"The framework {self.framework_str} is not supported. "
193
+ "Use `'torch'` instead."
194
+ )
195
+
196
+ @override(AlgorithmConfig)
197
+ def build_learner_connector(
198
+ self,
199
+ input_observation_space,
200
+ input_action_space,
201
+ device=None,
202
+ ):
203
+ pipeline = super().build_learner_connector(
204
+ input_observation_space=input_observation_space,
205
+ input_action_space=input_action_space,
206
+ device=device,
207
+ )
208
+
209
+ # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
210
+ # after the corresponding "add-OBS-..." default piece).
211
+ pipeline.insert_after(
212
+ AddObservationsFromEpisodesToBatch,
213
+ AddNextObservationsFromEpisodesToTrainBatch(),
214
+ )
215
+
216
+ return pipeline
217
+
218
+ @override(SACConfig)
219
+ def validate(self) -> None:
220
+ # First check, whether old `timesteps_per_iteration` is used.
221
+ if self.timesteps_per_iteration != DEPRECATED_VALUE:
222
+ deprecation_warning(
223
+ old="timesteps_per_iteration",
224
+ new="min_train_timesteps_per_iteration",
225
+ error=True,
226
+ )
227
+
228
+ # Call super's validation method.
229
+ super().validate()
230
+
231
+ # CQL-torch performs the optimizer steps inside the loss function.
232
+ # Using the multi-GPU optimizer will therefore not work (see multi-GPU
233
+ # check above) and we must use the simple optimizer for now.
234
+ if self.simple_optimizer is not True and self.framework_str == "torch":
235
+ self.simple_optimizer = True
236
+
237
+ if self.framework_str in ["tf", "tf2"] and tfp is None:
238
+ logger.warning(
239
+ "You need `tensorflow_probability` in order to run CQL! "
240
+ "Install it via `pip install tensorflow_probability`. Your "
241
+ f"tf.__version__={tf.__version__ if tf else None}."
242
+ "Trying to import tfp results in the following error:"
243
+ )
244
+ try_import_tfp(error=True)
245
+
246
+ # Assert that for a local learner the number of iterations is 1. Note,
247
+ # this is needed because we have no iterators, but instead a single
248
+ # batch returned directly from the `OfflineData.sample` method.
249
+ if (
250
+ self.num_learners == 0
251
+ and not self.dataset_num_iters_per_learner
252
+ and self.enable_rl_module_and_learner
253
+ ):
254
+ self._value_error(
255
+ "When using a single local learner the number of iterations "
256
+ "per learner, `dataset_num_iters_per_learner` has to be defined. "
257
+ "Set this hyperparameter in the `AlgorithmConfig.offline_data`."
258
+ )
259
+
260
+ @override(SACConfig)
261
+ def get_default_rl_module_spec(self) -> RLModuleSpecType:
262
+ if self.framework_str == "torch":
263
+ from ray.rllib.algorithms.cql.torch.default_cql_torch_rl_module import (
264
+ DefaultCQLTorchRLModule,
265
+ )
266
+
267
+ return RLModuleSpec(module_class=DefaultCQLTorchRLModule)
268
+ else:
269
+ raise ValueError(
270
+ f"The framework {self.framework_str} is not supported. " "Use `torch`."
271
+ )
272
+
273
+ @property
274
+ def _model_config_auto_includes(self):
275
+ return super()._model_config_auto_includes | {
276
+ "num_actions": self.num_actions,
277
+ }
278
+
279
+
280
+ class CQL(SAC):
281
+ """CQL (derived from SAC)."""
282
+
283
+ @classmethod
284
+ @override(SAC)
285
+ def get_default_config(cls) -> AlgorithmConfig:
286
+ return CQLConfig()
287
+
288
+ @classmethod
289
+ @override(SAC)
290
+ def get_default_policy_class(
291
+ cls, config: AlgorithmConfig
292
+ ) -> Optional[Type[Policy]]:
293
+ if config["framework"] == "torch":
294
+ return CQLTorchPolicy
295
+ else:
296
+ return CQLTFPolicy
297
+
298
+ @override(SAC)
299
+ def training_step(self) -> None:
300
+ # Old API stack (Policy, RolloutWorker, Connector).
301
+ if not self.config.enable_env_runner_and_connector_v2:
302
+ return self._training_step_old_api_stack()
303
+
304
+ # Sampling from offline data.
305
+ with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
306
+ # Return an iterator in case we are using remote learners.
307
+ batch_or_iterator = self.offline_data.sample(
308
+ num_samples=self.config.train_batch_size_per_learner,
309
+ num_shards=self.config.num_learners,
310
+ return_iterator=self.config.num_learners > 1,
311
+ )
312
+
313
+ # Updating the policy.
314
+ with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
315
+ # TODO (simon, sven): Check, if we should execute directly s.th. like
316
+ # `LearnerGroup.update_from_iterator()`.
317
+ learner_results = self.learner_group._update(
318
+ batch=batch_or_iterator,
319
+ minibatch_size=self.config.train_batch_size_per_learner,
320
+ num_iters=self.config.dataset_num_iters_per_learner,
321
+ )
322
+
323
+ # Log training results.
324
+ self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
325
+
326
+ # Synchronize weights.
327
+ # As the results contain for each policy the loss and in addition the
328
+ # total loss over all policies is returned, this total loss has to be
329
+ # removed.
330
+ modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
331
+
332
+ if self.eval_env_runner_group:
333
+ # Update weights - after learning on the local worker -
334
+ # on all remote workers. Note, we only have the local `EnvRunner`,
335
+ # but from this `EnvRunner` the evaulation `EnvRunner`s get updated.
336
+ with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
337
+ self.eval_env_runner_group.sync_weights(
338
+ # Sync weights from learner_group to all EnvRunners.
339
+ from_worker_or_learner_group=self.learner_group,
340
+ policies=modules_to_update,
341
+ inference_only=True,
342
+ )
343
+
344
+ @OldAPIStack
345
+ def _training_step_old_api_stack(self) -> ResultDict:
346
+ # Collect SampleBatches from sample workers.
347
+ with self._timers[SAMPLE_TIMER]:
348
+ train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
349
+ train_batch = train_batch.as_multi_agent()
350
+ self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
351
+ self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
352
+
353
+ # Postprocess batch before we learn on it.
354
+ post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
355
+ train_batch = post_fn(train_batch, self.env_runner_group, self.config)
356
+
357
+ # Learn on training batch.
358
+ # Use simple optimizer (only for multi-agent or tf-eager; all other
359
+ # cases should use the multi-GPU optimizer, even if only using 1 GPU)
360
+ if self.config.get("simple_optimizer") is True:
361
+ train_results = train_one_step(self, train_batch)
362
+ else:
363
+ train_results = multi_gpu_train_one_step(self, train_batch)
364
+
365
+ # Update target network every `target_network_update_freq` training steps.
366
+ cur_ts = self._counters[
367
+ NUM_AGENT_STEPS_TRAINED
368
+ if self.config.count_steps_by == "agent_steps"
369
+ else NUM_ENV_STEPS_TRAINED
370
+ ]
371
+ last_update = self._counters[LAST_TARGET_UPDATE_TS]
372
+ if cur_ts - last_update >= self.config.target_network_update_freq:
373
+ with self._timers[TARGET_NET_UPDATE_TIMER]:
374
+ to_update = self.env_runner.get_policies_to_train()
375
+ self.env_runner.foreach_policy_to_train(
376
+ lambda p, pid: pid in to_update and p.update_target()
377
+ )
378
+ self._counters[NUM_TARGET_UPDATES] += 1
379
+ self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
380
+
381
+ # Update remote workers's weights after learning on local worker
382
+ # (only those policies that were actually trained).
383
+ if self.env_runner_group.num_remote_workers() > 0:
384
+ with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
385
+ self.env_runner_group.sync_weights(policies=list(train_results.keys()))
386
+
387
+ # Return all collected metrics for the iteration.
388
+ return train_results
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TensorFlow policy class used for CQL.
3
+ """
4
+ from functools import partial
5
+ import numpy as np
6
+ import gymnasium as gym
7
+ import logging
8
+ import tree
9
+ from typing import Dict, List, Type, Union
10
+
11
+ import ray
12
+ import ray.experimental.tf_utils
13
+ from ray.rllib.algorithms.sac.sac_tf_policy import (
14
+ apply_gradients as sac_apply_gradients,
15
+ compute_and_clip_gradients as sac_compute_and_clip_gradients,
16
+ get_distribution_inputs_and_class,
17
+ _get_dist_class,
18
+ build_sac_model,
19
+ postprocess_trajectory,
20
+ setup_late_mixins,
21
+ stats,
22
+ validate_spaces,
23
+ ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin,
24
+ ComputeTDErrorMixin,
25
+ )
26
+ from ray.rllib.models.modelv2 import ModelV2
27
+ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
28
+ from ray.rllib.policy.tf_mixins import TargetNetworkMixin
29
+ from ray.rllib.policy.tf_policy_template import build_tf_policy
30
+ from ray.rllib.policy.policy import Policy
31
+ from ray.rllib.policy.sample_batch import SampleBatch
32
+ from ray.rllib.utils.exploration.random import Random
33
+ from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp
34
+ from ray.rllib.utils.typing import (
35
+ LocalOptimizer,
36
+ ModelGradients,
37
+ TensorType,
38
+ AlgorithmConfigDict,
39
+ )
40
+
41
+ tf1, tf, tfv = try_import_tf()
42
+ tfp = try_import_tfp()
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ MEAN_MIN = -9.0
47
+ MEAN_MAX = 9.0
48
+
49
+
50
+ def _repeat_tensor(t: TensorType, n: int):
51
+ # Insert new axis at position 1 into tensor t
52
+ t_rep = tf.expand_dims(t, 1)
53
+ # Repeat tensor t_rep along new axis n times
54
+ multiples = tf.concat([[1, n], tf.tile([1], tf.expand_dims(tf.rank(t) - 1, 0))], 0)
55
+ t_rep = tf.tile(t_rep, multiples)
56
+ # Merge new axis into batch axis
57
+ t_rep = tf.reshape(t_rep, tf.concat([[-1], tf.shape(t)[1:]], 0))
58
+ return t_rep
59
+
60
+
61
+ # Returns policy tiled actions and log probabilities for CQL Loss
62
+ def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
63
+ batch_size = tf.shape(tree.flatten(obs)[0])[0]
64
+ obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
65
+ logits, _ = model.get_action_model_outputs(obs_temp)
66
+ policy_dist = action_dist(logits, model)
67
+ actions, logp_ = policy_dist.sample_logp()
68
+ logp = tf.expand_dims(logp_, -1)
69
+ return actions, tf.reshape(logp, [batch_size, num_repeat, 1])
70
+
71
+
72
+ def q_values_repeat(model, obs, actions, twin=False):
73
+ action_shape = tf.shape(actions)[0]
74
+ obs_shape = tf.shape(tree.flatten(obs)[0])[0]
75
+ num_repeat = action_shape // obs_shape
76
+ obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
77
+ if not twin:
78
+ preds_, _ = model.get_q_values(obs_temp, actions)
79
+ else:
80
+ preds_, _ = model.get_twin_q_values(obs_temp, actions)
81
+ preds = tf.reshape(preds_, [obs_shape, num_repeat, 1])
82
+ return preds
83
+
84
+
85
+ def cql_loss(
86
+ policy: Policy,
87
+ model: ModelV2,
88
+ dist_class: Type[TFActionDistribution],
89
+ train_batch: SampleBatch,
90
+ ) -> Union[TensorType, List[TensorType]]:
91
+ logger.info(f"Current iteration = {policy.cur_iter}")
92
+ policy.cur_iter += 1
93
+
94
+ # For best performance, turn deterministic off
95
+ deterministic = policy.config["_deterministic_loss"]
96
+ assert not deterministic
97
+ twin_q = policy.config["twin_q"]
98
+ discount = policy.config["gamma"]
99
+
100
+ # CQL Parameters
101
+ bc_iters = policy.config["bc_iters"]
102
+ cql_temp = policy.config["temperature"]
103
+ num_actions = policy.config["num_actions"]
104
+ min_q_weight = policy.config["min_q_weight"]
105
+ use_lagrange = policy.config["lagrangian"]
106
+ target_action_gap = policy.config["lagrangian_thresh"]
107
+
108
+ obs = train_batch[SampleBatch.CUR_OBS]
109
+ actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
110
+ rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
111
+ next_obs = train_batch[SampleBatch.NEXT_OBS]
112
+ terminals = train_batch[SampleBatch.TERMINATEDS]
113
+
114
+ model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)
115
+
116
+ model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None)
117
+
118
+ target_model_out_tp1, _ = policy.target_model(
119
+ SampleBatch(obs=next_obs, _is_training=True), [], None
120
+ )
121
+
122
+ action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
123
+ action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
124
+ action_dist_t = action_dist_class(action_dist_inputs_t, model)
125
+ policy_t, log_pis_t = action_dist_t.sample_logp()
126
+ log_pis_t = tf.expand_dims(log_pis_t, -1)
127
+
128
+ # Unlike original SAC, Alpha and Actor Loss are computed first.
129
+ # Alpha Loss
130
+ alpha_loss = -tf.reduce_mean(
131
+ model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)
132
+ )
133
+
134
+ # Policy Loss (Either Behavior Clone Loss or SAC Loss)
135
+ alpha = tf.math.exp(model.log_alpha)
136
+ if policy.cur_iter >= bc_iters:
137
+ min_q, _ = model.get_q_values(model_out_t, policy_t)
138
+ if twin_q:
139
+ twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
140
+ min_q = tf.math.minimum(min_q, twin_q_)
141
+ actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - min_q)
142
+ else:
143
+ bc_logp = action_dist_t.logp(actions)
144
+ actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - bc_logp)
145
+ # actor_loss = -tf.reduce_mean(bc_logp)
146
+
147
+ # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
148
+ # SAC Loss:
149
+ # Q-values for the batched actions.
150
+ action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
151
+ action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
152
+ policy_tp1, _ = action_dist_tp1.sample_logp()
153
+
154
+ q_t, _ = model.get_q_values(model_out_t, actions)
155
+ q_t_selected = tf.squeeze(q_t, axis=-1)
156
+ if twin_q:
157
+ twin_q_t, _ = model.get_twin_q_values(model_out_t, actions)
158
+ twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1)
159
+
160
+ # Target q network evaluation.
161
+ q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
162
+ if twin_q:
163
+ twin_q_tp1, _ = policy.target_model.get_twin_q_values(
164
+ target_model_out_tp1, policy_tp1
165
+ )
166
+ # Take min over both twin-NNs.
167
+ q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1)
168
+
169
+ q_tp1_best = tf.squeeze(input=q_tp1, axis=-1)
170
+ q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best
171
+
172
+ # compute RHS of bellman equation
173
+ q_t_target = tf.stop_gradient(
174
+ rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked
175
+ )
176
+
177
+ # Compute the TD-error (potentially clipped), for priority replay buffer
178
+ base_td_error = tf.math.abs(q_t_selected - q_t_target)
179
+ if twin_q:
180
+ twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target)
181
+ td_error = 0.5 * (base_td_error + twin_td_error)
182
+ else:
183
+ td_error = base_td_error
184
+
185
+ critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target)
186
+ if twin_q:
187
+ critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target)
188
+
189
+ # CQL Loss (We are using Entropy version of CQL (the best version))
190
+ rand_actions, _ = policy._random_action_generator.get_exploration_action(
191
+ action_distribution=action_dist_class(
192
+ tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model
193
+ ),
194
+ timestep=0,
195
+ explore=True,
196
+ )
197
+ curr_actions, curr_logp = policy_actions_repeat(
198
+ model, action_dist_class, model_out_t, num_actions
199
+ )
200
+ next_actions, next_logp = policy_actions_repeat(
201
+ model, action_dist_class, model_out_tp1, num_actions
202
+ )
203
+
204
+ q1_rand = q_values_repeat(model, model_out_t, rand_actions)
205
+ q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
206
+ q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
207
+
208
+ if twin_q:
209
+ q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
210
+ q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True)
211
+ q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True)
212
+
213
+ random_density = np.log(0.5 ** int(curr_actions.shape[-1]))
214
+ cat_q1 = tf.concat(
215
+ [
216
+ q1_rand - random_density,
217
+ q1_next_actions - tf.stop_gradient(next_logp),
218
+ q1_curr_actions - tf.stop_gradient(curr_logp),
219
+ ],
220
+ 1,
221
+ )
222
+ if twin_q:
223
+ cat_q2 = tf.concat(
224
+ [
225
+ q2_rand - random_density,
226
+ q2_next_actions - tf.stop_gradient(next_logp),
227
+ q2_curr_actions - tf.stop_gradient(curr_logp),
228
+ ],
229
+ 1,
230
+ )
231
+
232
+ min_qf1_loss_ = (
233
+ tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1))
234
+ * min_q_weight
235
+ * cql_temp
236
+ )
237
+ min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight)
238
+ if twin_q:
239
+ min_qf2_loss_ = (
240
+ tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1))
241
+ * min_q_weight
242
+ * cql_temp
243
+ )
244
+ min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight)
245
+
246
+ if use_lagrange:
247
+ alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0]
248
+ min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
249
+ if twin_q:
250
+ min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
251
+ alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
252
+ else:
253
+ alpha_prime_loss = -min_qf1_loss
254
+
255
+ cql_loss = [min_qf1_loss]
256
+ if twin_q:
257
+ cql_loss.append(min_qf2_loss)
258
+
259
+ critic_loss = [critic_loss_1 + min_qf1_loss]
260
+ if twin_q:
261
+ critic_loss.append(critic_loss_2 + min_qf2_loss)
262
+
263
+ # Save for stats function.
264
+ policy.q_t = q_t_selected
265
+ policy.policy_t = policy_t
266
+ policy.log_pis_t = log_pis_t
267
+ policy.td_error = td_error
268
+ policy.actor_loss = actor_loss
269
+ policy.critic_loss = critic_loss
270
+ policy.alpha_loss = alpha_loss
271
+ policy.log_alpha_value = model.log_alpha
272
+ policy.alpha_value = alpha
273
+ policy.target_entropy = model.target_entropy
274
+ # CQL Stats
275
+ policy.cql_loss = cql_loss
276
+ if use_lagrange:
277
+ policy.log_alpha_prime_value = model.log_alpha_prime[0]
278
+ policy.alpha_prime_value = alpha_prime
279
+ policy.alpha_prime_loss = alpha_prime_loss
280
+
281
+ # Return all loss terms corresponding to our optimizers.
282
+ if use_lagrange:
283
+ return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + alpha_prime_loss
284
+ return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
285
+
286
+
287
+ def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
288
+ sac_dict = stats(policy, train_batch)
289
+ sac_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss))
290
+ if policy.config["lagrangian"]:
291
+ sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value
292
+ sac_dict["alpha_prime_value"] = policy.alpha_prime_value
293
+ sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss
294
+ return sac_dict
295
+
296
+
297
+ class ActorCriticOptimizerMixin(SACActorCriticOptimizerMixin):
298
+ def __init__(self, config):
299
+ super().__init__(config)
300
+ if config["lagrangian"]:
301
+ # Eager mode.
302
+ if config["framework"] == "tf2":
303
+ self._alpha_prime_optimizer = tf.keras.optimizers.Adam(
304
+ learning_rate=config["optimization"]["critic_learning_rate"]
305
+ )
306
+ # Static graph mode.
307
+ else:
308
+ self._alpha_prime_optimizer = tf1.train.AdamOptimizer(
309
+ learning_rate=config["optimization"]["critic_learning_rate"]
310
+ )
311
+
312
+
313
+ def setup_early_mixins(
314
+ policy: Policy,
315
+ obs_space: gym.spaces.Space,
316
+ action_space: gym.spaces.Space,
317
+ config: AlgorithmConfigDict,
318
+ ) -> None:
319
+ """Call mixin classes' constructors before Policy's initialization.
320
+
321
+ Adds the necessary optimizers to the given Policy.
322
+
323
+ Args:
324
+ policy: The Policy object.
325
+ obs_space (gym.spaces.Space): The Policy's observation space.
326
+ action_space (gym.spaces.Space): The Policy's action space.
327
+ config: The Policy's config.
328
+ """
329
+ policy.cur_iter = 0
330
+ ActorCriticOptimizerMixin.__init__(policy, config)
331
+ if config["lagrangian"]:
332
+ policy.model.log_alpha_prime = get_variable(
333
+ 0.0, framework="tf", trainable=True, tf_name="log_alpha_prime"
334
+ )
335
+ policy.alpha_prime_optim = tf.keras.optimizers.Adam(
336
+ learning_rate=config["optimization"]["critic_learning_rate"],
337
+ )
338
+ # Generic random action generator for calculating CQL-loss.
339
+ policy._random_action_generator = Random(
340
+ action_space,
341
+ model=None,
342
+ framework="tf2",
343
+ policy_config=config,
344
+ num_workers=0,
345
+ worker_index=0,
346
+ )
347
+
348
+
349
+ def compute_gradients_fn(
350
+ policy: Policy, optimizer: LocalOptimizer, loss: TensorType
351
+ ) -> ModelGradients:
352
+ grads_and_vars = sac_compute_and_clip_gradients(policy, optimizer, loss)
353
+
354
+ if policy.config["lagrangian"]:
355
+ # Eager: Use GradientTape (which is a property of the `optimizer`
356
+ # object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
357
+ if policy.config["framework"] == "tf2":
358
+ tape = optimizer.tape
359
+ log_alpha_prime = [policy.model.log_alpha_prime]
360
+ alpha_prime_grads_and_vars = list(
361
+ zip(
362
+ tape.gradient(policy.alpha_prime_loss, log_alpha_prime),
363
+ log_alpha_prime,
364
+ )
365
+ )
366
+ # Tf1.x: Use optimizer.compute_gradients()
367
+ else:
368
+ alpha_prime_grads_and_vars = (
369
+ policy._alpha_prime_optimizer.compute_gradients(
370
+ policy.alpha_prime_loss, var_list=[policy.model.log_alpha_prime]
371
+ )
372
+ )
373
+
374
+ # Clip if necessary.
375
+ if policy.config["grad_clip"]:
376
+ clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
377
+ else:
378
+ clip_func = tf.identity
379
+
380
+ # Save grads and vars for later use in `build_apply_op`.
381
+ policy._alpha_prime_grads_and_vars = [
382
+ (clip_func(g), v) for (g, v) in alpha_prime_grads_and_vars if g is not None
383
+ ]
384
+
385
+ grads_and_vars += policy._alpha_prime_grads_and_vars
386
+ return grads_and_vars
387
+
388
+
389
+ def apply_gradients_fn(policy, optimizer, grads_and_vars):
390
+ sac_results = sac_apply_gradients(policy, optimizer, grads_and_vars)
391
+
392
+ if policy.config["lagrangian"]:
393
+ # Eager mode -> Just apply and return None.
394
+ if policy.config["framework"] == "tf2":
395
+ policy._alpha_prime_optimizer.apply_gradients(
396
+ policy._alpha_prime_grads_and_vars
397
+ )
398
+ return
399
+ # Tf static graph -> Return grouped op.
400
+ else:
401
+ alpha_prime_apply_op = policy._alpha_prime_optimizer.apply_gradients(
402
+ policy._alpha_prime_grads_and_vars,
403
+ global_step=tf1.train.get_or_create_global_step(),
404
+ )
405
+ return tf.group([sac_results, alpha_prime_apply_op])
406
+ return sac_results
407
+
408
+
409
+ # Build a child class of `TFPolicy`, given the custom functions defined
410
+ # above.
411
+ CQLTFPolicy = build_tf_policy(
412
+ name="CQLTFPolicy",
413
+ loss_fn=cql_loss,
414
+ get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQLConfig(),
415
+ validate_spaces=validate_spaces,
416
+ stats_fn=cql_stats,
417
+ postprocess_fn=postprocess_trajectory,
418
+ before_init=setup_early_mixins,
419
+ after_init=setup_late_mixins,
420
+ make_model=build_sac_model,
421
+ extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
422
+ mixins=[ActorCriticOptimizerMixin, TargetNetworkMixin, ComputeTDErrorMixin],
423
+ action_distribution_fn=get_distribution_inputs_and_class,
424
+ compute_gradients_fn=compute_gradients_fn,
425
+ apply_gradients_fn=apply_gradients_fn,
426
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch policy class used for CQL.
3
+ """
4
+ import numpy as np
5
+ import gymnasium as gym
6
+ import logging
7
+ import tree
8
+ from typing import Dict, List, Tuple, Type, Union
9
+
10
+ import ray
11
+ import ray.experimental.tf_utils
12
+ from ray.rllib.algorithms.sac.sac_tf_policy import (
13
+ postprocess_trajectory,
14
+ validate_spaces,
15
+ )
16
+ from ray.rllib.algorithms.sac.sac_torch_policy import (
17
+ _get_dist_class,
18
+ stats,
19
+ build_sac_model_and_action_dist,
20
+ optimizer_fn,
21
+ ComputeTDErrorMixin,
22
+ setup_late_mixins,
23
+ action_distribution_fn,
24
+ )
25
+ from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
26
+ from ray.rllib.models.modelv2 import ModelV2
27
+ from ray.rllib.policy.policy_template import build_policy_class
28
+ from ray.rllib.policy.policy import Policy
29
+ from ray.rllib.policy.torch_mixins import TargetNetworkMixin
30
+ from ray.rllib.policy.sample_batch import SampleBatch
31
+ from ray.rllib.utils.framework import try_import_torch
32
+ from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
33
+ from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict
34
+ from ray.rllib.utils.torch_utils import (
35
+ apply_grad_clipping,
36
+ convert_to_torch_tensor,
37
+ concat_multi_gpu_td_errors,
38
+ )
39
+
40
+ torch, nn = try_import_torch()
41
+ F = nn.functional
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ MEAN_MIN = -9.0
46
+ MEAN_MAX = 9.0
47
+
48
+
49
+ def _repeat_tensor(t: TensorType, n: int):
50
+ # Insert new dimension at posotion 1 into tensor t
51
+ t_rep = t.unsqueeze(1)
52
+ # Repeat tensor t_rep along new dimension n times
53
+ t_rep = torch.repeat_interleave(t_rep, n, dim=1)
54
+ # Merge new dimension into batch dimension
55
+ t_rep = t_rep.view(-1, *t.shape[1:])
56
+ return t_rep
57
+
58
+
59
+ # Returns policy tiled actions and log probabilities for CQL Loss
60
+ def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
61
+ batch_size = tree.flatten(obs)[0].shape[0]
62
+ obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
63
+ logits, _ = model.get_action_model_outputs(obs_temp)
64
+ policy_dist = action_dist(logits, model)
65
+ actions, logp_ = policy_dist.sample_logp()
66
+ logp = logp_.unsqueeze(-1)
67
+ return actions, logp.view(batch_size, num_repeat, 1)
68
+
69
+
70
+ def q_values_repeat(model, obs, actions, twin=False):
71
+ action_shape = actions.shape[0]
72
+ obs_shape = tree.flatten(obs)[0].shape[0]
73
+ num_repeat = int(action_shape / obs_shape)
74
+ obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
75
+ if not twin:
76
+ preds_, _ = model.get_q_values(obs_temp, actions)
77
+ else:
78
+ preds_, _ = model.get_twin_q_values(obs_temp, actions)
79
+ preds = preds_.view(obs_shape, num_repeat, 1)
80
+ return preds
81
+
82
+
83
+ def cql_loss(
84
+ policy: Policy,
85
+ model: ModelV2,
86
+ dist_class: Type[TorchDistributionWrapper],
87
+ train_batch: SampleBatch,
88
+ ) -> Union[TensorType, List[TensorType]]:
89
+ logger.info(f"Current iteration = {policy.cur_iter}")
90
+ policy.cur_iter += 1
91
+
92
+ # Look up the target model (tower) using the model tower.
93
+ target_model = policy.target_models[model]
94
+
95
+ # For best performance, turn deterministic off
96
+ deterministic = policy.config["_deterministic_loss"]
97
+ assert not deterministic
98
+ twin_q = policy.config["twin_q"]
99
+ discount = policy.config["gamma"]
100
+ action_low = model.action_space.low[0]
101
+ action_high = model.action_space.high[0]
102
+
103
+ # CQL Parameters
104
+ bc_iters = policy.config["bc_iters"]
105
+ cql_temp = policy.config["temperature"]
106
+ num_actions = policy.config["num_actions"]
107
+ min_q_weight = policy.config["min_q_weight"]
108
+ use_lagrange = policy.config["lagrangian"]
109
+ target_action_gap = policy.config["lagrangian_thresh"]
110
+
111
+ obs = train_batch[SampleBatch.CUR_OBS]
112
+ actions = train_batch[SampleBatch.ACTIONS]
113
+ rewards = train_batch[SampleBatch.REWARDS].float()
114
+ next_obs = train_batch[SampleBatch.NEXT_OBS]
115
+ terminals = train_batch[SampleBatch.TERMINATEDS]
116
+
117
+ model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)
118
+
119
+ model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None)
120
+
121
+ target_model_out_tp1, _ = target_model(
122
+ SampleBatch(obs=next_obs, _is_training=True), [], None
123
+ )
124
+
125
+ action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
126
+ action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
127
+ action_dist_t = action_dist_class(action_dist_inputs_t, model)
128
+ policy_t, log_pis_t = action_dist_t.sample_logp()
129
+ log_pis_t = torch.unsqueeze(log_pis_t, -1)
130
+
131
+ # Unlike original SAC, Alpha and Actor Loss are computed first.
132
+ # Alpha Loss
133
+ alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean()
134
+
135
+ batch_size = tree.flatten(obs)[0].shape[0]
136
+ if batch_size == policy.config["train_batch_size"]:
137
+ policy.alpha_optim.zero_grad()
138
+ alpha_loss.backward()
139
+ policy.alpha_optim.step()
140
+
141
+ # Policy Loss (Either Behavior Clone Loss or SAC Loss)
142
+ alpha = torch.exp(model.log_alpha)
143
+ if policy.cur_iter >= bc_iters:
144
+ min_q, _ = model.get_q_values(model_out_t, policy_t)
145
+ if twin_q:
146
+ twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
147
+ min_q = torch.min(min_q, twin_q_)
148
+ actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
149
+ else:
150
+ bc_logp = action_dist_t.logp(actions)
151
+ actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
152
+ # actor_loss = -bc_logp.mean()
153
+
154
+ if batch_size == policy.config["train_batch_size"]:
155
+ policy.actor_optim.zero_grad()
156
+ actor_loss.backward(retain_graph=True)
157
+ policy.actor_optim.step()
158
+
159
+ # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
160
+ # SAC Loss:
161
+ # Q-values for the batched actions.
162
+ action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
163
+ action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
164
+ policy_tp1, _ = action_dist_tp1.sample_logp()
165
+
166
+ q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
167
+ q_t_selected = torch.squeeze(q_t, dim=-1)
168
+ if twin_q:
169
+ twin_q_t, _ = model.get_twin_q_values(
170
+ model_out_t, train_batch[SampleBatch.ACTIONS]
171
+ )
172
+ twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
173
+
174
+ # Target q network evaluation.
175
+ q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1)
176
+ if twin_q:
177
+ twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1)
178
+ # Take min over both twin-NNs.
179
+ q_tp1 = torch.min(q_tp1, twin_q_tp1)
180
+
181
+ q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
182
+ q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best
183
+
184
+ # compute RHS of bellman equation
185
+ q_t_target = (
186
+ rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked
187
+ ).detach()
188
+
189
+ # Compute the TD-error (potentially clipped), for priority replay buffer
190
+ base_td_error = torch.abs(q_t_selected - q_t_target)
191
+ if twin_q:
192
+ twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
193
+ td_error = 0.5 * (base_td_error + twin_td_error)
194
+ else:
195
+ td_error = base_td_error
196
+
197
+ critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
198
+ if twin_q:
199
+ critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)
200
+
201
+ # CQL Loss (We are using Entropy version of CQL (the best version))
202
+ rand_actions = convert_to_torch_tensor(
203
+ torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(
204
+ action_low, action_high
205
+ ),
206
+ policy.device,
207
+ )
208
+ curr_actions, curr_logp = policy_actions_repeat(
209
+ model, action_dist_class, model_out_t, num_actions
210
+ )
211
+ next_actions, next_logp = policy_actions_repeat(
212
+ model, action_dist_class, model_out_tp1, num_actions
213
+ )
214
+
215
+ q1_rand = q_values_repeat(model, model_out_t, rand_actions)
216
+ q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
217
+ q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
218
+
219
+ if twin_q:
220
+ q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
221
+ q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True)
222
+ q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True)
223
+
224
+ random_density = np.log(0.5 ** curr_actions.shape[-1])
225
+ cat_q1 = torch.cat(
226
+ [
227
+ q1_rand - random_density,
228
+ q1_next_actions - next_logp.detach(),
229
+ q1_curr_actions - curr_logp.detach(),
230
+ ],
231
+ 1,
232
+ )
233
+ if twin_q:
234
+ cat_q2 = torch.cat(
235
+ [
236
+ q2_rand - random_density,
237
+ q2_next_actions - next_logp.detach(),
238
+ q2_curr_actions - curr_logp.detach(),
239
+ ],
240
+ 1,
241
+ )
242
+
243
+ min_qf1_loss_ = (
244
+ torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
245
+ )
246
+ min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
247
+ if twin_q:
248
+ min_qf2_loss_ = (
249
+ torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
250
+ )
251
+ min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)
252
+
253
+ if use_lagrange:
254
+ alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[
255
+ 0
256
+ ]
257
+ min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
258
+ if twin_q:
259
+ min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
260
+ alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
261
+ else:
262
+ alpha_prime_loss = -min_qf1_loss
263
+
264
+ cql_loss = [min_qf1_loss]
265
+ if twin_q:
266
+ cql_loss.append(min_qf2_loss)
267
+
268
+ critic_loss = [critic_loss_1 + min_qf1_loss]
269
+ if twin_q:
270
+ critic_loss.append(critic_loss_2 + min_qf2_loss)
271
+
272
+ if batch_size == policy.config["train_batch_size"]:
273
+ policy.critic_optims[0].zero_grad()
274
+ critic_loss[0].backward(retain_graph=True)
275
+ policy.critic_optims[0].step()
276
+
277
+ if twin_q:
278
+ policy.critic_optims[1].zero_grad()
279
+ critic_loss[1].backward(retain_graph=False)
280
+ policy.critic_optims[1].step()
281
+
282
+ # Store values for stats function in model (tower), such that for
283
+ # multi-GPU, we do not override them during the parallel loss phase.
284
+ # SAC stats.
285
+ model.tower_stats["q_t"] = q_t_selected
286
+ model.tower_stats["policy_t"] = policy_t
287
+ model.tower_stats["log_pis_t"] = log_pis_t
288
+ model.tower_stats["actor_loss"] = actor_loss
289
+ model.tower_stats["critic_loss"] = critic_loss
290
+ model.tower_stats["alpha_loss"] = alpha_loss
291
+ model.tower_stats["log_alpha_value"] = model.log_alpha
292
+ model.tower_stats["alpha_value"] = alpha
293
+ model.tower_stats["target_entropy"] = model.target_entropy
294
+ # CQL stats.
295
+ model.tower_stats["cql_loss"] = cql_loss
296
+
297
+ # TD-error tensor in final stats
298
+ # will be concatenated and retrieved for each individual batch item.
299
+ model.tower_stats["td_error"] = td_error
300
+
301
+ if use_lagrange:
302
+ model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
303
+ model.tower_stats["alpha_prime_value"] = alpha_prime
304
+ model.tower_stats["alpha_prime_loss"] = alpha_prime_loss
305
+
306
+ if batch_size == policy.config["train_batch_size"]:
307
+ policy.alpha_prime_optim.zero_grad()
308
+ alpha_prime_loss.backward()
309
+ policy.alpha_prime_optim.step()
310
+
311
+ # Return all loss terms corresponding to our optimizers.
312
+ return tuple(
313
+ [actor_loss]
314
+ + critic_loss
315
+ + [alpha_loss]
316
+ + ([alpha_prime_loss] if use_lagrange else [])
317
+ )
318
+
319
+
320
+ def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
321
+ # Get SAC loss stats.
322
+ stats_dict = stats(policy, train_batch)
323
+
324
+ # Add CQL loss stats to the dict.
325
+ stats_dict["cql_loss"] = torch.mean(
326
+ torch.stack(*policy.get_tower_stats("cql_loss"))
327
+ )
328
+
329
+ if policy.config["lagrangian"]:
330
+ stats_dict["log_alpha_prime_value"] = torch.mean(
331
+ torch.stack(policy.get_tower_stats("log_alpha_prime_value"))
332
+ )
333
+ stats_dict["alpha_prime_value"] = torch.mean(
334
+ torch.stack(policy.get_tower_stats("alpha_prime_value"))
335
+ )
336
+ stats_dict["alpha_prime_loss"] = torch.mean(
337
+ torch.stack(policy.get_tower_stats("alpha_prime_loss"))
338
+ )
339
+ return stats_dict
340
+
341
+
342
+ def cql_optimizer_fn(
343
+ policy: Policy, config: AlgorithmConfigDict
344
+ ) -> Tuple[LocalOptimizer]:
345
+ policy.cur_iter = 0
346
+ opt_list = optimizer_fn(policy, config)
347
+ if config["lagrangian"]:
348
+ log_alpha_prime = nn.Parameter(torch.zeros(1, requires_grad=True).float())
349
+ policy.model.register_parameter("log_alpha_prime", log_alpha_prime)
350
+ policy.alpha_prime_optim = torch.optim.Adam(
351
+ params=[policy.model.log_alpha_prime],
352
+ lr=config["optimization"]["critic_learning_rate"],
353
+ eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
354
+ )
355
+ return tuple(
356
+ [policy.actor_optim]
357
+ + policy.critic_optims
358
+ + [policy.alpha_optim]
359
+ + [policy.alpha_prime_optim]
360
+ )
361
+ return opt_list
362
+
363
+
364
+ def cql_setup_late_mixins(
365
+ policy: Policy,
366
+ obs_space: gym.spaces.Space,
367
+ action_space: gym.spaces.Space,
368
+ config: AlgorithmConfigDict,
369
+ ) -> None:
370
+ setup_late_mixins(policy, obs_space, action_space, config)
371
+ if config["lagrangian"]:
372
+ policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(policy.device)
373
+
374
+
375
+ def compute_gradients_fn(policy, postprocessed_batch):
376
+ batches = [policy._lazy_tensor_dict(postprocessed_batch)]
377
+ model = policy.model
378
+ policy._loss(policy, model, policy.dist_class, batches[0])
379
+ stats = {LEARNER_STATS_KEY: policy._convert_to_numpy(cql_stats(policy, batches[0]))}
380
+ return [None, stats]
381
+
382
+
383
+ def apply_gradients_fn(policy, gradients):
384
+ return
385
+
386
+
387
+ # Build a child class of `TorchPolicy`, given the custom functions defined
388
+ # above.
389
+ CQLTorchPolicy = build_policy_class(
390
+ name="CQLTorchPolicy",
391
+ framework="torch",
392
+ loss_fn=cql_loss,
393
+ get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQLConfig(),
394
+ stats_fn=cql_stats,
395
+ postprocess_fn=postprocess_trajectory,
396
+ extra_grad_process_fn=apply_grad_clipping,
397
+ optimizer_fn=cql_optimizer_fn,
398
+ validate_spaces=validate_spaces,
399
+ before_loss_init=cql_setup_late_mixins,
400
+ make_model_and_action_dist=build_sac_model_and_action_dist,
401
+ extra_learn_fetches_fn=concat_multi_gpu_td_errors,
402
+ mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
403
+ action_distribution_fn=action_distribution_fn,
404
+ compute_gradients_fn=compute_gradients_fn,
405
+ apply_gradients_fn=apply_gradients_fn,
406
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (203 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc ADDED
Binary file (9.88 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc ADDED
Binary file (8.66 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ray.air.constants import TRAINING_ITERATION
4
+ from ray.rllib.algorithms.sac.sac_learner import (
5
+ LOGPS_KEY,
6
+ QF_LOSS_KEY,
7
+ QF_MEAN_KEY,
8
+ QF_MAX_KEY,
9
+ QF_MIN_KEY,
10
+ QF_PREDS,
11
+ QF_TWIN_LOSS_KEY,
12
+ QF_TWIN_PREDS,
13
+ TD_ERROR_MEAN_KEY,
14
+ )
15
+ from ray.rllib.algorithms.cql.cql import CQLConfig
16
+ from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner
17
+ from ray.rllib.core.columns import Columns
18
+ from ray.rllib.core.learner.learner import (
19
+ POLICY_LOSS_KEY,
20
+ )
21
+ from ray.rllib.utils.annotations import override
22
+ from ray.rllib.utils.metrics import ALL_MODULES
23
+ from ray.rllib.utils.framework import try_import_torch
24
+ from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType
25
+
26
+ torch, nn = try_import_torch()
27
+
28
+
29
+ class CQLTorchLearner(SACTorchLearner):
30
+ @override(SACTorchLearner)
31
+ def compute_loss_for_module(
32
+ self,
33
+ *,
34
+ module_id: ModuleID,
35
+ config: CQLConfig,
36
+ batch: Dict,
37
+ fwd_out: Dict[str, TensorType],
38
+ ) -> TensorType:
39
+
40
+ # TODO (simon, sven): Add upstream information pieces into this timesteps
41
+ # call arg to Learner.update_...().
42
+ self.metrics.log_value(
43
+ (ALL_MODULES, TRAINING_ITERATION),
44
+ 1,
45
+ reduce="sum",
46
+ )
47
+ # Get the train action distribution for the current policy and current state.
48
+ # This is needed for the policy (actor) loss and the `alpha`` loss.
49
+ action_dist_class = self.module[module_id].get_train_action_dist_cls()
50
+ action_dist_curr = action_dist_class.from_logits(
51
+ fwd_out[Columns.ACTION_DIST_INPUTS]
52
+ )
53
+
54
+ # Optimize also the hyperparameter `alpha` by using the current policy
55
+ # evaluated at the current state (from offline data). Note, in contrast
56
+ # to the original SAC loss, here the `alpha` and actor losses are
57
+ # calculated first.
58
+ # TODO (simon): Check, why log(alpha) is used, prob. just better
59
+ # to optimize and monotonic function. Original equation uses alpha.
60
+ alpha_loss = -torch.mean(
61
+ self.curr_log_alpha[module_id]
62
+ * (fwd_out["logp_resampled"].detach() + self.target_entropy[module_id])
63
+ )
64
+
65
+ # Get the current alpha.
66
+ alpha = torch.exp(self.curr_log_alpha[module_id])
67
+ # Start training with behavior cloning and turn to the classic Soft-Actor Critic
68
+ # after `bc_iters` of training iterations.
69
+ if (
70
+ self.metrics.peek((ALL_MODULES, TRAINING_ITERATION), default=0)
71
+ >= config.bc_iters
72
+ ):
73
+ actor_loss = torch.mean(
74
+ alpha.detach() * fwd_out["logp_resampled"] - fwd_out["q_curr"]
75
+ )
76
+ else:
77
+ # Use log-probabilities of the current action distribution to clone
78
+ # the behavior policy (selected actions in data) in the first `bc_iters`
79
+ # training iterations.
80
+ bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS])
81
+ actor_loss = torch.mean(
82
+ alpha.detach() * fwd_out["logp_resampled"] - bc_logps_curr
83
+ )
84
+
85
+ # The critic loss is composed of the standard SAC Critic L2 loss and the
86
+ # CQL entropy loss.
87
+
88
+ # Get the Q-values for the actually selected actions in the offline data.
89
+ # In the critic loss we use these as predictions.
90
+ q_selected = fwd_out[QF_PREDS]
91
+ if config.twin_q:
92
+ q_twin_selected = fwd_out[QF_TWIN_PREDS]
93
+
94
+ if not config.deterministic_backup:
95
+ q_next = (
96
+ fwd_out["q_target_next"]
97
+ - alpha.detach() * fwd_out["logp_next_resampled"]
98
+ )
99
+ else:
100
+ q_next = fwd_out["q_target_next"]
101
+
102
+ # Now mask all Q-values with terminating next states in the targets.
103
+ q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_next
104
+
105
+ # Compute the right hand side of the Bellman equation. Detach this node
106
+ # from the computation graph as we do not want to backpropagate through
107
+ # the target network when optimizing the Q loss.
108
+ # TODO (simon, sven): Kumar et al. (2020) use here also a reward scaler.
109
+ q_selected_target = (
110
+ # TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector.
111
+ batch[Columns.REWARDS]
112
+ # TODO (simon): Implement n_step.
113
+ + (config.gamma) * q_next_masked
114
+ ).detach()
115
+
116
+ # Calculate the TD error.
117
+ td_error = torch.abs(q_selected - q_selected_target)
118
+ # Calculate a TD-error for twin-Q values, if needed.
119
+ if config.twin_q:
120
+ td_error += torch.abs(q_twin_selected - q_selected_target)
121
+ # Rescale the TD error
122
+ td_error *= 0.5
123
+
124
+ # MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)).
125
+ # Note, this needs a sample from the current policy given the next state.
126
+ # Note further, we could also use here the Huber loss instead of the MSE.
127
+ # TODO (simon): Add the huber loss as an alternative (SAC uses it).
128
+ sac_critic_loss = torch.nn.MSELoss(reduction="mean")(
129
+ q_selected,
130
+ q_selected_target,
131
+ )
132
+ if config.twin_q:
133
+ sac_critic_twin_loss = torch.nn.MSELoss(reduction="mean")(
134
+ q_twin_selected,
135
+ q_selected_target,
136
+ )
137
+
138
+ # Now calculate the CQL loss (we use the entropy version of the CQL algorithm).
139
+ # Note, the entropy version performs best in shown experiments.
140
+
141
+ # Compute the log-probabilities for the random actions (note, we generate random
142
+ # actions (from the mu distribution as named in Kumar et al. (2020))).
143
+ # Note, all actions, action log-probabilities and Q-values are already computed
144
+ # by the module's `_forward_train` method.
145
+ # TODO (simon): This is the density for a discrete uniform, however, actions
146
+ # come from a continuous one. So actually this density should use (1/(high-low))
147
+ # instead of (1/2).
148
+ random_density = torch.log(
149
+ torch.pow(
150
+ 0.5,
151
+ torch.tensor(
152
+ fwd_out["actions_curr_repeat"].shape[-1],
153
+ device=fwd_out["actions_curr_repeat"].device,
154
+ ),
155
+ )
156
+ )
157
+ # Merge all Q-values and subtract the log-probabilities (note, we use the
158
+ # entropy version of CQL).
159
+ q_repeat = torch.cat(
160
+ [
161
+ fwd_out["q_rand_repeat"] - random_density,
162
+ fwd_out["q_next_repeat"] - fwd_out["logps_next_repeat"].detach(),
163
+ fwd_out["q_curr_repeat"] - fwd_out["logps_curr_repeat"].detach(),
164
+ ],
165
+ dim=1,
166
+ )
167
+ cql_loss = (
168
+ torch.logsumexp(q_repeat / config.temperature, dim=1).mean()
169
+ * config.min_q_weight
170
+ * config.temperature
171
+ )
172
+ cql_loss -= q_selected.mean() * config.min_q_weight
173
+ # Add the CQL loss term to the SAC loss term.
174
+ critic_loss = sac_critic_loss + cql_loss
175
+
176
+ # If a twin Q-value function is implemented calculated its CQL loss.
177
+ if config.twin_q:
178
+ q_twin_repeat = torch.cat(
179
+ [
180
+ fwd_out["q_twin_rand_repeat"] - random_density,
181
+ fwd_out["q_twin_next_repeat"]
182
+ - fwd_out["logps_next_repeat"].detach(),
183
+ fwd_out["q_twin_curr_repeat"]
184
+ - fwd_out["logps_curr_repeat"].detach(),
185
+ ],
186
+ dim=1,
187
+ )
188
+ cql_twin_loss = (
189
+ torch.logsumexp(q_twin_repeat / config.temperature, dim=1).mean()
190
+ * config.min_q_weight
191
+ * config.temperature
192
+ )
193
+ cql_twin_loss -= q_twin_selected.mean() * config.min_q_weight
194
+ # Add the CQL loss term to the SAC loss term.
195
+ critic_twin_loss = sac_critic_twin_loss + cql_twin_loss
196
+
197
+ # TODO (simon): Check, if we need to implement here also a Lagrangian
198
+ # loss.
199
+
200
+ total_loss = actor_loss + critic_loss + alpha_loss
201
+
202
+ # Add the twin critic loss to the total loss, if needed.
203
+ if config.twin_q:
204
+ # Reweigh the critic loss terms in the total loss.
205
+ total_loss += 0.5 * critic_twin_loss - 0.5 * critic_loss
206
+
207
+ # Log important loss stats (reduce=mean (default), but with window=1
208
+ # in order to keep them history free).
209
+ self.metrics.log_dict(
210
+ {
211
+ POLICY_LOSS_KEY: actor_loss,
212
+ QF_LOSS_KEY: critic_loss,
213
+ # TODO (simon): Add these keys to SAC Learner.
214
+ "cql_loss": cql_loss,
215
+ "alpha_loss": alpha_loss,
216
+ "alpha_value": alpha,
217
+ "log_alpha_value": torch.log(alpha),
218
+ "target_entropy": self.target_entropy[module_id],
219
+ LOGPS_KEY: torch.mean(
220
+ fwd_out["logp_resampled"]
221
+ ), # torch.mean(logps_curr),
222
+ QF_MEAN_KEY: torch.mean(fwd_out["q_curr_repeat"]),
223
+ QF_MAX_KEY: torch.max(fwd_out["q_curr_repeat"]),
224
+ QF_MIN_KEY: torch.min(fwd_out["q_curr_repeat"]),
225
+ TD_ERROR_MEAN_KEY: torch.mean(td_error),
226
+ },
227
+ key=module_id,
228
+ window=1, # <- single items (should not be mean/ema-reduced over time).
229
+ )
230
+ # TODO (simon): Add loss keys for langrangian, if needed.
231
+ # TODO (simon): Add only here then the Langrange parameter optimization.
232
+ if config.twin_q:
233
+ self.metrics.log_dict(
234
+ {
235
+ QF_TWIN_LOSS_KEY: critic_twin_loss,
236
+ },
237
+ key=module_id,
238
+ window=1, # <- single items (should not be mean/ema-reduced over time).
239
+ )
240
+
241
+ # Return the total loss.
242
+ return total_loss
243
+
244
+ @override(SACTorchLearner)
245
+ def compute_gradients(
246
+ self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
247
+ ) -> ParamDict:
248
+
249
+ grads = {}
250
+ for module_id in set(loss_per_module.keys()) - {ALL_MODULES}:
251
+ # Loop through optimizers registered for this module.
252
+ for optim_name, optim in self.get_optimizers_for_module(module_id):
253
+ # Zero the gradients. Note, we need to reset the gradients b/c
254
+ # each component for a module operates on the same graph.
255
+ optim.zero_grad(set_to_none=True)
256
+
257
+ # Compute the gradients for the component and module.
258
+ self.metrics.peek((module_id, optim_name + "_loss")).backward(
259
+ retain_graph=False if optim_name in ["policy", "alpha"] else True
260
+ )
261
+ # Store the gradients for the component and module.
262
+ # TODO (simon): Check another time the graph for overlapping
263
+ # gradients.
264
+ grads.update(
265
+ {
266
+ pid: grads[pid] + p.grad.clone()
267
+ if pid in grads
268
+ else p.grad.clone()
269
+ for pid, p in self.filter_param_dict_for_optimizer(
270
+ self._params, optim
271
+ ).items()
272
+ }
273
+ )
274
+
275
+ return grads
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tree
2
+ from typing import Any, Dict, Optional
3
+
4
+ from ray.rllib.algorithms.sac.sac_learner import (
5
+ QF_PREDS,
6
+ QF_TWIN_PREDS,
7
+ )
8
+ from ray.rllib.algorithms.sac.sac_catalog import SACCatalog
9
+ from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import (
10
+ DefaultSACTorchRLModule,
11
+ )
12
+ from ray.rllib.core.columns import Columns
13
+ from ray.rllib.core.models.base import ENCODER_OUT
14
+ from ray.rllib.utils.annotations import override
15
+ from ray.rllib.utils.framework import try_import_torch
16
+ from ray.rllib.utils.typing import TensorType
17
+
18
+ torch, nn = try_import_torch()
19
+
20
+
21
+ class DefaultCQLTorchRLModule(DefaultSACTorchRLModule):
22
+ def __init__(self, *args, **kwargs):
23
+ catalog_class = kwargs.pop("catalog_class", None)
24
+ if catalog_class is None:
25
+ catalog_class = SACCatalog
26
+ super().__init__(*args, **kwargs, catalog_class=catalog_class)
27
+
28
+ @override(DefaultSACTorchRLModule)
29
+ def _forward_train(self, batch: Dict) -> Dict[str, Any]:
30
+ # Call the super method.
31
+ fwd_out = super()._forward_train(batch)
32
+
33
+ # Make sure we perform a "straight-through gradient" pass here,
34
+ # ignoring the gradients of the q-net, however, still recording
35
+ # the gradients of the policy net (which was used to rsample the actions used
36
+ # here). This is different from doing `.detach()` or `with torch.no_grads()`,
37
+ # as these two methds would fully block all gradient recordings, including
38
+ # the needed policy ones.
39
+ all_params = list(self.pi_encoder.parameters()) + list(self.pi.parameters())
40
+ # if self.twin_q:
41
+ # all_params += list(self.qf_twin.parameters()) + list(
42
+ # self.qf_twin_encoder.parameters()
43
+ # )
44
+
45
+ for param in all_params:
46
+ param.requires_grad = False
47
+
48
+ # Compute the repeated actions, action log-probabilites and Q-values for all
49
+ # observations.
50
+ # First for the random actions (from the mu-distribution as named by Kumar et
51
+ # al. (2020)).
52
+ low = torch.tensor(
53
+ self.action_space.low,
54
+ device=fwd_out[QF_PREDS].device,
55
+ )
56
+ high = torch.tensor(
57
+ self.action_space.high,
58
+ device=fwd_out[QF_PREDS].device,
59
+ )
60
+ num_samples = batch[Columns.ACTIONS].shape[0] * self.model_config["num_actions"]
61
+ actions_rand_repeat = low + (high - low) * torch.rand(
62
+ (num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device
63
+ )
64
+
65
+ # First for the random actions (from the mu-distribution as named in Kumar
66
+ # et al. (2020)) using repeated observations.
67
+ rand_repeat_out = self._repeat_actions(batch[Columns.OBS], actions_rand_repeat)
68
+ (fwd_out["actions_rand_repeat"], fwd_out["q_rand_repeat"]) = (
69
+ rand_repeat_out[Columns.ACTIONS],
70
+ rand_repeat_out[QF_PREDS],
71
+ )
72
+ # Sample current and next actions (from the pi distribution as named in Kumar
73
+ # et al. (2020)) using repeated observations
74
+ # Second for the current observations and the current action distribution.
75
+ curr_repeat_out = self._repeat_actions(batch[Columns.OBS])
76
+ (
77
+ fwd_out["actions_curr_repeat"],
78
+ fwd_out["logps_curr_repeat"],
79
+ fwd_out["q_curr_repeat"],
80
+ ) = (
81
+ curr_repeat_out[Columns.ACTIONS],
82
+ curr_repeat_out[Columns.ACTION_LOGP],
83
+ curr_repeat_out[QF_PREDS],
84
+ )
85
+ # Then, for the next observations and the current action distribution.
86
+ next_repeat_out = self._repeat_actions(batch[Columns.NEXT_OBS])
87
+ (
88
+ fwd_out["actions_next_repeat"],
89
+ fwd_out["logps_next_repeat"],
90
+ fwd_out["q_next_repeat"],
91
+ ) = (
92
+ next_repeat_out[Columns.ACTIONS],
93
+ next_repeat_out[Columns.ACTION_LOGP],
94
+ next_repeat_out[QF_PREDS],
95
+ )
96
+ if self.twin_q:
97
+ # First for the random actions from the mu-distribution.
98
+ fwd_out["q_twin_rand_repeat"] = rand_repeat_out[QF_TWIN_PREDS]
99
+ # Second for the current observations and the current action distribution.
100
+ fwd_out["q_twin_curr_repeat"] = curr_repeat_out[QF_TWIN_PREDS]
101
+ # Then, for the next observations and the current action distribution.
102
+ fwd_out["q_twin_next_repeat"] = next_repeat_out[QF_TWIN_PREDS]
103
+ # Reset the gradient requirements for all Q-function parameters.
104
+ for param in all_params:
105
+ param.requires_grad = True
106
+
107
+ return fwd_out
108
+
109
+ def _repeat_tensor(self, tensor: TensorType, repeat: int) -> TensorType:
110
+ """Generates a repeated version of a tensor.
111
+
112
+ The repetition is done similar `np.repeat` and repeats each value
113
+ instead of the complete vector.
114
+
115
+ Args:
116
+ tensor: The tensor to be repeated.
117
+ repeat: How often each value in the tensor should be repeated.
118
+
119
+ Returns:
120
+ A tensor holding `repeat` repeated values of the input `tensor`
121
+ """
122
+ # Insert the new dimension at axis 1 into the tensor.
123
+ t_repeat = tensor.unsqueeze(1)
124
+ # Repeat the tensor along the new dimension.
125
+ t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1)
126
+ # Stack the repeated values into the batch dimension.
127
+ t_repeat = t_repeat.view(-1, *tensor.shape[1:])
128
+ # Return the repeated tensor.
129
+ return t_repeat
130
+
131
+ def _repeat_actions(
132
+ self, obs: TensorType, actions: Optional[TensorType] = None
133
+ ) -> Dict[str, TensorType]:
134
+ """Generated actions and Q-values for repeated observations.
135
+
136
+ The `self.model_config["num_actions"]` define a multiplier
137
+ used for generating `num_actions` as many actions as the batch size.
138
+ Observations are repeated and then a model forward pass is made.
139
+
140
+ Args:
141
+ obs: A batched observation tensor.
142
+ actions: An optional batched actions tensor.
143
+
144
+ Returns:
145
+ A dictionary holding the (sampled or passed-in actions), the log
146
+ probabilities (of sampled actions), the Q-values and if available
147
+ the twin-Q values.
148
+ """
149
+ output = {}
150
+ # Receive the batch size.
151
+ batch_size = obs.shape[0]
152
+ # Receive the number of action to sample.
153
+ num_actions = self.model_config["num_actions"]
154
+ # Repeat the observations `num_actions` times.
155
+ obs_repeat = tree.map_structure(
156
+ lambda t: self._repeat_tensor(t, num_actions), obs
157
+ )
158
+ # Generate a batch for the forward pass.
159
+ temp_batch = {Columns.OBS: obs_repeat}
160
+ if actions is None:
161
+ # TODO (simon): Run the forward pass in inference mode.
162
+ # Compute the action logits.
163
+ pi_encoder_outs = self.pi_encoder(temp_batch)
164
+ action_logits = self.pi(pi_encoder_outs[ENCODER_OUT])
165
+ # Generate the squashed Gaussian from the model's logits.
166
+ action_dist = self.get_train_action_dist_cls().from_logits(action_logits)
167
+ # Sample the actions. Note, we want to make a backward pass through
168
+ # these actions.
169
+ output[Columns.ACTIONS] = action_dist.rsample()
170
+ # Compute the action log-probabilities.
171
+ output[Columns.ACTION_LOGP] = action_dist.logp(
172
+ output[Columns.ACTIONS]
173
+ ).view(batch_size, num_actions, 1)
174
+ else:
175
+ output[Columns.ACTIONS] = actions
176
+
177
+ # Compute all Q-values.
178
+ temp_batch.update(
179
+ {
180
+ Columns.ACTIONS: output[Columns.ACTIONS],
181
+ }
182
+ )
183
+ output.update(
184
+ {
185
+ QF_PREDS: self._qf_forward_train_helper(
186
+ temp_batch,
187
+ self.qf_encoder,
188
+ self.qf,
189
+ ).view(batch_size, num_actions, 1)
190
+ }
191
+ )
192
+ # If we have a twin-Q network, compute its Q-values, too.
193
+ if self.twin_q:
194
+ output.update(
195
+ {
196
+ QF_TWIN_PREDS: self._qf_forward_train_helper(
197
+ temp_batch,
198
+ self.qf_twin_encoder,
199
+ self.qf_twin,
200
+ ).view(batch_size, num_actions, 1)
201
+ }
202
+ )
203
+ del temp_batch
204
+
205
+ # Return
206
+ return output
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+
6
+ [2] Mastering Atari with Discrete World Models - 2021
7
+ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
8
+ https://arxiv.org/pdf/2010.02193.pdf
9
+ """
10
+ from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3, DreamerV3Config
11
+
12
+ __all__ = [
13
+ "DreamerV3",
14
+ "DreamerV3Config",
15
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+
6
+ [2] Mastering Atari with Discrete World Models - 2021
7
+ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
8
+ https://arxiv.org/pdf/2010.02193.pdf
9
+ """
10
+
11
+ import gc
12
+ import logging
13
+ from typing import Any, Dict, Optional, Union
14
+
15
+ import gymnasium as gym
16
+
17
+ from ray.rllib.algorithms.algorithm import Algorithm
18
+ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
19
+ from ray.rllib.algorithms.dreamerv3.dreamerv3_catalog import DreamerV3Catalog
20
+ from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs
21
+ from ray.rllib.algorithms.dreamerv3.utils.env_runner import DreamerV3EnvRunner
22
+ from ray.rllib.algorithms.dreamerv3.utils.summaries import (
23
+ report_dreamed_eval_trajectory_vs_samples,
24
+ report_predicted_vs_sampled_obs,
25
+ report_sampling_and_replay_buffer,
26
+ )
27
+ from ray.rllib.core import DEFAULT_MODULE_ID
28
+ from ray.rllib.core.columns import Columns
29
+ from ray.rllib.core.rl_module.rl_module import RLModuleSpec
30
+ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
31
+ from ray.rllib.policy.sample_batch import SampleBatch
32
+ from ray.rllib.utils import deep_update
33
+ from ray.rllib.utils.annotations import override, PublicAPI
34
+ from ray.rllib.utils.framework import try_import_tf
35
+ from ray.rllib.utils.numpy import one_hot
36
+ from ray.rllib.utils.metrics import (
37
+ ENV_RUNNER_RESULTS,
38
+ GARBAGE_COLLECTION_TIMER,
39
+ LEARN_ON_BATCH_TIMER,
40
+ LEARNER_RESULTS,
41
+ NUM_ENV_STEPS_SAMPLED_LIFETIME,
42
+ NUM_ENV_STEPS_TRAINED_LIFETIME,
43
+ NUM_GRAD_UPDATES_LIFETIME,
44
+ NUM_SYNCH_WORKER_WEIGHTS,
45
+ SAMPLE_TIMER,
46
+ SYNCH_WORKER_WEIGHTS_TIMER,
47
+ TIMERS,
48
+ )
49
+ from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
50
+ from ray.rllib.utils.typing import LearningRateOrSchedule
51
+
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ _, tf, _ = try_import_tf()
56
+
57
+
58
+ class DreamerV3Config(AlgorithmConfig):
59
+ """Defines a configuration class from which a DreamerV3 can be built.
60
+
61
+ .. testcode::
62
+
63
+ from ray.rllib.algorithms.dreamerv3 import DreamerV3Config
64
+ config = (
65
+ DreamerV3Config()
66
+ .environment("CartPole-v1")
67
+ .training(
68
+ model_size="XS",
69
+ training_ratio=1,
70
+ # TODO
71
+ model={
72
+ "batch_size_B": 1,
73
+ "batch_length_T": 1,
74
+ "horizon_H": 1,
75
+ "gamma": 0.997,
76
+ "model_size": "XS",
77
+ },
78
+ )
79
+ )
80
+
81
+ config = config.learners(num_learners=0)
82
+ # Build a Algorithm object from the config and run 1 training iteration.
83
+ algo = config.build()
84
+ # algo.train()
85
+ del algo
86
+
87
+ .. testoutput::
88
+ :hide:
89
+
90
+ ...
91
+ """
92
+
93
+ def __init__(self, algo_class=None):
94
+ """Initializes a DreamerV3Config instance."""
95
+ super().__init__(algo_class=algo_class or DreamerV3)
96
+
97
+ # fmt: off
98
+ # __sphinx_doc_begin__
99
+
100
+ # DreamerV3 specific settings:
101
+ self.model_size = "XS"
102
+ self.training_ratio = 1024
103
+
104
+ self.replay_buffer_config = {
105
+ "type": "EpisodeReplayBuffer",
106
+ "capacity": int(1e6),
107
+ }
108
+ self.world_model_lr = 1e-4
109
+ self.actor_lr = 3e-5
110
+ self.critic_lr = 3e-5
111
+ self.batch_size_B = 16
112
+ self.batch_length_T = 64
113
+ self.horizon_H = 15
114
+ self.gae_lambda = 0.95 # [1] eq. 7.
115
+ self.entropy_scale = 3e-4 # [1] eq. 11.
116
+ self.return_normalization_decay = 0.99 # [1] eq. 11 and 12.
117
+ self.train_critic = True
118
+ self.train_actor = True
119
+ self.intrinsic_rewards_scale = 0.1
120
+ self.world_model_grad_clip_by_global_norm = 1000.0
121
+ self.critic_grad_clip_by_global_norm = 100.0
122
+ self.actor_grad_clip_by_global_norm = 100.0
123
+ self.symlog_obs = "auto"
124
+ self.use_float16 = False
125
+ self.use_curiosity = False
126
+
127
+ # Reporting.
128
+ # DreamerV3 is super sample efficient and only needs very few episodes
129
+ # (normally) to learn. Leaving this at its default value would gravely
130
+ # underestimate the learning performance over the course of an experiment.
131
+ self.metrics_num_episodes_for_smoothing = 1
132
+ self.report_individual_batch_item_stats = False
133
+ self.report_dream_data = False
134
+ self.report_images_and_videos = False
135
+ self.gc_frequency_train_steps = 100
136
+
137
+ # Override some of AlgorithmConfig's default values with DreamerV3-specific
138
+ # values.
139
+ self.lr = None
140
+ self.framework_str = "tf2"
141
+ self.gamma = 0.997 # [1] eq. 7.
142
+ # Do not use! Set `batch_size_B` and `batch_length_T` instead.
143
+ self.train_batch_size = None
144
+ self.env_runner_cls = DreamerV3EnvRunner
145
+ self.num_env_runners = 0
146
+ self.rollout_fragment_length = 1
147
+ # Dreamer only runs on the new API stack.
148
+ self.enable_rl_module_and_learner = True
149
+ self.enable_env_runner_and_connector_v2 = True
150
+ # TODO (sven): DreamerV3 still uses its own EnvRunner class. This env-runner
151
+ # does not use connectors. We therefore should not attempt to merge/broadcast
152
+ # the connector states between EnvRunners (if >0). Note that this is only
153
+ # relevant if num_env_runners > 0, which is normally not the case when using
154
+ # this algo.
155
+ self.use_worker_filter_stats = False
156
+ # __sphinx_doc_end__
157
+ # fmt: on
158
+
159
+ @property
160
+ def batch_size_B_per_learner(self):
161
+ """Returns the batch_size_B per Learner worker.
162
+
163
+ Needed by some of the DreamerV3 loss math."""
164
+ return self.batch_size_B // (self.num_learners or 1)
165
+
166
+ @override(AlgorithmConfig)
167
+ def training(
168
+ self,
169
+ *,
170
+ model_size: Optional[str] = NotProvided,
171
+ training_ratio: Optional[float] = NotProvided,
172
+ gc_frequency_train_steps: Optional[int] = NotProvided,
173
+ batch_size_B: Optional[int] = NotProvided,
174
+ batch_length_T: Optional[int] = NotProvided,
175
+ horizon_H: Optional[int] = NotProvided,
176
+ gae_lambda: Optional[float] = NotProvided,
177
+ entropy_scale: Optional[float] = NotProvided,
178
+ return_normalization_decay: Optional[float] = NotProvided,
179
+ train_critic: Optional[bool] = NotProvided,
180
+ train_actor: Optional[bool] = NotProvided,
181
+ intrinsic_rewards_scale: Optional[float] = NotProvided,
182
+ world_model_lr: Optional[LearningRateOrSchedule] = NotProvided,
183
+ actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
184
+ critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
185
+ world_model_grad_clip_by_global_norm: Optional[float] = NotProvided,
186
+ critic_grad_clip_by_global_norm: Optional[float] = NotProvided,
187
+ actor_grad_clip_by_global_norm: Optional[float] = NotProvided,
188
+ symlog_obs: Optional[Union[bool, str]] = NotProvided,
189
+ use_float16: Optional[bool] = NotProvided,
190
+ replay_buffer_config: Optional[dict] = NotProvided,
191
+ use_curiosity: Optional[bool] = NotProvided,
192
+ **kwargs,
193
+ ) -> "DreamerV3Config":
194
+ """Sets the training related configuration.
195
+
196
+ Args:
197
+ model_size: The main switch for adjusting the overall model size. See [1]
198
+ (table B) for more information on the effects of this setting on the
199
+ model architecture.
200
+ Supported values are "XS", "S", "M", "L", "XL" (as per the paper), as
201
+ well as, "nano", "micro", "mini", and "XXS" (for RLlib's
202
+ implementation). See ray.rllib.algorithms.dreamerv3.utils.
203
+ __init__.py for the details on what exactly each size does to the layer
204
+ sizes, number of layers, etc..
205
+ training_ratio: The ratio of total steps trained (sum of the sizes of all
206
+ batches ever sampled from the replay buffer) over the total env steps
207
+ taken (in the actual environment, not the dreamed one). For example,
208
+ if the training_ratio is 1024 and the batch size is 1024, we would take
209
+ 1 env step for every training update: 1024 / 1. If the training ratio
210
+ is 512 and the batch size is 1024, we would take 2 env steps and then
211
+ perform a single training update (on a 1024 batch): 1024 / 2.
212
+ gc_frequency_train_steps: The frequency (in training iterations) with which
213
+ we perform a `gc.collect()` calls at the end of a `training_step`
214
+ iteration. Doing this more often adds a (albeit very small) performance
215
+ overhead, but prevents memory leaks from becoming harmful.
216
+ TODO (sven): This might not be necessary anymore, but needs to be
217
+ confirmed experimentally.
218
+ batch_size_B: The batch size (B) interpreted as number of rows (each of
219
+ length `batch_length_T`) to sample from the replay buffer in each
220
+ iteration.
221
+ batch_length_T: The batch length (T) interpreted as the length of each row
222
+ sampled from the replay buffer in each iteration. Note that
223
+ `batch_size_B` rows will be sampled in each iteration. Rows normally
224
+ contain consecutive data (consecutive timesteps from the same episode),
225
+ but there might be episode boundaries in a row as well.
226
+ horizon_H: The horizon (in timesteps) used to create dreamed data from the
227
+ world model, which in turn is used to train/update both actor- and
228
+ critic networks.
229
+ gae_lambda: The lambda parameter used for computing the GAE-style
230
+ value targets for the actor- and critic losses.
231
+ entropy_scale: The factor with which to multiply the entropy loss term
232
+ inside the actor loss.
233
+ return_normalization_decay: The decay value to use when computing the
234
+ running EMA values for return normalization (used in the actor loss).
235
+ train_critic: Whether to train the critic network. If False, `train_actor`
236
+ must also be False (cannot train actor w/o training the critic).
237
+ train_actor: Whether to train the actor network. If True, `train_critic`
238
+ must also be True (cannot train actor w/o training the critic).
239
+ intrinsic_rewards_scale: The factor to multiply intrinsic rewards with
240
+ before adding them to the extrinsic (environment) rewards.
241
+ world_model_lr: The learning rate or schedule for the world model optimizer.
242
+ actor_lr: The learning rate or schedule for the actor optimizer.
243
+ critic_lr: The learning rate or schedule for the critic optimizer.
244
+ world_model_grad_clip_by_global_norm: World model grad clipping value
245
+ (by global norm).
246
+ critic_grad_clip_by_global_norm: Critic grad clipping value
247
+ (by global norm).
248
+ actor_grad_clip_by_global_norm: Actor grad clipping value (by global norm).
249
+ symlog_obs: Whether to symlog observations or not. If set to "auto"
250
+ (default), will check for the environment's observation space and then
251
+ only symlog if not an image space.
252
+ use_float16: Whether to train with mixed float16 precision. In this mode,
253
+ model parameters are stored as float32, but all computations are
254
+ performed in float16 space (except for losses and distribution params
255
+ and outputs).
256
+ replay_buffer_config: Replay buffer config.
257
+ Only serves in DreamerV3 to set the capacity of the replay buffer.
258
+ Note though that in the paper ([1]) a size of 1M is used for all
259
+ benchmarks and there doesn't seem to be a good reason to change this
260
+ parameter.
261
+ Examples:
262
+ {
263
+ "type": "EpisodeReplayBuffer",
264
+ "capacity": 100000,
265
+ }
266
+
267
+ Returns:
268
+ This updated AlgorithmConfig object.
269
+ """
270
+ # Not fully supported/tested yet.
271
+ if use_curiosity is not NotProvided:
272
+ raise ValueError(
273
+ "`DreamerV3Config.curiosity` is not fully supported and tested yet! "
274
+ "It thus remains disabled for now."
275
+ )
276
+
277
+ # Pass kwargs onto super's `training()` method.
278
+ super().training(**kwargs)
279
+
280
+ if model_size is not NotProvided:
281
+ self.model_size = model_size
282
+ if training_ratio is not NotProvided:
283
+ self.training_ratio = training_ratio
284
+ if gc_frequency_train_steps is not NotProvided:
285
+ self.gc_frequency_train_steps = gc_frequency_train_steps
286
+ if batch_size_B is not NotProvided:
287
+ self.batch_size_B = batch_size_B
288
+ if batch_length_T is not NotProvided:
289
+ self.batch_length_T = batch_length_T
290
+ if horizon_H is not NotProvided:
291
+ self.horizon_H = horizon_H
292
+ if gae_lambda is not NotProvided:
293
+ self.gae_lambda = gae_lambda
294
+ if entropy_scale is not NotProvided:
295
+ self.entropy_scale = entropy_scale
296
+ if return_normalization_decay is not NotProvided:
297
+ self.return_normalization_decay = return_normalization_decay
298
+ if train_critic is not NotProvided:
299
+ self.train_critic = train_critic
300
+ if train_actor is not NotProvided:
301
+ self.train_actor = train_actor
302
+ if intrinsic_rewards_scale is not NotProvided:
303
+ self.intrinsic_rewards_scale = intrinsic_rewards_scale
304
+ if world_model_lr is not NotProvided:
305
+ self.world_model_lr = world_model_lr
306
+ if actor_lr is not NotProvided:
307
+ self.actor_lr = actor_lr
308
+ if critic_lr is not NotProvided:
309
+ self.critic_lr = critic_lr
310
+ if world_model_grad_clip_by_global_norm is not NotProvided:
311
+ self.world_model_grad_clip_by_global_norm = (
312
+ world_model_grad_clip_by_global_norm
313
+ )
314
+ if critic_grad_clip_by_global_norm is not NotProvided:
315
+ self.critic_grad_clip_by_global_norm = critic_grad_clip_by_global_norm
316
+ if actor_grad_clip_by_global_norm is not NotProvided:
317
+ self.actor_grad_clip_by_global_norm = actor_grad_clip_by_global_norm
318
+ if symlog_obs is not NotProvided:
319
+ self.symlog_obs = symlog_obs
320
+ if use_float16 is not NotProvided:
321
+ self.use_float16 = use_float16
322
+ if replay_buffer_config is not NotProvided:
323
+ # Override entire `replay_buffer_config` if `type` key changes.
324
+ # Update, if `type` key remains the same or is not specified.
325
+ new_replay_buffer_config = deep_update(
326
+ {"replay_buffer_config": self.replay_buffer_config},
327
+ {"replay_buffer_config": replay_buffer_config},
328
+ False,
329
+ ["replay_buffer_config"],
330
+ ["replay_buffer_config"],
331
+ )
332
+ self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
333
+
334
+ return self
335
+
336
+ @override(AlgorithmConfig)
337
+ def reporting(
338
+ self,
339
+ *,
340
+ report_individual_batch_item_stats: Optional[bool] = NotProvided,
341
+ report_dream_data: Optional[bool] = NotProvided,
342
+ report_images_and_videos: Optional[bool] = NotProvided,
343
+ **kwargs,
344
+ ):
345
+ """Sets the reporting related configuration.
346
+
347
+ Args:
348
+ report_individual_batch_item_stats: Whether to include loss and other stats
349
+ per individual timestep inside the training batch in the result dict
350
+ returned by `training_step()`. If True, besides the `CRITIC_L_total`,
351
+ the individual critic loss values per batch row and time axis step
352
+ in the train batch (CRITIC_L_total_B_T) will also be part of the
353
+ results.
354
+ report_dream_data: Whether to include the dreamed trajectory data in the
355
+ result dict returned by `training_step()`. If True, however, will
356
+ slice each reported item in the dream data down to the shape.
357
+ (H, B, t=0, ...), where H is the horizon and B is the batch size. The
358
+ original time axis will only be represented by the first timestep
359
+ to not make this data too large to handle.
360
+ report_images_and_videos: Whether to include any image/video data in the
361
+ result dict returned by `training_step()`.
362
+ **kwargs:
363
+
364
+ Returns:
365
+ This updated AlgorithmConfig object.
366
+ """
367
+ super().reporting(**kwargs)
368
+
369
+ if report_individual_batch_item_stats is not NotProvided:
370
+ self.report_individual_batch_item_stats = report_individual_batch_item_stats
371
+ if report_dream_data is not NotProvided:
372
+ self.report_dream_data = report_dream_data
373
+ if report_images_and_videos is not NotProvided:
374
+ self.report_images_and_videos = report_images_and_videos
375
+
376
+ return self
377
+
378
+ @override(AlgorithmConfig)
379
+ def validate(self) -> None:
380
+ # Call the super class' validation method first.
381
+ super().validate()
382
+
383
+ # Make sure, users are not using DreamerV3 yet for multi-agent:
384
+ if self.is_multi_agent:
385
+ self._value_error("DreamerV3 does NOT support multi-agent setups yet!")
386
+
387
+ # Make sure, we are configure for the new API stack.
388
+ if not self.enable_rl_module_and_learner:
389
+ self._value_error(
390
+ "DreamerV3 must be run with `config.api_stack("
391
+ "enable_rl_module_and_learner=True)`!"
392
+ )
393
+
394
+ # If run on several Learners, the provided batch_size_B must be a multiple
395
+ # of `num_learners`.
396
+ if self.num_learners > 1 and (self.batch_size_B % self.num_learners != 0):
397
+ self._value_error(
398
+ f"Your `batch_size_B` ({self.batch_size_B}) must be a multiple of "
399
+ f"`num_learners` ({self.num_learners}) in order for "
400
+ "DreamerV3 to be able to split batches evenly across your Learner "
401
+ "processes."
402
+ )
403
+
404
+ # Cannot train actor w/o critic.
405
+ if self.train_actor and not self.train_critic:
406
+ self._value_error(
407
+ "Cannot train actor network (`train_actor=True`) w/o training critic! "
408
+ "Make sure you either set `train_critic=True` or `train_actor=False`."
409
+ )
410
+ # Use DreamerV3 specific batch size settings.
411
+ if self.train_batch_size is not None:
412
+ self._value_error(
413
+ "`train_batch_size` should NOT be set! Use `batch_size_B` and "
414
+ "`batch_length_T` instead."
415
+ )
416
+ # Must be run with `EpisodeReplayBuffer` type.
417
+ if self.replay_buffer_config.get("type") != "EpisodeReplayBuffer":
418
+ self._value_error(
419
+ "DreamerV3 must be run with the `EpisodeReplayBuffer` type! None "
420
+ "other supported."
421
+ )
422
+
423
+ @override(AlgorithmConfig)
424
+ def get_default_learner_class(self):
425
+ if self.framework_str == "tf2":
426
+ from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_learner import (
427
+ DreamerV3TfLearner,
428
+ )
429
+
430
+ return DreamerV3TfLearner
431
+ else:
432
+ raise ValueError(f"The framework {self.framework_str} is not supported.")
433
+
434
+ @override(AlgorithmConfig)
435
+ def get_default_rl_module_spec(self) -> RLModuleSpec:
436
+ if self.framework_str == "tf2":
437
+ from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_rl_module import (
438
+ DreamerV3TfRLModule,
439
+ )
440
+
441
+ return RLModuleSpec(
442
+ module_class=DreamerV3TfRLModule, catalog_class=DreamerV3Catalog
443
+ )
444
+ else:
445
+ raise ValueError(f"The framework {self.framework_str} is not supported.")
446
+
447
+ @property
448
+ def share_module_between_env_runner_and_learner(self) -> bool:
449
+ # If we only have one local Learner (num_learners=0) and only
450
+ # one local EnvRunner (num_env_runners=0), share the RLModule
451
+ # between these two to avoid having to sync weights, ever.
452
+ return self.num_learners == 0 and self.num_env_runners == 0
453
+
454
+ @property
455
+ @override(AlgorithmConfig)
456
+ def _model_config_auto_includes(self) -> Dict[str, Any]:
457
+ return super()._model_config_auto_includes | {
458
+ "gamma": self.gamma,
459
+ "horizon_H": self.horizon_H,
460
+ "model_size": self.model_size,
461
+ "symlog_obs": self.symlog_obs,
462
+ "use_float16": self.use_float16,
463
+ "batch_length_T": self.batch_length_T,
464
+ }
465
+
466
+
467
+ class DreamerV3(Algorithm):
468
+ """Implementation of the model-based DreamerV3 RL algorithm described in [1]."""
469
+
470
+ # TODO (sven): Deprecate/do-over the Algorithm.compute_single_action() API.
471
+ @override(Algorithm)
472
+ def compute_single_action(self, *args, **kwargs):
473
+ raise NotImplementedError(
474
+ "DreamerV3 does not support the `compute_single_action()` API. Refer to the"
475
+ " README here (https://github.com/ray-project/ray/tree/master/rllib/"
476
+ "algorithms/dreamerv3) to find more information on how to run action "
477
+ "inference with this algorithm."
478
+ )
479
+
480
+ @classmethod
481
+ @override(Algorithm)
482
+ def get_default_config(cls) -> AlgorithmConfig:
483
+ return DreamerV3Config()
484
+
485
+ @override(Algorithm)
486
+ def setup(self, config: AlgorithmConfig):
487
+ super().setup(config)
488
+
489
+ # Share RLModule between EnvRunner and single (local) Learner instance.
490
+ # To avoid possibly expensive weight synching step.
491
+ if self.config.share_module_between_env_runner_and_learner:
492
+ assert self.env_runner.module is None
493
+ self.env_runner.module = self.learner_group._learner.module[
494
+ DEFAULT_MODULE_ID
495
+ ]
496
+
497
+ # Summarize (single-agent) RLModule (only once) here.
498
+ if self.config.framework_str == "tf2":
499
+ self.env_runner.module.dreamer_model.summary(expand_nested=True)
500
+
501
+ # Create a replay buffer for storing actual env samples.
502
+ self.replay_buffer = EpisodeReplayBuffer(
503
+ capacity=self.config.replay_buffer_config["capacity"],
504
+ batch_size_B=self.config.batch_size_B,
505
+ batch_length_T=self.config.batch_length_T,
506
+ )
507
+
508
+ @override(Algorithm)
509
+ def training_step(self) -> None:
510
+ # Push enough samples into buffer initially before we start training.
511
+ if self.training_iteration == 0:
512
+ logger.info(
513
+ "Filling replay buffer so it contains at least "
514
+ f"{self.config.batch_size_B * self.config.batch_length_T} timesteps "
515
+ "(required for a single train batch)."
516
+ )
517
+
518
+ # Have we sampled yet in this `training_step()` call?
519
+ have_sampled = False
520
+ with self.metrics.log_time((TIMERS, SAMPLE_TIMER)):
521
+ # Continue sampling from the actual environment (and add collected samples
522
+ # to our replay buffer) as long as we:
523
+ while (
524
+ # a) Don't have at least batch_size_B x batch_length_T timesteps stored
525
+ # in the buffer. This is the minimum needed to train.
526
+ self.replay_buffer.get_num_timesteps()
527
+ < (self.config.batch_size_B * self.config.batch_length_T)
528
+ # b) The computed `training_ratio` is >= the configured (desired)
529
+ # training ratio (meaning we should continue sampling).
530
+ or self.training_ratio >= self.config.training_ratio
531
+ # c) we have not sampled at all yet in this `training_step()` call.
532
+ or not have_sampled
533
+ ):
534
+ # Sample using the env runner's module.
535
+ episodes, env_runner_results = synchronous_parallel_sample(
536
+ worker_set=self.env_runner_group,
537
+ max_agent_steps=(
538
+ self.config.rollout_fragment_length
539
+ * self.config.num_envs_per_env_runner
540
+ ),
541
+ sample_timeout_s=self.config.sample_timeout_s,
542
+ _uses_new_env_runners=True,
543
+ _return_metrics=True,
544
+ )
545
+ self.metrics.merge_and_log_n_dicts(
546
+ env_runner_results, key=ENV_RUNNER_RESULTS
547
+ )
548
+ # Add ongoing and finished episodes into buffer. The buffer will
549
+ # automatically take care of properly concatenating (by episode IDs)
550
+ # the different chunks of the same episodes, even if they come in via
551
+ # separate `add()` calls.
552
+ self.replay_buffer.add(episodes=episodes)
553
+ have_sampled = True
554
+
555
+ # We took B x T env steps.
556
+ env_steps_last_regular_sample = sum(len(eps) for eps in episodes)
557
+ total_sampled = env_steps_last_regular_sample
558
+
559
+ # If we have never sampled before (just started the algo and not
560
+ # recovered from a checkpoint), sample B random actions first.
561
+ if (
562
+ self.metrics.peek(
563
+ (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
564
+ default=0,
565
+ )
566
+ == 0
567
+ ):
568
+ _episodes, _env_runner_results = synchronous_parallel_sample(
569
+ worker_set=self.env_runner_group,
570
+ max_agent_steps=(
571
+ self.config.batch_size_B * self.config.batch_length_T
572
+ - env_steps_last_regular_sample
573
+ ),
574
+ sample_timeout_s=self.config.sample_timeout_s,
575
+ random_actions=True,
576
+ _uses_new_env_runners=True,
577
+ _return_metrics=True,
578
+ )
579
+ self.metrics.merge_and_log_n_dicts(
580
+ _env_runner_results, key=ENV_RUNNER_RESULTS
581
+ )
582
+ self.replay_buffer.add(episodes=_episodes)
583
+ total_sampled += sum(len(eps) for eps in _episodes)
584
+
585
+ # Summarize environment interaction and buffer data.
586
+ report_sampling_and_replay_buffer(
587
+ metrics=self.metrics, replay_buffer=self.replay_buffer
588
+ )
589
+
590
+ # Continue sampling batch_size_B x batch_length_T sized batches from the buffer
591
+ # and using these to update our models (`LearnerGroup.update_from_batch()`)
592
+ # until the computed `training_ratio` is larger than the configured one, meaning
593
+ # we should go back and collect more samples again from the actual environment.
594
+ # However, when calculating the `training_ratio` here, we use only the
595
+ # trained steps in this very `training_step()` call over the most recent sample
596
+ # amount (`env_steps_last_regular_sample`), not the global values. This is to
597
+ # avoid a heavy overtraining at the very beginning when we have just pre-filled
598
+ # the buffer with the minimum amount of samples.
599
+ replayed_steps_this_iter = sub_iter = 0
600
+ while (
601
+ replayed_steps_this_iter / env_steps_last_regular_sample
602
+ ) < self.config.training_ratio:
603
+ # Time individual batch updates.
604
+ with self.metrics.log_time((TIMERS, LEARN_ON_BATCH_TIMER)):
605
+ logger.info(f"\tSub-iteration {self.training_iteration}/{sub_iter})")
606
+
607
+ # Draw a new sample from the replay buffer.
608
+ sample = self.replay_buffer.sample(
609
+ batch_size_B=self.config.batch_size_B,
610
+ batch_length_T=self.config.batch_length_T,
611
+ )
612
+ replayed_steps = self.config.batch_size_B * self.config.batch_length_T
613
+ replayed_steps_this_iter += replayed_steps
614
+
615
+ if isinstance(
616
+ self.env_runner.env.single_action_space, gym.spaces.Discrete
617
+ ):
618
+ sample["actions_ints"] = sample[Columns.ACTIONS]
619
+ sample[Columns.ACTIONS] = one_hot(
620
+ sample["actions_ints"],
621
+ depth=self.env_runner.env.single_action_space.n,
622
+ )
623
+
624
+ # Perform the actual update via our learner group.
625
+ learner_results = self.learner_group.update_from_batch(
626
+ batch=SampleBatch(sample).as_multi_agent(),
627
+ # TODO(sven): Maybe we should do this broadcase of global timesteps
628
+ # at the end, like for EnvRunner global env step counts. Maybe when
629
+ # we request the state from the Learners, we can - at the same
630
+ # time - send the current globally summed/reduced-timesteps.
631
+ timesteps={
632
+ NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
633
+ (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
634
+ default=0,
635
+ )
636
+ },
637
+ )
638
+ self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
639
+
640
+ sub_iter += 1
641
+ self.metrics.log_value(NUM_GRAD_UPDATES_LIFETIME, 1, reduce="sum")
642
+
643
+ # Log videos showing how the decoder produces observation predictions
644
+ # from the posterior states.
645
+ # Only every n iterations and only for the first sampled batch row
646
+ # (videos are `config.batch_length_T` frames long).
647
+ report_predicted_vs_sampled_obs(
648
+ # TODO (sven): DreamerV3 is single-agent only.
649
+ metrics=self.metrics,
650
+ sample=sample,
651
+ batch_size_B=self.config.batch_size_B,
652
+ batch_length_T=self.config.batch_length_T,
653
+ symlog_obs=do_symlog_obs(
654
+ self.env_runner.env.single_observation_space,
655
+ self.config.symlog_obs,
656
+ ),
657
+ do_report=(
658
+ self.config.report_images_and_videos
659
+ and self.training_iteration % 100 == 0
660
+ ),
661
+ )
662
+
663
+ # Log videos showing some of the dreamed trajectories and compare them with the
664
+ # actual trajectories from the train batch.
665
+ # Only every n iterations and only for the first sampled batch row AND first ts.
666
+ # (videos are `config.horizon_H` frames long originating from the observation
667
+ # at B=0 and T=0 in the train batch).
668
+ report_dreamed_eval_trajectory_vs_samples(
669
+ metrics=self.metrics,
670
+ sample=sample,
671
+ burn_in_T=0,
672
+ dreamed_T=self.config.horizon_H + 1,
673
+ dreamer_model=self.env_runner.module.dreamer_model,
674
+ symlog_obs=do_symlog_obs(
675
+ self.env_runner.env.single_observation_space,
676
+ self.config.symlog_obs,
677
+ ),
678
+ do_report=(
679
+ self.config.report_dream_data and self.training_iteration % 100 == 0
680
+ ),
681
+ framework=self.config.framework_str,
682
+ )
683
+
684
+ # Update weights - after learning on the LearnerGroup - on all EnvRunner
685
+ # workers.
686
+ with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
687
+ # Only necessary if RLModule is not shared between (local) EnvRunner and
688
+ # (local) Learner.
689
+ if not self.config.share_module_between_env_runner_and_learner:
690
+ self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum")
691
+ self.env_runner_group.sync_weights(
692
+ from_worker_or_learner_group=self.learner_group,
693
+ inference_only=True,
694
+ )
695
+
696
+ # Try trick from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-
697
+ # issue-in-keras-model-training-e703907a6501
698
+ if self.config.gc_frequency_train_steps and (
699
+ self.training_iteration % self.config.gc_frequency_train_steps == 0
700
+ ):
701
+ with self.metrics.log_time((TIMERS, GARBAGE_COLLECTION_TIMER)):
702
+ gc.collect()
703
+
704
+ # Add train results and the actual training ratio to stats. The latter should
705
+ # be close to the configured `training_ratio`.
706
+ self.metrics.log_value("actual_training_ratio", self.training_ratio, window=1)
707
+
708
+ @property
709
+ def training_ratio(self) -> float:
710
+ """Returns the actual training ratio of this Algorithm (not the configured one).
711
+
712
+ The training ratio is copmuted by dividing the total number of steps
713
+ trained thus far (replayed from the buffer) over the total number of actual
714
+ env steps taken thus far.
715
+ """
716
+ eps = 0.0001
717
+ return self.metrics.peek(NUM_ENV_STEPS_TRAINED_LIFETIME, default=0) / (
718
+ (
719
+ self.metrics.peek(
720
+ (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
721
+ default=eps,
722
+ )
723
+ or eps
724
+ )
725
+ )
726
+
727
+ # TODO (sven): Remove this once DreamerV3 is on the new SingleAgentEnvRunner.
728
+ @PublicAPI
729
+ def __setstate__(self, state) -> None:
730
+ """Sts the algorithm to the provided state
731
+
732
+ Args:
733
+ state: The state dictionary to restore this `DreamerV3` instance to.
734
+ `state` may have been returned by a call to an `Algorithm`'s
735
+ `__getstate__()` method.
736
+ """
737
+ # Call the `Algorithm`'s `__setstate__()` method.
738
+ super().__setstate__(state=state)
739
+
740
+ # Assign the module to the local `EnvRunner` if sharing is enabled.
741
+ # Note, in `Learner.restore_from_path()` the module is first deleted
742
+ # and then a new one is built - therefore the worker has no
743
+ # longer a copy of the learner.
744
+ if self.config.share_module_between_env_runner_and_learner:
745
+ assert id(self.env_runner.module) != id(
746
+ self.learner_group._learner.module[DEFAULT_MODULE_ID]
747
+ )
748
+ self.env_runner.module = self.learner_group._learner.module[
749
+ DEFAULT_MODULE_ID
750
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+
3
+ from ray.rllib.core.models.catalog import Catalog
4
+ from ray.rllib.core.models.base import Encoder, Model
5
+ from ray.rllib.utils import override
6
+
7
+
8
+ class DreamerV3Catalog(Catalog):
9
+ """The Catalog class used to build all the models needed for DreamerV3 training."""
10
+
11
+ def __init__(
12
+ self,
13
+ observation_space: gym.Space,
14
+ action_space: gym.Space,
15
+ model_config_dict: dict,
16
+ ):
17
+ """Initializes a DreamerV3Catalog instance.
18
+
19
+ Args:
20
+ observation_space: The observation space of the environment.
21
+ action_space: The action space of the environment.
22
+ model_config_dict: The model config to use.
23
+ """
24
+ super().__init__(
25
+ observation_space=observation_space,
26
+ action_space=action_space,
27
+ model_config_dict=model_config_dict,
28
+ )
29
+
30
+ self.model_size = self._model_config_dict["model_size"]
31
+ self.is_img_space = len(self.observation_space.shape) in [2, 3]
32
+ self.is_gray_scale = (
33
+ self.is_img_space and len(self.observation_space.shape) == 2
34
+ )
35
+
36
+ # TODO (sven): We should work with sub-component configurations here,
37
+ # and even try replacing all current Dreamer model components with
38
+ # our default primitives. But for now, we'll construct the DreamerV3Model
39
+ # directly in our `build_...()` methods.
40
+
41
+ @override(Catalog)
42
+ def build_encoder(self, framework: str) -> Encoder:
43
+ """Builds the World-Model's encoder network depending on the obs space."""
44
+ if framework != "tf2":
45
+ raise NotImplementedError
46
+
47
+ if self.is_img_space:
48
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.cnn_atari import (
49
+ CNNAtari,
50
+ )
51
+
52
+ return CNNAtari(model_size=self.model_size)
53
+ else:
54
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
55
+
56
+ return MLP(model_size=self.model_size, name="vector_encoder")
57
+
58
+ def build_decoder(self, framework: str) -> Model:
59
+ """Builds the World-Model's decoder network depending on the obs space."""
60
+ if framework != "tf2":
61
+ raise NotImplementedError
62
+
63
+ if self.is_img_space:
64
+ from ray.rllib.algorithms.dreamerv3.tf.models.components import (
65
+ conv_transpose_atari,
66
+ )
67
+
68
+ return conv_transpose_atari.ConvTransposeAtari(
69
+ model_size=self.model_size,
70
+ gray_scaled=self.is_gray_scale,
71
+ )
72
+ else:
73
+ from ray.rllib.algorithms.dreamerv3.tf.models.components import (
74
+ vector_decoder,
75
+ )
76
+
77
+ return vector_decoder.VectorDecoder(
78
+ model_size=self.model_size,
79
+ observation_space=self.observation_space,
80
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+
6
+ [2] Mastering Atari with Discrete World Models - 2021
7
+ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
8
+ https://arxiv.org/pdf/2010.02193.pdf
9
+ """
10
+ from ray.rllib.core.learner.learner import Learner
11
+ from ray.rllib.utils.annotations import (
12
+ override,
13
+ OverrideToImplementCustomLogic_CallToSuperRecommended,
14
+ )
15
+
16
+
17
+ class DreamerV3Learner(Learner):
18
+ """DreamerV3 specific Learner class.
19
+
20
+ Only implements the `after_gradient_based_update()` method to define the logic
21
+ for updating the critic EMA-copy after each training step.
22
+ """
23
+
24
+ @OverrideToImplementCustomLogic_CallToSuperRecommended
25
+ @override(Learner)
26
+ def after_gradient_based_update(self, *, timesteps):
27
+ super().after_gradient_based_update(timesteps=timesteps)
28
+
29
+ # Update EMA weights of the critic.
30
+ for module_id, module in self.module._rl_modules.items():
31
+ module.critic.update_ema()
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file holds framework-agnostic components for DreamerV3's RLModule.
3
+ """
4
+
5
+ import abc
6
+ from typing import Any, Dict
7
+
8
+ import gymnasium as gym
9
+ import numpy as np
10
+
11
+ from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs
12
+ from ray.rllib.algorithms.dreamerv3.tf.models.actor_network import ActorNetwork
13
+ from ray.rllib.algorithms.dreamerv3.tf.models.critic_network import CriticNetwork
14
+ from ray.rllib.algorithms.dreamerv3.tf.models.dreamer_model import DreamerModel
15
+ from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel
16
+ from ray.rllib.core.columns import Columns
17
+ from ray.rllib.core.rl_module.rl_module import RLModule
18
+ from ray.rllib.policy.eager_tf_policy import _convert_to_tf
19
+ from ray.rllib.utils.annotations import override
20
+ from ray.rllib.utils.framework import try_import_tf
21
+ from ray.rllib.utils.numpy import one_hot
22
+ from ray.util.annotations import DeveloperAPI
23
+
24
+
25
+ _, tf, _ = try_import_tf()
26
+
27
+
28
+ @DeveloperAPI(stability="alpha")
29
+ class DreamerV3RLModule(RLModule, abc.ABC):
30
+ @override(RLModule)
31
+ def setup(self):
32
+ super().setup()
33
+
34
+ # Gather model-relevant settings.
35
+ B = 1
36
+ T = self.model_config["batch_length_T"]
37
+ horizon_H = self.model_config["horizon_H"]
38
+ gamma = self.model_config["gamma"]
39
+ symlog_obs = do_symlog_obs(
40
+ self.observation_space,
41
+ self.model_config.get("symlog_obs", "auto"),
42
+ )
43
+ model_size = self.model_config["model_size"]
44
+
45
+ if self.model_config["use_float16"]:
46
+ tf.compat.v1.keras.layers.enable_v2_dtype_behavior()
47
+ tf.keras.mixed_precision.set_global_policy("mixed_float16")
48
+
49
+ # Build encoder and decoder from catalog.
50
+ self.encoder = self.catalog.build_encoder(framework=self.framework)
51
+ self.decoder = self.catalog.build_decoder(framework=self.framework)
52
+
53
+ # Build the world model (containing encoder and decoder).
54
+ self.world_model = WorldModel(
55
+ model_size=model_size,
56
+ observation_space=self.observation_space,
57
+ action_space=self.action_space,
58
+ batch_length_T=T,
59
+ encoder=self.encoder,
60
+ decoder=self.decoder,
61
+ symlog_obs=symlog_obs,
62
+ )
63
+ self.actor = ActorNetwork(
64
+ action_space=self.action_space,
65
+ model_size=model_size,
66
+ )
67
+ self.critic = CriticNetwork(
68
+ model_size=model_size,
69
+ )
70
+ # Build the final dreamer model (containing the world model).
71
+ self.dreamer_model = DreamerModel(
72
+ model_size=self.model_config["model_size"],
73
+ action_space=self.action_space,
74
+ world_model=self.world_model,
75
+ actor=self.actor,
76
+ critic=self.critic,
77
+ horizon=horizon_H,
78
+ gamma=gamma,
79
+ )
80
+ self.action_dist_cls = self.catalog.get_action_dist_cls(
81
+ framework=self.framework
82
+ )
83
+
84
+ # Perform a test `call()` to force building the dreamer model's variables.
85
+ if self.framework == "tf2":
86
+ test_obs = np.tile(
87
+ np.expand_dims(self.observation_space.sample(), (0, 1)),
88
+ reps=(B, T) + (1,) * len(self.observation_space.shape),
89
+ )
90
+ if isinstance(self.action_space, gym.spaces.Discrete):
91
+ test_actions = np.tile(
92
+ np.expand_dims(
93
+ one_hot(
94
+ self.action_space.sample(),
95
+ depth=self.action_space.n,
96
+ ),
97
+ (0, 1),
98
+ ),
99
+ reps=(B, T, 1),
100
+ )
101
+ else:
102
+ test_actions = np.tile(
103
+ np.expand_dims(self.action_space.sample(), (0, 1)),
104
+ reps=(B, T, 1),
105
+ )
106
+
107
+ self.dreamer_model(
108
+ inputs=None,
109
+ observations=_convert_to_tf(test_obs, dtype=tf.float32),
110
+ actions=_convert_to_tf(test_actions, dtype=tf.float32),
111
+ is_first=_convert_to_tf(np.ones((B, T)), dtype=tf.bool),
112
+ start_is_terminated_BxT=_convert_to_tf(
113
+ np.zeros((B * T,)), dtype=tf.bool
114
+ ),
115
+ gamma=gamma,
116
+ )
117
+
118
+ # Initialize the critic EMA net:
119
+ self.critic.init_ema()
120
+
121
+ @override(RLModule)
122
+ def get_initial_state(self) -> Dict:
123
+ # Use `DreamerModel`'s `get_initial_state` method.
124
+ return self.dreamer_model.get_initial_state()
125
+
126
+ @override(RLModule)
127
+ def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
128
+ # Call the Dreamer-Model's forward_inference method and return a dict.
129
+ actions, next_state = self.dreamer_model.forward_inference(
130
+ observations=batch[Columns.OBS],
131
+ previous_states=batch[Columns.STATE_IN],
132
+ is_first=batch["is_first"],
133
+ )
134
+ return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state}
135
+
136
+ @override(RLModule)
137
+ def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
138
+ # Call the Dreamer-Model's forward_exploration method and return a dict.
139
+ actions, next_state = self.dreamer_model.forward_exploration(
140
+ observations=batch[Columns.OBS],
141
+ previous_states=batch[Columns.STATE_IN],
142
+ is_first=batch["is_first"],
143
+ )
144
+ return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state}
145
+
146
+ @override(RLModule)
147
+ def _forward_train(self, batch: Dict[str, Any]):
148
+ # Call the Dreamer-Model's forward_train method and return its outputs as-is.
149
+ return self.dreamer_model.forward_train(
150
+ observations=batch[Columns.OBS],
151
+ actions=batch[Columns.ACTIONS],
152
+ is_first=batch["is_first"],
153
+ )
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (206 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc ADDED
Binary file (32.1 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc ADDED
Binary file (1.29 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+
6
+ [2] Mastering Atari with Discrete World Models - 2021
7
+ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
8
+ https://arxiv.org/pdf/2010.02193.pdf
9
+ """
10
+ from typing import Any, Dict, Tuple
11
+
12
+ import gymnasium as gym
13
+
14
+ from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config
15
+ from ray.rllib.algorithms.dreamerv3.dreamerv3_learner import DreamerV3Learner
16
+ from ray.rllib.core import DEFAULT_MODULE_ID
17
+ from ray.rllib.core.columns import Columns
18
+ from ray.rllib.core.learner.learner import ParamDict
19
+ from ray.rllib.core.learner.tf.tf_learner import TfLearner
20
+ from ray.rllib.utils.annotations import override
21
+ from ray.rllib.utils.framework import try_import_tf, try_import_tfp
22
+ from ray.rllib.utils.tf_utils import symlog, two_hot, clip_gradients
23
+ from ray.rllib.utils.typing import ModuleID, TensorType
24
+
25
+ _, tf, _ = try_import_tf()
26
+ tfp = try_import_tfp()
27
+
28
+
29
+ class DreamerV3TfLearner(DreamerV3Learner, TfLearner):
30
+ """Implements DreamerV3 losses and gradient-based update logic in TensorFlow.
31
+
32
+ The critic EMA-copy update step can be found in the `DreamerV3Learner` base class,
33
+ as it is framework independent.
34
+
35
+ We define 3 local TensorFlow optimizers for the sub components "world_model",
36
+ "actor", and "critic". Each of these optimizers might use a different learning rate,
37
+ epsilon parameter, and gradient clipping thresholds and procedures.
38
+ """
39
+
40
+ @override(TfLearner)
41
+ def configure_optimizers_for_module(
42
+ self, module_id: ModuleID, config: DreamerV3Config = None
43
+ ):
44
+ """Create the 3 optimizers for Dreamer learning: world_model, actor, critic.
45
+
46
+ The learning rates used are described in [1] and the epsilon values used here
47
+ - albeit probably not that important - are used by the author's own
48
+ implementation.
49
+ """
50
+
51
+ dreamerv3_module = self._module[module_id]
52
+
53
+ # World Model optimizer.
54
+ optim_world_model = tf.keras.optimizers.Adam(epsilon=1e-8)
55
+ optim_world_model.build(dreamerv3_module.world_model.trainable_variables)
56
+ params_world_model = self.get_parameters(dreamerv3_module.world_model)
57
+ self.register_optimizer(
58
+ module_id=module_id,
59
+ optimizer_name="world_model",
60
+ optimizer=optim_world_model,
61
+ params=params_world_model,
62
+ lr_or_lr_schedule=config.world_model_lr,
63
+ )
64
+
65
+ # Actor optimizer.
66
+ optim_actor = tf.keras.optimizers.Adam(epsilon=1e-5)
67
+ optim_actor.build(dreamerv3_module.actor.trainable_variables)
68
+ params_actor = self.get_parameters(dreamerv3_module.actor)
69
+ self.register_optimizer(
70
+ module_id=module_id,
71
+ optimizer_name="actor",
72
+ optimizer=optim_actor,
73
+ params=params_actor,
74
+ lr_or_lr_schedule=config.actor_lr,
75
+ )
76
+
77
+ # Critic optimizer.
78
+ optim_critic = tf.keras.optimizers.Adam(epsilon=1e-5)
79
+ optim_critic.build(dreamerv3_module.critic.trainable_variables)
80
+ params_critic = self.get_parameters(dreamerv3_module.critic)
81
+ self.register_optimizer(
82
+ module_id=module_id,
83
+ optimizer_name="critic",
84
+ optimizer=optim_critic,
85
+ params=params_critic,
86
+ lr_or_lr_schedule=config.critic_lr,
87
+ )
88
+
89
+ @override(TfLearner)
90
+ def postprocess_gradients_for_module(
91
+ self,
92
+ *,
93
+ module_id: ModuleID,
94
+ config: DreamerV3Config,
95
+ module_gradients_dict: Dict[str, Any],
96
+ ) -> ParamDict:
97
+ """Performs gradient clipping on the 3 module components' computed grads.
98
+
99
+ Note that different grad global-norm clip values are used for the 3
100
+ module components: world model, actor, and critic.
101
+ """
102
+ for optimizer_name, optimizer in self.get_optimizers_for_module(
103
+ module_id=module_id
104
+ ):
105
+ grads_sub_dict = self.filter_param_dict_for_optimizer(
106
+ module_gradients_dict, optimizer
107
+ )
108
+ # Figure out, which grad clip setting to use.
109
+ grad_clip = (
110
+ config.world_model_grad_clip_by_global_norm
111
+ if optimizer_name == "world_model"
112
+ else config.actor_grad_clip_by_global_norm
113
+ if optimizer_name == "actor"
114
+ else config.critic_grad_clip_by_global_norm
115
+ )
116
+ global_norm = clip_gradients(
117
+ grads_sub_dict,
118
+ grad_clip=grad_clip,
119
+ grad_clip_by="global_norm",
120
+ )
121
+ module_gradients_dict.update(grads_sub_dict)
122
+
123
+ # DreamerV3 stats have the format: [WORLD_MODEL|ACTOR|CRITIC]_[stats name].
124
+ self.metrics.log_dict(
125
+ {
126
+ optimizer_name.upper() + "_gradients_global_norm": global_norm,
127
+ optimizer_name.upper()
128
+ + "_gradients_maxabs_after_clipping": (
129
+ tf.reduce_max(
130
+ [
131
+ tf.reduce_max(tf.math.abs(g))
132
+ for g in grads_sub_dict.values()
133
+ ]
134
+ )
135
+ ),
136
+ },
137
+ key=module_id,
138
+ window=1, # <- single items (should not be mean/ema-reduced over time).
139
+ )
140
+
141
+ return module_gradients_dict
142
+
143
+ @override(TfLearner)
144
+ def compute_gradients(
145
+ self,
146
+ loss_per_module,
147
+ gradient_tape,
148
+ **kwargs,
149
+ ):
150
+ # Override of the default gradient computation method.
151
+ # For DreamerV3, we need to compute gradients over the individual loss terms
152
+ # as otherwise, the world model's parameters would have their gradients also
153
+ # be influenced by the actor- and critic loss terms/gradient computations.
154
+ grads = {}
155
+ for component in ["world_model", "actor", "critic"]:
156
+ grads.update(
157
+ gradient_tape.gradient(
158
+ # Take individual loss term from the registered metrics for
159
+ # the main module.
160
+ self.metrics.peek(
161
+ (DEFAULT_MODULE_ID, component.upper() + "_L_total")
162
+ ),
163
+ self.filter_param_dict_for_optimizer(
164
+ self._params, self.get_optimizer(optimizer_name=component)
165
+ ),
166
+ )
167
+ )
168
+ del gradient_tape
169
+ return grads
170
+
171
+ @override(TfLearner)
172
+ def compute_loss_for_module(
173
+ self,
174
+ module_id: ModuleID,
175
+ config: DreamerV3Config,
176
+ batch: Dict[str, TensorType],
177
+ fwd_out: Dict[str, TensorType],
178
+ ) -> TensorType:
179
+ # World model losses.
180
+ prediction_losses = self._compute_world_model_prediction_losses(
181
+ config=config,
182
+ rewards_B_T=batch[Columns.REWARDS],
183
+ continues_B_T=(1.0 - tf.cast(batch["is_terminated"], tf.float32)),
184
+ fwd_out=fwd_out,
185
+ )
186
+
187
+ (
188
+ L_dyn_B_T,
189
+ L_rep_B_T,
190
+ ) = self._compute_world_model_dynamics_and_representation_loss(
191
+ config=config, fwd_out=fwd_out
192
+ )
193
+ L_dyn = tf.reduce_mean(L_dyn_B_T)
194
+ L_rep = tf.reduce_mean(L_rep_B_T)
195
+ # Make sure values for L_rep and L_dyn are the same (they only differ in their
196
+ # gradients).
197
+ tf.assert_equal(L_dyn, L_rep)
198
+
199
+ # Compute the actual total loss using fixed weights described in [1] eq. 4.
200
+ L_world_model_total_B_T = (
201
+ 1.0 * prediction_losses["L_prediction_B_T"]
202
+ + 0.5 * L_dyn_B_T
203
+ + 0.1 * L_rep_B_T
204
+ )
205
+
206
+ # In the paper, it says to sum up timesteps, and average over
207
+ # batch (see eq. 4 in [1]). But Danijar's implementation only does
208
+ # averaging (over B and T), so we'll do this here as well. This is generally
209
+ # true for all other loss terms as well (we'll always just average, no summing
210
+ # over T axis!).
211
+ L_world_model_total = tf.reduce_mean(L_world_model_total_B_T)
212
+
213
+ # Log world model loss stats.
214
+ self.metrics.log_dict(
215
+ {
216
+ "WORLD_MODEL_learned_initial_h": (
217
+ self.module[module_id].world_model.initial_h
218
+ ),
219
+ # Prediction losses.
220
+ # Decoder (obs) loss.
221
+ "WORLD_MODEL_L_decoder": prediction_losses["L_decoder"],
222
+ # Reward loss.
223
+ "WORLD_MODEL_L_reward": prediction_losses["L_reward"],
224
+ # Continue loss.
225
+ "WORLD_MODEL_L_continue": prediction_losses["L_continue"],
226
+ # Total.
227
+ "WORLD_MODEL_L_prediction": prediction_losses["L_prediction"],
228
+ # Dynamics loss.
229
+ "WORLD_MODEL_L_dynamics": L_dyn,
230
+ # Representation loss.
231
+ "WORLD_MODEL_L_representation": L_rep,
232
+ # Total loss.
233
+ "WORLD_MODEL_L_total": L_world_model_total,
234
+ },
235
+ key=module_id,
236
+ window=1, # <- single items (should not be mean/ema-reduced over time).
237
+ )
238
+
239
+ # Add the predicted obs distributions for possible (video) summarization.
240
+ if config.report_images_and_videos:
241
+ self.metrics.log_value(
242
+ (module_id, "WORLD_MODEL_fwd_out_obs_distribution_means_b0xT"),
243
+ fwd_out["obs_distribution_means_BxT"][: self.config.batch_length_T],
244
+ reduce=None, # No reduction, we want the tensor to stay in-tact.
245
+ window=1, # <- single items (should not be mean/ema-reduced over time).
246
+ )
247
+
248
+ if config.report_individual_batch_item_stats:
249
+ # Log important world-model loss stats.
250
+ self.metrics.log_dict(
251
+ {
252
+ "WORLD_MODEL_L_decoder_B_T": prediction_losses["L_decoder_B_T"],
253
+ "WORLD_MODEL_L_reward_B_T": prediction_losses["L_reward_B_T"],
254
+ "WORLD_MODEL_L_continue_B_T": prediction_losses["L_continue_B_T"],
255
+ "WORLD_MODEL_L_prediction_B_T": (
256
+ prediction_losses["L_prediction_B_T"]
257
+ ),
258
+ "WORLD_MODEL_L_dynamics_B_T": L_dyn_B_T,
259
+ "WORLD_MODEL_L_representation_B_T": L_rep_B_T,
260
+ "WORLD_MODEL_L_total_B_T": L_world_model_total_B_T,
261
+ },
262
+ key=module_id,
263
+ window=1, # <- single items (should not be mean/ema-reduced over time).
264
+ )
265
+
266
+ # Dream trajectories starting in all internal states (h + z_posterior) that were
267
+ # computed during world model training.
268
+ # Everything goes in as BxT: We are starting a new dream trajectory at every
269
+ # actually encountered timestep in the batch, so we are creating B*T
270
+ # trajectories of len `horizon_H`.
271
+ dream_data = self.module[module_id].dreamer_model.dream_trajectory(
272
+ start_states={
273
+ "h": fwd_out["h_states_BxT"],
274
+ "z": fwd_out["z_posterior_states_BxT"],
275
+ },
276
+ start_is_terminated=tf.reshape(batch["is_terminated"], [-1]), # -> BxT
277
+ )
278
+ if config.report_dream_data:
279
+ # To reduce this massive amount of data a little, slice out a T=1 piece
280
+ # from each stats that has the shape (H, BxT), meaning convert e.g.
281
+ # `rewards_dreamed_t0_to_H_BxT` into `rewards_dreamed_t0_to_H_Bx1`.
282
+ # This will reduce the amount of data to be transferred and reported
283
+ # by the factor of `batch_length_T`.
284
+ self.metrics.log_dict(
285
+ {
286
+ # Replace 'T' with '1'.
287
+ key[:-1] + "1": value[:, :: config.batch_length_T]
288
+ for key, value in dream_data.items()
289
+ if key.endswith("H_BxT")
290
+ },
291
+ key=(module_id, "dream_data"),
292
+ reduce=None,
293
+ window=1, # <- single items (should not be mean/ema-reduced over time).
294
+ )
295
+
296
+ value_targets_t0_to_Hm1_BxT = self._compute_value_targets(
297
+ config=config,
298
+ # Learn critic in symlog'd space.
299
+ rewards_t0_to_H_BxT=dream_data["rewards_dreamed_t0_to_H_BxT"],
300
+ intrinsic_rewards_t1_to_H_BxT=(
301
+ dream_data["rewards_intrinsic_t1_to_H_B"]
302
+ if config.use_curiosity
303
+ else None
304
+ ),
305
+ continues_t0_to_H_BxT=dream_data["continues_dreamed_t0_to_H_BxT"],
306
+ value_predictions_t0_to_H_BxT=dream_data["values_dreamed_t0_to_H_BxT"],
307
+ )
308
+ self.metrics.log_value(
309
+ key=(module_id, "VALUE_TARGETS_H_BxT"),
310
+ value=value_targets_t0_to_Hm1_BxT,
311
+ window=1, # <- single items (should not be mean/ema-reduced over time).
312
+ )
313
+
314
+ CRITIC_L_total = self._compute_critic_loss(
315
+ module_id=module_id,
316
+ config=config,
317
+ dream_data=dream_data,
318
+ value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT,
319
+ )
320
+ if config.train_actor:
321
+ ACTOR_L_total = self._compute_actor_loss(
322
+ module_id=module_id,
323
+ config=config,
324
+ dream_data=dream_data,
325
+ value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT,
326
+ )
327
+ else:
328
+ ACTOR_L_total = 0.0
329
+
330
+ # Return the total loss as a sum of all individual losses.
331
+ return L_world_model_total + CRITIC_L_total + ACTOR_L_total
332
+
333
+ def _compute_world_model_prediction_losses(
334
+ self,
335
+ *,
336
+ config: DreamerV3Config,
337
+ rewards_B_T: TensorType,
338
+ continues_B_T: TensorType,
339
+ fwd_out: Dict[str, TensorType],
340
+ ) -> Dict[str, TensorType]:
341
+ """Helper method computing all world-model related prediction losses.
342
+
343
+ Prediction losses are used to train the predictors of the world model, which
344
+ are: Reward predictor, continue predictor, and the decoder (which predicts
345
+ observations).
346
+
347
+ Args:
348
+ config: The DreamerV3Config to use.
349
+ rewards_B_T: The rewards batch in the shape (B, T) and of type float32.
350
+ continues_B_T: The continues batch in the shape (B, T) and of type float32
351
+ (1.0 -> continue; 0.0 -> end of episode).
352
+ fwd_out: The `forward_train` outputs of the DreamerV3RLModule.
353
+ """
354
+
355
+ # Learn to produce symlog'd observation predictions.
356
+ # If symlog is disabled (e.g. for uint8 image inputs), `obs_symlog_BxT` is the
357
+ # same as `obs_BxT`.
358
+ obs_BxT = fwd_out["sampled_obs_symlog_BxT"]
359
+ obs_distr_means = fwd_out["obs_distribution_means_BxT"]
360
+ # In case we wanted to construct a distribution object from the fwd out data,
361
+ # we would have to do it like this:
362
+ # obs_distr = tfp.distributions.MultivariateNormalDiag(
363
+ # loc=obs_distr_means,
364
+ # # Scale == 1.0.
365
+ # # [2]: "Distributions The image predictor outputs the mean of a diagonal
366
+ # # Gaussian likelihood with **unit variance** ..."
367
+ # scale_diag=tf.ones_like(obs_distr_means),
368
+ # )
369
+
370
+ # Leave time dim folded (BxT) and flatten all other (e.g. image) dims.
371
+ obs_BxT = tf.reshape(obs_BxT, shape=[-1, tf.reduce_prod(obs_BxT.shape[1:])])
372
+
373
+ # Squared diff loss w/ sum(!) over all (already folded) obs dims.
374
+ # decoder_loss_BxT = SUM[ (obs_distr.loc - observations)^2 ]
375
+ # Note: This is described strangely in the paper (stating a neglogp loss here),
376
+ # but the author's own implementation actually uses simple MSE with the loc
377
+ # of the Gaussian.
378
+ decoder_loss_BxT = tf.reduce_sum(
379
+ tf.math.square(obs_distr_means - obs_BxT), axis=-1
380
+ )
381
+
382
+ # Unfold time rank back in.
383
+ decoder_loss_B_T = tf.reshape(
384
+ decoder_loss_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
385
+ )
386
+ L_decoder = tf.reduce_mean(decoder_loss_B_T)
387
+
388
+ # The FiniteDiscrete reward bucket distribution computed by our reward
389
+ # predictor.
390
+ # [B x num_buckets].
391
+ reward_logits_BxT = fwd_out["reward_logits_BxT"]
392
+ # Learn to produce symlog'd reward predictions.
393
+ rewards_symlog_B_T = symlog(tf.cast(rewards_B_T, tf.float32))
394
+ # Fold time dim.
395
+ rewards_symlog_BxT = tf.reshape(rewards_symlog_B_T, shape=[-1])
396
+
397
+ # Two-hot encode.
398
+ two_hot_rewards_symlog_BxT = two_hot(rewards_symlog_BxT)
399
+ # two_hot_rewards_symlog_BxT=[B*T, num_buckets]
400
+ reward_log_pred_BxT = reward_logits_BxT - tf.math.reduce_logsumexp(
401
+ reward_logits_BxT, axis=-1, keepdims=True
402
+ )
403
+ # Multiply with two-hot targets and neg.
404
+ reward_loss_two_hot_BxT = -tf.reduce_sum(
405
+ reward_log_pred_BxT * two_hot_rewards_symlog_BxT, axis=-1
406
+ )
407
+ # Unfold time rank back in.
408
+ reward_loss_two_hot_B_T = tf.reshape(
409
+ reward_loss_two_hot_BxT,
410
+ (config.batch_size_B_per_learner, config.batch_length_T),
411
+ )
412
+ L_reward_two_hot = tf.reduce_mean(reward_loss_two_hot_B_T)
413
+
414
+ # Probabilities that episode continues, computed by our continue predictor.
415
+ # [B]
416
+ continue_distr = fwd_out["continue_distribution_BxT"]
417
+ # -log(p) loss
418
+ # Fold time dim.
419
+ continues_BxT = tf.reshape(continues_B_T, shape=[-1])
420
+ continue_loss_BxT = -continue_distr.log_prob(continues_BxT)
421
+ # Unfold time rank back in.
422
+ continue_loss_B_T = tf.reshape(
423
+ continue_loss_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
424
+ )
425
+ L_continue = tf.reduce_mean(continue_loss_B_T)
426
+
427
+ # Sum all losses together as the "prediction" loss.
428
+ L_pred_B_T = decoder_loss_B_T + reward_loss_two_hot_B_T + continue_loss_B_T
429
+ L_pred = tf.reduce_mean(L_pred_B_T)
430
+
431
+ return {
432
+ "L_decoder_B_T": decoder_loss_B_T,
433
+ "L_decoder": L_decoder,
434
+ "L_reward": L_reward_two_hot,
435
+ "L_reward_B_T": reward_loss_two_hot_B_T,
436
+ "L_continue": L_continue,
437
+ "L_continue_B_T": continue_loss_B_T,
438
+ "L_prediction": L_pred,
439
+ "L_prediction_B_T": L_pred_B_T,
440
+ }
441
+
442
+ def _compute_world_model_dynamics_and_representation_loss(
443
+ self, *, config: DreamerV3Config, fwd_out: Dict[str, Any]
444
+ ) -> Tuple[TensorType, TensorType]:
445
+ """Helper method computing the world-model's dynamics and representation losses.
446
+
447
+ Args:
448
+ config: The DreamerV3Config to use.
449
+ fwd_out: The `forward_train` outputs of the DreamerV3RLModule.
450
+
451
+ Returns:
452
+ Tuple consisting of a) dynamics loss: Trains the prior network, predicting
453
+ z^ prior states from h-states and b) representation loss: Trains posterior
454
+ network, predicting z posterior states from h-states and (encoded)
455
+ observations.
456
+ """
457
+
458
+ # Actual distribution over stochastic internal states (z) produced by the
459
+ # encoder.
460
+ z_posterior_probs_BxT = fwd_out["z_posterior_probs_BxT"]
461
+ z_posterior_distr_BxT = tfp.distributions.Independent(
462
+ tfp.distributions.OneHotCategorical(probs=z_posterior_probs_BxT),
463
+ reinterpreted_batch_ndims=1,
464
+ )
465
+
466
+ # Actual distribution over stochastic internal states (z) produced by the
467
+ # dynamics network.
468
+ z_prior_probs_BxT = fwd_out["z_prior_probs_BxT"]
469
+ z_prior_distr_BxT = tfp.distributions.Independent(
470
+ tfp.distributions.OneHotCategorical(probs=z_prior_probs_BxT),
471
+ reinterpreted_batch_ndims=1,
472
+ )
473
+
474
+ # Stop gradient for encoder's z-outputs:
475
+ sg_z_posterior_distr_BxT = tfp.distributions.Independent(
476
+ tfp.distributions.OneHotCategorical(
477
+ probs=tf.stop_gradient(z_posterior_probs_BxT)
478
+ ),
479
+ reinterpreted_batch_ndims=1,
480
+ )
481
+ # Stop gradient for dynamics model's z-outputs:
482
+ sg_z_prior_distr_BxT = tfp.distributions.Independent(
483
+ tfp.distributions.OneHotCategorical(
484
+ probs=tf.stop_gradient(z_prior_probs_BxT)
485
+ ),
486
+ reinterpreted_batch_ndims=1,
487
+ )
488
+
489
+ # Implement free bits. According to [1]:
490
+ # "To avoid a degenerate solution where the dynamics are trivial to predict but
491
+ # contain not enough information about the inputs, we employ free bits by
492
+ # clipping the dynamics and representation losses below the value of
493
+ # 1 nat ≈ 1.44 bits. This disables them while they are already minimized well to
494
+ # focus the world model on its prediction loss"
495
+ L_dyn_BxT = tf.math.maximum(
496
+ 1.0,
497
+ tfp.distributions.kl_divergence(
498
+ sg_z_posterior_distr_BxT, z_prior_distr_BxT
499
+ ),
500
+ )
501
+ # Unfold time rank back in.
502
+ L_dyn_B_T = tf.reshape(
503
+ L_dyn_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
504
+ )
505
+
506
+ L_rep_BxT = tf.math.maximum(
507
+ 1.0,
508
+ tfp.distributions.kl_divergence(
509
+ z_posterior_distr_BxT, sg_z_prior_distr_BxT
510
+ ),
511
+ )
512
+ # Unfold time rank back in.
513
+ L_rep_B_T = tf.reshape(
514
+ L_rep_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
515
+ )
516
+
517
+ return L_dyn_B_T, L_rep_B_T
518
+
519
+ def _compute_actor_loss(
520
+ self,
521
+ *,
522
+ module_id: ModuleID,
523
+ config: DreamerV3Config,
524
+ dream_data: Dict[str, TensorType],
525
+ value_targets_t0_to_Hm1_BxT: TensorType,
526
+ ) -> TensorType:
527
+ """Helper method computing the actor's loss terms.
528
+
529
+ Args:
530
+ module_id: The module_id for which to compute the actor loss.
531
+ config: The DreamerV3Config to use.
532
+ dream_data: The data generated by dreaming for H steps (horizon) starting
533
+ from any BxT state (sampled from the buffer for the train batch).
534
+ value_targets_t0_to_Hm1_BxT: The computed value function targets of the
535
+ shape (t0 to H-1, BxT).
536
+
537
+ Returns:
538
+ The total actor loss tensor.
539
+ """
540
+ actor = self.module[module_id].actor
541
+
542
+ # Note: `scaled_value_targets_t0_to_Hm1_B` are NOT stop_gradient'd yet.
543
+ scaled_value_targets_t0_to_Hm1_B = self._compute_scaled_value_targets(
544
+ module_id=module_id,
545
+ config=config,
546
+ value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT,
547
+ value_predictions_t0_to_Hm1_BxT=dream_data["values_dreamed_t0_to_H_BxT"][
548
+ :-1
549
+ ],
550
+ )
551
+
552
+ # Actions actually taken in the dream.
553
+ actions_dreamed = tf.stop_gradient(dream_data["actions_dreamed_t0_to_H_BxT"])[
554
+ :-1
555
+ ]
556
+ actions_dreamed_dist_params_t0_to_Hm1_B = dream_data[
557
+ "actions_dreamed_dist_params_t0_to_H_BxT"
558
+ ][:-1]
559
+
560
+ dist_t0_to_Hm1_B = actor.get_action_dist_object(
561
+ actions_dreamed_dist_params_t0_to_Hm1_B
562
+ )
563
+
564
+ # Compute log(p)s of all possible actions in the dream.
565
+ if isinstance(self.module[module_id].actor.action_space, gym.spaces.Discrete):
566
+ # Note that when we create the Categorical action distributions, we compute
567
+ # unimix probs, then math.log these and provide these log(p) as "logits" to
568
+ # the Categorical. So here, we'll continue to work with log(p)s (not
569
+ # really "logits")!
570
+ logp_actions_t0_to_Hm1_B = actions_dreamed_dist_params_t0_to_Hm1_B
571
+
572
+ # Log probs of actions actually taken in the dream.
573
+ logp_actions_dreamed_t0_to_Hm1_B = tf.reduce_sum(
574
+ actions_dreamed * logp_actions_t0_to_Hm1_B,
575
+ axis=-1,
576
+ )
577
+ # First term of loss function. [1] eq. 11.
578
+ logp_loss_H_B = logp_actions_dreamed_t0_to_Hm1_B * tf.stop_gradient(
579
+ scaled_value_targets_t0_to_Hm1_B
580
+ )
581
+ # Box space.
582
+ else:
583
+ logp_actions_dreamed_t0_to_Hm1_B = dist_t0_to_Hm1_B.log_prob(
584
+ actions_dreamed
585
+ )
586
+ # First term of loss function. [1] eq. 11.
587
+ logp_loss_H_B = scaled_value_targets_t0_to_Hm1_B
588
+
589
+ assert len(logp_loss_H_B.shape) == 2
590
+
591
+ # Add entropy loss term (second term [1] eq. 11).
592
+ entropy_H_B = dist_t0_to_Hm1_B.entropy()
593
+ assert len(entropy_H_B.shape) == 2
594
+ entropy = tf.reduce_mean(entropy_H_B)
595
+
596
+ L_actor_reinforce_term_H_B = -logp_loss_H_B
597
+ L_actor_action_entropy_term_H_B = -config.entropy_scale * entropy_H_B
598
+
599
+ L_actor_H_B = L_actor_reinforce_term_H_B + L_actor_action_entropy_term_H_B
600
+ # Mask out everything that goes beyond a predicted continue=False boundary.
601
+ L_actor_H_B *= tf.stop_gradient(dream_data["dream_loss_weights_t0_to_H_BxT"])[
602
+ :-1
603
+ ]
604
+ L_actor = tf.reduce_mean(L_actor_H_B)
605
+
606
+ # Log important actor loss stats.
607
+ self.metrics.log_dict(
608
+ {
609
+ "ACTOR_L_total": L_actor,
610
+ "ACTOR_value_targets_pct95_ema": actor.ema_value_target_pct95,
611
+ "ACTOR_value_targets_pct5_ema": actor.ema_value_target_pct5,
612
+ "ACTOR_action_entropy": entropy,
613
+ # Individual loss terms.
614
+ "ACTOR_L_neglogp_reinforce_term": tf.reduce_mean(
615
+ L_actor_reinforce_term_H_B
616
+ ),
617
+ "ACTOR_L_neg_entropy_term": tf.reduce_mean(
618
+ L_actor_action_entropy_term_H_B
619
+ ),
620
+ },
621
+ key=module_id,
622
+ window=1, # <- single items (should not be mean/ema-reduced over time).
623
+ )
624
+ if config.report_individual_batch_item_stats:
625
+ self.metrics.log_dict(
626
+ {
627
+ "ACTOR_L_total_H_BxT": L_actor_H_B,
628
+ "ACTOR_logp_actions_dreamed_H_BxT": (
629
+ logp_actions_dreamed_t0_to_Hm1_B
630
+ ),
631
+ "ACTOR_scaled_value_targets_H_BxT": (
632
+ scaled_value_targets_t0_to_Hm1_B
633
+ ),
634
+ "ACTOR_action_entropy_H_BxT": entropy_H_B,
635
+ # Individual loss terms.
636
+ "ACTOR_L_neglogp_reinforce_term_H_BxT": L_actor_reinforce_term_H_B,
637
+ "ACTOR_L_neg_entropy_term_H_BxT": L_actor_action_entropy_term_H_B,
638
+ },
639
+ key=module_id,
640
+ window=1, # <- single items (should not be mean/ema-reduced over time).
641
+ )
642
+
643
+ return L_actor
644
+
645
+ def _compute_critic_loss(
646
+ self,
647
+ *,
648
+ module_id: ModuleID,
649
+ config: DreamerV3Config,
650
+ dream_data: Dict[str, TensorType],
651
+ value_targets_t0_to_Hm1_BxT: TensorType,
652
+ ) -> TensorType:
653
+ """Helper method computing the critic's loss terms.
654
+
655
+ Args:
656
+ module_id: The ModuleID for which to compute the critic loss.
657
+ config: The DreamerV3Config to use.
658
+ dream_data: The data generated by dreaming for H steps (horizon) starting
659
+ from any BxT state (sampled from the buffer for the train batch).
660
+ value_targets_t0_to_Hm1_BxT: The computed value function targets of the
661
+ shape (t0 to H-1, BxT).
662
+
663
+ Returns:
664
+ The total critic loss tensor.
665
+ """
666
+ # B=BxT
667
+ H, B = dream_data["rewards_dreamed_t0_to_H_BxT"].shape[:2]
668
+ Hm1 = H - 1
669
+
670
+ # Note that value targets are NOT symlog'd and go from t0 to H-1, not H, like
671
+ # all the other dream data.
672
+
673
+ # From here on: B=BxT
674
+ value_targets_t0_to_Hm1_B = tf.stop_gradient(value_targets_t0_to_Hm1_BxT)
675
+ value_symlog_targets_t0_to_Hm1_B = symlog(value_targets_t0_to_Hm1_B)
676
+ # Fold time rank (for two_hot'ing).
677
+ value_symlog_targets_HxB = tf.reshape(value_symlog_targets_t0_to_Hm1_B, (-1,))
678
+ value_symlog_targets_two_hot_HxB = two_hot(value_symlog_targets_HxB)
679
+ # Unfold time rank.
680
+ value_symlog_targets_two_hot_t0_to_Hm1_B = tf.reshape(
681
+ value_symlog_targets_two_hot_HxB,
682
+ shape=[Hm1, B, value_symlog_targets_two_hot_HxB.shape[-1]],
683
+ )
684
+
685
+ # Get (B x T x probs) tensor from return distributions.
686
+ value_symlog_logits_HxB = dream_data["values_symlog_dreamed_logits_t0_to_HxBxT"]
687
+ # Unfold time rank and cut last time index to match value targets.
688
+ value_symlog_logits_t0_to_Hm1_B = tf.reshape(
689
+ value_symlog_logits_HxB,
690
+ shape=[H, B, value_symlog_logits_HxB.shape[-1]],
691
+ )[:-1]
692
+
693
+ values_log_pred_Hm1_B = (
694
+ value_symlog_logits_t0_to_Hm1_B
695
+ - tf.math.reduce_logsumexp(
696
+ value_symlog_logits_t0_to_Hm1_B, axis=-1, keepdims=True
697
+ )
698
+ )
699
+ # Multiply with two-hot targets and neg.
700
+ value_loss_two_hot_H_B = -tf.reduce_sum(
701
+ values_log_pred_Hm1_B * value_symlog_targets_two_hot_t0_to_Hm1_B, axis=-1
702
+ )
703
+
704
+ # Compute EMA regularization loss.
705
+ # Expected values (dreamed) from the EMA (slow critic) net.
706
+ # Note: Slow critic (EMA) outputs are already stop_gradient'd.
707
+ value_symlog_ema_t0_to_Hm1_B = tf.stop_gradient(
708
+ dream_data["v_symlog_dreamed_ema_t0_to_H_BxT"]
709
+ )[:-1]
710
+ # Fold time rank (for two_hot'ing).
711
+ value_symlog_ema_HxB = tf.reshape(value_symlog_ema_t0_to_Hm1_B, (-1,))
712
+ value_symlog_ema_two_hot_HxB = two_hot(value_symlog_ema_HxB)
713
+ # Unfold time rank.
714
+ value_symlog_ema_two_hot_t0_to_Hm1_B = tf.reshape(
715
+ value_symlog_ema_two_hot_HxB,
716
+ shape=[Hm1, B, value_symlog_ema_two_hot_HxB.shape[-1]],
717
+ )
718
+
719
+ # Compute ema regularizer loss.
720
+ # In the paper, it is not described how exactly to form this regularizer term
721
+ # and how to weigh it.
722
+ # So we follow Danijar's repo here:
723
+ # `reg = -dist.log_prob(sg(self.slow(traj).mean()))`
724
+ # with a weight of 1.0, where dist is the bucket'ized distribution output by the
725
+ # fast critic. sg=stop gradient; mean() -> use the expected EMA values.
726
+ # Multiply with two-hot targets and neg.
727
+ ema_regularization_loss_H_B = -tf.reduce_sum(
728
+ values_log_pred_Hm1_B * value_symlog_ema_two_hot_t0_to_Hm1_B, axis=-1
729
+ )
730
+
731
+ L_critic_H_B = value_loss_two_hot_H_B + ema_regularization_loss_H_B
732
+
733
+ # Mask out everything that goes beyond a predicted continue=False boundary.
734
+ L_critic_H_B *= tf.stop_gradient(dream_data["dream_loss_weights_t0_to_H_BxT"])[
735
+ :-1
736
+ ]
737
+
738
+ # Reduce over both H- (time) axis and B-axis (mean).
739
+ L_critic = tf.reduce_mean(L_critic_H_B)
740
+
741
+ # Log important critic loss stats.
742
+ self.metrics.log_dict(
743
+ {
744
+ "CRITIC_L_total": L_critic,
745
+ "CRITIC_L_neg_logp_of_value_targets": tf.reduce_mean(
746
+ value_loss_two_hot_H_B
747
+ ),
748
+ "CRITIC_L_slow_critic_regularization": tf.reduce_mean(
749
+ ema_regularization_loss_H_B
750
+ ),
751
+ },
752
+ key=module_id,
753
+ window=1, # <- single items (should not be mean/ema-reduced over time).
754
+ )
755
+ if config.report_individual_batch_item_stats:
756
+ # Log important critic loss stats.
757
+ self.metrics.log_dict(
758
+ {
759
+ # Symlog'd value targets. Critic learns to predict symlog'd values.
760
+ "VALUE_TARGETS_symlog_H_BxT": value_symlog_targets_t0_to_Hm1_B,
761
+ # Critic loss terms.
762
+ "CRITIC_L_total_H_BxT": L_critic_H_B,
763
+ "CRITIC_L_neg_logp_of_value_targets_H_BxT": value_loss_two_hot_H_B,
764
+ "CRITIC_L_slow_critic_regularization_H_BxT": (
765
+ ema_regularization_loss_H_B
766
+ ),
767
+ },
768
+ key=module_id,
769
+ window=1, # <- single items (should not be mean/ema-reduced over time).
770
+ )
771
+
772
+ return L_critic
773
+
774
+ def _compute_value_targets(
775
+ self,
776
+ *,
777
+ config: DreamerV3Config,
778
+ rewards_t0_to_H_BxT: TensorType,
779
+ intrinsic_rewards_t1_to_H_BxT: TensorType,
780
+ continues_t0_to_H_BxT: TensorType,
781
+ value_predictions_t0_to_H_BxT: TensorType,
782
+ ) -> TensorType:
783
+ """Helper method computing the value targets.
784
+
785
+ All args are (H, BxT, ...) and in non-symlog'd (real) reward space.
786
+ Non-symlog is important b/c log(a+b) != log(a) + log(b).
787
+ See [1] eq. 8 and 10.
788
+ Thus, targets are always returned in real (non-symlog'd space).
789
+ They need to be re-symlog'd before computing the critic loss from them (b/c the
790
+ critic produces predictions in symlog space).
791
+ Note that the original B and T ranks together form the new batch dimension
792
+ (folded into BxT) and the new time rank is the dream horizon (hence: [H, BxT]).
793
+
794
+ Variable names nomenclature:
795
+ `H`=1+horizon_H (start state + H steps dreamed),
796
+ `BxT`=batch_size * batch_length (meaning the original trajectory time rank has
797
+ been folded).
798
+
799
+ Rewards, continues, and value predictions are all of shape [t0-H, BxT]
800
+ (time-major), whereas returned targets are [t0 to H-1, B] (last timestep missing
801
+ b/c the target value equals vf prediction in that location anyways.
802
+
803
+ Args:
804
+ config: The DreamerV3Config to use.
805
+ rewards_t0_to_H_BxT: The reward predictor's predictions over the
806
+ dreamed trajectory t0 to H (and for the batch BxT).
807
+ intrinsic_rewards_t1_to_H_BxT: The predicted intrinsic rewards over the
808
+ dreamed trajectory t0 to H (and for the batch BxT).
809
+ continues_t0_to_H_BxT: The continue predictor's predictions over the
810
+ dreamed trajectory t0 to H (and for the batch BxT).
811
+ value_predictions_t0_to_H_BxT: The critic's value predictions over the
812
+ dreamed trajectory t0 to H (and for the batch BxT).
813
+
814
+ Returns:
815
+ The value targets in the shape: [t0toH-1, BxT]. Note that the last step (H)
816
+ does not require a value target as it matches the critic's value prediction
817
+ anyways.
818
+ """
819
+ # The first reward is irrelevant (not used for any VF target).
820
+ rewards_t1_to_H_BxT = rewards_t0_to_H_BxT[1:]
821
+ if intrinsic_rewards_t1_to_H_BxT is not None:
822
+ rewards_t1_to_H_BxT += intrinsic_rewards_t1_to_H_BxT
823
+
824
+ # In all the following, when building value targets for t=1 to T=H,
825
+ # exclude rewards & continues for t=1 b/c we don't need r1 or c1.
826
+ # The target (R1) for V1 is built from r2, c2, and V2/R2.
827
+ discount = continues_t0_to_H_BxT[1:] * config.gamma # shape=[2-16, BxT]
828
+ Rs = [value_predictions_t0_to_H_BxT[-1]] # Rs indices=[16]
829
+ intermediates = (
830
+ rewards_t1_to_H_BxT
831
+ + discount * (1 - config.gae_lambda) * value_predictions_t0_to_H_BxT[1:]
832
+ )
833
+ # intermediates.shape=[2-16, BxT]
834
+
835
+ # Loop through reversed timesteps (axis=1) from T+1 to t=2.
836
+ for t in reversed(range(discount.shape[0])):
837
+ Rs.append(intermediates[t] + discount[t] * config.gae_lambda * Rs[-1])
838
+
839
+ # Reverse along time axis and cut the last entry (value estimate at very end
840
+ # cannot be learnt from as it's the same as the ... well ... value estimate).
841
+ targets_t0toHm1_BxT = tf.stack(list(reversed(Rs))[:-1], axis=0)
842
+ # targets.shape=[t0 to H-1,BxT]
843
+
844
+ return targets_t0toHm1_BxT
845
+
846
+ def _compute_scaled_value_targets(
847
+ self,
848
+ *,
849
+ module_id: ModuleID,
850
+ config: DreamerV3Config,
851
+ value_targets_t0_to_Hm1_BxT: TensorType,
852
+ value_predictions_t0_to_Hm1_BxT: TensorType,
853
+ ) -> TensorType:
854
+ """Helper method computing the scaled value targets.
855
+
856
+ Args:
857
+ module_id: The module_id to compute value targets for.
858
+ config: The DreamerV3Config to use.
859
+ value_targets_t0_to_Hm1_BxT: The value targets computed by
860
+ `self._compute_value_targets` in the shape of (t0 to H-1, BxT)
861
+ and of type float32.
862
+ value_predictions_t0_to_Hm1_BxT: The critic's value predictions over the
863
+ dreamed trajectories (w/o the last timestep). The shape of this
864
+ tensor is (t0 to H-1, BxT) and the type is float32.
865
+
866
+ Returns:
867
+ The scaled value targets used by the actor for REINFORCE policy updates
868
+ (using scaled advantages). See [1] eq. 12 for more details.
869
+ """
870
+ actor = self.module[module_id].actor
871
+
872
+ value_targets_H_B = value_targets_t0_to_Hm1_BxT
873
+ value_predictions_H_B = value_predictions_t0_to_Hm1_BxT
874
+
875
+ # Compute S: [1] eq. 12.
876
+ Per_R_5 = tfp.stats.percentile(value_targets_H_B, 5)
877
+ Per_R_95 = tfp.stats.percentile(value_targets_H_B, 95)
878
+
879
+ # Update EMA values for 5 and 95 percentile, stored as tf variables under actor
880
+ # network.
881
+ # 5 percentile
882
+ new_val_pct5 = tf.where(
883
+ tf.math.is_nan(actor.ema_value_target_pct5),
884
+ # is NaN: Initial values: Just set.
885
+ Per_R_5,
886
+ # Later update (something already stored in EMA variable): Update EMA.
887
+ (
888
+ config.return_normalization_decay * actor.ema_value_target_pct5
889
+ + (1.0 - config.return_normalization_decay) * Per_R_5
890
+ ),
891
+ )
892
+ actor.ema_value_target_pct5.assign(new_val_pct5)
893
+ # 95 percentile
894
+ new_val_pct95 = tf.where(
895
+ tf.math.is_nan(actor.ema_value_target_pct95),
896
+ # is NaN: Initial values: Just set.
897
+ Per_R_95,
898
+ # Later update (something already stored in EMA variable): Update EMA.
899
+ (
900
+ config.return_normalization_decay * actor.ema_value_target_pct95
901
+ + (1.0 - config.return_normalization_decay) * Per_R_95
902
+ ),
903
+ )
904
+ actor.ema_value_target_pct95.assign(new_val_pct95)
905
+
906
+ # [1] eq. 11 (first term).
907
+ offset = actor.ema_value_target_pct5
908
+ invscale = tf.math.maximum(
909
+ 1e-8, actor.ema_value_target_pct95 - actor.ema_value_target_pct5
910
+ )
911
+ scaled_value_targets_H_B = (value_targets_H_B - offset) / invscale
912
+ scaled_value_predictions_H_B = (value_predictions_H_B - offset) / invscale
913
+
914
+ # Return advantages.
915
+ return scaled_value_targets_H_B - scaled_value_predictions_H_B
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+
6
+ [2] Mastering Atari with Discrete World Models - 2021
7
+ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
8
+ https://arxiv.org/pdf/2010.02193.pdf
9
+ """
10
+ from ray.rllib.algorithms.dreamerv3.dreamerv3_rl_module import DreamerV3RLModule
11
+ from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
12
+ from ray.rllib.utils.framework import try_import_tf
13
+
14
+ tf1, tf, _ = try_import_tf()
15
+
16
+
17
+ class DreamerV3TfRLModule(TfRLModule, DreamerV3RLModule):
18
+ """The tf-specific RLModule class for DreamerV3.
19
+
20
+ Serves mainly as a thin-wrapper around the `DreamerModel` (a tf.keras.Model) class.
21
+ """
22
+
23
+ framework = "tf2"
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ import gymnasium as gym
7
+ from gymnasium.spaces import Box, Discrete
8
+ import numpy as np
9
+
10
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
11
+ from ray.rllib.algorithms.dreamerv3.utils import (
12
+ get_gru_units,
13
+ get_num_z_categoricals,
14
+ get_num_z_classes,
15
+ )
16
+ from ray.rllib.utils.framework import try_import_tf, try_import_tfp
17
+
18
+ _, tf, _ = try_import_tf()
19
+ tfp = try_import_tfp()
20
+
21
+
22
+ class ActorNetwork(tf.keras.Model):
23
+ """The `actor` (policy net) of DreamerV3.
24
+
25
+ Consists of a simple MLP for Discrete actions and two MLPs for cont. actions (mean
26
+ and stddev).
27
+ Also contains two scalar variables to keep track of the percentile-5 and
28
+ percentile-95 values of the computed value targets within a batch. This is used to
29
+ compute the "scaled value targets" for actor learning. These two variables decay
30
+ over time exponentially (see [1] for more details).
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ model_size: str = "XS",
37
+ action_space: gym.Space,
38
+ ):
39
+ """Initializes an ActorNetwork instance.
40
+
41
+ Args:
42
+ model_size: The "Model Size" used according to [1] Appendix B.
43
+ Use None for manually setting the different network sizes.
44
+ action_space: The action space of the environment used.
45
+ """
46
+ super().__init__(name="actor")
47
+
48
+ self.model_size = model_size
49
+ self.action_space = action_space
50
+
51
+ # The EMA decay variables used for the [Percentile(R, 95%) - Percentile(R, 5%)]
52
+ # diff to scale value targets for the actor loss.
53
+ self.ema_value_target_pct5 = tf.Variable(
54
+ np.nan, trainable=False, name="value_target_pct5"
55
+ )
56
+ self.ema_value_target_pct95 = tf.Variable(
57
+ np.nan, trainable=False, name="value_target_pct95"
58
+ )
59
+
60
+ # For discrete actions, use a single MLP that computes logits.
61
+ if isinstance(self.action_space, Discrete):
62
+ self.mlp = MLP(
63
+ model_size=self.model_size,
64
+ output_layer_size=self.action_space.n,
65
+ name="actor_mlp",
66
+ )
67
+ # For cont. actions, use separate MLPs for Gaussian mean and stddev.
68
+ # TODO (sven): In the author's original code repo, this is NOT the case,
69
+ # inputs are pushed through a shared MLP, then only the two output linear
70
+ # layers are separate for std- and mean logits.
71
+ elif isinstance(action_space, Box):
72
+ output_layer_size = np.prod(action_space.shape)
73
+ self.mlp = MLP(
74
+ model_size=self.model_size,
75
+ output_layer_size=output_layer_size,
76
+ name="actor_mlp_mean",
77
+ )
78
+ self.std_mlp = MLP(
79
+ model_size=self.model_size,
80
+ output_layer_size=output_layer_size,
81
+ name="actor_mlp_std",
82
+ )
83
+ else:
84
+ raise ValueError(f"Invalid action space: {action_space}")
85
+
86
+ # Trace self.call.
87
+ dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
88
+ self.call = tf.function(
89
+ input_signature=[
90
+ tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
91
+ tf.TensorSpec(
92
+ shape=[
93
+ None,
94
+ get_num_z_categoricals(model_size),
95
+ get_num_z_classes(model_size),
96
+ ],
97
+ dtype=dl_type,
98
+ ),
99
+ ]
100
+ )(self.call)
101
+
102
+ def call(self, h, z):
103
+ """Performs a forward pass through this policy network.
104
+
105
+ Args:
106
+ h: The deterministic hidden state of the sequence model. [B, dim(h)].
107
+ z: The stochastic discrete representations of the original
108
+ observation input. [B, num_categoricals, num_classes].
109
+ """
110
+ # Flatten last two dims of z.
111
+ assert len(z.shape) == 3
112
+ z_shape = tf.shape(z)
113
+ z = tf.reshape(z, shape=(z_shape[0], -1))
114
+ assert len(z.shape) == 2
115
+ out = tf.concat([h, z], axis=-1)
116
+ out.set_shape(
117
+ [
118
+ None,
119
+ (
120
+ get_num_z_categoricals(self.model_size)
121
+ * get_num_z_classes(self.model_size)
122
+ + get_gru_units(self.model_size)
123
+ ),
124
+ ]
125
+ )
126
+ # Send h-cat-z through MLP.
127
+ action_logits = tf.cast(self.mlp(out), tf.float32)
128
+
129
+ if isinstance(self.action_space, Discrete):
130
+ action_probs = tf.nn.softmax(action_logits)
131
+
132
+ # Add the unimix weighting (1% uniform) to the probs.
133
+ # See [1]: "Unimix categoricals: We parameterize the categorical
134
+ # distributions for the world model representations and dynamics, as well as
135
+ # for the actor network, as mixtures of 1% uniform and 99% neural network
136
+ # output to ensure a minimal amount of probability mass on every class and
137
+ # thus keep log probabilities and KL divergences well behaved."
138
+ action_probs = 0.99 * action_probs + 0.01 * (1.0 / self.action_space.n)
139
+
140
+ # Danijar's code does: distr = [Distr class](logits=tf.log(probs)).
141
+ # Not sure why we don't directly use the already available probs instead.
142
+ action_logits = tf.math.log(action_probs)
143
+
144
+ # Distribution parameters are the log(probs) directly.
145
+ distr_params = action_logits
146
+ distr = self.get_action_dist_object(distr_params)
147
+
148
+ action = tf.stop_gradient(distr.sample()) + (
149
+ action_probs - tf.stop_gradient(action_probs)
150
+ )
151
+
152
+ elif isinstance(self.action_space, Box):
153
+ # Send h-cat-z through MLP to compute stddev logits for Normal dist
154
+ std_logits = tf.cast(self.std_mlp(out), tf.float32)
155
+ # minstd, maxstd taken from [1] from configs.yaml
156
+ minstd = 0.1
157
+ maxstd = 1.0
158
+
159
+ # Distribution parameters are the squashed std_logits and the tanh'd
160
+ # mean logits.
161
+ # squash std_logits from (-inf, inf) to (minstd, maxstd)
162
+ std_logits = (maxstd - minstd) * tf.sigmoid(std_logits + 2.0) + minstd
163
+ mean_logits = tf.tanh(action_logits)
164
+
165
+ distr_params = tf.concat([mean_logits, std_logits], axis=-1)
166
+ distr = self.get_action_dist_object(distr_params)
167
+
168
+ action = distr.sample()
169
+
170
+ return action, distr_params
171
+
172
+ def get_action_dist_object(self, action_dist_params_T_B):
173
+ """Helper method to create an action distribution object from (T, B, ..) params.
174
+
175
+ Args:
176
+ action_dist_params_T_B: The time-major action distribution parameters.
177
+ This could be simply the logits (discrete) or a to-be-split-in-2
178
+ tensor for mean and stddev (continuous).
179
+
180
+ Returns:
181
+ The tfp action distribution object, from which one can sample, compute
182
+ log probs, entropy, etc..
183
+ """
184
+ if isinstance(self.action_space, gym.spaces.Discrete):
185
+ # Create the distribution object using the unimix'd logits.
186
+ distr = tfp.distributions.OneHotCategorical(
187
+ logits=action_dist_params_T_B,
188
+ dtype=tf.float32,
189
+ )
190
+
191
+ elif isinstance(self.action_space, gym.spaces.Box):
192
+ # Compute Normal distribution from action_logits and std_logits
193
+ loc, scale = tf.split(action_dist_params_T_B, 2, axis=-1)
194
+ distr = tfp.distributions.Normal(loc=loc, scale=scale)
195
+
196
+ # If action_space is a box with multiple dims, make individual dims
197
+ # independent.
198
+ distr = tfp.distributions.Independent(distr, len(self.action_space.shape))
199
+
200
+ else:
201
+ raise ValueError(f"Action space {self.action_space} not supported!")
202
+
203
+ return distr
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ from typing import Optional
7
+
8
+ from ray.rllib.algorithms.dreamerv3.utils import get_cnn_multiplier
9
+ from ray.rllib.utils.framework import try_import_tf
10
+
11
+ _, tf, _ = try_import_tf()
12
+
13
+
14
+ class CNNAtari(tf.keras.Model):
15
+ """An image encoder mapping 64x64 RGB images via 4 CNN layers into a 1D space."""
16
+
17
+ def __init__(
18
+ self,
19
+ *,
20
+ model_size: Optional[str] = "XS",
21
+ cnn_multiplier: Optional[int] = None,
22
+ ):
23
+ """Initializes a CNNAtari instance.
24
+
25
+ Args:
26
+ model_size: The "Model Size" used according to [1] Appendix B.
27
+ Use None for manually setting the `cnn_multiplier`.
28
+ cnn_multiplier: Optional override for the additional factor used to multiply
29
+ the number of filters with each CNN layer. Starting with
30
+ 1 * `cnn_multiplier` filters in the first CNN layer, the number of
31
+ filters then increases via `2*cnn_multiplier`, `4*cnn_multiplier`, till
32
+ `8*cnn_multiplier`.
33
+ """
34
+ super().__init__(name="image_encoder")
35
+
36
+ cnn_multiplier = get_cnn_multiplier(model_size, override=cnn_multiplier)
37
+
38
+ # See appendix C in [1]:
39
+ # "We use a similar network architecture but employ layer normalization and
40
+ # SiLU as the activation function. For better framework support, we use
41
+ # same-padded convolutions with stride 2 and kernel size 3 instead of
42
+ # valid-padded convolutions with larger kernels ..."
43
+ # HOWEVER: In Danijar's DreamerV3 repo, kernel size=4 is used, so we use it
44
+ # here, too.
45
+ self.conv_layers = [
46
+ tf.keras.layers.Conv2D(
47
+ filters=1 * cnn_multiplier,
48
+ kernel_size=4,
49
+ strides=(2, 2),
50
+ padding="same",
51
+ # No bias or activation due to layernorm.
52
+ activation=None,
53
+ use_bias=False,
54
+ ),
55
+ tf.keras.layers.Conv2D(
56
+ filters=2 * cnn_multiplier,
57
+ kernel_size=4,
58
+ strides=(2, 2),
59
+ padding="same",
60
+ # No bias or activation due to layernorm.
61
+ activation=None,
62
+ use_bias=False,
63
+ ),
64
+ tf.keras.layers.Conv2D(
65
+ filters=4 * cnn_multiplier,
66
+ kernel_size=4,
67
+ strides=(2, 2),
68
+ padding="same",
69
+ # No bias or activation due to layernorm.
70
+ activation=None,
71
+ use_bias=False,
72
+ ),
73
+ # .. until output is 4 x 4 x [num_filters].
74
+ tf.keras.layers.Conv2D(
75
+ filters=8 * cnn_multiplier,
76
+ kernel_size=4,
77
+ strides=(2, 2),
78
+ padding="same",
79
+ # No bias or activation due to layernorm.
80
+ activation=None,
81
+ use_bias=False,
82
+ ),
83
+ ]
84
+ self.layer_normalizations = []
85
+ for _ in range(len(self.conv_layers)):
86
+ self.layer_normalizations.append(tf.keras.layers.LayerNormalization())
87
+ # -> 4 x 4 x num_filters -> now flatten.
88
+ self.flatten_layer = tf.keras.layers.Flatten(data_format="channels_last")
89
+
90
+ @tf.function(
91
+ input_signature=[
92
+ tf.TensorSpec(
93
+ shape=[None, 64, 64, 3],
94
+ dtype=tf.keras.mixed_precision.global_policy().compute_dtype
95
+ or tf.float32,
96
+ )
97
+ ]
98
+ )
99
+ def call(self, inputs):
100
+ """Performs a forward pass through the CNN Atari encoder.
101
+
102
+ Args:
103
+ inputs: The image inputs of shape (B, 64, 64, 3).
104
+ """
105
+ # [B, h, w] -> grayscale.
106
+ if len(inputs.shape) == 3:
107
+ inputs = tf.expand_dims(inputs, -1)
108
+ out = inputs
109
+ for conv_2d, layer_norm in zip(self.conv_layers, self.layer_normalizations):
110
+ out = tf.nn.silu(layer_norm(inputs=conv_2d(out)))
111
+ assert out.shape[1] == 4 and out.shape[2] == 4
112
+ return self.flatten_layer(out)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+
6
+ [2] Mastering Atari with Discrete World Models - 2021
7
+ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
8
+ https://arxiv.org/pdf/2010.02193.pdf
9
+ """
10
+ from typing import Optional
11
+
12
+ import numpy as np
13
+
14
+ from ray.rllib.algorithms.dreamerv3.utils import (
15
+ get_cnn_multiplier,
16
+ get_gru_units,
17
+ get_num_z_categoricals,
18
+ get_num_z_classes,
19
+ )
20
+ from ray.rllib.utils.framework import try_import_tf
21
+
22
+ _, tf, _ = try_import_tf()
23
+
24
+
25
+ class ConvTransposeAtari(tf.keras.Model):
26
+ """A Conv2DTranspose decoder to generate Atari images from a latent space.
27
+
28
+ Wraps an initial single linear layer with a stack of 4 Conv2DTranspose layers (with
29
+ layer normalization) and a diag Gaussian, from which we then sample the final image.
30
+ Sampling is done with a fixed stddev=1.0 and using the mean values coming from the
31
+ last Conv2DTranspose layer.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ model_size: Optional[str] = "XS",
38
+ cnn_multiplier: Optional[int] = None,
39
+ gray_scaled: bool,
40
+ ):
41
+ """Initializes a ConvTransposeAtari instance.
42
+
43
+ Args:
44
+ model_size: The "Model Size" used according to [1] Appendinx B.
45
+ Use None for manually setting the `cnn_multiplier`.
46
+ cnn_multiplier: Optional override for the additional factor used to multiply
47
+ the number of filters with each CNN transpose layer. Starting with
48
+ 8 * `cnn_multiplier` filters in the first CNN transpose layer, the
49
+ number of filters then decreases via `4*cnn_multiplier`,
50
+ `2*cnn_multiplier`, till `1*cnn_multiplier`.
51
+ gray_scaled: Whether the last Conv2DTranspose layer's output has only 1
52
+ color channel (gray_scaled=True) or 3 RGB channels (gray_scaled=False).
53
+ """
54
+ super().__init__(name="image_decoder")
55
+
56
+ self.model_size = model_size
57
+ cnn_multiplier = get_cnn_multiplier(self.model_size, override=cnn_multiplier)
58
+
59
+ # The shape going into the first Conv2DTranspose layer.
60
+ # We start with a 4x4 channels=8 "image".
61
+ self.input_dims = (4, 4, 8 * cnn_multiplier)
62
+
63
+ self.gray_scaled = gray_scaled
64
+
65
+ # See appendix B in [1]:
66
+ # "The decoder starts with a dense layer, followed by reshaping
67
+ # to 4 × 4 × C and then inverts the encoder architecture. ..."
68
+ self.dense_layer = tf.keras.layers.Dense(
69
+ units=int(np.prod(self.input_dims)),
70
+ activation=None,
71
+ use_bias=True,
72
+ )
73
+ # Inverse conv2d stack. See cnn_atari.py for corresponding Conv2D stack.
74
+ self.conv_transpose_layers = [
75
+ tf.keras.layers.Conv2DTranspose(
76
+ filters=4 * cnn_multiplier,
77
+ kernel_size=4,
78
+ strides=(2, 2),
79
+ padding="same",
80
+ # No bias or activation due to layernorm.
81
+ activation=None,
82
+ use_bias=False,
83
+ ),
84
+ tf.keras.layers.Conv2DTranspose(
85
+ filters=2 * cnn_multiplier,
86
+ kernel_size=4,
87
+ strides=(2, 2),
88
+ padding="same",
89
+ # No bias or activation due to layernorm.
90
+ activation=None,
91
+ use_bias=False,
92
+ ),
93
+ tf.keras.layers.Conv2DTranspose(
94
+ filters=1 * cnn_multiplier,
95
+ kernel_size=4,
96
+ strides=(2, 2),
97
+ padding="same",
98
+ # No bias or activation due to layernorm.
99
+ activation=None,
100
+ use_bias=False,
101
+ ),
102
+ ]
103
+ # Create one LayerNorm layer for each of the Conv2DTranspose layers.
104
+ self.layer_normalizations = []
105
+ for _ in range(len(self.conv_transpose_layers)):
106
+ self.layer_normalizations.append(tf.keras.layers.LayerNormalization())
107
+
108
+ # Important! No activation or layer norm for last layer as the outputs of
109
+ # this one go directly into the diag-gaussian as parameters.
110
+ self.output_conv2d_transpose = tf.keras.layers.Conv2DTranspose(
111
+ filters=1 if self.gray_scaled else 3,
112
+ kernel_size=4,
113
+ strides=(2, 2),
114
+ padding="same",
115
+ activation=None,
116
+ use_bias=True, # Last layer does use bias (b/c has no LayerNorm).
117
+ )
118
+ # .. until output is 64 x 64 x 3 (or 1 for self.gray_scaled=True).
119
+
120
+ # Trace self.call.
121
+ dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
122
+ self.call = tf.function(
123
+ input_signature=[
124
+ tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
125
+ tf.TensorSpec(
126
+ shape=[
127
+ None,
128
+ get_num_z_categoricals(model_size),
129
+ get_num_z_classes(model_size),
130
+ ],
131
+ dtype=dl_type,
132
+ ),
133
+ ]
134
+ )(self.call)
135
+
136
+ def call(self, h, z):
137
+ """Performs a forward pass through the Conv2D transpose decoder.
138
+
139
+ Args:
140
+ h: The deterministic hidden state of the sequence model.
141
+ z: The sequence of stochastic discrete representations of the original
142
+ observation input. Note: `z` is not used for the dynamics predictor
143
+ model (which predicts z from h).
144
+ """
145
+ # Flatten last two dims of z.
146
+ assert len(z.shape) == 3
147
+ z_shape = tf.shape(z)
148
+ z = tf.reshape(z, shape=(z_shape[0], -1))
149
+ assert len(z.shape) == 2
150
+ input_ = tf.concat([h, z], axis=-1)
151
+ input_.set_shape(
152
+ [
153
+ None,
154
+ (
155
+ get_num_z_categoricals(self.model_size)
156
+ * get_num_z_classes(self.model_size)
157
+ + get_gru_units(self.model_size)
158
+ ),
159
+ ]
160
+ )
161
+
162
+ # Feed through initial dense layer to get the right number of input nodes
163
+ # for the first conv2dtranspose layer.
164
+ out = self.dense_layer(input_)
165
+ # Reshape to image format.
166
+ out = tf.reshape(out, shape=(-1,) + self.input_dims)
167
+
168
+ # Pass through stack of Conv2DTransport layers (and layer norms).
169
+ for conv_transpose_2d, layer_norm in zip(
170
+ self.conv_transpose_layers, self.layer_normalizations
171
+ ):
172
+ out = tf.nn.silu(layer_norm(inputs=conv_transpose_2d(out)))
173
+ # Last output conv2d-transpose layer:
174
+ out = self.output_conv2d_transpose(out)
175
+ out += 0.5 # See Danijar's code
176
+ out_shape = tf.shape(out)
177
+
178
+ # Interpret output as means of a diag-Gaussian with std=1.0:
179
+ # From [2]:
180
+ # "Distributions: The image predictor outputs the mean of a diagonal Gaussian
181
+ # likelihood with unit variance, ..."
182
+
183
+ # Reshape `out` for the diagonal multi-variate Gaussian (each pixel is its own
184
+ # independent (b/c diagonal co-variance matrix) variable).
185
+ loc = tf.reshape(out, shape=(out_shape[0], -1))
186
+
187
+ return loc
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ import gymnasium as gym
7
+
8
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
9
+ from ray.rllib.algorithms.dreamerv3.utils import (
10
+ get_gru_units,
11
+ get_num_z_categoricals,
12
+ get_num_z_classes,
13
+ )
14
+ from ray.rllib.utils.framework import try_import_tf
15
+
16
+ _, tf, _ = try_import_tf()
17
+
18
+
19
+ class VectorDecoder(tf.keras.Model):
20
+ """A simple vector decoder to reproduce non-image (1D vector) observations.
21
+
22
+ Wraps an MLP for mean parameter computations and a Gaussian distribution,
23
+ from which we then sample using these mean values and a fixed stddev of 1.0.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ *,
29
+ model_size: str = "XS",
30
+ observation_space: gym.Space,
31
+ ):
32
+ """Initializes a VectorDecoder instance.
33
+
34
+ Args:
35
+ model_size: The "Model Size" used according to [1] Appendinx B.
36
+ Determines the exact size of the underlying MLP.
37
+ observation_space: The observation space to decode back into. This must
38
+ be a Box of shape (d,), where d >= 1.
39
+ """
40
+ super().__init__(name="vector_decoder")
41
+
42
+ self.model_size = model_size
43
+
44
+ assert (
45
+ isinstance(observation_space, gym.spaces.Box)
46
+ and len(observation_space.shape) == 1
47
+ )
48
+
49
+ self.mlp = MLP(
50
+ model_size=model_size,
51
+ output_layer_size=observation_space.shape[0],
52
+ )
53
+
54
+ # Trace self.call.
55
+ dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
56
+ self.call = tf.function(
57
+ input_signature=[
58
+ tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
59
+ tf.TensorSpec(
60
+ shape=[
61
+ None,
62
+ get_num_z_categoricals(model_size),
63
+ get_num_z_classes(model_size),
64
+ ],
65
+ dtype=dl_type,
66
+ ),
67
+ ]
68
+ )(self.call)
69
+
70
+ def call(self, h, z):
71
+ """Performs a forward pass through the vector encoder.
72
+
73
+ Args:
74
+ h: The deterministic hidden state of the sequence model. [B, dim(h)].
75
+ z: The stochastic discrete representations of the original
76
+ observation input. [B, num_categoricals, num_classes].
77
+ """
78
+ # Flatten last two dims of z.
79
+ assert len(z.shape) == 3
80
+ z_shape = tf.shape(z)
81
+ z = tf.reshape(z, shape=(z_shape[0], -1))
82
+ assert len(z.shape) == 2
83
+ out = tf.concat([h, z], axis=-1)
84
+ out.set_shape(
85
+ [
86
+ None,
87
+ (
88
+ get_num_z_categoricals(self.model_size)
89
+ * get_num_z_classes(self.model_size)
90
+ + get_gru_units(self.model_size)
91
+ ),
92
+ ]
93
+ )
94
+ # Send h-cat-z through MLP to get mean values of diag gaussian.
95
+ loc = self.mlp(out)
96
+
97
+ # Return only the predicted observations (mean, no sample).
98
+ return loc
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
7
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor_layer import (
8
+ RewardPredictorLayer,
9
+ )
10
+ from ray.rllib.algorithms.dreamerv3.utils import (
11
+ get_gru_units,
12
+ get_num_z_categoricals,
13
+ get_num_z_classes,
14
+ )
15
+ from ray.rllib.utils.framework import try_import_tf
16
+
17
+ _, tf, _ = try_import_tf()
18
+
19
+
20
+ class CriticNetwork(tf.keras.Model):
21
+ """The critic network described in [1], predicting values for policy learning.
22
+
23
+ Contains a copy of itself (EMA net) for weight regularization.
24
+ The EMA net is updated after each train step via EMA (using the `ema_decay`
25
+ parameter and the actual critic's weights). The EMA net is NOT used for target
26
+ computations (we use the actual critic for that), its only purpose is to compute a
27
+ weights regularizer term for the critic's loss such that the actual critic does not
28
+ move too quickly.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ *,
34
+ model_size: str = "XS",
35
+ num_buckets: int = 255,
36
+ lower_bound: float = -20.0,
37
+ upper_bound: float = 20.0,
38
+ ema_decay: float = 0.98,
39
+ ):
40
+ """Initializes a CriticNetwork instance.
41
+
42
+ Args:
43
+ model_size: The "Model Size" used according to [1] Appendinx B.
44
+ Use None for manually setting the different network sizes.
45
+ num_buckets: The number of buckets to create. Note that the number of
46
+ possible symlog'd outcomes from the used distribution is
47
+ `num_buckets` + 1:
48
+ lower_bound --bucket-- o[1] --bucket-- o[2] ... --bucket-- upper_bound
49
+ o=outcomes
50
+ lower_bound=o[0]
51
+ upper_bound=o[num_buckets]
52
+ lower_bound: The symlog'd lower bound for a possible reward value.
53
+ Note that a value of -20.0 here already allows individual (actual env)
54
+ rewards to be as low as -400M. Buckets will be created between
55
+ `lower_bound` and `upper_bound`.
56
+ upper_bound: The symlog'd upper bound for a possible reward value.
57
+ Note that a value of +20.0 here already allows individual (actual env)
58
+ rewards to be as high as 400M. Buckets will be created between
59
+ `lower_bound` and `upper_bound`.
60
+ ema_decay: The weight to use for updating the weights of the critic's copy
61
+ vs the actual critic. After each training update, the EMA copy of the
62
+ critic gets updated according to:
63
+ ema_net=(`ema_decay`*ema_net) + (1.0-`ema_decay`)*critic_net
64
+ The EMA copy of the critic is used inside the critic loss function only
65
+ to produce a regularizer term against the current critic's weights, NOT
66
+ to compute any target values.
67
+ """
68
+ super().__init__(name="critic")
69
+
70
+ self.model_size = model_size
71
+ self.ema_decay = ema_decay
72
+
73
+ # "Fast" critic network(s) (mlp + reward-pred-layer). This is the network
74
+ # we actually train with our critic loss.
75
+ # IMPORTANT: We also use this to compute the return-targets, BUT we regularize
76
+ # the critic loss term such that the weights of this fast critic stay close
77
+ # to the EMA weights (see below).
78
+ self.mlp = MLP(
79
+ model_size=self.model_size,
80
+ output_layer_size=None,
81
+ )
82
+ self.return_layer = RewardPredictorLayer(
83
+ num_buckets=num_buckets,
84
+ lower_bound=lower_bound,
85
+ upper_bound=upper_bound,
86
+ )
87
+
88
+ # Weights-EMA (EWMA) containing networks for critic loss (similar to a
89
+ # target net, BUT not used to compute anything, just for the
90
+ # weights regularizer term inside the critic loss).
91
+ self.mlp_ema = MLP(
92
+ model_size=self.model_size,
93
+ output_layer_size=None,
94
+ trainable=False,
95
+ )
96
+ self.return_layer_ema = RewardPredictorLayer(
97
+ num_buckets=num_buckets,
98
+ lower_bound=lower_bound,
99
+ upper_bound=upper_bound,
100
+ trainable=False,
101
+ )
102
+
103
+ # Trace self.call.
104
+ dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
105
+ self.call = tf.function(
106
+ input_signature=[
107
+ tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
108
+ tf.TensorSpec(
109
+ shape=[
110
+ None,
111
+ get_num_z_categoricals(model_size),
112
+ get_num_z_classes(model_size),
113
+ ],
114
+ dtype=dl_type,
115
+ ),
116
+ tf.TensorSpec(shape=[], dtype=tf.bool),
117
+ ]
118
+ )(self.call)
119
+
120
+ def call(self, h, z, use_ema):
121
+ """Performs a forward pass through the critic network.
122
+
123
+ Args:
124
+ h: The deterministic hidden state of the sequence model. [B, dim(h)].
125
+ z: The stochastic discrete representations of the original
126
+ observation input. [B, num_categoricals, num_classes].
127
+ use_ema: Whether to use the EMA-copy of the critic instead of the actual
128
+ critic to perform this computation.
129
+ """
130
+ # Flatten last two dims of z.
131
+ assert len(z.shape) == 3
132
+ z_shape = tf.shape(z)
133
+ z = tf.reshape(z, shape=(z_shape[0], -1))
134
+ assert len(z.shape) == 2
135
+ out = tf.concat([h, z], axis=-1)
136
+ out.set_shape(
137
+ [
138
+ None,
139
+ (
140
+ get_num_z_categoricals(self.model_size)
141
+ * get_num_z_classes(self.model_size)
142
+ + get_gru_units(self.model_size)
143
+ ),
144
+ ]
145
+ )
146
+
147
+ if not use_ema:
148
+ # Send h-cat-z through MLP.
149
+ out = self.mlp(out)
150
+ # Return expected return OR (expected return, probs of bucket values).
151
+ return self.return_layer(out)
152
+ else:
153
+ out = self.mlp_ema(out)
154
+ return self.return_layer_ema(out)
155
+
156
+ def init_ema(self) -> None:
157
+ """Initializes the EMA-copy of the critic from the critic's weights.
158
+
159
+ After calling this method, the two networks have identical weights.
160
+ """
161
+ vars = self.mlp.trainable_variables + self.return_layer.trainable_variables
162
+ vars_ema = self.mlp_ema.variables + self.return_layer_ema.variables
163
+ assert len(vars) == len(vars_ema) and len(vars) > 0
164
+ for var, var_ema in zip(vars, vars_ema):
165
+ assert var is not var_ema
166
+ var_ema.assign(var)
167
+
168
+ def update_ema(self) -> None:
169
+ """Updates the EMA-copy of the critic according to the update formula:
170
+
171
+ ema_net=(`ema_decay`*ema_net) + (1.0-`ema_decay`)*critic_net
172
+ """
173
+ vars = self.mlp.trainable_variables + self.return_layer.trainable_variables
174
+ vars_ema = self.mlp_ema.variables + self.return_layer_ema.variables
175
+ assert len(vars) == len(vars_ema) and len(vars) > 0
176
+ for var, var_ema in zip(vars, vars_ema):
177
+ var_ema.assign(self.ema_decay * var_ema + (1.0 - self.ema_decay) * var)
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+
7
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
8
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.representation_layer import (
9
+ RepresentationLayer,
10
+ )
11
+ from ray.rllib.utils.framework import try_import_tf, try_import_tfp
12
+
13
+ _, tf, _ = try_import_tf()
14
+ tfp = try_import_tfp()
15
+
16
+
17
+ class DisagreeNetworks(tf.keras.Model):
18
+ """Predict the RSSM's z^(t+1), given h(t), z^(t), and a(t).
19
+
20
+ Disagreement (stddev) between the N networks in this model on what the next z^ would
21
+ be are used to produce intrinsic rewards for enhanced, curiosity-based exploration.
22
+
23
+ TODO
24
+ """
25
+
26
+ def __init__(self, *, num_networks, model_size, intrinsic_rewards_scale):
27
+ super().__init__(name="disagree_networks")
28
+
29
+ self.model_size = model_size
30
+ self.num_networks = num_networks
31
+ self.intrinsic_rewards_scale = intrinsic_rewards_scale
32
+
33
+ self.mlps = []
34
+ self.representation_layers = []
35
+
36
+ for _ in range(self.num_networks):
37
+ self.mlps.append(
38
+ MLP(
39
+ model_size=self.model_size,
40
+ output_layer_size=None,
41
+ trainable=True,
42
+ )
43
+ )
44
+ self.representation_layers.append(
45
+ RepresentationLayer(model_size=self.model_size, name="disagree")
46
+ )
47
+
48
+ def call(self, inputs, z, a, training=None):
49
+ return self.forward_train(a=a, h=inputs, z=z)
50
+
51
+ def compute_intrinsic_rewards(self, h, z, a):
52
+ forward_train_outs = self.forward_train(a=a, h=h, z=z)
53
+ B = tf.shape(h)[0]
54
+
55
+ # Intrinsic rewards are computed as:
56
+ # Stddev (between the different nets) of the 32x32 discrete, stochastic
57
+ # probabilities. Meaning that if the larger the disagreement
58
+ # (stddev) between the nets on what the probabilities for the different
59
+ # classes should be, the higher the intrinsic reward.
60
+ z_predicted_probs_N_B = forward_train_outs["z_predicted_probs_N_HxB"]
61
+ N = len(z_predicted_probs_N_B)
62
+ z_predicted_probs_N_B = tf.stack(z_predicted_probs_N_B, axis=0)
63
+ # Flatten z-dims (num_categoricals x num_classes).
64
+ z_predicted_probs_N_B = tf.reshape(z_predicted_probs_N_B, shape=(N, B, -1))
65
+
66
+ # Compute stddevs over all disagree nets (axis=0).
67
+ # Mean over last axis ([num categoricals] x [num classes] folded axis).
68
+ stddevs_B_mean = tf.reduce_mean(
69
+ tf.math.reduce_std(z_predicted_probs_N_B, axis=0),
70
+ axis=-1,
71
+ )
72
+ # TEST:
73
+ stddevs_B_mean -= tf.reduce_mean(stddevs_B_mean)
74
+ # END TEST
75
+ return {
76
+ "rewards_intrinsic": stddevs_B_mean * self.intrinsic_rewards_scale,
77
+ "forward_train_outs": forward_train_outs,
78
+ }
79
+
80
+ def forward_train(self, a, h, z):
81
+ HxB = tf.shape(h)[0]
82
+ # Fold z-dims.
83
+ z = tf.reshape(z, shape=(HxB, -1))
84
+ # Concat all input components (h, z, and a).
85
+ inputs_ = tf.stop_gradient(tf.concat([h, z, a], axis=-1))
86
+
87
+ z_predicted_probs_N_HxB = [
88
+ repr(mlp(inputs_))[1] # [0]=sample; [1]=returned probs
89
+ for mlp, repr in zip(self.mlps, self.representation_layers)
90
+ ]
91
+ # shape=(N, HxB, [num categoricals], [num classes]); N=number of disagree nets.
92
+ # HxB -> folded horizon_H x batch_size_B (from dreamed data).
93
+
94
+ return {"z_predicted_probs_N_HxB": z_predicted_probs_N_HxB}
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ import re
7
+
8
+ import gymnasium as gym
9
+ import numpy as np
10
+
11
+ from ray.rllib.algorithms.dreamerv3.tf.models.disagree_networks import DisagreeNetworks
12
+ from ray.rllib.algorithms.dreamerv3.tf.models.actor_network import ActorNetwork
13
+ from ray.rllib.algorithms.dreamerv3.tf.models.critic_network import CriticNetwork
14
+ from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel
15
+ from ray.rllib.algorithms.dreamerv3.utils import (
16
+ get_gru_units,
17
+ get_num_z_categoricals,
18
+ get_num_z_classes,
19
+ )
20
+ from ray.rllib.utils.framework import try_import_tf
21
+ from ray.rllib.utils.tf_utils import inverse_symlog
22
+
23
+ _, tf, _ = try_import_tf()
24
+
25
+
26
+ class DreamerModel(tf.keras.Model):
27
+ """The main tf-keras model containing all necessary components for DreamerV3.
28
+
29
+ Includes:
30
+ - The world model with encoder, decoder, sequence-model (RSSM), dynamics
31
+ (generates prior z-state), and "posterior" model (generates posterior z-state).
32
+ Predicts env dynamics and produces dreamed trajectories for actor- and critic
33
+ learning.
34
+ - The actor network (policy).
35
+ - The critic network for value function prediction.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ *,
41
+ model_size: str = "XS",
42
+ action_space: gym.Space,
43
+ world_model: WorldModel,
44
+ actor: ActorNetwork,
45
+ critic: CriticNetwork,
46
+ horizon: int,
47
+ gamma: float,
48
+ use_curiosity: bool = False,
49
+ intrinsic_rewards_scale: float = 0.1,
50
+ ):
51
+ """Initializes a DreamerModel instance.
52
+
53
+ Args:
54
+ model_size: The "Model Size" used according to [1] Appendinx B.
55
+ Use None for manually setting the different network sizes.
56
+ action_space: The action space of the environment used.
57
+ world_model: The WorldModel component.
58
+ actor: The ActorNetwork component.
59
+ critic: The CriticNetwork component.
60
+ horizon: The dream horizon to use when creating dreamed trajectories.
61
+ """
62
+ super().__init__(name="dreamer_model")
63
+
64
+ self.model_size = model_size
65
+ self.action_space = action_space
66
+ self.use_curiosity = use_curiosity
67
+
68
+ self.world_model = world_model
69
+ self.actor = actor
70
+ self.critic = critic
71
+
72
+ self.horizon = horizon
73
+ self.gamma = gamma
74
+ self._comp_dtype = (
75
+ tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
76
+ )
77
+
78
+ self.disagree_nets = None
79
+ if self.use_curiosity:
80
+ self.disagree_nets = DisagreeNetworks(
81
+ num_networks=8,
82
+ model_size=self.model_size,
83
+ intrinsic_rewards_scale=intrinsic_rewards_scale,
84
+ )
85
+
86
+ self.dream_trajectory = tf.function(
87
+ input_signature=[
88
+ {
89
+ "h": tf.TensorSpec(
90
+ shape=[
91
+ None,
92
+ get_gru_units(self.model_size),
93
+ ],
94
+ dtype=self._comp_dtype,
95
+ ),
96
+ "z": tf.TensorSpec(
97
+ shape=[
98
+ None,
99
+ get_num_z_categoricals(self.model_size),
100
+ get_num_z_classes(self.model_size),
101
+ ],
102
+ dtype=self._comp_dtype,
103
+ ),
104
+ },
105
+ tf.TensorSpec(shape=[None], dtype=tf.bool),
106
+ ]
107
+ )(self.dream_trajectory)
108
+
109
+ def call(
110
+ self,
111
+ inputs,
112
+ observations,
113
+ actions,
114
+ is_first,
115
+ start_is_terminated_BxT,
116
+ gamma,
117
+ ):
118
+ """Main call method for building this model in order to generate its variables.
119
+
120
+ Note: This method should NOT be used by users directly. It's purpose is only to
121
+ perform all forward passes necessary to define all variables of the DreamerV3.
122
+ """
123
+
124
+ # Forward passes through all models are enough to build all trainable and
125
+ # non-trainable variables:
126
+
127
+ # World model.
128
+ results = self.world_model.forward_train(
129
+ observations,
130
+ actions,
131
+ is_first,
132
+ )
133
+ # Actor.
134
+ _, distr_params = self.actor(
135
+ h=results["h_states_BxT"],
136
+ z=results["z_posterior_states_BxT"],
137
+ )
138
+ # Critic.
139
+ values, _ = self.critic(
140
+ h=results["h_states_BxT"],
141
+ z=results["z_posterior_states_BxT"],
142
+ use_ema=tf.convert_to_tensor(False),
143
+ )
144
+
145
+ # Dream pipeline.
146
+ dream_data = self.dream_trajectory(
147
+ start_states={
148
+ "h": results["h_states_BxT"],
149
+ "z": results["z_posterior_states_BxT"],
150
+ },
151
+ start_is_terminated=start_is_terminated_BxT,
152
+ )
153
+
154
+ return {
155
+ "world_model_fwd": results,
156
+ "dream_data": dream_data,
157
+ "actions": actions,
158
+ "values": values,
159
+ }
160
+
161
+ @tf.function
162
+ def forward_inference(self, observations, previous_states, is_first, training=None):
163
+ """Performs a (non-exploring) action computation step given obs and states.
164
+
165
+ Note that all input data should not have a time rank (only a batch dimension).
166
+
167
+ Args:
168
+ observations: The current environment observation with shape (B, ...).
169
+ previous_states: Dict with keys `a`, `h`, and `z` used as input to the RSSM
170
+ to produce the next h-state, from which then to compute the action
171
+ using the actor network. All values in the dict should have shape
172
+ (B, ...) (no time rank).
173
+ is_first: Batch of is_first flags. These should be True if a new episode
174
+ has been started at the current timestep (meaning `observations` is the
175
+ reset observation from the environment).
176
+ """
177
+ # Perform one step in the world model (starting from `previous_state` and
178
+ # using the observations to yield a current (posterior) state).
179
+ states = self.world_model.forward_inference(
180
+ observations=observations,
181
+ previous_states=previous_states,
182
+ is_first=is_first,
183
+ )
184
+ # Compute action using our actor network and the current states.
185
+ _, distr_params = self.actor(h=states["h"], z=states["z"])
186
+ # Use the mode of the distribution (Discrete=argmax, Normal=mean).
187
+ distr = self.actor.get_action_dist_object(distr_params)
188
+ actions = distr.mode()
189
+ return actions, {"h": states["h"], "z": states["z"], "a": actions}
190
+
191
+ @tf.function
192
+ def forward_exploration(
193
+ self, observations, previous_states, is_first, training=None
194
+ ):
195
+ """Performs an exploratory action computation step given obs and states.
196
+
197
+ Note that all input data should not have a time rank (only a batch dimension).
198
+
199
+ Args:
200
+ observations: The current environment observation with shape (B, ...).
201
+ previous_states: Dict with keys `a`, `h`, and `z` used as input to the RSSM
202
+ to produce the next h-state, from which then to compute the action
203
+ using the actor network. All values in the dict should have shape
204
+ (B, ...) (no time rank).
205
+ is_first: Batch of is_first flags. These should be True if a new episode
206
+ has been started at the current timestep (meaning `observations` is the
207
+ reset observation from the environment).
208
+ """
209
+ # Perform one step in the world model (starting from `previous_state` and
210
+ # using the observations to yield a current (posterior) state).
211
+ states = self.world_model.forward_inference(
212
+ observations=observations,
213
+ previous_states=previous_states,
214
+ is_first=is_first,
215
+ )
216
+ # Compute action using our actor network and the current states.
217
+ actions, _ = self.actor(h=states["h"], z=states["z"])
218
+ return actions, {"h": states["h"], "z": states["z"], "a": actions}
219
+
220
+ def forward_train(self, observations, actions, is_first):
221
+ """Performs a training forward pass given observations and actions.
222
+
223
+ Note that all input data must have a time rank (batch-major: [B, T, ...]).
224
+
225
+ Args:
226
+ observations: The environment observations with shape (B, T, ...). Thus,
227
+ the batch has B rows of T timesteps each. Note that it's ok to have
228
+ episode boundaries (is_first=True) within a batch row. DreamerV3 will
229
+ simply insert an initial state before these locations and continue the
230
+ sequence modelling (with the RSSM). Hence, there will be no zero
231
+ padding.
232
+ actions: The actions actually taken in the environment with shape
233
+ (B, T, ...). See `observations` docstring for details on how B and T are
234
+ handled.
235
+ is_first: Batch of is_first flags. These should be True:
236
+ - if a new episode has been started at the current timestep (meaning
237
+ `observations` is the reset observation from the environment).
238
+ - in each batch row at T=0 (first timestep of each of the B batch
239
+ rows), regardless of whether the actual env had an episode boundary
240
+ there or not.
241
+ """
242
+ return self.world_model.forward_train(
243
+ observations=observations,
244
+ actions=actions,
245
+ is_first=is_first,
246
+ )
247
+
248
+ @tf.function
249
+ def get_initial_state(self):
250
+ """Returns the (current) initial state of the dreamer model (a, h-, z-states).
251
+
252
+ An initial state is generated using the previous action, the tanh of the
253
+ (learned) h-state variable and the dynamics predictor (or "prior net") to
254
+ compute z^0 from h0. In this last step, it is important that we do NOT sample
255
+ the z^-state (as we would usually do during dreaming), but rather take the mode
256
+ (argmax, then one-hot again).
257
+ """
258
+ states = self.world_model.get_initial_state()
259
+
260
+ action_dim = (
261
+ self.action_space.n
262
+ if isinstance(self.action_space, gym.spaces.Discrete)
263
+ else np.prod(self.action_space.shape)
264
+ )
265
+ states["a"] = tf.zeros(
266
+ (
267
+ 1,
268
+ action_dim,
269
+ ),
270
+ dtype=tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32,
271
+ )
272
+ return states
273
+
274
+ def dream_trajectory(self, start_states, start_is_terminated):
275
+ """Dreams trajectories of length H from batch of h- and z-states.
276
+
277
+ Note that incoming data will have the shapes (BxT, ...), where the original
278
+ batch- and time-dimensions are already folded together. Beginning from this
279
+ new batch dim (BxT), we will unroll `timesteps_H` timesteps in a time-major
280
+ fashion, such that the dreamed data will have shape (H, BxT, ...).
281
+
282
+ Args:
283
+ start_states: Dict of `h` and `z` states in the shape of (B, ...) and
284
+ (B, num_categoricals, num_classes), respectively, as
285
+ computed by a train forward pass. From each individual h-/z-state pair
286
+ in the given batch, we will branch off a dreamed trajectory of len
287
+ `timesteps_H`.
288
+ start_is_terminated: Float flags of shape (B,) indicating whether the
289
+ first timesteps of each batch row is already a terminated timestep
290
+ (given by the actual environment).
291
+ """
292
+ # Dreamed actions (one-hot encoded for discrete actions).
293
+ a_dreamed_t0_to_H = []
294
+ a_dreamed_dist_params_t0_to_H = []
295
+
296
+ h = start_states["h"]
297
+ z = start_states["z"]
298
+
299
+ # GRU outputs.
300
+ h_states_t0_to_H = [h]
301
+ # Dynamics model outputs.
302
+ z_states_prior_t0_to_H = [z]
303
+
304
+ # Compute `a` using actor network (already the first step uses a dreamed action,
305
+ # not a sampled one).
306
+ a, a_dist_params = self.actor(
307
+ # We have to stop the gradients through the states. B/c we are using a
308
+ # differentiable Discrete action distribution (straight through gradients
309
+ # with `a = stop_gradient(sample(probs)) + probs - stop_gradient(probs)`,
310
+ # we otherwise would add dependencies of the `-log(pi(a|s))` REINFORCE loss
311
+ # term on actions further back in the trajectory.
312
+ h=tf.stop_gradient(h),
313
+ z=tf.stop_gradient(z),
314
+ )
315
+ a_dreamed_t0_to_H.append(a)
316
+ a_dreamed_dist_params_t0_to_H.append(a_dist_params)
317
+
318
+ for i in range(self.horizon):
319
+ # Move one step in the dream using the RSSM.
320
+ h = self.world_model.sequence_model(a=a, h=h, z=z)
321
+ h_states_t0_to_H.append(h)
322
+
323
+ # Compute prior z using dynamics model.
324
+ z, _ = self.world_model.dynamics_predictor(h=h)
325
+ z_states_prior_t0_to_H.append(z)
326
+
327
+ # Compute `a` using actor network.
328
+ a, a_dist_params = self.actor(
329
+ h=tf.stop_gradient(h),
330
+ z=tf.stop_gradient(z),
331
+ )
332
+ a_dreamed_t0_to_H.append(a)
333
+ a_dreamed_dist_params_t0_to_H.append(a_dist_params)
334
+
335
+ h_states_H_B = tf.stack(h_states_t0_to_H, axis=0) # (T, B, ...)
336
+ h_states_HxB = tf.reshape(h_states_H_B, [-1] + h_states_H_B.shape.as_list()[2:])
337
+
338
+ z_states_prior_H_B = tf.stack(z_states_prior_t0_to_H, axis=0) # (T, B, ...)
339
+ z_states_prior_HxB = tf.reshape(
340
+ z_states_prior_H_B, [-1] + z_states_prior_H_B.shape.as_list()[2:]
341
+ )
342
+
343
+ a_dreamed_H_B = tf.stack(a_dreamed_t0_to_H, axis=0) # (T, B, ...)
344
+ a_dreamed_dist_params_H_B = tf.stack(a_dreamed_dist_params_t0_to_H, axis=0)
345
+
346
+ # Compute r using reward predictor.
347
+ r_dreamed_HxB, _ = self.world_model.reward_predictor(
348
+ h=h_states_HxB, z=z_states_prior_HxB
349
+ )
350
+ r_dreamed_H_B = tf.reshape(
351
+ inverse_symlog(r_dreamed_HxB), shape=[self.horizon + 1, -1]
352
+ )
353
+
354
+ # Compute intrinsic rewards.
355
+ if self.use_curiosity:
356
+ results_HxB = self.disagree_nets.compute_intrinsic_rewards(
357
+ h=h_states_HxB,
358
+ z=z_states_prior_HxB,
359
+ a=tf.reshape(a_dreamed_H_B, [-1] + a_dreamed_H_B.shape.as_list()[2:]),
360
+ )
361
+ # TODO (sven): Wrong? -> Cut out last timestep as we always predict z-states
362
+ # for the NEXT timestep and derive ri (for the NEXT timestep) from the
363
+ # disagreement between our N disagreee nets.
364
+ r_intrinsic_H_B = tf.reshape(
365
+ results_HxB["rewards_intrinsic"], shape=[self.horizon + 1, -1]
366
+ )[
367
+ 1:
368
+ ] # cut out first ts instead
369
+ curiosity_forward_train_outs = results_HxB["forward_train_outs"]
370
+ del results_HxB
371
+
372
+ # Compute continues using continue predictor.
373
+ c_dreamed_HxB, _ = self.world_model.continue_predictor(
374
+ h=h_states_HxB,
375
+ z=z_states_prior_HxB,
376
+ )
377
+ c_dreamed_H_B = tf.reshape(c_dreamed_HxB, [self.horizon + 1, -1])
378
+ # Force-set first `continue` flags to False iff `start_is_terminated`.
379
+ # Note: This will cause the loss-weights for this row in the batch to be
380
+ # completely zero'd out. In general, we don't use dreamed data past any
381
+ # predicted (or actual first) continue=False flags.
382
+ c_dreamed_H_B = tf.concat(
383
+ [
384
+ 1.0
385
+ - tf.expand_dims(
386
+ tf.cast(start_is_terminated, tf.float32),
387
+ 0,
388
+ ),
389
+ c_dreamed_H_B[1:],
390
+ ],
391
+ axis=0,
392
+ )
393
+
394
+ # Loss weights for each individual dreamed timestep. Zero-out all timesteps
395
+ # that lie past continue=False flags. B/c our world model does NOT learn how
396
+ # to skip terminal/reset episode boundaries, dreamed data crossing such a
397
+ # boundary should not be used for critic/actor learning either.
398
+ dream_loss_weights_H_B = (
399
+ tf.math.cumprod(self.gamma * c_dreamed_H_B, axis=0) / self.gamma
400
+ )
401
+
402
+ # Compute the value estimates.
403
+ v, v_symlog_dreamed_logits_HxB = self.critic(
404
+ h=h_states_HxB,
405
+ z=z_states_prior_HxB,
406
+ use_ema=False,
407
+ )
408
+ v_dreamed_HxB = inverse_symlog(v)
409
+ v_dreamed_H_B = tf.reshape(v_dreamed_HxB, shape=[self.horizon + 1, -1])
410
+
411
+ v_symlog_dreamed_ema_HxB, _ = self.critic(
412
+ h=h_states_HxB,
413
+ z=z_states_prior_HxB,
414
+ use_ema=True,
415
+ )
416
+ v_symlog_dreamed_ema_H_B = tf.reshape(
417
+ v_symlog_dreamed_ema_HxB, shape=[self.horizon + 1, -1]
418
+ )
419
+
420
+ ret = {
421
+ "h_states_t0_to_H_BxT": h_states_H_B,
422
+ "z_states_prior_t0_to_H_BxT": z_states_prior_H_B,
423
+ "rewards_dreamed_t0_to_H_BxT": r_dreamed_H_B,
424
+ "continues_dreamed_t0_to_H_BxT": c_dreamed_H_B,
425
+ "actions_dreamed_t0_to_H_BxT": a_dreamed_H_B,
426
+ "actions_dreamed_dist_params_t0_to_H_BxT": a_dreamed_dist_params_H_B,
427
+ "values_dreamed_t0_to_H_BxT": v_dreamed_H_B,
428
+ "values_symlog_dreamed_logits_t0_to_HxBxT": v_symlog_dreamed_logits_HxB,
429
+ "v_symlog_dreamed_ema_t0_to_H_BxT": v_symlog_dreamed_ema_H_B,
430
+ # Loss weights for critic- and actor losses.
431
+ "dream_loss_weights_t0_to_H_BxT": dream_loss_weights_H_B,
432
+ }
433
+
434
+ if self.use_curiosity:
435
+ ret["rewards_intrinsic_t1_to_H_B"] = r_intrinsic_H_B
436
+ ret.update(curiosity_forward_train_outs)
437
+
438
+ if isinstance(self.action_space, gym.spaces.Discrete):
439
+ ret["actions_ints_dreamed_t0_to_H_B"] = tf.argmax(a_dreamed_H_B, axis=-1)
440
+
441
+ return ret
442
+
443
+ def dream_trajectory_with_burn_in(
444
+ self,
445
+ *,
446
+ start_states,
447
+ timesteps_burn_in: int,
448
+ timesteps_H: int,
449
+ observations, # [B, >=timesteps_burn_in]
450
+ actions, # [B, timesteps_burn_in (+timesteps_H)?]
451
+ use_sampled_actions_in_dream: bool = False,
452
+ use_random_actions_in_dream: bool = False,
453
+ ):
454
+ """Dreams trajectory from N initial observations and initial states.
455
+
456
+ Note: This is only used for reporting and debugging, not for actual world-model
457
+ or policy training.
458
+
459
+ Args:
460
+ start_states: The batch of start states (dicts with `a`, `h`, and `z` keys)
461
+ to begin dreaming with. These are used to compute the first h-state
462
+ using the sequence model.
463
+ timesteps_burn_in: For how many timesteps should be use the posterior
464
+ z-states (computed by the posterior net and actual observations from
465
+ the env)?
466
+ timesteps_H: For how many timesteps should we dream using the prior
467
+ z-states (computed by the dynamics (prior) net and h-states only)?
468
+ Note that the total length of the returned trajectories will
469
+ be `timesteps_burn_in` + `timesteps_H`.
470
+ observations: The batch (B, T, ...) of observations (to be used only during
471
+ burn-in over `timesteps_burn_in` timesteps).
472
+ actions: The batch (B, T, ...) of actions to use during a) burn-in over the
473
+ first `timesteps_burn_in` timesteps and - possibly - b) during
474
+ actual dreaming, iff use_sampled_actions_in_dream=True.
475
+ If applicable, actions must already be one-hot'd.
476
+ use_sampled_actions_in_dream: If True, instead of using our actor network
477
+ to compute fresh actions, we will use the one provided via the `actions`
478
+ argument. Note that in the latter case, the `actions` time dimension
479
+ must be at least `timesteps_burn_in` + `timesteps_H` long.
480
+ use_random_actions_in_dream: Whether to use randomly sampled actions in the
481
+ dream. Note that this does not apply to the burn-in phase, during which
482
+ we will always use the actions given in the `actions` argument.
483
+ """
484
+ assert not (use_sampled_actions_in_dream and use_random_actions_in_dream)
485
+
486
+ B = observations.shape[0]
487
+
488
+ # Produce initial N internal posterior states (burn-in) using the given
489
+ # observations:
490
+ states = start_states
491
+ for i in range(timesteps_burn_in):
492
+ states = self.world_model.forward_inference(
493
+ observations=observations[:, i],
494
+ previous_states=states,
495
+ is_first=tf.fill((B,), 1.0 if i == 0 else 0.0),
496
+ )
497
+ states["a"] = actions[:, i]
498
+
499
+ # Start producing the actual dream, using prior states and either the given
500
+ # actions, dreamed, or random ones.
501
+ h_states_t0_to_H = [states["h"]]
502
+ z_states_prior_t0_to_H = [states["z"]]
503
+ a_t0_to_H = [states["a"]]
504
+
505
+ for j in range(timesteps_H):
506
+ # Compute next h using sequence model.
507
+ h = self.world_model.sequence_model(
508
+ a=states["a"],
509
+ h=states["h"],
510
+ z=states["z"],
511
+ )
512
+ h_states_t0_to_H.append(h)
513
+ # Compute z from h, using the dynamics model (we don't have an actual
514
+ # observation at this timestep).
515
+ z, _ = self.world_model.dynamics_predictor(h=h)
516
+ z_states_prior_t0_to_H.append(z)
517
+
518
+ # Compute next dreamed action or use sampled one or random one.
519
+ if use_sampled_actions_in_dream:
520
+ a = actions[:, timesteps_burn_in + j]
521
+ elif use_random_actions_in_dream:
522
+ if isinstance(self.action_space, gym.spaces.Discrete):
523
+ a = tf.random.randint((B,), 0, self.action_space.n, tf.int64)
524
+ a = tf.one_hot(
525
+ a,
526
+ depth=self.action_space.n,
527
+ dtype=tf.keras.mixed_precision.global_policy().compute_dtype
528
+ or tf.float32,
529
+ )
530
+ # TODO: Support cont. action spaces with bound other than 0.0 and 1.0.
531
+ else:
532
+ a = tf.random.uniform(
533
+ shape=(B,) + self.action_space.shape,
534
+ dtype=self.action_space.dtype,
535
+ )
536
+ else:
537
+ a, _ = self.actor(h=h, z=z)
538
+ a_t0_to_H.append(a)
539
+
540
+ states = {"h": h, "z": z, "a": a}
541
+
542
+ # Fold time-rank for upcoming batch-predictions (no sequences needed anymore).
543
+ h_states_t0_to_H_B = tf.stack(h_states_t0_to_H, axis=0)
544
+ h_states_t0_to_HxB = tf.reshape(
545
+ h_states_t0_to_H_B, shape=[-1] + h_states_t0_to_H_B.shape.as_list()[2:]
546
+ )
547
+
548
+ z_states_prior_t0_to_H_B = tf.stack(z_states_prior_t0_to_H, axis=0)
549
+ z_states_prior_t0_to_HxB = tf.reshape(
550
+ z_states_prior_t0_to_H_B,
551
+ shape=[-1] + z_states_prior_t0_to_H_B.shape.as_list()[2:],
552
+ )
553
+
554
+ a_t0_to_H_B = tf.stack(a_t0_to_H, axis=0)
555
+
556
+ # Compute o using decoder.
557
+ o_dreamed_t0_to_HxB = self.world_model.decoder(
558
+ h=h_states_t0_to_HxB,
559
+ z=z_states_prior_t0_to_HxB,
560
+ )
561
+ if self.world_model.symlog_obs:
562
+ o_dreamed_t0_to_HxB = inverse_symlog(o_dreamed_t0_to_HxB)
563
+
564
+ # Compute r using reward predictor.
565
+ r_dreamed_t0_to_HxB, _ = self.world_model.reward_predictor(
566
+ h=h_states_t0_to_HxB,
567
+ z=z_states_prior_t0_to_HxB,
568
+ )
569
+ r_dreamed_t0_to_HxB = inverse_symlog(r_dreamed_t0_to_HxB)
570
+ # Compute continues using continue predictor.
571
+ c_dreamed_t0_to_HxB, _ = self.world_model.continue_predictor(
572
+ h=h_states_t0_to_HxB,
573
+ z=z_states_prior_t0_to_HxB,
574
+ )
575
+
576
+ # Return everything as time-major (H, B, ...), where H is the timesteps dreamed
577
+ # (NOT burn-in'd) and B is a batch dimension (this might or might not include
578
+ # an original time dimension from the real env, from all of which we then branch
579
+ # out our dream trajectories).
580
+ ret = {
581
+ "h_states_t0_to_H_BxT": h_states_t0_to_H_B,
582
+ "z_states_prior_t0_to_H_BxT": z_states_prior_t0_to_H_B,
583
+ # Unfold time-ranks in predictions.
584
+ "observations_dreamed_t0_to_H_BxT": tf.reshape(
585
+ o_dreamed_t0_to_HxB, [-1, B] + list(observations.shape)[2:]
586
+ ),
587
+ "rewards_dreamed_t0_to_H_BxT": tf.reshape(r_dreamed_t0_to_HxB, (-1, B)),
588
+ "continues_dreamed_t0_to_H_BxT": tf.reshape(c_dreamed_t0_to_HxB, (-1, B)),
589
+ }
590
+
591
+ # Figure out action key (random, sampled from env, dreamed?).
592
+ if use_sampled_actions_in_dream:
593
+ key = "actions_sampled_t0_to_H_BxT"
594
+ elif use_random_actions_in_dream:
595
+ key = "actions_random_t0_to_H_BxT"
596
+ else:
597
+ key = "actions_dreamed_t0_to_H_BxT"
598
+ ret[key] = a_t0_to_H_B
599
+
600
+ # Also provide int-actions, if discrete action space.
601
+ if isinstance(self.action_space, gym.spaces.Discrete):
602
+ ret[re.sub("^actions_", "actions_ints_", key)] = tf.argmax(
603
+ a_t0_to_H_B, axis=-1
604
+ )
605
+
606
+ return ret
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ [1] Mastering Diverse Domains through World Models - 2023
3
+ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
4
+ https://arxiv.org/pdf/2301.04104v1.pdf
5
+ """
6
+ from typing import Optional
7
+
8
+ import gymnasium as gym
9
+ import tree # pip install dm_tree
10
+
11
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.continue_predictor import (
12
+ ContinuePredictor,
13
+ )
14
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.dynamics_predictor import (
15
+ DynamicsPredictor,
16
+ )
17
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
18
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.representation_layer import (
19
+ RepresentationLayer,
20
+ )
21
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor import (
22
+ RewardPredictor,
23
+ )
24
+ from ray.rllib.algorithms.dreamerv3.tf.models.components.sequence_model import (
25
+ SequenceModel,
26
+ )
27
+ from ray.rllib.algorithms.dreamerv3.utils import get_gru_units
28
+ from ray.rllib.utils.framework import try_import_tf
29
+ from ray.rllib.utils.tf_utils import symlog
30
+
31
+
32
+ _, tf, _ = try_import_tf()
33
+
34
+
35
+ class WorldModel(tf.keras.Model):
36
+ """WorldModel component of [1] w/ encoder, decoder, RSSM, reward/cont. predictors.
37
+
38
+ See eq. 3 of [1] for all components and their respective in- and outputs.
39
+ Note that in the paper, the "encoder" includes both the raw encoder plus the
40
+ "posterior net", which produces posterior z-states from observations and h-states.
41
+
42
+ Note: The "internal state" of the world model always consists of:
43
+ The actions `a` (initially, this is a zeroed-out action), `h`-states (deterministic,
44
+ continuous), and `z`-states (stochastic, discrete).
45
+ There are two versions of z-states: "posterior" for world model training and "prior"
46
+ for creating the dream data.
47
+
48
+ Initial internal state values (`a`, `h`, and `z`) are inserted where ever a new
49
+ episode starts within a batch row OR at the beginning of each train batch's B rows,
50
+ regardless of whether there was an actual episode boundary or not. Thus, internal
51
+ states are not required to be stored in or retrieved from the replay buffer AND
52
+ retrieved batches from the buffer must not be zero padded.
53
+
54
+ Initial `a` is the zero "one hot" action, e.g. [0.0, 0.0] for Discrete(2), initial
55
+ `h` is a separate learned variable, and initial `z` are computed by the "dynamics"
56
+ (or "prior") net, using only the initial-h state as input.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ *,
62
+ model_size: str = "XS",
63
+ observation_space: gym.Space,
64
+ action_space: gym.Space,
65
+ batch_length_T: int = 64,
66
+ encoder: tf.keras.Model,
67
+ decoder: tf.keras.Model,
68
+ num_gru_units: Optional[int] = None,
69
+ symlog_obs: bool = True,
70
+ ):
71
+ """Initializes a WorldModel instance.
72
+
73
+ Args:
74
+ model_size: The "Model Size" used according to [1] Appendinx B.
75
+ Use None for manually setting the different network sizes.
76
+ observation_space: The observation space of the environment used.
77
+ action_space: The action space of the environment used.
78
+ batch_length_T: The length (T) of the sequences used for training. The
79
+ actual shape of the input data (e.g. rewards) is then: [B, T, ...],
80
+ where B is the "batch size", T is the "batch length" (this arg) and
81
+ "..." is the dimension of the data (e.g. (64, 64, 3) for Atari image
82
+ observations). Note that a single row (within a batch) may contain data
83
+ from different episodes, but an already on-going episode is always
84
+ finished, before a new one starts within the same row.
85
+ encoder: The encoder Model taking observations as inputs and
86
+ outputting a 1D latent vector that will be used as input into the
87
+ posterior net (z-posterior state generating layer). Inputs are symlogged
88
+ if inputs are NOT images. For images, we use normalization between -1.0
89
+ and 1.0 (x / 128 - 1.0)
90
+ decoder: The decoder Model taking h- and z-states as inputs and generating
91
+ a (possibly symlogged) predicted observation. Note that for images,
92
+ the last decoder layer produces the exact, normalized pixel values
93
+ (not a Gaussian as described in [1]!).
94
+ num_gru_units: The number of GRU units to use. If None, use
95
+ `model_size` to figure out this parameter.
96
+ symlog_obs: Whether to predict decoded observations in symlog space.
97
+ This should be False for image based observations.
98
+ According to the paper [1] Appendix E: "NoObsSymlog: This ablation
99
+ removes the symlog encoding of inputs to the world model and also
100
+ changes the symlog MSE loss in the decoder to a simple MSE loss.
101
+ *Because symlog encoding is only used for vector observations*, this
102
+ ablation is equivalent to DreamerV3 on purely image-based environments".
103
+ """
104
+ super().__init__(name="world_model")
105
+
106
+ self.model_size = model_size
107
+ self.batch_length_T = batch_length_T
108
+ self.symlog_obs = symlog_obs
109
+ self.observation_space = observation_space
110
+ self.action_space = action_space
111
+ self._comp_dtype = (
112
+ tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
113
+ )
114
+
115
+ # Encoder (latent 1D vector generator) (xt -> lt).
116
+ self.encoder = encoder
117
+
118
+ # Posterior predictor consisting of an MLP and a RepresentationLayer:
119
+ # [ht, lt] -> zt.
120
+ self.posterior_mlp = MLP(
121
+ model_size=self.model_size,
122
+ output_layer_size=None,
123
+ # In Danijar's code, the posterior predictor only has a single layer,
124
+ # no matter the model size:
125
+ num_dense_layers=1,
126
+ name="posterior_mlp",
127
+ )
128
+ # The (posterior) z-state generating layer.
129
+ self.posterior_representation_layer = RepresentationLayer(
130
+ model_size=self.model_size,
131
+ )
132
+
133
+ # Dynamics (prior z-state) predictor: ht -> z^t
134
+ self.dynamics_predictor = DynamicsPredictor(model_size=self.model_size)
135
+
136
+ # GRU for the RSSM: [at, ht, zt] -> ht+1
137
+ self.num_gru_units = get_gru_units(
138
+ model_size=self.model_size,
139
+ override=num_gru_units,
140
+ )
141
+ # Initial h-state variable (learnt).
142
+ # -> tanh(self.initial_h) -> deterministic state
143
+ # Use our Dynamics predictor for initial stochastic state, BUT with greedy
144
+ # (mode) instead of sampling.
145
+ self.initial_h = tf.Variable(
146
+ tf.zeros(shape=(self.num_gru_units,)),
147
+ trainable=True,
148
+ name="initial_h",
149
+ )
150
+ # The actual sequence model containing the GRU layer.
151
+ self.sequence_model = SequenceModel(
152
+ model_size=self.model_size,
153
+ action_space=self.action_space,
154
+ num_gru_units=self.num_gru_units,
155
+ )
156
+
157
+ # Reward Predictor: [ht, zt] -> rt.
158
+ self.reward_predictor = RewardPredictor(model_size=self.model_size)
159
+ # Continue Predictor: [ht, zt] -> ct.
160
+ self.continue_predictor = ContinuePredictor(model_size=self.model_size)
161
+
162
+ # Decoder: [ht, zt] -> x^t.
163
+ self.decoder = decoder
164
+
165
+ # Trace self.call.
166
+ self.forward_train = tf.function(
167
+ input_signature=[
168
+ tf.TensorSpec(shape=[None, None] + list(self.observation_space.shape)),
169
+ tf.TensorSpec(
170
+ shape=[None, None]
171
+ + (
172
+ [self.action_space.n]
173
+ if isinstance(action_space, gym.spaces.Discrete)
174
+ else list(self.action_space.shape)
175
+ )
176
+ ),
177
+ tf.TensorSpec(shape=[None, None], dtype=tf.bool),
178
+ ]
179
+ )(self.forward_train)
180
+
181
+ @tf.function
182
+ def get_initial_state(self):
183
+ """Returns the (current) initial state of the world model (h- and z-states).
184
+
185
+ An initial state is generated using the tanh of the (learned) h-state variable
186
+ and the dynamics predictor (or "prior net") to compute z^0 from h0. In this last
187
+ step, it is important that we do NOT sample the z^-state (as we would usually
188
+ do during dreaming), but rather take the mode (argmax, then one-hot again).
189
+ """
190
+ h = tf.expand_dims(tf.math.tanh(tf.cast(self.initial_h, self._comp_dtype)), 0)
191
+ # Use the mode, NOT a sample for the initial z-state.
192
+ _, z_probs = self.dynamics_predictor(h)
193
+ z = tf.argmax(z_probs, axis=-1)
194
+ z = tf.one_hot(z, depth=z_probs.shape[-1], dtype=self._comp_dtype)
195
+
196
+ return {"h": h, "z": z}
197
+
198
+ def forward_inference(self, observations, previous_states, is_first, training=None):
199
+ """Performs a forward step for inference (e.g. environment stepping).
200
+
201
+ Works analogous to `forward_train`, except that all inputs are provided
202
+ for a single timestep in the shape of [B, ...] (no time dimension!).
203
+
204
+ Args:
205
+ observations: The batch (B, ...) of observations to be passed through
206
+ the encoder network to yield the inputs to the representation layer
207
+ (which then can compute the z-states).
208
+ previous_states: A dict with `h`, `z`, and `a` keys mapping to the
209
+ respective previous states/actions. All of the shape (B, ...), no time
210
+ rank.
211
+ is_first: The batch (B) of `is_first` flags.
212
+
213
+ Returns:
214
+ The next deterministic h-state (h(t+1)) as predicted by the sequence model.
215
+ """
216
+ observations = tf.cast(observations, self._comp_dtype)
217
+
218
+ initial_states = tree.map_structure(
219
+ lambda s: tf.repeat(s, tf.shape(observations)[0], axis=0),
220
+ self.get_initial_state(),
221
+ )
222
+
223
+ # If first, mask it with initial state/actions.
224
+ previous_h = self._mask(previous_states["h"], 1.0 - is_first) # zero out
225
+ previous_h = previous_h + self._mask(initial_states["h"], is_first) # add init
226
+
227
+ previous_z = self._mask(previous_states["z"], 1.0 - is_first) # zero out
228
+ previous_z = previous_z + self._mask(initial_states["z"], is_first) # add init
229
+
230
+ # Zero out actions (no special learnt initial state).
231
+ previous_a = self._mask(previous_states["a"], 1.0 - is_first)
232
+
233
+ # Compute new states.
234
+ h = self.sequence_model(a=previous_a, h=previous_h, z=previous_z)
235
+ z = self.compute_posterior_z(observations=observations, initial_h=h)
236
+
237
+ return {"h": h, "z": z}
238
+
239
+ def forward_train(self, observations, actions, is_first):
240
+ """Performs a forward step for training.
241
+
242
+ 1) Forwards all observations [B, T, ...] through the encoder network to yield
243
+ o_processed[B, T, ...].
244
+ 2) Uses initial state (h0/z^0/a0[B, 0, ...]) and sequence model (RSSM) to
245
+ compute the first internal state (h1 and z^1).
246
+ 3) Uses action a[B, 1, ...], z[B, 1, ...] and h[B, 1, ...] to compute the
247
+ next h-state (h[B, 2, ...]), etc..
248
+ 4) Repeats 2) and 3) until t=T.
249
+ 5) Uses all h[B, T, ...] and z[B, T, ...] to compute predicted/reconstructed
250
+ observations, rewards, and continue signals.
251
+ 6) Returns predictions from 5) along with all z-states z[B, T, ...] and
252
+ the final h-state (h[B, ...] for t=T).
253
+
254
+ Should we encounter is_first=True flags in the middle of a batch row (somewhere
255
+ within an ongoing sequence of length T), we insert this world model's initial
256
+ state again (zero-action, learned init h-state, and prior-computed z^) and
257
+ simply continue (no zero-padding).
258
+
259
+ Args:
260
+ observations: The batch (B, T, ...) of observations to be passed through
261
+ the encoder network to yield the inputs to the representation layer
262
+ (which then can compute the posterior z-states).
263
+ actions: The batch (B, T, ...) of actions to be used in combination with
264
+ h-states and computed z-states to yield the next h-states.
265
+ is_first: The batch (B, T) of `is_first` flags.
266
+ """
267
+ if self.symlog_obs:
268
+ observations = symlog(observations)
269
+
270
+ # Compute bare encoder outs (not z; this is done later with involvement of the
271
+ # sequence model and the h-states).
272
+ # Fold time dimension for CNN pass.
273
+ shape = tf.shape(observations)
274
+ B, T = shape[0], shape[1]
275
+ observations = tf.reshape(
276
+ observations, shape=tf.concat([[-1], shape[2:]], axis=0)
277
+ )
278
+
279
+ encoder_out = self.encoder(tf.cast(observations, self._comp_dtype))
280
+ # Unfold time dimension.
281
+ encoder_out = tf.reshape(
282
+ encoder_out, shape=tf.concat([[B, T], tf.shape(encoder_out)[1:]], axis=0)
283
+ )
284
+ # Make time major for faster upcoming loop.
285
+ encoder_out = tf.transpose(
286
+ encoder_out, perm=[1, 0] + list(range(2, len(encoder_out.shape.as_list())))
287
+ )
288
+ # encoder_out=[T, B, ...]
289
+
290
+ initial_states = tree.map_structure(
291
+ lambda s: tf.repeat(s, B, axis=0), self.get_initial_state()
292
+ )
293
+
294
+ # Make actions and `is_first` time-major.
295
+ actions = tf.transpose(
296
+ tf.cast(actions, self._comp_dtype),
297
+ perm=[1, 0] + list(range(2, tf.shape(actions).shape.as_list()[0])),
298
+ )
299
+ is_first = tf.transpose(tf.cast(is_first, self._comp_dtype), perm=[1, 0])
300
+
301
+ # Loop through the T-axis of our samples and perform one computation step at
302
+ # a time. This is necessary because the sequence model's output (h(t+1)) depends
303
+ # on the current z(t), but z(t) depends on the current sequence model's output
304
+ # h(t).
305
+ z_t0_to_T = [initial_states["z"]]
306
+ z_posterior_probs = []
307
+ z_prior_probs = []
308
+ h_t0_to_T = [initial_states["h"]]
309
+ for t in range(self.batch_length_T):
310
+ # If first, mask it with initial state/actions.
311
+ h_tm1 = self._mask(h_t0_to_T[-1], 1.0 - is_first[t]) # zero out
312
+ h_tm1 = h_tm1 + self._mask(initial_states["h"], is_first[t]) # add init
313
+
314
+ z_tm1 = self._mask(z_t0_to_T[-1], 1.0 - is_first[t]) # zero out
315
+ z_tm1 = z_tm1 + self._mask(initial_states["z"], is_first[t]) # add init
316
+
317
+ # Zero out actions (no special learnt initial state).
318
+ a_tm1 = self._mask(actions[t - 1], 1.0 - is_first[t])
319
+
320
+ # Perform one RSSM (sequence model) step to get the current h.
321
+ h_t = self.sequence_model(a=a_tm1, h=h_tm1, z=z_tm1)
322
+ h_t0_to_T.append(h_t)
323
+
324
+ posterior_mlp_input = tf.concat([encoder_out[t], h_t], axis=-1)
325
+ repr_input = self.posterior_mlp(posterior_mlp_input)
326
+ # Draw one z-sample (z(t)) and also get the z-distribution for dynamics and
327
+ # representation loss computations.
328
+ z_t, z_probs = self.posterior_representation_layer(repr_input)
329
+ # z_t=[B, num_categoricals, num_classes]
330
+ z_posterior_probs.append(z_probs)
331
+ z_t0_to_T.append(z_t)
332
+
333
+ # Compute the predicted z_t (z^) using the dynamics model.
334
+ _, z_probs = self.dynamics_predictor(h_t)
335
+ z_prior_probs.append(z_probs)
336
+
337
+ # Stack at time dimension to yield: [B, T, ...].
338
+ h_t1_to_T = tf.stack(h_t0_to_T[1:], axis=1)
339
+ z_t1_to_T = tf.stack(z_t0_to_T[1:], axis=1)
340
+
341
+ # Fold time axis to retrieve the final (loss ready) Independent distribution
342
+ # (over `num_categoricals` Categoricals).
343
+ z_posterior_probs = tf.stack(z_posterior_probs, axis=1)
344
+ z_posterior_probs = tf.reshape(
345
+ z_posterior_probs,
346
+ shape=[-1] + z_posterior_probs.shape.as_list()[2:],
347
+ )
348
+ # Fold time axis to retrieve the final (loss ready) Independent distribution
349
+ # (over `num_categoricals` Categoricals).
350
+ z_prior_probs = tf.stack(z_prior_probs, axis=1)
351
+ z_prior_probs = tf.reshape(
352
+ z_prior_probs,
353
+ shape=[-1] + z_prior_probs.shape.as_list()[2:],
354
+ )
355
+
356
+ # Fold time dimension for parallelization of all dependent predictions:
357
+ # observations (reproduction via decoder), rewards, continues.
358
+ h_BxT = tf.reshape(h_t1_to_T, shape=[-1] + h_t1_to_T.shape.as_list()[2:])
359
+ z_BxT = tf.reshape(z_t1_to_T, shape=[-1] + z_t1_to_T.shape.as_list()[2:])
360
+
361
+ obs_distribution_means = tf.cast(self.decoder(h=h_BxT, z=z_BxT), tf.float32)
362
+
363
+ # Compute (predicted) reward distributions.
364
+ rewards, reward_logits = self.reward_predictor(h=h_BxT, z=z_BxT)
365
+
366
+ # Compute (predicted) continue distributions.
367
+ continues, continue_distribution = self.continue_predictor(h=h_BxT, z=z_BxT)
368
+
369
+ # Return outputs for loss computation.
370
+ # Note that all shapes are [BxT, ...] (time axis already folded).
371
+ return {
372
+ # Obs.
373
+ "sampled_obs_symlog_BxT": observations,
374
+ "obs_distribution_means_BxT": obs_distribution_means,
375
+ # Rewards.
376
+ "reward_logits_BxT": reward_logits,
377
+ "rewards_BxT": rewards,
378
+ # Continues.
379
+ "continue_distribution_BxT": continue_distribution,
380
+ "continues_BxT": continues,
381
+ # Deterministic, continuous h-states (t1 to T).
382
+ "h_states_BxT": h_BxT,
383
+ # Sampled, discrete posterior z-states and their probs (t1 to T).
384
+ "z_posterior_states_BxT": z_BxT,
385
+ "z_posterior_probs_BxT": z_posterior_probs,
386
+ # Probs of the prior z-states (t1 to T).
387
+ "z_prior_probs_BxT": z_prior_probs,
388
+ }
389
+
390
+ def compute_posterior_z(self, observations, initial_h):
391
+ # Compute bare encoder outputs (not including z, which is computed in next step
392
+ # with involvement of the previous output (initial_h) of the sequence model).
393
+ # encoder_outs=[B, ...]
394
+ if self.symlog_obs:
395
+ observations = symlog(observations)
396
+ encoder_out = self.encoder(observations)
397
+ # Concat encoder outs with the h-states.
398
+ posterior_mlp_input = tf.concat([encoder_out, initial_h], axis=-1)
399
+ # Compute z.
400
+ repr_input = self.posterior_mlp(posterior_mlp_input)
401
+ # Draw a z-sample.
402
+ z_t, _ = self.posterior_representation_layer(repr_input)
403
+ return z_t
404
+
405
+ @staticmethod
406
+ def _mask(value, mask):
407
+ return tf.einsum("b...,b->b...", value, tf.cast(mask, value.dtype))
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO
2
+ from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy
3
+ from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
4
+
5
+ __all__ = [
6
+ "PPO",
7
+ "PPOConfig",
8
+ # @OldAPIStack
9
+ "PPOTF1Policy",
10
+ "PPOTF2Policy",
11
+ "PPOTorchPolicy",
12
+ ]
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (568 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc ADDED
Binary file (3.41 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc ADDED
Binary file (23.6 kB). View file
 
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc ADDED
Binary file (8.73 kB). View file