Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_tf_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_torch_policy.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/default_appo_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo.py +434 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_learner.py +147 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_rl_module.py +11 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py +393 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_torch_policy.py +412 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/default_appo_rl_module.py +59 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/default_appo_torch_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_learner.py +234 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_rl_module.py +13 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/default_appo_torch_rl_module.py +10 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/utils.py +133 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_catalog.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_learner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_rl_module.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/actor_network.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/critic_network.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/disagree_networks.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/dreamer_model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/world_model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/cnn_atari.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/continue_predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/conv_transpose_atari.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/dynamics_predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/mlp.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/representation_layer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor_layer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/sequence_model.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/vector_decoder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py +94 -0
.gitattributes
CHANGED
|
@@ -175,3 +175,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 175 |
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
|
| 176 |
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 177 |
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 175 |
.venv/lib/python3.11/site-packages/ray/jars/ray_dist.jar filter=lfs diff=lfs merge=lfs -text
|
| 176 |
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/worker.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 177 |
.venv/lib/python3.11/site-packages/ray/_private/__pycache__/test_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 178 |
+
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebb34d8a5e73fa6657fb50dde3c5afc10ca55bef89431f9fbe15555295f4da0e
|
| 3 |
+
size 168124
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.appo.appo import APPO, APPOConfig
|
| 2 |
+
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy, APPOTF2Policy
|
| 3 |
+
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"APPO",
|
| 7 |
+
"APPOConfig",
|
| 8 |
+
# @OldAPIStack
|
| 9 |
+
"APPOTF1Policy",
|
| 10 |
+
"APPOTF2Policy",
|
| 11 |
+
"APPOTorchPolicy",
|
| 12 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (580 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo.cpython-311.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_learner.cpython-311.pyc
ADDED
|
Binary file (8.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_rl_module.cpython-311.pyc
ADDED
|
Binary file (637 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_tf_policy.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/appo_torch_policy.cpython-311.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/default_appo_rl_module.cpython-311.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (5.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Asynchronous Proximal Policy Optimization (APPO)
|
| 2 |
+
|
| 3 |
+
The algorithm is described in [1] (under the name of "IMPACT"):
|
| 4 |
+
|
| 5 |
+
Detailed documentation:
|
| 6 |
+
https://docs.ray.io/en/master/rllib-algorithms.html#appo
|
| 7 |
+
|
| 8 |
+
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
|
| 9 |
+
Luo et al. 2020
|
| 10 |
+
https://arxiv.org/pdf/1912.00167
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Optional, Type
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
|
| 17 |
+
from ray.rllib.algorithms.impala.impala import IMPALA, IMPALAConfig
|
| 18 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 19 |
+
from ray.rllib.policy.policy import Policy
|
| 20 |
+
from ray.rllib.utils.annotations import override
|
| 21 |
+
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
| 22 |
+
from ray.rllib.utils.metrics import (
|
| 23 |
+
LAST_TARGET_UPDATE_TS,
|
| 24 |
+
NUM_AGENT_STEPS_SAMPLED,
|
| 25 |
+
NUM_ENV_STEPS_SAMPLED,
|
| 26 |
+
NUM_TARGET_UPDATES,
|
| 27 |
+
)
|
| 28 |
+
from ray.rllib.utils.metrics import LEARNER_STATS_KEY
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
|
| 34 |
+
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
|
| 35 |
+
OLD_ACTION_DIST_KEY = "old_action_dist"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class APPOConfig(IMPALAConfig):
|
| 39 |
+
"""Defines a configuration class from which an APPO Algorithm can be built.
|
| 40 |
+
|
| 41 |
+
.. testcode::
|
| 42 |
+
|
| 43 |
+
from ray.rllib.algorithms.appo import APPOConfig
|
| 44 |
+
config = (
|
| 45 |
+
APPOConfig()
|
| 46 |
+
.training(lr=0.01, grad_clip=30.0, train_batch_size_per_learner=50)
|
| 47 |
+
)
|
| 48 |
+
config = config.learners(num_learners=1)
|
| 49 |
+
config = config.env_runners(num_env_runners=1)
|
| 50 |
+
config = config.environment("CartPole-v1")
|
| 51 |
+
|
| 52 |
+
# Build an Algorithm object from the config and run 1 training iteration.
|
| 53 |
+
algo = config.build()
|
| 54 |
+
algo.train()
|
| 55 |
+
del algo
|
| 56 |
+
|
| 57 |
+
.. testcode::
|
| 58 |
+
|
| 59 |
+
from ray.rllib.algorithms.appo import APPOConfig
|
| 60 |
+
from ray import air
|
| 61 |
+
from ray import tune
|
| 62 |
+
|
| 63 |
+
config = APPOConfig()
|
| 64 |
+
# Update the config object.
|
| 65 |
+
config = config.training(lr=tune.grid_search([0.001,]))
|
| 66 |
+
# Set the config object's env.
|
| 67 |
+
config = config.environment(env="CartPole-v1")
|
| 68 |
+
# Use to_dict() to get the old-style python config dict when running with tune.
|
| 69 |
+
tune.Tuner(
|
| 70 |
+
"APPO",
|
| 71 |
+
run_config=air.RunConfig(
|
| 72 |
+
stop={"training_iteration": 1},
|
| 73 |
+
verbose=0,
|
| 74 |
+
),
|
| 75 |
+
param_space=config.to_dict(),
|
| 76 |
+
|
| 77 |
+
).fit()
|
| 78 |
+
|
| 79 |
+
.. testoutput::
|
| 80 |
+
:hide:
|
| 81 |
+
|
| 82 |
+
...
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, algo_class=None):
|
| 86 |
+
"""Initializes a APPOConfig instance."""
|
| 87 |
+
self.exploration_config = {
|
| 88 |
+
# The Exploration class to use. In the simplest case, this is the name
|
| 89 |
+
# (str) of any class present in the `rllib.utils.exploration` package.
|
| 90 |
+
# You can also provide the python class directly or the full location
|
| 91 |
+
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
|
| 92 |
+
# EpsilonGreedy").
|
| 93 |
+
"type": "StochasticSampling",
|
| 94 |
+
# Add constructor kwargs here (if any).
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
super().__init__(algo_class=algo_class or APPO)
|
| 98 |
+
|
| 99 |
+
# fmt: off
|
| 100 |
+
# __sphinx_doc_begin__
|
| 101 |
+
# APPO specific settings:
|
| 102 |
+
self.vtrace = True
|
| 103 |
+
self.use_gae = True
|
| 104 |
+
self.lambda_ = 1.0
|
| 105 |
+
self.clip_param = 0.4
|
| 106 |
+
self.use_kl_loss = False
|
| 107 |
+
self.kl_coeff = 1.0
|
| 108 |
+
self.kl_target = 0.01
|
| 109 |
+
self.target_worker_clipping = 2.0
|
| 110 |
+
|
| 111 |
+
# Circular replay buffer settings.
|
| 112 |
+
# Used in [1] for discrete action tasks:
|
| 113 |
+
# `circular_buffer_num_batches=4` and `circular_buffer_iterations_per_batch=2`
|
| 114 |
+
# For cont. action tasks:
|
| 115 |
+
# `circular_buffer_num_batches=16` and `circular_buffer_iterations_per_batch=20`
|
| 116 |
+
self.circular_buffer_num_batches = 4
|
| 117 |
+
self.circular_buffer_iterations_per_batch = 2
|
| 118 |
+
|
| 119 |
+
# Override some of IMPALAConfig's default values with APPO-specific values.
|
| 120 |
+
self.num_env_runners = 2
|
| 121 |
+
self.target_network_update_freq = 2
|
| 122 |
+
self.broadcast_interval = 1
|
| 123 |
+
self.grad_clip = 40.0
|
| 124 |
+
# Note: Only when using enable_rl_module_and_learner=True can the clipping mode
|
| 125 |
+
# be configured by the user. On the old API stack, RLlib will always clip by
|
| 126 |
+
# global_norm, no matter the value of `grad_clip_by`.
|
| 127 |
+
self.grad_clip_by = "global_norm"
|
| 128 |
+
|
| 129 |
+
self.opt_type = "adam"
|
| 130 |
+
self.lr = 0.0005
|
| 131 |
+
self.decay = 0.99
|
| 132 |
+
self.momentum = 0.0
|
| 133 |
+
self.epsilon = 0.1
|
| 134 |
+
self.vf_loss_coeff = 0.5
|
| 135 |
+
self.entropy_coeff = 0.01
|
| 136 |
+
self.tau = 1.0
|
| 137 |
+
# __sphinx_doc_end__
|
| 138 |
+
# fmt: on
|
| 139 |
+
|
| 140 |
+
self.lr_schedule = None # @OldAPIStack
|
| 141 |
+
self.entropy_coeff_schedule = None # @OldAPIStack
|
| 142 |
+
self.num_gpus = 0 # @OldAPIStack
|
| 143 |
+
self.num_multi_gpu_tower_stacks = 1 # @OldAPIStack
|
| 144 |
+
self.minibatch_buffer_size = 1 # @OldAPIStack
|
| 145 |
+
self.replay_proportion = 0.0 # @OldAPIStack
|
| 146 |
+
self.replay_buffer_num_slots = 100 # @OldAPIStack
|
| 147 |
+
self.learner_queue_size = 16 # @OldAPIStack
|
| 148 |
+
self.learner_queue_timeout = 300 # @OldAPIStack
|
| 149 |
+
|
| 150 |
+
# Deprecated keys.
|
| 151 |
+
self.target_update_frequency = DEPRECATED_VALUE
|
| 152 |
+
self.use_critic = DEPRECATED_VALUE
|
| 153 |
+
|
| 154 |
+
@override(IMPALAConfig)
|
| 155 |
+
def training(
|
| 156 |
+
self,
|
| 157 |
+
*,
|
| 158 |
+
vtrace: Optional[bool] = NotProvided,
|
| 159 |
+
use_gae: Optional[bool] = NotProvided,
|
| 160 |
+
lambda_: Optional[float] = NotProvided,
|
| 161 |
+
clip_param: Optional[float] = NotProvided,
|
| 162 |
+
use_kl_loss: Optional[bool] = NotProvided,
|
| 163 |
+
kl_coeff: Optional[float] = NotProvided,
|
| 164 |
+
kl_target: Optional[float] = NotProvided,
|
| 165 |
+
target_network_update_freq: Optional[int] = NotProvided,
|
| 166 |
+
tau: Optional[float] = NotProvided,
|
| 167 |
+
target_worker_clipping: Optional[float] = NotProvided,
|
| 168 |
+
circular_buffer_num_batches: Optional[int] = NotProvided,
|
| 169 |
+
circular_buffer_iterations_per_batch: Optional[int] = NotProvided,
|
| 170 |
+
# Deprecated keys.
|
| 171 |
+
target_update_frequency=DEPRECATED_VALUE,
|
| 172 |
+
use_critic=DEPRECATED_VALUE,
|
| 173 |
+
**kwargs,
|
| 174 |
+
) -> "APPOConfig":
|
| 175 |
+
"""Sets the training related configuration.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE
|
| 179 |
+
advantages will be used instead.
|
| 180 |
+
use_gae: If true, use the Generalized Advantage Estimator (GAE)
|
| 181 |
+
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
| 182 |
+
Only applies if vtrace=False.
|
| 183 |
+
lambda_: GAE (lambda) parameter.
|
| 184 |
+
clip_param: PPO surrogate slipping parameter.
|
| 185 |
+
use_kl_loss: Whether to use the KL-term in the loss function.
|
| 186 |
+
kl_coeff: Coefficient for weighting the KL-loss term.
|
| 187 |
+
kl_target: Target term for the KL-term to reach (via adjusting the
|
| 188 |
+
`kl_coeff` automatically).
|
| 189 |
+
target_network_update_freq: NOTE: This parameter is only applicable on
|
| 190 |
+
the new API stack. The frequency with which to update the target
|
| 191 |
+
policy network from the main trained policy network. The metric
|
| 192 |
+
used is `NUM_ENV_STEPS_TRAINED_LIFETIME` and the unit is `n` (see [1]
|
| 193 |
+
4.1.1), where: `n = [circular_buffer_num_batches (N)] *
|
| 194 |
+
[circular_buffer_iterations_per_batch (K)] * [train batch size]`
|
| 195 |
+
For example, if you set `target_network_update_freq=2`, and N=4, K=2,
|
| 196 |
+
and `train_batch_size_per_learner=500`, then the target net is updated
|
| 197 |
+
every 2*4*2*500=8000 trained env steps (every 16 batch updates on each
|
| 198 |
+
learner).
|
| 199 |
+
The authors in [1] suggests that this setting is robust to a range of
|
| 200 |
+
choices (try values between 0.125 and 4).
|
| 201 |
+
target_network_update_freq: The frequency to update the target policy and
|
| 202 |
+
tune the kl loss coefficients that are used during training. After
|
| 203 |
+
setting this parameter, the algorithm waits for at least
|
| 204 |
+
`target_network_update_freq` number of environment samples to be trained
|
| 205 |
+
on before updating the target networks and tune the kl loss
|
| 206 |
+
coefficients. NOTE: This parameter is only applicable when using the
|
| 207 |
+
Learner API (enable_rl_module_and_learner=True).
|
| 208 |
+
tau: The factor by which to update the target policy network towards
|
| 209 |
+
the current policy network. Can range between 0 and 1.
|
| 210 |
+
e.g. updated_param = tau * current_param + (1 - tau) * target_param
|
| 211 |
+
target_worker_clipping: The maximum value for the target-worker-clipping
|
| 212 |
+
used for computing the IS ratio, described in [1]
|
| 213 |
+
IS = min(π(i) / π(target), ρ) * (π / π(i))
|
| 214 |
+
circular_buffer_num_batches: The number of train batches that fit
|
| 215 |
+
into the circular buffer. Each such train batch can be sampled for
|
| 216 |
+
training max. `circular_buffer_iterations_per_batch` times.
|
| 217 |
+
circular_buffer_iterations_per_batch: The number of times any train
|
| 218 |
+
batch in the circular buffer can be sampled for training. A batch gets
|
| 219 |
+
evicted from the buffer either if it's the oldest batch in the buffer
|
| 220 |
+
and a new batch is added OR if the batch reaches this max. number of
|
| 221 |
+
being sampled.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
This updated AlgorithmConfig object.
|
| 225 |
+
"""
|
| 226 |
+
if target_update_frequency != DEPRECATED_VALUE:
|
| 227 |
+
deprecation_warning(
|
| 228 |
+
old="target_update_frequency",
|
| 229 |
+
new="target_network_update_freq",
|
| 230 |
+
error=True,
|
| 231 |
+
)
|
| 232 |
+
if use_critic != DEPRECATED_VALUE:
|
| 233 |
+
deprecation_warning(
|
| 234 |
+
old="use_critic",
|
| 235 |
+
help="`use_critic` no longer supported! APPO always uses a value "
|
| 236 |
+
"function (critic).",
|
| 237 |
+
error=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Pass kwargs onto super's `training()` method.
|
| 241 |
+
super().training(**kwargs)
|
| 242 |
+
|
| 243 |
+
if vtrace is not NotProvided:
|
| 244 |
+
self.vtrace = vtrace
|
| 245 |
+
if use_gae is not NotProvided:
|
| 246 |
+
self.use_gae = use_gae
|
| 247 |
+
if lambda_ is not NotProvided:
|
| 248 |
+
self.lambda_ = lambda_
|
| 249 |
+
if clip_param is not NotProvided:
|
| 250 |
+
self.clip_param = clip_param
|
| 251 |
+
if use_kl_loss is not NotProvided:
|
| 252 |
+
self.use_kl_loss = use_kl_loss
|
| 253 |
+
if kl_coeff is not NotProvided:
|
| 254 |
+
self.kl_coeff = kl_coeff
|
| 255 |
+
if kl_target is not NotProvided:
|
| 256 |
+
self.kl_target = kl_target
|
| 257 |
+
if target_network_update_freq is not NotProvided:
|
| 258 |
+
self.target_network_update_freq = target_network_update_freq
|
| 259 |
+
if tau is not NotProvided:
|
| 260 |
+
self.tau = tau
|
| 261 |
+
if target_worker_clipping is not NotProvided:
|
| 262 |
+
self.target_worker_clipping = target_worker_clipping
|
| 263 |
+
if circular_buffer_num_batches is not NotProvided:
|
| 264 |
+
self.circular_buffer_num_batches = circular_buffer_num_batches
|
| 265 |
+
if circular_buffer_iterations_per_batch is not NotProvided:
|
| 266 |
+
self.circular_buffer_iterations_per_batch = (
|
| 267 |
+
circular_buffer_iterations_per_batch
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return self
|
| 271 |
+
|
| 272 |
+
@override(IMPALAConfig)
|
| 273 |
+
def validate(self) -> None:
|
| 274 |
+
super().validate()
|
| 275 |
+
|
| 276 |
+
# On new API stack, circular buffer should be used, not `minibatch_buffer_size`.
|
| 277 |
+
if self.enable_rl_module_and_learner:
|
| 278 |
+
if self.minibatch_buffer_size != 1 or self.replay_proportion != 0.0:
|
| 279 |
+
self._value_error(
|
| 280 |
+
"`minibatch_buffer_size/replay_proportion` not valid on new API "
|
| 281 |
+
"stack with APPO! "
|
| 282 |
+
"Use `circular_buffer_num_batches` for the number of train batches "
|
| 283 |
+
"in the circular buffer. To change the maximum number of times "
|
| 284 |
+
"any batch may be sampled, set "
|
| 285 |
+
"`circular_buffer_iterations_per_batch`."
|
| 286 |
+
)
|
| 287 |
+
if self.num_multi_gpu_tower_stacks != 1:
|
| 288 |
+
self._value_error(
|
| 289 |
+
"`num_multi_gpu_tower_stacks` not supported on new API stack with "
|
| 290 |
+
"APPO! In order to train on multi-GPU, use "
|
| 291 |
+
"`config.learners(num_learners=[number of GPUs], "
|
| 292 |
+
"num_gpus_per_learner=1)`. To scale the throughput of batch-to-GPU-"
|
| 293 |
+
"pre-loading on each of your `Learners`, set "
|
| 294 |
+
"`num_gpu_loader_threads` to a higher number (recommended values: "
|
| 295 |
+
"1-8)."
|
| 296 |
+
)
|
| 297 |
+
if self.learner_queue_size != 16:
|
| 298 |
+
self._value_error(
|
| 299 |
+
"`learner_queue_size` not supported on new API stack with "
|
| 300 |
+
"APPO! In order set the size of the circular buffer (which acts as "
|
| 301 |
+
"a 'learner queue'), use "
|
| 302 |
+
"`config.training(circular_buffer_num_batches=..)`. To change the "
|
| 303 |
+
"maximum number of times any batch may be sampled, set "
|
| 304 |
+
"`config.training(circular_buffer_iterations_per_batch=..)`."
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
@override(IMPALAConfig)
|
| 308 |
+
def get_default_learner_class(self):
|
| 309 |
+
if self.framework_str == "torch":
|
| 310 |
+
from ray.rllib.algorithms.appo.torch.appo_torch_learner import (
|
| 311 |
+
APPOTorchLearner,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
return APPOTorchLearner
|
| 315 |
+
elif self.framework_str in ["tf2", "tf"]:
|
| 316 |
+
raise ValueError(
|
| 317 |
+
"TensorFlow is no longer supported on the new API stack! "
|
| 318 |
+
"Use `framework='torch'`."
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
raise ValueError(
|
| 322 |
+
f"The framework {self.framework_str} is not supported. "
|
| 323 |
+
"Use `framework='torch'`."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
@override(IMPALAConfig)
|
| 327 |
+
def get_default_rl_module_spec(self) -> RLModuleSpec:
|
| 328 |
+
if self.framework_str == "torch":
|
| 329 |
+
from ray.rllib.algorithms.appo.torch.appo_torch_rl_module import (
|
| 330 |
+
APPOTorchRLModule as RLModule,
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
f"The framework {self.framework_str} is not supported. "
|
| 335 |
+
"Use either 'torch' or 'tf2'."
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return RLModuleSpec(module_class=RLModule)
|
| 339 |
+
|
| 340 |
+
@property
|
| 341 |
+
@override(AlgorithmConfig)
|
| 342 |
+
def _model_config_auto_includes(self):
|
| 343 |
+
return super()._model_config_auto_includes | {"vf_share_layers": False}
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class APPO(IMPALA):
|
| 347 |
+
def __init__(self, config, *args, **kwargs):
|
| 348 |
+
"""Initializes an APPO instance."""
|
| 349 |
+
super().__init__(config, *args, **kwargs)
|
| 350 |
+
|
| 351 |
+
# After init: Initialize target net.
|
| 352 |
+
|
| 353 |
+
# TODO(avnishn): Does this need to happen in __init__? I think we can move it
|
| 354 |
+
# to setup()
|
| 355 |
+
if not self.config.enable_rl_module_and_learner:
|
| 356 |
+
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())
|
| 357 |
+
|
| 358 |
+
@override(IMPALA)
|
| 359 |
+
def training_step(self) -> None:
|
| 360 |
+
if self.config.enable_rl_module_and_learner:
|
| 361 |
+
return super().training_step()
|
| 362 |
+
|
| 363 |
+
train_results = super().training_step()
|
| 364 |
+
# Update the target network and the KL coefficient for the APPO-loss.
|
| 365 |
+
# The target network update frequency is calculated automatically by the product
|
| 366 |
+
# of `num_epochs` setting (usually 1 for APPO) and `minibatch_buffer_size`.
|
| 367 |
+
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
| 368 |
+
cur_ts = self._counters[
|
| 369 |
+
(
|
| 370 |
+
NUM_AGENT_STEPS_SAMPLED
|
| 371 |
+
if self.config.count_steps_by == "agent_steps"
|
| 372 |
+
else NUM_ENV_STEPS_SAMPLED
|
| 373 |
+
)
|
| 374 |
+
]
|
| 375 |
+
target_update_freq = self.config.num_epochs * self.config.minibatch_buffer_size
|
| 376 |
+
if cur_ts - last_update > target_update_freq:
|
| 377 |
+
self._counters[NUM_TARGET_UPDATES] += 1
|
| 378 |
+
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
| 379 |
+
|
| 380 |
+
# Update our target network.
|
| 381 |
+
self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())
|
| 382 |
+
|
| 383 |
+
# Also update the KL-coefficient for the APPO loss, if necessary.
|
| 384 |
+
if self.config.use_kl_loss:
|
| 385 |
+
|
| 386 |
+
def update(pi, pi_id):
|
| 387 |
+
assert LEARNER_STATS_KEY not in train_results, (
|
| 388 |
+
"{} should be nested under policy id key".format(
|
| 389 |
+
LEARNER_STATS_KEY
|
| 390 |
+
),
|
| 391 |
+
train_results,
|
| 392 |
+
)
|
| 393 |
+
if pi_id in train_results:
|
| 394 |
+
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
|
| 395 |
+
assert kl is not None, (train_results, pi_id)
|
| 396 |
+
# Make the actual `Policy.update_kl()` call.
|
| 397 |
+
pi.update_kl(kl)
|
| 398 |
+
else:
|
| 399 |
+
logger.warning("No data for {}, not updating kl".format(pi_id))
|
| 400 |
+
|
| 401 |
+
# Update KL on all trainable policies within the local (trainer)
|
| 402 |
+
# Worker.
|
| 403 |
+
self.env_runner.foreach_policy_to_train(update)
|
| 404 |
+
|
| 405 |
+
return train_results
|
| 406 |
+
|
| 407 |
+
@classmethod
|
| 408 |
+
@override(IMPALA)
|
| 409 |
+
def get_default_config(cls) -> AlgorithmConfig:
|
| 410 |
+
return APPOConfig()
|
| 411 |
+
|
| 412 |
+
@classmethod
|
| 413 |
+
@override(IMPALA)
|
| 414 |
+
def get_default_policy_class(
|
| 415 |
+
cls, config: AlgorithmConfig
|
| 416 |
+
) -> Optional[Type[Policy]]:
|
| 417 |
+
if config["framework"] == "torch":
|
| 418 |
+
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
| 419 |
+
|
| 420 |
+
return APPOTorchPolicy
|
| 421 |
+
elif config["framework"] == "tf":
|
| 422 |
+
if config.enable_rl_module_and_learner:
|
| 423 |
+
raise ValueError(
|
| 424 |
+
"RLlib's RLModule and Learner API is not supported for"
|
| 425 |
+
" tf1. Use "
|
| 426 |
+
"framework='tf2' instead."
|
| 427 |
+
)
|
| 428 |
+
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy
|
| 429 |
+
|
| 430 |
+
return APPOTF1Policy
|
| 431 |
+
else:
|
| 432 |
+
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy
|
| 433 |
+
|
| 434 |
+
return APPOTF2Policy
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_learner.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.appo.appo import APPOConfig
|
| 5 |
+
from ray.rllib.algorithms.appo.utils import CircularBuffer
|
| 6 |
+
from ray.rllib.algorithms.impala.impala_learner import IMPALALearner
|
| 7 |
+
from ray.rllib.core.learner.learner import Learner
|
| 8 |
+
from ray.rllib.core.learner.utils import update_target_network
|
| 9 |
+
from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI
|
| 10 |
+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
|
| 11 |
+
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
|
| 12 |
+
from ray.rllib.utils.annotations import override
|
| 13 |
+
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
|
| 14 |
+
from ray.rllib.utils.metrics import (
|
| 15 |
+
LAST_TARGET_UPDATE_TS,
|
| 16 |
+
NUM_ENV_STEPS_TRAINED_LIFETIME,
|
| 17 |
+
NUM_MODULE_STEPS_TRAINED,
|
| 18 |
+
NUM_TARGET_UPDATES,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.utils.schedules.scheduler import Scheduler
|
| 21 |
+
from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class APPOLearner(IMPALALearner):
|
| 25 |
+
"""Adds KL coeff updates via `after_gradient_based_update()` to IMPALA logic.
|
| 26 |
+
|
| 27 |
+
Framework-specific subclasses must override `_update_module_kl_coeff()`.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
@override(IMPALALearner)
|
| 31 |
+
def build(self):
|
| 32 |
+
self._learner_thread_in_queue = CircularBuffer(
|
| 33 |
+
num_batches=self.config.circular_buffer_num_batches,
|
| 34 |
+
iterations_per_batch=self.config.circular_buffer_iterations_per_batch,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
super().build()
|
| 38 |
+
|
| 39 |
+
# Make target networks.
|
| 40 |
+
self.module.foreach_module(
|
| 41 |
+
lambda mid, mod: (
|
| 42 |
+
mod.make_target_networks()
|
| 43 |
+
if isinstance(mod, TargetNetworkAPI)
|
| 44 |
+
else None
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# The current kl coefficients per module as (framework specific) tensor
|
| 49 |
+
# variables.
|
| 50 |
+
self.curr_kl_coeffs_per_module: LambdaDefaultDict[
|
| 51 |
+
ModuleID, Scheduler
|
| 52 |
+
] = LambdaDefaultDict(
|
| 53 |
+
lambda module_id: self._get_tensor_variable(
|
| 54 |
+
self.config.get_config_for_module(module_id).kl_coeff
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
@override(Learner)
|
| 59 |
+
def add_module(
|
| 60 |
+
self,
|
| 61 |
+
*,
|
| 62 |
+
module_id: ModuleID,
|
| 63 |
+
module_spec: RLModuleSpec,
|
| 64 |
+
config_overrides: Optional[Dict] = None,
|
| 65 |
+
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
|
| 66 |
+
) -> MultiRLModuleSpec:
|
| 67 |
+
marl_spec = super().add_module(
|
| 68 |
+
module_id=module_id,
|
| 69 |
+
module_spec=module_spec,
|
| 70 |
+
config_overrides=config_overrides,
|
| 71 |
+
new_should_module_be_updated=new_should_module_be_updated,
|
| 72 |
+
)
|
| 73 |
+
# Create target networks for added Module, if applicable.
|
| 74 |
+
if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
|
| 75 |
+
self.module[module_id].unwrapped().make_target_networks()
|
| 76 |
+
return marl_spec
|
| 77 |
+
|
| 78 |
+
@override(IMPALALearner)
|
| 79 |
+
def remove_module(self, module_id: str) -> MultiRLModuleSpec:
|
| 80 |
+
marl_spec = super().remove_module(module_id)
|
| 81 |
+
self.curr_kl_coeffs_per_module.pop(module_id)
|
| 82 |
+
return marl_spec
|
| 83 |
+
|
| 84 |
+
@override(Learner)
|
| 85 |
+
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
|
| 86 |
+
"""Updates the target Q Networks."""
|
| 87 |
+
super().after_gradient_based_update(timesteps=timesteps)
|
| 88 |
+
|
| 89 |
+
# TODO (sven): Maybe we should have a `after_gradient_based_update`
|
| 90 |
+
# method per module?
|
| 91 |
+
curr_timestep = timesteps.get(NUM_ENV_STEPS_TRAINED_LIFETIME, 0)
|
| 92 |
+
for module_id, module in self.module._rl_modules.items():
|
| 93 |
+
config = self.config.get_config_for_module(module_id)
|
| 94 |
+
|
| 95 |
+
last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
|
| 96 |
+
if isinstance(module.unwrapped(), TargetNetworkAPI) and (
|
| 97 |
+
curr_timestep - self.metrics.peek(last_update_ts_key, default=0)
|
| 98 |
+
>= (
|
| 99 |
+
config.target_network_update_freq
|
| 100 |
+
* config.circular_buffer_num_batches
|
| 101 |
+
* config.circular_buffer_iterations_per_batch
|
| 102 |
+
* config.train_batch_size_per_learner
|
| 103 |
+
)
|
| 104 |
+
):
|
| 105 |
+
for (
|
| 106 |
+
main_net,
|
| 107 |
+
target_net,
|
| 108 |
+
) in module.unwrapped().get_target_network_pairs():
|
| 109 |
+
update_target_network(
|
| 110 |
+
main_net=main_net,
|
| 111 |
+
target_net=target_net,
|
| 112 |
+
tau=config.tau,
|
| 113 |
+
)
|
| 114 |
+
# Increase lifetime target network update counter by one.
|
| 115 |
+
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
|
| 116 |
+
# Update the (single-value -> window=1) last updated timestep metric.
|
| 117 |
+
self.metrics.log_value(last_update_ts_key, curr_timestep, window=1)
|
| 118 |
+
|
| 119 |
+
if (
|
| 120 |
+
config.use_kl_loss
|
| 121 |
+
and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0)
|
| 122 |
+
> 0
|
| 123 |
+
):
|
| 124 |
+
self._update_module_kl_coeff(module_id=module_id, config=config)
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
@override(Learner)
|
| 128 |
+
def rl_module_required_apis(cls) -> list[type]:
|
| 129 |
+
# In order for a PPOLearner to update an RLModule, it must implement the
|
| 130 |
+
# following APIs:
|
| 131 |
+
return [TargetNetworkAPI, ValueFunctionAPI]
|
| 132 |
+
|
| 133 |
+
@abc.abstractmethod
|
| 134 |
+
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
|
| 135 |
+
"""Dynamically update the KL loss coefficients of each module.
|
| 136 |
+
|
| 137 |
+
The update is completed using the mean KL divergence between the action
|
| 138 |
+
distributions current policy and old policy of each module. That action
|
| 139 |
+
distribution is computed during the most recent update/call to `compute_loss`.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
module_id: The module whose KL loss coefficient to update.
|
| 143 |
+
config: The AlgorithmConfig specific to the given `module_id`.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
AppoLearner = APPOLearner
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_rl_module.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backward compat import.
|
| 2 |
+
from ray.rllib.algorithms.appo.default_appo_rl_module import ( # noqa
|
| 3 |
+
DefaultAPPORLModule as APPORLModule,
|
| 4 |
+
)
|
| 5 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 6 |
+
|
| 7 |
+
deprecation_warning(
|
| 8 |
+
old="ray.rllib.algorithms.appo.appo_rl_module.APPORLModule",
|
| 9 |
+
new="ray.rllib.algorithms.appo.default_appo_rl_module.DefaultAPPORLModule",
|
| 10 |
+
error=False,
|
| 11 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_tf_policy.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TensorFlow policy class used for APPO.
|
| 3 |
+
|
| 4 |
+
Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
| 5 |
+
Keep in sync with changes to VTraceTFPolicy.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import logging
|
| 10 |
+
import gymnasium as gym
|
| 11 |
+
from typing import Dict, List, Optional, Type, Union
|
| 12 |
+
|
| 13 |
+
from ray.rllib.algorithms.appo.utils import make_appo_models
|
| 14 |
+
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
|
| 15 |
+
from ray.rllib.algorithms.impala.impala_tf_policy import (
|
| 16 |
+
_make_time_major,
|
| 17 |
+
VTraceClipGradients,
|
| 18 |
+
VTraceOptimizer,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.evaluation.postprocessing import (
|
| 21 |
+
compute_bootstrap_value,
|
| 22 |
+
compute_gae_for_sample_batch,
|
| 23 |
+
Postprocessing,
|
| 24 |
+
)
|
| 25 |
+
from ray.rllib.models.tf.tf_action_dist import Categorical
|
| 26 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 27 |
+
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
| 28 |
+
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
| 29 |
+
from ray.rllib.policy.tf_mixins import (
|
| 30 |
+
EntropyCoeffSchedule,
|
| 31 |
+
LearningRateSchedule,
|
| 32 |
+
KLCoeffMixin,
|
| 33 |
+
ValueNetworkMixin,
|
| 34 |
+
GradStatsMixin,
|
| 35 |
+
TargetNetworkMixin,
|
| 36 |
+
)
|
| 37 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 38 |
+
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
| 39 |
+
from ray.rllib.utils.annotations import (
|
| 40 |
+
override,
|
| 41 |
+
)
|
| 42 |
+
from ray.rllib.utils.framework import try_import_tf
|
| 43 |
+
from ray.rllib.utils.tf_utils import explained_variance
|
| 44 |
+
from ray.rllib.utils.typing import TensorType
|
| 45 |
+
|
| 46 |
+
tf1, tf, tfv = try_import_tf()
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# TODO (sven): Deprecate once APPO and IMPALA fully on RLModules/Learner APIs.
|
| 52 |
+
def get_appo_tf_policy(name: str, base: type) -> type:
|
| 53 |
+
"""Construct an APPOTFPolicy inheriting either dynamic or eager base policies.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
A TF Policy to be used with Impala.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
class APPOTFPolicy(
|
| 63 |
+
VTraceClipGradients,
|
| 64 |
+
VTraceOptimizer,
|
| 65 |
+
LearningRateSchedule,
|
| 66 |
+
KLCoeffMixin,
|
| 67 |
+
EntropyCoeffSchedule,
|
| 68 |
+
ValueNetworkMixin,
|
| 69 |
+
TargetNetworkMixin,
|
| 70 |
+
GradStatsMixin,
|
| 71 |
+
base,
|
| 72 |
+
):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
observation_space,
|
| 76 |
+
action_space,
|
| 77 |
+
config,
|
| 78 |
+
existing_model=None,
|
| 79 |
+
existing_inputs=None,
|
| 80 |
+
):
|
| 81 |
+
# First thing first, enable eager execution if necessary.
|
| 82 |
+
base.enable_eager_execution_if_necessary()
|
| 83 |
+
|
| 84 |
+
# Although this is a no-op, we call __init__ here to make it clear
|
| 85 |
+
# that base.__init__ will use the make_model() call.
|
| 86 |
+
VTraceClipGradients.__init__(self)
|
| 87 |
+
VTraceOptimizer.__init__(self)
|
| 88 |
+
|
| 89 |
+
# Initialize base class.
|
| 90 |
+
base.__init__(
|
| 91 |
+
self,
|
| 92 |
+
observation_space,
|
| 93 |
+
action_space,
|
| 94 |
+
config,
|
| 95 |
+
existing_inputs=existing_inputs,
|
| 96 |
+
existing_model=existing_model,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# TF LearningRateSchedule depends on self.framework, so initialize
|
| 100 |
+
# after base.__init__() is called.
|
| 101 |
+
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
|
| 102 |
+
EntropyCoeffSchedule.__init__(
|
| 103 |
+
self, config["entropy_coeff"], config["entropy_coeff_schedule"]
|
| 104 |
+
)
|
| 105 |
+
ValueNetworkMixin.__init__(self, config)
|
| 106 |
+
KLCoeffMixin.__init__(self, config)
|
| 107 |
+
|
| 108 |
+
GradStatsMixin.__init__(self)
|
| 109 |
+
|
| 110 |
+
# Note: this is a bit ugly, but loss and optimizer initialization must
|
| 111 |
+
# happen after all the MixIns are initialized.
|
| 112 |
+
self.maybe_initialize_optimizer_and_loss()
|
| 113 |
+
|
| 114 |
+
# Initiate TargetNetwork ops after loss initialization.
|
| 115 |
+
TargetNetworkMixin.__init__(self)
|
| 116 |
+
|
| 117 |
+
@override(base)
|
| 118 |
+
def make_model(self) -> ModelV2:
|
| 119 |
+
return make_appo_models(self)
|
| 120 |
+
|
| 121 |
+
@override(base)
|
| 122 |
+
def loss(
|
| 123 |
+
self,
|
| 124 |
+
model: Union[ModelV2, "tf.keras.Model"],
|
| 125 |
+
dist_class: Type[TFActionDistribution],
|
| 126 |
+
train_batch: SampleBatch,
|
| 127 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 128 |
+
model_out, _ = model(train_batch)
|
| 129 |
+
action_dist = dist_class(model_out, model)
|
| 130 |
+
|
| 131 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 132 |
+
is_multidiscrete = False
|
| 133 |
+
output_hidden_shape = [self.action_space.n]
|
| 134 |
+
elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
| 135 |
+
is_multidiscrete = True
|
| 136 |
+
output_hidden_shape = self.action_space.nvec.astype(np.int32)
|
| 137 |
+
else:
|
| 138 |
+
is_multidiscrete = False
|
| 139 |
+
output_hidden_shape = 1
|
| 140 |
+
|
| 141 |
+
def make_time_major(*args, **kw):
|
| 142 |
+
return _make_time_major(
|
| 143 |
+
self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
actions = train_batch[SampleBatch.ACTIONS]
|
| 147 |
+
dones = train_batch[SampleBatch.TERMINATEDS]
|
| 148 |
+
rewards = train_batch[SampleBatch.REWARDS]
|
| 149 |
+
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
| 150 |
+
|
| 151 |
+
target_model_out, _ = self.target_model(train_batch)
|
| 152 |
+
prev_action_dist = dist_class(behaviour_logits, self.model)
|
| 153 |
+
values = self.model.value_function()
|
| 154 |
+
values_time_major = make_time_major(values)
|
| 155 |
+
bootstrap_values_time_major = make_time_major(
|
| 156 |
+
train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
|
| 157 |
+
)
|
| 158 |
+
bootstrap_value = bootstrap_values_time_major[-1]
|
| 159 |
+
|
| 160 |
+
if self.is_recurrent():
|
| 161 |
+
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
|
| 162 |
+
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
|
| 163 |
+
mask = tf.reshape(mask, [-1])
|
| 164 |
+
mask = make_time_major(mask)
|
| 165 |
+
|
| 166 |
+
def reduce_mean_valid(t):
|
| 167 |
+
return tf.reduce_mean(tf.boolean_mask(t, mask))
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
reduce_mean_valid = tf.reduce_mean
|
| 171 |
+
|
| 172 |
+
if self.config["vtrace"]:
|
| 173 |
+
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
|
| 174 |
+
|
| 175 |
+
# Prepare actions for loss.
|
| 176 |
+
loss_actions = (
|
| 177 |
+
actions if is_multidiscrete else tf.expand_dims(actions, axis=1)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
|
| 181 |
+
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
| 182 |
+
|
| 183 |
+
# Prepare KL for Loss
|
| 184 |
+
mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist))
|
| 185 |
+
|
| 186 |
+
unpacked_behaviour_logits = tf.split(
|
| 187 |
+
behaviour_logits, output_hidden_shape, axis=1
|
| 188 |
+
)
|
| 189 |
+
unpacked_old_policy_behaviour_logits = tf.split(
|
| 190 |
+
old_policy_behaviour_logits, output_hidden_shape, axis=1
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Compute vtrace on the CPU for better perf.
|
| 194 |
+
with tf.device("/cpu:0"):
|
| 195 |
+
vtrace_returns = vtrace.multi_from_logits(
|
| 196 |
+
behaviour_policy_logits=make_time_major(
|
| 197 |
+
unpacked_behaviour_logits
|
| 198 |
+
),
|
| 199 |
+
target_policy_logits=make_time_major(
|
| 200 |
+
unpacked_old_policy_behaviour_logits
|
| 201 |
+
),
|
| 202 |
+
actions=tf.unstack(make_time_major(loss_actions), axis=2),
|
| 203 |
+
discounts=tf.cast(
|
| 204 |
+
~make_time_major(tf.cast(dones, tf.bool)),
|
| 205 |
+
tf.float32,
|
| 206 |
+
)
|
| 207 |
+
* self.config["gamma"],
|
| 208 |
+
rewards=make_time_major(rewards),
|
| 209 |
+
values=values_time_major,
|
| 210 |
+
bootstrap_value=bootstrap_value,
|
| 211 |
+
dist_class=Categorical if is_multidiscrete else dist_class,
|
| 212 |
+
model=model,
|
| 213 |
+
clip_rho_threshold=tf.cast(
|
| 214 |
+
self.config["vtrace_clip_rho_threshold"], tf.float32
|
| 215 |
+
),
|
| 216 |
+
clip_pg_rho_threshold=tf.cast(
|
| 217 |
+
self.config["vtrace_clip_pg_rho_threshold"], tf.float32
|
| 218 |
+
),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
actions_logp = make_time_major(action_dist.logp(actions))
|
| 222 |
+
prev_actions_logp = make_time_major(prev_action_dist.logp(actions))
|
| 223 |
+
old_policy_actions_logp = make_time_major(
|
| 224 |
+
old_policy_action_dist.logp(actions)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
is_ratio = tf.clip_by_value(
|
| 228 |
+
tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
|
| 229 |
+
)
|
| 230 |
+
logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
|
| 231 |
+
self._is_ratio = is_ratio
|
| 232 |
+
|
| 233 |
+
advantages = vtrace_returns.pg_advantages
|
| 234 |
+
surrogate_loss = tf.minimum(
|
| 235 |
+
advantages * logp_ratio,
|
| 236 |
+
advantages
|
| 237 |
+
* tf.clip_by_value(
|
| 238 |
+
logp_ratio,
|
| 239 |
+
1 - self.config["clip_param"],
|
| 240 |
+
1 + self.config["clip_param"],
|
| 241 |
+
),
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
action_kl = (
|
| 245 |
+
tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
|
| 246 |
+
)
|
| 247 |
+
mean_kl_loss = reduce_mean_valid(action_kl)
|
| 248 |
+
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
| 249 |
+
|
| 250 |
+
# The value function loss.
|
| 251 |
+
value_targets = vtrace_returns.vs
|
| 252 |
+
delta = values_time_major - value_targets
|
| 253 |
+
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
|
| 254 |
+
|
| 255 |
+
# The entropy loss.
|
| 256 |
+
actions_entropy = make_time_major(action_dist.multi_entropy())
|
| 257 |
+
mean_entropy = reduce_mean_valid(actions_entropy)
|
| 258 |
+
|
| 259 |
+
else:
|
| 260 |
+
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
| 261 |
+
|
| 262 |
+
# Prepare KL for Loss
|
| 263 |
+
mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
|
| 264 |
+
|
| 265 |
+
logp_ratio = tf.math.exp(
|
| 266 |
+
make_time_major(action_dist.logp(actions))
|
| 267 |
+
- make_time_major(prev_action_dist.logp(actions))
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
|
| 271 |
+
surrogate_loss = tf.minimum(
|
| 272 |
+
advantages * logp_ratio,
|
| 273 |
+
advantages
|
| 274 |
+
* tf.clip_by_value(
|
| 275 |
+
logp_ratio,
|
| 276 |
+
1 - self.config["clip_param"],
|
| 277 |
+
1 + self.config["clip_param"],
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
action_kl = (
|
| 282 |
+
tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
|
| 283 |
+
)
|
| 284 |
+
mean_kl_loss = reduce_mean_valid(action_kl)
|
| 285 |
+
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
| 286 |
+
|
| 287 |
+
# The value function loss.
|
| 288 |
+
value_targets = make_time_major(
|
| 289 |
+
train_batch[Postprocessing.VALUE_TARGETS]
|
| 290 |
+
)
|
| 291 |
+
delta = values_time_major - value_targets
|
| 292 |
+
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
|
| 293 |
+
|
| 294 |
+
# The entropy loss.
|
| 295 |
+
mean_entropy = reduce_mean_valid(
|
| 296 |
+
make_time_major(action_dist.multi_entropy())
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# The summed weighted loss.
|
| 300 |
+
total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
|
| 301 |
+
# Optional KL loss.
|
| 302 |
+
if self.config["use_kl_loss"]:
|
| 303 |
+
total_loss += self.kl_coeff * mean_kl_loss
|
| 304 |
+
# Optional vf loss (or in a separate term due to separate
|
| 305 |
+
# optimizers/networks).
|
| 306 |
+
loss_wo_vf = total_loss
|
| 307 |
+
if not self.config["_separate_vf_optimizer"]:
|
| 308 |
+
total_loss += mean_vf_loss * self.config["vf_loss_coeff"]
|
| 309 |
+
|
| 310 |
+
# Store stats in policy for stats_fn.
|
| 311 |
+
self._total_loss = total_loss
|
| 312 |
+
self._loss_wo_vf = loss_wo_vf
|
| 313 |
+
self._mean_policy_loss = mean_policy_loss
|
| 314 |
+
# Backward compatibility: Deprecate policy._mean_kl.
|
| 315 |
+
self._mean_kl_loss = self._mean_kl = mean_kl_loss
|
| 316 |
+
self._mean_vf_loss = mean_vf_loss
|
| 317 |
+
self._mean_entropy = mean_entropy
|
| 318 |
+
self._value_targets = value_targets
|
| 319 |
+
|
| 320 |
+
# Return one total loss or two losses: vf vs rest (policy + kl).
|
| 321 |
+
if self.config["_separate_vf_optimizer"]:
|
| 322 |
+
return loss_wo_vf, mean_vf_loss
|
| 323 |
+
else:
|
| 324 |
+
return total_loss
|
| 325 |
+
|
| 326 |
+
@override(base)
|
| 327 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 328 |
+
values_batched = _make_time_major(
|
| 329 |
+
self,
|
| 330 |
+
train_batch.get(SampleBatch.SEQ_LENS),
|
| 331 |
+
self.model.value_function(),
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
stats_dict = {
|
| 335 |
+
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
| 336 |
+
"total_loss": self._total_loss,
|
| 337 |
+
"policy_loss": self._mean_policy_loss,
|
| 338 |
+
"entropy": self._mean_entropy,
|
| 339 |
+
"var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()),
|
| 340 |
+
"vf_loss": self._mean_vf_loss,
|
| 341 |
+
"vf_explained_var": explained_variance(
|
| 342 |
+
tf.reshape(self._value_targets, [-1]),
|
| 343 |
+
tf.reshape(values_batched, [-1]),
|
| 344 |
+
),
|
| 345 |
+
"entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
if self.config["vtrace"]:
|
| 349 |
+
is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1])
|
| 350 |
+
stats_dict["mean_IS"] = is_stat_mean
|
| 351 |
+
stats_dict["var_IS"] = is_stat_var
|
| 352 |
+
|
| 353 |
+
if self.config["use_kl_loss"]:
|
| 354 |
+
stats_dict["kl"] = self._mean_kl_loss
|
| 355 |
+
stats_dict["KL_Coeff"] = self.kl_coeff
|
| 356 |
+
|
| 357 |
+
return stats_dict
|
| 358 |
+
|
| 359 |
+
@override(base)
|
| 360 |
+
def postprocess_trajectory(
|
| 361 |
+
self,
|
| 362 |
+
sample_batch: SampleBatch,
|
| 363 |
+
other_agent_batches: Optional[SampleBatch] = None,
|
| 364 |
+
episode=None,
|
| 365 |
+
):
|
| 366 |
+
# Call super's postprocess_trajectory first.
|
| 367 |
+
# sample_batch = super().postprocess_trajectory(
|
| 368 |
+
# sample_batch, other_agent_batches, episode
|
| 369 |
+
# )
|
| 370 |
+
|
| 371 |
+
if not self.config["vtrace"]:
|
| 372 |
+
sample_batch = compute_gae_for_sample_batch(
|
| 373 |
+
self, sample_batch, other_agent_batches, episode
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
# Add the Columns.VALUES_BOOTSTRAPPED column, which we'll need
|
| 377 |
+
# inside the loss for vtrace calculations.
|
| 378 |
+
sample_batch = compute_bootstrap_value(sample_batch, self)
|
| 379 |
+
|
| 380 |
+
return sample_batch
|
| 381 |
+
|
| 382 |
+
@override(base)
|
| 383 |
+
def get_batch_divisibility_req(self) -> int:
|
| 384 |
+
return self.config["rollout_fragment_length"]
|
| 385 |
+
|
| 386 |
+
APPOTFPolicy.__name__ = name
|
| 387 |
+
APPOTFPolicy.__qualname__ = name
|
| 388 |
+
|
| 389 |
+
return APPOTFPolicy
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
APPOTF1Policy = get_appo_tf_policy("APPOTF1Policy", DynamicTFPolicyV2)
|
| 393 |
+
APPOTF2Policy = get_appo_tf_policy("APPOTF2Policy", EagerTFPolicyV2)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/appo_torch_policy.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch policy class used for APPO.
|
| 3 |
+
|
| 4 |
+
Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
| 5 |
+
Keep in sync with changes to VTraceTFPolicy.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gymnasium as gym
|
| 9 |
+
import numpy as np
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Any, Dict, List, Optional, Type, Union
|
| 12 |
+
|
| 13 |
+
import ray
|
| 14 |
+
from ray.rllib.algorithms.appo.utils import make_appo_models
|
| 15 |
+
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
|
| 16 |
+
from ray.rllib.algorithms.impala.impala_torch_policy import (
|
| 17 |
+
make_time_major,
|
| 18 |
+
VTraceOptimizer,
|
| 19 |
+
)
|
| 20 |
+
from ray.rllib.evaluation.postprocessing import (
|
| 21 |
+
compute_bootstrap_value,
|
| 22 |
+
compute_gae_for_sample_batch,
|
| 23 |
+
Postprocessing,
|
| 24 |
+
)
|
| 25 |
+
from ray.rllib.models.action_dist import ActionDistribution
|
| 26 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 27 |
+
from ray.rllib.models.torch.torch_action_dist import (
|
| 28 |
+
TorchDistributionWrapper,
|
| 29 |
+
TorchCategorical,
|
| 30 |
+
)
|
| 31 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
| 32 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 33 |
+
from ray.rllib.policy.torch_mixins import (
|
| 34 |
+
EntropyCoeffSchedule,
|
| 35 |
+
LearningRateSchedule,
|
| 36 |
+
KLCoeffMixin,
|
| 37 |
+
ValueNetworkMixin,
|
| 38 |
+
TargetNetworkMixin,
|
| 39 |
+
)
|
| 40 |
+
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
| 41 |
+
from ray.rllib.utils.annotations import override
|
| 42 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 43 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 44 |
+
from ray.rllib.utils.torch_utils import (
|
| 45 |
+
apply_grad_clipping,
|
| 46 |
+
explained_variance,
|
| 47 |
+
global_norm,
|
| 48 |
+
sequence_mask,
|
| 49 |
+
)
|
| 50 |
+
from ray.rllib.utils.typing import TensorType
|
| 51 |
+
|
| 52 |
+
torch, nn = try_import_torch()
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# TODO (sven): Deprecate once APPO and IMPALA fully on RLModules/Learner APIs.
|
| 58 |
+
class APPOTorchPolicy(
|
| 59 |
+
VTraceOptimizer,
|
| 60 |
+
LearningRateSchedule,
|
| 61 |
+
EntropyCoeffSchedule,
|
| 62 |
+
KLCoeffMixin,
|
| 63 |
+
ValueNetworkMixin,
|
| 64 |
+
TargetNetworkMixin,
|
| 65 |
+
TorchPolicyV2,
|
| 66 |
+
):
|
| 67 |
+
"""PyTorch policy class used with APPO."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, observation_space, action_space, config):
|
| 70 |
+
config = dict(ray.rllib.algorithms.appo.appo.APPOConfig().to_dict(), **config)
|
| 71 |
+
config["enable_rl_module_and_learner"] = False
|
| 72 |
+
config["enable_env_runner_and_connector_v2"] = False
|
| 73 |
+
|
| 74 |
+
# Although this is a no-op, we call __init__ here to make it clear
|
| 75 |
+
# that base.__init__ will use the make_model() call.
|
| 76 |
+
VTraceOptimizer.__init__(self)
|
| 77 |
+
|
| 78 |
+
lr_schedule_additional_args = []
|
| 79 |
+
if config.get("_separate_vf_optimizer"):
|
| 80 |
+
lr_schedule_additional_args = (
|
| 81 |
+
[config["_lr_vf"][0][1], config["_lr_vf"]]
|
| 82 |
+
if isinstance(config["_lr_vf"], (list, tuple))
|
| 83 |
+
else [config["_lr_vf"], None]
|
| 84 |
+
)
|
| 85 |
+
LearningRateSchedule.__init__(
|
| 86 |
+
self, config["lr"], config["lr_schedule"], *lr_schedule_additional_args
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
TorchPolicyV2.__init__(
|
| 90 |
+
self,
|
| 91 |
+
observation_space,
|
| 92 |
+
action_space,
|
| 93 |
+
config,
|
| 94 |
+
max_seq_len=config["model"]["max_seq_len"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
EntropyCoeffSchedule.__init__(
|
| 98 |
+
self, config["entropy_coeff"], config["entropy_coeff_schedule"]
|
| 99 |
+
)
|
| 100 |
+
ValueNetworkMixin.__init__(self, config)
|
| 101 |
+
KLCoeffMixin.__init__(self, config)
|
| 102 |
+
|
| 103 |
+
self._initialize_loss_from_dummy_batch()
|
| 104 |
+
|
| 105 |
+
# Initiate TargetNetwork ops after loss initialization.
|
| 106 |
+
TargetNetworkMixin.__init__(self)
|
| 107 |
+
|
| 108 |
+
@override(TorchPolicyV2)
|
| 109 |
+
def init_view_requirements(self):
|
| 110 |
+
self.view_requirements = self._get_default_view_requirements()
|
| 111 |
+
|
| 112 |
+
@override(TorchPolicyV2)
|
| 113 |
+
def make_model(self) -> ModelV2:
|
| 114 |
+
return make_appo_models(self)
|
| 115 |
+
|
| 116 |
+
@override(TorchPolicyV2)
|
| 117 |
+
def loss(
|
| 118 |
+
self,
|
| 119 |
+
model: ModelV2,
|
| 120 |
+
dist_class: Type[ActionDistribution],
|
| 121 |
+
train_batch: SampleBatch,
|
| 122 |
+
) -> Union[TensorType, List[TensorType]]:
|
| 123 |
+
"""Constructs the loss for APPO.
|
| 124 |
+
|
| 125 |
+
With IS modifications and V-trace for Advantage Estimation.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
model (ModelV2): The Model to calculate the loss for.
|
| 129 |
+
dist_class (Type[ActionDistribution]): The action distr. class.
|
| 130 |
+
train_batch: The training data.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
| 134 |
+
of loss tensors.
|
| 135 |
+
"""
|
| 136 |
+
target_model = self.target_models[model]
|
| 137 |
+
|
| 138 |
+
model_out, _ = model(train_batch)
|
| 139 |
+
action_dist = dist_class(model_out, model)
|
| 140 |
+
|
| 141 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 142 |
+
is_multidiscrete = False
|
| 143 |
+
output_hidden_shape = [self.action_space.n]
|
| 144 |
+
elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
| 145 |
+
is_multidiscrete = True
|
| 146 |
+
output_hidden_shape = self.action_space.nvec.astype(np.int32)
|
| 147 |
+
else:
|
| 148 |
+
is_multidiscrete = False
|
| 149 |
+
output_hidden_shape = 1
|
| 150 |
+
|
| 151 |
+
def _make_time_major(*args, **kwargs):
|
| 152 |
+
return make_time_major(
|
| 153 |
+
self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
actions = train_batch[SampleBatch.ACTIONS]
|
| 157 |
+
dones = train_batch[SampleBatch.TERMINATEDS]
|
| 158 |
+
rewards = train_batch[SampleBatch.REWARDS]
|
| 159 |
+
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
| 160 |
+
|
| 161 |
+
target_model_out, _ = target_model(train_batch)
|
| 162 |
+
|
| 163 |
+
prev_action_dist = dist_class(behaviour_logits, model)
|
| 164 |
+
values = model.value_function()
|
| 165 |
+
values_time_major = _make_time_major(values)
|
| 166 |
+
bootstrap_values_time_major = _make_time_major(
|
| 167 |
+
train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
|
| 168 |
+
)
|
| 169 |
+
bootstrap_value = bootstrap_values_time_major[-1]
|
| 170 |
+
|
| 171 |
+
if self.is_recurrent():
|
| 172 |
+
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
|
| 173 |
+
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
|
| 174 |
+
mask = torch.reshape(mask, [-1])
|
| 175 |
+
mask = _make_time_major(mask)
|
| 176 |
+
num_valid = torch.sum(mask)
|
| 177 |
+
|
| 178 |
+
def reduce_mean_valid(t):
|
| 179 |
+
return torch.sum(t[mask]) / num_valid
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
reduce_mean_valid = torch.mean
|
| 183 |
+
|
| 184 |
+
if self.config["vtrace"]:
|
| 185 |
+
logger.debug("Using V-Trace surrogate loss (vtrace=True)")
|
| 186 |
+
|
| 187 |
+
old_policy_behaviour_logits = target_model_out.detach()
|
| 188 |
+
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
| 189 |
+
|
| 190 |
+
if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
|
| 191 |
+
unpacked_behaviour_logits = torch.split(
|
| 192 |
+
behaviour_logits, list(output_hidden_shape), dim=1
|
| 193 |
+
)
|
| 194 |
+
unpacked_old_policy_behaviour_logits = torch.split(
|
| 195 |
+
old_policy_behaviour_logits, list(output_hidden_shape), dim=1
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
unpacked_behaviour_logits = torch.chunk(
|
| 199 |
+
behaviour_logits, output_hidden_shape, dim=1
|
| 200 |
+
)
|
| 201 |
+
unpacked_old_policy_behaviour_logits = torch.chunk(
|
| 202 |
+
old_policy_behaviour_logits, output_hidden_shape, dim=1
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Prepare actions for loss.
|
| 206 |
+
loss_actions = (
|
| 207 |
+
actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Prepare KL for loss.
|
| 211 |
+
action_kl = _make_time_major(old_policy_action_dist.kl(action_dist))
|
| 212 |
+
|
| 213 |
+
# Compute vtrace on the CPU for better perf.
|
| 214 |
+
vtrace_returns = vtrace.multi_from_logits(
|
| 215 |
+
behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits),
|
| 216 |
+
target_policy_logits=_make_time_major(
|
| 217 |
+
unpacked_old_policy_behaviour_logits
|
| 218 |
+
),
|
| 219 |
+
actions=torch.unbind(_make_time_major(loss_actions), dim=2),
|
| 220 |
+
discounts=(1.0 - _make_time_major(dones).float())
|
| 221 |
+
* self.config["gamma"],
|
| 222 |
+
rewards=_make_time_major(rewards),
|
| 223 |
+
values=values_time_major,
|
| 224 |
+
bootstrap_value=bootstrap_value,
|
| 225 |
+
dist_class=TorchCategorical if is_multidiscrete else dist_class,
|
| 226 |
+
model=model,
|
| 227 |
+
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
| 228 |
+
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
actions_logp = _make_time_major(action_dist.logp(actions))
|
| 232 |
+
prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
|
| 233 |
+
old_policy_actions_logp = _make_time_major(
|
| 234 |
+
old_policy_action_dist.logp(actions)
|
| 235 |
+
)
|
| 236 |
+
is_ratio = torch.clamp(
|
| 237 |
+
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
|
| 238 |
+
)
|
| 239 |
+
logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
|
| 240 |
+
self._is_ratio = is_ratio
|
| 241 |
+
|
| 242 |
+
advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
|
| 243 |
+
surrogate_loss = torch.min(
|
| 244 |
+
advantages * logp_ratio,
|
| 245 |
+
advantages
|
| 246 |
+
* torch.clamp(
|
| 247 |
+
logp_ratio,
|
| 248 |
+
1 - self.config["clip_param"],
|
| 249 |
+
1 + self.config["clip_param"],
|
| 250 |
+
),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
mean_kl_loss = reduce_mean_valid(action_kl)
|
| 254 |
+
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
| 255 |
+
|
| 256 |
+
# The value function loss.
|
| 257 |
+
value_targets = vtrace_returns.vs.to(values_time_major.device)
|
| 258 |
+
delta = values_time_major - value_targets
|
| 259 |
+
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
| 260 |
+
|
| 261 |
+
# The entropy loss.
|
| 262 |
+
mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
| 266 |
+
|
| 267 |
+
# Prepare KL for Loss
|
| 268 |
+
action_kl = _make_time_major(prev_action_dist.kl(action_dist))
|
| 269 |
+
|
| 270 |
+
actions_logp = _make_time_major(action_dist.logp(actions))
|
| 271 |
+
prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
|
| 272 |
+
logp_ratio = torch.exp(actions_logp - prev_actions_logp)
|
| 273 |
+
|
| 274 |
+
advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
|
| 275 |
+
surrogate_loss = torch.min(
|
| 276 |
+
advantages * logp_ratio,
|
| 277 |
+
advantages
|
| 278 |
+
* torch.clamp(
|
| 279 |
+
logp_ratio,
|
| 280 |
+
1 - self.config["clip_param"],
|
| 281 |
+
1 + self.config["clip_param"],
|
| 282 |
+
),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
mean_kl_loss = reduce_mean_valid(action_kl)
|
| 286 |
+
mean_policy_loss = -reduce_mean_valid(surrogate_loss)
|
| 287 |
+
|
| 288 |
+
# The value function loss.
|
| 289 |
+
value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
|
| 290 |
+
delta = values_time_major - value_targets
|
| 291 |
+
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))
|
| 292 |
+
|
| 293 |
+
# The entropy loss.
|
| 294 |
+
mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))
|
| 295 |
+
|
| 296 |
+
# The summed weighted loss.
|
| 297 |
+
total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
|
| 298 |
+
# Optional additional KL Loss
|
| 299 |
+
if self.config["use_kl_loss"]:
|
| 300 |
+
total_loss += self.kl_coeff * mean_kl_loss
|
| 301 |
+
|
| 302 |
+
# Optional vf loss (or in a separate term due to separate
|
| 303 |
+
# optimizers/networks).
|
| 304 |
+
loss_wo_vf = total_loss
|
| 305 |
+
if not self.config["_separate_vf_optimizer"]:
|
| 306 |
+
total_loss += mean_vf_loss * self.config["vf_loss_coeff"]
|
| 307 |
+
|
| 308 |
+
# Store values for stats function in model (tower), such that for
|
| 309 |
+
# multi-GPU, we do not override them during the parallel loss phase.
|
| 310 |
+
model.tower_stats["total_loss"] = total_loss
|
| 311 |
+
model.tower_stats["mean_policy_loss"] = mean_policy_loss
|
| 312 |
+
model.tower_stats["mean_kl_loss"] = mean_kl_loss
|
| 313 |
+
model.tower_stats["mean_vf_loss"] = mean_vf_loss
|
| 314 |
+
model.tower_stats["mean_entropy"] = mean_entropy
|
| 315 |
+
model.tower_stats["value_targets"] = value_targets
|
| 316 |
+
model.tower_stats["vf_explained_var"] = explained_variance(
|
| 317 |
+
torch.reshape(value_targets, [-1]),
|
| 318 |
+
torch.reshape(values_time_major, [-1]),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Return one total loss or two losses: vf vs rest (policy + kl).
|
| 322 |
+
if self.config["_separate_vf_optimizer"]:
|
| 323 |
+
return loss_wo_vf, mean_vf_loss
|
| 324 |
+
else:
|
| 325 |
+
return total_loss
|
| 326 |
+
|
| 327 |
+
@override(TorchPolicyV2)
|
| 328 |
+
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
| 329 |
+
"""Stats function for APPO. Returns a dict with important loss stats.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
policy: The Policy to generate stats for.
|
| 333 |
+
train_batch: The SampleBatch (already) used for training.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Dict[str, TensorType]: The stats dict.
|
| 337 |
+
"""
|
| 338 |
+
stats_dict = {
|
| 339 |
+
"cur_lr": self.cur_lr,
|
| 340 |
+
"total_loss": torch.mean(torch.stack(self.get_tower_stats("total_loss"))),
|
| 341 |
+
"policy_loss": torch.mean(
|
| 342 |
+
torch.stack(self.get_tower_stats("mean_policy_loss"))
|
| 343 |
+
),
|
| 344 |
+
"entropy": torch.mean(torch.stack(self.get_tower_stats("mean_entropy"))),
|
| 345 |
+
"entropy_coeff": self.entropy_coeff,
|
| 346 |
+
"var_gnorm": global_norm(self.model.trainable_variables()),
|
| 347 |
+
"vf_loss": torch.mean(torch.stack(self.get_tower_stats("mean_vf_loss"))),
|
| 348 |
+
"vf_explained_var": torch.mean(
|
| 349 |
+
torch.stack(self.get_tower_stats("vf_explained_var"))
|
| 350 |
+
),
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
if self.config["vtrace"]:
|
| 354 |
+
is_stat_mean = torch.mean(self._is_ratio, [0, 1])
|
| 355 |
+
is_stat_var = torch.var(self._is_ratio, [0, 1])
|
| 356 |
+
stats_dict["mean_IS"] = is_stat_mean
|
| 357 |
+
stats_dict["var_IS"] = is_stat_var
|
| 358 |
+
|
| 359 |
+
if self.config["use_kl_loss"]:
|
| 360 |
+
stats_dict["kl"] = torch.mean(
|
| 361 |
+
torch.stack(self.get_tower_stats("mean_kl_loss"))
|
| 362 |
+
)
|
| 363 |
+
stats_dict["KL_Coeff"] = self.kl_coeff
|
| 364 |
+
|
| 365 |
+
return convert_to_numpy(stats_dict)
|
| 366 |
+
|
| 367 |
+
@override(TorchPolicyV2)
|
| 368 |
+
def extra_action_out(
|
| 369 |
+
self,
|
| 370 |
+
input_dict: Dict[str, TensorType],
|
| 371 |
+
state_batches: List[TensorType],
|
| 372 |
+
model: TorchModelV2,
|
| 373 |
+
action_dist: TorchDistributionWrapper,
|
| 374 |
+
) -> Dict[str, TensorType]:
|
| 375 |
+
return {SampleBatch.VF_PREDS: model.value_function()}
|
| 376 |
+
|
| 377 |
+
@override(TorchPolicyV2)
|
| 378 |
+
def postprocess_trajectory(
|
| 379 |
+
self,
|
| 380 |
+
sample_batch: SampleBatch,
|
| 381 |
+
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
| 382 |
+
episode=None,
|
| 383 |
+
):
|
| 384 |
+
# Call super's postprocess_trajectory first.
|
| 385 |
+
# sample_batch = super().postprocess_trajectory(
|
| 386 |
+
# sample_batch, other_agent_batches, episode
|
| 387 |
+
# )
|
| 388 |
+
|
| 389 |
+
# Do all post-processing always with no_grad().
|
| 390 |
+
# Not using this here will introduce a memory leak
|
| 391 |
+
# in torch (issue #6962).
|
| 392 |
+
with torch.no_grad():
|
| 393 |
+
if not self.config["vtrace"]:
|
| 394 |
+
sample_batch = compute_gae_for_sample_batch(
|
| 395 |
+
self, sample_batch, other_agent_batches, episode
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
# Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need
|
| 399 |
+
# inside the loss for vtrace calculations.
|
| 400 |
+
sample_batch = compute_bootstrap_value(sample_batch, self)
|
| 401 |
+
|
| 402 |
+
return sample_batch
|
| 403 |
+
|
| 404 |
+
@override(TorchPolicyV2)
|
| 405 |
+
def extra_grad_process(
|
| 406 |
+
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
| 407 |
+
) -> Dict[str, TensorType]:
|
| 408 |
+
return apply_grad_clipping(self, optimizer, loss)
|
| 409 |
+
|
| 410 |
+
@override(TorchPolicyV2)
|
| 411 |
+
def get_batch_divisibility_req(self) -> int:
|
| 412 |
+
return self.config["rollout_fragment_length"]
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/default_appo_rl_module.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Any, Dict, List, Tuple
|
| 3 |
+
|
| 4 |
+
from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule
|
| 5 |
+
from ray.rllib.core.learner.utils import make_target_network
|
| 6 |
+
from ray.rllib.core.models.base import ACTOR
|
| 7 |
+
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
|
| 8 |
+
from ray.rllib.core.rl_module.apis import (
|
| 9 |
+
TARGET_NETWORK_ACTION_DIST_INPUTS,
|
| 10 |
+
TargetNetworkAPI,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.utils.typing import NetworkType
|
| 13 |
+
|
| 14 |
+
from ray.rllib.utils.annotations import (
|
| 15 |
+
override,
|
| 16 |
+
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
| 17 |
+
)
|
| 18 |
+
from ray.util.annotations import DeveloperAPI
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@DeveloperAPI
|
| 22 |
+
class DefaultAPPORLModule(DefaultPPORLModule, TargetNetworkAPI, abc.ABC):
|
| 23 |
+
"""Default RLModule used by APPO, if user does not specify a custom RLModule.
|
| 24 |
+
|
| 25 |
+
Users who want to train their RLModules with APPO may implement any RLModule
|
| 26 |
+
(or TorchRLModule) subclass as long as the custom class also implements the
|
| 27 |
+
`ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py)
|
| 28 |
+
and the `TargetNetworkAPI` (see
|
| 29 |
+
ray.rllib.core.rl_module.apis.target_network_api.py).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@override(TargetNetworkAPI)
|
| 33 |
+
def make_target_networks(self):
|
| 34 |
+
self._old_encoder = make_target_network(self.encoder)
|
| 35 |
+
self._old_pi = make_target_network(self.pi)
|
| 36 |
+
|
| 37 |
+
@override(TargetNetworkAPI)
|
| 38 |
+
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
|
| 39 |
+
return [
|
| 40 |
+
(self.encoder, self._old_encoder),
|
| 41 |
+
(self.pi, self._old_pi),
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
@override(TargetNetworkAPI)
|
| 45 |
+
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 46 |
+
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
|
| 47 |
+
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
|
| 48 |
+
return {TARGET_NETWORK_ACTION_DIST_INPUTS: old_action_dist_logits}
|
| 49 |
+
|
| 50 |
+
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
| 51 |
+
@override(DefaultPPORLModule)
|
| 52 |
+
def get_non_inference_attributes(self) -> List[str]:
|
| 53 |
+
# Get the NON inference-only attributes from the parent class
|
| 54 |
+
# `PPOTorchRLModule`.
|
| 55 |
+
ret = super().get_non_inference_attributes()
|
| 56 |
+
# Add the two (APPO) target networks to it (NOT needed in
|
| 57 |
+
# inference-only mode).
|
| 58 |
+
ret += ["_old_encoder", "_old_pi"]
|
| 59 |
+
return ret
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_learner.cpython-311.pyc
ADDED
|
Binary file (9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/appo_torch_rl_module.cpython-311.pyc
ADDED
|
Binary file (709 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/__pycache__/default_appo_torch_rl_module.cpython-311.pyc
ADDED
|
Binary file (818 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_learner.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Asynchronous Proximal Policy Optimization (APPO)
|
| 2 |
+
|
| 3 |
+
The algorithm is described in [1] (under the name of "IMPACT"):
|
| 4 |
+
|
| 5 |
+
Detailed documentation:
|
| 6 |
+
https://docs.ray.io/en/master/rllib-algorithms.html#appo
|
| 7 |
+
|
| 8 |
+
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
|
| 9 |
+
Luo et al. 2020
|
| 10 |
+
https://arxiv.org/pdf/1912.00167
|
| 11 |
+
"""
|
| 12 |
+
from typing import Dict
|
| 13 |
+
|
| 14 |
+
from ray.rllib.algorithms.appo.appo import (
|
| 15 |
+
APPOConfig,
|
| 16 |
+
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
|
| 17 |
+
LEARNER_RESULTS_KL_KEY,
|
| 18 |
+
)
|
| 19 |
+
from ray.rllib.algorithms.appo.appo_learner import APPOLearner
|
| 20 |
+
from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner
|
| 21 |
+
from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import (
|
| 22 |
+
make_time_major,
|
| 23 |
+
vtrace_torch,
|
| 24 |
+
)
|
| 25 |
+
from ray.rllib.core.columns import Columns
|
| 26 |
+
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
|
| 27 |
+
from ray.rllib.core.rl_module.apis import (
|
| 28 |
+
TARGET_NETWORK_ACTION_DIST_INPUTS,
|
| 29 |
+
TargetNetworkAPI,
|
| 30 |
+
ValueFunctionAPI,
|
| 31 |
+
)
|
| 32 |
+
from ray.rllib.utils.annotations import override
|
| 33 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 34 |
+
from ray.rllib.utils.numpy import convert_to_numpy
|
| 35 |
+
from ray.rllib.utils.typing import ModuleID, TensorType
|
| 36 |
+
|
| 37 |
+
torch, nn = try_import_torch()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class APPOTorchLearner(APPOLearner, IMPALATorchLearner):
|
| 41 |
+
"""Implements APPO loss / update logic on top of IMPALATorchLearner."""
|
| 42 |
+
|
| 43 |
+
@override(IMPALATorchLearner)
|
| 44 |
+
def compute_loss_for_module(
|
| 45 |
+
self,
|
| 46 |
+
*,
|
| 47 |
+
module_id: ModuleID,
|
| 48 |
+
config: APPOConfig,
|
| 49 |
+
batch: Dict,
|
| 50 |
+
fwd_out: Dict[str, TensorType],
|
| 51 |
+
) -> TensorType:
|
| 52 |
+
module = self.module[module_id].unwrapped()
|
| 53 |
+
assert isinstance(module, TargetNetworkAPI)
|
| 54 |
+
assert isinstance(module, ValueFunctionAPI)
|
| 55 |
+
|
| 56 |
+
# TODO (sven): Now that we do the +1ts trick to be less vulnerable about
|
| 57 |
+
# bootstrap values at the end of rollouts in the new stack, we might make
|
| 58 |
+
# this a more flexible, configurable parameter for users, e.g.
|
| 59 |
+
# `v_trace_seq_len` (independent of `rollout_fragment_length`). Separation
|
| 60 |
+
# of concerns (sampling vs learning).
|
| 61 |
+
rollout_frag_or_episode_len = config.get_rollout_fragment_length()
|
| 62 |
+
recurrent_seq_len = batch.get("seq_lens")
|
| 63 |
+
|
| 64 |
+
loss_mask = batch[Columns.LOSS_MASK].float()
|
| 65 |
+
loss_mask_time_major = make_time_major(
|
| 66 |
+
loss_mask,
|
| 67 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 68 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 69 |
+
)
|
| 70 |
+
size_loss_mask = torch.sum(loss_mask)
|
| 71 |
+
|
| 72 |
+
values = module.compute_values(
|
| 73 |
+
batch, embeddings=fwd_out.get(Columns.EMBEDDINGS)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
action_dist_cls_train = module.get_train_action_dist_cls()
|
| 77 |
+
target_policy_dist = action_dist_cls_train.from_logits(
|
| 78 |
+
fwd_out[Columns.ACTION_DIST_INPUTS]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
old_target_policy_dist = action_dist_cls_train.from_logits(
|
| 82 |
+
module.forward_target(batch)[TARGET_NETWORK_ACTION_DIST_INPUTS]
|
| 83 |
+
)
|
| 84 |
+
old_target_policy_actions_logp = old_target_policy_dist.logp(
|
| 85 |
+
batch[Columns.ACTIONS]
|
| 86 |
+
)
|
| 87 |
+
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
|
| 88 |
+
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])
|
| 89 |
+
|
| 90 |
+
behaviour_actions_logp_time_major = make_time_major(
|
| 91 |
+
behaviour_actions_logp,
|
| 92 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 93 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 94 |
+
)
|
| 95 |
+
target_actions_logp_time_major = make_time_major(
|
| 96 |
+
target_actions_logp,
|
| 97 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 98 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 99 |
+
)
|
| 100 |
+
old_actions_logp_time_major = make_time_major(
|
| 101 |
+
old_target_policy_actions_logp,
|
| 102 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 103 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 104 |
+
)
|
| 105 |
+
rewards_time_major = make_time_major(
|
| 106 |
+
batch[Columns.REWARDS],
|
| 107 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 108 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 109 |
+
)
|
| 110 |
+
values_time_major = make_time_major(
|
| 111 |
+
values,
|
| 112 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 113 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 114 |
+
)
|
| 115 |
+
assert Columns.VALUES_BOOTSTRAPPED not in batch
|
| 116 |
+
# Use as bootstrap values the vf-preds in the next "batch row", except
|
| 117 |
+
# for the very last row (which doesn't have a next row), for which the
|
| 118 |
+
# bootstrap value does not matter b/c it has a +1ts value at its end
|
| 119 |
+
# anyways. So we chose an arbitrary item (for simplicity of not having to
|
| 120 |
+
# move new data to the device).
|
| 121 |
+
bootstrap_values = torch.cat(
|
| 122 |
+
[
|
| 123 |
+
values_time_major[0][1:], # 0th ts values from "next row"
|
| 124 |
+
values_time_major[0][0:1], # <- can use any arbitrary value here
|
| 125 |
+
],
|
| 126 |
+
dim=0,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# The discount factor that is used should be gamma except for timesteps where
|
| 130 |
+
# the episode is terminated. In that case, the discount factor should be 0.
|
| 131 |
+
discounts_time_major = (
|
| 132 |
+
1.0
|
| 133 |
+
- make_time_major(
|
| 134 |
+
batch[Columns.TERMINATEDS],
|
| 135 |
+
trajectory_len=rollout_frag_or_episode_len,
|
| 136 |
+
recurrent_seq_len=recurrent_seq_len,
|
| 137 |
+
).float()
|
| 138 |
+
) * config.gamma
|
| 139 |
+
|
| 140 |
+
# Note that vtrace will compute the main loop on the CPU for better performance.
|
| 141 |
+
vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
|
| 142 |
+
target_action_log_probs=old_actions_logp_time_major,
|
| 143 |
+
behaviour_action_log_probs=behaviour_actions_logp_time_major,
|
| 144 |
+
discounts=discounts_time_major,
|
| 145 |
+
rewards=rewards_time_major,
|
| 146 |
+
values=values_time_major,
|
| 147 |
+
bootstrap_values=bootstrap_values,
|
| 148 |
+
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
|
| 149 |
+
clip_rho_threshold=config.vtrace_clip_rho_threshold,
|
| 150 |
+
)
|
| 151 |
+
pg_advantages = pg_advantages * loss_mask_time_major
|
| 152 |
+
|
| 153 |
+
# The policy gradients loss.
|
| 154 |
+
is_ratio = torch.clip(
|
| 155 |
+
torch.exp(behaviour_actions_logp_time_major - old_actions_logp_time_major),
|
| 156 |
+
0.0,
|
| 157 |
+
2.0,
|
| 158 |
+
)
|
| 159 |
+
logp_ratio = is_ratio * torch.exp(
|
| 160 |
+
target_actions_logp_time_major - behaviour_actions_logp_time_major
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
surrogate_loss = torch.minimum(
|
| 164 |
+
pg_advantages * logp_ratio,
|
| 165 |
+
pg_advantages
|
| 166 |
+
* torch.clip(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if config.use_kl_loss:
|
| 170 |
+
action_kl = old_target_policy_dist.kl(target_policy_dist) * loss_mask
|
| 171 |
+
mean_kl_loss = torch.sum(action_kl) / size_loss_mask
|
| 172 |
+
else:
|
| 173 |
+
mean_kl_loss = 0.0
|
| 174 |
+
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)
|
| 175 |
+
|
| 176 |
+
# The baseline loss.
|
| 177 |
+
delta = values_time_major - vtrace_adjusted_target_values
|
| 178 |
+
vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
|
| 179 |
+
mean_vf_loss = vf_loss / size_loss_mask
|
| 180 |
+
|
| 181 |
+
# The entropy loss.
|
| 182 |
+
mean_entropy_loss = (
|
| 183 |
+
-torch.sum(target_policy_dist.entropy() * loss_mask) / size_loss_mask
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# The summed weighted loss.
|
| 187 |
+
total_loss = (
|
| 188 |
+
mean_pi_loss
|
| 189 |
+
+ (mean_vf_loss * config.vf_loss_coeff)
|
| 190 |
+
+ (
|
| 191 |
+
mean_entropy_loss
|
| 192 |
+
* self.entropy_coeff_schedulers_per_module[
|
| 193 |
+
module_id
|
| 194 |
+
].get_current_value()
|
| 195 |
+
)
|
| 196 |
+
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Log important loss stats.
|
| 200 |
+
self.metrics.log_dict(
|
| 201 |
+
{
|
| 202 |
+
POLICY_LOSS_KEY: mean_pi_loss,
|
| 203 |
+
VF_LOSS_KEY: mean_vf_loss,
|
| 204 |
+
ENTROPY_KEY: -mean_entropy_loss,
|
| 205 |
+
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
|
| 206 |
+
LEARNER_RESULTS_CURR_KL_COEFF_KEY: (
|
| 207 |
+
self.curr_kl_coeffs_per_module[module_id]
|
| 208 |
+
),
|
| 209 |
+
},
|
| 210 |
+
key=module_id,
|
| 211 |
+
window=1, # <- single items (should not be mean/ema-reduced over time).
|
| 212 |
+
)
|
| 213 |
+
# Return the total loss.
|
| 214 |
+
return total_loss
|
| 215 |
+
|
| 216 |
+
@override(APPOLearner)
|
| 217 |
+
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
|
| 218 |
+
# Update the current KL value based on the recently measured value.
|
| 219 |
+
# Increase.
|
| 220 |
+
kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)))
|
| 221 |
+
kl_coeff_var = self.curr_kl_coeffs_per_module[module_id]
|
| 222 |
+
|
| 223 |
+
if kl > 2.0 * config.kl_target:
|
| 224 |
+
# TODO (Kourosh) why not *2.0?
|
| 225 |
+
kl_coeff_var.data *= 1.5
|
| 226 |
+
# Decrease.
|
| 227 |
+
elif kl < 0.5 * config.kl_target:
|
| 228 |
+
kl_coeff_var.data *= 0.5
|
| 229 |
+
|
| 230 |
+
self.metrics.log_value(
|
| 231 |
+
(module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY),
|
| 232 |
+
kl_coeff_var.item(),
|
| 233 |
+
window=1,
|
| 234 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/appo_torch_rl_module.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backward compat import.
|
| 2 |
+
from ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module import ( # noqa
|
| 3 |
+
DefaultAPPOTorchRLModule as APPOTorchRLModule,
|
| 4 |
+
)
|
| 5 |
+
from ray.rllib.utils.deprecation import deprecation_warning
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
deprecation_warning(
|
| 9 |
+
old="ray.rllib.algorithms.appo.torch.appo_torch_rl_module.APPOTorchRLModule",
|
| 10 |
+
new="ray.rllib.algorithms.appo.torch.default_appo_torch_rl_module."
|
| 11 |
+
"DefaultAPPOTorchRLModule",
|
| 12 |
+
error=False,
|
| 13 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/torch/default_appo_torch_rl_module.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.algorithms.appo.default_appo_rl_module import DefaultAPPORLModule
|
| 2 |
+
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
|
| 3 |
+
DefaultPPOTorchRLModule,
|
| 4 |
+
)
|
| 5 |
+
from ray.util.annotations import DeveloperAPI
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@DeveloperAPI
|
| 9 |
+
class DefaultAPPOTorchRLModule(DefaultPPOTorchRLModule, DefaultAPPORLModule):
|
| 10 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/appo/utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
|
| 3 |
+
Luo et al. 2020
|
| 4 |
+
https://arxiv.org/pdf/1912.00167
|
| 5 |
+
"""
|
| 6 |
+
from collections import deque
|
| 7 |
+
import random
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
from ray.rllib.models.catalog import ModelCatalog
|
| 12 |
+
from ray.rllib.models.modelv2 import ModelV2
|
| 13 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
POLICY_SCOPE = "func"
|
| 17 |
+
TARGET_POLICY_SCOPE = "target_func"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CircularBuffer:
|
| 21 |
+
"""A circular batch-wise buffer as described in [1] for APPO.
|
| 22 |
+
|
| 23 |
+
The buffer holds at most N batches, which are sampled at random (uniformly).
|
| 24 |
+
If full and a new batch is added, the oldest batch is discarded. Also, each batch
|
| 25 |
+
currently in the buffer can be sampled at most K times (after which it is also
|
| 26 |
+
discarded).
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, num_batches: int, iterations_per_batch: int):
|
| 30 |
+
# N from the paper (buffer size).
|
| 31 |
+
self.num_batches = num_batches
|
| 32 |
+
# K ("replay coefficient") from the paper.
|
| 33 |
+
self.iterations_per_batch = iterations_per_batch
|
| 34 |
+
|
| 35 |
+
self._buffer = deque(maxlen=self.num_batches)
|
| 36 |
+
self._lock = threading.Lock()
|
| 37 |
+
|
| 38 |
+
# The number of valid (not expired) entries in this buffer.
|
| 39 |
+
self._num_valid_batches = 0
|
| 40 |
+
|
| 41 |
+
def add(self, batch):
|
| 42 |
+
dropped_entry = None
|
| 43 |
+
dropped_ts = 0
|
| 44 |
+
|
| 45 |
+
# Add buffer and k=0 information to the deque.
|
| 46 |
+
with self._lock:
|
| 47 |
+
len_ = len(self._buffer)
|
| 48 |
+
if len_ == self.num_batches:
|
| 49 |
+
dropped_entry = self._buffer[0]
|
| 50 |
+
self._buffer.append([batch, 0])
|
| 51 |
+
self._num_valid_batches += 1
|
| 52 |
+
|
| 53 |
+
# A valid entry (w/ a batch whose k has not been reach K yet) was dropped.
|
| 54 |
+
if dropped_entry is not None and dropped_entry[0] is not None:
|
| 55 |
+
dropped_ts += dropped_entry[0].env_steps() * (
|
| 56 |
+
self.iterations_per_batch - dropped_entry[1]
|
| 57 |
+
)
|
| 58 |
+
self._num_valid_batches -= 1
|
| 59 |
+
|
| 60 |
+
return dropped_ts
|
| 61 |
+
|
| 62 |
+
def sample(self):
|
| 63 |
+
k = entry = batch = None
|
| 64 |
+
|
| 65 |
+
while True:
|
| 66 |
+
# Only initially, the buffer may be empty -> Just wait for some time.
|
| 67 |
+
if len(self) == 0:
|
| 68 |
+
time.sleep(0.001)
|
| 69 |
+
continue
|
| 70 |
+
# Sample a random buffer index.
|
| 71 |
+
with self._lock:
|
| 72 |
+
entry = self._buffer[random.randint(0, len(self._buffer) - 1)]
|
| 73 |
+
batch, k = entry
|
| 74 |
+
# Ignore batches that have already been invalidated.
|
| 75 |
+
if batch is not None:
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
# Increase k += 1 for this batch.
|
| 79 |
+
assert k is not None
|
| 80 |
+
entry[1] += 1
|
| 81 |
+
|
| 82 |
+
# This batch has been exhausted (k == K) -> Invalidate it in the buffer.
|
| 83 |
+
if k == self.iterations_per_batch - 1:
|
| 84 |
+
entry[0] = None
|
| 85 |
+
entry[1] = None
|
| 86 |
+
self._num_valid_batches += 1
|
| 87 |
+
|
| 88 |
+
# Return the sampled batch.
|
| 89 |
+
return batch
|
| 90 |
+
|
| 91 |
+
def __len__(self) -> int:
|
| 92 |
+
"""Returns the number of actually valid (non-expired) batches in the buffer."""
|
| 93 |
+
return self._num_valid_batches
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@OldAPIStack
|
| 97 |
+
def make_appo_models(policy) -> ModelV2:
|
| 98 |
+
"""Builds model and target model for APPO.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
ModelV2: The Model for the Policy to use.
|
| 102 |
+
Note: The target model will not be returned, just assigned to
|
| 103 |
+
`policy.target_model`.
|
| 104 |
+
"""
|
| 105 |
+
# Get the num_outputs for the following model construction calls.
|
| 106 |
+
_, logit_dim = ModelCatalog.get_action_dist(
|
| 107 |
+
policy.action_space, policy.config["model"]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Construct the (main) model.
|
| 111 |
+
policy.model = ModelCatalog.get_model_v2(
|
| 112 |
+
policy.observation_space,
|
| 113 |
+
policy.action_space,
|
| 114 |
+
logit_dim,
|
| 115 |
+
policy.config["model"],
|
| 116 |
+
name=POLICY_SCOPE,
|
| 117 |
+
framework=policy.framework,
|
| 118 |
+
)
|
| 119 |
+
policy.model_variables = policy.model.variables()
|
| 120 |
+
|
| 121 |
+
# Construct the target model.
|
| 122 |
+
policy.target_model = ModelCatalog.get_model_v2(
|
| 123 |
+
policy.observation_space,
|
| 124 |
+
policy.action_space,
|
| 125 |
+
logit_dim,
|
| 126 |
+
policy.config["model"],
|
| 127 |
+
name=TARGET_POLICY_SCOPE,
|
| 128 |
+
framework=policy.framework,
|
| 129 |
+
)
|
| 130 |
+
policy.target_model_variables = policy.target_model.variables()
|
| 131 |
+
|
| 132 |
+
# Return only the model (not the target model).
|
| 133 |
+
return policy.model
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (672 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3.cpython-311.pyc
ADDED
|
Binary file (32.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_catalog.cpython-311.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_learner.cpython-311.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/__pycache__/dreamerv3_rl_module.cpython-311.pyc
ADDED
|
Binary file (8.11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (213 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/actor_network.cpython-311.pyc
ADDED
|
Binary file (8.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/critic_network.cpython-311.pyc
ADDED
|
Binary file (8.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/disagree_networks.cpython-311.pyc
ADDED
|
Binary file (4.82 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/dreamer_model.cpython-311.pyc
ADDED
|
Binary file (25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/__pycache__/world_model.cpython-311.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/cnn_atari.cpython-311.pyc
ADDED
|
Binary file (4.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/continue_predictor.cpython-311.pyc
ADDED
|
Binary file (4.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/conv_transpose_atari.cpython-311.pyc
ADDED
|
Binary file (7.44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/dynamics_predictor.cpython-311.pyc
ADDED
|
Binary file (4.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/mlp.cpython-311.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/representation_layer.cpython-311.pyc
ADDED
|
Binary file (5.96 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor.cpython-311.pyc
ADDED
|
Binary file (5.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/reward_predictor_layer.cpython-311.pyc
ADDED
|
Binary file (4.96 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/sequence_model.cpython-311.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/__pycache__/vector_decoder.cpython-311.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
[1] Mastering Diverse Domains through World Models - 2023
|
| 3 |
+
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
|
| 4 |
+
https://arxiv.org/pdf/2301.04104v1.pdf
|
| 5 |
+
"""
|
| 6 |
+
from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP
|
| 7 |
+
from ray.rllib.algorithms.dreamerv3.utils import (
|
| 8 |
+
get_gru_units,
|
| 9 |
+
get_num_z_classes,
|
| 10 |
+
get_num_z_categoricals,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
| 13 |
+
|
| 14 |
+
_, tf, _ = try_import_tf()
|
| 15 |
+
tfp = try_import_tfp()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ContinuePredictor(tf.keras.Model):
|
| 19 |
+
"""The world-model network sub-component used to predict the `continue` flags .
|
| 20 |
+
|
| 21 |
+
Predicted continue flags are used to produce "dream data" to learn the policy in.
|
| 22 |
+
|
| 23 |
+
The continue flags are predicted via a linear output used to parameterize a
|
| 24 |
+
Bernoulli distribution, from which simply the mode is used (no stochastic
|
| 25 |
+
sampling!). In other words, if the sigmoid of the output of the linear layer is
|
| 26 |
+
>0.5, we predict a continuation of the episode, otherwise we predict an episode
|
| 27 |
+
terminal.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, *, model_size: str = "XS"):
|
| 31 |
+
"""Initializes a ContinuePredictor instance.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_size: The "Model Size" used according to [1] Appendinx B.
|
| 35 |
+
Determines the exact size of the underlying MLP.
|
| 36 |
+
"""
|
| 37 |
+
super().__init__(name="continue_predictor")
|
| 38 |
+
self.model_size = model_size
|
| 39 |
+
self.mlp = MLP(model_size=model_size, output_layer_size=1)
|
| 40 |
+
|
| 41 |
+
# Trace self.call.
|
| 42 |
+
dl_type = tf.keras.mixed_precision.global_policy().compute_dtype or tf.float32
|
| 43 |
+
self.call = tf.function(
|
| 44 |
+
input_signature=[
|
| 45 |
+
tf.TensorSpec(shape=[None, get_gru_units(model_size)], dtype=dl_type),
|
| 46 |
+
tf.TensorSpec(
|
| 47 |
+
shape=[
|
| 48 |
+
None,
|
| 49 |
+
get_num_z_categoricals(model_size),
|
| 50 |
+
get_num_z_classes(model_size),
|
| 51 |
+
],
|
| 52 |
+
dtype=dl_type,
|
| 53 |
+
),
|
| 54 |
+
]
|
| 55 |
+
)(self.call)
|
| 56 |
+
|
| 57 |
+
def call(self, h, z):
|
| 58 |
+
"""Performs a forward pass through the continue predictor.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
h: The deterministic hidden state of the sequence model. [B, dim(h)].
|
| 62 |
+
z: The stochastic discrete representations of the original
|
| 63 |
+
observation input. [B, num_categoricals, num_classes].
|
| 64 |
+
"""
|
| 65 |
+
# Flatten last two dims of z.
|
| 66 |
+
assert len(z.shape) == 3
|
| 67 |
+
z_shape = tf.shape(z)
|
| 68 |
+
z = tf.reshape(z, shape=(z_shape[0], -1))
|
| 69 |
+
assert len(z.shape) == 2
|
| 70 |
+
out = tf.concat([h, z], axis=-1)
|
| 71 |
+
out.set_shape(
|
| 72 |
+
[
|
| 73 |
+
None,
|
| 74 |
+
(
|
| 75 |
+
get_num_z_categoricals(self.model_size)
|
| 76 |
+
* get_num_z_classes(self.model_size)
|
| 77 |
+
+ get_gru_units(self.model_size)
|
| 78 |
+
),
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
# Send h-cat-z through MLP.
|
| 82 |
+
out = self.mlp(out)
|
| 83 |
+
# Remove the extra [B, 1] dimension at the end to get a proper Bernoulli
|
| 84 |
+
# distribution. Otherwise, tfp will think that the batch dims are [B, 1]
|
| 85 |
+
# where they should be just [B].
|
| 86 |
+
logits = tf.cast(tf.squeeze(out, axis=-1), tf.float32)
|
| 87 |
+
# Create the Bernoulli distribution object.
|
| 88 |
+
bernoulli = tfp.distributions.Bernoulli(logits=logits, dtype=tf.float32)
|
| 89 |
+
|
| 90 |
+
# Take the mode (greedy, deterministic "sample").
|
| 91 |
+
continue_ = bernoulli.mode()
|
| 92 |
+
|
| 93 |
+
# Return Bernoulli sample (whether to continue) OR (continue?, Bernoulli prob).
|
| 94 |
+
return continue_, bernoulli
|