Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py +8 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py +9 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py +388 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py +426 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py +406 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py +275 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py +206 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py +15 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py +750 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py +80 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py +31 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py +153 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py +915 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py +23 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py +203 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py +112 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py +187 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py +98 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py +177 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py +94 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py +606 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py +407 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -173,3 +173,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 173 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 174 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 175 |
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 173 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 174 |
.venv/lib/python3.11/site-packages/ray/_private/runtime_env/agent/thirdparty_files/propcache/_helpers_c.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 175 |
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
|
| 176 |
+
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 177 |
+
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc96e86e5e36ee78f9cfcd3d87220524f3cb583ba7b0472482fe408fbc1c57fa
|
| 3 |
+
size 114677
|
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e715fb00f3b4360472455b9c5d37eb8337c42bc50fea95d2d75fa67bebdcb096
|
| 3 |
+
size 158454
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/callbacks.cpython-311.pyc
ADDED
|
Binary file (424 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/mock.cpython-311.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (6.39 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (5.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm_config.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/bc/torch/__pycache__/default_bc_torch_rl_module.cpython-311.pyc
ADDED
|
Binary file (3.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/callbacks.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @OldAPIStack
|
| 2 |
+
from ray.rllib.callbacks.callbacks import RLlibCallback
|
| 3 |
+
from ray.rllib.callbacks.utils import _make_multi_callbacks
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Backward compatibility
|
| 7 |
+
DefaultCallbacks = RLlibCallback
|
| 8 |
+
make_multi_callbacks = _make_multi_callbacks
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.cql.cql import CQL, CQLConfig
|
| 2 |
+
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"CQL",
|
| 6 |
+
"CQLConfig",
|
| 7 |
+
# @OldAPIStack
|
| 8 |
+
"CQLTorchPolicy",
|
| 9 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (438 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql.cpython-311.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_tf_policy.cpython-311.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/__pycache__/cql_torch_policy.cpython-311.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, Type, Union
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
|
| 5 |
+
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy
|
| 6 |
+
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy
|
| 7 |
+
from ray.rllib.algorithms.sac.sac import (
|
| 8 |
+
SAC,
|
| 9 |
+
SACConfig,
|
| 10 |
+
)
|
| 11 |
+
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
|
| 12 |
+
AddObservationsFromEpisodesToBatch,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
|
| 15 |
+
AddNextObservationsFromEpisodesToTrainBatch,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.core.learner.learner import Learner
|
| 18 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 19 |
+
from ray.rllib.execution.rollout_ops import (
|
| 20 |
+
synchronous_parallel_sample,
|
| 21 |
+
)
|
| 22 |
+
from ray.rllib.execution.train_ops import (
|
| 23 |
+
multi_gpu_train_one_step,
|
| 24 |
+
train_one_step,
|
| 25 |
+
)
|
| 26 |
+
from ray.rllib.policy.policy import Policy
|
| 27 |
+
from ray.rllib.utils.annotations import OldAPIStack, override
|
| 28 |
+
from ray.rllib.utils.deprecation import (
|
| 29 |
+
DEPRECATED_VALUE,
|
| 30 |
+
deprecation_warning,
|
| 31 |
+
)
|
| 32 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
| 33 |
+
from ray.rllib.utils.metrics import (
|
| 34 |
+
ALL_MODULES,
|
| 35 |
+
LEARNER_RESULTS,
|
| 36 |
+
LEARNER_UPDATE_TIMER,
|
| 37 |
+
LAST_TARGET_UPDATE_TS,
|
| 38 |
+
NUM_AGENT_STEPS_SAMPLED,
|
| 39 |
+
NUM_AGENT_STEPS_TRAINED,
|
| 40 |
+
NUM_ENV_STEPS_SAMPLED,
|
| 41 |
+
NUM_ENV_STEPS_TRAINED,
|
| 42 |
+
NUM_TARGET_UPDATES,
|
| 43 |
+
OFFLINE_SAMPLING_TIMER,
|
| 44 |
+
TARGET_NET_UPDATE_TIMER,
|
| 45 |
+
SYNCH_WORKER_WEIGHTS_TIMER,
|
| 46 |
+
SAMPLE_TIMER,
|
| 47 |
+
TIMERS,
|
| 48 |
+
)
|
| 49 |
+
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType
|
| 50 |
+
|
| 51 |
+
tf1, tf, tfv = try_import_tf()
|
| 52 |
+
tfp = try_import_tfp()
|
| 53 |
+
logger = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class CQLConfig(SACConfig):
|
| 57 |
+
"""Defines a configuration class from which a CQL can be built.
|
| 58 |
+
|
| 59 |
+
.. testcode::
|
| 60 |
+
:skipif: True
|
| 61 |
+
|
| 62 |
+
from ray.rllib.algorithms.cql import CQLConfig
|
| 63 |
+
config = CQLConfig().training(gamma=0.9, lr=0.01)
|
| 64 |
+
config = config.resources(num_gpus=0)
|
| 65 |
+
config = config.env_runners(num_env_runners=4)
|
| 66 |
+
print(config.to_dict())
|
| 67 |
+
# Build a Algorithm object from the config and run 1 training iteration.
|
| 68 |
+
algo = config.build(env="CartPole-v1")
|
| 69 |
+
algo.train()
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, algo_class=None):
|
| 73 |
+
super().__init__(algo_class=algo_class or CQL)
|
| 74 |
+
|
| 75 |
+
# fmt: off
|
| 76 |
+
# __sphinx_doc_begin__
|
| 77 |
+
# CQL-specific config settings:
|
| 78 |
+
self.bc_iters = 20000
|
| 79 |
+
self.temperature = 1.0
|
| 80 |
+
self.num_actions = 10
|
| 81 |
+
self.lagrangian = False
|
| 82 |
+
self.lagrangian_thresh = 5.0
|
| 83 |
+
self.min_q_weight = 5.0
|
| 84 |
+
self.deterministic_backup = True
|
| 85 |
+
self.lr = 3e-4
|
| 86 |
+
# Note, the new stack defines learning rates for each component.
|
| 87 |
+
# The base learning rate `lr` has to be set to `None`, if using
|
| 88 |
+
# the new stack.
|
| 89 |
+
self.actor_lr = 1e-4
|
| 90 |
+
self.critic_lr = 1e-3
|
| 91 |
+
self.alpha_lr = 1e-3
|
| 92 |
+
|
| 93 |
+
self.replay_buffer_config = {
|
| 94 |
+
"_enable_replay_buffer_api": True,
|
| 95 |
+
"type": "MultiAgentPrioritizedReplayBuffer",
|
| 96 |
+
"capacity": int(1e6),
|
| 97 |
+
# If True prioritized replay buffer will be used.
|
| 98 |
+
"prioritized_replay": False,
|
| 99 |
+
"prioritized_replay_alpha": 0.6,
|
| 100 |
+
"prioritized_replay_beta": 0.4,
|
| 101 |
+
"prioritized_replay_eps": 1e-6,
|
| 102 |
+
# Whether to compute priorities already on the remote worker side.
|
| 103 |
+
"worker_side_prioritization": False,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Changes to Algorithm's/SACConfig's default:
|
| 107 |
+
|
| 108 |
+
# .reporting()
|
| 109 |
+
self.min_sample_timesteps_per_iteration = 0
|
| 110 |
+
self.min_train_timesteps_per_iteration = 100
|
| 111 |
+
# fmt: on
|
| 112 |
+
# __sphinx_doc_end__
|
| 113 |
+
|
| 114 |
+
self.timesteps_per_iteration = DEPRECATED_VALUE
|
| 115 |
+
|
| 116 |
+
@override(SACConfig)
|
| 117 |
+
def training(
|
| 118 |
+
self,
|
| 119 |
+
*,
|
| 120 |
+
bc_iters: Optional[int] = NotProvided,
|
| 121 |
+
temperature: Optional[float] = NotProvided,
|
| 122 |
+
num_actions: Optional[int] = NotProvided,
|
| 123 |
+
lagrangian: Optional[bool] = NotProvided,
|
| 124 |
+
lagrangian_thresh: Optional[float] = NotProvided,
|
| 125 |
+
min_q_weight: Optional[float] = NotProvided,
|
| 126 |
+
deterministic_backup: Optional[bool] = NotProvided,
|
| 127 |
+
**kwargs,
|
| 128 |
+
) -> "CQLConfig":
|
| 129 |
+
"""Sets the training-related configuration.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
bc_iters: Number of iterations with Behavior Cloning pretraining.
|
| 133 |
+
temperature: CQL loss temperature.
|
| 134 |
+
num_actions: Number of actions to sample for CQL loss
|
| 135 |
+
lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss).
|
| 136 |
+
lagrangian_thresh: Lagrangian threshold.
|
| 137 |
+
min_q_weight: in Q weight multiplier.
|
| 138 |
+
deterministic_backup: If the target in the Bellman update should have an
|
| 139 |
+
entropy backup. Defaults to `True`.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
This updated AlgorithmConfig object.
|
| 143 |
+
"""
|
| 144 |
+
# Pass kwargs onto super's `training()` method.
|
| 145 |
+
super().training(**kwargs)
|
| 146 |
+
|
| 147 |
+
if bc_iters is not NotProvided:
|
| 148 |
+
self.bc_iters = bc_iters
|
| 149 |
+
if temperature is not NotProvided:
|
| 150 |
+
self.temperature = temperature
|
| 151 |
+
if num_actions is not NotProvided:
|
| 152 |
+
self.num_actions = num_actions
|
| 153 |
+
if lagrangian is not NotProvided:
|
| 154 |
+
self.lagrangian = lagrangian
|
| 155 |
+
if lagrangian_thresh is not NotProvided:
|
| 156 |
+
self.lagrangian_thresh = lagrangian_thresh
|
| 157 |
+
if min_q_weight is not NotProvided:
|
| 158 |
+
self.min_q_weight = min_q_weight
|
| 159 |
+
if deterministic_backup is not NotProvided:
|
| 160 |
+
self.deterministic_backup = deterministic_backup
|
| 161 |
+
|
| 162 |
+
return self
|
| 163 |
+
|
| 164 |
+
@override(AlgorithmConfig)
|
| 165 |
+
def offline_data(self, **kwargs) -> "CQLConfig":
|
| 166 |
+
|
| 167 |
+
super().offline_data(**kwargs)
|
| 168 |
+
|
| 169 |
+
# Check, if the passed in class incorporates the `OfflinePreLearner`
|
| 170 |
+
# interface.
|
| 171 |
+
if "prelearner_class" in kwargs:
|
| 172 |
+
from ray.rllib.offline.offline_data import OfflinePreLearner
|
| 173 |
+
|
| 174 |
+
if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
|
| 175 |
+
raise ValueError(
|
| 176 |
+
f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
|
| 177 |
+
"subclass of `OfflinePreLearner`. Any class passed to "
|
| 178 |
+
"`prelearner_class` needs to implement the interface given by "
|
| 179 |
+
"`OfflinePreLearner`."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return self
|
| 183 |
+
|
| 184 |
+
@override(SACConfig)
|
| 185 |
+
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
|
| 186 |
+
if self.framework_str == "torch":
|
| 187 |
+
from ray.rllib.algorithms.cql.torch.cql_torch_learner import CQLTorchLearner
|
| 188 |
+
|
| 189 |
+
return CQLTorchLearner
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(
|
| 192 |
+
f"The framework {self.framework_str} is not supported. "
|
| 193 |
+
"Use `'torch'` instead."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
@override(AlgorithmConfig)
|
| 197 |
+
def build_learner_connector(
|
| 198 |
+
self,
|
| 199 |
+
input_observation_space,
|
| 200 |
+
input_action_space,
|
| 201 |
+
device=None,
|
| 202 |
+
):
|
| 203 |
+
pipeline = super().build_learner_connector(
|
| 204 |
+
input_observation_space=input_observation_space,
|
| 205 |
+
input_action_space=input_action_space,
|
| 206 |
+
device=device,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
|
| 210 |
+
# after the corresponding "add-OBS-..." default piece).
|
| 211 |
+
pipeline.insert_after(
|
| 212 |
+
AddObservationsFromEpisodesToBatch,
|
| 213 |
+
AddNextObservationsFromEpisodesToTrainBatch(),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return pipeline
|
| 217 |
+
|
| 218 |
+
@override(SACConfig)
|
| 219 |
+
def validate(self) -> None:
|
| 220 |
+
# First check, whether old `timesteps_per_iteration` is used.
|
| 221 |
+
if self.timesteps_per_iteration != DEPRECATED_VALUE:
|
| 222 |
+
deprecation_warning(
|
| 223 |
+
old="timesteps_per_iteration",
|
| 224 |
+
new="min_train_timesteps_per_iteration",
|
| 225 |
+
error=True,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Call super's validation method.
|
| 229 |
+
super().validate()
|
| 230 |
+
|
| 231 |
+
# CQL-torch performs the optimizer steps inside the loss function.
|
| 232 |
+
# Using the multi-GPU optimizer will therefore not work (see multi-GPU
|
| 233 |
+
# check above) and we must use the simple optimizer for now.
|
| 234 |
+
if self.simple_optimizer is not True and self.framework_str == "torch":
|
| 235 |
+
self.simple_optimizer = True
|
| 236 |
+
|
| 237 |
+
if self.framework_str in ["tf", "tf2"] and tfp is None:
|
| 238 |
+
logger.warning(
|
| 239 |
+
"You need `tensorflow_probability` in order to run CQL! "
|
| 240 |
+
"Install it via `pip install tensorflow_probability`. Your "
|
| 241 |
+
f"tf.__version__={tf.__version__ if tf else None}."
|
| 242 |
+
"Trying to import tfp results in the following error:"
|
| 243 |
+
)
|
| 244 |
+
try_import_tfp(error=True)
|
| 245 |
+
|
| 246 |
+
# Assert that for a local learner the number of iterations is 1. Note,
|
| 247 |
+
# this is needed because we have no iterators, but instead a single
|
| 248 |
+
# batch returned directly from the `OfflineData.sample` method.
|
| 249 |
+
if (
|
| 250 |
+
self.num_learners == 0
|
| 251 |
+
and not self.dataset_num_iters_per_learner
|
| 252 |
+
and self.enable_rl_module_and_learner
|
| 253 |
+
):
|
| 254 |
+
self._value_error(
|
| 255 |
+
"When using a single local learner the number of iterations "
|
| 256 |
+
"per learner, `dataset_num_iters_per_learner` has to be defined. "
|
| 257 |
+
"Set this hyperparameter in the `AlgorithmConfig.offline_data`."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
@override(SACConfig)
|
| 261 |
+
def get_default_rl_module_spec(self) -> RLModuleSpecType:
|
| 262 |
+
if self.framework_str == "torch":
|
| 263 |
+
from ray.rllib.algorithms.cql.torch.default_cql_torch_rl_module import (
|
| 264 |
+
DefaultCQLTorchRLModule,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return RLModuleSpec(module_class=DefaultCQLTorchRLModule)
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"The framework {self.framework_str} is not supported. " "Use `torch`."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
@property
|
| 274 |
+
def _model_config_auto_includes(self):
|
| 275 |
+
return super()._model_config_auto_includes | {
|
| 276 |
+
"num_actions": self.num_actions,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class CQL(SAC):
|
| 281 |
+
"""CQL (derived from SAC)."""
|
| 282 |
+
|
| 283 |
+
@classmethod
|
| 284 |
+
@override(SAC)
|
| 285 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 286 |
+
return CQLConfig()
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
@override(SAC)
|
| 290 |
+
def get_default_policy_class(
|
| 291 |
+
cls, config: AlgorithmConfig
|
| 292 |
+
) -> Optional[Type[Policy]]:
|
| 293 |
+
if config["framework"] == "torch":
|
| 294 |
+
return CQLTorchPolicy
|
| 295 |
+
else:
|
| 296 |
+
return CQLTFPolicy
|
| 297 |
+
|
| 298 |
+
@override(SAC)
|
| 299 |
+
def training_step(self) -> None:
|
| 300 |
+
# Old API stack (Policy, RolloutWorker, Connector).
|
| 301 |
+
if not self.config.enable_env_runner_and_connector_v2:
|
| 302 |
+
return self._training_step_old_api_stack()
|
| 303 |
+
|
| 304 |
+
# Sampling from offline data.
|
| 305 |
+
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
|
| 306 |
+
# Return an iterator in case we are using remote learners.
|
| 307 |
+
batch_or_iterator = self.offline_data.sample(
|
| 308 |
+
num_samples=self.config.train_batch_size_per_learner,
|
| 309 |
+
num_shards=self.config.num_learners,
|
| 310 |
+
return_iterator=self.config.num_learners > 1,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Updating the policy.
|
| 314 |
+
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
|
| 315 |
+
# TODO (simon, sven): Check, if we should execute directly s.th. like
|
| 316 |
+
# `LearnerGroup.update_from_iterator()`.
|
| 317 |
+
learner_results = self.learner_group._update(
|
| 318 |
+
batch=batch_or_iterator,
|
| 319 |
+
minibatch_size=self.config.train_batch_size_per_learner,
|
| 320 |
+
num_iters=self.config.dataset_num_iters_per_learner,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Log training results.
|
| 324 |
+
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
|
| 325 |
+
|
| 326 |
+
# Synchronize weights.
|
| 327 |
+
# As the results contain for each policy the loss and in addition the
|
| 328 |
+
# total loss over all policies is returned, this total loss has to be
|
| 329 |
+
# removed.
|
| 330 |
+
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
|
| 331 |
+
|
| 332 |
+
if self.eval_env_runner_group:
|
| 333 |
+
# Update weights - after learning on the local worker -
|
| 334 |
+
# on all remote workers. Note, we only have the local `EnvRunner`,
|
| 335 |
+
# but from this `EnvRunner` the evaulation `EnvRunner`s get updated.
|
| 336 |
+
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
|
| 337 |
+
self.eval_env_runner_group.sync_weights(
|
| 338 |
+
# Sync weights from learner_group to all EnvRunners.
|
| 339 |
+
from_worker_or_learner_group=self.learner_group,
|
| 340 |
+
policies=modules_to_update,
|
| 341 |
+
inference_only=True,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
@OldAPIStack
|
| 345 |
+
def _training_step_old_api_stack(self) -> ResultDict:
|
| 346 |
+
# Collect SampleBatches from sample workers.
|
| 347 |
+
with self._timers[SAMPLE_TIMER]:
|
| 348 |
+
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
|
| 349 |
+
train_batch = train_batch.as_multi_agent()
|
| 350 |
+
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
| 351 |
+
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
| 352 |
+
|
| 353 |
+
# Postprocess batch before we learn on it.
|
| 354 |
+
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
| 355 |
+
train_batch = post_fn(train_batch, self.env_runner_group, self.config)
|
| 356 |
+
|
| 357 |
+
# Learn on training batch.
|
| 358 |
+
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
| 359 |
+
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
| 360 |
+
if self.config.get("simple_optimizer") is True:
|
| 361 |
+
train_results = train_one_step(self, train_batch)
|
| 362 |
+
else:
|
| 363 |
+
train_results = multi_gpu_train_one_step(self, train_batch)
|
| 364 |
+
|
| 365 |
+
# Update target network every `target_network_update_freq` training steps.
|
| 366 |
+
cur_ts = self._counters[
|
| 367 |
+
NUM_AGENT_STEPS_TRAINED
|
| 368 |
+
if self.config.count_steps_by == "agent_steps"
|
| 369 |
+
else NUM_ENV_STEPS_TRAINED
|
| 370 |
+
]
|
| 371 |
+
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
| 372 |
+
if cur_ts - last_update >= self.config.target_network_update_freq:
|
| 373 |
+
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
| 374 |
+
to_update = self.env_runner.get_policies_to_train()
|
| 375 |
+
self.env_runner.foreach_policy_to_train(
|
| 376 |
+
lambda p, pid: pid in to_update and p.update_target()
|
| 377 |
+
)
|
| 378 |
+
self._counters[NUM_TARGET_UPDATES] += 1
|
| 379 |
+
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
| 380 |
+
|
| 381 |
+
# Update remote workers's weights after learning on local worker
|
| 382 |
+
# (only those policies that were actually trained).
|
| 383 |
+
if self.env_runner_group.num_remote_workers() > 0:
|
| 384 |
+
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
| 385 |
+
self.env_runner_group.sync_weights(policies=list(train_results.keys()))
|
| 386 |
+
|
| 387 |
+
# Return all collected metrics for the iteration.
|
| 388 |
+
return train_results
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_tf_policy.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TensorFlow policy class used for CQL.
|
| 3 |
+
"""
|
| 4 |
+
from functools import partial
|
| 5 |
+
import numpy as np
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
import logging
|
| 8 |
+
import tree
|
| 9 |
+
from typing import Dict, List, Type, Union
|
| 10 |
+
|
| 11 |
+
import ray
|
| 12 |
+
import ray.experimental.tf_utils
|
| 13 |
+
from ray.rllib.algorithms.sac.sac_tf_policy import (
|
| 14 |
+
apply_gradients as sac_apply_gradients,
|
| 15 |
+
compute_and_clip_gradients as sac_compute_and_clip_gradients,
|
| 16 |
+
get_distribution_inputs_and_class,
|
| 17 |
+
_get_dist_class,
|
| 18 |
+
build_sac_model,
|
| 19 |
+
postprocess_trajectory,
|
| 20 |
+
setup_late_mixins,
|
| 21 |
+
stats,
|
| 22 |
+
validate_spaces,
|
| 23 |
+
ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin,
|
| 24 |
+
ComputeTDErrorMixin,
|
| 25 |
+
)
|
| 26 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 27 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 28 |
+
from ray.rllib.policy.tf_mixins import TargetNetworkMixin
|
| 29 |
+
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
| 30 |
+
from ray.rllib.policy.policy import Policy
|
| 31 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 32 |
+
from ray.rllib.utils.exploration.random import Random
|
| 33 |
+
from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp
|
| 34 |
+
from ray.rllib.utils.typing import (
|
| 35 |
+
LocalOptimizer,
|
| 36 |
+
ModelGradients,
|
| 37 |
+
TensorType,
|
| 38 |
+
AlgorithmConfigDict,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
tf1, tf, tfv = try_import_tf()
|
| 42 |
+
tfp = try_import_tfp()
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
MEAN_MIN = -9.0
|
| 47 |
+
MEAN_MAX = 9.0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _repeat_tensor(t: TensorType, n: int):
|
| 51 |
+
# Insert new axis at position 1 into tensor t
|
| 52 |
+
t_rep = tf.expand_dims(t, 1)
|
| 53 |
+
# Repeat tensor t_rep along new axis n times
|
| 54 |
+
multiples = tf.concat([[1, n], tf.tile([1], tf.expand_dims(tf.rank(t) - 1, 0))], 0)
|
| 55 |
+
t_rep = tf.tile(t_rep, multiples)
|
| 56 |
+
# Merge new axis into batch axis
|
| 57 |
+
t_rep = tf.reshape(t_rep, tf.concat([[-1], tf.shape(t)[1:]], 0))
|
| 58 |
+
return t_rep
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Returns policy tiled actions and log probabilities for CQL Loss
|
| 62 |
+
def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
|
| 63 |
+
batch_size = tf.shape(tree.flatten(obs)[0])[0]
|
| 64 |
+
obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
|
| 65 |
+
logits, _ = model.get_action_model_outputs(obs_temp)
|
| 66 |
+
policy_dist = action_dist(logits, model)
|
| 67 |
+
actions, logp_ = policy_dist.sample_logp()
|
| 68 |
+
logp = tf.expand_dims(logp_, -1)
|
| 69 |
+
return actions, tf.reshape(logp, [batch_size, num_repeat, 1])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def q_values_repeat(model, obs, actions, twin=False):
|
| 73 |
+
action_shape = tf.shape(actions)[0]
|
| 74 |
+
obs_shape = tf.shape(tree.flatten(obs)[0])[0]
|
| 75 |
+
num_repeat = action_shape // obs_shape
|
| 76 |
+
obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
|
| 77 |
+
if not twin:
|
| 78 |
+
preds_, _ = model.get_q_values(obs_temp, actions)
|
| 79 |
+
else:
|
| 80 |
+
preds_, _ = model.get_twin_q_values(obs_temp, actions)
|
| 81 |
+
preds = tf.reshape(preds_, [obs_shape, num_repeat, 1])
|
| 82 |
+
return preds
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def cql_loss(
|
| 86 |
+
policy: Policy,
|
| 87 |
+
model: ModelV2,
|
| 88 |
+
dist_class: Type[TFActionDistribution],
|
| 89 |
+
train_batch: SampleBatch,
|
| 90 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 91 |
+
logger.info(f"Current iteration = {policy.cur_iter}")
|
| 92 |
+
policy.cur_iter += 1
|
| 93 |
+
|
| 94 |
+
# For best performance, turn deterministic off
|
| 95 |
+
deterministic = policy.config["_deterministic_loss"]
|
| 96 |
+
assert not deterministic
|
| 97 |
+
twin_q = policy.config["twin_q"]
|
| 98 |
+
discount = policy.config["gamma"]
|
| 99 |
+
|
| 100 |
+
# CQL Parameters
|
| 101 |
+
bc_iters = policy.config["bc_iters"]
|
| 102 |
+
cql_temp = policy.config["temperature"]
|
| 103 |
+
num_actions = policy.config["num_actions"]
|
| 104 |
+
min_q_weight = policy.config["min_q_weight"]
|
| 105 |
+
use_lagrange = policy.config["lagrangian"]
|
| 106 |
+
target_action_gap = policy.config["lagrangian_thresh"]
|
| 107 |
+
|
| 108 |
+
obs = train_batch[SampleBatch.CUR_OBS]
|
| 109 |
+
actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
|
| 110 |
+
rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
|
| 111 |
+
next_obs = train_batch[SampleBatch.NEXT_OBS]
|
| 112 |
+
terminals = train_batch[SampleBatch.TERMINATEDS]
|
| 113 |
+
|
| 114 |
+
model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)
|
| 115 |
+
|
| 116 |
+
model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None)
|
| 117 |
+
|
| 118 |
+
target_model_out_tp1, _ = policy.target_model(
|
| 119 |
+
SampleBatch(obs=next_obs, _is_training=True), [], None
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
|
| 123 |
+
action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
|
| 124 |
+
action_dist_t = action_dist_class(action_dist_inputs_t, model)
|
| 125 |
+
policy_t, log_pis_t = action_dist_t.sample_logp()
|
| 126 |
+
log_pis_t = tf.expand_dims(log_pis_t, -1)
|
| 127 |
+
|
| 128 |
+
# Unlike original SAC, Alpha and Actor Loss are computed first.
|
| 129 |
+
# Alpha Loss
|
| 130 |
+
alpha_loss = -tf.reduce_mean(
|
| 131 |
+
model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Policy Loss (Either Behavior Clone Loss or SAC Loss)
|
| 135 |
+
alpha = tf.math.exp(model.log_alpha)
|
| 136 |
+
if policy.cur_iter >= bc_iters:
|
| 137 |
+
min_q, _ = model.get_q_values(model_out_t, policy_t)
|
| 138 |
+
if twin_q:
|
| 139 |
+
twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
|
| 140 |
+
min_q = tf.math.minimum(min_q, twin_q_)
|
| 141 |
+
actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - min_q)
|
| 142 |
+
else:
|
| 143 |
+
bc_logp = action_dist_t.logp(actions)
|
| 144 |
+
actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - bc_logp)
|
| 145 |
+
# actor_loss = -tf.reduce_mean(bc_logp)
|
| 146 |
+
|
| 147 |
+
# Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
|
| 148 |
+
# SAC Loss:
|
| 149 |
+
# Q-values for the batched actions.
|
| 150 |
+
action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
|
| 151 |
+
action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
|
| 152 |
+
policy_tp1, _ = action_dist_tp1.sample_logp()
|
| 153 |
+
|
| 154 |
+
q_t, _ = model.get_q_values(model_out_t, actions)
|
| 155 |
+
q_t_selected = tf.squeeze(q_t, axis=-1)
|
| 156 |
+
if twin_q:
|
| 157 |
+
twin_q_t, _ = model.get_twin_q_values(model_out_t, actions)
|
| 158 |
+
twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1)
|
| 159 |
+
|
| 160 |
+
# Target q network evaluation.
|
| 161 |
+
q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
|
| 162 |
+
if twin_q:
|
| 163 |
+
twin_q_tp1, _ = policy.target_model.get_twin_q_values(
|
| 164 |
+
target_model_out_tp1, policy_tp1
|
| 165 |
+
)
|
| 166 |
+
# Take min over both twin-NNs.
|
| 167 |
+
q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1)
|
| 168 |
+
|
| 169 |
+
q_tp1_best = tf.squeeze(input=q_tp1, axis=-1)
|
| 170 |
+
q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best
|
| 171 |
+
|
| 172 |
+
# compute RHS of bellman equation
|
| 173 |
+
q_t_target = tf.stop_gradient(
|
| 174 |
+
rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Compute the TD-error (potentially clipped), for priority replay buffer
|
| 178 |
+
base_td_error = tf.math.abs(q_t_selected - q_t_target)
|
| 179 |
+
if twin_q:
|
| 180 |
+
twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target)
|
| 181 |
+
td_error = 0.5 * (base_td_error + twin_td_error)
|
| 182 |
+
else:
|
| 183 |
+
td_error = base_td_error
|
| 184 |
+
|
| 185 |
+
critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target)
|
| 186 |
+
if twin_q:
|
| 187 |
+
critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target)
|
| 188 |
+
|
| 189 |
+
# CQL Loss (We are using Entropy version of CQL (the best version))
|
| 190 |
+
rand_actions, _ = policy._random_action_generator.get_exploration_action(
|
| 191 |
+
action_distribution=action_dist_class(
|
| 192 |
+
tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model
|
| 193 |
+
),
|
| 194 |
+
timestep=0,
|
| 195 |
+
explore=True,
|
| 196 |
+
)
|
| 197 |
+
curr_actions, curr_logp = policy_actions_repeat(
|
| 198 |
+
model, action_dist_class, model_out_t, num_actions
|
| 199 |
+
)
|
| 200 |
+
next_actions, next_logp = policy_actions_repeat(
|
| 201 |
+
model, action_dist_class, model_out_tp1, num_actions
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
q1_rand = q_values_repeat(model, model_out_t, rand_actions)
|
| 205 |
+
q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
|
| 206 |
+
q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
|
| 207 |
+
|
| 208 |
+
if twin_q:
|
| 209 |
+
q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
|
| 210 |
+
q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True)
|
| 211 |
+
q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True)
|
| 212 |
+
|
| 213 |
+
random_density = np.log(0.5 ** int(curr_actions.shape[-1]))
|
| 214 |
+
cat_q1 = tf.concat(
|
| 215 |
+
[
|
| 216 |
+
q1_rand - random_density,
|
| 217 |
+
q1_next_actions - tf.stop_gradient(next_logp),
|
| 218 |
+
q1_curr_actions - tf.stop_gradient(curr_logp),
|
| 219 |
+
],
|
| 220 |
+
1,
|
| 221 |
+
)
|
| 222 |
+
if twin_q:
|
| 223 |
+
cat_q2 = tf.concat(
|
| 224 |
+
[
|
| 225 |
+
q2_rand - random_density,
|
| 226 |
+
q2_next_actions - tf.stop_gradient(next_logp),
|
| 227 |
+
q2_curr_actions - tf.stop_gradient(curr_logp),
|
| 228 |
+
],
|
| 229 |
+
1,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
min_qf1_loss_ = (
|
| 233 |
+
tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1))
|
| 234 |
+
* min_q_weight
|
| 235 |
+
* cql_temp
|
| 236 |
+
)
|
| 237 |
+
min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight)
|
| 238 |
+
if twin_q:
|
| 239 |
+
min_qf2_loss_ = (
|
| 240 |
+
tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1))
|
| 241 |
+
* min_q_weight
|
| 242 |
+
* cql_temp
|
| 243 |
+
)
|
| 244 |
+
min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight)
|
| 245 |
+
|
| 246 |
+
if use_lagrange:
|
| 247 |
+
alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0]
|
| 248 |
+
min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
|
| 249 |
+
if twin_q:
|
| 250 |
+
min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
|
| 251 |
+
alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
|
| 252 |
+
else:
|
| 253 |
+
alpha_prime_loss = -min_qf1_loss
|
| 254 |
+
|
| 255 |
+
cql_loss = [min_qf1_loss]
|
| 256 |
+
if twin_q:
|
| 257 |
+
cql_loss.append(min_qf2_loss)
|
| 258 |
+
|
| 259 |
+
critic_loss = [critic_loss_1 + min_qf1_loss]
|
| 260 |
+
if twin_q:
|
| 261 |
+
critic_loss.append(critic_loss_2 + min_qf2_loss)
|
| 262 |
+
|
| 263 |
+
# Save for stats function.
|
| 264 |
+
policy.q_t = q_t_selected
|
| 265 |
+
policy.policy_t = policy_t
|
| 266 |
+
policy.log_pis_t = log_pis_t
|
| 267 |
+
policy.td_error = td_error
|
| 268 |
+
policy.actor_loss = actor_loss
|
| 269 |
+
policy.critic_loss = critic_loss
|
| 270 |
+
policy.alpha_loss = alpha_loss
|
| 271 |
+
policy.log_alpha_value = model.log_alpha
|
| 272 |
+
policy.alpha_value = alpha
|
| 273 |
+
policy.target_entropy = model.target_entropy
|
| 274 |
+
# CQL Stats
|
| 275 |
+
policy.cql_loss = cql_loss
|
| 276 |
+
if use_lagrange:
|
| 277 |
+
policy.log_alpha_prime_value = model.log_alpha_prime[0]
|
| 278 |
+
policy.alpha_prime_value = alpha_prime
|
| 279 |
+
policy.alpha_prime_loss = alpha_prime_loss
|
| 280 |
+
|
| 281 |
+
# Return all loss terms corresponding to our optimizers.
|
| 282 |
+
if use_lagrange:
|
| 283 |
+
return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + alpha_prime_loss
|
| 284 |
+
return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 288 |
+
sac_dict = stats(policy, train_batch)
|
| 289 |
+
sac_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss))
|
| 290 |
+
if policy.config["lagrangian"]:
|
| 291 |
+
sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value
|
| 292 |
+
sac_dict["alpha_prime_value"] = policy.alpha_prime_value
|
| 293 |
+
sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss
|
| 294 |
+
return sac_dict
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class ActorCriticOptimizerMixin(SACActorCriticOptimizerMixin):
|
| 298 |
+
def __init__(self, config):
|
| 299 |
+
super().__init__(config)
|
| 300 |
+
if config["lagrangian"]:
|
| 301 |
+
# Eager mode.
|
| 302 |
+
if config["framework"] == "tf2":
|
| 303 |
+
self._alpha_prime_optimizer = tf.keras.optimizers.Adam(
|
| 304 |
+
learning_rate=config["optimization"]["critic_learning_rate"]
|
| 305 |
+
)
|
| 306 |
+
# Static graph mode.
|
| 307 |
+
else:
|
| 308 |
+
self._alpha_prime_optimizer = tf1.train.AdamOptimizer(
|
| 309 |
+
learning_rate=config["optimization"]["critic_learning_rate"]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def setup_early_mixins(
|
| 314 |
+
policy: Policy,
|
| 315 |
+
obs_space: gym.spaces.Space,
|
| 316 |
+
action_space: gym.spaces.Space,
|
| 317 |
+
config: AlgorithmConfigDict,
|
| 318 |
+
) -> None:
|
| 319 |
+
"""Call mixin classes' constructors before Policy's initialization.
|
| 320 |
+
|
| 321 |
+
Adds the necessary optimizers to the given Policy.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
policy: The Policy object.
|
| 325 |
+
obs_space (gym.spaces.Space): The Policy's observation space.
|
| 326 |
+
action_space (gym.spaces.Space): The Policy's action space.
|
| 327 |
+
config: The Policy's config.
|
| 328 |
+
"""
|
| 329 |
+
policy.cur_iter = 0
|
| 330 |
+
ActorCriticOptimizerMixin.__init__(policy, config)
|
| 331 |
+
if config["lagrangian"]:
|
| 332 |
+
policy.model.log_alpha_prime = get_variable(
|
| 333 |
+
0.0, framework="tf", trainable=True, tf_name="log_alpha_prime"
|
| 334 |
+
)
|
| 335 |
+
policy.alpha_prime_optim = tf.keras.optimizers.Adam(
|
| 336 |
+
learning_rate=config["optimization"]["critic_learning_rate"],
|
| 337 |
+
)
|
| 338 |
+
# Generic random action generator for calculating CQL-loss.
|
| 339 |
+
policy._random_action_generator = Random(
|
| 340 |
+
action_space,
|
| 341 |
+
model=None,
|
| 342 |
+
framework="tf2",
|
| 343 |
+
policy_config=config,
|
| 344 |
+
num_workers=0,
|
| 345 |
+
worker_index=0,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def compute_gradients_fn(
|
| 350 |
+
policy: Policy, optimizer: LocalOptimizer, loss: TensorType
|
| 351 |
+
) -> ModelGradients:
|
| 352 |
+
grads_and_vars = sac_compute_and_clip_gradients(policy, optimizer, loss)
|
| 353 |
+
|
| 354 |
+
if policy.config["lagrangian"]:
|
| 355 |
+
# Eager: Use GradientTape (which is a property of the `optimizer`
|
| 356 |
+
# object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
|
| 357 |
+
if policy.config["framework"] == "tf2":
|
| 358 |
+
tape = optimizer.tape
|
| 359 |
+
log_alpha_prime = [policy.model.log_alpha_prime]
|
| 360 |
+
alpha_prime_grads_and_vars = list(
|
| 361 |
+
zip(
|
| 362 |
+
tape.gradient(policy.alpha_prime_loss, log_alpha_prime),
|
| 363 |
+
log_alpha_prime,
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
# Tf1.x: Use optimizer.compute_gradients()
|
| 367 |
+
else:
|
| 368 |
+
alpha_prime_grads_and_vars = (
|
| 369 |
+
policy._alpha_prime_optimizer.compute_gradients(
|
| 370 |
+
policy.alpha_prime_loss, var_list=[policy.model.log_alpha_prime]
|
| 371 |
+
)
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Clip if necessary.
|
| 375 |
+
if policy.config["grad_clip"]:
|
| 376 |
+
clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
|
| 377 |
+
else:
|
| 378 |
+
clip_func = tf.identity
|
| 379 |
+
|
| 380 |
+
# Save grads and vars for later use in `build_apply_op`.
|
| 381 |
+
policy._alpha_prime_grads_and_vars = [
|
| 382 |
+
(clip_func(g), v) for (g, v) in alpha_prime_grads_and_vars if g is not None
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
grads_and_vars += policy._alpha_prime_grads_and_vars
|
| 386 |
+
return grads_and_vars
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def apply_gradients_fn(policy, optimizer, grads_and_vars):
|
| 390 |
+
sac_results = sac_apply_gradients(policy, optimizer, grads_and_vars)
|
| 391 |
+
|
| 392 |
+
if policy.config["lagrangian"]:
|
| 393 |
+
# Eager mode -> Just apply and return None.
|
| 394 |
+
if policy.config["framework"] == "tf2":
|
| 395 |
+
policy._alpha_prime_optimizer.apply_gradients(
|
| 396 |
+
policy._alpha_prime_grads_and_vars
|
| 397 |
+
)
|
| 398 |
+
return
|
| 399 |
+
# Tf static graph -> Return grouped op.
|
| 400 |
+
else:
|
| 401 |
+
alpha_prime_apply_op = policy._alpha_prime_optimizer.apply_gradients(
|
| 402 |
+
policy._alpha_prime_grads_and_vars,
|
| 403 |
+
global_step=tf1.train.get_or_create_global_step(),
|
| 404 |
+
)
|
| 405 |
+
return tf.group([sac_results, alpha_prime_apply_op])
|
| 406 |
+
return sac_results
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# Build a child class of `TFPolicy`, given the custom functions defined
|
| 410 |
+
# above.
|
| 411 |
+
CQLTFPolicy = build_tf_policy(
|
| 412 |
+
name="CQLTFPolicy",
|
| 413 |
+
loss_fn=cql_loss,
|
| 414 |
+
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQLConfig(),
|
| 415 |
+
validate_spaces=validate_spaces,
|
| 416 |
+
stats_fn=cql_stats,
|
| 417 |
+
postprocess_fn=postprocess_trajectory,
|
| 418 |
+
before_init=setup_early_mixins,
|
| 419 |
+
after_init=setup_late_mixins,
|
| 420 |
+
make_model=build_sac_model,
|
| 421 |
+
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
| 422 |
+
mixins=[ActorCriticOptimizerMixin, TargetNetworkMixin, ComputeTDErrorMixin],
|
| 423 |
+
action_distribution_fn=get_distribution_inputs_and_class,
|
| 424 |
+
compute_gradients_fn=compute_gradients_fn,
|
| 425 |
+
apply_gradients_fn=apply_gradients_fn,
|
| 426 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/cql_torch_policy.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch policy class used for CQL.
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import gymnasium as gym
|
| 6 |
+
import logging
|
| 7 |
+
import tree
|
| 8 |
+
from typing import Dict, List, Tuple, Type, Union
|
| 9 |
+
|
| 10 |
+
import ray
|
| 11 |
+
import ray.experimental.tf_utils
|
| 12 |
+
from ray.rllib.algorithms.sac.sac_tf_policy import (
|
| 13 |
+
postprocess_trajectory,
|
| 14 |
+
validate_spaces,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.algorithms.sac.sac_torch_policy import (
|
| 17 |
+
_get_dist_class,
|
| 18 |
+
stats,
|
| 19 |
+
build_sac_model_and_action_dist,
|
| 20 |
+
optimizer_fn,
|
| 21 |
+
ComputeTDErrorMixin,
|
| 22 |
+
setup_late_mixins,
|
| 23 |
+
action_distribution_fn,
|
| 24 |
+
)
|
| 25 |
+
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
| 26 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 27 |
+
from ray.rllib.policy.policy_template import build_policy_class
|
| 28 |
+
from ray.rllib.policy.policy import Policy
|
| 29 |
+
from ray.rllib.policy.torch_mixins import TargetNetworkMixin
|
| 30 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 31 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 32 |
+
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
| 33 |
+
from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict
|
| 34 |
+
from ray.rllib.utils.torch_utils import (
|
| 35 |
+
apply_grad_clipping,
|
| 36 |
+
convert_to_torch_tensor,
|
| 37 |
+
concat_multi_gpu_td_errors,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
torch, nn = try_import_torch()
|
| 41 |
+
F = nn.functional
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
MEAN_MIN = -9.0
|
| 46 |
+
MEAN_MAX = 9.0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _repeat_tensor(t: TensorType, n: int):
|
| 50 |
+
# Insert new dimension at posotion 1 into tensor t
|
| 51 |
+
t_rep = t.unsqueeze(1)
|
| 52 |
+
# Repeat tensor t_rep along new dimension n times
|
| 53 |
+
t_rep = torch.repeat_interleave(t_rep, n, dim=1)
|
| 54 |
+
# Merge new dimension into batch dimension
|
| 55 |
+
t_rep = t_rep.view(-1, *t.shape[1:])
|
| 56 |
+
return t_rep
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Returns policy tiled actions and log probabilities for CQL Loss
|
| 60 |
+
def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
|
| 61 |
+
batch_size = tree.flatten(obs)[0].shape[0]
|
| 62 |
+
obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
|
| 63 |
+
logits, _ = model.get_action_model_outputs(obs_temp)
|
| 64 |
+
policy_dist = action_dist(logits, model)
|
| 65 |
+
actions, logp_ = policy_dist.sample_logp()
|
| 66 |
+
logp = logp_.unsqueeze(-1)
|
| 67 |
+
return actions, logp.view(batch_size, num_repeat, 1)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def q_values_repeat(model, obs, actions, twin=False):
|
| 71 |
+
action_shape = actions.shape[0]
|
| 72 |
+
obs_shape = tree.flatten(obs)[0].shape[0]
|
| 73 |
+
num_repeat = int(action_shape / obs_shape)
|
| 74 |
+
obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
|
| 75 |
+
if not twin:
|
| 76 |
+
preds_, _ = model.get_q_values(obs_temp, actions)
|
| 77 |
+
else:
|
| 78 |
+
preds_, _ = model.get_twin_q_values(obs_temp, actions)
|
| 79 |
+
preds = preds_.view(obs_shape, num_repeat, 1)
|
| 80 |
+
return preds
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def cql_loss(
|
| 84 |
+
policy: Policy,
|
| 85 |
+
model: ModelV2,
|
| 86 |
+
dist_class: Type[TorchDistributionWrapper],
|
| 87 |
+
train_batch: SampleBatch,
|
| 88 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 89 |
+
logger.info(f"Current iteration = {policy.cur_iter}")
|
| 90 |
+
policy.cur_iter += 1
|
| 91 |
+
|
| 92 |
+
# Look up the target model (tower) using the model tower.
|
| 93 |
+
target_model = policy.target_models[model]
|
| 94 |
+
|
| 95 |
+
# For best performance, turn deterministic off
|
| 96 |
+
deterministic = policy.config["_deterministic_loss"]
|
| 97 |
+
assert not deterministic
|
| 98 |
+
twin_q = policy.config["twin_q"]
|
| 99 |
+
discount = policy.config["gamma"]
|
| 100 |
+
action_low = model.action_space.low[0]
|
| 101 |
+
action_high = model.action_space.high[0]
|
| 102 |
+
|
| 103 |
+
# CQL Parameters
|
| 104 |
+
bc_iters = policy.config["bc_iters"]
|
| 105 |
+
cql_temp = policy.config["temperature"]
|
| 106 |
+
num_actions = policy.config["num_actions"]
|
| 107 |
+
min_q_weight = policy.config["min_q_weight"]
|
| 108 |
+
use_lagrange = policy.config["lagrangian"]
|
| 109 |
+
target_action_gap = policy.config["lagrangian_thresh"]
|
| 110 |
+
|
| 111 |
+
obs = train_batch[SampleBatch.CUR_OBS]
|
| 112 |
+
actions = train_batch[SampleBatch.ACTIONS]
|
| 113 |
+
rewards = train_batch[SampleBatch.REWARDS].float()
|
| 114 |
+
next_obs = train_batch[SampleBatch.NEXT_OBS]
|
| 115 |
+
terminals = train_batch[SampleBatch.TERMINATEDS]
|
| 116 |
+
|
| 117 |
+
model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)
|
| 118 |
+
|
| 119 |
+
model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None)
|
| 120 |
+
|
| 121 |
+
target_model_out_tp1, _ = target_model(
|
| 122 |
+
SampleBatch(obs=next_obs, _is_training=True), [], None
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
|
| 126 |
+
action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
|
| 127 |
+
action_dist_t = action_dist_class(action_dist_inputs_t, model)
|
| 128 |
+
policy_t, log_pis_t = action_dist_t.sample_logp()
|
| 129 |
+
log_pis_t = torch.unsqueeze(log_pis_t, -1)
|
| 130 |
+
|
| 131 |
+
# Unlike original SAC, Alpha and Actor Loss are computed first.
|
| 132 |
+
# Alpha Loss
|
| 133 |
+
alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean()
|
| 134 |
+
|
| 135 |
+
batch_size = tree.flatten(obs)[0].shape[0]
|
| 136 |
+
if batch_size == policy.config["train_batch_size"]:
|
| 137 |
+
policy.alpha_optim.zero_grad()
|
| 138 |
+
alpha_loss.backward()
|
| 139 |
+
policy.alpha_optim.step()
|
| 140 |
+
|
| 141 |
+
# Policy Loss (Either Behavior Clone Loss or SAC Loss)
|
| 142 |
+
alpha = torch.exp(model.log_alpha)
|
| 143 |
+
if policy.cur_iter >= bc_iters:
|
| 144 |
+
min_q, _ = model.get_q_values(model_out_t, policy_t)
|
| 145 |
+
if twin_q:
|
| 146 |
+
twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
|
| 147 |
+
min_q = torch.min(min_q, twin_q_)
|
| 148 |
+
actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
|
| 149 |
+
else:
|
| 150 |
+
bc_logp = action_dist_t.logp(actions)
|
| 151 |
+
actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
|
| 152 |
+
# actor_loss = -bc_logp.mean()
|
| 153 |
+
|
| 154 |
+
if batch_size == policy.config["train_batch_size"]:
|
| 155 |
+
policy.actor_optim.zero_grad()
|
| 156 |
+
actor_loss.backward(retain_graph=True)
|
| 157 |
+
policy.actor_optim.step()
|
| 158 |
+
|
| 159 |
+
# Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
|
| 160 |
+
# SAC Loss:
|
| 161 |
+
# Q-values for the batched actions.
|
| 162 |
+
action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
|
| 163 |
+
action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
|
| 164 |
+
policy_tp1, _ = action_dist_tp1.sample_logp()
|
| 165 |
+
|
| 166 |
+
q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
| 167 |
+
q_t_selected = torch.squeeze(q_t, dim=-1)
|
| 168 |
+
if twin_q:
|
| 169 |
+
twin_q_t, _ = model.get_twin_q_values(
|
| 170 |
+
model_out_t, train_batch[SampleBatch.ACTIONS]
|
| 171 |
+
)
|
| 172 |
+
twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
|
| 173 |
+
|
| 174 |
+
# Target q network evaluation.
|
| 175 |
+
q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1)
|
| 176 |
+
if twin_q:
|
| 177 |
+
twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1)
|
| 178 |
+
# Take min over both twin-NNs.
|
| 179 |
+
q_tp1 = torch.min(q_tp1, twin_q_tp1)
|
| 180 |
+
|
| 181 |
+
q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
|
| 182 |
+
q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best
|
| 183 |
+
|
| 184 |
+
# compute RHS of bellman equation
|
| 185 |
+
q_t_target = (
|
| 186 |
+
rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked
|
| 187 |
+
).detach()
|
| 188 |
+
|
| 189 |
+
# Compute the TD-error (potentially clipped), for priority replay buffer
|
| 190 |
+
base_td_error = torch.abs(q_t_selected - q_t_target)
|
| 191 |
+
if twin_q:
|
| 192 |
+
twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
|
| 193 |
+
td_error = 0.5 * (base_td_error + twin_td_error)
|
| 194 |
+
else:
|
| 195 |
+
td_error = base_td_error
|
| 196 |
+
|
| 197 |
+
critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
|
| 198 |
+
if twin_q:
|
| 199 |
+
critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)
|
| 200 |
+
|
| 201 |
+
# CQL Loss (We are using Entropy version of CQL (the best version))
|
| 202 |
+
rand_actions = convert_to_torch_tensor(
|
| 203 |
+
torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(
|
| 204 |
+
action_low, action_high
|
| 205 |
+
),
|
| 206 |
+
policy.device,
|
| 207 |
+
)
|
| 208 |
+
curr_actions, curr_logp = policy_actions_repeat(
|
| 209 |
+
model, action_dist_class, model_out_t, num_actions
|
| 210 |
+
)
|
| 211 |
+
next_actions, next_logp = policy_actions_repeat(
|
| 212 |
+
model, action_dist_class, model_out_tp1, num_actions
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
q1_rand = q_values_repeat(model, model_out_t, rand_actions)
|
| 216 |
+
q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
|
| 217 |
+
q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
|
| 218 |
+
|
| 219 |
+
if twin_q:
|
| 220 |
+
q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
|
| 221 |
+
q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True)
|
| 222 |
+
q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True)
|
| 223 |
+
|
| 224 |
+
random_density = np.log(0.5 ** curr_actions.shape[-1])
|
| 225 |
+
cat_q1 = torch.cat(
|
| 226 |
+
[
|
| 227 |
+
q1_rand - random_density,
|
| 228 |
+
q1_next_actions - next_logp.detach(),
|
| 229 |
+
q1_curr_actions - curr_logp.detach(),
|
| 230 |
+
],
|
| 231 |
+
1,
|
| 232 |
+
)
|
| 233 |
+
if twin_q:
|
| 234 |
+
cat_q2 = torch.cat(
|
| 235 |
+
[
|
| 236 |
+
q2_rand - random_density,
|
| 237 |
+
q2_next_actions - next_logp.detach(),
|
| 238 |
+
q2_curr_actions - curr_logp.detach(),
|
| 239 |
+
],
|
| 240 |
+
1,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
min_qf1_loss_ = (
|
| 244 |
+
torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
|
| 245 |
+
)
|
| 246 |
+
min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
|
| 247 |
+
if twin_q:
|
| 248 |
+
min_qf2_loss_ = (
|
| 249 |
+
torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
|
| 250 |
+
)
|
| 251 |
+
min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)
|
| 252 |
+
|
| 253 |
+
if use_lagrange:
|
| 254 |
+
alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[
|
| 255 |
+
0
|
| 256 |
+
]
|
| 257 |
+
min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
|
| 258 |
+
if twin_q:
|
| 259 |
+
min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
|
| 260 |
+
alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
|
| 261 |
+
else:
|
| 262 |
+
alpha_prime_loss = -min_qf1_loss
|
| 263 |
+
|
| 264 |
+
cql_loss = [min_qf1_loss]
|
| 265 |
+
if twin_q:
|
| 266 |
+
cql_loss.append(min_qf2_loss)
|
| 267 |
+
|
| 268 |
+
critic_loss = [critic_loss_1 + min_qf1_loss]
|
| 269 |
+
if twin_q:
|
| 270 |
+
critic_loss.append(critic_loss_2 + min_qf2_loss)
|
| 271 |
+
|
| 272 |
+
if batch_size == policy.config["train_batch_size"]:
|
| 273 |
+
policy.critic_optims[0].zero_grad()
|
| 274 |
+
critic_loss[0].backward(retain_graph=True)
|
| 275 |
+
policy.critic_optims[0].step()
|
| 276 |
+
|
| 277 |
+
if twin_q:
|
| 278 |
+
policy.critic_optims[1].zero_grad()
|
| 279 |
+
critic_loss[1].backward(retain_graph=False)
|
| 280 |
+
policy.critic_optims[1].step()
|
| 281 |
+
|
| 282 |
+
# Store values for stats function in model (tower), such that for
|
| 283 |
+
# multi-GPU, we do not override them during the parallel loss phase.
|
| 284 |
+
# SAC stats.
|
| 285 |
+
model.tower_stats["q_t"] = q_t_selected
|
| 286 |
+
model.tower_stats["policy_t"] = policy_t
|
| 287 |
+
model.tower_stats["log_pis_t"] = log_pis_t
|
| 288 |
+
model.tower_stats["actor_loss"] = actor_loss
|
| 289 |
+
model.tower_stats["critic_loss"] = critic_loss
|
| 290 |
+
model.tower_stats["alpha_loss"] = alpha_loss
|
| 291 |
+
model.tower_stats["log_alpha_value"] = model.log_alpha
|
| 292 |
+
model.tower_stats["alpha_value"] = alpha
|
| 293 |
+
model.tower_stats["target_entropy"] = model.target_entropy
|
| 294 |
+
# CQL stats.
|
| 295 |
+
model.tower_stats["cql_loss"] = cql_loss
|
| 296 |
+
|
| 297 |
+
# TD-error tensor in final stats
|
| 298 |
+
# will be concatenated and retrieved for each individual batch item.
|
| 299 |
+
model.tower_stats["td_error"] = td_error
|
| 300 |
+
|
| 301 |
+
if use_lagrange:
|
| 302 |
+
model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
|
| 303 |
+
model.tower_stats["alpha_prime_value"] = alpha_prime
|
| 304 |
+
model.tower_stats["alpha_prime_loss"] = alpha_prime_loss
|
| 305 |
+
|
| 306 |
+
if batch_size == policy.config["train_batch_size"]:
|
| 307 |
+
policy.alpha_prime_optim.zero_grad()
|
| 308 |
+
alpha_prime_loss.backward()
|
| 309 |
+
policy.alpha_prime_optim.step()
|
| 310 |
+
|
| 311 |
+
# Return all loss terms corresponding to our optimizers.
|
| 312 |
+
return tuple(
|
| 313 |
+
[actor_loss]
|
| 314 |
+
+ critic_loss
|
| 315 |
+
+ [alpha_loss]
|
| 316 |
+
+ ([alpha_prime_loss] if use_lagrange else [])
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 321 |
+
# Get SAC loss stats.
|
| 322 |
+
stats_dict = stats(policy, train_batch)
|
| 323 |
+
|
| 324 |
+
# Add CQL loss stats to the dict.
|
| 325 |
+
stats_dict["cql_loss"] = torch.mean(
|
| 326 |
+
torch.stack(*policy.get_tower_stats("cql_loss"))
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if policy.config["lagrangian"]:
|
| 330 |
+
stats_dict["log_alpha_prime_value"] = torch.mean(
|
| 331 |
+
torch.stack(policy.get_tower_stats("log_alpha_prime_value"))
|
| 332 |
+
)
|
| 333 |
+
stats_dict["alpha_prime_value"] = torch.mean(
|
| 334 |
+
torch.stack(policy.get_tower_stats("alpha_prime_value"))
|
| 335 |
+
)
|
| 336 |
+
stats_dict["alpha_prime_loss"] = torch.mean(
|
| 337 |
+
torch.stack(policy.get_tower_stats("alpha_prime_loss"))
|
| 338 |
+
)
|
| 339 |
+
return stats_dict
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def cql_optimizer_fn(
|
| 343 |
+
policy: Policy, config: AlgorithmConfigDict
|
| 344 |
+
) -> Tuple[LocalOptimizer]:
|
| 345 |
+
policy.cur_iter = 0
|
| 346 |
+
opt_list = optimizer_fn(policy, config)
|
| 347 |
+
if config["lagrangian"]:
|
| 348 |
+
log_alpha_prime = nn.Parameter(torch.zeros(1, requires_grad=True).float())
|
| 349 |
+
policy.model.register_parameter("log_alpha_prime", log_alpha_prime)
|
| 350 |
+
policy.alpha_prime_optim = torch.optim.Adam(
|
| 351 |
+
params=[policy.model.log_alpha_prime],
|
| 352 |
+
lr=config["optimization"]["critic_learning_rate"],
|
| 353 |
+
eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
|
| 354 |
+
)
|
| 355 |
+
return tuple(
|
| 356 |
+
[policy.actor_optim]
|
| 357 |
+
+ policy.critic_optims
|
| 358 |
+
+ [policy.alpha_optim]
|
| 359 |
+
+ [policy.alpha_prime_optim]
|
| 360 |
+
)
|
| 361 |
+
return opt_list
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def cql_setup_late_mixins(
|
| 365 |
+
policy: Policy,
|
| 366 |
+
obs_space: gym.spaces.Space,
|
| 367 |
+
action_space: gym.spaces.Space,
|
| 368 |
+
config: AlgorithmConfigDict,
|
| 369 |
+
) -> None:
|
| 370 |
+
setup_late_mixins(policy, obs_space, action_space, config)
|
| 371 |
+
if config["lagrangian"]:
|
| 372 |
+
policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(policy.device)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def compute_gradients_fn(policy, postprocessed_batch):
|
| 376 |
+
batches = [policy._lazy_tensor_dict(postprocessed_batch)]
|
| 377 |
+
model = policy.model
|
| 378 |
+
policy._loss(policy, model, policy.dist_class, batches[0])
|
| 379 |
+
stats = {LEARNER_STATS_KEY: policy._convert_to_numpy(cql_stats(policy, batches[0]))}
|
| 380 |
+
return [None, stats]
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def apply_gradients_fn(policy, gradients):
|
| 384 |
+
return
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# Build a child class of `TorchPolicy`, given the custom functions defined
|
| 388 |
+
# above.
|
| 389 |
+
CQLTorchPolicy = build_policy_class(
|
| 390 |
+
name="CQLTorchPolicy",
|
| 391 |
+
framework="torch",
|
| 392 |
+
loss_fn=cql_loss,
|
| 393 |
+
get_default_config=lambda: ray.rllib.algorithms.cql.cql.CQLConfig(),
|
| 394 |
+
stats_fn=cql_stats,
|
| 395 |
+
postprocess_fn=postprocess_trajectory,
|
| 396 |
+
extra_grad_process_fn=apply_grad_clipping,
|
| 397 |
+
optimizer_fn=cql_optimizer_fn,
|
| 398 |
+
validate_spaces=validate_spaces,
|
| 399 |
+
before_loss_init=cql_setup_late_mixins,
|
| 400 |
+
make_model_and_action_dist=build_sac_model_and_action_dist,
|
| 401 |
+
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
|
| 402 |
+
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
|
| 403 |
+
action_distribution_fn=action_distribution_fn,
|
| 404 |
+
compute_gradients_fn=compute_gradients_fn,
|
| 405 |
+
apply_gradients_fn=apply_gradients_fn,
|
| 406 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (203 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/cql_torch_learner.cpython-311.pyc
ADDED
|
Binary file (9.88 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/__pycache__/default_cql_torch_rl_module.cpython-311.pyc
ADDED
|
Binary file (8.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/cql_torch_learner.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 4 |
+
from ray.rllib.algorithms.sac.sac_learner import (
|
| 5 |
+
LOGPS_KEY,
|
| 6 |
+
QF_LOSS_KEY,
|
| 7 |
+
QF_MEAN_KEY,
|
| 8 |
+
QF_MAX_KEY,
|
| 9 |
+
QF_MIN_KEY,
|
| 10 |
+
QF_PREDS,
|
| 11 |
+
QF_TWIN_LOSS_KEY,
|
| 12 |
+
QF_TWIN_PREDS,
|
| 13 |
+
TD_ERROR_MEAN_KEY,
|
| 14 |
+
)
|
| 15 |
+
from ray.rllib.algorithms.cql.cql import CQLConfig
|
| 16 |
+
from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner
|
| 17 |
+
from ray.rllib.core.columns import Columns
|
| 18 |
+
from ray.rllib.core.learner.learner import (
|
| 19 |
+
POLICY_LOSS_KEY,
|
| 20 |
+
)
|
| 21 |
+
from ray.rllib.utils.annotations import override
|
| 22 |
+
from ray.rllib.utils.metrics import ALL_MODULES
|
| 23 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 24 |
+
from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType
|
| 25 |
+
|
| 26 |
+
torch, nn = try_import_torch()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CQLTorchLearner(SACTorchLearner):
|
| 30 |
+
@override(SACTorchLearner)
|
| 31 |
+
def compute_loss_for_module(
|
| 32 |
+
self,
|
| 33 |
+
*,
|
| 34 |
+
module_id: ModuleID,
|
| 35 |
+
config: CQLConfig,
|
| 36 |
+
batch: Dict,
|
| 37 |
+
fwd_out: Dict[str, TensorType],
|
| 38 |
+
) -> TensorType:
|
| 39 |
+
|
| 40 |
+
# TODO (simon, sven): Add upstream information pieces into this timesteps
|
| 41 |
+
# call arg to Learner.update_...().
|
| 42 |
+
self.metrics.log_value(
|
| 43 |
+
(ALL_MODULES, TRAINING_ITERATION),
|
| 44 |
+
1,
|
| 45 |
+
reduce="sum",
|
| 46 |
+
)
|
| 47 |
+
# Get the train action distribution for the current policy and current state.
|
| 48 |
+
# This is needed for the policy (actor) loss and the `alpha`` loss.
|
| 49 |
+
action_dist_class = self.module[module_id].get_train_action_dist_cls()
|
| 50 |
+
action_dist_curr = action_dist_class.from_logits(
|
| 51 |
+
fwd_out[Columns.ACTION_DIST_INPUTS]
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Optimize also the hyperparameter `alpha` by using the current policy
|
| 55 |
+
# evaluated at the current state (from offline data). Note, in contrast
|
| 56 |
+
# to the original SAC loss, here the `alpha` and actor losses are
|
| 57 |
+
# calculated first.
|
| 58 |
+
# TODO (simon): Check, why log(alpha) is used, prob. just better
|
| 59 |
+
# to optimize and monotonic function. Original equation uses alpha.
|
| 60 |
+
alpha_loss = -torch.mean(
|
| 61 |
+
self.curr_log_alpha[module_id]
|
| 62 |
+
* (fwd_out["logp_resampled"].detach() + self.target_entropy[module_id])
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Get the current alpha.
|
| 66 |
+
alpha = torch.exp(self.curr_log_alpha[module_id])
|
| 67 |
+
# Start training with behavior cloning and turn to the classic Soft-Actor Critic
|
| 68 |
+
# after `bc_iters` of training iterations.
|
| 69 |
+
if (
|
| 70 |
+
self.metrics.peek((ALL_MODULES, TRAINING_ITERATION), default=0)
|
| 71 |
+
>= config.bc_iters
|
| 72 |
+
):
|
| 73 |
+
actor_loss = torch.mean(
|
| 74 |
+
alpha.detach() * fwd_out["logp_resampled"] - fwd_out["q_curr"]
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
# Use log-probabilities of the current action distribution to clone
|
| 78 |
+
# the behavior policy (selected actions in data) in the first `bc_iters`
|
| 79 |
+
# training iterations.
|
| 80 |
+
bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS])
|
| 81 |
+
actor_loss = torch.mean(
|
| 82 |
+
alpha.detach() * fwd_out["logp_resampled"] - bc_logps_curr
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# The critic loss is composed of the standard SAC Critic L2 loss and the
|
| 86 |
+
# CQL entropy loss.
|
| 87 |
+
|
| 88 |
+
# Get the Q-values for the actually selected actions in the offline data.
|
| 89 |
+
# In the critic loss we use these as predictions.
|
| 90 |
+
q_selected = fwd_out[QF_PREDS]
|
| 91 |
+
if config.twin_q:
|
| 92 |
+
q_twin_selected = fwd_out[QF_TWIN_PREDS]
|
| 93 |
+
|
| 94 |
+
if not config.deterministic_backup:
|
| 95 |
+
q_next = (
|
| 96 |
+
fwd_out["q_target_next"]
|
| 97 |
+
- alpha.detach() * fwd_out["logp_next_resampled"]
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
q_next = fwd_out["q_target_next"]
|
| 101 |
+
|
| 102 |
+
# Now mask all Q-values with terminating next states in the targets.
|
| 103 |
+
q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_next
|
| 104 |
+
|
| 105 |
+
# Compute the right hand side of the Bellman equation. Detach this node
|
| 106 |
+
# from the computation graph as we do not want to backpropagate through
|
| 107 |
+
# the target network when optimizing the Q loss.
|
| 108 |
+
# TODO (simon, sven): Kumar et al. (2020) use here also a reward scaler.
|
| 109 |
+
q_selected_target = (
|
| 110 |
+
# TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector.
|
| 111 |
+
batch[Columns.REWARDS]
|
| 112 |
+
# TODO (simon): Implement n_step.
|
| 113 |
+
+ (config.gamma) * q_next_masked
|
| 114 |
+
).detach()
|
| 115 |
+
|
| 116 |
+
# Calculate the TD error.
|
| 117 |
+
td_error = torch.abs(q_selected - q_selected_target)
|
| 118 |
+
# Calculate a TD-error for twin-Q values, if needed.
|
| 119 |
+
if config.twin_q:
|
| 120 |
+
td_error += torch.abs(q_twin_selected - q_selected_target)
|
| 121 |
+
# Rescale the TD error
|
| 122 |
+
td_error *= 0.5
|
| 123 |
+
|
| 124 |
+
# MSBE loss for the critic(s) (i.e. Q, see eqs. (7-8) Haarnoja et al. (2018)).
|
| 125 |
+
# Note, this needs a sample from the current policy given the next state.
|
| 126 |
+
# Note further, we could also use here the Huber loss instead of the MSE.
|
| 127 |
+
# TODO (simon): Add the huber loss as an alternative (SAC uses it).
|
| 128 |
+
sac_critic_loss = torch.nn.MSELoss(reduction="mean")(
|
| 129 |
+
q_selected,
|
| 130 |
+
q_selected_target,
|
| 131 |
+
)
|
| 132 |
+
if config.twin_q:
|
| 133 |
+
sac_critic_twin_loss = torch.nn.MSELoss(reduction="mean")(
|
| 134 |
+
q_twin_selected,
|
| 135 |
+
q_selected_target,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Now calculate the CQL loss (we use the entropy version of the CQL algorithm).
|
| 139 |
+
# Note, the entropy version performs best in shown experiments.
|
| 140 |
+
|
| 141 |
+
# Compute the log-probabilities for the random actions (note, we generate random
|
| 142 |
+
# actions (from the mu distribution as named in Kumar et al. (2020))).
|
| 143 |
+
# Note, all actions, action log-probabilities and Q-values are already computed
|
| 144 |
+
# by the module's `_forward_train` method.
|
| 145 |
+
# TODO (simon): This is the density for a discrete uniform, however, actions
|
| 146 |
+
# come from a continuous one. So actually this density should use (1/(high-low))
|
| 147 |
+
# instead of (1/2).
|
| 148 |
+
random_density = torch.log(
|
| 149 |
+
torch.pow(
|
| 150 |
+
0.5,
|
| 151 |
+
torch.tensor(
|
| 152 |
+
fwd_out["actions_curr_repeat"].shape[-1],
|
| 153 |
+
device=fwd_out["actions_curr_repeat"].device,
|
| 154 |
+
),
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
# Merge all Q-values and subtract the log-probabilities (note, we use the
|
| 158 |
+
# entropy version of CQL).
|
| 159 |
+
q_repeat = torch.cat(
|
| 160 |
+
[
|
| 161 |
+
fwd_out["q_rand_repeat"] - random_density,
|
| 162 |
+
fwd_out["q_next_repeat"] - fwd_out["logps_next_repeat"].detach(),
|
| 163 |
+
fwd_out["q_curr_repeat"] - fwd_out["logps_curr_repeat"].detach(),
|
| 164 |
+
],
|
| 165 |
+
dim=1,
|
| 166 |
+
)
|
| 167 |
+
cql_loss = (
|
| 168 |
+
torch.logsumexp(q_repeat / config.temperature, dim=1).mean()
|
| 169 |
+
* config.min_q_weight
|
| 170 |
+
* config.temperature
|
| 171 |
+
)
|
| 172 |
+
cql_loss -= q_selected.mean() * config.min_q_weight
|
| 173 |
+
# Add the CQL loss term to the SAC loss term.
|
| 174 |
+
critic_loss = sac_critic_loss + cql_loss
|
| 175 |
+
|
| 176 |
+
# If a twin Q-value function is implemented calculated its CQL loss.
|
| 177 |
+
if config.twin_q:
|
| 178 |
+
q_twin_repeat = torch.cat(
|
| 179 |
+
[
|
| 180 |
+
fwd_out["q_twin_rand_repeat"] - random_density,
|
| 181 |
+
fwd_out["q_twin_next_repeat"]
|
| 182 |
+
- fwd_out["logps_next_repeat"].detach(),
|
| 183 |
+
fwd_out["q_twin_curr_repeat"]
|
| 184 |
+
- fwd_out["logps_curr_repeat"].detach(),
|
| 185 |
+
],
|
| 186 |
+
dim=1,
|
| 187 |
+
)
|
| 188 |
+
cql_twin_loss = (
|
| 189 |
+
torch.logsumexp(q_twin_repeat / config.temperature, dim=1).mean()
|
| 190 |
+
* config.min_q_weight
|
| 191 |
+
* config.temperature
|
| 192 |
+
)
|
| 193 |
+
cql_twin_loss -= q_twin_selected.mean() * config.min_q_weight
|
| 194 |
+
# Add the CQL loss term to the SAC loss term.
|
| 195 |
+
critic_twin_loss = sac_critic_twin_loss + cql_twin_loss
|
| 196 |
+
|
| 197 |
+
# TODO (simon): Check, if we need to implement here also a Lagrangian
|
| 198 |
+
# loss.
|
| 199 |
+
|
| 200 |
+
total_loss = actor_loss + critic_loss + alpha_loss
|
| 201 |
+
|
| 202 |
+
# Add the twin critic loss to the total loss, if needed.
|
| 203 |
+
if config.twin_q:
|
| 204 |
+
# Reweigh the critic loss terms in the total loss.
|
| 205 |
+
total_loss += 0.5 * critic_twin_loss - 0.5 * critic_loss
|
| 206 |
+
|
| 207 |
+
# Log important loss stats (reduce=mean (default), but with window=1
|
| 208 |
+
# in order to keep them history free).
|
| 209 |
+
self.metrics.log_dict(
|
| 210 |
+
{
|
| 211 |
+
POLICY_LOSS_KEY: actor_loss,
|
| 212 |
+
QF_LOSS_KEY: critic_loss,
|
| 213 |
+
# TODO (simon): Add these keys to SAC Learner.
|
| 214 |
+
"cql_loss": cql_loss,
|
| 215 |
+
"alpha_loss": alpha_loss,
|
| 216 |
+
"alpha_value": alpha,
|
| 217 |
+
"log_alpha_value": torch.log(alpha),
|
| 218 |
+
"target_entropy": self.target_entropy[module_id],
|
| 219 |
+
LOGPS_KEY: torch.mean(
|
| 220 |
+
fwd_out["logp_resampled"]
|
| 221 |
+
), # torch.mean(logps_curr),
|
| 222 |
+
QF_MEAN_KEY: torch.mean(fwd_out["q_curr_repeat"]),
|
| 223 |
+
QF_MAX_KEY: torch.max(fwd_out["q_curr_repeat"]),
|
| 224 |
+
QF_MIN_KEY: torch.min(fwd_out["q_curr_repeat"]),
|
| 225 |
+
TD_ERROR_MEAN_KEY: torch.mean(td_error),
|
| 226 |
+
},
|
| 227 |
+
key=module_id,
|
| 228 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 229 |
+
)
|
| 230 |
+
# TODO (simon): Add loss keys for langrangian, if needed.
|
| 231 |
+
# TODO (simon): Add only here then the Langrange parameter optimization.
|
| 232 |
+
if config.twin_q:
|
| 233 |
+
self.metrics.log_dict(
|
| 234 |
+
{
|
| 235 |
+
QF_TWIN_LOSS_KEY: critic_twin_loss,
|
| 236 |
+
},
|
| 237 |
+
key=module_id,
|
| 238 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Return the total loss.
|
| 242 |
+
return total_loss
|
| 243 |
+
|
| 244 |
+
@override(SACTorchLearner)
|
| 245 |
+
def compute_gradients(
|
| 246 |
+
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
|
| 247 |
+
) -> ParamDict:
|
| 248 |
+
|
| 249 |
+
grads = {}
|
| 250 |
+
for module_id in set(loss_per_module.keys()) - {ALL_MODULES}:
|
| 251 |
+
# Loop through optimizers registered for this module.
|
| 252 |
+
for optim_name, optim in self.get_optimizers_for_module(module_id):
|
| 253 |
+
# Zero the gradients. Note, we need to reset the gradients b/c
|
| 254 |
+
# each component for a module operates on the same graph.
|
| 255 |
+
optim.zero_grad(set_to_none=True)
|
| 256 |
+
|
| 257 |
+
# Compute the gradients for the component and module.
|
| 258 |
+
self.metrics.peek((module_id, optim_name + "_loss")).backward(
|
| 259 |
+
retain_graph=False if optim_name in ["policy", "alpha"] else True
|
| 260 |
+
)
|
| 261 |
+
# Store the gradients for the component and module.
|
| 262 |
+
# TODO (simon): Check another time the graph for overlapping
|
| 263 |
+
# gradients.
|
| 264 |
+
grads.update(
|
| 265 |
+
{
|
| 266 |
+
pid: grads[pid] + p.grad.clone()
|
| 267 |
+
if pid in grads
|
| 268 |
+
else p.grad.clone()
|
| 269 |
+
for pid, p in self.filter_param_dict_for_optimizer(
|
| 270 |
+
self._params, optim
|
| 271 |
+
).items()
|
| 272 |
+
}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return grads
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tree
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.sac.sac_learner import (
|
| 5 |
+
QF_PREDS,
|
| 6 |
+
QF_TWIN_PREDS,
|
| 7 |
+
)
|
| 8 |
+
from ray.rllib.algorithms.sac.sac_catalog import SACCatalog
|
| 9 |
+
from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import (
|
| 10 |
+
DefaultSACTorchRLModule,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.core.columns import Columns
|
| 13 |
+
from ray.rllib.core.models.base import ENCODER_OUT
|
| 14 |
+
from ray.rllib.utils.annotations import override
|
| 15 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 16 |
+
from ray.rllib.utils.typing import TensorType
|
| 17 |
+
|
| 18 |
+
torch, nn = try_import_torch()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DefaultCQLTorchRLModule(DefaultSACTorchRLModule):
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
catalog_class = kwargs.pop("catalog_class", None)
|
| 24 |
+
if catalog_class is None:
|
| 25 |
+
catalog_class = SACCatalog
|
| 26 |
+
super().__init__(*args, **kwargs, catalog_class=catalog_class)
|
| 27 |
+
|
| 28 |
+
@override(DefaultSACTorchRLModule)
|
| 29 |
+
def _forward_train(self, batch: Dict) -> Dict[str, Any]:
|
| 30 |
+
# Call the super method.
|
| 31 |
+
fwd_out = super()._forward_train(batch)
|
| 32 |
+
|
| 33 |
+
# Make sure we perform a "straight-through gradient" pass here,
|
| 34 |
+
# ignoring the gradients of the q-net, however, still recording
|
| 35 |
+
# the gradients of the policy net (which was used to rsample the actions used
|
| 36 |
+
# here). This is different from doing `.detach()` or `with torch.no_grads()`,
|
| 37 |
+
# as these two methds would fully block all gradient recordings, including
|
| 38 |
+
# the needed policy ones.
|
| 39 |
+
all_params = list(self.pi_encoder.parameters()) + list(self.pi.parameters())
|
| 40 |
+
# if self.twin_q:
|
| 41 |
+
# all_params += list(self.qf_twin.parameters()) + list(
|
| 42 |
+
# self.qf_twin_encoder.parameters()
|
| 43 |
+
# )
|
| 44 |
+
|
| 45 |
+
for param in all_params:
|
| 46 |
+
param.requires_grad = False
|
| 47 |
+
|
| 48 |
+
# Compute the repeated actions, action log-probabilites and Q-values for all
|
| 49 |
+
# observations.
|
| 50 |
+
# First for the random actions (from the mu-distribution as named by Kumar et
|
| 51 |
+
# al. (2020)).
|
| 52 |
+
low = torch.tensor(
|
| 53 |
+
self.action_space.low,
|
| 54 |
+
device=fwd_out[QF_PREDS].device,
|
| 55 |
+
)
|
| 56 |
+
high = torch.tensor(
|
| 57 |
+
self.action_space.high,
|
| 58 |
+
device=fwd_out[QF_PREDS].device,
|
| 59 |
+
)
|
| 60 |
+
num_samples = batch[Columns.ACTIONS].shape[0] * self.model_config["num_actions"]
|
| 61 |
+
actions_rand_repeat = low + (high - low) * torch.rand(
|
| 62 |
+
(num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# First for the random actions (from the mu-distribution as named in Kumar
|
| 66 |
+
# et al. (2020)) using repeated observations.
|
| 67 |
+
rand_repeat_out = self._repeat_actions(batch[Columns.OBS], actions_rand_repeat)
|
| 68 |
+
(fwd_out["actions_rand_repeat"], fwd_out["q_rand_repeat"]) = (
|
| 69 |
+
rand_repeat_out[Columns.ACTIONS],
|
| 70 |
+
rand_repeat_out[QF_PREDS],
|
| 71 |
+
)
|
| 72 |
+
# Sample current and next actions (from the pi distribution as named in Kumar
|
| 73 |
+
# et al. (2020)) using repeated observations
|
| 74 |
+
# Second for the current observations and the current action distribution.
|
| 75 |
+
curr_repeat_out = self._repeat_actions(batch[Columns.OBS])
|
| 76 |
+
(
|
| 77 |
+
fwd_out["actions_curr_repeat"],
|
| 78 |
+
fwd_out["logps_curr_repeat"],
|
| 79 |
+
fwd_out["q_curr_repeat"],
|
| 80 |
+
) = (
|
| 81 |
+
curr_repeat_out[Columns.ACTIONS],
|
| 82 |
+
curr_repeat_out[Columns.ACTION_LOGP],
|
| 83 |
+
curr_repeat_out[QF_PREDS],
|
| 84 |
+
)
|
| 85 |
+
# Then, for the next observations and the current action distribution.
|
| 86 |
+
next_repeat_out = self._repeat_actions(batch[Columns.NEXT_OBS])
|
| 87 |
+
(
|
| 88 |
+
fwd_out["actions_next_repeat"],
|
| 89 |
+
fwd_out["logps_next_repeat"],
|
| 90 |
+
fwd_out["q_next_repeat"],
|
| 91 |
+
) = (
|
| 92 |
+
next_repeat_out[Columns.ACTIONS],
|
| 93 |
+
next_repeat_out[Columns.ACTION_LOGP],
|
| 94 |
+
next_repeat_out[QF_PREDS],
|
| 95 |
+
)
|
| 96 |
+
if self.twin_q:
|
| 97 |
+
# First for the random actions from the mu-distribution.
|
| 98 |
+
fwd_out["q_twin_rand_repeat"] = rand_repeat_out[QF_TWIN_PREDS]
|
| 99 |
+
# Second for the current observations and the current action distribution.
|
| 100 |
+
fwd_out["q_twin_curr_repeat"] = curr_repeat_out[QF_TWIN_PREDS]
|
| 101 |
+
# Then, for the next observations and the current action distribution.
|
| 102 |
+
fwd_out["q_twin_next_repeat"] = next_repeat_out[QF_TWIN_PREDS]
|
| 103 |
+
# Reset the gradient requirements for all Q-function parameters.
|
| 104 |
+
for param in all_params:
|
| 105 |
+
param.requires_grad = True
|
| 106 |
+
|
| 107 |
+
return fwd_out
|
| 108 |
+
|
| 109 |
+
def _repeat_tensor(self, tensor: TensorType, repeat: int) -> TensorType:
|
| 110 |
+
"""Generates a repeated version of a tensor.
|
| 111 |
+
|
| 112 |
+
The repetition is done similar `np.repeat` and repeats each value
|
| 113 |
+
instead of the complete vector.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
tensor: The tensor to be repeated.
|
| 117 |
+
repeat: How often each value in the tensor should be repeated.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
A tensor holding `repeat` repeated values of the input `tensor`
|
| 121 |
+
"""
|
| 122 |
+
# Insert the new dimension at axis 1 into the tensor.
|
| 123 |
+
t_repeat = tensor.unsqueeze(1)
|
| 124 |
+
# Repeat the tensor along the new dimension.
|
| 125 |
+
t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1)
|
| 126 |
+
# Stack the repeated values into the batch dimension.
|
| 127 |
+
t_repeat = t_repeat.view(-1, *tensor.shape[1:])
|
| 128 |
+
# Return the repeated tensor.
|
| 129 |
+
return t_repeat
|
| 130 |
+
|
| 131 |
+
def _repeat_actions(
|
| 132 |
+
self, obs: TensorType, actions: Optional[TensorType] = None
|
| 133 |
+
) -> Dict[str, TensorType]:
|
| 134 |
+
"""Generated actions and Q-values for repeated observations.
|
| 135 |
+
|
| 136 |
+
The `self.model_config["num_actions"]` define a multiplier
|
| 137 |
+
used for generating `num_actions` as many actions as the batch size.
|
| 138 |
+
Observations are repeated and then a model forward pass is made.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
obs: A batched observation tensor.
|
| 142 |
+
actions: An optional batched actions tensor.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
A dictionary holding the (sampled or passed-in actions), the log
|
| 146 |
+
probabilities (of sampled actions), the Q-values and if available
|
| 147 |
+
the twin-Q values.
|
| 148 |
+
"""
|
| 149 |
+
output = {}
|
| 150 |
+
# Receive the batch size.
|
| 151 |
+
batch_size = obs.shape[0]
|
| 152 |
+
# Receive the number of action to sample.
|
| 153 |
+
num_actions = self.model_config["num_actions"]
|
| 154 |
+
# Repeat the observations `num_actions` times.
|
| 155 |
+
obs_repeat = tree.map_structure(
|
| 156 |
+
lambda t: self._repeat_tensor(t, num_actions), obs
|
| 157 |
+
)
|
| 158 |
+
# Generate a batch for the forward pass.
|
| 159 |
+
temp_batch = {Columns.OBS: obs_repeat}
|
| 160 |
+
if actions is None:
|
| 161 |
+
# TODO (simon): Run the forward pass in inference mode.
|
| 162 |
+
# Compute the action logits.
|
| 163 |
+
pi_encoder_outs = self.pi_encoder(temp_batch)
|
| 164 |
+
action_logits = self.pi(pi_encoder_outs[ENCODER_OUT])
|
| 165 |
+
# Generate the squashed Gaussian from the model's logits.
|
| 166 |
+
action_dist = self.get_train_action_dist_cls().from_logits(action_logits)
|
| 167 |
+
# Sample the actions. Note, we want to make a backward pass through
|
| 168 |
+
# these actions.
|
| 169 |
+
output[Columns.ACTIONS] = action_dist.rsample()
|
| 170 |
+
# Compute the action log-probabilities.
|
| 171 |
+
output[Columns.ACTION_LOGP] = action_dist.logp(
|
| 172 |
+
output[Columns.ACTIONS]
|
| 173 |
+
).view(batch_size, num_actions, 1)
|
| 174 |
+
else:
|
| 175 |
+
output[Columns.ACTIONS] = actions
|
| 176 |
+
|
| 177 |
+
# Compute all Q-values.
|
| 178 |
+
temp_batch.update(
|
| 179 |
+
{
|
| 180 |
+
Columns.ACTIONS: output[Columns.ACTIONS],
|
| 181 |
+
}
|
| 182 |
+
)
|
| 183 |
+
output.update(
|
| 184 |
+
{
|
| 185 |
+
QF_PREDS: self._qf_forward_train_helper(
|
| 186 |
+
temp_batch,
|
| 187 |
+
self.qf_encoder,
|
| 188 |
+
self.qf,
|
| 189 |
+
).view(batch_size, num_actions, 1)
|
| 190 |
+
}
|
| 191 |
+
)
|
| 192 |
+
# If we have a twin-Q network, compute its Q-values, too.
|
| 193 |
+
if self.twin_q:
|
| 194 |
+
output.update(
|
| 195 |
+
{
|
| 196 |
+
QF_TWIN_PREDS: self._qf_forward_train_helper(
|
| 197 |
+
temp_batch,
|
| 198 |
+
self.qf_twin_encoder,
|
| 199 |
+
self.qf_twin,
|
| 200 |
+
).view(batch_size, num_actions, 1)
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
del temp_batch
|
| 204 |
+
|
| 205 |
+
# Return
|
| 206 |
+
return output
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
|
| 6 |
+
[2] Mastering Atari with Discrete World Models - 2021
|
| 7 |
+
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
|
| 8 |
+
https://arxiv.org/pdf/2010.02193.pdf
|
| 9 |
+
"""
|
| 10 |
+
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3, DreamerV3Config
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"DreamerV3",
|
| 14 |
+
"DreamerV3Config",
|
| 15 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3.py
ADDED
|
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
|
| 6 |
+
[2] Mastering Atari with Discrete World Models - 2021
|
| 7 |
+
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
|
| 8 |
+
https://arxiv.org/pdf/2010.02193.pdf
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import gc
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Any, Dict, Optional, Union
|
| 14 |
+
|
| 15 |
+
import gymnasium as gym
|
| 16 |
+
|
| 17 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 18 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
|
| 19 |
+
from ray.rllib.algorithms.dreamerv3.dreamerv3_catalog import DreamerV3Catalog
|
| 20 |
+
from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs
|
| 21 |
+
from ray.rllib.algorithms.dreamerv3.utils.env_runner import DreamerV3EnvRunner
|
| 22 |
+
from ray.rllib.algorithms.dreamerv3.utils.summaries import (
|
| 23 |
+
report_dreamed_eval_trajectory_vs_samples,
|
| 24 |
+
report_predicted_vs_sampled_obs,
|
| 25 |
+
report_sampling_and_replay_buffer,
|
| 26 |
+
)
|
| 27 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 28 |
+
from ray.rllib.core.columns import Columns
|
| 29 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 30 |
+
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
| 31 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 32 |
+
from ray.rllib.utils import deep_update
|
| 33 |
+
from ray.rllib.utils.annotations import override, PublicAPI
|
| 34 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 35 |
+
from ray.rllib.utils.numpy import one_hot
|
| 36 |
+
from ray.rllib.utils.metrics import (
|
| 37 |
+
ENV_RUNNER_RESULTS,
|
| 38 |
+
GARBAGE_COLLECTION_TIMER,
|
| 39 |
+
LEARN_ON_BATCH_TIMER,
|
| 40 |
+
LEARNER_RESULTS,
|
| 41 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME,
|
| 42 |
+
NUM_ENV_STEPS_TRAINED_LIFETIME,
|
| 43 |
+
NUM_GRAD_UPDATES_LIFETIME,
|
| 44 |
+
NUM_SYNCH_WORKER_WEIGHTS,
|
| 45 |
+
SAMPLE_TIMER,
|
| 46 |
+
SYNCH_WORKER_WEIGHTS_TIMER,
|
| 47 |
+
TIMERS,
|
| 48 |
+
)
|
| 49 |
+
from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
|
| 50 |
+
from ray.rllib.utils.typing import LearningRateOrSchedule
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
_, tf, _ = try_import_tf()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DreamerV3Config(AlgorithmConfig):
|
| 59 |
+
"""Defines a configuration class from which a DreamerV3 can be built.
|
| 60 |
+
|
| 61 |
+
.. testcode::
|
| 62 |
+
|
| 63 |
+
from ray.rllib.algorithms.dreamerv3 import DreamerV3Config
|
| 64 |
+
config = (
|
| 65 |
+
DreamerV3Config()
|
| 66 |
+
.environment("CartPole-v1")
|
| 67 |
+
.training(
|
| 68 |
+
model_size="XS",
|
| 69 |
+
training_ratio=1,
|
| 70 |
+
# TODO
|
| 71 |
+
model={
|
| 72 |
+
"batch_size_B": 1,
|
| 73 |
+
"batch_length_T": 1,
|
| 74 |
+
"horizon_H": 1,
|
| 75 |
+
"gamma": 0.997,
|
| 76 |
+
"model_size": "XS",
|
| 77 |
+
},
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
config = config.learners(num_learners=0)
|
| 82 |
+
# Build a Algorithm object from the config and run 1 training iteration.
|
| 83 |
+
algo = config.build()
|
| 84 |
+
# algo.train()
|
| 85 |
+
del algo
|
| 86 |
+
|
| 87 |
+
.. testoutput::
|
| 88 |
+
:hide:
|
| 89 |
+
|
| 90 |
+
...
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, algo_class=None):
|
| 94 |
+
"""Initializes a DreamerV3Config instance."""
|
| 95 |
+
super().__init__(algo_class=algo_class or DreamerV3)
|
| 96 |
+
|
| 97 |
+
# fmt: off
|
| 98 |
+
# __sphinx_doc_begin__
|
| 99 |
+
|
| 100 |
+
# DreamerV3 specific settings:
|
| 101 |
+
self.model_size = "XS"
|
| 102 |
+
self.training_ratio = 1024
|
| 103 |
+
|
| 104 |
+
self.replay_buffer_config = {
|
| 105 |
+
"type": "EpisodeReplayBuffer",
|
| 106 |
+
"capacity": int(1e6),
|
| 107 |
+
}
|
| 108 |
+
self.world_model_lr = 1e-4
|
| 109 |
+
self.actor_lr = 3e-5
|
| 110 |
+
self.critic_lr = 3e-5
|
| 111 |
+
self.batch_size_B = 16
|
| 112 |
+
self.batch_length_T = 64
|
| 113 |
+
self.horizon_H = 15
|
| 114 |
+
self.gae_lambda = 0.95 # [1] eq. 7.
|
| 115 |
+
self.entropy_scale = 3e-4 # [1] eq. 11.
|
| 116 |
+
self.return_normalization_decay = 0.99 # [1] eq. 11 and 12.
|
| 117 |
+
self.train_critic = True
|
| 118 |
+
self.train_actor = True
|
| 119 |
+
self.intrinsic_rewards_scale = 0.1
|
| 120 |
+
self.world_model_grad_clip_by_global_norm = 1000.0
|
| 121 |
+
self.critic_grad_clip_by_global_norm = 100.0
|
| 122 |
+
self.actor_grad_clip_by_global_norm = 100.0
|
| 123 |
+
self.symlog_obs = "auto"
|
| 124 |
+
self.use_float16 = False
|
| 125 |
+
self.use_curiosity = False
|
| 126 |
+
|
| 127 |
+
# Reporting.
|
| 128 |
+
# DreamerV3 is super sample efficient and only needs very few episodes
|
| 129 |
+
# (normally) to learn. Leaving this at its default value would gravely
|
| 130 |
+
# underestimate the learning performance over the course of an experiment.
|
| 131 |
+
self.metrics_num_episodes_for_smoothing = 1
|
| 132 |
+
self.report_individual_batch_item_stats = False
|
| 133 |
+
self.report_dream_data = False
|
| 134 |
+
self.report_images_and_videos = False
|
| 135 |
+
self.gc_frequency_train_steps = 100
|
| 136 |
+
|
| 137 |
+
# Override some of AlgorithmConfig's default values with DreamerV3-specific
|
| 138 |
+
# values.
|
| 139 |
+
self.lr = None
|
| 140 |
+
self.framework_str = "tf2"
|
| 141 |
+
self.gamma = 0.997 # [1] eq. 7.
|
| 142 |
+
# Do not use! Set `batch_size_B` and `batch_length_T` instead.
|
| 143 |
+
self.train_batch_size = None
|
| 144 |
+
self.env_runner_cls = DreamerV3EnvRunner
|
| 145 |
+
self.num_env_runners = 0
|
| 146 |
+
self.rollout_fragment_length = 1
|
| 147 |
+
# Dreamer only runs on the new API stack.
|
| 148 |
+
self.enable_rl_module_and_learner = True
|
| 149 |
+
self.enable_env_runner_and_connector_v2 = True
|
| 150 |
+
# TODO (sven): DreamerV3 still uses its own EnvRunner class. This env-runner
|
| 151 |
+
# does not use connectors. We therefore should not attempt to merge/broadcast
|
| 152 |
+
# the connector states between EnvRunners (if >0). Note that this is only
|
| 153 |
+
# relevant if num_env_runners > 0, which is normally not the case when using
|
| 154 |
+
# this algo.
|
| 155 |
+
self.use_worker_filter_stats = False
|
| 156 |
+
# __sphinx_doc_end__
|
| 157 |
+
# fmt: on
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def batch_size_B_per_learner(self):
|
| 161 |
+
"""Returns the batch_size_B per Learner worker.
|
| 162 |
+
|
| 163 |
+
Needed by some of the DreamerV3 loss math."""
|
| 164 |
+
return self.batch_size_B // (self.num_learners or 1)
|
| 165 |
+
|
| 166 |
+
@override(AlgorithmConfig)
|
| 167 |
+
def training(
|
| 168 |
+
self,
|
| 169 |
+
*,
|
| 170 |
+
model_size: Optional[str] = NotProvided,
|
| 171 |
+
training_ratio: Optional[float] = NotProvided,
|
| 172 |
+
gc_frequency_train_steps: Optional[int] = NotProvided,
|
| 173 |
+
batch_size_B: Optional[int] = NotProvided,
|
| 174 |
+
batch_length_T: Optional[int] = NotProvided,
|
| 175 |
+
horizon_H: Optional[int] = NotProvided,
|
| 176 |
+
gae_lambda: Optional[float] = NotProvided,
|
| 177 |
+
entropy_scale: Optional[float] = NotProvided,
|
| 178 |
+
return_normalization_decay: Optional[float] = NotProvided,
|
| 179 |
+
train_critic: Optional[bool] = NotProvided,
|
| 180 |
+
train_actor: Optional[bool] = NotProvided,
|
| 181 |
+
intrinsic_rewards_scale: Optional[float] = NotProvided,
|
| 182 |
+
world_model_lr: Optional[LearningRateOrSchedule] = NotProvided,
|
| 183 |
+
actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
|
| 184 |
+
critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
|
| 185 |
+
world_model_grad_clip_by_global_norm: Optional[float] = NotProvided,
|
| 186 |
+
critic_grad_clip_by_global_norm: Optional[float] = NotProvided,
|
| 187 |
+
actor_grad_clip_by_global_norm: Optional[float] = NotProvided,
|
| 188 |
+
symlog_obs: Optional[Union[bool, str]] = NotProvided,
|
| 189 |
+
use_float16: Optional[bool] = NotProvided,
|
| 190 |
+
replay_buffer_config: Optional[dict] = NotProvided,
|
| 191 |
+
use_curiosity: Optional[bool] = NotProvided,
|
| 192 |
+
**kwargs,
|
| 193 |
+
) -> "DreamerV3Config":
|
| 194 |
+
"""Sets the training related configuration.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
model_size: The main switch for adjusting the overall model size. See [1]
|
| 198 |
+
(table B) for more information on the effects of this setting on the
|
| 199 |
+
model architecture.
|
| 200 |
+
Supported values are "XS", "S", "M", "L", "XL" (as per the paper), as
|
| 201 |
+
well as, "nano", "micro", "mini", and "XXS" (for RLlib's
|
| 202 |
+
implementation). See ray.rllib.algorithms.dreamerv3.utils.
|
| 203 |
+
__init__.py for the details on what exactly each size does to the layer
|
| 204 |
+
sizes, number of layers, etc..
|
| 205 |
+
training_ratio: The ratio of total steps trained (sum of the sizes of all
|
| 206 |
+
batches ever sampled from the replay buffer) over the total env steps
|
| 207 |
+
taken (in the actual environment, not the dreamed one). For example,
|
| 208 |
+
if the training_ratio is 1024 and the batch size is 1024, we would take
|
| 209 |
+
1 env step for every training update: 1024 / 1. If the training ratio
|
| 210 |
+
is 512 and the batch size is 1024, we would take 2 env steps and then
|
| 211 |
+
perform a single training update (on a 1024 batch): 1024 / 2.
|
| 212 |
+
gc_frequency_train_steps: The frequency (in training iterations) with which
|
| 213 |
+
we perform a `gc.collect()` calls at the end of a `training_step`
|
| 214 |
+
iteration. Doing this more often adds a (albeit very small) performance
|
| 215 |
+
overhead, but prevents memory leaks from becoming harmful.
|
| 216 |
+
TODO (sven): This might not be necessary anymore, but needs to be
|
| 217 |
+
confirmed experimentally.
|
| 218 |
+
batch_size_B: The batch size (B) interpreted as number of rows (each of
|
| 219 |
+
length `batch_length_T`) to sample from the replay buffer in each
|
| 220 |
+
iteration.
|
| 221 |
+
batch_length_T: The batch length (T) interpreted as the length of each row
|
| 222 |
+
sampled from the replay buffer in each iteration. Note that
|
| 223 |
+
`batch_size_B` rows will be sampled in each iteration. Rows normally
|
| 224 |
+
contain consecutive data (consecutive timesteps from the same episode),
|
| 225 |
+
but there might be episode boundaries in a row as well.
|
| 226 |
+
horizon_H: The horizon (in timesteps) used to create dreamed data from the
|
| 227 |
+
world model, which in turn is used to train/update both actor- and
|
| 228 |
+
critic networks.
|
| 229 |
+
gae_lambda: The lambda parameter used for computing the GAE-style
|
| 230 |
+
value targets for the actor- and critic losses.
|
| 231 |
+
entropy_scale: The factor with which to multiply the entropy loss term
|
| 232 |
+
inside the actor loss.
|
| 233 |
+
return_normalization_decay: The decay value to use when computing the
|
| 234 |
+
running EMA values for return normalization (used in the actor loss).
|
| 235 |
+
train_critic: Whether to train the critic network. If False, `train_actor`
|
| 236 |
+
must also be False (cannot train actor w/o training the critic).
|
| 237 |
+
train_actor: Whether to train the actor network. If True, `train_critic`
|
| 238 |
+
must also be True (cannot train actor w/o training the critic).
|
| 239 |
+
intrinsic_rewards_scale: The factor to multiply intrinsic rewards with
|
| 240 |
+
before adding them to the extrinsic (environment) rewards.
|
| 241 |
+
world_model_lr: The learning rate or schedule for the world model optimizer.
|
| 242 |
+
actor_lr: The learning rate or schedule for the actor optimizer.
|
| 243 |
+
critic_lr: The learning rate or schedule for the critic optimizer.
|
| 244 |
+
world_model_grad_clip_by_global_norm: World model grad clipping value
|
| 245 |
+
(by global norm).
|
| 246 |
+
critic_grad_clip_by_global_norm: Critic grad clipping value
|
| 247 |
+
(by global norm).
|
| 248 |
+
actor_grad_clip_by_global_norm: Actor grad clipping value (by global norm).
|
| 249 |
+
symlog_obs: Whether to symlog observations or not. If set to "auto"
|
| 250 |
+
(default), will check for the environment's observation space and then
|
| 251 |
+
only symlog if not an image space.
|
| 252 |
+
use_float16: Whether to train with mixed float16 precision. In this mode,
|
| 253 |
+
model parameters are stored as float32, but all computations are
|
| 254 |
+
performed in float16 space (except for losses and distribution params
|
| 255 |
+
and outputs).
|
| 256 |
+
replay_buffer_config: Replay buffer config.
|
| 257 |
+
Only serves in DreamerV3 to set the capacity of the replay buffer.
|
| 258 |
+
Note though that in the paper ([1]) a size of 1M is used for all
|
| 259 |
+
benchmarks and there doesn't seem to be a good reason to change this
|
| 260 |
+
parameter.
|
| 261 |
+
Examples:
|
| 262 |
+
{
|
| 263 |
+
"type": "EpisodeReplayBuffer",
|
| 264 |
+
"capacity": 100000,
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
This updated AlgorithmConfig object.
|
| 269 |
+
"""
|
| 270 |
+
# Not fully supported/tested yet.
|
| 271 |
+
if use_curiosity is not NotProvided:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
"`DreamerV3Config.curiosity` is not fully supported and tested yet! "
|
| 274 |
+
"It thus remains disabled for now."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Pass kwargs onto super's `training()` method.
|
| 278 |
+
super().training(**kwargs)
|
| 279 |
+
|
| 280 |
+
if model_size is not NotProvided:
|
| 281 |
+
self.model_size = model_size
|
| 282 |
+
if training_ratio is not NotProvided:
|
| 283 |
+
self.training_ratio = training_ratio
|
| 284 |
+
if gc_frequency_train_steps is not NotProvided:
|
| 285 |
+
self.gc_frequency_train_steps = gc_frequency_train_steps
|
| 286 |
+
if batch_size_B is not NotProvided:
|
| 287 |
+
self.batch_size_B = batch_size_B
|
| 288 |
+
if batch_length_T is not NotProvided:
|
| 289 |
+
self.batch_length_T = batch_length_T
|
| 290 |
+
if horizon_H is not NotProvided:
|
| 291 |
+
self.horizon_H = horizon_H
|
| 292 |
+
if gae_lambda is not NotProvided:
|
| 293 |
+
self.gae_lambda = gae_lambda
|
| 294 |
+
if entropy_scale is not NotProvided:
|
| 295 |
+
self.entropy_scale = entropy_scale
|
| 296 |
+
if return_normalization_decay is not NotProvided:
|
| 297 |
+
self.return_normalization_decay = return_normalization_decay
|
| 298 |
+
if train_critic is not NotProvided:
|
| 299 |
+
self.train_critic = train_critic
|
| 300 |
+
if train_actor is not NotProvided:
|
| 301 |
+
self.train_actor = train_actor
|
| 302 |
+
if intrinsic_rewards_scale is not NotProvided:
|
| 303 |
+
self.intrinsic_rewards_scale = intrinsic_rewards_scale
|
| 304 |
+
if world_model_lr is not NotProvided:
|
| 305 |
+
self.world_model_lr = world_model_lr
|
| 306 |
+
if actor_lr is not NotProvided:
|
| 307 |
+
self.actor_lr = actor_lr
|
| 308 |
+
if critic_lr is not NotProvided:
|
| 309 |
+
self.critic_lr = critic_lr
|
| 310 |
+
if world_model_grad_clip_by_global_norm is not NotProvided:
|
| 311 |
+
self.world_model_grad_clip_by_global_norm = (
|
| 312 |
+
world_model_grad_clip_by_global_norm
|
| 313 |
+
)
|
| 314 |
+
if critic_grad_clip_by_global_norm is not NotProvided:
|
| 315 |
+
self.critic_grad_clip_by_global_norm = critic_grad_clip_by_global_norm
|
| 316 |
+
if actor_grad_clip_by_global_norm is not NotProvided:
|
| 317 |
+
self.actor_grad_clip_by_global_norm = actor_grad_clip_by_global_norm
|
| 318 |
+
if symlog_obs is not NotProvided:
|
| 319 |
+
self.symlog_obs = symlog_obs
|
| 320 |
+
if use_float16 is not NotProvided:
|
| 321 |
+
self.use_float16 = use_float16
|
| 322 |
+
if replay_buffer_config is not NotProvided:
|
| 323 |
+
# Override entire `replay_buffer_config` if `type` key changes.
|
| 324 |
+
# Update, if `type` key remains the same or is not specified.
|
| 325 |
+
new_replay_buffer_config = deep_update(
|
| 326 |
+
{"replay_buffer_config": self.replay_buffer_config},
|
| 327 |
+
{"replay_buffer_config": replay_buffer_config},
|
| 328 |
+
False,
|
| 329 |
+
["replay_buffer_config"],
|
| 330 |
+
["replay_buffer_config"],
|
| 331 |
+
)
|
| 332 |
+
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
|
| 333 |
+
|
| 334 |
+
return self
|
| 335 |
+
|
| 336 |
+
@override(AlgorithmConfig)
|
| 337 |
+
def reporting(
|
| 338 |
+
self,
|
| 339 |
+
*,
|
| 340 |
+
report_individual_batch_item_stats: Optional[bool] = NotProvided,
|
| 341 |
+
report_dream_data: Optional[bool] = NotProvided,
|
| 342 |
+
report_images_and_videos: Optional[bool] = NotProvided,
|
| 343 |
+
**kwargs,
|
| 344 |
+
):
|
| 345 |
+
"""Sets the reporting related configuration.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
report_individual_batch_item_stats: Whether to include loss and other stats
|
| 349 |
+
per individual timestep inside the training batch in the result dict
|
| 350 |
+
returned by `training_step()`. If True, besides the `CRITIC_L_total`,
|
| 351 |
+
the individual critic loss values per batch row and time axis step
|
| 352 |
+
in the train batch (CRITIC_L_total_B_T) will also be part of the
|
| 353 |
+
results.
|
| 354 |
+
report_dream_data: Whether to include the dreamed trajectory data in the
|
| 355 |
+
result dict returned by `training_step()`. If True, however, will
|
| 356 |
+
slice each reported item in the dream data down to the shape.
|
| 357 |
+
(H, B, t=0, ...), where H is the horizon and B is the batch size. The
|
| 358 |
+
original time axis will only be represented by the first timestep
|
| 359 |
+
to not make this data too large to handle.
|
| 360 |
+
report_images_and_videos: Whether to include any image/video data in the
|
| 361 |
+
result dict returned by `training_step()`.
|
| 362 |
+
**kwargs:
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
This updated AlgorithmConfig object.
|
| 366 |
+
"""
|
| 367 |
+
super().reporting(**kwargs)
|
| 368 |
+
|
| 369 |
+
if report_individual_batch_item_stats is not NotProvided:
|
| 370 |
+
self.report_individual_batch_item_stats = report_individual_batch_item_stats
|
| 371 |
+
if report_dream_data is not NotProvided:
|
| 372 |
+
self.report_dream_data = report_dream_data
|
| 373 |
+
if report_images_and_videos is not NotProvided:
|
| 374 |
+
self.report_images_and_videos = report_images_and_videos
|
| 375 |
+
|
| 376 |
+
return self
|
| 377 |
+
|
| 378 |
+
@override(AlgorithmConfig)
|
| 379 |
+
def validate(self) -> None:
|
| 380 |
+
# Call the super class' validation method first.
|
| 381 |
+
super().validate()
|
| 382 |
+
|
| 383 |
+
# Make sure, users are not using DreamerV3 yet for multi-agent:
|
| 384 |
+
if self.is_multi_agent:
|
| 385 |
+
self._value_error("DreamerV3 does NOT support multi-agent setups yet!")
|
| 386 |
+
|
| 387 |
+
# Make sure, we are configure for the new API stack.
|
| 388 |
+
if not self.enable_rl_module_and_learner:
|
| 389 |
+
self._value_error(
|
| 390 |
+
"DreamerV3 must be run with `config.api_stack("
|
| 391 |
+
"enable_rl_module_and_learner=True)`!"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# If run on several Learners, the provided batch_size_B must be a multiple
|
| 395 |
+
# of `num_learners`.
|
| 396 |
+
if self.num_learners > 1 and (self.batch_size_B % self.num_learners != 0):
|
| 397 |
+
self._value_error(
|
| 398 |
+
f"Your `batch_size_B` ({self.batch_size_B}) must be a multiple of "
|
| 399 |
+
f"`num_learners` ({self.num_learners}) in order for "
|
| 400 |
+
"DreamerV3 to be able to split batches evenly across your Learner "
|
| 401 |
+
"processes."
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Cannot train actor w/o critic.
|
| 405 |
+
if self.train_actor and not self.train_critic:
|
| 406 |
+
self._value_error(
|
| 407 |
+
"Cannot train actor network (`train_actor=True`) w/o training critic! "
|
| 408 |
+
"Make sure you either set `train_critic=True` or `train_actor=False`."
|
| 409 |
+
)
|
| 410 |
+
# Use DreamerV3 specific batch size settings.
|
| 411 |
+
if self.train_batch_size is not None:
|
| 412 |
+
self._value_error(
|
| 413 |
+
"`train_batch_size` should NOT be set! Use `batch_size_B` and "
|
| 414 |
+
"`batch_length_T` instead."
|
| 415 |
+
)
|
| 416 |
+
# Must be run with `EpisodeReplayBuffer` type.
|
| 417 |
+
if self.replay_buffer_config.get("type") != "EpisodeReplayBuffer":
|
| 418 |
+
self._value_error(
|
| 419 |
+
"DreamerV3 must be run with the `EpisodeReplayBuffer` type! None "
|
| 420 |
+
"other supported."
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
@override(AlgorithmConfig)
|
| 424 |
+
def get_default_learner_class(self):
|
| 425 |
+
if self.framework_str == "tf2":
|
| 426 |
+
from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_learner import (
|
| 427 |
+
DreamerV3TfLearner,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
return DreamerV3TfLearner
|
| 431 |
+
else:
|
| 432 |
+
raise ValueError(f"The framework {self.framework_str} is not supported.")
|
| 433 |
+
|
| 434 |
+
@override(AlgorithmConfig)
|
| 435 |
+
def get_default_rl_module_spec(self) -> RLModuleSpec:
|
| 436 |
+
if self.framework_str == "tf2":
|
| 437 |
+
from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_rl_module import (
|
| 438 |
+
DreamerV3TfRLModule,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
return RLModuleSpec(
|
| 442 |
+
module_class=DreamerV3TfRLModule, catalog_class=DreamerV3Catalog
|
| 443 |
+
)
|
| 444 |
+
else:
|
| 445 |
+
raise ValueError(f"The framework {self.framework_str} is not supported.")
|
| 446 |
+
|
| 447 |
+
@property
|
| 448 |
+
def share_module_between_env_runner_and_learner(self) -> bool:
|
| 449 |
+
# If we only have one local Learner (num_learners=0) and only
|
| 450 |
+
# one local EnvRunner (num_env_runners=0), share the RLModule
|
| 451 |
+
# between these two to avoid having to sync weights, ever.
|
| 452 |
+
return self.num_learners == 0 and self.num_env_runners == 0
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
@override(AlgorithmConfig)
|
| 456 |
+
def _model_config_auto_includes(self) -> Dict[str, Any]:
|
| 457 |
+
return super()._model_config_auto_includes | {
|
| 458 |
+
"gamma": self.gamma,
|
| 459 |
+
"horizon_H": self.horizon_H,
|
| 460 |
+
"model_size": self.model_size,
|
| 461 |
+
"symlog_obs": self.symlog_obs,
|
| 462 |
+
"use_float16": self.use_float16,
|
| 463 |
+
"batch_length_T": self.batch_length_T,
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class DreamerV3(Algorithm):
|
| 468 |
+
"""Implementation of the model-based DreamerV3 RL algorithm described in [1]."""
|
| 469 |
+
|
| 470 |
+
# TODO (sven): Deprecate/do-over the Algorithm.compute_single_action() API.
|
| 471 |
+
@override(Algorithm)
|
| 472 |
+
def compute_single_action(self, *args, **kwargs):
|
| 473 |
+
raise NotImplementedError(
|
| 474 |
+
"DreamerV3 does not support the `compute_single_action()` API. Refer to the"
|
| 475 |
+
" README here (https://github.com/ray-project/ray/tree/master/rllib/"
|
| 476 |
+
"algorithms/dreamerv3) to find more information on how to run action "
|
| 477 |
+
"inference with this algorithm."
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
@classmethod
|
| 481 |
+
@override(Algorithm)
|
| 482 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 483 |
+
return DreamerV3Config()
|
| 484 |
+
|
| 485 |
+
@override(Algorithm)
|
| 486 |
+
def setup(self, config: AlgorithmConfig):
|
| 487 |
+
super().setup(config)
|
| 488 |
+
|
| 489 |
+
# Share RLModule between EnvRunner and single (local) Learner instance.
|
| 490 |
+
# To avoid possibly expensive weight synching step.
|
| 491 |
+
if self.config.share_module_between_env_runner_and_learner:
|
| 492 |
+
assert self.env_runner.module is None
|
| 493 |
+
self.env_runner.module = self.learner_group._learner.module[
|
| 494 |
+
DEFAULT_MODULE_ID
|
| 495 |
+
]
|
| 496 |
+
|
| 497 |
+
# Summarize (single-agent) RLModule (only once) here.
|
| 498 |
+
if self.config.framework_str == "tf2":
|
| 499 |
+
self.env_runner.module.dreamer_model.summary(expand_nested=True)
|
| 500 |
+
|
| 501 |
+
# Create a replay buffer for storing actual env samples.
|
| 502 |
+
self.replay_buffer = EpisodeReplayBuffer(
|
| 503 |
+
capacity=self.config.replay_buffer_config["capacity"],
|
| 504 |
+
batch_size_B=self.config.batch_size_B,
|
| 505 |
+
batch_length_T=self.config.batch_length_T,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
@override(Algorithm)
|
| 509 |
+
def training_step(self) -> None:
|
| 510 |
+
# Push enough samples into buffer initially before we start training.
|
| 511 |
+
if self.training_iteration == 0:
|
| 512 |
+
logger.info(
|
| 513 |
+
"Filling replay buffer so it contains at least "
|
| 514 |
+
f"{self.config.batch_size_B * self.config.batch_length_T} timesteps "
|
| 515 |
+
"(required for a single train batch)."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Have we sampled yet in this `training_step()` call?
|
| 519 |
+
have_sampled = False
|
| 520 |
+
with self.metrics.log_time((TIMERS, SAMPLE_TIMER)):
|
| 521 |
+
# Continue sampling from the actual environment (and add collected samples
|
| 522 |
+
# to our replay buffer) as long as we:
|
| 523 |
+
while (
|
| 524 |
+
# a) Don't have at least batch_size_B x batch_length_T timesteps stored
|
| 525 |
+
# in the buffer. This is the minimum needed to train.
|
| 526 |
+
self.replay_buffer.get_num_timesteps()
|
| 527 |
+
< (self.config.batch_size_B * self.config.batch_length_T)
|
| 528 |
+
# b) The computed `training_ratio` is >= the configured (desired)
|
| 529 |
+
# training ratio (meaning we should continue sampling).
|
| 530 |
+
or self.training_ratio >= self.config.training_ratio
|
| 531 |
+
# c) we have not sampled at all yet in this `training_step()` call.
|
| 532 |
+
or not have_sampled
|
| 533 |
+
):
|
| 534 |
+
# Sample using the env runner's module.
|
| 535 |
+
episodes, env_runner_results = synchronous_parallel_sample(
|
| 536 |
+
worker_set=self.env_runner_group,
|
| 537 |
+
max_agent_steps=(
|
| 538 |
+
self.config.rollout_fragment_length
|
| 539 |
+
* self.config.num_envs_per_env_runner
|
| 540 |
+
),
|
| 541 |
+
sample_timeout_s=self.config.sample_timeout_s,
|
| 542 |
+
_uses_new_env_runners=True,
|
| 543 |
+
_return_metrics=True,
|
| 544 |
+
)
|
| 545 |
+
self.metrics.merge_and_log_n_dicts(
|
| 546 |
+
env_runner_results, key=ENV_RUNNER_RESULTS
|
| 547 |
+
)
|
| 548 |
+
# Add ongoing and finished episodes into buffer. The buffer will
|
| 549 |
+
# automatically take care of properly concatenating (by episode IDs)
|
| 550 |
+
# the different chunks of the same episodes, even if they come in via
|
| 551 |
+
# separate `add()` calls.
|
| 552 |
+
self.replay_buffer.add(episodes=episodes)
|
| 553 |
+
have_sampled = True
|
| 554 |
+
|
| 555 |
+
# We took B x T env steps.
|
| 556 |
+
env_steps_last_regular_sample = sum(len(eps) for eps in episodes)
|
| 557 |
+
total_sampled = env_steps_last_regular_sample
|
| 558 |
+
|
| 559 |
+
# If we have never sampled before (just started the algo and not
|
| 560 |
+
# recovered from a checkpoint), sample B random actions first.
|
| 561 |
+
if (
|
| 562 |
+
self.metrics.peek(
|
| 563 |
+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
|
| 564 |
+
default=0,
|
| 565 |
+
)
|
| 566 |
+
== 0
|
| 567 |
+
):
|
| 568 |
+
_episodes, _env_runner_results = synchronous_parallel_sample(
|
| 569 |
+
worker_set=self.env_runner_group,
|
| 570 |
+
max_agent_steps=(
|
| 571 |
+
self.config.batch_size_B * self.config.batch_length_T
|
| 572 |
+
- env_steps_last_regular_sample
|
| 573 |
+
),
|
| 574 |
+
sample_timeout_s=self.config.sample_timeout_s,
|
| 575 |
+
random_actions=True,
|
| 576 |
+
_uses_new_env_runners=True,
|
| 577 |
+
_return_metrics=True,
|
| 578 |
+
)
|
| 579 |
+
self.metrics.merge_and_log_n_dicts(
|
| 580 |
+
_env_runner_results, key=ENV_RUNNER_RESULTS
|
| 581 |
+
)
|
| 582 |
+
self.replay_buffer.add(episodes=_episodes)
|
| 583 |
+
total_sampled += sum(len(eps) for eps in _episodes)
|
| 584 |
+
|
| 585 |
+
# Summarize environment interaction and buffer data.
|
| 586 |
+
report_sampling_and_replay_buffer(
|
| 587 |
+
metrics=self.metrics, replay_buffer=self.replay_buffer
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Continue sampling batch_size_B x batch_length_T sized batches from the buffer
|
| 591 |
+
# and using these to update our models (`LearnerGroup.update_from_batch()`)
|
| 592 |
+
# until the computed `training_ratio` is larger than the configured one, meaning
|
| 593 |
+
# we should go back and collect more samples again from the actual environment.
|
| 594 |
+
# However, when calculating the `training_ratio` here, we use only the
|
| 595 |
+
# trained steps in this very `training_step()` call over the most recent sample
|
| 596 |
+
# amount (`env_steps_last_regular_sample`), not the global values. This is to
|
| 597 |
+
# avoid a heavy overtraining at the very beginning when we have just pre-filled
|
| 598 |
+
# the buffer with the minimum amount of samples.
|
| 599 |
+
replayed_steps_this_iter = sub_iter = 0
|
| 600 |
+
while (
|
| 601 |
+
replayed_steps_this_iter / env_steps_last_regular_sample
|
| 602 |
+
) < self.config.training_ratio:
|
| 603 |
+
# Time individual batch updates.
|
| 604 |
+
with self.metrics.log_time((TIMERS, LEARN_ON_BATCH_TIMER)):
|
| 605 |
+
logger.info(f"\tSub-iteration {self.training_iteration}/{sub_iter})")
|
| 606 |
+
|
| 607 |
+
# Draw a new sample from the replay buffer.
|
| 608 |
+
sample = self.replay_buffer.sample(
|
| 609 |
+
batch_size_B=self.config.batch_size_B,
|
| 610 |
+
batch_length_T=self.config.batch_length_T,
|
| 611 |
+
)
|
| 612 |
+
replayed_steps = self.config.batch_size_B * self.config.batch_length_T
|
| 613 |
+
replayed_steps_this_iter += replayed_steps
|
| 614 |
+
|
| 615 |
+
if isinstance(
|
| 616 |
+
self.env_runner.env.single_action_space, gym.spaces.Discrete
|
| 617 |
+
):
|
| 618 |
+
sample["actions_ints"] = sample[Columns.ACTIONS]
|
| 619 |
+
sample[Columns.ACTIONS] = one_hot(
|
| 620 |
+
sample["actions_ints"],
|
| 621 |
+
depth=self.env_runner.env.single_action_space.n,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# Perform the actual update via our learner group.
|
| 625 |
+
learner_results = self.learner_group.update_from_batch(
|
| 626 |
+
batch=SampleBatch(sample).as_multi_agent(),
|
| 627 |
+
# TODO(sven): Maybe we should do this broadcase of global timesteps
|
| 628 |
+
# at the end, like for EnvRunner global env step counts. Maybe when
|
| 629 |
+
# we request the state from the Learners, we can - at the same
|
| 630 |
+
# time - send the current globally summed/reduced-timesteps.
|
| 631 |
+
timesteps={
|
| 632 |
+
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
|
| 633 |
+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
|
| 634 |
+
default=0,
|
| 635 |
+
)
|
| 636 |
+
},
|
| 637 |
+
)
|
| 638 |
+
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
|
| 639 |
+
|
| 640 |
+
sub_iter += 1
|
| 641 |
+
self.metrics.log_value(NUM_GRAD_UPDATES_LIFETIME, 1, reduce="sum")
|
| 642 |
+
|
| 643 |
+
# Log videos showing how the decoder produces observation predictions
|
| 644 |
+
# from the posterior states.
|
| 645 |
+
# Only every n iterations and only for the first sampled batch row
|
| 646 |
+
# (videos are `config.batch_length_T` frames long).
|
| 647 |
+
report_predicted_vs_sampled_obs(
|
| 648 |
+
# TODO (sven): DreamerV3 is single-agent only.
|
| 649 |
+
metrics=self.metrics,
|
| 650 |
+
sample=sample,
|
| 651 |
+
batch_size_B=self.config.batch_size_B,
|
| 652 |
+
batch_length_T=self.config.batch_length_T,
|
| 653 |
+
symlog_obs=do_symlog_obs(
|
| 654 |
+
self.env_runner.env.single_observation_space,
|
| 655 |
+
self.config.symlog_obs,
|
| 656 |
+
),
|
| 657 |
+
do_report=(
|
| 658 |
+
self.config.report_images_and_videos
|
| 659 |
+
and self.training_iteration % 100 == 0
|
| 660 |
+
),
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Log videos showing some of the dreamed trajectories and compare them with the
|
| 664 |
+
# actual trajectories from the train batch.
|
| 665 |
+
# Only every n iterations and only for the first sampled batch row AND first ts.
|
| 666 |
+
# (videos are `config.horizon_H` frames long originating from the observation
|
| 667 |
+
# at B=0 and T=0 in the train batch).
|
| 668 |
+
report_dreamed_eval_trajectory_vs_samples(
|
| 669 |
+
metrics=self.metrics,
|
| 670 |
+
sample=sample,
|
| 671 |
+
burn_in_T=0,
|
| 672 |
+
dreamed_T=self.config.horizon_H + 1,
|
| 673 |
+
dreamer_model=self.env_runner.module.dreamer_model,
|
| 674 |
+
symlog_obs=do_symlog_obs(
|
| 675 |
+
self.env_runner.env.single_observation_space,
|
| 676 |
+
self.config.symlog_obs,
|
| 677 |
+
),
|
| 678 |
+
do_report=(
|
| 679 |
+
self.config.report_dream_data and self.training_iteration % 100 == 0
|
| 680 |
+
),
|
| 681 |
+
framework=self.config.framework_str,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# Update weights - after learning on the LearnerGroup - on all EnvRunner
|
| 685 |
+
# workers.
|
| 686 |
+
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
|
| 687 |
+
# Only necessary if RLModule is not shared between (local) EnvRunner and
|
| 688 |
+
# (local) Learner.
|
| 689 |
+
if not self.config.share_module_between_env_runner_and_learner:
|
| 690 |
+
self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum")
|
| 691 |
+
self.env_runner_group.sync_weights(
|
| 692 |
+
from_worker_or_learner_group=self.learner_group,
|
| 693 |
+
inference_only=True,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Try trick from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-
|
| 697 |
+
# issue-in-keras-model-training-e703907a6501
|
| 698 |
+
if self.config.gc_frequency_train_steps and (
|
| 699 |
+
self.training_iteration % self.config.gc_frequency_train_steps == 0
|
| 700 |
+
):
|
| 701 |
+
with self.metrics.log_time((TIMERS, GARBAGE_COLLECTION_TIMER)):
|
| 702 |
+
gc.collect()
|
| 703 |
+
|
| 704 |
+
# Add train results and the actual training ratio to stats. The latter should
|
| 705 |
+
# be close to the configured `training_ratio`.
|
| 706 |
+
self.metrics.log_value("actual_training_ratio", self.training_ratio, window=1)
|
| 707 |
+
|
| 708 |
+
@property
|
| 709 |
+
def training_ratio(self) -> float:
|
| 710 |
+
"""Returns the actual training ratio of this Algorithm (not the configured one).
|
| 711 |
+
|
| 712 |
+
The training ratio is copmuted by dividing the total number of steps
|
| 713 |
+
trained thus far (replayed from the buffer) over the total number of actual
|
| 714 |
+
env steps taken thus far.
|
| 715 |
+
"""
|
| 716 |
+
eps = 0.0001
|
| 717 |
+
return self.metrics.peek(NUM_ENV_STEPS_TRAINED_LIFETIME, default=0) / (
|
| 718 |
+
(
|
| 719 |
+
self.metrics.peek(
|
| 720 |
+
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
|
| 721 |
+
default=eps,
|
| 722 |
+
)
|
| 723 |
+
or eps
|
| 724 |
+
)
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
# TODO (sven): Remove this once DreamerV3 is on the new SingleAgentEnvRunner.
|
| 728 |
+
@PublicAPI
|
| 729 |
+
def __setstate__(self, state) -> None:
|
| 730 |
+
"""Sts the algorithm to the provided state
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
state: The state dictionary to restore this `DreamerV3` instance to.
|
| 734 |
+
`state` may have been returned by a call to an `Algorithm`'s
|
| 735 |
+
`__getstate__()` method.
|
| 736 |
+
"""
|
| 737 |
+
# Call the `Algorithm`'s `__setstate__()` method.
|
| 738 |
+
super().__setstate__(state=state)
|
| 739 |
+
|
| 740 |
+
# Assign the module to the local `EnvRunner` if sharing is enabled.
|
| 741 |
+
# Note, in `Learner.restore_from_path()` the module is first deleted
|
| 742 |
+
# and then a new one is built - therefore the worker has no
|
| 743 |
+
# longer a copy of the learner.
|
| 744 |
+
if self.config.share_module_between_env_runner_and_learner:
|
| 745 |
+
assert id(self.env_runner.module) != id(
|
| 746 |
+
self.learner_group._learner.module[DEFAULT_MODULE_ID]
|
| 747 |
+
)
|
| 748 |
+
self.env_runner.module = self.learner_group._learner.module[
|
| 749 |
+
DEFAULT_MODULE_ID
|
| 750 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_catalog.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
|
| 3 |
+
from ray.rllib.core.models.catalog import Catalog
|
| 4 |
+
from ray.rllib.core.models.base import Encoder, Model
|
| 5 |
+
from ray.rllib.utils import override
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DreamerV3Catalog(Catalog):
|
| 9 |
+
"""The Catalog class used to build all the models needed for DreamerV3 training."""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
observation_space: gym.Space,
|
| 14 |
+
action_space: gym.Space,
|
| 15 |
+
model_config_dict: dict,
|
| 16 |
+
):
|
| 17 |
+
"""Initializes a DreamerV3Catalog instance.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
observation_space: The observation space of the environment.
|
| 21 |
+
action_space: The action space of the environment.
|
| 22 |
+
model_config_dict: The model config to use.
|
| 23 |
+
"""
|
| 24 |
+
super().__init__(
|
| 25 |
+
observation_space=observation_space,
|
| 26 |
+
action_space=action_space,
|
| 27 |
+
model_config_dict=model_config_dict,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.model_size = self._model_config_dict["model_size"]
|
| 31 |
+
self.is_img_space = len(self.observation_space.shape) in [2, 3]
|
| 32 |
+
self.is_gray_scale = (
|
| 33 |
+
self.is_img_space and len(self.observation_space.shape) == 2
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# TODO (sven): We should work with sub-component configurations here,
|
| 37 |
+
# and even try replacing all current Dreamer model components with
|
| 38 |
+
# our default primitives. But for now, we'll construct the DreamerV3Model
|
| 39 |
+
# directly in our `build_...()` methods.
|
| 40 |
+
|
| 41 |
+
@override(Catalog)
|
| 42 |
+
def build_encoder(self, framework: str) -> Encoder:
|
| 43 |
+
"""Builds the World-Model's encoder network depending on the obs space."""
|
| 44 |
+
if framework != "tf2":
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
if self.is_img_space:
|
| 48 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.cnn_atari import (
|
| 49 |
+
CNNAtari,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
return CNNAtari(model_size=self.model_size)
|
| 53 |
+
else:
|
| 54 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 55 |
+
|
| 56 |
+
return MLP(model_size=self.model_size, name="vector_encoder")
|
| 57 |
+
|
| 58 |
+
def build_decoder(self, framework: str) -> Model:
|
| 59 |
+
"""Builds the World-Model's decoder network depending on the obs space."""
|
| 60 |
+
if framework != "tf2":
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
if self.is_img_space:
|
| 64 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components import (
|
| 65 |
+
conv_transpose_atari,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return conv_transpose_atari.ConvTransposeAtari(
|
| 69 |
+
model_size=self.model_size,
|
| 70 |
+
gray_scaled=self.is_gray_scale,
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components import (
|
| 74 |
+
vector_decoder,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return vector_decoder.VectorDecoder(
|
| 78 |
+
model_size=self.model_size,
|
| 79 |
+
observation_space=self.observation_space,
|
| 80 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_learner.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
|
| 6 |
+
[2] Mastering Atari with Discrete World Models - 2021
|
| 7 |
+
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
|
| 8 |
+
https://arxiv.org/pdf/2010.02193.pdf
|
| 9 |
+
"""
|
| 10 |
+
from ray.rllib.core.learner.learner import Learner
|
| 11 |
+
from ray.rllib.utils.annotations import (
|
| 12 |
+
override,
|
| 13 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DreamerV3Learner(Learner):
|
| 18 |
+
"""DreamerV3 specific Learner class.
|
| 19 |
+
|
| 20 |
+
Only implements the `after_gradient_based_update()` method to define the logic
|
| 21 |
+
for updating the critic EMA-copy after each training step.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 25 |
+
@override(Learner)
|
| 26 |
+
def after_gradient_based_update(self, *, timesteps):
|
| 27 |
+
super().after_gradient_based_update(timesteps=timesteps)
|
| 28 |
+
|
| 29 |
+
# Update EMA weights of the critic.
|
| 30 |
+
for module_id, module in self.module._rl_modules.items():
|
| 31 |
+
module.critic.update_ema()
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file holds framework-agnostic components for DreamerV3's RLModule.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import abc
|
| 6 |
+
from typing import Any, Dict
|
| 7 |
+
|
| 8 |
+
import gymnasium as gym
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs
|
| 12 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.actor_network import ActorNetwork
|
| 13 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.critic_network import CriticNetwork
|
| 14 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.dreamer_model import DreamerModel
|
| 15 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel
|
| 16 |
+
from ray.rllib.core.columns import Columns
|
| 17 |
+
from ray.rllib.core.rl_module.rl_module import RLModule
|
| 18 |
+
from ray.rllib.policy.eager_tf_policy import _convert_to_tf
|
| 19 |
+
from ray.rllib.utils.annotations import override
|
| 20 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 21 |
+
from ray.rllib.utils.numpy import one_hot
|
| 22 |
+
from ray.util.annotations import DeveloperAPI
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
_, tf, _ = try_import_tf()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@DeveloperAPI(stability="alpha")
|
| 29 |
+
class DreamerV3RLModule(RLModule, abc.ABC):
|
| 30 |
+
@override(RLModule)
|
| 31 |
+
def setup(self):
|
| 32 |
+
super().setup()
|
| 33 |
+
|
| 34 |
+
# Gather model-relevant settings.
|
| 35 |
+
B = 1
|
| 36 |
+
T = self.model_config["batch_length_T"]
|
| 37 |
+
horizon_H = self.model_config["horizon_H"]
|
| 38 |
+
gamma = self.model_config["gamma"]
|
| 39 |
+
symlog_obs = do_symlog_obs(
|
| 40 |
+
self.observation_space,
|
| 41 |
+
self.model_config.get("symlog_obs", "auto"),
|
| 42 |
+
)
|
| 43 |
+
model_size = self.model_config["model_size"]
|
| 44 |
+
|
| 45 |
+
if self.model_config["use_float16"]:
|
| 46 |
+
tf.compat.v1.keras.layers.enable_v2_dtype_behavior()
|
| 47 |
+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
|
| 48 |
+
|
| 49 |
+
# Build encoder and decoder from catalog.
|
| 50 |
+
self.encoder = self.catalog.build_encoder(framework=self.framework)
|
| 51 |
+
self.decoder = self.catalog.build_decoder(framework=self.framework)
|
| 52 |
+
|
| 53 |
+
# Build the world model (containing encoder and decoder).
|
| 54 |
+
self.world_model = WorldModel(
|
| 55 |
+
model_size=model_size,
|
| 56 |
+
observation_space=self.observation_space,
|
| 57 |
+
action_space=self.action_space,
|
| 58 |
+
batch_length_T=T,
|
| 59 |
+
encoder=self.encoder,
|
| 60 |
+
decoder=self.decoder,
|
| 61 |
+
symlog_obs=symlog_obs,
|
| 62 |
+
)
|
| 63 |
+
self.actor = ActorNetwork(
|
| 64 |
+
action_space=self.action_space,
|
| 65 |
+
model_size=model_size,
|
| 66 |
+
)
|
| 67 |
+
self.critic = CriticNetwork(
|
| 68 |
+
model_size=model_size,
|
| 69 |
+
)
|
| 70 |
+
# Build the final dreamer model (containing the world model).
|
| 71 |
+
self.dreamer_model = DreamerModel(
|
| 72 |
+
model_size=self.model_config["model_size"],
|
| 73 |
+
action_space=self.action_space,
|
| 74 |
+
world_model=self.world_model,
|
| 75 |
+
actor=self.actor,
|
| 76 |
+
critic=self.critic,
|
| 77 |
+
horizon=horizon_H,
|
| 78 |
+
gamma=gamma,
|
| 79 |
+
)
|
| 80 |
+
self.action_dist_cls = self.catalog.get_action_dist_cls(
|
| 81 |
+
framework=self.framework
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Perform a test `call()` to force building the dreamer model's variables.
|
| 85 |
+
if self.framework == "tf2":
|
| 86 |
+
test_obs = np.tile(
|
| 87 |
+
np.expand_dims(self.observation_space.sample(), (0, 1)),
|
| 88 |
+
reps=(B, T) + (1,) * len(self.observation_space.shape),
|
| 89 |
+
)
|
| 90 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 91 |
+
test_actions = np.tile(
|
| 92 |
+
np.expand_dims(
|
| 93 |
+
one_hot(
|
| 94 |
+
self.action_space.sample(),
|
| 95 |
+
depth=self.action_space.n,
|
| 96 |
+
),
|
| 97 |
+
(0, 1),
|
| 98 |
+
),
|
| 99 |
+
reps=(B, T, 1),
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
test_actions = np.tile(
|
| 103 |
+
np.expand_dims(self.action_space.sample(), (0, 1)),
|
| 104 |
+
reps=(B, T, 1),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.dreamer_model(
|
| 108 |
+
inputs=None,
|
| 109 |
+
observations=_convert_to_tf(test_obs, dtype=tf.float32),
|
| 110 |
+
actions=_convert_to_tf(test_actions, dtype=tf.float32),
|
| 111 |
+
is_first=_convert_to_tf(np.ones((B, T)), dtype=tf.bool),
|
| 112 |
+
start_is_terminated_BxT=_convert_to_tf(
|
| 113 |
+
np.zeros((B * T,)), dtype=tf.bool
|
| 114 |
+
),
|
| 115 |
+
gamma=gamma,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Initialize the critic EMA net:
|
| 119 |
+
self.critic.init_ema()
|
| 120 |
+
|
| 121 |
+
@override(RLModule)
|
| 122 |
+
def get_initial_state(self) -> Dict:
|
| 123 |
+
# Use `DreamerModel`'s `get_initial_state` method.
|
| 124 |
+
return self.dreamer_model.get_initial_state()
|
| 125 |
+
|
| 126 |
+
@override(RLModule)
|
| 127 |
+
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 128 |
+
# Call the Dreamer-Model's forward_inference method and return a dict.
|
| 129 |
+
actions, next_state = self.dreamer_model.forward_inference(
|
| 130 |
+
observations=batch[Columns.OBS],
|
| 131 |
+
previous_states=batch[Columns.STATE_IN],
|
| 132 |
+
is_first=batch["is_first"],
|
| 133 |
+
)
|
| 134 |
+
return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state}
|
| 135 |
+
|
| 136 |
+
@override(RLModule)
|
| 137 |
+
def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 138 |
+
# Call the Dreamer-Model's forward_exploration method and return a dict.
|
| 139 |
+
actions, next_state = self.dreamer_model.forward_exploration(
|
| 140 |
+
observations=batch[Columns.OBS],
|
| 141 |
+
previous_states=batch[Columns.STATE_IN],
|
| 142 |
+
is_first=batch["is_first"],
|
| 143 |
+
)
|
| 144 |
+
return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state}
|
| 145 |
+
|
| 146 |
+
@override(RLModule)
|
| 147 |
+
def _forward_train(self, batch: Dict[str, Any]):
|
| 148 |
+
# Call the Dreamer-Model's forward_train method and return its outputs as-is.
|
| 149 |
+
return self.dreamer_model.forward_train(
|
| 150 |
+
observations=batch[Columns.OBS],
|
| 151 |
+
actions=batch[Columns.ACTIONS],
|
| 152 |
+
is_first=batch["is_first"],
|
| 153 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_learner.cpython-311.pyc
ADDED
|
Binary file (32.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/__pycache__/dreamerv3_tf_rl_module.cpython-311.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
|
| 6 |
+
[2] Mastering Atari with Discrete World Models - 2021
|
| 7 |
+
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
|
| 8 |
+
https://arxiv.org/pdf/2010.02193.pdf
|
| 9 |
+
"""
|
| 10 |
+
from typing import Any, Dict, Tuple
|
| 11 |
+
|
| 12 |
+
import gymnasium as gym
|
| 13 |
+
|
| 14 |
+
from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config
|
| 15 |
+
from ray.rllib.algorithms.dreamerv3.dreamerv3_learner import DreamerV3Learner
|
| 16 |
+
from ray.rllib.core import DEFAULT_MODULE_ID
|
| 17 |
+
from ray.rllib.core.columns import Columns
|
| 18 |
+
from ray.rllib.core.learner.learner import ParamDict
|
| 19 |
+
from ray.rllib.core.learner.tf.tf_learner import TfLearner
|
| 20 |
+
from ray.rllib.utils.annotations import override
|
| 21 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
| 22 |
+
from ray.rllib.utils.tf_utils import symlog, two_hot, clip_gradients
|
| 23 |
+
from ray.rllib.utils.typing import ModuleID, TensorType
|
| 24 |
+
|
| 25 |
+
_, tf, _ = try_import_tf()
|
| 26 |
+
tfp = try_import_tfp()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DreamerV3TfLearner(DreamerV3Learner, TfLearner):
|
| 30 |
+
"""Implements DreamerV3 losses and gradient-based update logic in TensorFlow.
|
| 31 |
+
|
| 32 |
+
The critic EMA-copy update step can be found in the `DreamerV3Learner` base class,
|
| 33 |
+
as it is framework independent.
|
| 34 |
+
|
| 35 |
+
We define 3 local TensorFlow optimizers for the sub components "world_model",
|
| 36 |
+
"actor", and "critic". Each of these optimizers might use a different learning rate,
|
| 37 |
+
epsilon parameter, and gradient clipping thresholds and procedures.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
@override(TfLearner)
|
| 41 |
+
def configure_optimizers_for_module(
|
| 42 |
+
self, module_id: ModuleID, config: DreamerV3Config = None
|
| 43 |
+
):
|
| 44 |
+
"""Create the 3 optimizers for Dreamer learning: world_model, actor, critic.
|
| 45 |
+
|
| 46 |
+
The learning rates used are described in [1] and the epsilon values used here
|
| 47 |
+
- albeit probably not that important - are used by the author's own
|
| 48 |
+
implementation.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
dreamerv3_module = self._module[module_id]
|
| 52 |
+
|
| 53 |
+
# World Model optimizer.
|
| 54 |
+
optim_world_model = tf.keras.optimizers.Adam(epsilon=1e-8)
|
| 55 |
+
optim_world_model.build(dreamerv3_module.world_model.trainable_variables)
|
| 56 |
+
params_world_model = self.get_parameters(dreamerv3_module.world_model)
|
| 57 |
+
self.register_optimizer(
|
| 58 |
+
module_id=module_id,
|
| 59 |
+
optimizer_name="world_model",
|
| 60 |
+
optimizer=optim_world_model,
|
| 61 |
+
params=params_world_model,
|
| 62 |
+
lr_or_lr_schedule=config.world_model_lr,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Actor optimizer.
|
| 66 |
+
optim_actor = tf.keras.optimizers.Adam(epsilon=1e-5)
|
| 67 |
+
optim_actor.build(dreamerv3_module.actor.trainable_variables)
|
| 68 |
+
params_actor = self.get_parameters(dreamerv3_module.actor)
|
| 69 |
+
self.register_optimizer(
|
| 70 |
+
module_id=module_id,
|
| 71 |
+
optimizer_name="actor",
|
| 72 |
+
optimizer=optim_actor,
|
| 73 |
+
params=params_actor,
|
| 74 |
+
lr_or_lr_schedule=config.actor_lr,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Critic optimizer.
|
| 78 |
+
optim_critic = tf.keras.optimizers.Adam(epsilon=1e-5)
|
| 79 |
+
optim_critic.build(dreamerv3_module.critic.trainable_variables)
|
| 80 |
+
params_critic = self.get_parameters(dreamerv3_module.critic)
|
| 81 |
+
self.register_optimizer(
|
| 82 |
+
module_id=module_id,
|
| 83 |
+
optimizer_name="critic",
|
| 84 |
+
optimizer=optim_critic,
|
| 85 |
+
params=params_critic,
|
| 86 |
+
lr_or_lr_schedule=config.critic_lr,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@override(TfLearner)
|
| 90 |
+
def postprocess_gradients_for_module(
|
| 91 |
+
self,
|
| 92 |
+
*,
|
| 93 |
+
module_id: ModuleID,
|
| 94 |
+
config: DreamerV3Config,
|
| 95 |
+
module_gradients_dict: Dict[str, Any],
|
| 96 |
+
) -> ParamDict:
|
| 97 |
+
"""Performs gradient clipping on the 3 module components' computed grads.
|
| 98 |
+
|
| 99 |
+
Note that different grad global-norm clip values are used for the 3
|
| 100 |
+
module components: world model, actor, and critic.
|
| 101 |
+
"""
|
| 102 |
+
for optimizer_name, optimizer in self.get_optimizers_for_module(
|
| 103 |
+
module_id=module_id
|
| 104 |
+
):
|
| 105 |
+
grads_sub_dict = self.filter_param_dict_for_optimizer(
|
| 106 |
+
module_gradients_dict, optimizer
|
| 107 |
+
)
|
| 108 |
+
# Figure out, which grad clip setting to use.
|
| 109 |
+
grad_clip = (
|
| 110 |
+
config.world_model_grad_clip_by_global_norm
|
| 111 |
+
if optimizer_name == "world_model"
|
| 112 |
+
else config.actor_grad_clip_by_global_norm
|
| 113 |
+
if optimizer_name == "actor"
|
| 114 |
+
else config.critic_grad_clip_by_global_norm
|
| 115 |
+
)
|
| 116 |
+
global_norm = clip_gradients(
|
| 117 |
+
grads_sub_dict,
|
| 118 |
+
grad_clip=grad_clip,
|
| 119 |
+
grad_clip_by="global_norm",
|
| 120 |
+
)
|
| 121 |
+
module_gradients_dict.update(grads_sub_dict)
|
| 122 |
+
|
| 123 |
+
# DreamerV3 stats have the format: [WORLD_MODEL|ACTOR|CRITIC]_[stats name].
|
| 124 |
+
self.metrics.log_dict(
|
| 125 |
+
{
|
| 126 |
+
optimizer_name.upper() + "_gradients_global_norm": global_norm,
|
| 127 |
+
optimizer_name.upper()
|
| 128 |
+
+ "_gradients_maxabs_after_clipping": (
|
| 129 |
+
tf.reduce_max(
|
| 130 |
+
[
|
| 131 |
+
tf.reduce_max(tf.math.abs(g))
|
| 132 |
+
for g in grads_sub_dict.values()
|
| 133 |
+
]
|
| 134 |
+
)
|
| 135 |
+
),
|
| 136 |
+
},
|
| 137 |
+
key=module_id,
|
| 138 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return module_gradients_dict
|
| 142 |
+
|
| 143 |
+
@override(TfLearner)
|
| 144 |
+
def compute_gradients(
|
| 145 |
+
self,
|
| 146 |
+
loss_per_module,
|
| 147 |
+
gradient_tape,
|
| 148 |
+
**kwargs,
|
| 149 |
+
):
|
| 150 |
+
# Override of the default gradient computation method.
|
| 151 |
+
# For DreamerV3, we need to compute gradients over the individual loss terms
|
| 152 |
+
# as otherwise, the world model's parameters would have their gradients also
|
| 153 |
+
# be influenced by the actor- and critic loss terms/gradient computations.
|
| 154 |
+
grads = {}
|
| 155 |
+
for component in ["world_model", "actor", "critic"]:
|
| 156 |
+
grads.update(
|
| 157 |
+
gradient_tape.gradient(
|
| 158 |
+
# Take individual loss term from the registered metrics for
|
| 159 |
+
# the main module.
|
| 160 |
+
self.metrics.peek(
|
| 161 |
+
(DEFAULT_MODULE_ID, component.upper() + "_L_total")
|
| 162 |
+
),
|
| 163 |
+
self.filter_param_dict_for_optimizer(
|
| 164 |
+
self._params, self.get_optimizer(optimizer_name=component)
|
| 165 |
+
),
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
del gradient_tape
|
| 169 |
+
return grads
|
| 170 |
+
|
| 171 |
+
@override(TfLearner)
|
| 172 |
+
def compute_loss_for_module(
|
| 173 |
+
self,
|
| 174 |
+
module_id: ModuleID,
|
| 175 |
+
config: DreamerV3Config,
|
| 176 |
+
batch: Dict[str, TensorType],
|
| 177 |
+
fwd_out: Dict[str, TensorType],
|
| 178 |
+
) -> TensorType:
|
| 179 |
+
# World model losses.
|
| 180 |
+
prediction_losses = self._compute_world_model_prediction_losses(
|
| 181 |
+
config=config,
|
| 182 |
+
rewards_B_T=batch[Columns.REWARDS],
|
| 183 |
+
continues_B_T=(1.0 - tf.cast(batch["is_terminated"], tf.float32)),
|
| 184 |
+
fwd_out=fwd_out,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
(
|
| 188 |
+
L_dyn_B_T,
|
| 189 |
+
L_rep_B_T,
|
| 190 |
+
) = self._compute_world_model_dynamics_and_representation_loss(
|
| 191 |
+
config=config, fwd_out=fwd_out
|
| 192 |
+
)
|
| 193 |
+
L_dyn = tf.reduce_mean(L_dyn_B_T)
|
| 194 |
+
L_rep = tf.reduce_mean(L_rep_B_T)
|
| 195 |
+
# Make sure values for L_rep and L_dyn are the same (they only differ in their
|
| 196 |
+
# gradients).
|
| 197 |
+
tf.assert_equal(L_dyn, L_rep)
|
| 198 |
+
|
| 199 |
+
# Compute the actual total loss using fixed weights described in [1] eq. 4.
|
| 200 |
+
L_world_model_total_B_T = (
|
| 201 |
+
1.0 * prediction_losses["L_prediction_B_T"]
|
| 202 |
+
+ 0.5 * L_dyn_B_T
|
| 203 |
+
+ 0.1 * L_rep_B_T
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# In the paper, it says to sum up timesteps, and average over
|
| 207 |
+
# batch (see eq. 4 in [1]). But Danijar's implementation only does
|
| 208 |
+
# averaging (over B and T), so we'll do this here as well. This is generally
|
| 209 |
+
# true for all other loss terms as well (we'll always just average, no summing
|
| 210 |
+
# over T axis!).
|
| 211 |
+
L_world_model_total = tf.reduce_mean(L_world_model_total_B_T)
|
| 212 |
+
|
| 213 |
+
# Log world model loss stats.
|
| 214 |
+
self.metrics.log_dict(
|
| 215 |
+
{
|
| 216 |
+
"WORLD_MODEL_learned_initial_h": (
|
| 217 |
+
self.module[module_id].world_model.initial_h
|
| 218 |
+
),
|
| 219 |
+
# Prediction losses.
|
| 220 |
+
# Decoder (obs) loss.
|
| 221 |
+
"WORLD_MODEL_L_decoder": prediction_losses["L_decoder"],
|
| 222 |
+
# Reward loss.
|
| 223 |
+
"WORLD_MODEL_L_reward": prediction_losses["L_reward"],
|
| 224 |
+
# Continue loss.
|
| 225 |
+
"WORLD_MODEL_L_continue": prediction_losses["L_continue"],
|
| 226 |
+
# Total.
|
| 227 |
+
"WORLD_MODEL_L_prediction": prediction_losses["L_prediction"],
|
| 228 |
+
# Dynamics loss.
|
| 229 |
+
"WORLD_MODEL_L_dynamics": L_dyn,
|
| 230 |
+
# Representation loss.
|
| 231 |
+
"WORLD_MODEL_L_representation": L_rep,
|
| 232 |
+
# Total loss.
|
| 233 |
+
"WORLD_MODEL_L_total": L_world_model_total,
|
| 234 |
+
},
|
| 235 |
+
key=module_id,
|
| 236 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Add the predicted obs distributions for possible (video) summarization.
|
| 240 |
+
if config.report_images_and_videos:
|
| 241 |
+
self.metrics.log_value(
|
| 242 |
+
(module_id, "WORLD_MODEL_fwd_out_obs_distribution_means_b0xT"),
|
| 243 |
+
fwd_out["obs_distribution_means_BxT"][: self.config.batch_length_T],
|
| 244 |
+
reduce=None, # No reduction, we want the tensor to stay in-tact.
|
| 245 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if config.report_individual_batch_item_stats:
|
| 249 |
+
# Log important world-model loss stats.
|
| 250 |
+
self.metrics.log_dict(
|
| 251 |
+
{
|
| 252 |
+
"WORLD_MODEL_L_decoder_B_T": prediction_losses["L_decoder_B_T"],
|
| 253 |
+
"WORLD_MODEL_L_reward_B_T": prediction_losses["L_reward_B_T"],
|
| 254 |
+
"WORLD_MODEL_L_continue_B_T": prediction_losses["L_continue_B_T"],
|
| 255 |
+
"WORLD_MODEL_L_prediction_B_T": (
|
| 256 |
+
prediction_losses["L_prediction_B_T"]
|
| 257 |
+
),
|
| 258 |
+
"WORLD_MODEL_L_dynamics_B_T": L_dyn_B_T,
|
| 259 |
+
"WORLD_MODEL_L_representation_B_T": L_rep_B_T,
|
| 260 |
+
"WORLD_MODEL_L_total_B_T": L_world_model_total_B_T,
|
| 261 |
+
},
|
| 262 |
+
key=module_id,
|
| 263 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Dream trajectories starting in all internal states (h + z_posterior) that were
|
| 267 |
+
# computed during world model training.
|
| 268 |
+
# Everything goes in as BxT: We are starting a new dream trajectory at every
|
| 269 |
+
# actually encountered timestep in the batch, so we are creating B*T
|
| 270 |
+
# trajectories of len `horizon_H`.
|
| 271 |
+
dream_data = self.module[module_id].dreamer_model.dream_trajectory(
|
| 272 |
+
start_states={
|
| 273 |
+
"h": fwd_out["h_states_BxT"],
|
| 274 |
+
"z": fwd_out["z_posterior_states_BxT"],
|
| 275 |
+
},
|
| 276 |
+
start_is_terminated=tf.reshape(batch["is_terminated"], [-1]), # -> BxT
|
| 277 |
+
)
|
| 278 |
+
if config.report_dream_data:
|
| 279 |
+
# To reduce this massive amount of data a little, slice out a T=1 piece
|
| 280 |
+
# from each stats that has the shape (H, BxT), meaning convert e.g.
|
| 281 |
+
# `rewards_dreamed_t0_to_H_BxT` into `rewards_dreamed_t0_to_H_Bx1`.
|
| 282 |
+
# This will reduce the amount of data to be transferred and reported
|
| 283 |
+
# by the factor of `batch_length_T`.
|
| 284 |
+
self.metrics.log_dict(
|
| 285 |
+
{
|
| 286 |
+
# Replace 'T' with '1'.
|
| 287 |
+
key[:-1] + "1": value[:, :: config.batch_length_T]
|
| 288 |
+
for key, value in dream_data.items()
|
| 289 |
+
if key.endswith("H_BxT")
|
| 290 |
+
},
|
| 291 |
+
key=(module_id, "dream_data"),
|
| 292 |
+
reduce=None,
|
| 293 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
value_targets_t0_to_Hm1_BxT = self._compute_value_targets(
|
| 297 |
+
config=config,
|
| 298 |
+
# Learn critic in symlog'd space.
|
| 299 |
+
rewards_t0_to_H_BxT=dream_data["rewards_dreamed_t0_to_H_BxT"],
|
| 300 |
+
intrinsic_rewards_t1_to_H_BxT=(
|
| 301 |
+
dream_data["rewards_intrinsic_t1_to_H_B"]
|
| 302 |
+
if config.use_curiosity
|
| 303 |
+
else None
|
| 304 |
+
),
|
| 305 |
+
continues_t0_to_H_BxT=dream_data["continues_dreamed_t0_to_H_BxT"],
|
| 306 |
+
value_predictions_t0_to_H_BxT=dream_data["values_dreamed_t0_to_H_BxT"],
|
| 307 |
+
)
|
| 308 |
+
self.metrics.log_value(
|
| 309 |
+
key=(module_id, "VALUE_TARGETS_H_BxT"),
|
| 310 |
+
value=value_targets_t0_to_Hm1_BxT,
|
| 311 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
CRITIC_L_total = self._compute_critic_loss(
|
| 315 |
+
module_id=module_id,
|
| 316 |
+
config=config,
|
| 317 |
+
dream_data=dream_data,
|
| 318 |
+
value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT,
|
| 319 |
+
)
|
| 320 |
+
if config.train_actor:
|
| 321 |
+
ACTOR_L_total = self._compute_actor_loss(
|
| 322 |
+
module_id=module_id,
|
| 323 |
+
config=config,
|
| 324 |
+
dream_data=dream_data,
|
| 325 |
+
value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT,
|
| 326 |
+
)
|
| 327 |
+
else:
|
| 328 |
+
ACTOR_L_total = 0.0
|
| 329 |
+
|
| 330 |
+
# Return the total loss as a sum of all individual losses.
|
| 331 |
+
return L_world_model_total + CRITIC_L_total + ACTOR_L_total
|
| 332 |
+
|
| 333 |
+
def _compute_world_model_prediction_losses(
|
| 334 |
+
self,
|
| 335 |
+
*,
|
| 336 |
+
config: DreamerV3Config,
|
| 337 |
+
rewards_B_T: TensorType,
|
| 338 |
+
continues_B_T: TensorType,
|
| 339 |
+
fwd_out: Dict[str, TensorType],
|
| 340 |
+
) -> Dict[str, TensorType]:
|
| 341 |
+
"""Helper method computing all world-model related prediction losses.
|
| 342 |
+
|
| 343 |
+
Prediction losses are used to train the predictors of the world model, which
|
| 344 |
+
are: Reward predictor, continue predictor, and the decoder (which predicts
|
| 345 |
+
observations).
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
config: The DreamerV3Config to use.
|
| 349 |
+
rewards_B_T: The rewards batch in the shape (B, T) and of type float32.
|
| 350 |
+
continues_B_T: The continues batch in the shape (B, T) and of type float32
|
| 351 |
+
(1.0 -> continue; 0.0 -> end of episode).
|
| 352 |
+
fwd_out: The `forward_train` outputs of the DreamerV3RLModule.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
# Learn to produce symlog'd observation predictions.
|
| 356 |
+
# If symlog is disabled (e.g. for uint8 image inputs), `obs_symlog_BxT` is the
|
| 357 |
+
# same as `obs_BxT`.
|
| 358 |
+
obs_BxT = fwd_out["sampled_obs_symlog_BxT"]
|
| 359 |
+
obs_distr_means = fwd_out["obs_distribution_means_BxT"]
|
| 360 |
+
# In case we wanted to construct a distribution object from the fwd out data,
|
| 361 |
+
# we would have to do it like this:
|
| 362 |
+
# obs_distr = tfp.distributions.MultivariateNormalDiag(
|
| 363 |
+
# loc=obs_distr_means,
|
| 364 |
+
# # Scale == 1.0.
|
| 365 |
+
# # [2]: "Distributions The image predictor outputs the mean of a diagonal
|
| 366 |
+
# # Gaussian likelihood with **unit variance** ..."
|
| 367 |
+
# scale_diag=tf.ones_like(obs_distr_means),
|
| 368 |
+
# )
|
| 369 |
+
|
| 370 |
+
# Leave time dim folded (BxT) and flatten all other (e.g. image) dims.
|
| 371 |
+
obs_BxT = tf.reshape(obs_BxT, shape=[-1, tf.reduce_prod(obs_BxT.shape[1:])])
|
| 372 |
+
|
| 373 |
+
# Squared diff loss w/ sum(!) over all (already folded) obs dims.
|
| 374 |
+
# decoder_loss_BxT = SUM[ (obs_distr.loc - observations)^2 ]
|
| 375 |
+
# Note: This is described strangely in the paper (stating a neglogp loss here),
|
| 376 |
+
# but the author's own implementation actually uses simple MSE with the loc
|
| 377 |
+
# of the Gaussian.
|
| 378 |
+
decoder_loss_BxT = tf.reduce_sum(
|
| 379 |
+
tf.math.square(obs_distr_means - obs_BxT), axis=-1
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Unfold time rank back in.
|
| 383 |
+
decoder_loss_B_T = tf.reshape(
|
| 384 |
+
decoder_loss_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
|
| 385 |
+
)
|
| 386 |
+
L_decoder = tf.reduce_mean(decoder_loss_B_T)
|
| 387 |
+
|
| 388 |
+
# The FiniteDiscrete reward bucket distribution computed by our reward
|
| 389 |
+
# predictor.
|
| 390 |
+
# [B x num_buckets].
|
| 391 |
+
reward_logits_BxT = fwd_out["reward_logits_BxT"]
|
| 392 |
+
# Learn to produce symlog'd reward predictions.
|
| 393 |
+
rewards_symlog_B_T = symlog(tf.cast(rewards_B_T, tf.float32))
|
| 394 |
+
# Fold time dim.
|
| 395 |
+
rewards_symlog_BxT = tf.reshape(rewards_symlog_B_T, shape=[-1])
|
| 396 |
+
|
| 397 |
+
# Two-hot encode.
|
| 398 |
+
two_hot_rewards_symlog_BxT = two_hot(rewards_symlog_BxT)
|
| 399 |
+
# two_hot_rewards_symlog_BxT=[B*T, num_buckets]
|
| 400 |
+
reward_log_pred_BxT = reward_logits_BxT - tf.math.reduce_logsumexp(
|
| 401 |
+
reward_logits_BxT, axis=-1, keepdims=True
|
| 402 |
+
)
|
| 403 |
+
# Multiply with two-hot targets and neg.
|
| 404 |
+
reward_loss_two_hot_BxT = -tf.reduce_sum(
|
| 405 |
+
reward_log_pred_BxT * two_hot_rewards_symlog_BxT, axis=-1
|
| 406 |
+
)
|
| 407 |
+
# Unfold time rank back in.
|
| 408 |
+
reward_loss_two_hot_B_T = tf.reshape(
|
| 409 |
+
reward_loss_two_hot_BxT,
|
| 410 |
+
(config.batch_size_B_per_learner, config.batch_length_T),
|
| 411 |
+
)
|
| 412 |
+
L_reward_two_hot = tf.reduce_mean(reward_loss_two_hot_B_T)
|
| 413 |
+
|
| 414 |
+
# Probabilities that episode continues, computed by our continue predictor.
|
| 415 |
+
# [B]
|
| 416 |
+
continue_distr = fwd_out["continue_distribution_BxT"]
|
| 417 |
+
# -log(p) loss
|
| 418 |
+
# Fold time dim.
|
| 419 |
+
continues_BxT = tf.reshape(continues_B_T, shape=[-1])
|
| 420 |
+
continue_loss_BxT = -continue_distr.log_prob(continues_BxT)
|
| 421 |
+
# Unfold time rank back in.
|
| 422 |
+
continue_loss_B_T = tf.reshape(
|
| 423 |
+
continue_loss_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
|
| 424 |
+
)
|
| 425 |
+
L_continue = tf.reduce_mean(continue_loss_B_T)
|
| 426 |
+
|
| 427 |
+
# Sum all losses together as the "prediction" loss.
|
| 428 |
+
L_pred_B_T = decoder_loss_B_T + reward_loss_two_hot_B_T + continue_loss_B_T
|
| 429 |
+
L_pred = tf.reduce_mean(L_pred_B_T)
|
| 430 |
+
|
| 431 |
+
return {
|
| 432 |
+
"L_decoder_B_T": decoder_loss_B_T,
|
| 433 |
+
"L_decoder": L_decoder,
|
| 434 |
+
"L_reward": L_reward_two_hot,
|
| 435 |
+
"L_reward_B_T": reward_loss_two_hot_B_T,
|
| 436 |
+
"L_continue": L_continue,
|
| 437 |
+
"L_continue_B_T": continue_loss_B_T,
|
| 438 |
+
"L_prediction": L_pred,
|
| 439 |
+
"L_prediction_B_T": L_pred_B_T,
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
def _compute_world_model_dynamics_and_representation_loss(
|
| 443 |
+
self, *, config: DreamerV3Config, fwd_out: Dict[str, Any]
|
| 444 |
+
) -> Tuple[TensorType, TensorType]:
|
| 445 |
+
"""Helper method computing the world-model's dynamics and representation losses.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
config: The DreamerV3Config to use.
|
| 449 |
+
fwd_out: The `forward_train` outputs of the DreamerV3RLModule.
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
Tuple consisting of a) dynamics loss: Trains the prior network, predicting
|
| 453 |
+
z^ prior states from h-states and b) representation loss: Trains posterior
|
| 454 |
+
network, predicting z posterior states from h-states and (encoded)
|
| 455 |
+
observations.
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
# Actual distribution over stochastic internal states (z) produced by the
|
| 459 |
+
# encoder.
|
| 460 |
+
z_posterior_probs_BxT = fwd_out["z_posterior_probs_BxT"]
|
| 461 |
+
z_posterior_distr_BxT = tfp.distributions.Independent(
|
| 462 |
+
tfp.distributions.OneHotCategorical(probs=z_posterior_probs_BxT),
|
| 463 |
+
reinterpreted_batch_ndims=1,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Actual distribution over stochastic internal states (z) produced by the
|
| 467 |
+
# dynamics network.
|
| 468 |
+
z_prior_probs_BxT = fwd_out["z_prior_probs_BxT"]
|
| 469 |
+
z_prior_distr_BxT = tfp.distributions.Independent(
|
| 470 |
+
tfp.distributions.OneHotCategorical(probs=z_prior_probs_BxT),
|
| 471 |
+
reinterpreted_batch_ndims=1,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Stop gradient for encoder's z-outputs:
|
| 475 |
+
sg_z_posterior_distr_BxT = tfp.distributions.Independent(
|
| 476 |
+
tfp.distributions.OneHotCategorical(
|
| 477 |
+
probs=tf.stop_gradient(z_posterior_probs_BxT)
|
| 478 |
+
),
|
| 479 |
+
reinterpreted_batch_ndims=1,
|
| 480 |
+
)
|
| 481 |
+
# Stop gradient for dynamics model's z-outputs:
|
| 482 |
+
sg_z_prior_distr_BxT = tfp.distributions.Independent(
|
| 483 |
+
tfp.distributions.OneHotCategorical(
|
| 484 |
+
probs=tf.stop_gradient(z_prior_probs_BxT)
|
| 485 |
+
),
|
| 486 |
+
reinterpreted_batch_ndims=1,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Implement free bits. According to [1]:
|
| 490 |
+
# "To avoid a degenerate solution where the dynamics are trivial to predict but
|
| 491 |
+
# contain not enough information about the inputs, we employ free bits by
|
| 492 |
+
# clipping the dynamics and representation losses below the value of
|
| 493 |
+
# 1 nat ≈ 1.44 bits. This disables them while they are already minimized well to
|
| 494 |
+
# focus the world model on its prediction loss"
|
| 495 |
+
L_dyn_BxT = tf.math.maximum(
|
| 496 |
+
1.0,
|
| 497 |
+
tfp.distributions.kl_divergence(
|
| 498 |
+
sg_z_posterior_distr_BxT, z_prior_distr_BxT
|
| 499 |
+
),
|
| 500 |
+
)
|
| 501 |
+
# Unfold time rank back in.
|
| 502 |
+
L_dyn_B_T = tf.reshape(
|
| 503 |
+
L_dyn_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
L_rep_BxT = tf.math.maximum(
|
| 507 |
+
1.0,
|
| 508 |
+
tfp.distributions.kl_divergence(
|
| 509 |
+
z_posterior_distr_BxT, sg_z_prior_distr_BxT
|
| 510 |
+
),
|
| 511 |
+
)
|
| 512 |
+
# Unfold time rank back in.
|
| 513 |
+
L_rep_B_T = tf.reshape(
|
| 514 |
+
L_rep_BxT, (config.batch_size_B_per_learner, config.batch_length_T)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
return L_dyn_B_T, L_rep_B_T
|
| 518 |
+
|
| 519 |
+
def _compute_actor_loss(
|
| 520 |
+
self,
|
| 521 |
+
*,
|
| 522 |
+
module_id: ModuleID,
|
| 523 |
+
config: DreamerV3Config,
|
| 524 |
+
dream_data: Dict[str, TensorType],
|
| 525 |
+
value_targets_t0_to_Hm1_BxT: TensorType,
|
| 526 |
+
) -> TensorType:
|
| 527 |
+
"""Helper method computing the actor's loss terms.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
module_id: The module_id for which to compute the actor loss.
|
| 531 |
+
config: The DreamerV3Config to use.
|
| 532 |
+
dream_data: The data generated by dreaming for H steps (horizon) starting
|
| 533 |
+
from any BxT state (sampled from the buffer for the train batch).
|
| 534 |
+
value_targets_t0_to_Hm1_BxT: The computed value function targets of the
|
| 535 |
+
shape (t0 to H-1, BxT).
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
The total actor loss tensor.
|
| 539 |
+
"""
|
| 540 |
+
actor = self.module[module_id].actor
|
| 541 |
+
|
| 542 |
+
# Note: `scaled_value_targets_t0_to_Hm1_B` are NOT stop_gradient'd yet.
|
| 543 |
+
scaled_value_targets_t0_to_Hm1_B = self._compute_scaled_value_targets(
|
| 544 |
+
module_id=module_id,
|
| 545 |
+
config=config,
|
| 546 |
+
value_targets_t0_to_Hm1_BxT=value_targets_t0_to_Hm1_BxT,
|
| 547 |
+
value_predictions_t0_to_Hm1_BxT=dream_data["values_dreamed_t0_to_H_BxT"][
|
| 548 |
+
:-1
|
| 549 |
+
],
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# Actions actually taken in the dream.
|
| 553 |
+
actions_dreamed = tf.stop_gradient(dream_data["actions_dreamed_t0_to_H_BxT"])[
|
| 554 |
+
:-1
|
| 555 |
+
]
|
| 556 |
+
actions_dreamed_dist_params_t0_to_Hm1_B = dream_data[
|
| 557 |
+
"actions_dreamed_dist_params_t0_to_H_BxT"
|
| 558 |
+
][:-1]
|
| 559 |
+
|
| 560 |
+
dist_t0_to_Hm1_B = actor.get_action_dist_object(
|
| 561 |
+
actions_dreamed_dist_params_t0_to_Hm1_B
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Compute log(p)s of all possible actions in the dream.
|
| 565 |
+
if isinstance(self.module[module_id].actor.action_space, gym.spaces.Discrete):
|
| 566 |
+
# Note that when we create the Categorical action distributions, we compute
|
| 567 |
+
# unimix probs, then math.log these and provide these log(p) as "logits" to
|
| 568 |
+
# the Categorical. So here, we'll continue to work with log(p)s (not
|
| 569 |
+
# really "logits")!
|
| 570 |
+
logp_actions_t0_to_Hm1_B = actions_dreamed_dist_params_t0_to_Hm1_B
|
| 571 |
+
|
| 572 |
+
# Log probs of actions actually taken in the dream.
|
| 573 |
+
logp_actions_dreamed_t0_to_Hm1_B = tf.reduce_sum(
|
| 574 |
+
actions_dreamed * logp_actions_t0_to_Hm1_B,
|
| 575 |
+
axis=-1,
|
| 576 |
+
)
|
| 577 |
+
# First term of loss function. [1] eq. 11.
|
| 578 |
+
logp_loss_H_B = logp_actions_dreamed_t0_to_Hm1_B * tf.stop_gradient(
|
| 579 |
+
scaled_value_targets_t0_to_Hm1_B
|
| 580 |
+
)
|
| 581 |
+
# Box space.
|
| 582 |
+
else:
|
| 583 |
+
logp_actions_dreamed_t0_to_Hm1_B = dist_t0_to_Hm1_B.log_prob(
|
| 584 |
+
actions_dreamed
|
| 585 |
+
)
|
| 586 |
+
# First term of loss function. [1] eq. 11.
|
| 587 |
+
logp_loss_H_B = scaled_value_targets_t0_to_Hm1_B
|
| 588 |
+
|
| 589 |
+
assert len(logp_loss_H_B.shape) == 2
|
| 590 |
+
|
| 591 |
+
# Add entropy loss term (second term [1] eq. 11).
|
| 592 |
+
entropy_H_B = dist_t0_to_Hm1_B.entropy()
|
| 593 |
+
assert len(entropy_H_B.shape) == 2
|
| 594 |
+
entropy = tf.reduce_mean(entropy_H_B)
|
| 595 |
+
|
| 596 |
+
L_actor_reinforce_term_H_B = -logp_loss_H_B
|
| 597 |
+
L_actor_action_entropy_term_H_B = -config.entropy_scale * entropy_H_B
|
| 598 |
+
|
| 599 |
+
L_actor_H_B = L_actor_reinforce_term_H_B + L_actor_action_entropy_term_H_B
|
| 600 |
+
# Mask out everything that goes beyond a predicted continue=False boundary.
|
| 601 |
+
L_actor_H_B *= tf.stop_gradient(dream_data["dream_loss_weights_t0_to_H_BxT"])[
|
| 602 |
+
:-1
|
| 603 |
+
]
|
| 604 |
+
L_actor = tf.reduce_mean(L_actor_H_B)
|
| 605 |
+
|
| 606 |
+
# Log important actor loss stats.
|
| 607 |
+
self.metrics.log_dict(
|
| 608 |
+
{
|
| 609 |
+
"ACTOR_L_total": L_actor,
|
| 610 |
+
"ACTOR_value_targets_pct95_ema": actor.ema_value_target_pct95,
|
| 611 |
+
"ACTOR_value_targets_pct5_ema": actor.ema_value_target_pct5,
|
| 612 |
+
"ACTOR_action_entropy": entropy,
|
| 613 |
+
# Individual loss terms.
|
| 614 |
+
"ACTOR_L_neglogp_reinforce_term": tf.reduce_mean(
|
| 615 |
+
L_actor_reinforce_term_H_B
|
| 616 |
+
),
|
| 617 |
+
"ACTOR_L_neg_entropy_term": tf.reduce_mean(
|
| 618 |
+
L_actor_action_entropy_term_H_B
|
| 619 |
+
),
|
| 620 |
+
},
|
| 621 |
+
key=module_id,
|
| 622 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 623 |
+
)
|
| 624 |
+
if config.report_individual_batch_item_stats:
|
| 625 |
+
self.metrics.log_dict(
|
| 626 |
+
{
|
| 627 |
+
"ACTOR_L_total_H_BxT": L_actor_H_B,
|
| 628 |
+
"ACTOR_logp_actions_dreamed_H_BxT": (
|
| 629 |
+
logp_actions_dreamed_t0_to_Hm1_B
|
| 630 |
+
),
|
| 631 |
+
"ACTOR_scaled_value_targets_H_BxT": (
|
| 632 |
+
scaled_value_targets_t0_to_Hm1_B
|
| 633 |
+
),
|
| 634 |
+
"ACTOR_action_entropy_H_BxT": entropy_H_B,
|
| 635 |
+
# Individual loss terms.
|
| 636 |
+
"ACTOR_L_neglogp_reinforce_term_H_BxT": L_actor_reinforce_term_H_B,
|
| 637 |
+
"ACTOR_L_neg_entropy_term_H_BxT": L_actor_action_entropy_term_H_B,
|
| 638 |
+
},
|
| 639 |
+
key=module_id,
|
| 640 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
return L_actor
|
| 644 |
+
|
| 645 |
+
def _compute_critic_loss(
|
| 646 |
+
self,
|
| 647 |
+
*,
|
| 648 |
+
module_id: ModuleID,
|
| 649 |
+
config: DreamerV3Config,
|
| 650 |
+
dream_data: Dict[str, TensorType],
|
| 651 |
+
value_targets_t0_to_Hm1_BxT: TensorType,
|
| 652 |
+
) -> TensorType:
|
| 653 |
+
"""Helper method computing the critic's loss terms.
|
| 654 |
+
|
| 655 |
+
Args:
|
| 656 |
+
module_id: The ModuleID for which to compute the critic loss.
|
| 657 |
+
config: The DreamerV3Config to use.
|
| 658 |
+
dream_data: The data generated by dreaming for H steps (horizon) starting
|
| 659 |
+
from any BxT state (sampled from the buffer for the train batch).
|
| 660 |
+
value_targets_t0_to_Hm1_BxT: The computed value function targets of the
|
| 661 |
+
shape (t0 to H-1, BxT).
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
The total critic loss tensor.
|
| 665 |
+
"""
|
| 666 |
+
# B=BxT
|
| 667 |
+
H, B = dream_data["rewards_dreamed_t0_to_H_BxT"].shape[:2]
|
| 668 |
+
Hm1 = H - 1
|
| 669 |
+
|
| 670 |
+
# Note that value targets are NOT symlog'd and go from t0 to H-1, not H, like
|
| 671 |
+
# all the other dream data.
|
| 672 |
+
|
| 673 |
+
# From here on: B=BxT
|
| 674 |
+
value_targets_t0_to_Hm1_B = tf.stop_gradient(value_targets_t0_to_Hm1_BxT)
|
| 675 |
+
value_symlog_targets_t0_to_Hm1_B = symlog(value_targets_t0_to_Hm1_B)
|
| 676 |
+
# Fold time rank (for two_hot'ing).
|
| 677 |
+
value_symlog_targets_HxB = tf.reshape(value_symlog_targets_t0_to_Hm1_B, (-1,))
|
| 678 |
+
value_symlog_targets_two_hot_HxB = two_hot(value_symlog_targets_HxB)
|
| 679 |
+
# Unfold time rank.
|
| 680 |
+
value_symlog_targets_two_hot_t0_to_Hm1_B = tf.reshape(
|
| 681 |
+
value_symlog_targets_two_hot_HxB,
|
| 682 |
+
shape=[Hm1, B, value_symlog_targets_two_hot_HxB.shape[-1]],
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# Get (B x T x probs) tensor from return distributions.
|
| 686 |
+
value_symlog_logits_HxB = dream_data["values_symlog_dreamed_logits_t0_to_HxBxT"]
|
| 687 |
+
# Unfold time rank and cut last time index to match value targets.
|
| 688 |
+
value_symlog_logits_t0_to_Hm1_B = tf.reshape(
|
| 689 |
+
value_symlog_logits_HxB,
|
| 690 |
+
shape=[H, B, value_symlog_logits_HxB.shape[-1]],
|
| 691 |
+
)[:-1]
|
| 692 |
+
|
| 693 |
+
values_log_pred_Hm1_B = (
|
| 694 |
+
value_symlog_logits_t0_to_Hm1_B
|
| 695 |
+
- tf.math.reduce_logsumexp(
|
| 696 |
+
value_symlog_logits_t0_to_Hm1_B, axis=-1, keepdims=True
|
| 697 |
+
)
|
| 698 |
+
)
|
| 699 |
+
# Multiply with two-hot targets and neg.
|
| 700 |
+
value_loss_two_hot_H_B = -tf.reduce_sum(
|
| 701 |
+
values_log_pred_Hm1_B * value_symlog_targets_two_hot_t0_to_Hm1_B, axis=-1
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
# Compute EMA regularization loss.
|
| 705 |
+
# Expected values (dreamed) from the EMA (slow critic) net.
|
| 706 |
+
# Note: Slow critic (EMA) outputs are already stop_gradient'd.
|
| 707 |
+
value_symlog_ema_t0_to_Hm1_B = tf.stop_gradient(
|
| 708 |
+
dream_data["v_symlog_dreamed_ema_t0_to_H_BxT"]
|
| 709 |
+
)[:-1]
|
| 710 |
+
# Fold time rank (for two_hot'ing).
|
| 711 |
+
value_symlog_ema_HxB = tf.reshape(value_symlog_ema_t0_to_Hm1_B, (-1,))
|
| 712 |
+
value_symlog_ema_two_hot_HxB = two_hot(value_symlog_ema_HxB)
|
| 713 |
+
# Unfold time rank.
|
| 714 |
+
value_symlog_ema_two_hot_t0_to_Hm1_B = tf.reshape(
|
| 715 |
+
value_symlog_ema_two_hot_HxB,
|
| 716 |
+
shape=[Hm1, B, value_symlog_ema_two_hot_HxB.shape[-1]],
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
# Compute ema regularizer loss.
|
| 720 |
+
# In the paper, it is not described how exactly to form this regularizer term
|
| 721 |
+
# and how to weigh it.
|
| 722 |
+
# So we follow Danijar's repo here:
|
| 723 |
+
# `reg = -dist.log_prob(sg(self.slow(traj).mean()))`
|
| 724 |
+
# with a weight of 1.0, where dist is the bucket'ized distribution output by the
|
| 725 |
+
# fast critic. sg=stop gradient; mean() -> use the expected EMA values.
|
| 726 |
+
# Multiply with two-hot targets and neg.
|
| 727 |
+
ema_regularization_loss_H_B = -tf.reduce_sum(
|
| 728 |
+
values_log_pred_Hm1_B * value_symlog_ema_two_hot_t0_to_Hm1_B, axis=-1
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
L_critic_H_B = value_loss_two_hot_H_B + ema_regularization_loss_H_B
|
| 732 |
+
|
| 733 |
+
# Mask out everything that goes beyond a predicted continue=False boundary.
|
| 734 |
+
L_critic_H_B *= tf.stop_gradient(dream_data["dream_loss_weights_t0_to_H_BxT"])[
|
| 735 |
+
:-1
|
| 736 |
+
]
|
| 737 |
+
|
| 738 |
+
# Reduce over both H- (time) axis and B-axis (mean).
|
| 739 |
+
L_critic = tf.reduce_mean(L_critic_H_B)
|
| 740 |
+
|
| 741 |
+
# Log important critic loss stats.
|
| 742 |
+
self.metrics.log_dict(
|
| 743 |
+
{
|
| 744 |
+
"CRITIC_L_total": L_critic,
|
| 745 |
+
"CRITIC_L_neg_logp_of_value_targets": tf.reduce_mean(
|
| 746 |
+
value_loss_two_hot_H_B
|
| 747 |
+
),
|
| 748 |
+
"CRITIC_L_slow_critic_regularization": tf.reduce_mean(
|
| 749 |
+
ema_regularization_loss_H_B
|
| 750 |
+
),
|
| 751 |
+
},
|
| 752 |
+
key=module_id,
|
| 753 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 754 |
+
)
|
| 755 |
+
if config.report_individual_batch_item_stats:
|
| 756 |
+
# Log important critic loss stats.
|
| 757 |
+
self.metrics.log_dict(
|
| 758 |
+
{
|
| 759 |
+
# Symlog'd value targets. Critic learns to predict symlog'd values.
|
| 760 |
+
"VALUE_TARGETS_symlog_H_BxT": value_symlog_targets_t0_to_Hm1_B,
|
| 761 |
+
# Critic loss terms.
|
| 762 |
+
"CRITIC_L_total_H_BxT": L_critic_H_B,
|
| 763 |
+
"CRITIC_L_neg_logp_of_value_targets_H_BxT": value_loss_two_hot_H_B,
|
| 764 |
+
"CRITIC_L_slow_critic_regularization_H_BxT": (
|
| 765 |
+
ema_regularization_loss_H_B
|
| 766 |
+
),
|
| 767 |
+
},
|
| 768 |
+
key=module_id,
|
| 769 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
return L_critic
|
| 773 |
+
|
| 774 |
+
def _compute_value_targets(
|
| 775 |
+
self,
|
| 776 |
+
*,
|
| 777 |
+
config: DreamerV3Config,
|
| 778 |
+
rewards_t0_to_H_BxT: TensorType,
|
| 779 |
+
intrinsic_rewards_t1_to_H_BxT: TensorType,
|
| 780 |
+
continues_t0_to_H_BxT: TensorType,
|
| 781 |
+
value_predictions_t0_to_H_BxT: TensorType,
|
| 782 |
+
) -> TensorType:
|
| 783 |
+
"""Helper method computing the value targets.
|
| 784 |
+
|
| 785 |
+
All args are (H, BxT, ...) and in non-symlog'd (real) reward space.
|
| 786 |
+
Non-symlog is important b/c log(a+b) != log(a) + log(b).
|
| 787 |
+
See [1] eq. 8 and 10.
|
| 788 |
+
Thus, targets are always returned in real (non-symlog'd space).
|
| 789 |
+
They need to be re-symlog'd before computing the critic loss from them (b/c the
|
| 790 |
+
critic produces predictions in symlog space).
|
| 791 |
+
Note that the original B and T ranks together form the new batch dimension
|
| 792 |
+
(folded into BxT) and the new time rank is the dream horizon (hence: [H, BxT]).
|
| 793 |
+
|
| 794 |
+
Variable names nomenclature:
|
| 795 |
+
`H`=1+horizon_H (start state + H steps dreamed),
|
| 796 |
+
`BxT`=batch_size * batch_length (meaning the original trajectory time rank has
|
| 797 |
+
been folded).
|
| 798 |
+
|
| 799 |
+
Rewards, continues, and value predictions are all of shape [t0-H, BxT]
|
| 800 |
+
(time-major), whereas returned targets are [t0 to H-1, B] (last timestep missing
|
| 801 |
+
b/c the target value equals vf prediction in that location anyways.
|
| 802 |
+
|
| 803 |
+
Args:
|
| 804 |
+
config: The DreamerV3Config to use.
|
| 805 |
+
rewards_t0_to_H_BxT: The reward predictor's predictions over the
|
| 806 |
+
dreamed trajectory t0 to H (and for the batch BxT).
|
| 807 |
+
intrinsic_rewards_t1_to_H_BxT: The predicted intrinsic rewards over the
|
| 808 |
+
dreamed trajectory t0 to H (and for the batch BxT).
|
| 809 |
+
continues_t0_to_H_BxT: The continue predictor's predictions over the
|
| 810 |
+
dreamed trajectory t0 to H (and for the batch BxT).
|
| 811 |
+
value_predictions_t0_to_H_BxT: The critic's value predictions over the
|
| 812 |
+
dreamed trajectory t0 to H (and for the batch BxT).
|
| 813 |
+
|
| 814 |
+
Returns:
|
| 815 |
+
The value targets in the shape: [t0toH-1, BxT]. Note that the last step (H)
|
| 816 |
+
does not require a value target as it matches the critic's value prediction
|
| 817 |
+
anyways.
|
| 818 |
+
"""
|
| 819 |
+
# The first reward is irrelevant (not used for any VF target).
|
| 820 |
+
rewards_t1_to_H_BxT = rewards_t0_to_H_BxT[1:]
|
| 821 |
+
if intrinsic_rewards_t1_to_H_BxT is not None:
|
| 822 |
+
rewards_t1_to_H_BxT += intrinsic_rewards_t1_to_H_BxT
|
| 823 |
+
|
| 824 |
+
# In all the following, when building value targets for t=1 to T=H,
|
| 825 |
+
# exclude rewards & continues for t=1 b/c we don't need r1 or c1.
|
| 826 |
+
# The target (R1) for V1 is built from r2, c2, and V2/R2.
|
| 827 |
+
discount = continues_t0_to_H_BxT[1:] * config.gamma # shape=[2-16, BxT]
|
| 828 |
+
Rs = [value_predictions_t0_to_H_BxT[-1]] # Rs indices=[16]
|
| 829 |
+
intermediates = (
|
| 830 |
+
rewards_t1_to_H_BxT
|
| 831 |
+
+ discount * (1 - config.gae_lambda) * value_predictions_t0_to_H_BxT[1:]
|
| 832 |
+
)
|
| 833 |
+
# intermediates.shape=[2-16, BxT]
|
| 834 |
+
|
| 835 |
+
# Loop through reversed timesteps (axis=1) from T+1 to t=2.
|
| 836 |
+
for t in reversed(range(discount.shape[0])):
|
| 837 |
+
Rs.append(intermediates[t] + discount[t] * config.gae_lambda * Rs[-1])
|
| 838 |
+
|
| 839 |
+
# Reverse along time axis and cut the last entry (value estimate at very end
|
| 840 |
+
# cannot be learnt from as it's the same as the ... well ... value estimate).
|
| 841 |
+
targets_t0toHm1_BxT = tf.stack(list(reversed(Rs))[:-1], axis=0)
|
| 842 |
+
# targets.shape=[t0 to H-1,BxT]
|
| 843 |
+
|
| 844 |
+
return targets_t0toHm1_BxT
|
| 845 |
+
|
| 846 |
+
def _compute_scaled_value_targets(
|
| 847 |
+
self,
|
| 848 |
+
*,
|
| 849 |
+
module_id: ModuleID,
|
| 850 |
+
config: DreamerV3Config,
|
| 851 |
+
value_targets_t0_to_Hm1_BxT: TensorType,
|
| 852 |
+
value_predictions_t0_to_Hm1_BxT: TensorType,
|
| 853 |
+
) -> TensorType:
|
| 854 |
+
"""Helper method computing the scaled value targets.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
module_id: The module_id to compute value targets for.
|
| 858 |
+
config: The DreamerV3Config to use.
|
| 859 |
+
value_targets_t0_to_Hm1_BxT: The value targets computed by
|
| 860 |
+
`self._compute_value_targets` in the shape of (t0 to H-1, BxT)
|
| 861 |
+
and of type float32.
|
| 862 |
+
value_predictions_t0_to_Hm1_BxT: The critic's value predictions over the
|
| 863 |
+
dreamed trajectories (w/o the last timestep). The shape of this
|
| 864 |
+
tensor is (t0 to H-1, BxT) and the type is float32.
|
| 865 |
+
|
| 866 |
+
Returns:
|
| 867 |
+
The scaled value targets used by the actor for REINFORCE policy updates
|
| 868 |
+
(using scaled advantages). See [1] eq. 12 for more details.
|
| 869 |
+
"""
|
| 870 |
+
actor = self.module[module_id].actor
|
| 871 |
+
|
| 872 |
+
value_targets_H_B = value_targets_t0_to_Hm1_BxT
|
| 873 |
+
value_predictions_H_B = value_predictions_t0_to_Hm1_BxT
|
| 874 |
+
|
| 875 |
+
# Compute S: [1] eq. 12.
|
| 876 |
+
Per_R_5 = tfp.stats.percentile(value_targets_H_B, 5)
|
| 877 |
+
Per_R_95 = tfp.stats.percentile(value_targets_H_B, 95)
|
| 878 |
+
|
| 879 |
+
# Update EMA values for 5 and 95 percentile, stored as tf variables under actor
|
| 880 |
+
# network.
|
| 881 |
+
# 5 percentile
|
| 882 |
+
new_val_pct5 = tf.where(
|
| 883 |
+
tf.math.is_nan(actor.ema_value_target_pct5),
|
| 884 |
+
# is NaN: Initial values: Just set.
|
| 885 |
+
Per_R_5,
|
| 886 |
+
# Later update (something already stored in EMA variable): Update EMA.
|
| 887 |
+
(
|
| 888 |
+
config.return_normalization_decay * actor.ema_value_target_pct5
|
| 889 |
+
+ (1.0 - config.return_normalization_decay) * Per_R_5
|
| 890 |
+
),
|
| 891 |
+
)
|
| 892 |
+
actor.ema_value_target_pct5.assign(new_val_pct5)
|
| 893 |
+
# 95 percentile
|
| 894 |
+
new_val_pct95 = tf.where(
|
| 895 |
+
tf.math.is_nan(actor.ema_value_target_pct95),
|
| 896 |
+
# is NaN: Initial values: Just set.
|
| 897 |
+
Per_R_95,
|
| 898 |
+
# Later update (something already stored in EMA variable): Update EMA.
|
| 899 |
+
(
|
| 900 |
+
config.return_normalization_decay * actor.ema_value_target_pct95
|
| 901 |
+
+ (1.0 - config.return_normalization_decay) * Per_R_95
|
| 902 |
+
),
|
| 903 |
+
)
|
| 904 |
+
actor.ema_value_target_pct95.assign(new_val_pct95)
|
| 905 |
+
|
| 906 |
+
# [1] eq. 11 (first term).
|
| 907 |
+
offset = actor.ema_value_target_pct5
|
| 908 |
+
invscale = tf.math.maximum(
|
| 909 |
+
1e-8, actor.ema_value_target_pct95 - actor.ema_value_target_pct5
|
| 910 |
+
)
|
| 911 |
+
scaled_value_targets_H_B = (value_targets_H_B - offset) / invscale
|
| 912 |
+
scaled_value_predictions_H_B = (value_predictions_H_B - offset) / invscale
|
| 913 |
+
|
| 914 |
+
# Return advantages.
|
| 915 |
+
return scaled_value_targets_H_B - scaled_value_predictions_H_B
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
|
| 6 |
+
[2] Mastering Atari with Discrete World Models - 2021
|
| 7 |
+
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
|
| 8 |
+
https://arxiv.org/pdf/2010.02193.pdf
|
| 9 |
+
"""
|
| 10 |
+
from ray.rllib.algorithms.dreamerv3.dreamerv3_rl_module import DreamerV3RLModule
|
| 11 |
+
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
|
| 12 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 13 |
+
|
| 14 |
+
tf1, tf, _ = try_import_tf()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DreamerV3TfRLModule(TfRLModule, DreamerV3RLModule):
|
| 18 |
+
"""The tf-specific RLModule class for DreamerV3.
|
| 19 |
+
|
| 20 |
+
Serves mainly as a thin-wrapper around the `DreamerModel` (a tf.keras.Model) class.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
framework = "tf2"
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/actor_network.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
from gymnasium.spaces import Box, Discrete
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 11 |
+
from ray.rllib.algorithms.dreamerv3.utils import (
|
| 12 |
+
get_gru_units,
|
| 13 |
+
get_num_z_categoricals,
|
| 14 |
+
get_num_z_classes,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
| 17 |
+
|
| 18 |
+
_, tf, _ = try_import_tf()
|
| 19 |
+
tfp = try_import_tfp()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ActorNetwork(tf.keras.Model):
|
| 23 |
+
"""The `actor` (policy net) of DreamerV3.
|
| 24 |
+
|
| 25 |
+
Consists of a simple MLP for Discrete actions and two MLPs for cont. actions (mean
|
| 26 |
+
and stddev).
|
| 27 |
+
Also contains two scalar variables to keep track of the percentile-5 and
|
| 28 |
+
percentile-95 values of the computed value targets within a batch. This is used to
|
| 29 |
+
compute the "scaled value targets" for actor learning. These two variables decay
|
| 30 |
+
over time exponentially (see [1] for more details).
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
model_size: str = "XS",
|
| 37 |
+
action_space: gym.Space,
|
| 38 |
+
):
|
| 39 |
+
"""Initializes an ActorNetwork instance.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model_size: The "Model Size" used according to [1] Appendix B.
|
| 43 |
+
Use None for manually setting the different network sizes.
|
| 44 |
+
action_space: The action space of the environment used.
|
| 45 |
+
"""
|
| 46 |
+
super().__init__(name="actor")
|
| 47 |
+
|
| 48 |
+
self.model_size = model_size
|
| 49 |
+
self.action_space = action_space
|
| 50 |
+
|
| 51 |
+
# The EMA decay variables used for the [Percentile(R, 95%) - Percentile(R, 5%)]
|
| 52 |
+
# diff to scale value targets for the actor loss.
|
| 53 |
+
self.ema_value_target_pct5 = tf.Variable(
|
| 54 |
+
np.nan, trainable=False, name="value_target_pct5"
|
| 55 |
+
)
|
| 56 |
+
self.ema_value_target_pct95 = tf.Variable(
|
| 57 |
+
np.nan, trainable=False, name="value_target_pct95"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# For discrete actions, use a single MLP that computes logits.
|
| 61 |
+
if isinstance(self.action_space, Discrete):
|
| 62 |
+
self.mlp = MLP(
|
| 63 |
+
model_size=self.model_size,
|
| 64 |
+
output_layer_size=self.action_space.n,
|
| 65 |
+
name="actor_mlp",
|
| 66 |
+
)
|
| 67 |
+
# For cont. actions, use separate MLPs for Gaussian mean and stddev.
|
| 68 |
+
# TODO (sven): In the author's original code repo, this is NOT the case,
|
| 69 |
+
# inputs are pushed through a shared MLP, then only the two output linear
|
| 70 |
+
# layers are separate for std- and mean logits.
|
| 71 |
+
elif isinstance(action_space, Box):
|
| 72 |
+
output_layer_size = np.prod(action_space.shape)
|
| 73 |
+
self.mlp = MLP(
|
| 74 |
+
model_size=self.model_size,
|
| 75 |
+
output_layer_size=output_layer_size,
|
| 76 |
+
name="actor_mlp_mean",
|
| 77 |
+
)
|
| 78 |
+
self.std_mlp = MLP(
|
| 79 |
+
model_size=self.model_size,
|
| 80 |
+
output_layer_size=output_layer_size,
|
| 81 |
+
name="actor_mlp_std",
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Invalid action space: {action_space}")
|
| 85 |
+
|
| 86 |
+
# Trace self.call.
|
| 87 |
+
dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 88 |
+
self.call = tf.function(
|
| 89 |
+
input_signature=[
|
| 90 |
+
tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
|
| 91 |
+
tf.TensorSpec(
|
| 92 |
+
shape=[
|
| 93 |
+
None,
|
| 94 |
+
get_num_z_categoricals(model_size),
|
| 95 |
+
get_num_z_classes(model_size),
|
| 96 |
+
],
|
| 97 |
+
dtype=dl_type,
|
| 98 |
+
),
|
| 99 |
+
]
|
| 100 |
+
)(self.call)
|
| 101 |
+
|
| 102 |
+
def call(self, h, z):
|
| 103 |
+
"""Performs a forward pass through this policy network.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
h: The deterministic hidden state of the sequence model. [B, dim(h)].
|
| 107 |
+
z: The stochastic discrete representations of the original
|
| 108 |
+
observation input. [B, num_categoricals, num_classes].
|
| 109 |
+
"""
|
| 110 |
+
# Flatten last two dims of z.
|
| 111 |
+
assert len(z.shape) == 3
|
| 112 |
+
z_shape = tf.shape(z)
|
| 113 |
+
z = tf.reshape(z, shape=(z_shape[0], -1))
|
| 114 |
+
assert len(z.shape) == 2
|
| 115 |
+
out = tf.concat([h, z], axis=-1)
|
| 116 |
+
out.set_shape(
|
| 117 |
+
[
|
| 118 |
+
None,
|
| 119 |
+
(
|
| 120 |
+
get_num_z_categoricals(self.model_size)
|
| 121 |
+
* get_num_z_classes(self.model_size)
|
| 122 |
+
+ get_gru_units(self.model_size)
|
| 123 |
+
),
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
# Send h-cat-z through MLP.
|
| 127 |
+
action_logits = tf.cast(self.mlp(out), tf.float32)
|
| 128 |
+
|
| 129 |
+
if isinstance(self.action_space, Discrete):
|
| 130 |
+
action_probs = tf.nn.softmax(action_logits)
|
| 131 |
+
|
| 132 |
+
# Add the unimix weighting (1% uniform) to the probs.
|
| 133 |
+
# See [1]: "Unimix categoricals: We parameterize the categorical
|
| 134 |
+
# distributions for the world model representations and dynamics, as well as
|
| 135 |
+
# for the actor network, as mixtures of 1% uniform and 99% neural network
|
| 136 |
+
# output to ensure a minimal amount of probability mass on every class and
|
| 137 |
+
# thus keep log probabilities and KL divergences well behaved."
|
| 138 |
+
action_probs = 0.99 * action_probs + 0.01 * (1.0 / self.action_space.n)
|
| 139 |
+
|
| 140 |
+
# Danijar's code does: distr = [Distr class](logits=tf.log(probs)).
|
| 141 |
+
# Not sure why we don't directly use the already available probs instead.
|
| 142 |
+
action_logits = tf.math.log(action_probs)
|
| 143 |
+
|
| 144 |
+
# Distribution parameters are the log(probs) directly.
|
| 145 |
+
distr_params = action_logits
|
| 146 |
+
distr = self.get_action_dist_object(distr_params)
|
| 147 |
+
|
| 148 |
+
action = tf.stop_gradient(distr.sample()) + (
|
| 149 |
+
action_probs - tf.stop_gradient(action_probs)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
elif isinstance(self.action_space, Box):
|
| 153 |
+
# Send h-cat-z through MLP to compute stddev logits for Normal dist
|
| 154 |
+
std_logits = tf.cast(self.std_mlp(out), tf.float32)
|
| 155 |
+
# minstd, maxstd taken from [1] from configs.yaml
|
| 156 |
+
minstd = 0.1
|
| 157 |
+
maxstd = 1.0
|
| 158 |
+
|
| 159 |
+
# Distribution parameters are the squashed std_logits and the tanh'd
|
| 160 |
+
# mean logits.
|
| 161 |
+
# squash std_logits from (-inf, inf) to (minstd, maxstd)
|
| 162 |
+
std_logits = (maxstd - minstd) * tf.sigmoid(std_logits + 2.0) + minstd
|
| 163 |
+
mean_logits = tf.tanh(action_logits)
|
| 164 |
+
|
| 165 |
+
distr_params = tf.concat([mean_logits, std_logits], axis=-1)
|
| 166 |
+
distr = self.get_action_dist_object(distr_params)
|
| 167 |
+
|
| 168 |
+
action = distr.sample()
|
| 169 |
+
|
| 170 |
+
return action, distr_params
|
| 171 |
+
|
| 172 |
+
def get_action_dist_object(self, action_dist_params_T_B):
|
| 173 |
+
"""Helper method to create an action distribution object from (T, B, ..) params.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
action_dist_params_T_B: The time-major action distribution parameters.
|
| 177 |
+
This could be simply the logits (discrete) or a to-be-split-in-2
|
| 178 |
+
tensor for mean and stddev (continuous).
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
The tfp action distribution object, from which one can sample, compute
|
| 182 |
+
log probs, entropy, etc..
|
| 183 |
+
"""
|
| 184 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 185 |
+
# Create the distribution object using the unimix'd logits.
|
| 186 |
+
distr = tfp.distributions.OneHotCategorical(
|
| 187 |
+
logits=action_dist_params_T_B,
|
| 188 |
+
dtype=tf.float32,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
elif isinstance(self.action_space, gym.spaces.Box):
|
| 192 |
+
# Compute Normal distribution from action_logits and std_logits
|
| 193 |
+
loc, scale = tf.split(action_dist_params_T_B, 2, axis=-1)
|
| 194 |
+
distr = tfp.distributions.Normal(loc=loc, scale=scale)
|
| 195 |
+
|
| 196 |
+
# If action_space is a box with multiple dims, make individual dims
|
| 197 |
+
# independent.
|
| 198 |
+
distr = tfp.distributions.Independent(distr, len(self.action_space.shape))
|
| 199 |
+
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(f"Action space {self.action_space} not supported!")
|
| 202 |
+
|
| 203 |
+
return distr
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from ray.rllib.algorithms.dreamerv3.utils import get_cnn_multiplier
|
| 9 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 10 |
+
|
| 11 |
+
_, tf, _ = try_import_tf()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CNNAtari(tf.keras.Model):
|
| 15 |
+
"""An image encoder mapping 64x64 RGB images via 4 CNN layers into a 1D space."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
*,
|
| 20 |
+
model_size: Optional[str] = "XS",
|
| 21 |
+
cnn_multiplier: Optional[int] = None,
|
| 22 |
+
):
|
| 23 |
+
"""Initializes a CNNAtari instance.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_size: The "Model Size" used according to [1] Appendix B.
|
| 27 |
+
Use None for manually setting the `cnn_multiplier`.
|
| 28 |
+
cnn_multiplier: Optional override for the additional factor used to multiply
|
| 29 |
+
the number of filters with each CNN layer. Starting with
|
| 30 |
+
1 * `cnn_multiplier` filters in the first CNN layer, the number of
|
| 31 |
+
filters then increases via `2*cnn_multiplier`, `4*cnn_multiplier`, till
|
| 32 |
+
`8*cnn_multiplier`.
|
| 33 |
+
"""
|
| 34 |
+
super().__init__(name="image_encoder")
|
| 35 |
+
|
| 36 |
+
cnn_multiplier = get_cnn_multiplier(model_size, override=cnn_multiplier)
|
| 37 |
+
|
| 38 |
+
# See appendix C in [1]:
|
| 39 |
+
# "We use a similar network architecture but employ layer normalization and
|
| 40 |
+
# SiLU as the activation function. For better framework support, we use
|
| 41 |
+
# same-padded convolutions with stride 2 and kernel size 3 instead of
|
| 42 |
+
# valid-padded convolutions with larger kernels ..."
|
| 43 |
+
# HOWEVER: In Danijar's DreamerV3 repo, kernel size=4 is used, so we use it
|
| 44 |
+
# here, too.
|
| 45 |
+
self.conv_layers = [
|
| 46 |
+
tf.keras.layers.Conv2D(
|
| 47 |
+
filters=1 * cnn_multiplier,
|
| 48 |
+
kernel_size=4,
|
| 49 |
+
strides=(2, 2),
|
| 50 |
+
padding="same",
|
| 51 |
+
# No bias or activation due to layernorm.
|
| 52 |
+
activation=None,
|
| 53 |
+
use_bias=False,
|
| 54 |
+
),
|
| 55 |
+
tf.keras.layers.Conv2D(
|
| 56 |
+
filters=2 * cnn_multiplier,
|
| 57 |
+
kernel_size=4,
|
| 58 |
+
strides=(2, 2),
|
| 59 |
+
padding="same",
|
| 60 |
+
# No bias or activation due to layernorm.
|
| 61 |
+
activation=None,
|
| 62 |
+
use_bias=False,
|
| 63 |
+
),
|
| 64 |
+
tf.keras.layers.Conv2D(
|
| 65 |
+
filters=4 * cnn_multiplier,
|
| 66 |
+
kernel_size=4,
|
| 67 |
+
strides=(2, 2),
|
| 68 |
+
padding="same",
|
| 69 |
+
# No bias or activation due to layernorm.
|
| 70 |
+
activation=None,
|
| 71 |
+
use_bias=False,
|
| 72 |
+
),
|
| 73 |
+
# .. until output is 4 x 4 x [num_filters].
|
| 74 |
+
tf.keras.layers.Conv2D(
|
| 75 |
+
filters=8 * cnn_multiplier,
|
| 76 |
+
kernel_size=4,
|
| 77 |
+
strides=(2, 2),
|
| 78 |
+
padding="same",
|
| 79 |
+
# No bias or activation due to layernorm.
|
| 80 |
+
activation=None,
|
| 81 |
+
use_bias=False,
|
| 82 |
+
),
|
| 83 |
+
]
|
| 84 |
+
self.layer_normalizations = []
|
| 85 |
+
for _ in range(len(self.conv_layers)):
|
| 86 |
+
self.layer_normalizations.append(tf.keras.layers.LayerNormalization())
|
| 87 |
+
# -> 4 x 4 x num_filters -> now flatten.
|
| 88 |
+
self.flatten_layer = tf.keras.layers.Flatten(data_format="channels_last")
|
| 89 |
+
|
| 90 |
+
@tf.function(
|
| 91 |
+
input_signature=[
|
| 92 |
+
tf.TensorSpec(
|
| 93 |
+
shape=[None, 64, 64, 3],
|
| 94 |
+
dtype=tf.keras.mixed_precision.global_policy().compute_dtype
|
| 95 |
+
or tf.float32,
|
| 96 |
+
)
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
def call(self, inputs):
|
| 100 |
+
"""Performs a forward pass through the CNN Atari encoder.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
inputs: The image inputs of shape (B, 64, 64, 3).
|
| 104 |
+
"""
|
| 105 |
+
# [B, h, w] -> grayscale.
|
| 106 |
+
if len(inputs.shape) == 3:
|
| 107 |
+
inputs = tf.expand_dims(inputs, -1)
|
| 108 |
+
out = inputs
|
| 109 |
+
for conv_2d, layer_norm in zip(self.conv_layers, self.layer_normalizations):
|
| 110 |
+
out = tf.nn.silu(layer_norm(inputs=conv_2d(out)))
|
| 111 |
+
assert out.shape[1] == 4 and out.shape[2] == 4
|
| 112 |
+
return self.flatten_layer(out)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/conv_transpose_atari.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
|
| 6 |
+
[2] Mastering Atari with Discrete World Models - 2021
|
| 7 |
+
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
|
| 8 |
+
https://arxiv.org/pdf/2010.02193.pdf
|
| 9 |
+
"""
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from ray.rllib.algorithms.dreamerv3.utils import (
|
| 15 |
+
get_cnn_multiplier,
|
| 16 |
+
get_gru_units,
|
| 17 |
+
get_num_z_categoricals,
|
| 18 |
+
get_num_z_classes,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 21 |
+
|
| 22 |
+
_, tf, _ = try_import_tf()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ConvTransposeAtari(tf.keras.Model):
|
| 26 |
+
"""A Conv2DTranspose decoder to generate Atari images from a latent space.
|
| 27 |
+
|
| 28 |
+
Wraps an initial single linear layer with a stack of 4 Conv2DTranspose layers (with
|
| 29 |
+
layer normalization) and a diag Gaussian, from which we then sample the final image.
|
| 30 |
+
Sampling is done with a fixed stddev=1.0 and using the mean values coming from the
|
| 31 |
+
last Conv2DTranspose layer.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
*,
|
| 37 |
+
model_size: Optional[str] = "XS",
|
| 38 |
+
cnn_multiplier: Optional[int] = None,
|
| 39 |
+
gray_scaled: bool,
|
| 40 |
+
):
|
| 41 |
+
"""Initializes a ConvTransposeAtari instance.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model_size: The "Model Size" used according to [1] Appendinx B.
|
| 45 |
+
Use None for manually setting the `cnn_multiplier`.
|
| 46 |
+
cnn_multiplier: Optional override for the additional factor used to multiply
|
| 47 |
+
the number of filters with each CNN transpose layer. Starting with
|
| 48 |
+
8 * `cnn_multiplier` filters in the first CNN transpose layer, the
|
| 49 |
+
number of filters then decreases via `4*cnn_multiplier`,
|
| 50 |
+
`2*cnn_multiplier`, till `1*cnn_multiplier`.
|
| 51 |
+
gray_scaled: Whether the last Conv2DTranspose layer's output has only 1
|
| 52 |
+
color channel (gray_scaled=True) or 3 RGB channels (gray_scaled=False).
|
| 53 |
+
"""
|
| 54 |
+
super().__init__(name="image_decoder")
|
| 55 |
+
|
| 56 |
+
self.model_size = model_size
|
| 57 |
+
cnn_multiplier = get_cnn_multiplier(self.model_size, override=cnn_multiplier)
|
| 58 |
+
|
| 59 |
+
# The shape going into the first Conv2DTranspose layer.
|
| 60 |
+
# We start with a 4x4 channels=8 "image".
|
| 61 |
+
self.input_dims = (4, 4, 8 * cnn_multiplier)
|
| 62 |
+
|
| 63 |
+
self.gray_scaled = gray_scaled
|
| 64 |
+
|
| 65 |
+
# See appendix B in [1]:
|
| 66 |
+
# "The decoder starts with a dense layer, followed by reshaping
|
| 67 |
+
# to 4 × 4 × C and then inverts the encoder architecture. ..."
|
| 68 |
+
self.dense_layer = tf.keras.layers.Dense(
|
| 69 |
+
units=int(np.prod(self.input_dims)),
|
| 70 |
+
activation=None,
|
| 71 |
+
use_bias=True,
|
| 72 |
+
)
|
| 73 |
+
# Inverse conv2d stack. See cnn_atari.py for corresponding Conv2D stack.
|
| 74 |
+
self.conv_transpose_layers = [
|
| 75 |
+
tf.keras.layers.Conv2DTranspose(
|
| 76 |
+
filters=4 * cnn_multiplier,
|
| 77 |
+
kernel_size=4,
|
| 78 |
+
strides=(2, 2),
|
| 79 |
+
padding="same",
|
| 80 |
+
# No bias or activation due to layernorm.
|
| 81 |
+
activation=None,
|
| 82 |
+
use_bias=False,
|
| 83 |
+
),
|
| 84 |
+
tf.keras.layers.Conv2DTranspose(
|
| 85 |
+
filters=2 * cnn_multiplier,
|
| 86 |
+
kernel_size=4,
|
| 87 |
+
strides=(2, 2),
|
| 88 |
+
padding="same",
|
| 89 |
+
# No bias or activation due to layernorm.
|
| 90 |
+
activation=None,
|
| 91 |
+
use_bias=False,
|
| 92 |
+
),
|
| 93 |
+
tf.keras.layers.Conv2DTranspose(
|
| 94 |
+
filters=1 * cnn_multiplier,
|
| 95 |
+
kernel_size=4,
|
| 96 |
+
strides=(2, 2),
|
| 97 |
+
padding="same",
|
| 98 |
+
# No bias or activation due to layernorm.
|
| 99 |
+
activation=None,
|
| 100 |
+
use_bias=False,
|
| 101 |
+
),
|
| 102 |
+
]
|
| 103 |
+
# Create one LayerNorm layer for each of the Conv2DTranspose layers.
|
| 104 |
+
self.layer_normalizations = []
|
| 105 |
+
for _ in range(len(self.conv_transpose_layers)):
|
| 106 |
+
self.layer_normalizations.append(tf.keras.layers.LayerNormalization())
|
| 107 |
+
|
| 108 |
+
# Important! No activation or layer norm for last layer as the outputs of
|
| 109 |
+
# this one go directly into the diag-gaussian as parameters.
|
| 110 |
+
self.output_conv2d_transpose = tf.keras.layers.Conv2DTranspose(
|
| 111 |
+
filters=1 if self.gray_scaled else 3,
|
| 112 |
+
kernel_size=4,
|
| 113 |
+
strides=(2, 2),
|
| 114 |
+
padding="same",
|
| 115 |
+
activation=None,
|
| 116 |
+
use_bias=True, # Last layer does use bias (b/c has no LayerNorm).
|
| 117 |
+
)
|
| 118 |
+
# .. until output is 64 x 64 x 3 (or 1 for self.gray_scaled=True).
|
| 119 |
+
|
| 120 |
+
# Trace self.call.
|
| 121 |
+
dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 122 |
+
self.call = tf.function(
|
| 123 |
+
input_signature=[
|
| 124 |
+
tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
|
| 125 |
+
tf.TensorSpec(
|
| 126 |
+
shape=[
|
| 127 |
+
None,
|
| 128 |
+
get_num_z_categoricals(model_size),
|
| 129 |
+
get_num_z_classes(model_size),
|
| 130 |
+
],
|
| 131 |
+
dtype=dl_type,
|
| 132 |
+
),
|
| 133 |
+
]
|
| 134 |
+
)(self.call)
|
| 135 |
+
|
| 136 |
+
def call(self, h, z):
|
| 137 |
+
"""Performs a forward pass through the Conv2D transpose decoder.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
h: The deterministic hidden state of the sequence model.
|
| 141 |
+
z: The sequence of stochastic discrete representations of the original
|
| 142 |
+
observation input. Note: `z` is not used for the dynamics predictor
|
| 143 |
+
model (which predicts z from h).
|
| 144 |
+
"""
|
| 145 |
+
# Flatten last two dims of z.
|
| 146 |
+
assert len(z.shape) == 3
|
| 147 |
+
z_shape = tf.shape(z)
|
| 148 |
+
z = tf.reshape(z, shape=(z_shape[0], -1))
|
| 149 |
+
assert len(z.shape) == 2
|
| 150 |
+
input_ = tf.concat([h, z], axis=-1)
|
| 151 |
+
input_.set_shape(
|
| 152 |
+
[
|
| 153 |
+
None,
|
| 154 |
+
(
|
| 155 |
+
get_num_z_categoricals(self.model_size)
|
| 156 |
+
* get_num_z_classes(self.model_size)
|
| 157 |
+
+ get_gru_units(self.model_size)
|
| 158 |
+
),
|
| 159 |
+
]
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Feed through initial dense layer to get the right number of input nodes
|
| 163 |
+
# for the first conv2dtranspose layer.
|
| 164 |
+
out = self.dense_layer(input_)
|
| 165 |
+
# Reshape to image format.
|
| 166 |
+
out = tf.reshape(out, shape=(-1,) + self.input_dims)
|
| 167 |
+
|
| 168 |
+
# Pass through stack of Conv2DTransport layers (and layer norms).
|
| 169 |
+
for conv_transpose_2d, layer_norm in zip(
|
| 170 |
+
self.conv_transpose_layers, self.layer_normalizations
|
| 171 |
+
):
|
| 172 |
+
out = tf.nn.silu(layer_norm(inputs=conv_transpose_2d(out)))
|
| 173 |
+
# Last output conv2d-transpose layer:
|
| 174 |
+
out = self.output_conv2d_transpose(out)
|
| 175 |
+
out += 0.5 # See Danijar's code
|
| 176 |
+
out_shape = tf.shape(out)
|
| 177 |
+
|
| 178 |
+
# Interpret output as means of a diag-Gaussian with std=1.0:
|
| 179 |
+
# From [2]:
|
| 180 |
+
# "Distributions: The image predictor outputs the mean of a diagonal Gaussian
|
| 181 |
+
# likelihood with unit variance, ..."
|
| 182 |
+
|
| 183 |
+
# Reshape `out` for the diagonal multi-variate Gaussian (each pixel is its own
|
| 184 |
+
# independent (b/c diagonal co-variance matrix) variable).
|
| 185 |
+
loc = tf.reshape(out, shape=(out_shape[0], -1))
|
| 186 |
+
|
| 187 |
+
return loc
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
|
| 8 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 9 |
+
from ray.rllib.algorithms.dreamerv3.utils import (
|
| 10 |
+
get_gru_units,
|
| 11 |
+
get_num_z_categoricals,
|
| 12 |
+
get_num_z_classes,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 15 |
+
|
| 16 |
+
_, tf, _ = try_import_tf()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VectorDecoder(tf.keras.Model):
|
| 20 |
+
"""A simple vector decoder to reproduce non-image (1D vector) observations.
|
| 21 |
+
|
| 22 |
+
Wraps an MLP for mean parameter computations and a Gaussian distribution,
|
| 23 |
+
from which we then sample using these mean values and a fixed stddev of 1.0.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
*,
|
| 29 |
+
model_size: str = "XS",
|
| 30 |
+
observation_space: gym.Space,
|
| 31 |
+
):
|
| 32 |
+
"""Initializes a VectorDecoder instance.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model_size: The "Model Size" used according to [1] Appendinx B.
|
| 36 |
+
Determines the exact size of the underlying MLP.
|
| 37 |
+
observation_space: The observation space to decode back into. This must
|
| 38 |
+
be a Box of shape (d,), where d >= 1.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__(name="vector_decoder")
|
| 41 |
+
|
| 42 |
+
self.model_size = model_size
|
| 43 |
+
|
| 44 |
+
assert (
|
| 45 |
+
isinstance(observation_space, gym.spaces.Box)
|
| 46 |
+
and len(observation_space.shape) == 1
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.mlp = MLP(
|
| 50 |
+
model_size=model_size,
|
| 51 |
+
output_layer_size=observation_space.shape[0],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Trace self.call.
|
| 55 |
+
dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 56 |
+
self.call = tf.function(
|
| 57 |
+
input_signature=[
|
| 58 |
+
tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
|
| 59 |
+
tf.TensorSpec(
|
| 60 |
+
shape=[
|
| 61 |
+
None,
|
| 62 |
+
get_num_z_categoricals(model_size),
|
| 63 |
+
get_num_z_classes(model_size),
|
| 64 |
+
],
|
| 65 |
+
dtype=dl_type,
|
| 66 |
+
),
|
| 67 |
+
]
|
| 68 |
+
)(self.call)
|
| 69 |
+
|
| 70 |
+
def call(self, h, z):
|
| 71 |
+
"""Performs a forward pass through the vector encoder.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
h: The deterministic hidden state of the sequence model. [B, dim(h)].
|
| 75 |
+
z: The stochastic discrete representations of the original
|
| 76 |
+
observation input. [B, num_categoricals, num_classes].
|
| 77 |
+
"""
|
| 78 |
+
# Flatten last two dims of z.
|
| 79 |
+
assert len(z.shape) == 3
|
| 80 |
+
z_shape = tf.shape(z)
|
| 81 |
+
z = tf.reshape(z, shape=(z_shape[0], -1))
|
| 82 |
+
assert len(z.shape) == 2
|
| 83 |
+
out = tf.concat([h, z], axis=-1)
|
| 84 |
+
out.set_shape(
|
| 85 |
+
[
|
| 86 |
+
None,
|
| 87 |
+
(
|
| 88 |
+
get_num_z_categoricals(self.model_size)
|
| 89 |
+
* get_num_z_classes(self.model_size)
|
| 90 |
+
+ get_gru_units(self.model_size)
|
| 91 |
+
),
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
# Send h-cat-z through MLP to get mean values of diag gaussian.
|
| 95 |
+
loc = self.mlp(out)
|
| 96 |
+
|
| 97 |
+
# Return only the predicted observations (mean, no sample).
|
| 98 |
+
return loc
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/critic_network.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 7 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor_layer import (
|
| 8 |
+
RewardPredictorLayer,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.algorithms.dreamerv3.utils import (
|
| 11 |
+
get_gru_units,
|
| 12 |
+
get_num_z_categoricals,
|
| 13 |
+
get_num_z_classes,
|
| 14 |
+
)
|
| 15 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 16 |
+
|
| 17 |
+
_, tf, _ = try_import_tf()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CriticNetwork(tf.keras.Model):
|
| 21 |
+
"""The critic network described in [1], predicting values for policy learning.
|
| 22 |
+
|
| 23 |
+
Contains a copy of itself (EMA net) for weight regularization.
|
| 24 |
+
The EMA net is updated after each train step via EMA (using the `ema_decay`
|
| 25 |
+
parameter and the actual critic's weights). The EMA net is NOT used for target
|
| 26 |
+
computations (we use the actual critic for that), its only purpose is to compute a
|
| 27 |
+
weights regularizer term for the critic's loss such that the actual critic does not
|
| 28 |
+
move too quickly.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
*,
|
| 34 |
+
model_size: str = "XS",
|
| 35 |
+
num_buckets: int = 255,
|
| 36 |
+
lower_bound: float = -20.0,
|
| 37 |
+
upper_bound: float = 20.0,
|
| 38 |
+
ema_decay: float = 0.98,
|
| 39 |
+
):
|
| 40 |
+
"""Initializes a CriticNetwork instance.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model_size: The "Model Size" used according to [1] Appendinx B.
|
| 44 |
+
Use None for manually setting the different network sizes.
|
| 45 |
+
num_buckets: The number of buckets to create. Note that the number of
|
| 46 |
+
possible symlog'd outcomes from the used distribution is
|
| 47 |
+
`num_buckets` + 1:
|
| 48 |
+
lower_bound --bucket-- o[1] --bucket-- o[2] ... --bucket-- upper_bound
|
| 49 |
+
o=outcomes
|
| 50 |
+
lower_bound=o[0]
|
| 51 |
+
upper_bound=o[num_buckets]
|
| 52 |
+
lower_bound: The symlog'd lower bound for a possible reward value.
|
| 53 |
+
Note that a value of -20.0 here already allows individual (actual env)
|
| 54 |
+
rewards to be as low as -400M. Buckets will be created between
|
| 55 |
+
`lower_bound` and `upper_bound`.
|
| 56 |
+
upper_bound: The symlog'd upper bound for a possible reward value.
|
| 57 |
+
Note that a value of +20.0 here already allows individual (actual env)
|
| 58 |
+
rewards to be as high as 400M. Buckets will be created between
|
| 59 |
+
`lower_bound` and `upper_bound`.
|
| 60 |
+
ema_decay: The weight to use for updating the weights of the critic's copy
|
| 61 |
+
vs the actual critic. After each training update, the EMA copy of the
|
| 62 |
+
critic gets updated according to:
|
| 63 |
+
ema_net=(`ema_decay`*ema_net) + (1.0-`ema_decay`)*critic_net
|
| 64 |
+
The EMA copy of the critic is used inside the critic loss function only
|
| 65 |
+
to produce a regularizer term against the current critic's weights, NOT
|
| 66 |
+
to compute any target values.
|
| 67 |
+
"""
|
| 68 |
+
super().__init__(name="critic")
|
| 69 |
+
|
| 70 |
+
self.model_size = model_size
|
| 71 |
+
self.ema_decay = ema_decay
|
| 72 |
+
|
| 73 |
+
# "Fast" critic network(s) (mlp + reward-pred-layer). This is the network
|
| 74 |
+
# we actually train with our critic loss.
|
| 75 |
+
# IMPORTANT: We also use this to compute the return-targets, BUT we regularize
|
| 76 |
+
# the critic loss term such that the weights of this fast critic stay close
|
| 77 |
+
# to the EMA weights (see below).
|
| 78 |
+
self.mlp = MLP(
|
| 79 |
+
model_size=self.model_size,
|
| 80 |
+
output_layer_size=None,
|
| 81 |
+
)
|
| 82 |
+
self.return_layer = RewardPredictorLayer(
|
| 83 |
+
num_buckets=num_buckets,
|
| 84 |
+
lower_bound=lower_bound,
|
| 85 |
+
upper_bound=upper_bound,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Weights-EMA (EWMA) containing networks for critic loss (similar to a
|
| 89 |
+
# target net, BUT not used to compute anything, just for the
|
| 90 |
+
# weights regularizer term inside the critic loss).
|
| 91 |
+
self.mlp_ema = MLP(
|
| 92 |
+
model_size=self.model_size,
|
| 93 |
+
output_layer_size=None,
|
| 94 |
+
trainable=False,
|
| 95 |
+
)
|
| 96 |
+
self.return_layer_ema = RewardPredictorLayer(
|
| 97 |
+
num_buckets=num_buckets,
|
| 98 |
+
lower_bound=lower_bound,
|
| 99 |
+
upper_bound=upper_bound,
|
| 100 |
+
trainable=False,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Trace self.call.
|
| 104 |
+
dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 105 |
+
self.call = tf.function(
|
| 106 |
+
input_signature=[
|
| 107 |
+
tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
|
| 108 |
+
tf.TensorSpec(
|
| 109 |
+
shape=[
|
| 110 |
+
None,
|
| 111 |
+
get_num_z_categoricals(model_size),
|
| 112 |
+
get_num_z_classes(model_size),
|
| 113 |
+
],
|
| 114 |
+
dtype=dl_type,
|
| 115 |
+
),
|
| 116 |
+
tf.TensorSpec(shape=[], dtype=tf.bool),
|
| 117 |
+
]
|
| 118 |
+
)(self.call)
|
| 119 |
+
|
| 120 |
+
def call(self, h, z, use_ema):
|
| 121 |
+
"""Performs a forward pass through the critic network.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
h: The deterministic hidden state of the sequence model. [B, dim(h)].
|
| 125 |
+
z: The stochastic discrete representations of the original
|
| 126 |
+
observation input. [B, num_categoricals, num_classes].
|
| 127 |
+
use_ema: Whether to use the EMA-copy of the critic instead of the actual
|
| 128 |
+
critic to perform this computation.
|
| 129 |
+
"""
|
| 130 |
+
# Flatten last two dims of z.
|
| 131 |
+
assert len(z.shape) == 3
|
| 132 |
+
z_shape = tf.shape(z)
|
| 133 |
+
z = tf.reshape(z, shape=(z_shape[0], -1))
|
| 134 |
+
assert len(z.shape) == 2
|
| 135 |
+
out = tf.concat([h, z], axis=-1)
|
| 136 |
+
out.set_shape(
|
| 137 |
+
[
|
| 138 |
+
None,
|
| 139 |
+
(
|
| 140 |
+
get_num_z_categoricals(self.model_size)
|
| 141 |
+
* get_num_z_classes(self.model_size)
|
| 142 |
+
+ get_gru_units(self.model_size)
|
| 143 |
+
),
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if not use_ema:
|
| 148 |
+
# Send h-cat-z through MLP.
|
| 149 |
+
out = self.mlp(out)
|
| 150 |
+
# Return expected return OR (expected return, probs of bucket values).
|
| 151 |
+
return self.return_layer(out)
|
| 152 |
+
else:
|
| 153 |
+
out = self.mlp_ema(out)
|
| 154 |
+
return self.return_layer_ema(out)
|
| 155 |
+
|
| 156 |
+
def init_ema(self) -> None:
|
| 157 |
+
"""Initializes the EMA-copy of the critic from the critic's weights.
|
| 158 |
+
|
| 159 |
+
After calling this method, the two networks have identical weights.
|
| 160 |
+
"""
|
| 161 |
+
vars = self.mlp.trainable_variables + self.return_layer.trainable_variables
|
| 162 |
+
vars_ema = self.mlp_ema.variables + self.return_layer_ema.variables
|
| 163 |
+
assert len(vars) == len(vars_ema) and len(vars) > 0
|
| 164 |
+
for var, var_ema in zip(vars, vars_ema):
|
| 165 |
+
assert var is not var_ema
|
| 166 |
+
var_ema.assign(var)
|
| 167 |
+
|
| 168 |
+
def update_ema(self) -> None:
|
| 169 |
+
"""Updates the EMA-copy of the critic according to the update formula:
|
| 170 |
+
|
| 171 |
+
ema_net=(`ema_decay`*ema_net) + (1.0-`ema_decay`)*critic_net
|
| 172 |
+
"""
|
| 173 |
+
vars = self.mlp.trainable_variables + self.return_layer.trainable_variables
|
| 174 |
+
vars_ema = self.mlp_ema.variables + self.return_layer_ema.variables
|
| 175 |
+
assert len(vars) == len(vars_ema) and len(vars) > 0
|
| 176 |
+
for var, var_ema in zip(vars, vars_ema):
|
| 177 |
+
var_ema.assign(self.ema_decay * var_ema + (1.0 - self.ema_decay) * var)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/disagree_networks.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 8 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.representation_layer import (
|
| 9 |
+
RepresentationLayer,
|
| 10 |
+
)
|
| 11 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
| 12 |
+
|
| 13 |
+
_, tf, _ = try_import_tf()
|
| 14 |
+
tfp = try_import_tfp()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DisagreeNetworks(tf.keras.Model):
|
| 18 |
+
"""Predict the RSSM's z^(t+1), given h(t), z^(t), and a(t).
|
| 19 |
+
|
| 20 |
+
Disagreement (stddev) between the N networks in this model on what the next z^ would
|
| 21 |
+
be are used to produce intrinsic rewards for enhanced, curiosity-based exploration.
|
| 22 |
+
|
| 23 |
+
TODO
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, *, num_networks, model_size, intrinsic_rewards_scale):
|
| 27 |
+
super().__init__(name="disagree_networks")
|
| 28 |
+
|
| 29 |
+
self.model_size = model_size
|
| 30 |
+
self.num_networks = num_networks
|
| 31 |
+
self.intrinsic_rewards_scale = intrinsic_rewards_scale
|
| 32 |
+
|
| 33 |
+
self.mlps = []
|
| 34 |
+
self.representation_layers = []
|
| 35 |
+
|
| 36 |
+
for _ in range(self.num_networks):
|
| 37 |
+
self.mlps.append(
|
| 38 |
+
MLP(
|
| 39 |
+
model_size=self.model_size,
|
| 40 |
+
output_layer_size=None,
|
| 41 |
+
trainable=True,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
self.representation_layers.append(
|
| 45 |
+
RepresentationLayer(model_size=self.model_size, name="disagree")
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def call(self, inputs, z, a, training=None):
|
| 49 |
+
return self.forward_train(a=a, h=inputs, z=z)
|
| 50 |
+
|
| 51 |
+
def compute_intrinsic_rewards(self, h, z, a):
|
| 52 |
+
forward_train_outs = self.forward_train(a=a, h=h, z=z)
|
| 53 |
+
B = tf.shape(h)[0]
|
| 54 |
+
|
| 55 |
+
# Intrinsic rewards are computed as:
|
| 56 |
+
# Stddev (between the different nets) of the 32x32 discrete, stochastic
|
| 57 |
+
# probabilities. Meaning that if the larger the disagreement
|
| 58 |
+
# (stddev) between the nets on what the probabilities for the different
|
| 59 |
+
# classes should be, the higher the intrinsic reward.
|
| 60 |
+
z_predicted_probs_N_B = forward_train_outs["z_predicted_probs_N_HxB"]
|
| 61 |
+
N = len(z_predicted_probs_N_B)
|
| 62 |
+
z_predicted_probs_N_B = tf.stack(z_predicted_probs_N_B, axis=0)
|
| 63 |
+
# Flatten z-dims (num_categoricals x num_classes).
|
| 64 |
+
z_predicted_probs_N_B = tf.reshape(z_predicted_probs_N_B, shape=(N, B, -1))
|
| 65 |
+
|
| 66 |
+
# Compute stddevs over all disagree nets (axis=0).
|
| 67 |
+
# Mean over last axis ([num categoricals] x [num classes] folded axis).
|
| 68 |
+
stddevs_B_mean = tf.reduce_mean(
|
| 69 |
+
tf.math.reduce_std(z_predicted_probs_N_B, axis=0),
|
| 70 |
+
axis=-1,
|
| 71 |
+
)
|
| 72 |
+
# TEST:
|
| 73 |
+
stddevs_B_mean -= tf.reduce_mean(stddevs_B_mean)
|
| 74 |
+
# END TEST
|
| 75 |
+
return {
|
| 76 |
+
"rewards_intrinsic": stddevs_B_mean * self.intrinsic_rewards_scale,
|
| 77 |
+
"forward_train_outs": forward_train_outs,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def forward_train(self, a, h, z):
|
| 81 |
+
HxB = tf.shape(h)[0]
|
| 82 |
+
# Fold z-dims.
|
| 83 |
+
z = tf.reshape(z, shape=(HxB, -1))
|
| 84 |
+
# Concat all input components (h, z, and a).
|
| 85 |
+
inputs_ = tf.stop_gradient(tf.concat([h, z, a], axis=-1))
|
| 86 |
+
|
| 87 |
+
z_predicted_probs_N_HxB = [
|
| 88 |
+
repr(mlp(inputs_))[1] # [0]=sample; [1]=returned probs
|
| 89 |
+
for mlp, repr in zip(self.mlps, self.representation_layers)
|
| 90 |
+
]
|
| 91 |
+
# shape=(N, HxB, [num categoricals], [num classes]); N=number of disagree nets.
|
| 92 |
+
# HxB -> folded horizon_H x batch_size_B (from dreamed data).
|
| 93 |
+
|
| 94 |
+
return {"z_predicted_probs_N_HxB": z_predicted_probs_N_HxB}
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import gymnasium as gym
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.disagree_networks import DisagreeNetworks
|
| 12 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.actor_network import ActorNetwork
|
| 13 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.critic_network import CriticNetwork
|
| 14 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel
|
| 15 |
+
from ray.rllib.algorithms.dreamerv3.utils import (
|
| 16 |
+
get_gru_units,
|
| 17 |
+
get_num_z_categoricals,
|
| 18 |
+
get_num_z_classes,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 21 |
+
from ray.rllib.utils.tf_utils import inverse_symlog
|
| 22 |
+
|
| 23 |
+
_, tf, _ = try_import_tf()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DreamerModel(tf.keras.Model):
|
| 27 |
+
"""The main tf-keras model containing all necessary components for DreamerV3.
|
| 28 |
+
|
| 29 |
+
Includes:
|
| 30 |
+
- The world model with encoder, decoder, sequence-model (RSSM), dynamics
|
| 31 |
+
(generates prior z-state), and "posterior" model (generates posterior z-state).
|
| 32 |
+
Predicts env dynamics and produces dreamed trajectories for actor- and critic
|
| 33 |
+
learning.
|
| 34 |
+
- The actor network (policy).
|
| 35 |
+
- The critic network for value function prediction.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
*,
|
| 41 |
+
model_size: str = "XS",
|
| 42 |
+
action_space: gym.Space,
|
| 43 |
+
world_model: WorldModel,
|
| 44 |
+
actor: ActorNetwork,
|
| 45 |
+
critic: CriticNetwork,
|
| 46 |
+
horizon: int,
|
| 47 |
+
gamma: float,
|
| 48 |
+
use_curiosity: bool = False,
|
| 49 |
+
intrinsic_rewards_scale: float = 0.1,
|
| 50 |
+
):
|
| 51 |
+
"""Initializes a DreamerModel instance.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
model_size: The "Model Size" used according to [1] Appendinx B.
|
| 55 |
+
Use None for manually setting the different network sizes.
|
| 56 |
+
action_space: The action space of the environment used.
|
| 57 |
+
world_model: The WorldModel component.
|
| 58 |
+
actor: The ActorNetwork component.
|
| 59 |
+
critic: The CriticNetwork component.
|
| 60 |
+
horizon: The dream horizon to use when creating dreamed trajectories.
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(name="dreamer_model")
|
| 63 |
+
|
| 64 |
+
self.model_size = model_size
|
| 65 |
+
self.action_space = action_space
|
| 66 |
+
self.use_curiosity = use_curiosity
|
| 67 |
+
|
| 68 |
+
self.world_model = world_model
|
| 69 |
+
self.actor = actor
|
| 70 |
+
self.critic = critic
|
| 71 |
+
|
| 72 |
+
self.horizon = horizon
|
| 73 |
+
self.gamma = gamma
|
| 74 |
+
self._comp_dtype = (
|
| 75 |
+
tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.disagree_nets = None
|
| 79 |
+
if self.use_curiosity:
|
| 80 |
+
self.disagree_nets = DisagreeNetworks(
|
| 81 |
+
num_networks=8,
|
| 82 |
+
model_size=self.model_size,
|
| 83 |
+
intrinsic_rewards_scale=intrinsic_rewards_scale,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.dream_trajectory = tf.function(
|
| 87 |
+
input_signature=[
|
| 88 |
+
{
|
| 89 |
+
"h": tf.TensorSpec(
|
| 90 |
+
shape=[
|
| 91 |
+
None,
|
| 92 |
+
get_gru_units(self.model_size),
|
| 93 |
+
],
|
| 94 |
+
dtype=self._comp_dtype,
|
| 95 |
+
),
|
| 96 |
+
"z": tf.TensorSpec(
|
| 97 |
+
shape=[
|
| 98 |
+
None,
|
| 99 |
+
get_num_z_categoricals(self.model_size),
|
| 100 |
+
get_num_z_classes(self.model_size),
|
| 101 |
+
],
|
| 102 |
+
dtype=self._comp_dtype,
|
| 103 |
+
),
|
| 104 |
+
},
|
| 105 |
+
tf.TensorSpec(shape=[None], dtype=tf.bool),
|
| 106 |
+
]
|
| 107 |
+
)(self.dream_trajectory)
|
| 108 |
+
|
| 109 |
+
def call(
|
| 110 |
+
self,
|
| 111 |
+
inputs,
|
| 112 |
+
observations,
|
| 113 |
+
actions,
|
| 114 |
+
is_first,
|
| 115 |
+
start_is_terminated_BxT,
|
| 116 |
+
gamma,
|
| 117 |
+
):
|
| 118 |
+
"""Main call method for building this model in order to generate its variables.
|
| 119 |
+
|
| 120 |
+
Note: This method should NOT be used by users directly. It's purpose is only to
|
| 121 |
+
perform all forward passes necessary to define all variables of the DreamerV3.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
# Forward passes through all models are enough to build all trainable and
|
| 125 |
+
# non-trainable variables:
|
| 126 |
+
|
| 127 |
+
# World model.
|
| 128 |
+
results = self.world_model.forward_train(
|
| 129 |
+
observations,
|
| 130 |
+
actions,
|
| 131 |
+
is_first,
|
| 132 |
+
)
|
| 133 |
+
# Actor.
|
| 134 |
+
_, distr_params = self.actor(
|
| 135 |
+
h=results["h_states_BxT"],
|
| 136 |
+
z=results["z_posterior_states_BxT"],
|
| 137 |
+
)
|
| 138 |
+
# Critic.
|
| 139 |
+
values, _ = self.critic(
|
| 140 |
+
h=results["h_states_BxT"],
|
| 141 |
+
z=results["z_posterior_states_BxT"],
|
| 142 |
+
use_ema=tf.convert_to_tensor(False),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Dream pipeline.
|
| 146 |
+
dream_data = self.dream_trajectory(
|
| 147 |
+
start_states={
|
| 148 |
+
"h": results["h_states_BxT"],
|
| 149 |
+
"z": results["z_posterior_states_BxT"],
|
| 150 |
+
},
|
| 151 |
+
start_is_terminated=start_is_terminated_BxT,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"world_model_fwd": results,
|
| 156 |
+
"dream_data": dream_data,
|
| 157 |
+
"actions": actions,
|
| 158 |
+
"values": values,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
@tf.function
|
| 162 |
+
def forward_inference(self, observations, previous_states, is_first, training=None):
|
| 163 |
+
"""Performs a (non-exploring) action computation step given obs and states.
|
| 164 |
+
|
| 165 |
+
Note that all input data should not have a time rank (only a batch dimension).
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
observations: The current environment observation with shape (B, ...).
|
| 169 |
+
previous_states: Dict with keys `a`, `h`, and `z` used as input to the RSSM
|
| 170 |
+
to produce the next h-state, from which then to compute the action
|
| 171 |
+
using the actor network. All values in the dict should have shape
|
| 172 |
+
(B, ...) (no time rank).
|
| 173 |
+
is_first: Batch of is_first flags. These should be True if a new episode
|
| 174 |
+
has been started at the current timestep (meaning `observations` is the
|
| 175 |
+
reset observation from the environment).
|
| 176 |
+
"""
|
| 177 |
+
# Perform one step in the world model (starting from `previous_state` and
|
| 178 |
+
# using the observations to yield a current (posterior) state).
|
| 179 |
+
states = self.world_model.forward_inference(
|
| 180 |
+
observations=observations,
|
| 181 |
+
previous_states=previous_states,
|
| 182 |
+
is_first=is_first,
|
| 183 |
+
)
|
| 184 |
+
# Compute action using our actor network and the current states.
|
| 185 |
+
_, distr_params = self.actor(h=states["h"], z=states["z"])
|
| 186 |
+
# Use the mode of the distribution (Discrete=argmax, Normal=mean).
|
| 187 |
+
distr = self.actor.get_action_dist_object(distr_params)
|
| 188 |
+
actions = distr.mode()
|
| 189 |
+
return actions, {"h": states["h"], "z": states["z"], "a": actions}
|
| 190 |
+
|
| 191 |
+
@tf.function
|
| 192 |
+
def forward_exploration(
|
| 193 |
+
self, observations, previous_states, is_first, training=None
|
| 194 |
+
):
|
| 195 |
+
"""Performs an exploratory action computation step given obs and states.
|
| 196 |
+
|
| 197 |
+
Note that all input data should not have a time rank (only a batch dimension).
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
observations: The current environment observation with shape (B, ...).
|
| 201 |
+
previous_states: Dict with keys `a`, `h`, and `z` used as input to the RSSM
|
| 202 |
+
to produce the next h-state, from which then to compute the action
|
| 203 |
+
using the actor network. All values in the dict should have shape
|
| 204 |
+
(B, ...) (no time rank).
|
| 205 |
+
is_first: Batch of is_first flags. These should be True if a new episode
|
| 206 |
+
has been started at the current timestep (meaning `observations` is the
|
| 207 |
+
reset observation from the environment).
|
| 208 |
+
"""
|
| 209 |
+
# Perform one step in the world model (starting from `previous_state` and
|
| 210 |
+
# using the observations to yield a current (posterior) state).
|
| 211 |
+
states = self.world_model.forward_inference(
|
| 212 |
+
observations=observations,
|
| 213 |
+
previous_states=previous_states,
|
| 214 |
+
is_first=is_first,
|
| 215 |
+
)
|
| 216 |
+
# Compute action using our actor network and the current states.
|
| 217 |
+
actions, _ = self.actor(h=states["h"], z=states["z"])
|
| 218 |
+
return actions, {"h": states["h"], "z": states["z"], "a": actions}
|
| 219 |
+
|
| 220 |
+
def forward_train(self, observations, actions, is_first):
|
| 221 |
+
"""Performs a training forward pass given observations and actions.
|
| 222 |
+
|
| 223 |
+
Note that all input data must have a time rank (batch-major: [B, T, ...]).
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
observations: The environment observations with shape (B, T, ...). Thus,
|
| 227 |
+
the batch has B rows of T timesteps each. Note that it's ok to have
|
| 228 |
+
episode boundaries (is_first=True) within a batch row. DreamerV3 will
|
| 229 |
+
simply insert an initial state before these locations and continue the
|
| 230 |
+
sequence modelling (with the RSSM). Hence, there will be no zero
|
| 231 |
+
padding.
|
| 232 |
+
actions: The actions actually taken in the environment with shape
|
| 233 |
+
(B, T, ...). See `observations` docstring for details on how B and T are
|
| 234 |
+
handled.
|
| 235 |
+
is_first: Batch of is_first flags. These should be True:
|
| 236 |
+
- if a new episode has been started at the current timestep (meaning
|
| 237 |
+
`observations` is the reset observation from the environment).
|
| 238 |
+
- in each batch row at T=0 (first timestep of each of the B batch
|
| 239 |
+
rows), regardless of whether the actual env had an episode boundary
|
| 240 |
+
there or not.
|
| 241 |
+
"""
|
| 242 |
+
return self.world_model.forward_train(
|
| 243 |
+
observations=observations,
|
| 244 |
+
actions=actions,
|
| 245 |
+
is_first=is_first,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
@tf.function
|
| 249 |
+
def get_initial_state(self):
|
| 250 |
+
"""Returns the (current) initial state of the dreamer model (a, h-, z-states).
|
| 251 |
+
|
| 252 |
+
An initial state is generated using the previous action, the tanh of the
|
| 253 |
+
(learned) h-state variable and the dynamics predictor (or "prior net") to
|
| 254 |
+
compute z^0 from h0. In this last step, it is important that we do NOT sample
|
| 255 |
+
the z^-state (as we would usually do during dreaming), but rather take the mode
|
| 256 |
+
(argmax, then one-hot again).
|
| 257 |
+
"""
|
| 258 |
+
states = self.world_model.get_initial_state()
|
| 259 |
+
|
| 260 |
+
action_dim = (
|
| 261 |
+
self.action_space.n
|
| 262 |
+
if isinstance(self.action_space, gym.spaces.Discrete)
|
| 263 |
+
else np.prod(self.action_space.shape)
|
| 264 |
+
)
|
| 265 |
+
states["a"] = tf.zeros(
|
| 266 |
+
(
|
| 267 |
+
1,
|
| 268 |
+
action_dim,
|
| 269 |
+
),
|
| 270 |
+
dtype=tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32,
|
| 271 |
+
)
|
| 272 |
+
return states
|
| 273 |
+
|
| 274 |
+
def dream_trajectory(self, start_states, start_is_terminated):
|
| 275 |
+
"""Dreams trajectories of length H from batch of h- and z-states.
|
| 276 |
+
|
| 277 |
+
Note that incoming data will have the shapes (BxT, ...), where the original
|
| 278 |
+
batch- and time-dimensions are already folded together. Beginning from this
|
| 279 |
+
new batch dim (BxT), we will unroll `timesteps_H` timesteps in a time-major
|
| 280 |
+
fashion, such that the dreamed data will have shape (H, BxT, ...).
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
start_states: Dict of `h` and `z` states in the shape of (B, ...) and
|
| 284 |
+
(B, num_categoricals, num_classes), respectively, as
|
| 285 |
+
computed by a train forward pass. From each individual h-/z-state pair
|
| 286 |
+
in the given batch, we will branch off a dreamed trajectory of len
|
| 287 |
+
`timesteps_H`.
|
| 288 |
+
start_is_terminated: Float flags of shape (B,) indicating whether the
|
| 289 |
+
first timesteps of each batch row is already a terminated timestep
|
| 290 |
+
(given by the actual environment).
|
| 291 |
+
"""
|
| 292 |
+
# Dreamed actions (one-hot encoded for discrete actions).
|
| 293 |
+
a_dreamed_t0_to_H = []
|
| 294 |
+
a_dreamed_dist_params_t0_to_H = []
|
| 295 |
+
|
| 296 |
+
h = start_states["h"]
|
| 297 |
+
z = start_states["z"]
|
| 298 |
+
|
| 299 |
+
# GRU outputs.
|
| 300 |
+
h_states_t0_to_H = [h]
|
| 301 |
+
# Dynamics model outputs.
|
| 302 |
+
z_states_prior_t0_to_H = [z]
|
| 303 |
+
|
| 304 |
+
# Compute `a` using actor network (already the first step uses a dreamed action,
|
| 305 |
+
# not a sampled one).
|
| 306 |
+
a, a_dist_params = self.actor(
|
| 307 |
+
# We have to stop the gradients through the states. B/c we are using a
|
| 308 |
+
# differentiable Discrete action distribution (straight through gradients
|
| 309 |
+
# with `a = stop_gradient(sample(probs)) + probs - stop_gradient(probs)`,
|
| 310 |
+
# we otherwise would add dependencies of the `-log(pi(a|s))` REINFORCE loss
|
| 311 |
+
# term on actions further back in the trajectory.
|
| 312 |
+
h=tf.stop_gradient(h),
|
| 313 |
+
z=tf.stop_gradient(z),
|
| 314 |
+
)
|
| 315 |
+
a_dreamed_t0_to_H.append(a)
|
| 316 |
+
a_dreamed_dist_params_t0_to_H.append(a_dist_params)
|
| 317 |
+
|
| 318 |
+
for i in range(self.horizon):
|
| 319 |
+
# Move one step in the dream using the RSSM.
|
| 320 |
+
h = self.world_model.sequence_model(a=a, h=h, z=z)
|
| 321 |
+
h_states_t0_to_H.append(h)
|
| 322 |
+
|
| 323 |
+
# Compute prior z using dynamics model.
|
| 324 |
+
z, _ = self.world_model.dynamics_predictor(h=h)
|
| 325 |
+
z_states_prior_t0_to_H.append(z)
|
| 326 |
+
|
| 327 |
+
# Compute `a` using actor network.
|
| 328 |
+
a, a_dist_params = self.actor(
|
| 329 |
+
h=tf.stop_gradient(h),
|
| 330 |
+
z=tf.stop_gradient(z),
|
| 331 |
+
)
|
| 332 |
+
a_dreamed_t0_to_H.append(a)
|
| 333 |
+
a_dreamed_dist_params_t0_to_H.append(a_dist_params)
|
| 334 |
+
|
| 335 |
+
h_states_H_B = tf.stack(h_states_t0_to_H, axis=0) # (T, B, ...)
|
| 336 |
+
h_states_HxB = tf.reshape(h_states_H_B, [-1] + h_states_H_B.shape.as_list()[2:])
|
| 337 |
+
|
| 338 |
+
z_states_prior_H_B = tf.stack(z_states_prior_t0_to_H, axis=0) # (T, B, ...)
|
| 339 |
+
z_states_prior_HxB = tf.reshape(
|
| 340 |
+
z_states_prior_H_B, [-1] + z_states_prior_H_B.shape.as_list()[2:]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
a_dreamed_H_B = tf.stack(a_dreamed_t0_to_H, axis=0) # (T, B, ...)
|
| 344 |
+
a_dreamed_dist_params_H_B = tf.stack(a_dreamed_dist_params_t0_to_H, axis=0)
|
| 345 |
+
|
| 346 |
+
# Compute r using reward predictor.
|
| 347 |
+
r_dreamed_HxB, _ = self.world_model.reward_predictor(
|
| 348 |
+
h=h_states_HxB, z=z_states_prior_HxB
|
| 349 |
+
)
|
| 350 |
+
r_dreamed_H_B = tf.reshape(
|
| 351 |
+
inverse_symlog(r_dreamed_HxB), shape=[self.horizon + 1, -1]
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Compute intrinsic rewards.
|
| 355 |
+
if self.use_curiosity:
|
| 356 |
+
results_HxB = self.disagree_nets.compute_intrinsic_rewards(
|
| 357 |
+
h=h_states_HxB,
|
| 358 |
+
z=z_states_prior_HxB,
|
| 359 |
+
a=tf.reshape(a_dreamed_H_B, [-1] + a_dreamed_H_B.shape.as_list()[2:]),
|
| 360 |
+
)
|
| 361 |
+
# TODO (sven): Wrong? -> Cut out last timestep as we always predict z-states
|
| 362 |
+
# for the NEXT timestep and derive ri (for the NEXT timestep) from the
|
| 363 |
+
# disagreement between our N disagreee nets.
|
| 364 |
+
r_intrinsic_H_B = tf.reshape(
|
| 365 |
+
results_HxB["rewards_intrinsic"], shape=[self.horizon + 1, -1]
|
| 366 |
+
)[
|
| 367 |
+
1:
|
| 368 |
+
] # cut out first ts instead
|
| 369 |
+
curiosity_forward_train_outs = results_HxB["forward_train_outs"]
|
| 370 |
+
del results_HxB
|
| 371 |
+
|
| 372 |
+
# Compute continues using continue predictor.
|
| 373 |
+
c_dreamed_HxB, _ = self.world_model.continue_predictor(
|
| 374 |
+
h=h_states_HxB,
|
| 375 |
+
z=z_states_prior_HxB,
|
| 376 |
+
)
|
| 377 |
+
c_dreamed_H_B = tf.reshape(c_dreamed_HxB, [self.horizon + 1, -1])
|
| 378 |
+
# Force-set first `continue` flags to False iff `start_is_terminated`.
|
| 379 |
+
# Note: This will cause the loss-weights for this row in the batch to be
|
| 380 |
+
# completely zero'd out. In general, we don't use dreamed data past any
|
| 381 |
+
# predicted (or actual first) continue=False flags.
|
| 382 |
+
c_dreamed_H_B = tf.concat(
|
| 383 |
+
[
|
| 384 |
+
1.0
|
| 385 |
+
- tf.expand_dims(
|
| 386 |
+
tf.cast(start_is_terminated, tf.float32),
|
| 387 |
+
0,
|
| 388 |
+
),
|
| 389 |
+
c_dreamed_H_B[1:],
|
| 390 |
+
],
|
| 391 |
+
axis=0,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Loss weights for each individual dreamed timestep. Zero-out all timesteps
|
| 395 |
+
# that lie past continue=False flags. B/c our world model does NOT learn how
|
| 396 |
+
# to skip terminal/reset episode boundaries, dreamed data crossing such a
|
| 397 |
+
# boundary should not be used for critic/actor learning either.
|
| 398 |
+
dream_loss_weights_H_B = (
|
| 399 |
+
tf.math.cumprod(self.gamma * c_dreamed_H_B, axis=0) / self.gamma
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Compute the value estimates.
|
| 403 |
+
v, v_symlog_dreamed_logits_HxB = self.critic(
|
| 404 |
+
h=h_states_HxB,
|
| 405 |
+
z=z_states_prior_HxB,
|
| 406 |
+
use_ema=False,
|
| 407 |
+
)
|
| 408 |
+
v_dreamed_HxB = inverse_symlog(v)
|
| 409 |
+
v_dreamed_H_B = tf.reshape(v_dreamed_HxB, shape=[self.horizon + 1, -1])
|
| 410 |
+
|
| 411 |
+
v_symlog_dreamed_ema_HxB, _ = self.critic(
|
| 412 |
+
h=h_states_HxB,
|
| 413 |
+
z=z_states_prior_HxB,
|
| 414 |
+
use_ema=True,
|
| 415 |
+
)
|
| 416 |
+
v_symlog_dreamed_ema_H_B = tf.reshape(
|
| 417 |
+
v_symlog_dreamed_ema_HxB, shape=[self.horizon + 1, -1]
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
ret = {
|
| 421 |
+
"h_states_t0_to_H_BxT": h_states_H_B,
|
| 422 |
+
"z_states_prior_t0_to_H_BxT": z_states_prior_H_B,
|
| 423 |
+
"rewards_dreamed_t0_to_H_BxT": r_dreamed_H_B,
|
| 424 |
+
"continues_dreamed_t0_to_H_BxT": c_dreamed_H_B,
|
| 425 |
+
"actions_dreamed_t0_to_H_BxT": a_dreamed_H_B,
|
| 426 |
+
"actions_dreamed_dist_params_t0_to_H_BxT": a_dreamed_dist_params_H_B,
|
| 427 |
+
"values_dreamed_t0_to_H_BxT": v_dreamed_H_B,
|
| 428 |
+
"values_symlog_dreamed_logits_t0_to_HxBxT": v_symlog_dreamed_logits_HxB,
|
| 429 |
+
"v_symlog_dreamed_ema_t0_to_H_BxT": v_symlog_dreamed_ema_H_B,
|
| 430 |
+
# Loss weights for critic- and actor losses.
|
| 431 |
+
"dream_loss_weights_t0_to_H_BxT": dream_loss_weights_H_B,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
if self.use_curiosity:
|
| 435 |
+
ret["rewards_intrinsic_t1_to_H_B"] = r_intrinsic_H_B
|
| 436 |
+
ret.update(curiosity_forward_train_outs)
|
| 437 |
+
|
| 438 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 439 |
+
ret["actions_ints_dreamed_t0_to_H_B"] = tf.argmax(a_dreamed_H_B, axis=-1)
|
| 440 |
+
|
| 441 |
+
return ret
|
| 442 |
+
|
| 443 |
+
def dream_trajectory_with_burn_in(
|
| 444 |
+
self,
|
| 445 |
+
*,
|
| 446 |
+
start_states,
|
| 447 |
+
timesteps_burn_in: int,
|
| 448 |
+
timesteps_H: int,
|
| 449 |
+
observations, # [B, >=timesteps_burn_in]
|
| 450 |
+
actions, # [B, timesteps_burn_in (+timesteps_H)?]
|
| 451 |
+
use_sampled_actions_in_dream: bool = False,
|
| 452 |
+
use_random_actions_in_dream: bool = False,
|
| 453 |
+
):
|
| 454 |
+
"""Dreams trajectory from N initial observations and initial states.
|
| 455 |
+
|
| 456 |
+
Note: This is only used for reporting and debugging, not for actual world-model
|
| 457 |
+
or policy training.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
start_states: The batch of start states (dicts with `a`, `h`, and `z` keys)
|
| 461 |
+
to begin dreaming with. These are used to compute the first h-state
|
| 462 |
+
using the sequence model.
|
| 463 |
+
timesteps_burn_in: For how many timesteps should be use the posterior
|
| 464 |
+
z-states (computed by the posterior net and actual observations from
|
| 465 |
+
the env)?
|
| 466 |
+
timesteps_H: For how many timesteps should we dream using the prior
|
| 467 |
+
z-states (computed by the dynamics (prior) net and h-states only)?
|
| 468 |
+
Note that the total length of the returned trajectories will
|
| 469 |
+
be `timesteps_burn_in` + `timesteps_H`.
|
| 470 |
+
observations: The batch (B, T, ...) of observations (to be used only during
|
| 471 |
+
burn-in over `timesteps_burn_in` timesteps).
|
| 472 |
+
actions: The batch (B, T, ...) of actions to use during a) burn-in over the
|
| 473 |
+
first `timesteps_burn_in` timesteps and - possibly - b) during
|
| 474 |
+
actual dreaming, iff use_sampled_actions_in_dream=True.
|
| 475 |
+
If applicable, actions must already be one-hot'd.
|
| 476 |
+
use_sampled_actions_in_dream: If True, instead of using our actor network
|
| 477 |
+
to compute fresh actions, we will use the one provided via the `actions`
|
| 478 |
+
argument. Note that in the latter case, the `actions` time dimension
|
| 479 |
+
must be at least `timesteps_burn_in` + `timesteps_H` long.
|
| 480 |
+
use_random_actions_in_dream: Whether to use randomly sampled actions in the
|
| 481 |
+
dream. Note that this does not apply to the burn-in phase, during which
|
| 482 |
+
we will always use the actions given in the `actions` argument.
|
| 483 |
+
"""
|
| 484 |
+
assert not (use_sampled_actions_in_dream and use_random_actions_in_dream)
|
| 485 |
+
|
| 486 |
+
B = observations.shape[0]
|
| 487 |
+
|
| 488 |
+
# Produce initial N internal posterior states (burn-in) using the given
|
| 489 |
+
# observations:
|
| 490 |
+
states = start_states
|
| 491 |
+
for i in range(timesteps_burn_in):
|
| 492 |
+
states = self.world_model.forward_inference(
|
| 493 |
+
observations=observations[:, i],
|
| 494 |
+
previous_states=states,
|
| 495 |
+
is_first=tf.fill((B,), 1.0 if i == 0 else 0.0),
|
| 496 |
+
)
|
| 497 |
+
states["a"] = actions[:, i]
|
| 498 |
+
|
| 499 |
+
# Start producing the actual dream, using prior states and either the given
|
| 500 |
+
# actions, dreamed, or random ones.
|
| 501 |
+
h_states_t0_to_H = [states["h"]]
|
| 502 |
+
z_states_prior_t0_to_H = [states["z"]]
|
| 503 |
+
a_t0_to_H = [states["a"]]
|
| 504 |
+
|
| 505 |
+
for j in range(timesteps_H):
|
| 506 |
+
# Compute next h using sequence model.
|
| 507 |
+
h = self.world_model.sequence_model(
|
| 508 |
+
a=states["a"],
|
| 509 |
+
h=states["h"],
|
| 510 |
+
z=states["z"],
|
| 511 |
+
)
|
| 512 |
+
h_states_t0_to_H.append(h)
|
| 513 |
+
# Compute z from h, using the dynamics model (we don't have an actual
|
| 514 |
+
# observation at this timestep).
|
| 515 |
+
z, _ = self.world_model.dynamics_predictor(h=h)
|
| 516 |
+
z_states_prior_t0_to_H.append(z)
|
| 517 |
+
|
| 518 |
+
# Compute next dreamed action or use sampled one or random one.
|
| 519 |
+
if use_sampled_actions_in_dream:
|
| 520 |
+
a = actions[:, timesteps_burn_in + j]
|
| 521 |
+
elif use_random_actions_in_dream:
|
| 522 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 523 |
+
a = tf.random.randint((B,), 0, self.action_space.n, tf.int64)
|
| 524 |
+
a = tf.one_hot(
|
| 525 |
+
a,
|
| 526 |
+
depth=self.action_space.n,
|
| 527 |
+
dtype=tf.keras.mixed_precision.global_policy().compute_dtype
|
| 528 |
+
or tf.float32,
|
| 529 |
+
)
|
| 530 |
+
# TODO: Support cont. action spaces with bound other than 0.0 and 1.0.
|
| 531 |
+
else:
|
| 532 |
+
a = tf.random.uniform(
|
| 533 |
+
shape=(B,) + self.action_space.shape,
|
| 534 |
+
dtype=self.action_space.dtype,
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
a, _ = self.actor(h=h, z=z)
|
| 538 |
+
a_t0_to_H.append(a)
|
| 539 |
+
|
| 540 |
+
states = {"h": h, "z": z, "a": a}
|
| 541 |
+
|
| 542 |
+
# Fold time-rank for upcoming batch-predictions (no sequences needed anymore).
|
| 543 |
+
h_states_t0_to_H_B = tf.stack(h_states_t0_to_H, axis=0)
|
| 544 |
+
h_states_t0_to_HxB = tf.reshape(
|
| 545 |
+
h_states_t0_to_H_B, shape=[-1] + h_states_t0_to_H_B.shape.as_list()[2:]
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
z_states_prior_t0_to_H_B = tf.stack(z_states_prior_t0_to_H, axis=0)
|
| 549 |
+
z_states_prior_t0_to_HxB = tf.reshape(
|
| 550 |
+
z_states_prior_t0_to_H_B,
|
| 551 |
+
shape=[-1] + z_states_prior_t0_to_H_B.shape.as_list()[2:],
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
a_t0_to_H_B = tf.stack(a_t0_to_H, axis=0)
|
| 555 |
+
|
| 556 |
+
# Compute o using decoder.
|
| 557 |
+
o_dreamed_t0_to_HxB = self.world_model.decoder(
|
| 558 |
+
h=h_states_t0_to_HxB,
|
| 559 |
+
z=z_states_prior_t0_to_HxB,
|
| 560 |
+
)
|
| 561 |
+
if self.world_model.symlog_obs:
|
| 562 |
+
o_dreamed_t0_to_HxB = inverse_symlog(o_dreamed_t0_to_HxB)
|
| 563 |
+
|
| 564 |
+
# Compute r using reward predictor.
|
| 565 |
+
r_dreamed_t0_to_HxB, _ = self.world_model.reward_predictor(
|
| 566 |
+
h=h_states_t0_to_HxB,
|
| 567 |
+
z=z_states_prior_t0_to_HxB,
|
| 568 |
+
)
|
| 569 |
+
r_dreamed_t0_to_HxB = inverse_symlog(r_dreamed_t0_to_HxB)
|
| 570 |
+
# Compute continues using continue predictor.
|
| 571 |
+
c_dreamed_t0_to_HxB, _ = self.world_model.continue_predictor(
|
| 572 |
+
h=h_states_t0_to_HxB,
|
| 573 |
+
z=z_states_prior_t0_to_HxB,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Return everything as time-major (H, B, ...), where H is the timesteps dreamed
|
| 577 |
+
# (NOT burn-in'd) and B is a batch dimension (this might or might not include
|
| 578 |
+
# an original time dimension from the real env, from all of which we then branch
|
| 579 |
+
# out our dream trajectories).
|
| 580 |
+
ret = {
|
| 581 |
+
"h_states_t0_to_H_BxT": h_states_t0_to_H_B,
|
| 582 |
+
"z_states_prior_t0_to_H_BxT": z_states_prior_t0_to_H_B,
|
| 583 |
+
# Unfold time-ranks in predictions.
|
| 584 |
+
"observations_dreamed_t0_to_H_BxT": tf.reshape(
|
| 585 |
+
o_dreamed_t0_to_HxB, [-1, B] + list(observations.shape)[2:]
|
| 586 |
+
),
|
| 587 |
+
"rewards_dreamed_t0_to_H_BxT": tf.reshape(r_dreamed_t0_to_HxB, (-1, B)),
|
| 588 |
+
"continues_dreamed_t0_to_H_BxT": tf.reshape(c_dreamed_t0_to_HxB, (-1, B)),
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
# Figure out action key (random, sampled from env, dreamed?).
|
| 592 |
+
if use_sampled_actions_in_dream:
|
| 593 |
+
key = "actions_sampled_t0_to_H_BxT"
|
| 594 |
+
elif use_random_actions_in_dream:
|
| 595 |
+
key = "actions_random_t0_to_H_BxT"
|
| 596 |
+
else:
|
| 597 |
+
key = "actions_dreamed_t0_to_H_BxT"
|
| 598 |
+
ret[key] = a_t0_to_H_B
|
| 599 |
+
|
| 600 |
+
# Also provide int-actions, if discrete action space.
|
| 601 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 602 |
+
ret[re.sub("^actions_", "actions_ints_", key)] = tf.argmax(
|
| 603 |
+
a_t0_to_H_B, axis=-1
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
return ret
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/world_model.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import gymnasium as gym
|
| 9 |
+
import tree # pip install dm_tree
|
| 10 |
+
|
| 11 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.continue_predictor import (
|
| 12 |
+
ContinuePredictor,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.dynamics_predictor import (
|
| 15 |
+
DynamicsPredictor,
|
| 16 |
+
)
|
| 17 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 18 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.representation_layer import (
|
| 19 |
+
RepresentationLayer,
|
| 20 |
+
)
|
| 21 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor import (
|
| 22 |
+
RewardPredictor,
|
| 23 |
+
)
|
| 24 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.sequence_model import (
|
| 25 |
+
SequenceModel,
|
| 26 |
+
)
|
| 27 |
+
from ray.rllib.algorithms.dreamerv3.utils import get_gru_units
|
| 28 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 29 |
+
from ray.rllib.utils.tf_utils import symlog
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
_, tf, _ = try_import_tf()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class WorldModel(tf.keras.Model):
|
| 36 |
+
"""WorldModel component of [1] w/ encoder, decoder, RSSM, reward/cont. predictors.
|
| 37 |
+
|
| 38 |
+
See eq. 3 of [1] for all components and their respective in- and outputs.
|
| 39 |
+
Note that in the paper, the "encoder" includes both the raw encoder plus the
|
| 40 |
+
"posterior net", which produces posterior z-states from observations and h-states.
|
| 41 |
+
|
| 42 |
+
Note: The "internal state" of the world model always consists of:
|
| 43 |
+
The actions `a` (initially, this is a zeroed-out action), `h`-states (deterministic,
|
| 44 |
+
continuous), and `z`-states (stochastic, discrete).
|
| 45 |
+
There are two versions of z-states: "posterior" for world model training and "prior"
|
| 46 |
+
for creating the dream data.
|
| 47 |
+
|
| 48 |
+
Initial internal state values (`a`, `h`, and `z`) are inserted where ever a new
|
| 49 |
+
episode starts within a batch row OR at the beginning of each train batch's B rows,
|
| 50 |
+
regardless of whether there was an actual episode boundary or not. Thus, internal
|
| 51 |
+
states are not required to be stored in or retrieved from the replay buffer AND
|
| 52 |
+
retrieved batches from the buffer must not be zero padded.
|
| 53 |
+
|
| 54 |
+
Initial `a` is the zero "one hot" action, e.g. [0.0, 0.0] for Discrete(2), initial
|
| 55 |
+
`h` is a separate learned variable, and initial `z` are computed by the "dynamics"
|
| 56 |
+
(or "prior") net, using only the initial-h state as input.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
*,
|
| 62 |
+
model_size: str = "XS",
|
| 63 |
+
observation_space: gym.Space,
|
| 64 |
+
action_space: gym.Space,
|
| 65 |
+
batch_length_T: int = 64,
|
| 66 |
+
encoder: tf.keras.Model,
|
| 67 |
+
decoder: tf.keras.Model,
|
| 68 |
+
num_gru_units: Optional[int] = None,
|
| 69 |
+
symlog_obs: bool = True,
|
| 70 |
+
):
|
| 71 |
+
"""Initializes a WorldModel instance.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model_size: The "Model Size" used according to [1] Appendinx B.
|
| 75 |
+
Use None for manually setting the different network sizes.
|
| 76 |
+
observation_space: The observation space of the environment used.
|
| 77 |
+
action_space: The action space of the environment used.
|
| 78 |
+
batch_length_T: The length (T) of the sequences used for training. The
|
| 79 |
+
actual shape of the input data (e.g. rewards) is then: [B, T, ...],
|
| 80 |
+
where B is the "batch size", T is the "batch length" (this arg) and
|
| 81 |
+
"..." is the dimension of the data (e.g. (64, 64, 3) for Atari image
|
| 82 |
+
observations). Note that a single row (within a batch) may contain data
|
| 83 |
+
from different episodes, but an already on-going episode is always
|
| 84 |
+
finished, before a new one starts within the same row.
|
| 85 |
+
encoder: The encoder Model taking observations as inputs and
|
| 86 |
+
outputting a 1D latent vector that will be used as input into the
|
| 87 |
+
posterior net (z-posterior state generating layer). Inputs are symlogged
|
| 88 |
+
if inputs are NOT images. For images, we use normalization between -1.0
|
| 89 |
+
and 1.0 (x / 128 - 1.0)
|
| 90 |
+
decoder: The decoder Model taking h- and z-states as inputs and generating
|
| 91 |
+
a (possibly symlogged) predicted observation. Note that for images,
|
| 92 |
+
the last decoder layer produces the exact, normalized pixel values
|
| 93 |
+
(not a Gaussian as described in [1]!).
|
| 94 |
+
num_gru_units: The number of GRU units to use. If None, use
|
| 95 |
+
`model_size` to figure out this parameter.
|
| 96 |
+
symlog_obs: Whether to predict decoded observations in symlog space.
|
| 97 |
+
This should be False for image based observations.
|
| 98 |
+
According to the paper [1] Appendix E: "NoObsSymlog: This ablation
|
| 99 |
+
removes the symlog encoding of inputs to the world model and also
|
| 100 |
+
changes the symlog MSE loss in the decoder to a simple MSE loss.
|
| 101 |
+
*Because symlog encoding is only used for vector observations*, this
|
| 102 |
+
ablation is equivalent to DreamerV3 on purely image-based environments".
|
| 103 |
+
"""
|
| 104 |
+
super().__init__(name="world_model")
|
| 105 |
+
|
| 106 |
+
self.model_size = model_size
|
| 107 |
+
self.batch_length_T = batch_length_T
|
| 108 |
+
self.symlog_obs = symlog_obs
|
| 109 |
+
self.observation_space = observation_space
|
| 110 |
+
self.action_space = action_space
|
| 111 |
+
self._comp_dtype = (
|
| 112 |
+
tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Encoder (latent 1D vector generator) (xt -> lt).
|
| 116 |
+
self.encoder = encoder
|
| 117 |
+
|
| 118 |
+
# Posterior predictor consisting of an MLP and a RepresentationLayer:
|
| 119 |
+
# [ht, lt] -> zt.
|
| 120 |
+
self.posterior_mlp = MLP(
|
| 121 |
+
model_size=self.model_size,
|
| 122 |
+
output_layer_size=None,
|
| 123 |
+
# In Danijar's code, the posterior predictor only has a single layer,
|
| 124 |
+
# no matter the model size:
|
| 125 |
+
num_dense_layers=1,
|
| 126 |
+
name="posterior_mlp",
|
| 127 |
+
)
|
| 128 |
+
# The (posterior) z-state generating layer.
|
| 129 |
+
self.posterior_representation_layer = RepresentationLayer(
|
| 130 |
+
model_size=self.model_size,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Dynamics (prior z-state) predictor: ht -> z^t
|
| 134 |
+
self.dynamics_predictor = DynamicsPredictor(model_size=self.model_size)
|
| 135 |
+
|
| 136 |
+
# GRU for the RSSM: [at, ht, zt] -> ht+1
|
| 137 |
+
self.num_gru_units = get_gru_units(
|
| 138 |
+
model_size=self.model_size,
|
| 139 |
+
override=num_gru_units,
|
| 140 |
+
)
|
| 141 |
+
# Initial h-state variable (learnt).
|
| 142 |
+
# -> tanh(self.initial_h) -> deterministic state
|
| 143 |
+
# Use our Dynamics predictor for initial stochastic state, BUT with greedy
|
| 144 |
+
# (mode) instead of sampling.
|
| 145 |
+
self.initial_h = tf.Variable(
|
| 146 |
+
tf.zeros(shape=(self.num_gru_units,)),
|
| 147 |
+
trainable=True,
|
| 148 |
+
name="initial_h",
|
| 149 |
+
)
|
| 150 |
+
# The actual sequence model containing the GRU layer.
|
| 151 |
+
self.sequence_model = SequenceModel(
|
| 152 |
+
model_size=self.model_size,
|
| 153 |
+
action_space=self.action_space,
|
| 154 |
+
num_gru_units=self.num_gru_units,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Reward Predictor: [ht, zt] -> rt.
|
| 158 |
+
self.reward_predictor = RewardPredictor(model_size=self.model_size)
|
| 159 |
+
# Continue Predictor: [ht, zt] -> ct.
|
| 160 |
+
self.continue_predictor = ContinuePredictor(model_size=self.model_size)
|
| 161 |
+
|
| 162 |
+
# Decoder: [ht, zt] -> x^t.
|
| 163 |
+
self.decoder = decoder
|
| 164 |
+
|
| 165 |
+
# Trace self.call.
|
| 166 |
+
self.forward_train = tf.function(
|
| 167 |
+
input_signature=[
|
| 168 |
+
tf.TensorSpec(shape=[None, None] + list(self.observation_space.shape)),
|
| 169 |
+
tf.TensorSpec(
|
| 170 |
+
shape=[None, None]
|
| 171 |
+
+ (
|
| 172 |
+
[self.action_space.n]
|
| 173 |
+
if isinstance(action_space, gym.spaces.Discrete)
|
| 174 |
+
else list(self.action_space.shape)
|
| 175 |
+
)
|
| 176 |
+
),
|
| 177 |
+
tf.TensorSpec(shape=[None, None], dtype=tf.bool),
|
| 178 |
+
]
|
| 179 |
+
)(self.forward_train)
|
| 180 |
+
|
| 181 |
+
@tf.function
|
| 182 |
+
def get_initial_state(self):
|
| 183 |
+
"""Returns the (current) initial state of the world model (h- and z-states).
|
| 184 |
+
|
| 185 |
+
An initial state is generated using the tanh of the (learned) h-state variable
|
| 186 |
+
and the dynamics predictor (or "prior net") to compute z^0 from h0. In this last
|
| 187 |
+
step, it is important that we do NOT sample the z^-state (as we would usually
|
| 188 |
+
do during dreaming), but rather take the mode (argmax, then one-hot again).
|
| 189 |
+
"""
|
| 190 |
+
h = tf.expand_dims(tf.math.tanh(tf.cast(self.initial_h, self._comp_dtype)), 0)
|
| 191 |
+
# Use the mode, NOT a sample for the initial z-state.
|
| 192 |
+
_, z_probs = self.dynamics_predictor(h)
|
| 193 |
+
z = tf.argmax(z_probs, axis=-1)
|
| 194 |
+
z = tf.one_hot(z, depth=z_probs.shape[-1], dtype=self._comp_dtype)
|
| 195 |
+
|
| 196 |
+
return {"h": h, "z": z}
|
| 197 |
+
|
| 198 |
+
def forward_inference(self, observations, previous_states, is_first, training=None):
|
| 199 |
+
"""Performs a forward step for inference (e.g. environment stepping).
|
| 200 |
+
|
| 201 |
+
Works analogous to `forward_train`, except that all inputs are provided
|
| 202 |
+
for a single timestep in the shape of [B, ...] (no time dimension!).
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
observations: The batch (B, ...) of observations to be passed through
|
| 206 |
+
the encoder network to yield the inputs to the representation layer
|
| 207 |
+
(which then can compute the z-states).
|
| 208 |
+
previous_states: A dict with `h`, `z`, and `a` keys mapping to the
|
| 209 |
+
respective previous states/actions. All of the shape (B, ...), no time
|
| 210 |
+
rank.
|
| 211 |
+
is_first: The batch (B) of `is_first` flags.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
The next deterministic h-state (h(t+1)) as predicted by the sequence model.
|
| 215 |
+
"""
|
| 216 |
+
observations = tf.cast(observations, self._comp_dtype)
|
| 217 |
+
|
| 218 |
+
initial_states = tree.map_structure(
|
| 219 |
+
lambda s: tf.repeat(s, tf.shape(observations)[0], axis=0),
|
| 220 |
+
self.get_initial_state(),
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# If first, mask it with initial state/actions.
|
| 224 |
+
previous_h = self._mask(previous_states["h"], 1.0 - is_first) # zero out
|
| 225 |
+
previous_h = previous_h + self._mask(initial_states["h"], is_first) # add init
|
| 226 |
+
|
| 227 |
+
previous_z = self._mask(previous_states["z"], 1.0 - is_first) # zero out
|
| 228 |
+
previous_z = previous_z + self._mask(initial_states["z"], is_first) # add init
|
| 229 |
+
|
| 230 |
+
# Zero out actions (no special learnt initial state).
|
| 231 |
+
previous_a = self._mask(previous_states["a"], 1.0 - is_first)
|
| 232 |
+
|
| 233 |
+
# Compute new states.
|
| 234 |
+
h = self.sequence_model(a=previous_a, h=previous_h, z=previous_z)
|
| 235 |
+
z = self.compute_posterior_z(observations=observations, initial_h=h)
|
| 236 |
+
|
| 237 |
+
return {"h": h, "z": z}
|
| 238 |
+
|
| 239 |
+
def forward_train(self, observations, actions, is_first):
|
| 240 |
+
"""Performs a forward step for training.
|
| 241 |
+
|
| 242 |
+
1) Forwards all observations [B, T, ...] through the encoder network to yield
|
| 243 |
+
o_processed[B, T, ...].
|
| 244 |
+
2) Uses initial state (h0/z^0/a0[B, 0, ...]) and sequence model (RSSM) to
|
| 245 |
+
compute the first internal state (h1 and z^1).
|
| 246 |
+
3) Uses action a[B, 1, ...], z[B, 1, ...] and h[B, 1, ...] to compute the
|
| 247 |
+
next h-state (h[B, 2, ...]), etc..
|
| 248 |
+
4) Repeats 2) and 3) until t=T.
|
| 249 |
+
5) Uses all h[B, T, ...] and z[B, T, ...] to compute predicted/reconstructed
|
| 250 |
+
observations, rewards, and continue signals.
|
| 251 |
+
6) Returns predictions from 5) along with all z-states z[B, T, ...] and
|
| 252 |
+
the final h-state (h[B, ...] for t=T).
|
| 253 |
+
|
| 254 |
+
Should we encounter is_first=True flags in the middle of a batch row (somewhere
|
| 255 |
+
within an ongoing sequence of length T), we insert this world model's initial
|
| 256 |
+
state again (zero-action, learned init h-state, and prior-computed z^) and
|
| 257 |
+
simply continue (no zero-padding).
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
observations: The batch (B, T, ...) of observations to be passed through
|
| 261 |
+
the encoder network to yield the inputs to the representation layer
|
| 262 |
+
(which then can compute the posterior z-states).
|
| 263 |
+
actions: The batch (B, T, ...) of actions to be used in combination with
|
| 264 |
+
h-states and computed z-states to yield the next h-states.
|
| 265 |
+
is_first: The batch (B, T) of `is_first` flags.
|
| 266 |
+
"""
|
| 267 |
+
if self.symlog_obs:
|
| 268 |
+
observations = symlog(observations)
|
| 269 |
+
|
| 270 |
+
# Compute bare encoder outs (not z; this is done later with involvement of the
|
| 271 |
+
# sequence model and the h-states).
|
| 272 |
+
# Fold time dimension for CNN pass.
|
| 273 |
+
shape = tf.shape(observations)
|
| 274 |
+
B, T = shape[0], shape[1]
|
| 275 |
+
observations = tf.reshape(
|
| 276 |
+
observations, shape=tf.concat([[-1], shape[2:]], axis=0)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
encoder_out = self.encoder(tf.cast(observations, self._comp_dtype))
|
| 280 |
+
# Unfold time dimension.
|
| 281 |
+
encoder_out = tf.reshape(
|
| 282 |
+
encoder_out, shape=tf.concat([[B, T], tf.shape(encoder_out)[1:]], axis=0)
|
| 283 |
+
)
|
| 284 |
+
# Make time major for faster upcoming loop.
|
| 285 |
+
encoder_out = tf.transpose(
|
| 286 |
+
encoder_out, perm=[1, 0] + list(range(2, len(encoder_out.shape.as_list())))
|
| 287 |
+
)
|
| 288 |
+
# encoder_out=[T, B, ...]
|
| 289 |
+
|
| 290 |
+
initial_states = tree.map_structure(
|
| 291 |
+
lambda s: tf.repeat(s, B, axis=0), self.get_initial_state()
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Make actions and `is_first` time-major.
|
| 295 |
+
actions = tf.transpose(
|
| 296 |
+
tf.cast(actions, self._comp_dtype),
|
| 297 |
+
perm=[1, 0] + list(range(2, tf.shape(actions).shape.as_list()[0])),
|
| 298 |
+
)
|
| 299 |
+
is_first = tf.transpose(tf.cast(is_first, self._comp_dtype), perm=[1, 0])
|
| 300 |
+
|
| 301 |
+
# Loop through the T-axis of our samples and perform one computation step at
|
| 302 |
+
# a time. This is necessary because the sequence model's output (h(t+1)) depends
|
| 303 |
+
# on the current z(t), but z(t) depends on the current sequence model's output
|
| 304 |
+
# h(t).
|
| 305 |
+
z_t0_to_T = [initial_states["z"]]
|
| 306 |
+
z_posterior_probs = []
|
| 307 |
+
z_prior_probs = []
|
| 308 |
+
h_t0_to_T = [initial_states["h"]]
|
| 309 |
+
for t in range(self.batch_length_T):
|
| 310 |
+
# If first, mask it with initial state/actions.
|
| 311 |
+
h_tm1 = self._mask(h_t0_to_T[-1], 1.0 - is_first[t]) # zero out
|
| 312 |
+
h_tm1 = h_tm1 + self._mask(initial_states["h"], is_first[t]) # add init
|
| 313 |
+
|
| 314 |
+
z_tm1 = self._mask(z_t0_to_T[-1], 1.0 - is_first[t]) # zero out
|
| 315 |
+
z_tm1 = z_tm1 + self._mask(initial_states["z"], is_first[t]) # add init
|
| 316 |
+
|
| 317 |
+
# Zero out actions (no special learnt initial state).
|
| 318 |
+
a_tm1 = self._mask(actions[t - 1], 1.0 - is_first[t])
|
| 319 |
+
|
| 320 |
+
# Perform one RSSM (sequence model) step to get the current h.
|
| 321 |
+
h_t = self.sequence_model(a=a_tm1, h=h_tm1, z=z_tm1)
|
| 322 |
+
h_t0_to_T.append(h_t)
|
| 323 |
+
|
| 324 |
+
posterior_mlp_input = tf.concat([encoder_out[t], h_t], axis=-1)
|
| 325 |
+
repr_input = self.posterior_mlp(posterior_mlp_input)
|
| 326 |
+
# Draw one z-sample (z(t)) and also get the z-distribution for dynamics and
|
| 327 |
+
# representation loss computations.
|
| 328 |
+
z_t, z_probs = self.posterior_representation_layer(repr_input)
|
| 329 |
+
# z_t=[B, num_categoricals, num_classes]
|
| 330 |
+
z_posterior_probs.append(z_probs)
|
| 331 |
+
z_t0_to_T.append(z_t)
|
| 332 |
+
|
| 333 |
+
# Compute the predicted z_t (z^) using the dynamics model.
|
| 334 |
+
_, z_probs = self.dynamics_predictor(h_t)
|
| 335 |
+
z_prior_probs.append(z_probs)
|
| 336 |
+
|
| 337 |
+
# Stack at time dimension to yield: [B, T, ...].
|
| 338 |
+
h_t1_to_T = tf.stack(h_t0_to_T[1:], axis=1)
|
| 339 |
+
z_t1_to_T = tf.stack(z_t0_to_T[1:], axis=1)
|
| 340 |
+
|
| 341 |
+
# Fold time axis to retrieve the final (loss ready) Independent distribution
|
| 342 |
+
# (over `num_categoricals` Categoricals).
|
| 343 |
+
z_posterior_probs = tf.stack(z_posterior_probs, axis=1)
|
| 344 |
+
z_posterior_probs = tf.reshape(
|
| 345 |
+
z_posterior_probs,
|
| 346 |
+
shape=[-1] + z_posterior_probs.shape.as_list()[2:],
|
| 347 |
+
)
|
| 348 |
+
# Fold time axis to retrieve the final (loss ready) Independent distribution
|
| 349 |
+
# (over `num_categoricals` Categoricals).
|
| 350 |
+
z_prior_probs = tf.stack(z_prior_probs, axis=1)
|
| 351 |
+
z_prior_probs = tf.reshape(
|
| 352 |
+
z_prior_probs,
|
| 353 |
+
shape=[-1] + z_prior_probs.shape.as_list()[2:],
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Fold time dimension for parallelization of all dependent predictions:
|
| 357 |
+
# observations (reproduction via decoder), rewards, continues.
|
| 358 |
+
h_BxT = tf.reshape(h_t1_to_T, shape=[-1] + h_t1_to_T.shape.as_list()[2:])
|
| 359 |
+
z_BxT = tf.reshape(z_t1_to_T, shape=[-1] + z_t1_to_T.shape.as_list()[2:])
|
| 360 |
+
|
| 361 |
+
obs_distribution_means = tf.cast(self.decoder(h=h_BxT, z=z_BxT), tf.float32)
|
| 362 |
+
|
| 363 |
+
# Compute (predicted) reward distributions.
|
| 364 |
+
rewards, reward_logits = self.reward_predictor(h=h_BxT, z=z_BxT)
|
| 365 |
+
|
| 366 |
+
# Compute (predicted) continue distributions.
|
| 367 |
+
continues, continue_distribution = self.continue_predictor(h=h_BxT, z=z_BxT)
|
| 368 |
+
|
| 369 |
+
# Return outputs for loss computation.
|
| 370 |
+
# Note that all shapes are [BxT, ...] (time axis already folded).
|
| 371 |
+
return {
|
| 372 |
+
# Obs.
|
| 373 |
+
"sampled_obs_symlog_BxT": observations,
|
| 374 |
+
"obs_distribution_means_BxT": obs_distribution_means,
|
| 375 |
+
# Rewards.
|
| 376 |
+
"reward_logits_BxT": reward_logits,
|
| 377 |
+
"rewards_BxT": rewards,
|
| 378 |
+
# Continues.
|
| 379 |
+
"continue_distribution_BxT": continue_distribution,
|
| 380 |
+
"continues_BxT": continues,
|
| 381 |
+
# Deterministic, continuous h-states (t1 to T).
|
| 382 |
+
"h_states_BxT": h_BxT,
|
| 383 |
+
# Sampled, discrete posterior z-states and their probs (t1 to T).
|
| 384 |
+
"z_posterior_states_BxT": z_BxT,
|
| 385 |
+
"z_posterior_probs_BxT": z_posterior_probs,
|
| 386 |
+
# Probs of the prior z-states (t1 to T).
|
| 387 |
+
"z_prior_probs_BxT": z_prior_probs,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
def compute_posterior_z(self, observations, initial_h):
|
| 391 |
+
# Compute bare encoder outputs (not including z, which is computed in next step
|
| 392 |
+
# with involvement of the previous output (initial_h) of the sequence model).
|
| 393 |
+
# encoder_outs=[B, ...]
|
| 394 |
+
if self.symlog_obs:
|
| 395 |
+
observations = symlog(observations)
|
| 396 |
+
encoder_out = self.encoder(observations)
|
| 397 |
+
# Concat encoder outs with the h-states.
|
| 398 |
+
posterior_mlp_input = tf.concat([encoder_out, initial_h], axis=-1)
|
| 399 |
+
# Compute z.
|
| 400 |
+
repr_input = self.posterior_mlp(posterior_mlp_input)
|
| 401 |
+
# Draw a z-sample.
|
| 402 |
+
z_t, _ = self.posterior_representation_layer(repr_input)
|
| 403 |
+
return z_t
|
| 404 |
+
|
| 405 |
+
@staticmethod
|
| 406 |
+
def _mask(value, mask):
|
| 407 |
+
return tf.einsum("b...,b->b...", value, tf.cast(mask, value.dtype))
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO
|
| 2 |
+
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy
|
| 3 |
+
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"PPO",
|
| 7 |
+
"PPOConfig",
|
| 8 |
+
# @OldAPIStack
|
| 9 |
+
"PPOTF1Policy",
|
| 10 |
+
"PPOTF2Policy",
|
| 11 |
+
"PPOTorchPolicy",
|
| 12 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (568 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/default_ppo_rl_module.cpython-311.pyc
ADDED
|
Binary file (3.41 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo.cpython-311.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/__pycache__/ppo_catalog.cpython-311.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|